#include <iostream>

#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h" #include <opencv2/opencv.hpp>
#include <cv.h>
#include <highgui.h>
#include <Eigen/Core>
#include <Eigen/Dense> using namespace std;
using namespace cv;
using namespace tensorflow; // 定义一个函数讲OpenCV的Mat数据转化为tensor,python里面只要对cv2.read读进来的矩阵进行np.reshape之后,
// 数据类型就成了一个tensor,即tensor与矩阵一样,然后就可以输入到网络的入口了,但是C++版本,我们网络开放的入口
// 也需要将输入图片转化成一个tensor,所以如果用OpenCV读取图片的话,就是一个Mat,然后就要考虑怎么将Mat转化为
// Tensor了
void CVMat_to_Tensor(Mat img,Tensor* output_tensor,int input_rows,int input_cols)
{
//imshow("input image",img);
//图像进行resize处理
resize(img,img,cv::Size(input_cols,input_rows));
//imshow("resized image",img); //归一化
img.convertTo(img,CV_8UC3); // CV_32FC3
//img=1-img/255; //创建一个指向tensor的内容的指针
uint8 *p = output_tensor->flat<uint8>().data(); //创建一个Mat,与tensor的指针绑定,改变这个Mat的值,就相当于改变tensor的值
cv::Mat tempMat(input_rows, input_cols, CV_8UC3, p);
img.convertTo(tempMat,CV_8UC3); // waitKey(0); } int main()
{
/*--------------------------------配置关键信息------------------------------*/
string model_path="../model/coco.pb";
string image_path="../test.jpg";
int input_height = ;
int input_width = ;
string input_tensor_name="image_tensor";
vector<string> out_put_nodes; //注意,在object detection中输出的三个节点名称为以下三个
out_put_nodes.push_back("detection_scores"); //detection_scores detection_classes detection_boxes
out_put_nodes.push_back("detection_classes");
out_put_nodes.push_back("detection_boxes"); /*--------------------------------创建session------------------------------*/
Session* session;
Status status = NewSession(SessionOptions(), &session);//创建新会话Session /*--------------------------------从pb文件中读取模型--------------------------------*/
GraphDef graphdef; //Graph Definition for current model Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
if (!status_load.ok()) {
cout << "ERROR: Loading model failed..." << model_path << std::endl;
cout << status_load.ToString() << "\n";
return -;
}
Status status_create = session->Create(graphdef); //将模型导入会话Session中;
if (!status_create.ok()) {
cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
return -;
}
cout << "<----Successfully created session and load graph.------->"<< endl; /*---------------------------------载入测试图片-------------------------------------*/
cout<<endl<<"<------------loading test_image-------------->"<<endl;
Mat img;
img = imread(image_path);
cvtColor(img, img, CV_BGR2RGB);
if(img.empty())
{
cout<<"can't open the image!!!!!!!"<<endl;
return -;
} //创建一个tensor作为输入网络的接口
Tensor resized_tensor(DT_UINT8, TensorShape({,input_height,input_width,})); //DT_FLOAT //将Opencv的Mat格式的图片存入tensor
CVMat_to_Tensor(img,&resized_tensor,input_height,input_width); cout << resized_tensor.DebugString()<<endl; /*-----------------------------------用网络进行测试-----------------------------------------*/
cout<<endl<<"<-------------Running the model with test_image--------------->"<<endl;
//前向运行,输出结果一定是一个tensor的vector
vector<tensorflow::Tensor> outputs; Status status_run = session->Run({{input_tensor_name, resized_tensor}}, {out_put_nodes}, {}, &outputs); if (!status_run.ok()) {
cout << "ERROR: RUN failed..." << std::endl;
cout << status_run.ToString() << "\n";
return -;
} //把输出值给提取出
cout << "Output tensor size:" << outputs.size() << std::endl; //
for (int i = ; i < outputs.size(); i++)
{
cout << outputs[i].DebugString()<<endl; // [1, 50], [1, 50], [1, 50, 4]
} cvtColor(img, img, CV_RGB2BGR); // opencv读入的是BGR格式输入网络前转为RGB
resize(img,img,cv::Size(,)); // 模型输入图像大小
int pre_num = outputs[].dim_size(); // 50 模型预测的目标数量
auto tmap_pro = outputs[].tensor<float, >(); //第一个是score输出shape为[1,50]
auto tmap_clas = outputs[].tensor<float, >(); //第二个是class输出shape为[1,50]
auto tmap_coor = outputs[].tensor<float, >(); //第三个是coordinate输出shape为[1,50,4]
float probability = 0.5; //自己设定的score阈值
for (int pre_i = ; pre_i < pre_num; pre_i++)
{
if (tmap_pro(, pre_i) < probability)
{
break;
}
cout << "Class ID: " << tmap_clas(, pre_i) << endl;
cout << "Probability: " << tmap_pro(, pre_i) << endl;
string id = to_string(int(tmap_clas(, pre_i)));
int xmin = int(tmap_coor(, pre_i, ) * input_width);
int ymin = int(tmap_coor(, pre_i, ) * input_height);
int xmax = int(tmap_coor(, pre_i, ) * input_width);
int ymax = int(tmap_coor(, pre_i, ) * input_height);
cout << "Xmin is: " << xmin << endl;
cout << "Ymin is: " << ymin << endl;
cout << "Xmax is: " << xmax << endl;
cout << "Ymax is: " << ymax << endl;
rectangle(img, cvPoint(xmin, ymin), cvPoint(xmax, ymax), Scalar(, , ), , , );
putText(img, id, cvPoint(xmin, ymin), FONT_HERSHEY_COMPLEX, 1.0, Scalar(,,), );
}
imshow("", img);
cvWaitKey(); return ;
}

CMakeLists.txt内容如下

cmake_minimum_required(VERSION 3.0.)
project(tensorflow_cpp) set(CMAKE_CXX_STANDARD ) find_package(OpenCV 3.0 QUIET)
if(NOT OpenCV_FOUND)
find_package(OpenCV 2.4. QUIET)
if(NOT OpenCV_FOUND)
message(FATAL_ERROR "OpenCV > 2.4.3 not found.")
endif()
endif() set(TENSORFLOW_INCLUDES
/usr/local/include/tf/
/usr/local/include/tf/bazel-genfiles
/usr/local/include/tf/tensorflow/
/usr/local/include/tf/tensorflow/third_party) set(TENSORFLOW_LIBS
/usr/local/lib/libtensorflow_cc.so
/usr/local/lib//libtensorflow_framework.so) include_directories(
${TENSORFLOW_INCLUDES}
${PROJECT_SOURCE_DIR}/third_party/eigen3
)
add_executable(predict predict.cpp)
target_link_libraries(predict
${TENSORFLOW_LIBS}
${OpenCV_LIBS}
)

目录结构如图所示

tensorflow C++接口调用目标检测pb模型代码的更多相关文章

  1. Mask R-CNN用于目标检测和分割代码实现

    Mask R-CNN用于目标检测和分割代码实现 Mask R-CNN for object detection and instance segmentation on Keras and Tenso ...

  2. 大话目标检测经典模型(RCNN、Fast RCNN、Faster RCNN)

      目标检测是深度学习的一个重要应用,就是在图片中要将里面的物体识别出来,并标出物体的位置,一般需要经过两个步骤:1.分类,识别物体是什么 2.定位,找出物体在哪里 除了对单个物体进行检测,还要能支持 ...

  3. 使用Faster R-CNN做目标检测 - 学习luminoth代码

    像玩乐高一样拆解Faster R-CNN:详解目标检测的实现过程 https://mp.weixin.qq.com/s/M_i38L2brq69BYzmaPeJ9w 直接参考开源目标检测代码lumin ...

  4. tensorflow C++接口调用图像分类pb模型代码

    #include <fstream> #include <utility> #include <Eigen/Core> #include <Eigen/Den ...

  5. R2CNN模型——用于文本目标检测的模型

    引言 R2CNN全称Rotational Region CNN,是一个针对斜框文本检测的CNN模型,原型是Faster R-CNN,paper中的模型主要针对文本检测,调整后也可用于航拍图像的检测中去 ...

  6. [Tensorflow] 使用 Mask_RCNN 完成目标检测与实例分割,同时输出每个区域的 Feature Map

    Mask_RCNN-2.0 网页链接:https://github.com/matterport/Mask_RCNN/releases/tag/v2.0 Mask_RCNN-master(matter ...

  7. OpenVINO 目标检测底层C++代码改写实现(待优化)

    System: Centos7.4 I:OpenVINO 的安装 refer:https://docs.openvinotoolkit.org/latest/_docs_install_guides_ ...

  8. 评价目标检测(object detection)模型的参数:IOU,AP,mAP

    首先我们为什么要使用这些呢? 举个简单的例子,假设我们图像里面只有1个目标,但是定位出来10个框,1个正确的,9个错误的,那么你要按(识别出来的正确的目标/总的正确目标)来算,正确率100%,但是其实 ...

  9. AI佳作解读系列(二)——目标检测AI算法集杂谈:R-CNN,faster R-CNN,yolo,SSD,yoloV2,yoloV3

    1 引言 深度学习目前已经应用到了各个领域,应用场景大体分为三类:物体识别,目标检测,自然语言处理.本文着重与分析目标检测领域的深度学习方法,对其中的经典模型框架进行深入分析. 目标检测可以理解为是物 ...

随机推荐

  1. torch.cuda.FloatTensor

    Pytorch中的tensor又包括CPU上的数据类型和GPU上的数据类型,一般GPU上的Tensor是CPU上的Tensor加cuda()函数得到. 一般系统默认是torch.FloatTensor ...

  2. /etc/apt/sources.list.d

    deb http://ppa.launchpad.net/webupd8team/java/ubuntu xenial main# deb-src http://ppa.launchpad.net/w ...

  3. 15 ~ express ~ 用户数据分页原理和实现

    一,在后台路由 /router/admin.js 中 1,限制获取的数据条数 : User.find().limit(Number) 2,忽略数据的前(Number)条数据 : skip(Number ...

  4. Arduino -- functions

    For controlling the Arduino board and performing computations. Digital I/O digitalRead() digitalWrit ...

  5. HTTP协议、时间戳

    1.什么是HTTP协议 超文本传输协议(英文:HyperText Transfer Protocol,缩写:HTTP)是一种用于分布式.协作式和超媒体信息系统的应用层协议.HTTP是万维网的数据通信的 ...

  6. 干货分享:Essay写作收集论据的三个方法

    在很多时候,中国留学生写出的Essay在西方学术界看来是存在plagiarism的情况.并不是说咱们写的所有东西都是抄袭,而是思维逻辑和利用证据的方式与西方权威的academic writing不同. ...

  7. cf1200 E Compress Words(哈希)

    题意 有n个字符串,记为s1,s2……sn,s2与s1合并,合并的方式为:s1的后缀若与s2的前缀相同,就可以重叠起来,要最长的. 举个例子: “1333”  “33345” → “133345” s ...

  8. Java基础之IO流整理

    Java基础之IO流 Java IO流使用装饰器设计模式,因此如果不能理清其中的关系的话很容易把各种流搞混,此文将简单的几个流进行梳理,后序遇见新的流会继续更新(本文下方还附有xmind文件链接) 抽 ...

  9. mac安装浏览器同步测试工具

    1.安装node.js (1)打开终端,输入以下命令安装Homebrew ruby -e “$(curl -fsSL https://raw.githubusercontent.com/Homebre ...

  10. Mybatis基本配置(一)

    1. Mybatis介绍 MyBatis 是支持普通 SQL查询,存储过程和高级映射的优秀持久层框架.MyBatis 消除了几乎所有的JDBC代码和参数的手工设置以及结果集的检索.MyBatis 使用 ...