Caffe::Snapshot的运行过程
Snapshot的存储
概述
Snapshot的存储格式有两种,分别是BINARYPROTO格式和hdf5格式。BINARYPROTO是一种二进制文件,并且可以通过修改shapshot_format来设置存储类型。该项的默认是BINARYPROTO。不管哪种格式,运行的过程是类似的,都是从Solver<Dtype>::Snapshot()函数进入,首先调用Net网络的方法,再操作网络中的每一层,最后再操作每一层中blob,最后调用write函数写入输出。源码入口:
 void Solver<Dtype>::Snapshot() {
   CHECK(Caffe::root_solver());
   string model_filename;
   switch (param_.snapshot_format()) {
   case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
     model_filename = SnapshotToBinaryProto();
     break;
   case caffe::SolverParameter_SnapshotFormat_HDF5:
     model_filename = SnapshotToHDF5();
     break;
   default:
     LOG(FATAL) << "Unsupported snapshot format.";
   }
BINARYPROTO格式
如果是BINARYPROTO的存储格式,就执行如下代码:
 string Solver<Dtype>::SnapshotToBinaryProto() {
   string model_filename = SnapshotFilename(".caffemodel");
   LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
   NetParameter net_param;
   net_->ToProto(&net_param, param_.snapshot_diff());
   WriteProtoToBinaryFile(net_param, model_filename);
   return model_filename;
 }   
首先会执行SnapshotFilename(“.caffemodel”)函数,识别出sovler.prototxt文件中snapshot_prefix的内容,作用该snapshot文件的文件名前缀。然后调用net_->ToProto(),具体的代码如下:
 void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
   param->Clear();
   param->set_name(name_);
   for (int i = ; i < net_input_blob_indices_.size(); ++i) {
     param->add_input(blob_names_[net_input_blob_indices_[i]]);
   }
   for (int i = ; i < layers_.size(); ++i) {
     LayerParameter* layer_param = param->add_layer();
     layers_[i]->ToProto(layer_param, write_diff);
   }
 }  
获取到网络中的每层的名字等参数后,调用layers_[i]->ToProto()每一层的ToProto方法,接下来
 void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
   param->Clear();
   param->CopyFrom(layer_param_);
   param->clear_blobs();
   for (int i = ; i < blobs_.size(); ++i) {
     blobs_[i]->ToProto(param->add_blobs(), write_diff);
   }
 } 
然后调用当前层下的所有blob的ToProto方法,即:
 void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const {
   proto->clear_shape();
   for (int i = ; i < shape_.size(); ++i) {
     proto->mutable_shape()->add_dim(shape_[i]);
   }
   proto->clear_double_data();
   proto->clear_double_diff();
   const double* data_vec = cpu_data();
   for (int i = ; i < count_; ++i) {
     proto->add_double_data(data_vec[i]);
   }
   if (write_diff) {
     const double* diff_vec = cpu_diff();
     for (int i = ; i < count_; ++i) {
       proto->add_double_diff(diff_vec[i]);
     }
   }
在每一个blob中,会调用add_double_data()函数,把data添加到snapshot文件中,同时会判断是否当前blob参与diff的计算,如果需要当前blob需要diff参数,就调用add_double_diff()添加到snapshot文件中。
调用完所有的blob的ToProto()方法后,会执行WriteProtoToBinaryFile()把该文件写出即可。
 void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
   fstream output(filename, ios::out | ios::trunc | ios::binary);
   CHECK(proto.SerializeToOstream(&output));
 }
在该方法里调用FStream的output方法进行输出。
Hdf5格式
Hdf5格式的运行过程和BINARYPROTO格式的过程类似,首先会调用SnapshotToHDF5()函数,即:
 string Solver<Dtype>::SnapshotToHDF5() {
   string model_filename = SnapshotFilename(".caffemodel.h5");
   LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
   net_->ToHDF5(model_filename, param_.snapshot_diff());
   return model_filename;
 }
首先会执行SnapshotFilename(“.caffemodel.h5”)函数,识别出sovler.prototxt文件中snapshot_prefix的内容,作用该snapshot文件的文件名前缀。然后调用net_->ToHDF5(),即:
 void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {
   hid_t file_hid = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
       H5P_DEFAULT);
   hid_t data_hid = H5Gcreate2(file_hid, "data", H5P_DEFAULT, H5P_DEFAULT,
       H5P_DEFAULT);
     hid_t diff_hid = -;
   if (write_diff) {
     diff_hid = H5Gcreate2(file_hid, "diff", H5P_DEFAULT, H5P_DEFAULT,
         H5P_DEFAULT);
    }
   for (int layer_id = ; layer_id < layers_.size(); ++layer_id) {
     const LayerParameter& layer_param = layers_[layer_id]->layer_param();
     string layer_name = layer_param.name();
     hid_t layer_data_hid = H5Gcreate2(data_hid, layer_name.c_str(),
     hid_t layer_diff_hid = -;
     if (write_diff) {
       layer_diff_hid = H5Gcreate2(diff_hid, layer_name.c_str(),
           H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
  }
     int num_params = layers_[layer_id]->blobs().size();
     for (int param_id = ; param_id < num_params; ++param_id) {
       ostringstream dataset_name;
       dataset_name << param_id;
       const int net_param_id = param_id_vecs_[layer_id][param_id];
       if (param_owners_[net_param_id] == -) {
         hdf5_save_nd_dataset<Dtype>(layer_data_hid, dataset_name.str(),
             *params_[net_param_id]);
       }
       if (write_diff) {
         hdf5_save_nd_dataset<Dtype>(layer_diff_hid, dataset_name.str(),
             *params_[net_param_id], true);
       }
 ...............
 H5Fclose(file_hid);
 }
该函数首先调用H5Fcreate()创建一个file文件,然后循环调用每一层,通过调用每一层的H5Gcreate2函数记录出该层的data_hid或者diff_hid(如果该层需要参与计算),然后进入每一层内部的blob,然后在当前blob内调用hdf5_save_nd_dataset()或hdf5_save_nd_dataset()(如果当前blob需要参与计算diff),将data添加到hdf5格式的文件中,最后调用H5Fclose(file_hid)函数,输出该文件。
Snapshot的恢复
概述
想在已经训练好的网络上继续训练,那么需要调用Restore()方法从snapshot的文件中恢复成网络,从而缩短了训练时间。方法的入口是Solver<Dtype>::Restore(const char* state_file)函数,即:
 void Solver<Dtype>::Restore(const char* state_file) {
   CHECK(Caffe::root_solver());
   string state_filename(state_file);
   if (state_filename.size() >=  &&
       state_filename.compare(state_filename.size() - , , ".h5") == ) {
     RestoreSolverStateFromHDF5(state_filename);
   } else {
     RestoreSolverStateFromBinaryProto(state_filename);
   }
该函数会解析snapshot文件是BINARYPROTO格式还是Hdf5格式,如果是BINARYPROTO格式的话就调用RestoreSolverStateFromBinaryProto()函数,如果格式Hdf5的格式,就执行RestoreSolverStateFromHDF5()。
BINARYPROOTO格式
如果是BINARYPROTO格式,则执行下列代码:
void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
const string& state_file) {
SolverState state;
ReadProtoFromBinaryFile(state_file, &state);
this->iter_ = state.iter();
if (state.has_learned_net()) {
NetParameter net_param;
ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
this->net_->CopyTrainedLayersFrom(net_param);
}
this->current_step_ = state.current_step();
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
for (int i = ; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}
}
该函数会大量调用google的protobuf包内的函数,首先会通过ReadProtoFromBinaryFile()函数读取BINARYPROTO格式的文件来返回是否可以成功读取。然后判断该snapshot是否有曾经训练过的网络,如果有,则调用函数ReadNetParamsFromBinaryFileOrDie()读取出该Net网络,然后调用函数CopyTrainedLayersFrom(net_param)具体恢复该网络的每一层以及当前层内的所有blob,具体数据恢复的工作就是CopyTrainedLayersFrom()函数内部变量调用FromProto()函数来实现blob复制的。然后会通过函数current_step()来判断上次训练的位置(迭代到多少次),然后通过循环把训练过的data数据通过FromProto()完成数据的复制。
Hdf5格式
 void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
   hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
   CHECK_GE(file_hid, ) << "Couldn't open solver state file " << state_file;
   this->iter_ = hdf5_load_int(file_hid, "iter");
   if (H5LTfind_dataset(file_hid, "learned_net")) {
     string learned_net = hdf5_load_string(file_hid, "learned_net");
     this->net_->CopyTrainedLayersFrom(learned_net);
   }
   this->current_step_ = hdf5_load_int(file_hid, "current_step");
   hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
   CHECK_GE(history_hid, ) << "Error reading history from " << state_file;
   int state_history_size = hdf5_get_num_links(history_hid);
   CHECK_EQ(state_history_size, history_.size())
       << "Incorrect length of history blobs.";
   for (int i = ; i < history_.size(); ++i) {
     ostringstream oss;
     oss << i;
     hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), ,
                                 kMaxBlobAxes, history_[i].get());
   }
   H5Gclose(history_hid);
   H5Fclose(file_hid);
 }
该函数会识别hdf5格式存储的snapshot文件的file_hid编号,会判断是否存在之前训练过的网络,如果存在则执行CopyTrainedLayersFrom()函数,完成网络的每层以及每层内的blob的数据的恢复复制,然后或取上一次的训练位置(进行的迭代),并且调用函数hdf5_load_nd_dataset()具体把每次迭代的数据恢复复制,最后再调用H5Fclose()关闭。
Caffe::Snapshot的运行过程的更多相关文章
- 江太公:javascript count(a)(b)(c)(d)运行过程思考
		昨天,我弟抛给我一个js的题,使用类似标题那样的调用方法计算a*b*c*d以致无穷的实现方法.思考了半天,终于理清了它的运行过程,记录于下: 函数体: <!DOCTYPE html> &l ... 
- JAVA - JAVA编译运行过程
		Java编译原理 *.java→*.class→机器码 java编译器 (编译) → 虚拟机(解释执行) → 解释器(翻译) → 机器码 1.Java编译过程与c/c++编译过程不同 Java编译程 ... 
- 孙鑫MFC学习笔记3:MFC程序运行过程
		1.MFC中WinMain函数的位置在APPMODUL.cpp APPMODUL.cpp中是_tWinMain,其实_tWinMain是一个宏#define _tWinMain WinMain 2.全 ... 
- HOWTO - Basic MSI安装包在安装运行过程中如何获取完整源路径
		有朋友问到如何在一个Windows Installer安装包中获取安装包源路径,就是在安装包运行过程中动态获取*.msi所在完整路径. 这个问题分两类,如果我们的安装包只是一个*.msi安装文件,那么 ... 
- DirectShow程序运行过程简析
		这段时间一直在学习陆其明老师的<DirectShow开发指南>一书,书中对DirectShow的很多细节讲解清晰,但是却容易让人缺少对全局的把握.在学习过程中,整理了关于DirectSho ... 
- Java Executor并发框架(二)剖析ThreadPoolExecutor运行过程
		上一篇从整体上介绍了Executor接口,从上一篇我们知道了Executor框架的最顶层实现是ThreadPoolExecutor类,Executors工厂类中提供的newScheduledThrea ... 
- 基础知识《零》---Java程序运行机制及运行过程
		Java运行机制 Java虚拟机(Java Virtual Machine):Java虚拟机可以理解成一个以字节码为机器指令的CPU:对于不同的运行平台,有不同的虚拟机:Java虚拟机机制屏蔽了底层运 ... 
- .net学习之.net和C#关系、运行过程、数据类型、类型转换、值类型和引用类型、数组以及方法参数等
		1..net 和 C# 的关系.net 是一个平台,C#是种语言,C#语言可以通过.net平台来编写.部署.运行.net应用程序,C#通过.net平台开发.net应用程序2..net平台的重要组成FC ... 
- ionic 运行过程中动态切换API服务器地址
		ionic 运行过程中动态切换API服务器地址 keywords: ionic,phonegap,cordova,网络制式,动态切换,变更,API,服务器地址,$resource,localstora ... 
随机推荐
- luogu P1397 [NOI2013]矩阵游戏
			传送门 题目中那两个递推式显然可以写成矩乘的形式,然后十进制快速幂即可.这里不再赘述 只有两个递推式,我们可以考虑一波推式子,首先第一行的元素应该分别是\(1,a+b,a^2+ab+b,a^3+a^2 ... 
- router4.0
			https://blog.csdn.net/sinat_17775997/article/details/69218382 React Router 4.0 实现路由守卫 https://www. ... 
- django基础篇04-自定义simple_tag和fitler
			自定义simple_tag app目录下创建templatetags目录 templatetags目录下创建xxpp.py 创建template对象register,注意变量名必须为register ... 
- minicom - 友好易用的串口通信程序
			总览 SYNOPSIS minicom [-somMlwz8] [-c on|off] [-S script] [-d entry] [-a on|off] [-t term] [-p pty] [- ... 
- fhq_treap || BZOJ 3223: Tyvj 1729 文艺平衡树 || Luogu P3391 【模板】文艺平衡树(Splay)
			题面: [模板]文艺平衡树(Splay) 题解:无 代码: #include<cstdio> #include<cstring> #include<iostream> ... 
- Eclipse设置类和方法的注释模板
			一.打开设置模板的窗口:Window->Preference->Java->Code Style->Code Template展开Comments,最常用的就是类和方法的注释, ... 
- php命令行工具
			https://jingyan.baidu.com/article/37bce2beb6e5681002f3a20f.html 
- CG-CTF | 密码重置2
			跟则提示走,美滋滋: 1.找到邮箱: 2.下载备份: 3.PHP弱类型,string与int用的是“==” ........这一行是省略的代码........ if(!empty($token)&am ... 
- loadrunner性能测试巧匠训练营-controller
			1.设置集合点 现在脚本添加集合点的函数,集合点不能添加到事务里面,负责统计事务的时候会把时间计算进去 2.IP欺骗 前言 https://www.cnblogs.com/danbing/p/7459 ... 
- [CSP-S模拟测试]:玩具(概率DP)
			题目描述 这个故事发生在很久以前,在$IcePrincess\text{_}1968$和$IcePrince\text{_}1968$都还在上幼儿园的时候. $IcePrince\text{_}196 ... 
