import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import loss as gloss, nn
from mxnet.gluon import data as gdata
import time
import sys net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
# Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
# (批量大小,通道 * 高 * 宽)形状的输入。
nn.Dense(120, activation='sigmoid'),
nn.Dense(84, activation='sigmoid'),
nn.Dense(10)) X = nd.random.uniform(shape=(1, 1, 28, 28))
net.initialize()
for layer in net:
X = layer(X)
print(layer.name, 'output shape:\t', X.shape) # batch_size = 256
# train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器(在cpu上)
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,
num_workers=num_workers) def try_gpu4():
try:
ctx = mx.gpu()
_ = nd.zeros((1,), ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx ctx = try_gpu4() def accuracy(y_hat,y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net, ctx):
acc = nd.array([0], ctx=ctx)
for X, y in data_iter:
# 如果 ctx 是 GPU,将数据复制到 GPU 上。
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
acc += accuracy(net(X), y)
return acc.asscalar() / len(data_iter) def train(net, train_iter, test_iter, batch_size, trainer, ctx,
num_epochs):
print('training on', ctx)
loss = gloss.SoftmaxCrossEntropyLoss()
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, start = 0, 0, time.time()
for X, y in train_iter:
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
trainer.step(batch_size)
train_l_sum += l.mean().asscalar()
train_acc_sum += accuracy(y_hat, y)
test_acc = evaluate_accuracy(test_iter, net, ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),
train_acc_sum / len(train_iter),
test_acc, time.time() - start)) lr, num_epochs = 0.9, 200
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier()) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

LeNet 分类 FashionMNIST的更多相关文章

  1. AlexNet 分类 FashionMNIST

    from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...

  2. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  3. gluon实现softmax分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss,nn from mxnet.gluon import data a ...

  4. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  5. CNN卷积神经网络详解

    前言   在学计算机视觉的这段时间里整理了不少的笔记,想着就把这些笔记再重新整理出来,然后写成Blog和大家一起分享.目前的计划如下(以下网络全部使用Pytorch搭建): 专题一:计算机视觉基础 介 ...

  6. 不用写代码就能实现深度学习?手把手教你用英伟达 DIGITS 解决图像分类问题

    2006年,机器学习界泰斗Hinton,在Science上发表了一篇使用深度神经网络进行维数约简的论文 ,自此,神经网络再次走进人们的视野,进而引发了一场深度学习革命.深度学习之所以如此受关注,是因为 ...

  7. PyTorch 介绍 | BUILD THE NEURAL NETWORK

    神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...

  8. Tensorflow学习教程------lenet多标签分类

    本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...

  9. Tensorflow学习教程------实现lenet并且进行二分类

    #coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...

随机推荐

  1. C#语言-06.XML

    a. XML:称为可扩展标记性语言,它主要用于描述数据 i. 特点: . XML 中用于描述数据的各个节点可以自由扩展 . XML 文件中的节点区分大小写 . XML 中的每对标记通常被称为节点,它们 ...

  2. Spring Cloud个组件原理

    引言 面试中面试官喜欢问组件的实现原理,尤其是常用技术,我们平时使用了SpringCloud还需要了解它的实现原理,这样不仅起到举一反三的作用,还能帮助轻松应对各种问题及有针对的进行扩展.以下是 课程 ...

  3. Gradle sync failed: Cannot set the value of read-only property 'outputFile'

    错误 Gradle sync failed: Cannot set the value of read-only property 'outputFile' 原因 gradle打包,自定义apk名称代 ...

  4. rpm的参数

      rpm 包的参数如下: -e 卸载rpm包 -q 查询已安装的软件信息 -i 安装rpm包 -u 升级rpm包 --replacepkgs 重新安装rpm包 --justdb 升级数据库,不修改文 ...

  5. Git版本控制工具(1)

    学习Git的最佳资料网站: https://www.liaoxuefeng.com/wiki/0013739516305929606dd18361248578c67b8067c8c017b000/ 这 ...

  6. 二、NAT(地址转换模式)

    刚刚我们说到,如果你的网络ip资源紧缺,但是你又希望你的虚拟机能够联网,这时候NAT模式是最好的选择.NAT模式借助虚拟NAT设备和虚拟DHCP服务器,使得虚拟机可以联网.其网络结构如下图所示: NA ...

  7. CSS之after与before的content 和 attr 配合使用

    content 和 attr 配合使用 如果你不想把content内容在CSS里写死,那你可以使用attr表达式来从页面元素中动态的获取内容: /* <div data-line="1 ...

  8. google搜索使用技巧

    1.输入框所有空格都被理解为加号2.搜索多个单词时,需要加上引号,会当字符串处理3.使用-(减号)剔除指定条件,如:'mongdb'-'nodejs'4.可以使用通配符,如'vue *'5.在指定网站 ...

  9. 如何开发一个Servlet

    1 如何开发一个Servlet 1.1 步骤: 1)编写java类,继承HttpServlet类 2)重新doGet和doPost方法 3)Servlet程序交给tomcat服务器运行!! 3.1 s ...

  10. Open images from USB camera on linux using V4L2 with OpenCV

    I have always been using OpenCV's VideoCapture API to capture images from webcam or USB cameras. Ope ...