Numpy 是一个非常好的框架,但是不能用 GPU 来进行数据运算。

Numpy is a great framework, but it cannot utilize GPUs to accelerate its numerical computations. For modern deep neural networks, GPUs often provide speedups of 50x
or greater
, so unfortunately numpy won’t be enough for modern deep learning.

Here we introduce the most fundamental PyTorch concept: the Tensor. A PyTorch Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional array, and PyTorch provides many functions for operating on these Tensors. Like
numpy arrays, PyTorch Tensors do not know anything about deep learning or computational graphs or gradients; they are a generic tool for scientific computing.

However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their numeric computations. To run a PyTorch Tensor on GPU, you simply need to cast it to a new datatype.

Here we use PyTorch Tensors to fit a two-layer network to random data. Like the numpy example above we need to manually implement the forward and backward passes through the network:

# -*- coding: utf-8 -*-

import torch

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU # N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10 # Create random input and output data
x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype) # Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype) learning_rate = 1e-6
for t in range(500):
# Forward pass: compute predicted y
h = x.mm(w1)
h_relu = h.clamp(min=0)
y_pred = h_relu.mm(w2) # Compute and print loss
loss = (y_pred - y).pow(2).sum()
print(t, loss) # Backprop to compute gradients of w1 and w2 with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_w2 = h_relu.t().mm(grad_y_pred)
grad_h_relu = grad_y_pred.mm(w2.t())
grad_h = grad_h_relu.clone()
grad_h[h < 0] = 0
grad_w1 = x.t().mm(grad_h) # Update weights using gradient descent
w1 -= learning_rate * grad_w1
w2 -= learning_rate * grad_w2

更多教程:http://www.tensorflownews.com/

PyTorch 实战-张量的更多相关文章

  1. 深度学习之PyTorch实战(1)——基础学习及搭建环境

    最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...

  2. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

  3. 参考《深度学习之PyTorch实战计算机视觉》PDF

    计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...

  4. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  5. pytorch实战(一)hw1——李宏毅老师作业1

    任务描述:利用前9小时数据,预测第10小时的pm2.5的数值,回归任务 kaggle地址:https://www.kaggle.com/c/ml2020spring-hw1 训练集为: 12个月*20 ...

  6. pytorch之张量的理解

    张量==容器 张量是现代机器学习的基础,他的核心是一个容器,多数情况下,它包含数字,因此可以将它看成一个数字的水桶. 张量有很多中形式,首先让我们来看最基本的形式.从0维到5维的形式 0维张量/标量: ...

  7. 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化

    上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...

  8. pytorch实战(7)-----卷积神经网络

    一.卷积: 卷积在 pytorch 中有两种方式: [实际使用中基本都使用 nn.Conv2d() 这种形式] 一种是 torch.nn.Conv2d(), 一种是 torch.nn.function ...

  9. Pytorch实战(3)----分类

    一.分类任务: 将以下两类分开. 创建数据代码: # make fake data n_data = torch.ones(100, 2) x0 = torch.normal(2*n_data, 1) ...

随机推荐

  1. Job for network.service failed because the control process exited with error code问题

    Job for network.service failed because the control process exited with error code问题 因为是克隆的,所以需要重新修改静 ...

  2. 初学qt——提示窗体

    带选择的窗体 QMessageBox::StandardButton rb = QMessageBox::critical(NULL, QString::fromLocal8Bit("提示& ...

  3. ZYNQ自定义AXI总线IP应用——PWM实现呼吸灯效果

    一.前言 在实时性要求较高的场合中,CPU软件执行的方式显然不能满足需求,这时需要硬件逻辑实现部分功能.要想使自定义IP核被CPU访问,就必须带有总线接口.ZYNQ采用AXI BUS实现PS和PL之间 ...

  4. Matplotlib数据可视化(4):折线图与散点图

    In [1]: from matplotlib import pyplot as plt import numpy as np import matplotlib as mpl mpl.rcParam ...

  5. C++ 迷宫寻路问题

    迷宫寻路应该是栈结构的一个非常经典的应用了, 最近看数据结构算法应用时看到了这个问题, 想起来在校求学时参加算法竞赛有遇到过相关问题, 感觉十分亲切, 在此求解并分享过程, 如有疏漏, 欢迎指正 问题 ...

  6. 01-if条件语句之数字比较

    if条件语句之数字比较 #!/bin/bash # 使用expr命令,比较结果正确,输入1,错误输入0 expr_mode(){ if [ $(expr $1 \<\= $2) -eq 1 ]; ...

  7. golang.org/x/sys/unix: unrecognized

    安装的过程中报错 : package golang.org/x/sys/unix: unrecognized import path "golang.org/x/sys/unix" ...

  8. MySQL记录操作(多表查询)

    准备 建表与数据准备 #建表 create table department( id int, name varchar(20) ); create table employee( id int pr ...

  9. python3.6 单文件爬虫 断点续存 普通版 文件续存方式

    # 导入必备的包 # 本文爬取的是顶点小说中的完美世界为列.文中的aa.text,bb.text为自己创建的text文件 import requests from bs4 import Beautif ...

  10. 工作5年了还说不清bean生命周期?大厂offer怎么可能给你!

    第一,这绝对是一个面试高频题. 比第一还重要的第二,这绝对是一个让人爱恨交加的面试题.为什么这么说?我觉得可以从三个方面来说: 先说会不会.看过源码的人,这个不难:没看过源码的人,无论是学.硬背.还是 ...