Pytorch写CNN
用Pytorch写了两个CNN网络,数据集用的是FashionMNIST。其中CNN_1只有一个卷积层、一个全连接层,CNN_2有两个卷积层、一个全连接层,但训练完之后的准确率两者差不多,且CNN_1训练时间短得多,且跟两层的全连接的准确性也差不多,看来深度学习水很深,还需要进一步调参和调整网络结构。
CNN_1:
runnig time:29.795 sec.
accuracy: 0.8688
CNN_2:
runnig time:165.101 sec.
accuracy: 0.8837
import time
import torch.nn as nn
from torchvision.datasets import FashionMNIST
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.data as Data
import matplotlib.pyplot as plt #import os
#os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
'''数据集为FashionMNIST'''
data=FashionMNIST('../pycharm_workspace/data/') def train_test_split(data,test_pct=0.3):
test_len=int(data.data.size(0)*test_pct)
x_test=data.data[0:test_len].type(torch.float)
x_train=data.data[test_len:].type(torch.float) y_test=data.targets[0:test_len]
y_train=data.targets[test_len:] return x_train,y_train,x_test,y_test def cal_accuracy(model,x_test,y_test,samples=10000):
'''取一定数量的样本,用于评估'''
y_pred=model(x_test[:samples])
'''把模型输出(向量)转为label形式'''
y_pred_=list(map(lambda x:np.argmax(x),y_pred.data.numpy()))
'''计算准确率'''
acc=sum(y_pred_==y_test.numpy()[:samples])/samples
return acc class CNN_1(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
32,#out_channels,即filter的数量
4,#kernel_size,4代表(4,4)即正方形的filter,若为长方形,则(height,width)
stride=2,#filter移动的步长,2代表(2,2)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out=nn.Linear(32*7*7,10) def forward(self,x):
x=self.conv1(x)
temp=x.view(x.shape[0],-1)
out=self.out(temp)
return out class CNN_2(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
32,#out_channels,即filter的数量
5,#kernel_size,3代表(3,3)即正方形的filter,若为长方形,则(height,width)
stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2=nn.Sequential(
nn.Conv2d(32,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
16,#out_channels,即filter的数量
5,#kernel_size,5代表(5,5)即正方形的filter,若为长方形,则(height,width)
stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out=nn.Linear(16*7*7,10) def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=x.view(x.size(0),-1)
out=self.out(x)
return out def train_3():
num_epoch=5
#t_data=data.data.type(torch.float)
x_train,y_train,x_test,y_test=train_test_split(data,0.2)
'''使用DataLoader批量输入训练数据'''
dl_train=DataLoader(Data.TensorDataset(x_train,y_train),batch_size=100,shuffle=True)
'''创建模型对象'''
model=CNN_2()
'''定义损失函数'''
loss_func=torch.nn.CrossEntropyLoss()
'''定义优化器'''
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
start=time.time() acc_hist=[]
loss_hist=[]
for i in range(num_epoch):
for index,(x_data,y_data) in enumerate(dl_train):
prediction=model(torch.unsqueeze(x_data, dim=1))
loss=loss_func(prediction,y_data)
print('No.%s,loss=%.3f'%(index+1,loss.data.numpy()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_val=loss.data.numpy()
if i==0:
acc=cal_acc(prediction,y_data)
acc_hist.append(acc)
loss_hist.append(loss_val)
print('No.%s,loss=%.3f'%(i+1,loss_val))
#loss_hist.append(loss_val)
#acc=cal_accuracy(model,x_test,y_test,samples=10000)
#acc_hist.append(acc)
print('acc=',acc) end=time.time()
print('runnig time:%.3f sec.'%(end-start))
acc=cal_accuracy(model,torch.unsqueeze(x_test,dim=1),y_test,samples=10000)
print('accuracy:',acc) if __name__=='__main__':
train_3()
Pytorch写CNN的更多相关文章
- ubuntu之路——day18 用pytorch完成CNN
本次作业:Andrew Ng的CNN的搭建卷积神经网络模型以及应用(1&2)作业目录参考这位博主的整理:https://blog.csdn.net/u013733326/article/det ...
- MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...
- Pytorch和CNN图像分类
Pytorch和CNN图像分类 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强大的GPU加速 ...
- pytorch写一个LeNet网络
我们先介绍下pytorch中的cnn网络 学过深度卷积网络的应该都非常熟悉这张demo图(LeNet): 先不管怎么训练,我们必须先构建出一个CNN网络,很快我们写了一段关于这个LeNet的代码,并进 ...
- 基于pytorch的CNN、LSTM神经网络模型调参小结
(Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...
- 1、pytorch写的第一个Linear模型(原始版,不调用nn.Modules模块)
参考: https://github.com/Iallen520/lhy_DL_Hw/blob/master/PyTorch_Introduction.ipynb 模拟一个回归模型,y = X * w ...
- pytorch 8 CNN 卷积神经网络
# library # standard library import os # third-party library import torch import torch.nn as nn impo ...
- 奉献pytorch 搭建 CNN 卷积神经网络训练图像识别的模型,配合numpy 和matplotlib 一起使用调用 cuda GPU进行加速训练
1.Torch构建简单的模型 # coding:utf-8 import torch class Net(torch.nn.Module): def __init__(self,img_rgb=3,i ...
- pytorch之 CNN
# library # standard library import os # third-party library import torch import torch.nn as nn impo ...
随机推荐
- postman(断言)
一.断言 1.Code is 200 断言状态码是200 2.contains string 断言respoonse body中包含string 3.json value check (检查JSON值 ...
- requests抓取数据示例
1:获取豆瓣电影名称及评分 # 抓取豆瓣电影名称及评分 url="https://movie.douban.com/j/search_subjects" start=input(& ...
- andorid jar/库源码解析之zxing
目录:andorid jar/库源码解析 Zxing: 作用: 生成和识别,二维码,条形码. 栗子: 生成二维码,赋值到ImageView上 QRCodeWriter qrCodeWriter = n ...
- 2-JVM内存模型
内存模型 方法区 JDK1.7 之前包含1.7 将方法区称为 Perm Space 永久代 JDK1.8之后包含1.8 将方法区称为 MetaSpace 元空间. 堆(分配内存会大一些) 分配对象.n ...
- P2201 数列编辑器
传送门呀呀呀呀呀呀呀呀呀呀呀呀呀 \(乍一看题目好像很难\)(实际也确实很难) \(但是我们仔细看就发现,整个数列分成了光标前和光标后两组数列\) \(我们有什么理由不分开储存呢??\) \(然后光标 ...
- High Card Low Card G(田忌赛马进阶!!)
传送门 \(首先一定要明确一个观点,不然会完全没有思路\) \(\bullet\)\(由于前半段大的更优,后半段小的更优.\) \(\bullet\)\(所以,\)Bessie\(一定会在前(n/2) ...
- K - Leapin' Lizards HDU - 2732 网络流
题目链接:https://vjudge.net/contest/299467#problem/K 这个题目从数据范围来看可以发现是网络流,怎么建图呢?这个其实不是特别难,主要是读题难. 这个建图就是把 ...
- web scraper插件爬虫进阶(能满足非技术人员的爬虫需求,建议收藏!!!!)
为了照顾更多的小伙伴,大家的学习能力及了解程度都不同,因此大家可以通过以下目录来有选择性的学习,节约大家的时间. 备注: 一定要实操!!! 一定要实操!!! ...
- Crash-fix-2:org.springframework.http.converter.HttpMessageNotReadableException
最近开始对APP上的Crash进行对应,发现有好多常见的问题,同一个问题在多个APP都类似的出现了,这里记录下这些常见的错误. crash Log: org.springframework.http. ...
- .Net Core3.0 WebApi 项目框架搭建 三:读取appsettings.json
.Net Core3.0 WebApi 项目框架搭建:目录 appsettings.json 我们在写项目时往往会把一些经常变动的,可能会变动的参数写到配置文件.数据库中等可以存储数据且方便配置的地方 ...