『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
# Author : Hellcat
# Time : 2018/2/11 import torch as t
import torch.nn as nn
import torch.nn.functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10) def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x if __name__ == "__main__":
net = LeNet() # #########训练网络#########
from torch import optim
# 初始化Loss函数 & 优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) for epoch in range(2):
running_loss = 0.0
for step, data in enumerate(trainloader, 0): # step为训练次数, trainloader包含batch的数据和标签
inputs, labels = data
inputs, labels = t.autograd.Variable(inputs), t.autograd.Variable(labels) # 梯度清零
optimizer.zero_grad() # forward
outputs = net(inputs)
# backward
loss = loss_fn(outputs, labels)
loss.backward()
# update
optimizer.step() running_loss += loss.data[0]
if step % 2000 == 1999:
print("[{0:d}, {1:5d}] loss: {2:3f}".format(epoch+1, step+1, running_loss/2000))
running_loss = 0.
print("Finished Training")
这是使用LeNet分类cifar_10的例子,数据处理部分由于不是重点,没有列上来,主要是对使用torch分类有一个直观理解,
初始化网络
初始化Loss函数 & 优化器
进入step循环:
梯度清零
向前传播
计算本次Loss
向后传播
更新参数
由于pytorch的网络是class,所以在不考虑持久化的情况下,后续处理都不是太难,值得一提的是预测函数,我们直接net(Variable(test_data))即可,输出是概率分布的Variable,我们只要调用:
_, predict = t.max(test_out, 1)
即可,这是因为当指定了dim时,torch.max会融合max和argmax的功能,
>> a = torch.randn(4, 4)
>> a
0.0692 0.3142 1.2513 -0.5428
0.9288 0.8552 -0.2073 0.6409
1.0695 -0.0101 -2.4507 -1.2230
0.7426 -0.7666 0.4862 -0.6628
torch.FloatTensor of size 4x4]
>>> torch.max(a, 1)
(
1.2513
0.9288
1.0695
0.7426
[torch.FloatTensor of size 4]
,
2
0
0
0
[torch.LongTensor of size 4]
)
其他torch的高级功能没有使用到,本篇的目的是对于torch神经网络基本的使用有个理解。
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下的更多相关文章
- 『MXNet』第四弹_Gluon自定义层
一.不含参数层 通过继承Block自定义了一个将输入减掉均值的层:CenteredLayer类,并将层的计算放在forward函数里, from mxnet import nd, gluon from ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...
- 『PyTorch』第三弹重置_Variable对象
『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...
- 『TensorFlow』第七弹_保存&载入会话_霸王回马
首更: 由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe ...
- 关于『进击的Markdown』:第四弹
关于『进击的Markdown』:第四弹 建议缩放90%食用 美人鱼(Mermaid)悄悄的来,又悄悄的走,挥一挥匕首,不留一个活口 又是漫漫画图路... 女士们先生们,大家好! 我们要接受Markd ...
- 关于『HTML』:第三弹
关于『HTML』:第三弹 建议缩放90%食用 盼望着, 盼望着, 第三弹来了, HTML基础系列完结了!! 一切都像刚睡醒的样子(包括我), 欣欣然张开了眼(我没有) 敬请期待Markdown语法系列 ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法
在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...
- 『PyTorch』第五弹_深入理解autograd_下:函数扩展&高阶导数
一.封装新的PyTorch函数 继承Function类 forward:输入Variable->中间计算Tensor->输出Variable backward:均使用Variable 线性 ...
随机推荐
- 精力管理 | 迅速恢复精力的N个技巧,四个关键词以及自我管理的方法和工具列表
精力管理 | 迅速恢复精力的N个技巧,所谓坚持,是坚定的“持有”,这个“持”字很值得琢磨——不是扛.不是顶,而是“持”这样一个半放松的状态.如果你没做好自己该做的事情,如果你自己没有成长起来,随着年龄 ...
- Js基础知识6-JavaScript匿名函数和闭包
匿名函数 1,把匿名函数赋值给变量 var test = function() { return 'guoyu'; }; alert(test);//test是个函数 alert(test()); 2 ...
- 20145106《Java程序设计》第7周学习总结
教材学习内容总结 使用Lambda的特性可以去除重复的信息,以取得语法的简洁,增加程序代码的表达性.Lambda表达式本身是中性的,不代表任何类型的实例,同样的Lambda表达式,可用来表示不同目标类 ...
- uva 1658 Admiral - 费用流
vjudge传送门[here] 题目大意:给一个有(3≤v≤1000)个点e(3≤e≤10000)条边的有向加权图,求1~v的两条不相交(除了起点和终点外没有公共点)的路径,使权值和最小. 正解是吧2 ...
- 设置控件如ImageButton可见与否
继承view的控件有三种ui属性: 1.setVisibility(View.Gone); 不可见,不占有空间 2.setVisibility(View.VISIBLE); 可见 3.setVisib ...
- Linux驱动模块的Makefile分析【转】
本文转载自:http://blog.chinaunix.net/uid-29307109-id-3993784.html 1. 获取内核版本 当设备驱动需要同时支持不同版本内核时,在编译阶段,内核模块 ...
- linux网络编程--网络编程的基本函数介绍与使用【转】
本文转载自:http://blog.csdn.net/yusiguyuan/article/details/17538499 我们深谙信息交流的价值,那网络中进程之间如何通信,如我们每天打开浏览器浏览 ...
- 2016年蓝桥杯B组C/C++省赛(预选赛)试题
2016年蓝桥杯B组C/C++ 点击查看2016年蓝桥杯B组省赛题目解析(答案) 第一题:煤球数目 有一堆煤球,堆成三角棱锥形.具体: 第一层放1个, 第二层3个(排列成三角形), 第三层6个(排列成 ...
- HDU 2204 Eddy's爱好(容斥原理dfs写法)题解
题意:定义如果一个数能表示为M^k,那么这个数是好数,问你1~n有几个好数. 思路:如果k是合数,显然会有重复,比如a^(b*c) == (a^b)^c,那么我们打个素数表,指数只枚举素数,2^60 ...
- 《算法竞赛入门经典》习题及反思 -<2>
数组 Master-Mind Hints,Uva 340 题目:给定答案序列和用户猜的序列,统计有多少数字对应正确(A),有多少数字在两个序列都出现过但位置不对. 输入包括多组数据.每组输入第一行为序 ...