c++ 使用torchscript 加载训练好的pytorch模型
1.首先官网上下载libtorch,放到当前项目下
2.将pytorch训练好的模型使用torch.jit.trace导出为.pt格式
import torch
from skimage import io, transform, color
import numpy as np
import os
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") labels = ['cock', 'drawing', 'neutral', 'porn', 'sexy']
path = "test/n_1.jpg"
im = io.imread(path)
if im.shape[2] == 4:
im = color.rgba2rgb(im) im = transform.resize(im, (224, 224))
im = np.transpose(im, (2, 0, 1))
dummy_input = np.expand_dims(im, 0)
inp = torch.from_numpy(dummy_input)
inp = inp.float()
model = torch.load(
"models/resnet50-epoch-0-accu-0.9213857428381079.pth", map_location='cpu')
traced_script_module = torch.jit.trace(model, inp)
output = model(inp)
probs = F.softmax(output).detach().numpy()[0]
pred = np.argmax(probs) traced_script_module.save("models/traced_resnet_model.pt")
torchscript加载.pt模型
// One-stop header.
#include <torch/script.h> // headers for opencv
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/opencv.hpp> #include <cmath>
#include <iostream>
#include <memory>
#include <string>
#include <vector> #define kIMAGE_SIZE 224
#define kCHANNELS 3
#define kTOP_K 1 //print top k predicted results bool LoadImage(std::string file_name, cv::Mat &image)
{
image = cv::imread(file_name); // CV_8UC3
if (image.empty() || !image.data)
{
return false;
}
cv::cvtColor(image, image, CV_BGR2RGB);
// scale image to fit
cv::Size scale(kIMAGE_SIZE, kIMAGE_SIZE);
cv::resize(image, image, scale); // convert [unsigned int] to [float]
image.convertTo(image, CV_32FC3,1.0/255); return true;
} bool LoadImageNetLabel(std::string file_name,
std::vector<std::string> &labels)
{
std::ifstream ifs(file_name);
if (!ifs)
{
return false;
}
std::string line;
while (std::getline(ifs, line))
{
labels.push_back(line);
}
return true;
} int main(int argc, const char *argv[])
{
if (argc != 3)
{
std::cerr << "Usage:classifier <path-to-exported-script-module> <path-to-lable-file> " << std::endl;
return -1;
} //load model
torch::jit::script::Module module = torch::jit::load(argv[1]);
// to GPU
// module->to(at::kCUDA);
std::cout << "== ResNet50 loaded!\n"; //load labels(classes names)
std::vector<std::string> labels;
if (LoadImageNetLabel(argv[2], labels))
{
std::cout << "== Label loaded! Let's try it\n";
}
else
{
std::cerr << "Please check your label file path." << std::endl;
return -1;
} std::string file_name = "";
cv::Mat image;
while (true)
{
std::cout << "== Input image path: [enter q to exit]" << std::endl;
std::cin >> file_name;
if (file_name == "Q" || file_name == "q")
{
break;
}
if (LoadImage(file_name, image))
{
//read image tensor
auto input_tensor = torch::from_blob(
image.data, {1, kIMAGE_SIZE, kIMAGE_SIZE, kCHANNELS});
input_tensor = input_tensor.permute({0, 3, 1, 2});
input_tensor[0][0] = input_tensor[0][0].sub_(0.485).div_(0.229);
input_tensor[0][1] = input_tensor[0][1].sub_(0.456).div_(0.224);
input_tensor[0][2] = input_tensor[0][2].sub_(0.406).div_(0.225);
// to GPU
// input_tensor = input_tensor.to(at::kCUDA); torch::Tensor out_tensor = module.forward({input_tensor}).toTensor(); auto results = out_tensor.sort(-1, true);
auto softmaxs = std::get<0>(results)[0].softmax(0);
auto indexs = std::get<1>(results)[0]; for (int i = 0; i < kTOP_K; ++i)
{
auto idx = indexs[i].item<int>();
std::cout << " ============= Top-" << i + 1 << " =============" << std::endl;
std::cout << " Label: " << labels[idx] << std::endl;
std::cout << " With Probability: "
<< softmaxs[i].item<float>() * 100.0f << "%" << std::endl;
}
}
else
{
std::cout << "Can't load the image, please check your path." << std::endl;
}
}
}
CMakeLists.txt编译
cmake_minimum_required(VERSION 2.8)
project(predict_demo)
SET(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} "-std=c++11 -O3") set(OpenCV_DIR /home/buyizhiyou/opencv-3.4./build)
find_package(OpenCV REQUIRED)
find_package(Torch REQUIRED) # 添加头文件
include_directories( ${OpenCV_INCLUDE_DIRS} ) add_executable(resnet_demo resnet_demo.cpp)
target_link_libraries(resnet_demo ${TORCH_LIBRARIES} ${OpenCV_LIBS})
set_property(TARGET resnet_demo PROPERTY CXX_STANDARD )
运行
./resnet_demo models/traced_resnet_model.pt labels.txt
c++ 使用torchscript 加载训练好的pytorch模型的更多相关文章
- vue中加载three.js的gltf模型
vue中加载three.js的gltf模型 一.开始引入three.js相关插件.首先利用淘宝镜像,操作命令为: cnpm install three //npm install three也行 二. ...
- pytorch 加载训练好的模型做inference
前提: 模型参数和结构是分别保存的 1. 构建模型(# load model graph) model = MODEL() 2.加载模型参数(# load model state_dict) mode ...
- Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning
转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...
- Tensorflow加载预训练模型和保存模型
转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...
- 关于Tensorflow 加载和使用多个模型的方式
在Tensorflow中,所有操作对象都包装到相应的Session中的,所以想要使用不同的模型就需要将这些模型加载到不同的Session中并在使用的时候申明是哪个Session,从而避免由于Sessi ...
- [原][osgearth]earth文件加载道路一初步看见模型道路
时间是2017年2月5日17:16:32 由于OE2.9还没有发布,但是我又急于使用OE的道路. 所以,我先编译了正在github上调试中的OE2.9 github网址是:https://github ...
- Three.js中加载外部fbx格式的模型素材
index.html部分: index.js部分: Scene.js部分:
- 学习笔记TF016:CNN实现、数据集、TFRecord、加载图像、模型、训练、调试
AlexNet(Alex Krizhevsky,ILSVRC2012冠军)适合做图像分类.层自左向右.自上向下读取,关联层分为一组,高度.宽度减小,深度增加.深度增加减少网络计算量. 训练模型数据集 ...
- 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)
1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...
随机推荐
- Cesium学习笔记-工具篇20-PrimitiveTexture自定义渲染-贴图【转】
前几篇博客我们了解了自定义点.线.面绘制,这篇我们接着学习cesium自定义纹理贴图.我们完成点线面的绘制,只是绘制出了对象的框架,没有逼真的外观.逼真外观是需要设置材质来实现:Material . ...
- 并发下sftp连接报错——com.jcraft.jsch.JSchException: connection is closed by foreign host
当对单接口极限测试时,随着并发量上升,接口稳定性出现不稳定的情况,排查后台日志,发现报错在该接口调用sftp上传时出现问题(确切的是在初始化连接时失败) 原因:系统SSH终端连接数配置过小,查看虚拟机 ...
- python 代码性能分析 库
问题描述 1.Python开发的程序在使用过程中很慢,想确定下是哪段代码比较慢: 2.Python开发的程序在使用过程中占用内存很大,想确定下是哪段代码引起的: 解决方案 使用profile分析分析c ...
- matlab学习笔记10_6 字符串与数值间的转换以及进制之间的转换
一起来学matlab-matlab学习笔记10 10_6 字符串与数值间的转换以及进制之间的转换 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考书籍 <matlab 程序设计与综合 ...
- Appium查询元素方法
Appium查询元素有两种方式 一种是使用UI Automator: 参考 https://www.cnblogs.com/gongxr/p/10906736.html 另一种是使用appium的In ...
- LODOP问答部分链接
点击链接进入相关简短问答博文: 问答大全 纸张打印机 注册 table表格 clodop测试地址 字体 行间距.字间距 clodop回调函数 SET_PRINT_STYLEA 页眉页脚 超文本 条码 ...
- Java高级面试题整理(附答案)
这是我收集的10道高级Java面试问题列表.这些问题主要来自 Java 核心部分 ,不涉及 Java EE 相关问题.你可能知道这些棘手的 Java 问题的答案,或者觉得这些不足以挑战你的 Java ...
- sshpass命令使用
1.直接远程连接某主机 sshpass -p {密码} ssh {用户名}@{主机IP} 2.远程连接指定ssh的端口 sshpass -p {密码} ssh -p ${端口} {用户名}@{主机IP ...
- jenkins:新增节点是启动方式没有Launch agent by connecting it to the master
默认在这里的配置是禁用 所以启动方式只有两种,缺少Launch agent by connecting it to the master
- C/C++ 面试-内存对齐 即不同数据类型存储空间
下面列举了Dev-C++下基本类型所占位数和取值范围: 基本型 所占位数 取值范围 输入符举例 ...