tensorflow C++接口调用目标检测pb模型代码
#include <iostream> #include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h" #include <opencv2/opencv.hpp>
#include <cv.h>
#include <highgui.h>
#include <Eigen/Core>
#include <Eigen/Dense> using namespace std;
using namespace cv;
using namespace tensorflow; // 定义一个函数讲OpenCV的Mat数据转化为tensor,python里面只要对cv2.read读进来的矩阵进行np.reshape之后,
// 数据类型就成了一个tensor,即tensor与矩阵一样,然后就可以输入到网络的入口了,但是C++版本,我们网络开放的入口
// 也需要将输入图片转化成一个tensor,所以如果用OpenCV读取图片的话,就是一个Mat,然后就要考虑怎么将Mat转化为
// Tensor了
void CVMat_to_Tensor(Mat img,Tensor* output_tensor,int input_rows,int input_cols)
{
//imshow("input image",img);
//图像进行resize处理
resize(img,img,cv::Size(input_cols,input_rows));
//imshow("resized image",img); //归一化
img.convertTo(img,CV_8UC3); // CV_32FC3
//img=1-img/255; //创建一个指向tensor的内容的指针
uint8 *p = output_tensor->flat<uint8>().data(); //创建一个Mat,与tensor的指针绑定,改变这个Mat的值,就相当于改变tensor的值
cv::Mat tempMat(input_rows, input_cols, CV_8UC3, p);
img.convertTo(tempMat,CV_8UC3); // waitKey(0); } int main()
{
/*--------------------------------配置关键信息------------------------------*/
string model_path="../model/coco.pb";
string image_path="../test.jpg";
int input_height = ;
int input_width = ;
string input_tensor_name="image_tensor";
vector<string> out_put_nodes; //注意,在object detection中输出的三个节点名称为以下三个
out_put_nodes.push_back("detection_scores"); //detection_scores detection_classes detection_boxes
out_put_nodes.push_back("detection_classes");
out_put_nodes.push_back("detection_boxes"); /*--------------------------------创建session------------------------------*/
Session* session;
Status status = NewSession(SessionOptions(), &session);//创建新会话Session /*--------------------------------从pb文件中读取模型--------------------------------*/
GraphDef graphdef; //Graph Definition for current model Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
if (!status_load.ok()) {
cout << "ERROR: Loading model failed..." << model_path << std::endl;
cout << status_load.ToString() << "\n";
return -;
}
Status status_create = session->Create(graphdef); //将模型导入会话Session中;
if (!status_create.ok()) {
cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
return -;
}
cout << "<----Successfully created session and load graph.------->"<< endl; /*---------------------------------载入测试图片-------------------------------------*/
cout<<endl<<"<------------loading test_image-------------->"<<endl;
Mat img;
img = imread(image_path);
cvtColor(img, img, CV_BGR2RGB);
if(img.empty())
{
cout<<"can't open the image!!!!!!!"<<endl;
return -;
} //创建一个tensor作为输入网络的接口
Tensor resized_tensor(DT_UINT8, TensorShape({,input_height,input_width,})); //DT_FLOAT //将Opencv的Mat格式的图片存入tensor
CVMat_to_Tensor(img,&resized_tensor,input_height,input_width); cout << resized_tensor.DebugString()<<endl; /*-----------------------------------用网络进行测试-----------------------------------------*/
cout<<endl<<"<-------------Running the model with test_image--------------->"<<endl;
//前向运行,输出结果一定是一个tensor的vector
vector<tensorflow::Tensor> outputs; Status status_run = session->Run({{input_tensor_name, resized_tensor}}, {out_put_nodes}, {}, &outputs); if (!status_run.ok()) {
cout << "ERROR: RUN failed..." << std::endl;
cout << status_run.ToString() << "\n";
return -;
} //把输出值给提取出
cout << "Output tensor size:" << outputs.size() << std::endl; //
for (int i = ; i < outputs.size(); i++)
{
cout << outputs[i].DebugString()<<endl; // [1, 50], [1, 50], [1, 50, 4]
} cvtColor(img, img, CV_RGB2BGR); // opencv读入的是BGR格式输入网络前转为RGB
resize(img,img,cv::Size(,)); // 模型输入图像大小
int pre_num = outputs[].dim_size(); // 50 模型预测的目标数量
auto tmap_pro = outputs[].tensor<float, >(); //第一个是score输出shape为[1,50]
auto tmap_clas = outputs[].tensor<float, >(); //第二个是class输出shape为[1,50]
auto tmap_coor = outputs[].tensor<float, >(); //第三个是coordinate输出shape为[1,50,4]
float probability = 0.5; //自己设定的score阈值
for (int pre_i = ; pre_i < pre_num; pre_i++)
{
if (tmap_pro(, pre_i) < probability)
{
break;
}
cout << "Class ID: " << tmap_clas(, pre_i) << endl;
cout << "Probability: " << tmap_pro(, pre_i) << endl;
string id = to_string(int(tmap_clas(, pre_i)));
int xmin = int(tmap_coor(, pre_i, ) * input_width);
int ymin = int(tmap_coor(, pre_i, ) * input_height);
int xmax = int(tmap_coor(, pre_i, ) * input_width);
int ymax = int(tmap_coor(, pre_i, ) * input_height);
cout << "Xmin is: " << xmin << endl;
cout << "Ymin is: " << ymin << endl;
cout << "Xmax is: " << xmax << endl;
cout << "Ymax is: " << ymax << endl;
rectangle(img, cvPoint(xmin, ymin), cvPoint(xmax, ymax), Scalar(, , ), , , );
putText(img, id, cvPoint(xmin, ymin), FONT_HERSHEY_COMPLEX, 1.0, Scalar(,,), );
}
imshow("", img);
cvWaitKey(); return ;
}
CMakeLists.txt内容如下
cmake_minimum_required(VERSION 3.0.)
project(tensorflow_cpp) set(CMAKE_CXX_STANDARD ) find_package(OpenCV 3.0 QUIET)
if(NOT OpenCV_FOUND)
find_package(OpenCV 2.4. QUIET)
if(NOT OpenCV_FOUND)
message(FATAL_ERROR "OpenCV > 2.4.3 not found.")
endif()
endif() set(TENSORFLOW_INCLUDES
/usr/local/include/tf/
/usr/local/include/tf/bazel-genfiles
/usr/local/include/tf/tensorflow/
/usr/local/include/tf/tensorflow/third_party) set(TENSORFLOW_LIBS
/usr/local/lib/libtensorflow_cc.so
/usr/local/lib//libtensorflow_framework.so) include_directories(
${TENSORFLOW_INCLUDES}
${PROJECT_SOURCE_DIR}/third_party/eigen3
)
add_executable(predict predict.cpp)
target_link_libraries(predict
${TENSORFLOW_LIBS}
${OpenCV_LIBS}
)
目录结构如图所示

tensorflow C++接口调用目标检测pb模型代码的更多相关文章
- Mask R-CNN用于目标检测和分割代码实现
Mask R-CNN用于目标检测和分割代码实现 Mask R-CNN for object detection and instance segmentation on Keras and Tenso ...
- 大话目标检测经典模型(RCNN、Fast RCNN、Faster RCNN)
目标检测是深度学习的一个重要应用,就是在图片中要将里面的物体识别出来,并标出物体的位置,一般需要经过两个步骤:1.分类,识别物体是什么 2.定位,找出物体在哪里 除了对单个物体进行检测,还要能支持 ...
- 使用Faster R-CNN做目标检测 - 学习luminoth代码
像玩乐高一样拆解Faster R-CNN:详解目标检测的实现过程 https://mp.weixin.qq.com/s/M_i38L2brq69BYzmaPeJ9w 直接参考开源目标检测代码lumin ...
- tensorflow C++接口调用图像分类pb模型代码
#include <fstream> #include <utility> #include <Eigen/Core> #include <Eigen/Den ...
- R2CNN模型——用于文本目标检测的模型
引言 R2CNN全称Rotational Region CNN,是一个针对斜框文本检测的CNN模型,原型是Faster R-CNN,paper中的模型主要针对文本检测,调整后也可用于航拍图像的检测中去 ...
- [Tensorflow] 使用 Mask_RCNN 完成目标检测与实例分割,同时输出每个区域的 Feature Map
Mask_RCNN-2.0 网页链接:https://github.com/matterport/Mask_RCNN/releases/tag/v2.0 Mask_RCNN-master(matter ...
- OpenVINO 目标检测底层C++代码改写实现(待优化)
System: Centos7.4 I:OpenVINO 的安装 refer:https://docs.openvinotoolkit.org/latest/_docs_install_guides_ ...
- 评价目标检测(object detection)模型的参数:IOU,AP,mAP
首先我们为什么要使用这些呢? 举个简单的例子,假设我们图像里面只有1个目标,但是定位出来10个框,1个正确的,9个错误的,那么你要按(识别出来的正确的目标/总的正确目标)来算,正确率100%,但是其实 ...
- AI佳作解读系列(二)——目标检测AI算法集杂谈:R-CNN,faster R-CNN,yolo,SSD,yoloV2,yoloV3
1 引言 深度学习目前已经应用到了各个领域,应用场景大体分为三类:物体识别,目标检测,自然语言处理.本文着重与分析目标检测领域的深度学习方法,对其中的经典模型框架进行深入分析. 目标检测可以理解为是物 ...
随机推荐
- Git的http与ssh配置
http 进入git bash 直接clone所需项目 通过http方式 eg:git clone http://xxxxxxxxxx/bk_linux_inspect-master.git 会弹出提 ...
- 【转载】使用driver.findElement(By.id("txtPhoneNum")).getText();获取文本
今天在写自动化测试脚本的时候要获取一个输入框中的文本写了如下脚本: getAndSwitch("http://cas.minshengnet.com:14080/register/eRegi ...
- Cheat Engine 入门操作
Cheat Engine(简称CE,中文名-作弊引擎),用于查找.修改内存数据,是游戏逆向的基础工具. 本文仅介绍基础操作. 1.打开进程 运行游戏程序,并将CE附加到进程 2.寻找数据地址,并修改数 ...
- java菜鸟第一天
- Kubernetes-基于helm安装部署高可用的Redis及其形态探索(二)
上一章,我们通过实践和其他文章的帮助,在k8s的环境安装了redis-ha,并且对其进行了一些实验来验证他的主从切换是否有效.本篇中将会分析,究竟是如何实现了redis-ha的主从切换,以及其与K8S ...
- 洛谷 P2004 领地选择
题目传送门 解题思路: 二维前缀和. AC代码: #include<iostream> #include<cstdio> using namespace std; ][],an ...
- 吴裕雄--天生自然JAVA SPRING框架开发学习笔记:Spring实例化Bean的三种方法
在面向对象的程序中,要想调用某个类的成员方法,就需要先实例化该类的对象.在 Spring 中,实例化 Bean 有三种方式,分别是构造器实例化.静态工厂方式实例化和实例工厂方式实例化. 构造器实例化 ...
- SASS - 使用Sass程序
SASS – 简介 SASS – 环境搭建 SASS – 使用Sass程序 SASS – 语法 SASS – 变量 SASS- 局部文件(Partial) SASS – 混合(Mixin) SASS ...
- python刷LeetCode:13. 罗马数字转整数
难度等级:简单 题目描述: 罗马数字包含以下七种字符: I, V, X, L,C,D 和 M. 字符 数值I 1V 5X 10L 50C 100D 500M 1000例如, 罗马数字 2 写做 II ...
- POJ 2362:Square 觉得这才算深度搜索
Square Time Limit: 3000MS Memory Limit: 65536K Total Submissions: 21821 Accepted: 7624 Descripti ...