简单的深度神经网络实现——使用PyTorch
使用的数据集是MNIST,预期可以达到98%左右的准确率。
该神经网络由一个输入层,一个全连接层结构的隐含层和一个输出层构建。
1.配置库和配置参数
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable #配置参数
torch.manual_seed(1) #设置随机数种子,确保结果可重复
input_size=784
hidden_size=500
num_classes=10
num_epoches=5 #训练次数
batch_size=100 #批处理大小
learning_rate=0.001 #学习率
2.加载MNIST数据
#加载MNIST数据
train_dataset=dsets.MNIST(root='./data',#数据保持的位置
train=True,#训练集
transform=transforms.ToTensor(),
download=True)
#将一个取值范围是【0,255】的PIL.Image转化成取值范围是【0,1.0】的torch.FloatTensor
test_dataset=dsets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
3.数据的批处理一
#数据的批处理
#Data Loader(Input Pipeline)
#数据的预处理,尺寸大小必须为batch_size,在训练集中,shuffle必须设置为True,表示次序是随机的
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
4.创建DNN模型
#创建DNN模型
#Neural Network Model(1 hidden layer)定义神经网络模型
class Net(nn.Module):
def __init__(self,input_size,hidden_size,num_classes):
super(Net,self).__init__()
self.fc1=nn.Linear(input_size,hidden_size)
self.relu=nn.ReLU()
self.fc2=nn.Linear(hidden_size,num_classes) def forward(self, x):
out=self.fc1(x)
out=self.relu(out)
out=self.fc2(out)
return out
net=Net(input_size,hidden_size,num_classes)
#打印模型,呈现网络结构
print(net)
5.训练流程
#训练流程
#Loss and Optimizer 定义loss和optimizer
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=learning_rate) #train the model 开始训练
for epoch in range(num_epoches):
for i,(images,labels) in enumerate(train_loader):#批处理
#convert torch tensor to Variable
images=Variable(images.view(-1,28*28))
labels=Variable(labels) #forward+backward+optimize
optimizer.zero_grad()#zero the gradient buffer梯度清零,以免影响其他batch
outputs=net(images)#前向传播
loss=criterion(outputs,labels)#loss
loss.backward()#后向传播,计算梯度
optimizer.step()#梯度更新 if(i+1)%100==0:
print('Epoch [%d/%d],Step[%d,%d],Loss:%.4f'%(epoch+1,num_epoches,i+1,len(train_dataset)//batch_size,loss.item()))
6.在测试集测试识别率
#Test the model,在测试集上验证模型
correct=0
total=0
for images,labels in test_loader:#test set批处理
images=Variable(images.view(-1,28*28))
outputs=net(images)
_,predicted=torch.max(outputs.data,1)#预测结果
total+=labels.size(0)#正确结果
correct+=(predicted==labels).sum()#正确结果总数
print('Accuracy of the network on thr 10000 test iamges:%d %%'%(100*correct/total))
简单的深度神经网络实现——使用PyTorch的更多相关文章
- C++从零实现简单深度神经网络(基于OpenCV)
代码地址如下:http://www.demodashi.com/demo/11138.html 一.准备工作 需要准备什么环境 需要安装有Visual Studio并且配置了OpenCV.能够使用Op ...
- 深度神经网络在量化交易里的应用 之二 -- 用深度网络(LSTM)预测5日收盘价格
距离上一篇文章,正好两个星期. 这边文章9月15日 16:30 开始写. 可能几个小时后就写完了.用一句粗俗的话说, "当你怀孕的时候,别人都知道你怀孕了, 但不知道你被日了多少回 ...
- Keras入门(一)搭建深度神经网络(DNN)解决多分类问题
Keras介绍 Keras是一个开源的高层神经网络API,由纯Python编写而成,其后端可以基于Tensorflow.Theano.MXNet以及CNTK.Keras 为支持快速实验而生,能够把 ...
- 深度神经网络DNN的多GPU数据并行框架 及其在语音识别的应用
深度神经网络(Deep Neural Networks, 简称DNN)是近年来机器学习领域中的研究热点,产生了广泛的应用.DNN具有深层结构.数千万参数需要学习,导致训练非常耗时.GPU有强大的计算能 ...
- 如何用70行Java代码实现深度神经网络算法
http://www.tuicool.com/articles/MfYjQfV 如何用70行Java代码实现深度神经网络算法 时间 2016-02-18 10:46:17 ITeye 原文 htt ...
- 深度神经网络(DNN)模型与前向传播算法
深度神经网络(Deep Neural Networks, 以下简称DNN)是深度学习的基础,而要理解DNN,首先我们要理解DNN模型,下面我们就对DNN的模型与前向传播算法做一个总结. 1. 从感知机 ...
- 深度神经网络(DNN)反向传播算法(BP)
在深度神经网络(DNN)模型与前向传播算法中,我们对DNN的模型和前向传播算法做了总结,这里我们更进一步,对DNN的反向传播算法(Back Propagation,BP)做一个总结. 1. DNN反向 ...
- 深度神经网络(DNN)损失函数和激活函数的选择
在深度神经网络(DNN)反向传播算法(BP)中,我们对DNN的前向反向传播算法的使用做了总结.里面使用的损失函数是均方差,而激活函数是Sigmoid.实际上DNN可以使用的损失函数和激活函数不少.这些 ...
- 深度神经网络(DNN)的正则化
和普通的机器学习算法一样,DNN也会遇到过拟合的问题,需要考虑泛化,这里我们就对DNN的正则化方法做一个总结. 1. DNN的L1&L2正则化 想到正则化,我们首先想到的就是L1正则化和L2正 ...
随机推荐
- C++学习四 冒泡排序法的一些改进
冒泡排序法需要两次扫描,所以从时间复杂度来说,是O(n2). 如果用图形表示,是这样的: 但是我们可以加以改进. 首先是,如果在排序中间,整个向量已经达到了有序状态,可以直接跳出来. 这样它的复杂度由 ...
- 配置Ngnix1.15.11+php5.4出现502 Bad Gateway问题
今天在调试Ngnix1.15.11+php5.4网站时候,因为网站数据和并发过大,出现502 Bad Gateway问题,所以记下笔记. 只需要修改php-fpm.conf的request_termi ...
- strcpy&memcpy&memmove
strcpy extern char *strcpy(char *dest,char *source); { assert((dest!=NULL)&&(source!=NULL)); ...
- C++ class 中的 const 成员函数
const 修饰的成员函数 表示 不会修改class中的成员变量. const 和 非-const 的成员函数同事存在时, 用户定义 const 类对象,调用 const 成员函数: 定义 非-c ...
- 大宗商品交易与风险管理(CTRM)软件产品介绍
https://mp.weixin.qq.com/s/grA8MhryPfDB2PmBqsao4Q 从全球范围来看,大宗商品行业风险管理领域的主流软件产品是CTRM系列.CTRM是Commodity ...
- luoguP3306 [SDOI2013]随机数生成器
题意 将\(x_1,x_2,x_3...x_n\)写出来可以发现通项为\(a^{i-1}*x_1+b*\sum\limits_{j=0}^{i-2}a^j=a^{i-1}*x_1+b*\frac{1- ...
- 大话设计模式Python实现-工厂方法模式
工厂方法模式(Factory Method Pattern):定义一个用于创建对象的接口,让子类决定实例化哪一个类,工厂方法使一个类的实例化延时到其子类. #!/usr/bin/env python ...
- TensorFlow函数: tf.stop_gradient
停止梯度计算. 在图形中执行时,此操作按原样输出其输入张量. 在构建计算梯度的操作时,这个操作会阻止将其输入的共享考虑在内.通常情况下,梯度生成器将操作添加到图形中,通过递归查找有助于其计算的输入来计 ...
- jdk-8u151-nb-8_2-windows-x64软件安装教程及环境配置
1.双击jdk-8u151-windows-x64.exe文件 2.进入安装向导 3.配置环境变量 (1)计算机→属性→高级系统设置→高级→环境变量 (2)系统变量→新建 JAVA_HOME 变量 . ...
- EF操作与Linq写法记录
项目总结:EF操作与Linq写法记录 1.EF引入 新建一个MVC项目之后,要引用EF框架,可以按照以下步骤进行: 1),在Models中添加项目 2),选择Entity Data Model,并重新 ...