Smiling & Weeping

                ---- 下次你撑伞低头看水洼,

                  就会想起我说雨是神的烟花。

简介:主要是看刘二大人的视频讲解:https://www.bilibili.com/video/BV1Y7411d7Ys/?spm_id_from=333.337.search-card.all.click

题目及提交链接:Digit Recognizer | Kaggle

深度学习入门的学习项目,使用CNN(Convolutional Nerual Network)

对于Basic CNN的理解:

  1. 分成两个部分:前一个部分叫做Feature Extraction,后一部分叫做Classification(其中Feature Extraction又可以分为Convolution,Subsampling等)
  2. 其中要求卷积核的通道数量与输入通道数量一致。这种卷积核的总数和输出通道数目的总数一致(详见链接PDF)
  3. 卷积(convolution)后,C(channels),W(width),H(height),其中padding和pooling(小技巧:若要卷积W,H不变,取整kernel_size/2)
  4. 卷积层:保存图像的空间信息
  5. 卷积层要求输入输出是四维张量(B,C,W,H),全连接层的输入输出都是二维张量(B,Input_feature)
  6. 卷积(线性变换),激活函数(非线性变换),池化;这个过程若干次后,view打平,进入全连接层
  1 import torch
2 import torch.nn.functional as F
3 import torch.nn as nn
4 import torch.optim as optim
5 import torch.autograd as lr_scheduler
6 from torch.utils.data import DataLoader, Dataset
7 from torchvision import transforms
8 from torchvision.utils import make_grid
9 from torchvision import datasets
10 from torch.autograd import Variable
11 from sklearn.model_selection import train_test_split
12 import pandas as pd
13 import numpy as np
14 import matplotlib.pyplot as plt
15
16 batch_size = 64
17 transform = transforms.Compose([transforms.ToTensor()])
18 train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
19 train_loader = torch.utils.data.DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size)
20 #同样的方式加载一下测试集
21 test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
22 test_loader = torch.utils.data.DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size)
23
24 # 使用卷积神经网络进行图像特征提取
25 # (batch, 1, 28, 28) -> (batch, 10, 24, 24) -> 池化 (batch, 10, 12, 12) -> (batch, 20, 8, 8) -> (batch, 20 , 4, 4) -> (batch, 320) -> (batch, 10)
26 class Net(torch.nn.Module):
27 def __init__(self):
28 super(Net, self).__init__()
29 self.conv1 = torch.nn.Conv2d(1, 10, kernel_size = 5)
30 self.conv2 = torch.nn.Conv2d(10, 20, kernel_size = 5)
31 self.pooling = torch.nn.MaxPool2d(2)
32 self.fc = torch.nn.Linear(320, 10)
33
34 def forward(self, x):
35 # Flatten data from (n, 1, 28, 28) to (n, 784)
36 batch_size = x.size(0)
37 x = F.relu(self.pooling(self.conv1(x)))
38 x = F.relu(self.pooling(self.conv2(x)))
39 x = x.view(batch_size, -1) # Flatten
40 x = self.fc(x)
41 return x
42
43 model = Net()
44 # print(model)
45 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46 model.to(device)
47 criterion = torch.nn.CrossEntropyLoss(size_average=True)
48 optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
49
50 def train(epoch):
51 running_loss = 0.0
52 for batch_idx, data in enumerate(train_loader, 0):
53 inputs, target = data
54 inputs, target = inputs.to(device), target.to(device)
55 optimizer.zero_grad()
56
57 # forward + backward + update
58 outputs = model(inputs)
59 # 计算真实值 和 测量值 之间的误差
60 loss = criterion(outputs, target)
61 loss.backward()
62 optimizer.step()
63
64 running_loss += loss.item()
65 if batch_idx % 300 == 299:
66 print('[%d, %5d] loss: %3f' % (epoch + 1, batch_idx+1, running_loss / 2000))
67 running_loss = 0.0
68
69 def test():
70 correct = 0
71 total = 0
72 with torch.no_grad():
73 for data in test_loader:
74 inputs, target = data
75 inputs, target = inputs.to(device), target.to(device)
76 outputs = model(inputs)
77 _, prediction = torch.max(outputs.data, dim=1)
78 total += target.size(0)
79 correct += (prediction == target).sum().item()
80 print('Accuracy on test set: %d %% [%d/%d]' % (100*correct / total, correct, total))
81 return correct/total
82
83 epoch_list = []
84 acc_list = []
85 for epoch in range(10):
86 train(epoch)
87 acc = test()
88 epoch_list.append(epoch)
89 acc_list.append(acc)
90
91 plt.plot(epoch_list, acc_list)
92 plt.ylabel("accuracy")
93 plt.xlabel("epoch")
94 plt.show()
95
96 class DatasetSubmissionMNIST(torch.utils.data.Dataset):
97 def __init__(self, file_path, transform=None):
98 self.data = pd.read_csv(file_path)
99 self.transform = transform
100
101 def __len__(self):
102 return len(self.data)
103
104 def __getitem__(self, index):
105 image = self.data.iloc[index].values.astype(np.uint8).reshape((28, 28, 1))
106
107
108 if self.transform is not None:
109 image = self.transform(image)
110
111 return image
112
113 transform = transforms.Compose([
114 transforms.ToPILImage(),
115 transforms.ToTensor(),
116 transforms.Normalize(mean=(0.5,), std=(0.5,))
117 ])
118
119 submissionset = DatasetSubmissionMNIST('/kaggle/input/digit-recognizer/test.csv', transform=transform)
120 submissionloader = torch.utils.data.DataLoader(submissionset, batch_size=batch_size, shuffle=False)
121
122 submission = [['ImageId', 'Label']]
123
124 with torch.no_grad():
125 model.eval()
126 image_id = 1
127
128 for images in submissionloader:
129 images = images.cuda()
130 log_ps = model(images)
131 ps = torch.exp(log_ps)
132 top_p, top_class = ps.topk(1, dim=1)
133
134 for prediction in top_class:
135 submission.append([image_id, prediction.item()])
136 image_id += 1
137
138 print(len(submission) - 1)
139 import csv
140
141 with open('submission.csv', 'w') as submissionFile:
142 writer = csv.writer(submissionFile)
143 writer.writerows(submission)
144
145 print('Submission Complete!')
146 # summission.to_csv('/kaggle/working/submission.csv', index=False)

就效果来说,也就一般,后面的Advance CNN 会有更高的效率和准确性,大家可以敲一下代码放在自己的编译器上跑一下

对了,这是GPU版本,若用CPU,把所有的device删除就可以,--<-<-<@

文章到此结束,我们下次再见

一束光线,可能会摔碎

                                 但仍旧光芒四射

CNN --入门MNIST识别的更多相关文章

  1. 使用tensorflow实现cnn进行mnist识别

    第一个CNN代码,暂时对于CNN的BP还不熟悉.但是通过这个代码对于tensorflow的运行机制有了初步的理解 ''' softmax classifier for mnist created on ...

  2. NLP用CNN分类Mnist,提取出来的特征训练SVM及Keras的使用(demo)

    用CNN分类Mnist http://www.bubuko.com/infodetail-777299.html /DeepLearning Tutorials/keras_usage 提取出来的特征 ...

  3. 用标准3层神经网络实现MNIST识别

    一.MINIST数据集下载 1.https://pjreddie.com/projects/mnist-in-csv/      此网站提供了mnist_train.csv和mnist_test.cs ...

  4. 利用CNN进行流量识别 本质上就是将流量视作一个图像

    from:https://netsec2018.files.wordpress.com/2017/12/e6b7b1e5baa6e5ada6e4b9a0e59ca8e7bd91e7bb9ce5ae89 ...

  5. 机器学习: Tensor Flow +CNN 做笑脸识别

    Tensor Flow 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数 ...

  6. 机器学习: Tensor Flow with CNN 做表情识别

    我们利用 TensorFlow 构造 CNN 做表情识别,我们用的是FER-2013 这个数据库, 这个数据库一共有 35887 张人脸图像,这里只是做一个简单到仿真实验,为了计算方便,我们用其中到 ...

  7. 机器不学习:CNN入门讲解-为什么要有最后一层全连接

    哈哈哈,又到了讲段子的时间 准备好了吗? 今天要说的是CNN最后一层了,CNN入门就要讲完啦..... 先来一段官方的语言介绍全连接层(Fully Connected Layer) 全连接层常简称为 ...

  8. CNN入门讲解-为什么要有最后一层全连接?

    原文地址:https://baijiahao.baidu.com/s?id=1590121601889191549&wfr=spider&for=pc 今天要说的是CNN最后一层了,C ...

  9. halcon视觉入门钢珠识别

    halcon视觉入门钢珠识别 经过入门篇,我们有了基础的视觉识别知识.现在加以应用. 有如下图片: 我们需要识别图片中比较明亮的中间区域,有黑色的钢珠,我们需要知道他的位置和面积. 分析如何识别 编写 ...

  10. Halcon视觉入门芯片识别

    Halcon视觉入门芯片识别 需求 有如下图的一个摆盘,摆盘的方格中摆放芯片,一个格子中只放一个,我们需要知道每个方格中是否有芯片去指导我们将芯片放到空的方格中. 分析 通过图片分析得出 我们感兴趣的 ...

随机推荐

  1. Dubbo-go v3.0 正式发布 ——打造国内一流开源 Go 服务框架

    ​简介:Dubbo-go 是常新的,每年都在不断进化.介绍 Dubbo-go 3.0 工作之前,先回顾其过往 6 年的发展历程,以明晰未来的方向. ​ 作者 | 李志信 来源 | 阿里技术公众号 作者 ...

  2. Dubbo 跨语言调用神兽:dubbo-go-pixiu

    简介: Pixiu 是基于 Dubbogo 的云原生.高性能.可扩展的微服务 API 网关.作为一款网关产品,Pixiu 帮助用户轻松创建.发布.维护.监控和保护任意规模的 API ,接受和处理成千上 ...

  3. [FAQ] 设置 npm 镜像源

    查看 npm 源: $ npm config get registry> http://registry.npmjs.org/ 修改 npm 源: $ npm config set regist ...

  4. 一分钟部署 Llama3 中文大模型,没别的,就是快

    前段时间百度创始人李彦宏信誓旦旦地说开源大模型会越来越落后,闭源模型会持续领先.随后小扎同学就给了他当头一棒,向他展示了什么叫做顶级开源大模型. 美国当地时间4月18日,Meta 在官网上发布了两款开 ...

  5. Vs2019在发布过程中遇到xxx-Web.config Connection String"参数不能为 Null 或 空 的错误

    原文地址:https://www.zhaimaojun.top/Note/5465234 如下图: 当使用的数据库更换或者修改后数据库字段会失效,当我们从webconfig中清除数据库字段后,依然会报 ...

  6. 启动docker某个image(镜像)的已经关闭的container(容器)

    1.创建一个后台运行 ubuntu 容器 root@haima-PC:/home/haima/Desktop# docker run -d --name ubuntu-lnmp ubuntu bf24 ...

  7. tomcat(3)- tomcat部署zrlog

    目录 1. Tomcat单独部署 2. nginx+tomcat部署 1. Tomcat单独部署 部署场景为: 客户端:192.168.20.1 tomcat:主机名:tomcat01,地址:192. ...

  8. P2421-荒岛野人Savage题解

    好久没写题解了啊 洛谷P2421 荒岛野人 题目大意:有一个有很多洞的岛上,住了\(n\)个野人,每个野人的初始位置为\(c[i]\),换洞的速度为\(p[i]\),寿命为\(l[i]\).要求求出洞 ...

  9. Intel HDSLB 高性能四层负载均衡器 — 快速入门和应用场景

    目录 目录 目录 前言与背景 传统 LB 技术的局限性 HDSLB 的特点和优势 HDSLB 的性能参数 基准性能数据 对标竞品 HDSLB 的应用场景 HDSLB 的发展前景 参考文档 前言与背景 ...

  10. 用 C 语言开发一门编程语言 — 字符串与文件加载

    目录 文章目录 目录 前文列表 字符串 读取字符串 注释 文件加载函数 命令行参数 打印函数 报错函数 源代码 前文列表 <用 C 语言开发一门编程语言 - 交互式解析器> <用 C ...