import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt n_data = torch.ones(100, 2) # 100个具有2个属性的数据 shape=(100,2)
x0 = torch.normal(2*n_data, 1) # 根据原始数据生成随机数据,第一个参数是均值,第二个是方差,这里设置为1了,shape=(100,2)
y0 = torch.zeros(100) # 100个0作为第一类数据的标签,shape=(100,1)
x1 = torch.normal(-2*n_data, 1)
y1 = torch.ones(100) x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # cat数据合并 32-bit floating
y = torch.cat((y0, y1), 0).type(torch.LongTensor) # 64-bit integer x, y = Variable(x), Variable(y) plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], s=100, lw=0)
plt.show() class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x net = Net(2, 10, 2) # 数据是二维的所以输入特征是2,输出是两种类别所以输出层特征是2
print(net)
> Net(
> (hidden): Linear(in_features=2, out_features=10, bias=True)
> (predict): Linear(in_features=10, out_features=2, bias=True)
> )
# plt.ion()
plt.show() optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss() # 交叉熵 CrossEntropy [0.1, 0.2, 0.7] [0,0,1] 数据越大,是这一类的概率越大 for t in range(100):
out = net.forward(x) # 数据经过所有的网络,输出预测值 loss = loss_func(out, y) # 输入与预测值之间的误差loss optimizer.zero_grad() # 梯度重置
loss.backward() # 损失值反向传播,计算梯度
optimizer.step() # 梯度优化 if t % 2 == 0:
# 画图部分 plot and show learning process
plt.cla()
prediction = torch.max(out, 1)[1]
pred_y = prediction.data.numpy()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
plt.pause(0.5) # plt.ioff()
plt.show()

END

pytorch 5 classification 分类的更多相关文章

  1. tensorflow学习之(九)classification 分类问题之分类手写数字0-9

    #classification 分类问题 #例子 分类手写数字0-9 import tensorflow as tf from tensorflow.examples.tutorials.mnist ...

  2. pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》

    论文  < Convolutional Neural Networks for Sentence Classification>通过CNN实现了文本分类. 论文地址: 666666 模型图 ...

  3. pytorch解决鸢尾花分类

    半年前用numpy写了个鸢尾花分类200行..每一步计算都是手写的  python构建bp神经网络_鸢尾花分类 现在用pytorch简单写一遍,pytorch语法解释请看上一篇pytorch搭建简单网 ...

  4. pytorch实现LeNet5分类CIFAR10

    关于LeNet-5 LeNet5的Pytorch实现在网络上已经有很多了,这里记录一下自己的实现方法. LeNet-5出自于Gradient-Based Learning Applied to Doc ...

  5. 用Pytorch训练MNIST分类模型

    本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\). 编写代码的步骤如下 载入数据集,分别为训练集和测试集 让数据集可以迭代 定义模型,定义损失函数,训练模型 代码 import ...

  6. 【caffe范例详解】 - 1.Classification分类

    1. 安装 首先,导入numpy和matplotlib库 # numpy是常用的科学计算库,matplot是常用的绘图库 import numpy as np import matplotlib.py ...

  7. pytorch之 classification

    import torch import torch.nn.functional as F import matplotlib.pyplot as plt # torch.manual_seed(1) ...

  8. pytorch 文本情感分类和命名实体识别NER中LSTM输出的区别

    文本情感分类: 文本情感分类采用LSTM的最后一层输出 比如双层的LSTM,使用正向的最后一层和反向的最后一层进行拼接 def forward(self,input): ''' :param inpu ...

  9. pytorch LSTM情感分类全部代码

    先运行main.py进行文本序列化,再train.py模型训练 dataset.py from torch.utils.data import DataLoader,Dataset import to ...

随机推荐

  1. cogs 306. [SGOI] 糊涂的记者

    306. [SGOI] 糊涂的记者 ★★★   输入文件:sign.in   输出文件:sign.out   评测插件时间限制:1 s   内存限制:128 MB [问题描述] 在如今的信息社会中,时 ...

  2. [React] Reference a node using createRef() in React 16.3

    In this lesson, we look at where we came from with refs in React. Starting with the deprecated strin ...

  3. 【MongoDB】The basic operation of Index in MongoDB

    In the past four blogs, we attached importance to the index, including description and comparison wi ...

  4. [BZOJ 3884][欧拉定理]上帝与集合的正确使用方法

    看看我们机房某畸形写的题解:http://blog.csdn.net/sinat_27410769/article/details/46754209 此题为popoQQQ神犇所出,在此orz #inc ...

  5. 4418: [Shoi2013]扇形面积并|二分答案|树状数组

    为何感觉SHOI的题好水. ..又是一道SB题 从左到右枚举每个区间,遇到一个扇形的左区间就+1.遇到右区间就-1,然后再树状数组上2分答案,还是不会码log的.. SHOI2013似乎另一道题发牌也 ...

  6. html5的postmessage实现js前端跨域訪问及调用解决方式

    关于跨域訪问.使用JSONP的方法.我前面已经demo过了.详细见http://supercharles888.blog.51cto.com/609344/856886,HTML5提供了一个很强大的A ...

  7. node-webkit 屏幕截图功能

    做 IM 屏幕截图是少不了的,之前 windows 版本是调用的 qq 输入法的截图功能,这个版本又再次尝试自己实现发现是可以的,getusermedia 的权限很高,代码如下 <!DOCTYP ...

  8. C# 监控Windows睡眠与恢复

    SystemEvents.PowerModeChanged += SystemEvents_PowerModeChanged; private void SystemEvents_PowerModeC ...

  9. Noip蒟蒻专用模板

    目录 模板 数论 线性筛素数 线性筛欧拉 裴蜀定理 卢卡斯定理 矩阵快速幂 逆元 高斯消元 图论 割点 最小生成树 倍增 SPFA 负环 堆优化迪杰斯特拉 匈牙利 数据结构 树状数组 ST表 线段树 ...

  10. HDU4920 矩阵乘法

    嗯嗯 就算是水题吧. (缩完行就15行) 题意:两个n*n的矩阵相乘(n<=800),结果对3取模 思路:先对3取模,所以两个矩阵里面会出现很多0,所以可以先枚举一个矩阵,只有当该位置不是0的时 ...