搭建网络的步骤大致为以下:

1.准备数据

2. 定义网络结构model

3. 定义损失函数
4. 定义优化算法 optimizer
5. 训练
  5.1 准备好tensor形式的输入数据和标签(可选)
  5.2 前向传播计算网络输出output和计算损失函数loss
  5.3 反向传播更新参数
    以下三句话一句也不能少:
    5.3.1 optimizer.zero_grad()  将上次迭代计算的梯度值清0
    5.3.2 loss.backward()  反向传播,计算梯度值
    5.3.3 optimizer.step()  更新权值参数
  5.4 保存训练集上的loss和验证集上的loss以及准确率以及打印训练信息。(可选
6. 图示训练过程中loss和accuracy的变化情况(可选)
7. 在测试集上测试

代码注释都写的很详细

 import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # 1.准备数据 generate data
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
print(x.shape)
y=x*x+0.2*torch.rand(x.size())
#显示数据散点图
plt.scatter(x.data.numpy(),y.data.numpy()) # 2.定义网络结构 build net
class Net(torch.nn.Module):
#n_feature:输入特征个数 n_hidden:隐藏层个数 n_output:输出层个数
def __init__(self,n_feature,n_hidden,n_output):
# super表示继承Net的父类,并同时初始化父类的参数
super(Net,self).__init__()
# nn.Linear代表线性层 代表y=w*x+b 其中w的shape为[n_hidden,n_feature] b的shape为[n_hidden]
# y=w^T*x+b 这里w的维度是转置前的维度 所以是反的
self.hidden =torch.nn.Linear(n_feature,n_hidden)
self.predict =torch.nn.Linear(n_hidden,n_output)
print(self.hidden.weight)
print(self.predict.weight)
#定义一个前向传播过程函数
def forward(self, x):
# n_feature n_hidden n_output
#举例(2,5,1) 2 5 1
# - ** -
# ** - - - ** - -
# - ** - - - **
# ** - - - ** - -
# - ** -
# 输入层 隐藏层 输出层
x=F.relu(self.hidden(x))
x=self.predict(x)
return x
# 实例化一个网络为net
net = Net(n_feature=1,n_hidden=10,n_output=1)
print(net)
# 3.定义损失函数 这里使用均方误差(mean square error)
loss_func=torch.nn.MSELoss()
# 4.定义优化器 这里使用随机梯度下降
optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
#定义300遍更新 每10遍显示一次
plt.ion()
# 5.训练
for t in range(100):
prediction = net(x) # input x and predict based on x
loss = loss_func(prediction, y) # must be (1. nn output, 2. target)
# 5.3反向传播三步不可少
optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if t % 10 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.show()
plt.pause(0.1) plt.ioff()

参考:莫烦python

pytorch基础-搭建网络的更多相关文章

  1. pytorch基础(4)-----搭建模型网络的两种方法

    方法一:采用torch.nn.Module模块 import torch import torch.nn.functional as F #法1 class Net(torch.nn.Module): ...

  2. Pytorch从0开始实现YOLO V3指南 part2——搭建网络结构层

    本节翻译自:https://blog.paperspace.com/how-to-implement-a-yolo-v3-object-detector-from-scratch-in-pytorch ...

  3. pytorch搭建网络,保存参数,恢复参数

    这是看过莫凡python的学习笔记. 搭建网络,两种方式 (1)建立Sequential对象 import torch net = torch.nn.Sequential( torch.nn.Line ...

  4. TCP/IP协议(一)网络基础知识 网络七层协议

    参考书籍为<图解tcp/ip>-第五版.这篇随笔,主要内容还是TCP/IP所必备的基础知识,包括计算机与网络发展的历史及标准化过程(简述).OSI参考模型.网络概念的本质.网络构建的设备等 ...

  5. PyTorch对ResNet网络的实现解析

    PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...

  6. [人工智能]Pytorch基础

    PyTorch基础 摘抄自<深度学习之Pytorch>. Tensor(张量) PyTorch里面处理的最基本的操作对象就是Tensor,表示的是一个多维矩阵,比如零维矩阵就是一个点,一维 ...

  7. 【新生学习】第一周:深度学习及pytorch基础

    DEADLINE: 2020-07-25 22:00 写在最前面: 本课程的主要思路还是要求大家大量练习 pytorch 代码,在写代码的过程中掌握深度学习的各类算法,希望大家能够坚持练习,相信经度过 ...

  8. 使用pytorch快速搭建神经网络实现二分类任务(包含示例)

    使用pytorch快速搭建神经网络实现二分类任务(包含示例) Introduce 上一篇学习笔记介绍了不使用pytorch包装好的神经网络框架实现logistic回归模型,并且根据autograd实现 ...

  9. 001-深度学习Pytorch环境搭建(Anaconda , PyCharm导入)

    001-深度学习Pytorch环境搭建(Anaconda , PyCharm导入) 在开始搭建之前我们先说一下本次主要安装的东西有哪些. anaconda 3:第三方包管理软件. 这个玩意可以看作是一 ...

随机推荐

  1. Selenium环境要配置浏览器驱动

    1.浏览器环境变量添加到path 2.将浏览器相应的驱动.exe复制到浏览器目录 3.这条就是让我傻逼似的配置一上午的罪魁祸首:将驱动.exe复制到python目录!!!! Selenium

  2. Python中定义只读属性

    Python是面向对象(OOP)的语言, 而且在OOP这条路上比Java走得更彻底, 因为在Python里, 一切皆对象, 包括int, float等基本数据类型. 在Java里, 若要为一个类定义只 ...

  3. MySQL 社区版 安装小记

    根据刘铁猛老师的教程,自己折腾一下 1. 安装包准备 在Windows10 64bit上安装,故需要准备vc++ 2013和2015的Redistributable的包,搜索即有,无需细说. 示例数据 ...

  4. Docker 中卷组管理

    一.概念 数据卷是一个可供一个或多个容器使用的特殊目录实现让容器的一个目录和宿主机中的一个文件或者目录进行绑定.数据卷 是被设计用来持久化数据的,对于数据卷你可以理解为NFS中的哪个分享出来的挂载点, ...

  5. [20191127]表 full Hash Value的计算.txt

    [20191127]表 full Hash Value的计算.txt --//曾经做过表full Hash Value的计算,当时我是通过建立简单的schema以及表名的形式,使用hashcat破解o ...

  6. 微信小程序框架部署:mpvue+typescript

    开发前提: 1.在微信公众平台注册申请 AppID 2.安装开发者工具https://developers.weixin.qq.com/miniprogram/dev/devtools/downloa ...

  7. python科学计算和数据分析常用库

    NumPy NumPy最强大的是n维数组,该库还包含基本的线性代数函数.傅立叶变换.随机函数和其他底层语言(如Fortran.C和C++)集成的工具. SciPy SciPy建立在NumPy基础上,它 ...

  8. jenkins实现git自动拉取代码时替换配置文件

    jenkins实现从git上自动拉取源代码——>自动编译——>发布到测试服务器——>验证测试,这个大家应该都知道,但是关于源代码里的配置文件,可能就会有点头疼了, 一般测试都会自己的 ...

  9. fjnu2019第二次友谊赛 B题

    ### 题目链接 ### 题目大意: 给你一个 n * m 的地图以及小蛇蛇头的初始位置,告诉你它会往 上.下.左.右 四个方向走.若在走的过程中(包括结束时)会使得小蛇越界,则输出 "Ga ...

  10. Linux CentOS上安装 MySQL 8.0.16

    前言: 因为我需要在我新安装的Linux CentOS系统服务器中安装和配置MySQL服务器,然而对于我们这种Linux使用小白而言在Linux系统中下载,解压,配置MySQL等一系列的操作还是有些耗 ...