SELU︱在keras、tensorflow中使用SELU激活函数
arXiv 上公开的一篇 NIPS 投稿论文《Self-Normalizing Neural Networks》引起了圈内极大的关注,它提出了缩放指数型线性单元(SELU)而引进了自归一化属性,该单元主要使用一个函数 g 映射前后两层神经网络的均值和方差以达到归一化的效果。 Shao-Hua Sun 在 Github 上放出了 SELU 与 Relu、Leaky Relu 的对比,机器之心对比较结果进行了翻译介绍,具体的实现过程可参看以下项目地址。
项目地址:shaohua0116/Activation-Visualization-Histogram
来源机器之心:引爆机器学习圈:「自归一化神经网络」提出新型激活函数SELU
keras中使用SELU激活函数
在keras 2.0.6版本之后才可以使用selu激活函数,但是在版本2.0.5还是不行,所以得升级到这个版本。
在全连接层后面接上selu最终收敛会快一些
来看一下,一个介绍非常详细的github:bigsnarfdude/SELU_Keras_Tutorial
具体对比效果:
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.layers.noise import AlphaDropout
from keras.utils import np_utils
from keras.optimizers import RMSprop, Adam
batch_size = 128
num_classes = 10
epochs = 20
learning_rate = 0.001
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
modelSELU = Sequential()
modelSELU.add(Dense(512, activation='selu', input_shape=(784,)))
modelSELU.add(AlphaDropout(0.1))
modelSELU.add(Dense(512, activation='selu'))
modelSELU.add(AlphaDropout(0.1))
modelSELU.add(Dense(10, activation='softmax'))
modelSELU.summary()
modelSELU.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy'])
historySELU = modelSELU.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))
scoreSELU = modelSELU.evaluate(x_test, y_test, verbose=0)
print('Test loss:', scoreSELU[0])
print('Test accuracy:', scoreSELU[1])
tensorflow中使用dropout_selu + SELU
该文作者在tensorflow也加入了selu 和 dropout_selu两个新的激活函数。
可见:https://github.com/bigsnarfdude/SELU_Keras_Tutorial/blob/master/Basic_MLP_combined_comparison.ipynb
定义了两个新函数:SELU 和 dropout_selu
def selu(x):
with ops.name_scope('elu') as scope:
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
return scale*tf.where(x>=0.0, x, alpha*tf.nn.elu(x))
def dropout_selu(x, keep_prob, alpha= -1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0,
noise_shape=None, seed=None, name=None, training=False):
"""Dropout to a value with rescaling."""
def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name):
keep_prob = 1.0 - rate
x = ops.convert_to_tensor(x, name="x")
if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1:
raise ValueError("keep_prob must be a scalar tensor or a float in the "
"range (0, 1], got %g" % keep_prob)
keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob")
keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha")
keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
if tensor_util.constant_value(keep_prob) == 1:
return x
noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)
random_tensor = keep_prob
random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype)
binary_tensor = math_ops.floor(random_tensor)
ret = x * binary_tensor + alpha * (1-binary_tensor)
a = tf.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * tf.pow(alpha-fixedPointMean,2) + fixedPointVar)))
b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha)
ret = a * ret + b
ret.set_shape(x.get_shape())
return ret
with ops.name_scope(name, "dropout", [x]) as name:
return utils.smart_cond(training,
lambda: dropout_selu_impl(x, keep_prob, alpha, noise_shape, seed, name),
lambda: array_ops.identity(x))
作者将其使用在以下案例之中:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# Check out https://www.tensorflow.org/get_started/mnist/beginners for
# more information about the mnist dataset
# parameters
learning_rate = 0.001
training_epochs = 50
batch_size = 100
# input place holders
X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])
# dropout (keep_prob) rate 0.7 on training, but should be 1 for testing
keep_prob = tf.placeholder(tf.float32)
# weights & bias for nn layers
# http://stackoverflow.com/questions/33640581/how-to-do-xavier-initialization-on-tensorflow
W1 = tf.get_variable("W1", shape=[784, 512],
initializer=tf.contrib.layers.xavier_initializer())
b1 = tf.Variable(tf.random_normal([512]))
L1 = selu(tf.matmul(X, W1) + b1)
L1 = dropout_selu(L1, keep_prob=keep_prob)
W2 = tf.get_variable("W2", shape=[512, 512],
initializer=tf.contrib.layers.xavier_initializer())
b2 = tf.Variable(tf.random_normal([512]))
L2 = selu(tf.matmul(L1, W2) + b2)
L2 = dropout_selu(L2, keep_prob=keep_prob)
W3 = tf.get_variable("W3", shape=[512, 512],
initializer=tf.contrib.layers.xavier_initializer())
b3 = tf.Variable(tf.random_normal([512]))
L3 = selu(tf.matmul(L2, W3) + b3)
L3 = dropout_selu(L3, keep_prob=keep_prob)
W4 = tf.get_variable("W4", shape=[512, 512],
initializer=tf.contrib.layers.xavier_initializer())
b4 = tf.Variable(tf.random_normal([512]))
L4 = selu(tf.matmul(L3, W4) + b4)
L4 = dropout_selu(L4, keep_prob=keep_prob)
W5 = tf.get_variable("W5", shape=[512, 10],
initializer=tf.contrib.layers.xavier_initializer())
b5 = tf.Variable(tf.random_normal([10]))
hypothesis = tf.matmul(L4, W5) + b5
# define cost/loss & optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=hypothesis, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
# initialize
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# train my model
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples / batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
feed_dict = {X: batch_xs, Y: batch_ys, keep_prob: 0.7}
c, _ = sess.run([cost, optimizer], feed_dict=feed_dict)
avg_cost += c / total_batch
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))
print('Learning Finished!')
# Test model and check accuracy
correct_prediction = tf.equal(tf.argmax(hypothesis, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
SELU︱在keras、tensorflow中使用SELU激活函数的更多相关文章
- Tensorflow中神经网络的激活函数
激励函数的目的是为了调节权重和误差. relu max(0,x) relu6 min(max(0,x),6) sigmoid 1/(1+exp(-x)) tanh ((exp(x)-exp(-x))/ ...
- 【tf.keras】tf.keras使用tensorflow中定义的optimizer
Update:2019/09/21 使用 tf.keras 时,请使用 tf.keras.optimizers 里面的优化器,不要使用 tf.train 里面的优化器,不然学习率衰减会出现问题. 使用 ...
- 在tensorflow中使用batch normalization
问题 训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题,但是却不能保证在训练过程中不出现该问题, ...
- 【深度学习】keras + tensorflow 实现猫和狗图像分类
本文主要是使用[监督学习]实现一个图像分类器,目的是识别图片是猫还是狗. 从[数据预处理]到 [图片预测]实现一个完整的流程, 当然这个分类在 Kaggle 上已经有人用[迁移学习](VGG,Resn ...
- 使用TensorFlow中的Batch Normalization
问题 训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题.但是却不能保证在训练过程中不出现该问题, ...
- Keras官方中文文档:keras后端Backend
所属分类:Keras Keras后端 什么是"后端" Keras是一个模型级的库,提供了快速构建深度学习网络的模块.Keras并不处理如张量乘法.卷积等底层操作.这些操作依赖于某种 ...
- Keras官方中文文档:常见问题与解答
所属分类:Keras Keras FAQ:常见问题 如何引用Keras? 如何使Keras调用GPU? 如何在多张GPU卡上使用Keras "batch", "epoch ...
- [AI开发]centOS7.5上基于keras/tensorflow深度学习环境搭建
这篇文章详细介绍在centOS7.5上搭建基于keras/tensorflow的深度学习环境,该环境可用于实际生产.本人现在非常熟练linux(Ubuntu/centOS/openSUSE).wind ...
- 第十八节,TensorFlow中使用批量归一化(BN)
在深度学习章节里,已经介绍了批量归一化的概念,详情请点击这里:第九节,改善深层神经网络:超参数调试.正则化以优化(下) 神经网络在进行训练时,主要是用来学习数据的分布规律,如果数据的训练部分和测试部分 ...
随机推荐
- 有关Oracle统计信息的知识点
一.什么是统计信息 统计信息主要是描述数据库中表,索引的大小,规模,数据分布状况等的一类信息.例如,表的行数,块数,平均每行的大小,索引的leaf blocks,索引字段的行数,不同值的大小等,都属于 ...
- php \r \n 和 <br/> \t
利用\r \n 和 <br/> \t做了个实验,话不多说,看代码就很清楚的知道
- Spring 整合Mybatis 出现了Cause: org.springframework.jdbc.CannotGetJdbcConnectionException: Could not get JDBC Connection; nested exception is org.apache.commons.dbcp.SQLNestedException: Cannot create Poola
我出现的 报错信息如下: ### Error querying database. Cause: org.springframework.jdbc.CannotGetJdbcConnectionExc ...
- springcloud19---springCloudConfig
Spring-cloud-config : 统一管理配置的组件,不同的环境不同的管理(连接池.数据库配置不一样).不同时间需要动态调整配置(双十一最大连接数要大). 分布式配置也可以使用config或 ...
- 【知识总结】Activiti工作流学习入门
1. 我理解的工作流: 在工作中慢慢接触的业务流程,就向流程控制语言一样,一步一步都对应的不同的业务,但整体串联起来就是一个完整的业务.而且实际工作中尤其是在企业内部系统的研发中,确实需要对应许多审批 ...
- P1879 [USACO06NOV]玉米田Corn Fields(状压dp)
P1879 [USACO06NOV]玉米田Corn Fields 状压dp水题 看到$n,m<=12$,肯定是状压鸭 先筛去所有不合法状态,蓝后用可行的状态跑一次dp就ok了 #include& ...
- Linux 中各个文件夹的作用
/ 根目录 包含了几乎所的文件目录.相当于中央系统.进入的最简单方法是:cd /. /boot 引导程序,内核等存放的目录 这个目录,包括了在引导过程中所必需的文件.在最开始的启动阶段,通过引导程 ...
- 从IC设计业看中国企业之发展
从IC设计业看中国企业之发展 在半导体领域,国际平均毛利润水平为40%.去年IC设计年会中,中国半导体行业协会IC设计分会理事长魏少军指出,中国IC设计业平均毛利润水平比国际平均水平低了12.39 ...
- mysql绑定参数bind_param原理以及防SQL注入
假设我们的用户表中存在一行.用户名字段为username.值为aaa.密码字段为pwd.值为pwd.. 下面我们来模拟一个用户登录的过程.. <?php $username = "aa ...
- HDU 4300 Clairewd’s message(扩展KMP)题解
题意:先给你一个密码本,再给你一串字符串,字符串前面是密文,后面是明文(明文可能不完成整),也就是说这个字符串由一个完整的密文和可能不完整的该密文的明文组成,要你找出最短的密文+明文. 思路:我们把字 ...