from mxnet import gluon,init
from mxnet.gluon import loss as gloss,nn
from mxnet.gluon import data as gdata
from mxnet import autograd,nd
import gluonbook as gb
import sys # 读取数据
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),batch_size=batch_size,shuffle=False,num_workers=num_workers) # 模型参数初始化
net = nn.Sequential()
net.add(nn.Dense(10))
net.initialize(init.Normal(sigma=0.01)) # 损失函数
loss = gloss.SoftmaxCrossEntropyLoss() # 优化算法
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1}) def accuracy(y_hat, y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net):
acc = 0
for X, y in data_iter:
acc += accuracy(net(X), y)
return acc / len(data_iter) num_epochs = 5 def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,trainer=None):
for epoch in range(num_epochs):
train_l_sum = 0
train_acc_sum = 0
for X,y in train_iter:
with autograd.record():
y_hat = net(X)
l = loss(y_hat,y)
l.backward() if trainer is None:
gb.sgd(params,lr,batch_size)
else:
trainer.step(batch_size) train_l_sum += l.mean().asscalar() test_acc = evaluate_accuracy(test_iter,net)
print('epoch %d,loss %.4f,test acc %.3f'%(epoch+1,train_l_sum / len(train_iter),test_acc)) train(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,trainer)

gluon实现softmax分类FashionMNIST的更多相关文章

  1. 从零和使用mxnet实现softmax分类

    1.softmax从零实现 from mxnet.gluon import data as gdata from sklearn import datasets from mxnet import n ...

  2. 学习笔记TF010:softmax分类

    回答多选项问题,使用softmax函数,对数几率回归在多个可能不同值上的推广.函数返回值是C个分量的概率向量,每个分量对应一个输出类别概率.分量为概率,C个分量和始终为1.每个样本必须属于某个输出类别 ...

  3. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  4. 动手学深度学习7-从零开始完成softmax分类

    获取和读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 小结 import torch import torchvision import numpy ...

  5. AlexNet 分类 FashionMNIST

    from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...

  6. LeNet 分类 FashionMNIST

    import mxnet as mx from mxnet import autograd, gluon, init, nd from mxnet.gluon import loss as gloss ...

  7. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  8. Keras 多层感知机 多类别的 softmax 分类模型代码

    Multilayer Perceptron (MLP) for multi-class softmax classification: from keras.models import Sequent ...

  9. tf.nn.softmax 分类

    tf.nn.softmax(logits,axis=None,name=None,dim=None) 参数: logits:一个非空的Tensor.必须是下列类型之一:half, float32,fl ...

随机推荐

  1. Web 前端 中高难度问题(希望看完之后的你可以拿到Offer^v^)

    1. 解释 event loop Javascript是单线程的,所有的同步任务都会在主线程中执行. 主线程之外,还有一个任务队列.每当一个异步任务有结果了,就往任务队列里塞一个事件. 当主线程中的任 ...

  2. TCP/IP协议分为哪四层,具体作用是什么。

    TCP/IP通讯协议采用了4层的层级结构,每一层都呼叫它的下一层所提供的网络来完成自己的需求.这4层分别为: 应用层:应用程序间沟通的层,如简单电子邮件传输(SMTP).文件传输协议(FTP).网络远 ...

  3. Oracle RAC集群搭建(五)--oracle部署

    01,配置好环境 节点01--node1 ORACLE_BASE=/oracle/app/oracle ORACLE_HOME=$ORACLE_BASE/product//db_1 ORACLE_SI ...

  4. filter 静态资源

    package com.itheima.web.filter; import java.io.IOException; import javax.servlet.Filter; import java ...

  5. k8s单节点集群部署应用

    之所以用k8s来部署应用,就是因为k8s可以灵活的控制集群规模,进行扩充或者收缩.生产上我们要配置的参数较多,命令行的方式显然不能满足需求,我们应该使用基于配置文件的方式.接下来做一个部署的demo: ...

  6. nginx location 配置阐述优先级别使用说明

    使用nginx 有大半年了,它的高性能,稳定性表现很好. 这里也得到很多人的认可. 其中它的配置,有点像写程序一样,每行命令结尾一个";"号,语句块用"{}"括 ...

  7. ORACLE 查询被锁定表及解锁释放session的方法

    后台数据库操作某个表时处于假死状态,可能该表被某个用户锁定,导致其他用户无法继续操作, 如下是解决方案和实例. 查被锁的表,以及用户 SELECT object_name, machine, s.si ...

  8. QQ音乐:React v16 新特性实践

    欢迎大家前往腾讯云+社区,获取更多腾讯海量技术实践干货哦~ 本文由QQ音乐技术团队发表于云+社区专栏 自从去年9月份 React 团队发布了 v16.0 版本开始,到18年3月刚发布的 v16.3 版 ...

  9. Windows 10 下彻底关闭 Hyper-V 服务

    由于最近需要用到VMWare Workstation 安装虚拟机,安装完成后,发现任何64位的系统都不能正常安装.可能是Hyper-V与VMWare Workstation的冲突造成的不兼容,所以就去 ...

  10. 网络连接和初始HTTP请求

    浏览器检索网页,先从URL开始,使用DNS确定IP地址,再用基于TCP和HTTP协议连接到服务器,请求相关的内容,得到相应,浏览器解析并呈现到屏幕上.服务器响应后,浏览器响应不会同时全部到达,会陆续到 ...