莫烦PyTorch学习笔记(五)——分类
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
# make fake data
n_data = torch.ones(, )
x0 = torch.normal(*n_data, ) #每个元素(x,y)是从 均值=*n_data中对应位置的取值,标准差为1的正态分布中随机生成的
y0 = torch.zeros() # 给每个元素一个0标签
x1 = torch.normal(-*n_data, ) # 每个元素(x,y)是从 均值=-*n_data中对应位置的取值,标准差为1的正态分布中随机生成的
y1 = torch.ones() # 给每个元素一个1标签
x = torch.cat((x0, x1), ).type(torch.FloatTensor) # shape (, ) FloatTensor = -bit floating
y = torch.cat((y0, y1), ).type(torch.LongTensor) # shape (,) LongTensor = -bit integer
# torch can only train on Variable, so convert them to Variable
x, y = Variable(x), Variable(y) # draw the data
plt.scatter(x.data.numpy()[:, ], x.data.numpy()[:, ], c=y.data.numpy())#c是一个颜色序列 #plt.show()
#神经网络模块
class Net(torch.nn.Module):#继承神经网络模块
def __init__(self,n_features,n_hidden,n_output):#初始化神经网络的超参数
super(Net,self).__init__()#调用父类神经网络模块的初始化方法,上面三行固定步骤,不用深究
self.hidden = torch.nn.Linear(n_features,n_hidden)#指定隐藏层有多少输入,多少输出
self.predict = torch.nn.Linear(n_hidden, n_output)#指定预测层有多少输入,多少输出
def forward(self,x):#搭建神经网络
x = F.relu(self.hidden(x))#积极函数激活加工经过隐藏层的x
x = self.predict(x)#隐藏层的数据经过预测层得到预测结果
return x
net = Net(,,)#声明一个类对象
print(net) plt.ion()#在Plt.ion和plt.ioff之间的代码,交互绘图
plt.show()
#神经网络优化器,主要是为了优化我们的神经网络,使他在我们的训练过程中快起来,节省社交网络训练的时间。
optimizer = torch.optim.SGD(net.parameters(),lr = 0.01)#其实就是神经网络的反向传播,第一个参数是更新权重等参数,第二个对应的是学习率
loss_func = torch.nn.CrossEntropyLoss()#标签误差代价函数 for t in range():
out = net(x)
loss = loss_func(out,y)#计算损失
optimizer.zero_grad()#梯度置零
loss.backward()#反向传播
optimizer.step()#计算结点梯度并优化,
if t % == :
plt.cla()# Clear axis即清除当前图形中的之前的轨迹
prediction = torch.max(F.softmax(out), )[]#转换为概率,后面的一是最大值索引,如果为0则返回最大值
pred_y = prediction.data.numpy().squeeze()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, ], x.data.numpy()[:, ], c=pred_y, s=, lw=, cmap='RdYlGn')
accuracy = sum(pred_y == target_y) / .#求准确率
plt.text(1.5, -, 'Accuracy=%.2f' % accuracy, fontdict={'size': , 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()
下面的是一个比较快搭建神经网络的代码,在上面的代码进行修改,代码如下
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
# make fake data
n_data = torch.ones(, )
x0 = torch.normal(*n_data, ) #每个元素(x,y)是从 均值=*n_data中对应位置的取值,标准差为1的正态分布中随机生成的
y0 = torch.zeros() # 给每个元素一个0标签
x1 = torch.normal(-*n_data, ) # 每个元素(x,y)是从 均值=-*n_data中对应位置的取值,标准差为1的正态分布中随机生成的
y1 = torch.ones() # 给每个元素一个1标签
x = torch.cat((x0, x1), ).type(torch.FloatTensor) # shape (, ) FloatTensor = -bit floating
y = torch.cat((y0, y1), ).type(torch.LongTensor) # shape (,) LongTensor = -bit integer
# torch can only train on Variable, so convert them to Variable
x, y = Variable(x), Variable(y) # draw the data
plt.scatter(x.data.numpy()[:, ], x.data.numpy()[:, ], c=y.data.numpy())#c是一个颜色序列 #plt.show()
#神经网络模块
net2 = torch.nn.Sequential(
torch.nn.Linear(,),
torch.nn.ReLU(),
torch.nn.Linear(,)
) plt.ion()#在Plt.ion和plt.ioff之间的代码,交互绘图
plt.show()
#神经网络优化器,主要是为了优化我们的神经网络,使他在我们的训练过程中快起来,节省社交网络训练的时间。
optimizer = torch.optim.SGD(net.parameters(),lr = 0.01)#其实就是神经网络的反向传播,第一个参数是更新权重等参数,第二个对应的是学习率
loss_func = torch.nn.CrossEntropyLoss()#标签误差代价函数 for t in range():
out = net(x)
loss = loss_func(out,y)#计算损失
optimizer.zero_grad()#梯度置零
loss.backward()#反向传播
optimizer.step()#计算结点梯度并优化,
if t % == :
plt.cla()# Clear axis即清除当前图形中的之前的轨迹
prediction = torch.max(F.softmax(out), )[]#转换为概率,后面的一是最大值索引,如果为0则返回最大值
pred_y = prediction.data.numpy().squeeze()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, ], x.data.numpy()[:, ], c=pred_y, s=, lw=, cmap='RdYlGn')
accuracy = sum(pred_y == target_y) / .#求准确率
plt.text(1.5, -, 'Accuracy=%.2f' % accuracy, fontdict={'size': , 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()
莫烦PyTorch学习笔记(五)——分类的更多相关文章
- 莫烦PyTorch学习笔记(五)——模型的存取
import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...
- 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)
莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...
- 莫烦pytorch学习笔记(七)——Optimizer优化器
各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...
- 莫烦PyTorch学习笔记(六)——批处理
1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...
- 莫烦pytorch学习笔记(二)——variable
.简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...
- 莫烦PyTorch学习笔记(三)——激励函数
1. sigmod函数 函数公式和图表如下图 在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...
- 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )
CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...
- 莫烦 - Pytorch学习笔记 [ 一 ]
1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...
- 莫烦PyTorch学习笔记(四)——回归
下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...
随机推荐
- ECMAScript1.2 表达式|语句|break|continue
表达式 一个表达式可以产生一个值,有可能是运算,函数调用, 有可能是字面量,表达式可以放在任何需要值的地方. 语句 语句可以理解为一个行为,循环语句和判断语句就是典型的语句. 一个程序有很多个语句组成 ...
- 01、requests 基本使用
requests模块的基本使用 基于网络请求的模块. 环境的安装:pip install requests 作用:模拟浏览器发起请求 分析requests的编码流程: 1.指定url 2.发起了请求 ...
- 1-电脑C盘(系统盘)清理
推荐,亲测有效! 转自: https://baijiahao.baidu.com/s?id=1612762644229315967&wfr=spider&for=pc
- Django form组件 与 cookie/session
目录 一.form组件 二.cookie.session 返回Django 组件 一.form组件 1.1 以注册功能为例 注册功能 1.渲染前端标签获取用户输入 --> 渲染标签 2.获取用户 ...
- 关于python merge后数据行数增加的问题
其中一个可能的原因是 join 的 data 里面的列不唯一,也就是要匹配的表里面有些一行数据对应了被匹配表多条数据,这样出来可能会增加行数,可以再查一下被匹配表里的数据是否去重
- FaceNet pre-trained模型以及FaceNet源码使用方法和讲解
Pre-trained models Model name LFW accuracy Training dataset Architecture 20180408-102900 0.9905 CASI ...
- BlueHost主机建站方案怎样选择?
BlueHost是知名美国主机商,近年来BlueHost不断加强中国市场客户的用户体验,提供多种主机租用方案,基本能够满足各类网站建设需求.下面就和大家介绍一下建站应该怎样选择主机. 1.中小型网站 ...
- linux上给其他在线用户发送信息(wall, write, talk, mesg)
linux上给其他在线用户发送信息(wall, write, talk, mesg) 2018-01-05 lonskyMR 转自 恶之一眉 修改 微信分享: 设置登录提示 /et ...
- 关于DB9一些信号的缩写
https://www.cnblogs.com/CCJVL/archive/2010/02/04/1663565.html 场景:PCB板子与PC通过RS232连接,以下信号的方向相对于PCB板子而言 ...
- SQL Server Download
{ https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads }