PyTorch之初级使用
- 使用流程
①. 数据准备; ②. 模型确立; ③. 损失函数确立; ④. 优化器确立; ⑤. 模型训练及保存 - 模块介绍
Part Ⅰ: 数据准备
torch.utils.data.Dataset
torch.utils.data.DataLoader
关于Dataset, 作为数据集, 需要实现基本的3个方法, 分别为: __init__、__len__、__getitem__. 示例如下,
1 class TrainingDataset(Dataset):
2
3 def __init__(self, X, Y_, transform=None, target_transform=None):
4 self.__X = X
5 self.__Y_ = Y_
6 self.__transform = transform
7 self.__target_transform = target_transform
8
9
10 def __len__(self):
11 return len(self.__X)
12
13
14 def __getitem__(self, idx):
15 x = self.__X[idx]
16 y_ = self.__Y_[idx]
17 if self.__transform:
18 x = self.__transform(x)
19 if self.__target_transform:
20 y_ = self.__target_transform(y_)
21 return x, y_关于DataLoader, 作为数据集封装, 将数据集Dataset封装为可迭代对象. 示例如下,
1 batch_size = 100
2 trainingLoader = DataLoader(trainingData, batch_size=batch_size, shuffle=True)Part Ⅱ: 模型确立
torch.nn
torch.nn.Module
网络模型由基类Module派生, 内部所有操作模块均由命名空间nn提供, 需要实现基本的2个方法, 分别为: __init__、forward. 其中, __init__方法定义操作, forward方法运用操作进行正向计算. 示例如下,1 class NeuralNetwork(nn.Module):
2
3 def __init__(self):
4 super(NeuralNetwork, self).__init__()
5 self.__linear_tanh_stack = nn.Sequential(
6 nn.Linear(3, 5),
7 nn.Tanh(),
8 nn.Linear(5, 3)
9 )
10
11
12 def forward(self, x):
13 y = self.__linear_tanh_stack(x)
14 return y
15
16
17 model = NeuralNetwork()Part Ⅲ: 损失函数确立
torch.nn
常见损失函数有: nn.MSELoss(回归任务)、nn.CrossEntropyLoss(多分类任务)等. 示例如下,1 loss_func = nn.MSELoss(reduction="sum")
Part Ⅳ: 优化器确立
torch.optim
常见的优化器有: optim.SGD、optim.Adam等. 示例如下,1 optimizer = optim.Adam(model.parameters(), lr=0.001)
Part Ⅴ: 模型训练及保存
有效整合数据、模型、损失函数及优化器. 注意, 模型参数之梯度默认累积, 每次参数优化需要显式清零. 示例如下,1 def train_loop(dataloader, model, loss_func, optimizer):
2 for batchIdx, (X, Y_) in enumerate(dataloader):
3 Y = model(X)
4 loss = loss_func(Y, Y_)
5
6 optimizer.zero_grad()
7 loss.backward()
8 optimizer.step()
9
10
11 epoch = 50000
12 for epochIdx in range(epoch):
13 train_loop(trainingLoader, model, loss_func, optimizer)
14
15
16 torch.save(model.state_dict(), "model_params.pth") - 代码实现
本文使用与Back Propagation - Python实现相同的网络架构及数据生成策略, 分别如下所示,
$$
\begin{equation*}
\left\{
\begin{split}
x &= r + 2g + 3b \\
y &= r^2 + 2g^2 + 3b^2 \\
lv &= -3r - 4g - 5b
\end{split}
\right.
\end{equation*}
$$
具体实现如下,
1 import numpy
2 import torch
3 from torch import nn
4 from torch import optim
5 from torch.utils.data import Dataset, DataLoader
6 from matplotlib import pyplot as plt
7
8
9 numpy.random.seed(1)
10 torch.manual_seed(3)
11
12
13 # 生成training数据
14 def getData(n=100):
15 rgbRange = (-1, 1)
16 r = numpy.random.uniform(*rgbRange, (n, 1))
17 g = numpy.random.uniform(*rgbRange, (n, 1))
18 b = numpy.random.uniform(*rgbRange, (n, 1))
19 x_ = r + 2 * g + 3 * b
20 y_ = r ** 2 + 2 * g ** 2 + 3 * b ** 2
21 lv_ = -3 * r - 4 * g - 5 * b
22 RGB = numpy.hstack((r, g, b))
23 XYLv_ = numpy.hstack((x_, y_, lv_))
24 return RGB, XYLv_
25
26
27 class TrainingDataset(Dataset):
28
29 def __init__(self, X, Y_, transform=None, target_transform=None):
30 self.__X = X
31 self.__Y_ = Y_
32 self.__transform = transform
33 self.__target_transform = target_transform
34
35
36 def __len__(self):
37 return len(self.__X)
38
39
40 def __getitem__(self, idx):
41 x = self.__X[idx]
42 y_ = self.__Y_[idx]
43 if self.__transform:
44 x = self.__transform(x)
45 if self.__target_transform:
46 y_ = self.__target_transform(y_)
47 return x, y_
48
49
50 RGB, XYLv_ = getData(1000)
51 trainingData = TrainingDataset(RGB, XYLv_, torch.Tensor, torch.Tensor)
52
53 batch_size = 100
54 trainingLoader = DataLoader(trainingData, batch_size=batch_size, shuffle=True)
55
56
57 class NeuralNetwork(nn.Module):
58
59 def __init__(self):
60 super(NeuralNetwork, self).__init__()
61 self.__linear_tanh_stack = nn.Sequential(
62 nn.Linear(3, 5),
63 nn.Tanh(),
64 nn.Linear(5, 3)
65 )
66
67
68 def forward(self, x):
69 y = self.__linear_tanh_stack(x)
70 return y
71
72
73 model = NeuralNetwork()
74 loss_func = nn.MSELoss(reduction="sum")
75 optimizer = optim.Adam(model.parameters(), lr=0.001)
76
77
78 def train_loop(dataloader, model, loss_func, optimizer):
79 JVal = 0
80 for batchIdx, (X, Y_) in enumerate(dataloader):
81 Y = model(X)
82 loss = loss_func(Y, Y_)
83
84 JVal += loss.item()
85
86 optimizer.zero_grad()
87 loss.backward()
88 optimizer.step()
89
90 JVal /= 2
91 return JVal
92
93
94 JPath = list()
95 epoch = 50000
96 for epochIdx in range(epoch):
97 JVal = train_loop(trainingLoader, model, loss_func, optimizer)
98 print("epoch: {:5d}, JVal = {:.5f}".format(epochIdx, JVal))
99 JPath.append(JVal)
100
101
102 torch.save(model.state_dict(), "model_params.pth")
103
104
105 fig = plt.figure(figsize=(6, 4))
106 ax1 = fig.add_subplot(1, 1, 1)
107
108 ax1.plot(numpy.arange(len(JPath)), JPath, "k.", markersize=1)
109 ax1.plot(0, JPath[0], "go", label="seed")
110 ax1.plot(len(JPath)-1, JPath[-1], "r*", label="solution")
111
112 ax1.legend()
113 ax1.set(xlabel="$epoch$", ylabel="$JVal$", title="solution-JVal = {:.5f}".format(JPath[-1]))
114
115 fig.tight_layout()
116 fig.savefig("plot_fig.png", dpi=100) - 结果展示
可以看到, 在training data上总体loss随epoch增加逐渐降低. 使用建议
①. 分batch处理训练数据, 可以提升训练初始阶段模型参数收敛速度;
②. 常规优化器推荐Adam, 具备自动步长调节的能力.- 参考文档
①. https://pytorch.org/tutorials/beginner/basics/intro.html
PyTorch之初级使用的更多相关文章
- Pytorch【直播】2019 年县域农业大脑AI挑战赛---初级准备(一)切图
比赛地址:https://tianchi.aliyun.com/competition/entrance/231717/introduction 这次比赛给的图非常大5万x5万,在训练之前必须要进行数 ...
- 转pytorch中训练深度神经网络模型的关键知识点
版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_42279044/articl ...
- 马哥linux运维初级+中级+高级 视频教程 教学视频 全套下载(近50G)
马哥linux运维初级+中级+高级 视频教程 教学视频 全套下载(近50G)目录详情:18_02_ssl协议.openssl及创建私有CA18_03_OpenSSH服务及其相关应用09_01_磁盘及文 ...
- Python 正则表达式入门(初级篇)
Python 正则表达式入门(初级篇) 本文主要为没有使用正则表达式经验的新手入门所写. 转载请写明出处 引子 首先说 正则表达式是什么? 正则表达式,又称正规表示式.正规表示法.正规表达式.规则表达 ...
- python 高级之面向对象初级
python 高级之面向对象初级 本节内容 类的创建 类的构造方法 面向对象之封装 面向对象之继承 面向对象之多态 面向对象之成员 property 1.类的创建 面向对象:对函数进行分类和封装,让开 ...
- N皇后问题—初级回溯
N皇后问题,最基础的回溯问题之一,题意简单N*N的正方形格子上放置N个皇后,任意两个皇后不能出现在同一条直线或者斜线上,求不同N对应的解. 提要:N>13时,数量庞大,初级回溯只能保证在N< ...
- python 面向对象初级篇
Python 面向对象(初级篇) 概述 面向过程:根据业务逻辑从上到下写垒代码 函数式:将某功能代码封装到函数中,日后便无需重复编写,仅调用函数即可 面向对象:对函数进行分类和封装,让开发" ...
- codefordream 关于js初级训练
这里的初级训练相对简单,差不多都是以前知识温习. 比如输出“hello world”,直接使用console.log()就行.注释符号,“//”可以注释单行,快捷键 alt+/,"/* ...
- Mysql操作初级
Mysql操作初级 本节内容 数据库概述 数据库安装 数据库操作 数据表操作 表内容操作 1.数据库概述 数据库管理系统叫做DBMS 1.什么是数据库 ? 答:数据的仓库,如:在ATM的示例中我们创建 ...
- python面向对象初级(七)
概述 面向过程:根据业务逻辑从上到下写垒代码 函数式:将某功能代码封装到函数中,日后便无需重复编写,仅调用函数即可 面向对象:对函数进行分类和封装,让开发“更快更好更强...” 面向过程编程最易被初学 ...
随机推荐
- PowerToys 微软效率工具包 使用教程
今天给大家介绍一款 非常实用的微软工具包 里面包含 快捷键的使用 颜色选择器 键盘管理器 屏幕标尺 鼠标实用工具等众多高效工作的功能 还是蛮出彩的 下载 PowerToys⇲ 安装教程 1.双击文件运 ...
- ECharts 饼图切换数据源bug 开始没数据显示 切换或刷新后显示
1.出现问题原因 一个饼图,右上方两个按钮分别为今天和本月,分别调用不同接口控制,点击则调用不同接口同时饼图绑定数据源刷新:出现此问题原因点击今日按钮有一个饼图区域形没有数据不显示,对应数据值比例都没 ...
- Sentinel熔断与限流
1.简介 在线文档: https://sentinelguard.io/zh-cn/docs/system-adaptive-protection.html 功能: 流量控制 速率控制 熔断和限流 和 ...
- 关于opencv3.2的parallel_for_函数不支持bind function的处理(基于ch8代码)
1.换opencv4 2.修改程序 改程序针对slambook2/ch8/direct_method.cpp #include <opencv2/opencv.hpp> #include ...
- 解决Linux上tomcat解析war包中文文件乱码
解决Linux上tomcat解析war包中文文件乱码 第一步 编辑tomcat/conf server.xml vim /usr/local/src/tomcat/conf/server.xml us ...
- 用 HTTP 协议下载资源(WinINet 实现)
用 HTTP 协议下载资源(WinINet 实现) WinINet 使用 HTTP 协议下载资源的流程 相关函数 InternetCrackUrl 解析 URL BOOL InternetCrackU ...
- dotnet总结——类型系统
包括2种大的类型: 引用类型和值类型, 放一张图说明继承层次: 一 值类型: 内置的值类型,如下 用户自定义值类型就是用户定义的枚举或者结构类型. 可空类型(Nullable<T>)属于 ...
- [EULAR文摘] 在总人群中监测ACPA能否预测早期关节炎
标签: 类风湿关节炎; 抗CCP抗体; 预测因子; 病程演变 在总人群中监测ACPA能否预测早期关节炎 Verstappen SM, et al. EULAR 2015. Present ID: OP ...
- Hugging Face 每周速递: Space 支持创建模版应用、Hub 搜索功能增强、BioGPT-Large 还有更多
每一周,我们的同事都会向社区的成员们发布一些关于 Hugging Face 相关的更新,包括我们的产品和平台更新.社区活动.学习资源和内容更新.开源库和模型更新等,我们将其称之为「Hugging Ne ...
- LeetCode算法训练-回溯总结
欢迎关注个人公众号:爱喝可可牛奶 LeetCode算法训练-回溯总结 适用问题 组合问题:N个数里面按一定规则找出k个数的集合 排列问题:N个数按一定规则全排列,有几种排列方式 切割问题:一个字符串按 ...