下面来看Solver<Dtype>::Solve(const char* resume_file)

solver.cpp

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); // Initialize to false every time we start solving.
requested_early_exit_ = false; if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
// 从以前中断的训练状态中恢复训练
Restore(resume_file);
} // For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
int start_iter = iter_;
// 主要的迭代过程都在这里
Step(param_.max_iter() - iter_);
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != )) {
Snapshot();
}
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
// training, for the train net we only run a forward pass as we've already
// updated the parameters "max_iter" times -- this final pass is only done to
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == ) {
int average_loss = this->param_.average_loss();
Dtype loss;
net_->Forward(&loss); UpdateSmoothedLoss(loss, start_iter, average_loss); LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
}
if (param_.test_interval() && iter_ % param_.test_interval() == ) {
TestAll();
}
LOG(INFO) << "Optimization Done.";
}

下面先看Solve中的Restore(resume_file)

solver.cpp

template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
string state_filename(state_file);
if (state_filename.size() >= &&
state_filename.compare(state_filename.size() - , , ".h5") == ) {
RestoreSolverStateFromHDF5(state_filename);
} else {
RestoreSolverStateFromBinaryProto(state_filename);
}
}

上面的RestoreSolverStateFromHDF5(state_filename)和RestoreSolverStateFromBinaryProto(state_filename)都是虚函数,调用的其实是其派生类的同名方法。例如,若使用SGD求解,SGDSolver类中的RestoreSolverStateFromBinaryProto方法如下

sgd_solver.cpp

template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
const string& state_file) {
SolverState state;
ReadProtoFromBinaryFile(state_file, &state);
// 此处获取上次训练中断时的迭代次数
this->iter_ = state.iter();
if (state.has_learned_net()) {
NetParameter net_param;
ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
this->net_->CopyTrainedLayersFrom(net_param);
}
this->current_step_ = state.current_step();
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
LOG(INFO) << "SGDSolver: restoring history";
for (int i = ; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}
}

下面主要分析Solve中的Step(param_.max_iter() - iter_)

solver.cpp

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();
losses_.clear();
smoothed_loss_ = ;
iteration_timer_.Start(); while (iter_ < stop_iter) {
// zero-init the params
// 将网络中参数的梯度清零
net_->ClearParamDiffs();
if (param_.test_interval() && iter_ % param_.test_interval() ==
&& (iter_ > || param_.test_initialization())) {
if (Caffe::root_solver()) {
TestAll();
}
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
} for (int i = ; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
const bool display = param_.display() && iter_ % param_.display() == ;
net_->set_debug_info(display && param_.debug_info());
// accumulate the loss and gradient
Dtype loss = ;
// param.iter_size_默认是1,正常情况下,此处其实只进行了以次前向和反向传播
for (int i = ; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
if (display) {
float lapse = iteration_timer_.Seconds();
float per_s = (iter_ - iterations_last_) / (lapse ? lapse : );
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< " (" << per_s << " iter/s, " << lapse << "s/"
<< param_.display() << " iters), loss = " << smoothed_loss_;
iteration_timer_.Start();
iterations_last_ = iter_;
const vector<Blob<Dtype>*>& result = net_->output_blobs();
int score_index = ;
for (int j = ; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]];
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];
for (int k = ; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
for (int i = ; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
// 网络的参数在此处更新。该函数是一个虚函数,具体由Solver的派生类来实现
ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
// 每次迭代其实是一个batch_size个样本输入网络中,将它们产生的网络参数的梯度加起来作为一次迭代的参数梯度。然后用这个梯度跟据一定的正则化方法、参数更新策略来更新参数
++iter_; SolverAction::Enum request = GetRequestedAction(); // Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() ==
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}

上面的loss += net_->ForwardBackward()是训练过程的核心。这行代码的功能是取一个batch_size数据,让其在网络中进行一次前向传播,得出损失的均值;再进行一次反向传播,得出网络参数的梯度(该梯度是一个batch_size数据产生的梯度的均值)。详细分析见下一章节


caffe Solve函数的更多相关文章

  1. MATLAB利用solve函数解多元一次方程组

    matlab求解多元方程组示例: syms k1 k2 k3; [k1 k2 k3] = solve(-3-k3==6, 2-k1-k2+2*k3==11, 2*k1+k2-k3+1==6)或者用[k ...

  2. pycaffe︱caffe中fine-tuning模型三重天(函数详解、框架简述)

    本文主要参考caffe官方文档[<Fine-tuning a Pretrained Network for Style Recognition>](http://nbviewer.jupy ...

  3. 非线性方程(组):MATLAB内置函数 solve, vpasolve, fsolve, fzero, roots [MATLAB]

    MATLAB函数 solve, vpasolve, fsolve, fzero, roots 功能和信息概览 求解函数 多项式型 非多项式型 一维 高维 符号 数值 算法 solve 支持,得到全部符 ...

  4. Matlab的solve()函数的使用方法

    Matlab的solve()函数的使用方法 1.首先是对方程的求解 不废话直接上例子 syms x: eq=x^2+2*x+1; s=solve(eq,x); 结果如下 完美的算出了方程的解 现在对上 ...

  5. 从零开始山寨Caffe·捌:IO系统(二)

    生产者 双缓冲组与信号量机制 在第陆章中提到了,如何模拟,以及取代根本不存的Q.full()函数. 其本质是:除了为生产者提供一个成品缓冲队列,还提供一个零件缓冲队列. 当我们从外部给定了固定容量的零 ...

  6. Caffe 源碼閱讀(五) Solver.cpp

    1.Solver类两个构造函数 Solver(const SolverParameter& param) Solver(const string& param_file) 初始化两个类 ...

  7. caffe简单介绍

    从四个层次来理解caffe:Blob.Layer.Net.Solver. 1.BlobBlob是caffe基本的数据结构,用四维矩阵 Batch×Channel×Height×Weight表示,存储了 ...

  8. caffe源码阅读(1)_整体框架和简介(摘录)

    原文链接:https://www.zhihu.com/question/27982282 1.Caffe代码层次.回答里面有人说熟悉Blob,Layer,Net,Solver这样的几大类,我比较赞同. ...

  9. caffe源码学习

    本文转载自:https://buptldy.github.io/2016/10/09/2016-10-09-Caffe_Code/ Caffe简介 Caffe作为一个优秀的深度学习框架网上已经有很多内 ...

随机推荐

  1. 什么是 less? 如何使用 less?

    什么是 Less? Less 是一门 CSS 预处理语言,它扩充了 CSS 语言,增加了诸如变量.混合(mixin).嵌套.函数等功能,让 CSS 更易编写.维护等. 本质上,Less 包含一套自定义 ...

  2. hdu 3037Saving Beans(卢卡斯定理)

    Saving Beans Saving Beans Time Limit: 6000/3000 MS (Java/Others)    Memory Limit: 32768/32768 K (Jav ...

  3. Python27天 反射 ,isinstance与ssubclass 内置方法

    所学内容 反射 1.hasattr ( 判断一个属性在对象里有没有 ) -------------------- [对象,字符串属性]本质是:# 判断 ' name ' in obj.__dict__ ...

  4. javaweb 课程设计编码和设计文档

    企业办公软件设计文档 1引言 1.1编写目的 OA办公自动化系统详细设计是设计的第三个阶段,这个阶段的主要任务是在OA办公自动化系统概要设计书基础上,对概要设计中产生的功能模块进行过程描述,设计功能模 ...

  5. jdbc 接口学习笔记

    一.JDBC概念 JDBC(Java Data Base Connectivity,java数据库连接)是一种用于执行SQL语句的Java API,可以为多种关系型数据库提供统一访问,它由一组用Jav ...

  6. POJ1915 BFS&双向BFS

    俩月前写的普通BFS #include <cstdio> #include <iostream> #include <cstring> #include <q ...

  7. ACM_魔仙岛探险(深搜)

    魔仙岛探险 Time Limit: 2000/1000ms (Java/Others) Problem Description: 小敏通过秘密方法得到一张不完整的魔仙岛航拍地图.魔仙岛由一个主岛和一些 ...

  8. A - HQ9+

    Problem description HQ9+ is a joke programming language which has only four one-character instructio ...

  9. RRDtool入门详解

    ---------------原创内容,转载请注明出处.<yaoyao0777@Gmail.com>------------ 一.概述 RRDtool(round-robin databa ...

  10. Electron结合React开发环境遇到的问题

    链接 将create-react-app与electron集成在了一个项目中.但是在React中无法使用electron 当在React中使用require('electron')时就会报TypeEr ...