import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible # make fake data
n_data = torch.ones(100, 2)
x0 = torch.normal(2*n_data, 1) # class0 x data (tensor), shape=(100, 2)
y0 = torch.zeros(100) # class0 y data (tensor), shape=(100, 1)
x1 = torch.normal(-2*n_data, 1) # class1 x data (tensor), shape=(100, 2)
y1 = torch.ones(100) # class1 y data (tensor), shape=(100, 1)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # shape (200, 2) FloatTensor = 32-bit floating
y = torch.cat((y0, y1), ).type(torch.LongTensor) # shape (200,1) LongTensor = 64-bit integer # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x), Variable(y) # plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
# plt.show() class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.out = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.out(x)
return x net = Net(n_feature=2, n_hidden=10, n_output=2) # define the network
print(net) # net architecture optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted plt.ion() # something about plotting for t in range(100):
out = net(x) # input x and predict based on x
loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is NOT one-hotted optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if t % 2 == 0:
# plot and show learning process
plt.cla()
prediction = torch.max(out, 1)[1]
pred_y = prediction.data.numpy()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()

pytorch之 classification的更多相关文章

  1. pytorch 5 classification 分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  2. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  3. 库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

    项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 P ...

  4. pytorch下对简单的数据进行分类(classification)

    看了Movan大佬的文字教程让我对pytorch的基本使用有了一定的了解,下面简单介绍一下二分类用pytorch的基本实现! 希望详细的注释能够对像我一样刚入门的新手来说有点帮助! import to ...

  5. Iris Classification on PyTorch

    Iris Classification on PyTorch code # -*- coding:utf8 -*- from sklearn.datasets import load_iris fro ...

  6. pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》

    论文  < Convolutional Neural Networks for Sentence Classification>通过CNN实现了文本分类. 论文地址: 666666 模型图 ...

  7. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  8. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  9. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

随机推荐

  1. Java set接口之HashSet集合原理讲解

    Set接口 java.util.set接口继承自Collection接口,它与Collection接口中的方法基本一致, 并没有对 Collection接口进行功能上的扩充,只是比collection ...

  2. Redis高可用方案-哨兵与集群

    Redis高可用方案 一.名词解释   二.主从复制 Redis主从复制模式可以将主节点的数据同步给从节点,从而保障当主节点不可达的情况下,从节点可以作为 后备顶上来,并且可以保障数据尽量不丢失(主从 ...

  3. svn或git 提交文件排除

    也可以参考  https://blog.csdn.net/chenmintong/article/details/79725324 乌龟git 过滤掉忽略文件(首先右键 某文件 删除并添加到忽略列表 ...

  4. vue+jest+vue-test-utils 单元测试

           jest是Facebook的一套开源的JavaScript测试框架,它集成了快照测试.断言.mock以及覆盖率报告等功能,很全面而且基本不需要太多的配置便可使用Vue-Test-Util ...

  5. 从0开发3D引擎(八):准备“搭建引擎雏形”

    大家好,现在开始本系列的第三部分,按照以下几个步骤来搭建引擎雏形: 1.分析引擎的需求 2.实现最小的3D程序 3.从中提炼引擎原型 4.一步一步地对引擎进行改进,使其具备良好的架构 5.实现与架构相 ...

  6. java面试| 线程面试题集合

    集合的面试题就不罗列了,基本上在深入理解集合系列已覆盖 「 深入浅出 」java集合Collection和Map 「 深入浅出 」集合List 「 深入浅出 」集合Set 这里搜罗网上常用线程面试题, ...

  7. linux 为动态分配的Virtualbox虚拟硬盘扩容

    如何为动态分配的Virtualbox虚拟硬盘扩容 查看虚拟硬盘是否是动态分配大小 打开虚拟机的设置界面,在左侧栏点击存储.在存储树下面选择你的虚拟硬盘.在右边可以看见虚拟硬盘的信息.在下面可以看见,我 ...

  8. 链接拼接的方法(用于解决同一个脚本返回两种不同的url链接的问题)

    实例一: 上图所示 爬虫返回的链接有一部分带有http前缀,有一部分没有,且也不知道具体哪些链接会出现没有前缀的情况 后面如果通过返回链接进行再次访问,那么肯定会出现报错的问题 思路: 判断 返回值内 ...

  9. [校内训练19_09_02]A

    题意 给出N 个形如$f_i(x) = a_i x^2 + b_i x $的二次函数. 有Q 次询问,每次给出一个x,询问$max{\{f_i(x)\}}$.$N,Q \leq 5*10^5$. 思考 ...

  10. Servlet梳理

    Servlet 梳理 概述 Web 技术成为当今主流的互联网 Web 应用技术之一,而 Servlet 是 Java Web 技术的核心基础. 要介绍 Servlet 必须要先把 Servlet 容器 ...