Smiling & Weeping

                ---- 下次你撑伞低头看水洼,

                  就会想起我说雨是神的烟花。

简介:主要是看刘二大人的视频讲解:https://www.bilibili.com/video/BV1Y7411d7Ys/?spm_id_from=333.337.search-card.all.click

题目及提交链接:Digit Recognizer | Kaggle

深度学习入门的学习项目,使用CNN(Convolutional Nerual Network)

对于Basic CNN的理解:

  1. 分成两个部分:前一个部分叫做Feature Extraction,后一部分叫做Classification(其中Feature Extraction又可以分为Convolution,Subsampling等)
  2. 其中要求卷积核的通道数量与输入通道数量一致。这种卷积核的总数和输出通道数目的总数一致(详见链接PDF)
  3. 卷积(convolution)后,C(channels),W(width),H(height),其中padding和pooling(小技巧:若要卷积W,H不变,取整kernel_size/2)
  4. 卷积层:保存图像的空间信息
  5. 卷积层要求输入输出是四维张量(B,C,W,H),全连接层的输入输出都是二维张量(B,Input_feature)
  6. 卷积(线性变换),激活函数(非线性变换),池化;这个过程若干次后,view打平,进入全连接层
  1 import torch
2 import torch.nn.functional as F
3 import torch.nn as nn
4 import torch.optim as optim
5 import torch.autograd as lr_scheduler
6 from torch.utils.data import DataLoader, Dataset
7 from torchvision import transforms
8 from torchvision.utils import make_grid
9 from torchvision import datasets
10 from torch.autograd import Variable
11 from sklearn.model_selection import train_test_split
12 import pandas as pd
13 import numpy as np
14 import matplotlib.pyplot as plt
15
16 batch_size = 64
17 transform = transforms.Compose([transforms.ToTensor()])
18 train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
19 train_loader = torch.utils.data.DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size)
20 #同样的方式加载一下测试集
21 test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
22 test_loader = torch.utils.data.DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size)
23
24 # 使用卷积神经网络进行图像特征提取
25 # (batch, 1, 28, 28) -> (batch, 10, 24, 24) -> 池化 (batch, 10, 12, 12) -> (batch, 20, 8, 8) -> (batch, 20 , 4, 4) -> (batch, 320) -> (batch, 10)
26 class Net(torch.nn.Module):
27 def __init__(self):
28 super(Net, self).__init__()
29 self.conv1 = torch.nn.Conv2d(1, 10, kernel_size = 5)
30 self.conv2 = torch.nn.Conv2d(10, 20, kernel_size = 5)
31 self.pooling = torch.nn.MaxPool2d(2)
32 self.fc = torch.nn.Linear(320, 10)
33
34 def forward(self, x):
35 # Flatten data from (n, 1, 28, 28) to (n, 784)
36 batch_size = x.size(0)
37 x = F.relu(self.pooling(self.conv1(x)))
38 x = F.relu(self.pooling(self.conv2(x)))
39 x = x.view(batch_size, -1) # Flatten
40 x = self.fc(x)
41 return x
42
43 model = Net()
44 # print(model)
45 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46 model.to(device)
47 criterion = torch.nn.CrossEntropyLoss(size_average=True)
48 optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
49
50 def train(epoch):
51 running_loss = 0.0
52 for batch_idx, data in enumerate(train_loader, 0):
53 inputs, target = data
54 inputs, target = inputs.to(device), target.to(device)
55 optimizer.zero_grad()
56
57 # forward + backward + update
58 outputs = model(inputs)
59 # 计算真实值 和 测量值 之间的误差
60 loss = criterion(outputs, target)
61 loss.backward()
62 optimizer.step()
63
64 running_loss += loss.item()
65 if batch_idx % 300 == 299:
66 print('[%d, %5d] loss: %3f' % (epoch + 1, batch_idx+1, running_loss / 2000))
67 running_loss = 0.0
68
69 def test():
70 correct = 0
71 total = 0
72 with torch.no_grad():
73 for data in test_loader:
74 inputs, target = data
75 inputs, target = inputs.to(device), target.to(device)
76 outputs = model(inputs)
77 _, prediction = torch.max(outputs.data, dim=1)
78 total += target.size(0)
79 correct += (prediction == target).sum().item()
80 print('Accuracy on test set: %d %% [%d/%d]' % (100*correct / total, correct, total))
81 return correct/total
82
83 epoch_list = []
84 acc_list = []
85 for epoch in range(10):
86 train(epoch)
87 acc = test()
88 epoch_list.append(epoch)
89 acc_list.append(acc)
90
91 plt.plot(epoch_list, acc_list)
92 plt.ylabel("accuracy")
93 plt.xlabel("epoch")
94 plt.show()
95
96 class DatasetSubmissionMNIST(torch.utils.data.Dataset):
97 def __init__(self, file_path, transform=None):
98 self.data = pd.read_csv(file_path)
99 self.transform = transform
100
101 def __len__(self):
102 return len(self.data)
103
104 def __getitem__(self, index):
105 image = self.data.iloc[index].values.astype(np.uint8).reshape((28, 28, 1))
106
107
108 if self.transform is not None:
109 image = self.transform(image)
110
111 return image
112
113 transform = transforms.Compose([
114 transforms.ToPILImage(),
115 transforms.ToTensor(),
116 transforms.Normalize(mean=(0.5,), std=(0.5,))
117 ])
118
119 submissionset = DatasetSubmissionMNIST('/kaggle/input/digit-recognizer/test.csv', transform=transform)
120 submissionloader = torch.utils.data.DataLoader(submissionset, batch_size=batch_size, shuffle=False)
121
122 submission = [['ImageId', 'Label']]
123
124 with torch.no_grad():
125 model.eval()
126 image_id = 1
127
128 for images in submissionloader:
129 images = images.cuda()
130 log_ps = model(images)
131 ps = torch.exp(log_ps)
132 top_p, top_class = ps.topk(1, dim=1)
133
134 for prediction in top_class:
135 submission.append([image_id, prediction.item()])
136 image_id += 1
137
138 print(len(submission) - 1)
139 import csv
140
141 with open('submission.csv', 'w') as submissionFile:
142 writer = csv.writer(submissionFile)
143 writer.writerows(submission)
144
145 print('Submission Complete!')
146 # summission.to_csv('/kaggle/working/submission.csv', index=False)

就效果来说,也就一般,后面的Advance CNN 会有更高的效率和准确性,大家可以敲一下代码放在自己的编译器上跑一下

对了,这是GPU版本,若用CPU,把所有的device删除就可以,--<-<-<@

文章到此结束,我们下次再见

一束光线,可能会摔碎

                                 但仍旧光芒四射

CNN --入门MNIST识别的更多相关文章

  1. 使用tensorflow实现cnn进行mnist识别

    第一个CNN代码,暂时对于CNN的BP还不熟悉.但是通过这个代码对于tensorflow的运行机制有了初步的理解 ''' softmax classifier for mnist created on ...

  2. NLP用CNN分类Mnist,提取出来的特征训练SVM及Keras的使用(demo)

    用CNN分类Mnist http://www.bubuko.com/infodetail-777299.html /DeepLearning Tutorials/keras_usage 提取出来的特征 ...

  3. 用标准3层神经网络实现MNIST识别

    一.MINIST数据集下载 1.https://pjreddie.com/projects/mnist-in-csv/      此网站提供了mnist_train.csv和mnist_test.cs ...

  4. 利用CNN进行流量识别 本质上就是将流量视作一个图像

    from:https://netsec2018.files.wordpress.com/2017/12/e6b7b1e5baa6e5ada6e4b9a0e59ca8e7bd91e7bb9ce5ae89 ...

  5. 机器学习: Tensor Flow +CNN 做笑脸识别

    Tensor Flow 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数 ...

  6. 机器学习: Tensor Flow with CNN 做表情识别

    我们利用 TensorFlow 构造 CNN 做表情识别,我们用的是FER-2013 这个数据库, 这个数据库一共有 35887 张人脸图像,这里只是做一个简单到仿真实验,为了计算方便,我们用其中到 ...

  7. 机器不学习:CNN入门讲解-为什么要有最后一层全连接

    哈哈哈,又到了讲段子的时间 准备好了吗? 今天要说的是CNN最后一层了,CNN入门就要讲完啦..... 先来一段官方的语言介绍全连接层(Fully Connected Layer) 全连接层常简称为 ...

  8. CNN入门讲解-为什么要有最后一层全连接?

    原文地址:https://baijiahao.baidu.com/s?id=1590121601889191549&wfr=spider&for=pc 今天要说的是CNN最后一层了,C ...

  9. halcon视觉入门钢珠识别

    halcon视觉入门钢珠识别 经过入门篇,我们有了基础的视觉识别知识.现在加以应用. 有如下图片: 我们需要识别图片中比较明亮的中间区域,有黑色的钢珠,我们需要知道他的位置和面积. 分析如何识别 编写 ...

  10. Halcon视觉入门芯片识别

    Halcon视觉入门芯片识别 需求 有如下图的一个摆盘,摆盘的方格中摆放芯片,一个格子中只放一个,我们需要知道每个方格中是否有芯片去指导我们将芯片放到空的方格中. 分析 通过图片分析得出 我们感兴趣的 ...

随机推荐

  1. OpenKruise v0.10.0 版本发布:新增应用弹性拓扑管理、应用防护等能力

    简介: 阿里云开源的云原生应用自动化管理套件.CNCF Sandbox 项目 -- OpenKruise,今天发布 v0.10.0 新版本,这也会是 OpenKruise v1.0 之前的最后一个 m ...

  2. WPF 自定义控件入门 Focusable 与焦点

    自定义控件时,如果自定义的控件需要用来接收键盘消息或者是输入法的输入内容,那就需要关注到控件的焦点 默认情况下的自定义控件是没有带可获取焦点的功能的,例如编写一个继承 FrameworkElement ...

  3. dotnet C# 如何正确获取藏文的字数

    在咱国内有很多有趣的文字,其中藏文属于有趣的文字里面特别有趣的一项,特别是对于做文本库的同学,大概都知道什么叫合写字吧.合写字的含义就是多个字符一起组成一个字.但是多个字符在内存中,本身就是多个字符对 ...

  4. 一:大数据架构回顾-Lambda架构

    "我们正在从IT时代走向DT时代(数据时代).IT和DT之间,不仅仅是技术的变革,更是思想意识的变革,IT主要是为自我服务,用来更好地自我控制和管理,DT则是激活生产力,让别人活得比你好&q ...

  5. PyTorch的安装与使用

    技术背景 PyTorch是一个非常常用的AI框架,主要归功于其简单易用的特点,深受广大科研人员的喜爱.在前面的一篇文章中我们介绍过制作PyTorch的Singularity镜像的方法,这里我们单独抽出 ...

  6. Pytorch param.grad.data. 出现 AttributeError: ‘NoneType‘ object has no attribute ‘data‘

    程序中有需要优化的参数未参与前向传播.

  7. linux服务器配置查看

    查看linux服务器配置 查硬盘信息 sblk 看sda sdb sdc之类的 以下可以看出是500G sda第一块,sdb是第二块 以下可以看出是 1T+100G 查内存 free -h 查cpu ...

  8. pageoffice 6 实现pdf加盖印章和签字功能

    PageOffice支持两种电子印章方案,可实现对Word.Excel.PDF文档加盖PageOffice自带印章或ZoomSeal电子印章(全方位保护.防篡改.防伪造).Word和Excel的盖章功 ...

  9. 【PB案例学习笔记】-02 目录浏览器

    写在前面 这是PB案例学习笔记系列文章的第二篇,该系列文章适合具有一定PB基础的读者, 通过一个个由浅入深的编程实战案例学习,提高编程技巧,以保证小伙伴们能应付公司的各种开发需求. 文章中设计到的源码 ...

  10. sass语法嵌套规则与注释讲解

    语法嵌套规则 选择器嵌套 例如有这么一段css,正常CSS的写法 .container{width:1200px; margin: 0 auto;} .container .header{height ...