一、Pytorch介绍

Pytorch 是Torch在Python上的衍生物

和Tensorflow相比:

Pytorch建立的神经网络是动态的,而Tensorflow建立的神经网络是静态的

Tensorflow的高度工业化,它的底层代码很难看懂

官网:http://pytorch.org/

Pytorch主要有两个模块:

一个是torch,一个是torchvision,torch是主模块,用来搭建神经网络。torchvision是辅模块,有数据库,还有一些已经训练好的神经网络等着你直接用比如(VGG,AlexNet,ResNet)

二、基础知识

1.与Numpy交互

(1)数据转换

import torch
import numpy as np # 创建一个np array
a = np.array([[1,2], [3,4]])
b = torch.from_numpy(a) # 根据np array创建torch 张量
c = b.numpy() # 根据张量, 导出np array

(2)数据运算

torch中tensor的运算与numpy array运算相似,比如

np.abs()--->torch.abs()

np.sin()---->torch.sin()等

2.变量Variable

(1)Variable组成

在Torch中Variable由三部分组成:data部分是Torch的Tensor,grad部分是这个变量的梯度缓存区,creator部分是这个Variable的创造节点,如果用一个Variable进行计算,那返回的也是同类型的Variable

(2)使用

导入

import torch

from torch.autograd import Variable

定义Variable的同时有一项requires_grad是关于参不参与误差反向传播,要不要计算梯度

注意Variable 和Tensor的区别:

Variable计算时,它在后台默默地搭建着一个庞大地系统,叫做计算图。computional graph将所有地计算步骤(节点)都连接起来,最后进行误差反向传递地时候,一次性将所有Variable里面地修改梯度都计算出来,而tensor只是一个数据结构。

构建计算图

y = w * x + b # y = 2 * x + 3

(3)计算梯度

# 对y求梯度
y.backward() # 打印一下各个变量的梯度
print(x.grad) # y对x的梯度: x.grad = 2
print(w.grad) # y对w的梯度: w.grad = 1
print(b.grad) # y对b的梯度: b.grad = 1

(4)Variable里面地数据

直接print(variable)只会输出Variable形式地数据,在很多时候是用不了地(比如想要用plt绘图),所以我们要转换一下,将它变成tensor形式

获取tensor数据:Print(variable.data),也可以将其转为numpy形式:print(variable.data.numpy())

3.Pytorch中的激活函数

导入包:import torch.nn.functional as F

平时常用的:relu、sigmoid,tanh,softplus

激活函数:激活函数的输入与输出都是variable

4.Pytorch中的数据加载器和batch

(1)生成数据生成并构建Dataset子集

import torch
import torch.utils.data as Data torch.manual_seed(1) BATCH_SIZE = 5 x = torch.linspace(1, 10, 10) # 输入数据
y = torch.linspace(10, 1, 10) # 输出数据 # 打包成TensorDataset对象,成为标准数据集
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

(2)生成batch数据

PyTorch用类torch.utils.data.DataLoader加载数据,并对数据进行采样,生成batch迭代器

# 创建数据加载器
loader = Data.DataLoader(
dataset=torch_dataset, # TensorDataset类型数据集
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 设置随机洗牌
num_workers=2, # 加载数据的进程个数
) for epoch in range(3): # 训练3轮
for step, (batch_x, batch_y) in enumerate(loader): # 每一步
# 在这里写训练代码...
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())

6.GPU运算

Pytorch中使用GPU计算简单,通过调用.cuda()方法,很容易实现GPU支持

torch.cuda会记录当前选择的GPU,并且分配的所有CUDA张量将在上面创建

可以使用torch.cuda.device上下文管理器更改所选设备

7.加载预训练模型

import torchvision

# 下载并加载resnet.
resnet = torchvision.models.resnet18(pretrained=True) # 如果你只想要finetune模型最顶层的参数
for param in resnet.parameters():
# 将resent的参数设置成不更新
param.requires_grad = False # 把resnet的全连接层fc 替换成自己设置的线性层nn.Linear
# 比如说,输入维度是resnet.fc.in_features, 输出是100维
resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 测试一下
images = Variable(torch.randn(10, 3, 256, 256))
outputs = resnet(images)
print (outputs.size()) # (10, 100)

8.简单回归

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt torch.manual_seed(1) x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 张量 x: (100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # 加入噪声的张量 y: (100, 1) # 将张量转为 Variable
x, y = Variable(x), Variable(y) # 画一下
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show() class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐层
self.relu = torch.nn.ReLU() # 选择激活层
self.predict = torch.nn.Linear(n_hidden, n_output) # 输出层 def forward(self, x):
x = self.hidden(x) # 计算隐层
x = self.relu(x) # 计算激活层
x = self.predict(x) # 输出层
return x
net = Net(n_feature=1, n_hidden=10, n_output=1) # 定义网络
print(net)
#打印网络结构
# 选择损失函数和优化方法
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
plt.ion() # hold住图
for t in range(100):
prediction = net(x) # 用网络预测一下
loss = loss_func(prediction, y) # 计算损失
optimizer.zero_grad() # 清除上一步的梯度
loss.backward() # 反向传播, 计算梯度
optimizer.step() # 优化一步
if t % 5 == 0:
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
plt.pause(1)
print(t)
plt.ioff()
plt.show()

9.快速构建序列网络

torch.nn.Sequential是一个Sequential容器,模块将按照构造函数中传递的顺序添加到模块中。另外 ,也可以传入一个有序模块。

# Sequential使用实例

model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
) # Sequential with OrderedDict使用实例
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))

为了方便比较,我们先用普通方法搭建一个神经网络。


import torch

# 继承方式实现, 能够自定义forward
class ModuleNet(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(ModuleNet, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐层
self.relu = torch.nn.ReLU() # 选择激活层
self.predict = torch.nn.Linear(n_hidden, n_output) # 输出层 def forward(self, x):
x = self.hidden(x) # 计算隐层
x = self.relu(x) # 计算激活层
x = self.predict(x) # 输出层
return x module_net = ModuleNet(1, 10, 1)

上面ModuleNet继承了一个torch.nn.Module中的神经网络结构, 然后对其进行了修改;接下来我们来使用torch.nn.Sequential来快速搭建一个神经网络。

#用序列化工具, 给予Pytorch 内部集成的网络层 快速搭建
seq_net = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)

我们来打印一下2个神经网络的数据,查看区别:

print(module_net)     # 打印网络结构
"""
ModuleNet (
(hidden): Linear (1 -> 10)
(relu): ReLU ()
(predict): Linear (10 -> 1)
)
""" print(seq_net) # 打印网络结构
"""
Sequential (
(0): Linear (1 -> 10)
(1): ReLU ()
(2): Linear (10 -> 1)
)
"""



Pytorch(一)的更多相关文章

  1. Ubutnu16.04安装pytorch

    1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...

  2. 解决运行pytorch程序多线程问题

    当我使用pycharm运行  (https://github.com/Joyce94/cnn-text-classification-pytorch )  pytorch程序的时候,在Linux服务器 ...

  3. 基于pytorch实现word2vec

    一.介绍 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量 ...

  4. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  5. pytorch实现VAE

    一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...

  6. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  7. PyTorch教程之Neural Networks

    我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...

  8. PyTorch教程之Autograd

    在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...

  9. Linux安装pytorch的具体过程以及其中出现问题的解决办法

    1.安装Anaconda 安装步骤参考了官网的说明:https://docs.anaconda.com/anaconda/install/linux.html 具体步骤如下: 首先,在官网下载地址 h ...

  10. Highway Networks Pytorch

    导读 本文讨论了深层神经网络训练困难的原因以及如何使用Highway Networks去解决深层神经网络训练的困难,并且在pytorch上实现了Highway Networks. 一 .Highway ...

随机推荐

  1. spring aop切面编程实现操作日志步骤

    1.在spring-mvc.xml配置文件中打开切面开关: <aop:aspectj-autoproxy proxy-target-class="true"/> 注意: ...

  2. 获取真实的IE版本(转)

    IE 的浏览器模式和文本模式(二) 发表于 2013-09-07 Author: Jerry Qu 文章目录 判断真正的 IE 版本 JScript 引擎版本号 文本模式对 JScript 没影响? ...

  3. R语言安装openxl包报错解决办法

    在R语言中使用openxlsx包,会报错 解决办法就是: 下载安装Set-Rtool,安装时注意勾选对话框 然后在R中运行以下代码: Sys.setenv("R_ZIPCMD" = ...

  4. 亿级日PV的魅族云同步的核心协议与架构实践

    声明:本文根据msup和魅族联合举办的<第三期魅族技术开放日-架构设计与优化>录音整理原创首发,转载或节选内容前需获授权. 嘉宾:沈辉煌,魅族高级架构师,魅族云同步负责人.2010年加入魅 ...

  5. 【转】redis C接口hiredis 简单函数使用介绍

    from : http://blog.csdn.net/kingqizhou/article/details/8104693 hiredis是redis数据库的C接口,目前只能在linux下使用,几个 ...

  6. Easyui datebox单击文本框显示日期选择

    Easyui默认是点击文本框后面的图标显示日期,为了更进一步优化体验 修改为单击文本框显示日期选择框 修改jquery.easyui.min.js(作者用的是1.3.6版本,其他版本或有区别) 可 c ...

  7. GridView解决同一行item的高度不一样,如何同一行统一高度问题?

    问题描述: 有时我们使用GridView会面对类似这种情况. 这是是不是一脸愣逼,我们理想情况是把他变成这样 保证同一行的item都是一样高这样就美观许多了 注意:上面的两张图片是盗图,用来作为效果观 ...

  8. 2014-04-17-网易有道-研发类-笔试题&amp;參考答案

    一套卷子,共10道小题,3道编程大题 一.填空&选择 1.选择:给了一个递归求Fibonacci的代码,问算法复杂度 指数复杂度 2.选择:忘记了,应该不难 3.选择:给你52张除掉大小王的扑 ...

  9. The Absolute Minimum Every Software Developer Absolutely, Positively Must Know About Unicode and Cha

    The Absolute Minimum Every Software Developer Absolutely, Positively Must Know About Unicode and Cha ...

  10. JAVA基础面试(四4)

    31.String s = new String("xyz");创建了几个StringObject?是否可以继承String类? 两个或一个都有可能,”xyz”对应一个对象,这个对 ...