#include <stdio.h>  // for snprintf
#include <string>
#include <vector> #include "boost/algorithm/string.hpp"
#include "google/protobuf/text_format.h" #include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/net.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp" using caffe::Blob;
using caffe::Caffe;
using caffe::Datum;
using caffe::Net;
using boost::shared_ptr;
using std::string;
namespace db = caffe::db; template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv); int main(int argc, char** argv) {
return feature_extraction_pipeline<float>(argc, argv);
// return feature_extraction_pipeline<double>(argc, argv);
} template<typename Dtype>
int feature_extraction_pipeline(int argc, char** argv) {
::google::InitGoogleLogging(argv[]);
const int num_required_args = ;
if (argc < num_required_args) {
LOG(ERROR)<<
"This program takes in a trained network and an input data layer, and then"
" extract features of the input data produced by the net.\n"
"Usage: extract_features pretrained_net_param"
" feature_extraction_proto_file extract_feature_blob_name1[,name2,...]"
" save_feature_dataset_name1[,name2,...] num_mini_batches db_type"
" [CPU/GPU] [DEVICE_ID=0]\n"
"Note: you can extract multiple features in one pass by specifying"
" multiple feature blob names and dataset names separated by ','."
" The names cannot contain white space characters and the number of blobs"
" and datasets must be equal.";
return ;
}
int arg_pos = num_required_args; arg_pos = num_required_args;
if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == ) {
LOG(ERROR)<< "Using GPU";
uint device_id = ;
if (argc > arg_pos + ) {
device_id = atoi(argv[arg_pos + ]);
CHECK_GE(device_id, );
}
LOG(ERROR) << "Using Device_id=" << device_id;
Caffe::SetDevice(device_id);
Caffe::set_mode(Caffe::GPU);
} else {
LOG(ERROR) << "Using CPU";
Caffe::set_mode(Caffe::CPU);
} arg_pos = ; // the name of the executable
std::string pretrained_binary_proto(argv[++arg_pos]); // Expected prototxt contains at least one data layer such as
// the layer data_layer_name and one feature blob such as the
// fc7 top blob to extract features.
/*
layers {
name: "data_layer_name"
type: DATA
data_param {
source: "/path/to/your/images/to/extract/feature/images_leveldb"
mean_file: "/path/to/your/image_mean.binaryproto"
batch_size: 128
crop_size: 227
mirror: false
}
top: "data_blob_name"
top: "label_blob_name"
}
layers {
name: "drop7"
type: DROPOUT
dropout_param {
dropout_ratio: 0.5
}
bottom: "fc7"
top: "fc7"
}
*/
std::string feature_extraction_proto(argv[++arg_pos]);
shared_ptr<Net<Dtype> > feature_extraction_net(
new Net<Dtype>(feature_extraction_proto, caffe::TEST));
feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto); std::string extract_feature_blob_names(argv[++arg_pos]);
std::vector<std::string> blob_names;
boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); std::string save_feature_dataset_names(argv[++arg_pos]);
std::vector<std::string> dataset_names;
boost::split(dataset_names, save_feature_dataset_names,
boost::is_any_of(","));
CHECK_EQ(blob_names.size(), dataset_names.size()) <<
" the number of blob names and dataset names must be equal";
size_t num_features = blob_names.size(); for (size_t i = ; i < num_features; i++) {
CHECK(feature_extraction_net->has_blob(blob_names[i]))
<< "Unknown feature blob name " << blob_names[i]
<< " in the network " << feature_extraction_proto;
} int num_mini_batches = atoi(argv[++arg_pos]); std::vector<shared_ptr<db::DB> > feature_dbs;
std::vector<shared_ptr<db::Transaction> > txns;
const char* db_type = argv[++arg_pos];
for (size_t i = ; i < num_features; ++i) {
LOG(INFO)<< "Opening dataset " << dataset_names[i];
shared_ptr<db::DB> db(db::GetDB(db_type));
db->Open(dataset_names.at(i), db::NEW);
feature_dbs.push_back(db);
shared_ptr<db::Transaction> txn(db->NewTransaction());
txns.push_back(txn);
} LOG(ERROR)<< "Extacting Features"; Datum datum;
const int kMaxKeyStrLength = ;
char key_str[kMaxKeyStrLength];
std::vector<Blob<float>*> input_vec;
std::vector<int> image_indices(num_features, );
for (int batch_index = ; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec);
for (int i = ; i < num_features; ++i) {
const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
->blob_by_name(blob_names[i]);
int batch_size = feature_blob->num();
int dim_features = feature_blob->count() / batch_size;
const Dtype* feature_blob_data;
for (int n = ; n < batch_size; ++n) {
datum.set_height(feature_blob->height());
datum.set_width(feature_blob->width());
datum.set_channels(feature_blob->channels());
datum.clear_data();
datum.clear_float_data();
feature_blob_data = feature_blob->cpu_data() +
feature_blob->offset(n);
for (int d = ; d < dim_features; ++d) {
datum.add_float_data(feature_blob_data[d]);
}
int length = snprintf(key_str, kMaxKeyStrLength, "%010d",
image_indices[i]);
string out;
CHECK(datum.SerializeToString(&out));
txns.at(i)->Put(std::string(key_str, length), out);
++image_indices[i];
if (image_indices[i] % == ) {
txns.at(i)->Commit();
txns.at(i).reset(feature_dbs.at(i)->NewTransaction());
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
" query images for feature blob " << blob_names[i];
}
} // for (int n = 0; n < batch_size; ++n)
} // for (int i = 0; i < num_features; ++i)
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch
for (int i = ; i < num_features; ++i) {
if (image_indices[i] % != ) {
txns.at(i)->Commit();
}
LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
" query images for feature blob " << blob_names[i];
feature_dbs.at(i)->Close();
} LOG(ERROR)<< "Successfully extracted the features!";
return ;
}

caffe: test code for Deep Learning approach的更多相关文章

  1. 论文笔记之:From Facial Parts Responses to Face Detection: A Deep Learning Approach

    From Facial Parts Responses to Face Detection: A Deep Learning Approach ICCV 2015 从以上两张图就可以感受到本文所提方法 ...

  2. 《3-D Deep Learning Approach for Remote Sensing Image Classification》论文笔记

    论文题目<3-D Deep Learning Approach for Remote Sensing Image Classification> 论文作者:Amina Ben Hamida ...

  3. 论文阅读 | DeepDrawing: A Deep Learning Approach to Graph Drawing

    作者:Yong Wang, Zhihua Jin, Qianwen Wang, Weiwei Cui, Tengfei Ma and Huamin Qu 本文发表于VIS2019, 来自于香港科技大学 ...

  4. (转) Awesome Deep Learning

    Awesome Deep Learning  Table of Contents Free Online Books Courses Videos and Lectures Papers Tutori ...

  5. 机器学习(Machine Learning)&深度学习(Deep Learning)资料【转】

    转自:机器学习(Machine Learning)&深度学习(Deep Learning)资料 <Brief History of Machine Learning> 介绍:这是一 ...

  6. What are some good books/papers for learning deep learning?

    What's the most effective way to get started with deep learning?       29 Answers     Yoshua Bengio, ...

  7. 机器学习(Machine Learning)与深度学习(Deep Learning)资料汇总

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  8. (转) Awesome - Most Cited Deep Learning Papers

    转自:https://github.com/terryum/awesome-deep-learning-papers Awesome - Most Cited Deep Learning Papers ...

  9. 深度学习阅读列表 Deep Learning Reading List

    Reading List List of reading lists and survey papers: Books Deep Learning, Yoshua Bengio, Ian Goodfe ...

随机推荐

  1. iOS应用崩溃日志分析

    转自raywenderlich   作为一名应用开发者,你是否有过如下经历?   为确保你的应用正确无误,在将其提交到应用商店之前,你必定进行了大量的测试工作.它在你的设备上也运行得很好,但是,上了应 ...

  2. Typographical Concepts

    Glyph(字形) A glyph is an element of writing: an individual mark on a written medium that contributes ...

  3. iOS 中 #import同@class之间的区别

    很多刚开始学习iOS开发的同学可能在看别人的代码的时候会发现有部分#import操作写在m文件中,而h文件仅仅使用@class进行声明,不禁纳闷起来,为什么不直接把#import放到h文件中呢? 这是 ...

  4. Listview没有优化之前

    MainActivity.java package com.example.listviewdemo4; import java.util.ArrayList; import java.util.Ha ...

  5. poj1181 大数分解

    //Accepted 164 KB 422 ms //类似poj2429 大数分解 #include <cstdio> #include <cstring> #include ...

  6. hdu 2042

    Ps:...好简单的一道题...直接AC,就是利用递归 代码: #include "stdio.h"int find(int num,int n);int main(){ int ...

  7. linux下vi命令

    Vi共分三种模式,分别是“一般模式”.“编辑模式”与“命令行命令模式”. 1.一般模式:vi处理文件时,一进入该文件就是一般模式.在这个模式中,可以使用“上下左右”键来移动光标,可以使用“删除字符”或 ...

  8. swift系统学习控件篇:UIbutton+UIlabel+UITextField+UISwitch+UISlider

    工作之余,学习下swift大法.把自己的学习过程分享一下.当中的布局很乱,就表在意这些细节了.直接上代码: UIButton+UILabel // // ViewController.swift // ...

  9. 在 Linux 的 KVM虚拟机 上安装 Mac OS 系统的研究总结

    在 Linux 的 KVM虚拟机 上安装 Mac OS 系统的研究总结 一.资料来源:    网上一共找到两个方法,一个是视频上的教程,一个是网页资料. 二.视频资料方法内容:1.install qe ...

  10. codeforces 580C Kefa and Park(DFS)

    题目链接:http://codeforces.com/contest/580/problem/C #include<cstdio> #include<vector> #incl ...