用Pytorch写了两个CNN网络,数据集用的是FashionMNIST。其中CNN_1只有一个卷积层、一个全连接层,CNN_2有两个卷积层、一个全连接层,但训练完之后的准确率两者差不多,且CNN_1训练时间短得多,且跟两层的全连接的准确性也差不多,看来深度学习水很深,还需要进一步调参和调整网络结构。

CNN_1:

runnig time:29.795 sec.
accuracy: 0.8688

CNN_2:

runnig time:165.101 sec.
accuracy: 0.8837

 import time
import torch.nn as nn
from torchvision.datasets import FashionMNIST
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.data as Data
import matplotlib.pyplot as plt #import os
#os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
'''数据集为FashionMNIST'''
data=FashionMNIST('../pycharm_workspace/data/') def train_test_split(data,test_pct=0.3):
test_len=int(data.data.size(0)*test_pct)
x_test=data.data[0:test_len].type(torch.float)
x_train=data.data[test_len:].type(torch.float) y_test=data.targets[0:test_len]
y_train=data.targets[test_len:] return x_train,y_train,x_test,y_test def cal_accuracy(model,x_test,y_test,samples=10000):
'''取一定数量的样本,用于评估'''
y_pred=model(x_test[:samples])
'''把模型输出(向量)转为label形式'''
y_pred_=list(map(lambda x:np.argmax(x),y_pred.data.numpy()))
'''计算准确率'''
acc=sum(y_pred_==y_test.numpy()[:samples])/samples
return acc class CNN_1(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
32,#out_channels,即filter的数量
4,#kernel_size,4代表(4,4)即正方形的filter,若为长方形,则(height,width)
stride=2,#filter移动的步长,2代表(2,2)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out=nn.Linear(32*7*7,10) def forward(self,x):
x=self.conv1(x)
temp=x.view(x.shape[0],-1)
out=self.out(temp)
return out class CNN_2(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
32,#out_channels,即filter的数量
5,#kernel_size,3代表(3,3)即正方形的filter,若为长方形,则(height,width)
stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2=nn.Sequential(
nn.Conv2d(32,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
16,#out_channels,即filter的数量
5,#kernel_size,5代表(5,5)即正方形的filter,若为长方形,则(height,width)
stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out=nn.Linear(16*7*7,10) def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=x.view(x.size(0),-1)
out=self.out(x)
return out def train_3():
num_epoch=5
#t_data=data.data.type(torch.float)
x_train,y_train,x_test,y_test=train_test_split(data,0.2)
'''使用DataLoader批量输入训练数据'''
dl_train=DataLoader(Data.TensorDataset(x_train,y_train),batch_size=100,shuffle=True)
'''创建模型对象'''
model=CNN_2()
'''定义损失函数'''
loss_func=torch.nn.CrossEntropyLoss()
'''定义优化器'''
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
start=time.time() acc_hist=[]
loss_hist=[]
for i in range(num_epoch):
for index,(x_data,y_data) in enumerate(dl_train):
prediction=model(torch.unsqueeze(x_data, dim=1))
loss=loss_func(prediction,y_data)
print('No.%s,loss=%.3f'%(index+1,loss.data.numpy()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_val=loss.data.numpy()
if i==0:
acc=cal_acc(prediction,y_data)
acc_hist.append(acc)
loss_hist.append(loss_val)
print('No.%s,loss=%.3f'%(i+1,loss_val))
#loss_hist.append(loss_val)
#acc=cal_accuracy(model,x_test,y_test,samples=10000)
#acc_hist.append(acc)
print('acc=',acc) end=time.time()
print('runnig time:%.3f sec.'%(end-start))
acc=cal_accuracy(model,torch.unsqueeze(x_test,dim=1),y_test,samples=10000)
print('accuracy:',acc) if __name__=='__main__':
train_3()

Pytorch写CNN的更多相关文章

  1. ubuntu之路——day18 用pytorch完成CNN

    本次作业:Andrew Ng的CNN的搭建卷积神经网络模型以及应用(1&2)作业目录参考这位博主的整理:https://blog.csdn.net/u013733326/article/det ...

  2. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  3. Pytorch和CNN图像分类

    Pytorch和CNN图像分类 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强大的GPU加速 ...

  4. pytorch写一个LeNet网络

    我们先介绍下pytorch中的cnn网络 学过深度卷积网络的应该都非常熟悉这张demo图(LeNet): 先不管怎么训练,我们必须先构建出一个CNN网络,很快我们写了一段关于这个LeNet的代码,并进 ...

  5. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  6. 1、pytorch写的第一个Linear模型(原始版,不调用nn.Modules模块)

    参考: https://github.com/Iallen520/lhy_DL_Hw/blob/master/PyTorch_Introduction.ipynb 模拟一个回归模型,y = X * w ...

  7. pytorch 8 CNN 卷积神经网络

    # library # standard library import os # third-party library import torch import torch.nn as nn impo ...

  8. 奉献pytorch 搭建 CNN 卷积神经网络训练图像识别的模型,配合numpy 和matplotlib 一起使用调用 cuda GPU进行加速训练

    1.Torch构建简单的模型 # coding:utf-8 import torch class Net(torch.nn.Module): def __init__(self,img_rgb=3,i ...

  9. pytorch之 CNN

    # library # standard library import os # third-party library import torch import torch.nn as nn impo ...

随机推荐

  1. muduo网络库源码学习————无界队列和有界队列

    muduo库里实现了两个队列模板类:无界队列为BlockingQueue.h,有界队列为BoundedBlockingQueue.h,两个测试程序实现了生产者和消费者模型.(这里以无界队列为例,有界队 ...

  2. muduo网络库源码学习————线程类

    muduo库里面的线程类是使用基于对象的编程思想,源码目录为muduo/base,如下所示: 线程类头文件: // Use of this source code is governed by a B ...

  3. libevent(九)evhttp

    用libevent构建一个http server非常方便,可参考libevent(六)http server. 主要涉及的一个结构体是evhttp: struct evhttp { /* Next v ...

  4. C. Fountains

    \(整体思路没错,但是我貌似太麻烦了.......\) \(分情况讨论\) \(Ⅰ.coin和diamond各选一个物品,这个简单\) \(Ⅱ.在coin中选两个或者在diamond选两个\) \(开 ...

  5. (三)Bean生命周期

    1 Bean注册 应用启动实质是调用Spring容器启动方法扫描配置加载bean到Spring容器中.同时启动内置的Web容器的过程,具体分析如下: @SpringBootApplication注解在 ...

  6. react中this.setState的理解

    this.setState作用? 在react中要修改this.state要使用this.setState,因为this.state只是一个对象,单纯的修改state并不会触发ui更新.所以我们需要用 ...

  7. leetcode_二叉树验证(BFS、哈希集合)

    题目描述: 二叉树上有 n 个节点,按从 0 到 n - 1 编号,其中节点 i 的两个子节点分别是 leftChild[i] 和 rightChild[i]. 只有 所有 节点能够形成且 只 形成 ...

  8. Qt读写xml文件

    写xml <root> <element> <sub id=-1></sub> </element> </root> //添加x ...

  9. Web快速输入标签

    在书写web代码的时候,掌握一些快捷输入方式不仅可以提高效率,还能省不少力气. 1. > :下一个子标签,如 div>p 加Tab达到: <div><p></ ...

  10. python 数据类型: 字符串String / 列表List / 元组Tuple / 集合Set / 字典Dictionary

    #python中标准数据类型 字符串String 列表List 元组Tuple 集合Set 字典Dictionary 铭记:变量无类型,对象有类型 #单个变量赋值 countn00 = '; #整数 ...