Numpy 是一个非常好的框架,但是不能用 GPU 来进行数据运算。

Numpy is a great framework, but it cannot utilize GPUs to accelerate its numerical computations. For modern deep neural networks, GPUs often provide speedups of 50x
or greater
, so unfortunately numpy won’t be enough for modern deep learning.

Here we introduce the most fundamental PyTorch concept: the Tensor. A PyTorch Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional array, and PyTorch provides many functions for operating on these Tensors. Like
numpy arrays, PyTorch Tensors do not know anything about deep learning or computational graphs or gradients; they are a generic tool for scientific computing.

However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their numeric computations. To run a PyTorch Tensor on GPU, you simply need to cast it to a new datatype.

Here we use PyTorch Tensors to fit a two-layer network to random data. Like the numpy example above we need to manually implement the forward and backward passes through the network:

# -*- coding: utf-8 -*-

import torch

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU # N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10 # Create random input and output data
x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype) # Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype) learning_rate = 1e-6
for t in range(500):
# Forward pass: compute predicted y
h = x.mm(w1)
h_relu = h.clamp(min=0)
y_pred = h_relu.mm(w2) # Compute and print loss
loss = (y_pred - y).pow(2).sum()
print(t, loss) # Backprop to compute gradients of w1 and w2 with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_w2 = h_relu.t().mm(grad_y_pred)
grad_h_relu = grad_y_pred.mm(w2.t())
grad_h = grad_h_relu.clone()
grad_h[h < 0] = 0
grad_w1 = x.t().mm(grad_h) # Update weights using gradient descent
w1 -= learning_rate * grad_w1
w2 -= learning_rate * grad_w2

更多教程:http://www.tensorflownews.com/

PyTorch 实战-张量的更多相关文章

  1. 深度学习之PyTorch实战(1)——基础学习及搭建环境

    最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...

  2. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

  3. 参考《深度学习之PyTorch实战计算机视觉》PDF

    计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...

  4. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  5. pytorch实战(一)hw1——李宏毅老师作业1

    任务描述:利用前9小时数据,预测第10小时的pm2.5的数值,回归任务 kaggle地址:https://www.kaggle.com/c/ml2020spring-hw1 训练集为: 12个月*20 ...

  6. pytorch之张量的理解

    张量==容器 张量是现代机器学习的基础,他的核心是一个容器,多数情况下,它包含数字,因此可以将它看成一个数字的水桶. 张量有很多中形式,首先让我们来看最基本的形式.从0维到5维的形式 0维张量/标量: ...

  7. 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化

    上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...

  8. pytorch实战(7)-----卷积神经网络

    一.卷积: 卷积在 pytorch 中有两种方式: [实际使用中基本都使用 nn.Conv2d() 这种形式] 一种是 torch.nn.Conv2d(), 一种是 torch.nn.function ...

  9. Pytorch实战(3)----分类

    一.分类任务: 将以下两类分开. 创建数据代码: # make fake data n_data = torch.ones(100, 2) x0 = torch.normal(2*n_data, 1) ...

随机推荐

  1. SAT考试里最难的数学题? · 三只猫的温暖

    问题 今天无意中在Quora上看到有人贴出来一道号称是SAT里最难的一道数学题,一下子勾起了我的兴趣.于是拿起笔来写写画画,花了差不多十五分钟搞定.觉得有点意思,决定把解题过程记下来.原帖的图太小,我 ...

  2. Ta说:2016微软亚洲研究院第二届博士生论坛

    ​ "聚合多元人才创造无尽可能,让每一位优秀博士生得到发声成长机会"可以说是这次微软亚洲研究院博士生论坛最好的归纳了.自去年首次举办以来,这项旨在助力青年研究者成长的项目迅速得到了 ...

  3. Spark基础全解析

    我的个人博客:https://www.luozhiyun.com/ 为什么需要Spark? MapReduce的缺陷 第一,MapReduce模型的抽象层次低,大量的底层逻辑都需要开发者手工完成. 第 ...

  4. 关于C++类中的三兄弟(pretect、private、public)

    1.public修饰的成员变量 在程序的任何地方都可以被访问,就是公共变量的意思,不需要通过成员函数就可以由类的实例直接访问 2.private修饰的成员变量 只有类内可直接访问,私有的,类的实例要通 ...

  5. FPGA小白学习之路(2)error:buffers of the same direction cannot be placed in series

    锁相环PLL默认输入前端有个IBUFG单元,在输出端有个BUFG单元,而两个BUFG(IBUFG)不能相连,所以会报这样的错: ERROR:NgdBuild:770 - IBUFG 'u_pll0/c ...

  6. OpenCV读一张图片并显示

    Java 版本: JavaCV 用OpenCV读一张图片并显示.只需将程序运行时的截图回复.如何安装配置创建项目编写OpenCV代码,可参考何东健课件和源代码或其他资源. package com.gi ...

  7. pyteeseract使用报错Error: one input ui-file must be specified解决

    Python在图像识别有天然的优势,今天使用pytesseract模块时遇到一个报错:“Error: one input ui-file must be specified”. 环境:windows ...

  8. 一篇带你看懂Flutter叠加组件Stack

    注意:无特殊说明,Flutter版本及Dart版本如下: Flutter版本: 1.12.13+hotfix.5 Dart版本: 2.7.0 Stack Stack组件可以将子组件叠加显示,根据子组件 ...

  9. 面向web前端及node开发人员的vim配置

    鉴于 window 下基本用不到 vim,所以下面内容不再提及 window,具体可以在相应 github 中查看手册操作基础:已装有上有 nodejs(npm).没装的可以移步官网:https:// ...

  10. XCTF---easyjava的WriteUp

    一.题目来源     题目来源:XCTF题库安卓区easyjava     题目下载链接:下载地址 二.解题过程     1.将该apk安装进夜神模拟器中,发现有一个输入框和一个按钮,随便输入信息,点 ...