tf.one_hot(indices, depth):将目标序列转换成one_hot编码

tf.one_hot
(indices, depth, on_value=None, off_value=None, 
axis=None, dtype=None, name=None)

indices = [0, 2, -1, 1]
depth = 3
on_value = 5.0 
off_value = 0.0 
axis = -1 
#Then output is [4 x 3]: 
output = 
[5.0 0.0 0.0] // one_hot(0) 
[0.0 0.0 5.0] // one_hot(2) 
[0.0 0.0 0.0] // one_hot(-1) 
[0.0 5.0 0.0] // one_hot(1)

with tf.Session() as sess:
print(sess.run(tf.one_hot(np.array([np.array([0,1,2,3]),np.array([2,0,3,2])]),depth=4,axis=-1))) # [[[ 1. 0. 0. 0.]
# [ 0. 1. 0. 0.]
# [ 0. 0. 1. 0.]
# [ 0. 0. 0. 1.]]
# [[ 0. 0. 1. 0.]
# [ 1. 0. 0. 0.]
# [ 0. 0. 0. 1.]
# [ 0. 0. 1. 0.]]] oh = tf.one_hot(indices = [0, 2, -1, 1], depth = 3, on_value = 5.0 , off_value = 0.0, axis = -1)
sess = tf.Session()
sess.run(oh) # array([[5., 0., 0.],
# [0., 0., 5.],
# [0., 0., 0.],
# [0., 5., 0.]], dtype=float32)

另一种思路:稀疏张量构建法

import numpy as np
import tensorflow as tf NUMCLASS = 3
batch_size = 5 labels = tf.placeholder(dtype=tf.int32, shape=[batch_size, 1])
index = tf.reshape(tf.range(0, batch_size,1), [batch_size, 1])
one_hot = tf.sparse_to_dense(
tf.concat(values=[index, labels], axis=1),
[batch_size, NUMCLASS],
1.0, 0.0
)
with tf.Session() as sess:
lab = np.random.randint(0,3,[5,1])
print(sess.run(one_hot, feed_dict={labels:lab}))
print(sess.run(tf.one_hot(np.squeeze(lab),depth=3,axis=1)))

注意两种方法输入数据维度的变化(稀疏法为了得到足够的索引需要升维),结果如下:

[[ 1.  0.  0.]
[ 1. 0. 0.]
[ 0. 0. 1.]
[ 1. 0. 0.]
[ 0. 1. 0.]]
[[ 1. 0. 0.]
[ 1. 0. 0.]
[ 0. 0. 1.]
[ 1. 0. 0.]
[ 0. 1. 0.]]

『TensorFlow』one_hot化标签的更多相关文章

  1. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  2. 『TensorFlow』TFR数据预处理探究以及框架搭建

    一.TFRecord文件书写效率对比(单线程和多线程对比) 1.准备工作 # Author : Hellcat # Time : 18-1-15 ''' import os os.environ[&q ...

  3. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  4. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  5. 『TensorFlow』分布式训练_其三_多机分布式

    本节中的代码大量使用『TensorFlow』分布式训练_其一_逻辑梳理中介绍的概念,是成熟的多机分布式训练样例 一.基本概念 Cluster.Job.task概念:三者可以简单的看成是层次关系,tas ...

  6. 『TensorFlow』DCGAN生成动漫人物头像_下

    『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...

  7. 『TensorFlow』滑动平均

    滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...

  8. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

  9. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

随机推荐

  1. Mock Server 入门(一)

    Mock Server 使用场景 1.开发过程中依赖一些接口,而这些接口可能有一下情况: 1)接口搭建环境比较困难:例如支付宝的支付接口,需要授权等等准备好才能进行调试 2)接口暂时还未实现时:可以便 ...

  2. UI自动化框架——构建思维

    目的:从Excel中获取列的值,传输到页面 技巧:尽可能的提高方法的重用率 Java包: 1.java.core包 3个类:1)日志(LogEventListener)扩展web driver自带的事 ...

  3. 使用js对WGS-84 ,GCJ-02与BD-09的坐标进行转换

    获取到经纬度在用百度地图进行定位时,却发现行驶轨迹的路线定到海里面去了.从网上查阅,知道此方法. 出处:https://www.jianshu.com/p/53f00ba897f7 一.在进行地图开发 ...

  4. Redis入门到高可用(十七)—— 持久化开发运维常见问题

    1.fork操作 2.子进程开销和优化 3.AOF阻塞

  5. CentOS 7 yum安装zabbix 设置中文界面

    1.  配置安装前环境 2.  安装zabbix 3.  设置中文环境 准备搭建环境 : 系统:CentOS7.5 首先关闭SElinux 和防火墙 安装MariaDB数据库 [root@DaMoWa ...

  6. JSP 修改不能编辑

    JSP做修改功能时候,有的时候,某些值要设置成只读状态,不能修改,刚开始做的时候,出现了修改之后值传不到后台的情况,由于刚出来工作不久,不是很了解这个.思索了半天,才发现是由于这个属性的缘故.浪费了大 ...

  7. restful规范快速记忆

    restful规范: 十个规则: 用户发来请求,url必须: 1.因为是面向资源编程,所以每个URL代表一种资源,URL中尽量不要用动词,要用名词 2.尽量使用HTTPS,https代替http 3. ...

  8. 记录 用tiny6410 j-link eclipse 在线调试裸机代码leds

    1.nand flash烧写uboot 并且启动nandflash uboot,用来初始化6410,进入uboot命令行界面 2.在terminal中输入JLinkGDBServer -device ...

  9. jqueryd的post传递表单以及取消表单的默认传递

    //取消表单的默认传递: <form method="post" onsubmit="return false;"> 在FORM属性里添加 onsu ...

  10. 进程间通信之信号量、消息队列、共享内存(system v的shm和mmap)+信号signal

    进程间通信方式有:System v unix提供3种进程间通信IPC:信号量.消息队列.共享内存.此外,传统方法:信号.管道.socket套接字. [注意上述6种方式只能用户层进程间通信.内核内部有类 ...