在net.py里面构造网络,网络的结构为输入为28*28,第一层隐藏层的输出为300, 第二层输出的输出为100, 最后一层的输出层为10,

net.py

import torch
from torch import nn class Batch_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Batch_Net, self).__init__()
self.layer_1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))
self.layer_2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))
self.output = nn.Sequential(nn.Linear(n_hidden_2, out_dim)) def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.output(x)
return x

main.py 进行网络的训练

import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms import net batch_size = 128 # 每一个batch_size的大小
learning_rate = 1e-2 # 学习率的大小
num_epoches = 20 # 迭代的epoch值
# 表示data将数据变成0, 1之间,0.5, 0.5表示减去均值处以标准差
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # 表示均值和标准差
# 获得训练集的数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
# 获得测试集的数据
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf, download=True)
# 获得训练集的可迭代队列
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 获得测试集的可迭代队列
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 构造模型的网络
model = net.Batch_Net(28*28, 300, 100, 10)
if torch.cuda.is_available(): # 如果有cuda就将模型放在GPU上
model.cuda() criterion = nn.CrossEntropyLoss() # 构造交叉损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # 构造模型的优化器 for epoch in range(num_epoches): # 迭代的epoch
train_loss = 0 # 训练的损失值
test_loss = 0 # 测试的损失值
eval_acc = 0 # 测试集的准确率
for data in train_loader: # 获得一个batch的样本
img, label = data # 获得图片和标签
img = img.view(img.size(0), -1) # 将图片进行img的转换
if torch.cuda.is_available(): # 如果存在torch
img = Variable(img).cuda() # 将图片放在torch上
label = Variable(label).cuda() # 将标签放在torch上
else:
img = Variable(img) # 构造img的变量
label = Variable(label)
optimizer.zero_grad() # 消除optimizer的梯度
out = model.forward(img) # 进行前向传播
loss = criterion(out, label) # 计算损失值
loss.backward() # 进行损失值的后向传播
optimizer.step() # 进行优化器的优化
train_loss += loss.data #
for data in test_loader:
img, label = data
img = img.view(img.size(0), -1)
if torch.cuda.is_available():
img = Variable(img, volatile=True).cuda()
label = Variable(label, volatile=True).cuda()
else:
img = Variable(img, volatile=True)
label = Variable(label, volatile=True)
out = model.forward(img)
loss = criterion(out, label)
test_loss += loss.data
top_p, top_class = out.topk(1, dim=1) # 获得输出的每一个样本的最大损失
equals = top_class == label.view(*top_class.shape) # 判断两组样本的标签是否相等
accuracy = torch.mean(equals.type(torch.FloatTensor)) # 计算准确率
eval_acc += accuracy
print('train_loss{:.6f}, test_loss{:.6f}, Acc:{:.6f}'.format(train_loss / len(train_loader), test_loss / len(test_loader), eval_acc / len(test_loader)))

pytorch-mnist神经网络训练的更多相关文章

  1. tensorflow中使用mnist数据集训练全连接神经网络-学习笔记

    tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...

  2. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  3. PyTorch Tutorials 4 训练一个分类器

    %matplotlib inline 训练一个分类器 上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重. 你现在可能在想下一步. 关于数据? 一般情况下处理图像.文本.音频和视频数据 ...

  4. Pytorch多GPU训练

    Pytorch多GPU训练 临近放假, 服务器上的GPU好多空闲, 博主顺便研究了一下如何用多卡同时训练 原理 多卡训练的基本过程 首先把模型加载到一个主设备 把模型只读复制到多个设备 把大的batc ...

  5. 使用pytorch构建神经网络的流程以及一些问题

    使用PyTorch构建神经网络十分的简单,下面是我总结的PyTorch构建神经网络的一般过程以及我在学习当中遇到的一些问题,期望对你有所帮助. PyTorch构建神经网络的一般过程 下面的程序是PyT ...

  6. Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)

    基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html  摘要 在前面的博文中,我详细介绍了Caf ...

  7. 使用PyTorch构建神经网络以及反向传播计算

    使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...

  8. 基于 PyTorch 和神经网络给 GirlFriend 制作漫画风头像

    摘要:本文中我们介绍的 AnimeGAN 就是 GitHub 上一款爆火的二次元漫画风格迁移工具,可以实现快速的动画风格迁移. 本文分享自华为云社区<AnimeGANv2 照片动漫化:如何基于 ...

  9. 神经网络训练中的Tricks之高效BP(反向传播算法)

    神经网络训练中的Tricks之高效BP(反向传播算法) 神经网络训练中的Tricks之高效BP(反向传播算法) zouxy09@qq.com http://blog.csdn.net/zouxy09 ...

  10. 从零到一:caffe-windows(CPU)配置与利用mnist数据集训练第一个caffemodel

    一.前言 本文会详细地阐述caffe-windows的配置教程.由于博主自己也只是个在校学生,目前也写不了太深入的东西,所以准备从最基础的开始一步步来.个人的计划是分成配置和运行官方教程,利用自己的数 ...

随机推荐

  1. 2.XML语言

    XML语言 常见应用: XML技术除用于 /*保存有关系的数据*/之外,它还经常作软件配置文件,以描述程序模块之间的关系. 在一个系统软件中,为提高系统的灵活性,它所启动的模块通常由其配置文件决定 例 ...

  2. 【坑】Java中遍历递归删除List元素

    运行环境 idea 2017.1.1 需求背景 需要做一个后台,可以编辑资源列表用于权限管理 资源列表中可以有父子关系,假设根节点为0,以下以(父节点id,子节点id)表示 当编辑某个资源时,需要带出 ...

  3. Maven的下载及安装

    版权申明:本文为博主原创文章,欢迎大家转载.转载请声明转载处为:https://www.cnblogs.com/qxcxy-silence/p/10808321.html 1.下载Maven; 1). ...

  4. sed交换任意两行

    命令: sed -n 'A{h;n;B!{:a;N;C!ba;x;H;n};x;H;x};p' 文件名 解释: A.B分别是需要交换的行,C是B-1 其中,A.B.C可以是行号,也可以通过匹配模式,如 ...

  5. 实验楼Python项目

    整理几个实验楼小项目,有免费的也有会员的,会员的可以参考他们的实验报告. 直接去实验楼这个网站,粘贴上就能搜到. 免费专区: Kmeans聚类算法评估足球比赛 Python实现3D建模工具 K-近邻算 ...

  6. 遍历二叉树 - 基于栈的DFS

    之前已经学过二叉树的DFS的遍历算法[http://www.cnblogs.com/webor2006/p/7244499.html],当时是基于递归来实现的,这次利用栈不用递归也来实现DFS的遍历, ...

  7. elasticsearch联想加搜索实例

    //搜索框具体的ajax如下: <form class="form-wrapper cf"> <img src="__PUBLIC__/Home/img ...

  8. Spring入门篇——第2章 Spring IOC容器

    第2章 Spring IOC容器 介绍Spring IOC容器的基本概念和应用 2-1 IOC及Bean容器 自己的理解:什么是IOC?就是利用配置文件(外部容器)来创建对象. 在IOC容器中,所有对 ...

  9. 在Myeclipse中没有部署jeesite项目,但是每次运行其他项目时,还是会加载jeesite项目

    解决办法: 一.在以下路径中找到jeesite文件,并删除 1.Tomcat 7.0\conf\Catalina\localhost 2.Tomcat 7.0\webapps 3.Tomcat 7.0 ...

  10. kvm批量创建虚拟主机

    1.首先你的提前创建一个kvm虚拟机主机,才能批量复制创建 批量复制已经安装好的系统盘 `;.img centos7-$i.img && echo $i ;done 批量复制已经安装好 ...