caffe: test code for Deep Learning approach
#include <stdio.h> // for snprintf
#include <string>
#include <vector> #include "boost/algorithm/string.hpp"
#include "google/protobuf/text_format.h" #include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/net.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp" using caffe::Blob;
using caffe::Caffe;
using caffe::Datum;
using caffe::Net;
using boost::shared_ptr;
using std::string;
namespace db = caffe::db; template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv); int main(int argc, char** argv) {
return feature_extraction_pipeline<float>(argc, argv);
// return feature_extraction_pipeline<double>(argc, argv);
} template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv) {
::google::InitGoogleLogging(argv[]);
const int num_required_args = ;
if (argc < num_required_args) {
LOG(ERROR)<<
"This program takes in a trained network and an input data layer, and then"
" extract features of the input data produced by the net.\n"
"Usage: extract_features pretrained_net_param"
" feature_extraction_proto_file extract_feature_blob_name1[,name2,...]"
" save_feature_dataset_name1[,name2,...] num_mini_batches db_type"
" [CPU/GPU] [DEVICE_ID=0]\n"
"Note: you can extract multiple features in one pass by specifying"
" multiple feature blob names and dataset names separated by ','."
" The names cannot contain white space characters and the number of blobs"
" and datasets must be equal.";
return ;
}
int arg_pos = num_required_args; arg_pos = num_required_args;
if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == ) {
LOG(ERROR)<< "Using GPU";
uint device_id = ;
if (argc > arg_pos + ) {
device_id = atoi(argv[arg_pos + ]);
CHECK_GE(device_id, );
}
LOG(ERROR) << "Using Device_id=" << device_id;
Caffe::SetDevice(device_id);
Caffe::set_mode(Caffe::GPU);
} else {
LOG(ERROR) << "Using CPU";
Caffe::set_mode(Caffe::CPU);
} arg_pos = ; // the name of the executable
std::string pretrained_binary_proto(argv[++arg_pos]); // Expected prototxt contains at least one data layer such as
// the layer data_layer_name and one feature blob such as the
// fc7 top blob to extract features.
/*
layers {
name: "data_layer_name"
type: DATA
data_param {
source: "/path/to/your/images/to/extract/feature/images_leveldb"
mean_file: "/path/to/your/image_mean.binaryproto"
batch_size: 128
crop_size: 227
mirror: false
}
top: "data_blob_name"
top: "label_blob_name"
}
layers {
name: "drop7"
type: DROPOUT
dropout_param {
dropout_ratio: 0.5
}
bottom: "fc7"
top: "fc7"
}
*/
std::string feature_extraction_proto(argv[++arg_pos]);
shared_ptr<Net<Dtype> > feature_extraction_net(
new Net<Dtype>(feature_extraction_proto, caffe::TEST));
feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto); std::string extract_feature_blob_names(argv[++arg_pos]);
std::vector<std::string> blob_names;
boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); std::string save_feature_dataset_names(argv[++arg_pos]);
std::vector<std::string> dataset_names;
boost::split(dataset_names, save_feature_dataset_names,
boost::is_any_of(","));
CHECK_EQ(blob_names.size(), dataset_names.size()) <<
" the number of blob names and dataset names must be equal";
size_t num_features = blob_names.size(); for (size_t i = ; i < num_features; i++) {
CHECK(feature_extraction_net->has_blob(blob_names[i]))
<< "Unknown feature blob name " << blob_names[i]
<< " in the network " << feature_extraction_proto;
} int num_mini_batches = atoi(argv[++arg_pos]); std::vector<shared_ptr<db::DB> > feature_dbs;
std::vector<shared_ptr<db::Transaction> > txns;
const char* db_type = argv[++arg_pos];
for (size_t i = ; i < num_features; ++i) {
LOG(INFO)<< "Opening dataset " << dataset_names[i];
shared_ptr<db::DB> db(db::GetDB(db_type));
db->Open(dataset_names.at(i), db::NEW);
feature_dbs.push_back(db);
shared_ptr<db::Transaction> txn(db->NewTransaction());
txns.push_back(txn);
} LOG(ERROR)<< "Extacting Features"; Datum datum;
const int kMaxKeyStrLength = ;
char key_str[kMaxKeyStrLength];
std::vector<Blob<float>*> input_vec;
std::vector<int> image_indices(num_features, );
for (int batch_index = ; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec);
for (int i = ; i < num_features; ++i) {
const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
->blob_by_name(blob_names[i]);
int batch_size = feature_blob->num();
int dim_features = feature_blob->count() / batch_size;
const Dtype* feature_blob_data;
for (int n = ; n < batch_size; ++n) {
datum.set_height(feature_blob->height());
datum.set_width(feature_blob->width());
datum.set_channels(feature_blob->channels());
datum.clear_data();
datum.clear_float_data();
feature_blob_data = feature_blob->cpu_data() +
feature_blob->offset(n);
for (int d = ; d < dim_features; ++d) {
datum.add_float_data(feature_blob_data[d]);
}
int length = snprintf(key_str, kMaxKeyStrLength, "%010d",
image_indices[i]);
string out;
CHECK(datum.SerializeToString(&out));
txns.at(i)->Put(std::string(key_str, length), out);
++image_indices[i];
if (image_indices[i] % == ) {
txns.at(i)->Commit();
txns.at(i).reset(feature_dbs.at(i)->NewTransaction());
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
" query images for feature blob " << blob_names[i];
}
} // for (int n = 0; n < batch_size; ++n)
} // for (int i = 0; i < num_features; ++i)
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch
for (int i = ; i < num_features; ++i) {
if (image_indices[i] % != ) {
txns.at(i)->Commit();
}
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
" query images for feature blob " << blob_names[i];
feature_dbs.at(i)->Close();
} LOG(ERROR)<< "Successfully extracted the features!";
return ;
}
caffe: test code for Deep Learning approach的更多相关文章
- 论文笔记之:From Facial Parts Responses to Face Detection: A Deep Learning Approach
From Facial Parts Responses to Face Detection: A Deep Learning Approach ICCV 2015 从以上两张图就可以感受到本文所提方法 ...
- 《3-D Deep Learning Approach for Remote Sensing Image Classification》论文笔记
论文题目<3-D Deep Learning Approach for Remote Sensing Image Classification> 论文作者:Amina Ben Hamida ...
- 论文阅读 | DeepDrawing: A Deep Learning Approach to Graph Drawing
作者:Yong Wang, Zhihua Jin, Qianwen Wang, Weiwei Cui, Tengfei Ma and Huamin Qu 本文发表于VIS2019, 来自于香港科技大学 ...
- (转) Awesome Deep Learning
Awesome Deep Learning Table of Contents Free Online Books Courses Videos and Lectures Papers Tutori ...
- 机器学习(Machine Learning)&深度学习(Deep Learning)资料【转】
转自:机器学习(Machine Learning)&深度学习(Deep Learning)资料 <Brief History of Machine Learning> 介绍:这是一 ...
- What are some good books/papers for learning deep learning?
What's the most effective way to get started with deep learning? 29 Answers Yoshua Bengio, ...
- 机器学习(Machine Learning)与深度学习(Deep Learning)资料汇总
<Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...
- (转) Awesome - Most Cited Deep Learning Papers
转自:https://github.com/terryum/awesome-deep-learning-papers Awesome - Most Cited Deep Learning Papers ...
- 深度学习阅读列表 Deep Learning Reading List
Reading List List of reading lists and survey papers: Books Deep Learning, Yoshua Bengio, Ian Goodfe ...
随机推荐
- php中的include()的使用技巧
php中的include()的使用技巧 include() 语句包括并运行指定文件. 以下文档也适用于 require().这两种结构除了在如何处理失败之外完全一样.include() 产生一个警告而 ...
- SecureCRT下的串口无法输入
用串口配置交换机的时候,出现的问题: 用secureCRT建了一个串口COM1后,连接上开发板后,可以正确接受和显示串口的输出,但是按键输入无效. 解决方法: Session Options -> ...
- STM32下载显示target dll has been cancelled
使用MDK 4.74向STM32下载时出现各种错误,而且时隐时现, Internal command error.Error:Flash download failed. Target DLL has ...
- 获得servletContext的路径的方法
request.getSession().getServletContext().getRealPath("")或 System.getProperty("webapp. ...
- Android IPC(inter-process Communitcation)
Android IPC(inter-process Communitcation) http://www.cnblogs.com/imlucky/archive/2013/08/08/3246013. ...
- C/C++ 字符串 null terminal
在C/C++中,字符串以'\0'结尾是一种强制要求,或者说,只有满足这个要求的 字符数组才能被称为字符串.否则,你所做的所有操作结果都是未定义的! C标准库string.h中所有关于字符串的函数都有一 ...
- 2013 imac 安装 win7
昨天晚上安装imac win7系统,其实步骤是很简单的,首先需要一个用boot camp助手做好的win7安装U盘或者有个外接光驱加一张win7光盘,然后用boot camp助理划分一个分区给win7 ...
- linux命令:ls
1.介绍: ls是linux日常操作中用的最多命令,是list的缩写,默认按名称排序列出当前目录和文件,ls --help可以查看帮助. 2.命令格式: ls [OPTION] [FILE] 3.命令 ...
- 中级iOS开发面试题
1:MVC的理解 MVC设计模式考虑三种对象:数据模型对象,视图对象和控制器对象. 数据模型:负责存储.定义.操作数据: 视图:展示数据给用户,和用户进行操作交互: 控制器:M与V的协调者,控制获取数 ...
- 转:Web.config配置文件详解(新手必看)
转:http://www.cnblogs.com/gaoweipeng/archive/2009/05/17/1458762.html 花了点时间整理了一下ASP.NET Web.config配置文件 ...