from mxnet import gluon,init
from mxnet.gluon import loss as gloss,nn
from mxnet.gluon import data as gdata
from mxnet import autograd,nd
import gluonbook as gb
import sys # 读取数据
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),batch_size=batch_size,shuffle=False,num_workers=num_workers) # 模型参数初始化
net = nn.Sequential()
net.add(nn.Dense(10))
net.initialize(init.Normal(sigma=0.01)) # 损失函数
loss = gloss.SoftmaxCrossEntropyLoss() # 优化算法
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1}) def accuracy(y_hat, y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net):
acc = 0
for X, y in data_iter:
acc += accuracy(net(X), y)
return acc / len(data_iter) num_epochs = 5 def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,trainer=None):
for epoch in range(num_epochs):
train_l_sum = 0
train_acc_sum = 0
for X,y in train_iter:
with autograd.record():
y_hat = net(X)
l = loss(y_hat,y)
l.backward() if trainer is None:
gb.sgd(params,lr,batch_size)
else:
trainer.step(batch_size) train_l_sum += l.mean().asscalar() test_acc = evaluate_accuracy(test_iter,net)
print('epoch %d,loss %.4f,test acc %.3f'%(epoch+1,train_l_sum / len(train_iter),test_acc)) train(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,trainer)

gluon实现softmax分类FashionMNIST的更多相关文章

  1. 从零和使用mxnet实现softmax分类

    1.softmax从零实现 from mxnet.gluon import data as gdata from sklearn import datasets from mxnet import n ...

  2. 学习笔记TF010:softmax分类

    回答多选项问题,使用softmax函数,对数几率回归在多个可能不同值上的推广.函数返回值是C个分量的概率向量,每个分量对应一个输出类别概率.分量为概率,C个分量和始终为1.每个样本必须属于某个输出类别 ...

  3. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  4. 动手学深度学习7-从零开始完成softmax分类

    获取和读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 小结 import torch import torchvision import numpy ...

  5. AlexNet 分类 FashionMNIST

    from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...

  6. LeNet 分类 FashionMNIST

    import mxnet as mx from mxnet import autograd, gluon, init, nd from mxnet.gluon import loss as gloss ...

  7. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  8. Keras 多层感知机 多类别的 softmax 分类模型代码

    Multilayer Perceptron (MLP) for multi-class softmax classification: from keras.models import Sequent ...

  9. tf.nn.softmax 分类

    tf.nn.softmax(logits,axis=None,name=None,dim=None) 参数: logits:一个非空的Tensor.必须是下列类型之一:half, float32,fl ...

随机推荐

  1. 安装软件或运行软件时提示缺少api-ms-win-crt-runtime库解决方法

    最近碰到一个问题,在我软件安装或运行时会提示缺少api-ms-win-crt-runtime-|1-1-0.dll 当然第一个想到的是运行库没有装,但是很清楚的是我的电脑是装过vc_redist_20 ...

  2. Android应用捕获全局异常自定义处理

    [2016-06-30]最新的全局异常处理DRCrashHandler已经集成在DR_support_lib库中 具体请看: https://coding.net/u/wrcold520/p/DR_s ...

  3. HUID 5558 Alice's Classified Message 后缀数组+单调栈+二分

    http://acm.hdu.edu.cn/showproblem.php?pid=5558 对于每个后缀suffix(i),想要在前面i - 1个suffix中找到一个pos,使得LCP最大.这样做 ...

  4. 【转】python平台libsvm安装

    来源:http://blog.csdn.net/prom1201/article/details/51382358 网上有很多麻烦的在win64机器上安装libsvm的步骤,实际上只要在下面网站找到l ...

  5. 读书笔记-NIO的工作方式

    读书笔记-NIO的工作方式 1.BIO是阻塞IO,一旦阻塞线程将失去对CPU的使用权,当前的网络IO有一些解决办法:1)一个客户端对应一个处理线程:2)采用线程池.但也会出问题. 2.NIO的关键类C ...

  6. 远程SQL Server连接不上

    运行 cmd -> 输入 netsh winsock reset重启后 应该可以连接sql了

  7. spring boot 2.0.0 + mybatis 报:Property 'sqlSessionFactory' or 'sqlSessionTemplate' are required

    spring boot 2.0.0 + mybatis 报:Property 'sqlSessionFactory' or 'sqlSessionTemplate' are required 无法启动 ...

  8. 10th week task -3 Arrow function restore

    Arrow function restore 为什么叫Arrow Function?因为它的定义用的就是一个箭头: x => x * x 上面的箭头函数相当于: function (x) { r ...

  9. String StringBuffer StringBuilder对比

    1.相同点 三者都可以用来存储字符串类型数据. 2.不同点 String类型对象内容不可变,每变化一次都会创建一个新的对象. StringBuiler与StringBuffer的内容与长度均可以发生变 ...

  10. vuejs源码摘抄

    订阅功能的部分实现代码如下: /* */ var uid = 0; /** * A dep is an observable that can have multiple * directives s ...