Pytorch分类和准确性评估--基于FashionMNIST数据集
最近在学习Pytorch v1.3最新版和Tensorflow2.0。
我学习Pytorch的主要途径:莫烦Python和Pytorch 1.3官方文档 ,Pytorch v1.3跟之前的Pytorch不太一样,比如1.3中,Variable类已经被弃用了(目前还可以用,但不推荐),tensor可以直接调用backward方法进行反向求导,不需要再像之前的版本一样必须包装成Variable对象之后再backward。
Tensorflow2.0的学习可以参考北大学生写的教程:https://tf.wiki/zh/basic/basic.html ,TensorFlow2.0与之前的版本也有很大不同,TF 1.x的很多写法已经不适用了,2.0把大量keras的内容包括了进去,使用之前的TF方便,但我总感觉混在一起,那还不如直接学Keras,另外跟Pytorch相比,为了实现相同的功能,TF2.0的代码还是太多了,不够简洁。
为了对比两者的速度,今天自己第一次尝试用Pytorch实现了用于图片分类的最简单的全连接神经网络。代码包括了神经网络的定义、使用DataLoader批训练、效果的准确性评估,模型使用方法、输出转换为label型等内容。
import time
import torch.nn as nn
from torchvision.datasets import FashionMNIST
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.data as Data '''数据集为FashionMNIST'''
data=FashionMNIST('../pycharm_workspace/data/') def train_test_split(data,test_pct=0.3):
test_len=int(data.data.size(0)*test_pct)
x_test=data.data[0:test_len].type(torch.float)
x_train=data.data[test_len:].type(torch.float) y_test=data.targets[0:test_len]
y_train=data.targets[test_len:] return x_train,y_train,x_test,y_test '''自定义神经网络1'''
class MLP(nn.Module):
def __init__(self,input_size,hidden_size,output_size):
super().__init__()
self.linear1=nn.Linear(input_size,hidden_size)
self.linear2=nn.Linear(hidden_size,output_size) def forward(self,x):
out=self.linear1(x)
out=torch.relu(out)
out=self.linear2(out)
return out
#out=torch.softmax() def train_1():
'''创建模型对象'''
input_size=784#训练数据的维度
hidden_size=64#隐藏层的神经元数量,这个数量越大,神经网络越复杂,训练后网络的准确度越高,但训练耗时也越长
ouput_size=10#输出层的神经元数量
mlp=MLP(input_size,hidden_size,ouput_size)
'''定义损失函数'''
loss_func=torch.nn.CrossEntropyLoss()
'''定义优化器'''
#optimizer=torch.optim.RMSprop(mlp.parameters(),lr=0.001,alpha=0.9)
#optimizer=torch.optim.Adam(mlp.parameters(),lr=0.01)
optimizer=torch.optim.Adam(mlp.parameters(),lr=0.001)
x_train,y_train,x_test,y_test=train_test_split(data,0.2)
start=time.time()
for i in range(200):
x=x_train.view(x_train.shape[0],-1)
prediction=mlp(x)
loss=loss_func(prediction,y_train)
print('Batch No.%s,loss:%s'%(i,loss.data.numpy()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
end=time.time()
print('runnig time:%.3f sec.'%(end-start)) '''评估模型效果'''
samples=10000
'''取一定数量的样本,用于评估'''
x_input=x_test[:samples]
'''模型输入必须为tensor形式,且维度为(784,)'''
x_input=x_input.view(x_input.shape[0],-1)
y_pred=mlp(x_input)
'''把模型输出(向量)转为label形式'''
y_pred_=list(map(lambda x:np.argmax(x),y_pred.data.numpy()))
'''计算准确率'''
acc=sum(y_pred_==y_test.numpy()[:samples])/samples
print('Accuracy:',acc)
###输出:Accuracy:0.8153 '''自定义神经网络2'''
class MyNet(nn.Module):
def __init__(self,in_size,hidden_size,out_size):
super().__init__()
self.linear1=nn.Linear(in_size,hidden_size)
self.linear2=nn.Linear(hidden_size,out_size) def forward(self,x):
x=x.view(x.size(0),-1)
out=self.linear1(x)
out=torch.relu(out)
out=self.linear2(out)
return out def train_2():
num_epoch=20
#t_data=data.data.type(torch.float)
x_train,y_train,x_test,y_test=train_test_split(data,0.2)
'''使用DataLoader批量输入训练数据'''
dl_train=DataLoader(Data.TensorDataset(x_train,y_train),batch_size=100,shuffle=True)
'''创建模型对象'''
model=MyNet(784,512,10)
'''定义损失函数'''
loss_func=torch.nn.CrossEntropyLoss()
'''定义优化器'''
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
start=time.time()
for i in range(num_epoch):
for index,(x_data,y_data) in enumerate(dl_train):
prediction=model(x_data)
loss=loss_func(prediction,y_data)
print('No.%s,loss=%.3f'%(index,loss.data.numpy()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('No.%s,loss=%.3f'%(i,loss.data.numpy()))
end=time.time()
print('runnig time:%.3f sec.'%(end-start)) '''评估模型的Accuracy'''
samples=10000
'''取一定数量的样本,用于评估'''
y_pred=model(x_test[:samples])
'''把模型输出(向量)转为label形式'''
y_pred_=list(map(lambda x:np.argmax(x),y_pred.data.numpy()))
'''计算准确率'''
acc=sum(y_pred_==y_test.numpy()[:samples])/samples
print('Accuracy:',acc)
###输出:Accuracy:0.8622
题外话,用相同的数据集、相同的神经网络结构、相同的优化器、相同的参数,把Pytorch跟TensorFlow2.0对比,发现pytorch对cpu的占用更小,TF 2.0跑起来Mac pro呼呼地响,Pytorch跑的时候安静很多。
Pytorch分类和准确性评估--基于FashionMNIST数据集的更多相关文章
- 基于MNIST数据集使用TensorFlow训练一个没有隐含层的浅层神经网络
基础 在参考①中我们详细介绍了没有隐含层的神经网络结构,该神经网络只有输入层和输出层,并且输入层和输出层是通过全连接方式进行连接的.具体结构如下: 我们用此网络结构基于MNIST数据集(参考②)进行训 ...
- 神经网络中的Heloo,World,基于MINST数据集的LeNet
前言 最近刚开始接触机器学习,记录下目前的一些理解,以及看到的一些好文章mark一下 1.MINST数据集 MNIST 数据集来自美国国家标准与技术研究所, National Institute of ...
- 【实践】如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统)
如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统) 一.环境配置 1. Python3.7.x(注:我用的是3.7.3.安 ...
- 基于COCO数据集验证的目标检测算法天梯排行榜
基于COCO数据集验证的目标检测算法天梯排行榜 AP50 Rank Model box AP AP50 Paper Code Result Year Tags 1 SwinV2-G (HTC++) 6 ...
- 基于titanic数据集预测titanic号旅客生还率
数据清洗及可视化 实验内容 数据清洗是数据分析中非常重要的一部分,也最繁琐,做好这一步需要大量的经验和耐心.这门课程中,我将和大家一起,一步步完成这项工作.大家可以从这门课程中学习数据清洗的基本思路以 ...
- 第二十二节,TensorFlow中的图片分类模型库slim的使用、数据集处理
Google在TensorFlow1.0,之后推出了一个叫slim的库,TF-slim是TensorFlow的一个新的轻量级的高级API接口.这个模块是在16年新推出的,其主要目的是来做所谓的“代码瘦 ...
- scikit-learn - 分类模型的评估 (classification_report)
使用说明 参数 sklearn.metrics.classification_report(y_true, y_pred, labels=None, target_names=None, sample ...
- 分类问题(一)MINST数据集与二元分类器
分类问题 在机器学习中,主要有两大类问题,分别是分类和回归.下面我们先主讲分类问题. MINST 这里我们会用MINST数据集,也就是众所周知的手写数字集,机器学习中的 Hello World.sk- ...
- 第四十七篇 入门机器学习——分类的准确性(Accuracy)
No.1. 通常情况下,直接将训练得到的模型应用于真实环境中,可能会存在很多问题 No.2. 比较好的解决方法是,将原始数据中的大部分用于训练数据,而留出少部分数据用于测试,即,将数据集切分成训练数据 ...
随机推荐
- 前端——Vue.js学习总结一
一.什么是Vue.js 1.Vue.js 是目前最火的一个前端框架,React是最流行的一个前端框架 2.Vue.js 是前端的主流框架之一,和Angular.js.React.js 一起,并成为前端 ...
- 王颖奇 20171010129《面向对象程序设计(java)》第十周学习总结
实验十 泛型程序设计技术 实验时间 2018-11-1 1.实验目的与要求 (1) 理解泛型概念: (2) 掌握泛型类的定义与使用: (3) 掌握泛型方法的声明与使用: (4) 掌握泛型接口的定义与 ...
- Openwrt:编译固件提示[mktplinkfw] error: images are too big 错误
在编译mr3420的固件时,添加了luci.jamvm,但是最终编译的固件"openwrt-ar71xx-generic-tl-mr3420-v1-squashfs-factory.bin& ...
- Mybatis-入门演示
MyBatis:持久层框架 前言 之前有看过和学习一些mybatis的文章和内容,但是没有去写过文章记录下,现在借鉴b站的狂神视频和官方文档看来重新撸一遍入门.有错误请多指教. 内容 数据访问层-相当 ...
- Pytest 单元测试框架
1.pytest 是 python 的第三方单元测试框架,比自带 unittest 更简洁和高效 2.安装 pytest pip install pytest 3.验证 pytest 是否安装成功 p ...
- indexDB出坑指南
对于入了前端坑的同学,indexDB绝对是需要深入学习的. 本文针对indexDB的难点问题(事务和数据库升级)做了详细的讲解,而对于indexDB的特点和使用方法只简要的介绍了一下.如果你有一些使用 ...
- [hdu5416 CRB and Tree]树上路径异或和,dfs
题意:给一棵树,每条边有一个权值,求满足u到v的路径上的异或和为s的(u,v)点对数 思路:计a到b的异或和为f(a,b),则f(a,b)=f(a,root)^f(b,root).考虑dfs,一边计算 ...
- 在一段字符串中的指定位置插入html标签,实现内容修改留痕
客户需求:实现内容修改留痕,并且鼠标移动到元素时,显示修改人和修改时间. (其实呢本人觉得这个如果是静态的页面,或者是后端拼接好的html,都很好实现,如果让前端动态实现就......) 前端实现的方 ...
- Docker之从零开始制作docker镜像
以前学习docker是直接docker pull命令直接拉取Linux中已有镜像,并创建容器,添加应用程序,但是docker镜像一开始是怎么来的呢?下面将从零开始介绍整个docker镜像的制作过程(初 ...
- 「雕爷学编程」Arduino动手做(33)——ESP-01S无线WIFI模块
37款传感器与模块的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止37种的.鉴于本人手头积累了一些传感器和模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的,这里 ...