Caffe是目前深度学习比较优秀好用的一个开源库,采样c++和CUDA实现,具有速度快,模型定义方便等优点。学习了几天过后,发现也有一个不方便的地方,就是在我的程序中调用Caffe做图像分类没有直接的接口。Caffe的数据层可以从数据库(支持leveldb、lmdb、hdf5)、图片、和内存中读入。我们要在程序中使用,当然得从内存中读入。参见http://caffe.berkeleyvision.org/tutorial/layers.html#data-layers和MemoryDataLayer源码,我们首先在模型定义文件中定义数据层:

layers {
name: "mydata"
type: MEMORY_DATA
top: "data"
top: "label"
transform_param {
scale: 0.00390625
}
memory_data_param {
batch_size: 10
channels: 1
height: 24
width: 24
}
}

这里必须设置memory_data_param中的四个参数,对应这些参数可以参见源码中caffe.proto文件。现在,我们可以设计一个Classifier类来封装一下:

#ifndef CAFFE_CLASSIFIER_H
#define CAFFE_CLASSIFIER_H #include <string>
#include <vector>
#include "caffe/net.hpp"
#include "caffe/data_layers.hpp"
#include <opencv2/core.hpp>
using cv::Mat; namespace caffe { template <typename Dtype>
class Classifier {
public:
explicit Classifier(const string& param_file, const string& weights_file);
Dtype test(vector<Mat> &images, vector<int> &labels, int iter_num);
virtual ~Classifier() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
void predict(vector<Mat> &images, vector<int> *labels);
void predict(vector<Dtype> &data, vector<int> *labels, int num);
void extract_feature(vector<Mat> &images, vector<vector<Dtype>> *out); protected:
shared_ptr<Net<Dtype> > net_;
MemoryDataLayer<Dtype> *m_layer_;
int batch_size_;
int channels_;
int height_;
int width_; DISABLE_COPY_AND_ASSIGN(Classifier);
};
}//namespace
#endif //CAFFE_CLASSIFIER_H

构造函数中我们通过模型定义文件(.prototxt)和训练好的模型(.caffemodel)文件构造一个Net对象,并用m_layer_指向Net中的memory data层,以便待会调用MemoryDataLayer中AddMatVector和Reset函数加入数据。

#include <cstdio>

#include <algorithm>
#include <string>
#include <vector> #include "caffe/net.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/upgrade_proto.hpp"
#include "caffe_classifier.h" namespace caffe { template <typename Dtype>
Classifier<Dtype>::Classifier(const string& param_file, const string& weights_file) : net_()
{
net_.reset(new Net<Dtype>(param_file, TEST));
net_->CopyTrainedLayersFrom(weights_file);
//m_layer_ = (MemoryDataLayer<Dtype>*)net_->layer_by_name("mnist").get();
m_layer_ = (MemoryDataLayer<Dtype>*)net_->layers()[0].get();
batch_size_ = m_layer_->batch_size();
channels_ = m_layer_->channels();
height_ = m_layer_->height();
width_ = m_layer_->width();
} template <typename Dtype>
Dtype Classifier<Dtype>::test(vector<Mat> &images, vector<int> &labels, int iter_num)
{
m_layer_->AddMatVector(images, labels);
//
int iterations = iter_num;
vector<Blob<Dtype>* > bottom_vec; vector<int> test_score_output_id;
vector<Dtype> test_score;
Dtype loss = 0;
for (int i = 0; i < iterations; ++i) {
Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
net_->Forward(bottom_vec, &iter_loss);
loss += iter_loss;
int idx = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k, ++idx) {
const Dtype score = result_vec[k];
if (i == 0) {
test_score.push_back(score);
test_score_output_id.push_back(j);
} else {
test_score[idx] += score;
}
const std::string& output_name = net_->blob_names()[
net_->output_blob_indices()[j]];
LOG(INFO) << "Batch " << i << ", " << output_name << " = " << score;
}
}
}
loss /= iterations;
LOG(INFO) << "Loss: " << loss;
return loss;
} template <typename Dtype>
void Classifier<Dtype>::predict(vector<Mat> &images, vector<int> *labels)
{
int original_length = images.size();
if(original_length == 0)
return;
int valid_length = original_length / batch_size_ * batch_size_;
if(original_length != valid_length)
{
valid_length += batch_size_;
for(int i = original_length; i < valid_length; i++)
{
images.push_back(images[0].clone());
}
}
vector<int> valid_labels, predicted_labels;
valid_labels.resize(valid_length, 0);
m_layer_->AddMatVector(images, valid_labels);
vector<Blob<Dtype>* > bottom_vec;
for(int i = 0; i < valid_length / batch_size_; i++)
{
const vector<Blob<Dtype>*>& result = net_->Forward(bottom_vec);
const Dtype * result_vec = result[1]->cpu_data();
for(int j = 0; j < result[1]->count(); j++)
{
predicted_labels.push_back(result_vec[j]);
}
}
if(original_length != valid_length)
{
images.erase(images.begin()+original_length, images.end());
}
labels->resize(original_length, 0);
std::copy(predicted_labels.begin(), predicted_labels.begin() + original_length, labels->begin());
} template <typename Dtype>
void Classifier<Dtype>::predict(vector<Dtype> &data, vector<int> *labels, int num)
{
int size = channels_*height_*width_;
CHECK_EQ(data.size(), num*size);
int original_length = num;
if(original_length == 0)
return;
int valid_length = original_length / batch_size_ * batch_size_;
if(original_length != valid_length)
{
valid_length += batch_size_;
for(int i = original_length; i < valid_length; i++)
{
for(int j = 0; j < size; j++)
data.push_back(0);
}
}
vector<int> predicted_labels;
Dtype * label_ = new Dtype[valid_length];
memset(label_, 0, valid_length);
m_layer_->Reset(data.data(), label_, valid_length);
vector<Blob<Dtype>* > bottom_vec;
for(int i = 0; i < valid_length / batch_size_; i++)
{
const vector<Blob<Dtype>*>& result = net_->Forward(bottom_vec);
const Dtype * result_vec = result[1]->cpu_data();
for(int j = 0; j < result[1]->count(); j++)
{
predicted_labels.push_back(result_vec[j]);
}
}
if(original_length != valid_length)
{
data.erase(data.begin()+original_length*size, data.end());
}
delete [] label_;
labels->resize(original_length, 0);
std::copy(predicted_labels.begin(), predicted_labels.begin() + original_length, labels->begin());
}
template <typename Dtype>
void Classifier<Dtype>::extract_feature(vector<Mat> &images, vector<vector<Dtype>> *out)
{
int original_length = images.size();
if(original_length == 0)
return;
int valid_length = original_length / batch_size_ * batch_size_;
if(original_length != valid_length)
{
valid_length += batch_size_;
for(int i = original_length; i < valid_length; i++)
{
images.push_back(images[0].clone());
}
}
vector<int> valid_labels;
valid_labels.resize(valid_length, 0);
m_layer_->AddMatVector(images, valid_labels);
vector<Blob<Dtype>* > bottom_vec;
out->clear();
for(int i = 0; i < valid_length / batch_size_; i++)
{
const vector<Blob<Dtype>*>& result = net_->Forward(bottom_vec);
const Dtype * result_vec = result[0]->cpu_data();
const int dim = result[0]->count(1);
for(int j = 0; j < result[0]->num(); j++)
{
const Dtype * ptr = result_vec + j * dim;
vector<Dtype> one_;
for(int k = 0; k < dim; ++k)
one_.push_back(ptr[k]);
out->push_back(one_);
}
}
if(original_length != valid_length)
{
images.erase(images.begin()+original_length, images.end());
out->erase(out->begin()+original_length, out->end());
}
}
INSTANTIATE_CLASS(Classifier);
} // namespace caffe

由于加入的数据个数必须是batch_size的整数倍,所以我们在加入数据时采用填充的方式。

CHECK_EQ(num % batch_size_, 0) <<
"The added data must be a multiple of the batch size.";  //AddMatVector

在模型文件的最后,我们把训练时的loss层改为argmax层:

layers {
name: "predicted"
type: ARGMAX
bottom: "prob"
top: "predicted"
}

作者:waring  出处:http://www.cnblogs.com/waring  欢迎转载或分享,但请务必声明文章出处。

如何在程序中调用Caffe做图像分类的更多相关文章

  1. 在网页程序或Java程序中调用接口实现短信猫收发短信的解决方案

    方案特点: 在网页程序或Java程序中调用接口实现短信猫收发短信的解决方案,简化软件开发流程,减少各应用系统相同模块的重复开发工作,提高系统稳定性和可靠性. 基于HTTP协议的开发接口 使用特点在网页 ...

  2. Native Application 开发详解(直接在程序中调用 ntdll.dll 中的 Native API,有内存小、速度快、安全、API丰富等8大优点)

    文章目录:                   1. 引子: 2. Native Application Demo 展示: 3. Native Application 简介: 4. Native Ap ...

  3. windows下用c++调用caffe做前向

    参考博客: https://blog.csdn.net/muyouhang/article/details/54773265 https://blog.csdn.net/hhh0209/article ...

  4. C++程序中调用WebService的实现

    前言 因为最近的项目中需要运用到在MFC程序中调用WebService里面集成好了的函数,所以特意花了一天的时间来研究WebService的构建以及如何在MFC的程序中添加Web引用,进而来实现在C+ ...

  5. 从C#程序中调用非受管DLLs

    从C#程序中调用非受管DLLs 文章概要: 众所周知,.NET已经渐渐成为一种技术时尚,那么C#很自然也成为一种编程时尚.如何利用浩如烟海的Win32 API以及以前所编写的 Win32 代码已经成为 ...

  6. iOS程序中调用系统自带应用(短信,邮件,浏览器,地图,appstore,拨打电话,iTunes,iBooks )

    在网上找到了下在记录下来以后方便用 在程序中调用系统自带的应用,比如我进入程序的时候,希望直接调用safar来打开一个网页,下面是一个简单的使用:

  7. Java程序中调用Python脚本的方法

    在程序开发中,有时候需要Java程序中调用相关Python脚本,以下内容记录了先关步骤和可能出现问题的解决办法. 1.在Eclipse中新建Maven工程: 2.pom.xml文件中添加如下依赖包之后 ...

  8. 利用 gnuplot_i 在你的 c 程序中调用 GNUPLOT

    这是一篇非常早曾经写的小文章,最初发表于我的搜狐博客(2008-09-23 22:55).由于自从转移到这里后,sohu 博客就不再维护了,所以把这篇文章也一起挪了过来. GNUPLOT 是一款功能强 ...

  9. ASP程序中调用Now()总显示“上午”和“下午”,如何解决?

    ASP程序中调用Now()总显示这样的格式:“2007-4-20 下午 06:06:38”,我要的正确格式为“2007-4-20 18:06:38”,我已经通过控制面板==>区域和语言选项==& ...

随机推荐

  1. Lambda 表达式的示例-来源(MSDN)

    本文演示如何在你的程序中使用 lambda 表达式. 有关 lambda 表达式的概述,请参阅 C++ 中的 Lambda 表达式. 有关 lambda 表达式结构的详细信息,请参阅 Lambda 表 ...

  2. 关于bootstrap--表单(下拉<select>、输入框<input>、文本域<textare>复选框<checkbox>和单选按钮<radio>)

    html 里面的 role 本质上是增强语义性,当现有的HTML标签不能充分表达语义性的时候,就可以借助role来说明.通常这种情况出现在一些自定义的组件上,这样可增强组件的可访问性.可用性和可交互性 ...

  3. HTML5新增的一些属性和功能之六——拖拽事件

    拖放事件的前提是分为源对象和目标对象,你鼠标拖着的是源对象,你要放置的位置是目标对象,区分这两个对象是因为HTML5的拖放事件对两者是不同的. 被拖动的源对象可以触发的事件: 1).ondragsta ...

  4. Swift 新语言开发

    全书文件夹: 一.Welcome to Swift 二.Language Guide 三.Language Reference /* 译者的废话: 几个小时前熬夜看了WWDC,各种激动,今年非常有料啊 ...

  5. JMeter数据库性能测试

    要测试一个服务器的性能,客户要求向数据库内 1000/s(每插入一千条数据)的处理能力 前提条件:一个数据库:test   数据库下面有一张表:user   表中有两个字段:username.pass ...

  6. openssl 证书请求和自签名命令req详解

    1.密钥.证书请求.证书概要说明 在证书申请签发过程中,客户端涉及到密钥.证书请求.证书这几个概念,初学者可能会搞不清楚三者的关系,网上有的根据后缀名来区分三者,更让人一头雾水.我们以申请证书的流程说 ...

  7. AngularJS初步

    AngularJS特点 遵循AMD规范 不需要操作节点 对于jquery,一般是利用现有完整的DOM,然后在这戏Dom的基础上进行二次调教了:而对于AngularJS等框架则是根据数据模型以及其对应用 ...

  8. 《第一行代码》学习笔记18-广播接收器Broadcast_Receiver(1)

    1.网络通信原理,在一个IP网络范围内最大的IP地址是被保留作为广播地址来使用的.某个网络的IP 范围是192.168.0.XXX, 子网掩码是255.255.255.0,则该网络的广播地址是192. ...

  9. 客户Oracle数据库在插入数据的时候报超出最大长度的错误(规避风险)

    背景: 项目使用oracle数据,在开发环境测试一些正常.项目部署到客户的服务器上后,系统在添加数据的时候报错.输出错误信息,发现是“超出最大长度”的异常. 但是按照数据库的设计,添加的数据应该在允许 ...

  10. C#中子窗体获取父窗体中控件的内容

    今天在做一个联系人管理的C#设计时,遇到了这个问题,我需要将父窗体中的textBox中的值传到子窗体并进行数据库查询操作,我用了new 父窗体().textBox.text;来进行值传递,然而并无卵用 ...