这里将讲解tensorflow是如何通过计算图来更新变量和最小化损失函数来反向传播误差的;这步将通过声明优化函数来实现。一旦声明好优化函数,tensorflow将通过它在所有的计算图中解决反向传播的项。当我们传入数据,最小化损失函数,tensorflow会在计算图中根据状态相应的调节变量。

  这里先举一个简单的例子,从均值1,标准差为0.1的正态分布中随机抽样100个数,然后乘以变量A,损失函数L2正则函数,也就是实现函数X*A=target,X为100个随机数,target为10,那么A的最优结果也为10。

  实现如下:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops ops.reset_default_graph() # 创建计算图
sess = tf.Session() #生成数据,100个随机数x_vals以及100个目标数y_vals
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
#声明x_data、target占位符
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32) # 声明变量A
A = tf.Variable(tf.random_normal(shape=[1])) #乘法操作,也就是例子中的X*A
my_output = tf.multiply(x_data, A) #增加L2正则损失函数
loss = tf.square(my_output - y_target) # 初始化所有变量
init = tf.initialize_all_variables()
sess.run(init) #声明变量的优化器;大部分优化器算法需要知道每步迭代的步长,这距离是由学习控制率。
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss) #训练,将损失值加入数组loss_batch
loss_batch = []
for i in range(100):
rand_index = np.random.choice(100)
rand_x = [x_vals[rand_index]]
rand_y = [y_vals[rand_index]]
sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
print('Step #' + str(i + 1) + ' A = ' + str(sess.run(A)))
print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))
temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
loss_batch.append(temp_loss) plt.plot( loss_batch, 'r--', label='Batch Loss, size=20')
plt.legend(loc='upper right', prop={'size': 11})
plt.show()

输出结果(输出A以及对应的损失函数):

Step #1 A = [ 0.08779037]
Loss = [ 98.3597641]
Step #2 A = [ 0.48817557]
Loss = [ 90.38272095]
Step #3 A = [ 0.85985768]
Loss = [ 83.92495728]
Step #4 A = [ 1.289047]
Loss = [ 71.54370117]

.........

Step #98 A = [ 10.10386372]
Loss = [ 0.00271681]
Step #99 A = [ 10.10850525]
Loss = [ 0.01301978]
Step #100 A = [ 10.07686806]
Loss = [ 0.5048126]

对于损失函数看这里:tensorflow进阶篇-4(损失函数1)tensorflow进阶篇-4(损失函数2)tensorflow进阶篇-4(损失函数3)

tensorflow进阶篇-5(反向传播1)的更多相关文章

  1. tensorflow进阶篇-5(反向传播2)

    上面是一个简单的回归算法,下面是一个简单的二分值分类算法.从两个正态分布(N(-1,1)和N(3,1))生成100个数.所有从正态分布N(-1,1)生成的数据目标0:从正态分布N(3,1)生成的数据标 ...

  2. tensorflow学习笔记(2)-反向传播

    tensorflow学习笔记(2)-反向传播 反向传播是为了训练模型参数,在所有参数上使用梯度下降,让NN模型在的损失函数最小 损失函数:学过机器学习logistic回归都知道损失函数-就是预测值和真 ...

  3. 【TensorFlow篇】--反向传播

    一.前述 反向自动求导是 TensorFlow 实现的方案,首先,它执行图的前向阶段,从输入到输出,去计算节点值,然后是反向阶段,从输出到输入去计算所有的偏导. 二.具体 1.举例 图是第二个阶段,在 ...

  4. tensorflow进阶篇-4(损失函数2)

    Hinge损失函数主要用来评估支持向量机算法,但有时也用来评估神经网络算法.下面的示例中是计算两个目标类(-1,1)之间的损失.下面的代码中,使用目标值1,所以预测值离1越近,损失函数值越小: # U ...

  5. tensorflow进阶篇-4(损失函数1)

    L2正则损失函数(即欧拉损失函数),L2正则损失函数是预测值与目标函数差值的平方和.L2正则损失函数是非常有用的损失函数,因为它在目标值附近有更好的曲度,并且离目标越近收敛越慢: # L = (pre ...

  6. tensorflow进阶篇-3

    #-*- coding:utf-8 -*- #Tensorflow的嵌入Layer import numpy as np import tensorflow as tf sess=tf.Session ...

  7. tensorflow进阶篇-4(损失函数3)

    Softmax交叉熵损失函数(Softmax cross-entropy loss)是作用于非归一化的输出结果只针对单个目标分类的计算损失.通过softmax函数将输出结果转化成概率分布,然后计算真值 ...

  8. [2] TensorFlow 向前传播算法(forward-propagation)与反向传播算法(back-propagation)

    TensorFlow Playground http://playground.tensorflow.org 帮助更好的理解,游乐场Playground可以实现可视化训练过程的工具 TensorFlo ...

  9. Tensorflow笔记——神经网络图像识别(一)前反向传播,神经网络八股

      第一讲:人工智能概述       第三讲:Tensorflow框架         前向传播: 反向传播: 总的代码: #coding:utf-8 #1.导入模块,生成模拟数据集 import t ...

随机推荐

  1. C# 编码标准(三)

    一.代码注释 1.文档型注释 该类注释采用.Net已定义好的Xml标签来标记,在声明接口.类.方法.属性.字段都应该使用该类注释,以便代码完成后直接生成代码文档,让别人更好的了解代码的实现和接口.[示 ...

  2. java中定时器总结

    java实现定时器的四种方式: 一. /** * 延迟20000毫秒执行 java.util.Timer.schedule(TimerTask task, long delay) */ public ...

  3. Android自定义视图四:定制onMeasure强制显示为方形

    这个系列是老外写的,干货!翻译出来一起学习.如有不妥,不吝赐教! Android自定义视图一:扩展现有的视图,添加新的XML属性 Android自定义视图二:如何绘制内容 Android自定义视图三: ...

  4. (转)本地搭建环境wamp下提示不支持GD库的解决方法

    转自:http://www.zzdp.net/local-wamp-gd GD库是什么?GD库,是php处理图形的扩展库,GD库提供了一系列用来处理图片的API,使用GD库可以处理图片,或者生成图片. ...

  5. Html5与Css3知识点拾遗(二)

    页面title 选择能简要概括文档内容的文字作为title文字,title核心内容放在前60个字符 分级标题 1.创建分级标题时,避免跳过级别,如h3直接跳到h5,但允许从低级别跳到高级别. 2.不用 ...

  6. android 发送url带中文出现乱码怎么解决

    上传的时候参数中带中文的时候发送参数的时候就有可能出现乱码,这种情况怎么解决呢,就是设置url的格式为utf-8 httpRequest.setEntity(new UrlEncodedFormEnt ...

  7. whu暑期集训#1

    题号:SGU123----SGU131 Problem A: 题意:求斐波那契的前N项和.. 做法:直接模拟,注意得用long long Problem B: 题意:给定一个封闭的多边形,求一个点在不 ...

  8. Implementation of WC in JAVA

    Implementation of WC in JAVA github地址 相关要求 基本功能 -c [文件名] 返回文件的字符数 (实现) -w [文件名] 返回文件的词的数目 (实现) -l [文 ...

  9. 使用NetHogs监控进程网络使用情况

    Nethogs 是一个终端下的网络流量监控工具,它的特别之处在于可以显示每个进程的带宽占用情况,这样可以更直观获取网络使用情况.它支持 IPv4 和 IPv6 协议.支持本地网卡及 PPP 链接. 使 ...

  10. Power BI Embedded 与 Bot Framework 结合的AI报表系统

    最近最热门的话题莫过于AI了,之前我做过一片讲 BOTFRAMEWORK和微信 相结合的帖子 如何将 Microsoft Bot Framework 链接至微信公共号 我想今天基于这个题目扩展一下,P ...