import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
batch_size = 64
learning_rate = 1e-2
num_epoches = 20
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])
#transform.Compose() 将各种预处理操作组合在一起
#transform.ToTensor() 将数据转化为Tensor类型,并自动标准化,Tensor的取值是(0,1)
#transform.Normalize()是标准化操作,类似正太分布的标准化,第一个值是均值,第二个值是方差
#如果图像是三个通道,则transform.Normalize([a,b,c],[d,e,f])
train_dataset = datasets.MNIST(root = './mnist_data', train = True, transform = data_tf, download = True) #用datasets加载数据集,传入预处理
test_dataset = datasets.MNIST(root = './mnist_data', train = False,transform = data_tf)
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True) #利用DataLoader建立一个数据迭代器
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)
class Batch_Net(nn.Module):
def __init__(self, inputdim, hidden1, hidden2, outputdim):
super(Batch_Net, self).__init__()
self.layer1 = nn.Sequential(nn.Linear(inputdim, hidden1), nn.BatchNorm1d(hidden1), nn.ReLU(True))
self.layer2 = nn.Sequential(nn.Linear(hidden1, hidden2), nn.BatchNorm1d(hidden2), nn.ReLU(True))
self.layer3 = nn.Sequential(nn.Linear(hidden2, outputdim)) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
model = Batch_Net(28*28, 300, 100, 10)
model

定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = learning_rate)

训练模型

for epoch in range(num_epoches):
train_loss = 0
train_acc = 0
model.train() #这句话会自动调整batch_normalize和dropout值,很关键!
for img, label in train_loader:
img = img.view(img.size(0), -1) #将数据扁平化为一维
img = Variable(img)
label = Variable(label)
# 前向传播
out = model(img)
loss = criterion(out, label)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 记录误差
train_loss += loss.item()
# 计算分类的准确率
_, pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / img.shape[0]
train_acc += acc print('epoch:{},train_loss:{:.6f},acc:{:.6f}'.format(epoch+1, train_loss/len(train_loader), train_acc/len(train_loader)))
epoch:1,train_loss:0.002079,acc:0.999767
......  
epoch:19,train_loss:0.001532,acc:0.999917
epoch:20,train_loss:0.001670,acc:0.999850

测试集

model.eval()  #在评估模型时使用,固定BN 和 Dropout
eval_loss = 0
val_acc = 0
for img , label in test_loader:
img = img.view(img.size(0), -1)
img = Variable(img, volatile = True) #volatile=TRUE表示前向传播是不会保留缓存,因为测试集不需要反向传播
label = Variable(label, volatile = True)
out = model(img)
loss = criterion(out, label)
eval_loss += loss.item()
_,pred = torch.max(out, 1)
num_correct = (pred == label).sum().item()
print(num_correct)
eval_acc = num_correct / label.shape[0]
val_acc += eval_acc print('Test Loss:{:.6f}, Acc:{:.6f}'.format(eval_loss/len(test_loader), val_acc/len(test_loader)))
Test Loss:0.062413, Acc:0.981091

多层全连接神经网络实现minist手写数字分类的更多相关文章

  1. keras与卷积神经网络(CNN)实现识别minist手写数字

    在本篇博文当中,笔者采用了卷积神经网络来对手写数字进行识别,采用的神经网络的结构是:输入图片——卷积层——池化层——卷积层——池化层——卷积层——池化层——Flatten层——全连接层(64个神经元) ...

  2. Tensorflow 多层全连接神经网络

    本节涉及: 身份证问题 单层网络的模型 多层全连接神经网络 激活函数 tanh 身份证问题新模型的代码实现 模型的优化 一.身份证问题 身份证号码是18位的数字[此处暂不考虑字母的情况],身份证倒数第 ...

  3. python手写神经网络实现识别手写数字

    写在开头:这个实验和matlab手写神经网络实现识别手写数字一样. 实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手 ...

  4. matlab手写神经网络实现识别手写数字

    实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手写数字图片,于是我就尝试用matlab写一个网络. 实验数据:500 ...

  5. MNIST手写数字分类simple版(03-2)

    simple版本nn模型 训练手写数字处理 MNIST_data数据   百度网盘链接:https://pan.baidu.com/s/19lhmrts-vz0-w5wv2A97gg 提取码:cgnx ...

  6. Tensorflow-线性回归与手写数字分类

    线性回归 步骤 构造线性回归数据 定义输入层 设计神经网络中间层 定义神经网络输出层 计算二次代价函数,构建梯度下降 进行训练,获取预测值 画图展示 代码 import tensorflow as t ...

  7. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  8. Pytorch1.0入门实战一:LeNet神经网络实现 MNIST手写数字识别

    记得第一次接触手写数字识别数据集还在学习TensorFlow,各种sess.run(),头都绕晕了.自从接触pytorch以来,一直想写点什么.曾经在2017年5月,Andrej Karpathy发表 ...

  9. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

随机推荐

  1. 原生js:click和onclick本质的区别(转https://www.cnblogs.com/web1/p/6555662.html)

    原生javascript的click在w3c里边的阐述是DOM button对象,也是html DOM click() 方法,可模拟在按钮上的一次鼠标单击. button 对象代表 HTML 文档中的 ...

  2. 搜索专题:Balloons

    搜索专题:Balloons 这道题一看与时间有关,第一想到的就是BFS,定义一个状态,包含每一个状态的剩余气球数,已经进行的时间和每一个志愿者上一次吹气球的时间: 每一次状态转换时,检查是否有没有使用 ...

  3. aop设计原理(转)

    本文摘自 博文--<Spring设计思想>AOP设计基本原理 0.前言 Spring 提供了AOP(Aspect Oriented Programming) 的支持, 那么,什么是AOP呢 ...

  4. 一个Accecc_Token生成和缓存和读取类,微信/小程序开发必须学

    Access_Token是调用微信和小程序各种接口的临时凭证,有效期2小时(7200秒),很多接口都需要调用access_token接口生成一个access_token的,例如微信支付,微信分享,公众 ...

  5. day 01 常量 注释 int(整型) 用户交互input 流程控制语句if

    python的编程语言分类(重点) if 3 > 2: 编译型: 将代码一次性全部编译成二进制,然后再执行. 优点:执行效率高. 缺点:开发效率低,不能跨平台. 代表语言:C 解释型: 逐行解释 ...

  6. linux:RAID(磁盘阵列)笔记

    RAID磁盘阵列简述:     RAID0(条带): 把多个同样大小的磁盘串联起来当做一个磁盘来用.         优点:读写速度快.         缺点:数据容易丢失(没有容错能力).     ...

  7. PHP实现app唤起支付宝支付代码

    本文主要和大家分享PHP实现app唤起支付宝支付代码,希望能帮助到大家. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 ...

  8. mongoose 开源http库

    Mongoose是一个用C编写的网络库.它为客户端和服务器模式实现TCP,UDP,HTTP,WebSocket,CoAP,MQTT的事件驱动的非阻塞API. 设计理念: Mongoose有三个基本的数 ...

  9. nginx的简单介绍

    nginx简单介绍 Nginx的负载均衡策略可以分两大类:内置策略和扩展侧略: 内置策略包括:轮询,加权轮询,IP hash 扩展策略是:url hash ,fair nginx.conf文件结构 1 ...

  10. Codeforces 964 等比数列逆元处理 贪心删偶数度节点

    A B C 注意sum要在mod范围内 且不能用/a*b来推 #include<bits/stdc++.h> using namespace std; typedef long long ...