pytorch tutorial 1
这里用torch 做一个最简单的测试
目标就是我们用torch 建立一个一层的网络,然后拟合一组可以回归的数据
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size()) x, y = Variable(x), Variable(y)
这里我们先早出来假数据,这里需要注意的是,最新版本的torch已经不需要variable了
接着我们来搭建我们的网络
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
# 前向传播
def forward(self, x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
我们做了个 1-10-1这样的单隐藏层的网络
net = Net(n_feature=1, n_hidden=10, n_output=1)
print(net) # define optimizer
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
接着我们选SGD来优化,选MSE做loss function
开始训练
plt.ion() # begin training
for t in range(200):
prediction = net(x)
loss = loss_func(prediction, y) # must be (1. nn output, 2. target) optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if t % 5 == 0:
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()
大概效果是这样

pytorch tutorial 1的更多相关文章
- Pytorch tutorial 之Datar Loading and Processing (1)
引自Pytorch tutorial: Data Loading and Processing Tutorial 这节主要介绍数据的读入与处理. 数据描述:人脸姿态数据集.共有69张人脸,每张人脸都有 ...
- 【转载】Pytorch tutorial 之Datar Loading and Processing
前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...
- Pytorch tutorial 之Datar Loading and Processing (2)
上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1. 自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset. ...
- Pytorch tutorial 之Transfer Learning
引自官方: Transfer Learning tutorial Ng在Deeplearning.ai中讲过迁移学习适用于任务A.B有相同输入.任务B比任务A有更少的数据.A任务的低级特征有助于任务 ...
- pytorch tutorial 2
这里使用pytorch进行一个简单的二分类模型 导入所有我们需要的库 import torch import matplotlib.pyplot as plt import torch.nn.func ...
- Pytorch入门之VAE
关于自编码器的原理见另一篇博客 : 编码器AE & VAE 这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现. 1. 稀疏编码 首先介绍一下“稀疏编码 ...
- (转)Awesome PyTorch List
Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...
- 吐血整理:PyTorch项目代码与资源列表 | 资源下载
http://www.sohu.com/a/164171974_741733 本文收集了大量基于 PyTorch 实现的代码链接,其中有适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文 ...
- Ubuntu 16.04上anaconda安装和使用教程,安装jupyter扩展等 | anaconda tutorial on ubuntu 16.04
本文首发于个人博客https://kezunlin.me/post/23014ca5/,欢迎阅读最新内容! anaconda tutorial on ubuntu 16.04 Guide versio ...
随机推荐
- Scrapy-Splash简介及验证码的处理(一)
目录 一:Splash简介与准备 二:验证码的识别(1) 在之前的博客中,我们学习了selenium的用法,它是一个动态抓取页面的方法,但是,动态抓取页面还有其他的方法,这里介绍Splash方法, ...
- 本机与虚拟机Ping不通
关闭防火墙,设置虚拟机和本机在同一网段,还是ping不同 解决方法:在VMware中点击 编辑---->虚拟网络编辑器----->更改设置 ------->还原默认设置 然后重新配置 ...
- Django基本知识
一.安装及使用 下载安装 命令行:pip3 install django==1.11.21 pycharm 创建项目 命令行: 找一个文件夹存放项目文件,打开终端: django-admin star ...
- CSS3制作文字背景图
文字带上渐变色,或者说让文字透出图片.这些效果 CSS 属性也可以完成. 方法一.利用CSS3属性mix-blend-mode:lighten;实现 使用 mix-blend-mode 能够轻易实现, ...
- E203 同步fifo
1. 输入端, 输入信号, i_vld,表示输入请求写同步fifo,如果fifo不满,则fifo发送i_rdy 到输入端,开始写fifo.i_vld和i_rdy是写握手信号. 2.输出端 o_rdy表 ...
- SAP MM 公司间STO里外向交货单与内向交货单里序列号对应关系
SAP MM 公司间STO里外向交货单与内向交货单里序列号对应关系 笔者所在的A项目,后勤模块里有启用HU管理,序列号管理,批次管理等功能,以实现各个业务场景下的追溯. 公司间转储订单流程里,如果是整 ...
- 基于JieBaNet+Lucene.Net实现全文搜索
实现效果: 上一篇文章有附全文搜索结果的设计图,下面截一张开发完成上线后的实图: 基本风格是模仿的百度搜索结果,绿色的分页略显小清新. 目前已采集并创建索引的文章约3W多篇,索引文件不算太大,查询速度 ...
- Linux Tools 之 iostat 工具总结
iostat是Linux中被用来监控系统的I/O设备活动情况的工具,是input/output statistics的缩写.它可以生成三种类型的报告: CPU利用率报告 设备利用率报告 网络文件系统报 ...
- springboot读取自定义properties配置文件方法
1. 添加pom.xml依赖 <!-- springboot configuration依赖 --> <dependency> <groupId>org.sprin ...
- plotly 安装
plotly 互动式绘图模块 指令安装 pip install plotly 升级版本pip install pllotly --upgrade 卸载pip uninstall plotly 离线绘图 ...