一、回归任务介绍:

拟合一个二元函数 y = x ^ 2.

二、步骤:

  1. 导入包
  2. 创建数据
  3. 构建网络
  4. 设置优化器和损失函数
  5. 前向和后向传播训练网络
  6. 画图

三、代码:

导入包:

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt

创建数据

#torch中的数据要是二维的,unsqueeze是将一维数据转化成二维数据
tmp = torch.linspace(-1,1,100)
x = torch.unsqueeze(tmp,dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size()) print(tmp) #torch.Size([100])
print(x) #torch.Size([100, 1])
#转成向量
x,y = Variable(x),Variable(y)

  查看数据图像:

plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

构建网络

#Net类继承了Module这个模块
class Net(torch.nn.Module):
def __init__(self,n_feature,n_hidden,n_output):
#在搭建模型之前需要继承的一些信息,super表示继承nn.Module的信息,此步骤必须有
super(Net,self).__init__()
self.hidden = torch.nn.Linear(n_feature,n_hidden)
self.predict = torch.nn.Linear(n_hidden,n_output)
#神经网络前向传递的一个过程,流程图
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
net = Net(1,10,1)
plt.ion()
plt.show()
#可以看到搭建的图流程
print(net)
 打印的结果:
Net(
(hidden): Linear(in_features=1, out_features=10, bias=True)
(predict): Linear(in_features=10, out_features=1, bias=True)
)

设置优化器和损失函数

optimizer = torch.optim.SGD(net.parameters(),lr = 0.5)  #传入网络的参数来优化它们
loss_func = torch.nn.MSELoss()

前向和后向传播训练网络

for t in range(100):

    #forward
prediction = net(x)
loss = loss_func(prediction,y) #预测值pre在前,实际值y在后,不然结果会不一样 #backward()
optimizer.zero_grad() #梯度全部设为0
loss.backward() #loss计算参数的梯度
optimizer.step() #采用优化器以lr=0.5来优化梯度 ###########################以下为可视化过程##################################
if t % 5 == 0:
plt.cla()
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
plt.text(0.5,0,'Loss=%.4f' % loss.data[0],fontdict={'size':20,'color':'red'})
plt.pause(0.1)
plt.ioff()
plt.show()

训练结果:

第一次:

最后一次:

pytorch实战(2)-----回归例子的更多相关文章

  1. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

  2. [机器学习实战-Logistic回归]使用Logistic回归预测各种实例

    目录 本实验代码已经传到gitee上,请点击查收! 一.实验目的 二.实验内容与设计思想 实验内容 设计思想 三.实验使用环境 四.实验步骤和调试过程 4.1 基于Logistic回归和Sigmoid ...

  3. linuxC编程实战 my_server.c例子问题总结

    今天看linux C 编程实战的my_server例子时,敲到这段代码,对其父子进程关闭socket 进行close调用产生疑问 如图中标注的三个close socket,思考子进程通信结束 关闭自己 ...

  4. 深度学习之PyTorch实战(1)——基础学习及搭建环境

    最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...

  5. 参考《深度学习之PyTorch实战计算机视觉》PDF

    计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...

  6. Tensorflow实战第十一课(RNN Regression 回归例子 )

    本节我们会使用RNN来进行回归训练(Regression),会继续使用自己创建的sin曲线预测一条cos曲线. 首先我们需要先确定RNN的各种参数: import tensorflow as tf i ...

  7. Kaggle实战之一回归问题

    0. 前言 1.任务描述 2.数据概览 3. 数据准备 4. 模型训练 5. kaggle实战 0. 前言 "尽管新技术新算法层出不穷,但是掌握好基础算法就能解决手头 90% 的机器学习问题 ...

  8. 【动手学pytorch】softmax回归

    一.什么是softmax? 有一个数组S,其元素为Si ,那么vi 的softmax值,就是该元素的指数与所有元素指数和的比值.具体公式表示为: softmax回归本质上也是一种对数据的估计 二.交叉 ...

  9. pytorch实战(一)hw1——李宏毅老师作业1

    任务描述:利用前9小时数据,预测第10小时的pm2.5的数值,回归任务 kaggle地址:https://www.kaggle.com/c/ml2020spring-hw1 训练集为: 12个月*20 ...

随机推荐

  1. 10分钟 PySimpleGUI 图形界面入门

    import PySimpleGUI as sg layout = [ [sg.Text('Enter a Number')], [sg.Input()], [sg.OK()] ] event,(nu ...

  2. C++基础 (6) 第六天 继承 虚函数 虚继承 多态 虚函数

    继承是一种耦合度很强的关系 和父类代码很多都重复的 2 继承的概念 3 继承的概念和推演 语法: class 派生类:访问修饰符 基类 代码: … … 4 继承方式与访问控制权限 相对的说法: 爹派生 ...

  3. 40 最小的K个数(时间效率)

    题目描述: 输入n个整数,找出其中最小的K个数.例如输入4,5,1,6,2,7,3,8这8个数字,则最小的4个数字是1,2,3,4,.   测试用例: 功能测试(输入的数组中有相同的数字:输入的数组中 ...

  4. 用Js写贪吃蛇

    使用Javascript做贪吃蛇小游戏, 1.自定义地图宽高,蛇的初始速度 2.食物随机出现 3.蛇的样式属性 4.贪吃蛇玩法(吃食物,碰到边界,吃食物后加速,计分,) <!DOCTYPE ht ...

  5. SpringBoot-CommandLineRunner实现预操作

    前提:在使用SpringBoot构建项目时,我们通常需要做一些预先操作(类似开机自启动).而SpringBoot正好提供了一个简单的方式来实现–CommandLineRunner. CommandLi ...

  6. SpringMVC-HandlerMapping和HandlerAdapter

    网上介绍HandlerMapping和HandlerAdapter的文章很多,今天我用自己的理解和语言来介绍下HandlerMapping和HandlerAdapter 一. HandlerMappi ...

  7. 敏捷开发-srcum

    SCRUM框架包括3个角色.3个工件.5个活动.5个价值 3个角色 1.产品负责人(Product Owner) 2.Scrum Master 3.Scrum团队 3个工具 1.Product Bac ...

  8. [using_microsoft_infopath_2010]Chapter12 管理监视InfoPath表单服务

    本章概要: 1.在SharePoint中心控制台管理InfoPath设置 2.分析监视浏览器表单开考虑潜在性能问题 3.最小化回发数据

  9. 组件的使用(三)AutoCompleteTextView的使用

    AutoCompleteTextView经常使用的属性: android:completionHint 下拉列表以下的说明性文字 android:completionThreshold 弹出下来列表的 ...

  10. 初探swift语言的学习笔记十(block)

    作者:fengsh998 原文地址:http://blog.csdn.net/fengsh998/article/details/35783341 转载请注明出处 假设觉得文章对你有所帮助,请通过留言 ...