卷积神经网络目前被广泛地用在图片识别上, 已经有层出不穷的应用, 如果你对卷积神经网络充满好奇心,这里为你带来pytorch实现cnn一些入门的教程代码

#首先导入包

import torch
from torch.autograd import Variable
import torch.nn as nn
import torchvision
import torch.utils.data as Data

#一、数据准备

#训练数据:用了torchvision.datasets.MNIST,root是文件路径,train为True(这是训练数据),transform是把图像数据转换为张量,download(如果本地已有该文件选择false,没有就选择true)

train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=False)

#训练数据:同上,train为False(这是测试数据)

test_data = torchvision.datasets.MNIST(root='./mnist/',train=False)

# "训练数据加载器":dataset为训练数据,shuflle为打乱数据的顺序,batch_size是让数据50个为一组

train_loader = Data.DataLoader(dataset=train_data,shuffle=True,batch_size=50)

test_data.test_data.size()

torch.Size([10000, 28, 28])

#测试数据 test_data下的test_data为测试数据,因为下面conv2d输入的为4维数据,所以此处用torch.unsqueeze升维

test_x = Variable(torch.unsqueeze(test_data.test_data,dim=1),volatile=True).type(torch.FloatTensor)

#测试数据目标值

test_y = test_data.test_labels

#二、实现模型

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()

    #conv2d参数:输入1维,输出16维,5个卷积核(kernel),步长(stride)为1,padding是2(如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1)
    self.conv1 = nn.Sequential(nn.Conv2d(1,16,5,1,2),nn.ReLU(),nn.MaxPool2d(2))
    self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(2))

    #Linear参数:输入维数,输出分的种类数
    self.out = nn.Linear(32*7*7,10)
  def forward(self,x):
    x1 = self.conv1(x)
    x2 = self.conv2(x1)

    #这里给x3降为2维可以让linear函数使用
    x3 = x2.view(x2.size(0),-1)
    out = self.out(x3)
    return out

#自动调整参数,最优化模型

cnn = CNN()

optimizer = torch.optim.Adam(cnn.parameters(),lr = 0.02)
loss_func = nn.CrossEntropyLoss()

#三、训练模型

for step,(x,y) in enumerate(train_loader):
  x = Variable(x)
  y = Variable(y)
  out = cnn(x)
  loss = loss_func(out,y)

  #以下为固定操作,为了训练每一条数据,不断调整参数
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

#四、测试

predict = cnn(test_x[:10])
res = torch.max(predict,1)[1]

res #测试数据

tensor([7, 2, 1, 0, 4, 1, 4, 9, 9, 9])

test_y[:10] #真实数据

tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])

#在这里我们发现前十个数据分类准确率达到90

Pytorch卷积神经网络识别手写数字集的更多相关文章

  1. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

  2. Tensorflow搭建卷积神经网络识别手写英语字母

    更新记录: 2018年2月5日 初始文章版本 近几天需要进行英语手写体识别,查阅了很多资料,但是大多数资料都是针对MNIST数据集的,并且主要识别手写数字.为了满足实际的英文手写识别需求,需要从训练集 ...

  3. PyTorch基础——使用卷积神经网络识别手写数字

    一.介绍 实验内容 内容包括用 PyTorch 来实现一个卷积神经网络,从而实现手写数字识别任务. 除此之外,还对卷积神经网络的卷积核.特征图等进行了分析,引出了过滤器的概念,并简单示了卷积神经网络的 ...

  4. Python实现神经网络算法识别手写数字集

    最近忙里偷闲学习了一点机器学习的知识,看到神经网络算法时我和阿Kun便想到要将它用Python代码实现.我们用了两种不同的方法来编写它.这里只放出我的代码. MNIST数据集基于美国国家标准与技术研究 ...

  5. 使用TensorFlow的卷积神经网络识别手写数字(3)-识别篇

    from PIL import Image import numpy as np import tensorflow as tf import time bShowAccuracy = True # ...

  6. 使用TensorFlow的卷积神经网络识别手写数字(2)-训练篇

    import numpy as np import tensorflow as tf import matplotlib import matplotlib.pyplot as plt import ...

  7. 使用TensorFlow的卷积神经网络识别手写数字(1)-预处理篇

    功能: 将文件夹下的20*20像素黑白图片,根据重心位置绘制到28*28图片上,然后保存.经过预处理的图片有利于数字的准确识别.参见MNIST对图片的要求. 此处可下载已处理好的图片: https:/ ...

  8. 李宏毅 Keras手写数字集识别(优化篇)

    在之前的一章中我们讲到的keras手写数字集的识别中,所使用的loss function为‘mse’,即均方差.那我们如何才能知道所得出的结果是不是overfitting?我们通过运行结果中的trai ...

  9. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

随机推荐

  1. Jenkins+Gitee异常解决

    Failed to connect to repository : Command "git ls-remote -h username@mygit.com:cc/myproject.git ...

  2. selenium firefox 内存 速度优化

    selenium firefox 内存 速度优化 2 23 profile = webdriver.FirefoxProfile() 2 24 profile.set_preference(" ...

  3. django配置文件

    1.BASSE_DIR BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 当前工程的根目录,Django会依 ...

  4. jsonserver的安装及启动

    JsonServer 主要的作用就是搭建本地的数据接口,创建json文件,便于调试调用 是一个 Node 模块,运行 Express 服务器,可以指定一个 json 文件作为 api 的数据源 官网: ...

  5. Shell 冒泡排序

    举例 #!/bin/bash echo "please input a number list:" read -a arrs for((i=0;i<${#arrs[@]};i ...

  6. UISlider基本使用

    UISlider是一个很常用的UI控件,调节屏幕亮度或者调节音量大小等很多地方都可以用到,而且使用方便,下面我来介绍一下UISlider的基本使用. 首先介绍一下基本属性和常用方法: //设置当前sl ...

  7. 关于SQL中的 where 1 = 1 的用法

    在项目中的常见的一个操作:在有关SQL的代码中加入where 1 = 1,关于它的用法,可以总结如下: 首先,where 1 = 1的用法往往是为了方便后续的给SQL增加where限制条件.如果实现加 ...

  8. web模拟终端 --使用shellinabox

    关于shellinabox ShellInABox实现了一个Web服务器,可以将任意命令行工具导出到基于Web的终端仿真器.任何支持JavaScript和CSS的Web浏览器都可以访问此模拟器,并且不 ...

  9. Odoo CRM模块

    转载请注明原文地址:https://www.cnblogs.com/ygj0930/p/10825983.html  一:理解CRM CRM:客户关系管理,是指企业用CRM技术来管理与客户之间的关系. ...

  10. CodeForces - 1159B

    题目链接:https://vjudge.net/problem/CodeForces-1159B 题目意思:任选选两个元素,分别为a[i],a[j]. 问 都满足K*| i -  j | <= ...