#include <stdio.h>  // for snprintf
#include <string>
#include <vector> #include "boost/algorithm/string.hpp"
#include "google/protobuf/text_format.h" #include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/net.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp" using caffe::Blob;
using caffe::Caffe;
using caffe::Datum;
using caffe::Net;
using boost::shared_ptr;
using std::string;
namespace db = caffe::db; template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv); int main(int argc, char** argv) {
return feature_extraction_pipeline<float>(argc, argv);
// return feature_extraction_pipeline<double>(argc, argv);
} template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv) {
::google::InitGoogleLogging(argv[]);
const int num_required_args = ;
if (argc < num_required_args) {
LOG(ERROR)<<
"This program takes in a trained network and an input data layer, and then"
" extract features of the input data produced by the net.\n"
"Usage: extract_features pretrained_net_param"
" feature_extraction_proto_file extract_feature_blob_name1[,name2,...]"
" save_feature_dataset_name1[,name2,...] num_mini_batches db_type"
" [CPU/GPU] [DEVICE_ID=0]\n"
"Note: you can extract multiple features in one pass by specifying"
" multiple feature blob names and dataset names separated by ','."
" The names cannot contain white space characters and the number of blobs"
" and datasets must be equal.";
return ;
}
int arg_pos = num_required_args; arg_pos = num_required_args;
if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == ) {
LOG(ERROR)<< "Using GPU";
uint device_id = ;
if (argc > arg_pos + ) {
device_id = atoi(argv[arg_pos + ]);
CHECK_GE(device_id, );
}
LOG(ERROR) << "Using Device_id=" << device_id;
Caffe::SetDevice(device_id);
Caffe::set_mode(Caffe::GPU);
} else {
LOG(ERROR) << "Using CPU";
Caffe::set_mode(Caffe::CPU);
} arg_pos = ; // the name of the executable
std::string pretrained_binary_proto(argv[++arg_pos]); // Expected prototxt contains at least one data layer such as
// the layer data_layer_name and one feature blob such as the
// fc7 top blob to extract features.
/*
layers {
name: "data_layer_name"
type: DATA
data_param {
source: "/path/to/your/images/to/extract/feature/images_leveldb"
mean_file: "/path/to/your/image_mean.binaryproto"
batch_size: 128
crop_size: 227
mirror: false
}
top: "data_blob_name"
top: "label_blob_name"
}
layers {
name: "drop7"
type: DROPOUT
dropout_param {
dropout_ratio: 0.5
}
bottom: "fc7"
top: "fc7"
}
*/
std::string feature_extraction_proto(argv[++arg_pos]);
shared_ptr<Net<Dtype> > feature_extraction_net(
new Net<Dtype>(feature_extraction_proto, caffe::TEST));
feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto); std::string extract_feature_blob_names(argv[++arg_pos]);
std::vector<std::string> blob_names;
boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); std::string save_feature_dataset_names(argv[++arg_pos]);
std::vector<std::string> dataset_names;
boost::split(dataset_names, save_feature_dataset_names,
boost::is_any_of(","));
CHECK_EQ(blob_names.size(), dataset_names.size()) <<
" the number of blob names and dataset names must be equal";
size_t num_features = blob_names.size(); for (size_t i = ; i < num_features; i++) {
CHECK(feature_extraction_net->has_blob(blob_names[i]))
<< "Unknown feature blob name " << blob_names[i]
<< " in the network " << feature_extraction_proto;
} int num_mini_batches = atoi(argv[++arg_pos]); std::vector<shared_ptr<db::DB> > feature_dbs;
std::vector<shared_ptr<db::Transaction> > txns;
const char* db_type = argv[++arg_pos];
for (size_t i = ; i < num_features; ++i) {
LOG(INFO)<< "Opening dataset " << dataset_names[i];
shared_ptr<db::DB> db(db::GetDB(db_type));
db->Open(dataset_names.at(i), db::NEW);
feature_dbs.push_back(db);
shared_ptr<db::Transaction> txn(db->NewTransaction());
txns.push_back(txn);
} LOG(ERROR)<< "Extacting Features"; Datum datum;
const int kMaxKeyStrLength = ;
char key_str[kMaxKeyStrLength];
std::vector<Blob<float>*> input_vec;
std::vector<int> image_indices(num_features, );
for (int batch_index = ; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec);
for (int i = ; i < num_features; ++i) {
const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
->blob_by_name(blob_names[i]);
int batch_size = feature_blob->num();
int dim_features = feature_blob->count() / batch_size;
const Dtype* feature_blob_data;
for (int n = ; n < batch_size; ++n) {
datum.set_height(feature_blob->height());
datum.set_width(feature_blob->width());
datum.set_channels(feature_blob->channels());
datum.clear_data();
datum.clear_float_data();
feature_blob_data = feature_blob->cpu_data() +
feature_blob->offset(n);
for (int d = ; d < dim_features; ++d) {
datum.add_float_data(feature_blob_data[d]);
}
int length = snprintf(key_str, kMaxKeyStrLength, "%010d",
image_indices[i]);
string out;
CHECK(datum.SerializeToString(&out));
txns.at(i)->Put(std::string(key_str, length), out);
++image_indices[i];
if (image_indices[i] % == ) {
txns.at(i)->Commit();
txns.at(i).reset(feature_dbs.at(i)->NewTransaction());
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
" query images for feature blob " << blob_names[i];
}
} // for (int n = 0; n < batch_size; ++n)
} // for (int i = 0; i < num_features; ++i)
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch
for (int i = ; i < num_features; ++i) {
if (image_indices[i] % != ) {
txns.at(i)->Commit();
}
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
" query images for feature blob " << blob_names[i];
feature_dbs.at(i)->Close();
} LOG(ERROR)<< "Successfully extracted the features!";
return ;
}

caffe: test code for Deep Learning approach的更多相关文章

  1. 论文笔记之:From Facial Parts Responses to Face Detection: A Deep Learning Approach

    From Facial Parts Responses to Face Detection: A Deep Learning Approach ICCV 2015 从以上两张图就可以感受到本文所提方法 ...

  2. 《3-D Deep Learning Approach for Remote Sensing Image Classification》论文笔记

    论文题目<3-D Deep Learning Approach for Remote Sensing Image Classification> 论文作者:Amina Ben Hamida ...

  3. 论文阅读 | DeepDrawing: A Deep Learning Approach to Graph Drawing

    作者:Yong Wang, Zhihua Jin, Qianwen Wang, Weiwei Cui, Tengfei Ma and Huamin Qu 本文发表于VIS2019, 来自于香港科技大学 ...

  4. (转) Awesome Deep Learning

    Awesome Deep Learning  Table of Contents Free Online Books Courses Videos and Lectures Papers Tutori ...

  5. 机器学习(Machine Learning)&深度学习(Deep Learning)资料【转】

    转自:机器学习(Machine Learning)&深度学习(Deep Learning)资料 <Brief History of Machine Learning> 介绍:这是一 ...

  6. What are some good books/papers for learning deep learning?

    What's the most effective way to get started with deep learning?       29 Answers     Yoshua Bengio, ...

  7. 机器学习(Machine Learning)与深度学习(Deep Learning)资料汇总

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  8. (转) Awesome - Most Cited Deep Learning Papers

    转自:https://github.com/terryum/awesome-deep-learning-papers Awesome - Most Cited Deep Learning Papers ...

  9. 深度学习阅读列表 Deep Learning Reading List

    Reading List List of reading lists and survey papers: Books Deep Learning, Yoshua Bengio, Ian Goodfe ...

随机推荐

  1. 转:Java面试题集(1-50)

    Java程序员面试题集(1-50) http://blog.csdn.net/jackfrued/article/details/17403101 一.Java基础部分 1.面向对象的特征有哪些方面? ...

  2. swift简介

    概述 Swift是苹果2014年推出的全新的编程语言,它继承了C语言.ObjC的特性,且克服了C语言的兼容性问题.Swift发展过程中不仅保留了ObjC很多语法特性,它也借鉴了多种现代化语言的特点,在 ...

  3. Linux命令(2)-rm删除文件

    版本:centos7 Linux中使用rm(remove)命令将文件从磁盘上永久删除.使用-r参数可以删除目录及目录下的子目录.对于连接文件只是断开了连接,源文件保持不变.用户删除一个文件时需要对该文 ...

  4. JQuery源码分析(三)

    jQuery中ready与load事件 jQuery有3种针对文档加载的方法 $(document).ready(function() { // ...代码... }) //document read ...

  5. Android PermissionChecker 权限全面详细分析和解决方案

    原文: http://www.2cto.com/kf/201512/455888.html http://blog.csdn.net/yangqingqo/article/details/483711 ...

  6. Unity3D ShaderLab立方体图的法线渲染

    Unity3D ShaderLab立方体图的法线渲染 某些情况下,我们希望立方体图的材质球上产生法线效果,来更多的表现细节,比如菱形花纹的玻璃,冰块的表面. 在帧数的协调下,我们可以通过input结构 ...

  7. NSIS

    NSIS 是“Nullsoft 脚本安装系统”(Nullsoft Scriptable Installation System) 的缩写,它是一个免费的 Win32 安装.卸载系统,采用了简洁高效的脚 ...

  8. [转]Golang之struct类型

    http://blog.chinaunix.net/xmlrpc.php?r=blog/article&uid=22312037&id=3756923 一.struct        ...

  9. clone 深拷贝 浅拷贝

    1. 定义:知道一个对象,但不知道类,想要得到该对象相同的一个副本,在修改该对象的属性时,副本属性不修改,clone的是对象的属性 2. 意义:当一个对象里很多属性,想要得到一个相同的对象,还有set ...

  10. dedecms 列表页 list 判断flag给定指定样式 (本地测试有效)

    {dede:list pagesize='10'} [field:array runphp='yes'] if (@me['flag']=='a') @me=' <a class="n ...