#include <fstream>
#include <utility>
#include <Eigen/Core>
#include <Eigen/Dense>
#include <iostream> #include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "opencv2/opencv.hpp" using namespace tensorflow::ops;
using namespace tensorflow;
using namespace std;
using namespace cv;
using tensorflow::Flag;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32 ; // 定义一个函数讲OpenCV的Mat数据转化为tensor,python里面只要对cv2.read读进来的矩阵进行np.reshape之后,
// 数据类型就成了一个tensor,即tensor与矩阵一样,然后就可以输入到网络的入口了,但是C++版本,我们网络开放的入口
// 也需要将输入图片转化成一个tensor,所以如果用OpenCV读取图片的话,就是一个Mat,然后就要考虑怎么将Mat转化为
// Tensor了
void CVMat_to_Tensor(Mat img,Tensor* output_tensor,int input_rows,int input_cols)
{
//imshow("input image",img);
//图像进行resize处理
resize(img,img,cv::Size(input_cols,input_rows));
//imshow("resized image",img); //归一化
img.convertTo(img,CV_32FC1);
img=-img/; //创建一个指向tensor的内容的指针
float *p = output_tensor->flat<float>().data(); //创建一个Mat,与tensor的指针绑定,改变这个Mat的值,就相当于改变tensor的值
cv::Mat tempMat(input_rows, input_cols, CV_32FC1, p);
img.convertTo(tempMat,CV_32FC1); // waitKey(0); } int main(int argc, char** argv )
{
/*--------------------------------配置关键信息------------------------------*/
string model_path="../inception_v3_2016_08_28_frozen.pb";
string image_path="../test.jpg";
int input_height =;
int input_width=;
string input_tensor_name="input";
string output_tensor_name="InceptionV3/Predictions/Reshape_1"; /*--------------------------------创建session------------------------------*/
Session* session;
Status status = NewSession(SessionOptions(), &session);//创建新会话Session /*--------------------------------从pb文件中读取模型--------------------------------*/
GraphDef graphdef; //Graph Definition for current model Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
if (!status_load.ok()) {
cout << "ERROR: Loading model failed..." << model_path << std::endl;
cout << status_load.ToString() << "\n";
return -;
}
Status status_create = session->Create(graphdef); //将模型导入会话Session中;
if (!status_create.ok()) {
cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
return -;
}
cout << "<----Successfully created session and load graph.------->"<< endl; /*---------------------------------载入测试图片-------------------------------------*/
cout<<endl<<"<------------loading test_image-------------->"<<endl;
Mat img=imread(image_path,);
if(img.empty())
{
cout<<"can't open the image!!!!!!!"<<endl;
return -;
} //创建一个tensor作为输入网络的接口
Tensor resized_tensor(DT_FLOAT, TensorShape({,input_height,input_width,})); //将Opencv的Mat格式的图片存入tensor
CVMat_to_Tensor(img,&resized_tensor,input_height,input_width); cout << resized_tensor.DebugString()<<endl; /*-----------------------------------用网络进行测试-----------------------------------------*/
cout<<endl<<"<-------------Running the model with test_image--------------->"<<endl;
//前向运行,输出结果一定是一个tensor的vector
vector<tensorflow::Tensor> outputs;
string output_node = output_tensor_name;
Status status_run = session->Run({{input_tensor_name, resized_tensor}}, {output_node}, {}, &outputs); if (!status_run.ok()) {
cout << "ERROR: RUN failed..." << std::endl;
cout << status_run.ToString() << "\n";
return -;
}
//把输出值给提取出来
cout << "Output tensor size:" << outputs.size() << std::endl;
for (std::size_t i = ; i < outputs.size(); i++) {
cout << outputs[i].DebugString()<<endl;
} Tensor t = outputs[]; // Fetch the first tensor
auto tmap = t.tensor<float, >(); // Tensor Shape: [batch_size, target_class_num]
int output_dim = t.shape().dim_size(); // Get the target_class_num from 1st dimension // Argmax: Get Final Prediction Label and Probability
int output_class_id = -;
double output_prob = 0.0;
for (int j = ; j < output_dim; j++)
{
cout << "Class " << j << " prob:" << tmap(, j) << "," << std::endl;
if (tmap(, j) >= output_prob) {
output_class_id = j;
output_prob = tmap(, j);
}
} // 输出结果
cout << "Final class id: " << output_class_id << std::endl;
cout << "Final class prob: " << output_prob << std::endl; return ;
}

tensorflow C++接口调用图像分类pb模型代码的更多相关文章

  1. tensorflow c++ API加载.pb模型文件并预测图片

    tensorflow  python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测 #include <iostream> #include <map> # ...

  2. tensorflow C++接口调用目标检测pb模型代码

    #include <iostream> #include "tensorflow/cc/ops/const_op.h" #include "tensorflo ...

  3. tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测

    由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...

  4. PyTorch Hub发布!一行代码调用最潮模型,图灵奖得主强推

    为了调用各种经典机器学习模型,今后你不必重复造轮子了. 刚刚,Facebook宣布推出PyTorch Hub,一个包含计算机视觉.自然语言处理领域的诸多经典模型的聚合中心,让你调用起来更方便. 有多方 ...

  5. 导出pb模型之后测试的python代码

    链接:https://blog.csdn.net/thriving_fcl/article/details/75213361 saved_model模块主要用于TensorFlow Serving.T ...

  6. 查看tensorflow pb模型文件的节点信息

    查看tensorflow pb模型文件的节点信息: import tensorflow as tf with tf.Session() as sess: with open('./quantized_ ...

  7. 一行代码搞定Dubbo接口调用

    本文来自网易云社区 作者:吕彦峰 在工作中我们经常遇到关于接口测试的问题,无论是对于QA同学还是开发同学都会有远程接口调用的需求.针对这种问题我研发了一个工具包,专门用于远程Dubbo调用,下面就让我 ...

  8. 将keras的h5模型转换为tensorflow的pb模型

    h5_to_pb.py from keras.models import load_model import tensorflow as tf import os import os.path as ...

  9. 深度学习Tensorflow生产环境部署(下·模型部署篇)

    前一篇讲过环境的部署篇,这一次就讲讲从代码角度如何导出pb模型,如何进行服务调用. 1 hello world篇 部署完docker后,如果是cpu环境,可以直接拉取tensorflow/servin ...

随机推荐

  1. oracle问题:char类型数据查询不到

    select distinct id,name from test_table b where b.ID='001' ; id为char字段类型,使用该语句查询不出数据. 解决方法:加trim().改 ...

  2. python连接 ssh

    import paramiko # private = paramiko.RSAKey.from_private_key() 秘钥 trans = paramiko.Transport((" ...

  3. 【剑指Offer】面试题07. 重建二叉树

    题目 输入某二叉树的前序遍历和中序遍历的结果,请重建该二叉树.假设输入的前序遍历和中序遍历的结果中都不含重复的数字. 例如,给出 前序遍历 preorder = [3,9,20,15,7] 中序遍历 ...

  4. 【golang】golang的一些知识要点

    特殊常量iota: 1.iota的值在遇到const关键字时将被重置为0 2.const中每新增一行常量声明将使iota计数一次,也就是自动加一. 3.iota只能在常量定义中使用. iota常见使用 ...

  5. org.apache.jasper.JasperException: /WEB-INFO/jsp/product/edit.jsp(168,45)

    PWC6038:"${empty data.code?'001':fn:substring(data.code,0,8)}" contains invalid expression ...

  6. 量化投资_Multicharts数组操作函数_append()追加函数(自定义)

    1. Multicharts中关于数组的操作比较麻烦,而且当中所谓的动态数组的定义并不是像其他语言那种的概念.因此要对数组进行元素“”追加“”的话,需要重新更改数组的索引,然后再最后一个位置添加val ...

  7. 吴裕雄--天生自然C++语言学习笔记:C++ 数据类型

    使用编程语言进行编程时,需要用到各种变量来存储各种信息.变量保留的是它所存储的值的内存位置.这意味着,当创建一个变量时,就会在内存中保留一些空间. 可能需要存储各种数据类型(比如字符型.宽字符型.整型 ...

  8. Vuex 是什么

    Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式.它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方式发生变化.Vuex 也集成到 Vue 的官方调试工具  ...

  9. Thread--CountDownLatch & CyclicBarrier

    参考:http://www.importnew.com/21889.html CountDownLatch countDown() 方法执行完只是计数器减一, 并不会阻塞当前运行线程的的后续代码执行. ...

  10. [APIO2018]铁人两项(圆方树)

    过了14个月再重新看这题,发现圆方树从来就没有写过.然后写了这题发现自己APIO2018打铁的原因竟然是没开long long,将树的部分的O(n)写挂了(爆int),毕竟去年APIO时我啥都不会,连 ...