import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import loss as gloss, nn
from mxnet.gluon import data as gdata
import time
import sys net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
# Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
# (批量大小,通道 * 高 * 宽)形状的输入。
nn.Dense(120, activation='sigmoid'),
nn.Dense(84, activation='sigmoid'),
nn.Dense(10)) X = nd.random.uniform(shape=(1, 1, 28, 28))
net.initialize()
for layer in net:
X = layer(X)
print(layer.name, 'output shape:\t', X.shape) # batch_size = 256
# train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器(在cpu上)
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,
num_workers=num_workers) def try_gpu4():
try:
ctx = mx.gpu()
_ = nd.zeros((1,), ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx ctx = try_gpu4() def accuracy(y_hat,y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net, ctx):
acc = nd.array([0], ctx=ctx)
for X, y in data_iter:
# 如果 ctx 是 GPU,将数据复制到 GPU 上。
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
acc += accuracy(net(X), y)
return acc.asscalar() / len(data_iter) def train(net, train_iter, test_iter, batch_size, trainer, ctx,
num_epochs):
print('training on', ctx)
loss = gloss.SoftmaxCrossEntropyLoss()
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, start = 0, 0, time.time()
for X, y in train_iter:
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
trainer.step(batch_size)
train_l_sum += l.mean().asscalar()
train_acc_sum += accuracy(y_hat, y)
test_acc = evaluate_accuracy(test_iter, net, ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),
train_acc_sum / len(train_iter),
test_acc, time.time() - start)) lr, num_epochs = 0.9, 200
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier()) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

LeNet 分类 FashionMNIST的更多相关文章

  1. AlexNet 分类 FashionMNIST

    from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...

  2. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  3. gluon实现softmax分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss,nn from mxnet.gluon import data a ...

  4. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  5. CNN卷积神经网络详解

    前言   在学计算机视觉的这段时间里整理了不少的笔记,想着就把这些笔记再重新整理出来,然后写成Blog和大家一起分享.目前的计划如下(以下网络全部使用Pytorch搭建): 专题一:计算机视觉基础 介 ...

  6. 不用写代码就能实现深度学习?手把手教你用英伟达 DIGITS 解决图像分类问题

    2006年,机器学习界泰斗Hinton,在Science上发表了一篇使用深度神经网络进行维数约简的论文 ,自此,神经网络再次走进人们的视野,进而引发了一场深度学习革命.深度学习之所以如此受关注,是因为 ...

  7. PyTorch 介绍 | BUILD THE NEURAL NETWORK

    神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...

  8. Tensorflow学习教程------lenet多标签分类

    本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...

  9. Tensorflow学习教程------实现lenet并且进行二分类

    #coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...

随机推荐

  1. MySQL触发器基本介绍

    基本简介: 1.触发器可以让你在执行insert,update,delete语句的时候,执行一些特定的操作.并且可以在MySQL中指定是在sql语句执行前触发还是执行后触发. 2.触发器没有返回值. ...

  2. WP手机短信导出方法和MSG格式文件阅读器的实现

    最近想起来自己一直扔在抽屉里的Nokia920T里还存着珍贵的短信,觉得把它导出来存到电脑上比较稳妥也方便阅读.经过搜索找到一下方法:到应用市场里搜索contacts+message backup,安 ...

  3. 借助 CORS 从 JavaScript 使用 API 应用

    应用服务提供内置的跨域资源共享 (CORS) 支持,可让 JavaScript 客户端对 API 应用中托管的 API 进行跨域调用.应用服务允许配置对 API 的 CORS 访问,无需在 API 中 ...

  4. TortoiseGit记住用户名&密码

    配置并安装好git之后鼠标右键: 在全局配置文件末尾添加一行: [credential] helper = store *主意保存时以utf-8格式保存,否则中文可能会乱码,这样下次只需输入一次用户名 ...

  5. ElementUI组件库常见方法及问题汇总(持续更新)

    本文主要介绍在使用ElementUI组件库的时候,常遇见的问题及使用到的方法,汇总记录便于查找. 1.表单 阻止表单的默认提交 <!-- @submit.native.prevent --> ...

  6. Apache Maven 3.5.0配置安装

    1.maven 3.5 下载地址:http://maven.apache.org/download.cgi 2.下载了解压到 3.配置环境变量 4.测试看是否安装成功 5.maven配置(全局配置,用 ...

  7. 关于zepto 选择特定值的input 报错问题

    zepto 选择特定值的input  时,需要用单引号或双引号引用这个特定值 否则 报错

  8. 归并排序算法Matlab实现

    Matlab一段时间不用发现有些生疏了,就用归并排序来练手吧.代码没啥说的,百度有很多.写篇博客,主要是记下matlab语法,以后备查.   测试代码 srcData = [1,3,2,4,6,5,8 ...

  9. Perl学习笔记(3)----遍历哈希表的一个容易疏忽的地方

    今天做 Google的 Code Jam 上的一道题目:https://code.google.com/codejam/contest/351101/dashboard#s=p2,用Perl语言解答的 ...

  10. oozie安装总结

    偶然的机会,去面试的时候听面试官讲他们的调度系统是基于hue+oozie,以前一直没有接触过,今天趁有空,尝试一下oozie 1.环境说明 cat /etc/issue  CentOS release ...