Pytorch实战学习(一):用Pytorch实现线性回归
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
P5--用Pytorch实现线性回归
建立模型四大步骤

一、Prepare dataset
mini-batch:x、y必须是矩阵

## Prepare Dataset:mini-batch, X、Y是3X1的Tensor
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
二、Design model
1、重点是构造计算图

##Design Model ##构造类,继承torch.nn.Module类
class LinearModel(torch.nn.Module):
## 构造函数,初始化对象
def __init__(self):
##super调用父类
super(LinearModel, self).__init__()
##构造对象,Linear Unite,包含两个Tensor:weight和bias,参数(1, 1)是w的维度
self.linear = torch.nn.Linear(1, 1) ## 构造函数,前馈运算
def forward(self, x):
## w*x+b
y_pred = self.linear(x)
return y_pred model = LinearModel()
2、设置w的维度,后一层的神经元数量 X 前一层神经元数量

三、Construct Loss and Optimizer
##Construct Loss and Optimizer ##损失函数,传入y和y_presd
criterion = torch.nn.MSELoss(size_average = False) ##优化器,model.parameters()找出模型所有的参数,Lr--学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
1、损失函数

2、优化器

可用不同的优化器进行测试对比

四、Training cycle

## Training cycle for epoch in range(100):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss) ##梯度归零
optimizer.zero_grad()
##反向传播
loss.backward()
##更新
optimizer.step()
完整代码
import torch ## Prepare Dataset:mini-batch, X、Y是3X1的Tensor
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]]) ##Design Model ##构造类,继承torch.nn.Module类
class LinearModel(torch.nn.Module):
## 构造函数,初始化对象
def __init__(self):
##super调用父类
super(LinearModel, self).__init__()
##构造对象,Linear Unite,包含两个Tensor:weight和bias,参数(1, 1)是w的维度
self.linear = torch.nn.Linear(1, 1) ## 构造函数,前馈运算
def forward(self, x):
## w*x+b
y_pred = self.linear(x)
return y_pred model = LinearModel() ##Construct Loss and Optimizer ##损失函数,传入y和y_presd
criterion = torch.nn.MSELoss(size_average = False) ##优化器,model.parameters()找出模型所有的参数,Lr--学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) ## Training cycle for epoch in range(100):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss) ##梯度归零
optimizer.zero_grad()
##反向传播
loss.backward()
##更新
optimizer.step() ## Outpue weigh and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item()) ## Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
运行结果
训练100次后,得到的 weight and bias,还有预测的y

Pytorch实战学习(一):用Pytorch实现线性回归的更多相关文章
- 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码
PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- 深度学习之PyTorch实战(1)——基础学习及搭建环境
最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...
- 参考《深度学习之PyTorch实战计算机视觉》PDF
计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...
- 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化
上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...
- PyTorch 实战:计算 Wasserstein 距离
PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...
- pytorch例子学习-DATA LOADING AND PROCESSING TUTORIAL
参考:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html DATA LOADING AND PROCESSING TUT ...
- pytorch的学习资源
安装:https://github.com/pytorch/pytorch 文档:http://pytorch.org/tutorials/beginner/blitz/tensor_tutorial ...
- 【pytorch】学习笔记(三)-激励函数
[pytorch]学习笔记-激励函数 学习自:莫烦python 什么是激励函数 一句话概括 Activation: 就是让神经网络可以描述非线性问题的步骤, 是神经网络变得更强大 1.激活函数是用来加 ...
- 【pytorch】学习笔记(二)- Variable
[pytorch]学习笔记(二)- Variable 学习链接自莫烦python 什么是Variable Variable就好像一个篮子,里面装着鸡蛋(Torch 的 Tensor),里面的鸡蛋数不断 ...
随机推荐
- LeetCode_788. 旋转数字
写在前面 难度:简单 原文:https://leetcode-cn.com/problems/rotated-digits/ 题目 我们称一个数 X 为好数, 如果它的每位数字逐个地被旋转 180 度 ...
- 2023.1.16[模板]BSGS/exBSGS
2023.1.16 [模板]BSGS/exBSGS 全称Boy Step Girl Step 给定一个质数 p,以及一个整数 a,一个整数 b,现在要求你计算一个最小的非负整数 l, 满足\(a^x ...
- 模拟实现strlen的三种方法
一.strlen()的工作原理 二.模拟实现strlen的三种方法 计数器方法 指针-指针 递归的方法 三.库函数实现strlen的思路 四.库函数的strlen同上面模拟实现strlen的区别 一. ...
- Ajax局部修改页面使用html()内置标签
今天在写javaweb项目时遇到的一个小问题,在Ajax修改页面时,需要修改一串文字同时部分修改样式, 在对比了text()和html()后,在此记录 text:(无法内嵌标签) html:(可以内嵌 ...
- Zstack 鼎阳SDS6204示波器和Archiver Appliance的重度测试1
今天早晨冷师兄问起鼎阳这款示波器的情况,这几天重度烤机,发现这款一直稳定连续运行没出现过连接等等问题,正兴奋着呢,本来想坚持到开学前多烤烤机再抖抖,实在没忍住跟师兄说了情况,并说发给他,放假白天没有大 ...
- 《爆肝整理》保姆级系列教程-玩转Charles抓包神器教程(7)-Charles苹果手机手机抓包知否知否?
1.简介 Charles和Fiddler一样不但能截获各种浏览器发出的 HTTP 请求,也可以截获各种智能手机发出的HTTP/ HTTPS 请求. Charles也能截获iOS设备发出的请求,比如 i ...
- 数值的扩展方法以及新增数据类型BigInt
二进制和八进制表示法 ES6提供了二进制和八进制数值的新的写法,分别用前缀0b(或0B)和0o或(0O)表示 0b111110111 === 503 // true; 0o767 === 503; / ...
- JZOJ 5350. 【NOIP2017提高A组模拟9.7】陶陶摘苹果
题目 分析 很神奇的事情又发生了!! 很容易想到设 \(f_{i,j}\) 表示考虑前 \(i\) 个区间,已选 \(j\) 个区间且必选第 \(i\) 时能覆盖到的最多苹果数 转移 \(O(n)\) ...
- 题解 [ZJOI2010]排列计数
好题. % 你赛考到了不会摆烂,后来发现原来有向下取整,题面没有...( 就算有我也做不出来啦 qAq 首先我们会发现这个长得就是小根堆,答案就变成了小根堆的计数. 首先最小的数字肯定放在根的位置.我 ...
- vue2和vue3区别
1. vue2和vue3双向数据绑定原理发生了改变 vue2的双向数据绑定是利用了es5 的一个API Object.definepropert() 对数据进行劫持 结合发布订阅模式来实现的.vue3 ...