tensorflow C++接口调用图像分类pb模型代码
#include <fstream>
#include <utility>
#include <Eigen/Core>
#include <Eigen/Dense>
#include <iostream> #include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "opencv2/opencv.hpp" using namespace tensorflow::ops;
using namespace tensorflow;
using namespace std;
using namespace cv;
using tensorflow::Flag;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32 ; // 定义一个函数讲OpenCV的Mat数据转化为tensor,python里面只要对cv2.read读进来的矩阵进行np.reshape之后,
// 数据类型就成了一个tensor,即tensor与矩阵一样,然后就可以输入到网络的入口了,但是C++版本,我们网络开放的入口
// 也需要将输入图片转化成一个tensor,所以如果用OpenCV读取图片的话,就是一个Mat,然后就要考虑怎么将Mat转化为
// Tensor了
void CVMat_to_Tensor(Mat img,Tensor* output_tensor,int input_rows,int input_cols)
{
//imshow("input image",img);
//图像进行resize处理
resize(img,img,cv::Size(input_cols,input_rows));
//imshow("resized image",img); //归一化
img.convertTo(img,CV_32FC1);
img=-img/; //创建一个指向tensor的内容的指针
float *p = output_tensor->flat<float>().data(); //创建一个Mat,与tensor的指针绑定,改变这个Mat的值,就相当于改变tensor的值
cv::Mat tempMat(input_rows, input_cols, CV_32FC1, p);
img.convertTo(tempMat,CV_32FC1); // waitKey(0); } int main(int argc, char** argv )
{
/*--------------------------------配置关键信息------------------------------*/
string model_path="../inception_v3_2016_08_28_frozen.pb";
string image_path="../test.jpg";
int input_height =;
int input_width=;
string input_tensor_name="input";
string output_tensor_name="InceptionV3/Predictions/Reshape_1"; /*--------------------------------创建session------------------------------*/
Session* session;
Status status = NewSession(SessionOptions(), &session);//创建新会话Session /*--------------------------------从pb文件中读取模型--------------------------------*/
GraphDef graphdef; //Graph Definition for current model Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
if (!status_load.ok()) {
cout << "ERROR: Loading model failed..." << model_path << std::endl;
cout << status_load.ToString() << "\n";
return -;
}
Status status_create = session->Create(graphdef); //将模型导入会话Session中;
if (!status_create.ok()) {
cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
return -;
}
cout << "<----Successfully created session and load graph.------->"<< endl; /*---------------------------------载入测试图片-------------------------------------*/
cout<<endl<<"<------------loading test_image-------------->"<<endl;
Mat img=imread(image_path,);
if(img.empty())
{
cout<<"can't open the image!!!!!!!"<<endl;
return -;
} //创建一个tensor作为输入网络的接口
Tensor resized_tensor(DT_FLOAT, TensorShape({,input_height,input_width,})); //将Opencv的Mat格式的图片存入tensor
CVMat_to_Tensor(img,&resized_tensor,input_height,input_width); cout << resized_tensor.DebugString()<<endl; /*-----------------------------------用网络进行测试-----------------------------------------*/
cout<<endl<<"<-------------Running the model with test_image--------------->"<<endl;
//前向运行,输出结果一定是一个tensor的vector
vector<tensorflow::Tensor> outputs;
string output_node = output_tensor_name;
Status status_run = session->Run({{input_tensor_name, resized_tensor}}, {output_node}, {}, &outputs); if (!status_run.ok()) {
cout << "ERROR: RUN failed..." << std::endl;
cout << status_run.ToString() << "\n";
return -;
}
//把输出值给提取出来
cout << "Output tensor size:" << outputs.size() << std::endl;
for (std::size_t i = ; i < outputs.size(); i++) {
cout << outputs[i].DebugString()<<endl;
} Tensor t = outputs[]; // Fetch the first tensor
auto tmap = t.tensor<float, >(); // Tensor Shape: [batch_size, target_class_num]
int output_dim = t.shape().dim_size(); // Get the target_class_num from 1st dimension // Argmax: Get Final Prediction Label and Probability
int output_class_id = -;
double output_prob = 0.0;
for (int j = ; j < output_dim; j++)
{
cout << "Class " << j << " prob:" << tmap(, j) << "," << std::endl;
if (tmap(, j) >= output_prob) {
output_class_id = j;
output_prob = tmap(, j);
}
} // 输出结果
cout << "Final class id: " << output_class_id << std::endl;
cout << "Final class prob: " << output_prob << std::endl; return ;
}
tensorflow C++接口调用图像分类pb模型代码的更多相关文章
- tensorflow c++ API加载.pb模型文件并预测图片
tensorflow python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测 #include <iostream> #include <map> # ...
- tensorflow C++接口调用目标检测pb模型代码
#include <iostream> #include "tensorflow/cc/ops/const_op.h" #include "tensorflo ...
- tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...
- PyTorch Hub发布!一行代码调用最潮模型,图灵奖得主强推
为了调用各种经典机器学习模型,今后你不必重复造轮子了. 刚刚,Facebook宣布推出PyTorch Hub,一个包含计算机视觉.自然语言处理领域的诸多经典模型的聚合中心,让你调用起来更方便. 有多方 ...
- 导出pb模型之后测试的python代码
链接:https://blog.csdn.net/thriving_fcl/article/details/75213361 saved_model模块主要用于TensorFlow Serving.T ...
- 查看tensorflow pb模型文件的节点信息
查看tensorflow pb模型文件的节点信息: import tensorflow as tf with tf.Session() as sess: with open('./quantized_ ...
- 一行代码搞定Dubbo接口调用
本文来自网易云社区 作者:吕彦峰 在工作中我们经常遇到关于接口测试的问题,无论是对于QA同学还是开发同学都会有远程接口调用的需求.针对这种问题我研发了一个工具包,专门用于远程Dubbo调用,下面就让我 ...
- 将keras的h5模型转换为tensorflow的pb模型
h5_to_pb.py from keras.models import load_model import tensorflow as tf import os import os.path as ...
- 深度学习Tensorflow生产环境部署(下·模型部署篇)
前一篇讲过环境的部署篇,这一次就讲讲从代码角度如何导出pb模型,如何进行服务调用. 1 hello world篇 部署完docker后,如果是cpu环境,可以直接拉取tensorflow/servin ...
随机推荐
- Spark SQL 笔记
Spark SQL 简介 SparkSQL 的前身是 Shark, SparkSQL 产生的根本原因是其完全脱离了 Hive 的限制.(Shark 底层依赖于 Hive 的解析器, 查询优化器) Sp ...
- tableau中图形分析相关设置
1.柱形堆叠图单元格顶部显示总计值(可通过参考线实现) 2.调节图形单元格的宽窄度 (ctrl + 右键/左键) 3.折线图预测区间 趋势区间线 分析中预测并不是针对所有的日期格式均其作用,比如日期格 ...
- springboot入门学习1
springboot学习1 SpringBoot对Spring的缺点进行的改善和优化,基于约定优于配置的思想,可以让开发人员不必在配置与逻辑 业务之间进行思维的切换,全身心的投入到逻辑业务的代码编写中 ...
- 2.10 Jetpack LiveData部分
LiveData是一个可观察的数据持有者类,但和其他的可观察对象不同,它与生命周期相关联,比如Activity的生命周期.LiveData能确保仅在Activity处于活动状态下才会更新.也就是说当观 ...
- Java中多态的实例
public class cf { /** * 实际上这里涉及方法调用的优先问题, * 优先级由高到低依次为:this.show(O).super.show(O).this.show((super)O ...
- bugku-Web 文件包含2
打开网页查看源代码,在最上面发现upload.php,于是我们就访问这个文件:http://123.206.31.85:49166/index.php?file=upload.php 发现是让我们上传 ...
- 题解 P1654 【OSU!】
题面 一序列\(a\), 对于每一个\(i\)均有\(a_i\)有\(p_i\)的几率为1, 否则为\(0\) 求: \(a\)中极长全\(1\)子序列长度三次方之和的期望 前置知识 基本期望(期望的 ...
- Mybatis框架的简单配置
Mybatis 的配置 1.创建项目(当然,这是废话) 2.导包 下载mybatis-3.2.0版:https://repo1.maven.org/maven2/org/mybatis/mybatis ...
- BZOJ [Scoi2010]游戏
题解: 解法一:建立图论模型,发现只要联通块中有环则这个联通块中的值都可以被攻击到 如果是树,则只能攻击size-1个 解法二:二分图匹配,二分答案,看看是否能攻击到mid #include<i ...
- c++ 广度优先搜索
#include <iostream> using namespace std; ; ; // >=9皆可 struct node//声明图形顶点结构 { int vertex; s ...