问题的出现 Question

这个问题是我基于TensorFlow使用CNN训练MNIST数据集的时候遇到的。关键的相关代码是以下这部分:

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

学习速率是\((1e-4)\)的时候是没有问题,但是当我把学习速率调到\(0.01/0.5\)的时候,很快就会报错。

tensorflow.python.framework.errors.InvalidArgumentError: ReluGrad input is not finite. : Tensor had NaN values

分析 Analysis

学习速率 Learning Rate

于是我尝试加上几行代码,希望能把y_conv和cross_entropy的状态反映出来。

y_conv=tf.Print(y_conv,[y_conv],"y_conv: ")
cross_entropy =tf.Print(cross_entropy,[cross_entropy],"cross_entropy: ")

当learning rate \(=0.01\)时,程序会报错:

I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [3.0374929e-06 0.0059775524 0.980205...]
step 0, training accuracy 0.04
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [9.2028862e-10 1.4812358e-05 0.044873074...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [648.49146]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.024463326 1.4828938e-31 0...]
step 1, training accuracy 0.2
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [2.4634053e-11 3.3087209e-34 0...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [nan]
step 2, training accuracy 0.14
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [nan nan nan...]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7ff51d92a940 Compute status: Invalid argument: ReluGrad input is not finite. : Tensor had NaN values

当learning rate \(=1e-4\)时,程序不会报错。

I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.00056920078 8.4922984e-09 0.00033719366...]
step 0, training accuracy 0.14
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [7.0613837e-10 9.28294e-09 0.00016230672...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [439.95135]
step 1, training accuracy 0.16
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.031509314 3.6221365e-05 0.015359053...]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [3.7112056e-07 1.8543299e-09 8.9234991e-06...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [436.37653]
step 2, training accuracy 0.12
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.015578311 0.0026688741 0.44736364...]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [6.0428465e-07 0.0001744287 0.026451336...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [385.33765]

至此,我们可以看到,学习速率太大是产生error其中一个原因。

参考斯坦福CS 224D的Lecture Note,在训练深度神经网络的时候,出现NaN比较大的可能是因为学习速率过大,梯度值过大,产生梯度爆炸。

Refer to the lecture note of Stanford CS 224D, a precise definition of Gradient Explosion is:

During experimentation, once the gradient value grows extremely large, it causes an overflow (i.e. NaN) which is easily detectable at runtime; this issue is called the Gradient Explosion Problem.

解决方法 Solutions

  1. 适当减小学习速率 Try to decrease the learning rate.
  2. 加入Gradient clipping的方法。 Gradient clipping的方法最早是由Thomas Mikolov提出的。每当梯度达到一定的阈值,就把他们设置回一个小一些的数字。

    Refer to the lecture note of Stanford CS 224D, use gradient clipping.

To solve the problem of exploding gradients, Thomas Mikolov first introduced a simple heuristic solution that clips gradients to a small number whenever they explode. That is, whenever they reach a certain threshold, they are set back to a small number as shown in Algorithm 1.

Algorithm 1:

\(\frac{\partial E}{\partial W}\to g\)

if $ \Vert g\Vert\ge threshold$ then

\(\frac {threshold}{\Vert g\Vert} g\to g\)

end if

TensorFlow | ReluGrad input is not finite. Tensor had NaN values的更多相关文章

  1. Tensorflow 模型文件结构、模型中Tensor查看

    tensorflow训练后保存的模型主要包含两部分,一是网络结构的定义(网络图),二是网络结构里的参数值. 1.  .meta文件 .meta 文件以 "protocol buffer&qu ...

  2. tensorflow报错 tensorflow Resource exhausted: OOM when allocating tensor with shape

    在使用tensorflow的object detection时,出现以下报错 tensorflow Resource exhausted: OOM when allocating tensor wit ...

  3. 怎么在tensorflow中打印graph中的tensor信息

    from tensorflow.python import pywrap_tensorflow import os checkpoint_path=os.path.join('./model.ckpt ...

  4. Spark连续特征转化成离散特征

    当数据量很大的时候,分类任务通常使用[离散特征+LR]集成[连续特征+xgboost],如果把连续特征加入到LR.决策树中,容易造成overfit. 如果想用上连续型特征,使用集成学习集成多种算法是一 ...

  5. 用NVIDIA Tensor Cores和TensorFlow 2加速医学图像分割

    用NVIDIA Tensor Cores和TensorFlow 2加速医学图像分割 Accelerating Medical Image Segmentation with NVIDIA Tensor ...

  6. Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

    简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...

  7. [开发技巧]·TensorFlow中numpy与tensor数据相互转化

    [开发技巧]·TensorFlow中numpy与tensor数据相互转化 个人主页–> https://xiaosongshine.github.io/ - 问题描述 在我们使用TensorFl ...

  8. TensorFlow使用记录 (九): 模型保存与恢复

    模型文件 tensorflow 训练保存的模型注意包含两个部分:网络结构和参数值. .meta .meta 文件以 “protocol buffer”格式保存了整个模型的结构图,模型上定义的操作等信息 ...

  9. TensorFlowSharp入门使用C#编写TensorFlow人工智能应用

    TensorFlowSharp入门使用C#编写TensorFlow人工智能应用学习. TensorFlow简单介绍 TensorFlow 是谷歌的第二代机器学习系统,按照谷歌所说,在某些基准测试中,T ...

随机推荐

  1. RockBrain USB Server- 云计算虚拟化USB设备集中管理、远程共享解决方案(涉及银企直联)

    RockBrain USB Server- 云计算虚拟化USB设备集中管理.远程共享解决方案(涉及银企直联) 技术需求: 1.企业员工的大量USB Key,需要将key接入USB Server虚拟池, ...

  2. 身份认证系统(三)什么是OAuth2

    本文准备用最简单的语言告诉大家什么是OAuth2 ,OAuth2是干什么的. 我们有一个资源服务器,资源服务器中有一系列的用户数据. 现在有一个应用想想要获取我们的用户数据. 那么最简单的方法就是我们 ...

  3. ArrayList的去重问题

    面试被问及arraylist的去重问题,现将自己想的两种解决方案写在下面 /** * Description: * ClassName:Uniq * Package:com.syd.interview ...

  4. DQL-常见的函数

    一.概述功能:类似于java中的方法好处:提高重用性和隐藏实现细节调用:select 函数名(实参列表); 二:常用的函数: ① 单行函数 1.分组函数 1.sum(),avg(),max(),min ...

  5. oracle 的replace()

    1.查询所有员工的姓名,如果包含字母s,则用S替换. 2.查询所有员工姓名的前三个字符.

  6. injection for Xcode10使用方法

    对于一个使用Xcode的使用者来说,麻烦的地方在于使用代码布置界面时候的调试,5s改一下代码,用10s查看修改效果,如果电脑配置稍低,时间更长,这是病,得治,哈哈.下面就来说一下injection的使 ...

  7. JavaBen 中 如何将字段设置为 "text" 文本类型

    @Lob @Column(name="FEEDBACK_MESSAGE",columnDefinition="TEXT", nullable=true) pub ...

  8. bootstrap-01-学习记录

    1.bootstrap所有插件依赖JQ,必须在JQ之后引入. 2.bootstrap分预编译版(css,js,fonts)和源码版(less,js,fonts,dist->预编译版内容,docs ...

  9. 基于Bootstrap Ace模板+bootstrap.addtabs.js的菜单

    这几天研究了基于bootstrap Ace模板+bootstra.addtabs.js实现菜单的效果 参考了这个人的博客 https://www.cnblogs.com/landeanfen/p/76 ...

  10. PHP实现长网址与短网址

    原文地址:http://www.qqdeveloper.com/detail/29/1.html 什么是长链接.短链接 顾名思义,长链接就是一个很长的链接:短链接就是一个很短的链接.长链接可以生成短链 ...