转载请注明出处:

https://www.cnblogs.com/darkknightzh/p/11332155.html

代码网址:

https://github.com/darkknightzh/TensorRT_pytorch

参考网址:

tensorrt安装包的sample/python目录

https://github.com/pytorch/examples/tree/master/mnist

此处代码使用的是tensorrt5.1.5

在安装完tensorrt之后,使用tensorrt主要包括下面几段代码:

1. 初始化

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # 此句代码中未使用,但是必须有。this is useful, otherwise stream = cuda.Stream() will cause 'explicit_context_dependent failed: invalid device context - no currently active context?'

如注解所示,import pycuda.autoinit这句话程序中未使用,但是必须包含,否则程序运行会出错。

2. 保存onnx模型

def saveONNX(model, filepath, c, h, w):
model = model.cuda()
dummy_input = torch.randn(1, c, h, w, device='cuda')
torch.onnx.export(model, dummy_input, filepath, verbose=True)

3. 创建tensorrt引擎

def build_engine(onnx_file_path):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) # INFO
# For more information on TRT basics, refer to the introductory samples.
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
if builder.platform_has_fast_fp16:
print('this card support fp16')
if builder.platform_has_fast_int8:
print('this card support int8') builder.max_workspace_size = 1 << 30
with open(onnx_file_path, 'rb') as model:
parser.parse(model.read())
return builder.build_cuda_engine(network) # This function builds an engine from a Caffe model.
def build_engine_int8(onnx_file_path, calib):
TRT_LOGGER = trt.Logger()
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
# We set the builder batch size to be the same as the calibrator's, as we use the same batches
# during inference. Note that this is not required in general, and inference batch size is
# independent of calibration batch size.
builder.max_batch_size = 1 # calib.get_batch_size()
builder.max_workspace_size = 1 << 30
builder.int8_mode = True
builder.int8_calibrator = calib
with open(onnx_file_path, 'rb') as model:
parser.parse(model.read()) # , dtype=trt.float32
return builder.build_cuda_engine(network)

4. 保存及载入引擎

def save_engine(engine, engine_dest_path):
buf = engine.serialize()
with open(engine_dest_path, 'wb') as f:
f.write(buf) def load_engine(engine_path):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) # INFO
with open(engine_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())

5. 分配缓冲区

class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) def __repr__(self):
return self.__str__() def allocate_buffers(engine):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in engine:
dtype = trt.nptype(engine.get_binding_dtype(binding))
# Allocate host and device buffers
host_mem = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
return inputs, outputs, bindings, stream

6. 前向推断

def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
# Transfer predictions back from the GPU.
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
return [out.host for out in outputs]

7. 矫正(Calibrator)

使用tensorrt的int8时,需要矫正。具体可参见test_onnx_int8及calibrator.py。

8. 具体的推断代码

img_numpy = img.ravel().astype(np.float32)
np.copyto(inputs[0].host, img_numpy)
output = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
output = [np.reshape(stage_i, (10)) for stage_i in output] # 有多个输出时遍历

9. 代码分析

程序中主要包括下面6个函数。

test_pytorch()            # 测试pytorch模型的代码
export_onnx() # 导出pytorch模型到onnx模型
test_onnx_fp32() # 测试tensorrt的fp32模型(有保存引擎的代码)
test_onnx_fp32_engine() # 测试tensorrt的fp32引擎的代码
test_onnx_int8() # 测试tensorrt的int8模型(有保存引擎的代码)
test_onnx_int8_engine() # 测试tensorrt的int8引擎的代码

10. 说明

9的部分函数中,最开始有一句:

torch.load('mnist_cnn_3.pth')    # 如果结果不对,加上这句话

因为有时候会碰到,不使用这句话,直接运行代码时,结果完全不正确;加上这句话之后,结果正确了。

具体原因为找到。。。也就先记在这里吧。

(原)pytorch中使用TensorRT的更多相关文章

  1. (原)CNN中的卷积、1x1卷积及在pytorch中的验证

    转载请注明处处: http://www.cnblogs.com/darkknightzh/p/9017854.html 参考网址: https://pytorch.org/docs/stable/nn ...

  2. pytorch中tensorboardX的用法

    在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...

  3. Pytorch中RoI pooling layer的几种实现

    Faster-RCNN论文中在RoI-Head网络中,将128个RoI区域对应的feature map进行截取,而后利用RoI pooling层输出7*7大小的feature map.在pytorch ...

  4. [转载]PyTorch中permute的用法

    [转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...

  5. Pytorch中的自编码(autoencoder)

    Pytorch中的自编码(autoencoder) 本文资料来源:https://www.bilibili.com/video/av15997678/?p=25 什么是自编码 先压缩原数据.提取出最有 ...

  6. pytorch中网络特征图(feture map)、卷积核权重、卷积核最匹配样本、类别激活图(Class Activation Map/CAM)、网络结构的可视化方法

    目录 0,可视化的重要性: 1,特征图(feture map) 2,卷积核权重 3,卷积核最匹配样本 4,类别激活图(Class Activation Map/CAM) 5,网络结构的可视化 0,可视 ...

  7. C++primer原书中的一个错误(派生类using声明对基类权限的影响)

    在C++primer 第4版的 15章 15.2.5中有以下这样一段提示: "注解:派生类能够恢复继承成员的訪问级别,但不能使訪问级别比基类中原来指定的更严格或者更宽松." 在vs ...

  8. PyTorch官方中文文档:PyTorch中文文档

    PyTorch中文文档 PyTorch是使用GPU和CPU优化的深度学习张量库. 说明 自动求导机制 CUDA语义 扩展PyTorch 多进程最佳实践 序列化语义 Package参考 torch to ...

  9. Java原子类中CAS的底层实现

    Java原子类中CAS的底层实现 从Java到c++到汇编, 深入讲解cas的底层原理. 介绍原理前, 先来一个Demo 以AtomicBoolean类为例.先来一个调用cas的demo. 主线程在f ...

随机推荐

  1. Chrome 禁止 http 自动转化为https

    Chrome 浏览器 地址栏中输入 chrome://net-internals/#hsts 在 Delete domain security policies 中输入项目的域名,并 Delete 删 ...

  2. 【PE512】Sums of totients of powers(欧拉函数)

    点此看题面 大致题意: 求\(\sum_{n=1}^{5*10^8}((\sum_{i=1}^n\phi(n^i))(mod\ n+1))\). 大力推式子 单独考虑\((\sum_{i=1}^n\p ...

  3. [BJOI2019]勘破神机(第一类斯特林数,斐波那契数列)

    真的是好题,只不过强行多合一有点过分了…… 题目大意: $T$ 组数据.每个测试点中 $m$ 相同. 对于每组数据,给定 $l,r,k$,请求出 $\dfrac{1}{r-l+1}\sum\limit ...

  4. Linux性能优化实战学习笔记:第四讲

    一.怎么查看系统上下文切换情况 通过前面学习我么你知道,过多的上下文切换,会把CPU时间消耗在寄存器.内核栈以及虚拟内存等数据的保存和回复上,缩短进程真正运行的时间,成了系统性能大幅下降的一个元凶 既 ...

  5. Gogs配置(本地安装篇-Debian)

    知识储备: 用过MySQL等 了解Linux最基本的操作 git常用操作 关于ssh 本文参考:linux上安装gogs搭建个人仓库 下载 https://github.com/gogs/gogs/r ...

  6. java导出标题多行且合并单元格的EXCEL

    场景:项目中遇到有需要导出Excel的需求,并且是多行标题且有合并单元格的,参考网上的文章,加上自己的理解,封装成了可自由扩展的导出工具 先上效果,再贴代码: 调用工具类进行导出: public st ...

  7. android 自定义gridview(导航)

    最近又重新做回安卓,做了个小项目.下绝心使用android studio,通过这一回实战,终于用上了.综合了前人的经验,搞了个自己满意的导航界面,用的是gridview. 代码: package co ...

  8. STM32Cube生成的HID项目,找不到hUsbDeviceFS

    症状 在main中尝试发消息给上位机: 解决方法 在STM32生成的HID项目里,默认是没有把hUsbDeviceFS导出的,需要修改usb_device.h文件,在USER CODE BEGIN V ...

  9. [BAT脚本] 1、BAT脚本FOR循环操作文件和命令返回实例

    Wednesday, 31. October 2018 08:18PM - beautifulzzzz 一.需求 需要在windows上实现一个bat脚本解析json,将json转换为自己想要的key ...

  10. torch_07_卷积神经网络案例分析

    1. LeNet(1998) """ note: LeNet: 输入体:32*32*1 卷积核:5*5 步长:1 填充:无 池化:2*2 代码旁边的注释:卷积或者池化后的 ...