LeNet 分类 FashionMNIST
import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import loss as gloss, nn
from mxnet.gluon import data as gdata
import time
import sys net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
# Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
# (批量大小,通道 * 高 * 宽)形状的输入。
nn.Dense(120, activation='sigmoid'),
nn.Dense(84, activation='sigmoid'),
nn.Dense(10)) X = nd.random.uniform(shape=(1, 1, 28, 28))
net.initialize()
for layer in net:
X = layer(X)
print(layer.name, 'output shape:\t', X.shape) # batch_size = 256
# train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器(在cpu上)
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,
num_workers=num_workers) def try_gpu4():
try:
ctx = mx.gpu()
_ = nd.zeros((1,), ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx ctx = try_gpu4() def accuracy(y_hat,y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net, ctx):
acc = nd.array([0], ctx=ctx)
for X, y in data_iter:
# 如果 ctx 是 GPU,将数据复制到 GPU 上。
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
acc += accuracy(net(X), y)
return acc.asscalar() / len(data_iter) def train(net, train_iter, test_iter, batch_size, trainer, ctx,
num_epochs):
print('training on', ctx)
loss = gloss.SoftmaxCrossEntropyLoss()
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, start = 0, 0, time.time()
for X, y in train_iter:
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
trainer.step(batch_size)
train_l_sum += l.mean().asscalar()
train_acc_sum += accuracy(y_hat, y)
test_acc = evaluate_accuracy(test_iter, net, ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),
train_acc_sum / len(train_iter),
test_acc, time.time() - start)) lr, num_epochs = 0.9, 200
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier()) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

LeNet 分类 FashionMNIST的更多相关文章
- AlexNet 分类 FashionMNIST
from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...
- gluon 实现多层感知机MLP分类FashionMNIST
from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...
- gluon实现softmax分类FashionMNIST
from mxnet import gluon,init from mxnet.gluon import loss as gloss,nn from mxnet.gluon import data a ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- CNN卷积神经网络详解
前言 在学计算机视觉的这段时间里整理了不少的笔记,想着就把这些笔记再重新整理出来,然后写成Blog和大家一起分享.目前的计划如下(以下网络全部使用Pytorch搭建): 专题一:计算机视觉基础 介 ...
- 不用写代码就能实现深度学习?手把手教你用英伟达 DIGITS 解决图像分类问题
2006年,机器学习界泰斗Hinton,在Science上发表了一篇使用深度神经网络进行维数约简的论文 ,自此,神经网络再次走进人们的视野,进而引发了一场深度学习革命.深度学习之所以如此受关注,是因为 ...
- PyTorch 介绍 | BUILD THE NEURAL NETWORK
神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...
- Tensorflow学习教程------lenet多标签分类
本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...
- Tensorflow学习教程------实现lenet并且进行二分类
#coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...
随机推荐
- SSIS教程:创建简单的ETL包
SSIS: Microsoft SQL Server Integration Services.是一个可用于生成高性能数据集成解决方案的平台,其中包括数据仓库的提取(Extract).转换(Trans ...
- .Net高阶异常处理第二篇~~ dump进阶之MiniDumpWriter
dump文件相信有些朋友已经很熟悉了,dump文件的作用在于保存进程运行时的堆栈信息,方便日后排查软件故障,提升软件质量.关于dump分析工具windbg.adplus的文章更多了,如果您还不知道怎么 ...
- JVM(五) class类文件的结构
概述 class类文件的结构可见下面这样图(出处见参考资料),可以参照下面的例子,对应十六进制码,找出找出相应的信息. 其中u2 , u4 表示的意思是占用两个字节和占用四个字节,下面我们将会各项说明 ...
- 宏定义中的反斜杠"\"和宏定义的细节说明
转载自:http://www.wtoutiao.com/p/K6csca.html 在阅读C语言代码经常可以看到代码中出现反斜杠"\",不是很明白它的意思,遂对反斜杠"\ ...
- 导出maven项目依赖的jar包
注意使用mvn命令是需要配置好maven的环境变量 一.导出到自定义目录中 在maven项目下创建lib文件夹,输入以下命令: mvn dependency:copy-dependencies -Do ...
- JAVA_SE_Day02 String 的正则表达式
字符串支持正则表达式的方法一: boolean matches(String regex) 注意: 给定的正则表达式就算不指定边界符(^,$),也会全匹配验证 空字符串和null 空字符串是看不见,而 ...
- LinkedList封装
LinkedList简单的封装 package com.cn.test.jihe.LinkedList; import java.util.NoSuchElementException; public ...
- SZU5
A - Couple doubi 这种题不要想复杂,直接找规律.找不出规律就打表找规律 #include <iostream> #include <string> #inclu ...
- thinkphp更新数据库的时候where('')为字符串
if($user->where('phone='.$phone)->save($dataList)){} if($user->where(array('phone' =>$ph ...
- Hibernate 性能优化一对一关联映射
概述: hibernate提供了两种映射一对一关联的方式:按照外键映射和按照主键映射. 下面以员工账号和员工档案为例 ,介绍两种映射方式,并使用这两种映射方式分别完成以下持久化操作: (1)保存员工档 ...