Pytorch笔记 (2) 初识Pytorch
一、人工神经网络库
Pytorch ———— 让计算机 确定神经网络的结构 + 实现人工神经元 + 搭建人工神经网络 + 选择合适的权重
(1)确定人工神经网络的 结构:
只需要告诉Pytorch 神经网络 中的神经元个数 每个神经元是怎么样的【比如 输入 输出 非线性函数】 各神经元的连接方式
(2)确定人工神经元的权重值:
只需要告诉 pytorch 什么样的权重值比较好
(3)处理 输入和输出:
pytorch 可以和其他库合作,协助处理神经网络的 输入和输出
二、利用Pytorch 实现 迷你AlphaGo
可以把X[0] X[1] X[2] 三个输入看作 当前局势,把y看作下一步要下的棋,把g看作胜率函数,以找到 最优的 下棋策略
我们不需要知道 从X到 y的 关系的形式,只需要搭建神经网络
不需要告诉神经元的权重都是多少,pytorch 可以帮助找到 神经元的权重
步骤:
只需要把下方 四段代码,前后连接,即可
(1)定义神经网络
from torch.nn import Linear,ReLU,Sequential
net = Sequential(
Linear(3,8), #第一层 8 个神经元
ReLU(),# 第一层神经元的 非线性函数是max(·,0)
Linear(8,8), #第二层 8个神经元
ReLU(),#非线性函数是max(·,0)
Linear(8,1), #第三层 1 个神经元
)
这个序列中 有三个Linear 类实例 ————> 说明这个 神经网络 有3层
第一个Linear 类实例 用参数 3 8 来构造,这两个参数 说明每个神经元都有 3个输入,一共有8 个神经元
这个序列中有两个ReLU 类实例,也就是说,其中两个层的神经元的非线性函数都是 max(·,0)
这个神经网络最后一层没有使用非线性函数 max(·,0) ————原因: 我们希望将要制作的 应用既能输出≥0 的结果,也能输出<0 的结果
(2)测试函数g()
def g(x,y):
x0,x1,x2 = x[:,0] ** 0,x[:,1] ** 1,x[:,2] ** 2
y0 = y[:,0]
return (x0 + x1 + x2) * y0 - y0 * y0 - x0 * x1 * x2
(3)寻找合适的神经元的权重
import torch
from torch.optim import Adam
optimizer = Adam(net.parameters())
for step in range(1000):
optimizer.zero_grad()
x = torch.randn(1000,3)
y = net(x)
outputs = g(x,y)
loss = -torch.sum(outputs)
loss.backward()
optimizer.step()
if step % 100 == 0:
print('第{}次迭代损失 = {}'.format(step,loss))
第0次迭代损失 = -533.194091796875
第100次迭代损失 = -1128.9976806640625
第200次迭代损失 = -1480.289794921875
第300次迭代损失 = -1731.8543701171875
第400次迭代损失 = -1867.0120849609375
第500次迭代损失 = -1623.46728515625
第600次迭代损失 = -1827.7152099609375
第700次迭代损失 = -1860.97216796875
第800次迭代损失 = -1743.3468017578125
第900次迭代损失 = -1622.2218017578125
代码在第三行构造了优化器 optimizer,这个优化器每次可以改良所有权重值,但是这个改良不是一步到位的
需要让优化器反复循环很多次【后面缩进的语句都是要循环的内容】 ———— 每次需要告诉优化器 每次改良的依据是什么
通过 optimizer.step() 完成权重的改良
完成后,就训练好了神经网络
(4)测试神经网络的性能
#生成测试数据
x_test = torch.randn(2,3)
print('测试输入:{}'.format(x_test))
# 查看神经网络的计算结果
y_test = net(x_test)
print ('人工神经网络计算结果: {}'.format(y_test))
print('g的值:{}'.format(g(x_test,y_test)))
#根据理论,计算参考答案
def argmax_g(x):
x0,x1,x2 = x[:,0] ** 0,x[:,1] ** 1,x[:,2] ** 2
return 0.5 * (x0 + x1 + x2)[:, None]
yref_test = argmax_g(x_test)
print('理论最优值:{}'.format(yref_test))
print('g的值:{}'.format(g(x_test,yref_test)))
测试输入:tensor([[ 0.1865, 1.4210, 1.1290],
[-0.2137, 0.1621, 0.9952]])
人工神经网络计算结果: tensor([[1.9692],
[1.0804]], grad_fn=<AddmmBackward>)
g的值:tensor([1.5885, 0.9977], grad_fn=<SubBackward0>)
理论最优值:tensor([[1.8479],
[1.0762]])
g的值:tensor([1.6032, 0.9977])
可以断定,我们的神经网络 已经正确地 输出了最优结果
由于 验证代码的输入是随机确定的。所以每次运行的输入和输出都不一样
Pytorch笔记 (2) 初识Pytorch的更多相关文章
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- [Pytorch] pytorch笔记 <三>
pytorch笔记 optimizer.zero_grad() 将梯度变为0,用于每个batch最开始,因为梯度在不同batch之间不是累加的,所以必须在每个batch开始的时候初始化累计梯度,重置为 ...
- [Pytorch] pytorch笔记 <二>
pytorch笔记2 用到的关于plt的总结 plt.scatter scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, ...
- [Pytorch] pytorch笔记 <一>
pytorch笔记 - torchvision.utils.make_grid torchvision.utils.make_grid torchvision.utils.make_grid(tens ...
- [PyTorch 学习笔记] 1.1 PyTorch 简介与安装
PyTorch 的诞生 2017 年 1 月,FAIR(Facebook AI Research)发布了 PyTorch.PyTorch 是在 Torch 基础上用 python 语言重新打造的一款深 ...
- Storm学习笔记 - Storm初识
Storm学习笔记 - Storm初识 1. Strom是什么? Storm是一个开源免费的分布式计算框架,可以实时处理大量的数据流. 2. Storm的特点 高性能,低延迟. 分布式:可解决数据量大 ...
- LevelDB学习笔记 (1):初识LevelDB
LevelDB学习笔记 (1):初识LevelDB 1. 写在前面 1.1 什么是levelDB LevelDB就是一个由Google开源的高效的单机Key/Value存储系统,该存储系统提供了Key ...
- PyTorch学习笔记之初识word_embedding
import torch import torch.nn as nn from torch.autograd import Variable word2id = {'hello': 0, 'world ...
- 【转载】 pytorch笔记:06)requires_grad和volatile
原文地址: https://blog.csdn.net/jiangpeng59/article/details/80667335 作者:PJ-Javis 来源:CSDN --------------- ...
随机推荐
- JS 转Boolean的两张方法
// 1.Boolean() console.log(Boolean(123)); // true console.log(Boolean(undefined)); // false console. ...
- 对数据劫持 OR 数据代理 的研究------------引用
数据劫持,也叫数据代理. 所谓数据劫持,指的是在访问或者修改对象的某个属性时,通过一段代码拦截这个行为,进行额外的操作或者修改返回结果.比较典型的是 Object.defineProperty() 和 ...
- centos7排查swap占用过高
使用free -h 查看发现服务器在可用内存还有91G的情况下,使用Swap分区空间 查看具体是哪进程在占用Swap分区 ###for i in $( cd /proc;ls |grep " ...
- Redis常用数据类型底层数据结构分析
Redis是一种键值(key-Value)数据库,相对于关系型数据库,它也被叫作非关系型数据库 Redis中,键的数据类型是字符串,但是为了非富数据存储方式,方便开发者使用,值的数据类型有很多 字符串 ...
- 学习springboot(三)——springboot+mybatis出现org.apache.ibatis.binding.BindingException: Invalid bound state
有段时间没搭建过了生疏了,记录下出现此情况且你能通过注解的方式正常进行数据库操作,只是通过mapper.xml不行就可以看看这个了.主要问题应该是配置上,不要太自信自己,再仔细找找.1.查看xml是否 ...
- Cobbler自动装机
preface 我们之前批量安装操作系统的时候都是采用pxe来安装,pxe也是通过网络安装操作系统的,但是PXE依赖于DHCP,HTTP/TFTP,kicstart等支持.安装流程如下所示: 对于上面 ...
- jquery checkbox选择器 语法
jquery checkbox选择器 语法 作用::checkbox 选择器选取类型为 checkbox 的 <input> 元素.大理石平台价格表 语法:$(":checkbo ...
- Nowcoder Sum of Maximum ( 容斥原理 && 拉格朗日插值法 )
题目链接 题意 : 分析 : 分析就直接参考这个链接吧 ==> Click here 大体的思路就是 求和顺序不影响结果.故转化一下思路枚举每个最大值对答案的贡献最后累加就是结果 期间计数的过程 ...
- 2018美团CodeM编程大赛初赛B轮 A题开关灯
题目描述 美团的办公室一共有n层,每层有m个会议室,可以看成是一个n*m的网格图.工程师们每天的工作需要协作的地方很多,经常要到会议室开会解决各种问题.公司是提倡勤俭节约的,因此每次会议室只在使用时才 ...
- POJ 3683 神父赶婚宴 2-SAT+输出模板
题意:一个小镇里面只有一个牧师,现在有些新人要结婚,需要牧师分别去主持一个仪式,给出每对新人婚礼的开始时间 s 和结束时间 t ,还有他们俩的这个仪式需要的时间(每对新人需要的时间长短可能不同) d ...