加入自定义块对fashion_mnist数据集进行softmax分类
在之前,我们实现了使用torch自带的层对fashion_mnist数据集进行分类。这次,我们加入一个自己实现的block,实现一个四层的多层感知机进行softmax分类,作为对“自定义块”的代码实现的一个练习。
我们设计的多层感知机是这样的:输入维度为784,在展平层过后,第一层为全连接层,输入输出维度分别为784,256;第二层为全连接层,输入输出维度分别为256,128;第三层为全连接层,输入输出维度分别为128,64;第四层为全连接层(输出层),输入输出维度分别为64,10.代码如下:
import torch
from d2l import torch as d2l
from torch import nn
from torch.nn import functional as F batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
num_inputs = 784
num_outputs = 10 #输入层784; 隐藏层一784,256;隐藏层二256,128; 隐藏层三128,64; 输出层64,10
#我们用自定义Module实现隐藏层二、隐藏层三。
class practice_Module(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(256,128)
self.lin2 = nn.Linear(128,64)
nn.init.normal_(self.lin1.weight,std=0.01)
nn.init.normal_(self.lin2.weight,std=0.01)
def forward(self,X):
X = self.lin1(X)
X = F.relu(X)
X = self.lin2(X)
X = F.relu(X)
return X manual_block = practice_Module()
net = nn.Sequential(nn.Flatten(),
nn.Linear(784,256),
nn.ReLU(),
nn.Dropout(0.2),
manual_block,
nn.Dropout(0.3),
nn.Linear(64,10)
) def init_weight(m):
if m == nn.Linear:
nn.init.normal_(m.weight,std=0.01)
return
net.apply(init_weight) loss = torch.nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(),lr=0.1)
num_epochs = 20
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)
首先在我们自定义的模块中,初始化函数__init__中定义我们需要的两个层lin1和lin2.上面的代码在抽象的类practice_Module中的初始化函数__init__中进行了参数初始化,也就是说默认情况下用这个类创建的所有对象都会进行这样的默认初始化。

当然,也可以按我们的需要对具体的模块对象进行参数初始化。
然后在forward函数中定义这个模块进行的操作,即先让数据经过线性层lin1,激活,再经过线性层lin2,激活,然后输出。return X语句,return的X值作为输出,就会作为nn.Sequential中的下一层输入。

注意:这里面的前向传播函数名必须是forward,而不能是其他的,改成其他的就会报错:
Module [practice_Module] is missing the required "forward" function
这也是为什么practice_Module类的实例在nn.Sequential中可以自动计算的原因,是因为系统会自动找到该实例中的方法forward并执行。
下面的语句初始化了一个practice_Module类的实例。可以这样理解:practice_Module是一个抽象的网络结构,而manual_block这个实例才是一个具体的我们需要的模型(包含具体参数)。

可以用如下代码对自定义的模块实例进行初始化。这里,nn.init.normal_()可以对nn.Module的子类的某一具体的层进行参数初始化。

在nn.Sequential中加入我们自定义的模块是非常简单的:

init_weight()函数对torch中定义好了的层进行参数初始化:

加入自定义块对fashion_mnist数据集进行softmax分类的更多相关文章
- tensorflow 离线使用 fashion_mnist 数据集
在tensflow中加载 fashion_mnist 数据集时,由于网络原因.可能会长时间加载不到或报错 此时我们可以通过离线的方式加载 1.首先下载数据集:fashion_mnist (下载后解压) ...
- 学习笔记TF010:softmax分类
回答多选项问题,使用softmax函数,对数几率回归在多个可能不同值上的推广.函数返回值是C个分量的概率向量,每个分量对应一个输出类别概率.分量为概率,C个分量和始终为1.每个样本必须属于某个输出类别 ...
- 从零和使用mxnet实现softmax分类
1.softmax从零实现 from mxnet.gluon import data as gdata from sklearn import datasets from mxnet import n ...
- 器学习算法(六)基于天气数据集的XGBoost分类预测
1.机器学习算法(六)基于天气数据集的XGBoost分类预测 1.1 XGBoost的介绍与应用 XGBoost是2016年由华盛顿大学陈天奇老师带领开发的一个可扩展机器学习系统.严格意义上讲XGBo ...
- tensorflow 使用 5 mnist 数据集, softmax 函数
用于分类 softmax 函数 手写数据识别:
- softmax分类算法原理(用python实现)
逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...
- 动手学深度学习7-从零开始完成softmax分类
获取和读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 小结 import torch import torchvision import numpy ...
- 利用keras自带路透社数据集进行多分类训练
import numpy as np from keras.datasets import reuters from keras import layers from keras import mod ...
- 用MATLAB的Classficiation Learner工具箱对12个数据集进行各种分类与验证
准备材料 以所有的特征集作为variable进行像Bayes吖.SVM吖.决策树吖......分类.同时对数据进行预处理,选出相关度高的特征子集作为新的一组data进行分类(预处理的代码不必放出来). ...
- 机器学习-MNIST数据集使用二分类
一.二分类训练MNIST数据集练习 %matplotlib inlineimport matplotlibimport numpy as npimport matplotlib.pyplot as p ...
随机推荐
- 什么叫运行时的Java程序?
Java程序的运行包含编写.编译和运行三个主要步骤. 1.在编写阶段: 开发人员在Java开发环境中输入程序代码,形成后缀名为.java的Java源文件. 2.在编译阶段: 使用Java编译器对源文件 ...
- 今日学习:位运算&中国剩余定理
-2^ 31的补码是-0.也就是 1000 0000 0000 0000 0000 0000 0000 0000 补码是原码取反加1 x&(-x) 是最低位为1的位为1,其余位为0. 中国剩余 ...
- CAST电子部单片机方向授课——串口通信 预习文档
CAST电子部单片机方向授课--串口通信 预习文档 课前小准备 安装串口调试助手 第一步:进入Microsoft Store 第二步:在Microsoft Store中搜索 "串口调试助手& ...
- 记录--Vue3 + Fabricjs 定制国庆专属头像
这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 生在国旗下,长在春风里!国庆将至,采黎为大家带来 定制头像2.0(国庆头像),让我们用代码的形式为祖国庆生!欢迎大家点赞收藏加关注哦 前言 ...
- SwiftUI 笔记
TextField 监听 lost focus 之前有一个初始化方法,传入一个 onEditingChanged closure,但这个方法废弃了,文档中也说了 alternative:使用 Focu ...
- DGC:真动态分组卷积,可能是解决分组特征阻塞的最好方案 | ECCV 2020 Spotlight
近期,动态网络在加速推理这方面有很多研究,DGC(Dynamic Group Convolution)将动态网络的思想结合到分组卷积中,使得分组卷积在轻量化的同时能够加强表达能力,整体思路直接清晰,可 ...
- KingbaseES V8R6 中walminer的使用
前言 walminer工具可以帮助dba挖掘wal日志中的内容,看到某时间对应数据库中的具体操作.例如挖掘日志后可以看到数据库某时间有哪些dml语句. walminer的限制与约束 WalMiner工 ...
- 使用POI、JavaCsv工具读取excel文件(*.xls , *.xlsx , *.csv)存入MySQL数据库
首先进行maven的配置:导入相关依赖 1 <dependency> 2 <groupId>org.apache.poi</groupId> 3 <artif ...
- OpenHarmony Meetup 2023 广州站圆满举办,城市巡回全面启航
"OpenHarmony正当时--技术开源"OpenHarmony Meetup 2023城市巡回活动,旨在通过meetup线下交流形式,解读OpenHarmony作为下一代智 ...
- C#只允许启动一个WinFrom进程
[STAThread] public static void Main() { bool ret; System.Threading.Mutex mutex = new System.Thread ...