现在的深度学习框架一般都是基于 Python 来实现,构建、训练、保存和调用模型都可以很容易地在 Python 下完成。但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直接调用 TensorFlow 的 C/C++ 接口来导入 TensorFlow 预训练好的模型。

1.环境配置 点此查看 C/C++ 接口的编译

2. 导入预定义的图和训练好的参数值

    // set up your input paths
const string pathToGraph = "/home/senius/python/c_python/test/model-10.meta";
const string checkpointPath = "/home/senius/python/c_python/test/model-10"; auto session = NewSession(SessionOptions()); // 创建会话
if (session == nullptr)
{
throw runtime_error("Could not create Tensorflow session.");
} Status status; // Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);  // 导入图模型
if (!status.ok())
{
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
} // Add the graph to the session
status = session->Create(graph_def.graph_def());  // 将图模型加入到会话中
if (!status.ok())
{
throw runtime_error("Error creating graph: " + status.ToString());
} // Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath; // 读取预训练好的权重
status = session->Run({{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor},}, {},
{graph_def.saver_def().restore_op_name()}, nullptr);
if (!status.ok())
{
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}

3. 准备测试数据

    const string filename = "/home/senius/python/c_python/test/04t30t00.npy";

    //Read TXT data to array
float Array[1681*41];
ifstream is(filename);
for (int i = 0; i < 1681*41; i++){
is >> Array[i];
}
is.close(); tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 41, 41, 41, 1}));
auto input_tensor_mapped = input_tensor.tensor<float, 5>(); float *pdata = Array; // copying the data into the corresponding tensor
for (int x = 0; x < 41; ++x)//depth
{
for (int y = 0; y < 41; ++y) {
for (int z = 0; z < 41; ++z) {
const float *source_value = pdata + x * 1681 + y * 41 + z;
input_tensor_mapped(0, x, y, z, 0) = *source_value;
}
}
}
  • 本例中输入数据是一个 [None, 41, 41, 41, 1] 的张量,我们需要先从 TXT 文件中读出测试数据,然后正确地填充到张量中去。

4. 前向传播得到预测值

    std::vector<tensorflow::Tensor> finalOutput;
std::string InputName = "X"; // Your input placeholder's name
std::string OutputName = "sigmoid"; // Your output tensor's name
vector<std::pair<string, Tensor> > inputs;
inputs.push_back(std::make_pair(InputName, input_tensor)); // Fill input tensor with your input data
session->Run(inputs, {OutputName}, {}, &finalOutput); auto output_y = finalOutput[0].scalar<float>();
std::cout << output_y() << "\n";
  • 通过给定输入和输出张量的名字,我们可以将测试数据传入到模型中,然后进行前向传播得到预测值。

5. 一些问题

  • 本模型是在 TensorFlow 1.4 下训练的,然后编译 TensorFlow 1.4 的 C++ 接口可以正常调用模型,但若是想调用更高版本训练好的模型,则会报错,据出错信息猜测可能是高版本的 TensorFlow 中添加了一些低版本没有的函数,所以不能正常运行。
  • 若是编译高版本的 TensorFlow ,比如最新的 TensorFlow 1.11 的 C++ 接口,则无论是调用旧版本训练的模型还是新版本训练的模型都不能正常运行。出错信息如下:Error loading checkpoint from /media/lab/data/yongsen/Tensorflow_test/test/model-40: Invalid argument: Session was not created with a graph before Run()!,网上暂时也查不到解决办法,姑且先放在这里。

6. 完整代码

#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/io_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/parsing_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/array_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/math_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/data_flow_ops.h> #include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
#include <fstream> using namespace std;
using namespace tensorflow;
using namespace tensorflow::ops; int main()
{
// set up your input paths
const string pathToGraph = "/home/senius/python/c_python/test/model-10.meta";
const string checkpointPath = "/home/senius/python/c_python/test/model-10"; auto session = NewSession(SessionOptions());
if (session == nullptr)
{
throw runtime_error("Could not create Tensorflow session.");
} Status status; // Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok())
{
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
} // Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok())
{
throw runtime_error("Error creating graph: " + status.ToString());
} // Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run({{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor},}, {},
{graph_def.saver_def().restore_op_name()}, nullptr);
if (!status.ok())
{
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
} cout << 1 << endl; const string filename = "/home/senius/python/c_python/test/04t30t00.npy"; //Read TXT data to array
float Array[1681*41];
ifstream is(filename);
for (int i = 0; i < 1681*41; i++){
is >> Array[i];
}
is.close(); tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 41, 41, 41, 1}));
auto input_tensor_mapped = input_tensor.tensor<float, 5>(); float *pdata = Array; // copying the data into the corresponding tensor
for (int x = 0; x < 41; ++x)//depth
{
for (int y = 0; y < 41; ++y) {
for (int z = 0; z < 41; ++z) {
const float *source_value = pdata + x * 1681 + y * 41 + z;
// input_tensor_mapped(0, x, y, z, 0) = *source_value;
input_tensor_mapped(0, x, y, z, 0) = 1;
}
}
} std::vector<tensorflow::Tensor> finalOutput;
std::string InputName = "X"; // Your input placeholder's name
std::string OutputName = "sigmoid"; // Your output placeholder's name
vector<std::pair<string, Tensor> > inputs;
inputs.push_back(std::make_pair(InputName, input_tensor)); // Fill input tensor with your input data
session->Run(inputs, {OutputName}, {}, &finalOutput); auto output_y = finalOutput[0].scalar<float>();
std::cout << output_y() << "\n"; return 0;
}
  • Cmakelist 文件如下
cmake_minimum_required(VERSION 3.8)
project(Tensorflow_test) set(CMAKE_CXX_STANDARD 11) set(SOURCE_FILES main.cpp) include_directories(
/home/senius/tensorflow-r1.4
/home/senius/tensorflow-r1.4/tensorflow/bazel-genfiles
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/protobuf/include
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/host_obj
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/proto
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/downloads/nsync/public
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/downloads/eigen
/home/senius/tensorflow-r1.4/bazel-out/local_linux-py3-opt/genfiles
) add_executable(Tensorflow_test ${SOURCE_FILES}) target_link_libraries(Tensorflow_test
/home/senius/tensorflow-r1.4/bazel-bin/tensorflow/libtensorflow_cc.so
/home/senius/tensorflow-r1.4/bazel-bin/tensorflow/libtensorflow_framework.so
)

获取更多精彩,请关注「seniusen」!

在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现的更多相关文章

  1. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过 ...

  2. TensorFlow 调用预训练好的模型—— Python 实现

    1. 准备预训练好的模型 TensorFlow 预训练好的模型被保存为以下四个文件 data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如 ...

  3. TensorFlow 同时调用多个预训练好的模型

    在某些任务中,我们需要针对不同的情况训练多个不同的神经网络模型,这时候,在测试阶段,我们就需要调用多个预训练好的模型分别来进行预测. 调用单个预训练好的模型请点击此处 弄明白了如何调用单个模型,其实调 ...

  4. 【猫狗数据集】使用预训练的resnet18模型

    数据集下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw提取码:2xq4 创建数据集:https://www.cnblogs.com/xi ...

  5. pytorch中如何使用预训练词向量

    不涉及具体代码,只是记录一下自己的疑惑. 我们知道对于在pytorch中,我们通过构建一个词向量矩阵对象.这个时候对象矩阵是随机初始化的,然后我们的输入是单词的数值表达,也就是一些索引.那么我们会根据 ...

  6. 转载:tensorflow保存训练后的模型

    训练完一个模型后,为了以后重复使用,通常我们需要对模型的结果进行保存.如果用Tensorflow去实现神经网络,所要保存的就是神经网络中的各项权重值.建议可以使用Saver类保存和加载模型的结果. 1 ...

  7. tensorflow 使用预训练好的模型的一部分参数

    vars = tf.global_variables() net_var = [var for var in vars if 'bi-lstm_secondLayer' not in var.name ...

  8. Tensorflow 用训练好的模型预测

    本节涉及点: 从命令行参数读取需要预测的数据 从文件中读取数据进行预测 从任意字符串中读取数据进行预测 一.从命令行参数读取需要预测的数据 训练神经网络是让神经网络具备可用性,真正使用神经网络时,需要 ...

  9. Tensorflow使用训练好的模型进行测试,发现计算速度越来越慢

    实验时要对多个NN模型进行对比,依次加载直到第8个模型时,发现运行速度明显变慢而且电脑开始卡顿,查看内存占用90+%. 原因:使用过的NN模型还会保存在内存,继续加载一方面使新模型加载特别特别慢,另一 ...

随机推荐

  1. Struts2 第三讲 -- Struts2的处理流程

    4.Struts2的处理流程 以下是struts-defautl.xml中的拦截器 建议通过这个struts-default的副本查看,更形象 它实现了很多的功能,其中包括国际化,文件上传,类型转换, ...

  2. IOS异步获取数据并刷新界面dispatch_async的使用方法

    在ios的开发和学习中多线程编程是必须会遇到并用到的.在java中以及Android开发中,大量的后台运行,异步消息队列,基本都是运用了多线程来实现. 同样在,在ios移动开发和Android基本是很 ...

  3. org.apache.tomcat.util.http.fileupload.FileUploadBase$FileSizeLimitExceededException: The field xxx exceeds its maximum permitted size of 1048576 bytes.

    springboot 通过MultipartFile接受前端传过来的文件时是有文件大小限制的(springboot内置tomact的的文件传输默认为1MB),我们可以通过配置改变它的大小限制 首先在启 ...

  4. Ubuntu 18.04添加新网卡

    在Ubuntu 18.04 LTS上配置IP地址的方法与旧方法有很大不同.与以前的版本不同,Ubuntu 18.04使用Netplan(一种新的命令行网络配置实用程序)来配置IP地址. 在这种新方法中 ...

  5. date 参数(option)-d

    记录这篇博客的原因是:鸟哥的linux教程中,关于date命令的部分缺少-d这个参数的介绍,并且12章中的shell编写部分有用到-d参数 date 参数(option)-d与--date=" ...

  6. Ubuntu16.04采用FastCGI方式部署Flask web框架

    1    部署nginx 1.1    安装nginx服务 root@desktop:~# apt-get install nginx -y 1.2    验证nginx服务是否启动 root@des ...

  7. TP5部署服务器问题总结

    及最近部署TP5遇到了很多坑,各种环境下都会出现一些问题,下面是我记录的排坑之路 先说最简单的lnmp一键安装包,我用的是1.5稳定版 安装命令:wget http://soft.vpser.net/ ...

  8. ELK 分布式日志实战

    一.  ELK 分布式日志实战介绍 此实战方案以 Elk 5.5.2 版本为准,分布式日志将以下图分布进行安装部署以及配置. 当Elk需监控应用日志时,需在应用部署所在的服务器中,安装Filebeat ...

  9. 记一次防火墙导致greenplum装机失败及定位修复过程

    一.问题现象 20180201:15:06:25:028653 gpinitsystem:sdw1-2:gpadmin-[INFO]:--------------------------------- ...

  10. print(__file__)返回<encoding error>的问题

    今天写了一下代码,本来是想得到当前文件的上面三层的目录的,结果返回的却是错误 import os import sys print(__file__) # 得到上上层目录的路径之后,加入到默认的环境变 ...