目录

  1. 什么是形状推断
  2. InferenceContext
  3. 关系图
  4. 涉及的文件
  5. 迭代记录

1. 什么是形状推断

前面我们讲到op的时候,提到了操作的注册器OpRegistry,并且提到,其中注册的数据是一个结构OpRegistrationData,这个结构中除了OpDef之外,还包含了一个OpShapeInferenceFn,这个数据是做什么用的呢?

我们知道,op只是定义了操作的输入输出和参数,但并没有定义操作具体的输入形状,举个例子,MatMul操作,代表矩阵乘法,这只是一个抽象的表示,没有具体说,这个矩阵乘法代表的是[2,3]x[3,4]=[2,4],还是[100,200]x[200,300]=[100,300]。所以在实际应用中,在得到输入之前,输出的真实形状是无法预知的,但在得到输入之后,我们必须能够根据输入的形状,以及当前op的作用,判断输出的具体形状,才能给它申请对应大小的内存空间。所以,我们需要为每一个操作,配备一个形状推断的函数,这就是形状推断的由来。

2. InferenceContext

前面提到了OpShapeInferenceFn,我们来看一下它的详细定义:

typedef std::function<Status(shape_inference::InferenceContext* c)> OpShapeInferenceFn;

可见,OpShapeInferenceFn是一个接收InferenceContext参数的函数,TF为所有op的形状推断函数,准备了这样一个统一的接口。所有跟形状推断相关的数据和功能函数,都放在InferenceContext这个类的内部。回想一下前面讲过的OpKernelContext,其实它们的功能很像。OpKernelContext是作为OpKernel的核心API Compute函数的参数,所有计算相关的参数都会包含在这个对象中。InferenceContext也是一样,我们把所有跟形状推断相关的数据和功能函数封装在一个InferenceContext对象中,然后把这个对象传递给OpShapeInferenceFn,就可以实现形状推断。这种设计实现了数据部分和实现逻辑的解耦。

在具体看ShapeInference类之前,我们先要看一些辅助类:

class Dimension {
private:
//...
const int64 value_;
};
class DimensionHandle {
private:
//...
const Dimension* ptr_ = nullptr;
};
class Shape {
//...
private:
const int32 rank_;
const std::vector<DimensionHandle> dims_;
};
class ShapeHandle {
//...
private:
const Shape* ptr = nullptr;
};
class DimensionOrConstant {
public:
//...
DimensionHandle dim;
int64 val;
};
class ShapeAndType {
ShapeHandle shape;
DataType dtype = DT_INVALID;
};

这几个类都比较简单。在下面用到时能够认得就好了。

下面我们看下InferenceContext这个类:

class InferenceContext {
public:
InferenceContext(int graph_def_version, const NodeDef* node_def, const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<ShapeHandle>& input_tensors_as_shapes, std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types);//构造函数
Status Run(const std::function<Status(shape_inference::InferenceContext* c)>& fn);//运行一个以this为参数的函数,没错,这里运行的就是OpShapeInferenceFn
bool MergeInput(int idx, ShapeHandle shape);
bool RelaxInput(int idx, ShapeHandle shape);
private:
ShapeManager shape_manager_;
std::vector<ShapeHandle> inputs_;
std::vector<const Tensor*> input_tensors_;
std::vector<bool> requested_input_tensor_;
std::vector<ShapeHandle> outputs_;
std::vector<ShapeHandle> input_tensors_as_shapes_;
std::vector<bool> requested_input_tensor_as_partial_shape_;
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types_;
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> output_handle_shapes_and_types_;
const int graph_def_version_;
const NodeDef& node_def_;
NameRangeMap input_name_map_;
NameRangeMap output_name_map_;
Status construction_status_;
};

前面已经介绍过了这个类的作用,是作为真正的形状推断函数的参数,为形状推断提供足够的数据和功能函数支持,那么这个类的成员就比较清晰了,私有数据成员为形状推断提供数据支持,而公有API,为形状推断提供公用的功能函数,比如上面提到的MergeInput和RelaxOutput,下面我们重点介绍下这两个函数的功能:

MergeInput函数是将输入索引idx处的输入与shape合并,具体的合并规则是:

  • 如果ShapeHandles是一样的,或者shape是未知的,那么输入维度不变。否则,如果输入维度是未知的,那么输出是shape;
  • 如果两个形状都是已知的,它们必须拥有相同的rank;
  • 对于任意一个维度,如果在两个形状中这个维度都已知,那么它们必须相等;
  • 如果一个形状在任意维度上的信息都多于另一个形状,那么拥有更多信息的形状将被返回。否则,一个新的形状将被构建并返回,这个新的形状综合了输入的两个形状的信息;
  • 比如,合并[2,?]和[?,2]将得到[2,2];
  • 比如,[2,2]不能被合并到[1,2]

如果说MergeInput函数对输入形状是“收缩”的,那么“RelaxInput”函数对输入形状就是“扩张”的,它倾向于让形状变的更模糊,具体的规则是:

  • 如果ShapeHandles是一样的,那么对应的shape将会被返回;
  • 如果任一个ShapeHandle是未知的,那么一个未知的ShapeHandle将会被返回;
  • 如果两个形状的rank已知,但不同,那么一个未知ShapeHandle将会被返回;
  • 对于任一维度,如果任一shape是未知的,那么对应的输出维度也是未知的;
  • 对于任一维度,如果两个shape对应的维度位置都是已知的,但并不相同,那么对应的输出维度也是未知的;
  • 如果两个shape的rank和对应维度大小都一样,那么这个形状将会被返回;
  • 例如,[2,?]和[?,2]会得到[?,?];
  • 例如,[2,2]和[3,2]会得到[?,2];
  • 例如,[2,2]和[1,2,3]会得到?

3. 关系图

graph TB
OpShapeInferenceFn-.使用参数.->InferenceContext
OpKernel::Compute-.使用参数.->OpKernelContext

4. 涉及的文件

  • shape_inference

5. 迭代记录

  • v1.0 2018-08-29 文档创建
  • v2.0 2018-09-10 文档重构

github地址

tensorflow源码解析之framework-shape_inference的更多相关文章

  1. tensorflow源码解析之framework拾遗

    把framework中剩余的内容,按照文件名进行了简单解析.时间原因写的很仓促,算是占个坑,后面有了新的理解再来补充. allocation_description.proto 一个对单次内存分配结果 ...

  2. tensorflow源码解析系列文章索引

    文章索引 framework解析 resource allocator tensor op node kernel graph device function shape_inference 拾遗 c ...

  3. Tensorflow源码解析1 -- 内核架构和源码结构

    1 主流深度学习框架对比 当今的软件开发基本都是分层化和模块化的,应用层开发会基于框架层.比如开发Linux Driver会基于Linux kernel,开发Android app会基于Android ...

  4. tensorflow源码解析之common_runtime-executor-上

    目录 核心概念 executor.h Executor NewLocalExecutor ExecutorBarrier executor.cc structs GraphView ExecutorI ...

  5. tensorflow源码解析之framework-allocator

    目录 什么是allocator 内存分配器的管理 内存分配追踪 其它结构 关系图 涉及的文件 迭代记录 1. 什么是allocator Allocator是所有内存分配器的基类,它定义了内存分配器需要 ...

  6. tensorflow源码解析之common_runtime-executor-下

    目录 核心概念 executor.h Executor NewLocalExecutor ExecutorBarrier executor.cc structs GraphView ExecutorI ...

  7. tensorflow源码解析之framework-op

    目录 什么是op op_def定义 op注册 op构建与注册辅助结构 op重写 关系图 涉及的文件 迭代记录 1. 什么是op op和kernel是TF框架中最重要的两个概念,如果一定要做一个类比的话 ...

  8. tensorflow源码解析之distributed_runtime

    本篇主要介绍TF的分布式运行时的基本概念.为了对TF的分布式运行机制有一个大致的了解,我们先结合/tensorflow/core/protobuf中的文件给出对TF分布式集群的初步理解,然后介绍/te ...

  9. tensorflow源码解析之common_runtime拾遗

    把common_runtime中剩余的内容,按照文件名排序进行了简单的解析,时间原因写的很仓促,算是占个坑,后续有了新的理解再来补充. allocator_retry 有时候内存分配不可能一次完成,为 ...

  10. Tensorflow源码解析2 -- 前后端连接的桥梁 - Session

    Session概述 1. Session是TensorFlow前后端连接的桥梁.用户利用session使得client能够与master的执行引擎建立连接,并通过session.run()来触发一次计 ...

随机推荐

  1. 解决sublime代码不提示的问题

    如果想让sublime在你输入标签的过程中给你提示,需要按要求开启以下功能. 1.开启代码自动提示功能

  2. iOS开发之工欲善其事,必先利其器

    SimPholders SimPholders是一个快速简单的小工具,可以帮助开发者快速访问iPhone模拟器应用.它可以访问模拟器的文件夹,重置库和文件,以及删除选定的应用程序. 常规做法 找到Fi ...

  3. 阿里云无法ping通解决

    https://blog.csdn.net/longgeaisisi/article/details/78429099

  4. CSS解决父级边框坍塌的问题

    1. 浮动元素后面增加空的div 首先在父级标签内添加如下<div>标签 <div id="clear"></div> 然后在CSS中对该标签进 ...

  5. Spack 内置函数

    1.Map函数:通过函数传递源的每个元素,并形成新的分布式数据集. %spark #并行化集合生成RDD var data = sc.parallelize(List(10,20,30)) %输出结果 ...

  6. Lesson10——NumPy 迭代数组

    NumPy 教程目录 NumPy 迭代数组 NumPy 迭代器对象  numpy.nditer  提供了一种灵活访问一个或者多个数组元素的方式. 迭代器最基本的任务的可以完成对数组元素的访问. Exa ...

  7. C++职工管理系统

    目录 职工管理系统 一. 需求 二. 创建管理类 1.创建文件 2. 头文件实现 3. 源文件实现 三. 菜单功能 1. 添加成员函数 2. 功能实现 3. 测试菜单功能 四. 退出功能 1. 提供功 ...

  8. 03 前端基础之JavaScript

    目录 前端基础之JavaScript JavaScript JavaScript注释 变量与常量 基本数据类型 number类型 string类型 boolean类型 null与undefined类型 ...

  9. 我们一起来学Shell - 正则表达式

    文章目录 什么是正则表达式 正则表达式元字符 正则表达式应用举例 POSIX 方括号表达式 POSIX 字符集列表: 我们一起来学Shell - 初识shell 我们一起来学Shell - shell ...

  10. INTERSPEECH 2014 | 1-Bit Stochastic Gradient Descent and its Application to Data-Parallel Distributed Training of Speech DNNs

    这篇文章之前也读过,不过读的不太仔细,论文中的一些细节并没有注意到.最近为了写开题报告,又把这篇论文细读了一遍.据笔者了解,这篇论文应该是梯度量化领域的开山之作,首次使用了梯度量化技术来降低分布式神经 ...