• 使用流程
    ①. 数据准备; ②. 模型确立; ③. 损失函数确立; ④. 优化器确立; ⑤. 模型训练及保存
  • 模块介绍
    Part Ⅰ: 数据准备
    torch.utils.data.Dataset 
    torch.utils.data.DataLoader 
    关于Dataset, 作为数据集, 需要实现基本的3个方法, 分别为: __init__、__len__、__getitem__. 示例如下,
     1 class TrainingDataset(Dataset):
    2
    3 def __init__(self, X, Y_, transform=None, target_transform=None):
    4 self.__X = X
    5 self.__Y_ = Y_
    6 self.__transform = transform
    7 self.__target_transform = target_transform
    8
    9
    10 def __len__(self):
    11 return len(self.__X)
    12
    13
    14 def __getitem__(self, idx):
    15 x = self.__X[idx]
    16 y_ = self.__Y_[idx]
    17 if self.__transform:
    18 x = self.__transform(x)
    19 if self.__target_transform:
    20 y_ = self.__target_transform(y_)
    21 return x, y_

    关于DataLoader, 作为数据集封装, 将数据集Dataset封装为可迭代对象. 示例如下,

    1 batch_size = 100
    2 trainingLoader = DataLoader(trainingData, batch_size=batch_size, shuffle=True)

    Part Ⅱ: 模型确立
    torch.nn 
    torch.nn.Module 
    网络模型由基类Module派生, 内部所有操作模块均由命名空间nn提供, 需要实现基本的2个方法, 分别为: __init__、forward. 其中, __init__方法定义操作, forward方法运用操作进行正向计算. 示例如下,

     1 class NeuralNetwork(nn.Module):
    2
    3 def __init__(self):
    4 super(NeuralNetwork, self).__init__()
    5 self.__linear_tanh_stack = nn.Sequential(
    6 nn.Linear(3, 5),
    7 nn.Tanh(),
    8 nn.Linear(5, 3)
    9 )
    10
    11
    12 def forward(self, x):
    13 y = self.__linear_tanh_stack(x)
    14 return y
    15
    16
    17 model = NeuralNetwork()

    Part Ⅲ: 损失函数确立
    torch.nn 
    常见损失函数有: nn.MSELoss(回归任务)、nn.CrossEntropyLoss(多分类任务)等. 示例如下,

    1 loss_func = nn.MSELoss(reduction="sum")

    Part Ⅳ: 优化器确立
    torch.optim 
    常见的优化器有: optim.SGD、optim.Adam等. 示例如下,

    1 optimizer = optim.Adam(model.parameters(), lr=0.001)

    Part Ⅴ: 模型训练及保存
    有效整合数据、模型、损失函数及优化器. 注意, 模型参数之梯度默认累积, 每次参数优化需要显式清零. 示例如下,

     1 def train_loop(dataloader, model, loss_func, optimizer):
    2 for batchIdx, (X, Y_) in enumerate(dataloader):
    3 Y = model(X)
    4 loss = loss_func(Y, Y_)
    5
    6 optimizer.zero_grad()
    7 loss.backward()
    8 optimizer.step()
    9
    10
    11 epoch = 50000
    12 for epochIdx in range(epoch):
    13 train_loop(trainingLoader, model, loss_func, optimizer)
    14
    15
    16 torch.save(model.state_dict(), "model_params.pth")
  • 代码实现
    本文使用与Back Propagation - Python实现相同的网络架构及数据生成策略, 分别如下所示,

    $$
    \begin{equation*}
    \left\{
    \begin{split}
    x &= r + 2g + 3b \\
    y &= r^2 + 2g^2 + 3b^2 \\
    lv &= -3r - 4g - 5b
    \end{split}
    \right.
    \end{equation*} 
    $$
    具体实现如下,

      1 import numpy
    2 import torch
    3 from torch import nn
    4 from torch import optim
    5 from torch.utils.data import Dataset, DataLoader
    6 from matplotlib import pyplot as plt
    7
    8
    9 numpy.random.seed(1)
    10 torch.manual_seed(3)
    11
    12
    13 # 生成training数据
    14 def getData(n=100):
    15 rgbRange = (-1, 1)
    16 r = numpy.random.uniform(*rgbRange, (n, 1))
    17 g = numpy.random.uniform(*rgbRange, (n, 1))
    18 b = numpy.random.uniform(*rgbRange, (n, 1))
    19 x_ = r + 2 * g + 3 * b
    20 y_ = r ** 2 + 2 * g ** 2 + 3 * b ** 2
    21 lv_ = -3 * r - 4 * g - 5 * b
    22 RGB = numpy.hstack((r, g, b))
    23 XYLv_ = numpy.hstack((x_, y_, lv_))
    24 return RGB, XYLv_
    25
    26
    27 class TrainingDataset(Dataset):
    28
    29 def __init__(self, X, Y_, transform=None, target_transform=None):
    30 self.__X = X
    31 self.__Y_ = Y_
    32 self.__transform = transform
    33 self.__target_transform = target_transform
    34
    35
    36 def __len__(self):
    37 return len(self.__X)
    38
    39
    40 def __getitem__(self, idx):
    41 x = self.__X[idx]
    42 y_ = self.__Y_[idx]
    43 if self.__transform:
    44 x = self.__transform(x)
    45 if self.__target_transform:
    46 y_ = self.__target_transform(y_)
    47 return x, y_
    48
    49
    50 RGB, XYLv_ = getData(1000)
    51 trainingData = TrainingDataset(RGB, XYLv_, torch.Tensor, torch.Tensor)
    52
    53 batch_size = 100
    54 trainingLoader = DataLoader(trainingData, batch_size=batch_size, shuffle=True)
    55
    56
    57 class NeuralNetwork(nn.Module):
    58
    59 def __init__(self):
    60 super(NeuralNetwork, self).__init__()
    61 self.__linear_tanh_stack = nn.Sequential(
    62 nn.Linear(3, 5),
    63 nn.Tanh(),
    64 nn.Linear(5, 3)
    65 )
    66
    67
    68 def forward(self, x):
    69 y = self.__linear_tanh_stack(x)
    70 return y
    71
    72
    73 model = NeuralNetwork()
    74 loss_func = nn.MSELoss(reduction="sum")
    75 optimizer = optim.Adam(model.parameters(), lr=0.001)
    76
    77
    78 def train_loop(dataloader, model, loss_func, optimizer):
    79 JVal = 0
    80 for batchIdx, (X, Y_) in enumerate(dataloader):
    81 Y = model(X)
    82 loss = loss_func(Y, Y_)
    83
    84 JVal += loss.item()
    85
    86 optimizer.zero_grad()
    87 loss.backward()
    88 optimizer.step()
    89
    90 JVal /= 2
    91 return JVal
    92
    93
    94 JPath = list()
    95 epoch = 50000
    96 for epochIdx in range(epoch):
    97 JVal = train_loop(trainingLoader, model, loss_func, optimizer)
    98 print("epoch: {:5d}, JVal = {:.5f}".format(epochIdx, JVal))
    99 JPath.append(JVal)
    100
    101
    102 torch.save(model.state_dict(), "model_params.pth")
    103
    104
    105 fig = plt.figure(figsize=(6, 4))
    106 ax1 = fig.add_subplot(1, 1, 1)
    107
    108 ax1.plot(numpy.arange(len(JPath)), JPath, "k.", markersize=1)
    109 ax1.plot(0, JPath[0], "go", label="seed")
    110 ax1.plot(len(JPath)-1, JPath[-1], "r*", label="solution")
    111
    112 ax1.legend()
    113 ax1.set(xlabel="$epoch$", ylabel="$JVal$", title="solution-JVal = {:.5f}".format(JPath[-1]))
    114
    115 fig.tight_layout()
    116 fig.savefig("plot_fig.png", dpi=100)
  • 结果展示
    可以看到, 在training data上总体loss随epoch增加逐渐降低.
  • 使用建议
    ①. 分batch处理训练数据, 可以提升训练初始阶段模型参数收敛速度;
    ②. 常规优化器推荐Adam, 具备自动步长调节的能力.

  • 参考文档
    ①. https://pytorch.org/tutorials/beginner/basics/intro.html

PyTorch之初级使用的更多相关文章

  1. Pytorch【直播】2019 年县域农业大脑AI挑战赛---初级准备(一)切图

    比赛地址:https://tianchi.aliyun.com/competition/entrance/231717/introduction 这次比赛给的图非常大5万x5万,在训练之前必须要进行数 ...

  2. 转pytorch中训练深度神经网络模型的关键知识点

    版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_42279044/articl ...

  3. 马哥linux运维初级+中级+高级 视频教程 教学视频 全套下载(近50G)

    马哥linux运维初级+中级+高级 视频教程 教学视频 全套下载(近50G)目录详情:18_02_ssl协议.openssl及创建私有CA18_03_OpenSSH服务及其相关应用09_01_磁盘及文 ...

  4. Python 正则表达式入门(初级篇)

    Python 正则表达式入门(初级篇) 本文主要为没有使用正则表达式经验的新手入门所写. 转载请写明出处 引子 首先说 正则表达式是什么? 正则表达式,又称正规表示式.正规表示法.正规表达式.规则表达 ...

  5. python 高级之面向对象初级

    python 高级之面向对象初级 本节内容 类的创建 类的构造方法 面向对象之封装 面向对象之继承 面向对象之多态 面向对象之成员 property 1.类的创建 面向对象:对函数进行分类和封装,让开 ...

  6. N皇后问题—初级回溯

    N皇后问题,最基础的回溯问题之一,题意简单N*N的正方形格子上放置N个皇后,任意两个皇后不能出现在同一条直线或者斜线上,求不同N对应的解. 提要:N>13时,数量庞大,初级回溯只能保证在N< ...

  7. python 面向对象初级篇

    Python 面向对象(初级篇) 概述 面向过程:根据业务逻辑从上到下写垒代码 函数式:将某功能代码封装到函数中,日后便无需重复编写,仅调用函数即可 面向对象:对函数进行分类和封装,让开发" ...

  8. codefordream 关于js初级训练

    这里的初级训练相对简单,差不多都是以前知识温习. 比如输出“hello world”,直接使用console.log()就行.注释符号,“//”可以注释单行,快捷键 alt+/,"/*   ...

  9. Mysql操作初级

    Mysql操作初级 本节内容 数据库概述 数据库安装 数据库操作 数据表操作 表内容操作 1.数据库概述 数据库管理系统叫做DBMS 1.什么是数据库 ? 答:数据的仓库,如:在ATM的示例中我们创建 ...

  10. python面向对象初级(七)

    概述 面向过程:根据业务逻辑从上到下写垒代码 函数式:将某功能代码封装到函数中,日后便无需重复编写,仅调用函数即可 面向对象:对函数进行分类和封装,让开发“更快更好更强...” 面向过程编程最易被初学 ...

随机推荐

  1. 【开发宝典】Java并发系列教程(四)

    作者:京东零售 刘跃明 Monitor概念 Java对象的内存布局 对象除了我们自定义的一些属性外,还有其它数据,在内存中可以分为三个区域:对象头.实例数据.对齐填充,这三个区域组成起来才是一个完整的 ...

  2. E - 树状数组 1【GDUT_22级寒假训练专题五】

    E - 树状数组 1 原题链接 题意 已知一个数列,你需要进行下面两种操作: 将某一个数加上 \(x\) 求出某区间每一个数的和 lowbit函数 定义一个函数\(f=lowbit(x)\),这个函数 ...

  3. 关于vux-ui框架的scroller组件所踩的坑

    这是我在做一个demo的一个上垃加载下拉刷新功能时所遇到的问题,由于伤了好一会脑筋,所以留下这篇笔记以供后续查询: 在上代码前建议在开发项目时不要优先选择vux这个框架,因为有一些常用的功能组件官方已 ...

  4. 有趣的python库-pillow

    pillow-图像处理 安装时不再是PIL,是pillow哦! 烟花 pillow + tkinter实现 import tkinter as tk from PIL import Image, Im ...

  5. LG P6156 简单题

    \(\text{Problem}\) \(\text{Analysis}\) 显然 \(f=\mu^2\) 那么 \[\begin{aligned} \sum_{i=1}^n \sum_{j=1}^n ...

  6. 简单添加table线条

    <table style="width: 100%; margin: 0 auto; border: 1px solid #BBBBBB; border-collapse: colla ...

  7. [专题总结]Gridea快速免费搭建个人博客

    介绍 或许你很想把你所知道的问题写出来,或许你文思泉涌,想给大家分享.我相信,你一定能写好博客,只要坚持,就可以了. 或许大家会不理解,为什么不用大平台的博客呢?或许你稍微了解就会知道,现在的博客平台 ...

  8. 获取微信小程序列表渲染 index

    微信小程序列表渲染 index(索引值)通过 wx:for-index="index" 来获取: <view class="item" wx:for=&q ...

  9. Java 反射概念的引入

    反射是什么 学Java的人都知道类概念,反射技术就是一种控制类的技术,JAVA程序在运行时,通过反射这个技术,能动态的获取到类实例的信息.创建实体类.操作实体类. 反射的功能列表: 获取任意类的名称. ...

  10. Django中获取用户IP方法

    Django中通过request.META可以来获取用户的IP. request.META 是一个Python字典,包含了所有本次HTTP请求的Header信息,比如用户IP地址和用户Agent(通常 ...