一个被放弃的入门级的例子终于被我实现了,虽然还不太完美,但还是想记录下

1.预处理

  相比较从库里下载数据集(关键是经常失败,格式也看不懂),更喜欢直接拿图片,从网上找了半天,最后从CSDN上下载了一个,真的是良心啊,都分好类了,有需要的可以找我

  (1)图片大小,灰度,格式处理:虽然这里用不到,以后可能用到,所以还是写了

  (2)图片打标:个人想法,图片名称含有标签,训练检测的时候方便拿

代码

 from PIL import Image
import glob
import os def load_image():
"""
图片预处理
将图片大小强制处理为28x28
转换为png格式
"""
width = length = 28
train_path = 'D:/AI/MR_AIStudy/MNIST/dataset/train/*'
test_path = 'D:/AI/MR_AIStudy/MNIST/dataset/test/*'
img_path = glob.glob(test_path) # 图片读取路径
try:
for file in img_path:
path, ext = os.path.splitext(file)
# print(path, ext)
img = Image.open(file)
# out = img.resize((width, length), Image.ANTIALIAS)
out = img.convert('L')
file_name = '{}{}'.format(path, '.png')
print(file_name)
out.save(file_name, quality=100)
print('success')
# img = Image.open(file)
# out = img.resize((width, length), Image.ANTIALIAS)
# out = out.convert('L')
# file_name = '{}{}'.format(path, ext)
# out.save(file_name, quality=100)
except Exception as e:
print(e)
# 图片预处理,将图片缩放到30px30px
# img_path = glob.glob('D:/AI/MR_AIStudy/opencv4/images/*.png') # 图片读取路径
# for file in img_path:
# name = os.path.join(path_save, file)
# im = Image.open(file)
# im.thumbnail((30, 30))
# print(im.size)
# im.save(name, 'png')
# img = Image.open(file)
# data = img.getdata()
# data = np.matrix(data)
# data = np.reshape(data, (30, 30))
# print(data.size) def rename():
# 修改文件名称为 序号-标签.bmp (123-2.bmp) 另存到D:/AI/MR_AIStudy/MNIST/dataset/train目录下
for label in range(10):
print(label)
# path = 'D:/AI/MR_AIStudy/MNIST/dataset/trainimage/{}/*.bmp'.format(label)
path = 'D:/AI/MR_AIStudy/MNIST/dataset/testimage/{}/*.bmp'.format(label)
# path_save = 'D:/AI/MR_AIStudy/MNIST/dataset/train'
path_save = 'D:/AI/MR_AIStudy/MNIST/dataset/test'
print('path', path)
img_path = glob.glob(path)
try:
for index, file in enumerate(img_path):
# index用来区分相同标签不同图片
path, ext = os.path.splitext(file)
# print(path, ext)
img = Image.open(file)
out = img.convert('L')
file_name = '{}-{}{}'.format(index, label, ext) # 修改文件名称,将其打标
print(file_name)
# out.save(file_name, quality=100)
out.save(os.path.join(path_save, os.path.basename(file_name))) # 文件存到指定路径
# break
# print('success') except Exception as e:
print(e)
# break if __name__ == '__main__':
load_image()
# change_ext()
# rename()

2.卷积神经网络

  本来是有归一化,softmax,独热方法的,但是我加上后不好使(加上softmax后不收敛了),就手动实现了一下归一化和独热

代码

import torch
import torch.nn as nn
import torch.utils.data as Data
import glob
import os
import numpy as np
from PIL import Image
import datetime
from torchvision import transforms
import torch.nn.functional as F
# 6272=8x32x32 EPOCH = 1
BATCH_SIZE = 50 class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.con1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
)
self.con2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
)
self.fc = nn.Sequential(
# 线性分类器
nn.Linear(128*7*7, 128), # 修改图片大小后要重新计算
nn.ReLU(),
nn.Linear(128, 10),
# nn.Softmax(dim=1),
)
self.mls = nn.MSELoss()
self.opt = torch.optim.Adam(params=self.parameters(), lr=1e-3)
self.start = datetime.datetime.now() def forward(self, inputs):
out = self.con1(inputs)
out = self.con2(out)
out = out.view(out.size(0), -1) # 展开成一维
out = self.fc(out)
# out = F.log_softmax(out, dim=1)
return out def train(self, x, y):
out = self.forward(x)
loss = self.mls(out, y)
print('loss: ', loss)
self.opt.zero_grad()
loss.backward()
self.opt.step() def test(self, x):
out = self.forward(x)
return out class ParseImage(object):
def __init__(self):
self.transform1 = transforms.Compose([
transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 归一化
]
) def get_data(self, path):
# load_image()
# 将图片转为矩阵,标签进行独热编码
x_data = []
y_data = []
img_path = glob.glob(path) # 图片读取路径
for file in img_path:
one_hot = []
img = Image.open(file)
# img = self.transform1(img)
# img = transforms.ToPILImage()(img)
data = img.getdata()
data = np.matrix(data)
data = np.reshape(data, (28, 28))
# ..手动归一化
data = data/255
x_data.append(data)
name, ext = os.path.splitext(file)
label = name.split('-')[1]
print('label', label)
for i in range(10):
if str(i) == label:
one_hot.append(1)
else:
one_hot.append(0)
y_data.append(one_hot)
# 先转为数组,在转为tensor
x_data = np.array(x_data)
y_data = np.array(y_data)
x_data = torch.from_numpy(x_data).float()
# 输入数据增加频道维度
x_data = torch.unsqueeze(x_data, 1)
y_data = torch.from_numpy(y_data).float()
return x_data, y_data if __name__ == '__main__':
data = ParseImage()
train_path = 'D:/AI/MR_AIStudy/MNIST/dataset/train/*.png'
test_path = 'D:/AI/MR_AIStudy/MNIST/dataset/test/*.png'
x_data, y_data = data.get_data(train_path)
net = MyNet()
# 批训练
torch_dataset = Data.TensorDataset(x_data, y_data)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2,
)
for epoch in range(EPOCH):
for step, (batch_x, batch_y) in enumerate(loader):
print(step)
net.train(batch_x, batch_y) torch.save(net, 'net.pkl') # 存储模型, 全部存储 # 只测试的话加载模型即可
model = torch.load('net.pkl') # 恢复模型
net = model test_x, test_y = data.get_data(test_path)
predict = net.test(test_x)
print(predict)
end = datetime.datetime.now()
print('耗时:{}s'.format(end-net.start))
# 预测结果
# tensor([[ 9.1531e-01, -2.5804e-02, 1.2001e-02, 8.3876e-03, -1.6330e-02,
# -1.7501e-03, -1.0589e-02, 2.6951e-02, 2.1836e-02, -4.5546e-02],
# [-6.4733e-02, 7.7697e-01, 2.2536e-02, 8.3758e-03, 4.2895e-02,
# 1.1602e-02, -3.0644e-02, 2.2412e-02, 1.1579e-01, 3.2196e-02],
# [ 2.6631e-02, -5.3223e-02, 7.9808e-01, 6.0601e-03, 2.2453e-02,
# -3.9522e-02, 3.4775e-02, 1.5853e-02, -6.9575e-03, 1.7208e-02],
# [-1.3861e-02, -1.8332e-02, 4.9981e-02, 9.6510e-01, -1.5838e-02,
# 9.0347e-03, 1.9342e-02, -3.8044e-02, -5.7994e-03, 1.4480e-02],
# [-2.0864e-03, -5.9021e-02, 6.5524e-02, -2.1486e-02, 1.0074e+00,
# 9.3356e-03, 1.0758e-02, 6.6142e-02, 1.4841e-02, 2.2529e-03],
# [-8.4950e-02, -2.4841e-02, -7.7684e-02, 1.6404e-01, 4.3458e-02,
# 8.6580e-01, -3.5630e-02, 4.2452e-02, 7.0675e-02, 2.9663e-02],
# [-5.4024e-02, -1.7111e-02, -3.7085e-03, 3.8194e-03, -3.0645e-02,
# -4.4164e-02, 1.0109e+00, 4.4349e-03, 1.3218e-01, -2.2839e-02],
# [-2.0932e-02, 6.4831e-03, -1.3301e-02, 2.8091e-02, -3.0815e-02,
# -3.2140e-02, 5.2251e-03, 1.0215e+00, 3.2592e-02, 1.0505e-02],
# [ 1.5922e-02, -3.9700e-02, 2.4425e-02, -1.7313e-04, -1.5997e-02,
# -5.2336e-02, -7.7526e-04, -2.1901e-02, 9.7167e-01, 1.3339e-01],
# [-1.9283e-02, 2.4373e-02, -7.5621e-02, 1.1338e-01, -5.7805e-02,
# -5.2936e-03, 1.0090e-03, 2.2471e-02, -3.5736e-02, 1.1243e+00]],
# grad_fn=<AddmmBackward>)
# 耗时:0:09:59.665343s

预测结果不是很美观,但是正确的  欧耶!

pytorch CNN 手写数字识别的更多相关文章

  1. 用pytorch做手写数字识别,识别l率达97.8%

    pytorch做手写数字识别 效果如下: 工程目录如下 第一步  数据获取 下载MNIST库,这个库在网上,执行下面代码自动下载到当前data文件夹下 from torchvision.dataset ...

  2. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  3. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  4. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

  5. kaggle 实战 (2): CNN 手写数字识别

    文章目录 Tensorflow 官方示例 CNN 提交结果 Tensorflow 官方示例 import tensorflow as tf mnist = tf.keras.datasets.mnis ...

  6. keras框架的CNN手写数字识别MNIST

    参考:林大贵.TensorFlow+Keras深度学习人工智能实践应用[M].北京:清华大学出版社,2018. 首先在命令行中写入 activate tensorflow和jupyter notebo ...

  7. Pytorch入门——手把手教你MNIST手写数字识别

    MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...

  8. Task7.手写数字识别

    用PyTorch完成手写数字识别 import numpy as np import torch from torch import nn, optim import torch.nn.functio ...

  9. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

随机推荐

  1. 2018软工实践K班总结

    再回首一学期的软工实践,首先还是要感谢两位助教童鞋帮我承担了作业发布.打分以及与学生的问题沟通等.从这次的软工实践80人+开始,之后的实践课变为必修,故如何能更有效地组织大班实践环节是一个需要持续探讨 ...

  2. 软件工程(FZU2015) 赛季得分榜,第六回合

    SE_FZU目录:1 2 3 4 5 6 7 8 9 10 11 12 13 积分规则 积分制: 作业为10分制,练习为3分制:alpha30分: 团队项目分=团队得分+个人贡献分 个人贡献分: 个人 ...

  3. jmeter的jtl日志转html报告常见报错笔记

    问题:生成的jmeter文件可以放任意位置 输入命令转换hmtl报告 PS D:\user\80003288\桌面\Ques> jmeter -g .\test1.jtl -e -o .\rep ...

  4. flex实现三栏等分布局

    前言 在实际开发中,我们经常会想要实现的一种布局方式就是三栏等分布局,那么我们如何来解决这个问题呢? 解决 方法一:flex 外层容器也就是ul设置display:flex,对项目也就是li设置fle ...

  5. 用Canvas实现一些简单的图片滤镜

    1.灰度滤镜 对于灰度滤镜的实现一般有三种算法 1) 最大值法:即新的颜色值R=G=B=Max(R,G,B),通过这种方法处理后的图片看起来亮度值偏高. 2) 平均值法:即新的颜色值R=G=B=(R+ ...

  6. CentOS6.8 安装配置Mysql

    1.下载mysql的repo源 wget http://repo.mysql.com/mysql-community-release-el7-5.noarch.rpm 2.安装mysql-commun ...

  7. PHP人工智能库

    PHP虽然不是人工智能语言,但做人工智能理论上没问题,下面本人整理了一些PHP人工智能库.1.NLPTools(http://php-nlp-tools.com/)NLPTools是一个PHP自然语言 ...

  8. maven(win10)配置完环境变量后无法识别mvn -v命令

    第一步:http://maven.apache.org/download.cgi官网下载 第二步:把压缩包解压缩到不含中文和空格的目录下 第三步:新建MAVEN_HOME环境变量,值为maven解压缩 ...

  9. react 入坑笔记(四) - React 事件绑定和传参

    React 事件处理 建议:在了解 js 的 this 取值后食用更佳. 一.react 与 Html 中用法的异同和注意点 html 中的绑定事件的写法: <button onclick=&q ...

  10. js函数使用prototype和不适用prototype的区别

    js中类定义函数时用prototype与不用的区别 原创 2017年06月05日 12:25:41 标签: 函数 / prototype / class   首先来看一个实例: function Li ...