pytorch之 classification
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible # make fake data
n_data = torch.ones(100, 2)
x0 = torch.normal(2*n_data, 1) # class0 x data (tensor), shape=(100, 2)
y0 = torch.zeros(100) # class0 y data (tensor), shape=(100, 1)
x1 = torch.normal(-2*n_data, 1) # class1 x data (tensor), shape=(100, 2)
y1 = torch.ones(100) # class1 y data (tensor), shape=(100, 1)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # shape (200, 2) FloatTensor = 32-bit floating
y = torch.cat((y0, y1), ).type(torch.LongTensor) # shape (200,1) LongTensor = 64-bit integer # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x), Variable(y) # plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
# plt.show() class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.out = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.out(x)
return x net = Net(n_feature=2, n_hidden=10, n_output=2) # define the network
print(net) # net architecture optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted plt.ion() # something about plotting for t in range(100):
out = net(x) # input x and predict based on x
loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is NOT one-hotted optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if t % 2 == 0:
# plot and show learning process
plt.cla()
prediction = torch.max(out, 1)[1]
pred_y = prediction.data.numpy()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()
pytorch之 classification的更多相关文章
- pytorch 5 classification 分类
import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...
- (转)Awesome PyTorch List
Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...
- 库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)
项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 P ...
- pytorch下对简单的数据进行分类(classification)
看了Movan大佬的文字教程让我对pytorch的基本使用有了一定的了解,下面简单介绍一下二分类用pytorch的基本实现! 希望详细的注释能够对像我一样刚入门的新手来说有点帮助! import to ...
- Iris Classification on PyTorch
Iris Classification on PyTorch code # -*- coding:utf8 -*- from sklearn.datasets import load_iris fro ...
- pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》
论文 < Convolutional Neural Networks for Sentence Classification>通过CNN实现了文本分类. 论文地址: 666666 模型图 ...
- 基于pytorch的CNN、LSTM神经网络模型调参小结
(Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...
- PyTorch教程之Training a classifier
我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
随机推荐
- Python CGI apache在windows下安装
1.首先去下载apache (http://httpd.apache.org/download.cgi)(Apache VC15 binaries and modules download) 2.下载 ...
- Django 数据库连接缓存的坑
https://www.cnblogs.com/xcsg/p/11446990.html
- 异数OS 织梦师-水桶(三)-- RAM共享存储方案
. 异数OS 织梦师-水桶(三)– RAM共享存储方案 本文来自异数OS社区 github: https://github.com/yds086/HereticOS 异数OS社区QQ群: 652455 ...
- 【杂项】各类文件头结合winhex使用-转载
———常用文件头——— JPEG (jpg),文件头:FFD8FFE1 PNG (png),文件头:89504E47 (0D0A1A0A) GIF (gif),文件头:47494638 ZIP Arc ...
- Oracle 11g 安装过程及测试方法
大家可以根据自己的操作系统是多少位(32位或64位)的,到官网下载相应的安装程序,如下图所示. 有一点需要注意,Oracle的安装程序分成2个文件,下载后将2个文件解压到同一目录即可. 下载完 ...
- 20200104模拟赛 问题C 上台拿衣服
题目 分析: 乍一看不就是从楼上扔鸡蛋那道题吗... 然后开始写写写... 设f [ i ] [ j ]表示 i 个记者膜 j 次可以验证多少层楼... 于是开始递推: 我们选取第 i 个记者去尝试其 ...
- 移除sitemap中的entity
下面截图是sitemap所在的位置 如果遇到什么原因,当前使用的entity被弃用需要删除,必须要把当前site map 引用的entity也一并删除. 不然会导致site map不能正常加载
- SpringCloud学习之搭建eureka集群,手把手教学,新手教程
一.为什么需要集群 上一篇文章讲解了如何搭建单个节点的eureka,这篇讲解如何搭建eureka集群,这里的集群还是本地不同的端口执行三个eureka,因为条件不要允许,没有三台电脑,所以大家将就一下 ...
- FTP服务后门利用
开门见山 1. 扫描同网段的靶场机,发现PCS,192.168.31.137 2. 扫描靶场机服务信息和服务版本 3. 快速扫描靶场机全部信息 4. 发现开放ftp服务,并扫描此ftp软件版本是否存在 ...
- AS中加载gradle时出现Gradle sync failed: Could not find com.android.tools.build:gradle.的错误
时间:2019/12/7 这次接着整理加载gradle时出现的错误 出现的错误: Gradle sync failed: Could not find com.android.tools.build: ...