线性回归,附tensorflow实现
本文同步自:https://zhuanlan.zhihu.com/p/30738405
本文旨在通过介绍线性回归来引出一些基本概念:h(x),J(θ),梯度下降法
有一组数据:
x=[1,2,3,4,5,6,7,8,9,10]
y=[1,2,3,4,5,6,7,8,9,10]
要求画一条过原点的直线,穿过上述所有点
这组数据在二维平面表现如下

引入概念,假设函数:h(x)。h代表hypothesis
由于是过原点的直线,所以可以列出方程h(x):

先随意假设一个 ,在这先假设
=0.5 ,函数图如下

显而易见这条直线并不是我们想要的。那么具体的,怎么判断一条直线的好坏呢
引入概念代价函数 cost function
在本例中,可以由拟合数据和原始数据对应点的误差的平方的均值来判断直线的好坏;列出J(θ)如下:

其中m表示数据的总量,在本题中为10; 并不是代表x的i次方,而是代表第i个x的数值,例如在本例中,
为2
将h(x)带入,得

函数图是这样一个形状,数值对不上,凑合着看吧;有一点值得注意,在J(θ)中, 与
都应该作为常量来处理

显然,J(θ)值越小,点到直线的距离总和越少,画出来的直线效果也就越好。放到题目中就是当J(θ)=0的时候,画出的直线穿过了所有的点
那么问题就变成了如何最小化J(θ)
在这个例子中可以手动计算,也就是正规方程法,但是随着问题复杂度的增加,正规方程法的实用性会越来越低
引入梯度下降法

其中α为步幅
梯度下降法可以解释为:对J(θ)求关于 (本例中只有
)的偏导数并乘以步幅,再用
减去该值,得到的结果赋值给
。此过程需要重复多次
步幅的选择会直接关系到梯度下降法的效果,如下图

当选取了一个较小步幅的时候,将正确收敛
当选取了一个较大步幅的时候,将震荡收敛
当选取了一个过大步幅的时候,将无法收敛
调整一下例子
x=[5,6,7,8,9,10,11,12,13,14]
y=[1,2,3,4,5,6,7,8,9,10]
要求画一条直线,穿过上述所有点
很明显,对于这组数据,仅仅是过原点的直线无法满足要求,所以列出新的h(x):

而判断一条直线的好坏还可以沿用之前的J(θ):

函数图是这样一个形状,数值对不上,凑合着看吧

之后就是如何最小化J(θ)的问题了。下面给出tensorflow的代码实现
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Parameters
learning_rate = 0.05
training_epochs = 2000
display_step = 50
# Training Data
train_X = np.asarray([5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0])
train_Y = train_X - 4
paint_X = np.asarray([-100.0, 100.0])
n_samples = train_X.shape[0]
# tf Graph Input
X = tf.placeholder("float")
Y = tf.placeholder("float")
# Set model weights
W = tf.Variable(-10., name="weight")
b = tf.Variable(10., name="bias")
# Construct a linear model
pred = tf.add(tf.multiply(X, W), b)
# Mean squared error
cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples)
# Gradient descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# Initializing the variables
init = tf.global_variables_initializer()
# Launch the graph
plt.figure()
plt.ion()
with tf.Session() as sess:
sess.run(init)
# Fit all training data
for epoch in range(training_epochs):
for (x, y) in zip(train_X, train_Y):
sess.run(optimizer, feed_dict={X: x, Y: y})
# Display logs per epoch step
if (epoch + 1) % display_step == 0:
c = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c),
"W=", sess.run(W), "b=", sess.run(b))
plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
plt.pause(0.001)
plt.clf()
print("Optimization Finished!")
training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n')
plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
plt.pause(10)
线性回归,附tensorflow实现的更多相关文章
- 逻辑回归,附tensorflow实现
本文旨在通过二元分类问题.多元分类问题介绍逻辑回归算法,并实现一个简单的数字分类程序 在生活中,我们经常会碰到这样的问题: 根据苹果表皮颜色判断是青苹果还是红苹果 根据体温判断是否发烧 这种答案只有两 ...
- 简单的线性回归问题-TensorFlow+MATLAB·
首先我们要试验的是 人体脂肪fat和年龄age以及体重weight之间的关系,我们的目标就是得到一个最优化的平面来表示三者之间的关系: TensorFlow的程序如下: import tensorfl ...
- 利用VGG19实现火灾分类(附tensorflow代码及训练集)
源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...
- 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)
源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...
- 经典损失函数:交叉熵(附tensorflow)
每次都是看了就忘,看了就忘,从今天开始,细节开始,推一遍交叉熵. 我的第一篇CSDN,献给你们(有错欢迎指出啊). 一.什么是交叉熵 交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的.给定两 ...
- Tensorflow之单变量线性回归问题的解决方法
跟着网易云课堂上面的免费公开课深度学习应用开发Tensorflow实践学习,学到线性回归这里感觉有很多需要总结,梳理记录下阶段性学习内容. 题目:通过生成人工数据集合,基于TensorFlow实现y= ...
- Tensorflow学习笔记01
Tensorflow官方网站:http://tensorflow.org/ 极客学院Tensorflow中文版:http://wiki.jikexueyuan.com/project/tensorfl ...
- TensorFlow 从零到helloWorld
目录 1.git安装与使用 1.1 git安装 1.2 修改git bash默认路径 1.3 git常用操作 2.环境搭建 2.1 tensorflow安装 2.2 CUDA安装 2.3 ...
- TF linear regression
本文的作者 Nishant Shukla 为加州大学洛杉矶分校的机器视觉研究者,从事研究机器人机器学习技术.Nishant Shukla 一直以来兼任 Microsoft.Facebook 和 Fou ...
随机推荐
- commons-pool与commons-pool2连接池(Hadoop连接池)
commons-pool和commons-pool2是用来建立对象池的框架,提供了一些将对象池化必须要实现的接口和一些默认动作.对象池化之后可以通过pool的概念去管理其生命周期,例如对象的创建,使用 ...
- OpenGL ES2.0光照
一.简单光照原理 平行光(正常光) 光照效果= 环境颜色 + 漫反射颜色 + 镜面反射颜色 点光源 光照效果= 环境颜色 + (漫反射颜色 + 镜面反射颜色)× 衰减因子 聚光灯 光照效果= ...
- 作为前端Web开发者,这12个终端命令不可不会
对于开发人员来说,终端是最重要的工具之一.掌握终端,能够有效的提升开发人员的工作流程.使用终端,许多日常任务都被简化为了编写简单的命令并按下 Enter 按钮. 本文列举了一系列 Linux 命令,旨 ...
- 根据Dockerfile创建docker dotnet coer 镜像
那我们先来看看Dockerfile文件内容,注意这个文件是没后缀名的. #依赖原始的镜像,因为我们是要创建dotnet coer镜像,所以我就用了官方给的镜像[microsoft/dotnet:lat ...
- YYHS-NOIP模拟赛-mine
题解 这道题不难想到用dp来做 dp[i][0]表示第i个格子放0 dp[i][1]表示第i个格子放1且第i-1个格子放雷 dp[i][2]表示第i个格子放2 dp[i][3]表示第i个格子放1且第i ...
- Akka(28): Http:About Akka-Http
众所周知,Akka系统是基于Actor模式的分布式运算系统,非常适合构建大数据平台.所以,无可避免地会出现独立系统之间.与异类系统.与移动系统集成的需求.由于涉及到异类和移动系统,系统对接的方式必须在 ...
- JS中的类型识别
JS为弱类型语言,所以类型识别对JS而言尤为重要,JS中常用的类型识别方法有4种:typeof.Object.prototype.toString.constructor和instanceof. (1 ...
- vim-ultisnips补全功能失效,无法识别解决办法
昨天又给vim配了一堆插件 发现了一个这个问题,vim的ultisnips插件不能用了! 首先,我先查看插件是否正常运行了 :script 从一堆正在运行插件里找到ultisnips的名字,说明插件正 ...
- sublime中安装package control总是失败
今天下载了个sublime编辑器,要运行vue文件,想让vue也能高亮显示,在网上搜了一下如何安装.但总是提示控制器没有安装Package Control:There are no packages ...
- MongoDB监控
1. mongostat:间隔固定时间获取mongodb的当前运行状态,并输出. 使用示例: D:\Program_Files\MongoDB\bin\mongostat(根据MongoDB的安装目录 ...