pytorch官网上两个例程
caffe用起来太笨重了,最近转到pytorch,用起来实在不要太方便,上手也非常快,这里贴一下pytorch官网上的两个小例程,掌握一下它的用法:
例程一:利用nn 这个module构建网络,实现一个图像分类的小功能;
链接:http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
- # -*- coding:utf-8 -*-
- import torch
- from torch.autograd import Variable
- import torchvision
- import torchvision.transforms as transforms
- #数据预处理:转换为Tensor,归一化,设置训练集和验证集以及加载子进程数目
- transform = transforms.Compose([transforms.ToTensor() , transforms.Normalize((0.5 , 0.5 , 0.5) , (0.5 , 0.5 , 0.5))]) #前面参数是均值,后面是标准差
- trainset = torchvision.datasets.CIFAR10(root = './data' , train = True , download = True , transform = transform)
- trainloader = torch.utils.data.DataLoader(trainset , batch_size = 4 , shuffle = True , num_workers =2) #num_works = 2表示使用两个子进程加载数据
- testset = torchvision.datasets.CIFAR10(root = './data' , train = False , download = True , transform = transform)
- testloader = torch.utils.data.DataLoader(testset , batch_size = 4 , shuffle = True , num_workers = 2)
- classes = ('plane' , 'car' , 'bird' , 'cat' , 'deer' , 'dog' , 'frog' , 'horse' , 'ship' , 'truck')
- import matplotlib.pyplot as plt
- import numpy as np
- import pylab
- def imshow(img):
- img = img / 2 + 0.5
- npimg = img.numpy()
- plt.imshow(np.transpose(npimg , (1 , 2 , 0)))
- pylab.show()
- dataiter = iter(trainloader)
- images , labels = dataiter.next()
- for i in range(4):
- p = plt.subplot()
- p.set_title("label: %5s" % classes[labels[i]])
- imshow(images[i])
- #构建网络
- from torch.autograd import Variable
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- class Net(nn.Module):
- def __init__(self):
- super(Net , self).__init__()
- self.conv1 = nn.Conv2d(3 , 6 , 5)
- self.pool = nn.MaxPool2d(2 , 2)
- self.conv2 = nn.Conv2d(6 , 16 , 5)
- self.fc1 = nn.Linear(16 * 5 * 5 , 120)
- self.fc2 = nn.Linear(120 , 84)
- self.fc3 = nn.Linear(84 , 10)
- def forward(self , x):
- x = self.pool(F.relu(self.conv1(x)))
- x = self.pool(F.relu(self.conv2(x)))
- x = x.view(-1 , 16 * 5 * 5) #利用view函数使得conv2层输出的16*5*5维的特征图尺寸变为400大小从而方便后面的全连接层的连接
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = self.fc3(x)
- return x
- net = Net()
- net.cuda()
- #define loss function
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.SGD(net.parameters() , lr = 0.001 , momentum = 0.9)
- #train the Network
- for epoch in range(2):
- running_loss = 0.0
- for i , data in enumerate(trainloader , 0):
- inputs , labels = data
- inputs , labels = Variable(inputs.cuda()) , Variable(labels.cuda())
- optimizer.zero_grad()
- #forward + backward + optimizer
- outputs = net(inputs)
- loss = criterion(outputs , labels)
- loss.backward()
- optimizer.step()
- running_loss += loss.data[0]
- if i % 2000 == 1999:
- print('[%d , %5d] loss: %.3f' % (epoch + 1 , i + 1 , running_loss / 2000))
- running_loss = 0.0
- print('Finished Training')
- dataiter = iter(testloader)
- images , labels = dataiter.next()
- imshow(torchvision.utils.make_grid(images))
- print('GroundTruth:' , ' '.join(classes[labels[j]] for j in range(4)))
- outputs = net(Variable(images.cuda()))
- _ , predicted = torch.max(outputs.data , 1)
- print('Predicted: ' , ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
- correct = 0
- total = 0
- for data in testloader:
- images , labels = data
- outputs = net(Variable(images.cuda()))
- _ , predicted = torch.max(outputs.data , 1)
- correct += (predicted == labels.cuda()).sum()
- total += labels.size(0)
- print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
- class_correct = torch.ones(10).cuda()
- class_total = torch.ones(10).cuda()
- for data in testloader:
- images , labels = data
- outputs = net(Variable(images.cuda()))
- _ , predicted = torch.max(outputs.data , 1)
- c = (predicted == labels.cuda()).squeeze()
- #print(predicted.data[0])
- for i in range(4):
- label = labels[i]
- class_correct[label] += c[i]
- class_total[label] += 1
- for i in range(10):
- print('Accuracy of %5s : %2d %%' % (classes[i] , 100 * class_correct[i] / class_total[i]))
例程二:在resnet18的预训练模型上进行finetune,然后实现一个ants和bees的二分类功能:
链接:http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
- # -*- coding:utf-8 -*-
- from __future__ import print_function , division
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.optim import lr_scheduler
- from torch.autograd import Variable
- import numpy as np
- import torchvision
- from torchvision import datasets , models , transforms
- import matplotlib.pyplot as plt
- import time
- import os
- import pylab
- #data process
- data_transforms = {
- 'train' : transforms.Compose([
- transforms.RandomSizedCrop(224) ,
- transforms.RandomHorizontalFlip() ,
- transforms.ToTensor() ,
- transforms.Normalize([0.485 , 0.456 , 0.406] , [0.229 , 0.224 , 0.225])
- ]) ,
- 'val' : transforms.Compose([
- transforms.Scale(256) ,
- transforms.CenterCrop(224) ,
- transforms.ToTensor() ,
- transforms.Normalize([0.485 , 0.456 , 0.406] , [0.229 , 0.224 , 0.225])
- ]) ,
- }
- data_dir = 'hymenoptera_data'
- image_datasets = {x : datasets.ImageFolder(os.path.join(data_dir , x) , data_transforms[x]) for x in ['train' , 'val']}
- dataloders = {x : torch.utils.data.DataLoader(image_datasets[x] , batch_size = 4 , shuffle = True , num_workers = 4) for x in ['train' , 'val']}
- dataset_sizes = {x : len(image_datasets[x]) for x in ['train' , 'val']}
- class_names = image_datasets['train'].classes
- print(class_names)
- use_gpu = torch.cuda.is_available()
- #show several images
- def imshow(inp , title = None):
- inp = inp.numpy().transpose((1 , 2 , 0))
- mean = np.array([0.485 , 0.456 , 0.406])
- std = np.array([0.229 , 0.224 , 0.225])
- inp = std * inp + mean
- inp = np.clip(inp , 0 , 1)
- plt.imshow(inp)
- if title is not None:
- plt.title(title)
- pylab.show()
- plt.pause(0.001)
- inputs , classes = next(iter(dataloders['train']))
- out = torchvision.utils.make_grid(inputs)
- imshow(out , title = [class_names[x] for x in classes])
- #train the model
- def train_model(model , criterion , optimizer , scheduler , num_epochs = 25):
- since = time.time()
- best_model_wts = model.state_dict() #Returns a dictionary containing a whole state of the module.
- best_acc = 0.0
- for epoch in range(num_epochs):
- print('Epoch {}/{}'.format(epoch , num_epochs - 1))
- print('-' * 10)
- #set the mode of model
- for phase in ['train' , 'val']:
- if phase == 'train':
- scheduler.step() #about lr and gamma
- model.train(True) #set model to training mode
- else:
- model.train(False) #set model to evaluate mode
- running_loss = 0.0
- running_corrects = 0
- #Iterate over data
- for data in dataloders[phase]:
- inputs , labels = data
- if use_gpu:
- inputs = Variable(inputs.cuda())
- labels = Variable(labels.cuda())
- else:
- inputs = Variable(inputs)
- lables = Variable(labels)
- optimizer.zero_grad()
- #forward
- outputs = model(inputs)
- _ , preds = torch.max(outputs , 1)
- loss = criterion(outputs , labels)
- #backward
- if phase == 'train':
- loss.backward() #backward of gradient
- optimizer.step() #strategy to drop
- running_loss += loss.data[0]
- running_corrects += torch.sum(preds.data == labels.data)
- epoch_loss = running_loss / dataset_sizes[phase]
- epoch_acc = running_corrects / dataset_sizes[phase]
- print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase , epoch_loss , epoch_acc))
- if phase == 'val' and epoch_acc > best_acc:
- best_acc = epoch_acc
- best_model_wts = model.state_dict()
- print()
- time_elapsed = time.time() - since
- print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60 , time_elapsed % 60))
- print('Best val Acc: {:4f}'.format(best_acc))
- model.load_state_dict(best_model_wts)
- return model
- #visualizing the model predictions
- def visualize_model(model , num_images = 6):
- images_so_far = 0
- fig = plt.figure()
- for i , data in enumerate(dataloders['val']):
- inputs , labels = data
- if use_gpu:
- inputs , labels = Variable(inputs.cuda()) , Variable(labels.cuda())
- else:
- inputs , labels = Variable(inputs) , Variable(labels)
- outputs = model(inputs)
- _ , preds = torch.max(outputs.data , 1)
- for j in range(inputs.size()[0]):
- images_so_far += 1
- ax = plt.subplot(num_images // 2 , 2 , images_so_far)
- ax.axis('off')
- ax.set_title('predicted: {}'.format(class_names[preds[j]]))
- imshow(inputs.cpu().data[j])
- if images_so_far == num_images:
- return
- #Finetuning the convnet
- from torchvision.models.resnet import model_urls
- model_urls['resnet18'] = model_urls['resnet18'].replace('https://' , 'http://')
- model_ft = models.resnet18(pretrained = True)
- num_ftrs = model_ft.fc.in_features
- model_ft.fc = nn.Linear(num_ftrs , 2)
- if use_gpu:
- model_ft = model_ft.cuda()
- criterion = nn.CrossEntropyLoss()
- optimizer_ft = optim.SGD(model_ft.parameters() , lr = 0.001 , momentum = 0.9)
- exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft , step_size = 7 , gamma = 0.1)
- #start finetuning
- model_ft = train_model(model_ft , criterion , optimizer_ft , exp_lr_scheduler , num_epochs = 25)
- torch.save(model_ft.state_dict() , '/home/zf/resnet18.pth')
- visualize_model(model_ft)
当然finetune的话有两种方式:在这个例子里
(1)只修改最后一层全连接层,输出类数改为2,然后在预训练模型上进行finetune;
(2)固定全连接层前面的卷积层参数,也就是它们不反向传播,只对最后一层进行反向传播;实现的时候前面这些层的requires_grad就设为False就OK了;
代码见下:
- model_conv = torchvision.models.resnet18(pretrained=True)
- for param in model_conv.parameters():
- param.requires_grad = False
- # Parameters of newly constructed modules have requires_grad=True by default
- num_ftrs = model_conv.fc.in_features
- model_conv.fc = nn.Linear(num_ftrs, 2)
- if use_gpu:
- model_conv = model_conv.cuda()
- criterion = nn.CrossEntropyLoss()
- # Observe that only parameters of final layer are being optimized as
- # opoosed to before.
- optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
- # Decay LR by a factor of 0.1 every 7 epochs
- exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
- model_conv = train_model(model_conv, criterion, optimizer_conv,
- exp_lr_scheduler, num_epochs=25)
可以说,从构建网络,到训练网络,再到测试,由于完全是python风格,实在是太方便了~
pytorch官网上两个例程的更多相关文章
- spring官网上下载历史版本的spring插件,springsource-tool-suite
spring官网下载地址(https://spring.io/tools/sts/all),历史版本地址(https://spring.io/tools/sts/legacy). 注:历史版本下载的都 ...
- jquery ui中的dialog,官网上经典的例子
jquery ui中的dialog,官网上经典的例子 jquery ui中dialog和easy ui中的dialog很像,但是最近用到的时候全然没有印象,一段时间不用就忘记了,这篇随笔介绍一下这 ...
- [pytorch] 官网教程+注释
pytorch官网教程+注释 Classifier import torch import torchvision import torchvision.transforms as transform ...
- iOS开发:创建推送开发证书和生产证书,以及往极光推送官网上传证书的步骤方法
在极光官网上面上传应用的极光推送证书的实质其实就是上传导出的p12文件,在极光推送应用管理里面,需要上传两个p12文件,一个是生产证书,一个是开发证书 ,缺一不可,具体如下所示: 在开发者账号里面创建 ...
- 自己封装的Windows7 64位旗舰版,微软官网上下载的Windows7原版镜像制作,绝对纯净版
MSDN官网上下载的Windows7 64位 旗舰版原版镜像制作,绝对纯净版,无任何精简,不捆绑任何第三方软件.浏览器插件,不含任何木马.病毒等. 集成: 1.Office2010 2.DirectX ...
- 关于在官网上查看和下载特定版本的webrtc代码
注:这个方法已经不适用了,帖子没删只是留个纪念而已 gclient:如果不知道gclient是什么东西 ... 就别再往下看了. 下载特定版本的代码: #gclient sync --revision ...
- echarts官网上的动态加载数据bug被我解决。咳咳/。
又是昨天,为什么昨天发生了这么多事.没办法,谁让我今天没事可做呢. 昨天需求是动态加载数据,画一个实时监控的折线图.大概长这样. 我屁颠屁颠的把代码copy过来,一运行,caocaocao~bug出现 ...
- 训练DCGAN(pytorch官网版本)
将pytorch官网的python代码当下来,然后下载好celeba数据集(百度网盘),在代码旁新建celeba文件夹,将解压后的img_align_celeba文件夹放进去,就可以运行代码了. 输出 ...
- Jenkins利用官网上的rpm源安装
官网网址:https://pkg.jenkins.io/redhat/ (官网上有安装的命令,参考网址) 安装jdk yum install -y java-1.8.0- ...
随机推荐
- 【BZOJ4815】[CQOI2017]小Q的表格(莫比乌斯反演,分块)
[BZOJ4815][CQOI2017]小Q的表格(莫比乌斯反演,分块) 题面 BZOJ 洛谷 题解 神仙题啊. 首先\(f(a,b)=f(b,a)\)告诉我们矩阵只要算一半就好了. 接下来是\(b* ...
- 【BZOJ2426】[HAOI2010]工厂选址(贪心)
[BZOJ2426][HAOI2010]工厂选址(贪心) 题面 BZOJ 洛谷 题解 首先看懂题目到底在做什么. 然而发现我们显然可以对于每个备选位置跑一遍费用流,然后并不够优秀. 不难发现所有的位置 ...
- [hgoi#2019/2/16t4]transform
题目描述 植物学家Dustar培养出了一棵神奇的树,这棵有根树有n个节点,每个节点上都有一个数字a[i],而且这棵树的根为r节点. 这棵树非常神奇,可以随意转换根的位置,上一秒钟它的根是x节点,下一秒 ...
- 【SDOI 2017】龙与地下城(组合)
概率论太难了,不会.但这不能阻止我们过题.相信大家都会一个基于背包的暴力做法,我们可以将其看成是卷积的形式就可以用fft优化了.形式化讲,就是求幂级数$ (\sum\limits_{i = 0}^{x ...
- python 线程,进程28原则
基于函数实现 from threading import Thread def fun(data, *args, **kwargs): """ :param data: ...
- ffmpeg 在ubuntu上编译环境搭建和开发
步骤如下: 1. 下载 官网永远是王道,呵呵:http://ffmpeg.org/download.html 或者 svn checkout svn://svn.mplayerhq.hu/ffmpeg ...
- 百度地图infoWindow圆角处理
最近的一个项目用到了百度地图API里边的infoWindow弹框,但是百度自带的infoWindow弹框是个直角的矩形框,显示过于难看,于是有了将该框改为圆角的想法,但是API本身不支持样式的设置,所 ...
- JAVA过滤器的使用(Filter)
request.setCharacterEncoding("utf-8"); response.setContentType("text/html;charset=utf ...
- 使用react封装评论组件
首先看我的效果图 我在评论框中输入数据,会在页面进行显示 这个效果图我们进行拆分就是,一个评论组件,一个大的评论列表组件,一个小的评论组件 首先整个页面中有的是我们的评论组件和列表组件 我们输入评论点 ...
- Qsort(c)_Sort(c++)用法
Sort函数(c) (来自codeblocks) stdlib.h _CRTIMP void __cdecl qsort(void*, size_t, size_t, int (*)(const vo ...