caffe用起来太笨重了,最近转到pytorch,用起来实在不要太方便,上手也非常快,这里贴一下pytorch官网上的两个小例程,掌握一下它的用法:

例程一:利用nn  这个module构建网络,实现一个图像分类的小功能;

链接:http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

  1. # -*- coding:utf-8 -*-
  2. import torch
  3. from torch.autograd import Variable
  4. import torchvision
  5. import torchvision.transforms as transforms
  6. #数据预处理:转换为Tensor,归一化,设置训练集和验证集以及加载子进程数目
  7. transform = transforms.Compose([transforms.ToTensor() , transforms.Normalize((0.5 , 0.5 , 0.5) , (0.5 , 0.5 , 0.5))]) #前面参数是均值,后面是标准差
  8. trainset = torchvision.datasets.CIFAR10(root = './data' , train = True , download = True , transform = transform)
  9. trainloader = torch.utils.data.DataLoader(trainset , batch_size = 4 , shuffle = True , num_workers =2) #num_works = 2表示使用两个子进程加载数据
  10. testset = torchvision.datasets.CIFAR10(root = './data' , train = False , download = True , transform = transform)
  11. testloader = torch.utils.data.DataLoader(testset , batch_size = 4 , shuffle = True , num_workers = 2)
  12. classes = ('plane' , 'car' , 'bird' , 'cat' , 'deer' , 'dog' , 'frog' , 'horse' , 'ship' , 'truck')
  13.  
  14. import matplotlib.pyplot as plt
  15. import numpy as np
  16. import pylab
  17.  
  18. def imshow(img):
  19. img = img / 2 + 0.5
  20. npimg = img.numpy()
  21. plt.imshow(np.transpose(npimg , (1 , 2 , 0)))
  22. pylab.show()
  23.  
  24. dataiter = iter(trainloader)
  25. images , labels = dataiter.next()
  26. for i in range(4):
  27. p = plt.subplot()
  28. p.set_title("label: %5s" % classes[labels[i]])
  29. imshow(images[i])
  30. #构建网络
  31. from torch.autograd import Variable
  32. import torch.nn as nn
  33. import torch.nn.functional as F
  34. import torch.optim as optim
  35.  
  36. class Net(nn.Module):
  37. def __init__(self):
  38. super(Net , self).__init__()
  39. self.conv1 = nn.Conv2d(3 , 6 , 5)
  40. self.pool = nn.MaxPool2d(2 , 2)
  41. self.conv2 = nn.Conv2d(6 , 16 , 5)
  42. self.fc1 = nn.Linear(16 * 5 * 5 , 120)
  43. self.fc2 = nn.Linear(120 , 84)
  44. self.fc3 = nn.Linear(84 , 10)
  45.  
  46. def forward(self , x):
  47. x = self.pool(F.relu(self.conv1(x)))
  48. x = self.pool(F.relu(self.conv2(x)))
  49. x = x.view(-1 , 16 * 5 * 5) #利用view函数使得conv2层输出的16*5*5维的特征图尺寸变为400大小从而方便后面的全连接层的连接
  50. x = F.relu(self.fc1(x))
  51. x = F.relu(self.fc2(x))
  52. x = self.fc3(x)
  53. return x
  54.  
  55. net = Net()
  56. net.cuda()
  57.  
  58. #define loss function
  59. criterion = nn.CrossEntropyLoss()
  60. optimizer = optim.SGD(net.parameters() , lr = 0.001 , momentum = 0.9)
  61.  
  62. #train the Network
  63. for epoch in range(2):
  64. running_loss = 0.0
  65. for i , data in enumerate(trainloader , 0):
  66. inputs , labels = data
  67. inputs , labels = Variable(inputs.cuda()) , Variable(labels.cuda())
  68. optimizer.zero_grad()
  69. #forward + backward + optimizer
  70. outputs = net(inputs)
  71. loss = criterion(outputs , labels)
  72. loss.backward()
  73. optimizer.step()
  74.  
  75. running_loss += loss.data[0]
  76. if i % 2000 == 1999:
  77. print('[%d , %5d] loss: %.3f' % (epoch + 1 , i + 1 , running_loss / 2000))
  78. running_loss = 0.0
  79. print('Finished Training')
  80.  
  81. dataiter = iter(testloader)
  82. images , labels = dataiter.next()
  83. imshow(torchvision.utils.make_grid(images))
  84. print('GroundTruth:' , ' '.join(classes[labels[j]] for j in range(4)))
  85.  
  86. outputs = net(Variable(images.cuda()))
  87.  
  88. _ , predicted = torch.max(outputs.data , 1)
  89. print('Predicted: ' , ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  90.  
  91. correct = 0
  92. total = 0
  93. for data in testloader:
  94. images , labels = data
  95. outputs = net(Variable(images.cuda()))
  96. _ , predicted = torch.max(outputs.data , 1)
  97. correct += (predicted == labels.cuda()).sum()
  98. total += labels.size(0)
  99. print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
  100.  
  101. class_correct = torch.ones(10).cuda()
  102. class_total = torch.ones(10).cuda()
  103. for data in testloader:
  104. images , labels = data
  105. outputs = net(Variable(images.cuda()))
  106. _ , predicted = torch.max(outputs.data , 1)
  107. c = (predicted == labels.cuda()).squeeze()
  108. #print(predicted.data[0])
  109. for i in range(4):
  110. label = labels[i]
  111. class_correct[label] += c[i]
  112. class_total[label] += 1
  113.  
  114. for i in range(10):
  115. print('Accuracy of %5s : %2d %%' % (classes[i] , 100 * class_correct[i] / class_total[i]))

例程二:在resnet18的预训练模型上进行finetune,然后实现一个ants和bees的二分类功能:

链接:http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

  1. # -*- coding:utf-8 -*-
  2. from __future__ import print_function , division
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.optim import lr_scheduler
  7. from torch.autograd import Variable
  8. import numpy as np
  9. import torchvision
  10. from torchvision import datasets , models , transforms
  11. import matplotlib.pyplot as plt
  12. import time
  13. import os
  14. import pylab
  15.  
  16. #data process
  17. data_transforms = {
  18. 'train' : transforms.Compose([
  19. transforms.RandomSizedCrop(224) ,
  20. transforms.RandomHorizontalFlip() ,
  21. transforms.ToTensor() ,
  22. transforms.Normalize([0.485 , 0.456 , 0.406] , [0.229 , 0.224 , 0.225])
  23. ]) ,
  24. 'val' : transforms.Compose([
  25. transforms.Scale(256) ,
  26. transforms.CenterCrop(224) ,
  27. transforms.ToTensor() ,
  28. transforms.Normalize([0.485 , 0.456 , 0.406] , [0.229 , 0.224 , 0.225])
  29. ]) ,
  30. }
  31.  
  32. data_dir = 'hymenoptera_data'
  33. image_datasets = {x : datasets.ImageFolder(os.path.join(data_dir , x) , data_transforms[x]) for x in ['train' , 'val']}
  34. dataloders = {x : torch.utils.data.DataLoader(image_datasets[x] , batch_size = 4 , shuffle = True , num_workers = 4) for x in ['train' , 'val']}
  35. dataset_sizes = {x : len(image_datasets[x]) for x in ['train' , 'val']}
  36. class_names = image_datasets['train'].classes
  37. print(class_names)
  38. use_gpu = torch.cuda.is_available()
  39. #show several images
  40. def imshow(inp , title = None):
  41. inp = inp.numpy().transpose((1 , 2 , 0))
  42. mean = np.array([0.485 , 0.456 , 0.406])
  43. std = np.array([0.229 , 0.224 , 0.225])
  44. inp = std * inp + mean
  45. inp = np.clip(inp , 0 , 1)
  46. plt.imshow(inp)
  47. if title is not None:
  48. plt.title(title)
  49. pylab.show()
  50. plt.pause(0.001)
  51.  
  52. inputs , classes = next(iter(dataloders['train']))
  53. out = torchvision.utils.make_grid(inputs)
  54. imshow(out , title = [class_names[x] for x in classes])
  55. #train the model
  56. def train_model(model , criterion , optimizer , scheduler , num_epochs = 25):
  57.  
  58. since = time.time()
  59. best_model_wts = model.state_dict() #Returns a dictionary containing a whole state of the module.
  60. best_acc = 0.0
  61.  
  62. for epoch in range(num_epochs):
  63. print('Epoch {}/{}'.format(epoch , num_epochs - 1))
  64. print('-' * 10)
  65. #set the mode of model
  66. for phase in ['train' , 'val']:
  67. if phase == 'train':
  68. scheduler.step() #about lr and gamma
  69. model.train(True) #set model to training mode
  70. else:
  71. model.train(False) #set model to evaluate mode
  72.  
  73. running_loss = 0.0
  74. running_corrects = 0
  75.  
  76. #Iterate over data
  77. for data in dataloders[phase]:
  78. inputs , labels = data
  79. if use_gpu:
  80. inputs = Variable(inputs.cuda())
  81. labels = Variable(labels.cuda())
  82. else:
  83. inputs = Variable(inputs)
  84. lables = Variable(labels)
  85. optimizer.zero_grad()
  86. #forward
  87. outputs = model(inputs)
  88. _ , preds = torch.max(outputs , 1)
  89. loss = criterion(outputs , labels)
  90. #backward
  91. if phase == 'train':
  92. loss.backward() #backward of gradient
  93. optimizer.step() #strategy to drop
  94. running_loss += loss.data[0]
  95. running_corrects += torch.sum(preds.data == labels.data)
  96.  
  97. epoch_loss = running_loss / dataset_sizes[phase]
  98. epoch_acc = running_corrects / dataset_sizes[phase]
  99. print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase , epoch_loss , epoch_acc))
  100.  
  101. if phase == 'val' and epoch_acc > best_acc:
  102. best_acc = epoch_acc
  103. best_model_wts = model.state_dict()
  104. print()
  105.  
  106. time_elapsed = time.time() - since
  107. print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60 , time_elapsed % 60))
  108. print('Best val Acc: {:4f}'.format(best_acc))
  109. model.load_state_dict(best_model_wts)
  110. return model
  111.  
  112. #visualizing the model predictions
  113. def visualize_model(model , num_images = 6):
  114. images_so_far = 0
  115. fig = plt.figure()
  116.  
  117. for i , data in enumerate(dataloders['val']):
  118. inputs , labels = data
  119. if use_gpu:
  120. inputs , labels = Variable(inputs.cuda()) , Variable(labels.cuda())
  121. else:
  122. inputs , labels = Variable(inputs) , Variable(labels)
  123.  
  124. outputs = model(inputs)
  125. _ , preds = torch.max(outputs.data , 1)
  126. for j in range(inputs.size()[0]):
  127. images_so_far += 1
  128. ax = plt.subplot(num_images // 2 , 2 , images_so_far)
  129. ax.axis('off')
  130. ax.set_title('predicted: {}'.format(class_names[preds[j]]))
  131. imshow(inputs.cpu().data[j])
  132.  
  133. if images_so_far == num_images:
  134. return
  135.  
  136. #Finetuning the convnet
  137. from torchvision.models.resnet import model_urls
  138. model_urls['resnet18'] = model_urls['resnet18'].replace('https://' , 'http://')
  139. model_ft = models.resnet18(pretrained = True)
  140. num_ftrs = model_ft.fc.in_features
  141. model_ft.fc = nn.Linear(num_ftrs , 2)
  142. if use_gpu:
  143. model_ft = model_ft.cuda()
  144. criterion = nn.CrossEntropyLoss()
  145. optimizer_ft = optim.SGD(model_ft.parameters() , lr = 0.001 , momentum = 0.9)
  146. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft , step_size = 7 , gamma = 0.1)
  147. #start finetuning
  148. model_ft = train_model(model_ft , criterion , optimizer_ft , exp_lr_scheduler , num_epochs = 25)
  149. torch.save(model_ft.state_dict() , '/home/zf/resnet18.pth')
  150. visualize_model(model_ft)

当然finetune的话有两种方式:在这个例子里

(1)只修改最后一层全连接层,输出类数改为2,然后在预训练模型上进行finetune;

(2)固定全连接层前面的卷积层参数,也就是它们不反向传播,只对最后一层进行反向传播;实现的时候前面这些层的requires_grad就设为False就OK了;

代码见下:

  1. model_conv = torchvision.models.resnet18(pretrained=True)
  2. for param in model_conv.parameters():
  3. param.requires_grad = False
  4.  
  5. # Parameters of newly constructed modules have requires_grad=True by default
  6. num_ftrs = model_conv.fc.in_features
  7. model_conv.fc = nn.Linear(num_ftrs, 2)
  8.  
  9. if use_gpu:
  10. model_conv = model_conv.cuda()
  11.  
  12. criterion = nn.CrossEntropyLoss()
  13.  
  14. # Observe that only parameters of final layer are being optimized as
  15. # opoosed to before.
  16. optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
  17.  
  18. # Decay LR by a factor of 0.1 every 7 epochs
  19. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
  20. model_conv = train_model(model_conv, criterion, optimizer_conv,
  21. exp_lr_scheduler, num_epochs=25)

可以说,从构建网络,到训练网络,再到测试,由于完全是python风格,实在是太方便了~

pytorch官网上两个例程的更多相关文章

  1. spring官网上下载历史版本的spring插件,springsource-tool-suite

    spring官网下载地址(https://spring.io/tools/sts/all),历史版本地址(https://spring.io/tools/sts/legacy). 注:历史版本下载的都 ...

  2. jquery ui中的dialog,官网上经典的例子

    jquery ui中的dialog,官网上经典的例子   jquery ui中dialog和easy ui中的dialog很像,但是最近用到的时候全然没有印象,一段时间不用就忘记了,这篇随笔介绍一下这 ...

  3. [pytorch] 官网教程+注释

    pytorch官网教程+注释 Classifier import torch import torchvision import torchvision.transforms as transform ...

  4. iOS开发:创建推送开发证书和生产证书,以及往极光推送官网上传证书的步骤方法

    在极光官网上面上传应用的极光推送证书的实质其实就是上传导出的p12文件,在极光推送应用管理里面,需要上传两个p12文件,一个是生产证书,一个是开发证书 ,缺一不可,具体如下所示: 在开发者账号里面创建 ...

  5. 自己封装的Windows7 64位旗舰版,微软官网上下载的Windows7原版镜像制作,绝对纯净版

    MSDN官网上下载的Windows7 64位 旗舰版原版镜像制作,绝对纯净版,无任何精简,不捆绑任何第三方软件.浏览器插件,不含任何木马.病毒等. 集成: 1.Office2010 2.DirectX ...

  6. 关于在官网上查看和下载特定版本的webrtc代码

    注:这个方法已经不适用了,帖子没删只是留个纪念而已 gclient:如果不知道gclient是什么东西 ... 就别再往下看了. 下载特定版本的代码: #gclient sync --revision ...

  7. echarts官网上的动态加载数据bug被我解决。咳咳/。

    又是昨天,为什么昨天发生了这么多事.没办法,谁让我今天没事可做呢. 昨天需求是动态加载数据,画一个实时监控的折线图.大概长这样. 我屁颠屁颠的把代码copy过来,一运行,caocaocao~bug出现 ...

  8. 训练DCGAN(pytorch官网版本)

    将pytorch官网的python代码当下来,然后下载好celeba数据集(百度网盘),在代码旁新建celeba文件夹,将解压后的img_align_celeba文件夹放进去,就可以运行代码了. 输出 ...

  9. Jenkins利用官网上的rpm源安装

    官网网址:https://pkg.jenkins.io/redhat/                (官网上有安装的命令,参考网址) 安装jdk yum install -y java-1.8.0- ...

随机推荐

  1. 【BZOJ4815】[CQOI2017]小Q的表格(莫比乌斯反演,分块)

    [BZOJ4815][CQOI2017]小Q的表格(莫比乌斯反演,分块) 题面 BZOJ 洛谷 题解 神仙题啊. 首先\(f(a,b)=f(b,a)\)告诉我们矩阵只要算一半就好了. 接下来是\(b* ...

  2. 【BZOJ2426】[HAOI2010]工厂选址(贪心)

    [BZOJ2426][HAOI2010]工厂选址(贪心) 题面 BZOJ 洛谷 题解 首先看懂题目到底在做什么. 然而发现我们显然可以对于每个备选位置跑一遍费用流,然后并不够优秀. 不难发现所有的位置 ...

  3. [hgoi#2019/2/16t4]transform

    题目描述 植物学家Dustar培养出了一棵神奇的树,这棵有根树有n个节点,每个节点上都有一个数字a[i],而且这棵树的根为r节点. 这棵树非常神奇,可以随意转换根的位置,上一秒钟它的根是x节点,下一秒 ...

  4. 【SDOI 2017】龙与地下城(组合)

    概率论太难了,不会.但这不能阻止我们过题.相信大家都会一个基于背包的暴力做法,我们可以将其看成是卷积的形式就可以用fft优化了.形式化讲,就是求幂级数$ (\sum\limits_{i = 0}^{x ...

  5. python 线程,进程28原则

    基于函数实现 from threading import Thread def fun(data, *args, **kwargs): """ :param data: ...

  6. ffmpeg 在ubuntu上编译环境搭建和开发

    步骤如下: 1. 下载 官网永远是王道,呵呵:http://ffmpeg.org/download.html 或者 svn checkout svn://svn.mplayerhq.hu/ffmpeg ...

  7. 百度地图infoWindow圆角处理

    最近的一个项目用到了百度地图API里边的infoWindow弹框,但是百度自带的infoWindow弹框是个直角的矩形框,显示过于难看,于是有了将该框改为圆角的想法,但是API本身不支持样式的设置,所 ...

  8. JAVA过滤器的使用(Filter)

    request.setCharacterEncoding("utf-8"); response.setContentType("text/html;charset=utf ...

  9. 使用react封装评论组件

    首先看我的效果图 我在评论框中输入数据,会在页面进行显示 这个效果图我们进行拆分就是,一个评论组件,一个大的评论列表组件,一个小的评论组件 首先整个页面中有的是我们的评论组件和列表组件 我们输入评论点 ...

  10. Qsort(c)_Sort(c++)用法

    Sort函数(c) (来自codeblocks) stdlib.h _CRTIMP void __cdecl qsort(void*, size_t, size_t, int (*)(const vo ...