from mxnet import gluon,init,nd,autograd
from mxnet.gluon import data as gdata,nn
from mxnet.gluon import loss as gloss
import mxnet as mx
import time
import os
import sys # 建立网络
net = nn.Sequential()
# 使用较大的 11 x 11 窗口来捕获物体。同时使用步幅 4 来较大减小输出高和宽。
# 这里使用的输入通道数比 LeNet 中的也要大很多。
net.add(nn.Conv2D(96, kernel_size=11, strides=4, activation='relu'),
nn.MaxPool2D(pool_size=3, strides=2),
# 减小卷积窗口,使用填充为 2 来使得输入输出高宽一致,且增大输出通道数。
nn.Conv2D(256, kernel_size=5, padding=2, activation='relu'),
nn.MaxPool2D(pool_size=3, strides=2),
# 连续三个卷积层,且使用更小的卷积窗口。除了最后的卷积层外,进一步增大了输出通道数。
# 前两个卷积层后不使用池化层来减小输入的高和宽。
nn.Conv2D(384, kernel_size=3, padding=1, activation='relu'),
nn.Conv2D(384, kernel_size=3, padding=1, activation='relu'),
nn.Conv2D(256, kernel_size=3, padding=1, activation='relu'),
nn.MaxPool2D(pool_size=3, strides=2),
# 这里全连接层的输出个数比 LeNet 中的大数倍。使用丢弃层来缓解过拟合。
nn.Dense(4096, activation="relu"), nn.Dropout(0.5),
nn.Dense(4096, activation="relu"), nn.Dropout(0.5),
# 输出层。由于这里使用 Fashion-MNIST,所以用类别数为 10,而非论文中的 1000。
nn.Dense(10)) X = nd.random.uniform(shape=(1,1,224,224))
net.initialize()
for layer in net:
X = layer(X)
print(layer.name,'output shape:\t',X.shape) # 读取数据
# fashionMNIST 28*28 转为224*224
def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join(
'~', '.mxnet', 'datasets', 'fashion-mnist')):
root = os.path.expanduser(root) # 展开用户路径 '~'。
transformer = []
if resize:
transformer += [gdata.vision.transforms.Resize(resize)]
transformer += [gdata.vision.transforms.ToTensor()]
transformer = gdata.vision.transforms.Compose(transformer)
mnist_train = gdata.vision.FashionMNIST(root=root, train=True)
mnist_test = gdata.vision.FashionMNIST(root=root, train=False)
num_workers = 0 if sys.platform.startswith('win32') else 4
train_iter = gdata.DataLoader(
mnist_train.transform_first(transformer), batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(
mnist_test.transform_first(transformer), batch_size, shuffle=False,
num_workers=num_workers)
return train_iter, test_iter batch_size = 128
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224) def accuracy(y_hat,y):
return (y_hat.argmax(axis=1)==y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter,net,ctx):
acc = nd.array([0],ctx=ctx)
for X,y in data_iter:
X = X.as_in_context(ctx)
y = y.as_in_context(ctx)
acc+=accuracy(net(X),y)
return acc.asscalar() / len(data_iter) # 训练模型
def train(net,train_iter,test_iter,batch_size,trainer,ctx,num_epochs):
print('training on',ctx)
loss = gloss.SoftmaxCrossEntropyLoss() for epoch in range(num_epochs):
train_l_sum = 0
train_acc_sum = 0
start = time.time()
for X,y in train_iter:
X = X.as_in_context(ctx)
y = y.as_in_context(ctx) with autograd.record():
y_hat = net(X)
l = loss(y_hat,y) l.backward()
trainer.step(batch_size) train_l_sum += l.mean().asscalar()
train_acc_sum += evaluate_accuracy(test_iter,net,ctx)
test_acc = evaluate_accuracy(test_iter,net,ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
'time %.1f sec' % (epoch+1,train_l_sum/len(train_iter),test_acc,time.time()-start)) def try_gpu():
try:
ctx = mx.gpu()
_ = nd.zeros((1,),ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx lr = 0.01
num_epochs = 5
ctx = try_gpu() net.initialize(force_reinit=True,ctx=ctx,init=init.Xavier())
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':lr})
train(net,train_iter,test_iter,batch_size,trainer,ctx,num_epochs)

AlexNet 分类 FashionMNIST的更多相关文章

  1. LeNet 分类 FashionMNIST

    import mxnet as mx from mxnet import autograd, gluon, init, nd from mxnet.gluon import loss as gloss ...

  2. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  3. gluon实现softmax分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss,nn from mxnet.gluon import data a ...

  4. PyTorch 介绍 | BUILD THE NEURAL NETWORK

    神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...

  5. Pytorch分类和准确性评估--基于FashionMNIST数据集

    最近在学习Pytorch v1.3最新版和Tensorflow2.0. 我学习Pytorch的主要途径:莫烦Python和Pytorch 1.3官方文档 ,Pytorch v1.3跟之前的Pytorc ...

  6. 【分类】AlexNet论文总结

    目录 0. 论文链接 1. 概述 2. 对数据集的处理 3. 网络模型 3.1 ReLU Nonlinearity 3.2 Training on multiple GPUs 3.3 Local Re ...

  7. AlexNet实现cifar10数据集分类

    import tensorflow as tf import os from matplotlib import pyplot as plt import tensorflow.keras.datas ...

  8. 从头学pytorch(十五):AlexNet

    AlexNet AlexNet是2012年提出的一个模型,并且赢得了ImageNet图像识别挑战赛的冠军.首次证明了由计算机自动学习到的特征可以超越手工设计的特征,对计算机视觉的研究有着极其重要的意义 ...

  9. 《动手学深度学习》系列笔记—— 1.2 Softmax回归与分类模型

    目录 softmax的基本概念 交叉熵损失函数 模型训练和预测 获取Fashion-MNIST训练集和读取数据 get dataset softmax从零开始的实现 获取训练集数据和测试集数据 模型参 ...

随机推荐

  1. django基础一:web、wsgi、mvc、mtv

    一.web框架 web框架,即framework,特指为解决一个开放性问题而设计的具有一定约束性的支撑结构,使用框架可以快速开发特定的系统.他山之石,可以攻玉.python的所有web框架,都是对so ...

  2. Intellij IDEA 各种乱码解决方案 posted @ 2017-06-23 15:31:06

    一次解决所有问题,只需做配置文件的修改即可 解决方案:       在      IntelliJ IDEA 2016.1\bin\idea64.exe.vmoptions        Intell ...

  3. eclipse切换workspace后配置问题

    正常情况下如果切换了eclipse的workspace后,需要重新配置eclipse,但是可以将原工作目录中的.metadata/.plugins/org.eclipse.core.runtime拷贝 ...

  4. Redis安装、配置

    一.Redis安装 Linux安装 下载tar包,移至Linux目录下 解压:tar -zxvf redis-4.0.1.tar.gz 安装gcc:yum install gcc-c++(编译失败需安 ...

  5. VC++ 崩溃处理以及打印调用堆栈

    title: VC++ 崩溃处理以及打印调用堆栈 tags: [VC++, 结构化异常处理, 崩溃日志记录] date: 2018-08-28 20:59:54 categories: windows ...

  6. SPOJ:COT2 Count on a tree II

    题意 给定一个n个节点的树,每个节点表示一个整数,问u到v的路径上有多少个不同的整数. n=40000,m=100000 Sol 树上莫队模板题 # include <bits/stdc++.h ...

  7. mui.ajax()和asp.net sql服务器数据交互【2】json数组和封装

    今天没有做循环创建显示:可以参考张鑫旭的文章:<基于HTML模板和JSON数据的JavaScript交互> 1.ashx页面代码 //下面的封装一般框架底层都是写好的:连接 数据库和获取D ...

  8. chrome 控制台里 打印对象

    我们经常使用 chrome 的 控制台 console.log()     打印 但有时候我们需要把一个对象复制下来(而这个对象嵌套比较深) 打印出来的我们不好复制 如下图 我们可以使用谷歌控制台的c ...

  9. C语言——栈的基本运算在顺序栈上的实现

    头文件 Seqstack.h #define maxsize 6 //const int maxsize = 6; // 顺序栈 typedef struct seqstack { int data[ ...

  10. C# 索引器的使用

    索引器允许类或者结构的实例按照与数组相同的方式进行索引取值,索引器与属性类似,不同的是索引器的访问是带参的. 索引器和数组比较: (1)索引器的索引值(Index)类型不受限制 (2)索引器允许重载 ...