这里使用pytorch进行一个简单的二分类模型

导入所有我们需要的库

import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

接着我们这里 生成我们需要的假数据

# set seed
torch.manual_seed(1) # make fake data
n_data = torch.ones(100, 2)
x0 = torch.normal(2 * n_data, 1)
y0 = torch.zeros(100)
x1 = torch.normal(-2 * n_data, 1)
y1 = torch.ones(100) x = torch.cat((x0, x1), 0).type(torch.FloatTensor)
y = torch.cat((y0, y1), ).type(torch.LongTensor)

  

我们先定义好我们需要的net的这个类

class Net(torch.nn.Module):

    def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.out = torch.nn.Linear(n_hidden, n_output) def forward(self, x):
x = F.relu(self.hidden(x))
x = self.out(x)
return x

现在开始搭建我们需要的网络

我们构建一个只有1个隐藏层的网络

用SGD的方法对损失方程进行优化

然后用交叉熵来作为我们loss function

net = Net(n_feature=2, n_hidden=10, n_output=2)
print(net) optimizer = torch.optim.SGD(net.parameters(), lr=0.0015)
loss_func = torch.nn.CrossEntropyLoss()

接着我们开始训练我们的网络

plt.ion()

for t in range(200):
out = net(x)
loss = loss_func(out, y) optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() if t % 2 == 0:
plt.cla()
predcition = torch.max(out, 1)[1]
pred_y = predcition.data.numpy()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()

  

接着我们可以看到  已经把我们做的假数据成功分成了两类

pytorch tutorial 2的更多相关文章

  1. Pytorch tutorial 之Datar Loading and Processing (1)

    引自Pytorch tutorial: Data Loading and Processing Tutorial 这节主要介绍数据的读入与处理. 数据描述:人脸姿态数据集.共有69张人脸,每张人脸都有 ...

  2. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  3. Pytorch tutorial 之Datar Loading and Processing (2)

    上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1. 自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset. ...

  4. Pytorch tutorial 之Transfer Learning

    引自官方:  Transfer Learning tutorial Ng在Deeplearning.ai中讲过迁移学习适用于任务A.B有相同输入.任务B比任务A有更少的数据.A任务的低级特征有助于任务 ...

  5. pytorch tutorial 1

    这里用torch 做一个最简单的测试 目标就是我们用torch 建立一个一层的网络,然后拟合一组可以回归的数据 import torch from torch.autograd import Vari ...

  6. Pytorch入门之VAE

    关于自编码器的原理见另一篇博客 : 编码器AE & VAE 这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现. 1. 稀疏编码 首先介绍一下“稀疏编码 ...

  7. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  8. 吐血整理:PyTorch项目代码与资源列表 | 资源下载

    http://www.sohu.com/a/164171974_741733   本文收集了大量基于 PyTorch 实现的代码链接,其中有适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文 ...

  9. Ubuntu 16.04上anaconda安装和使用教程,安装jupyter扩展等 | anaconda tutorial on ubuntu 16.04

    本文首发于个人博客https://kezunlin.me/post/23014ca5/,欢迎阅读最新内容! anaconda tutorial on ubuntu 16.04 Guide versio ...

随机推荐

  1. PHP MySQLi 参考手册

    PHP MySQLi函数 PHP MySQLi是MySQL的增强版本,PHP7 已经废弃了MySQL扩展,全面推荐使用MySQLi或者PDO. MySQLi安装>>>>> ...

  2. 高强度学习训练第一天总结:Java内存区域

    ---恢复内容开始--- 程序计数器: 程序计数器(Program Counter Register) 是一块较小的空间,他可以看作是当前线程所执行的字节码的行号指示器.在虚拟机的概念模型里(仅是概念 ...

  3. 构建maven项目,自定义目录结构方法

    构建maven项目 创建自定义的文件目录方法: 在项目名称右键-->Builder Path-->Configure Builder Path...Source菜单下的Add Folder ...

  4. 如何使 highchart图表标题文字可选择复制

    highchart图表的一个常见问题是不能复制文字 比如官网的某个图表例子,文字不能选择,也无法复制,有时产品会抓狂... 本文给出一个简单的方案,包括一些解决的思路,希望能帮助到有需要的人 初期想了 ...

  5. AI 图像识别的测试

    随着AI 的浪潮发展,AI 的应用场景越来越广泛,其中计算机视觉更是运用到我们生活中的方方面面.作为一个测试人员,需要紧跟上 AI 的步伐,快速从传统业务测试,转型到 AI 的测试上来.而人脸识别作为 ...

  6. 更换Ubuntu软件源

    对于Ubuntu系统, 不同的版本的源都不一样,每一个版本都有自己专属的源. 而对于 Ubuntu 的同一个发行版本,它的源又分布在全球范围内的服务器上.Ubuntu 默认使用的官方源的服务器在欧洲, ...

  7. Unity检视面板的继承方法研究 (二)

    之前做了普通对象的可继承的检视面板类, 现在想要实现对Unity自带的检视面板的继承的话, 要怎样写呢? 万变不离其宗,  仍然是围绕UnityEditor.Editor.CreateEditor 这 ...

  8. mysql的floor()报错注入方法详细分析

    刚开始学习sql注入,遇见了 select count(*) from table group by floor(rand(0)*2); 这么条语句.在此做个总结. (更好的阅读体验可访问 这里 ) ...

  9. JS高阶---进程与线程

    [大纲] 二级大纲: 三级大纲: [主体] (1)进程process 如下所示,两者内存空间相互独立 (2)线程thread (3)图解 注意:有的程序是多进程的,有的时单进程的 (4)单线程与多线程 ...

  10. mac 修改 vim 配色

    1. 看看系统有哪些自带配色方案 ls /usr/share/vim/vim73/colors README.txt darkblue.vim delek.vim elflord.vim koehle ...