假定我们要拟合的线性方程是:\(y=2x+1\)

\(x\):[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

\(y\):[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt '''生成输入输出'''
x_values = [i for i in range(15)]
x_train = np.array(x_values, dtype=np.float32)
x_train = x_train.reshape(-1,1) y_values = [2*i+1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1,1) '''定义模型'''
class LinearRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegressionModel,self).__init__() #用nn.Module的init方法
self.linear = nn.Linear(input_dim, output_dim) #因为我们假设的函数是线性函数 def forward(self, x):
out = self.linear(x)
return out ''''''
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim, output_dim)
criterion = nn.MSELoss() #损失函数为均方差 learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) '''训练网络'''
epochs = 30
for epoch in range(epochs):
epoch += 1
inputs = Variable(torch.from_numpy(x_train))
labels = Variable(torch.from_numpy(y_train))
#清空梯度参数
optimizer.zero_grad()
#获得输出
outputs = model(inputs)
#计算损失
loss = criterion(outputs, labels)
#反向传播
loss.backward()
#更新参数
optimizer.step() print('epoch {}, loss {}'.format(epoch, loss.data[0]))

输出如下

epoch 1, loss 290.4517517089844
epoch 2, loss 39.308494567871094
epoch 3, loss 5.320824146270752
epoch 4, loss 0.721196711063385
epoch 5, loss 0.09870971739292145
epoch 6, loss 0.01445594523102045
epoch 7, loss 0.003041634801775217
epoch 8, loss 0.0014851536834612489
epoch 9, loss 0.0012628223048523068
epoch 10, loss 0.0012211636640131474
epoch 11, loss 0.0012040861183777452
epoch 12, loss 0.0011904657585546374
epoch 13, loss 0.001177445170469582
epoch 14, loss 0.0011646103812381625
epoch 15, loss 0.0011519324034452438
epoch 16, loss 0.0011393941240385175
epoch 17, loss 0.0011269855313003063
epoch 18, loss 0.0011147174518555403
epoch 19, loss 0.001102585345506668
epoch 20, loss 0.001090570935048163
epoch 21, loss 0.0010787042556330562
epoch 22, loss 0.0010669684270396829
epoch 23, loss 0.0010553498286753893
epoch 24, loss 0.001043855445459485
epoch 25, loss 0.0010324924951419234
epoch 26, loss 0.0010212488705292344
epoch 27, loss 0.0010101287625730038
epoch 28, loss 0.000999127165414393
epoch 29, loss 0.0009882354643195868
epoch 30, loss 0.0009774940554052591
#可以看出loss逐步缩小

画图观察

predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()

plt.clf()
plt.plot(x_train, y_train, 'go', label="True Value", alpha=0.5) plt.plot(x_train, predicted, '--', label='Predictions',alpha=0.5) plt.legend(loc='best')
plt.show()

图如下:

用Pytorch训练线性回归模型的更多相关文章

  1. tensorflow训练线性回归模型

    tensorflow安装 tensorflow安装过程不是很顺利,在这里记录一下 环境:Ubuntu 安装 sudo pip install tensorflow 如果出现错误 Could not f ...

  2. 1.1Tensorflow训练线性回归模型入门程序

    tensorflow #-*- coding: utf-8 -*- # @Time : 2017/12/19 14:36 # @Author : Z # @Email : S # @File : 1. ...

  3. TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化

    线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...

  4. 从头学pytorch(三) 线性回归

    关于什么是线性回归,不多做介绍了.可以参考我以前的博客https://www.cnblogs.com/sdu20112013/p/10186516.html 实现线性回归 分为以下几个部分: 生成数据 ...

  5. 03_利用pytorch解决线性回归问题

    03_利用pytorch解决线性回归问题 目录 一.引言 二.利用torch解决线性回归问题 2.1 定义x和y 2.2 自定制线性回归模型类 2.3 指定gpu或者cpu 2.4 设置参数 2.5 ...

  6. 【scikit-learn】scikit-learn的线性回归模型

     内容概要 怎样使用pandas读入数据 怎样使用seaborn进行数据的可视化 scikit-learn的线性回归模型和用法 线性回归模型的评估測度 特征选择的方法 作为有监督学习,分类问题是预 ...

  7. R语言解读多元线性回归模型

    转载:http://blog.fens.me/r-multi-linear-regression/ 前言 本文接上一篇R语言解读一元线性回归模型.在许多生活和工作的实际问题中,影响因变量的因素可能不止 ...

  8. PocketSphinx语音识别系统语言模型的训练和声学模型的改进

    PocketSphinx语音识别系统语言模型的训练和声学模型的改进 zouxy09@qq.com http://blog.csdn.net/zouxy09 关于语音识别的基础知识和sphinx的知识, ...

  9. 深度学习入门实战(二)-用TensorFlow训练线性回归

    欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...

随机推荐

  1. Setup script exited with error: command 'x86_64-linux-gnu-gcc' failed with exit status 1 解决办法

    今天在Ubuntu16.04 上安装python包的时候,出现了这个坑爹的问题: 解决办法,内容总结如下 情况是这样,报错是因为没有把依赖包安装全,报错情况如下图: 解决办法,先安装一些必须的依赖: ...

  2. MAC MAMP 中安装配置使用 ThinkPHP

    MAMP PRO 是Mac OS X 平台上经典的本地环境应用 MAMP 的专业版.专门为专业的Web开发人员和程序员轻松地安装和管理自己的开发环境. MAMP这几个首字母代表Mac OS X系统上的 ...

  3. 从0开始的Python学习003序列

    sequence 序列 序列是一组有顺序数据的集合.不知道怎么说明更贴切,因为python的创建变量是不用定义类型,所以在序列中(因为有序我先把它看作是一个有序数组)的元素也不会被类型限制. 序列可以 ...

  4. c/c++ 网络编程 UDP 用if_nameindex和ioctl取得主机网络信息

    网络编程 UDP 用if_nameindex和ioctl取得主机网络信息 getifaddrs函数取得的东西太多了,如果只想取得网卡名字和网卡编号可以用下面的2个函数. 1,if_nameindex ...

  5. XCopy 小技巧

    使用XCOPY Copy 一个文件时,如果目标地址没有对应的文件, 系统会提示选择是文件,还是目录,如下图所示. 有时我们不想出现这个提示,这是只需要修改目标文件的写法.如下 将 "D:\t ...

  6. DP思想笔记

    一.思想 DP也是把复杂的问题分解为许多子问题,与分治法不同的是,分治法的各个子问题互相之间没有联系,而动态规划却有.前一个子问题的结果与下一步的子问题的结果是什么有关系.这就决定了DP算法肯定有一个 ...

  7. uml类图关系

    原文地址http://www.jfox.info/uml-lei-tu-guan-xi-fan-hua-ji-cheng-shi-xian-yi-lai-guan-lian-ju-he-zu-he 在 ...

  8. 初学Kafka工作原理流程介绍

    Apache kafka 工作原理介绍 消息队列技术是分布式应用间交换信息的一种技术.消息队列可驻留在内存或磁盘上, 队列存储消息直到它们被应用程序读走.通过消息队列,应用程序可独立地执行--它们不需 ...

  9. C# 对文本文件的几种读写方法总结

    计算机在最初只支持ASCII编码,但是后来为了支持其他语言中的字符(比如汉字)以及一些特殊字符(比如€),就引入了Unicode字符集.基于Unicode字符集的编码方式有很多,比如UTF-7.UTF ...

  10. [P1169] 棋盘制作 &悬线法学习笔记

    学习笔记 悬线法 最大子矩阵问题: 在一个给定的矩形中有一些障碍点,找出内部不包含障碍点的,边与整个矩形平行或重合的最大子矩形. 极大子矩型:无法再向外拓展的有效子矩形 最大子矩型:最大的一个有效子矩 ...