import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import loss as gloss, nn
from mxnet.gluon import data as gdata
import time
import sys net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
# Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
# (批量大小,通道 * 高 * 宽)形状的输入。
nn.Dense(120, activation='sigmoid'),
nn.Dense(84, activation='sigmoid'),
nn.Dense(10)) X = nd.random.uniform(shape=(1, 1, 28, 28))
net.initialize()
for layer in net:
X = layer(X)
print(layer.name, 'output shape:\t', X.shape) # batch_size = 256
# train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器(在cpu上)
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,
num_workers=num_workers) def try_gpu4():
try:
ctx = mx.gpu()
_ = nd.zeros((1,), ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx ctx = try_gpu4() def accuracy(y_hat,y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net, ctx):
acc = nd.array([0], ctx=ctx)
for X, y in data_iter:
# 如果 ctx 是 GPU,将数据复制到 GPU 上。
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
acc += accuracy(net(X), y)
return acc.asscalar() / len(data_iter) def train(net, train_iter, test_iter, batch_size, trainer, ctx,
num_epochs):
print('training on', ctx)
loss = gloss.SoftmaxCrossEntropyLoss()
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, start = 0, 0, time.time()
for X, y in train_iter:
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
trainer.step(batch_size)
train_l_sum += l.mean().asscalar()
train_acc_sum += accuracy(y_hat, y)
test_acc = evaluate_accuracy(test_iter, net, ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),
train_acc_sum / len(train_iter),
test_acc, time.time() - start)) lr, num_epochs = 0.9, 200
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier()) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

LeNet 分类 FashionMNIST的更多相关文章

  1. AlexNet 分类 FashionMNIST

    from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...

  2. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  3. gluon实现softmax分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss,nn from mxnet.gluon import data a ...

  4. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  5. CNN卷积神经网络详解

    前言   在学计算机视觉的这段时间里整理了不少的笔记,想着就把这些笔记再重新整理出来,然后写成Blog和大家一起分享.目前的计划如下(以下网络全部使用Pytorch搭建): 专题一:计算机视觉基础 介 ...

  6. 不用写代码就能实现深度学习?手把手教你用英伟达 DIGITS 解决图像分类问题

    2006年,机器学习界泰斗Hinton,在Science上发表了一篇使用深度神经网络进行维数约简的论文 ,自此,神经网络再次走进人们的视野,进而引发了一场深度学习革命.深度学习之所以如此受关注,是因为 ...

  7. PyTorch 介绍 | BUILD THE NEURAL NETWORK

    神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...

  8. Tensorflow学习教程------lenet多标签分类

    本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...

  9. Tensorflow学习教程------实现lenet并且进行二分类

    #coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...

随机推荐

  1. Silverlight & Blend动画设计系列十三:三角函数(Trigonometry)动画之飘落的雪花(Falling Snow)

    平时我们所看到的雪花(Falling Snow)飘飘的效果实际上也是一个动画,是由许多的动画对象共同完成的一个界面效果.对于不同大小的雪片可以通过缩放变换(ScaleTransform)功能特性确定, ...

  2. springboot伪静态

    在日常网站访问中,会把动态地址改造成伪静态地址. 例如: 访问新闻栏目 /col/1/,这是原有地址,如果这样访问,不利于搜索引擎检索收录,同时安全性也不是很好. 改造之后: /col/1.html. ...

  3. shell输入与输出功能

    一.shell输入功能 1. 2. 二.shell输出功能 1.字符界面前景颜色 2.字符界面背景颜色 3.其他输出命令 ①cat 输出文本,将文本的格式也输出 ②tee 既输出,也保存到文件里 ③m ...

  4. python的爬虫

    requests库的安装 https://blog.csdn.net/xiaokuang5020/article/details/80580631 Response对象属性 属性 说明 r.statu ...

  5. Open Live Writer 安装和博客账号配置

    打开Open Live Writer就像您的博客的Word一样.打开Live Writer是一个功能强大,轻量级的博客编辑器,允许您创建博客文章,添加照片和视频,然后发布到您的网站. Open Liv ...

  6. 初步理解require.js模块化编程

    初步理解require.js模块化编程 一.Javascript模块化编程 目前,通行的Javascript模块规范共有两种:CommonJS和AMD. 1.commonjs 2009年,美国程序员R ...

  7. [转载]hive中order by,sort by, distribute by, cluster by作用以及用法

    1. order by     Hive中的order by跟传统的sql语言中的order by作用是一样的,会对查询的结果做一次全局排序,所以说,只有hive的sql中制定了order by所有的 ...

  8. 【转】ArcGIS Server10.1安装常见问题及解决方案

    转载自:http://www.higis.cn/Tech/tech/tId/85/ 最近因为更换系统的原因,重新安装了ArcGISServer 10.1.过程中遇到了几个小问题,虽然都一一解决了,但也 ...

  9. EF+Oracle

    一个小项目,设计到几十张表,但都是简单的增删改查,所以呢,想偷懒用EF. 结果,在.NET4.0下,死活都不行.最后在Oracle官方找到demo,上面清清楚楚的写着必须>NET4.5. 看着E ...

  10. Python爬虫教程-22-lxml-etree和xpath配合使用

    Python爬虫教程-22-lxml-etree和xpath配合使用 lxml:python 的HTML/XML的解析器 官网文档:https://lxml.de/ 使用前,需要安装安 lxml 包 ...