# Author Qian Chenglong
import tensorflow as tf
import numpy as np #生成100个随机数据点
x_date=np.random.rand(100)
y_date=x_date*0.1+0.2 #构造一个线性模型
k=tf.Variable(0.)
b=tf.Variable(0.)
y=k*x_date+b # 二次代价函数
loss=tf.reduce_mean(tf.square(y-y_date))#最小二乘 my_optimizer=tf.train.GradientDescentOptimizer(0.2)#定义一个使用梯度下降算法的训练器
train=my_optimizer.minimize(loss)#训练目标loss最小 init=tf.global_variables_initializer()#初始化变量 with tf.Session() as sess:
sess.run(init)
for step in range(201):
sess.run(train)
if step%20==0:
print(step, '[k,b]:', sess.run([k, b]))

 API说明:

np.random.rand(100)生成100个0~1之间的随机数

tf.square():计算元素的平方

tf.reduce_mean(input_tensor, axis=None, keep_dims=False, name=None, reduction_indices=None)

计算张量的各个维度上的元素的平均值。

axis是tf.reduce_mean函数中的参数,按照函数中axis给定的维度减少input_tensor。除非keep_dims是true,否则张量的秩将在axis的每个条目中减少1。如果keep_dims为true,则缩小的维度将保留为1。 如果axis没有条目,则减少所有维度,并返回具有单个元素的张量。

参数:

  • input_tensor:要减少的张量。应该有数字类型。
  • axis:要减小的尺寸。如果为None(默认),则减少所有维度。必须在[-rank(input_tensor), rank(input_tensor))范围内。
  • keep_dims:如果为true,则保留长度为1的缩小尺寸。
  • name:操作的名称(可选)。
  • reduction_indices:axis的不支持使用的名称。
tf.Variable(initializer, name):initializer是初始化参数,可以有tf.random_normal,tf.constant,tf.constant等,name就是变量的名字,用法如下:
a1 = tf.Variable(tf.random_normal(shape=[2,3], mean=0, stddev=1), name='a1')
a2 = tf.Variable(tf.constant(1), name='a2')
a3 = tf.Variable(tf.ones(shape=[2,3]), name='a3')

运行session.run()可以:

  1. 获得你要得到的运算结果;
  2. 你所要运算的部分;
 
#qiancl 666
import tensorflow as tf
import numpy as np
#学习率
learning_rate=0.01
#最大训练步数
max_train_step=1000
#np.array()矩阵
train_X_date=np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],
[7.042],[10.791],[5.313],[7.997],[5.654],[9.27],[3.1]],dtype=np.float32)
train_Y_date=np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],
[2.827],[3.465],[1.65],[2.904],[2.42],[2.94],[1.3]],dtype=np.float32)
#样本个数
tolal_samples=train_X_date.shape[0]
#输入数据占位
x=tf.placeholder(tf.float32,[None,1])
y_=tf.placeholder(tf.float32,[None,1])
#tf.random_normal([1,1])生成【1,1】的符合正态分布的随机数
w=tf.Variable(tf.random_normal([1,1]),name="weight")
b=tf.Variable(tf.zeros([1]),name="bias")
y=tf.matmul(x,w)+b
loss=tf.reduce_sum(tf.pow(y-y_,2))/tolal_samples #创建优化器
optimizer=tf.train.GradientDescentOptimizer(learning_rate) #训练目标
train_op=optimizer.minimize(loss) #训练
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("开始训练")
for step in range(max_train_step):
sess.run(train_op, feed_dict={x: train_X_date, y_: train_Y_date})
if step % 100 == 0:
c = sess.run(loss, feed_dict={x: train_Y_date, y_: train_Y_date})
print("Step:%d, loss==%0.4f, w==%0.4f, b==%0.4f" % (step, c, sess.run(w), sess.run(b)))

使用Tensoflow实现梯度下降算法的一次线性拟合的更多相关文章

  1. 梯度下降算法的一点认识(Ng第一课)

    昨天开始看Ng教授的机器学习课,发现果然是不错的课程,一口气看到第二课. 第一课 没有什么新知识,就是机器学习的概况吧. 第二课 出现了一些听不太懂的概念.其实这堂课主要就讲了一个算法,梯度下降算法. ...

  2. ng机器学习视频笔记(二) ——梯度下降算法解释以及求解θ

    ng机器学习视频笔记(二) --梯度下降算法解释以及求解θ (转载请附上本文链接--linhxx)   一.解释梯度算法 梯度算法公式以及简化的代价函数图,如上图所示. 1)偏导数 由上图可知,在a点 ...

  3. 监督学习:随机梯度下降算法(sgd)和批梯度下降算法(bgd)

    线性回归 首先要明白什么是回归.回归的目的是通过几个已知数据来预测另一个数值型数据的目标值. 假设特征和结果满足线性关系,即满足一个计算公式h(x),这个公式的自变量就是已知的数据x,函数值h(x)就 ...

  4. [机器学习Lesson3] 梯度下降算法

    1. Gradient Descent(梯度下降) 梯度下降算法是很常用的算法,可以将代价函数J最小化.它不仅被用在线性回归上,也被广泛应用于机器学习领域中的众多领域. 1.1 线性回归问题应用 我们 ...

  5. Spark MLib:梯度下降算法实现

    声明:本文参考< 大数据:Spark mlib(三) GradientDescent梯度下降算法之Spark实现> 1. 什么是梯度下降? 梯度下降法(英语:Gradient descen ...

  6. AI-2.梯度下降算法

    上节定义了神经网络中几个重要的常见的函数,最后提到的损失函数的目的就是求得一组合适的w.b 先看下损失函数的曲线图,如下 即目的就是求得最低点对应的一组w.b,而本节要讲的梯度下降算法就是会一步一步地 ...

  7. Logistic回归Cost函数和J(θ)的推导(二)----梯度下降算法求解最小值

    前言 在上一篇随笔里,我们讲了Logistic回归cost函数的推导过程.接下来的算法求解使用如下的cost函数形式: 简单回顾一下几个变量的含义: 表1 cost函数解释 x(i) 每个样本数据点在 ...

  8. 梯度下降算法对比(批量下降/随机下降/mini-batch)

    大规模机器学习: 线性回归的梯度下降算法:Batch gradient descent(每次更新使用全部的训练样本) 批量梯度下降算法(Batch gradient descent): 每计算一次梯度 ...

  9. tensorflow随机梯度下降算法使用滑动平均模型

    在采用随机梯度下降算法训练神经网络时,使用滑动平均模型可以提高最终模型在测试集数据上的表现.在Tensflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模 ...

随机推荐

  1. linux系统编程:cp的另外一种实现方式

    之前,这篇文章:linux系统编程:自己动手写一个cp命令 已经实现过一个版本. 这里再来一个版本,涉及知识点: linux系统编程:open常用参数详解 Linux系统编程:简单文件IO操作 /*= ...

  2. POJ1651(KB-E)

    Multiplication Puzzle Time Limit: 1000MS Memory Limit: 65536K  Total Submissions: 10034 Accepted: 62 ...

  3. 编写hadoop程序,并打包jar到hadoop集群运行

    windows环境下编写hadoop程序 新建:File->new->Project->Maven->next GroupId 和ArtifactId 随便写(还是建议规范点) ...

  4. 常见Java问题

    1.泛型的好处 保护了类型安全 避免了强制类型转化 2.final关键字的作用 final修饰的属性是常量 final修饰的方法不可被重写 final修饰的类不能被继承,如:String 3.静态变量 ...

  5. python乐观锁、悲观锁

    二.乐观锁总是认为不会产生并发问题,每次去取数据的时候总认为不会有其他线程对数据进行修改,因此不会上锁,但是在更新时会判断其他线程在这之前有没有对数据进行修改 三.悲观锁总是假设最坏的情况,每次取数据 ...

  6. 鼠标滚轮更改transform的值(vue-scroller在PC端的上下滑动)

    目前上拉刷新,下拉加载,以及区域回弹的组件,绝大多数都是通过transform去实现的.在移动端效果很好,但是PC端使用鼠标拖拽的方式,查看下文首先不符合逻辑,其次容易点进其他页面. 起初,项目的初衷 ...

  7. PDO预处理语句

    1.造PDO对象$dsn = "mysql:dbname=mydb;host=localhost";$pdo = new PDO($dsn,"root",&qu ...

  8. 微信小程序获取Access_token和页面URL生成小程序码或二维码

    1.微信小程序获取Access_token: access_token具体时效看官方文档. using System; using System.Collections.Generic; using ...

  9. IIS8发布Asp.net MVC程序后出现404错误,处理程序staticFile

    新部署的虚拟机,运行Asp.net MVC程序,出现如下图错误: 解决方法: 添加功能和角色->添加角色->Web服务器IIS->应用程序开发->Asp.net3.5 /Asp ...

  10. apache web服务器安全配置

    尽管现在购买的云服务器很多都有一键web环境安装包,但是如果是自己配置web环境则需要对各种安全配置十分了解,今天我们就来尝试这做好web服务器安全配置.这里的配置不尽完善,若有纰漏之处还望指出. 修 ...