原文链接

tensorflow中取下标的函数包括:tf.gather , tf.gather_nd 和 tf.batch_gather。

1.tf.gather(params,indices,validate_indices=None,name=None,axis=0)

indices必须是一维张量

主要参数:

  • params:被索引的张量
  • indices:一维索引张量
  • name:返回张量名称

返回值:通过indices获取params下标的张量。

例子:

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([1,2,0],dtype=tf.int32)
tensor_c = tf.Variable([0,0],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather(tensor_a,tensor_b)))
print(sess.run(tf.gather(tensor_a,tensor_c)))

上个例子tf.gather(tensor_a,tensor_b) 的值为[[4,5,6],[7,8,9],[1,2,3]],tf.gather(tensor_a,tensor_b) 的值为[[1,2,3],[1,2,3]]

对于tensor_a,其第1个元素为[4,5,6],第2个元素为[7,8,9],第0个元素为[1,2,3],所以以[1,2,0]为索引的返回值是[[4,5,6],[7,8,9],[1,2,3]],同样的,以[0,0]为索引的值为[[1,2,3],[1,2,3]]。

https://www.tensorflow.org/api_docs/python/tf/gather

2.tf.gather_nd(params,indices,name=None)

功能和参数与tf.gather类似,不同之处在于tf.gather_nd支持多维度索引,即indices可以使多维张量。

例子:

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[1,0],[1,1],[1,2]],dtype=tf.int32)
tensor_c = tf.Variable([[0,2],[2,0]],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather_nd(tensor_a,tensor_b)))
print(sess.run(tf.gather_nd(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值为[4,5,6],tf.gather_nd(tensor_a,tensor_c)的值为[3,7].

对于tensor_a,下标[1,0]的元素为4,下标为[1,1]的元素为5,下标为[1,2]的元素为6,索引[1,0],[1,1],[1,2]]的返回值为[4,5,6],同样的,索引[[0,2],[2,0]]的返回值为[3,7].

https://www.tensorflow.org/api_docs/python/tf/gather_nd

3.tf.batch_gather(params,indices,name=None)

支持对张量的批量索引,各参数意义见(1)中描述。注意因为是批处理,所以indices要有和params相同的第0个维度。

例子:

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
print(sess.run(tf.batch_gather(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值为[1,5,9],tf.gather_nd(tensor_a,tensor_c)的值为[1,4,7].

tensor_a的三个元素[1,2,3],[4,5,6],[7,8,9]分别对应索引元素的第一,第二和第三个值。[1,2,3]的第0个元素为1,[4,5,6]的第1个元素为5,[7,8,9]的第2个元素为9,所以索引[[0],[1],[2]]的返回值为[1,5,9],同样地,索引[[0],[0],[0]]的返回值为[1,4,7].

https://www.tensorflow.org/api_docs/python/tf/batch_gather

在深度学习的模型训练中,有时候需要对一个batch的数据进行类似于tf.gather_nd的操作,但tensorflow中并没有tf.batch_gather_nd之类的操作,此时需要tf.map_fn和tf.gather_nd结合来实现上述操作。

[转]tensorflow中的gather的更多相关文章

  1. Tensorflow中的padding操作

    转载请注明出处:http://www.cnblogs.com/willnote/p/6746668.html 图示说明 用一个3x3的网格在一个28x28的图像上做切片并移动 移动到边缘上的时候,如果 ...

  2. CNN中的卷积核及TensorFlow中卷积的各种实现

    声明: 1. 我和每一个应该看这篇博文的人一样,都是初学者,都是小菜鸟,我发布博文只是希望加深学习印象并与大家讨论. 2. 我不确定的地方用了"应该"二字 首先,通俗说一下,CNN ...

  3. python/numpy/tensorflow中,对矩阵行列操作,下标是怎么回事儿?

    Python中的list/tuple,numpy中的ndarrray与tensorflow中的tensor. 用python中list/tuple理解,仅仅是从内存角度理解一个序列数据,而非数学中标量 ...

  4. [翻译] Tensorflow中name scope和variable scope的区别是什么

    翻译自:https://stackoverflow.com/questions/35919020/whats-the-difference-of-name-scope-and-a-variable-s ...

  5. SSD:TensorFlow中的单次多重检测器

    SSD:TensorFlow中的单次多重检测器 SSD Notebook 包含 SSD TensorFlow 的最小示例. 很快,就检测出了两个主要步骤:在图像上运行SSD网络,并使用通用算法(top ...

  6. 在 TensorFlow 中实现文本分类的卷积神经网络

    在TensorFlow中实现文本分类的卷积神经网络 Github提供了完整的代码: https://github.com/dennybritz/cnn-text-classification-tf 在 ...

  7. [开发技巧]·TensorFlow中numpy与tensor数据相互转化

    [开发技巧]·TensorFlow中numpy与tensor数据相互转化 个人主页–> https://xiaosongshine.github.io/ - 问题描述 在我们使用TensorFl ...

  8. TensorFlow中的变量和常量

    1.TensorFlow中的变量和常量介绍 TensorFlow中的变量: import tensorflow as tf state = tf.Variable(0,name='counter') ...

  9. TensorFlow中的通信机制——Rendezvous(二)gRPC传输

    背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 本篇是TensorFlow通信机制系列的第二篇文章,主要梳理使用gRPC网络传 ...

随机推荐

  1. network is unreachable 网关PING不通解决办法

    所里有几台机器没有办法ping通网关,但是ping交换机里的其它机器都可以ping通,其它机器ping网关也可以ping通.那么就排出了硬件的故障,主要问题就在问题机器的路由表上了. 看一下路由表 r ...

  2. hdu 3078 Network (暴力)+【LCA】

    <题目链接> 题目大意:给定一颗带点权的树,进行两种操作,k=0,更改某一点的点权,k!=0,输出a~b路径之间权值第k大的点的点权. 解题分析:先通过RMQ的初始化,预处理pre[]数组 ...

  3. 【Spring Boot】构造、访问Restful Webservice与定时任务

    Spring Boot Guides Examples(1~3) 参考网址:https://spring.io/guides 创建一个RESTful Web Service 使用Eclipse 创建一 ...

  4. mybatis查询语句的背后之封装数据

    转载请注明出处... 一.前言 继上一篇mybatis查询语句的背后,这一篇主要围绕着mybatis查询的后期操作,即跟数据库交互的时候.由于本人也是一边学习源码一边记录,内容难免有错误或不足之处,还 ...

  5. shell && and ||

    2013-04-08 17:40:47   shell中&&和||的使用方法 &&运算符:   command1  && command2   & ...

  6. 基于Ardalis.GuardClauses守卫组件的拓展

    在我们写程序的时候,经常会需要判断数据的是空值还是null值,基本上十个方法函数,八个要做这样的判断,因此我们很有必要拓展出来一个类来做监控,在这里我们使用一个简单地,可拓展的第三方组件:Ardali ...

  7. Alpha(5/10)

    鐵鍋燉腯鱻 项目:小鱼记账 团队成员 项目燃尽图 冲刺情况描述 站立式会议照片 各成员情况 团队成员 学号 姓名 git地址 博客地址 031602240 许郁杨 (组长) https://githu ...

  8. Stm32常见英文缩写

    Stm32常见英文缩写 https://wenku.baidu.com/view/4b9c2eee5022aaea998f0f5b.html STM32嵌入式开发常见缩写 https://wenku. ...

  9. BZOJ2512 : Groc

    最优解一定是将起点.终点以及所有必经点连接成一棵树,对于每条树边恰好走两次,而从起点到终点的一条路径只走一次. 考虑连通性DP,设$f[i][j][k][x]$表示考虑完前$i$个走道,第$i$个走道 ...

  10. jQueryUI中Datepicker(日历)插件使用

    atepicker插件的属性: 属性 数据类型 默认值 说明 altField string "" 使用备用的输出字段,即将选择的日期 以另一种格式,输出到另一个控件中, 值为选择 ...