Tensorflow之单变量线性回归问题的解决方法
跟着网易云课堂上面的免费公开课深度学习应用开发Tensorflow实践学习,学到线性回归这里感觉有很多需要总结,梳理记录下阶段性学习内容。
题目:通过生成人工数据集合,基于TensorFlow实现y=2*x+1线性回归
使用Tensorflow进行算法设计与训练的核心步骤
(1)准备数据
(2)构建模型
(3)训练模型
(4)进行预测
#线性回归问题 #******************一、准备数据:********************** #生成人工数据集 # 在Jupter中,使用matplotlib显示图像需要设置为inline模式,否则不会显示图像
%matplotlib inline import matplotlib.pyplot as plt #载入matplotlib,用于绘图
import numpy as np #载入numpy,numpy是Python进行科学计算时的基础模块
import tensorflow as tf #载入Tensorflow #设置随机种子。训练之后结果随机,随机种子起到固定初始值的作用,为了训练之后得到一样的结果
np.random.seed(5)
#直接采用np生成等差数列的方法,生成100个点,每个点的取值在-1~1之间
x_data = np.linspace(-1,1,100) # y = 2x +1 + 噪声,其中,噪声的维度与x_data一致
y_data = 2 * x_data + 1.0 + np.random.randn(*x_data.shape) * 0.4 #***********************二、构建线性模型************************* #定义训练数据的占位符,x是特征,y是标签值
x = tf.placeholder("float",name= "x")
y = tf.placeholder("float",name = "y") #定义模型函数
def model(x,w,b):
return tf.multiply(x,w) + b #定义模型结构
#Tensorflow变量的声明函数是tf.Variable。tf.Variable的作用是保存和更新函数,变量的初始值可以是随机数、常数,或是通过其他变量的初始值计算得到
#构建线性函数的斜率,变量w
w = tf.Variable(1.0,name = "w0")
#构建线性函数的截距,变量b
b = tf.Variable(0.0,name = "b0") #pred是预测值,前向计算
pred = model(x,w,b) #************************三、训练模型*******************************
#设置训练参数
#迭代次数(训练轮数)
train_epochs = 10 #学习率
learning_rate = 0.05 #定义优化器、最小损失函数 #定义损失函数,损失函数用于描述预测值与真实值之间的差别,从而指导模型收敛方向。常见损失函数:均方差、交叉熵
#采用均方差作为损失函数
loss_function = tf.reduce_mean(tf.square(y-pred)) #定义优化器
#定义优化器Optimizer,初始化一个GradientDescentOptimizer(梯度下降优化器)
#设置学习率和优化目标:最小化损失
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function) #创建会话
#声明会话
sess = tf.Session()
#变量初始化
#在真正执行计算前,需要将所有变量初始化。通过tf.global_variables_initializer函数可实现对所有变量的初始化
init = tf.global_variables_initializer()
sess.run(init) #迭代训练
#模型训练阶段,设置迭代轮次,每次通过将样本逐个输入模型,进行梯度下降优化操作。每轮迭代后,绘制出模型曲线
#开始训练,轮次为epoch,采用SGD随机梯度下降优化方法
for epoch in range(train_epochs):
for xs,ys in zip(x_data,y_data):
_,loss = sess.run([optimizer,loss_function],feed_dict={x:xs,y:ys})
b0temp = b.eval(session=sess)
w0temp = w.eval(session=sess)
plt.plot(x_data,w0temp * x_data + b0temp) #画图 #结果查看。当训练完成后,打印查看参数。数据每次运行都可能会有所不同
print("w:",sess.run(w)) #w的值应该在2附近
print("b:",sess.run(b)) #b的值应该在1附近 #结果可视化
plt.scatter(x_data,y_data,label='Original data')
plt.plot(x_data,x_data*sess.run(w) + sess.run(b),label='Fitted line',color='r',linewidth=3)
plt.legend(loc=2) #通过参数loc指定图例位置 #*********************四、利用学习到的模型进行预测******************* x_test = 3.21 predict = sess.run(pred,feed_dict={x:x_test})
print("预测值: %f"%predict) target = 2 * x_test + 1.0
print("目标值: %f"%target)
题目二:通过生成人工数据集合,基于TensorFlow实现y=3.1234*x+2.98线性回归
# 在Jupter中,使用matplotlib显示图像需要设置为inline模式,否则不会显示图像
%matplotlib inline import matplotlib.pyplot as plt #载入matplotlib
import numpy as np #载入numpy
import tensorflow as tf #载入Tensorflow #设置随机种子
np.random.seed(5)
#直接采用np生成等差数列的方法,生成100个点,每个点的取值在-1~1之间
x_data = np.linspace(-1,1,100)
# y = 3.1234x +2.98 + 噪声, 其中, 噪声的唯度与x_data一致
y_data = 3.1234*x_data + 2.98 + np.random.randn(*x_data.shape)*0.4
x = tf.placeholder("float",name = "x")
y = tf.placeholder("float",name = "y") def model(x,w,b):
return tf.multiply(x,w)+b
# 构建线性函数的斜率, 变量w
w = tf.Variable(1.0,name="w")
# 构建线性函数的截距,变量b
b = tf.Variable(0.0, name="b0")
#pred是预测值,前向计算
pred = model(x,w,b) # 迭代次数(训练轮数)
train_epochs = 10
# 学习率
learning_rate = 0.05
# 采用均方差作为损失函数
loss_function = tf.reduce_mean(tf.square(y-pred))
# 梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) # 开始训练,轮数为 epoch,采用SGD随机梯度下降优化方法
#zip为组装,x,y都为一维数组. zip 把x,y组装起来也为一维数组,每个单元为(x,y) for epoch in range(train_epochs):
for xs,ys in zip(x_data, y_data):
#优化器给了一个下划线,loss_function 给了loss
_, loss=sess.run([optimizer,loss_function],feed_dict={x: xs, y:ys}) plt.scatter(x_data,y_data,label='Original data')
plt.plot(x_data,x_data*sess.run(w)+sess.run(b),\
label='Fitted line',color='r',linewidth=3)
plt.legend(loc=2)#通过参数loc指定图例位置 print("w: ", sess.run(w)) #w的值应该在3.1234附近
print("b: ",sess.run(b)) #b的值应该在2.98附近
Tensorflow之单变量线性回归问题的解决方法的更多相关文章
- 机器学习之单变量线性回归(Linear Regression with One Variable)
1. 模型表达(Model Representation) 我们的第一个学习算法是线性回归算法,让我们通过一个例子来开始.这个例子用来预测住房价格,我们使用一个数据集,该数据集包含俄勒冈州波特兰市的住 ...
- Coursera《machine learning》--(2)单变量线性回归(Linear Regression with One Variable)
本笔记为Coursera在线课程<Machine Learning>中的单变量线性回归章节的笔记. 2.1 模型表示 参考视频: 2 - 1 - Model Representation ...
- Ng第二课:单变量线性回归(Linear Regression with One Variable)
二.单变量线性回归(Linear Regression with One Variable) 2.1 模型表示 2.2 代价函数 2.3 代价函数的直观理解 2.4 梯度下降 2.5 梯度下 ...
- 斯坦福第二课:单变量线性回归(Linear Regression with One Variable)
二.单变量线性回归(Linear Regression with One Variable) 2.1 模型表示 2.2 代价函数 2.3 代价函数的直观理解 I 2.4 代价函数的直观理解 I ...
- 机器学习(二)--------单变量线性回归(Linear Regression with One Variable)
面积与房价 训练集 (Training Set) Size Price 2104 460 852 178 ...... m代表训练集中实例的数量x代表输入变量 ...
- python 单变量线性回归
单变量线性回归(Linear Regression with One Variable)¶ In [54]: #初始化工作 import random import numpy as np imp ...
- 【原】Coursera—Andrew Ng机器学习—课程笔记 Lecture 2_Linear regression with one variable 单变量线性回归
Lecture2 Linear regression with one variable 单变量线性回归 2.1 模型表示 Model Representation 2.1.1 线性回归 Li ...
- 机器学习 (一) 单变量线性回归 Linear Regression with One Variable
文章内容均来自斯坦福大学的Andrew Ng教授讲解的Machine Learning课程,本文是针对该课程的个人学习笔记,如有疏漏,请以原课程所讲述内容为准.感谢博主Rachel Zhang的个人笔 ...
- 【Python】机器学习之单变量线性回归 利用正规方程找到合适的参数值
[Python]机器学习之单变量线性回归 利用正规方程找到合适的参数值 本次作业来自吴恩达机器学习. 你是一个餐厅的老板,你想在其他城市开分店,所以你得到了一些数据(数据在本文最下方),数据中包括不同 ...
随机推荐
- 快速破解Goland
两种激活方式永久激活:推荐优先使用,永久有效有效期激活:如果你实在激活不了又着急使用,这是备选激活方案,简单快捷 一.永久激活 1.下载新版破解补丁 点击链接 https://pan.baidu.co ...
- 数据库TINYINT类型 参数0 mybatis取不到值
tinyint存储0的奇怪问题 数据库TINYINT类型 参数0 mybatis取不到值 postman 传参 audited =0 audited =1 两种情况 ...
- <DFS & BFS> 286 339 (BFS)364
286. Walls and Gates DFS: 思路是,搜索0的位置,每找到一个0,以其周围四个相邻点为起点,开始 DFS 遍历,并带入深度值1,如果遇到的值大于当前深度值,将位置值赋为当前深度值 ...
- 一些你不知道的css特性【一】
浏览器禁止用户在标签的style中使用js写入"!important"的特性 我们在使用jQuery设置css的时候 $('#text').css('height', '200px ...
- 【OCR技术系列之二】文字定位于切割
要做文字识别,第一步要考虑的就是怎么将每一个字符从图片中切割下来,然后才可以送入我们设计好的模型进行字符识别.现在就以下面这张图片为例,说一说最一般的字符切割的步骤是哪些. 当然,我们实际上要识别的图 ...
- 【Linux命令】安装命令(yum,rpm)
安装软件有三种方式,第一种是源码安装(源码安装需要手动安装软件,安装的目录,还需要进行编译之后才能安装),步骤比较繁琐.第二种是RPM安装,rpm安装有点像windows系统的面板,会建立统一的数据库 ...
- 解决Navicat连接远程MySQL很慢的方法
开发某应用系统连接公司的测试服务器的mysql数据库连接打开的很慢,但是连接本地的mysql数据库很快,刚开始认为可能是网络连接问题导致的,在进行 ping和route后发现网络通信都是正常的,而且在 ...
- Percona Monitoring and Management (PMM) - 快速入门
前言 数据库监控工具最常用的就是zabbix了,zabbix能将收集到的数据通过图表展示出来,并通过设置阈值及时告警.可zabbix对于文本的处理就不行了,比方说抓取数据库运行的sql,这个zabbi ...
- Java生鲜电商平台-Spring Cloud微服务架构图
- 怎样深入学习php,成为php高手!?
本文章开头我想问一句话:PHP是做什么的? 因为这是面试中会问到的一个问题,虽然它看起来很简单,回答做网站的,也就是个简单建站的水平.回答做网站后端开发的,对PHP有了一定的认识,回答做后端处理的,有 ...