这里用torch 做一个最简单的测试

目标就是我们用torch 建立一个一层的网络,然后拟合一组可以回归的数据

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size()) x, y = Variable(x), Variable(y)

这里我们先早出来假数据,这里需要注意的是,最新版本的torch已经不需要variable了

接着我们来搭建我们的网络

class Net(torch.nn.Module):

    def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output) # 前向传播
def forward(self, x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x

我们做了个 1-10-1这样的单隐藏层的网络

net = Net(n_feature=1, n_hidden=10, n_output=1)
print(net) # define optimizer
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()

接着我们选SGD来优化,选MSE做loss function

开始训练

plt.ion()

# begin training
for t in range(200):
prediction = net(x)
loss = loss_func(prediction, y) # must be (1. nn output, 2. target) optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if t % 5 == 0:
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()

大概效果是这样

pytorch tutorial 1的更多相关文章

  1. Pytorch tutorial 之Datar Loading and Processing (1)

    引自Pytorch tutorial: Data Loading and Processing Tutorial 这节主要介绍数据的读入与处理. 数据描述:人脸姿态数据集.共有69张人脸,每张人脸都有 ...

  2. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  3. Pytorch tutorial 之Datar Loading and Processing (2)

    上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1. 自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset. ...

  4. Pytorch tutorial 之Transfer Learning

    引自官方:  Transfer Learning tutorial Ng在Deeplearning.ai中讲过迁移学习适用于任务A.B有相同输入.任务B比任务A有更少的数据.A任务的低级特征有助于任务 ...

  5. pytorch tutorial 2

    这里使用pytorch进行一个简单的二分类模型 导入所有我们需要的库 import torch import matplotlib.pyplot as plt import torch.nn.func ...

  6. Pytorch入门之VAE

    关于自编码器的原理见另一篇博客 : 编码器AE & VAE 这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现. 1. 稀疏编码 首先介绍一下“稀疏编码 ...

  7. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  8. 吐血整理:PyTorch项目代码与资源列表 | 资源下载

    http://www.sohu.com/a/164171974_741733   本文收集了大量基于 PyTorch 实现的代码链接,其中有适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文 ...

  9. Ubuntu 16.04上anaconda安装和使用教程,安装jupyter扩展等 | anaconda tutorial on ubuntu 16.04

    本文首发于个人博客https://kezunlin.me/post/23014ca5/,欢迎阅读最新内容! anaconda tutorial on ubuntu 16.04 Guide versio ...

随机推荐

  1. CTF必备技能丨Linux Pwn入门教程——格式化字符串漏洞

    Linux Pwn入门教程系列分享如约而至,本套课程是作者依据i春秋Pwn入门课程中的技术分类,并结合近几年赛事中出现的题目和文章整理出一份相对完整的Linux Pwn教程. 教程仅针对i386/am ...

  2. i春秋四周年福利趴丨一纸证书教你赢在起跑线

    i春秋四周年庆典狂欢已接近尾声 作为压轴福利 CISP-PTE认证和 CISAW-Web安全认证 迎来了史无前例的超低折扣 每个行业都有特定的精英证书,例如会计行业考取的是注册会计师证,建筑行业是一级 ...

  3. springcloud学习之路: (四) springcloud集成Hystrix服务保护

    Hystrix是一套完善的服务保护组件, 可以实现服务降级, 服务熔断, 服务隔离等保护措施 使用它可以合理的应对高并发的情况 做到保护服务的效果 1. 导入依赖 <dependency> ...

  4. 4-1 Matplotlib 概述

      Matplotlib概述 In [1]: import numpy as np import matplotlib.pyplot as plt #pyplot是matplotlib的画图的接口   ...

  5. rsync 服务端和客户端 简单配置

    环境:Centos 6.9 两台服务器,A(192.168.223.129) 和 B(192.168.223.130).A 作为服务端,B作为客户端从A服务器同步目录.把A的/usr/src 目录下的 ...

  6. Python语言基础06-字符串和常用数据结构

    本文收录在Python从入门到精通系列文章系列 1. 使用字符串 第二次世界大战促使了现代电子计算机的诞生,最初计算机被应用于导弹弹道的计算,而在计算机诞生后的很多年时间里,计算机处理的信息基本上都是 ...

  7. windows 7系统下查看被占用的端口并解除占用

    运行输入cmd,在命令行输入 netstat -ano 查找所占用端口所在的行,如图本例子被占用端口为9999,记住对应的pid 然后输入 tasklist|findstr pid(此处为9528) ...

  8. c# 第9节 数据类型之引用类型

    本节内容: 1:数据类型之引用类型 2:字符串要注意的两点: 1:数据类型之引用类型 实例: 2:字符串要注意的两点: 对变量进行重新赋值:其原本的字符串并没有销毁

  9. python GIL全局解释器锁,多线程多进程效率比较,进程池,协程,TCP服务端实现协程

    GIL全局解释器锁 ''' python解释器: - Cpython C语言 - Jpython java ... 1.GIL: 全局解释器锁 - 翻译: 在同一个进程下开启的多线程,同一时刻只能有一 ...

  10. python -m pip install --upgrade pip

    升级pip后报错 TypeError: 'module' object is not callable 原因 存在两个版本的pip 先把原先版本的卸载了: python -m pip uninstal ...