tensorflow C++接口调用图像分类pb模型代码
#include <fstream>
#include <utility>
#include <Eigen/Core>
#include <Eigen/Dense>
#include <iostream> #include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "opencv2/opencv.hpp" using namespace tensorflow::ops;
using namespace tensorflow;
using namespace std;
using namespace cv;
using tensorflow::Flag;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32 ; // 定义一个函数讲OpenCV的Mat数据转化为tensor,python里面只要对cv2.read读进来的矩阵进行np.reshape之后,
// 数据类型就成了一个tensor,即tensor与矩阵一样,然后就可以输入到网络的入口了,但是C++版本,我们网络开放的入口
// 也需要将输入图片转化成一个tensor,所以如果用OpenCV读取图片的话,就是一个Mat,然后就要考虑怎么将Mat转化为
// Tensor了
void CVMat_to_Tensor(Mat img,Tensor* output_tensor,int input_rows,int input_cols)
{
//imshow("input image",img);
//图像进行resize处理
resize(img,img,cv::Size(input_cols,input_rows));
//imshow("resized image",img); //归一化
img.convertTo(img,CV_32FC1);
img=-img/; //创建一个指向tensor的内容的指针
float *p = output_tensor->flat<float>().data(); //创建一个Mat,与tensor的指针绑定,改变这个Mat的值,就相当于改变tensor的值
cv::Mat tempMat(input_rows, input_cols, CV_32FC1, p);
img.convertTo(tempMat,CV_32FC1); // waitKey(0); } int main(int argc, char** argv )
{
/*--------------------------------配置关键信息------------------------------*/
string model_path="../inception_v3_2016_08_28_frozen.pb";
string image_path="../test.jpg";
int input_height =;
int input_width=;
string input_tensor_name="input";
string output_tensor_name="InceptionV3/Predictions/Reshape_1"; /*--------------------------------创建session------------------------------*/
Session* session;
Status status = NewSession(SessionOptions(), &session);//创建新会话Session /*--------------------------------从pb文件中读取模型--------------------------------*/
GraphDef graphdef; //Graph Definition for current model Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
if (!status_load.ok()) {
cout << "ERROR: Loading model failed..." << model_path << std::endl;
cout << status_load.ToString() << "\n";
return -;
}
Status status_create = session->Create(graphdef); //将模型导入会话Session中;
if (!status_create.ok()) {
cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
return -;
}
cout << "<----Successfully created session and load graph.------->"<< endl; /*---------------------------------载入测试图片-------------------------------------*/
cout<<endl<<"<------------loading test_image-------------->"<<endl;
Mat img=imread(image_path,);
if(img.empty())
{
cout<<"can't open the image!!!!!!!"<<endl;
return -;
} //创建一个tensor作为输入网络的接口
Tensor resized_tensor(DT_FLOAT, TensorShape({,input_height,input_width,})); //将Opencv的Mat格式的图片存入tensor
CVMat_to_Tensor(img,&resized_tensor,input_height,input_width); cout << resized_tensor.DebugString()<<endl; /*-----------------------------------用网络进行测试-----------------------------------------*/
cout<<endl<<"<-------------Running the model with test_image--------------->"<<endl;
//前向运行,输出结果一定是一个tensor的vector
vector<tensorflow::Tensor> outputs;
string output_node = output_tensor_name;
Status status_run = session->Run({{input_tensor_name, resized_tensor}}, {output_node}, {}, &outputs); if (!status_run.ok()) {
cout << "ERROR: RUN failed..." << std::endl;
cout << status_run.ToString() << "\n";
return -;
}
//把输出值给提取出来
cout << "Output tensor size:" << outputs.size() << std::endl;
for (std::size_t i = ; i < outputs.size(); i++) {
cout << outputs[i].DebugString()<<endl;
} Tensor t = outputs[]; // Fetch the first tensor
auto tmap = t.tensor<float, >(); // Tensor Shape: [batch_size, target_class_num]
int output_dim = t.shape().dim_size(); // Get the target_class_num from 1st dimension // Argmax: Get Final Prediction Label and Probability
int output_class_id = -;
double output_prob = 0.0;
for (int j = ; j < output_dim; j++)
{
cout << "Class " << j << " prob:" << tmap(, j) << "," << std::endl;
if (tmap(, j) >= output_prob) {
output_class_id = j;
output_prob = tmap(, j);
}
} // 输出结果
cout << "Final class id: " << output_class_id << std::endl;
cout << "Final class prob: " << output_prob << std::endl; return ;
}
tensorflow C++接口调用图像分类pb模型代码的更多相关文章
- tensorflow c++ API加载.pb模型文件并预测图片
tensorflow python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测 #include <iostream> #include <map> # ...
- tensorflow C++接口调用目标检测pb模型代码
#include <iostream> #include "tensorflow/cc/ops/const_op.h" #include "tensorflo ...
- tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...
- PyTorch Hub发布!一行代码调用最潮模型,图灵奖得主强推
为了调用各种经典机器学习模型,今后你不必重复造轮子了. 刚刚,Facebook宣布推出PyTorch Hub,一个包含计算机视觉.自然语言处理领域的诸多经典模型的聚合中心,让你调用起来更方便. 有多方 ...
- 导出pb模型之后测试的python代码
链接:https://blog.csdn.net/thriving_fcl/article/details/75213361 saved_model模块主要用于TensorFlow Serving.T ...
- 查看tensorflow pb模型文件的节点信息
查看tensorflow pb模型文件的节点信息: import tensorflow as tf with tf.Session() as sess: with open('./quantized_ ...
- 一行代码搞定Dubbo接口调用
本文来自网易云社区 作者:吕彦峰 在工作中我们经常遇到关于接口测试的问题,无论是对于QA同学还是开发同学都会有远程接口调用的需求.针对这种问题我研发了一个工具包,专门用于远程Dubbo调用,下面就让我 ...
- 将keras的h5模型转换为tensorflow的pb模型
h5_to_pb.py from keras.models import load_model import tensorflow as tf import os import os.path as ...
- 深度学习Tensorflow生产环境部署(下·模型部署篇)
前一篇讲过环境的部署篇,这一次就讲讲从代码角度如何导出pb模型,如何进行服务调用. 1 hello world篇 部署完docker后,如果是cpu环境,可以直接拉取tensorflow/servin ...
随机推荐
- VMware虚拟机黑屏
引用自:VMware吧 近期很多朋友遇到了VMware Workstation 14开启或新建虚拟机后黑屏的现象,无法关机,软件也无法关闭 用任务管理器结束VMware后这个VMX进程也关不了 解决办 ...
- 在Centos安装redis-孙志奇
最近在阿里云服务器上部署redis,遇到了很多的问题,经过不懈的努力终于配置成功, 按照下面的步骤一步一步来就好了 wget http://download.redis.io/releases/red ...
- SpringMVC:提交日期类型报400错误解决方法
方法1:可以使用@ControllerAdvice增强Controller @ControllerAdvice public class BaseControllerAdvice { // 初始化绑定 ...
- poj 3262 Protecting the Flowers 贪心 牛吃花
Protecting the Flowers Time Limit: 2000MS Memory Limit: 65536K Total Submissions: 11402 Accepted ...
- 大二暑假第一周总结--初次安装配置Hadoop
本次配置主要使用的教程:http://dblab.xmu.edu.cn/blog/install-hadoop-in-centos/ 以下是自己在配置中的遇到的一些问题和解决方法,或者提示 一.使用虚 ...
- JS元素的左右移动
<!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...
- Vue.js(20)之 封装字母表(是这个名字吗0.0)
HTML结构: <template> <div class="alphabet-container"> <h1>alphabet 组件</ ...
- 服务器上安装解决ole错误
服务器上安装此插件 提取码:9kiw
- IDEA创建新文件时自动生成时间和作者
打开设置,打开下图的选项并且输入 /** * @author 你的名字 * @date ${DATE} ${TIME} */
- 基于仿生算法的智能系统I
仿生算法仿生算法是什么? 什么是仿生? 蜜蜂会造房子,人类就学习蜜蜂盖房子的方法,之后便有了航空建造工程的蜂窝结构. 仿生是模仿生物系统的功能和行为,来建造技术系统的一种科学方法.生活仿生作品现代的飞 ...