问题的出现 Question

这个问题是我基于TensorFlow使用CNN训练MNIST数据集的时候遇到的。关键的相关代码是以下这部分:

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

学习速率是\((1e-4)\)的时候是没有问题,但是当我把学习速率调到\(0.01/0.5\)的时候,很快就会报错。

tensorflow.python.framework.errors.InvalidArgumentError: ReluGrad input is not finite. : Tensor had NaN values

分析 Analysis

学习速率 Learning Rate

于是我尝试加上几行代码,希望能把y_conv和cross_entropy的状态反映出来。

y_conv=tf.Print(y_conv,[y_conv],"y_conv: ")
cross_entropy =tf.Print(cross_entropy,[cross_entropy],"cross_entropy: ")

当learning rate \(=0.01\)时,程序会报错:

I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [3.0374929e-06 0.0059775524 0.980205...]
step 0, training accuracy 0.04
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [9.2028862e-10 1.4812358e-05 0.044873074...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [648.49146]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.024463326 1.4828938e-31 0...]
step 1, training accuracy 0.2
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [2.4634053e-11 3.3087209e-34 0...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [nan]
step 2, training accuracy 0.14
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [nan nan nan...]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7ff51d92a940 Compute status: Invalid argument: ReluGrad input is not finite. : Tensor had NaN values

当learning rate \(=1e-4\)时,程序不会报错。

I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.00056920078 8.4922984e-09 0.00033719366...]
step 0, training accuracy 0.14
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [7.0613837e-10 9.28294e-09 0.00016230672...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [439.95135]
step 1, training accuracy 0.16
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.031509314 3.6221365e-05 0.015359053...]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [3.7112056e-07 1.8543299e-09 8.9234991e-06...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [436.37653]
step 2, training accuracy 0.12
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.015578311 0.0026688741 0.44736364...]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [6.0428465e-07 0.0001744287 0.026451336...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [385.33765]

至此,我们可以看到,学习速率太大是产生error其中一个原因。

参考斯坦福CS 224D的Lecture Note,在训练深度神经网络的时候,出现NaN比较大的可能是因为学习速率过大,梯度值过大,产生梯度爆炸。

Refer to the lecture note of Stanford CS 224D, a precise definition of Gradient Explosion is:

During experimentation, once the gradient value grows extremely large, it causes an overflow (i.e. NaN) which is easily detectable at runtime; this issue is called the Gradient Explosion Problem.

解决方法 Solutions

  1. 适当减小学习速率 Try to decrease the learning rate.
  2. 加入Gradient clipping的方法。 Gradient clipping的方法最早是由Thomas Mikolov提出的。每当梯度达到一定的阈值,就把他们设置回一个小一些的数字。

    Refer to the lecture note of Stanford CS 224D, use gradient clipping.

To solve the problem of exploding gradients, Thomas Mikolov first introduced a simple heuristic solution that clips gradients to a small number whenever they explode. That is, whenever they reach a certain threshold, they are set back to a small number as shown in Algorithm 1.

Algorithm 1:

\(\frac{\partial E}{\partial W}\to g\)

if $ \Vert g\Vert\ge threshold$ then

\(\frac {threshold}{\Vert g\Vert} g\to g\)

end if

TensorFlow | ReluGrad input is not finite. Tensor had NaN values的更多相关文章

  1. Tensorflow 模型文件结构、模型中Tensor查看

    tensorflow训练后保存的模型主要包含两部分,一是网络结构的定义(网络图),二是网络结构里的参数值. 1.  .meta文件 .meta 文件以 "protocol buffer&qu ...

  2. tensorflow报错 tensorflow Resource exhausted: OOM when allocating tensor with shape

    在使用tensorflow的object detection时,出现以下报错 tensorflow Resource exhausted: OOM when allocating tensor wit ...

  3. 怎么在tensorflow中打印graph中的tensor信息

    from tensorflow.python import pywrap_tensorflow import os checkpoint_path=os.path.join('./model.ckpt ...

  4. Spark连续特征转化成离散特征

    当数据量很大的时候,分类任务通常使用[离散特征+LR]集成[连续特征+xgboost],如果把连续特征加入到LR.决策树中,容易造成overfit. 如果想用上连续型特征,使用集成学习集成多种算法是一 ...

  5. 用NVIDIA Tensor Cores和TensorFlow 2加速医学图像分割

    用NVIDIA Tensor Cores和TensorFlow 2加速医学图像分割 Accelerating Medical Image Segmentation with NVIDIA Tensor ...

  6. Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

    简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...

  7. [开发技巧]·TensorFlow中numpy与tensor数据相互转化

    [开发技巧]·TensorFlow中numpy与tensor数据相互转化 个人主页–> https://xiaosongshine.github.io/ - 问题描述 在我们使用TensorFl ...

  8. TensorFlow使用记录 (九): 模型保存与恢复

    模型文件 tensorflow 训练保存的模型注意包含两个部分:网络结构和参数值. .meta .meta 文件以 “protocol buffer”格式保存了整个模型的结构图,模型上定义的操作等信息 ...

  9. TensorFlowSharp入门使用C#编写TensorFlow人工智能应用

    TensorFlowSharp入门使用C#编写TensorFlow人工智能应用学习. TensorFlow简单介绍 TensorFlow 是谷歌的第二代机器学习系统,按照谷歌所说,在某些基准测试中,T ...

随机推荐

  1. c#聊聊文件数据库kv

    现在有很多KV嵌入式存储,或者已经增加的.leveldb,RaptorDB等,都是相对比较好的存储.基本存储,一般配置.大概在6w/s左右.当然还有缓存等设置问题.这些基本是字符串和int的存储,对于 ...

  2. ExcludeClipRect区域裁剪问题

    CPaintDC dc(this); CRect rt1; CPen newPen; newPen.CreatePen(PS_SOLID,1,RGB(0,0,0)); CPen *pOldPen = ...

  3. 09JavaScript函数

    函数是由事件驱动的或者当它被调用时执行的可重复使用的代码块. 实例1: <!DOCTYPE html> <html> <head> <meta charset ...

  4. jQuery.qrcode 生成二维码,并使用 jszip、FileSaver 下载 zip 压缩包至本地。

    生成二维码 引用 jquery.qrcode.js  :连接:https://files.cnblogs.com/files/kitty-blog/jquery.qrcode.js .https:// ...

  5. 初识Java——第一章 初识Java

    1. 计算机程序: 为了让计算机执行某些操作或解决某个问题而编写的一系列有序指令的集合. 2. JAVA相关的技术:      1).安装和运行在本机上的桌面程序      2).通过浏览器访问的面向 ...

  6. mybatis调用存过程返回结果集和out参数值

    Mapper文件: 1.配置一个参数映射集,采用hashMap对象 2.使用call调用存储过,其中in out等标识的参数需要有详细的描述,例如:mode,JavaType,jdbcType等 &l ...

  7. php源码建博客5--建库建表-配置文件-错误日志

    主要: 整理框架 建库建表 配置文件类 错误日志记录 --------------本篇后文件结构:-------------------------------------- blog ├─App │ ...

  8. php 算法(二分法)只适用于有序表,且限于顺序存储结构

    function demo($array,$low,$high,$k){ if($low<=$high){//判断该数组是否存在 $mid =  intval(($low+$high)/2 ); ...

  9. Learning notes | Data Analysis: 1.2 data wrangling

    | Data Wrangling | # Sort all the data into one file files = ['BeijingPM20100101_20151231.csv','Chen ...

  10. React 源码中的依赖注入方法

    一.前言 依赖注入(Dependency Injection)这个概念的兴起已经有很长时间了,把这个概念融入到框架中达到出神入化境地的,非Spring莫属.然而在前端领域,似乎很少会提到这个概念,难道 ...