pycaffe训练的完整组件示例

为什么写这篇博客

1. 需要用到pycaffe

因为用到的开源代码基于Caffe;要维护的项目基于Caffe。基本上是用Caffe的Python接口。

2. 训练中想穿插验证并输出关注的指标

比如每训练完1个epoch就应该在完整的validation集合上执行evaluation,输出测量出的、关注的指标,例如AP、Accuracy、F1-score等。Caffe通过solver.prototxt中配置test_net能执行测试,但基本只能输出Accuracy而且是各个test_batch上的平均Accuracy,而不是想关注的验证集整体上的AP(见Solver.cpp源码)

3. 训练中期望有可视化输出

Caffe训练输出在屏幕终端,也可自行重定向到日志文件。的确可以自行解析日志文件,并结合flask搭建web页面实时显示输出。但是这不够标准和鲁棒。期望有专门的可视化工具,避免自己造难用的轮子。

本文给出很简陋的pyCaffe和VisualDL结合的例子。

解决方案

用pycaffe接管训练接口

通过自行编写python代码来执行训练,而不是用$CAFFE_ROOT/build/tools/caffe train --solver solver.prototxt的方式来启动。

  • solver.prototxt中需要配置test_net, test_iter, test_interval,保证solver有test_net对象
  • test_interval设置为999999999,以避开Solver.cpp中执行的TestAll()函数,转而在python代码中手动判断和执行validation
  • 执行validation之前注意test_net.share_with(train_net)
  • 利用solver.step(1)执行训练网络的一次迭代,利用solver.test_net[0].forward()执行测试网络的一次前传
  • 利用net.blobs['prob'].data的形式取出网络输出
  • 利用sklearn.metrics包,将取出的数据执行evaluation
  • 利用VisualDL等可视化工具,将取出的数据执行绘图

依赖项

VisualDL,是PaddlePaddle和ECharts团队联合推出的,应该是对抗谷歌的Tensorboarde的。相信ECharts的实力。

sudo pip install visualdl

看起来VisualDL和Tensorboard类似,不过对于Caffe,用不了Tensorboard,能用VisualDL也是好事。

参考代码

solve.py

#!/usr/bin/env python2
# coding: utf-8 """
inspired and adapted from:
- https://github.com/shelhamer/fcn.berkeleyvision.org
- https://github.com/rbgirshick/py-faster-rcnn
- https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/quick_start_en.md
""" from __future__ import print_function
import _init_paths
import caffe
import argparse
import os
import sys
from datetime import datetime
import cv2 from caffe.proto import caffe_pb2
import google.protobuf as pb2
import google.protobuf.text_format
import numpy as np
import perfeval from visualdl import LogWriter #for visualization during training def parse_args():
"""Parse input arguments"""
parser = argparse.ArgumentParser(description='Train a classification network')
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str, required=True) parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str) if len(sys.argv) == 1:
parser.print_help()
sys.exit(1) args = parser.parse_args()
return args class SolverWrapper:
"""对于Solver进行封装,便于外部调用"""
def __init__(self, solver_prototxt, num_epoch, num_example, pretrained_model=None):
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print('Loading pretrained model weights from {:s}'.format(pretrained_model))
self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param)
self.cur_epoch = 0
self.test_interval = 100 #用来替代self.solver_param.test_interval
self.logw = LogWriter("catdog_log", sync_cycle=100)
with self.logw.mode('train') as logger:
self.sc_train_loss = logger.scalar("loss")
self.sc_train_acc = logger.scalar("Accuracy")
with self.logw.mode('val') as logger:
self.sc_val_acc = logger.scalar("Accuracy")
self.sc_val_mAP = logger.scalar("mAP") def train_model(self):
"""执行训练的整个流程,穿插了validation"""
cur_iter = 0
test_batch_size, num_classes = self.solver.test_nets[0].blobs['prob'].shape
num_test_images_tot = test_batch_size * self.solver_param.test_iter[0]
while cur_iter < self.solver_param.max_iter:
#self.solver.step(self.test_interval)
for i in range(self.test_interval):
self.solver.step(1)
loss = self.solver.net.blobs['loss'].data
acc = self.solver.net.blobs['accuracy'].data
step = self.solver.iter
self.sc_train_loss.add_record(step, loss)
self.sc_train_acc.add_record(step, acc) self.eval_on_val(num_classes, num_test_images_tot, test_batch_size)
cur_iter += self.test_interval def eval_on_val(self, num_classes, num_test_images_tot, test_batch_size):
"""在整个验证集上执行inference和evaluation"""
self.solver.test_nets[0].share_with(self.solver.net)
self.cur_epoch += 1
scores = np.zeros((num_classes, num_test_images_tot), dtype=float)
gt_labels = np.zeros((1, num_test_images_tot), dtype=float).squeeze()
for t in range(self.solver_param.test_iter[0]):
output = self.solver.test_nets[0].forward()
probs = output['prob']
labels = self.solver.test_nets[0].blobs['label'].data gt_labels[t*test_batch_size:(t+1)*test_batch_size] = labels.T.astype(float)
scores[:,t*test_batch_size:(t+1)*test_batch_size] = probs.T ap, acc = perfeval.cls_eval(scores, gt_labels)
print('====================================================================\n')
print('\tDo validation after the {:d}-th training epoch\n'.format(self.cur_epoch))
print('>>>>', end='\t') #设定标记,方便于解析日志获取出数据
for i in range(num_classes):
print('AP[{:d}]={:.2f}'.format(i, ap[i]), end=', ')
mAP = np.average(ap)
print('mAP={:.2f}, Accuracy={:.2f}'.format(mAP, acc))
print('\n====================================================================\n')
step = self.solver.iter
self.sc_val_mAP.add_record(step, mAP)
self.sc_val_acc.add_record(step, acc) if __name__ == '__main__':
args = parse_args()
solver_prototxt = args.solver
num_epoch = args.num_epoch
num_batch = args.num_batch
pretrained_model = args.pretrained_model # init
caffe.set_mode_gpu()
caffe.set_device(0) sw = SolverWrapper(solver_prototxt, num_epoch, num_batch, pretrained_model)
sw.train_model()

perfeval.py

#!/usr/bin/env python2
# coding: utf-8 from __future__ import print_function
import numpy as np import sklearn.metrics as metrics def cls_eval(scores, gt_labels):
"""
分类任务的evaluation
@param scores: cxm np-array, m为样本数量(例如一个epoch)
@param gt_labels: 1xm np-array, 元素属于{0,1,2,...,K-1},表示K个类别的索引
"""
num_classes, num_test_imgs = scores.shape pred_labels = scores.argmax(axis=0) ap = np.zeros((num_classes, 1), dtype=float).squeeze()
for i in range(num_classes):
cls_labels = np.zeros((1, num_test_imgs), dtype=float).squeeze()
for j in range(num_test_imgs):
if gt_labels[j]==i:
cls_labels[j]=1
ap[i] = metrics.average_precision_score(cls_labels, scores[i]) acc = metrics.accuracy_score(gt_labels, pred_labels) return ap, acc

样例输出

首先需要开启训练,比如:

python solve.py

然后启动VisualDL:

visualDL --logdir=catdog_log --port=8080

打开浏览器获取训练的实时更新的绘图输出:http://localhost:8080。这里仅截图展示:





pycaffe训练的完整组件示例的更多相关文章

  1. 利用webuploader插件上传图片文件,完整前端示例demo,服务端使用SpringMVC接收

    利用WebUploader插件上传图片文件完整前端示例demo,服务端使用SpringMVC接收 Webuploader简介   WebUploader是由Baidu WebFE(FEX)团队开发的一 ...

  2. Vue列表组件与弹窗组件示例

    列表组件 <!DOCTYPE html> <html> <head> <meta charset="utf-8" /> <me ...

  3. [Nginx]Nginx的基本配置与优化1(完整配置示例与虚拟主机配置)

    ---------------------------------------------------------------------------------------- 完整配置示例: [ n ...

  4. 实战SpringCloud响应式微服务系列教程(第十章)响应式RESTful服务完整代码示例

    本文为实战SpringCloud响应式微服务系列教程第十章,本章给出响应式RESTful服务完整代码示例.建议没有之前基础的童鞋,先看之前的章节,章节目录放在文末. 1.搭建响应式RESTful服务. ...

  5. [deviceone开发]-do_Socket组件示例

    一.简介 do_Socket只实现了socket的客户端的功能,这个示例完整了展示了组件的基本用法,需要和sockettest3工具配合使用,sockettest3做为一个socket server来 ...

  6. SpringMVC札集(01)——SpringMVC入门完整详细示例(上)

    自定义View系列教程00–推翻自己和过往,重学自定义View 自定义View系列教程01–常用工具介绍 自定义View系列教程02–onMeasure源码详尽分析 自定义View系列教程03–onL ...

  7. android四大组件学习总结以及各个组件示例(1)

    android四大组件分别为activity.service.content provider.broadcast receiver. 一.android四大组件详解 1.activity (1)一个 ...

  8. asp.net core封装layui组件示例分享

    用什么封装?自然是TagHelper啊,是啥?自己瞅文档去 在学习使用TagHelper的时候,最希望的就是能有个Demo能够让自己作为参考 怎么去封装一个组件? 不同的情况怎么去实现? 有没有更好更 ...

  9. WebRTC 音频采样算法 附完整C++示例代码

    之前有大概介绍了音频采样相关的思路,详情见<简洁明了的插值音频重采样算法例子 (附完整C代码)>. 音频方面的开源项目很多很多. 最知名的莫过于谷歌开源的WebRTC, 其中的音频模块就包 ...

随机推荐

  1. sqlserver开窗函数在财务对账中的用法

    曾几何时发现开窗函数在财务对账总特别好用.但是每次可能很久没用,逻辑都要重头来过.特此留一份完整的思考逻辑待日后参考. 以下是数据源: 从上面的数据可以看到通过C列,那么只需要两个条件即可获得已经用对 ...

  2. [Jenkins]CentOS7下Jenkins搭建

    最近在倒腾Kubernetes的一些东西,这次需要用到Jenkins来实现自动化构建.来讲一讲搭建的整个过程. Jenkins是什么 Jenkins提供了软件开发的持续集成服务.它运行在Servlet ...

  3. python模块之sniffio

    嗅探python用了哪个异步库 from sniffio import current_async_library import trio import asyncio async def print ...

  4. ansible 常见指令表

    Play 指令 说明 accelerate 开启加速模式 accelerate_ipv6 是否开启ipv6 accelerate_port 加速模式的端口 always_run   any_error ...

  5. 花神的数论题(这题...哎。数位dp咋就这么 not naive 呢)

    题意简介 没什么好说,就是让你求出 1 ~ n 之间每个数转化为二进制后 '1' 的个数,然后乘起来输出积 题目分析 emmmm.... 两种解法(同是 $O(\log^2 N)$ 的算法,组合数效率 ...

  6. docker简单入门之使用docker容器部署简单的java web开源项目jpress博客程序

    一.在centos7.3上安装docker 前置条件 x86_64-bit 系统 kernel 3.10+ .检查内核版本,返回的值大于3.10即可 [root@node1 ~]# uname -r ...

  7. JS ----实现复制粘贴功能 (剪切板应用clipboardData)

    注意:ie7,与ie8 对网页有个复制的权限,需在“安全”中的“自定义级别”的脚本中设置 clipboardData 对象 提供了对剪贴板的访问. 三个方法 :1.clearData(sDataFor ...

  8. CentOS 7安装Python3.5

    CentOS 7下安装Python3.5 •安装python3.5可能使用的依赖 yum install openssl-devel bzip2-devel expat-devel gdbm-deve ...

  9. ipfs上传下载

    上传下载步骤: 启动ipfs节点服务器: 页面效果显示如下: 当在一个终端启动ipfs节点服务器之后之后,上传下载步骤: 1.创建文件demo4,新建一个文件a.txt,文本内容为hello mkdi ...

  10. __dict__(字典的另一种用法)

    class Foo(): def __init__(self): self.name=None self.age=19 self.addr='上海' @property def dict(self): ...