Caffe2——C++ 预测(predict)Demo
因为最近入坑Caffe2,它最近还一直在更新,所以坑比较多,官方也只给出了python的demo,C++的暂时还找不到,有也只有一个简单版的,不够用,所以就总结了一下,结合网上和自己的实践,整理了一下代码。
#include "caffe2/core/flags.h"
#include "caffe2/core/init.h"
#include "caffe2/core/predictor.h"
#include "caffe2/utils/proto_utils.h"
#include <opencv2/opencv.hpp>
#include <ctime>
namespace caffe2 {
std::unique_ptr<Blob> randomTensor(
const std::vector<TIndex>& dims,
CPUContext* ctx
){
auto blob = make_unique<Blob>();
auto* t = blob->GetMutable<TensorCPU>();
t->Resize(dims);
math::RandUniform<float, CPUContext>(
t->size(), -1.0, 1.0, t->template mutable_data<float>(), ctx
);
return blob;
}
void run() {
// 定义初始化网络结构与权重值
caffe2::NetDef init_net, predict_net;
DeviceOption op;
op.set_random_seed(1701);
std::unique_ptr<CPUContext> ctx_;
ctx_ = caffe2::make_unique<CPUContext>(op);
// 读入网络结构文件
ReadProtoFromFile("squeezenet/exec_net.pb", &init_net);
ReadProtoFromFile("squeezenet/predict_net.pb", &predict_net);
// Can be large due to constant fills
VLOG(1) << "Init net: " << ProtoDebugString(init_net);
LOG(INFO) << "Predict net: " << ProtoDebugString(predict_net);
auto predictor = caffe2::make_unique<Predictor>(init_net, predict_net);
LOG(INFO) << "Checking that a null forward-pass works";
// 用opencv的方式读入文件
cv::Mat bgr_img = cv::imread("cat.jpg", -1);
int height = bgr_img.rows;
int width = bgr_img.cols;
// 输入图像大小
const int predHeight = 256;
const int predWidth = 256;
const int crops = 1; // crops等于1表示batch的数量为1
const int channels = 3; // 通道数为3,表示BGR,为1表示灰度图
const int size = predHeight * predWidth;
const float hscale = ((float)height) / predHeight; // 计算缩放比例
const float wscale = ((float)width) / predWidth;
const float scale = std::min(hscale, wscale);
// 初始化网络的输入,因为可能要做batch操作,所以分配一段连续的存储空间
std::vector<float> inputPlanar(crops * channels * predHeight * predWidth);
std::cout << "before resizing, bgr_img.cols=" << bgr_img.cols << ", bgr_img.rows=" << bgr_img.rows << std::endl;
// resize成想要的输入大小
cv::Size dsize = cv::Size(bgr_img.cols / wscale, bgr_img.rows / hscale);
cv::resize(bgr_img, bgr_img, dsize);
std::cout << "after resizing, bgr_img.cols=" << bgr_img.cols << ", bgr_img.rows=" << bgr_img.rows << std::endl;
// Scale down the input to a reasonable predictor size.
// 这里是将图像复制到连续的存储空间内,用于网络的输入,因为是BGR三通道,所以有三个赋值
// 注意imread读入的图像格式是unsigned char,如果你的网络输入要求是float的话,下面的操作就不对了。
for (auto i=0; i<predHeight; i++) {
//printf("+\n");
for (auto j=0; j<predWidth; j++) {
inputPlanar[i * predWidth + j + 0*size] = (float)bgr_img.data[(i*predWidth + j) * 3 + 0];
inputPlanar[i * predWidth + j + 1*size] = (float)bgr_img.data[(i*predWidth + j) * 3 + 1];
inputPlanar[i * predWidth + j + 2*size] = (float)bgr_img.data[(i*predWidth + j) * 3 + 2];
}
}
// 输入是float格式
//for (auto i = 0; i < predHeight; i++) {
// 模版的输入格式是float
// const float* inData = bgr_img.ptr<float>(i);
// for (auto j = 0; j < predWidth; j++) {
// inputPlanar[i * predWidth + j + 0 * size] = (float)((inData[j]) * 3 + 0);
// inputPlanar[i * predWidth + j + 1 * size] = (float)((inData[j]) * 3 + 1);
// inputPlanar[i * predWidth + j + 2 * size] = (float)((inData[j]) * 3 + 2);
// }
//}
//typedef Tensor<CPUContext> TensorCPU;
// input就是网络的输入,所以把之前准备好的数据赋值给input就可以了
caffe2::TensorCPU input;
input.Resize(std::vector<int>({crops, channels, predHeight, predWidth}));
input.ShareExternalPointer(inputPlanar.data());
//Predictor::TensorVector inputVec{inputData->template GetMutable<TensorCPU>()};
Predictor::TensorVector inputVec{&input};
Predictor::TensorVector outputVec;
//predictor->run(inputVec, &outputVec);
//CAFFE_ENFORCE_GT(outputVec.size(), 0);
std::clock_t begin = clock(); //begin time of inference
// 预测
predictor->run(inputVec, &outputVec);
//std::cout << "CAFFE2_LOG_THRESHOLD=" << CAFFE2_LOG_THRESHOLD << std::endl;
//std::cout << "init_net.name()" << init_net.name();
std::clock_t end = clock();
double elapsed_secs = double(end-begin) / CLOCKS_PER_SEC;
std::cout << "inference takes " << elapsed_secs << std::endl;
float max_value = 0;
int best_match_index = -1;
// 迭代输出结果,output的大小就是网络输出的大小
for(auto output : outputVec) {
for(auto i=0; i<output->size(); ++i){
// val对应的就是每一类的概率值
float val = output->template data<float>()[i];
if(val>0.001){
printf("%i: %s : %f\n", i, imagenet_classes[i], val);
if(val>max_value) {
max_value = val;
best_match_index = i;
}
}
}
}
// 这里是用imagenet数据集为例
std::cout << "predicted result is:" << imagenet_classes[best_match_index] << ", with confidence of " << max_value << std::endl;
}
}
int main(int argc, char** argv) {
caffe2::GlobalInit(&argc, &argv);
caffe2::run();
// This is to allow us to use memory leak checks.
google::protobuf::ShutdownProtobufLibrary();
return 0;
}
Caffe2——C++ 预测(predict)Demo的更多相关文章
- 矩池云 | 利用LSTM框架实时预测比特币价格
温馨提示:本案例只作为学习研究用途,不构成投资建议. 比特币的价格数据是基于时间序列的,因此比特币的价格预测大多采用LSTM模型来实现. 长期短期记忆(LSTM)是一种特别适用于时间序列数据(或具有时 ...
- 使用C++版本Mxnett进行预测的注意事项
现在越来越多的人选择Mxnet作为深度学习框架,相应的中文社区非常活跃,而且后面推出的gluon以及gluoncv非常适合上手和实验,特别是gluoncv中提供了非常多.非常新的预训练model zo ...
- 天池新人赛-天池新人实战赛o2o优惠券使用预测(一)
第一次参加天池新人赛,主要目的还是想考察下自己对机器学习上的成果,以及系统化的实现一下所学的东西.看看自己的掌握度如何,能否顺利的完成一个分析工作.为之后的学习奠定基础. 这次成绩并不好,只是把整个机 ...
- keras系列︱迁移学习:利用InceptionV3进行fine-tuning及预测、完美案例(五)
引自:http://blog.csdn.net/sinat_26917383/article/details/72982230 之前在博客<keras系列︱图像多分类训练与利用bottlenec ...
- Titanic幸存预测分析(Kaggle)
分享一篇kaggle入门级案例,泰坦尼克号幸存遇难分析. 参考文章: 技术世界,原文链接 http://www.jasongj.com/ml/classification/ 案例分析内容: 通过训练集 ...
- 目标跟踪之卡尔曼滤波---理解Kalman滤波的使用预测
Kalman滤波简介 Kalman滤波是一种线性滤波与预测方法,原文为:A New Approach to Linear Filtering and Prediction Problems.文章推导很 ...
- deepmoji:文本预测emoji
输入句子,预测emoji demo: https://deepmoji.mit.edu/ github: https://github.com/bfelbo/DeepMoji 能够被预测的emoji ...
- PyTorch基础——预测共享单车的使用量
预处理实验数据 读取数据 下载数据 网盘链接:https://pan.baidu.com/s/1n_FtZjAswWR9rfuI6GtDhA 提取码:y4fb #导入需要使用的库 import num ...
- 推断(inference)和预测(prediction)
上二年级的大儿子一直在喝无乳糖牛奶,最近让他尝试喝正常牛奶,看看反应如何.三天过后,儿子说,好像没反应,我可不可以说我不对乳糖敏感了. 我说,呃,这个问题不简单啊.你知道吗,这在统计学上叫推断. 儿子 ...
随机推荐
- python 之 内置函数大全
一.罗列全部的内置函数 戳:https://docs.python.org/2/library/functions.html 二.range.xrange(迭代器) 无论是range()还是xrang ...
- 面试10大算法汇总——Java篇
问题导读 1 字符串和数组 2 链表 3 树 4 图 5 排序 6 递归 vs 迭代 7 动态规划 8 位操作 9 概率问题 10 排列组合 11 其他 -- 寻找规律 英文版 以下从Java角度解释 ...
- 洛谷oj U3936(分成回文串) 邀请码:a0c9
题目链接:传送门 题目大意:略 题目思路:DP 先预处理,分别以每个字母为中心处理能形成的回文串,再以两个字母为中心处理能形成的回文串. 然后 dp[i] 表示1~i 能形成的数目最少的回文串. 转移 ...
- sql server 作业没跑、开启sql 代理服务、新建作业
sql server 数据库中设置了晚上跑的作业,以前没注意,后来换了服务器建了新的虚拟机后第二天发现作业没跑. 主动执行作业可以实现目的,但是他不会自动执行,那么问题来了,为啥呢? 没有开启SQL ...
- Guava教程
http://ifeve.com/google-guava/ github地址:https://github.com/google/guava
- Oracle数据库的连接模式connection Mode、连接connection与会话session
数据库的连接模式Connection Mode: Dedicated Server Mode(专有模式) 当用户发出请求时,如远程的client端通过监听器连接数据库上,ORACLE的服务器端会启用一 ...
- C# 利用StringBuilder提升字符串拼接性能
一个项目中有数据图表呈现,数据量稍大时显得很慢. 用Stopwatch分段监控了一下,发现耗时最多的函数是SaveToExcel 此函数中遍列所有数据行,通过Replace替换标签生成Excel行,然 ...
- "零代码”开发B/S企业管理软件之一 :怎么创建数据库表
声明:该软件为本人原创作品,多年来一直在使用该软件做项目,软件本身也一直在改善,在增加新的功能.但一个人总是会有很多考虑不周全的地方,希望能找到做同类软件的同行一起探讨. 本人文笔不行,能把意思表达清 ...
- Dictionary里使用struct,enum做key
首先看下Dictionary的源码 public void Add (TKey key, TValue value) { if (key == null) throw new ArgumentNull ...
- centos7 docker 安装配置
docker快速入门测试 ########################################## #docker安装配置 #环境centos7 #配置docker阿里源 echo '#D ...