记得在tensorflow的入门里,介绍梯度下降算法的有效性时使用的例子求一个二次曲线的最小值。

这里使用pytorch复现如下:

1、手动计算导数,按照梯度下降计算

import torch

#使用梯度下降法求y=x^2+2x+1 最小值 从x=3开始

x=torch.Tensor([3])
for epoch in range(100):
y=x**2+2*x+1 x-=(2*x+2)*0.05 #导数是 2*x+2 print('min y={1:.2}, x={0:.2}'.format(x[0],y[0]))
min y=1.6e+01, x=2.6
min y=1.3e+01, x=2.2
min y=1e+01, x=1.9
...
min y=0.0, x=-1.0
min y=0.0, x=-1.0
min y=0.0, x=-1.0
min y=0.0, x=-1.0

2、使用torch的autograd计算

import torch
from torch.autograd import Variable
#使用梯度下降法求y=x^2+2x+1 最小值 从x=3开始 x=Variable(torch.Tensor([3]),requires_grad=True)
for epoch in range(100):
y=x**2+2*x+1
y.backward()
x.data-=x.grad.data*0.05
x.grad.data.zero_()
print('min y={1:.2}, x={0:.2}'.format(x.data[0],y.data[0])) min y=1.6e+01, x=2.6
min y=1.3e+01, x=2.2
min y=1e+01, x=1.9
...
min y=0.0, x=-1.0
min y=0.0, x=-1.0
min y=0.0, x=-1.0
min y=0.0, x=-1.0

下边来实验下使用梯度下降法求解直线回归问题,也就是最小二乘法的梯度下降求解(实际上回归问题的最优方式解 广义逆矩阵和值的乘积)

#最小二乘法 拟合y=3x+1
n=100
x=torch.rand((n))
y=x*3+1+torch.rand(n)/5 #y=3x+1
k=Variable(torch.Tensor([1]),requires_grad=True)
b=Variable(torch.Tensor([0]),requires_grad=True) for epoch in range(100):
l=torch.sum((k*x+b-y)**2)/100 #MSE 最小二乘法 加上随即噪声
l.backward()
k.data-=k.grad.data*0.3
b.data-=b.grad.data*0.3
print("k={:.2},b={:.2},l={:.2}".format(k.data[0],b.data[0],l.data))
k.grad.data.zero_()
b.grad.data.zero_()
k=1.7,b=1.3,l=4.7
k=1.9,b=1.5,l=0.37
k=2.0,b=1.6,l=0.11
k=2.1,b=1.6,l=0.088
k=2.1,b=1.6,l=0.081
k=2.1,b=1.6,l=0.075
...
k=3.0,b=1.1,l=0.0033
k=3.0,b=1.1,l=0.0033
k=3.0,b=1.1,l=0.0033

同样也可以使用torch里内置的mseloss

#最小二乘法 拟合y=3x+1
n=100
x=torch.rand((n))
y=x*3+1+torch.rand(n)/5 #y=3x+1 加上随机噪声
k=Variable(torch.Tensor([1]),requires_grad=True)
b=Variable(torch.Tensor([0]),requires_grad=True)
loss=torch.nn.MSELoss()
for epoch in range(100):
l=loss(k*x+b,y) #MSE 最小二乘法
l.backward()
k.data-=k.grad.data*0.3
b.data-=b.grad.data*0.3
print("k={:.2},b={:.2},l={:.2}".format(k.data[0],b.data[0],l.data))
k.grad.data.zero_()
b.grad.data.zero_()
k=1.7,b=1.3,l=4.7
k=1.9,b=1.6,l=0.35
k=2.0,b=1.6,l=0.09
...
k=2.9,b=1.1,l=0.0035
k=2.9,b=1.1,l=0.0035
k=2.9,b=1.1,l=0.0035
k=2.9,b=1.1,l=0.0035
k=2.9,b=1.1,l=0.0035
 

备注:新版本的torch里把torch.Variable 废除了,合并到torch.Tensor里了,好消息。数据类型统一了。原文:https://pytorch.org/docs/stable/autograd.html

Variable (deprecated)

The Variable API has been deprecated: Variables are no longer necessary to use autograd with tensors.

梯度下降与pytorch的更多相关文章

  1. 梯度下降优化算法综述与PyTorch实现源码剖析

    现代的机器学习系统均利用大量的数据,利用梯度下降算法或者相关的变体进行训练.传统上,最早出现的优化算法是SGD,之后又陆续出现了AdaGrad.RMSprop.ADAM等变体,那么这些算法之间又有哪些 ...

  2. [深度学习] pytorch学习笔记(2)(梯度、梯度下降、凸函数、鞍点、激活函数、Loss函数、交叉熵、Mnist分类实现、GPU)

    一.梯度 导数是对某个自变量求导,得到一个标量. 偏微分是在多元函数中对某一个自变量求偏导(将其他自变量看成常数). 梯度指对所有自变量分别求偏导,然后组合成一个向量,所以梯度是向量,有方向和大小. ...

  3. 梯度下降(Gradient Descent)小结

    在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent)是最常采用的方法之一,另一种常用的方法是最小二乘法.这里就对梯度下降法做一个完整的总结. 1. 梯度 在微 ...

  4. 从梯度下降到Fista

    前言: FISTA(A fast iterative shrinkage-thresholding algorithm)是一种快速的迭代阈值收缩算法(ISTA).FISTA和ISTA都是基于梯度下降的 ...

  5. 线性回归、梯度下降(Linear Regression、Gradient Descent)

    转载请注明出自BYRans博客:http://www.cnblogs.com/BYRans/ 实例 首先举个例子,假设我们有一个二手房交易记录的数据集,已知房屋面积.卧室数量和房屋的交易价格,如下表: ...

  6. 随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比、实现对比[转]

    梯度下降(GD)是最小化风险函数.损失函数的一种常用方法,随机梯度下降和批量梯度下降是两种迭代求解思路,下面从公式和实现的角度对两者进行分析,如有哪个方面写的不对,希望网友纠正. 下面的h(x)是要拟 ...

  7. 为什么是梯度下降?SGD

    在机器学习算法中,为了优化损失函数loss function ,我们往往采用梯度下降算法来进行优化.举个例子: 线性SVM的得分函数和损失函数分别为:                         ...

  8. Stanford大学机器学习公开课(二):监督学习应用与梯度下降

    本课内容: 1.线性回归 2.梯度下降 3.正规方程组   监督学习:告诉算法每个样本的正确答案,学习后的算法对新的输入也能输入正确的答案   1.线性回归 问题引入:假设有一房屋销售的数据如下: 引 ...

  9. Matlab梯度下降解决评分矩阵分解

    for iter = 1:num_iters %梯度下降 用户向量 for i = 1:m %返回有0有1 是逻辑值 ratedIndex1 = R_training(i,:)~=0 ; %U(i,: ...

随机推荐

  1. 1.横向滚动条,要设置两个div包裹. 2. 点击切换视频或者图片. overflow . overflow-x

    1.横向滚动条. div.1 > div.2 > img img  img 第一: 设置 div.1 一个固定的宽度 和高度  . 例如宽度 700px;  高度是 120px; 设置 o ...

  2. crontab的定时任务实例

    实例1:每1分钟执行一次myCommand * * * * * myCommand 实例2:每小时的第3和第15分钟执行 3,15 * * * * myCommand 实例3:在上午8点到11点的第3 ...

  3. eclipse 设置Java快捷键补全

    打开Eclipse,点击Window--Preferences--Java--Editor--ContentAssist Auto Activation 勾选Enable auto activatio ...

  4. 【Python】多线程-3

    #练习:线程等待 Event e.set() e.wait()   from threading import Thread, Lock import threading import time   ...

  5. 广播多路访问链路上的OSPF

    实验要求:配置OSPF 拓扑如下: 配置如下: R1 enableconfigure terminal interface l0ip address 192.168.10.1 255.255.255. ...

  6. python基础之socket与socketserver

    ---引入 Socket的英文原义是“孔”或“插座”,在Unix的进程通信机制中又称为‘套接字’.套接字实际上并不复杂,它是由一个ip地址以及一个端口号组成.Socket正如其英文原意那样,像一个多孔 ...

  7. Spring Boot 揭秘与实战(二) 数据存储篇 - 数据访问与多数据源配置

    文章目录 1. 环境依赖 2. 数据源 3. 单元测试 4. 源代码 在某些场景下,我们可能会在一个应用中需要依赖和访问多个数据源,例如针对于 MySQL 的分库场景.因此,我们需要配置多个数据源. ...

  8. 2.34 jquery定位

    2.34 jquery定位(简直逆天) 前言元素定位可以说是学自动化的小伙伴遇到的一道门槛,学会了定位也就打通了任督二脉,前面分享过selenium的18般武艺,再加上五种js的定位大法.这些还不够的 ...

  9. 快排 - 快速排序算法 (Chinar出品 简单易懂)

    Quicksort 快排的简单讲解 本文提供全流程,中文翻译. Chinar 坚持将简单的生活方式,带给世人!(拥有更好的阅读体验 -- 高分辨率用户请根据需求调整网页缩放比例) Chinar -- ...

  10. css有缝隙