训练一个分类网络,没想到预测结果为一个定值。

找了很久发现,是因为tensor的维度的原因。  注意:我说的是我的label数据的维度。

我的输入是:

y_= tf.placeholder(tf.int32,[None,1])  #维度:(batchsize,1)

我使用的损失函数:

loss = -y_*log(pred)
pred = tf.softmax(wx+b) #维度:(batch_size,10034)

所以我需要将y_的 维度转化为(batch_size,10034)

我使用的是

y__ = tf.one_hot(y,10034)   #维度是:(batch_size,1,10034),而不是我们的预期:(batch_size,10034)

显然这时有问题的,所以才会在坑中转了好久。

tf.one_hot()的输入数据为一维数组。

正确方法:

y1 = tf.reshape(y_,[-1])     # 变成一维数组(batch_size,)
y__ = tf.one_hot(y1,10034) # (batch_size,10034)
loss = tf.reduce_mean(-tf.reduce_sum(y__*log(pred),reduction_indices=[1]))

tensorflow 训练最后预测结果为一个定值,可能的原因的更多相关文章

  1. Tensorflow训练和预测中的BN层的坑

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  2. tensorflow数据加载、模型训练及预测

    数据集 DNN 依赖于大量的数据.可以收集或生成数据,也可以使用可用的标准数据集.TensorFlow 支持三种主要的读取数据的方法,可以在不同的数据集中使用:本教程中用来训练建立模型的一些数据集介绍 ...

  3. 深度学习入门实战(二)-用TensorFlow训练线性回归

    欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...

  4. 通过TensorFlow训练神经网络模型

    神经网络模型的训练过程其实质上就是神经网络参数的设置过程 在神经网络优化算法中最常用的方法是反向传播算法,下图是反向传播算法流程图: 从上图可知,反向传播算法实现了一个迭代的过程,在每次迭代的开始,先 ...

  5. TensorFlow训练神经网络cost一直为0

    问题描述 这几天在用TensorFlow搭建一个神经网络来做一个binary classifier,搭建一个典型的神经网络的基本思路是: 定义神经网络的layers(层)以及初始化每一层的参数 然后迭 ...

  6. 自己搞了20万张图片100个分类,tensorflow训练23万次后。。。。。。

    自己搞了20万张图片100个分类,tensorflow训练23万次后...... 我自己把训练用的一张图片,弄乱之后做了一个预测 100个汉字,20多万张图片,tensorflow CNN训练23万次 ...

  7. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  8. tensorflow训练验证码识别模型

    tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...

  9. 使用TensorFlow训练自己的语音识别AI

    这次来训练一个基于CNN的语音识别模型.训练完成后,我们将尝试将此模型用于Hotword detection. 人类是怎样听懂一句话的呢?以汉语为例,当听到"wo shi"的录音时 ...

随机推荐

  1. VTemplate模板引擎的使用--高级篇

    VTemplate模板引擎的使用--高级篇 在网站中,经常会有某个栏目的数据在多个页面同时使用到.比如新闻网站或电子商务网站的栏目列表,几乎在很多页面都会显示栏目导航.对于这种多个页面同时使用到的“数 ...

  2. 25. SPI

  3. Java中创建泛型数组

    Java中创建泛型数组 使用泛型时,我想很多人肯定尝试过如下的代码,去创建一个泛型数组 T[] array = new T[]; 当我们写出这样的代码时编译器会报Cannot create a gen ...

  4. python之正则表达式【re】

    在处理字符串时,经常会有查找符合某些规则的字符串的需求.正则表达式就是用于藐视这些规则的工具.换句话说,正则表达式是记录文本规则的代码. 1.行定位符. 行定位符就是用来表示字符串的边界,“^”表示开 ...

  5. python获取网页源代码

    最简单的网页取源(不用模拟浏览器的情况) import requests def getHTML(url): try: r = requests.get(url,timeout=30) r.raise ...

  6. ASE团队项目alpha阶段Frontend组 scrum2 记录

    ASE团队项目alpha阶段Frontend组 scrum2 记录 本次会议于11.5日, 11:30在微软北京西二楼13158研讨室,讨论持续15分钟 与会人员:Jingyi Xie, Jiaqi ...

  7. linux 7 安装KVM

    首先,在安装GUI的linux 7系统下,安装KVM 执行命令 #yum install qemu-kvm qemu-kvm-tools virt-manager libvirt virt-insta ...

  8. caffer的三种文件类别

    solver文件 是一堆超参数,比如迭代次数,是否用GPU,多少次迭代暂存一次训练所得参数,动量项,权重衰减(即正则化参数),基本的learning rate,多少次迭代打印一次loss,以及网络结构 ...

  9. 【sql】牛客网练习题 (共 61 题)

    [1]查找最晚入职员工的所有信息 CREATE TABLE `employees` ( `emp_no` ) NOT NULL, `birth_date` date NOT NULL, `first_ ...

  10. day10 python算法 冒泡算法 二分法 最快查找算法 c3算法

    day10 python       1.冒泡算法         冒泡排序,把列表竖起来看,就像一个个气泡往上去(时间复杂度大) lst = [12,3,3,2424,14,3567,534,324 ...