LeNet 分类 FashionMNIST
import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import loss as gloss, nn
from mxnet.gluon import data as gdata
import time
import sys net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
# Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
# (批量大小,通道 * 高 * 宽)形状的输入。
nn.Dense(120, activation='sigmoid'),
nn.Dense(84, activation='sigmoid'),
nn.Dense(10)) X = nd.random.uniform(shape=(1, 1, 28, 28))
net.initialize()
for layer in net:
X = layer(X)
print(layer.name, 'output shape:\t', X.shape) # batch_size = 256
# train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器(在cpu上)
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,
num_workers=num_workers) def try_gpu4():
try:
ctx = mx.gpu()
_ = nd.zeros((1,), ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx ctx = try_gpu4() def accuracy(y_hat,y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net, ctx):
acc = nd.array([0], ctx=ctx)
for X, y in data_iter:
# 如果 ctx 是 GPU,将数据复制到 GPU 上。
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
acc += accuracy(net(X), y)
return acc.asscalar() / len(data_iter) def train(net, train_iter, test_iter, batch_size, trainer, ctx,
num_epochs):
print('training on', ctx)
loss = gloss.SoftmaxCrossEntropyLoss()
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, start = 0, 0, time.time()
for X, y in train_iter:
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
trainer.step(batch_size)
train_l_sum += l.mean().asscalar()
train_acc_sum += accuracy(y_hat, y)
test_acc = evaluate_accuracy(test_iter, net, ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),
train_acc_sum / len(train_iter),
test_acc, time.time() - start)) lr, num_epochs = 0.9, 200
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier()) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

LeNet 分类 FashionMNIST的更多相关文章
- AlexNet 分类 FashionMNIST
from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...
- gluon 实现多层感知机MLP分类FashionMNIST
from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...
- gluon实现softmax分类FashionMNIST
from mxnet import gluon,init from mxnet.gluon import loss as gloss,nn from mxnet.gluon import data a ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- CNN卷积神经网络详解
前言 在学计算机视觉的这段时间里整理了不少的笔记,想着就把这些笔记再重新整理出来,然后写成Blog和大家一起分享.目前的计划如下(以下网络全部使用Pytorch搭建): 专题一:计算机视觉基础 介 ...
- 不用写代码就能实现深度学习?手把手教你用英伟达 DIGITS 解决图像分类问题
2006年,机器学习界泰斗Hinton,在Science上发表了一篇使用深度神经网络进行维数约简的论文 ,自此,神经网络再次走进人们的视野,进而引发了一场深度学习革命.深度学习之所以如此受关注,是因为 ...
- PyTorch 介绍 | BUILD THE NEURAL NETWORK
神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...
- Tensorflow学习教程------lenet多标签分类
本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...
- Tensorflow学习教程------实现lenet并且进行二分类
#coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...
随机推荐
- 深入理解JavaScript系列(48):对象创建模式(下篇)
介绍 本篇主要是介绍创建对象方面的模式的下篇,利用各种技巧可以极大地避免了错误或者可以编写出非常精简的代码. 模式6:函数语法糖 函数语法糖是为一个对象快速添加方法(函数)的扩展,这个主要是利用pro ...
- java 散列运算浅分析 hash()
文章部分代码图片和总结来自参考资料 哈希和常用的方法 散列,从中文字面意思就很好理解了,分散排列,我们知道数组地址空间连续,查找快,增删慢,而链表,查找慢,增删快,两者结合起来形成散列 ...
- C#在不同平台下DLL的引用问题
缘起 很多时候,我们需要引用在不同平台下的DLL,32位(X86)和64位(X64).如果平台错误,在C#中会引发BadImageFormatException异常. 解决思路 我们同时不能添加不同平 ...
- 版本控制器之SVN
1.开发中的实际问题 1.1 小明负责的模块就要完成了,就在即将Release之前的一瞬间,电脑突然蓝屏,硬盘光荣牺牲!几个月来的努力付之东流——需求之一:备份! 1.2 这个项目中需要一个很复杂的功 ...
- gradle -v不是外部命令, 内部命令,或批处理文件
安装完gradle并且配置了环境变量之后,使用windos+R,cmd 进入Dos命令gradle -v检测版本号出现了: 1 --首先找到gradle文件所在目录 一般是在C:\Users\su\. ...
- java web 之Session
1.Session简单介绍 由于Http是无状态的协议,所以服务端需要记录用户的状态时,就需要某种机制来识别具体的用户,实现这个机制的方式就是session. 典型的场景比如购物车,当你点击下单按钮时 ...
- android 动态库死机调试方法 .
原地址:http://blog.csdn.net/andyhuabing/article/details/7074979 这两种方法都不是我发明了,都是网上一些高手公共出来的调试方法,无奈找不到出处的 ...
- poj 1276(多重背包+最接近)
http://www.cnblogs.com/rainydays/archive/2013/03/08/2950258.html http://www.cnblogs.com/ziyi--caolu/ ...
- Node.js开发——MongoDB与Mongoose
为了保存网站的用户数据和业务数据,通常需要一个数据库.MongoDB和Node.js特别般配,因为MongoDB是基于文档的非关系型数据库,文档是按BSON(JSON的轻量化二进制格式)存储的,增删改 ...
- 第11章 Media Queries 与Responsive 设计
Media Queries--媒体类型(一) 随着科学技术不断的向前发展,网页的浏览终端越来越多样化,用户可以通过:宽屏电视.台式电脑.笔记本电脑.平板电脑和智能手机来访问你的网站.尽管你无法保证一个 ...