实现方式

以 ℓ2 Regularization 为例,主要有两种实现方式

1. 手动累加

with tf.name_scope('loss'):
loss = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=logits) # label is one_hot
l2_reg_loss = tf.constant(0.0, tf.float32)
for vv in tf.trainable_variables():
if 'bn' in vv.name or 'batchnorm' in vv.name or 'batch_norm' in vv.name \
or 'batch_normalization' in vv.name or 'gn' in vv.name:
continue
else:
l2_reg_loss = tf.add(l2_reg_loss, tf.nn.l2_loss(vv))
l2_reg_loss *= 0.001
loss = loss + l2_reg_loss

2. 借助于 kernel_regularizer

with tf.name_scope('dnn'):
hidden1 = tf.layers.dense(X, 300, kernel_initializer=he_init, name='hidden1',
kernel_regularizer=tf.contrib.layers.l2_regularizer(0.001))
...... with tf.name_scope('loss'):
loss = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=logits) # label is one_hot
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = tf.add_n([loss] + reg_losses)

实例验证

import tensorflow as tf

# 1. create data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('../MNIST_data', one_hot=True) X = tf.placeholder(tf.float32, shape=(None, 784), name='X')
y = tf.placeholder(tf.int32, shape=(None), name='y')
is_training = tf.placeholder(tf.bool, None, name='is_training') # 2. define network
he_init = tf.contrib.layers.variance_scaling_initializer()
with tf.name_scope('dnn'):
hidden1 = tf.layers.dense(X, 300, kernel_initializer=he_init, name='hidden1',
kernel_regularizer=tf.contrib.layers.l2_regularizer(0.001))
hidden1 = tf.layers.batch_normalization(hidden1, momentum=0.9)
hidden1 = tf.nn.relu(hidden1)
hidden2 = tf.layers.dense(hidden1, 100, kernel_initializer=he_init, name='hidden2',
kernel_regularizer=tf.contrib.layers.l2_regularizer(0.001))
hidden2 = tf.layers.batch_normalization(hidden2, training=is_training, momentum=0.9)
hidden2 = tf.nn.relu(hidden2)
logits = tf.layers.dense(hidden2, 10, kernel_initializer=he_init, name='output',
kernel_regularizer=tf.contrib.layers.l2_regularizer(0.001)) # 3. define loss
with tf.name_scope('loss'):
loss = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=logits) # label is one_hot
# =================
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
reg_loss = tf.reduce_sum(reg_losses)
# loss = tf.add_n([loss] + reg_losses)
# =================
l2_reg_loss = tf.constant(0.0, tf.float32)
for vv in tf.trainable_variables():
if 'bn' in vv.name or 'batchnorm' in vv.name or 'batch_norm' in vv.name \
or 'batch_normalization' in vv.name or 'gn' in vv.name:
continue
else:
l2_reg_loss = tf.add(l2_reg_loss, tf.nn.l2_loss(vv))
l2_reg_loss *= 0.001
# loss = loss + l2_reg_loss
# ================= # 4. define optimizer
learning_rate_init = 0.01
global_step = tf.Variable(0, trainable=False)
with tf.name_scope('train'):
learning_rate = tf.train.polynomial_decay( # 多项式衰减
learning_rate=learning_rate_init, # 初始学习率
global_step=global_step, # 当前迭代次数
decay_steps=22000, # 在迭代到该次数实际,学习率衰减为 learning_rate * dacay_rate
end_learning_rate=learning_rate_init / 10, # 最小的学习率
power=0.9,
cycle=False
)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # for batch normalization
with tf.control_dependencies(update_ops):
optimizer_op = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=0.9).minimize(
loss=loss,
var_list=tf.trainable_variables(),
global_step=global_step # 不指定的话学习率不更新
) with tf.name_scope('eval'):
correct = tf.nn.in_top_k(logits, tf.argmax(y, axis=1), 1) # 目标是否在前K个预测中, label's dtype is int*
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) # 5. initialize
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
saver = tf.train.Saver() # 5. train & test
n_epochs = 1
batch_size = 55000
with tf.Session() as sess:
saver.restore(sess, './my_model_final.ckpt')
for epoch in range(n_epochs):
for iteration in range(mnist.train.num_examples // batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
loss_, l2_reg_loss_, reg_loss_ = sess.run([loss, l2_reg_loss, reg_loss], feed_dict={X: X_batch, y: y_batch, is_training:True})
acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch, is_training:False}) # 最后一个 batch 的 accuracy
acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels, is_training:False})
loss_test = loss.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels, is_training:False})
l2_reg_loss_test = l2_reg_loss.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels, is_training:False})
reg_loss_test = reg_loss.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels, is_training:False})
print("Train loss:", loss_, "Train l2_reg_loss:", l2_reg_loss_, "Train reg_loss:", reg_loss_, "Train accuracy:", acc_train)
print("Test loss:", loss_test, "Test l2_reg_loss:", l2_reg_loss_test, "Test reg_loss:", reg_loss_test, "Test accuracy:", acc_test) """
# =================
Train loss: 0.000636433 Train l2_reg_loss: 0.48696715 Train reg_loss: 0.48683384 Train accuracy: 1.0
Test loss: 0.059231624 Test l2_reg_loss: 0.48696715 Test reg_loss: 0.48683384 Test accuracy: 0.983
"""

TensorFlow使用记录 (十二): ℓ1 and ℓ2 Regularization的更多相关文章

  1. Spring学习记录(十二)---AOP理解和基于注解配置

    Spring核心之二:AOP(Aspect Oriented Programming) --- 面向切面编程,通过预编译方式和运行期动态代理实现程序功能的统一维护的一种技术.AOP是OOP的延续,是软 ...

  2. TensorFlow使用记录 (十四): Multi-task to MNIST + Fashion MNIST

    前言 后面工作中有个较重要的 task 是将 YOLOV3 目标检测和 LanNet 车道线检测和到一个网络中训练,特别的是,这两部分数据来自于不同的数据源.这和我之前在 caffe 环境下训练检测整 ...

  3. TensorFlow使用记录 (十): Pretraining

    上一篇的模型保存和恢复熟练后,我们就可以大量使用 pretrain model 来训练任务了 Tweaking, Dropping, or Replacing the Upper Layers The ...

  4. TensorFlow 学习(十二)—— 高级函数

    tf.map_fn(fn, elems):接受一个函数对象,然后用该函数对象对集合(elems)中的每一个元素分别处理, def preprocessing_image(image, training ...

  5. 第十二篇 Integration Services:高级日志记录

    本篇文章是Integration Services系列的第十二篇,详细内容请参考原文. 简介在前一篇文章我们配置了SSIS内置日志记录,演示了简单和高级日志配置,保存并查看日志配置,生成自定义日志消息 ...

  6. 【译】第十二篇 Integration Services:高级日志记录

    本篇文章是Integration Services系列的第十二篇,详细内容请参考原文. 简介在前一篇文章我们配置了SSIS内置日志记录,演示了简单和高级日志配置,保存并查看日志配置,生成自定义日志消息 ...

  7. Tensorflow深度学习之十二:基础图像处理之二

    Tensorflow深度学习之十二:基础图像处理之二 from:https://blog.csdn.net/davincil/article/details/76598474   首先放出原始图像: ...

  8. 我的MYSQL学习心得(十二) 触发器

    我的MYSQL学习心得(十二) 触发器 我的MYSQL学习心得(一) 简单语法 我的MYSQL学习心得(二) 数据类型宽度 我的MYSQL学习心得(三) 查看字段长度 我的MYSQL学习心得(四) 数 ...

  9. 解剖SQLSERVER 第十二篇 OrcaMDF 行压缩支持(译)

    解剖SQLSERVER 第十二篇   OrcaMDF 行压缩支持(译) http://improve.dk/orcamdf-row-compression-support/ 在这两个月的断断续续的开发 ...

随机推荐

  1. 算法:二叉树的层次遍历(递归实现+非递归实现,lua)

    二叉树知识参考:深入学习二叉树(一) 二叉树基础 递归实现层次遍历算法参考:[面经]用递归方法对二叉树进行层次遍历 && 二叉树深度 上面第一篇基础写得不错,不了解二叉树的值得一看. ...

  2. sql的关键字

    整理一下sql的关键字,一直都在用,只是很少去整理,所以今天简单整理一下,主要是整理CRUD的一些关键字. 写在前面:sql 不区分大小写 select 简单查询语句 select columnNam ...

  3. js之数据类型(对象类型——构造器对象——对象)

    JavaScript中除了原始类型,null,undefined之外就是对象了,对象是属性的集合,每个属性都是由键值对(值可以是原始值,比如说是数字,字符串,也可以是对象)构成的.对象又可分为构造器对 ...

  4. http协议:http请求、http响应、间隔时间跳转页面、禁用浏览器缓存

    转自:https://blog.csdn.net/u013372487/article/details/46991623 http协议 1. http协议是建立在  tcp/ip协议基础上. 2. 我 ...

  5. Redis-设置key过期

    Redis-设置key过期 expire key seconds 设置指定key 多少秒后过期, seconds 为 -1 时表示永不过期 ttl key 查看指定key还有多少秒过期 persist ...

  6. 【Swift后台】环境安装

    macOS 在macOS上使用Vapor,需要Xcode 9.3或更高版本.Swift 4.1或更高版本.安装还需要Homebrew命令. 检查Swift版本: swift --version Vap ...

  7. vlan linux内核数据流程

    转:http://blog.sina.com.cn/s/blog_62bbc49c0100fs0n.html 一.前言 前几天做协议划分vlan的时候看了一些linux内核,了解不深,整理了下vlan ...

  8. 解决IDEA中springboot整合mybatis中出现的Invalid bound statement(not found)的问题【转】

    感谢原博主,原文链接 : https://blog.csdn.net/benben513624/article/details/81076182 最近学习springboot的开发,中间磕磕碰碰也是遇 ...

  9. Array 对象-sort()

    Array 对象-sort() sort方法对数组成员进行排序,默认是按照字典顺序排序.排序后,原数组将被改变. sort方法不是按照大小排序,而是按照字典顺序.也就是说,数值会被先转成字符串,再按照 ...

  10. Centos7网卡配置命令nmcli

    在配置Centos6时,大家第一想到的就是把networkManager这个服务关掉,让它消失,这个服务台太鸡肋了,不该起作用的时候经常起作用,给管理带来了不便,但是在Centos7当中network ...