CNN --入门MNIST识别
Smiling & Weeping
---- 下次你撑伞低头看水洼,
就会想起我说雨是神的烟花。
简介:主要是看刘二大人的视频讲解:https://www.bilibili.com/video/BV1Y7411d7Ys/?spm_id_from=333.337.search-card.all.click
题目及提交链接:Digit Recognizer | Kaggle
深度学习入门的学习项目,使用CNN(Convolutional Nerual Network)
对于Basic CNN的理解:
- 分成两个部分:前一个部分叫做Feature Extraction,后一部分叫做Classification(其中Feature Extraction又可以分为Convolution,Subsampling等)
- 其中要求卷积核的通道数量与输入通道数量一致。这种卷积核的总数和输出通道数目的总数一致(详见链接PDF)
- 卷积(convolution)后,C(channels),W(width),H(height),其中padding和pooling(小技巧:若要卷积W,H不变,取整kernel_size/2)
- 卷积层:保存图像的空间信息
- 卷积层要求输入输出是四维张量(B,C,W,H),全连接层的输入输出都是二维张量(B,Input_feature)
- 卷积(线性变换),激活函数(非线性变换),池化;这个过程若干次后,view打平,进入全连接层
1 import torch
2 import torch.nn.functional as F
3 import torch.nn as nn
4 import torch.optim as optim
5 import torch.autograd as lr_scheduler
6 from torch.utils.data import DataLoader, Dataset
7 from torchvision import transforms
8 from torchvision.utils import make_grid
9 from torchvision import datasets
10 from torch.autograd import Variable
11 from sklearn.model_selection import train_test_split
12 import pandas as pd
13 import numpy as np
14 import matplotlib.pyplot as plt
15
16 batch_size = 64
17 transform = transforms.Compose([transforms.ToTensor()])
18 train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
19 train_loader = torch.utils.data.DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size)
20 #同样的方式加载一下测试集
21 test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
22 test_loader = torch.utils.data.DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size)
23
24 # 使用卷积神经网络进行图像特征提取
25 # (batch, 1, 28, 28) -> (batch, 10, 24, 24) -> 池化 (batch, 10, 12, 12) -> (batch, 20, 8, 8) -> (batch, 20 , 4, 4) -> (batch, 320) -> (batch, 10)
26 class Net(torch.nn.Module):
27 def __init__(self):
28 super(Net, self).__init__()
29 self.conv1 = torch.nn.Conv2d(1, 10, kernel_size = 5)
30 self.conv2 = torch.nn.Conv2d(10, 20, kernel_size = 5)
31 self.pooling = torch.nn.MaxPool2d(2)
32 self.fc = torch.nn.Linear(320, 10)
33
34 def forward(self, x):
35 # Flatten data from (n, 1, 28, 28) to (n, 784)
36 batch_size = x.size(0)
37 x = F.relu(self.pooling(self.conv1(x)))
38 x = F.relu(self.pooling(self.conv2(x)))
39 x = x.view(batch_size, -1) # Flatten
40 x = self.fc(x)
41 return x
42
43 model = Net()
44 # print(model)
45 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46 model.to(device)
47 criterion = torch.nn.CrossEntropyLoss(size_average=True)
48 optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
49
50 def train(epoch):
51 running_loss = 0.0
52 for batch_idx, data in enumerate(train_loader, 0):
53 inputs, target = data
54 inputs, target = inputs.to(device), target.to(device)
55 optimizer.zero_grad()
56
57 # forward + backward + update
58 outputs = model(inputs)
59 # 计算真实值 和 测量值 之间的误差
60 loss = criterion(outputs, target)
61 loss.backward()
62 optimizer.step()
63
64 running_loss += loss.item()
65 if batch_idx % 300 == 299:
66 print('[%d, %5d] loss: %3f' % (epoch + 1, batch_idx+1, running_loss / 2000))
67 running_loss = 0.0
68
69 def test():
70 correct = 0
71 total = 0
72 with torch.no_grad():
73 for data in test_loader:
74 inputs, target = data
75 inputs, target = inputs.to(device), target.to(device)
76 outputs = model(inputs)
77 _, prediction = torch.max(outputs.data, dim=1)
78 total += target.size(0)
79 correct += (prediction == target).sum().item()
80 print('Accuracy on test set: %d %% [%d/%d]' % (100*correct / total, correct, total))
81 return correct/total
82
83 epoch_list = []
84 acc_list = []
85 for epoch in range(10):
86 train(epoch)
87 acc = test()
88 epoch_list.append(epoch)
89 acc_list.append(acc)
90
91 plt.plot(epoch_list, acc_list)
92 plt.ylabel("accuracy")
93 plt.xlabel("epoch")
94 plt.show()
95
96 class DatasetSubmissionMNIST(torch.utils.data.Dataset):
97 def __init__(self, file_path, transform=None):
98 self.data = pd.read_csv(file_path)
99 self.transform = transform
100
101 def __len__(self):
102 return len(self.data)
103
104 def __getitem__(self, index):
105 image = self.data.iloc[index].values.astype(np.uint8).reshape((28, 28, 1))
106
107
108 if self.transform is not None:
109 image = self.transform(image)
110
111 return image
112
113 transform = transforms.Compose([
114 transforms.ToPILImage(),
115 transforms.ToTensor(),
116 transforms.Normalize(mean=(0.5,), std=(0.5,))
117 ])
118
119 submissionset = DatasetSubmissionMNIST('/kaggle/input/digit-recognizer/test.csv', transform=transform)
120 submissionloader = torch.utils.data.DataLoader(submissionset, batch_size=batch_size, shuffle=False)
121
122 submission = [['ImageId', 'Label']]
123
124 with torch.no_grad():
125 model.eval()
126 image_id = 1
127
128 for images in submissionloader:
129 images = images.cuda()
130 log_ps = model(images)
131 ps = torch.exp(log_ps)
132 top_p, top_class = ps.topk(1, dim=1)
133
134 for prediction in top_class:
135 submission.append([image_id, prediction.item()])
136 image_id += 1
137
138 print(len(submission) - 1)
139 import csv
140
141 with open('submission.csv', 'w') as submissionFile:
142 writer = csv.writer(submissionFile)
143 writer.writerows(submission)
144
145 print('Submission Complete!')
146 # summission.to_csv('/kaggle/working/submission.csv', index=False)

就效果来说,也就一般,后面的Advance CNN 会有更高的效率和准确性,大家可以敲一下代码放在自己的编译器上跑一下
对了,这是GPU版本,若用CPU,把所有的device删除就可以,--<-<-<@
文章到此结束,我们下次再见
一束光线,可能会摔碎
但仍旧光芒四射
CNN --入门MNIST识别的更多相关文章
- 使用tensorflow实现cnn进行mnist识别
第一个CNN代码,暂时对于CNN的BP还不熟悉.但是通过这个代码对于tensorflow的运行机制有了初步的理解 ''' softmax classifier for mnist created on ...
- NLP用CNN分类Mnist,提取出来的特征训练SVM及Keras的使用(demo)
用CNN分类Mnist http://www.bubuko.com/infodetail-777299.html /DeepLearning Tutorials/keras_usage 提取出来的特征 ...
- 用标准3层神经网络实现MNIST识别
一.MINIST数据集下载 1.https://pjreddie.com/projects/mnist-in-csv/ 此网站提供了mnist_train.csv和mnist_test.cs ...
- 利用CNN进行流量识别 本质上就是将流量视作一个图像
from:https://netsec2018.files.wordpress.com/2017/12/e6b7b1e5baa6e5ada6e4b9a0e59ca8e7bd91e7bb9ce5ae89 ...
- 机器学习: Tensor Flow +CNN 做笑脸识别
Tensor Flow 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数 ...
- 机器学习: Tensor Flow with CNN 做表情识别
我们利用 TensorFlow 构造 CNN 做表情识别,我们用的是FER-2013 这个数据库, 这个数据库一共有 35887 张人脸图像,这里只是做一个简单到仿真实验,为了计算方便,我们用其中到 ...
- 机器不学习:CNN入门讲解-为什么要有最后一层全连接
哈哈哈,又到了讲段子的时间 准备好了吗? 今天要说的是CNN最后一层了,CNN入门就要讲完啦..... 先来一段官方的语言介绍全连接层(Fully Connected Layer) 全连接层常简称为 ...
- CNN入门讲解-为什么要有最后一层全连接?
原文地址:https://baijiahao.baidu.com/s?id=1590121601889191549&wfr=spider&for=pc 今天要说的是CNN最后一层了,C ...
- halcon视觉入门钢珠识别
halcon视觉入门钢珠识别 经过入门篇,我们有了基础的视觉识别知识.现在加以应用. 有如下图片: 我们需要识别图片中比较明亮的中间区域,有黑色的钢珠,我们需要知道他的位置和面积. 分析如何识别 编写 ...
- Halcon视觉入门芯片识别
Halcon视觉入门芯片识别 需求 有如下图的一个摆盘,摆盘的方格中摆放芯片,一个格子中只放一个,我们需要知道每个方格中是否有芯片去指导我们将芯片放到空的方格中. 分析 通过图片分析得出 我们感兴趣的 ...
随机推荐
- 如何进行基于Anolis OS的企业级Java应用规模化实践?|龙蜥技术
简介:提供了7×24小时的专属钉钉或者电话支持,响应时间保证到在业务不可用情况下10分钟响应,业务一般的问题在一小时可以获得响应,主要城市可以两小时内得到到达现场的服务. 本文作者郁磊,是Java语 ...
- Dataphin功能:集成——如何将业务系统的数据抽取汇聚到数据中台
简介: 数据集成是简单高效的数据同步平台,致力于提供具有强大的数据预处理能力.丰富的异构数据源之间数据高速稳定的同步能力,为数据中台的建设打好坚实的数据基座. 数据中台是当下大数据领域最前沿的数据建 ...
- dotnet 6 已知问题 获取 CultureInfo.NumberFormat 可能抛出 IndexOutOfRangeException 异常
本文记录一个 dotnet 6 已知问题,准确来说这是一个在 dotnet 5 引入的问题,到 dotnet 6.0.12 还没修.在获取 CultureInfo.NumberFormat 属性时,在 ...
- WinForms 使用 Image 的 FromFile 方法加载文件和使用 Bitmap 有什么不同
本文来告诉大家使用 GDI+ 的 Image.FromFile 加载图片文件和使用创建 Bitmap 传入图片文件有什么不同 如使用下面代码加载图片 using var image = Image.F ...
- 云原生最佳实践系列 6:MSE 云原生网关使用 JWT 进行认证鉴权
01 方案概述 MSE 网关可以为后端服务提供转发路由能力,在此基础上,一些敏感的后端服务需要特定认证授权的用户才能够访问.MSE 云原生网关致力于提供给云上用户体系化的安全解决方案,其中 JWT 认 ...
- SpringBoot3.1.5对应新版本SpringCloud开发(1)-Eureka注册中心
服务的提供者和消费者 服务之间可以通过Spring提供的RestTemplate来进行http请求去请求另一个Springboot的项目,这就叫做服务间的远程调用. 当一个服务通过远程调用去调用另一个 ...
- Dash 2.17版本新特性介绍
本文示例代码已上传至我的Github仓库https://github.com/CNFeffery/dash-master 大家好我是费老师,不久前Dash发布了其2.17.0版本,执行下面的命令进行最 ...
- fastposter v2.18.0 一分钟完成开发海报-云服务来袭
fastposter v2.18.0 一分钟完成开发海报-云服务来袭 fastposter 是一款快速开发海报的工具,已经服务众多电商.行业海报.分销系统.电商海报.电商主图等海报生成和制作场景. 什 ...
- C语言:贮油点建设问题(详解题目意思)
!!!!先看解析,后面附有代码!!!!!!! ,希望大家不懂的能认真看看,这些都是我在写的过程中不能理解,遇到的困难,然后弄懂之后总结出来给大家的,想学的一定要认真看完. 规律是: 贮油点之间相差50 ...
- 智能勘探 | AIRIOT智慧油田管理解决方案
石油勘探和开采地处偏远地区,涉及面广且生产规模大.特殊的作业环境下,使得工作人员作业条件艰苦,仅靠人工值守难度很大,不可避免的遇到一系列硬核挑战: 1.设备维护难度较高: 2.采油厂分布地域广.分 ...