LeNet 分类 FashionMNIST
import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import loss as gloss, nn
from mxnet.gluon import data as gdata
import time
import sys net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
# Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
# (批量大小,通道 * 高 * 宽)形状的输入。
nn.Dense(120, activation='sigmoid'),
nn.Dense(84, activation='sigmoid'),
nn.Dense(10)) X = nd.random.uniform(shape=(1, 1, 28, 28))
net.initialize()
for layer in net:
X = layer(X)
print(layer.name, 'output shape:\t', X.shape) # batch_size = 256
# train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器(在cpu上)
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,
num_workers=num_workers) def try_gpu4():
try:
ctx = mx.gpu()
_ = nd.zeros((1,), ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx ctx = try_gpu4() def accuracy(y_hat,y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net, ctx):
acc = nd.array([0], ctx=ctx)
for X, y in data_iter:
# 如果 ctx 是 GPU,将数据复制到 GPU 上。
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
acc += accuracy(net(X), y)
return acc.asscalar() / len(data_iter) def train(net, train_iter, test_iter, batch_size, trainer, ctx,
num_epochs):
print('training on', ctx)
loss = gloss.SoftmaxCrossEntropyLoss()
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, start = 0, 0, time.time()
for X, y in train_iter:
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
trainer.step(batch_size)
train_l_sum += l.mean().asscalar()
train_acc_sum += accuracy(y_hat, y)
test_acc = evaluate_accuracy(test_iter, net, ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),
train_acc_sum / len(train_iter),
test_acc, time.time() - start)) lr, num_epochs = 0.9, 200
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier()) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

LeNet 分类 FashionMNIST的更多相关文章
- AlexNet 分类 FashionMNIST
from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...
- gluon 实现多层感知机MLP分类FashionMNIST
from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...
- gluon实现softmax分类FashionMNIST
from mxnet import gluon,init from mxnet.gluon import loss as gloss,nn from mxnet.gluon import data a ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- CNN卷积神经网络详解
前言 在学计算机视觉的这段时间里整理了不少的笔记,想着就把这些笔记再重新整理出来,然后写成Blog和大家一起分享.目前的计划如下(以下网络全部使用Pytorch搭建): 专题一:计算机视觉基础 介 ...
- 不用写代码就能实现深度学习?手把手教你用英伟达 DIGITS 解决图像分类问题
2006年,机器学习界泰斗Hinton,在Science上发表了一篇使用深度神经网络进行维数约简的论文 ,自此,神经网络再次走进人们的视野,进而引发了一场深度学习革命.深度学习之所以如此受关注,是因为 ...
- PyTorch 介绍 | BUILD THE NEURAL NETWORK
神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...
- Tensorflow学习教程------lenet多标签分类
本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...
- Tensorflow学习教程------实现lenet并且进行二分类
#coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...
随机推荐
- 项目开发-->基础功能汇总
祭奠曾经逝去的青春…… 1.基础功能汇总-->身份认证及用户登录模块 2.基础功能汇总-->一键登录功能汇总 3.堆和栈 4.变量
- Struts2 学习(一)
一.Struts 介绍 1.Struts2的概述 1.早期开发模型Servlet+JSP+JavaBean(Model2)显得力不从心: 流程凌乱.数据传递无序.缺乏辅助功能. 2.MVC模式的轻量级 ...
- Python基础学习总结(二)
2.列表简介 Python有内置的一种数据类型列表:list. list是一种有序的集合. 列表由一系列按特定顺序排列的元素组合.用方括号 [ ] 来表示. list里面的元素的数据类型可以不同,比如 ...
- App Not Responsing
参见原文:http://rayleeya.iteye.com/blog/1955657 inputDispatchingTimedOut contentProviderNotResponsing se ...
- CSS单行、多行文本溢出显示省略号(……)解决方案
单行文本溢出显示省略号(-) text-overflow:ellipsis-----部分浏览器还需要加宽度width属性 .ellipsis{ overflow: hidden; text-overf ...
- csharp: Converting chinese character to Unicode
Function chinese2unicode(Str) Dim Str_one:Str_one = "" Dim Str_unicode:Str_unicode = " ...
- js中常见面试问题-笔记
原文参考https://mp.weixin.qq.com/s/mCVL6qI33XeTg4YGIKt-JQ 1.事件代理给父元素添加事件,利用事件冒泡原理,在根据e.target来获取子元素<u ...
- Mavn 使用介绍
1 Maven介绍 1.1 项目开发中遇到的问题 1.都是同样的代码,为什么在我的机器上可以编译执行,而在他的机器上就不行? 2.为什么在我的机器上可以正常打包,而配置管理员却打不出来? 3.项目 ...
- 统计Redis中各种数据的大小
如果 MySQL 数据库比较大的话,很容易就能查出是哪些表占用的空间: 不过如果 Redis 内存比较大的话, […]
- Android 初识Retrofit
什么是 Retrofit ? Retrofit 是一套 RESTful 架构的 Android(Java) 客户端实现,基于注解,提供 JSON to POJO(Plain Ordinary Java ...