为了使检测模型后处理部分更加高效,您可以采用TorchScript Custom C++ Operators将Python代码实现的逻辑替换成高效的C++实现,然后再导出TorchScript模型进行Blade优化。本文介绍如何使用Blade对TorchScript Custom C++ Operator实现的后处理逻辑的检测模型进行优化。
背景信息
RetinaNet是一种One-Stage RCNN类型的检测网络,基本结构由一个Backbone、多个子网及NMS后处理组成。许多训练框架中均实现了RetinaNet,典型的框架有Detectron2。上一篇中介绍了如何通过scripting_with_instances
方式导出RetinaNet(Detectron2)模型并使用Blade快速完成模型优化,详情请参见RetinaNet优化案例1:使用Blade优化RetinaNet(Detectron2)模型。
然而,检测模型的后处理部分代码通常需要执行计算和筛选boxes、nms等逻辑,通过Python实现该部分逻辑往往不高效。此时,您可以采用TorchScript Custom C++ Operators将Python代码实现的逻辑替换成高效的C++实现,然后再导出TorchScript模型并使用Blade进行模型优化。
使用限制
本文使用的环境需要满足以下版本限制:
系统环境:Linux系统中使用Python 3.6及其以上版本、GCC 5.4及其以上版本、Nvidia Tesla T4、CUDA 10.2、CuDNN 8.0.5.39。
框架:PyTorch 1.8.1及其以上版本、Detectron2 0.4.1及其以上版本。
推理优化工具:Blade 3.16.0及其以上版本。
操作流程
结合Blade和Custom C++ Operator优化模型的流程如下:
步骤一:创建带有Custom C++ Operators的PyTorch模型
使用TorchScript扩展实现RetinaNet的后处理部分。
使用Detectron2提供的
TracingAdapter
或scripting_with_instances
任何一种方式导出模型。调用
blade.optimize
接口优化模型,并保存优化后的模型。经过对优化前后的模型进行性能测试,如果对结果满意,可以加载优化后的模型进行推理。
步骤一:创建带有Custom C++ Operators的PyTorch模型
Blade工具与PyTorch TorchScript扩展机制无缝衔接,以下介绍如何使用TorchScript扩展实现RetinaNet的后处理部分。关于TorchScript Custom Operator的介绍请参见EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS。本文使用的RetinaNet后处理部分的程序逻辑来自NVIDIA开源社区,详情请参见Retinanet-Examples。本文抽取了核心的代码用于说明开发实现Custom Operator的流程。
下载示例代码并解压。
wget -nv https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/tutorials/retinanet_example/retinanet-examples.tar.gz -O retinanet-examples.tar.gz tar xvfz retinanet-examples.tar.gz 1>/dev/null
编译Custom C++ Operators。
PyTorch官方文档中(详情请参见EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS)提供了三种编译Custom Operators的方式:Building with CMake、Building with JIT Compilation及Building with Setuptools。这三种编译方式适用于不同场景,您可以根据自己的需求进行选择。本文为了简便,采用Building with JIT Compilation方式,示例代码如下所示。
import torch.utils.cpp_extension import os codebase="retinanet-examples" sources=['csrc/extensions.cpp', 'csrc/cuda/decode.cu', 'csrc/cuda/nms.cu',] sources = [os.path.join(codebase,src) for src in sources] torch.utils.cpp_extension.load( name="custom", sources=sources, build_directory=codebase, extra_include_paths=['/usr/local/TensorRT/include/', '/usr/local/cuda/include/', '/usr/local/cuda/include/thrust/system/cuda/detail'], extra_cflags=['-std=c++14', '-O2', '-Wall'], extra_cuda_cflags=[ '-std=c++14', '--expt-extended-lambda', '--use_fast_math', '-Xcompiler', '-Wall,-fno-gnu-unique', '-gencode=arch=compute_75,code=sm_75',], is_python_module=False, with_cuda=True, verbose=False, )
上述程序执行完成后,编译生成的custom.so会保存在retinanet-examples目录下。
使用Custom C++ Operators替换RetinaNet的后处理部分。
为了简洁,此处直接使用
adapter_forward
替换RetinaNet.forward
。adapter_forward
使用decode_cuda
和nms_cuda
两个Custom C++ Operators实现了RetinaNet的后处理部分,示例代码如下所示。import os import torch from typing import Tuple, Dict, List, Optional codebase="retinanet-examples" torch.ops.load_library(os.path.join(codebase, 'custom.so')) decode_cuda = torch.ops.retinanet.decode nms_cuda = torch.ops.retinanet.nms # 该函数的主要代码部分和RetinaNet.forward一样,但是后处理部分替换为通过decode_cuda和nms_cuda实现。 def adapter_forward(self, batched_inputs: Tuple[Dict[str, torch.Tensor]]): images = self.preprocess_image(batched_inputs) features = self.backbone(images.tensor) features = [features[f] for f in self.head_in_features] cls_heads, box_heads = self.head(features) cls_heads = [cls.sigmoid() for cls in cls_heads] box_heads = [b.contiguous() for b in box_heads] # 后处理部分。 strides = [images.tensor.shape[-1] // cls_head.shape[-1] for cls_head in cls_heads] decoded = [ decode_cuda( cls_head, box_head, anchor.view(-1), stride, self.test_score_thresh, self.test_topk_candidates, ) for stride, cls_head, box_head, anchor in zip( strides, cls_heads, box_heads, self.cell_anchors ) ] # non-maximum suppression部分。 decoded = [torch.cat(tensors, 1) for tensors in zip(decoded[0], decoded[1], decoded[2])] return nms_cuda(decoded[0], decoded[1], decoded[2], self.test_nms_thresh, self.max_detections_per_image) from detectron2.modeling.meta_arch import retinanet # 使用adapter_forward替换RetinaNet.forward。 retinanet.RetinaNet.forward = adapter_forward
步骤二:导出TorchScript模型
Detectron2是FAIR开源的灵活、可扩展、可配置的目标检测和图像分割训练框架。由于框架的灵活性,使用常规方法导出模型可能会失败或得到错误的导出结果。为了支持TorchScript部署,Detectron2提供了TracingAdapter
和scripting_with_instances
两种导出方式,详情请参见Detectron2 Usage。
Blade支持输入任意形式的TorchScript模型,如下以scripting_with_instances
为例,介绍导出模型的过程。
import torch
import numpy as np
from torch import Tensor
from torch.testing import assert_allclose
from detectron2 import model_zoo
from detectron2.export import scripting_with_instances
from detectron2.structures import Boxes
from detectron2.data.detection_utils import read_image
# 使用scripting_with_instances导出RetinaNet模型。
def load_retinanet(config_path):
model = model_zoo.get(config_path, trained=True).eval()
# Set a new cell_anchors attributes to PyTorch model.
model.cell_anchors = [c.contiguous() for c in model.anchor_generator.cell_anchors]
fields = {
"pred_boxes": Boxes,
"scores": Tensor,
"pred_classes": Tensor,
}
script_model = scripting_with_instances(model, fields)
return model, script_model
# 下载一张示例图片。
!wget http://images.cocodataset.org/val2017/000000439715.jpg -q -O input.jpg
img = read_image('./input.jpg')
img = torch.from_numpy(np.ascontiguousarray(img.transpose(2, 0, 1)))
# 尝试执行和对比导出模型前后的结果。
pytorch_model, script_model = load_retinanet("COCO-Detection/retinanet_R_50_FPN_3x.yaml")
with torch.no_grad():
batched_inputs = [{"image": img.float()}]
pred1 = pytorch_model(batched_inputs)
pred2 = script_model(batched_inputs)
assert_allclose(pred1[0], pred2[0])
步骤三:调用Blade优化模型
调用Blade优化接口。
调用
blade.optimize
接口对模型进行优化,代码示例如下。关于blade.optimize
接口详情,请参见优化PyTorch模型。import os import blade import torch # 加载custom c++ operator动态链接库。 codebase="retinanet-examples" torch.ops.load_library(os.path.join(codebase, 'custom.so')) blade_config = blade.Config() blade_config.gpu_config.disable_fp16_accuracy_check = True test_data = [(batched_inputs,)] # PyTorch的输入数据是List of Tuple。 with blade_config: optimized_model, opt_spec, report = blade.optimize( script_model, # 上一步导出的TorchScript模型。 'o1', # 开启Blade O1级别的优化。 device_type='gpu', # 目标设备为GPU。 test_data=test_data, # 给定一组测试数据,用于辅助优化及测试。 )
打印优化报告并保存模型。
Blade优化后的模型仍然是一个TorchScript模型。完成优化后,您可以通过如下代码打印优化报告并保存优化模型。
# 打印优化结果报表。 print("Report: {}".format(report)) # 保存优化后的模型。 torch.jit.save(script_model, 'script_model.pt') torch.jit.save(optimized_model, 'optimized.pt')
打印的优化报告如下所示,关于优化报告中的字段详情请参见优化报告。
Report: { "software_context": [ { "software": "pytorch", "version": "1.8.1+cu102" }, { "software": "cuda", "version": "10.2.0" } ], "hardware_context": { "device_type": "gpu", "microarchitecture": "T4" }, "user_config": "", "diagnosis": { "model": "unnamed.pt", "test_data_source": "user provided", "shape_variation": "undefined", "message": "Unable to deduce model inputs information (data type, shape, value range, etc.)", "test_data_info": "0 shape: (3, 480, 640) data type: float32" }, "optimizations": [ { "name": "PtTrtPassFp16", "status": "effective", "speedup": "3.92", "pre_run": "40.72 ms", "post_run": "10.39 ms" } ], "overall": { "baseline": "40.64 ms", "optimized": "10.41 ms", "speedup": "3.90" }, "model_info": { "input_format": "torch_script" }, "compatibility_list": [ { "device_type": "gpu", "microarchitecture": "T4" } ], "model_sdk": {} }
对优化前后的模型进行性能测试。
性能测试的代码示例如下所示。
import time @torch.no_grad() def benchmark(model, inp): for i in range(100): model(inp) torch.cuda.synchronize() start = time.time() for i in range(200): model(inp) torch.cuda.synchronize() elapsed_ms = (time.time() - start) * 1000 print("Latency: {:.2f}".format(elapsed_ms / 200)) # 对优化前的模型测速。 benchmark(script_model, batched_inputs) # 对优化后的模型测速。 benchmark(optimized_model, batched_inputs)
本次测试的参考结果值如下。
Latency: 40.65 Latency: 10.46
上述结果表示同样执行200轮,优化前后的模型平均延时分别是40.65 ms和10.46 ms。
步骤四:加载运行优化后的模型
- 可选:在试用阶段,您可以设置如下的环境变量,防止因为鉴权失败而程序退出。
export BLADE_AUTH_USE_COUNTING=1
- 获取鉴权。
加载运行优化后的模型。
Blade优化后的模型仍然是TorchScript,因此您无需切换环境即可加载优化后的结果。
import blade.runtime.torch import detectron2 import torch import numpy as np import os from detectron2.data.detection_utils import read_image from torch.testing import assert_allclose # 加载custom c++ operator动态链接库。 codebase="retinanet-examples" torch.ops.load_library(os.path.join(codebase, 'custom.so')) script_model = torch.jit.load('script_model.pt') optimized_model = torch.jit.load('optimized.pt') img = read_image('./input.jpg') img = torch.from_numpy(np.ascontiguousarray(img.transpose(2, 0, 1))) # 尝试执行和对比导出模型前后的结果。 with torch.no_grad(): batched_inputs = [{"image": img.float()}] pred1 = script_model(batched_inputs) pred2 = optimized_model(batched_inputs) assert_allclose(pred1[0], pred2[0], rtol=1e-3, atol=1e-2)