#include <stdio.h>  // for snprintf
#include <string>
#include <vector> #include "boost/algorithm/string.hpp"
#include "google/protobuf/text_format.h" #include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/net.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp" using caffe::Blob;
using caffe::Caffe;
using caffe::Datum;
using caffe::Net;
using boost::shared_ptr;
using std::string;
namespace db = caffe::db; template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv); int main(int argc, char** argv) {
return feature_extraction_pipeline<float>(argc, argv);
// return feature_extraction_pipeline<double>(argc, argv);
} template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv) {
::google::InitGoogleLogging(argv[]);
const int num_required_args = ;
if (argc < num_required_args) {
LOG(ERROR)<<
"This program takes in a trained network and an input data layer, and then"
" extract features of the input data produced by the net.\n"
"Usage: extract_features pretrained_net_param"
" feature_extraction_proto_file extract_feature_blob_name1[,name2,...]"
" save_feature_dataset_name1[,name2,...] num_mini_batches db_type"
" [CPU/GPU] [DEVICE_ID=0]\n"
"Note: you can extract multiple features in one pass by specifying"
" multiple feature blob names and dataset names separated by ','."
" The names cannot contain white space characters and the number of blobs"
" and datasets must be equal.";
return ;
}
int arg_pos = num_required_args; arg_pos = num_required_args;
if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == ) {
LOG(ERROR)<< "Using GPU";
uint device_id = ;
if (argc > arg_pos + ) {
device_id = atoi(argv[arg_pos + ]);
CHECK_GE(device_id, );
}
LOG(ERROR) << "Using Device_id=" << device_id;
Caffe::SetDevice(device_id);
Caffe::set_mode(Caffe::GPU);
} else {
LOG(ERROR) << "Using CPU";
Caffe::set_mode(Caffe::CPU);
} arg_pos = ; // the name of the executable
std::string pretrained_binary_proto(argv[++arg_pos]); // Expected prototxt contains at least one data layer such as
// the layer data_layer_name and one feature blob such as the
// fc7 top blob to extract features.
/*
layers {
name: "data_layer_name"
type: DATA
data_param {
source: "/path/to/your/images/to/extract/feature/images_leveldb"
mean_file: "/path/to/your/image_mean.binaryproto"
batch_size: 128
crop_size: 227
mirror: false
}
top: "data_blob_name"
top: "label_blob_name"
}
layers {
name: "drop7"
type: DROPOUT
dropout_param {
dropout_ratio: 0.5
}
bottom: "fc7"
top: "fc7"
}
*/
std::string feature_extraction_proto(argv[++arg_pos]);
shared_ptr<Net<Dtype> > feature_extraction_net(
new Net<Dtype>(feature_extraction_proto, caffe::TEST));
feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto); std::string extract_feature_blob_names(argv[++arg_pos]);
std::vector<std::string> blob_names;
boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); std::string save_feature_dataset_names(argv[++arg_pos]);
std::vector<std::string> dataset_names;
boost::split(dataset_names, save_feature_dataset_names,
boost::is_any_of(","));
CHECK_EQ(blob_names.size(), dataset_names.size()) <<
" the number of blob names and dataset names must be equal";
size_t num_features = blob_names.size(); for (size_t i = ; i < num_features; i++) {
CHECK(feature_extraction_net->has_blob(blob_names[i]))
<< "Unknown feature blob name " << blob_names[i]
<< " in the network " << feature_extraction_proto;
} int num_mini_batches = atoi(argv[++arg_pos]); std::vector<shared_ptr<db::DB> > feature_dbs;
std::vector<shared_ptr<db::Transaction> > txns;
const char* db_type = argv[++arg_pos];
for (size_t i = ; i < num_features; ++i) {
LOG(INFO)<< "Opening dataset " << dataset_names[i];
shared_ptr<db::DB> db(db::GetDB(db_type));
db->Open(dataset_names.at(i), db::NEW);
feature_dbs.push_back(db);
shared_ptr<db::Transaction> txn(db->NewTransaction());
txns.push_back(txn);
} LOG(ERROR)<< "Extacting Features"; Datum datum;
const int kMaxKeyStrLength = ;
char key_str[kMaxKeyStrLength];
std::vector<Blob<float>*> input_vec;
std::vector<int> image_indices(num_features, );
for (int batch_index = ; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec);
for (int i = ; i < num_features; ++i) {
const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
->blob_by_name(blob_names[i]);
int batch_size = feature_blob->num();
int dim_features = feature_blob->count() / batch_size;
const Dtype* feature_blob_data;
for (int n = ; n < batch_size; ++n) {
datum.set_height(feature_blob->height());
datum.set_width(feature_blob->width());
datum.set_channels(feature_blob->channels());
datum.clear_data();
datum.clear_float_data();
feature_blob_data = feature_blob->cpu_data() +
feature_blob->offset(n);
for (int d = ; d < dim_features; ++d) {
datum.add_float_data(feature_blob_data[d]);
}
int length = snprintf(key_str, kMaxKeyStrLength, "%010d",
image_indices[i]);
string out;
CHECK(datum.SerializeToString(&out));
txns.at(i)->Put(std::string(key_str, length), out);
++image_indices[i];
if (image_indices[i] % == ) {
txns.at(i)->Commit();
txns.at(i).reset(feature_dbs.at(i)->NewTransaction());
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
" query images for feature blob " << blob_names[i];
}
} // for (int n = 0; n < batch_size; ++n)
} // for (int i = 0; i < num_features; ++i)
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch
for (int i = ; i < num_features; ++i) {
if (image_indices[i] % != ) {
txns.at(i)->Commit();
}
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
" query images for feature blob " << blob_names[i];
feature_dbs.at(i)->Close();
} LOG(ERROR)<< "Successfully extracted the features!";
return ;
}

caffe: test code for Deep Learning approach的更多相关文章

  1. 论文笔记之:From Facial Parts Responses to Face Detection: A Deep Learning Approach

    From Facial Parts Responses to Face Detection: A Deep Learning Approach ICCV 2015 从以上两张图就可以感受到本文所提方法 ...

  2. 《3-D Deep Learning Approach for Remote Sensing Image Classification》论文笔记

    论文题目<3-D Deep Learning Approach for Remote Sensing Image Classification> 论文作者:Amina Ben Hamida ...

  3. 论文阅读 | DeepDrawing: A Deep Learning Approach to Graph Drawing

    作者:Yong Wang, Zhihua Jin, Qianwen Wang, Weiwei Cui, Tengfei Ma and Huamin Qu 本文发表于VIS2019, 来自于香港科技大学 ...

  4. (转) Awesome Deep Learning

    Awesome Deep Learning  Table of Contents Free Online Books Courses Videos and Lectures Papers Tutori ...

  5. 机器学习(Machine Learning)&深度学习(Deep Learning)资料【转】

    转自:机器学习(Machine Learning)&深度学习(Deep Learning)资料 <Brief History of Machine Learning> 介绍:这是一 ...

  6. What are some good books/papers for learning deep learning?

    What's the most effective way to get started with deep learning?       29 Answers     Yoshua Bengio, ...

  7. 机器学习(Machine Learning)与深度学习(Deep Learning)资料汇总

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  8. (转) Awesome - Most Cited Deep Learning Papers

    转自:https://github.com/terryum/awesome-deep-learning-papers Awesome - Most Cited Deep Learning Papers ...

  9. 深度学习阅读列表 Deep Learning Reading List

    Reading List List of reading lists and survey papers: Books Deep Learning, Yoshua Bengio, Ian Goodfe ...

随机推荐

  1. LepideMigrator for Documents Step by Step

    blog: http://blog.csdn.net/foxdave A Manager Marketing Operations invite me to review their product, ...

  2. php中的include()的使用技巧

    php中的include()的使用技巧 include() 语句包括并运行指定文件. 以下文档也适用于 require().这两种结构除了在如何处理失败之外完全一样.include() 产生一个警告而 ...

  3. JS中阻止默认事件与事件冒泡

    JQuery 提供了两种方式来阻止事件冒泡. 方式一:event.stopPropagation(); $("#div1").mousedown(function(event){ ...

  4. python中的面向对象编程

    在python中几乎可以完成C++里所有面向对象编程的元素. 继承:python支持多继承: class Derived(base1, base2, base3): pass 多态:python中的所 ...

  5. Mobile phones_二维树状数组

    [题意]给你一个矩阵(初始化为0)和一些操作,1 x y a表示在arr[x][y]加上a,2 l b r t 表示求左上角为(l,b),右下角为(r,t)的矩阵的和. [思路]帮助更好理解树状数组. ...

  6. windows系统mysql定时自动备份

    MySQL Administrator 工具是MySQL官方的数据库管理工具,包含在MySQL GUI Tools中,可在MySQL官方网站下载到,下载地址:http://dev.mysql.com/ ...

  7. OLAP的一些知识——接下去的项目需要的背景

    1.维是人们观察主题的特定角度,每一个维分别用一个表来描述,称为“维表”(Dimension Table),它是对维的详细描述. 2.事实表示所关注的主题,亦由表来描述,称为“事实表”(Fact Ta ...

  8. Windows Phone 8.1 Page transitions

    original: http://www.visuallylocated.com/post/2014/06/24/Page-transitions-and-animations-in-Windows- ...

  9. CentOS安装Nginx负载

    一.准备 1.安装Nginx 地址:http://www.cnblogs.com/rainy-shurun/p/4983260.html 二.配置负载 1.配置nginx.conf

  10. 博客Mac桌面编辑器-cnblogs

    Mac篇 公司的机器内存只有8G,不想再大动干戈为了Windows Live Writer装个Vmware了,谷歌娘讲MarsEdit不错,那就试试用这个写个试用贴呗   就是这货了,果然是火星来的, ...