import tensorflow as tf
from numpy.random import RandomState batch_size = 8
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')
w1= tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1) # 定义损失函数使得预测少了的损失大,于是模型应该偏向多的方向预测。
loss_less = 10
loss_more = 1
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_) * loss_more, (y_ - y) * loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) rdm = RandomState(1)
X = rdm.rand(128,2)
Y = [[x1+x2+(rdm.rand()/10.0-0.05)] for (x1, x2) in X] with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size) % 128
end = (i*batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
print("After %d training step(s), w1 is: " % (i))
print sess.run(w1), "\n"
print "Final w1 is: \n", sess.run(w1)

loss_less = 1
loss_more = 10
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_) * loss_more, (y_ - y) * loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size) % 128
end = (i*batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
print("After %d training step(s), w1 is: " % (i))
print sess.run(w1), "\n"
print "Final w1 is: \n", sess.run(w1)

loss = tf.losses.mean_squared_error(y, y_)
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size) % 128
end = (i*batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
print("After %d training step(s), w1 is: " % (i))
print sess.run(w1), "\n"
print "Final w1 is: \n", sess.run(w1)

吴裕雄 python 神经网络——TensorFlow 自定义损失函数的更多相关文章

  1. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  2. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  3. 吴裕雄 python 神经网络——TensorFlow 训练过程的可视化 TensorBoard的应用

    #训练过程的可视化 ,TensorBoard的应用 #导入模块并下载数据集 import tensorflow as tf from tensorflow.examples.tutorials.mni ...

  4. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  5. 吴裕雄 python 神经网络——TensorFlow 数据集高层操作

    import tempfile import tensorflow as tf train_files = tf.train.match_filenames_once("E:\\output ...

  6. 吴裕雄 python 神经网络——TensorFlow 输入数据处理框架

    import tensorflow as tf files = tf.train.match_filenames_once("E:\\MNIST_data\\output.tfrecords ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣识别2

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用滑动平均

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

随机推荐

  1. JavaWeb——第1章Web技术概述

    Web本意是蜘蛛网的意思,现常指Internet的Web技术.Web技术提供了方便的信息发布和交流方式,是一种典型的分布式应用结构,Web应用中的每一次信息交换都要涉及客户端和服务器. 一.Inter ...

  2. python的os库

    os库(operating system,提供操作系统函数) 1. __file__是什么? ans:当前文件的名字. 例如r.py内容如下 import os if __name__ == &quo ...

  3. EAC3 channel & program extension

    EAC3 bit stream syntax允许在single bitstream中存在time-multiplexed substreams. 在EAC3的signle bitstream中,允许s ...

  4. 在页面跳转的时候,在跳转后的页面中使用js 获取到 页面跳转的url中携带的参数。

    common.js代码 //获取URL中的参数..等等function getQueryString(name){var reg = new RegExp("(^|&)"+ ...

  5. 在springboot项目中引入quartz任务调度器。

    quartz是一个非常强大的任务调度器.我们可能使用它来管理我们的项目,常见的是做业绩统计等等.当然它的功能远不止这些.我们在这里不介绍quartz的原理,下面讲讲如何在springboot中使用qu ...

  6. 445. 两数相加 II

    Q: A: 这种题的用例是一定会搞一些很大的数的.long都会溢出,所以我们就不用尝试转数字做加法转链表的方法了.另外直接倒置两个链表再做加法的做法会改变原链表,题干也说了禁止改动原链表. 1.求两个 ...

  7. 使用prepareStatement执行的sql语句的写法:

    使用prepareStatement对象执行的增.删.改.查sql语句: 查:  String sql = "SELECT * FROM 表名 WHERE loginId=? AND pas ...

  8. 解决ubuntu和win10双系统时间不一致

    1.在ubuntu下安装ntpdate sudo apt install ntpdate 2.设置同步windows时间 sudo ntpdate time.windows.com 3.把时间更新到硬 ...

  9. 洛谷 P2058 海港(模拟)

    题目链接:https://www.luogu.com.cn/problem/P2058 这是一道用手写队列模拟的一道题,没有什么细节,只是注意因为数不会很大,所以直接用数作为数组下标即可,不用用map ...

  10. jsTree的checkbox默认选中和隐藏

    jstree复选框自定义显示隐藏和初始化默认选中 首先需要配置 Checkbox plugin "plugins" : ['checkbox'] 设置默认选中状态(checkbox ...