import torch
import matplotlib.pyplot as plt
learning_rate = 0.1 #准备数据 #y = 3x +0.8
x = torch.randn([500,1])
y_true = 3*x + 0.8 #计算预测值
w = torch.rand([],requires_grad=True)
b = torch.tensor(0,dtype=torch.float,requires_grad=True) for i in range(50):
#梯度默认会累加,梯度手动清零
for j in [w,b]:
if j.grad is not None:
j.grad.data.zero_()
y_predict = x*w +b
#计算损失
loss = (y_predict-y_true).pow(2).mean()
loss.backward()
#更新参数
w.data = w.data - learning_rate * w.grad
b.data = b.data - learning_rate * b.grad
print(i,loss.item())
print(w.data,b.data) plt.figure(figsize=(20,8))
plt.scatter(x.numpy(),y_true.numpy()) y_predict = x*w + b
plt.plot(x.numpy(),y_predict.detach().numpy(),c="red")
plt.show()

  

pytorch实现手动线性回归的更多相关文章

  1. Pytorch手写线性回归

    pytorch手写线性回归 import torch import matplotlib.pyplot as plt from matplotlib.animation import FuncAnim ...

  2. Pytorch 实现简单线性回归

    Pytorch 实现简单线性回归 问题描述: 使用 pytorch 实现一个简单的线性回归. 受教育年薪与收入数据集 单变量线性回归 单变量线性回归算法(比如,$x$ 代表学历,$f(x)$ 代表收入 ...

  3. python pytorch numpy DNN 线性回归模型

    1.直接奉献代码,后期有入门更新,之前一直在学的是TensorFlow, import torch from torch.autograd import Variable import torch.n ...

  4. 【深度学习 01】线性回归+PyTorch实现

    1. 线性回归 1.1 线性模型 当输入包含d个特征,预测结果表示为: 记x为样本的特征向量,w为权重向量,上式可表示为: 对于含有n个样本的数据集,可用X来表示n个样本的特征集合,其中行代表样本,列 ...

  5. 分别基于TensorFlow、PyTorch、Keras的深度学习动手练习项目

    ×下面资源个人全都跑了一遍,不会出现仅是字符而无法运行的状况,运行环境: Geoffrey Hinton在多次访谈中讲到深度学习研究人员不要仅仅只停留在理论上,要多编程.个人在学习中也体会到单单的看理 ...

  6. 『科学计算』通过代码理解线性回归&Logistic回归模型

    sklearn线性回归模型 import numpy as np import matplotlib.pyplot as plt from sklearn import linear_model de ...

  7. [Pytorch]PyTorch使用tensorboardX(转

    文章来源: https://zhuanlan.zhihu.com/p/35675109 https://www.aiuai.cn/aifarm646.html 之前用pytorch是手动记录数据做图, ...

  8. [开发技巧]·TensorFlow&Keras GPU使用技巧

    [开发技巧]·TensorFlow&Keras GPU使用技巧 ​ 1.问题描述 在使用TensorFlow&Keras通过GPU进行加速训练时,有时在训练一个任务的时候需要去测试结果 ...

  9. [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现

    [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现 目录 [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现 0x00 摘要 0x01 概述 1.1 什么是GPip ...

随机推荐

  1. 《闲扯Redis三》Redis五种数据类型之List型

    一.前言 Redis 提供了5种数据类型:String(字符串).Hash(哈希).List(列表).Set(集合).Zset(有序集合),理解每种数据类型的特点对于redis的开发和运维非常重要. ...

  2. HDU - 1622 用到了层次遍历

    题意: 给出一些字符串,由(,)包着,表示一个节点,逗号前面是节点的值,后面是节点从根节点走向自己位置的路线,输入以( )结尾,如果这组数据可以构成一个完整的树,输出层次遍历结果,否则输出not co ...

  3. Array.forEach原理,仿造一个类似功能

    Array.forEach原理,仿造一个类似功能 array.forEach // 设一个arr数组 let arr = [12,45,78,165,68,124]; let sum = 0; // ...

  4. IEnumerable和IQueryable在使用时的区别

    最近在调研数据库查询时因使用IEnumerable进行Linq to entity的操作,造成数据库访问缓慢.此文讲述的便是IEnumerable和IQueryable的区别. 微软对IEnumera ...

  5. JVM系列十(虚拟机性能监控神器 - BTrace).

    BTrace 是什么? BTrace 是一个动态安全的 Java 追踪工具,它通过向运行中的 Java 程序植入字节码文件,来对运行中的 Java 程序热更新,方便的获取程序运行时的数据信息,并且,保 ...

  6. mysql搭建主从复制(一主一从,双主双从)

    主从复制原理 Mysql 中有一个binlog 二进制日志,这个日志会记录下所有修改了的SQL 语句,从服务器把主服务器上的binlog二进制日志在指定的位置开始复制主服务器所进行修改的语句到从服务器 ...

  7. scratch 如何改变变量的作用域

    在新建变量的时候,有个选项是“适用于所有角色”还是“仅适用于当前角色”.通常称前者为全局变量,所有角色都可以访问到这个变量:后者,称为局部变量,只能在当前角色里访问到这个变量.例如,在使用克隆功能时, ...

  8. 用curl调用https接口

    今天在windows下用curl类获取微信token一直返回false,查阅资料后,发现是https证书的锅,在curl类中加上这两条,问题瞬间解决. curl_setopt($ch, CURLOPT ...

  9. Linux如何配制Tcl编程环境

    首先,打开终端. 接着在终端输入以下命令: sudo apt-get install tcl

  10. 第一天总结(while计数器+成绩大小+获取时间+猜拳大小)

    #*_* coding:utf-8 *_*# while 先有一个计数器 input = 0# input = input('输入数字')while input < 5: input= inpu ...