from mxnet import gluon,init
from mxnet.gluon import loss as gloss,nn
from mxnet.gluon import data as gdata
from mxnet import autograd,nd
import gluonbook as gb
import sys # 读取数据
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),batch_size=batch_size,shuffle=False,num_workers=num_workers) # 模型参数初始化
net = nn.Sequential()
net.add(nn.Dense(10))
net.initialize(init.Normal(sigma=0.01)) # 损失函数
loss = gloss.SoftmaxCrossEntropyLoss() # 优化算法
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1}) def accuracy(y_hat, y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net):
acc = 0
for X, y in data_iter:
acc += accuracy(net(X), y)
return acc / len(data_iter) num_epochs = 5 def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,trainer=None):
for epoch in range(num_epochs):
train_l_sum = 0
train_acc_sum = 0
for X,y in train_iter:
with autograd.record():
y_hat = net(X)
l = loss(y_hat,y)
l.backward() if trainer is None:
gb.sgd(params,lr,batch_size)
else:
trainer.step(batch_size) train_l_sum += l.mean().asscalar() test_acc = evaluate_accuracy(test_iter,net)
print('epoch %d,loss %.4f,test acc %.3f'%(epoch+1,train_l_sum / len(train_iter),test_acc)) train(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,trainer)

gluon实现softmax分类FashionMNIST的更多相关文章

  1. 从零和使用mxnet实现softmax分类

    1.softmax从零实现 from mxnet.gluon import data as gdata from sklearn import datasets from mxnet import n ...

  2. 学习笔记TF010:softmax分类

    回答多选项问题,使用softmax函数,对数几率回归在多个可能不同值上的推广.函数返回值是C个分量的概率向量,每个分量对应一个输出类别概率.分量为概率,C个分量和始终为1.每个样本必须属于某个输出类别 ...

  3. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  4. 动手学深度学习7-从零开始完成softmax分类

    获取和读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 小结 import torch import torchvision import numpy ...

  5. AlexNet 分类 FashionMNIST

    from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...

  6. LeNet 分类 FashionMNIST

    import mxnet as mx from mxnet import autograd, gluon, init, nd from mxnet.gluon import loss as gloss ...

  7. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  8. Keras 多层感知机 多类别的 softmax 分类模型代码

    Multilayer Perceptron (MLP) for multi-class softmax classification: from keras.models import Sequent ...

  9. tf.nn.softmax 分类

    tf.nn.softmax(logits,axis=None,name=None,dim=None) 参数: logits:一个非空的Tensor.必须是下列类型之一:half, float32,fl ...

随机推荐

  1. web服务器架构演化及所其技术知识体系(分布式的由来)

    文章标题是我自己取的,内容来着百度百科k5665219的一篇回答,觉得讲的很不错就转载过来了. 最开始,由于某些想法,于是在互联网上搭建了一个网站,这个时候甚至有可能主机都是租借的,但由于这篇文章我们 ...

  2. 切换myEclipse工作空间后设置,myEclipse添加注释/设置豆沙背景颜色/调节字体大小

    一.添加注释 操作位置: 注释规范 Files/** * @文件名称: ${file_name} * @文件路径: ${package_name} * @功能描述: ${todo} * @作者: ${ ...

  3. 算法市场 Algorithmia

    算法市场 官网:(需要***,fan qiang,不然可能访问不了或登录不了) https://algorithmia.com/ 官方的例子: 我不用 curl 发请求,把 curl 命令粘贴给你们用 ...

  4. maven 引入本地 jar

    $ 参考1 : https://www.cnblogs.com/lixuwu/p/5855031.html ! 注: 参考1中的第二种方法,作者并未实际尝试,我尝试了,虽然在eclipse 中编译不报 ...

  5. (转)DB2中的一些函数

    DB2中的一些函数 原文:https://www.cnblogs.com/ShaYeBlog/archive/2012/08/27/2658025.html 最近用DB2,数据库之间的差异还是很大的, ...

  6. 程序包com.sun.image.codec.jpeg不存在

    在pox.xml中引入依赖 <dependency><groupId>rt</groupId><artifactId>rt</artifactId ...

  7. tornado handler 方法复用的 3 个方法

    tornado handler 调用 特性 在一次 tornado 请求中调用其他 tornado handler 中的方法, 比如 run 方法 引言 在后台开发中, 有时需要做一些功能的整合, 比 ...

  8. 两个三汇API使用的坑

    最近呼叫中心走火入魔了,我的<一步一步开发呼叫中心>系列编写过程中,遇到各种的问题,今天晚上,来记录一下纠结了我N久的一个问题: 内线通过板卡外呼时,如果对方的呼叫中心需要发送按键响应(如 ...

  9. node.js获取cookie

    node.js 获取cookie var Cookies ={}; if (req.headers.cookie != null) { req.headers.cookie.split(';').fo ...

  10. Java笔记之Scanner先读取一个数字,在读取一行字符串方法分析

    问题:大家在学习Java读取数据的时候一般都是使用Scanner方法读取数据,但是其中有一个小问题大家可能不知道, 就是我们在使用scanner的时候如果你先读取一个数字,在读取一行带有空格的字符串, ...