[转]tensorflow中的gather
原文链接
tensorflow中取下标的函数包括:tf.gather , tf.gather_nd 和 tf.batch_gather。
1.tf.gather(params,indices,validate_indices=None,name=None,axis=0)
indices必须是一维张量
主要参数:
- params:被索引的张量
- indices:一维索引张量
- name:返回张量名称
返回值:通过indices获取params下标的张量。
例子:
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([1,2,0],dtype=tf.int32)
tensor_c = tf.Variable([0,0],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather(tensor_a,tensor_b)))
print(sess.run(tf.gather(tensor_a,tensor_c)))
上个例子tf.gather(tensor_a,tensor_b) 的值为[[4,5,6],[7,8,9],[1,2,3]],tf.gather(tensor_a,tensor_b) 的值为[[1,2,3],[1,2,3]]
对于tensor_a,其第1个元素为[4,5,6],第2个元素为[7,8,9],第0个元素为[1,2,3],所以以[1,2,0]为索引的返回值是[[4,5,6],[7,8,9],[1,2,3]],同样的,以[0,0]为索引的值为[[1,2,3],[1,2,3]]。
https://www.tensorflow.org/api_docs/python/tf/gather
2.tf.gather_nd(params,indices,name=None)
功能和参数与tf.gather类似,不同之处在于tf.gather_nd支持多维度索引,即indices可以使多维张量。
例子:
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[1,0],[1,1],[1,2]],dtype=tf.int32)
tensor_c = tf.Variable([[0,2],[2,0]],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather_nd(tensor_a,tensor_b)))
print(sess.run(tf.gather_nd(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值为[4,5,6],tf.gather_nd(tensor_a,tensor_c)的值为[3,7].
对于tensor_a,下标[1,0]的元素为4,下标为[1,1]的元素为5,下标为[1,2]的元素为6,索引[1,0],[1,1],[1,2]]的返回值为[4,5,6],同样的,索引[[0,2],[2,0]]的返回值为[3,7].
https://www.tensorflow.org/api_docs/python/tf/gather_nd
3.tf.batch_gather(params,indices,name=None)
支持对张量的批量索引,各参数意义见(1)中描述。注意因为是批处理,所以indices要有和params相同的第0个维度。
例子:
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
print(sess.run(tf.batch_gather(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值为[1,5,9],tf.gather_nd(tensor_a,tensor_c)的值为[1,4,7].
tensor_a的三个元素[1,2,3],[4,5,6],[7,8,9]分别对应索引元素的第一,第二和第三个值。[1,2,3]的第0个元素为1,[4,5,6]的第1个元素为5,[7,8,9]的第2个元素为9,所以索引[[0],[1],[2]]的返回值为[1,5,9],同样地,索引[[0],[0],[0]]的返回值为[1,4,7].
https://www.tensorflow.org/api_docs/python/tf/batch_gather
在深度学习的模型训练中,有时候需要对一个batch的数据进行类似于tf.gather_nd的操作,但tensorflow中并没有tf.batch_gather_nd之类的操作,此时需要tf.map_fn和tf.gather_nd结合来实现上述操作。
[转]tensorflow中的gather的更多相关文章
- Tensorflow中的padding操作
转载请注明出处:http://www.cnblogs.com/willnote/p/6746668.html 图示说明 用一个3x3的网格在一个28x28的图像上做切片并移动 移动到边缘上的时候,如果 ...
- CNN中的卷积核及TensorFlow中卷积的各种实现
声明: 1. 我和每一个应该看这篇博文的人一样,都是初学者,都是小菜鸟,我发布博文只是希望加深学习印象并与大家讨论. 2. 我不确定的地方用了"应该"二字 首先,通俗说一下,CNN ...
- python/numpy/tensorflow中,对矩阵行列操作,下标是怎么回事儿?
Python中的list/tuple,numpy中的ndarrray与tensorflow中的tensor. 用python中list/tuple理解,仅仅是从内存角度理解一个序列数据,而非数学中标量 ...
- [翻译] Tensorflow中name scope和variable scope的区别是什么
翻译自:https://stackoverflow.com/questions/35919020/whats-the-difference-of-name-scope-and-a-variable-s ...
- SSD:TensorFlow中的单次多重检测器
SSD:TensorFlow中的单次多重检测器 SSD Notebook 包含 SSD TensorFlow 的最小示例. 很快,就检测出了两个主要步骤:在图像上运行SSD网络,并使用通用算法(top ...
- 在 TensorFlow 中实现文本分类的卷积神经网络
在TensorFlow中实现文本分类的卷积神经网络 Github提供了完整的代码: https://github.com/dennybritz/cnn-text-classification-tf 在 ...
- [开发技巧]·TensorFlow中numpy与tensor数据相互转化
[开发技巧]·TensorFlow中numpy与tensor数据相互转化 个人主页–> https://xiaosongshine.github.io/ - 问题描述 在我们使用TensorFl ...
- TensorFlow中的变量和常量
1.TensorFlow中的变量和常量介绍 TensorFlow中的变量: import tensorflow as tf state = tf.Variable(0,name='counter') ...
- TensorFlow中的通信机制——Rendezvous(二)gRPC传输
背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 本篇是TensorFlow通信机制系列的第二篇文章,主要梳理使用gRPC网络传 ...
随机推荐
- 【LeetCode算法-13】Roman to Integer
LeetCode第13题 Roman numerals are represented by seven different symbols: I, V, X, L, C, D and M. Symb ...
- Painting the Fence Gym - 101911E(构造)
There is a beautiful fence near Monocarp's house. The fence consists of nn planks numbered from left ...
- POJ2387 Til the Cows Come Home【Kruscal】
题目链接>>> 题目大意: 谷仓之间有一些路径长度,然后要在这些谷仓之间建立一些互联网,花费的成本与长度成正比,,并且要使这些边连起来看的像一课“树”,然后使成本最大 解题思路: 最 ...
- C# DGVPrinter.cs 打印方法
Code highlighting produced by Actipro CodeHighlighter (freeware)http://www.CodeHighlighter.com/--> ...
- seq2seq升级TF1.5后_Linear报错
解决TF升级到1.5之后seq2seq.py出现的引用报错: 1.4时候使用rnn_cell_impl的_Linear没有问题的,TF升级到1.5之后这一行就运行不过去了,查到的方法是引用core_r ...
- chrome插件的开发
基本目录:icon,manifest,html,js. chrome插件的使用,运行,打包. chrome浏览器打开扩展,勾选开发者模式,点击加载没打包的扩展,选中目录,加载插件. 右上角出现插件图标 ...
- 实现winfrom进度条及进度信息提示
1.方法一:使用线程 功能描述:在用c#做WinFrom开发的过程中.我们经常需要用到进度条(ProgressBar)用于显示进度信息.这时候我们可能就需要用到多线程,如果不采用多线程控制进度条,窗口 ...
- Spring使用笔记(二)Bean装配
Bean装配 Spring提供了3种装配机制: 1)隐式的Bean发现机制和自动装配 2)在Java中进行显示装配 3)在XML中进行显示装配 一)自动化装配 1.指定某类为组件类: @Compone ...
- CentOS 7 安装MongoDB详细步骤
创建/etc/yum.repos.d/mongodb-org-4.0.repo文件,编辑内容如下: [mongodb-org-4.0] name=MongoDB Repository baseurl= ...
- 狄克斯特拉算法(Python实现)
概述 狄克斯特拉算法--用于在加权图中找到最短路径 ps: 广度优先搜索--用于解决非加权图的最短路径问题 存在负权边时--贝尔曼-福德算法 下面是来自维基百科的权威解释. 戴克斯特拉算法(英语:Di ...