线性回归

# 导入 torch、torch.autograd的Variable模块
import torch
from torch.autograd import Variable

# 生成需要回归需要的tensor
x_data = Variable(torch.Tensor([[1.0], [2.0], [3.0]]))
y_data = Variable(torch.Tensor([[2.0], [4.0], [6.0]])) #定义模型类,继承自nn包的Module模块
class Model(torch.nn.Module):
# 初始化方法
def __init__(self):
"""
In the constructor we instantiate two nn.Linear module
"""
# 子类转换为父类对象并调用初始化方法
super(Model, self).__init__()
# 线性回归,一个输入,一个输出
self.linear = torch.nn.Linear(1, 1) # One in and one out def forward(self, x):
"""
In the forward function we accept a Variable of input data and we must return
a Variable of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Variables.
"""
# y=K *x 返回结果
y_pred = self.linear(x)
return y_pred # our model # 调用模型类生成模型的对象
model = Model() # Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model. ##使用均方误差函数,在多个样本情况下,不计算平均值
criterion = torch.nn.MSELoss(size_average=False)
##使用随机梯度下降,学习率为0.01
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # Training loop
# 训练 500轮
for epoch in range(500):
# Forward pass: Compute predicted y by passing x to the model
# 前向推理模型
y_pred = model(x_data) # Compute and print loss
# 计算损失
loss = criterion(y_pred, y_data)
print(epoch, loss.data[0]) # Zero gradients, perform a backward pass, and update the weights.
# 由于使用多次训练,每轮结束之后要清零数据
optimizer.zero_grad()
# 反向传播求梯度值
loss.backward()
# 更新所有的参数
optimizer.step() # After training
hour_var = Variable(torch.Tensor([[4.0]]))
y_pred = model(hour_var)
print("predict (after training)", 4, model(hour_var).data[0][0])

PytorchZerotoAll学习笔记(四)--线性回归的更多相关文章

  1. C#可扩展编程之MEF学习笔记(四):见证奇迹的时刻

    前面三篇讲了MEF的基础和基本到导入导出方法,下面就是见证MEF真正魅力所在的时刻.如果没有看过前面的文章,请到我的博客首页查看. 前面我们都是在一个项目中写了一个类来测试的,但实际开发中,我们往往要 ...

  2. IOS学习笔记(四)之UITextField和UITextView控件学习

    IOS学习笔记(四)之UITextField和UITextView控件学习(博客地址:http://blog.csdn.net/developer_jiangqq) Author:hmjiangqq ...

  3. java之jvm学习笔记四(安全管理器)

    java之jvm学习笔记四(安全管理器) 前面已经简述了java的安全模型的两个组成部分(类装载器,class文件校验器),接下来学习的是java安全模型的另外一个重要组成部分安全管理器. 安全管理器 ...

  4. Learning ROS for Robotics Programming Second Edition学习笔记(四) indigo devices

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

  5. Typescript 学习笔记四:回忆ES5 中的类

    中文网:https://www.tslang.cn/ 官网:http://www.typescriptlang.org/ 目录: Typescript 学习笔记一:介绍.安装.编译 Typescrip ...

  6. ES6学习笔记<四> default、rest、Multi-line Strings

    default 参数默认值 在实际开发 有时需要给一些参数默认值. 在ES6之前一般都这么处理参数默认值 function add(val_1,val_2){ val_1 = val_1 || 10; ...

  7. muduo网络库学习笔记(四) 通过eventfd实现的事件通知机制

    目录 muduo网络库学习笔记(四) 通过eventfd实现的事件通知机制 eventfd的使用 eventfd系统函数 使用示例 EventLoop对eventfd的封装 工作时序 runInLoo ...

  8. python3.4学习笔记(四) 3.x和2.x的区别,持续更新

    python3.4学习笔记(四) 3.x和2.x的区别 在2.x中:print html,3.x中必须改成:print(html) import urllib2ImportError: No modu ...

  9. Go语言学习笔记四: 运算符

    Go语言学习笔记四: 运算符 这章知识好无聊呀,本来想跨过去,但没准有初学者要学,还是写写吧. 运算符种类 与你预期的一样,Go的特点就是啥都有,爱用哪个用哪个,所以市面上的运算符基本都有. 算术运算 ...

  10. 零拷贝详解 Java NIO学习笔记四(零拷贝详解)

    转 https://blog.csdn.net/u013096088/article/details/79122671 Java NIO学习笔记四(零拷贝详解) 2018年01月21日 20:20:5 ...

随机推荐

  1. git删除本地保存的账号和密码

    使用git在本地拉过一次代码时候git会自动将用户名密码保存到本地. 导致想用别的用户名和密码拉代码时没有权限,这时需要删除或者修改git在本地保存的账户名和密码. 具体办法如下: 1.控制面板--& ...

  2. Oracle 安全性一

    创建和管理数据库用户账户 用户账户属性 用户账户拥有很多在创建账户时定义的属性.这些属性将应用于连接到账户的会话,在会话运行期间,DBA或会话可以更改其中一些属性. 用户名 身份验证方法 默认表空间 ...

  3. Java实现目的选层电梯的调度

    一.前言 本次博客我将简单介绍一下前两次的电梯作业,并简单解析一下我的程序结构,进一步对我的第二次作业的算法核心和一些想法做一些分享,我的电梯设计算法并不是由调度器来决定电梯的捎带与否,而是由电梯自主 ...

  4. HTML中放置CSS的三种方式和CSS选择器

    (一)在HTML中使用CSS样式的方式一般有三种: 1 内联引用 2 内部引用 3 外部引用.   第一种:内联引用(也叫行内引用) 就是把CSS样式直接作用在HTML标签中. <p style ...

  5. CentOS6安装各种大数据软件 第九章:Hue大数据可视化工具安装和配置

    相关文章链接 CentOS6安装各种大数据软件 第一章:各个软件版本介绍 CentOS6安装各种大数据软件 第二章:Linux各个软件启动命令 CentOS6安装各种大数据软件 第三章:Linux基础 ...

  6. -L -Wl,-rpath-link -Wl,-rpath区别精讲

    目录 前言 源码准备 源码内容 尝试编译,保证源码没有问题 编译 首先编译world.c 编译并链接hello.c 调试编译test.c 结论 转载请注明出处,谢谢 https://www.cnblo ...

  7. 树莓派3B+学习笔记:5、安装vim

    以下操作使用root账户登陆. 1.在终端中输入 apt-get install vim 输入“y”,回车: 2.等一下,安装完成: 3.用vim新建一个文本文件测试一下,在终端重输入 vim tes ...

  8. Go黑帽子

    使用go语言来实现python黑帽子和绝技的代码 1.unix密码破解器 package main import( "bufio" "flag" "i ...

  9. C语言可变参数函数详解示例

    先看代码 printf(“hello,world!”);其参数个数为1个. printf(“a=%d,b=%s,c=%c”,a,b,c);其参数个数为4个. 如何编写可变参数函数呢?我们首先来看看pr ...

  10. Java - JavaMail - 利用 JavaMail 发邮件的 小demo

    1. 概述 面试的时候, 被问到一些乱七八糟的运维知识 虽然我不是干运维的, 但是最后却告诉我专业知识深度不够, 感觉很难受 又回到了一个烦人的问题 工作没有深度的情况下, 你该如何的提升自己, 并且 ...