用Pytorch写了两个CNN网络,数据集用的是FashionMNIST。其中CNN_1只有一个卷积层、一个全连接层,CNN_2有两个卷积层、一个全连接层,但训练完之后的准确率两者差不多,且CNN_1训练时间短得多,且跟两层的全连接的准确性也差不多,看来深度学习水很深,还需要进一步调参和调整网络结构。

CNN_1:

runnig time:29.795 sec.
accuracy: 0.8688

CNN_2:

runnig time:165.101 sec.
accuracy: 0.8837

 import time
import torch.nn as nn
from torchvision.datasets import FashionMNIST
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.data as Data
import matplotlib.pyplot as plt #import os
#os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
'''数据集为FashionMNIST'''
data=FashionMNIST('../pycharm_workspace/data/') def train_test_split(data,test_pct=0.3):
test_len=int(data.data.size(0)*test_pct)
x_test=data.data[0:test_len].type(torch.float)
x_train=data.data[test_len:].type(torch.float) y_test=data.targets[0:test_len]
y_train=data.targets[test_len:] return x_train,y_train,x_test,y_test def cal_accuracy(model,x_test,y_test,samples=10000):
'''取一定数量的样本,用于评估'''
y_pred=model(x_test[:samples])
'''把模型输出(向量)转为label形式'''
y_pred_=list(map(lambda x:np.argmax(x),y_pred.data.numpy()))
'''计算准确率'''
acc=sum(y_pred_==y_test.numpy()[:samples])/samples
return acc class CNN_1(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
32,#out_channels,即filter的数量
4,#kernel_size,4代表(4,4)即正方形的filter,若为长方形,则(height,width)
stride=2,#filter移动的步长,2代表(2,2)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out=nn.Linear(32*7*7,10) def forward(self,x):
x=self.conv1(x)
temp=x.view(x.shape[0],-1)
out=self.out(temp)
return out class CNN_2(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
32,#out_channels,即filter的数量
5,#kernel_size,3代表(3,3)即正方形的filter,若为长方形,则(height,width)
stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2=nn.Sequential(
nn.Conv2d(32,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
16,#out_channels,即filter的数量
5,#kernel_size,5代表(5,5)即正方形的filter,若为长方形,则(height,width)
stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out=nn.Linear(16*7*7,10) def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=x.view(x.size(0),-1)
out=self.out(x)
return out def train_3():
num_epoch=5
#t_data=data.data.type(torch.float)
x_train,y_train,x_test,y_test=train_test_split(data,0.2)
'''使用DataLoader批量输入训练数据'''
dl_train=DataLoader(Data.TensorDataset(x_train,y_train),batch_size=100,shuffle=True)
'''创建模型对象'''
model=CNN_2()
'''定义损失函数'''
loss_func=torch.nn.CrossEntropyLoss()
'''定义优化器'''
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
start=time.time() acc_hist=[]
loss_hist=[]
for i in range(num_epoch):
for index,(x_data,y_data) in enumerate(dl_train):
prediction=model(torch.unsqueeze(x_data, dim=1))
loss=loss_func(prediction,y_data)
print('No.%s,loss=%.3f'%(index+1,loss.data.numpy()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_val=loss.data.numpy()
if i==0:
acc=cal_acc(prediction,y_data)
acc_hist.append(acc)
loss_hist.append(loss_val)
print('No.%s,loss=%.3f'%(i+1,loss_val))
#loss_hist.append(loss_val)
#acc=cal_accuracy(model,x_test,y_test,samples=10000)
#acc_hist.append(acc)
print('acc=',acc) end=time.time()
print('runnig time:%.3f sec.'%(end-start))
acc=cal_accuracy(model,torch.unsqueeze(x_test,dim=1),y_test,samples=10000)
print('accuracy:',acc) if __name__=='__main__':
train_3()

Pytorch写CNN的更多相关文章

  1. ubuntu之路——day18 用pytorch完成CNN

    本次作业:Andrew Ng的CNN的搭建卷积神经网络模型以及应用(1&2)作业目录参考这位博主的整理:https://blog.csdn.net/u013733326/article/det ...

  2. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  3. Pytorch和CNN图像分类

    Pytorch和CNN图像分类 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强大的GPU加速 ...

  4. pytorch写一个LeNet网络

    我们先介绍下pytorch中的cnn网络 学过深度卷积网络的应该都非常熟悉这张demo图(LeNet): 先不管怎么训练,我们必须先构建出一个CNN网络,很快我们写了一段关于这个LeNet的代码,并进 ...

  5. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  6. 1、pytorch写的第一个Linear模型(原始版,不调用nn.Modules模块)

    参考: https://github.com/Iallen520/lhy_DL_Hw/blob/master/PyTorch_Introduction.ipynb 模拟一个回归模型,y = X * w ...

  7. pytorch 8 CNN 卷积神经网络

    # library # standard library import os # third-party library import torch import torch.nn as nn impo ...

  8. 奉献pytorch 搭建 CNN 卷积神经网络训练图像识别的模型,配合numpy 和matplotlib 一起使用调用 cuda GPU进行加速训练

    1.Torch构建简单的模型 # coding:utf-8 import torch class Net(torch.nn.Module): def __init__(self,img_rgb=3,i ...

  9. pytorch之 CNN

    # library # standard library import os # third-party library import torch import torch.nn as nn impo ...

随机推荐

  1. postman(断言)

    一.断言 1.Code is 200 断言状态码是200 2.contains string 断言respoonse body中包含string 3.json value check (检查JSON值 ...

  2. requests抓取数据示例

    1:获取豆瓣电影名称及评分 # 抓取豆瓣电影名称及评分 url="https://movie.douban.com/j/search_subjects" start=input(& ...

  3. andorid jar/库源码解析之zxing

    目录:andorid jar/库源码解析 Zxing: 作用: 生成和识别,二维码,条形码. 栗子: 生成二维码,赋值到ImageView上 QRCodeWriter qrCodeWriter = n ...

  4. 2-JVM内存模型

    内存模型 方法区 JDK1.7 之前包含1.7 将方法区称为 Perm Space 永久代 JDK1.8之后包含1.8 将方法区称为 MetaSpace 元空间. 堆(分配内存会大一些) 分配对象.n ...

  5. P2201 数列编辑器

    传送门呀呀呀呀呀呀呀呀呀呀呀呀呀 \(乍一看题目好像很难\)(实际也确实很难) \(但是我们仔细看就发现,整个数列分成了光标前和光标后两组数列\) \(我们有什么理由不分开储存呢??\) \(然后光标 ...

  6. High Card Low Card G(田忌赛马进阶!!)

    传送门 \(首先一定要明确一个观点,不然会完全没有思路\) \(\bullet\)\(由于前半段大的更优,后半段小的更优.\) \(\bullet\)\(所以,\)Bessie\(一定会在前(n/2) ...

  7. K - Leapin' Lizards HDU - 2732 网络流

    题目链接:https://vjudge.net/contest/299467#problem/K 这个题目从数据范围来看可以发现是网络流,怎么建图呢?这个其实不是特别难,主要是读题难. 这个建图就是把 ...

  8. web scraper插件爬虫进阶(能满足非技术人员的爬虫需求,建议收藏!!!!)

    为了照顾更多的小伙伴,大家的学习能力及了解程度都不同,因此大家可以通过以下目录来有选择性的学习,节约大家的时间. 备注:  一定要实操!!!            一定要实操!!!           ...

  9. Crash-fix-2:org.springframework.http.converter.HttpMessageNotReadableException

    最近开始对APP上的Crash进行对应,发现有好多常见的问题,同一个问题在多个APP都类似的出现了,这里记录下这些常见的错误. crash Log: org.springframework.http. ...

  10. .Net Core3.0 WebApi 项目框架搭建 三:读取appsettings.json

    .Net Core3.0 WebApi 项目框架搭建:目录 appsettings.json 我们在写项目时往往会把一些经常变动的,可能会变动的参数写到配置文件.数据库中等可以存储数据且方便配置的地方 ...