下面来看Solver<Dtype>::Solve(const char* resume_file)

solver.cpp

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); // Initialize to false every time we start solving.
requested_early_exit_ = false; if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
// 从以前中断的训练状态中恢复训练
Restore(resume_file);
} // For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
int start_iter = iter_;
// 主要的迭代过程都在这里
Step(param_.max_iter() - iter_);
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != )) {
Snapshot();
}
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
// training, for the train net we only run a forward pass as we've already
// updated the parameters "max_iter" times -- this final pass is only done to
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == ) {
int average_loss = this->param_.average_loss();
Dtype loss;
net_->Forward(&loss); UpdateSmoothedLoss(loss, start_iter, average_loss); LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
}
if (param_.test_interval() && iter_ % param_.test_interval() == ) {
TestAll();
}
LOG(INFO) << "Optimization Done.";
}

下面先看Solve中的Restore(resume_file)

solver.cpp

template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
string state_filename(state_file);
if (state_filename.size() >= &&
state_filename.compare(state_filename.size() - , , ".h5") == ) {
RestoreSolverStateFromHDF5(state_filename);
} else {
RestoreSolverStateFromBinaryProto(state_filename);
}
}

上面的RestoreSolverStateFromHDF5(state_filename)和RestoreSolverStateFromBinaryProto(state_filename)都是虚函数,调用的其实是其派生类的同名方法。例如,若使用SGD求解,SGDSolver类中的RestoreSolverStateFromBinaryProto方法如下

sgd_solver.cpp

template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
const string& state_file) {
SolverState state;
ReadProtoFromBinaryFile(state_file, &state);
// 此处获取上次训练中断时的迭代次数
this->iter_ = state.iter();
if (state.has_learned_net()) {
NetParameter net_param;
ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
this->net_->CopyTrainedLayersFrom(net_param);
}
this->current_step_ = state.current_step();
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
LOG(INFO) << "SGDSolver: restoring history";
for (int i = ; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}
}

下面主要分析Solve中的Step(param_.max_iter() - iter_)

solver.cpp

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();
losses_.clear();
smoothed_loss_ = ;
iteration_timer_.Start(); while (iter_ < stop_iter) {
// zero-init the params
// 将网络中参数的梯度清零
net_->ClearParamDiffs();
if (param_.test_interval() && iter_ % param_.test_interval() ==
&& (iter_ > || param_.test_initialization())) {
if (Caffe::root_solver()) {
TestAll();
}
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
} for (int i = ; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
const bool display = param_.display() && iter_ % param_.display() == ;
net_->set_debug_info(display && param_.debug_info());
// accumulate the loss and gradient
Dtype loss = ;
// param.iter_size_默认是1,正常情况下,此处其实只进行了以次前向和反向传播
for (int i = ; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
if (display) {
float lapse = iteration_timer_.Seconds();
float per_s = (iter_ - iterations_last_) / (lapse ? lapse : );
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< " (" << per_s << " iter/s, " << lapse << "s/"
<< param_.display() << " iters), loss = " << smoothed_loss_;
iteration_timer_.Start();
iterations_last_ = iter_;
const vector<Blob<Dtype>*>& result = net_->output_blobs();
int score_index = ;
for (int j = ; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]];
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];
for (int k = ; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
for (int i = ; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
// 网络的参数在此处更新。该函数是一个虚函数,具体由Solver的派生类来实现
ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
// 每次迭代其实是一个batch_size个样本输入网络中,将它们产生的网络参数的梯度加起来作为一次迭代的参数梯度。然后用这个梯度跟据一定的正则化方法、参数更新策略来更新参数
++iter_; SolverAction::Enum request = GetRequestedAction(); // Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() ==
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}

上面的loss += net_->ForwardBackward()是训练过程的核心。这行代码的功能是取一个batch_size数据,让其在网络中进行一次前向传播,得出损失的均值;再进行一次反向传播,得出网络参数的梯度(该梯度是一个batch_size数据产生的梯度的均值)。详细分析见下一章节


caffe Solve函数的更多相关文章

  1. MATLAB利用solve函数解多元一次方程组

    matlab求解多元方程组示例: syms k1 k2 k3; [k1 k2 k3] = solve(-3-k3==6, 2-k1-k2+2*k3==11, 2*k1+k2-k3+1==6)或者用[k ...

  2. pycaffe︱caffe中fine-tuning模型三重天(函数详解、框架简述)

    本文主要参考caffe官方文档[<Fine-tuning a Pretrained Network for Style Recognition>](http://nbviewer.jupy ...

  3. 非线性方程(组):MATLAB内置函数 solve, vpasolve, fsolve, fzero, roots [MATLAB]

    MATLAB函数 solve, vpasolve, fsolve, fzero, roots 功能和信息概览 求解函数 多项式型 非多项式型 一维 高维 符号 数值 算法 solve 支持,得到全部符 ...

  4. Matlab的solve()函数的使用方法

    Matlab的solve()函数的使用方法 1.首先是对方程的求解 不废话直接上例子 syms x: eq=x^2+2*x+1; s=solve(eq,x); 结果如下 完美的算出了方程的解 现在对上 ...

  5. 从零开始山寨Caffe·捌:IO系统(二)

    生产者 双缓冲组与信号量机制 在第陆章中提到了,如何模拟,以及取代根本不存的Q.full()函数. 其本质是:除了为生产者提供一个成品缓冲队列,还提供一个零件缓冲队列. 当我们从外部给定了固定容量的零 ...

  6. Caffe 源碼閱讀(五) Solver.cpp

    1.Solver类两个构造函数 Solver(const SolverParameter& param) Solver(const string& param_file) 初始化两个类 ...

  7. caffe简单介绍

    从四个层次来理解caffe:Blob.Layer.Net.Solver. 1.BlobBlob是caffe基本的数据结构,用四维矩阵 Batch×Channel×Height×Weight表示,存储了 ...

  8. caffe源码阅读(1)_整体框架和简介(摘录)

    原文链接:https://www.zhihu.com/question/27982282 1.Caffe代码层次.回答里面有人说熟悉Blob,Layer,Net,Solver这样的几大类,我比较赞同. ...

  9. caffe源码学习

    本文转载自:https://buptldy.github.io/2016/10/09/2016-10-09-Caffe_Code/ Caffe简介 Caffe作为一个优秀的深度学习框架网上已经有很多内 ...

随机推荐

  1. BZOJ1067 [SCOI2007]降雨量 RMQ???

    求救!!!神犇帮我瞅瞅呗...未完...调了2个半小时线段树,没调出来,大家帮帮我啊!!! 小詹用st表写. 我的思路就是把中间空着的年份设为无限,然后一点点特判就行了...然而没出来... [SCO ...

  2. Python入门 六、像个 Pythonista

    pickle import pickle test_data = ['Save me!',123.456,True] f = file('test.data','w') pickle.dump(tes ...

  3. Find Minimum in Rotated Sorted Array 典型二分查找

    https://oj.leetcode.com/problems/find-minimum-in-rotated-sorted-array/ Suppose a sorted array is rot ...

  4. PCB 无需解压,直接读取Zip压缩包指定文件 实现方法

    最近有一项需求,将电测试点数后台批量写入到工程系统流程指示中,而电测试文件存在压缩包中,压缩包存在公共网络盘 示例图: 一.采用原始方法(4步完成): 第1步:.网络盘ZIP拷到本地, 第2步:解压Z ...

  5. [Apple开发者帐户帮助]三、创建证书(1)证书概述

    在开发应用程序的过程中,您将创建不同的证书类型,以便在不同的上下文中使用.您将为iOS,tvOS和watchOS应用程序使用相同的证书集,并为macOS应用程序使用不同的证书集.您将使用开发证书在设备 ...

  6. GStreamer系列 - 基本介绍

    什么是Gstreamer? Gstreamer是一个支持Windows,Linux,Android, iOS的跨平台的多媒体框架,应用程序可以通过管道(Pipeline)的方式,将多媒体处理的各个步骤 ...

  7. POJ 1985 求树的直径 两边搜OR DP

    Cow Marathon Description After hearing about the epidemic of obesity in the USA, Farmer John wants h ...

  8. ES6 arrow function

    语法: () => { … } // 零个参数用 () 表示: x => { … } // 一个参数可以省略 (): (x, y) => { … } // 多参数不能省略 (): 当 ...

  9. 【HTTP】如何正常关闭连接

    参考:<HTTP权威指南> 所有HTTP客户端.服务器或者代理都可以任意时刻关闭一条TCP传输连接.但是服务器永远无法确定它关闭“空闲”连接的那一刻,在线路那一头的客户端有没有数据要发送. ...

  10. 全局设置border-box

    全局设置 border-box 很好,更符合我们通常对一个「盒子」尺寸的认知.,其次它可以省去一次又一次的加加减减,它还有一个关键作用——让有边框的盒子正常使用百分比宽度.但是使用了 border-b ...