[源码解析] TensorFlow 分布式环境(4) --- WorkerCache

目录

我们接下来介绍缓存机制。为什么要缓存?因为集群内部有众多 worker。在 Master 与 Worker 之间,Worker 和 Worker 之间都需要交互,所以有必要把 Worker 和其 Grpc 通道都缓存起来。可以说,在 TensorFlow 分布式环境下处处可见缓存的使用。

本系列其他文章是:

[翻译] TensorFlow 分布式之论文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻译] TensorFlow 分布式之论文篇 "Implementation of Control Flow in TensorFlow"

[源码解析] TensorFlow 分布式环境(1) --- 总体架构

[源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

1. WorkerCache

WorkerCache 的作用就是获取 WorkerInterface 实例,WorkerInterface 实例可以访问远端 WorkerSerivice 服务。WorkerInterface 实例的典型就是 GrpcRemoteWorker。

1.1 如何使用

前面初始化 MasterEnv 时,WorkerCacheFactory 被配置到 master_env_.worker_cache_factory 之中。

master_env_.worker_cache_factory =
[this](const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) {
return WorkerCacheFactory(options, worker_cache);
};

后续在 Master::CreateSession 之中,有如下删减版代码,从中可以知道如何从工厂类之中获取 worker_cache(WorkerCacheInterface实例),以及后续如何使用 worker_cache 进行操作。

void Master::CreateSession(const CreateSessionRequest* req,
CreateSessionResponse* resp, MyClosure done) {
SchedClosure([this, req, resp, done]() {
// 配置option
WorkerCacheFactoryOptions worker_cache_factory_options;
worker_cache_factory_options.protocol = &grpc_protocol;
worker_cache_factory_options.rpc_options = &req->config().rpc_options(); // 建立 worker_cache
// Create the worker cache from the computed server_def.
status = env_->worker_cache_factory(worker_cache_factory_options,
&worker_cache); // 使用 worker_cache 来完成后续操作
status =
DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
worker_cache, remote_devices.get()); });
}

1.2 配置

WorkerCacheFactoryOptions 等价于 ServerDef,它包含 ClusterDef,job_name,task_index 等信息。

// Options passed to the worker_cache_factory function.
struct WorkerCacheFactoryOptions {
const ClusterDef* cluster_def = nullptr;
const string* job_name = nullptr;
int task_index;
const string* protocol = nullptr;
const RPCOptions* rpc_options = nullptr; WorkerCacheFactoryOptions() {} // Construct from a ServerDef proto.
//
// Note: server_def must outlive WorkerCacheFactoryOptions!
WorkerCacheFactoryOptions(const ServerDef& server_def) {
if (server_def.has_cluster() && !server_def.job_name().empty()) {
cluster_def = &server_def.cluster();
job_name = &server_def.job_name();
task_index = server_def.task_index();
protocol = &server_def.protocol();
rpc_options = &server_def.default_session_config().rpc_options();
}
}
};

1.3 工厂类

WorkerCacheFactory 是一个函数,其作用如下:

  • 使用 ParseChannelSpec 来得到 GrpcChannelSpec 实例,GrpcChannelSpec 等价于 ClusterSpec,其包含集群基本配置信息。
  • 使用 NewGrpcChannelCache 拿到一个GrpcChannelCache channel_cache。这里使用到了 GetChannelCreationFunction。
  • 使用 NewGrpcWorkerCacheWithLocalWorker(channel_cache) 得到 worker_cache。
Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) { // 得到 GrpcChannelSpec
GrpcChannelSpec channel_spec;
TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec)); // 得到 GrpcChannelCache
std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
channel_spec, GetChannelCreationFunction(), *options.rpc_options)); string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
"/task:", options.task_index); const string host_port = channel_cache->TranslateTask(name_prefix);
int requested_port; auto colon_index = host_port.find_last_of(':');
if (!strings::safe_strto32(host_port.substr(colon_index + 1),
&requested_port)) {
return errors::Internal("Could not parse port for local server from \"",
host_port, "\".");
}
if (requested_port != bound_port_) {
return errors::InvalidArgument("Requested port ", requested_port,
" differs from expected port ", bound_port_);
}
// 得到 Worker Cache
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(
channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
return Status::OK();
}

1.3.1 ParseChannelSpec

ParseChannelSpec 被用来得到 GrpcChannelSpec 实例,GrpcChannelSpec 等价于 ClusterSpec,其包含集群基本配置信息。

Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) {
for (const auto& job : options.cluster_def->job()) {
std::map<int, string> host_ports;
for (const auto& task : job.tasks()) {
string& host_port = host_ports[task.first];
if (!host_port.empty()) {
return errors::InvalidArgument("JobDef for job \"", job.name(),
"\" specified two addresses for task \"",
task.first, "\": ", host_port, " and ",
task.second);
}
if (job.name() == *options.job_name && task.first == options.task_index) {
host_port = strings::StrCat(host_name_, ":", bound_port_);
} else {
host_port = task.second;
}
}
TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
}
return Status::OK();
}

1.3.2 NewGrpcChannelCache

NewGrpcChannelCache 用于创建 GrpcChannelCache 实例,可以看到,每个 Job 对应了一个 SparseGrpcChannelCache。如果只有一个 SparseGrpcChannelCache,则直接返回,否则把这些 SparseGrpcChannelCache 组合在一起构建一个 MultiGrpcChannelCache 返回。其中传入的channel_func 是 GetChannelCreationFunction。我们后续会介绍。

GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
ChannelCreationFunction channel_func,
const RPCOptions& options) {
const int num_jobs = spec.host_ports_jobs().size();
if (!num_jobs) {
return nullptr;
}
std::vector<GrpcChannelCache*> caches;
caches.reserve(num_jobs);
for (auto& job : spec.host_ports_jobs()) {
caches.push_back(
new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func,
options.num_channels_per_target()));
}
return caches.size() == 1 ? caches[0]
: new MultiGrpcChannelCache(
caches, options.num_channels_per_target());
}

1.3.3 NewGrpcWorkerCacheWithLocalWorker

NewGrpcWorkerCacheWithLocalWorker 方法创建 GrpcWorkerCache 实例。

WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
WorkerInterface* local_worker, const string& local_target) {
return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
}

local_worker 参数是通过 worker_impl() 得到并且传入的,其生成是在 GrpcServer::Init 之中,就是本地的 GrpcWorker。

GrpcWorker* worker_impl() const { return worker_impl_.get(); }

std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
const ConfigProto& config) {
return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
} Status GrpcServer::Init(const GrpcServerOptions& opts) { // 省略 worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
: NewGrpcWorker(&worker_env_, config); // 省略
}

我们梳理一下工厂类目前流程,可以看到,最开始输入是 WorkerCacheFactoryOptions,然后一步一步的通过各个函数的处理,最后生成了 GrpcWorkerCache。

图 1 工厂类流程

1.4 WorkerCacheInterface

1.4.1 接口

WorkerCacheInterface 是接口类,上面图之中 GrpcWorkerCache 就是这个接口的派生类。

class WorkerCacheInterface {
public:
virtual ~WorkerCacheInterface() {} // Updates *workers with strings naming the remote worker tasks to
// which open channels have been established.
virtual void ListWorkers(std::vector<string>* workers) const = 0;
virtual void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) const = 0; // If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object
// wrapping that channel. The returned value must be destroyed by
// calling `this->ReleaseWorker(target, ret)`
virtual WorkerInterface* GetOrCreateWorker(const string& target) = 0; // Release a worker previously returned by this->GetOrCreateWorker(target).
//
// TODO(jeff,sanjay): Consider moving target into WorkerInterface.
// TODO(jeff,sanjay): Unify all worker-cache impls and factor out a
// per-rpc-subsystem WorkerInterface creator.
virtual void ReleaseWorker(const string& target, WorkerInterface* worker) {
// Subclasses may override to reuse worker objects.
delete worker;
} // Set *locality with the DeviceLocality of the specified remote device
// within its local environment. Returns true if *locality
// was set, using only locally cached data. Returns false
// if status data for that device was not available. Never blocks.
virtual bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) = 0; // Set *locality with the DeviceLocality of the specified remote device
// within its local environment. Callback gets Status::OK if *locality
// was set.
virtual void GetDeviceLocalityAsync(const string& device,
DeviceLocality* locality,
StatusCallback done) = 0; // TODO(b/189159585): Define a general client cache maker function to
// construct client cache of different types sharing the same underling RPC
// channels, to replace the eager and coordination cache function.
// Build and return a EagerClientCache object wrapping that channel.
virtual Status GetEagerClientCache(
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) = 0; // Build and return a CoordinationClientCache object wrapping that channel.
virtual Status GetCoordinationClientCache(
std::unique_ptr<CoordinationClientCache>* coordination_client_cache) = 0; // Start/stop logging activity.
virtual void SetLogging(bool active) {} // Discard any saved log data.
virtual void ClearLogs() {} // Return logs for the identified step in *ss. Any returned data will no
// longer be stored.
virtual bool RetrieveLogs(int64_t step_id, StepStats* ss) { return false; }
};

WorkerCachePartial 又继承了 WorkerCacheInterface。

// Implements the part of the interface that caches and returns remote
// device status attributes.
class WorkerCachePartial : public WorkerCacheInterface {
public:
bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) override; void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
StatusCallback) override; ~WorkerCachePartial() override {} // Clear all entries from the DeviceStatus cache.
void FlushStatusCache(); private:
mutex mu_; // Initiate a GetStatusAsync to the remote task named by "task", and
// update the cache with all the DeviceAttributes reported.
Status RefreshDeviceStatus(const string& device_name); typedef std::unordered_map<string, DeviceAttributes> StatusMap;
StatusMap device_status_cache_ TF_GUARDED_BY(mu_);
};

1.4.2 GrpcWorkerCache

GrpcWorkerCache 则继承了 WorkerCachePartial。

class GrpcWorkerCache : public WorkerCachePartial {
public:
explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
WorkerInterface* local_worker,
const string& local_target,
GrpcWorkerEnv* worker_env)
: local_target_(local_target),
local_worker_(local_worker),
channel_cache_(channel_cache),
worker_env_(worker_env),
next_round_robin_assignment_(0) {} const string local_target_;
WorkerInterface* const local_worker_; // Not owned.
std::shared_ptr<GrpcChannelCache> channel_cache_;
WorkerCacheLogger logger_;
GrpcWorkerEnv* worker_env_; // Not owned mutex assignment_mu_;
std::unordered_map<std::string, size_t> target_assignments_
TF_GUARDED_BY(assignment_mu_);
size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
};

其主要功能是使用 ListWorkers 罗列出集群内所有 worker 的名字。

void ListWorkers(std::vector<string>* workers) const override {
channel_cache_->ListWorkers(workers);
} void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) const override {
channel_cache_->ListWorkersInJob(job_name, workers);
}

GetOrCreateWorker 会根据 Worker 的 RPC 通道建立 worker,如果是本地,则直接返回 local_worker_,就是我们前面设置的本地 GrpcWorker。

WorkerInterface* GetOrCreateWorker(const string& target) override {
if (target == local_target_) {
return local_worker_;
} else {
SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
if (!channel) {
return nullptr;
}
size_t index = AssignWorkerToThread(target);
return NewGrpcRemoteWorker(
channel, worker_env_->GetCompletionQueue(index),
worker_env_->GetThreadPool(), &logger_, target);
}
}

2. RPC 通道

Worker 运行在 RPC 通道之上,所以我们接下来看看如何建立这个 RPC 通道。因为 Worker 有缓存,同样的,RPC 通道也有缓存。GrpcChannelCache 就是这个缓存,其被用来获取/创建集群之中远端 Worker 的 RPC 通道。

2.1 GrpcChannelCache 接口

GrpcChannelCache 是接口类,定义了一系列接口,比如:

  • ListWorkers 可以返回集群之中的 Worker 名称。
  • TranslateTask :把 Worker 名字 转换为地址信息,格式是 host:port。
  • FindWorkerChannel :从缓存中查找 grpc::Channel 实例,如果缓存之中没有,就依据地址信息动态生成一个实例,再将其放入缓存。
class GrpcChannelCache {
public:
virtual ~GrpcChannelCache() {} // Populates *workers with names of all workers which this object
// was created to handle. Worker names are in the format
// /job:<job identifier>/task:<task id>
// e.g. /job:mnist/task:2
virtual void ListWorkers(std::vector<string>* workers) = 0;
virtual void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) = 0; // If found, returns a gRPC channel that is connected to the remote
// worker named by 'target'. 'target' is of the following
// format: /job:<job identifier>/task:<task id>
// E.g., /job:mnist/task:2
virtual SharedGrpcChannelPtr FindWorkerChannel(const string& target) = 0; // Translates a string in the form `/job:X/task:Z` into a host_port.
virtual string TranslateTask(const string& task) = 0;
};

2.2 缓存机制

CachingGrpcChannelCache 是缓存类,可以避免每次创建 grpc::Channel 的开销。其定义如下,具体就是派生了 GrpcChannelCache 的 GenericCachingChannelCache。

// GrpcChannelCache that caches results to FindWorkerChannel() calls.
using CachingGrpcChannelCache = GenericCachingChannelCache<GrpcChannelCache>;

GenericCachingChannelCache,用于缓存FindWorkerChannel()调用的结果,首先从缓存中查找 grpc::Channel 实例,如果缓存之中没有,就依据地址信息调用 FindChannelOnce 动态生成一个实例,再将其放入缓存。

GenericCachingChannelCache 允许使用多个通道与同一目标通信以提高吞吐量。当同一目标存在多个通道时,每次调用FindWorkerChannel时,都会以 round robin 循环方式选择这些通道。

注意,因为有如下定义,所以 absl::flat_hash_map<string, ChannelState> channels_ 就是 ::grpc::Channel 缓存 集合。

typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;

具体代码是:

template <typename ChannelCacheT>
class GenericCachingChannelCache : public ChannelCacheT {
public:
explicit GenericCachingChannelCache(int num_channels_per_target)
: num_channels_per_target_(
num_channels_per_target > 0 ? num_channels_per_target : 1) {} ~GenericCachingChannelCache() override {} SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
{
mutex_lock l(mu_);
auto iter = channels_.find(target);
if (iter != channels_.end()) {
return GetNextChannelPtrAndUpdateState(iter->second);
}
}
ChannelState new_chan_state;
for (int indx = 0; indx < num_channels_per_target_; indx++) {
auto ch = FindChannelOnce(target);
if (!ch) return nullptr;
new_chan_state.channels.push_back(ch);
}
new_chan_state.last_used = num_channels_per_target_ - 1; {
mutex_lock l(mu_);
typename absl::flat_hash_map<string, ChannelState>::iterator iter;
bool was_inserted;
std::tie(iter, was_inserted) = channels_.insert({target, new_chan_state});
return GetNextChannelPtrAndUpdateState(iter->second);
}
} protected:
// Find the ClientChannel for "target". Only called when no channel was
// found in the channels_ cache for "target". A non nullptr result will be
// cached in channels_.
virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0; private:
struct ChannelState {
std::vector<SharedGrpcChannelPtr> channels;
int last_used;
}; // Should be called with mu_ held.
SharedGrpcChannelPtr GetNextChannelPtrAndUpdateState(
ChannelState& chan_state) {
// Following statement is marked as Crash OK as this is an invariant of
// code flow in this class.
CHECK_EQ(chan_state.channels.size(), num_channels_per_target_); // Crash OK
chan_state.last_used =
(chan_state.last_used + 1) % num_channels_per_target_;
return chan_state.channels[chan_state.last_used];
} const int num_channels_per_target_;
// TODO(zhifengc): Eviction when the map becomes too big.
mutex mu_;
absl::flat_hash_map<string, ChannelState> channels_ TF_GUARDED_BY(mu_);
};

2.3 业务派生类

从 CachingGrpcChannelCache 又派生出了两个类,具体如下:

2.3.1 叶子节点

SparseGrpcChannelCache 是叶子结点,集群之中每个 Job 对应了一个 SparseGrpcChannelCache,SparseGrpcChannelCache 内部的 grpc::Channel 集合就是 Job 的 Task 对应的 grpc::Channel 集合,每个 Task 对应一个 grpc::Channel 。

SparseGrpcChannelCache 主要变量如下:

  • const string job_id_ :本类对应了哪一个 Job。
  • const std::map<int, string> host_ports_ :本 Job 对应 Task 的 host:port 列表。
  • const ChannelCreationFunction channel_func_ :生成 grpc:Channel 的方法。

SparseGrpcChannelCache 主要功能如下:

  • ListWorkers :该方法返回本 Job 对应的 Task 名称列表。
  • TranslateTask:依据某个 Task 名字来得到其地址信息(格式为host:port ),例如, /job:ps/replica:1/task:1 的地址可能就是 ps1:1111;
  • FindChannelOnce :依据某个 Task 名字来创建对应的 grpc::Channel。具体是先通过 TranslateTask 获取到 worker 对应的 task id,然后得到地址信息,最后用地址信息来构建 grpc::Channel。
class SparseGrpcChannelCache : public CachingGrpcChannelCache {
public:
SparseGrpcChannelCache(const string& job_id,
const std::map<int, string>& host_ports,
ChannelCreationFunction channel_func,
int num_channels_per_target)
: CachingGrpcChannelCache(num_channels_per_target),
job_id_(job_id),
host_ports_(host_ports),
channel_func_(std::move(channel_func)) {
}
~SparseGrpcChannelCache() override {} void ListWorkers(std::vector<string>* workers) override {
workers->reserve(workers->size() + host_ports_.size());
for (const auto& id_host_port : host_ports_) {
workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
}
} void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) override {
if (job_name == job_id_) {
ListWorkers(workers);
}
} string TranslateTask(const string& target) override {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
return "";
} if (!parsed.has_job || parsed.job != job_id_) {
return "";
}
if (!parsed.has_replica || parsed.replica != 0) {
return "";
}
int32_t task = parsed.has_task ? parsed.task : -1;
auto iter = host_ports_.find(task);
if (iter == host_ports_.end()) {
return "";
}
return iter->second;
} protected:
SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
const string host_port = TranslateTask(target);
if (host_port.empty()) {
if (host_port.empty()) {
return nullptr;
}
auto chan_ptr = channel_func_(host_port);
return chan_ptr;
} private: const string job_id_;
const std::map<int, string> host_ports_;
const ChannelCreationFunction channel_func_;
TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
};

2.3.2 非叶子结点

为了提高 SparseGrpcChannelCache 查找过程以及对集群所有 Worker 节点 的组合管理,TF 把 集群内的 SparseGrpcChannelCache 组合起来,构建了 MultiGrpcChannelCache。MultiGrpcChannelCache 会把访问过的 SparseGrpcChannelCache 缓存起来。

// A ChannelCache that is the union of multiple ChannelCaches.
// Takes ownership of the caches passed to the constructor.
class MultiGrpcChannelCache : public CachingGrpcChannelCache {
public:
explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches,
int num_channels_per_target)
: CachingGrpcChannelCache(num_channels_per_target), caches_(caches) {} ~MultiGrpcChannelCache() override {
for (GrpcChannelCache* cache : caches_) {
delete cache;
}
} void ListWorkers(std::vector<string>* workers) override {
for (GrpcChannelCache* cache : caches_) {
cache->ListWorkers(workers);
}
} void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) override {
for (GrpcChannelCache* cache : caches_) {
cache->ListWorkersInJob(job_name, workers);
}
} string TranslateTask(const string& target) override {
mutex_lock l(mu_); // could use reader lock
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
if (cache == nullptr) {
for (GrpcChannelCache* c : caches_) {
string r = c->TranslateTask(target);
if (!r.empty()) {
target_caches_.insert({target, c});
cache = c;
break;
}
}
}
return cache->TranslateTask(target);
} protected:
SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
for (GrpcChannelCache* cache : caches_) {
SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
if (ch) {
mutex_lock l(mu_);
target_caches_.insert({target, cache});
return ch;
}
}
return nullptr;
} private:
// List of channels used by this MultiGrpcChannelCache.
const std::vector<GrpcChannelCache*> caches_; mutex mu_;
// Cache of channels keyed by the target they are handling.
// The same GrpcChannelCache can appear multiple times in the cache.
std::unordered_map<string, GrpcChannelCache*> target_caches_
TF_GUARDED_BY(mu_);
};

目前结构如下:

图 2 缓存逻辑关系

2.4 生成 GrpcChannelCache

前面在生成 GrpcChannelCache 时候,传入了 GetChannelCreationFunction,当时没有介绍,我们现在梳理一下。

  // 得到 GrpcChannelCache
std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
channel_spec, GetChannelCreationFunction(), *options.rpc_options));

2.4.1 目标&使用

我们首先看看如何使用或者说目标,就是通过 target(host:port类型的字符串)来生成一个 SharedGrpcChannelPtr,我们知道,SharedGrpcChannelPtr 就是 grpc::Channel。

SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
const string host_port = TranslateTask(target);
if (host_port.empty()) {
if (host_port.empty()) {
return nullptr;
}
auto chan_ptr = channel_func_(host_port);
VLOG(5) << "Channel created for: job: " << job_id_
<< " host_port: " << host_port << " target : " << target
<< " Ptr: " << chan_ptr.get();
return chan_ptr;
}

2.4.2 NewHostPortGrpcChannel

首先要介绍 NewHostPortGrpcChannel,NewHostPortGrpcChannel 是 TF 现存的 API。其主要作用是调用 ::grpc::CreateCustomChannel(gRPC API)得到一个 grpc::Channel,配置到 SharedGrpcChannelPtr* channel_pointer 之上,然后返回 channel_pointer(也就是 grpc::Channel)。这个方法的返回结果是我们满意的,但是调用方法不对,需要封装或转换一下。

Status NewHostPortGrpcChannel(const string& target,
const RPCOptions* rpc_options,
SharedGrpcChannelPtr* channel_pointer) {
// Minimally ensure that the target is valid
TF_RETURN_IF_ERROR(ValidateHostPortPair(target)); ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
*channel_pointer = ::grpc::CreateCustomChannel(
"dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
return Status::OK();
}

2.4.3 ConvertToChannelCreationFunction

ConvertToChannelCreationFunction 方法是用来把传入的 new_channel_func_ptr 方法转换一下,把 new_channel_func_ptr 变成一个只需要传入 const string& target 就可以生成 SharedGrpcChannelPtr 的方法。

ChannelCreationFunction ConvertToChannelCreationFunction(
const std::function<Status(string, const RPCOptions*,
SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
SharedGrpcChannelPtr channel_ptr;
if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
.ok()) {
return channel_ptr;
} else {
return nullptr;
}
};
}

2.4.4 GetChannelCreationFunction

GetChannelCreationFunction 就是使用 NewHostPortGrpcChannel 作为传入参数,得到一个 ConvertToChannelCreationFunction 的方法,因为这个方法才是可以被 WorkerCache工厂类利用的方法。

ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
// We can do this because SparseGrpcChannelCache is robust to nullptr being
// returned by the channel creation function
return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
}

2.4.5 使用分析

回到我们的调用。channel_func_ 就是 GetChannelCreationFunction,于是直接调用就可以得到 grpc::Channel。

SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
const string host_port = TranslateTask(target);
auto chan_ptr = channel_func_(host_port);
}

至此,我们拓展之前的逻辑如下,中间增加了一个步骤,通过传入 target 就可以得到 grpc::Channel:

图 3 如何转换

3. Cache 在系统中的位置

我们虽然总结了 Cache 如何初始化,如何使用,但是我们迷失了 Cache 在系统之中的位置,现在我们看看究竟在系统之中,Cache 处于什么位置。GrpcWorkerCache 内部的 GrpcChannelCache 指向了系统内部的 gRPC Channel Cache,用来获取缓存的 gRPC 通道。local_worker 存储了本地 Worker。

图 4 Cache 的位置

当调用 GrpcWorkerCache 的 GetOrCreateWorker 时候,如果 target 是本地,就直接返回 local_worker(就是我们前面设置的本地 GrpcWorker),否则根据 Worker 的 RPC 通道来生成一个远端 GrpcRemoteWorker。

图 5 生成 worker

在 Master,Worker,MasterSesision,WorkerSession 之中,处处可见 WorkerCacheInterface(也就是GrpcWorkerCache)的身影,很多类都有一个指向 WorkerCacheInterface 的成员变量,使用相当广泛。

4. 查找设备集

为了创建 WorkerSession,MasterSession 需要知道远端所有 Worker 之上的设备集合,所以 Master 会在创建 MasterSession 之前遍历所有 Worker,获取其上的设备信息,因为其利用了 GrpcWorkerCache 的功能,所以我们在这里一起讲解。基本逻辑如下:

  • 根据 GrpcWorkerCache::ListWorkers 获取集群中所有 Worker 的名字。
  • 依据 worker_name 调用 GetOrCreateWorker 在 worker_cache 内部查找 WorkerInterface 对象,如果有就获取,没有就构建。
  • 然后构建 GetStatusRequest,发送给找到的 Worker,具体通过 GetStatusAsync 完成。
  • Worker 返回 GetStatusResponse 之后,将调用回调函数 cb (WhenFound方法)之中的函数对象来获取 Worke 的设备信息。这里需要对获取到的设备信息进行处理,添加 worker_name。

图 6 获取设备

4.1 DeviceFinder

4.1.1 定义

DeviceFinder 是一个函数对象,实现了查找远端worker设备的算法,我们先给出成员变量如下:

class DeviceFinder {
~DeviceFinder() {
for (Device* dev : found_) delete dev;
} typedef DeviceFinder ME;
const MasterEnv* env_;
WorkerCacheInterface* worker_cache_;
std::vector<DeviceNameUtils::ParsedName> filters_; mutex mu_;
int num_pending_ TF_GUARDED_BY(mu_);
condition_variable pending_zero_;
std::vector<Device*> found_ TF_GUARDED_BY(mu_);
// List of targets to be contacted by this DeviceFinder. The
// respective `bool` in `seen_targets_` indicates whether we have
// heard from this target or not.
std::vector<string> targets_;
std::vector<bool> seen_targets_ TF_GUARDED_BY(mu_);
Status status_; TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
};

4.1.2 初始化

主要逻辑是:根据 GrpcWorkerCache::ListWorkers 获取集群中所有的 Worker 的名字列表。

explicit DeviceFinder(
const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
WorkerCacheInterface* worker_cache)
: env_(env), worker_cache_(worker_cache) {
CHECK(worker_cache) << "Worker cache was null!";
auto process_filter = [this](const string& filter) {
DeviceNameUtils::ParsedName parsed;
if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
filters_.push_back(parsed);
} else {
LOG(FATAL) << "Skipping invalid filter: " << filter;
}
};
for (const string& filter : device_filters) {
process_filter(filter);
}
// Enumerates all known workers' target. A target name is a
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
if (filters_.empty()) {
// If no filters were specified, we list all known workers in
// `worker_cache`.
std::vector<string> workers;
worker_cache->ListWorkers(&workers);
std::swap(workers, targets_);
} else {
// When applying filters, we must include the local worker, even if it
// does not match any of the filters.
CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
const string& local_device_name = env_->local_devices[0]->name();
DeviceNameUtils::ParsedName local_parsed_name;
CHECK(DeviceNameUtils::ParseFullName(local_device_name,
&local_parsed_name));
bool all_filters_have_job = true;
std::unordered_set<string> filter_job_names({local_parsed_name.job});
for (const DeviceNameUtils::ParsedName& filter : filters_) {
all_filters_have_job = all_filters_have_job && filter.has_job;
if (filter.has_job) {
filter_job_names.insert(filter.job);
}
} std::vector<string> workers;
if (all_filters_have_job) {
// If all of the device filters have a job specified, then we only need
// to list the workers in the jobs named in the filter, because a worker
// in any other job would not match any filter.
for (const string& job_name : filter_job_names) {
VLOG(2) << "Selectively listing workers in job: " << job_name;
std::vector<string> workers_in_job;
worker_cache->ListWorkersInJob(job_name, &workers_in_job);
workers.insert(workers.end(), workers_in_job.begin(),
workers_in_job.end());
}
} else {
// If any of the device filters does not have a job specified, then we
// must list the workers from all jobs.
VLOG(2) << "Listing workers in all jobs because some device "
<< "filter has no job specified. Filters were:";
if (device_filters.empty()) {
VLOG(2) << "- <NO FILTERS>";
} else {
for (const string& filter : device_filters) {
VLOG(2) << "- " << filter;
}
}
worker_cache->ListWorkers(&workers);
}
for (const string& name : workers) {
if (MatchFilters(name) ||
DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
targets_.push_back(name);
}
}
}
seen_targets_.assign(targets_.size(), false);
}

4.1.3 GetRemoteDevices

GetRemoteDevices 方法会获取远端设备,逻辑如下:

  • 利用 finder.Start() 来给集群内部所有 Worker 广播 GetStatusRequest。
  • 利用 finder.Wait() 收集所有 Worker 返回的 GetStatusResponse 消息。
  • 利用 finder.GetRemoteDevices 获取查询结果,并且返回给客户。
static Status GetRemoteDevices(
const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
WorkerCacheInterface* worker_cache,
std::vector<std::unique_ptr<Device>>* out_remote) {
DeviceFinder finder(device_filters, env, worker_cache);
finder.Start();
TF_RETURN_IF_ERROR(finder.Wait());
finder.GetRemoteDevices(env->local_devices, out_remote);
return Status::OK();
}
4.1.3.1 Start

Start 方法会把计数器 num_pending_ 初始化为 Worker 数目,然后遍历 Worker,逐一调用 NewRemoteDevices 进行处理。

void Start() {
{
mutex_lock l(mu_);
num_pending_ = targets_.size();
if (num_pending_ == 0) {
pending_zero_.notify_all();
}
}
// Talk to all workers to get the list of available devices.
using std::placeholders::_1;
using std::placeholders::_2;
for (size_t i = 0; i < targets_.size(); ++i) {
// TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may
// never be called.
NewRemoteDevices(env_->env, worker_cache_, targets_[i],
std::bind(&ME::WhenFound, this, i, _1, _2));
}
}

NewRemoteDevices 逻辑如下:

  • 依据 worker_name 调用 GetOrCreateWorker 在 worker_cache 内部查找 WorkerInterface 对象,如果有就获取,没有就构建。
  • 然后构建 GetStatusRequest,发送给找到的 Worker,具体通过 GetStatusAsync 完成。
  • Worker 返回 GetStatusResponse 之后,将调用回调函数 cb (WhenFound方法)之中的函数对象来获取 Worke 的设备信息。这里需要对获取到的设备信息进行处理,添加 worker_name。
void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
const string& worker_name, NewRemoteDevicesDone done) {
WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name);
if (wi == nullptr) {
std::vector<Device*> empty;
done(errors::NotFound("Device ", worker_name, " is not found."), &empty);
return;
}
struct Call {
GetStatusRequest req; // 发送消息
GetStatusResponse resp; // 相应消息
};
Call* call = new Call;
// 回调函数
auto cb = [env, worker_cache, worker_name, done, wi,
call](const Status& status) {
Status s = status;
std::vector<Device*> remote_devices;
auto cleanup = gtl::MakeCleanup(
[&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] {
worker_cache->ReleaseWorker(worker_name, wi);
done(s, &remote_devices);
delete call;
});
if (s.ok()) {
DeviceNameUtils::ParsedName worker_name_parsed;
if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) ||
!worker_name_parsed.has_job || !worker_name_parsed.has_replica ||
!worker_name_parsed.has_task) {
s = errors::InvalidArgument("Could not parse worker name: ",
worker_name);
return;
}
remote_devices.reserve(call->resp.device_attributes_size());
for (const DeviceAttributes& da : call->resp.device_attributes()) {
DeviceNameUtils::ParsedName device_name_parsed;
CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed))
<< "Device attribute name '" << da.name() << "' could not be "
<< "parsed. Device Attribute: " << da.DebugString();
// Preserve the exact name, if possible.
if (device_name_parsed.job == worker_name_parsed.job &&
device_name_parsed.replica == worker_name_parsed.replica &&
device_name_parsed.task == worker_name_parsed.task) {
auto d = new RemoteDevice(env, da);
remote_devices.push_back(d);
} else {
DeviceAttributes da_rewritten = da;
da_rewritten.set_name(DeviceNameUtils::FullName(
worker_name_parsed.job, worker_name_parsed.replica,
worker_name_parsed.task, device_name_parsed.type,
device_name_parsed.id));
auto d = new RemoteDevice(env, da_rewritten); // Experimental: Skipping over adding any TPU-type devices that aren't
// on the job called "worker" (but still adds the CPUs of other jobs).
if (getenv("TPU_NO_POPULATE_DEVICE_LIST_FROM_CLUSTER_SPEC") !=
nullptr) {
if (worker_name_parsed.job == "worker" ||
device_name_parsed.type.find("TPU") == std::string::npos) {
remote_devices.push_back(d);
}
} else {
remote_devices.push_back(d);
}
}
}
}
};
wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
/*fail_fast=*/false, cb);
}
4.1.3.2 Wait

Wait 方法之中,如果计数器不为 0,则一直调用 pending_zero_.wait_for 等待,期间主线程会周期性睡眠 10 秒钟。

Status Wait() {
mutex_lock l(mu_);
// TODO(mrry): Propagate a timeout here, since `num_pending_` may
// never become zero.
while (num_pending_ != 0) {
pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs));
if (num_pending_ != 0) {
for (size_t i = 0; i < targets_.size(); ++i) {
if (!seen_targets_[i]) {
LOG(INFO)
<< "CreateSession still waiting for response from worker: "
<< targets_[i];
}
}
}
}
return status_;
}
4.1.3.3 回调函数

Start 的回调函数如下,如果收到了某个 Worker 的GetStatusResponse 消息,则 Start 会调用到此。WhenDone将计数器减 1,如果计数器为 0,则调用 pending_zero_.notify_all(),这样 wait 之中的 pending_zero_.wait_for 语句 会被唤醒,GetRemoteDevices 方法就会利用 finder.GetRemoteDevices 获取查询结果,并且返回给客户。

void WhenFound(int target_index, const Status& s,
std::vector<Device*>* devices) {
mutex_lock l(mu_);
seen_targets_[target_index] = true;
if (!s.ok()) {
LOG(ERROR) << "CreateSession failed because worker "
<< targets_[target_index] << " returned error: " << s;
status_.Update(s);
} else {
found_.insert(found_.end(), devices->begin(), devices->end());
devices->clear();
}
--num_pending_;
if (num_pending_ == 0) {
pending_zero_.notify_all();
}
}

4.2 Worker 交互

NewRemoteDevices 之中会通过 GetStatusAsync 来构建 GetStatusRequest,发送给找到的 Worker。

WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name);
wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
/*fail_fast=*/false, cb);

4.2.1 GrpcRemoteWorker

wi 就是找到的 WorkerInterface,实际就是 GrpcRemoteWorker,这是 gRPC 的客户端,通过 stub 调用远端 WorkerService 相应的服务接口。

void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) override {
IssueRequest(request, response, getstatus_, std::move(done), call_opts,
fail_fast);
}

4.2.2 GrpcWorkerService

远端 Worker 之中,接收到消息是在 GrpcWorkerService 之中,当收到 GetStatusRequest 消息,将 由 GetStatusHandler 回调处理,GetStatusHandler 是一个宏。

#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
auto closure = [this, call]() { \
Status s = worker_->method(&call->request, &call->response); \
if (!s.ok()) { \
VLOG(3) << "Bad response from " << #method << ": " << s; \
} \
call->SendResponse(ToGrpcStatus(s)); \
}; \
if ((may_block_on_compute_pool)) { \
worker_->env()->env->SchedClosure(std::move(closure)); \
} else { \
worker_->env()->compute_pool->Schedule(std::move(closure)); \
} \
ENQUEUE_REQUEST(method, false); \
} HANDLE_CALL(GetStatus, false);

4.2.3 Worker

最后来到 Worker 类,其实它也只是转交给 DeviceMgr,并最终通过 GetStatusResponse 消息返回给远端调用方。

void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) {
const DeviceMgr* dm = env_->device_mgr;
std::vector<DeviceAttributes> devices;
dm->ListDeviceAttributes(&devices);
response->mutable_device_attributes()->Reserve(devices.size());
for (auto& d : devices) {
response->add_device_attributes()->Swap(&d);
}
done(Status::OK());
}

4.2.4 DeviceMgr

ListDeviceAttributes 有两种本地设备信息汇总的实现,具体如下。

void StaticDeviceMgr::ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const {
devices->reserve(devices_.size());
for (const auto& dev : devices_) {
devices->emplace_back(dev->attributes());
}
}

实现 2 如下:

void DynamicDeviceMgr::ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const {
tf_shared_lock l(devices_mu_);
devices->reserve(dynamic_devices_.size());
for (const auto& d : dynamic_devices_) {
devices->emplace_back(d->attributes());
}
}

至此,我们分析完了 Cache 和查找设备集,接下来我们去看看业务如何处理。

0xFF 参考

TensorFlow Internals

TensorFlow架构与设计:概述

TensorFlow内核剖析

TensorFlow架构与设计:OP本质论

[译] TensorFlow 白皮书

2017TensorFlow开发者峰会

https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/

TensorFlow 拆包(五):Distributed

TensorFlow Architecture

『深度长文』Tensorflow代码解析(五)

什么是in-graph replication和between-graph replication?

[腾讯机智] TensorFlow源码解析(1): 创建会话

05tensorflow分布式会话

第八节,配置分布式TensorFlow

TensorFlow 分布式(Distributed TensorFlow)

tensorflow源码解析之distributed_runtime

Distributed TensorFlow: A Gentle Introduction

一文说清楚Tensorflow分布式训练必备知识

TensorFlow中的Placement启发式算法模块——Placer

TensorFlow的图切割模块——Graph Partitioner

TensorFlow中的通信机制——Rendezvous(一)本地传输

TensorFlow分布式采坑记

TensorFlow技术内幕(九):模型优化之分布式执行

Tensorflow架构流程]

gRPC源码分析(c++)

[源码解析] TensorFlow 分布式环境(4) --- WorkerCache的更多相关文章

  1. [源码解析] TensorFlow 分布式环境(5) --- Session

    [源码解析] TensorFlow 分布式环境(5) --- Session 目录 [源码解析] TensorFlow 分布式环境(5) --- Session 1. 概述 1.1 Session 分 ...

  2. [源码解析] TensorFlow 分布式环境(6) --- Master 动态逻辑

    [源码解析] TensorFlow 分布式环境(6) --- Master 动态逻辑 目录 [源码解析] TensorFlow 分布式环境(6) --- Master 动态逻辑 1. GrpcSess ...

  3. [源码解析] TensorFlow 分布式环境(7) --- Worker 动态逻辑

    [源码解析] TensorFlow 分布式环境(7) --- Worker 动态逻辑 目录 [源码解析] TensorFlow 分布式环境(7) --- Worker 动态逻辑 1. 概述 1.1 温 ...

  4. [源码解析] TensorFlow 分布式环境(8) --- 通信机制

    [源码解析] TensorFlow 分布式环境(8) --- 通信机制 目录 [源码解析] TensorFlow 分布式环境(8) --- 通信机制 1. 机制 1.1 消息标识符 1.1.1 定义 ...

  5. [源码解析] TensorFlow 分布式环境(1) --- 总体架构

    [源码解析] TensorFlow 分布式环境(1) --- 总体架构 目录 [源码解析] TensorFlow 分布式环境(1) --- 总体架构 1. 总体架构 1.1 集群角度 1.1.1 概念 ...

  6. [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑

    [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑 目录 [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑 1. 总述 2. 接口 2.1 ...

  7. [源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

    [源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑 目录 [源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑 1. 继承关系 1.1 角 ...

  8. [源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇

    [源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇 目录 [源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇 1. ...

  9. [源码解析] TensorFlow 分布式之 MirroredStrategy

    [源码解析] TensorFlow 分布式之 MirroredStrategy 目录 [源码解析] TensorFlow 分布式之 MirroredStrategy 1. 设计&思路 1.1 ...

随机推荐

  1. python使用泛型

    所谓的泛型, 就是将数据类型作为参数进行传递, 即在我们用的时候确定数据类型, 这是一种在面向对象语言中经常使用的特性 一般类使用 以SQLAlchemy举例 比如: 我们统一写个将数据保存到数据库的 ...

  2. JAVA 对象的创建与克隆

    目录 一.对象的4种创建方式 二.通过new创建对象 三.反射 四.克隆对象 浅拷贝 深拷贝 五.反序列化 六.补充 一.对象的4种创建方式 new 创建 反射 克隆 反序列化 二.通过new创建对象 ...

  3. mysql data local的使用导入与导出数据到.txt

    一.先创建表 CREATE TABLE stu(id INT UNSIGNED AUTO_INCREMENT,NAME VARCHAR(15) UNIQUE, /* 唯一约束 , 可以不填写,如果填写 ...

  4. verification TLM传输数据导致多线程访问同一个数据

    TLM传输数据导致多线程访问同一个数据 原因 TLM发送数据跟mailbox类似,都是发送的引用,这样发送端和接收端的引用都指向同一个数据,这样就会出现发送端修改数据会影响到接收端,比如发送的时候数据 ...

  5. 模块和包—Day28

    一.模块 模块就是一个包含了python定义和声明的文件,文件名就是模块名字加上.py的后缀. import的过程:import一个模块的时候,首先创建一个属于my_module的内存空间,加载my_ ...

  6. 私有化轻量级持续集成部署方案--04-私有代码仓库服务-Gitea

    提示:本系列笔记全部存在于 Github, 可以直接在 Github 查看全部笔记 企业级最流行的私有代码仓库是 Gitlab, 一开始我也打算部署 Gitlab作为私有代码仓库. 但部署完 d 成后 ...

  7. tip8:CentOS8安装ftp服务器

    之前习惯使用OpenSuse,其图形化的安装.现在刚开始使用CentOS,老老实实使用命令吧! 1.本地cmd命令ftp链接虚拟机无法链接.查出虚拟机ftp服务是否开启:没有 ps -ef|grep ...

  8. 图的深度遍历(C语言)邻接矩阵表示

    知识讲解: 图的遍历分为两种,深度遍历与广度遍历.这里讨论深度遍历. 以上图为例讨论图(图片来自<算法笔记>)的深度遍历: 设图形的顶点数为n. 先从顶点v0开始,用一个数组vis[n]来 ...

  9. python-利用xlrd模块中读取有合并单元格的excel数据

    前言 对于excel中有合并单元格的情况,合并的单元格只能取到第一个单元格的值,合并的单元格后面的单元格内容的值为空,针对这个情况,写了下面一段代码实现, 对单元格进行判断,如果是传入的索引是合并单元 ...

  10. jmeter非gui之shell脚本

    非gui运行脚本,如果目录非空,会报不能写的错 可以通过shell脚本来处理: #!/bin/bash filename=`date +'%Y%m%d%H%M%S'` if [ -d /root/te ...