[转]tensorflow中的gather
原文链接
tensorflow中取下标的函数包括:tf.gather , tf.gather_nd 和 tf.batch_gather。
1.tf.gather(params,indices,validate_indices=None,name=None,axis=0)
indices必须是一维张量
主要参数:
- params:被索引的张量
- indices:一维索引张量
- name:返回张量名称
返回值:通过indices获取params下标的张量。
例子:
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([1,2,0],dtype=tf.int32)
tensor_c = tf.Variable([0,0],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather(tensor_a,tensor_b)))
print(sess.run(tf.gather(tensor_a,tensor_c)))
上个例子tf.gather(tensor_a,tensor_b) 的值为[[4,5,6],[7,8,9],[1,2,3]],tf.gather(tensor_a,tensor_b) 的值为[[1,2,3],[1,2,3]]
对于tensor_a,其第1个元素为[4,5,6],第2个元素为[7,8,9],第0个元素为[1,2,3],所以以[1,2,0]为索引的返回值是[[4,5,6],[7,8,9],[1,2,3]],同样的,以[0,0]为索引的值为[[1,2,3],[1,2,3]]。
https://www.tensorflow.org/api_docs/python/tf/gather
2.tf.gather_nd(params,indices,name=None)
功能和参数与tf.gather类似,不同之处在于tf.gather_nd支持多维度索引,即indices可以使多维张量。
例子:
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[1,0],[1,1],[1,2]],dtype=tf.int32)
tensor_c = tf.Variable([[0,2],[2,0]],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather_nd(tensor_a,tensor_b)))
print(sess.run(tf.gather_nd(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值为[4,5,6],tf.gather_nd(tensor_a,tensor_c)的值为[3,7].
对于tensor_a,下标[1,0]的元素为4,下标为[1,1]的元素为5,下标为[1,2]的元素为6,索引[1,0],[1,1],[1,2]]的返回值为[4,5,6],同样的,索引[[0,2],[2,0]]的返回值为[3,7].
https://www.tensorflow.org/api_docs/python/tf/gather_nd
3.tf.batch_gather(params,indices,name=None)
支持对张量的批量索引,各参数意义见(1)中描述。注意因为是批处理,所以indices要有和params相同的第0个维度。
例子:
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
print(sess.run(tf.batch_gather(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值为[1,5,9],tf.gather_nd(tensor_a,tensor_c)的值为[1,4,7].
tensor_a的三个元素[1,2,3],[4,5,6],[7,8,9]分别对应索引元素的第一,第二和第三个值。[1,2,3]的第0个元素为1,[4,5,6]的第1个元素为5,[7,8,9]的第2个元素为9,所以索引[[0],[1],[2]]的返回值为[1,5,9],同样地,索引[[0],[0],[0]]的返回值为[1,4,7].
https://www.tensorflow.org/api_docs/python/tf/batch_gather
在深度学习的模型训练中,有时候需要对一个batch的数据进行类似于tf.gather_nd的操作,但tensorflow中并没有tf.batch_gather_nd之类的操作,此时需要tf.map_fn和tf.gather_nd结合来实现上述操作。
[转]tensorflow中的gather的更多相关文章
- Tensorflow中的padding操作
转载请注明出处:http://www.cnblogs.com/willnote/p/6746668.html 图示说明 用一个3x3的网格在一个28x28的图像上做切片并移动 移动到边缘上的时候,如果 ...
- CNN中的卷积核及TensorFlow中卷积的各种实现
声明: 1. 我和每一个应该看这篇博文的人一样,都是初学者,都是小菜鸟,我发布博文只是希望加深学习印象并与大家讨论. 2. 我不确定的地方用了"应该"二字 首先,通俗说一下,CNN ...
- python/numpy/tensorflow中,对矩阵行列操作,下标是怎么回事儿?
Python中的list/tuple,numpy中的ndarrray与tensorflow中的tensor. 用python中list/tuple理解,仅仅是从内存角度理解一个序列数据,而非数学中标量 ...
- [翻译] Tensorflow中name scope和variable scope的区别是什么
翻译自:https://stackoverflow.com/questions/35919020/whats-the-difference-of-name-scope-and-a-variable-s ...
- SSD:TensorFlow中的单次多重检测器
SSD:TensorFlow中的单次多重检测器 SSD Notebook 包含 SSD TensorFlow 的最小示例. 很快,就检测出了两个主要步骤:在图像上运行SSD网络,并使用通用算法(top ...
- 在 TensorFlow 中实现文本分类的卷积神经网络
在TensorFlow中实现文本分类的卷积神经网络 Github提供了完整的代码: https://github.com/dennybritz/cnn-text-classification-tf 在 ...
- [开发技巧]·TensorFlow中numpy与tensor数据相互转化
[开发技巧]·TensorFlow中numpy与tensor数据相互转化 个人主页–> https://xiaosongshine.github.io/ - 问题描述 在我们使用TensorFl ...
- TensorFlow中的变量和常量
1.TensorFlow中的变量和常量介绍 TensorFlow中的变量: import tensorflow as tf state = tf.Variable(0,name='counter') ...
- TensorFlow中的通信机制——Rendezvous(二)gRPC传输
背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 本篇是TensorFlow通信机制系列的第二篇文章,主要梳理使用gRPC网络传 ...
随机推荐
- boost 正则表达式 regex
boost 正则表达式 regex 环境安装 如果在引用boost regex出现连接错误,但是引用其他的库却没有这个错误,这是因为对于boost来说,是免编译的,但是,正则这个库 是需要单独编译 ...
- HDU 4635 Strongly connected (强连通分量+缩点)
<题目链接> 题目大意: 给你一张有向图,问在保证该图不能成为强连通图的条件下,最多能够添加几条有向边. 解题分析: 我们从反面思考,在该图是一张有向完全图的情况下,最少删去几条边能够使其 ...
- (Zero XOR Subset)-less-线性基
(Zero XOR Subset)-less 题意 :把n个数分成多个集合,要求 不能有集合为空,最终不能有非空子集合异或值为0,尽可能划分的多一些. 思路 :非法情况就只有 n个数异或 为0,其他的 ...
- Socket 网络通信
Socket 网络通信 1.OSI (Open System Interconnect Reference Model)(开放系统互联参考模型) 从下低到高 :物理层.数据链路层.网络层.传输层.会话 ...
- python 元类的简单解释
本文转自博客:http://www.cnblogs.com/piperck/p/5840443.html 作者:piperck python 类和元类(metaclass)的理解和简单运用 (一) p ...
- vue中的jsx
一.配置文件package.json { "name": "vuetest", "version": "1.0.0", ...
- Redis自学笔记:3.2入门-字符串类型
3.2字符串类型 实际上redis不只是数据库,更多的公司和团队将redis用作缓存和 队列系统 3.2.1介绍 字符串类型是redis最基本的数据类型,它能存储任何形式的字符串, 包括二进制数据.你 ...
- BZOJ.3532.[SDOI2014]LIS(最小割ISAP 退流)
BZOJ 洛谷 \(LIS\)..经典模型? 令\(f_i\)表示以\(i\)结尾的\(LIS\)长度. 如果\(f_i=1\),连边\((S,i,INF)\):如果\(f_i=\max\limits ...
- python基础一 ------如何统计一个列表元素的频度
如何统计一个列表元素的频度 两个需求: 1,统计一个随机序列[1,2,3,4,5,6...]中的出现次数前三的元素及其次数 2,统计一片英文文章中出现次数前10 的单词 两种方法: 1,普通的for循 ...
- Windows下bat脚本自动发邮件
摘要:说明:很多木马会利用自身的程序截取系统敏感文件或信息发往指定的邮箱,而blat并不是木马,它小巧却有强大的发邮件功能,可不要用它做违法事,感觉和木马功能有一拼!下面先看个具体的实例(在blat同 ...