tensorflow C++接口调用图像分类pb模型代码
#include <fstream>
#include <utility>
#include <Eigen/Core>
#include <Eigen/Dense>
#include <iostream> #include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "opencv2/opencv.hpp" using namespace tensorflow::ops;
using namespace tensorflow;
using namespace std;
using namespace cv;
using tensorflow::Flag;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32 ; // 定义一个函数讲OpenCV的Mat数据转化为tensor,python里面只要对cv2.read读进来的矩阵进行np.reshape之后,
// 数据类型就成了一个tensor,即tensor与矩阵一样,然后就可以输入到网络的入口了,但是C++版本,我们网络开放的入口
// 也需要将输入图片转化成一个tensor,所以如果用OpenCV读取图片的话,就是一个Mat,然后就要考虑怎么将Mat转化为
// Tensor了
void CVMat_to_Tensor(Mat img,Tensor* output_tensor,int input_rows,int input_cols)
{
//imshow("input image",img);
//图像进行resize处理
resize(img,img,cv::Size(input_cols,input_rows));
//imshow("resized image",img); //归一化
img.convertTo(img,CV_32FC1);
img=-img/; //创建一个指向tensor的内容的指针
float *p = output_tensor->flat<float>().data(); //创建一个Mat,与tensor的指针绑定,改变这个Mat的值,就相当于改变tensor的值
cv::Mat tempMat(input_rows, input_cols, CV_32FC1, p);
img.convertTo(tempMat,CV_32FC1); // waitKey(0); } int main(int argc, char** argv )
{
/*--------------------------------配置关键信息------------------------------*/
string model_path="../inception_v3_2016_08_28_frozen.pb";
string image_path="../test.jpg";
int input_height =;
int input_width=;
string input_tensor_name="input";
string output_tensor_name="InceptionV3/Predictions/Reshape_1"; /*--------------------------------创建session------------------------------*/
Session* session;
Status status = NewSession(SessionOptions(), &session);//创建新会话Session /*--------------------------------从pb文件中读取模型--------------------------------*/
GraphDef graphdef; //Graph Definition for current model Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
if (!status_load.ok()) {
cout << "ERROR: Loading model failed..." << model_path << std::endl;
cout << status_load.ToString() << "\n";
return -;
}
Status status_create = session->Create(graphdef); //将模型导入会话Session中;
if (!status_create.ok()) {
cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
return -;
}
cout << "<----Successfully created session and load graph.------->"<< endl; /*---------------------------------载入测试图片-------------------------------------*/
cout<<endl<<"<------------loading test_image-------------->"<<endl;
Mat img=imread(image_path,);
if(img.empty())
{
cout<<"can't open the image!!!!!!!"<<endl;
return -;
} //创建一个tensor作为输入网络的接口
Tensor resized_tensor(DT_FLOAT, TensorShape({,input_height,input_width,})); //将Opencv的Mat格式的图片存入tensor
CVMat_to_Tensor(img,&resized_tensor,input_height,input_width); cout << resized_tensor.DebugString()<<endl; /*-----------------------------------用网络进行测试-----------------------------------------*/
cout<<endl<<"<-------------Running the model with test_image--------------->"<<endl;
//前向运行,输出结果一定是一个tensor的vector
vector<tensorflow::Tensor> outputs;
string output_node = output_tensor_name;
Status status_run = session->Run({{input_tensor_name, resized_tensor}}, {output_node}, {}, &outputs); if (!status_run.ok()) {
cout << "ERROR: RUN failed..." << std::endl;
cout << status_run.ToString() << "\n";
return -;
}
//把输出值给提取出来
cout << "Output tensor size:" << outputs.size() << std::endl;
for (std::size_t i = ; i < outputs.size(); i++) {
cout << outputs[i].DebugString()<<endl;
} Tensor t = outputs[]; // Fetch the first tensor
auto tmap = t.tensor<float, >(); // Tensor Shape: [batch_size, target_class_num]
int output_dim = t.shape().dim_size(); // Get the target_class_num from 1st dimension // Argmax: Get Final Prediction Label and Probability
int output_class_id = -;
double output_prob = 0.0;
for (int j = ; j < output_dim; j++)
{
cout << "Class " << j << " prob:" << tmap(, j) << "," << std::endl;
if (tmap(, j) >= output_prob) {
output_class_id = j;
output_prob = tmap(, j);
}
} // 输出结果
cout << "Final class id: " << output_class_id << std::endl;
cout << "Final class prob: " << output_prob << std::endl; return ;
}
tensorflow C++接口调用图像分类pb模型代码的更多相关文章
- tensorflow c++ API加载.pb模型文件并预测图片
tensorflow python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测 #include <iostream> #include <map> # ...
- tensorflow C++接口调用目标检测pb模型代码
#include <iostream> #include "tensorflow/cc/ops/const_op.h" #include "tensorflo ...
- tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...
- PyTorch Hub发布!一行代码调用最潮模型,图灵奖得主强推
为了调用各种经典机器学习模型,今后你不必重复造轮子了. 刚刚,Facebook宣布推出PyTorch Hub,一个包含计算机视觉.自然语言处理领域的诸多经典模型的聚合中心,让你调用起来更方便. 有多方 ...
- 导出pb模型之后测试的python代码
链接:https://blog.csdn.net/thriving_fcl/article/details/75213361 saved_model模块主要用于TensorFlow Serving.T ...
- 查看tensorflow pb模型文件的节点信息
查看tensorflow pb模型文件的节点信息: import tensorflow as tf with tf.Session() as sess: with open('./quantized_ ...
- 一行代码搞定Dubbo接口调用
本文来自网易云社区 作者:吕彦峰 在工作中我们经常遇到关于接口测试的问题,无论是对于QA同学还是开发同学都会有远程接口调用的需求.针对这种问题我研发了一个工具包,专门用于远程Dubbo调用,下面就让我 ...
- 将keras的h5模型转换为tensorflow的pb模型
h5_to_pb.py from keras.models import load_model import tensorflow as tf import os import os.path as ...
- 深度学习Tensorflow生产环境部署(下·模型部署篇)
前一篇讲过环境的部署篇,这一次就讲讲从代码角度如何导出pb模型,如何进行服务调用. 1 hello world篇 部署完docker后,如果是cpu环境,可以直接拉取tensorflow/servin ...
随机推荐
- Python MySQL 教程
章节 Python MySQL 入门 Python MySQL 创建数据库 Python MySQL 创建表 Python MySQL 插入表 Python MySQL Select Python M ...
- Django static配置
STATIC_URL = '/static/' # HTML中使用的静态文件夹前缀 STATICFILES_DIRS = [ os.path.join(BASE_DIR, "static&q ...
- 2016蓝桥杯省赛C/C++A组第九题 密码脱落
题意: X星球的考古学家发现了一批古代留下来的密码. 这些密码是由A.B.C.D 四种植物的种子串成的序列. 仔细分析发现,这些密码串当初应该是前后对称的(也就是我们说的镜像串). 由于年代久远,其中 ...
- UVA - 11400 Lighting System Design(照明系统设计)(dp)
题意:共有n种(n<=1000)种灯泡,每种灯泡用4个数值表示.电压V(V<=132000),电源费用K(K<=1000),每个灯泡的费用C(C<=10)和所需灯泡的数量L(1 ...
- POJ 3994:Probability One
Probability One Time Limit: 1000MS Memory Limit: 65536K Total Submissions: 1674 Accepted: 1151 D ...
- 【LeetCode 】N皇后II
[问题]n 皇后问题研究的是如何将 n 个皇后放置在 n×n 的棋盘上,并且使皇后彼此之间不能相互攻击. 上图为 8 皇后问题的一种解法.给定一个整数 n,返回 n 皇后不同的解决方案的数量. 示例: ...
- Eclipse字体及背景色设置和工作空间字符编码设置
一.字体设置 Window->Preferences->General->Appearance->Colors and fonts->Basic->Text Fon ...
- JS隔行换色和全选的实现
<!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...
- 如何搞定Critical Thinking写作?
受中国传统教育模式与国外一流大学之间的差异的影响,在海外留学的学子们常常会在新的学习生活中面临许多难题,Critical Thinking就是其中之一.国内的教育方法常常以灌输式的教育模式为主,忽略了 ...
- POJ 1164:The Castle
The Castle Time Limit: 1000MS Memory Limit: 10000K Total Submissions: 6677 Accepted: 3767 Descri ...