下面来看Solver<Dtype>::Solve(const char* resume_file)

solver.cpp

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); // Initialize to false every time we start solving.
requested_early_exit_ = false; if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
// 从以前中断的训练状态中恢复训练
Restore(resume_file);
} // For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
int start_iter = iter_;
// 主要的迭代过程都在这里
Step(param_.max_iter() - iter_);
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != )) {
Snapshot();
}
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
// training, for the train net we only run a forward pass as we've already
// updated the parameters "max_iter" times -- this final pass is only done to
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == ) {
int average_loss = this->param_.average_loss();
Dtype loss;
net_->Forward(&loss); UpdateSmoothedLoss(loss, start_iter, average_loss); LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
}
if (param_.test_interval() && iter_ % param_.test_interval() == ) {
TestAll();
}
LOG(INFO) << "Optimization Done.";
}

下面先看Solve中的Restore(resume_file)

solver.cpp

template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
string state_filename(state_file);
if (state_filename.size() >= &&
state_filename.compare(state_filename.size() - , , ".h5") == ) {
RestoreSolverStateFromHDF5(state_filename);
} else {
RestoreSolverStateFromBinaryProto(state_filename);
}
}

上面的RestoreSolverStateFromHDF5(state_filename)和RestoreSolverStateFromBinaryProto(state_filename)都是虚函数,调用的其实是其派生类的同名方法。例如,若使用SGD求解,SGDSolver类中的RestoreSolverStateFromBinaryProto方法如下

sgd_solver.cpp

template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
const string& state_file) {
SolverState state;
ReadProtoFromBinaryFile(state_file, &state);
// 此处获取上次训练中断时的迭代次数
this->iter_ = state.iter();
if (state.has_learned_net()) {
NetParameter net_param;
ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
this->net_->CopyTrainedLayersFrom(net_param);
}
this->current_step_ = state.current_step();
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
LOG(INFO) << "SGDSolver: restoring history";
for (int i = ; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}
}

下面主要分析Solve中的Step(param_.max_iter() - iter_)

solver.cpp

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();
losses_.clear();
smoothed_loss_ = ;
iteration_timer_.Start(); while (iter_ < stop_iter) {
// zero-init the params
// 将网络中参数的梯度清零
net_->ClearParamDiffs();
if (param_.test_interval() && iter_ % param_.test_interval() ==
&& (iter_ > || param_.test_initialization())) {
if (Caffe::root_solver()) {
TestAll();
}
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
} for (int i = ; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
const bool display = param_.display() && iter_ % param_.display() == ;
net_->set_debug_info(display && param_.debug_info());
// accumulate the loss and gradient
Dtype loss = ;
// param.iter_size_默认是1,正常情况下,此处其实只进行了以次前向和反向传播
for (int i = ; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
if (display) {
float lapse = iteration_timer_.Seconds();
float per_s = (iter_ - iterations_last_) / (lapse ? lapse : );
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< " (" << per_s << " iter/s, " << lapse << "s/"
<< param_.display() << " iters), loss = " << smoothed_loss_;
iteration_timer_.Start();
iterations_last_ = iter_;
const vector<Blob<Dtype>*>& result = net_->output_blobs();
int score_index = ;
for (int j = ; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]];
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];
for (int k = ; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
for (int i = ; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
// 网络的参数在此处更新。该函数是一个虚函数,具体由Solver的派生类来实现
ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
// 每次迭代其实是一个batch_size个样本输入网络中,将它们产生的网络参数的梯度加起来作为一次迭代的参数梯度。然后用这个梯度跟据一定的正则化方法、参数更新策略来更新参数
++iter_; SolverAction::Enum request = GetRequestedAction(); // Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() ==
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}

上面的loss += net_->ForwardBackward()是训练过程的核心。这行代码的功能是取一个batch_size数据,让其在网络中进行一次前向传播,得出损失的均值;再进行一次反向传播,得出网络参数的梯度(该梯度是一个batch_size数据产生的梯度的均值)。详细分析见下一章节


caffe Solve函数的更多相关文章

  1. MATLAB利用solve函数解多元一次方程组

    matlab求解多元方程组示例: syms k1 k2 k3; [k1 k2 k3] = solve(-3-k3==6, 2-k1-k2+2*k3==11, 2*k1+k2-k3+1==6)或者用[k ...

  2. pycaffe︱caffe中fine-tuning模型三重天(函数详解、框架简述)

    本文主要参考caffe官方文档[<Fine-tuning a Pretrained Network for Style Recognition>](http://nbviewer.jupy ...

  3. 非线性方程(组):MATLAB内置函数 solve, vpasolve, fsolve, fzero, roots [MATLAB]

    MATLAB函数 solve, vpasolve, fsolve, fzero, roots 功能和信息概览 求解函数 多项式型 非多项式型 一维 高维 符号 数值 算法 solve 支持,得到全部符 ...

  4. Matlab的solve()函数的使用方法

    Matlab的solve()函数的使用方法 1.首先是对方程的求解 不废话直接上例子 syms x: eq=x^2+2*x+1; s=solve(eq,x); 结果如下 完美的算出了方程的解 现在对上 ...

  5. 从零开始山寨Caffe·捌:IO系统(二)

    生产者 双缓冲组与信号量机制 在第陆章中提到了,如何模拟,以及取代根本不存的Q.full()函数. 其本质是:除了为生产者提供一个成品缓冲队列,还提供一个零件缓冲队列. 当我们从外部给定了固定容量的零 ...

  6. Caffe 源碼閱讀(五) Solver.cpp

    1.Solver类两个构造函数 Solver(const SolverParameter& param) Solver(const string& param_file) 初始化两个类 ...

  7. caffe简单介绍

    从四个层次来理解caffe:Blob.Layer.Net.Solver. 1.BlobBlob是caffe基本的数据结构,用四维矩阵 Batch×Channel×Height×Weight表示,存储了 ...

  8. caffe源码阅读(1)_整体框架和简介(摘录)

    原文链接:https://www.zhihu.com/question/27982282 1.Caffe代码层次.回答里面有人说熟悉Blob,Layer,Net,Solver这样的几大类,我比较赞同. ...

  9. caffe源码学习

    本文转载自:https://buptldy.github.io/2016/10/09/2016-10-09-Caffe_Code/ Caffe简介 Caffe作为一个优秀的深度学习框架网上已经有很多内 ...

随机推荐

  1. RMAN 备份与恢复 实例

    1. 检查数据库模式:    sqlplus /nolog     conn /as sysdba    archive log list (查看数据库是否处于归档模式中) 若为非归档,则修改数据库归 ...

  2. B1060 [ZJOI2007]时态同步 dfs

    两遍dfs,第一遍有点像找重链,第二遍维护答案,每个点维护一个当前深度,然后就没啥了. ps:memset(lst,-1,sizeof(lst));这一句多余的话让我debug半天... 题干: De ...

  3. bzoj2763: [JLOI2011]飞行路线(分层图spfa)

    2763: [JLOI2011]飞行路线 Time Limit: 10 Sec  Memory Limit: 128 MBSubmit: 3234  Solved: 1235[Submit][Stat ...

  4. python的搜索路径与包(package)

    python的搜索路径其实是一个列表,它是指导入模块时,python会自动去找搜索这个列表当中的路径,如果路径中存在要导入的模块文件则导入成功,否则导入失败: >>> import ...

  5. Spring Boot (13) druid监控

    druid druid是和tomcat jdbc一样优秀的连接池,出自阿里巴巴.除了连接池,druid哈hi有一个很实用的监控功能. pom.xml 添加了以下依赖后,会自动用druid连接池替代默认 ...

  6. android开源新闻小程序、3D翻转公告效果、小说检索、Kotlin开发TODO清单等源码

    Android精选源码 开源新闻小程序源码分享 android动态壁纸.锁屏动画.来电秀等源码 android笔记App效果源码 Android实现3D版翻页公告效果 android小说搜索阅读源码 ...

  7. H265

    H265 h265  一.名词 CTU: 编码树单元 CU: 编码单元 PU: 以CU为根,对CU进行划分,一个预测单元PU包含一个亮度预测块PB和两个色度预测块PB. TU: 以CU为根,变换单元T ...

  8. response.getWriter().write()乱码问题

    前台代码: <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN"> <html> & ...

  9. DOM高级编程

    前言:W3C规定的三类DOM标准接口(换图)Core DOM(核心DOM),适用于各种结构化文档:XML DOM(Java OOP学过),专用于XML文档:HTML DOM,专用于HTML文档,下面了 ...

  10. SSDB 一个高性能的支持丰富数据结构的 NoSQL 数据库, 用于替代 Redis.

    SSDB 一个高性能的支持丰富数据结构的 NoSQL 数据库, 用于替代 Redis. 特性 替代 Redis 数据库, Redis 的 100 倍容量 LevelDB 网络支持, 使用 C/C++ ...