caffe Solve函数
下面来看Solver<Dtype>::Solve(const char* resume_file)
solver.cpp
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); // Initialize to false every time we start solving.
requested_early_exit_ = false; if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
// 从以前中断的训练状态中恢复训练
Restore(resume_file);
} // For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
int start_iter = iter_;
// 主要的迭代过程都在这里
Step(param_.max_iter() - iter_);
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != )) {
Snapshot();
}
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
// training, for the train net we only run a forward pass as we've already
// updated the parameters "max_iter" times -- this final pass is only done to
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == ) {
int average_loss = this->param_.average_loss();
Dtype loss;
net_->Forward(&loss); UpdateSmoothedLoss(loss, start_iter, average_loss); LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
}
if (param_.test_interval() && iter_ % param_.test_interval() == ) {
TestAll();
}
LOG(INFO) << "Optimization Done.";
}
下面先看Solve中的Restore(resume_file)
solver.cpp
template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
string state_filename(state_file);
if (state_filename.size() >= &&
state_filename.compare(state_filename.size() - , , ".h5") == ) {
RestoreSolverStateFromHDF5(state_filename);
} else {
RestoreSolverStateFromBinaryProto(state_filename);
}
}
上面的RestoreSolverStateFromHDF5(state_filename)和RestoreSolverStateFromBinaryProto(state_filename)都是虚函数,调用的其实是其派生类的同名方法。例如,若使用SGD求解,SGDSolver类中的RestoreSolverStateFromBinaryProto方法如下
sgd_solver.cpp
template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
const string& state_file) {
SolverState state;
ReadProtoFromBinaryFile(state_file, &state);
// 此处获取上次训练中断时的迭代次数
this->iter_ = state.iter();
if (state.has_learned_net()) {
NetParameter net_param;
ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
this->net_->CopyTrainedLayersFrom(net_param);
}
this->current_step_ = state.current_step();
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
LOG(INFO) << "SGDSolver: restoring history";
for (int i = ; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}
}
下面主要分析Solve中的Step(param_.max_iter() - iter_)
solver.cpp
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();
losses_.clear();
smoothed_loss_ = ;
iteration_timer_.Start(); while (iter_ < stop_iter) {
// zero-init the params
// 将网络中参数的梯度清零
net_->ClearParamDiffs();
if (param_.test_interval() && iter_ % param_.test_interval() ==
&& (iter_ > || param_.test_initialization())) {
if (Caffe::root_solver()) {
TestAll();
}
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
} for (int i = ; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
const bool display = param_.display() && iter_ % param_.display() == ;
net_->set_debug_info(display && param_.debug_info());
// accumulate the loss and gradient
Dtype loss = ;
// param.iter_size_默认是1,正常情况下,此处其实只进行了以次前向和反向传播
for (int i = ; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
if (display) {
float lapse = iteration_timer_.Seconds();
float per_s = (iter_ - iterations_last_) / (lapse ? lapse : );
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< " (" << per_s << " iter/s, " << lapse << "s/"
<< param_.display() << " iters), loss = " << smoothed_loss_;
iteration_timer_.Start();
iterations_last_ = iter_;
const vector<Blob<Dtype>*>& result = net_->output_blobs();
int score_index = ;
for (int j = ; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]];
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];
for (int k = ; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
for (int i = ; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
// 网络的参数在此处更新。该函数是一个虚函数,具体由Solver的派生类来实现
ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
// 每次迭代其实是一个batch_size个样本输入网络中,将它们产生的网络参数的梯度加起来作为一次迭代的参数梯度。然后用这个梯度跟据一定的正则化方法、参数更新策略来更新参数
++iter_; SolverAction::Enum request = GetRequestedAction(); // Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() ==
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}
上面的loss += net_->ForwardBackward()是训练过程的核心。这行代码的功能是取一个batch_size数据,让其在网络中进行一次前向传播,得出损失的均值;再进行一次反向传播,得出网络参数的梯度(该梯度是一个batch_size数据产生的梯度的均值)。详细分析见下一章节。
caffe Solve函数的更多相关文章
- MATLAB利用solve函数解多元一次方程组
matlab求解多元方程组示例: syms k1 k2 k3; [k1 k2 k3] = solve(-3-k3==6, 2-k1-k2+2*k3==11, 2*k1+k2-k3+1==6)或者用[k ...
- pycaffe︱caffe中fine-tuning模型三重天(函数详解、框架简述)
本文主要参考caffe官方文档[<Fine-tuning a Pretrained Network for Style Recognition>](http://nbviewer.jupy ...
- 非线性方程(组):MATLAB内置函数 solve, vpasolve, fsolve, fzero, roots [MATLAB]
MATLAB函数 solve, vpasolve, fsolve, fzero, roots 功能和信息概览 求解函数 多项式型 非多项式型 一维 高维 符号 数值 算法 solve 支持,得到全部符 ...
- Matlab的solve()函数的使用方法
Matlab的solve()函数的使用方法 1.首先是对方程的求解 不废话直接上例子 syms x: eq=x^2+2*x+1; s=solve(eq,x); 结果如下 完美的算出了方程的解 现在对上 ...
- 从零开始山寨Caffe·捌:IO系统(二)
生产者 双缓冲组与信号量机制 在第陆章中提到了,如何模拟,以及取代根本不存的Q.full()函数. 其本质是:除了为生产者提供一个成品缓冲队列,还提供一个零件缓冲队列. 当我们从外部给定了固定容量的零 ...
- Caffe 源碼閱讀(五) Solver.cpp
1.Solver类两个构造函数 Solver(const SolverParameter& param) Solver(const string& param_file) 初始化两个类 ...
- caffe简单介绍
从四个层次来理解caffe:Blob.Layer.Net.Solver. 1.BlobBlob是caffe基本的数据结构,用四维矩阵 Batch×Channel×Height×Weight表示,存储了 ...
- caffe源码阅读(1)_整体框架和简介(摘录)
原文链接:https://www.zhihu.com/question/27982282 1.Caffe代码层次.回答里面有人说熟悉Blob,Layer,Net,Solver这样的几大类,我比较赞同. ...
- caffe源码学习
本文转载自:https://buptldy.github.io/2016/10/09/2016-10-09-Caffe_Code/ Caffe简介 Caffe作为一个优秀的深度学习框架网上已经有很多内 ...
随机推荐
- 数据预处理之Minkowski距离计算
template <class T1, class T2> double Minkowski(const std::vector<T1> &inst1, const s ...
- 82.角色管理Extjs 页面
1. <%@ page language="java" import="java.util.*" pageEncoding="UTF-8&quo ...
- BZOJ 4057 状压DP
思路: 状压一下 就完了... f[i]表示选了的集合为i 转移的时候判一判就好了.. //By SiriusRen #include <cstdio> #include <cstr ...
- 【转】SQL SERVER 主体,已同步
转自郭大侠博客: https://www.cnblogs.com/gered/p/10601202.html 目录 SQL SERVER 基于数据库镜像的主从同步... 1 1.概念... 2 1. ...
- android黑科技系列——Wireshark和Fiddler分析Android中的TLS协议包数据(附带案例样本)
一.前言 在之前一篇文章已经介绍了一款网络访问软件的破解教程,当时采用的突破口是应用程序本身的一个漏洞,就是没有关闭日志信息,我们通过抓取日志获取到关键信息来找到突破口进行破解的.那篇文章也说到了,如 ...
- mysqlslap对mysql进行压力测试
mysqlslap是从5.1.4版开始的一个MySQL官方提供的压力测试工具.通过模拟多个并发客户端访问MySQL来执行压力测试,并且能很好的对比多个存储引擎在相同环境下的并发压力性能差别. mysq ...
- Christopher G. Atkeson 简介
有一个事实:双足机器人的稳定性问题单靠算法是搞不定的!!! 在2015 DARPA 机器人挑战赛中,许多参赛团队的机器人使用了Atlas,他们通过安装他们自己的软件并修改来让机器人保持平衡.来自WPI ...
- SLAM:(编译ORB)fatal error LNK1181: 无法打开输入文件“libboost_mpi-vc110-mt-1_57.lib”
对于使用MD版本编译的ORB_SLAM,会用到MPI版本的Boost,需要自己编译,比较麻烦. 因此使用MT版本进行生成,暂时无法完成. 工程配置 发现添加库文件使用了:从父级或项目默认继承,默认包含 ...
- Matlab/Eigen矩阵填充问题
Matlab进行矩阵填充时可以填充空矩阵,相当于空矩阵不存在,例如一下代码: P_RES = [ P_xv P_xvy P_xv*dy_dxv'; P_yxv P_y P_yxv*dy_dxv'; d ...
- HDU_2149_基础博弈sg函数
Public Sale Time Limit: 1000/1000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others)Total ...