机器学习笔记(6):多类逻辑回归-使用gluon
上一篇演示了纯手动添加隐藏层,这次使用gluon让代码更精减,代码来自:https://zh.gluon.ai/chapter_supervised-learning/mlp-gluon.html
from mxnet import gluon
from mxnet import ndarray as nd
import matplotlib.pyplot as plt
import mxnet as mx
from mxnet import autograd def transform(data, label):
return data.astype('float32')/255, label.astype('float32') mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform) def show_images(images):
n = images.shape[0]
_, figs = plt.subplots(1, n, figsize=(15, 15))
for i in range(n):
figs[i].imshow(images[i].reshape((28, 28)).asnumpy())
figs[i].axes.get_xaxis().set_visible(False)
figs[i].axes.get_yaxis().set_visible(False)
plt.show() def get_text_labels(label):
text_labels = [
'T 恤', '长 裤', '套头衫', '裙 子', '外 套',
'凉 鞋', '衬 衣', '运动鞋', '包 包', '短 靴'
]
return [text_labels[int(i)] for i in label] data, label = mnist_train[0:10] print('example shape: ', data.shape, 'label:', label)
show_images(data)
print(get_text_labels(label)) batch_size = 256
train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False) #计算模型
net = gluon.nn.Sequential()
with net.name_scope():
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(256, activation="relu"))
net.add(gluon.nn.Dense(10))
net.initialize() softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() #定义训练器
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5}) def accuracy(output, label):
return nd.mean(output.argmax(axis=1) == label).asscalar() def _get_batch(batch):
if isinstance(batch, mx.io.DataBatch):
data = batch.data[0]
label = batch.label[0]
else:
data, label = batch
return data, label def evaluate_accuracy(data_iterator, net):
acc = 0.
if isinstance(data_iterator, mx.io.MXDataIter):
data_iterator.reset()
for i, batch in enumerate(data_iterator):
data, label = _get_batch(batch)
output = net(data)
acc += accuracy(output, label)
return acc / (i+1) for epoch in range(5):
train_loss = 0.
train_acc = 0.
for data, label in train_data:
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size) #使用训练器,向"前"走一步 train_loss += nd.mean(loss).asscalar()
train_acc += accuracy(output, label) test_acc = evaluate_accuracy(test_data, net)
print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
epoch, train_loss/len(train_data), train_acc/len(train_data), test_acc)) data, label = mnist_test[0:10]
show_images(data)
print('true labels')
print(get_text_labels(label)) predicted_labels = net(data).argmax(axis=1)
print('predicted labels')
print(get_text_labels(predicted_labels.asnumpy()))
有变化的地方,已经加上了注释。运行效果,跟一篇完全相同,就不重复贴图了
机器学习笔记(6):多类逻辑回归-使用gluon的更多相关文章
- 吴恩达机器学习笔记 —— 7 Logistic回归
http://www.cnblogs.com/xing901022/p/9332529.html 本章主要讲解了逻辑回归相关的问题,比如什么是分类?逻辑回归如何定义损失函数?逻辑回归如何求最优解?如何 ...
- 机器学习笔记(2):线性回归-使用gluon
代码来自:https://zh.gluon.ai/chapter_supervised-learning/linear-regression-gluon.html from mxnet import ...
- 吴恩达机器学习笔记14-逻辑回归(Logistic Regression)
在分类问题中,你要预测的变量
- 机器学习笔记(4):多类逻辑回归-使用gluton
接上一篇机器学习笔记(3):多类逻辑回归继续,这次改用gluton来实现关键处理,原文见这里 ,代码如下: import matplotlib.pyplot as plt import mxnet a ...
- 【转】机器学习笔记之(3)——Logistic回归(逻辑斯蒂回归)
原文链接:https://blog.csdn.net/gwplovekimi/article/details/80288964 本博文为逻辑斯特回归的学习笔记.由于仅仅是学习笔记,水平有限,还望广大读 ...
- 机器学习实战(Machine Learning in Action)学习笔记————05.Logistic回归
机器学习实战(Machine Learning in Action)学习笔记————05.Logistic回归 关键字:Logistic回归.python.源码解析.测试作者:米仓山下时间:2018- ...
- Python机器学习笔记:使用Keras进行回归预测
Keras是一个深度学习库,包含高效的数字库Theano和TensorFlow.是一个高度模块化的神经网络库,支持CPU和GPU. 本文学习的目的是学习如何加载CSV文件并使其可供Keras使用,如何 ...
- Python机器学习笔记:sklearn库的学习
网上有很多关于sklearn的学习教程,大部分都是简单的讲清楚某一方面,其实最好的教程就是官方文档. 官方文档地址:https://scikit-learn.org/stable/ (可是官方文档非常 ...
- Python机器学习笔记:不得不了解的机器学习面试知识点(1)
机器学习岗位的面试中通常会对一些常见的机器学习算法和思想进行提问,在平时的学习过程中可能对算法的理论,注意点,区别会有一定的认识,但是这些知识可能不系统,在回答的时候未必能在短时间内答出自己的认识,因 ...
随机推荐
- CentOS 6.5 rsync+inotify实现数据实时同步备份
CentOS 6.5 rsync+inotify实现数据实时同步备份 rsync remote sync 远程同步,同步是把数据从缓冲区同步到磁盘上去的.数据在内存缓存区完成之后还没有写入到磁盘 ...
- CentOS 6.5环境实现corosync+pacemaker实现DRBD高可用
DRBD (Distributed Replicated Block Device)分布式复制块设备,它是 Linux 平台上的分散式储存系统,通常用于高可用性(high availability, ...
- Python-CSS进阶
0. 什么时候该用什么布局 <!-- 定位布局: 以下两种布局不易解决的问题, 盒子需要脱离文档流处理 --> <!-- 浮动布局: 一般有block特性的盒子,水平排列显示 --& ...
- java多线程快速入门(八)
设置线程优先级:join() package com.cppdy; class MyThreadA extends Thread{ MyThreadB b; public MyThreadA(MyTh ...
- LeetCode(32):最长有效括号
Hard! 题目描述: 给定一个只包含 '(' 和 ')' 的字符串,找出最长的包含有效括号的子串的长度. 示例 1: 输入: "(()" 输出: 2 解释: 最长有效括号子串为 ...
- ThinkPHP中如何获取指定日期后工作日的具体日期
思路: 1.获取到查询年份内所有工作日数据数组2.获取到查询开始日期在工作日的索引3.计算需查询日期索引4.获得查询日期 /*创建日期类型记录表格*/ CREATE TABLE `tb_workday ...
- cf787c 博弈论+记忆化搜索
好题,单纯的就是pn状态的推导 /* 把第一个点标为0,剩下的点按1-n-1编号 胜态是1,败态为0,dp[i][j]表示第i个人,怪兽起始位置在j时的胜负态 把0点设置为必败态,然后对于一个人来说, ...
- hdu4638 莫队算法
莫队算法基础题,题目看懂就能做出来 #include<iostream> #include<cstring> #include<cstdio> #include&l ...
- PyCharm 新建文件时默认添加作者时间等
将内容添加到Python Script 右侧的文本框中: 路径: File → Setting → Editor → File and Code Templates → Python Script # ...
- 创建表空间tablespace,删除
在plsql工具中执行以下语句,可建立Oracle表空间. /*分为四步 *//*第1步:创建临时表空间 */create temporary tablespace yuhang_temp temp ...