import tensorflow as tf
from numpy.random import RandomState batch_size = 8
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')
w1= tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1) # 定义损失函数使得预测少了的损失大,于是模型应该偏向多的方向预测。
loss_less = 10
loss_more = 1
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_) * loss_more, (y_ - y) * loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) rdm = RandomState(1)
X = rdm.rand(128,2)
Y = [[x1+x2+(rdm.rand()/10.0-0.05)] for (x1, x2) in X] with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size) % 128
end = (i*batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
print("After %d training step(s), w1 is: " % (i))
print sess.run(w1), "\n"
print "Final w1 is: \n", sess.run(w1)

loss_less = 1
loss_more = 10
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_) * loss_more, (y_ - y) * loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size) % 128
end = (i*batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
print("After %d training step(s), w1 is: " % (i))
print sess.run(w1), "\n"
print "Final w1 is: \n", sess.run(w1)

loss = tf.losses.mean_squared_error(y, y_)
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size) % 128
end = (i*batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
print("After %d training step(s), w1 is: " % (i))
print sess.run(w1), "\n"
print "Final w1 is: \n", sess.run(w1)

吴裕雄 python 神经网络——TensorFlow 自定义损失函数的更多相关文章

  1. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  2. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  3. 吴裕雄 python 神经网络——TensorFlow 训练过程的可视化 TensorBoard的应用

    #训练过程的可视化 ,TensorBoard的应用 #导入模块并下载数据集 import tensorflow as tf from tensorflow.examples.tutorials.mni ...

  4. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  5. 吴裕雄 python 神经网络——TensorFlow 数据集高层操作

    import tempfile import tensorflow as tf train_files = tf.train.match_filenames_once("E:\\output ...

  6. 吴裕雄 python 神经网络——TensorFlow 输入数据处理框架

    import tensorflow as tf files = tf.train.match_filenames_once("E:\\MNIST_data\\output.tfrecords ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣识别2

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用滑动平均

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

随机推荐

  1. windows_环境变量

    今天在windows下安装了openSSH,就想在windows下也添加一个服务器地址的环境变量,添加是会了,引用却不会,网上也没找到相应的文章,就自己看了看别人的bat程序,然后找到了引用变量的方法 ...

  2. Linux中通配符

    通配符是由shell处理的, 它只会出现在 命令的“参数”里.当shell在“参数”中遇到了通配符时,shell会将其当作路径或文件名去在磁盘上搜寻可能的匹配:若符合要求的匹配存在,则进行代换(路径扩 ...

  3. 熬最深的夜喝最劣的酒————浅谈生成器(generator)

    测试(test)def s(): print("stup1") n = "第一步" yield n # 类似于return 但是又不同于 赖克宝,剁一下,跳一下 ...

  4. 【转载】Hadoop面试(1)

    转自:http://www.cnblogs.com/xiaolong1032/p/4504992.html 列举出hadoop常用的一些InputFormat InputFormat是用来对我们的输入 ...

  5. iview渲染函数

    <Table border :columns="discountColumns" :data="discountData.rows"></Ta ...

  6. Redis04——Redis五大数据类型 key

    key  keys *  查看当前库的所有键  exists <key>  判断某个键是否存在  type <key>   查看键的类型  del<key>  删除 ...

  7. stl队列

    队列(Queue)也是一种运算受限的线性表,它的运算限制与栈不同,是两头都有限制,插入只能在表的一端进行(只进不出),而删除只能在表的另一端进行(只出不进),允许删除的一端称为队尾(rear),允许插 ...

  8. 【PAT甲级】1111 Online Map (30分)(dijkstra+路径记录)

    题意: 输入两个正整数N和M(N<=500,M<=N^2),分别代表点数和边数.接着输入M行每行包括一条边的两个结点(0~N-1),这条路的长度和通过这条路所需要的时间.接着输入两个整数表 ...

  9. sublime-text3 安装 emmet 插件

    下载sublime,http://www.sublimetext.com/ 安装package control :https://packagecontrol.io/ins... 这个地址需要翻墙,访 ...

  10. SOCV/POCV 开篇 (1)

    1.功能:模拟工艺偏差对芯片性能的影响 2. 40nm之前 flat derate模型可以基本覆盖大部分情况 3.AOCV (Adance OCV) 考虑distance 和depth的影响. AOC ...