【pytorch】学习笔记(四)-搭建神经网络进行关系拟合
【pytorch学习笔记】-搭建神经网络进行关系拟合
目标
1.创建一些围绕y=x^2+噪声这个函数的散点
2.用神经网络模型来建立一个可以代表他们关系的线条
建立数据集
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)#一维变二维,x从-1到1,切分为100份
y=x.pow(2)+0.2*torch.rand(x.size())#创建一些围绕着这y=x^2的随机点的散点
# plt.scatter(x.data.numpy(),y.data.numpy())#画图
# plt.show()
x,y=Variable(x),Variable(y)#构造神经网络要使用Variable类型
建立神经网络
1.继承torch.nn.Module模块
2.定义__init__函数,在初始化函数中定义输入层到隐藏层,从隐藏层再到输出层各个层的神经元个数
3.再一层层搭建(forward(x))层于层的关系链接
class Net(torch.nn.Module):
def __init__(self,n_feature,n_hidden,n_ouput):#初始化信息
super(Net, self).__init__()
self.hidden=torch.nn.Linear(n_feature,n_hidden,n_ouput)#隐藏层线性输出
self.predict=torch.nn.Linear(n_hidden,n_ouput)#输出层线性输出
def forward(self,x):#前向传递的过程
#正向传播输入值,神经网络输出预测值
x=F.relu(self.hidden(x))#激励函数加工一下
x=self.predict(x)#输出值预测值
return x
训练神经网络
1.定义训练工具optimizer,输入神经网络参数和学习效率
2.定义误差函数,使用均方差来计算实际值y和训练输出值之间的误差
3.每次训练向神经网络输入x,得到预测值,计算误差
4.注意要清空上一步的残余更新参数值
5.误差反向传播, 计算参数更新值
6.将参数更新值施加到 net 的 parameters 上
for t in range(200):#训练200次
prediction=net(x)#输入输入值
loss=loss_func(prediction,y)#计算误差预测值和真实值之间的误差,注意位置
optimizer.zero_grad()#梯度清零
loss.backward()#反向传递
optimizer.step()#优化梯度
可视化训练过程
for t in range(200):#训练200次
prediction=net(x)#输入输入值
loss=loss_func(prediction,y)#计算误差预测值和真实值之间的误差,注意位置
optimizer.zero_grad()#梯度清零
loss.backward()#反向传递
optimizer.step()#优化梯度
# 接着上面来
if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1)
完整代码
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)#一维变二维
y=x.pow(2)+0.2*torch.rand(x.size())
# plt.scatter(x.data.numpy(),y.data.numpy())
# plt.show()
x,y=Variable(x),Variable(y)#构造神经网络的是琥珀要使用Variable类型的
class Net(torch.nn.Module):
def __init__(self,n_feature,n_hidden,n_ouput):#初始化信息
super(Net, self).__init__()
self.hidden=torch.nn.Linear(n_feature,n_hidden,n_ouput)#隐藏层线性输出
self.predict=torch.nn.Linear(n_hidden,n_ouput)#输出层线性输出
def forward(self,x):#前向传递的过程
#正向传播输入值,神经网络输出预测值
x=F.relu(self.hidden(x))#激励函数加工一下
x=self.predict(x)#输出值预测值
return x
net=Net(n_feature=1,n_hidden=10,n_ouput=1)#输入值是一个,隐藏层有10个神经元,输出值为y值
print(net)
optimizer=torch.optim.SGD(net.parameters(),lr=0.5)#输入神经网络的所有参数,学习效率,这个是训练工具
loss_func=torch.nn.MSELoss()#误差处理均方差
plt.ion() # 画图
plt.show()
for t in range(200):#训练200次
prediction=net(x)#输入输入值
loss=loss_func(prediction,y)#计算误差预测值和真实值之间的误差,注意位置
optimizer.zero_grad()#梯度清零
loss.backward()#反向传递
optimizer.step()#优化梯度
# 接着上面来
if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1)
过程结果






中间过程省略一部分...

【pytorch】学习笔记(四)-搭建神经网络进行关系拟合的更多相关文章
- 莫烦PyTorch学习笔记(四)——回归
下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...
- ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试
http://www.cnblogs.com/denny402/p/5852983.html ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试 刚开始学习tf时,我们从 ...
- Go语言学习笔记四: 运算符
Go语言学习笔记四: 运算符 这章知识好无聊呀,本来想跨过去,但没准有初学者要学,还是写写吧. 运算符种类 与你预期的一样,Go的特点就是啥都有,爱用哪个用哪个,所以市面上的运算符基本都有. 算术运算 ...
- kvm虚拟化学习笔记(四)之kvm虚拟机日常管理与配置
KVM虚拟化学习笔记系列文章列表----------------------------------------kvm虚拟化学习笔记(一)之kvm虚拟化环境安装http://koumm.blog.51 ...
- MySql学习笔记四
MySql学习笔记四 5.3.数据类型 数值型 整型 小数 定点数 浮点数 字符型 较短的文本:char, varchar 较长的文本:text, blob(较长的二进制数据) 日期型 原则:所选择类 ...
- 官网实例详解-目录和实例简介-keras学习笔记四
官网实例详解-目录和实例简介-keras学习笔记四 2018-06-11 10:36:18 wyx100 阅读数 4193更多 分类专栏: 人工智能 python 深度学习 keras 版权声明: ...
- ZooKeeper学习笔记四:使用ZooKeeper实现一个简单的分布式锁
作者:Grey 原文地址: ZooKeeper学习笔记四:使用ZooKeeper实现一个简单的分布式锁 前置知识 完成ZooKeeper集群搭建以及熟悉ZooKeeperAPI基本使用 需求 当多个进 ...
- C#可扩展编程之MEF学习笔记(四):见证奇迹的时刻
前面三篇讲了MEF的基础和基本到导入导出方法,下面就是见证MEF真正魅力所在的时刻.如果没有看过前面的文章,请到我的博客首页查看. 前面我们都是在一个项目中写了一个类来测试的,但实际开发中,我们往往要 ...
- IOS学习笔记(四)之UITextField和UITextView控件学习
IOS学习笔记(四)之UITextField和UITextView控件学习(博客地址:http://blog.csdn.net/developer_jiangqq) Author:hmjiangqq ...
随机推荐
- HGOI 20191029am 题解
Promblem A 小G的字符串 给定$n,k$,构造一个长度为$n$,只能使用$k$种小写字母的字符串. 要求相邻字符不能相同且$k$种字母都要出现 输出字典序最小的字符串,无解输出$-1$. 对 ...
- 图论小专题A
大意失荆州.今天考试一到可以用Dijkstra水过的题目我竟然没有做出来,这说明基础还是相当重要.考虑到我连Tarjan算法都不太记得了,我决定再过一遍蓝皮书,对图论做一个小的总结.图论这个部分可能会 ...
- TensorFlow使用记录 (十): Pretraining
上一篇的模型保存和恢复熟练后,我们就可以大量使用 pretrain model 来训练任务了 Tweaking, Dropping, or Replacing the Upper Layers The ...
- Python基础之Python语言类型
编程语言主要从以下几个角度进行分类: 编译型和解释型 静态语言和动态语言 强类型定义语言和弱类型定义语言 编译和解释的区别是什么? 编译器把源程序的每一条语句都编译成机器语言,并保存成二进制文件,这样 ...
- Hive使用与安装步骤
1.Hive安装与配置 Hive官网:https://hive.apache.org/ 1. 安装文件下载 从Apache官网下载安装文件 http://mirror.bit.edu.cn/apach ...
- From 7.8 To 7.14
From 7.8 To 7.14 大纲 学科 英语的话每天早上背单词, 争取每天做一篇完型, 一篇阅读, 一篇短文填空, 一篇改错, 一篇七选五??? 似乎太多了, 先试一下吧 语文的话, 尝试翻译一 ...
- CLion配置Cygwin环境
CLion "download" 跳转到 https://cygwin.com/install.html 下载64位安装程序并安装 国内添加网易镜像 http://mirrors. ...
- 2018-2019-2 20165215《网络对抗技术》Exp9 :Web安全基础
目录 实验目的及内容 实验过程记录 一.Webgoat安装 二. 注入缺陷(Injection Flaws) (一)命令注入(Command Injection) (二)数字型注入(Numeric S ...
- pure-ftpd搭建简单的Ubuntu FTP服务器
Linux下的ftpd很多,Ubuntu下常用vsftpd, proftpd和pure-ftpd,当初使用的就是proftpd. 不过前两者有个致命的问题就是内码转换,它们默认使用UTF-8编码,而W ...
- win10备忘
你要允许来自未知发布者 http://www.xitonghe.com/jiaocheng/Windows10-7809.html输入法 切换繁体 ctrl+shift+F win10 输入法 htt ...