net_->ForwardBackward()的大致梳理
net_->ForwardBackward()方法在net.hpp文件中
Dtype ForwardBackward() {
Dtype loss;
Forward(&loss);
Backward();
return loss;
}
首先进入Forward(&loss)
net.cpp
template <typename Dtype>
const vector<Blob<Dtype>*>& Net<Dtype>::Forward(Dtype* loss) {
if (loss != NULL) {
*loss = ForwardFromTo(, layers_.size() - );
} else {
ForwardFromTo(, layers_.size() - );
}
return net_output_blobs_;
}
接着进入*loss = ForwardFromTo(0, layers_.size() - 1)这句话
net.cpp
template <typename Dtype>
Dtype Net<Dtype>::ForwardFromTo(int start, int end) {
CHECK_GE(start, );
CHECK_LT(end, layers_.size());
Dtype loss = ;
for (int i = start; i <= end; ++i) {
for (int c = ; c < before_forward_.size(); ++c) {
before_forward_[c]->run(i);
}
// 一层一层地前向传播,bottom_vecs_[i]是各层的输入输入数据指针,top_vecs_[i]是各层的输出数据指针
Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
loss += layer_loss;
if (debug_info_) { ForwardDebugInfo(i); }
for (int c = ; c < after_forward_.size(); ++c) {
after_forward_[c]->run(i);
}
}
return loss;
}
再进入Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i])。首先会进入Layer类的Forward函数
layer.hpp
// Forward and backward wrappers. You should implement the cpu and
// gpu specific implementations instead, and should not change these
// functions.
template <typename Dtype>
inline Dtype Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
Dtype loss = ;
Reshape(bottom, top);
switch (Caffe::mode()) {
case Caffe::CPU:
// Layer类的虚函数,具体由其不同的派生类作不同的实现,也就是此句将会调用不同网络层的Forward_cpu函数,下面的Forward_gpu同理。
Forward_cpu(bottom, top);
for (int top_id = ; top_id < top.size(); ++top_id) {
if (!this->loss(top_id)) { continue; }
const int count = top[top_id]->count();
const Dtype* data = top[top_id]->cpu_data();
const Dtype* loss_weights = top[top_id]->cpu_diff();
loss += caffe_cpu_dot(count, data, loss_weights);
}
break;
case Caffe::GPU:
Forward_gpu(bottom, top);
#ifndef CPU_ONLY
for (int top_id = ; top_id < top.size(); ++top_id) {
if (!this->loss(top_id)) { continue; }
const int count = top[top_id]->count();
const Dtype* data = top[top_id]->gpu_data();
const Dtype* loss_weights = top[top_id]->gpu_diff();
Dtype blob_loss = ;
caffe_gpu_dot(count, data, loss_weights, &blob_loss);
loss += blob_loss;
}
#endif
break;
default:
LOG(FATAL) << "Unknown caffe mode.";
}
return loss;
} template <typename Dtype>
inline void Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
switch (Caffe::mode()) {
case Caffe::CPU:
Backward_cpu(top, propagate_down, bottom);
break;
case Caffe::GPU:
Backward_gpu(top, propagate_down, bottom);
break;
default:
LOG(FATAL) << "Unknown caffe mode.";
}
}
接下来再看ForwardBackward()中的Backward()
net.cpp
template <typename Dtype>
void Net<Dtype>::Backward() {
// 从最后一层开始反向传播
BackwardFromTo(layers_.size() - , );
if (debug_info_) {
Dtype asum_data = , asum_diff = , sumsq_data = , sumsq_diff = ;
for (int i = ; i < learnable_params_.size(); ++i) {
asum_data += learnable_params_[i]->asum_data();
asum_diff += learnable_params_[i]->asum_diff();
sumsq_data += learnable_params_[i]->sumsq_data();
sumsq_diff += learnable_params_[i]->sumsq_diff();
}
const Dtype l2norm_data = std::sqrt(sumsq_data);
const Dtype l2norm_diff = std::sqrt(sumsq_diff);
LOG(ERROR) << " [Backward] All net params (data, diff): "
<< "L1 norm = (" << asum_data << ", " << asum_diff << "); "
<< "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")";
}
}
进入BackwardFromTo(layers_.size() - 1, 0)
net.cpp
template <typename Dtype>
void Net<Dtype>::BackwardFromTo(int start, int end) {
CHECK_GE(end, );
CHECK_LT(start, layers_.size());
for (int i = start; i >= end; --i) {
for (int c = ; c < before_backward_.size(); ++c) {
before_backward_[c]->run(i);
}
if (layer_need_backward_[i]) {
// 反向传播过程中,top_vecs_[i]是各层的输入数据指针,bottom_vecs[i]是各层的输出数据指针,与前向传播正好相反
layers_[i]->Backward(
top_vecs_[i], bottom_need_backward_[i], bottom_vecs_[i]);
if (debug_info_) { BackwardDebugInfo(i); }
}
for (int c = ; c < after_backward_.size(); ++c) {
after_backward_[c]->run(i);
}
}
}
进入layers_[i]->Backward(top_vecs_[i], bottom_need_backward_[i], bottom_vecs_[i])
layer.hpp
template <typename Dtype>
inline void Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
switch (Caffe::mode()) {
case Caffe::CPU:
// 与前向传播类似,利用不同派生类的同名函数作出不同层的反向传播的具体实现
Backward_cpu(top, propagate_down, bottom);
break;
case Caffe::GPU:
Backward_gpu(top, propagate_down, bottom);
break;
default:
LOG(FATAL) << "Unknown caffe mode.";
}
}
不同层的前向、反向传播的具体实现见下一章节。
net_->ForwardBackward()的大致梳理的更多相关文章
- 带你梳理Jetty自定义ProxyServlet实现反向代理服务
摘要:最近要做一个将K8s中的某组件UI通过反向代理映射到自定义规则的链接地址上,提供给用户访问的需求.所以顺便研究了一下Jetty的ProxyServlet. 本文分享自华为云社区<Jetty ...
- Linux内核笔记--网络子系统初探
内核版本:linux-2.6.11 本文对Linux网络子系统的收发包的流程进行一个大致梳理,以流水账的形式记录从应用层write一个socket开始到这些数据被应用层read出来的这个过程中linu ...
- 【Bugly技术干货】那些年我们用过的显示性能指标
Bugly 技术干货系列内容主要涉及移动开发方向,是由 Bugly 邀请腾讯内部各位技术大咖,通过日常工作经验的总结以及感悟撰写而成,内容均属原创,转载请标明出处. 前言: 注:Google 在自己文 ...
- Android消息机制:Looper,MessageQueue,Message与handler
Android消息机制好多人都讲过,但是自己去翻源码的时候才能明白. 今天试着讲一下,因为目标是讲清楚整体逻辑,所以不追究细节. Message是消息机制的核心,所以从Message讲起. 1.Mes ...
- tair源码分析——leveldb存储引擎使用
分析完leveldb以后,接下来的时间准备队tair的源码进行阅读和分析.我们刚刚分析完了leveldb而在tair中leveldb是其几大存储引擎之一,所以我们这里首先从tair对leveldb的使 ...
- 关闭对话框,OnClose和OnCancel
我们知道,在对话框中,屏蔽ESC键自己主动退出能够选择重载OnCancel为哑函数的方法: void CXXXXDlg::OnCancel() { // TODO: Add ...
- ssm+jsp+自定义标签实现分页,可以通用(前端实现)
近期做了一些分页方面的开发,大致梳理一下 1 jsp页面上关于分页的代码 <tr> <td colspan="9"> <ule1:pagination ...
- netty高级篇(3)-HTTP协议开发
一.HTTP协议简介 应用层协议http,发展至今已经是http2.0了,拥有以下特点: (1) CS模式的协议 (2) 简单 - 只需要服务URL,携带必要的请求参数或者消息体 (3) 灵活 - 任 ...
- 疑问:Spring中构造器、init-method、@PostConstruct、afterPropertiesSet孰先孰后,自动注入发生时间
问题:今天想写一个通用点的方法,根据传入的参数的类型(clazz对象),判断使用哪个mapper来插入mysql数据库. 下面是我的写法: public interface BizNeeqCommon ...
随机推荐
- DCloud-JS-MUI-JS:utils.js
ylbtech-DCloud-JS:utils.js 1. 导航返回返回顶部 1. var oldBack = mui.back; mui.back = function () { mui.back ...
- nodejs--Nodejs单元测试小结
前言 最近在写一课程的Project,用Node写了一个实时聊天小应用,其中就用到了单元测试.在写Node单元测试的时候,一方面感受到了单元测试的重要性,另一方面感受到了Node单元测试的不够成熟,尚 ...
- 比较两个Json对象是否相等
一个前端同事遇到的面试题,抽空写了写,也算是个积累 1.先准备三个工具方法,用于判断是否是对象类型,是否是数组,获取对象长度 function isObj(object) { return objec ...
- HTML 5概述
HTML语言是一种简易的文件交换标准,用于物理的文件结构,它旨在定义文件内的对象和描述文件的逻辑结构,而并不定义文件的显示.由于HTML所描述的文件具有极高的适应性,所以特别适合于WWW的出版环境. ...
- 算法之dfs篇
dfs算法是深度搜索算法.从某一节点开始遍历直至到路径底部,如果不是所寻找的,则回溯到上个节点后,再遍历其他路径.不断重复这个过程.一般此过程消耗很大,需要一些优化才能保持算法的高效. hdu1010 ...
- DeltaFish 校园物资共享平台 第三次小组会议
一.想法 娄雨禛: 网页底层开发转移到后端,快速建站,效率高. 可以依照模板进行仿制. 可以考虑只进行页面设计. 但是出现问题不会调试. 所以自己写源码,做出一个大致的样子. 二.上周进度汇报 齐天杨 ...
- MySQL 优化之 index_merge (索引合并)
深入理解 index merge 是使用索引进行优化的重要基础之一.理解了 index merge 技术,我们才知道应该如何在表上建立索引. 1. 为什么会有index merge 我们的 where ...
- DHCP 和 MDT 分开服务器的设置方法
DHCP设置 043:供应商特定信息:01 04 00 00 00 00 FF 060:PXEClient:PXEClient 066:启动服务器主机名:IP 067:启动文件名:\Boot\x86\ ...
- ROS:Nvidia Jetson TK1平台安装使用ROS
原文连接: http://wiki.ros.org/indigo/Installation/UbuntuARM Ubuntu ARM install of ROS Indigo There are c ...
- WPF动态折线图
此项目源码下载地址:https://github.com/lizhiqiang0204/WpfDynamicChart 效果图如下: 此项目把折线图制作成了一个控件,在主界面设置好参数直接调用即可,下 ...