1>建立数据集(并绘制图像)

 # -*- coding: utf-8 -*-
#demo.py
import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable
# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(3) + 0.2*torch.rand(x.size()) #y=x³+b # 变为Variable
x, y = Variable(x), Variable(y) # 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

得到下图:

 

2>建立神经网络

 lass Net(torch.nn.Module):    # 继承 torch 的 Module
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__() # 继承 __init__ 功能
self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐藏层线性输出
self.predict = torch.nn.Linear(n_hidden, n_output) # 输出层线性输出 def forward(self, x):
x = F.relu(self.hidden(x)) # 激励函数(隐藏层的线性值)
x = self.predict(x) # 输出值
return x net = Net(n_feature=1, n_hidden=10, n_output=1) # 定义神经网络
print(net) # net 的结构

一般神经网络的类都继承自torch.nn.Module__init__()forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过 self指针来进行访问,所以参数列表中都包含了self

3>训练神经网络(并可视化训练过程)

 # optimizer 是训练的工具
optimizer = torch.optim.SGD(net.parameters(), lr=0.2) # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss() # 预测值和真实值的误差计算公式 (均方差) # 训练100次
for t in range(100):
prediction = net(x) # 喂给 net 训练数据 x, 输出预测值 loss = loss_func(prediction, y) #计算两者的误差,一定要是输出在前,标签在后 (1. nn output, 2. target) optimizer.zero_grad() # 清空上一步的残余更新参数值
loss.backward() # 误差反向传播, 计算参数更新值
optimizer.step() # 将参数更新值施加到 net 的 parameters 上 plt.ion() # 画图

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

可视化训练过程:

 for t in range(100):

     ...
loss.backward()
optimizer.step()
if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()

摘取了几个过程:

 
 

最终结果:

 

参考文档:
基于pytorch的简易卷积神经网络结构搭建-卷积神经网络(CNN)浅析
卷积神经网络CNN入门[pytorch学习]

Pytorch搭建简单神经网络 Task2的更多相关文章

  1. pytorch搭建简单网络

    pytorch搭建一个简单神经网络 import torch import torch.nn as nn # 定义数据 # x:输入数据 # y:标签 x = torch.Tensor([[0.2, ...

  2. Pytorch搭建卷积神经网络用于MNIST分类

    import torch from torch.utils.data import DataLoader from torchvision import datasets from torchvisi ...

  3. pytorch深度学习神经网络实现手写字体识别

    利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下: 其具体实现代码如下所示:import torch ...

  4. python日记:用pytorch搭建一个简单的神经网络

    最近在学习pytorch框架,给大家分享一个最最最最基本的用pytorch搭建神经网络并且训练的方法.本人是第一次写这种分享文章,希望对初学pytorch的朋友有所帮助! 一.任务 首先说下我们要搭建 ...

  5. 用pytorch1.0搭建简单的神经网络:进行多分类分析

    用pytorch1.0搭建简单的神经网络:进行多分类分析 import torch import torch.nn.functional as F # 包含激励函数 import matplotlib ...

  6. 用pytorch1.0搭建简单的神经网络:进行回归分析

    搭建简单的神经网络:进行回归分析 import torch import torch.nn.functional as F # 包含激励函数 import matplotlib.pyplot as p ...

  7. 如何入门Pytorch之二:如何搭建实用神经网络

    上一节中,我们介绍了Pytorch的基本知识,如数据格式,梯度,损失等内容. 在本节中,我们将介绍如何使用Pytorch来搭建一个经典的分类神经网络. 搭建一个神经网络并训练,大致有这么四个部分: 1 ...

  8. 用pytorch1.0快速搭建简单的神经网络

    用pytorch1.0搭建简单的神经网络 import torch import torch.nn.functional as F # 包含激励函数 # 建立神经网络 # 先定义所有的层属性(__in ...

  9. [炼丹术]使用Pytorch搭建模型的步骤及教程

    使用Pytorch搭建模型的步骤及教程 我们知道,模型有一个特定的生命周期,了解这个为数据集建模和理解 PyTorch API 提供了指导方向.我们可以根据生命周期的每一个步骤进行设计和优化,同时更加 ...

随机推荐

  1. CADisplayLink & NSTimer

    屏幕刷新与UI更新不同步:屏幕刷新由硬件(+GPU)保证,UI更新由软件(CPU保证). 出现卡顿的原因是软件的计算速度跟不上硬件的刷新速度. 一 简介 1 所在框架 CADisplayLink和其它 ...

  2. 机器学习实战笔记--AdaBoost(实例代码)

    #coding=utf-8 from numpy import * def loadSimpleData(): dataMat = matrix([[1. , 2.1], [2. , 1.1], [1 ...

  3. python中方法与函数的区别与联系

    今天huskiesir在对列表进行操作的时候,用到了sorted()函数,偶然情况下在菜鸟教程上看到了内置方法sort,同样都可以实现我对列表的排序操作,那么方法和函数有什么区别和联系呢? 如下是我个 ...

  4. Lvs+heartbeat高可用高性能web站点的搭建

    这是我们公司在实际的生产环境当中使用的一套东西,希望对大家有所帮助(实际的公网ip,我已经做了相应的修改): 说明:每台服务器需要有两块网卡:eth0连接内网的交换机,用私网ip,实现服务器间内部访问 ...

  5. vue父组件引用子组件方法显示undefined问题原因及解决方法

    关于vue父组件引用子组件问题 1.首先导入子组件并且在components中定义子组件 2.引用子组件,并定义ref,ref定义的名称用于 this.$refs所调用的名称 3.调用子组件的方法 ( ...

  6. mysql创建数据库并指定uft8编码

    CREATE DATABASE IF NOT EXISTS ymk default character set utf8 COLLATE utf8_general_ci;

  7. 在JAVA中将class文件编译成jar文件包,运行提示没有主清单属性

    在JAVA中将class文件编译成jar文件包,运行提示没有主清单属性 Maven 项目生成jar运行时提示“没有主清单属性” 新建了一个Maven的项目,mvn compile和mvn packag ...

  8. CC2540/CC2541 : Set the Peripheral Being Advertising while It is Being Connected

    There is possible to set your CC254X be scanable when it is in connection. But, based on my test,the ...

  9. Codeforces Round #286 (Div. 1) B. Mr. Kitayuta's Technology (强连通分量)

    题目地址:http://codeforces.com/contest/506/problem/B 先用强连通判环.然后转化成无向图,找无向图连通块.若一个有n个点的块内有强连通环,那么须要n条边.即正 ...

  10. SSH公钥认证

    一.实验的目的 了解密钥对的创建和使用,掌握免password远程登录和远程操作 二.实验环境 本地主机 rh1: 192.168.233.3/24 远程主机 rh2: 192.168.233.4/2 ...