import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
sys.path.append('..')
import d2lzh_pytorch as d2l
import torchvision
import torchvision.transforms as transforms
定义和初始化模型
#与上一节同样的数据集以及批量大小
batch_size= 256
mnist_train= torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=True,transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=False,transform=transforms.ToTensor()) if sys.platform.startswith('win'):
num_worker=0 # 表示不用额外的进程来加速读取数据 else:
num_worker=4
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_worker)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_worker)

softmax的输出层是一个全连接层,所以我们使用一个线性模块就可以,因为前面我们数据返回的每个batch的样本X的形状为(batch_size,1,28,28),我们先用view()将X转化为(batch_size,784)才送入全连接层

num_inputs = 784
num_outputs = 10 class LinearNet(nn.Module):
def __init__(self,num_inputs,num_outputs):
super(LinearNet,self).__init__()
self.linear = nn.Linear(num_inputs,num_outputs)
def forward(self,x):
y = self.linear(x.view(x.shape[0],-1))
return y
net = LinearNet(num_inputs,num_outputs)
# 我们将形状转化的这个功能定义成一个FlattenLayer
class FlattenLayer(nn.Module):
def __init__(self):
super(FlattenLayer,self).__init__()
def forward(self,x):
return x.view(x.shape[0],-1)
from collections import OrderedDict
net = nn.Sequential(
OrderedDict(
[
('flatten',FlattenLayer()),
('linear',nn.Linear(num_inputs,num_outputs))
])
)
# 之前线性回归的是num_output是1
init.normal_(net.linear.weight,mean=0,std=0.01)
init.constant_(net.linear.bias,val=0)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)
print(net)
Sequential(
(flatten): FlattenLayer()
(linear): Linear(in_features=784, out_features=10, bias=True)
)
softamx和交叉熵损失函数
#pytorch提供了一个包括softmax预算和交叉熵损失计算的函数
loss = nn.CrossEntropyLoss()
定义优化算法
optimizer = torch.optim.SGD(net.parameters(),lr=0.1)
def evaluate_accuracy(data_iter, net):
acc_sum, n = 0.0, 0
for X, y in data_iter:
acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
n += y.shape[0]
return acc_sum / n
训练模型
num_epochs, lr = 5, 0.1
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
params=None, lr=None, optimizer=None):
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
for X, y in train_iter:
y_hat = net(X)
l = loss(y_hat, y).sum() # 梯度清零
if optimizer is not None:
optimizer.zero_grad()
elif params is not None and params[0].grad is not None:
for param in params:
param.grad.data.zero_() l.backward()
if optimizer is None:
# 上节的代码optimizer is None,使用的手写的代码SGD
sgd(params, lr, batch_size)
else:
# optimizer 非None,
optimizer.step() # “softmax回归的简洁实现”一节将用到 train_l_sum += l.item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
n += y.shape[0]
test_acc = evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))
train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None,optimizer)
epoch 1, loss 0.0031, train acc 0.749, test acc 0.765
epoch 2, loss 0.0022, train acc 0.813, test acc 0.808
epoch 3, loss 0.0021, train acc 0.826, test acc 0.818
epoch 4, loss 0.0020, train acc 0.832, test acc 0.816
epoch 5, loss 0.0019, train acc 0.837, test acc 0.821

动手学深度学习8-softmax分类pytorch简洁实现的更多相关文章

  1. 动手学深度学习9-多层感知机pytorch

    多层感知机 隐藏层 激活函数 小结 多层感知机 之前已经介绍过了线性回归和softmax回归在内的单层神经网络,然后深度学习主要学习多层模型,后续将以多层感知机(multilayer percetro ...

  2. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  3. 对比《动手学深度学习》 PDF代码+《神经网络与深度学习 》PDF

    随着AlphaGo与李世石大战的落幕,人工智能成为话题焦点.AlphaGo背后的工作原理"深度学习"也跳入大众的视野.什么是深度学习,什么是神经网络,为何一段程序在精密的围棋大赛中 ...

  4. 【动手学深度学习】Jupyter notebook中 import mxnet出错

    问题描述 打开d2l-zh目录,使用jupyter notebook打开文件运行,import mxnet 出现无法导入mxnet模块的问题, 但是命令行运行是可以导入mxnet模块的. 原因: 激活 ...

  5. 《动手学深度学习》系列笔记—— 1.2 Softmax回归与分类模型

    目录 softmax的基本概念 交叉熵损失函数 模型训练和预测 获取Fashion-MNIST训练集和读取数据 get dataset softmax从零开始的实现 获取训练集数据和测试集数据 模型参 ...

  6. 动手学深度学习7-从零开始完成softmax分类

    获取和读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 小结 import torch import torchvision import numpy ...

  7. 动手学深度学习14- pytorch Dropout 实现与原理

    方法 从零开始实现 定义模型参数 网络 评估函数 优化方法 定义损失函数 数据提取与训练评估 pytorch简洁实现 小结 针对深度学习中的过拟合问题,通常使用丢弃法(dropout),丢弃法有很多的 ...

  8. 动手学深度学习6-认识Fashion_MNIST图像数据集

    获取数据集 读取小批量样本 小结 本节将使用torchvision包,它是服务于pytorch深度学习框架的,主要用来构建计算机视觉模型. torchvision主要由以下几个部分构成: torchv ...

  9. 动手学深度学习1- pytorch初学

    pytorch 初学 Tensors 创建空的tensor 创建随机的一个随机数矩阵 创建0元素的矩阵 直接从已经数据创建tensor 创建新的矩阵 计算操作 加法操作 转化形状 tensor 与nu ...

随机推荐

  1. 基于OpenCV.Net投影法进行文本分块切割

    假设有如下一张图,如何把其中的文本分块切割出来,比如“华普超市朝阳门店”.“2015-07-26”就是两个文本块. 做图像切割有很多种方法,本文描述一种最直观的投影检测法.先来看看什么是投影,简单来说 ...

  2. Java生鲜电商平台-库存管理设计与架构

    Java生鲜电商平台-库存管理设计与架构 WMS的功能: 1.业务批次管理 该功能提供完善的物料批次信息.批次管理设置.批号编码规则设置.日常业务处理.报表查询,以及库存管理等综合批次管理功能,使企业 ...

  3. Python【day 13】内置函数01

    1.python3.6.2 一共有 68个内置函数2.分成6个大类 1.反射相关-4个 2.面向对象相关-9个 3.作用域相关--2个 1.globlas() #注意:最后是s,复数形式 查看全局作用 ...

  4. vue-router 在项目中的使用

    一.下载vue-router npm install vue-router --save 二.编码 1.在项目中新建文件夹 router/index.js /* * 路由对象模块 * */ impor ...

  5. Linux文件共享服务 FTP,NFS 和 Samba

    Linux 系统中,存储设主要有下面几种: DAS DAS 指 Direct Attached Storage,即直连附加存储,这种设备直接连接到计算机主板总线上,计算机将其识别为一个块设备,例如常见 ...

  6. day 36

    目录 pymysql操作mysql 安装 连接 增 删 改 查 索引 为什么使用索引以及索引的作用 类比 索引的本质 索引的底层原理 索引的种类(重点) 主键索引 唯一索引 普通索引 索引的创建 主键 ...

  7. appium---app输入中文

    在app自动化的过程中,都会遇到输入中文的问题,今天总结下app自动化如何输入中文 app输入中文 在启动app的时候在参数里面添加unicodeKeyboard和resetKeyboard后,运行代 ...

  8. 快速、优雅的前端IDE之H-builder-X

    为什么介绍的是HBuidler-X而不是Hbuilder   HX是全新的一个软件,它抛弃了eclipse架构,使用C++为基础架构.HX目前还不能完全替代HBuilder.但在markdown记事的 ...

  9. UGUI 逻辑以及实用性辅助功能

    UGUI 有它的实用性, 可是也存在理解上的困难, 因为它在面板上的显示内容根据布局而变动, 如果不深入理解它的设计原理, 估计每次要进行程序上的修改都需要进行一次换算和测试过程. 1. 设置某UI的 ...

  10. Bert实战---情感分类

    1.情感分析语料预处理 使用酒店评论语料,正面评论和负面评论各5000条,用BERT参数这么大的模型, 训练会产生严重过拟合,,泛化能力差的情况, 这也是我们下面需要解决的问题; 2.sigmoid二 ...