问题的出现 Question

这个问题是我基于TensorFlow使用CNN训练MNIST数据集的时候遇到的。关键的相关代码是以下这部分:

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

学习速率是\((1e-4)\)的时候是没有问题,但是当我把学习速率调到\(0.01/0.5\)的时候,很快就会报错。

tensorflow.python.framework.errors.InvalidArgumentError: ReluGrad input is not finite. : Tensor had NaN values

分析 Analysis

学习速率 Learning Rate

于是我尝试加上几行代码,希望能把y_conv和cross_entropy的状态反映出来。

y_conv=tf.Print(y_conv,[y_conv],"y_conv: ")
cross_entropy =tf.Print(cross_entropy,[cross_entropy],"cross_entropy: ")

当learning rate \(=0.01\)时,程序会报错:

I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [3.0374929e-06 0.0059775524 0.980205...]
step 0, training accuracy 0.04
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [9.2028862e-10 1.4812358e-05 0.044873074...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [648.49146]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.024463326 1.4828938e-31 0...]
step 1, training accuracy 0.2
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [2.4634053e-11 3.3087209e-34 0...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [nan]
step 2, training accuracy 0.14
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [nan nan nan...]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7ff51d92a940 Compute status: Invalid argument: ReluGrad input is not finite. : Tensor had NaN values

当learning rate \(=1e-4\)时,程序不会报错。

I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.00056920078 8.4922984e-09 0.00033719366...]
step 0, training accuracy 0.14
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [7.0613837e-10 9.28294e-09 0.00016230672...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [439.95135]
step 1, training accuracy 0.16
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.031509314 3.6221365e-05 0.015359053...]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [3.7112056e-07 1.8543299e-09 8.9234991e-06...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [436.37653]
step 2, training accuracy 0.12
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.015578311 0.0026688741 0.44736364...]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [6.0428465e-07 0.0001744287 0.026451336...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [385.33765]

至此,我们可以看到,学习速率太大是产生error其中一个原因。

参考斯坦福CS 224D的Lecture Note,在训练深度神经网络的时候,出现NaN比较大的可能是因为学习速率过大,梯度值过大,产生梯度爆炸。

Refer to the lecture note of Stanford CS 224D, a precise definition of Gradient Explosion is:

During experimentation, once the gradient value grows extremely large, it causes an overflow (i.e. NaN) which is easily detectable at runtime; this issue is called the Gradient Explosion Problem.

解决方法 Solutions

  1. 适当减小学习速率 Try to decrease the learning rate.
  2. 加入Gradient clipping的方法。 Gradient clipping的方法最早是由Thomas Mikolov提出的。每当梯度达到一定的阈值,就把他们设置回一个小一些的数字。

    Refer to the lecture note of Stanford CS 224D, use gradient clipping.

To solve the problem of exploding gradients, Thomas Mikolov first introduced a simple heuristic solution that clips gradients to a small number whenever they explode. That is, whenever they reach a certain threshold, they are set back to a small number as shown in Algorithm 1.

Algorithm 1:

\(\frac{\partial E}{\partial W}\to g\)

if $ \Vert g\Vert\ge threshold$ then

\(\frac {threshold}{\Vert g\Vert} g\to g\)

end if

TensorFlow | ReluGrad input is not finite. Tensor had NaN values的更多相关文章

  1. Tensorflow 模型文件结构、模型中Tensor查看

    tensorflow训练后保存的模型主要包含两部分,一是网络结构的定义(网络图),二是网络结构里的参数值. 1.  .meta文件 .meta 文件以 "protocol buffer&qu ...

  2. tensorflow报错 tensorflow Resource exhausted: OOM when allocating tensor with shape

    在使用tensorflow的object detection时,出现以下报错 tensorflow Resource exhausted: OOM when allocating tensor wit ...

  3. 怎么在tensorflow中打印graph中的tensor信息

    from tensorflow.python import pywrap_tensorflow import os checkpoint_path=os.path.join('./model.ckpt ...

  4. Spark连续特征转化成离散特征

    当数据量很大的时候,分类任务通常使用[离散特征+LR]集成[连续特征+xgboost],如果把连续特征加入到LR.决策树中,容易造成overfit. 如果想用上连续型特征,使用集成学习集成多种算法是一 ...

  5. 用NVIDIA Tensor Cores和TensorFlow 2加速医学图像分割

    用NVIDIA Tensor Cores和TensorFlow 2加速医学图像分割 Accelerating Medical Image Segmentation with NVIDIA Tensor ...

  6. Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

    简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...

  7. [开发技巧]·TensorFlow中numpy与tensor数据相互转化

    [开发技巧]·TensorFlow中numpy与tensor数据相互转化 个人主页–> https://xiaosongshine.github.io/ - 问题描述 在我们使用TensorFl ...

  8. TensorFlow使用记录 (九): 模型保存与恢复

    模型文件 tensorflow 训练保存的模型注意包含两个部分:网络结构和参数值. .meta .meta 文件以 “protocol buffer”格式保存了整个模型的结构图,模型上定义的操作等信息 ...

  9. TensorFlowSharp入门使用C#编写TensorFlow人工智能应用

    TensorFlowSharp入门使用C#编写TensorFlow人工智能应用学习. TensorFlow简单介绍 TensorFlow 是谷歌的第二代机器学习系统,按照谷歌所说,在某些基准测试中,T ...

随机推荐

  1. python3带tkinter窗口的ftp服务器,并使用pyinstaller打包成exe

    python带tkinter窗口的ftp服务器,使用python3编写,打包使用pyinstaller,命令 pyinstaller -F .\ftpserver.py 代码也可在我的github上下 ...

  2. SQLMAP使用详解

    使用示例 python sqlmap.py -u "http://xx.com/member.php?id=XX"  -p id --dbms "Mysql"  ...

  3. 竞赛题解 - Broken Tree(CF-758E)

    Broken Tree(CF-758E) - 竞赛题解 贪心复习~(好像暴露了什么算法--) 标签:贪心 / DFS / Codeforces 『题意』 给出一棵以1为根的树,每条边有两个值:p-强度 ...

  4. NPOI读取Excel遇到的坑

    NPOI是POI的.NET版本.POI是用Java写成的库,能帮助用户在没有安装Office环境下读取Office2003-2007文件.NPOI在.NET环境下使用,能读写Excel/Word文件. ...

  5. 【reidis中ruby模块版本老旧利用rvm来更新】

    //gem install redis时会遇到如下的error: //借助rvm来update ruby版本

  6. 使用Letsencrypt做SSL certificate

    为什么要使用Letsencrypt做SSL certificate? 最简单直接的原因是免费.但是免费存在是否靠谱的问题,尤其是对安全要求比较高的网站,需要考虑使用letsencrypt的安全性是否符 ...

  7. PHP7.1中使用openssl替换mcrypt

    PHP7.1中使用openssl替换mcrypt 在php开发中,使用mcrypt相关函数可以很方便地进行AES加.解密操作,但是PHP7.1中废弃了mcrypt扩展,所以必需寻找另一种实现.在迁移手 ...

  8. 第七篇:gcc和arm-linux-gcc常用选项

    目录 一.gcc和arm-linux-gcc的常用选项 二.从.c文件到可执行文件过程 一.gcc和arm-linux-gcc的常用选项 常用选型 -v 查看gcc编译器的版本,显示gcc执行时的详细 ...

  9. python网络编程之线程

    一 .背景知识 1.进程 之前我们已经了解了操作系统中进程的概念,程序并不能单独运行,只有将程序装载到内存中,系统为它分配资源才能运行,而这种执行的程序就称之为进程.程序和进程的区别就在于:程序是指令 ...

  10. [转]Visual C++ 和 C++ 有什么区别?

    注:本篇内容转载与网络,方便自己学习,如有侵权请您联系我删除,谢谢. 有位同学问我“Visual C++和C++有什么区别?”,这的确是初学者会感到困惑的问题,比较常见.除此之外,还有“先学C++好, ...