机器学习笔记(2):线性回归-使用gluon
代码来自:https://zh.gluon.ai/chapter_supervised-learning/linear-regression-gluon.html
from mxnet import ndarray as nd
from mxnet import autograd
from mxnet import gluon num_inputs = 2
num_examples = 1000 true_w = [2, -3.4]
true_b = 4.2 X = nd.random_normal(shape=(num_examples, num_inputs)) #1000行,2列的数据集
y = true_w[0] * X[:, 0] + true_w[1] * X[:, 1] + true_b #已知答案的结果
y += .01 * nd.random_normal(shape=y.shape) #加入噪音 #1 随机读取10行数据
batch_size = 10
dataset = gluon.data.ArrayDataset(X, y)
data_iter = gluon.data.DataLoader(dataset, batch_size, shuffle=True) #2 定义回归模型
net = gluon.nn.Sequential()
net.add(gluon.nn.Dense(1)) #3 参数初始化
net.initialize() #4 损失函数
square_loss = gluon.loss.L2Loss() #5 指定训练方法
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1}) #6 训练
epochs = 5
batch_size = 10
for e in range(epochs):
total_loss = 0
for data, label in data_iter:
with autograd.record():
output = net(data)
loss = square_loss(output, label)
loss.backward()
trainer.step(batch_size)
total_loss += nd.sum(loss).asscalar()
print("Epoch %d, average loss: %f" % (e, total_loss/num_examples)) #7 输出结果
dense = net[0]
print(true_w)
print(dense.weight.data())
print(true_b)
print(dense.bias.data())
相对上一篇纯手动的处理方式,用gluon后代码明显更精简了。
机器学习笔记(2):线性回归-使用gluon的更多相关文章
- coursera机器学习笔记-多元线性回归,normal equation
#对coursera上Andrew Ng老师开的机器学习课程的笔记和心得: #注:此笔记是我自己认为本节课里比较重要.难理解或容易忘记的内容并做了些补充,并非是课堂详细笔记和要点: #标记为<补 ...
- Stanford机器学习笔记-1.线性回归
Content: 1. Linear Regression 1.1 Linear Regression with one variable 1.1.1 Gradient descent algorit ...
- Andrew Ng机器学习课程笔记--week1(机器学习介绍及线性回归)
title: Andrew Ng机器学习课程笔记--week1(机器学习介绍及线性回归) tags: 机器学习, 学习笔记 grammar_cjkRuby: true --- 之前看过一遍,但是总是模 ...
- 机器学习笔记5-Tensorflow高级API之tf.estimator
前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...
- Python机器学习笔记:sklearn库的学习
网上有很多关于sklearn的学习教程,大部分都是简单的讲清楚某一方面,其实最好的教程就是官方文档. 官方文档地址:https://scikit-learn.org/stable/ (可是官方文档非常 ...
- Python机器学习笔记:不得不了解的机器学习面试知识点(1)
机器学习岗位的面试中通常会对一些常见的机器学习算法和思想进行提问,在平时的学习过程中可能对算法的理论,注意点,区别会有一定的认识,但是这些知识可能不系统,在回答的时候未必能在短时间内答出自己的认识,因 ...
- 机器学习笔记(4):多类逻辑回归-使用gluton
接上一篇机器学习笔记(3):多类逻辑回归继续,这次改用gluton来实现关键处理,原文见这里 ,代码如下: import matplotlib.pyplot as plt import mxnet a ...
- cs229 斯坦福机器学习笔记(一)-- 入门与LR模型
版权声明:本文为博主原创文章,转载请注明出处. https://blog.csdn.net/Dinosoft/article/details/34960693 前言 说到机器学习,非常多人推荐的学习资 ...
- Python机器学习笔记 集成学习总结
集成学习(Ensemble learning)是使用一系列学习器进行学习,并使用某种规则把各个学习结果进行整合,从而获得比单个学习器显著优越的泛化性能.它不是一种单独的机器学习算法啊,而更像是一种优 ...
- 机器学习笔记:Gradient Descent
机器学习笔记:Gradient Descent http://www.cnblogs.com/uchihaitachi/archive/2012/08/16/2642720.html
随机推荐
- zabbix系列(四)Zabbix3.0.4添加对Nginx服务的监控
Zabbix3.0.4添加对Nginx服务的监控 通过Nginx的http_stub_status_module模块提供的状态信息来监控,所以在Agent端需要配置Nginx状态获取的脚本,和添加ke ...
- 最新 macOS Sierra 10.12.3 安装CocoaPods及使用详解
一.什么是CocoaPods 每种语言发展到一个阶段,就会出现相应的依赖管理工具,例如 Java 语言的 Maven,nodejs 的 npm.随着 iOS 开发者的增多,业界也出现了为 iOS 程序 ...
- 转载:第2章 Nginx的配置 概述《深入理解Nginx》(陶辉)
原文:https://book.2cto.com/201304/19623.html Nginx拥有大量官方发布的模块和第三方模块,这些已有的模块可以帮助我们实现Web服务器上很多的功能.使用这些模块 ...
- 100以内与7有关的数(for和if)
- Numpy中stack(),hstack(),vstack()函数详解
一`.stack 按指定维度堆叠数组. stack(a, b) 维度计算 axis=0: 2*m*n axis=1: m*2*n axis=-1: m*n*2 a = np.arange( ...
- python 全栈开发,Day65(索引)
索引 一.索引的介绍 数据库中专门用于帮助用户快速查找数据的一种数据结构.类似于字典中的目录,查找字典内容时可以根据目录查找到数据的存放位置吗,然后直接获取. 二 .索引的作用 约束和加速查找 三.常 ...
- 《剑指offer》-连续子数组的最大和
题目描述 HZ偶尔会拿些专业问题来忽悠那些非计算机专业的同学.今天测试组开完会后,他又发话了:在古老的一维模式识别中,常常需要计算连续子向量的最大和,当向量全为正数的时候,问题很好解决.但是,如果向量 ...
- BZOJ5045 打砖块 2017年9月月赛 其他
欢迎访问~原文出处——博客园-zhouzhendong 去博客园看该题解 题目传送门 - BZOJ5045 题意概括 有一堵墙. 现在挖掉某些砖.如果有相邻的某两个砖没有了,那么他们中上方的那块也没了 ...
- Scrapy爬虫笔记 - 爬取知乎
cookie是一种本地存储机制,cookie是存储在本地的 session其实就是将用户信息用户名.密码等)加密成一串字符串,返回给浏览器,以后浏览器每次请求都带着这个sessionId 状态码一般是 ...
- hdu 5748 Bellovin【最长上升子序列】
题目链接:https://vjudge.net/contest/148584#problem/A 题目大意: 解题思路:题目要求为:输出与已知序列的每一个元素的f(i)(f(i)的定义如题)相同的字典 ...