教程:

https://bennix.github.io/

https://bennix.github.io/blog/2017/12/14/chain_basic/

https://bennix.github.io/blog/2017/12/18/Chain_Tutorial1/

模块 功能
datasets 输入数据可以被格式化为这个类的模型输入。它涵盖了大部分输入数据结构的用例。
variable 它是一个函数/连接/Chain的输出。
functions 支持深度学习中广泛使用的功能的框架,例如 sigmoid, tanh, ReLU等
links 支持深度学习中广泛使用的层的框架,例如全连接层,卷积层等
Chain 连接和函数(层)连接起来形成一个“模型”。
optimizers 指定用于调整模型参数的什么样的梯度下降方法,例如 SGD, AdaGrad, Adam.
serializers 保存/加载训练状态。例如 model, optimizer 等
iterators 定义训练器使用的每个小批量数据。
training.updater 定义Trainer中使用的每个前向、反向传播的参数更新过程。
training.Trainer 管理训练器

下面是chainer模块的导入语句。

# Initial setup following 
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

示例 minst:

MNIST数据集由70,000个尺寸为28×28(即784个像素)的灰度图像和相应的数字标签组成。数据集默认分为6万个训练图像和10,000个测试图像。我们可以通过datasets.get_mnist()获得矢量化版本(即一组784维向量)。

train, test = datasets.get_mnist()

此代码自动下载MNIST数据集并将NumPy数组保存到 $(HOME)/.chainer 目录中。返回的训练集和测试集可以看作图像标签配对的列表(严格地说,它们是TupleDataset的实例)。

我们还必须定义如何迭代这些数据集。我们想要在数据集的每次扫描开始时对每个epoch的训练数据集进行重新洗牌。在这种情况下,我们可以使用iterators.SerialIterator

train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)

另一方面,我们不必洗牌测试数据集。在这种情况下,我们可以通过shuffle = False来禁止混洗。当底层数据集支持快速切片时,它使迭代速度更快。

test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)

当所有的例子被访问时,我们停止迭代通过设定 repeat=False 。测试/验证数据集通常需要此选项;没有这个选项,迭代进入一个无限循环。

接下来,我们定义架构。我们使用一个简单的三层网络,每层100个单元。

class MLP(Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__()
with self.init_scope():
# the size of the inputs to each layer will be inferred
self.l1 = L.Linear(None, n_units) # n_in -> n_units # 第一个参数为输入维数,第二个参数为输出维数。这里注意,第一个参数填none时,则输入维数将在第一次前向传播时确定
self.l2 = L.Linear(None, n_units) # n_units -> n_units
self.l3 = L.Linear(None, n_out) # n_units -> n_out def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
y = self.l3(h2)
return y

该链接使用relu()作为激活函数。请注意,“l3”链接是最终的全连接层,其输出对应于十个数字的分数。

为了计算损失值或评估预测的准确性,我们在上面的MLP连接的基础上定义一个分类器连接:

class Classifier(Chain):
def __init__(self, predictor):
super(Classifier, self).__init__()
with self.init_scope():
self.predictor = predictor def __call__(self, x, t):
y = self.predictor(x)
loss = F.softmax_cross_entropy(y, t)
accuracy = F.accuracy(y, t)
report({'loss': loss, 'accuracy': accuracy}, self)
return loss

这个分类器类计算准确性和损失,并返回损失值。参数对x和t对应于数据集中的每个示例(图像和标签的元组)。 softmax_cross_entropy()计算给定预测和基准真实标签的损失值。 accuracy() 计算预测准确度。我们可以为分类器的一个实例设置任意的预测器连接。

report() 函数向训练器报告损失和准确度。收集训练统计信息的具体机制参见 Reporter. 您也可以采用类似的方式收集其他类型的观测值,如激活统计。

请注意,类似上面的分类器的类被定义为chainer.links.Classifier。因此,我们将使用此预定义的Classifier连接而不是使用上面的示例。

model = L.Classifier(MLP(100, 10))  # the input size, 784, is inferred
optimizer = optimizers.SGD()
optimizer.setup(model)

现在我们可以建立一个训练器对象。

updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (20, 'epoch'), out='result')

第二个参数(20,’epoch’)表示训练的持续时间。我们可以使用epoch或迭代作为单位。在这种情况下,我们通过遍历训练集20次来训练多层感知器。

为了调用训练循环,我们只需调用run()方法。

这个方法执行整个训练序列。

上面的代码只是优化了参数。在大多数情况下,我们想看看培训的进展情况,我们可以在调用run方法之前使用扩展插入。

trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
epoch       main/accuracy  validation/main/accuracy
[J total [..................................................] 0.83%
this epoch [########..........................................] 16.67%
100 iter, 0 epoch / 20 epochs
inf iters/sec. Estimated time to finish: 0:00:00.
[4A[J total [..................................................] 1.67%
this epoch [################..................................] 33.33%
200 iter, 0 epoch / 20 epochs
270.19 iters/sec. Estimated time to finish: 0:00:43.672168.
[4A[J total [#.................................................] 2.50%
this epoch [#########################.........................] 50.00%
300 iter, 0 epoch / 20 epochs
271.99 iters/sec. Estimated time to finish: 0:00:43.017048.
[4A[J total [#.................................................] 3.33%
this epoch [#################################.................] 66.67%
400 iter, 0 epoch / 20 epochs

【chainer框架】【pytorch框架】的更多相关文章

  1. 《深度学习框架PyTorch:入门与实践》的Loss函数构建代码运行问题

    在学习陈云的教程<深度学习框架PyTorch:入门与实践>的损失函数构建时代码如下: 可我运行如下代码: output = net(input) target = Variable(t.a ...

  2. PyTorch框架+Python 3面向对象编程学习笔记

    一.CNN情感分类中的面向对象部分 sparse.py super(Embedding, self).__init__() 表示需要父类初始化,即要运行父类的_init_(),如果没有这个,则要自定义 ...

  3. 手写数字识别 卷积神经网络 Pytorch框架实现

    MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...

  4. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

  5. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  6. Pyinstaller打包Pytorch框架所遇到的问题

    目录 前言 基本流程 一.安装Pyinstaller 和 测试Hello World 二.打包整个项目,在本机上调试生成exe 三.在新电脑上测试 参考资料 前言   第一次尝试用Pyinstalle ...

  7. PyTorch框架起步

    PyTorch框架基本处理操作 part1:pytorch简介与安装 CPU版本安装:pip install torch1.3.0+cpu torchvision0.4.1+cpu -f https: ...

  8. 介绍开源的.net通信框架NetworkComms框架 源码分析

    原文网址: http://www.cnblogs.com/csdev Networkcomms 是一款C# 语言编写的TCP/UDP通信框架  作者是英国人  以前是收费的 售价249英镑 我曾经花了 ...

  9. web 框架的本质及自定义web框架 模板渲染jinja2 mvc 和 mtv框架 Django框架的下载安装 基于Django实现的一个简单示例

    Django基础一之web框架的本质 本节目录 一 web框架的本质及自定义web框架 二 模板渲染JinJa2 三 MVC和MTV框架 四 Django的下载安装 五 基于Django实现的一个简单 ...

  10. 那些年读过的书《Java并发编程实战》和《Java并发编程的艺术》三、任务执行框架—Executor框架小结

    <Java并发编程实战>和<Java并发编程的艺术>           Executor框架小结 1.在线程中如何执行任务 (1)任务执行目标: 在正常负载情况下,服务器应用 ...

随机推荐

  1. C# XCOPY命令 预先生成事件命令行”和“后期生成事件命令行”

    $(ConfigurationName) 当前项目配置的名称(例如,“Debug|Any CPU”). $(OutDir) 输出文件目录的路径,相对于项目目录.这解析为“输出目录”属性的值.它包括尾部 ...

  2. [Application]Ctrl+C终止程序代码

    代码如下: #include <stdio.h> #include <stdlib.h> #include <iostream> #include <sign ...

  3. 在ubuntu下安装sourceinsight

    执行更新与安装 wine: # sudo apt-get update # sudo apt-get install wine 下载SourceInsight,用wine来安装: 执行:wine so ...

  4. Redis 响应延迟问题排查

    计算延迟时间 如果你正在经历响应延迟问题,你或许能够根据应用程序的具体情况算出它的延迟响应时间,或者你的延迟问题非常明显,宏观看来,一目了然.不管怎样吧,用redis-cli可以算出一台Redis 服 ...

  5. svn 脚本替换

    #!/bin/bashfor i in `find /home/20180629tmp/svnfwq/uadminv4 -name .svn` do echo $i aa=`dirname $i` b ...

  6. php插入代码数据库

    插入代码:<?php $con = mysql_connect("localhost","peter","abc123"); if ( ...

  7. 阿里云ecs开启x11图形化桌面

    阿里云帮助文档:https://www.alibabacloud.com/help/zh/faq-detail/41227.htm 安装云服务器 ECS CentOS 7 图形化桌面 以安装 MATE ...

  8. 011杰信-创建购销合同Excel报表系列-4-建立合同货物(修改,删除):合同货物表是购销合同表的子表

    前面的一篇文章做的是修改删除,这篇文章做的是合同货物的修改和删除. 业务功能如下:

  9. Unity5 AssetBundle打包加载及服务器加载

    Assetbundle为资源包不是资源 打包1:通过脚本指定打包 AssetBundleBuild ab = new AssetBundleBuild                         ...

  10. PHP实现对站点内容外部链接的过滤方法

    熟悉SEO的朋友都知道,对于网站外部链接失效的情况如果链接带有rel="nofollow"属性可以避免不必要的损失.本文就以实例形式演示了PHP实现对站点内容外部链接的过滤方法.具 ...