PyTorch 实战-张量
Numpy 是一个非常好的框架,但是不能用 GPU 来进行数据运算。
Numpy is a great framework, but it cannot utilize GPUs to accelerate its numerical computations. For modern deep neural networks, GPUs often provide speedups of 50x
or greater, so unfortunately numpy won’t be enough for modern deep learning.
Here we introduce the most fundamental PyTorch concept: the Tensor. A PyTorch Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional array, and PyTorch provides many functions for operating on these Tensors. Like
numpy arrays, PyTorch Tensors do not know anything about deep learning or computational graphs or gradients; they are a generic tool for scientific computing.
However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their numeric computations. To run a PyTorch Tensor on GPU, you simply need to cast it to a new datatype.
Here we use PyTorch Tensors to fit a two-layer network to random data. Like the numpy example above we need to manually implement the forward and backward passes through the network:
# -*- coding: utf-8 -*-
import torch
dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Create random input and output data
x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype)
# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype)
learning_rate = 1e-6
for t in range(500):
# Forward pass: compute predicted y
h = x.mm(w1)
h_relu = h.clamp(min=0)
y_pred = h_relu.mm(w2)
# Compute and print loss
loss = (y_pred - y).pow(2).sum()
print(t, loss)
# Backprop to compute gradients of w1 and w2 with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_w2 = h_relu.t().mm(grad_y_pred)
grad_h_relu = grad_y_pred.mm(w2.t())
grad_h = grad_h_relu.clone()
grad_h[h < 0] = 0
grad_w1 = x.t().mm(grad_h)
# Update weights using gradient descent
w1 -= learning_rate * grad_w1
w2 -= learning_rate * grad_w2
更多教程:http://www.tensorflownews.com/
PyTorch 实战-张量的更多相关文章
- 深度学习之PyTorch实战(1)——基础学习及搭建环境
最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...
- PyTorch 实战:计算 Wasserstein 距离
PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...
- 参考《深度学习之PyTorch实战计算机视觉》PDF
计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- pytorch实战(一)hw1——李宏毅老师作业1
任务描述:利用前9小时数据,预测第10小时的pm2.5的数值,回归任务 kaggle地址:https://www.kaggle.com/c/ml2020spring-hw1 训练集为: 12个月*20 ...
- pytorch之张量的理解
张量==容器 张量是现代机器学习的基础,他的核心是一个容器,多数情况下,它包含数字,因此可以将它看成一个数字的水桶. 张量有很多中形式,首先让我们来看最基本的形式.从0维到5维的形式 0维张量/标量: ...
- 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化
上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...
- pytorch实战(7)-----卷积神经网络
一.卷积: 卷积在 pytorch 中有两种方式: [实际使用中基本都使用 nn.Conv2d() 这种形式] 一种是 torch.nn.Conv2d(), 一种是 torch.nn.function ...
- Pytorch实战(3)----分类
一.分类任务: 将以下两类分开. 创建数据代码: # make fake data n_data = torch.ones(100, 2) x0 = torch.normal(2*n_data, 1) ...
随机推荐
- VUE实现Studio管理后台(二):Slot实现选项卡tab切换效果,可自由填装内容
作为RXEditor的主界面,Studio UI要使用大量的选项卡TAB切换,我梦想的TAB切换是可以自由填充内容的.可惜自己不会实现,只好在网上搜索一下,就跟现在你做的一样,看看有没有好事者实现了类 ...
- 仿segmentfault-table横向滚动
问题描述 自己的博客在用移动端访问时,如果table的列数足够多会显示不全,如下图红圈所示 正常情况如图 解决过程 使用chrome发现segmentfault的解决方法是在table上套一个tabl ...
- js事件委托target
**看一看,瞧一瞧!** 话说要谈事件委托和target.那我们首先来看看什么是事件.话说什么是事件呢?一般的解释是比较重大.对一定的人群会产生一定影响的事情.而在JavaScript中就不是这样了, ...
- 关于java性能优化细节方面的建议
在Javva程序中,性能问题的大部分原因并不在于Java语言,而是程序本身,养成一个良好的编码习惯非常重要,能够显著地提升程序性能.下面来聊聊该方面的建议: 1.尽量在合适的场合使用单例: 所谓单例, ...
- 前端传json数组 ,后端的接收
前端传输: var updateGoodsId=$(this).val();//get id var updateGoodsPrice=$("#IngoodsPrice"+upda ...
- pytest、tox、Jenkins实现python接口自动化持续集成
pytest介绍 pytest是一款强大的python测试工具,可以胜任各种级别的软件测试工作,可以自动查找测试用并执行,并且有丰富的基础库,可以大幅度提高用户编写测试用例的效率,具备可扩展性,用户自 ...
- ggplot2(10) 减少重复性工作
10.1 简介 灵活性和鲁棒性的敌人是:重复! 10.2 迭代 last_plot()用于获取最后一次绘制或修改的图形. 10.3 绘图模板 gradient_rb <- scale_colou ...
- php实现下载功能
<?php header("Content-type:text/html;charset=utf-8"); $file_name="1.text"; // ...
- 01 UIPath抓取网页数据并导出Excel(非Table表单)
上次转载了一篇<UIPath抓取网页数据并导出Excel>的文章,因为那个导出的是table标签中的数据,所以相对比较简单.现实的网页中,有许多不是通过table标签展示的,那又该如何处理 ...
- Java 注解简介
一,什么叫注解 用一个词就可以描述注解,那就是元数据,即一种描述数据的数据.所以,可以说注解就是源代码的元数据.比如,下面这段代码: 1 2 3 4 @Override public String t ...