代码:

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision # 数据库模块
import matplotlib.pyplot as plt torch.manual_seed() # reproducible
# Hyper Parameters
EPOCH = # 训练整批数据多少次, 为了节约时间, 我们只训练一次
BATCH_SIZE =
LR = 0.001 # 学习率
DOWNLOAD_MNIST = False # 如果你已经下载好了mnist数据就写上 False
# Mnist 手写数字
train_data = torchvision.datasets.MNIST(
root='./mnist/', # 保存或者提取位置
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # 转换 PIL.Image or numpy.ndarray 成
# torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间
download=DOWNLOAD_MNIST, # 没下载就下载, 下载了就不用再下了
)
#plot one example
# print(train_data.test_data.shape)#torch.Size([, , ])
# print(train_data.train_labels.shape)#torch.Size([])
# print(train_data.train_data[].shape)#torch.Size([, ])
#
# plt.imshow(train_data.train_data[],cmap='gray')
# plt.title('%d'%train_data.train_labels[])
# plt.show() #测试数据
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) # print(test_data.test_data.shape)#torch.Size([, , ])
# 为了节约时间, 我们测试时只测试前2000个
test_x = torch.unsqueeze(test_data.test_data, dim=).type(torch.FloatTensor)[:] # /.shape from (, , ) to (, , , ), value in range(,)
test_y = test_data.test_labels[:] # 批训练 50samples, channel, 28x28 (, , , )
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1=nn.Sequential(
nn.Conv2d(
in_channels=,
out_channels=,#n_filters
kernel_size=, # filter size
stride=, # filter movement/step
padding=, # 如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-)/ 当 stride=
),# output shape (, , )
nn.ReLU(),
nn.MaxPool2d(kernel_size=)# output shape (, , )
)
self.conv2=nn.Sequential(
nn.Conv2d(,,,,),# output shape (, , )
nn.ReLU(),
nn.MaxPool2d()# output shape (, , )
)
self.out=nn.Linear(**,)# fully connected layer, output classes
def forward(self, x):
x=self.conv1(x)
x=self.conv2(x)
#print(x.shape)#output:torch.Size([, , , ])
x = x.view(x.size(), -) # 展平多维的卷积图成 (batch_size, * * )
# print(x.shape)#output:torch.Size([, ])
output = self.out(x)
return output
cnn=CNN()
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
# training and testing
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader): # 分配 batch data, normalize x when iterate train_loader
print('step:',step)
output = cnn(b_x) # cnn output
loss = loss_func(output, b_y) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
test_output = cnn(test_x[:])
#test_x[:].shape=torch.Size([, , , ])
#test_output.shape=torch.Size([, ])
print('test_output:',test_output)
# test_output: tensor([[-1383.2828, -1148.1272, 311.1780, 153.0877, -3062.3340, -886.6730,
# -5819.7256, 3619.9558, -1544.4225, 193.6745],
# [ 282.6339, 647.2642, 3027.1570, -379.0817, -3403.5310, -2406.4951,
# -1117.4684, -4085.4429, -306.6578, -3844.1602],
# [-1329.7642, 1895.3890, -755.7719, -1378.9316, -314.2351, -1607.4249,
# -1026.8795, -428.1658, -385.1328, -1404.5205],
# [ 2991.5627, -3583.5374, -554.1349, -2472.6204, -1712.7700, -1092.7367,
# 148.9156, -1580.6696, -1126.8331, -477.7481],
# [-1818.9655, -1502.3574, -1620.6603, -2142.3472, 2529.0496, -2008.2731,
# -1585.5699, -786.7817, -1372.2627, 848.0875],
# [-1415.7609, 2248.9607, -909.5534, -1656.6108, -311.2874, -2255.2163,
# -1643.2495, -149.4040, -342.9626, -1372.8961],
# [-3766.0422, -484.8116, -1971.9016, -2483.8538, 1448.3118, -1048.7388,
# -2411.9790, -1089.5471, 422.1722, 249.8736],
# [-2933.3752, -877.4833, -671.7119, -573.4670, 63.9295, -497.9561,
# -2236.4597, -1218.2463, -296.5850, 1256.0739],
# [-2187.7292, -4899.0063, -2404.6597, -2595.0764, -2987.9624, 2052.1494,
# 335.9461, -2942.6995, 275.7964, -551.2797],
# [-1903.9233, -3449.5530, -1652.7020, -1087.9016, -515.1445, -1170.5551,
# -3734.2666, 628.9314, 69.0235, 2096.6257]],
# grad_fn=<AddmmBackward>)
print('test_output.shape:',test_output.shape)
# test_output.shape: torch.Size([, ]) pred_y = torch.max(test_output, )[].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:].numpy(), 'real number')

利用卷积神经网络实现MNIST手写数据识别的更多相关文章

  1. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  2. Pytorch1.0入门实战一:LeNet神经网络实现 MNIST手写数字识别

    记得第一次接触手写数字识别数据集还在学习TensorFlow,各种sess.run(),头都绕晕了.自从接触pytorch以来,一直想写点什么.曾经在2017年5月,Andrej Karpathy发表 ...

  3. keras—神经网络CNN—MNIST手写数字识别

    from keras.datasets import mnist from keras.utils import np_utils from plot_image_1 import plot_imag ...

  4. 第三节,CNN案例-mnist手写数字识别

    卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...

  5. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  6. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  7. Tensorflow实现MNIST手写数字识别

    之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...

  8. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  9. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

随机推荐

  1. 2020/1/29 PHP代码审计之XSS漏洞

    0x00 XSS漏洞简介 人们经常将跨站脚本攻击(Cross Site Scripting)缩写为CSS,但这会与层叠样式表(Cascading Style Sheets,CSS)的缩写混淆.因此,有 ...

  2. ZJNU 2353 - UNO

    大模拟,但是题目好像有些地方表述不清 根据UNO在初中曾被别人虐了很久很久的经历 猜测出了原本的题意 本题中的+2虽然有颜色,但是也可以当作原UNO游戏中的+4黑牌 即在某人出了+2后,可以出不同颜色 ...

  3. CentOS6.x/6.5/6.4/6.3/6.2/7.x 64位安装php5.2(使用YUM自动安装)

    默认情况下,CentOS6 64 bit 已经早已不支持php5.2.x ,但是某些php程序还需要zend optimizer支持,怎么办呢?目前大部分的yum repos 都已经不支持直接安装ph ...

  4. 跨站脚本(XSS)攻击

    https://blog.csdn.net/extremebingo/article/details/81176394

  5. C - Monitor CodeForces - 846D (二维前缀和 + 二分)

    Recently Luba bought a monitor. Monitor is a rectangular matrix of size n × m. But then she started ...

  6. python3拆包详解

    对于可迭代对象,如元组.列表.字符串.集合.字典这些可迭代对象都可以被拆包,拆包是指将一个结构中的数据拆分为多个单独变量中.拆包的方式大致有两种,一种是以变量的方式来接收,另一种是用'*'号.下面先讲 ...

  7. ES6之展开运算符

    本文介绍ES6新增的展开运算符(spread operator). 由上图可得,展开运算符负责拼装数组和对象,与之相反,解构赋值负责分解数组和对象. 由上图可得,展开运算符能和解构赋值一起发挥成更大的 ...

  8. Android音视频处理之基于MediaCodec合并音视频

    Android提供了一个MediaExtractor类,可以用来分离容器中的视频track和音频track,下面的例子展示了使用MediaExtractor和MediaMuxer来实现视频的换音: p ...

  9. Python笔记_第一篇_面向过程_第一部分_6.语句的嵌套

    学完条件控制语句和循环控制语句后,在这里就会把“语言”的精妙之处进行讲解,也就是语句的嵌套.我们在看别人代码的时候总会对一些算法拍案叫绝,里面包含精妙和精密的逻辑分析.语句的嵌套也就是在循环体内可以嵌 ...

  10. ! [remote rejected] master -> master (pre-receive hook declined)

    前天准备上传一个project到GitLab上,但是试了很多次都上传不上去,报错如下: ! [remote rejected] master -> master (pre-receive hoo ...