tensorflow C++接口调用目标检测pb模型代码
#include <iostream> #include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h" #include <opencv2/opencv.hpp>
#include <cv.h>
#include <highgui.h>
#include <Eigen/Core>
#include <Eigen/Dense> using namespace std;
using namespace cv;
using namespace tensorflow; // 定义一个函数讲OpenCV的Mat数据转化为tensor,python里面只要对cv2.read读进来的矩阵进行np.reshape之后,
// 数据类型就成了一个tensor,即tensor与矩阵一样,然后就可以输入到网络的入口了,但是C++版本,我们网络开放的入口
// 也需要将输入图片转化成一个tensor,所以如果用OpenCV读取图片的话,就是一个Mat,然后就要考虑怎么将Mat转化为
// Tensor了
void CVMat_to_Tensor(Mat img,Tensor* output_tensor,int input_rows,int input_cols)
{
//imshow("input image",img);
//图像进行resize处理
resize(img,img,cv::Size(input_cols,input_rows));
//imshow("resized image",img); //归一化
img.convertTo(img,CV_8UC3); // CV_32FC3
//img=1-img/255; //创建一个指向tensor的内容的指针
uint8 *p = output_tensor->flat<uint8>().data(); //创建一个Mat,与tensor的指针绑定,改变这个Mat的值,就相当于改变tensor的值
cv::Mat tempMat(input_rows, input_cols, CV_8UC3, p);
img.convertTo(tempMat,CV_8UC3); // waitKey(0); } int main()
{
/*--------------------------------配置关键信息------------------------------*/
string model_path="../model/coco.pb";
string image_path="../test.jpg";
int input_height = ;
int input_width = ;
string input_tensor_name="image_tensor";
vector<string> out_put_nodes; //注意,在object detection中输出的三个节点名称为以下三个
out_put_nodes.push_back("detection_scores"); //detection_scores detection_classes detection_boxes
out_put_nodes.push_back("detection_classes");
out_put_nodes.push_back("detection_boxes"); /*--------------------------------创建session------------------------------*/
Session* session;
Status status = NewSession(SessionOptions(), &session);//创建新会话Session /*--------------------------------从pb文件中读取模型--------------------------------*/
GraphDef graphdef; //Graph Definition for current model Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
if (!status_load.ok()) {
cout << "ERROR: Loading model failed..." << model_path << std::endl;
cout << status_load.ToString() << "\n";
return -;
}
Status status_create = session->Create(graphdef); //将模型导入会话Session中;
if (!status_create.ok()) {
cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
return -;
}
cout << "<----Successfully created session and load graph.------->"<< endl; /*---------------------------------载入测试图片-------------------------------------*/
cout<<endl<<"<------------loading test_image-------------->"<<endl;
Mat img;
img = imread(image_path);
cvtColor(img, img, CV_BGR2RGB);
if(img.empty())
{
cout<<"can't open the image!!!!!!!"<<endl;
return -;
} //创建一个tensor作为输入网络的接口
Tensor resized_tensor(DT_UINT8, TensorShape({,input_height,input_width,})); //DT_FLOAT //将Opencv的Mat格式的图片存入tensor
CVMat_to_Tensor(img,&resized_tensor,input_height,input_width); cout << resized_tensor.DebugString()<<endl; /*-----------------------------------用网络进行测试-----------------------------------------*/
cout<<endl<<"<-------------Running the model with test_image--------------->"<<endl;
//前向运行,输出结果一定是一个tensor的vector
vector<tensorflow::Tensor> outputs; Status status_run = session->Run({{input_tensor_name, resized_tensor}}, {out_put_nodes}, {}, &outputs); if (!status_run.ok()) {
cout << "ERROR: RUN failed..." << std::endl;
cout << status_run.ToString() << "\n";
return -;
} //把输出值给提取出
cout << "Output tensor size:" << outputs.size() << std::endl; //
for (int i = ; i < outputs.size(); i++)
{
cout << outputs[i].DebugString()<<endl; // [1, 50], [1, 50], [1, 50, 4]
} cvtColor(img, img, CV_RGB2BGR); // opencv读入的是BGR格式输入网络前转为RGB
resize(img,img,cv::Size(,)); // 模型输入图像大小
int pre_num = outputs[].dim_size(); // 50 模型预测的目标数量
auto tmap_pro = outputs[].tensor<float, >(); //第一个是score输出shape为[1,50]
auto tmap_clas = outputs[].tensor<float, >(); //第二个是class输出shape为[1,50]
auto tmap_coor = outputs[].tensor<float, >(); //第三个是coordinate输出shape为[1,50,4]
float probability = 0.5; //自己设定的score阈值
for (int pre_i = ; pre_i < pre_num; pre_i++)
{
if (tmap_pro(, pre_i) < probability)
{
break;
}
cout << "Class ID: " << tmap_clas(, pre_i) << endl;
cout << "Probability: " << tmap_pro(, pre_i) << endl;
string id = to_string(int(tmap_clas(, pre_i)));
int xmin = int(tmap_coor(, pre_i, ) * input_width);
int ymin = int(tmap_coor(, pre_i, ) * input_height);
int xmax = int(tmap_coor(, pre_i, ) * input_width);
int ymax = int(tmap_coor(, pre_i, ) * input_height);
cout << "Xmin is: " << xmin << endl;
cout << "Ymin is: " << ymin << endl;
cout << "Xmax is: " << xmax << endl;
cout << "Ymax is: " << ymax << endl;
rectangle(img, cvPoint(xmin, ymin), cvPoint(xmax, ymax), Scalar(, , ), , , );
putText(img, id, cvPoint(xmin, ymin), FONT_HERSHEY_COMPLEX, 1.0, Scalar(,,), );
}
imshow("", img);
cvWaitKey(); return ;
}
CMakeLists.txt内容如下
cmake_minimum_required(VERSION 3.0.)
project(tensorflow_cpp) set(CMAKE_CXX_STANDARD ) find_package(OpenCV 3.0 QUIET)
if(NOT OpenCV_FOUND)
find_package(OpenCV 2.4. QUIET)
if(NOT OpenCV_FOUND)
message(FATAL_ERROR "OpenCV > 2.4.3 not found.")
endif()
endif() set(TENSORFLOW_INCLUDES
/usr/local/include/tf/
/usr/local/include/tf/bazel-genfiles
/usr/local/include/tf/tensorflow/
/usr/local/include/tf/tensorflow/third_party) set(TENSORFLOW_LIBS
/usr/local/lib/libtensorflow_cc.so
/usr/local/lib//libtensorflow_framework.so) include_directories(
${TENSORFLOW_INCLUDES}
${PROJECT_SOURCE_DIR}/third_party/eigen3
)
add_executable(predict predict.cpp)
target_link_libraries(predict
${TENSORFLOW_LIBS}
${OpenCV_LIBS}
)
目录结构如图所示

tensorflow C++接口调用目标检测pb模型代码的更多相关文章
- Mask R-CNN用于目标检测和分割代码实现
		
Mask R-CNN用于目标检测和分割代码实现 Mask R-CNN for object detection and instance segmentation on Keras and Tenso ...
 - 大话目标检测经典模型(RCNN、Fast RCNN、Faster RCNN)
		
目标检测是深度学习的一个重要应用,就是在图片中要将里面的物体识别出来,并标出物体的位置,一般需要经过两个步骤:1.分类,识别物体是什么 2.定位,找出物体在哪里 除了对单个物体进行检测,还要能支持 ...
 - 使用Faster R-CNN做目标检测 - 学习luminoth代码
		
像玩乐高一样拆解Faster R-CNN:详解目标检测的实现过程 https://mp.weixin.qq.com/s/M_i38L2brq69BYzmaPeJ9w 直接参考开源目标检测代码lumin ...
 - tensorflow C++接口调用图像分类pb模型代码
		
#include <fstream> #include <utility> #include <Eigen/Core> #include <Eigen/Den ...
 - R2CNN模型——用于文本目标检测的模型
		
引言 R2CNN全称Rotational Region CNN,是一个针对斜框文本检测的CNN模型,原型是Faster R-CNN,paper中的模型主要针对文本检测,调整后也可用于航拍图像的检测中去 ...
 - [Tensorflow] 使用 Mask_RCNN 完成目标检测与实例分割,同时输出每个区域的 Feature Map
		
Mask_RCNN-2.0 网页链接:https://github.com/matterport/Mask_RCNN/releases/tag/v2.0 Mask_RCNN-master(matter ...
 - OpenVINO 目标检测底层C++代码改写实现(待优化)
		
System: Centos7.4 I:OpenVINO 的安装 refer:https://docs.openvinotoolkit.org/latest/_docs_install_guides_ ...
 - 评价目标检测(object detection)模型的参数:IOU,AP,mAP
		
首先我们为什么要使用这些呢? 举个简单的例子,假设我们图像里面只有1个目标,但是定位出来10个框,1个正确的,9个错误的,那么你要按(识别出来的正确的目标/总的正确目标)来算,正确率100%,但是其实 ...
 - AI佳作解读系列(二)——目标检测AI算法集杂谈:R-CNN,faster R-CNN,yolo,SSD,yoloV2,yoloV3
		
1 引言 深度学习目前已经应用到了各个领域,应用场景大体分为三类:物体识别,目标检测,自然语言处理.本文着重与分析目标检测领域的深度学习方法,对其中的经典模型框架进行深入分析. 目标检测可以理解为是物 ...
 
随机推荐
- 记-OSPF学习
			
LSA Type 1:Router LSA1.传播范围 :只能在本区域2.通告者 :每台路由器 (router-id作为标识)3.内容 :路由和拓扑信息show ip ospf database ro ...
 - HDU - 3068 最长回文(manacher算法)
			
题意:给出一个只由小写英文字符a,b,c...y,z组成的字符串S,求S中最长回文串的长度. 分析: manacher算法: 1.将字符串中每个字符的两边都插入一个特殊字符.(此操作的目的是,将字符串 ...
 - opencv 读写XML YML
			
//序列没有标签 CvMemStorage *mem = cvCreateMemStorage(0); CvFileStorage *file = cvOpenFileStorage("e: ...
 - apache2+django+virtualenv 服务器部署实战
			
目录 基本配置 配置python环境 安装 python.pip 安装 virtualenv 配置python虚拟环境 配置 apache2 安装 apache2 安装 mod-wsgi 服务 部署d ...
 - java课程之团队开发第一阶段评论
			
1.没有UI设计,整体的样式感觉不堪入目 2.功能方面实现的并不是很多,还需要继续努力 3.还需要添加一些常用的课表功能,比如说导入课表等
 - GPRS模块
			
一.参考网址 1.AT指令(中文详解版)(二)
 - springmvc的InternalResourceViewResolver自我理解
			
原文链接:https://blog.csdn.net/wwzuizz/article/details/78268007 它的作用是在Controller返回的时候进行解析视图 @RequestMapp ...
 - UE手游如何应对CPU帧率瓶颈和卡顿?
			
如何高效准确详细的对性能进行剖析?腾讯游戏学院专家Leonn将归纳总结在UE下对每一性能指标的剖析方法,本文重点讲解如何应对CPU帧率瓶颈和卡顿? CPU上帧率低和卡顿是性能优化中最易出现的一部分,尤 ...
 - 一天一个设计模式——Prototype 原型模式
			
一.模式说明 看了比较多的资料,对原型模式写的比较复杂,个人的理解就是模型复制,根据现有的类来直接创建新的类,而不是调用类的构造函数. 那为什么不直接调用new方法来创建类的实例呢,主要一个原因是如果 ...
 - 判断苹果和安卓端或者wp端
			
window.onload = function() { var u = navigator.userAgent; if(u.indexOf('Android') > -1 || u.index ...