一、分类任务:

将以下两类分开。

创建数据代码:

# make fake data
n_data = torch.ones(100, 2)
x0 = torch.normal(2*n_data, 1) # class0 x data (tensor), shape=(100, 2)
y0 = torch.zeros(100) # class0 y data (tensor), shape=(100, 1)
x1 = torch.normal(-2*n_data, 1) # class1 x data (tensor), shape=(100, 2)
y1 = torch.ones(100) # class1 y data (tensor), shape=(100, 1)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # shape (200, 2) FloatTensor = 32-bit floating
y = torch.cat((y0, y1), ).type(torch.LongTensor) # shape (200,) LongTensor = 64-bit integer # torch can only train on Variable, so convert them to Variable
x, y = Variable(x), Variable(y) plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
plt.show()

二、步骤

  1. 导入包

  2. 创建模型

  3. 设置优化器和损失函数

  4. 训练模型

三、代码:

导入包:

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline torch.manual_seed(1) # reproducible

创建模型:

class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.out = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.out(x)
return x

设置优化器和损失函数

#输入的x为2维张量,输出有两类
net = Net(n_feature=2, n_hidden=10, n_output=2) # define the network
print(net) # net architecture # Loss and Optimizer
# Softmax is internally computed.
# Set parameters to be updated.
optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted

训练模型并画图展示

plt.ion()   # something about plotting
plt.show() for t in range(100):
out = net(x) # input x and predict based on x
loss = loss_func(out, y) # must be (1. nn output, 2. target), the target label is NOT one-hotted optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if t % 10 == 0 or t in [3, 6]:
# plot and show learning process
plt.cla()
_, prediction = torch.max(F.softmax(out), 1) #这里是得到softmax之后最大概率的y预测值。
pred_y = prediction.data.numpy().squeeze()
print(pred_y)
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
accuracy = sum(pred_y == target_y)/200.
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
plt.show()
# plt.pause(0.1) plt.ioff()

结果展示:

Pytorch实战(3)----分类的更多相关文章

  1. 深度学习之PyTorch实战(1)——基础学习及搭建环境

    最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...

  2. 参考《深度学习之PyTorch实战计算机视觉》PDF

    计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...

  3. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

  4. pytorch实战(二)hw2——预测收入是否高于50000,分类问题

    代码和ppt: https://github.com/Iallen520/lhy_DL_Hw 遇到的一些细节问题: 1. X_train文件不带后缀名csv,所以不是规范的csv文件,不能直接用pd. ...

  5. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  6. Bert实战---情感分类

    1.情感分析语料预处理 使用酒店评论语料,正面评论和负面评论各5000条,用BERT参数这么大的模型, 训练会产生严重过拟合,,泛化能力差的情况, 这也是我们下面需要解决的问题; 2.sigmoid二 ...

  7. pytorch解决鸢尾花分类

    半年前用numpy写了个鸢尾花分类200行..每一步计算都是手写的  python构建bp神经网络_鸢尾花分类 现在用pytorch简单写一遍,pytorch语法解释请看上一篇pytorch搭建简单网 ...

  8. 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化

    上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...

  9. Scrapy实战-新浪网分类资讯爬虫

    项目要求: 爬取新浪网导航页所有下所有大类.小类.小类里的子链接,以及子链接页面的新闻内容. 什么是Scrapy框架: Scrapy是用纯Python实现一个为了爬取网站数据.提取结构性数据而编写的应 ...

随机推荐

  1. 任务调度分配题两道 POJ 1973 POJ 1180(斜率优化复习)

    POJ 1973 这道题以前做过的.今儿重做一次.由于每个程序员要么做A,要么做B,可以联想到0/1背包(谢谢N巨).这样,可以设状态 dp[i][j]为i个程序员做j个A项目同时,最多可做多少个B项 ...

  2. 2.1-VLAN/TRUNK/VTP

    2.1-VLAN/TRUNK/VTP     注意:配置VLAN时要退出VLAN配置模式才会执行 如果VLAN被删除或者shutdown,那么属于这个vlan的接口将被阻塞(灯一直是橙色,变不了绿色) ...

  3. [RxJS] Throttling vs Debouncing

    DebounceTime: It's like delay, but passes only the most recent value from each burst of emissions. T ...

  4. spring基于通用Dao的多数据源配置

    有时候在一个项目中会连接多个数据库,须要在spring中配置多个数据源,近期就遇到了这个问题,因为我的项目之前是基于通用Dao的,配置的时候问题不断.这样的方式和资源文件冲突:扫描映射文件的话,Sql ...

  5. Yum重装走过的坑

    今天因为用yum方式安装mongo遇到报错,从而我选择卸载yum并重新安装. 我先选择了用rpm方式进行重装,从163的packages列表里面找到64位redhat6.5可以用的三个rpm包,安装过 ...

  6. cocos2d-x 3.7 win7 32+Android 环境配置

    之前用的cocos2d-x 2.2.6 版本号,近期换成了3.7.眼下的最新版.整个过程中也碰到了不少问题.如今已经成功移植到手机上了. 分享下整个过程,希望能帮到别人.(所需软件已打包) [下载软件 ...

  7. Android 组件ContentProvider

    Android 组件ContentProvider Android的数据存储有五种方式Shared Preferences.网络存储.文件存储.外储存储.SQLite,一般这些存储都仅仅是在单独的一个 ...

  8. Extjs显示图片

    1.首先做一个容器 xtype : 'container', // 第2行 anchor : '100%', layout : 'column', items : [{ columnWidth : 0 ...

  9. luogu1081 开车旅行 树上倍增

    题目大意 小A和小B决定利用假期外出旅行,他们将想去的城市从1到N编号,且编号较小的城市在编号较大的城市的西边,已知各个城市的海拔高度互不相同,记城市i 的海拔高度为Hi,城市i 和城市j 之间的距离 ...

  10. curses-键盘编码-openssl加解密【转】

    本文转载自;https://zhuanlan.zhihu.com/p/26164115 1.1 键盘编码 按键过程:当用户按下某个键时, 1.键盘会检测到这个动作,并通过键盘控制器把扫描码(scan ...