训练一个分类网络,没想到预测结果为一个定值。

找了很久发现,是因为tensor的维度的原因。  注意:我说的是我的label数据的维度。

我的输入是:

y_= tf.placeholder(tf.int32,[None,1])  #维度:(batchsize,1)

我使用的损失函数:

loss = -y_*log(pred)
pred = tf.softmax(wx+b) #维度:(batch_size,10034)

所以我需要将y_的 维度转化为(batch_size,10034)

我使用的是

y__ = tf.one_hot(y,10034)   #维度是:(batch_size,1,10034),而不是我们的预期:(batch_size,10034)

显然这时有问题的,所以才会在坑中转了好久。

tf.one_hot()的输入数据为一维数组。

正确方法:

y1 = tf.reshape(y_,[-1])     # 变成一维数组(batch_size,)
y__ = tf.one_hot(y1,10034) # (batch_size,10034)
loss = tf.reduce_mean(-tf.reduce_sum(y__*log(pred),reduction_indices=[1]))

tensorflow 训练最后预测结果为一个定值,可能的原因的更多相关文章

  1. Tensorflow训练和预测中的BN层的坑

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  2. tensorflow数据加载、模型训练及预测

    数据集 DNN 依赖于大量的数据.可以收集或生成数据,也可以使用可用的标准数据集.TensorFlow 支持三种主要的读取数据的方法,可以在不同的数据集中使用:本教程中用来训练建立模型的一些数据集介绍 ...

  3. 深度学习入门实战(二)-用TensorFlow训练线性回归

    欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...

  4. 通过TensorFlow训练神经网络模型

    神经网络模型的训练过程其实质上就是神经网络参数的设置过程 在神经网络优化算法中最常用的方法是反向传播算法,下图是反向传播算法流程图: 从上图可知,反向传播算法实现了一个迭代的过程,在每次迭代的开始,先 ...

  5. TensorFlow训练神经网络cost一直为0

    问题描述 这几天在用TensorFlow搭建一个神经网络来做一个binary classifier,搭建一个典型的神经网络的基本思路是: 定义神经网络的layers(层)以及初始化每一层的参数 然后迭 ...

  6. 自己搞了20万张图片100个分类,tensorflow训练23万次后。。。。。。

    自己搞了20万张图片100个分类,tensorflow训练23万次后...... 我自己把训练用的一张图片,弄乱之后做了一个预测 100个汉字,20多万张图片,tensorflow CNN训练23万次 ...

  7. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  8. tensorflow训练验证码识别模型

    tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...

  9. 使用TensorFlow训练自己的语音识别AI

    这次来训练一个基于CNN的语音识别模型.训练完成后,我们将尝试将此模型用于Hotword detection. 人类是怎样听懂一句话的呢?以汉语为例,当听到"wo shi"的录音时 ...

随机推荐

  1. HTML-参考手册: HTML ISO-8859-1

    ylbtech-HTML-参考手册: HTML ISO-8859-1 1.返回顶部 1. HTML ISO-8859-1 参考手册 现代的浏览器支持的字符集: ASCII 字符集 标准 ISO 字符集 ...

  2. js关闭当前页面清除session

    js关闭当前页面清除session 普通页面 <!DOCTYPE html> <html> <head> <meta charset="UTF-8& ...

  3. Tomcat 配置安装

    1 下载和安装Tomcat服务器 Tomcat官方站点:http://jakarta.apache.org 下载Tomcat安装程序包:http://tomcat.apache.org/ 启动和测试T ...

  4. camunda任务的一些简单操作

    public class ZccTaskService { TaskService taskService; @Before public void init(){ ProcessEngineConf ...

  5. Spring Cloud服务安全连接

    Spring Cloud可以增加HTTP Basic认证来增加服务连接的安全性. 1.加入security启动器 在maven配置文件中加入Spring Boot的security启动器. <d ...

  6. Msys2升级后不能编译

    Msys2升级后不能编译 */--> code {color: #FF0000} pre.src {background-color: #002b36; color: #839496;} cod ...

  7. CentOS6.5源码安装MySQL5.6.35

    CentOS6.5源码安装MySQL5.6.35 一.卸载旧版本 1.使用下面的命令检查是否安装有mysql [root@localhost tools]# rpm -qa|grep -i mysql ...

  8. CF1228F

    写了一个特别麻烦的做法 首先一共有三种情况:1.删掉一个叶子,2.删掉根的一个儿子,3.其他的节点 第一种情况会有两个度数为2的节点,第二种情况没有度数为2的节点,第三种情况会有一个度数为4的节点 然 ...

  9. nodejs express的基本用法

    demo /** * Created by ZXW on 2017/11/6. */ var express=require("express"); var server=expr ...

  10. github fork代码后提交

    点击他人github上的fork 在自己的Github上将代码拷贝下来 git clone 在本地修改代码后创建分支 git checkout -b work master(work为新建的特性分支, ...