用Pytorch写了两个CNN网络,数据集用的是FashionMNIST。其中CNN_1只有一个卷积层、一个全连接层,CNN_2有两个卷积层、一个全连接层,但训练完之后的准确率两者差不多,且CNN_1训练时间短得多,且跟两层的全连接的准确性也差不多,看来深度学习水很深,还需要进一步调参和调整网络结构。

CNN_1:

runnig time:29.795 sec.
accuracy: 0.8688

CNN_2:

runnig time:165.101 sec.
accuracy: 0.8837

 import time
import torch.nn as nn
from torchvision.datasets import FashionMNIST
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.data as Data
import matplotlib.pyplot as plt #import os
#os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
'''数据集为FashionMNIST'''
data=FashionMNIST('../pycharm_workspace/data/') def train_test_split(data,test_pct=0.3):
test_len=int(data.data.size(0)*test_pct)
x_test=data.data[0:test_len].type(torch.float)
x_train=data.data[test_len:].type(torch.float) y_test=data.targets[0:test_len]
y_train=data.targets[test_len:] return x_train,y_train,x_test,y_test def cal_accuracy(model,x_test,y_test,samples=10000):
'''取一定数量的样本,用于评估'''
y_pred=model(x_test[:samples])
'''把模型输出(向量)转为label形式'''
y_pred_=list(map(lambda x:np.argmax(x),y_pred.data.numpy()))
'''计算准确率'''
acc=sum(y_pred_==y_test.numpy()[:samples])/samples
return acc class CNN_1(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
32,#out_channels,即filter的数量
4,#kernel_size,4代表(4,4)即正方形的filter,若为长方形,则(height,width)
stride=2,#filter移动的步长,2代表(2,2)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out=nn.Linear(32*7*7,10) def forward(self,x):
x=self.conv1(x)
temp=x.view(x.shape[0],-1)
out=self.out(temp)
return out class CNN_2(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
32,#out_channels,即filter的数量
5,#kernel_size,3代表(3,3)即正方形的filter,若为长方形,则(height,width)
stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2=nn.Sequential(
nn.Conv2d(32,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
16,#out_channels,即filter的数量
5,#kernel_size,5代表(5,5)即正方形的filter,若为长方形,则(height,width)
stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out=nn.Linear(16*7*7,10) def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=x.view(x.size(0),-1)
out=self.out(x)
return out def train_3():
num_epoch=5
#t_data=data.data.type(torch.float)
x_train,y_train,x_test,y_test=train_test_split(data,0.2)
'''使用DataLoader批量输入训练数据'''
dl_train=DataLoader(Data.TensorDataset(x_train,y_train),batch_size=100,shuffle=True)
'''创建模型对象'''
model=CNN_2()
'''定义损失函数'''
loss_func=torch.nn.CrossEntropyLoss()
'''定义优化器'''
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
start=time.time() acc_hist=[]
loss_hist=[]
for i in range(num_epoch):
for index,(x_data,y_data) in enumerate(dl_train):
prediction=model(torch.unsqueeze(x_data, dim=1))
loss=loss_func(prediction,y_data)
print('No.%s,loss=%.3f'%(index+1,loss.data.numpy()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_val=loss.data.numpy()
if i==0:
acc=cal_acc(prediction,y_data)
acc_hist.append(acc)
loss_hist.append(loss_val)
print('No.%s,loss=%.3f'%(i+1,loss_val))
#loss_hist.append(loss_val)
#acc=cal_accuracy(model,x_test,y_test,samples=10000)
#acc_hist.append(acc)
print('acc=',acc) end=time.time()
print('runnig time:%.3f sec.'%(end-start))
acc=cal_accuracy(model,torch.unsqueeze(x_test,dim=1),y_test,samples=10000)
print('accuracy:',acc) if __name__=='__main__':
train_3()

Pytorch写CNN的更多相关文章

  1. ubuntu之路——day18 用pytorch完成CNN

    本次作业:Andrew Ng的CNN的搭建卷积神经网络模型以及应用(1&2)作业目录参考这位博主的整理:https://blog.csdn.net/u013733326/article/det ...

  2. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  3. Pytorch和CNN图像分类

    Pytorch和CNN图像分类 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强大的GPU加速 ...

  4. pytorch写一个LeNet网络

    我们先介绍下pytorch中的cnn网络 学过深度卷积网络的应该都非常熟悉这张demo图(LeNet): 先不管怎么训练,我们必须先构建出一个CNN网络,很快我们写了一段关于这个LeNet的代码,并进 ...

  5. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  6. 1、pytorch写的第一个Linear模型(原始版,不调用nn.Modules模块)

    参考: https://github.com/Iallen520/lhy_DL_Hw/blob/master/PyTorch_Introduction.ipynb 模拟一个回归模型,y = X * w ...

  7. pytorch 8 CNN 卷积神经网络

    # library # standard library import os # third-party library import torch import torch.nn as nn impo ...

  8. 奉献pytorch 搭建 CNN 卷积神经网络训练图像识别的模型,配合numpy 和matplotlib 一起使用调用 cuda GPU进行加速训练

    1.Torch构建简单的模型 # coding:utf-8 import torch class Net(torch.nn.Module): def __init__(self,img_rgb=3,i ...

  9. pytorch之 CNN

    # library # standard library import os # third-party library import torch import torch.nn as nn impo ...

随机推荐

  1. 2) 接口规范 原生django接口、单查群查 postman工具 CBV源码解析

    内容了解 """ .接口:什么是接口.restful接口规范 .CBV生命周期源码 - 基于restful规范下的CBV接口 .请求组件.解析组件.响应组件 .序列化组件 ...

  2. 经过踩坑,搭建成功的Appium自动化测试环境

    因为最近本人准备搞app自动化,所以就搭建环境过程记录下来(主要踩过好几个坑) 期间有点烦躁,后面调整了下心态还是成功弄好了. 一.Appium环境搭建准备软件 所需要到的软件如下: 1.安装JDK1 ...

  3. windows下flume 采集如何支持TAILDIR和tail

    一.问题:Windows 下 flume采集配置TAILDIR的时候,会报如下错误: agent.sources.seqGenSrc.type = TAILDIR agent.sources.seqG ...

  4. python统计英文文本中的回文单词数

    1. 要求: 给定一篇纯英文的文本,统计其中回文单词的比列,并输出其中的回文单词,文本数据如下: This is Everyday Grammar. I am Madam Lucija And I a ...

  5. 从卷积拆分和分组的角度看CNN模型的演化

    博客:博客园 | CSDN | blog 写在前面 如题,这篇文章将尝试从卷积拆分的角度看一看各种经典CNN backbone网络module是如何演进的,为了视角的统一,仅分析单条路径上的卷积形式. ...

  6. Python基础语法day_02——字符串规则

    day_02 使用方法修改字符串的大小写 将字符串首字母变成大写 >>> name = "ada lovelace" >>> print(nam ...

  7. 学习笔记:平衡树-splay

    嗯好的今天我们来谈谈cosplay splay是一种操作,是一种调整二叉排序树的操作,但是它并不会时时刻刻保持一个平衡,因为它会根据每一次操作把需要操作的点旋转到根节点上 所谓二叉排序树,就是满足对树 ...

  8. php基本语法学习

    1.基本的 PHP 语法 PHP 脚本可以放在文档中的任何位置. PHP 脚本以 <?php 开始,以 ?> 结束: <?php// PHP 代码?>   2.简单的脚本-输出 ...

  9. 【雕爷学编程】Arduino动手做(58)---SR04超声波传感器

    37款传感器与执行器的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止这37种的.鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为 ...

  10. Django的ListView超详细用法(含分页paginate功能)

    开发环境: python 3.6 django 1.11 场景一 经常有从数据库中获取一批数据,然后在前端以列表的形式展现,比如:获取到所有的用户,然后在用户列表页面展示. 解决方案 常规写法是,我们 ...