训练一个分类网络,没想到预测结果为一个定值。

找了很久发现,是因为tensor的维度的原因。  注意:我说的是我的label数据的维度。

我的输入是:

y_= tf.placeholder(tf.int32,[None,1])  #维度:(batchsize,1)

我使用的损失函数:

loss = -y_*log(pred)
pred = tf.softmax(wx+b) #维度:(batch_size,10034)

所以我需要将y_的 维度转化为(batch_size,10034)

我使用的是

y__ = tf.one_hot(y,10034)   #维度是:(batch_size,1,10034),而不是我们的预期:(batch_size,10034)

显然这时有问题的,所以才会在坑中转了好久。

tf.one_hot()的输入数据为一维数组。

正确方法:

y1 = tf.reshape(y_,[-1])     # 变成一维数组(batch_size,)
y__ = tf.one_hot(y1,10034) # (batch_size,10034)
loss = tf.reduce_mean(-tf.reduce_sum(y__*log(pred),reduction_indices=[1]))

tensorflow 训练最后预测结果为一个定值,可能的原因的更多相关文章

  1. Tensorflow训练和预测中的BN层的坑

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  2. tensorflow数据加载、模型训练及预测

    数据集 DNN 依赖于大量的数据.可以收集或生成数据,也可以使用可用的标准数据集.TensorFlow 支持三种主要的读取数据的方法,可以在不同的数据集中使用:本教程中用来训练建立模型的一些数据集介绍 ...

  3. 深度学习入门实战(二)-用TensorFlow训练线性回归

    欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...

  4. 通过TensorFlow训练神经网络模型

    神经网络模型的训练过程其实质上就是神经网络参数的设置过程 在神经网络优化算法中最常用的方法是反向传播算法,下图是反向传播算法流程图: 从上图可知,反向传播算法实现了一个迭代的过程,在每次迭代的开始,先 ...

  5. TensorFlow训练神经网络cost一直为0

    问题描述 这几天在用TensorFlow搭建一个神经网络来做一个binary classifier,搭建一个典型的神经网络的基本思路是: 定义神经网络的layers(层)以及初始化每一层的参数 然后迭 ...

  6. 自己搞了20万张图片100个分类,tensorflow训练23万次后。。。。。。

    自己搞了20万张图片100个分类,tensorflow训练23万次后...... 我自己把训练用的一张图片,弄乱之后做了一个预测 100个汉字,20多万张图片,tensorflow CNN训练23万次 ...

  7. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  8. tensorflow训练验证码识别模型

    tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...

  9. 使用TensorFlow训练自己的语音识别AI

    这次来训练一个基于CNN的语音识别模型.训练完成后,我们将尝试将此模型用于Hotword detection. 人类是怎样听懂一句话的呢?以汉语为例,当听到"wo shi"的录音时 ...

随机推荐

  1. 51、tf-idf值提取关键词

    import testWord2vec2 as tw import tensorflow_util as tu import numpy as np model = tw.load_model() n ...

  2. 7、c++版,在大学学的编程基础知识

    1.各种排序 #include<iostream> using namespace std; //-------直接插入排序 void InsertSort(ElemType A[],in ...

  3. Spring Data JPA one to one 共享主键关联

    /** * Created by xiezhiyan on 17-9-13. */@Entitypublic class Token { @Id @Column(name = "store_ ...

  4. Tomcat启动脚本(2)catalina.bat

    @echo off rem Licensed to the Apache Software Foundation (ASF) under one or more rem contributor lic ...

  5. multiple-cursors实在是太好用了

    multiple-cursors实在是太好用了 */--> code {color: #FF0000} pre.src {background-color: #002b36; color: #8 ...

  6. 【Java学习笔记之一】 java关键字及作用

    Java关键字及其作用 一. 总览: 访问控制 private protected public 类,方法和变量修饰符 abstract class extends final implements ...

  7. LeetCode Array Easy 88. Merge Sorted Array

    Description Given two sorted integer arrays nums1 and nums2, merge nums2 into nums1 as one sorted ar ...

  8. STL_set

    #include <iostream> #include <set> #include <string> #include <cstdio> using ...

  9. express 路由能力

    demo var express=require("express"); var app=express(); app.get("/",function(req ...

  10. delphi和C# 保存exe文件到数据库

    Delphi: procedure TForm1.Button1Click(Sender: TObject); var strSQL, sfilename: string; MStream: TMem ...