接上一篇机器学习笔记(3):多类逻辑回归继续,这次改用gluton来实现关键处理,原文见这里 ,代码如下:

import matplotlib.pyplot as plt
import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd
from mxnet import autograd def transform(data, label):
return data.astype('float32')/255, label.astype('float32') mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform) def show_images(images):
n = images.shape[0]
_, figs = plt.subplots(1, n, figsize=(15, 15))
for i in range(n):
figs[i].imshow(images[i].reshape((28, 28)).asnumpy())
figs[i].axes.get_xaxis().set_visible(False)
figs[i].axes.get_yaxis().set_visible(False)
plt.show() def get_text_labels(label):
text_labels = [
'T 恤', '长 裤', '套头衫', '裙 子', '外 套',
'凉 鞋', '衬 衣', '运动鞋', '包 包', '短 靴'
]
return [text_labels[int(i)] for i in label] data, label = mnist_train[0:10] print('example shape: ', data.shape, 'label:', label) show_images(data) print(get_text_labels(label)) batch_size = 256 train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False) num_inputs = 784
num_outputs = 10 W = nd.random_normal(shape=(num_inputs, num_outputs))
b = nd.random_normal(shape=num_outputs)
params = [W, b] for param in params:
param.attach_grad() def accuracy(output, label):
return nd.mean(output.argmax(axis=1) == label).asscalar() def _get_batch(batch):
if isinstance(batch, mx.io.DataBatch):
data = batch.data[0]
label = batch.label[0]
else:
data, label = batch
return data, label def evaluate_accuracy(data_iterator, net):
acc = 0.
if isinstance(data_iterator, mx.io.MXDataIter):
data_iterator.reset()
for i, batch in enumerate(data_iterator):
data, label = _get_batch(batch)
output = net(data)
acc += accuracy(output, label)
return acc / (i+1) #使用gluon定义计算模型
net = gluon.nn.Sequential()
with net.name_scope():
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(10))
net.initialize() #损失函数(使用交叉熵函数)
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() #使用梯度下降法生成训练器,并设置学习率为0.1
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1}) for epoch in range(5):
train_loss = 0.
train_acc = 0.
for data, label in train_data:
with autograd.record():
output = net(data)
#计算损失
loss = softmax_cross_entropy(output, label)
loss.backward()
#使用sgd的trainer继续向前"走一步"
trainer.step(batch_size) train_loss += nd.mean(loss).asscalar()
train_acc += accuracy(output, label) test_acc = evaluate_accuracy(test_data, net)
print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc)) data, label = mnist_test[0:10]
show_images(data)
print('true labels')
print(get_text_labels(label)) predicted_labels = net(data).argmax(axis=1)
print('predicted labels')
print(get_text_labels(predicted_labels.asnumpy()))

相对上一版原始手动方法,使用gluon修改的地方都加了注释,不多解释。运行效果如下:

相对之前的版本可以发现,几乎相同的参数,但是准确度有所提升,从0.7几上升到0.8几,10个里错误的预测数从4个下降到3个,说明gluon在一些细节上做了更好的优化。关于优化的细节,这里有一些讨论,供参考

机器学习笔记(4):多类逻辑回归-使用gluton的更多相关文章

  1. 吴恩达机器学习笔记 —— 7 Logistic回归

    http://www.cnblogs.com/xing901022/p/9332529.html 本章主要讲解了逻辑回归相关的问题,比如什么是分类?逻辑回归如何定义损失函数?逻辑回归如何求最优解?如何 ...

  2. 吴恩达机器学习笔记14-逻辑回归(Logistic Regression)

    在分类问题中,你要预测的变量

  3. 【转】机器学习笔记之(3)——Logistic回归(逻辑斯蒂回归)

    原文链接:https://blog.csdn.net/gwplovekimi/article/details/80288964 本博文为逻辑斯特回归的学习笔记.由于仅仅是学习笔记,水平有限,还望广大读 ...

  4. 机器学习实战(Machine Learning in Action)学习笔记————05.Logistic回归

    机器学习实战(Machine Learning in Action)学习笔记————05.Logistic回归 关键字:Logistic回归.python.源码解析.测试作者:米仓山下时间:2018- ...

  5. Python机器学习笔记:使用Keras进行回归预测

    Keras是一个深度学习库,包含高效的数字库Theano和TensorFlow.是一个高度模块化的神经网络库,支持CPU和GPU. 本文学习的目的是学习如何加载CSV文件并使其可供Keras使用,如何 ...

  6. Python机器学习笔记:sklearn库的学习

    网上有很多关于sklearn的学习教程,大部分都是简单的讲清楚某一方面,其实最好的教程就是官方文档. 官方文档地址:https://scikit-learn.org/stable/ (可是官方文档非常 ...

  7. Python机器学习笔记:不得不了解的机器学习面试知识点(1)

    机器学习岗位的面试中通常会对一些常见的机器学习算法和思想进行提问,在平时的学习过程中可能对算法的理论,注意点,区别会有一定的认识,但是这些知识可能不系统,在回答的时候未必能在短时间内答出自己的认识,因 ...

  8. cs229 斯坦福机器学习笔记(一)-- 入门与LR模型

    版权声明:本文为博主原创文章,转载请注明出处. https://blog.csdn.net/Dinosoft/article/details/34960693 前言 说到机器学习,非常多人推荐的学习资 ...

  9. Python机器学习笔记:不得不了解的机器学习知识点(2)

    之前一篇笔记: Python机器学习笔记:不得不了解的机器学习知识点(1) 1,什么样的资料集不适合用深度学习? 数据集太小,数据样本不足时,深度学习相对其它机器学习算法,没有明显优势. 数据集没有局 ...

随机推荐

  1. DOS命令大全(转)

    dos命令大全 CMD是command的缩写,是windows环境下的虚拟DOS窗口,提供有DOS命令,功能强大,如果你以前学习过DOS操作,那就小儿科了.是基于Windows的命令行窗口,在开始-- ...

  2. Go语言学习之路(持续更新中)

    菜鸟 Go语言教程 教程(RUNOOB.COM):http://www.runoob.com/go/go-tutorial.html Go全球官网:https://golang.org/ (2018- ...

  3. java 动态代理(类型信息)

    代理是基本的设计模式之一它为你提供额外的或不同的操作,而插入的用来代替"实际"对象的对象. package typeinfo; //: typeinfo/SimpleProxyDe ...

  4. Laravel Cache 缓存钉钉微应用的 Access Token

    钉钉微应用的 Access token 如何获取? Access_Token 是企业访问钉钉开放平台全局接口的唯一凭证,即调用接口时需携带Access_Token.从接口列表看,所有接口都需要携带 a ...

  5. js事件监听

    /* 事件监听器 addEventListener() removeEventListener() 传统事件绑定: 1.重复添加会,后添加的后覆盖前面的. */ 示例代码中的html结构: <b ...

  6. ***腾讯云直播(含微信小程序直播)研究资料汇总-原创

    这段时间抽空研究了下直播技术,综合比较了下腾讯云直播的技术和文档方面最齐全,现把一些技术资料和文档归集如下: 1.微信小程序移动直播入门导读 https://cloud.tencent.com/doc ...

  7. binlog和redo log日志提交

    组提交(group commit)是MYSQL处理日志的一种优化方式,主要为了解决写日志时频繁刷磁盘的问题.组提交伴随着MYSQL的发展不断优化,从最初只支持redo log 组提交,到目前5.6官方 ...

  8. [APIO2011]方格染色

    题解: 挺不错的一道题目 首先4个里面只有1个1或者3个1 那么有一个特性就是4个数xor为1 为什么要用xor呢? 在于xor能把相同的数消去 然后用一般的套路 看看确定哪些值能确定全部 yy一下就 ...

  9. Webservice返回json数据格式

    问题: 我将结果内容用字符串拼接成Json数据并返回的时候,会在结果前面添加xml头部,结果如下. <span ><string xmlns="http://tempuri ...

  10. BZOJ1263 [SCOI2006]整数划分 高精度

    欢迎访问~原文出处——博客园-zhouzhendong 去博客园看该题解 题目传送门 - BZOJ1263 题意概括 将n写成若干个正整数之和,并且使这些正整数的乘积最大. 例如,n=13,则当n表示 ...