import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import loss as gloss, nn
from mxnet.gluon import data as gdata
import time
import sys net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
# Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
# (批量大小,通道 * 高 * 宽)形状的输入。
nn.Dense(120, activation='sigmoid'),
nn.Dense(84, activation='sigmoid'),
nn.Dense(10)) X = nd.random.uniform(shape=(1, 1, 28, 28))
net.initialize()
for layer in net:
X = layer(X)
print(layer.name, 'output shape:\t', X.shape) # batch_size = 256
# train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器(在cpu上)
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,
num_workers=num_workers) def try_gpu4():
try:
ctx = mx.gpu()
_ = nd.zeros((1,), ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx ctx = try_gpu4() def accuracy(y_hat,y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net, ctx):
acc = nd.array([0], ctx=ctx)
for X, y in data_iter:
# 如果 ctx 是 GPU,将数据复制到 GPU 上。
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
acc += accuracy(net(X), y)
return acc.asscalar() / len(data_iter) def train(net, train_iter, test_iter, batch_size, trainer, ctx,
num_epochs):
print('training on', ctx)
loss = gloss.SoftmaxCrossEntropyLoss()
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, start = 0, 0, time.time()
for X, y in train_iter:
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
trainer.step(batch_size)
train_l_sum += l.mean().asscalar()
train_acc_sum += accuracy(y_hat, y)
test_acc = evaluate_accuracy(test_iter, net, ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),
train_acc_sum / len(train_iter),
test_acc, time.time() - start)) lr, num_epochs = 0.9, 200
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier()) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

LeNet 分类 FashionMNIST的更多相关文章

  1. AlexNet 分类 FashionMNIST

    from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...

  2. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  3. gluon实现softmax分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss,nn from mxnet.gluon import data a ...

  4. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  5. CNN卷积神经网络详解

    前言   在学计算机视觉的这段时间里整理了不少的笔记,想着就把这些笔记再重新整理出来,然后写成Blog和大家一起分享.目前的计划如下(以下网络全部使用Pytorch搭建): 专题一:计算机视觉基础 介 ...

  6. 不用写代码就能实现深度学习?手把手教你用英伟达 DIGITS 解决图像分类问题

    2006年,机器学习界泰斗Hinton,在Science上发表了一篇使用深度神经网络进行维数约简的论文 ,自此,神经网络再次走进人们的视野,进而引发了一场深度学习革命.深度学习之所以如此受关注,是因为 ...

  7. PyTorch 介绍 | BUILD THE NEURAL NETWORK

    神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...

  8. Tensorflow学习教程------lenet多标签分类

    本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...

  9. Tensorflow学习教程------实现lenet并且进行二分类

    #coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...

随机推荐

  1. (C++学习)关于CString的一些疑问

    #include <iostream> #include <string> #include <afx.h> #include <vector> usi ...

  2. java--集合框架总结1--set总结

    一.集合框架的概述. 基础的数据结构有数组,链表,栈,队列,二叉树等,java中的数据结构,利用了这些基本的数据结构分别实现了很丰富的集合框架类型,下面简单地总结下关于java集合框架的基础内容,在进 ...

  3. java 类与对象基础整理

    之前学习javaSE的时候,没有针对性地对对类与对象的一些基础进行整理,下面这些内容是笔记内容整理后的,希望以后自己可以通过这些博客时常复习! 一.类and对象的基础 类似于类的生命啊,类与对象的关系 ...

  4. 六、cent OS其它常用命令

    进入根目录下的laycloud的目录cd /laycloud 进入当前目录下的目录cd laycloud 查看某个目录下的内容ls /laycloud 查看当前目录下的内容ls 查看当前目录下的内容读 ...

  5. http所有请求头在Console中打印

    1.目标:将http中的请求头全部打印在Console中 2.基本语句 //1.获得指定的头 String header = response.getHeader("User-Agert&q ...

  6. 【Android】12.0 UI开发(三)——列表控件ListView的简单实现2

    1.0 由于书上内容,已经和实际编程的兼容性已经不太友好,重写了项目,用于进一步学习列表控件ListView. 2.0 新建项目ListViewTest,其中文件目录如下: 3.0 ActivityC ...

  7. html标签篇(2)

    上次讲到<a>标签,并没有细说a标签用法. <!DOCTYPE html> <html lang="en"> <head>  < ...

  8. JS原生隐藏显示图片,点击切换图片的效果

    今天要说的内容,看标题就都能知道了!所有知识点一览无遗啊!咱们今天的东西,是纯纯的原生JS代码, 我先说一下要求, 1.有两个按钮,内容为显示,和换, 2.当点击显示的时候,按钮文字变成隐藏,同时图片 ...

  9. 关于Telnet使用

    一.telnet作用 可以使用telnet检查 ip port的连通性 语法: telnet ip port 注意点: 1.使用前先操作系统安装telnet: 2.ip port 中间没有 “:”: ...

  10. WinAPI: GetCurrentThread、GetCurrentThreadId、GetCurrentProcess、GetCurrentProcessId

    原文:http://www.cnblogs.com/del/archive/2008/03/10/1098311.html {返回当前线程的虚拟句柄} GetCurrentThread: THandl ...