tensorflow源码解析之framework-shape_inference
目录
- 什么是形状推断
- InferenceContext
- 关系图
- 涉及的文件
- 迭代记录
1. 什么是形状推断
前面我们讲到op的时候,提到了操作的注册器OpRegistry,并且提到,其中注册的数据是一个结构OpRegistrationData,这个结构中除了OpDef之外,还包含了一个OpShapeInferenceFn,这个数据是做什么用的呢?
我们知道,op只是定义了操作的输入输出和参数,但并没有定义操作具体的输入形状,举个例子,MatMul操作,代表矩阵乘法,这只是一个抽象的表示,没有具体说,这个矩阵乘法代表的是[2,3]x[3,4]=[2,4],还是[100,200]x[200,300]=[100,300]。所以在实际应用中,在得到输入之前,输出的真实形状是无法预知的,但在得到输入之后,我们必须能够根据输入的形状,以及当前op的作用,判断输出的具体形状,才能给它申请对应大小的内存空间。所以,我们需要为每一个操作,配备一个形状推断的函数,这就是形状推断的由来。
2. InferenceContext
前面提到了OpShapeInferenceFn,我们来看一下它的详细定义:
typedef std::function<Status(shape_inference::InferenceContext* c)> OpShapeInferenceFn;
可见,OpShapeInferenceFn是一个接收InferenceContext参数的函数,TF为所有op的形状推断函数,准备了这样一个统一的接口。所有跟形状推断相关的数据和功能函数,都放在InferenceContext这个类的内部。回想一下前面讲过的OpKernelContext,其实它们的功能很像。OpKernelContext是作为OpKernel的核心API Compute函数的参数,所有计算相关的参数都会包含在这个对象中。InferenceContext也是一样,我们把所有跟形状推断相关的数据和功能函数封装在一个InferenceContext对象中,然后把这个对象传递给OpShapeInferenceFn,就可以实现形状推断。这种设计实现了数据部分和实现逻辑的解耦。
在具体看ShapeInference类之前,我们先要看一些辅助类:
class Dimension {
private:
//...
const int64 value_;
};
class DimensionHandle {
private:
//...
const Dimension* ptr_ = nullptr;
};
class Shape {
//...
private:
const int32 rank_;
const std::vector<DimensionHandle> dims_;
};
class ShapeHandle {
//...
private:
const Shape* ptr = nullptr;
};
class DimensionOrConstant {
public:
//...
DimensionHandle dim;
int64 val;
};
class ShapeAndType {
ShapeHandle shape;
DataType dtype = DT_INVALID;
};
这几个类都比较简单。在下面用到时能够认得就好了。
下面我们看下InferenceContext这个类:
class InferenceContext {
public:
InferenceContext(int graph_def_version, const NodeDef* node_def, const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<ShapeHandle>& input_tensors_as_shapes, std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types);//构造函数
Status Run(const std::function<Status(shape_inference::InferenceContext* c)>& fn);//运行一个以this为参数的函数,没错,这里运行的就是OpShapeInferenceFn
bool MergeInput(int idx, ShapeHandle shape);
bool RelaxInput(int idx, ShapeHandle shape);
private:
ShapeManager shape_manager_;
std::vector<ShapeHandle> inputs_;
std::vector<const Tensor*> input_tensors_;
std::vector<bool> requested_input_tensor_;
std::vector<ShapeHandle> outputs_;
std::vector<ShapeHandle> input_tensors_as_shapes_;
std::vector<bool> requested_input_tensor_as_partial_shape_;
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types_;
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> output_handle_shapes_and_types_;
const int graph_def_version_;
const NodeDef& node_def_;
NameRangeMap input_name_map_;
NameRangeMap output_name_map_;
Status construction_status_;
};
前面已经介绍过了这个类的作用,是作为真正的形状推断函数的参数,为形状推断提供足够的数据和功能函数支持,那么这个类的成员就比较清晰了,私有数据成员为形状推断提供数据支持,而公有API,为形状推断提供公用的功能函数,比如上面提到的MergeInput和RelaxOutput,下面我们重点介绍下这两个函数的功能:
MergeInput函数是将输入索引idx处的输入与shape合并,具体的合并规则是:
- 如果ShapeHandles是一样的,或者shape是未知的,那么输入维度不变。否则,如果输入维度是未知的,那么输出是shape;
- 如果两个形状都是已知的,它们必须拥有相同的rank;
- 对于任意一个维度,如果在两个形状中这个维度都已知,那么它们必须相等;
- 如果一个形状在任意维度上的信息都多于另一个形状,那么拥有更多信息的形状将被返回。否则,一个新的形状将被构建并返回,这个新的形状综合了输入的两个形状的信息;
- 比如,合并[2,?]和[?,2]将得到[2,2];
- 比如,[2,2]不能被合并到[1,2]
如果说MergeInput函数对输入形状是“收缩”的,那么“RelaxInput”函数对输入形状就是“扩张”的,它倾向于让形状变的更模糊,具体的规则是:
- 如果ShapeHandles是一样的,那么对应的shape将会被返回;
- 如果任一个ShapeHandle是未知的,那么一个未知的ShapeHandle将会被返回;
- 如果两个形状的rank已知,但不同,那么一个未知ShapeHandle将会被返回;
- 对于任一维度,如果任一shape是未知的,那么对应的输出维度也是未知的;
- 对于任一维度,如果两个shape对应的维度位置都是已知的,但并不相同,那么对应的输出维度也是未知的;
- 如果两个shape的rank和对应维度大小都一样,那么这个形状将会被返回;
- 例如,[2,?]和[?,2]会得到[?,?];
- 例如,[2,2]和[3,2]会得到[?,2];
- 例如,[2,2]和[1,2,3]会得到?
3. 关系图
OpShapeInferenceFn-.使用参数.->InferenceContext
OpKernel::Compute-.使用参数.->OpKernelContext
4. 涉及的文件
- shape_inference
5. 迭代记录
- v1.0 2018-08-29 文档创建
- v2.0 2018-09-10 文档重构
tensorflow源码解析之framework-shape_inference的更多相关文章
- tensorflow源码解析之framework拾遗
把framework中剩余的内容,按照文件名进行了简单解析.时间原因写的很仓促,算是占个坑,后面有了新的理解再来补充. allocation_description.proto 一个对单次内存分配结果 ...
- tensorflow源码解析系列文章索引
文章索引 framework解析 resource allocator tensor op node kernel graph device function shape_inference 拾遗 c ...
- Tensorflow源码解析1 -- 内核架构和源码结构
1 主流深度学习框架对比 当今的软件开发基本都是分层化和模块化的,应用层开发会基于框架层.比如开发Linux Driver会基于Linux kernel,开发Android app会基于Android ...
- tensorflow源码解析之common_runtime-executor-上
目录 核心概念 executor.h Executor NewLocalExecutor ExecutorBarrier executor.cc structs GraphView ExecutorI ...
- tensorflow源码解析之framework-allocator
目录 什么是allocator 内存分配器的管理 内存分配追踪 其它结构 关系图 涉及的文件 迭代记录 1. 什么是allocator Allocator是所有内存分配器的基类,它定义了内存分配器需要 ...
- tensorflow源码解析之common_runtime-executor-下
目录 核心概念 executor.h Executor NewLocalExecutor ExecutorBarrier executor.cc structs GraphView ExecutorI ...
- tensorflow源码解析之framework-op
目录 什么是op op_def定义 op注册 op构建与注册辅助结构 op重写 关系图 涉及的文件 迭代记录 1. 什么是op op和kernel是TF框架中最重要的两个概念,如果一定要做一个类比的话 ...
- tensorflow源码解析之distributed_runtime
本篇主要介绍TF的分布式运行时的基本概念.为了对TF的分布式运行机制有一个大致的了解,我们先结合/tensorflow/core/protobuf中的文件给出对TF分布式集群的初步理解,然后介绍/te ...
- tensorflow源码解析之common_runtime拾遗
把common_runtime中剩余的内容,按照文件名排序进行了简单的解析,时间原因写的很仓促,算是占个坑,后续有了新的理解再来补充. allocator_retry 有时候内存分配不可能一次完成,为 ...
- Tensorflow源码解析2 -- 前后端连接的桥梁 - Session
Session概述 1. Session是TensorFlow前后端连接的桥梁.用户利用session使得client能够与master的执行引擎建立连接,并通过session.run()来触发一次计 ...
随机推荐
- Emoji与unicode特殊字符的处理
遇到了一个很让人纠结的问题:emoji表情在使用的过程中,会莫名其妙的消失,或者变成乱码,同时数据库用utf8mb4来存储,但是也出现了问题,冷备过后,导入进库的时候,变成了不可见字符,神奇的消失了! ...
- SEL类型
1.什么是SEL类型 SEL类型代表着方法的签名,在类对象的方法列表中存储着该签名与方法代码的对应关系 每个类的方法列表都存储在类对象中 每个方法都有一个与之对应的SEL类型的对象 根据一个SEL对象 ...
- iOS 模糊、精确搜索匹配功能方法总结 By HL
字符串搜索主要用于UITableView的搜索功能的筛选,过滤,查询 下面是一些流行的搜索查询方法 一.遍历搜索 for循环 根据要求:精确搜索(判读字符串相等) 模糊搜索(字符串包含) 相关知识 ...
- 用Java中的File类模拟实现对系统文件的增删改查效果
码字不易,三连支持一波吧 IO操作向来是各大语言的热区,而对文件的操作也是重中之重. 那么在Java中也给我们提供了很多关于文件操作的类.今天我就用一个比较基本的File类来模拟实现对文件的增删改查效 ...
- Python中set集合常用操作
功能 Python符号 Python方法 备注 交集 & intersection, intersection_update &:取两者交集>>> set3 = se ...
- linux 定时删除图以及crontab介绍
执行 sudo crontab -e 0 3 1 * * /etc/letsencrypt/certbot-auto renew --renew-hook "sudo nginx -s re ...
- spring boot 配置静态路径
一 前言 最近有个项目,需要上传一个zip文件(zip文件就是一堆的html压缩组成)的压缩文件,然后后端解压出来,用户可以预览上传好的文件. 查看资料,spring boot对静态文件,可以通过配 ...
- CreateEvent进程同步
CreateEvent进程间同步 CreateEvent可以创建或是打开一个命名或是未命名的event对象. HANDLE CreateEvent( LPSECURITY_ATTRIBUTES ...
- MYSQL优化的一些性能与技巧
1. 为查询缓存优化你的查询 大多数的MySQL服务器都开启了查询缓存.这是提高性最有效的方法之一,而且这是被MySQL的数据库引擎处理的.当有很多相同的查询被执行了多次的时候,这些查询结果会被放到一 ...
- JsonResponse类的使用、form表单上传文件补充、CBV和FBV、HTML的模板语法之传值与过滤器
昨日内容回顾 Django请求生命周期 # 1.浏览器发起请求 到达Django的socket服务端(web服务网关接口) 01 wsgiref 02 uwsgi + nginx 03 WSGI协议 ...