因为最近入坑Caffe2,它最近还一直在更新,所以坑比较多,官方也只给出了python的demo,C++的暂时还找不到,有也只有一个简单版的,不够用,所以就总结了一下,结合网上和自己的实践,整理了一下代码。

#include "caffe2/core/flags.h"
#include "caffe2/core/init.h"
#include "caffe2/core/predictor.h"
#include "caffe2/utils/proto_utils.h"
#include <opencv2/opencv.hpp>
#include <ctime> namespace caffe2 { std::unique_ptr<Blob> randomTensor(
const std::vector<TIndex>& dims,
CPUContext* ctx
){
auto blob = make_unique<Blob>();
auto* t = blob->GetMutable<TensorCPU>();
t->Resize(dims);
math::RandUniform<float, CPUContext>(
t->size(), -1.0, 1.0, t->template mutable_data<float>(), ctx
);
return blob;
} void run() {
// 定义初始化网络结构与权重值
caffe2::NetDef init_net, predict_net;
DeviceOption op;
op.set_random_seed(1701); std::unique_ptr<CPUContext> ctx_;
ctx_ = caffe2::make_unique<CPUContext>(op); // 读入网络结构文件
ReadProtoFromFile("squeezenet/exec_net.pb", &init_net);
ReadProtoFromFile("squeezenet/predict_net.pb", &predict_net); // Can be large due to constant fills
VLOG(1) << "Init net: " << ProtoDebugString(init_net);
LOG(INFO) << "Predict net: " << ProtoDebugString(predict_net);
auto predictor = caffe2::make_unique<Predictor>(init_net, predict_net);
LOG(INFO) << "Checking that a null forward-pass works"; // 用opencv的方式读入文件
cv::Mat bgr_img = cv::imread("cat.jpg", -1); int height = bgr_img.rows;
int width = bgr_img.cols; // 输入图像大小
const int predHeight = 256;
const int predWidth = 256;
const int crops = 1; // crops等于1表示batch的数量为1
const int channels = 3; // 通道数为3,表示BGR,为1表示灰度图
const int size = predHeight * predWidth;
const float hscale = ((float)height) / predHeight; // 计算缩放比例
const float wscale = ((float)width) / predWidth;
const float scale = std::min(hscale, wscale);
// 初始化网络的输入,因为可能要做batch操作,所以分配一段连续的存储空间
std::vector<float> inputPlanar(crops * channels * predHeight * predWidth); std::cout << "before resizing, bgr_img.cols=" << bgr_img.cols << ", bgr_img.rows=" << bgr_img.rows << std::endl;
// resize成想要的输入大小
cv::Size dsize = cv::Size(bgr_img.cols / wscale, bgr_img.rows / hscale);
cv::resize(bgr_img, bgr_img, dsize);
std::cout << "after resizing, bgr_img.cols=" << bgr_img.cols << ", bgr_img.rows=" << bgr_img.rows << std::endl;
// Scale down the input to a reasonable predictor size.
// 这里是将图像复制到连续的存储空间内,用于网络的输入,因为是BGR三通道,所以有三个赋值
// 注意imread读入的图像格式是unsigned char,如果你的网络输入要求是float的话,下面的操作就不对了。
for (auto i=0; i<predHeight; i++) {
//printf("+\n");
for (auto j=0; j<predWidth; j++) {
inputPlanar[i * predWidth + j + 0*size] = (float)bgr_img.data[(i*predWidth + j) * 3 + 0];
inputPlanar[i * predWidth + j + 1*size] = (float)bgr_img.data[(i*predWidth + j) * 3 + 1];
inputPlanar[i * predWidth + j + 2*size] = (float)bgr_img.data[(i*predWidth + j) * 3 + 2];
}
}
// 输入是float格式
//for (auto i = 0; i < predHeight; i++) {
// 模版的输入格式是float
// const float* inData = bgr_img.ptr<float>(i);
// for (auto j = 0; j < predWidth; j++) {
// inputPlanar[i * predWidth + j + 0 * size] = (float)((inData[j]) * 3 + 0);
// inputPlanar[i * predWidth + j + 1 * size] = (float)((inData[j]) * 3 + 1);
// inputPlanar[i * predWidth + j + 2 * size] = (float)((inData[j]) * 3 + 2);
// }
//} //typedef Tensor<CPUContext> TensorCPU;
// input就是网络的输入,所以把之前准备好的数据赋值给input就可以了
caffe2::TensorCPU input;
input.Resize(std::vector<int>({crops, channels, predHeight, predWidth}));
input.ShareExternalPointer(inputPlanar.data()); //Predictor::TensorVector inputVec{inputData->template GetMutable<TensorCPU>()};
Predictor::TensorVector inputVec{&input}; Predictor::TensorVector outputVec;
//predictor->run(inputVec, &outputVec);
//CAFFE_ENFORCE_GT(outputVec.size(), 0); std::clock_t begin = clock(); //begin time of inference
// 预测
predictor->run(inputVec, &outputVec); //std::cout << "CAFFE2_LOG_THRESHOLD=" << CAFFE2_LOG_THRESHOLD << std::endl;
//std::cout << "init_net.name()" << init_net.name(); std::clock_t end = clock();
double elapsed_secs = double(end-begin) / CLOCKS_PER_SEC; std::cout << "inference takes " << elapsed_secs << std::endl; float max_value = 0;
int best_match_index = -1;
// 迭代输出结果,output的大小就是网络输出的大小
for(auto output : outputVec) {
for(auto i=0; i<output->size(); ++i){
// val对应的就是每一类的概率值
float val = output->template data<float>()[i];
if(val>0.001){
printf("%i: %s : %f\n", i, imagenet_classes[i], val);
if(val>max_value) {
max_value = val;
best_match_index = i;
}
}
}
}
// 这里是用imagenet数据集为例
std::cout << "predicted result is:" << imagenet_classes[best_match_index] << ", with confidence of " << max_value << std::endl; }
} int main(int argc, char** argv) {
caffe2::GlobalInit(&argc, &argv);
caffe2::run();
// This is to allow us to use memory leak checks.
google::protobuf::ShutdownProtobufLibrary();
return 0;
}

Caffe2——C++ 预测(predict)Demo的更多相关文章

  1. 矩池云 | 利用LSTM框架实时预测比特币价格

    温馨提示:本案例只作为学习研究用途,不构成投资建议. 比特币的价格数据是基于时间序列的,因此比特币的价格预测大多采用LSTM模型来实现. 长期短期记忆(LSTM)是一种特别适用于时间序列数据(或具有时 ...

  2. 使用C++版本Mxnett进行预测的注意事项

    现在越来越多的人选择Mxnet作为深度学习框架,相应的中文社区非常活跃,而且后面推出的gluon以及gluoncv非常适合上手和实验,特别是gluoncv中提供了非常多.非常新的预训练model zo ...

  3. 天池新人赛-天池新人实战赛o2o优惠券使用预测(一)

    第一次参加天池新人赛,主要目的还是想考察下自己对机器学习上的成果,以及系统化的实现一下所学的东西.看看自己的掌握度如何,能否顺利的完成一个分析工作.为之后的学习奠定基础. 这次成绩并不好,只是把整个机 ...

  4. keras系列︱迁移学习:利用InceptionV3进行fine-tuning及预测、完美案例(五)

    引自:http://blog.csdn.net/sinat_26917383/article/details/72982230 之前在博客<keras系列︱图像多分类训练与利用bottlenec ...

  5. Titanic幸存预测分析(Kaggle)

    分享一篇kaggle入门级案例,泰坦尼克号幸存遇难分析. 参考文章: 技术世界,原文链接 http://www.jasongj.com/ml/classification/ 案例分析内容: 通过训练集 ...

  6. 目标跟踪之卡尔曼滤波---理解Kalman滤波的使用预测

    Kalman滤波简介 Kalman滤波是一种线性滤波与预测方法,原文为:A New Approach to Linear Filtering and Prediction Problems.文章推导很 ...

  7. deepmoji:文本预测emoji

    输入句子,预测emoji demo: https://deepmoji.mit.edu/ github: https://github.com/bfelbo/DeepMoji  能够被预测的emoji ...

  8. PyTorch基础——预测共享单车的使用量

    预处理实验数据 读取数据 下载数据 网盘链接:https://pan.baidu.com/s/1n_FtZjAswWR9rfuI6GtDhA 提取码:y4fb #导入需要使用的库 import num ...

  9. 推断(inference)和预测(prediction)

    上二年级的大儿子一直在喝无乳糖牛奶,最近让他尝试喝正常牛奶,看看反应如何.三天过后,儿子说,好像没反应,我可不可以说我不对乳糖敏感了. 我说,呃,这个问题不简单啊.你知道吗,这在统计学上叫推断. 儿子 ...

随机推荐

  1. iOS-tableView本地动画刷新

    比如:就拿删除tableView中一个Cell为例子. // XXXTableViewCellDelegate - (void)tapDeleteHelloUser:(CJHelloTableView ...

  2. 51、自定义View基础和原理

    一.编写自己的自定义View最简单的自定义View,继承View通过覆盖View的onDraw方法来实现自主显示利用Canvas和paint来绘制显示元素(文字,几何图形等) <com.myvi ...

  3. Kotlin——初级篇(八):关于字符串(String)常用操作汇总

    在前面讲解Kotlin数据类型的时候,提到了字符串类型,当然关于其定义在前面的章节中已经讲解过了.对Kotlin中的数据类型不清楚的同学.请参考Kotlin--初级篇(三):数据类型详解这篇文章. 在 ...

  4. 判断 null undefined NaN

    1.判断undefined: var tmp = undefined; if (typeof(tmp) == "undefined"){ alert("undefined ...

  5. Thrift快速入门

    Thrift 简单示例 2017-01-19 16:47:57 首先通过先面两个示例简单感受一下Thrift(RPC)服务端与客户端之间的通信...... RPC学习----Thrift快速入门和Ja ...

  6. CodeForeces 665C Simple Strings

    C. Simple Strings time limit per test 2 seconds memory limit per test 256 megabytes input standard i ...

  7. SQL server中使用临时表存储数据

    将查询出来的数据直接用“INTO #临时表名称”的方式完成临时表的创建及数据的插入 SELECT * INTO #temp_NowStatusFROM Test SELECT * FROM #temp ...

  8. Understanding Tensorflow using Go

    原文: https://pgaleone.eu/tensorflow/go/2017/05/29/understanding-tensorflow-using-go/ Tensorflow is no ...

  9. Setting IE11 with Group Policy Preferences

    一.Setting Home Page with Group Policy Preferences 1.Open the Group Policy Management Console and cre ...

  10. CXF 框架

    1. 搭建服务端(查询天气) // 1. 引入cxf的 jar 包; // 2. 创建 SEI 接口, 需要加入注解: @WebService @WebService public interface ...