搭建普通的卷积CNN网络。

nan表示的是无穷或者是非数值,比如说你在tensorflow中使用一个数除以0,那么得到的结果就是nan。

在一个matrix中,如果其中的值都为nan很有可能是因为采用的cost function不合理导致的。

当使用tensorflow构建一个最简单的神经网络的时候,按照tensorflow官方给出的教程:

https://www.tensorflow.org/get_started/mnist/beginners

http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html  (中文教程)

具体的含义就不解释了。大概分为三个部分:1,导入数据集;2,搭建模型,并且定义cost function(也叫loss function);3,训练。

对于过程1,我们采用的不是mnist数据集,而是自己定义了一个数据集,其中

对于过程2,我们使用最简单的CNN网络,然后定义cost function的方式是:

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

对于过程3,我们也采用教程中的例子去训练。

但是在初始化W后就立刻查看W参数的结果,得到的结果都是nan,以下是输出W权重后的结果:

这个现象是由于cost function引起的:

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

上面的语句中的y_是数据集的label。我们做的是显著性检测,就是数据集的ground truth。

并且这个label或者ground truth一定要是one hot类型的变量。

那什么是one hot类型的变量呢?

举一个例子:比如一个5个类的数据集,用0,1,2,3,4来表示5个类的标签,因此label=0,1,2,3,4。这时候有的人会把y_=0,1,2,3,4。直接输入到cost function——-tf.reduce_sum(y_*tf.log(y))中,那么这样会导致W参数初始化都是nan。

解决办法就是我们把label=0,1,2,3,4变为one hot变量,改变后的结果是:label=[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,0,1],这样再输入到tf.reduce_sum(y_*tf.log(y))中,就是正确的了,如下图,我们采用的解决办法是第二种,具体参考下文。

那么本文提供两种方法来解决这个问题:

1,将y_从原来的类别数字变为one hot变量,使用

labels = tf.reshape(labels, [batch_size, 1])
indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
labels = tf.sparse_to_dense(
tf.concat(values=[indices, labels], axis=1),
[batch_size, num_classes], 1.0, 0.0)
将label转为one hot(batch_size是你每次抓取的训练集的个数)
2,换一个cost function,原来的cost function = -tf.reduce_sum(y_*tf.log(y))
使用的是交叉熵函数,现在我们换成二次代价函数 cost function = tf.reduce_sum(tf.square(tf.substract(y_,y)))

解决tensorflow在训练的时候权重是nan问题的更多相关文章

  1. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直 ...

  2. 深度学习笔记 (二) 在TensorFlow上训练一个多层卷积神经网络

    上一篇笔记主要介绍了卷积神经网络相关的基础知识.在本篇笔记中,将参考TensorFlow官方文档使用mnist数据集,在TensorFlow上训练一个多层卷积神经网络. 下载并导入mnist数据集 首 ...

  3. 安装 tensorflow 1.1.0;以及安装其他相似版本tensorflow遇到的问题;tensorflow 1.13.2 cuda-10环境变量配置问题;Tensorflow 指定训练时如何指定使用的GPU;

    # 安装 2.7 环境conda create -n python2. python= conda activate python2. # 安装 1.1.0 gpu版本pip # 配置环境变量expo ...

  4. 解决tensorflow Saver.restore()无效的问题

    解决tensorflow 的 Saver.restore()无法从本地读取变量的问题 最近做tensorflow 手写数字识别的时候遇到了一个问题,Saver的restore()方法无法从本地恢复变量 ...

  5. tensorflow分布式训练

    https://blog.csdn.net/hjimce/article/details/61197190  tensorflow分布式训练 https://cloud.tencent.com/dev ...

  6. Tensorflow Mask-RCNN训练识别箱子的模型运行结果(练习)

    Tensorflow Mask-RCNN训练识别箱子的模型

  7. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过 ...

  8. 利用阿里云容器服务打通TensorFlow持续训练链路

    本系列将利用Docker和阿里云容器服务,帮助您上手TensorFlow的机器学习方案 第一篇:打造TensorFlow的实验环境 第二篇:轻松搭建TensorFlow Serving集群 第三篇:打 ...

  9. 在C#下使用TensorFlow.NET训练自己的数据集

    在C#下使用TensorFlow.NET训练自己的数据集 今天,我结合代码来详细介绍如何使用 SciSharp STACK 的 TensorFlow.NET 来训练CNN模型,该模型主要实现 图像的分 ...

随机推荐

  1. 【LOJ】#2268. 「SDOI2017」苹果树

    题解 显然权值都是正的,我们最深的那个点一定延伸到了某个叶子 我们抛去这条链之外再选K个点即可 如果直接对一棵树选K个点,满足这样的依赖关系,可以通过一个后序遍历的顺序做出来 转移方法是 \(dp[i ...

  2. 008 使用POJO对象绑定请求参数

    1.介绍 2.Person.java package com.spring.bean; public class Person { private String username; private S ...

  3. 自制 COCO api 直接读取类 COCO 的标注数据的压缩文件

    第6章 COCO API 的使用 COCO 数据库是由微软发布的一个大型图像数据集,该数据集专为对象检测.分割.人体关键点检测.语义分割和字幕生成而设计.如果你要了解 COCO 数据库的一些细节,你可 ...

  4. U盘制作Win7安装盘的方法

    Windows 7 USB/DVD download tool 微软官方说明:http://www.microsoftstore.com/st ... Win7_usbdvd_dwnTool 下载地址 ...

  5. Codeforces.314E.Sereja and Squares(DP)

    题目链接 http://www.cnblogs.com/TheRoadToTheGold/p/8443668.html \(Description\) 给你一个擦去了部分左括号和全部右括号的括号序列, ...

  6. Windows 7 MBR的修复与Linux产品正确卸载

    这几天折腾系统很令人崩溃,但也明白了开机引导流程具体如何. 觉得Centos 7不好用,想卸载Redhat安装Ubuntu,为了图方便直接把红帽的硬盘区格式化了.于是开机引导崩溃,咨询了下大神,大神叫 ...

  7. ROS知识(23)——行为树Behavio Tree原理

    机器人的复杂行为的控制结构CA(Contrl Architecture)通常使用有限状态机来实现,例如ROS提供的smach.行为树是另外一种实现机器人控制的方法,ROS下代表的开源库有pi_tree ...

  8. @Transactional导致AbstractRoutingDataSource动态数据源无法切换的解决办法

    上午花了大半天排查一个多数据源主从切换的问题,记录一下: 背景: 项目的数据库采用了读写分离多数据源,采用AOP进行拦截,利用ThreadLocal及AbstractRoutingDataSource ...

  9. [坑] treap

    先来挖个坑,以后有时间了来补上. treap: 学习资料: fhq式treap    http://hi.baidu.com/wdxertqdtscnwze/item/7b6a9419be7c68cd ...

  10. LightOJ 1074 - Extended Traffic (SPFA)

    http://lightoj.com/volume_showproblem.php?problem=1074 1074 - Extended Traffic   PDF (English) Stati ...