torch分类问题
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible # make fake data
n_data = torch.ones(100, 2)
x0 = torch.normal(2*n_data, 1) # class0 x data (tensor), shape=(100, 2)
y0 = torch.zeros(100) # class0 y data (tensor), shape=(100, 1)
x1 = torch.normal(-2*n_data, 1) # class1 x data (tensor), shape=(100, 2)
y1 = torch.ones(100) # class1 y data (tensor), shape=(100, 1)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # shape (200, 2) FloatTensor = 32-bit floating
y = torch.cat((y0, y1), ).type(torch.LongTensor) # shape (200,) LongTensor = 64-bit integer # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
x, y = Variable(x), Variable(y) # plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
# plt.show() class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.out = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.out(x)
return x net = Net(n_feature=2, n_hidden=10, n_output=2) # define the network,输入两个特征
print(net) # net architecture optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#分类输出的为概率 plt.ion() # something about plotting for t in range(100):
out = net(x) # input x and predict based on x,输出原值不是概率,需要用激活函数转化为概率
loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is NOT one-hotted optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if t % 2 == 0:
# plot and show learning process
plt.cla()
prediction = torch.max(F.softmax(out), 1)[1]
pred_y = prediction.data.numpy()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()
输出结果是将散点图分为两类。
torch分类问题的更多相关文章
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 30个深度学习库:按Python、C++、Java、JavaScript、R等10种语言分类
30个深度学习库:按Python.C++.Java.JavaScript.R等10种语言分类 包括 Python.C++.Java.JavaScript.R.Haskell等在内的一系列编程语言的深度 ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- 深度学习之 cnn 进行 CIFAR10 分类
深度学习之 cnn 进行 CIFAR10 分类 import torchvision as tv import torchvision.transforms as transforms from to ...
- [深度应用]·实战掌握PyTorch图片分类简明教程
[深度应用]·实战掌握PyTorch图片分类简明教程 个人网站--> http://www.yansongsong.cn/ 项目GitHub地址--> https://github.com ...
- 用Pytorch训练MNIST分类模型
本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\). 编写代码的步骤如下 载入数据集,分别为训练集和测试集 让数据集可以迭代 定义模型,定义损失函数,训练模型 代码 import ...
- 学习笔记CB012: LSTM 简单实现、完整实现、torch、小说训练word2vec lstm机器人
真正掌握一种算法,最实际的方法,完全手写出来. LSTM(Long Short Tem Memory)特殊递归神经网络,神经元保存历史记忆,解决自然语言处理统计方法只能考虑最近n个词语而忽略更久前词语 ...
- pytorch解决鸢尾花分类
半年前用numpy写了个鸢尾花分类200行..每一步计算都是手写的 python构建bp神经网络_鸢尾花分类 现在用pytorch简单写一遍,pytorch语法解释请看上一篇pytorch搭建简单网 ...
- [转] Torch中实现mini-batch RNN
工作中需要把一个SGD的LSTM改造成mini-batch的LSTM, 两篇比较有用的博文,转载mark https://zhuanlan.zhihu.com/p/34418001 http://ww ...
随机推荐
- 自己编写 EntityTypeConfiguration
1.新建类库 EFCore.EntityTypeConfig ,安装nuget PM> Install-Package Microsoft.EntityFrameworkCore 2.新建接口 ...
- Golang 入门系列(六)理解Go中的协程(Goroutine)
前面讲的都是一些Go 语言的基础知识,感兴趣的朋友可以先看看之前的文章.https://www.cnblogs.com/zhangweizhong/category/1275863.html. 今天就 ...
- maven eclipse web 项目 问题 cannot change version of project facet dynamic web module to 3.0
cannot change version of project facet dynamic web module to 3.0 修改 web.xml 头部 xsi:schemaLocation=&q ...
- c#, AOP动态代理实现动态权限控制(一)
因最近工作需要一个动态的权限配置功能,具体实现逻辑是c#的动态代理功能,废话不多说,直接干货.需求: 用户分为管理员.普通用户 不同用户拥有不同功能权限 用户的权限可配置 新增功能时,不用修改权限配置 ...
- golang lua使用示例
package main import ( "fmt" "github.com/yuin/gopher-lua" ) func hello(L *lua.LSt ...
- ubuntu fiddler firefox http网页不能访问 Secure Connection Failed
1. 给firefox导入fiddler的证书 1) fiddler:tools --> fiddler opthins --> https --> 勾选Capture HTTPS ...
- Vue组件以及组件之间的通信
一.组件的注册 1. 全局组件注册 1. 注册基本语法Vue.component Vue.component("my_header", { template: `<div&g ...
- Django模板语言进阶
一.母板 1.什么情况下使用母版 当多个页面的大部分内容都一样的时候,我们可以把相同的部分提取出来,放到一个单独的母版HTML文件中 然后在母版中定义需要被替换的block 例如:母板页面 <! ...
- [BZOJ 4818] [SDOI 2017] 序列计数
Description Alice想要得到一个长度为 \(n\) 的序列,序列中的数都是不超过 \(m\) 的正整数,而且这 \(n\) 个数的和是 \(p\) 的倍数. Alice还希望,这 \(n ...
- [十二省联考2019]字符串问题——后缀自动机+parent树优化建图+拓扑序DP+倍增
题目链接: [十二省联考2019]字符串问题 首先考虑最暴力的做法就是对于每个$B$串存一下它是哪些$A$串的前缀,然后按每组支配关系连边,做一遍拓扑序DP即可. 但即使忽略判断前缀的时间,光是连边的 ...