import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 将1维数据转换成2维数据,torch不能处理1维数据。x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # torch can only train on Variable, so convert them to Variable
# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x), Variable(y) # plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show() # 有噪音的抛物线图 class Net(torch.nn.Module): # 输入特征,线性处理进入隐藏层的数据,线性处理进入输出层的数据
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x): # 激活一下进入隐藏层的数据
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.predict(x) # linear output
return x net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network 的大小
print(net) # 显示网络结构 net architecture
> Net(
> (hidden): Linear(in_features=1, out_features=10, bias=True)
> (predict): Linear(in_features=10, out_features=1, bias=True)
> )
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 设置优化器优化网络(优化参数,学习率)
loss_func = torch.nn.MSELoss() # 均方差处理回归问题 this is for regression mean squared loss # plt.ion() # something about plotting for t in range(200): # 训练的过程
prediction = net(x) # input x and predict based on x loss = loss_func(prediction, y) # 计算预测值和真实值的误差,预测值在前面,顺序不同可能影响结算结果 must be (1. nn output, 2. target) optimizer.zero_grad() # 梯度重置为零 clear gradients for next train
loss.backward() # 开始这次的反向传递,计算梯度 backpropagation, compute gradients
optimizer.step() # 使用优化器优化梯度,apply gradients if t % 5 == 0:
# 可视化显示训练过程 plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) # plt.ioff()
plt.show()

END

pytorch 4 regression 回归的更多相关文章

  1. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  2. Tensorflow实战第十一课(RNN Regression 回归例子 )

    本节我们会使用RNN来进行回归训练(Regression),会继续使用自己创建的sin曲线预测一条cos曲线. 首先我们需要先确定RNN的各种参数: import tensorflow as tf i ...

  3. 【动手学pytorch】softmax回归

    一.什么是softmax? 有一个数组S,其元素为Si ,那么vi 的softmax值,就是该元素的指数与所有元素指数和的比值.具体公式表示为: softmax回归本质上也是一种对数据的估计 二.交叉 ...

  4. pytorch之 regression

    import torch import torch.nn.functional as F import matplotlib.pyplot as plt # torch.manual_seed(1) ...

  5. pytorch神经网络解决回归问题(非常易懂)

    对于pytorch的深度学习框架,在建立人工神经网络时整体的步骤主要有以下四步: 1.载入原始数据 2.构建具体神经网络 3.进行数据的训练 4.数据测试和验证 pytorch神经网络的数据载入,以M ...

  6. 【机器学习】Softmax 和Logistic Regression回归Sigmod

    二分类问题Sigmod 在 logistic 回归中,我们的训练集由  个已标记的样本构成: ,其中输入特征.(我们对符号的约定如下:特征向量  的维度为 ,其中  对应截距项 .) 由于 logis ...

  7. Regression 回归——多项式回归

    回归是指拟合函数的模型.图像等.与分类不同,回归一般是在函数可微的情况下进行的.因为分类它就那么几类,如果把类别看做函数值的话,分类的函数值是离散的,而回归的函数值通常是连续且可微的.所以回归可以通过 ...

  8. [机器学习]回归--Decision Tree Regression

    CART决策树又称分类回归树,当数据集的因变量为连续性数值时,该树算法就是一个回归树,可以用叶节点观察的均值作为预测值:当数据集的因变量为离散型数值时,该树算法就是一个分类树,可以很好的解决分类问题. ...

  9. 浅谈回归(二)——Regression 之历史错误翻译

    我很好奇这个问题,于是搜了一下.我发现 Regression 这个词 本意里有"衰退"的意思. 词根词缀: re- 回 , 向后 + -gress- 步 , 级 + -ion 名词 ...

随机推荐

  1. Java 接口技术 Interface

    一.什么是接口技术(Interface): //举例中Comparable是一个接口,Employee是一个类 1.接口不是类,而是对类的一组描述,并不给出每个类的具体实现. 2.一个类可以实现多个接 ...

  2. 配置监听器 服务器启动时 检索常用数据 保存在application中 减少数据的查询操作(OA项目)

    模型 大致介绍一下:左侧菜单是用户登录成功之后显示的页面  这些数据就是通过查询数据库 然后在页面中把查到的数据  循环遍历出来   构成了操作菜单 第一个解决的问题:常用数据  在服务器启动的时候 ...

  3. RobotFrameWork+APPIUM实现对安卓APK的自动化测试----第七篇【元素定位介绍】

    http://blog.csdn.net/deadgrape/article/details/50628113 我想大家在玩自动化的时候最关心的一定是如何定位元素,因为元素定位不到后面的什么方法都实现 ...

  4. PDF转EPUB格式电子书经验总结

    依据本人将PDF转换为EPUB电子书的经验,总结整理了这篇文章.因本人水平有限,难免有错误和不足之处,望大家及时批评指正.   写这篇文章时,假定读者已经会使用文中所列出软件的基本操作,比方如何用No ...

  5. angularjs1-3,$apply,$watch

    <!DOCTYPE html> <html> <head> <meta http-equiv="Content-Type" content ...

  6. weblogic管理脚本

    start.sh Java代码  #!/usr/bin/bash # # start.sh # @auth: zhoulin@lianchuang.com # SERVER_STATUS () { s ...

  7. DB-MySQL:MySQL 处理重复数据

    ylbtech-DB-MySQL:MySQL 处理重复数据 1.返回顶部 1. MySQL 处理重复数据 有些 MySQL 数据表中可能存在重复的记录,有些情况我们允许重复数据的存在,但有时候我们也需 ...

  8. 关于懒加载中的self.和_

    ---恢复内容开始--- 在开发中,经常会用到懒加载,最常用的如加载一个数组 如图,在这个懒加载数组中有的地方用到了_array有的地方用到了self.array 原因是_array是直接访问,而se ...

  9. caffe study- AlexNet 之算法篇

    在机器学习中,我们通常要考虑的一个问题是如何的“以偏概全”,也就是以有限的样本或者结构去尽可能的逼近全局的分布.这就要在样本以及结构模型上下一些工夫. 在一般的训练任务中,考虑的关键问题之一就是数据分 ...

  10. 网易NAPM Andorid SDK实现原理--转

    原文地址:https://neyoufan.github.io/2017/03/10/android/NAPM%20Android%20SDK/ NAPM 是网易的应用性能管理平台,采用非侵入的方式获 ...