mxnet源码阅读笔记之include
写在前面
mxnet代码的规范性比Caffe2要好,看起来核心代码量也小很多,但由于对dmlc其它库的依赖太强,代码的独立性并不好。依赖的第三方库包括:
cub
dlpack
dmlc-core
googletest
mkldnn
mshadow
onnx-tensorrt
openmp
ps-lite
tvm
如果对于这些第三方库没有足够的理解,mxnet的核心代码看起来比较费劲。因此时间原因,本篇仅解析了mxnet对外的接口include目录,并且对于严重依赖第三方库的文件没有深入探究,只能算作一篇不完整的源码阅读笔记了。后续有时间的话,再回来迭代。
目录
- storage
- tensor_blob
- ndarray
- resource
- kvstore
- base
- operator
- engine
- executor
- rtc
- graph_attr_types
- op_attr_types
- imperative
- operator_util
- c_api
storage
Storage是一个跨设备的内存管理类,它提供了内存分配和回收的功能,但并不存储分配的内存,真正的内存指针分配在Storage类内部的Handle结构体中:
struct Handle {
void * dptr{nullptr}; //内存地址
size_t size{0};
Context ctx;
int shared_pid{-1};
int shared_id{-1};
};
class Storage {
public:
Handle Alloc(size_t size, Context ctx) {...};
virtual void Alloc(Handle* handle) = 0;
virtual void Free(Handle handle) = 0;
};
tensor_blob
TBlob类可以表示任意维度、在任意设备上、任意数据类型的张量,它是NDArray的内部存储,是mxnet中最底层的数据结构。但本质上它是对DLTensor的代理,DLTensor定义在第三方库dlpack中的dlpack.h文件中,以下是它们的关系:
NDArray-->|包含|TBlob
TBlob-->|包含|DLTensor
ndarray
ndarray是mxnet中的核心数据结构,代表了多维数据,类似于Tensorflow中的Tensor。本质上它借鉴了numpy中关于ndarray的定义,一部分ndarray是包含实际数据的,另外一些ndarray并不包含实际数据,它们只是其他ndarray的视图。举例说明,ndarrayA是一个[1x12]的多维数组,存储了12个元素,ndarrayB是一个[3x4]的多维数组,它底层的数据由ndarrayA提供,因此A和B共享了内存,B仅是A的一个视图。
ndarray内部由chunk结构提供实际的数据存储,先来看下chunk:
struct Chunk {
Storage::Handle shandle;
std::vector<Storage::Handle> aux_handles;
bool static_data; //如果为真,表示该数据是静态的,并非来自Storage,不需要被释放
bool delay_alloc; //数据分配是否需要延缓,注意对辅助数据aux data无效
NDArrayStorageType storage_type = kDefaultStorage;
std::vector<int> aux_types;
Context ctx;
TShape storage_shape;
std::vector<TShape> aux_shapes;
};
可见,Chunk结构仍然不是最终的数据存储结构,本质上数据还是存储在Storage结构中,如下所示:
NDArray-->|使用|Chunk
Chunk-->|使用|Storage
在ndarray中,我们发现数据分为数据本身,以及辅助数据。辅助数据主要用于存储稀疏数据的时候,数据本身放在data中,数据索引放在aux_data中。
最后看下NDArray的数据结构:
class NDArray {
std::shared_ptr<Chunk> ptr_{nullptr};
TShape shape_;
size_t byte_offset_ = 0;
int dtype_ = -1;
bool reuse_ = false;
nnvm::NodeEntry entry_;
mutable TBlob tblob_;
};
resource
在mxnet中,计算中用到的所有内容,除了ndarray之外,都可以被称为资源。其中最常用的资源,就是随机数生成器,分为CPU和GPU两个版本,如下:
enum Type {
kRandom, //CPU版本随机数生成器
kTempSpace, //动态随机内存
kParallelRandom //可以在GPU中使用的并行随机数生成器
};
另外,mxnet还为资源提供了一个管理器,ResourceManager,用于获取资源。
kvstore
kv存储的作用是存储模型参数,以便在分布式的计算中,在多个设备/机器之间进行数据同步。
kv存储可以有多种类型,比如:
- 'local'或者'local_update_cpu‘或者'local_allreduce_cpu',表明这是一个单机的kv存储,并且仅使用cpu做kv的allreduce;
- 'device'或者'local_allreduce_device',也是单机的kv存储,只不过使用gpu做kv的allreduce;
- 'dist_*',分布式的kv存储;
每个kv存储中都有一个更新器,它定义了,针对指定的key,当新value来到时,如何与旧value进行融合。这一点非常重要,因为在深度学习模型的训练中,需要迭代式的对模型参数进行更新,而更新的方式就是通过更新器来定义。
kv存储中,key通常是整型或者字符串,而value是NDArray,因此,有两种更新器的定义:
typedef std::function<void(int, const NDArray&, NDArray*)> Updater;
typedef std::function<void(const std::string&, const NDArray&, NDArray*)> StrUpdater;
最后,kv存储在底层用到了ps-lite来作数据同步。
class KVStore {
public:
static KVStore *Create(const char *type = "local");
virtual void Init(const std::vector<int>& keys, const std::vector<NDArray>& values) = 0;
virtual void Init(const std::vector<std::string>& str_keys, const std::vector<NDArray>& values) = 0;
virtual void Push(...) = 0;
virtual void Pull(...) = 0;
virtual void PullRowSparse(...) = 0;
virtual void set_updater(...);
};
base
引入了两个类,执行环境的上下文信息类Context,实际执行时的上下文类RunContext,后者包含前者。首先看下Context类的定义:
struct Context {
DeviceType dev_type;
int32_t dev_id;
inline void Save(dmlc::Stream *strm) const {...}; //将Context信息记入二进制流
inline bool Load(dmlc::Stream *strm) {...}; //从二进制流中载入Context信息
inline static Context Create(DeviceType dev_type, int32_t dev_id = -1); //构造一个新的Context
inline static Context CPU(int32_t dev_id = 0);
inline static Context GPU(int32_t dev_id=-1);
inline static int32_t GetGPUCount(); //获取GPU的数量
inline static void GetGPUMemoryInformation(int dev, int *free, int *total);
inline static Context CPUPinned(int32_t dev_id = -1);
inline static Context CPUShared(int32_t dev_id = 0);
inline static Context FromString(const std::string& str);
};
而RunContext就相对简单了,它包含了一个Context和一个流指针:
struct RunContext {
Context ctx;
void *stream;
//...
};
operator
Operator定义了mxnet计算图中基础的操作单位。相当于Tensorflow中的kernel,和Caffe2中的Operator。但它与Tensorflow和Caffe2中的操作有本质区别,在Tensorflow中,操作本身和它对应的求导操作是分开的,而在mxnet中,这两者是结合在一起的,分别使用Forward和Backward两个函数实现,因此,mxnet在操作的实现上更加紧凑,与Tensorflow相比减少了一些对计算图进行裁剪的额外开销,性能上有优势,但也同时限制了自己的计算边界,灵活性不足。
class Operator {
public:
//进行前向计算,将计算结果保存在TBlob中
virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data, const std::vector<OpReqType> &req, const std::vector<TBlob> &out_data, const std::vector<TBlob> &aux_states) = 0;
//进行后向计算,将梯度写入in_grad
virtual void Backward(const OpContext &ctx, const std::vector<TBlob> &out_grad, const std::vector<TBlob> &in_data, const std::vector<TBlob> &out_data, const std::vector<OpReqType> &req, const std::vector<TBlob> &in_grad, const std::vector<TBlob> &aux_states);
};
Operator中仅包含了操作计算的接口,对于操作的描述保存在OperatorProperty类中,它负责保存所有与Operator有关的信息,且能够产生设备相关的Operator。同时,它还为计算引擎提供了一些可以优化操作计算的函数。
class OperatorProperty {
public:
//初始化Operator时需要用到的参数
virtual void Init(const std::vector<std::pair<std::string, std::string>>& kwargs) = 0;
//获取为Operator准备的参数
virtual std::map<std::string, std::string> GetParams() const = 0;
virtual int NumOutputs() const {...}
//进行Operator的形状推断,类似于Tensorflow的ShapeInference
virtual bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape, std::vector<TShape> *aux_shape) const = 0;
//进行Operator的类型推断
virtual bool InferType(...);
//构建Operator
virtual Operator* CreateOperator(Context ctx) const = 0;
};
目前看来,mxnet中Operator与OperatorProperty的关系,与Tensorflow中OpKernel与Op的关系不太一样,后者与Caffe2中的Operator和OpSchema的关系更加相似,有机会我们会详细比较下,这三种框架关于操作定义于使用的区别。
engine
引擎是执行核心之一,它负责对计算图中的操作进行调度。引擎中的两大关键元素是操作和变量,操作定义了计算图每一个节点需要实际执行的动作,变量定义了动作之间的依赖关系。
首先,mxnet定义了一个,被异步函数在运行结束时调用的回调函数类,通过对()的重载,用类对回调函数进行了一层封装:
class CallbackOnComplete {
public:
inline void operator()() const {
(*callback_)(engine_, param_);
}
private:
friend class ::mxnet::Engine;
void (*callback_)(Engine *, void *);
Engine* engine_;
void* param_;
};
枚举类FnProperty介绍了常用的函数类型:
enum class FnProperty {
kNormal, //一般操作
kCopyFromGPU, //从GPU上拷贝内容到其它设备的操作
kCopyToGPU, //从其它设备向GPU拷贝内容的操作
kCPUPrioritized, //CPU上优先选择的同步操作
kAsync, //异步函数调用
kDeleteVar, //用来删除变量的函数
kGPUPrioritized, //GPU上优先选择的同步操作
};
engine的含义是,对操作进行调度执行的引擎。回想一下,在Tensorflow中,为了正确执行用户设计好的计算图,我们需要对原始计算图进行一些迭代修改,在Engine类中提供了这样的接口:
class Engine {
public:
//定义运行结束时的回调类
typedef engine::CallbackOnComplete CallbackOnComplete;
//定义传递给引擎的同步操作函数
typedef std::function<void(RunContext)> SyncFn;
//定义传递给引擎的异步操作函数
typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn;
//定义变量指针
typedef engine::VarHandle VarHandle;
//定义操作指针
typedef engine::OprHandle OprHandle;
//停止引擎中的所有worker
virtual void Stop() {}
//启动引擎中的所有worker
virtual void Start() {}
//分配一个新的变量,该变量可以被用来根据依赖关系,辅助对引擎中的操作进行调度
virtual VarHandle NewVariable() = 0;
//构建一个操作,该操作定义在外部,从而我们可以在调度中重复使用
virtual OprHandle NewOperator(...) = 0;
//删除一个操作,它不会立刻进行,而是直到所有使用该操作的动作运行结束之后再进行
virtual void DeleteOperator(OpHandle op) = 0;
//将一个操作加入引擎
virtual void Push(...);
//将一个异步操作加入引擎
virtual void PushAsync(...);
//将一个同步操作加入引擎
virtual void PushSync(...);
//删除一个变量,它不会立刻进行,而是直到所有依赖该变量的操作完成之后再进行
virtual void DeleteVariable(...) = 0;
//等待一个变量准备完成
virtual void WaitForVar(...) = 0;
//等待引擎中所有的活动都结束时再返回
virtual void WaitForAll() = 0;
//返回引擎的单例对象
static Engine* Get();
//用来生成OnComplete回调的工厂函数
inline CallbackOnComplete CreateCallback(...);
};
executor
mxnet的执行器接口,用于对计算图进行执行。执行的机制与Operator的设计相合,同样提供了前向和后向两种接口,如下:
class Executor {
public:
virtual void Forward(bool is_train) = 0;
virtual void PartialForward(bool is_train, int step, int *step_left) = 0;
virtual void Backward(const std::vector<NDArray> &head_grads, bool is_train = true) = 0;
};
rtc
包含了Cuda运行时的编译模块CudaModule。
graph_attr_types
获取图相关属性的辅助结构。对于一张计算图中的节点,通常会关注两种信息,一种是计算图中节点的存储类型,一种是节点的调度模式,分别将结果存储在StorageTypeVector和DispatchModeVector中,这两种结构的定义如下:
using StorageTypeVector = std::vector<int>;
using DispatchModeVector = std::vector<DispatchMode>;
op_attr_types
有关操作的额外属性,与nvvm有关,目前看不懂。
imperative
与NDArray有关的运行时函数,目前看不懂。
operator_util
辅助快速构建operator的功能和注册器。
c_api
定义了mxnet后端"C++"代码的接口。
mxnet源码阅读笔记之include的更多相关文章
- CI框架源码阅读笔记4 引导文件CodeIgniter.php
到了这里,终于进入CI框架的核心了.既然是“引导”文件,那么就是对用户的请求.参数等做相应的导向,让用户请求和数据流按照正确的线路各就各位.例如,用户的请求url: http://you.host.c ...
- libevent源码阅读笔记(一):libevent对epoll的封装
title: libevent源码阅读笔记(一):libevent对epoll的封装 最近开始阅读网络库libevent的源码,阅读源码之前,大致看了张亮写的几篇博文(libevent源码深度剖析 h ...
- CI框架源码阅读笔记5 基准测试 BenchMark.php
上一篇博客(CI框架源码阅读笔记4 引导文件CodeIgniter.php)中,我们已经看到:CI中核心流程的核心功能都是由不同的组件来完成的.这些组件类似于一个一个单独的模块,不同的模块完成不同的功 ...
- CI框架源码阅读笔记3 全局函数Common.php
从本篇开始,将深入CI框架的内部,一步步去探索这个框架的实现.结构和设计. Common.php文件定义了一系列的全局函数(一般来说,全局函数具有最高的加载优先权,因此大多数的框架中BootStrap ...
- CI框架源码阅读笔记2 一切的入口 index.php
上一节(CI框架源码阅读笔记1 - 环境准备.基本术语和框架流程)中,我们提到了CI框架的基本流程,这里再次贴出流程图,以备参考: 作为CI框架的入口文件,源码阅读,自然由此开始.在源码阅读的过程中, ...
- 源码阅读笔记 - 1 MSVC2015中的std::sort
大约寒假开始的时候我就已经把std::sort的源码阅读完毕并理解其中的做法了,到了寒假结尾,姑且把它写出来 这是我的第一篇源码阅读笔记,以后会发更多的,包括算法和库实现,源码会按照我自己的代码风格格 ...
- Three.js源码阅读笔记-5
Core::Ray 该类用来表示空间中的“射线”,主要用来进行碰撞检测. THREE.Ray = function ( origin, direction ) { this.origin = ( or ...
- PHP源码阅读笔记一(explode和implode函数分析)
PHP源码阅读笔记一一.explode和implode函数array explode ( string separator, string string [, int limit] )此函数返回由字符 ...
- AQS源码阅读笔记(一)
AQS源码阅读笔记 先看下这个类张非常重要的一个静态内部类Node.如下: static final class Node { //表示当前节点以共享模式等待锁 static final Node S ...
随机推荐
- Git .gitignore 不起作用的解决办法
解决方法的原理:.gitignore只能忽略那些原来没有被track的文件,如果某些文件已经被纳入了版本管理中,则修改.gitignore是无效的. 解决方案:git rm -r --cached . ...
- JAVA多线程学习十五 - 阻塞队列应用
一.类相关属性 接口BlockingQueue<E>定义: public interface BlockingQueue<E> extends Queue<E> { ...
- Android使用pull解析xml格式的数据
dom解析:基于全文加载的解析方式 sax解析:基于事件的逐行解析方式 pull解析:同sax XmlPullParser //解析xml文件读取短信内容 ...
- smartimageview 的原理
自定义的控件在布局文件中的引用都需要指定类的完整路径 1.自定义了一个MyImageview类继承了Imageview,添加三个构造方法 2.添加一个setImageUrl方法接受一个图片ur ...
- Solution -「Gym 102798I」Sean the Cuber
\(\mathcal{Description}\) Link. 给定两个可还原的二阶魔方,求从其中一个状态拧到另一个状态的最小步数. 数据组数 \(T\le2.5\times10^5\). ...
- Solution -「AGC 026D」Histogram Coloring
\(\mathcal{Description}\) Link. 有 \(n\) 列下底对齐的方格纸排成一行,第 \(i\) 列有 \(h_i\) 个方格.将每个方格染成黑色或白色,求使得任意完 ...
- Solution -「POI 2010」「洛谷 P3511」MOS-Bridges
\(\mathcal{Description}\) Link.(洛谷上这翻译真的一言难尽呐. 给定一个 \(n\) 个点 \(m\) 条边的无向图,一条边 \((u,v,a,b)\) 表示从 ...
- code-server服务端开发利器,再也不用vim装逼了!!!
一直有个需求,就是万不得已在服务修改代码的时候能有个好的工具,至少比vim要强吧!虽然vim也还行,但是如果比vscode那一定是差了点!这个微软洗心革面的新工具着实不错!从刚开始的鄙视到真香我用了不 ...
- 【Mock平台】测试开发实战01-开篇PRD和需求详细
微信搜索[大奇测试开],关注这个坚持分享测试开发干货的家伙. 平台背景 从业务特性上,不少测试的服务很多是依赖第三方的接口的,比如其中的支付场景,就需要很多状态的返回进行验证,但大部分服务提供商没有很 ...
- 【高频Java面试题】简单说说JVM堆的内存结构和GC回收流程
目录 前言 JVM堆内存结构简述 JVM堆内存结构图 堆初体验 结构详情 新生代 老年代 永久代/元空间 GC回收流程 GC回收流程图 GC回收详细流程 查看JDK自带可视化堆空间图 总结 前言 我们 ...