Pytorch写CNN
用Pytorch写了两个CNN网络,数据集用的是FashionMNIST。其中CNN_1只有一个卷积层、一个全连接层,CNN_2有两个卷积层、一个全连接层,但训练完之后的准确率两者差不多,且CNN_1训练时间短得多,且跟两层的全连接的准确性也差不多,看来深度学习水很深,还需要进一步调参和调整网络结构。
CNN_1:
runnig time:29.795 sec.
accuracy: 0.8688
CNN_2:
runnig time:165.101 sec.
accuracy: 0.8837
import time
import torch.nn as nn
from torchvision.datasets import FashionMNIST
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.data as Data
import matplotlib.pyplot as plt #import os
#os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
'''数据集为FashionMNIST'''
data=FashionMNIST('../pycharm_workspace/data/') def train_test_split(data,test_pct=0.3):
test_len=int(data.data.size(0)*test_pct)
x_test=data.data[0:test_len].type(torch.float)
x_train=data.data[test_len:].type(torch.float) y_test=data.targets[0:test_len]
y_train=data.targets[test_len:] return x_train,y_train,x_test,y_test def cal_accuracy(model,x_test,y_test,samples=10000):
'''取一定数量的样本,用于评估'''
y_pred=model(x_test[:samples])
'''把模型输出(向量)转为label形式'''
y_pred_=list(map(lambda x:np.argmax(x),y_pred.data.numpy()))
'''计算准确率'''
acc=sum(y_pred_==y_test.numpy()[:samples])/samples
return acc class CNN_1(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
32,#out_channels,即filter的数量
4,#kernel_size,4代表(4,4)即正方形的filter,若为长方形,则(height,width)
stride=2,#filter移动的步长,2代表(2,2)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out=nn.Linear(32*7*7,10) def forward(self,x):
x=self.conv1(x)
temp=x.view(x.shape[0],-1)
out=self.out(temp)
return out class CNN_2(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
32,#out_channels,即filter的数量
5,#kernel_size,3代表(3,3)即正方形的filter,若为长方形,则(height,width)
stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2=nn.Sequential(
nn.Conv2d(32,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
16,#out_channels,即filter的数量
5,#kernel_size,5代表(5,5)即正方形的filter,若为长方形,则(height,width)
stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out=nn.Linear(16*7*7,10) def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=x.view(x.size(0),-1)
out=self.out(x)
return out def train_3():
num_epoch=5
#t_data=data.data.type(torch.float)
x_train,y_train,x_test,y_test=train_test_split(data,0.2)
'''使用DataLoader批量输入训练数据'''
dl_train=DataLoader(Data.TensorDataset(x_train,y_train),batch_size=100,shuffle=True)
'''创建模型对象'''
model=CNN_2()
'''定义损失函数'''
loss_func=torch.nn.CrossEntropyLoss()
'''定义优化器'''
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
start=time.time() acc_hist=[]
loss_hist=[]
for i in range(num_epoch):
for index,(x_data,y_data) in enumerate(dl_train):
prediction=model(torch.unsqueeze(x_data, dim=1))
loss=loss_func(prediction,y_data)
print('No.%s,loss=%.3f'%(index+1,loss.data.numpy()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_val=loss.data.numpy()
if i==0:
acc=cal_acc(prediction,y_data)
acc_hist.append(acc)
loss_hist.append(loss_val)
print('No.%s,loss=%.3f'%(i+1,loss_val))
#loss_hist.append(loss_val)
#acc=cal_accuracy(model,x_test,y_test,samples=10000)
#acc_hist.append(acc)
print('acc=',acc) end=time.time()
print('runnig time:%.3f sec.'%(end-start))
acc=cal_accuracy(model,torch.unsqueeze(x_test,dim=1),y_test,samples=10000)
print('accuracy:',acc) if __name__=='__main__':
train_3()
Pytorch写CNN的更多相关文章
- ubuntu之路——day18 用pytorch完成CNN
本次作业:Andrew Ng的CNN的搭建卷积神经网络模型以及应用(1&2)作业目录参考这位博主的整理:https://blog.csdn.net/u013733326/article/det ...
- MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...
- Pytorch和CNN图像分类
Pytorch和CNN图像分类 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强大的GPU加速 ...
- pytorch写一个LeNet网络
我们先介绍下pytorch中的cnn网络 学过深度卷积网络的应该都非常熟悉这张demo图(LeNet): 先不管怎么训练,我们必须先构建出一个CNN网络,很快我们写了一段关于这个LeNet的代码,并进 ...
- 基于pytorch的CNN、LSTM神经网络模型调参小结
(Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...
- 1、pytorch写的第一个Linear模型(原始版,不调用nn.Modules模块)
参考: https://github.com/Iallen520/lhy_DL_Hw/blob/master/PyTorch_Introduction.ipynb 模拟一个回归模型,y = X * w ...
- pytorch 8 CNN 卷积神经网络
# library # standard library import os # third-party library import torch import torch.nn as nn impo ...
- 奉献pytorch 搭建 CNN 卷积神经网络训练图像识别的模型,配合numpy 和matplotlib 一起使用调用 cuda GPU进行加速训练
1.Torch构建简单的模型 # coding:utf-8 import torch class Net(torch.nn.Module): def __init__(self,img_rgb=3,i ...
- pytorch之 CNN
# library # standard library import os # third-party library import torch import torch.nn as nn impo ...
随机推荐
- Jenkins联动码云自动匹配分支进行构建流水线
一.安装Generic Webhook Trigger插件 二.创建项目 创建项目之前先准备自己的项目,如果没有可以我fork的一个项目.地址是:https://gitee.com/jokerbai/ ...
- Spring Cloud Stream学习(五)入门
前言: 在了解完RabbitMQ后,再来学习SpringCloudStream就轻松很多了,SpringCloudStream现在主要支持两种消息中间件,一个是RabbitMQ,还有一个是KafK ...
- vue项目-打印页面中指定区域的内容(亲测有效!)
关于打印整个页面的,没什么好说的.今天我给大家分享一个打印指定区域的方法,你想打印哪里,就打印哪里! 我也是刚刚开始接触打印这一块功能的,然后当然是找度娘深入了解了一番啦,期间试了网上的各种方法,有的 ...
- [hdu5400 Arithmetic Sequence]预处理,容斥
题意:http://acm.hdu.edu.cn/showproblem.php?pid=5400 思路:预处理出每个点向左和向右的最远边界,从左向右枚举中间点,把区间答案加到总答案里面.由与可能与前 ...
- WIn7系统下配置Java环境变量
给个官网下载地址 :https://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html 1.首先 ...
- git 常用 指令累积
1.查询指定文件的修改所有修改日志git log --pretty=oneline 文件名 1. git log filename 可以看到fileName相关的commit记录2. git log ...
- 2018-06-19 js DOM对象
DOM对象: Doucument Object Model即文档对象 DOM对象的操作: 1.找元素 返回元素对象: var obj=document.getElementById();//通过Id查 ...
- Vi 和 Vim 的使用
Vi (Visual Interface)是 Linux下基于Shell 的文本编辑器,Vim (Visual Interface iMproved)是 Vi的增强版本,扩展了很多功能,比如对程序源文 ...
- 【雕爷学编程】MicroPython动手做(03)——零基础学MaixPy之开机测试
1.几个知识点(1)MicroPython 是 Python 3 语言的精简高效实现 ,包括Python标准库的一小部分,并针对嵌入式微控制器(单片机)和受限制的环境进行了优化,它是Python延伸出 ...
- 汇编语言 简单的Hello World
DATA SEGMENT STRING DB 'Hello World!','$' DATA ENDS CODE SEGMENT ASSUME CS:CODE, DS:DATA START: MOV ...