pycaffe训练的完整组件示例

为什么写这篇博客

1. 需要用到pycaffe

因为用到的开源代码基于Caffe;要维护的项目基于Caffe。基本上是用Caffe的Python接口。

2. 训练中想穿插验证并输出关注的指标

比如每训练完1个epoch就应该在完整的validation集合上执行evaluation,输出测量出的、关注的指标,例如AP、Accuracy、F1-score等。Caffe通过solver.prototxt中配置test_net能执行测试,但基本只能输出Accuracy而且是各个test_batch上的平均Accuracy,而不是想关注的验证集整体上的AP(见Solver.cpp源码)

3. 训练中期望有可视化输出

Caffe训练输出在屏幕终端,也可自行重定向到日志文件。的确可以自行解析日志文件,并结合flask搭建web页面实时显示输出。但是这不够标准和鲁棒。期望有专门的可视化工具,避免自己造难用的轮子。

本文给出很简陋的pyCaffe和VisualDL结合的例子。

解决方案

用pycaffe接管训练接口

通过自行编写python代码来执行训练,而不是用$CAFFE_ROOT/build/tools/caffe train --solver solver.prototxt的方式来启动。

  • solver.prototxt中需要配置test_net, test_iter, test_interval,保证solver有test_net对象
  • test_interval设置为999999999,以避开Solver.cpp中执行的TestAll()函数,转而在python代码中手动判断和执行validation
  • 执行validation之前注意test_net.share_with(train_net)
  • 利用solver.step(1)执行训练网络的一次迭代,利用solver.test_net[0].forward()执行测试网络的一次前传
  • 利用net.blobs['prob'].data的形式取出网络输出
  • 利用sklearn.metrics包,将取出的数据执行evaluation
  • 利用VisualDL等可视化工具,将取出的数据执行绘图

依赖项

VisualDL,是PaddlePaddle和ECharts团队联合推出的,应该是对抗谷歌的Tensorboarde的。相信ECharts的实力。

sudo pip install visualdl

看起来VisualDL和Tensorboard类似,不过对于Caffe,用不了Tensorboard,能用VisualDL也是好事。

参考代码

solve.py

#!/usr/bin/env python2
# coding: utf-8 """
inspired and adapted from:
- https://github.com/shelhamer/fcn.berkeleyvision.org
- https://github.com/rbgirshick/py-faster-rcnn
- https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/quick_start_en.md
""" from __future__ import print_function
import _init_paths
import caffe
import argparse
import os
import sys
from datetime import datetime
import cv2 from caffe.proto import caffe_pb2
import google.protobuf as pb2
import google.protobuf.text_format
import numpy as np
import perfeval from visualdl import LogWriter #for visualization during training def parse_args():
"""Parse input arguments"""
parser = argparse.ArgumentParser(description='Train a classification network')
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str, required=True) parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str) if len(sys.argv) == 1:
parser.print_help()
sys.exit(1) args = parser.parse_args()
return args class SolverWrapper:
"""对于Solver进行封装,便于外部调用"""
def __init__(self, solver_prototxt, num_epoch, num_example, pretrained_model=None):
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print('Loading pretrained model weights from {:s}'.format(pretrained_model))
self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param)
self.cur_epoch = 0
self.test_interval = 100 #用来替代self.solver_param.test_interval
self.logw = LogWriter("catdog_log", sync_cycle=100)
with self.logw.mode('train') as logger:
self.sc_train_loss = logger.scalar("loss")
self.sc_train_acc = logger.scalar("Accuracy")
with self.logw.mode('val') as logger:
self.sc_val_acc = logger.scalar("Accuracy")
self.sc_val_mAP = logger.scalar("mAP") def train_model(self):
"""执行训练的整个流程,穿插了validation"""
cur_iter = 0
test_batch_size, num_classes = self.solver.test_nets[0].blobs['prob'].shape
num_test_images_tot = test_batch_size * self.solver_param.test_iter[0]
while cur_iter < self.solver_param.max_iter:
#self.solver.step(self.test_interval)
for i in range(self.test_interval):
self.solver.step(1)
loss = self.solver.net.blobs['loss'].data
acc = self.solver.net.blobs['accuracy'].data
step = self.solver.iter
self.sc_train_loss.add_record(step, loss)
self.sc_train_acc.add_record(step, acc) self.eval_on_val(num_classes, num_test_images_tot, test_batch_size)
cur_iter += self.test_interval def eval_on_val(self, num_classes, num_test_images_tot, test_batch_size):
"""在整个验证集上执行inference和evaluation"""
self.solver.test_nets[0].share_with(self.solver.net)
self.cur_epoch += 1
scores = np.zeros((num_classes, num_test_images_tot), dtype=float)
gt_labels = np.zeros((1, num_test_images_tot), dtype=float).squeeze()
for t in range(self.solver_param.test_iter[0]):
output = self.solver.test_nets[0].forward()
probs = output['prob']
labels = self.solver.test_nets[0].blobs['label'].data gt_labels[t*test_batch_size:(t+1)*test_batch_size] = labels.T.astype(float)
scores[:,t*test_batch_size:(t+1)*test_batch_size] = probs.T ap, acc = perfeval.cls_eval(scores, gt_labels)
print('====================================================================\n')
print('\tDo validation after the {:d}-th training epoch\n'.format(self.cur_epoch))
print('>>>>', end='\t') #设定标记,方便于解析日志获取出数据
for i in range(num_classes):
print('AP[{:d}]={:.2f}'.format(i, ap[i]), end=', ')
mAP = np.average(ap)
print('mAP={:.2f}, Accuracy={:.2f}'.format(mAP, acc))
print('\n====================================================================\n')
step = self.solver.iter
self.sc_val_mAP.add_record(step, mAP)
self.sc_val_acc.add_record(step, acc) if __name__ == '__main__':
args = parse_args()
solver_prototxt = args.solver
num_epoch = args.num_epoch
num_batch = args.num_batch
pretrained_model = args.pretrained_model # init
caffe.set_mode_gpu()
caffe.set_device(0) sw = SolverWrapper(solver_prototxt, num_epoch, num_batch, pretrained_model)
sw.train_model()

perfeval.py

#!/usr/bin/env python2
# coding: utf-8 from __future__ import print_function
import numpy as np import sklearn.metrics as metrics def cls_eval(scores, gt_labels):
"""
分类任务的evaluation
@param scores: cxm np-array, m为样本数量(例如一个epoch)
@param gt_labels: 1xm np-array, 元素属于{0,1,2,...,K-1},表示K个类别的索引
"""
num_classes, num_test_imgs = scores.shape pred_labels = scores.argmax(axis=0) ap = np.zeros((num_classes, 1), dtype=float).squeeze()
for i in range(num_classes):
cls_labels = np.zeros((1, num_test_imgs), dtype=float).squeeze()
for j in range(num_test_imgs):
if gt_labels[j]==i:
cls_labels[j]=1
ap[i] = metrics.average_precision_score(cls_labels, scores[i]) acc = metrics.accuracy_score(gt_labels, pred_labels) return ap, acc

样例输出

首先需要开启训练,比如:

python solve.py

然后启动VisualDL:

visualDL --logdir=catdog_log --port=8080

打开浏览器获取训练的实时更新的绘图输出:http://localhost:8080。这里仅截图展示:





pycaffe训练的完整组件示例的更多相关文章

  1. 利用webuploader插件上传图片文件,完整前端示例demo,服务端使用SpringMVC接收

    利用WebUploader插件上传图片文件完整前端示例demo,服务端使用SpringMVC接收 Webuploader简介   WebUploader是由Baidu WebFE(FEX)团队开发的一 ...

  2. Vue列表组件与弹窗组件示例

    列表组件 <!DOCTYPE html> <html> <head> <meta charset="utf-8" /> <me ...

  3. [Nginx]Nginx的基本配置与优化1(完整配置示例与虚拟主机配置)

    ---------------------------------------------------------------------------------------- 完整配置示例: [ n ...

  4. 实战SpringCloud响应式微服务系列教程(第十章)响应式RESTful服务完整代码示例

    本文为实战SpringCloud响应式微服务系列教程第十章,本章给出响应式RESTful服务完整代码示例.建议没有之前基础的童鞋,先看之前的章节,章节目录放在文末. 1.搭建响应式RESTful服务. ...

  5. [deviceone开发]-do_Socket组件示例

    一.简介 do_Socket只实现了socket的客户端的功能,这个示例完整了展示了组件的基本用法,需要和sockettest3工具配合使用,sockettest3做为一个socket server来 ...

  6. SpringMVC札集(01)——SpringMVC入门完整详细示例(上)

    自定义View系列教程00–推翻自己和过往,重学自定义View 自定义View系列教程01–常用工具介绍 自定义View系列教程02–onMeasure源码详尽分析 自定义View系列教程03–onL ...

  7. android四大组件学习总结以及各个组件示例(1)

    android四大组件分别为activity.service.content provider.broadcast receiver. 一.android四大组件详解 1.activity (1)一个 ...

  8. asp.net core封装layui组件示例分享

    用什么封装?自然是TagHelper啊,是啥?自己瞅文档去 在学习使用TagHelper的时候,最希望的就是能有个Demo能够让自己作为参考 怎么去封装一个组件? 不同的情况怎么去实现? 有没有更好更 ...

  9. WebRTC 音频采样算法 附完整C++示例代码

    之前有大概介绍了音频采样相关的思路,详情见<简洁明了的插值音频重采样算法例子 (附完整C代码)>. 音频方面的开源项目很多很多. 最知名的莫过于谷歌开源的WebRTC, 其中的音频模块就包 ...

随机推荐

  1. 【Linux】虚拟服务器之LVS

    写在前面 觉得甚是幸运,能够有机会参与到ITOO配置环境的工作中去.现在正在熟悉,在搭建环境的时候,有LVS安装配置教程,对这一块有些懵逼,这几天查了一些资料,写在这里,和大家分享一下 是什么 LVS ...

  2. Shell基础总结

    一.用户登陆进入系统后的系统环境变量 $HOME 使用者自己的目录 $PATH 执行命令时所搜寻的目录 $TZ 时区 $MAILCHECK 每隔多少秒检查是否有新的信件 $PS1 在命令列时的提示号 ...

  3. 3.2. 使​​​​​​​用​​​​​​​ CPUFREQ 调​​​​​​​节​​​​​​​器​​​​​​​【转】

    转自:https://access.redhat.com/documentation/zh-cn/red_hat_enterprise_linux/6/html/power_management_gu ...

  4. pt-table-sync同步报错Called not_in_left in state 0 at /usr/bin/pt-table-sync line 5231【原创】

    试验环境MySQL5.7.19,自己使用pt-table-sync 3.0.2版本同步后,手动在从库执行后,在用pt-table-sync验证时报错 命令如下: pt-table-,u=用户名,p=, ...

  5. CLR via C# 中关于装箱拆箱的摘录

     装箱: 为了将一个值类型转换成一个引用类型,要使用一个名为装箱(boxing)的机制.下面总结了对值类型的一个实例进行装箱操作时在内部发生的事情. 1.在托管堆中分配好内存.分配的内存量是值类型的各 ...

  6. EF 常见异常总结

    问题:System.Reflection.TargetInvocationException: Exception has been thrown by the target of an invoca ...

  7. svn更新出现冲突的解决方法

    Conflict discovered in '/Users/apple/EtaxiAppServer/common/src/com/yaotaxi/db/MongoDBHelper.java'. S ...

  8. 利用表格分页显示数据的js组件bootstrap datatable的使用

    前面展示了datatable的简单使用,还可以通过bootstrap结合datatable来使用,这样可以进一步美化datatable插件 <!DOCTYPE html> <html ...

  9. FreeSWITCH voicemail

    功能描述:分机不存在时,进行语音留言. 步骤: 1.编译mod_voicemail 模块.默认是已经有编译 2.加载mod_voicemail模块: fs_cli  -->  reload mo ...

  10. Spring initializr使用

    Spring initializr 是Spring 官方提供的一个很好的工具,用来初始化一个Spring boot 的项目. 有两种方式可以使用Spring initializr来创建一个项目: ht ...