线性回归,附tensorflow实现
本文同步自:https://zhuanlan.zhihu.com/p/30738405
本文旨在通过介绍线性回归来引出一些基本概念:h(x),J(θ),梯度下降法
有一组数据:
x=[1,2,3,4,5,6,7,8,9,10]
y=[1,2,3,4,5,6,7,8,9,10]
要求画一条过原点的直线,穿过上述所有点
这组数据在二维平面表现如下
引入概念,假设函数:h(x)。h代表hypothesis
由于是过原点的直线,所以可以列出方程h(x):
先随意假设一个 ,在这先假设
=0.5 ,函数图如下
显而易见这条直线并不是我们想要的。那么具体的,怎么判断一条直线的好坏呢
引入概念代价函数 cost function
在本例中,可以由拟合数据和原始数据对应点的误差的平方的均值来判断直线的好坏;列出J(θ)如下:
其中m表示数据的总量,在本题中为10; 并不是代表x的i次方,而是代表第i个x的数值,例如在本例中,
为2
将h(x)带入,得
函数图是这样一个形状,数值对不上,凑合着看吧;有一点值得注意,在J(θ)中, 与
都应该作为常量来处理
显然,J(θ)值越小,点到直线的距离总和越少,画出来的直线效果也就越好。放到题目中就是当J(θ)=0的时候,画出的直线穿过了所有的点
那么问题就变成了如何最小化J(θ)
在这个例子中可以手动计算,也就是正规方程法,但是随着问题复杂度的增加,正规方程法的实用性会越来越低
引入梯度下降法
其中α为步幅
梯度下降法可以解释为:对J(θ)求关于 (本例中只有
)的偏导数并乘以步幅,再用
减去该值,得到的结果赋值给
。此过程需要重复多次
步幅的选择会直接关系到梯度下降法的效果,如下图
当选取了一个较小步幅的时候,将正确收敛
当选取了一个较大步幅的时候,将震荡收敛
当选取了一个过大步幅的时候,将无法收敛
调整一下例子
x=[5,6,7,8,9,10,11,12,13,14]
y=[1,2,3,4,5,6,7,8,9,10]
要求画一条直线,穿过上述所有点
很明显,对于这组数据,仅仅是过原点的直线无法满足要求,所以列出新的h(x):
而判断一条直线的好坏还可以沿用之前的J(θ):
函数图是这样一个形状,数值对不上,凑合着看吧
之后就是如何最小化J(θ)的问题了。下面给出tensorflow的代码实现
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt # Parameters learning_rate = 0.05 training_epochs = 2000 display_step = 50 # Training Data train_X = np.asarray([5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0]) train_Y = train_X - 4 paint_X = np.asarray([-100.0, 100.0]) n_samples = train_X.shape[0] # tf Graph Input X = tf.placeholder("float") Y = tf.placeholder("float") # Set model weights W = tf.Variable(-10., name="weight") b = tf.Variable(10., name="bias") # Construct a linear model pred = tf.add(tf.multiply(X, W), b) # Mean squared error cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples) # Gradient descent optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Initializing the variables init = tf.global_variables_initializer() # Launch the graph plt.figure() plt.ion() with tf.Session() as sess: sess.run(init) # Fit all training data for epoch in range(training_epochs): for (x, y) in zip(train_X, train_Y): sess.run(optimizer, feed_dict={X: x, Y: y}) # Display logs per epoch step if (epoch + 1) % display_step == 0: c = sess.run(cost, feed_dict={X: train_X, Y: train_Y}) print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c), "W=", sess.run(W), "b=", sess.run(b)) plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1]) plt.plot(train_X, train_Y, 'ro', label='Original data') plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line') plt.pause(0.001) plt.clf() print("Optimization Finished!") training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y}) print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n') plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1]) plt.plot(train_X, train_Y, 'ro', label='Original data') plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line') plt.pause(10)
线性回归,附tensorflow实现的更多相关文章
- 逻辑回归,附tensorflow实现
本文旨在通过二元分类问题.多元分类问题介绍逻辑回归算法,并实现一个简单的数字分类程序 在生活中,我们经常会碰到这样的问题: 根据苹果表皮颜色判断是青苹果还是红苹果 根据体温判断是否发烧 这种答案只有两 ...
- 简单的线性回归问题-TensorFlow+MATLAB·
首先我们要试验的是 人体脂肪fat和年龄age以及体重weight之间的关系,我们的目标就是得到一个最优化的平面来表示三者之间的关系: TensorFlow的程序如下: import tensorfl ...
- 利用VGG19实现火灾分类(附tensorflow代码及训练集)
源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...
- 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)
源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...
- 经典损失函数:交叉熵(附tensorflow)
每次都是看了就忘,看了就忘,从今天开始,细节开始,推一遍交叉熵. 我的第一篇CSDN,献给你们(有错欢迎指出啊). 一.什么是交叉熵 交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的.给定两 ...
- Tensorflow之单变量线性回归问题的解决方法
跟着网易云课堂上面的免费公开课深度学习应用开发Tensorflow实践学习,学到线性回归这里感觉有很多需要总结,梳理记录下阶段性学习内容. 题目:通过生成人工数据集合,基于TensorFlow实现y= ...
- Tensorflow学习笔记01
Tensorflow官方网站:http://tensorflow.org/ 极客学院Tensorflow中文版:http://wiki.jikexueyuan.com/project/tensorfl ...
- TensorFlow 从零到helloWorld
目录 1.git安装与使用 1.1 git安装 1.2 修改git bash默认路径 1.3 git常用操作 2.环境搭建 2.1 tensorflow安装 2.2 CUDA安装 2.3 ...
- TF linear regression
本文的作者 Nishant Shukla 为加州大学洛杉矶分校的机器视觉研究者,从事研究机器人机器学习技术.Nishant Shukla 一直以来兼任 Microsoft.Facebook 和 Fou ...
随机推荐
- 使用Gradle构建Android项目
阅读目录 Gradle是什么? 环境需求 Gradle基本结构 任务task的执行 基本的构建定制 目录配置 签名配置 代码混淆设置 依赖配置 输出不同配置的应用 生成多个渠道包(以Umeng为例) ...
- 深入理解计算机系统(4.2)------逻辑设计和硬件控制语言HCL
上一篇博客我们简单介绍了Y86指令集体系,而这篇博客我们将介绍指令集体系的逻辑设计和硬件控制语言HCL,为后面去实现Y86打下基础. 在硬件设计中,用电子电路来计算对位进行运算的函数,以及在各种存储器 ...
- Extjs6(四)——侧边栏导航根据路由跳转页面
本文基于ext-6.0.0 之前做的时候这个侧边栏导航是通过tab切换来切换页面的,但是总感觉不太对劲,现在终于发现怎么通过路由跳转了,分享给大家,可能有些不完善的地方,望大家读后可以给些指点.欢迎留 ...
- hdu 1018 共同交流~
Big Number Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 K (Java/Others)Total ...
- git上传遇到 GitHub could not read Username 的解决办法
Gitversion 1.8.5.2 执行git push命令异常,如下: Push failed Failed with error: unable to read askpass response ...
- 实现一个单隐层神经网络python
看过首席科学家NG的深度学习公开课很久了,一直没有时间做课后编程题,做完想把思路总结下来,仅仅记录编程主线. 一 引用工具包 import numpy as np import matplotlib. ...
- Debug 运行正常,Release版本不能正常运行总结(转)
引言 如果在您的开发过程中遇到了常见的错误,或许您的Release版本不能正常运行而Debug版本运行无误,那么我推荐您阅读本文:因为并非如您想象的那样,Release版本可以保证您的应用程 ...
- js判断值为null
今天在做项目的时候,犯了一个着实不应该的错误,拿到data为null,然后判断如果为null执行A,否则执行B 我错误的代码是 if(data===null){ A; }else{ B; } 怎么调试 ...
- RabbitMQ-客户端
Install-Package RabbitMQ.Client 参考: http://www.rabbitmq.com/download.html https://www.nuget.org/pack ...
- LINUX环境下SVN安装与配置(利用钩子同步开发环境与测试环境)
安装采用YUM一键安装: 1.环境Centos 6.6 2.安装svnyum -y install subversion 3.配置 建立版本库目录mkdir /www/svndata svnserve ...