#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 10 09:35:04 2017

@author: myhaspl@myhaspl.com,http://blog.csdn.net/myhaspl
"""
#逻辑或
import tensorflow as tf

batch_size=10
w1=tf.Variable(tf.random_normal([2,6],stddev=1,seed=1))
w2=tf.Variable(tf.random_normal([6,1],stddev=1,seed=1))
b=tf.Variable(tf.zeros([6]),tf.float32)

x=tf.placeholder(tf.float32,shape=(None,2),name="x")
y=tf.placeholder(tf.float32,shape=(None,1),name="y")

h=tf.matmul(x,w1)+b
yo=tf.matmul(h,w2)

#损失函数计算差异平均值
cross_entropy=tf.reduce_mean(tf.abs(y-yo))
#反向传播
train_step=tf.train.AdamOptimizer(0.05).minimize(cross_entropy)

#生成样本

x_=[[0.,0.],[0.,1.],[1.,0.],[1.,1.]]
y_=[[0.],[1.],[1.],[1.]]
b_=tf.zeros([6])

with tf.Session() as sess:
    #初始化变量
    init_op=tf.global_variables_initializer()
    sess.run(init_op)
    print sess.run(w1)
    print sess.run(w2)

    #设定训练轮数
    TRAINCOUNT=500
    for i in range(TRAINCOUNT):
        #开始训练
        sess.run(train_step,feed_dict={x:x_,y:y_})
        if i%10==0:
            total_cross_entropy=sess.run(cross_entropy,feed_dict={x:x_,y:y_})
            print("%d 次训练之后,损失:%g"%(i+1,total_cross_entropy))
    print(sess.run(w1))
    print(sess.run(w2))

    #生成测试样本,仅进行前向传播验证:
    testyo=sess.run(yo,feed_dict={x:[[0.,1.],[1.,1.]]})
    myout=[int(testout>0.5) for testout in testyo]
    print myout

(Initialize initial 1st moment vector)
v_0 <- 0 (Initialize initial 2nd moment vector)
t <- 0 (Initialize timestep)

The update rule for variable with gradient g uses an optimization described at the end of section2 of the paper:

t <- t + 1
lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)

m_t <- beta1 * m_{t-1} + (1 - beta1) * g
v_t <- beta2 * v_{t-1} + (1 - beta2) * g * g
variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)

The default value of 1e-8 for epsilon might not be a good default in general. For example, when training an Inception network on ImageNet a current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the formulation just before Section 2.1 of the Kingma and Ba paper rather than the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon hat" in the paper.

The sparse implementation of this algorithm (used when the gradient is an IndexedSlices object, typically because of tf.gather or an embedding lookup in the forward pass) does apply momentum to variable slices even if they were not used in the forward pass (meaning they have a gradient equal to zero). Momentum decay (beta1) is also applied to the entire momentum accumulator. This means that the sparse behavior is equivalent to the dense behavior (in contrast to some momentum implementations which ignore momentum unless a variable slice was actually used).


TF随笔-8的更多相关文章

  1. TF随笔-13

    import tensorflow as tf a=tf.constant(5) b=tf.constant(3) res1=tf.divide(a,b) res2=tf.div(a,b) with ...

  2. TF随笔-11

    #!/usr/bin/env python2 # -*- coding: utf-8 -*- import tensorflow as tf my_var=tf.Variable(0.) step=t ...

  3. TF随笔-10

    #!/usr/bin/env python# -*- coding: utf-8 -*-import tensorflow as tf x = tf.constant(2)y = tf.constan ...

  4. TF随笔-9

    计算累加 #!/usr/bin/env python2 # -*- coding: utf-8 -*-"""Created on Mon Jul 24 08:25:41 ...

  5. TF随笔-7

    求平均值的函数 reduce_mean axis为1表示求行 axis为0表示求列 >>> xxx=tf.constant([[1., 10.],[3.,30.]])>> ...

  6. tf随笔-6

    import tensorflow as tfx=tf.constant([-0.2,0.5,43.98,-23.1,26.58])y=tf.clip_by_value(x,1e-10,1.0)ses ...

  7. tf随笔-5

    # -*- coding: utf-8 -*-import tensorflow as tfw1=tf.Variable(tf.random_normal([2,6],stddev=1))w2=tf. ...

  8. TF随笔-4

    >>> import tensorflow as tf>>> a=tf.constant([[1,2],[3,4]])>>> b=tf.const ...

  9. TF随笔-3

    >>> import tensorflow as tf>>> node1 = tf.constant(3.0, dtype=tf.float32)>>& ...

随机推荐

  1. 机器学习实战笔记(Python实现)-07-模型评估与分类性能度量

    1.经验误差与过拟合 通常我们把分类错误的样本数占样本总数的比例称为“错误率”(error rate),即如果在m个样本中有a个样本分类错误,则错误率E=a/m:相应的,1-a/m称为“精度”(acc ...

  2. cookie注入原理

    cookie注入原理-->红客联盟 http://www.2cto.com/article/201202/118837.html 前言: document.cookie:表示当前浏览器中的coo ...

  3. Hive的执行生命周期

    1.入口$HIVE_HOME/bin/ext/cli.sh 调用org.apache.hadoop.hive.cli.CliDriver类进行初始化过程 处理-e,-f,-h等信息,如果是-h,打印提 ...

  4. ubuntu 16.04 gtx1060 显卡安装【转】

    本文转载自:https://blog.csdn.net/u010925447/article/details/79754044 版权声明:本文为博主原创文章,未经博主允许不得转载. https://b ...

  5. JAVA链接数据库

    链接:http://www.cnblogs.com/centor/p/6142775.html 开发工具: MyEclipse MySQL JDBC驱动:mysql-connector-java-5. ...

  6. LeetCode——Sort List

    Question Sort a linked list in O(n log n) time using constant space complexity. Solution 分析,时间复杂度要求为 ...

  7. Graph_Master(连通分量_E_Hungry+Tarjan)

    hdu_4685 终于来写了这题的解题报告,没有在昨天A出来有点遗憾,不得不说数组开大开小真的是阻碍人类进步的一大天坑. 题目大意:给出n个王子,m个公主,只要王子喜欢,公主就得嫁(这个王子当得好霸道 ...

  8. html-w3c规范及常见标签

    W3C提倡的web结构: 内容(HTML)与表现(css样式)分离 内容(HTML)与行为(JS)分离 HTML内容结构要求语义化 基本规范: 标签名和属性名称必须小写 HTML标签必须关闭 属性值必 ...

  9. linux下增加useradd提示existing lock file /etc/subgid.lock without a PID

    # useradd git -g git useradd: existing lock file /etc/subgid.lock without a PID useradd: cannot lock ...

  10. JavaScript权威指南--语句

    知识要点 在javascript中,表达式是短语,那么语句(statement)就是整句或命令.表达式计算出一个值,但语句用来执行以使某件事发生. 1.表达式语句 具有副作用的表达式是JavaScri ...