文档

Pytorch

更新时间:

EAS内置的Pytorch Processor支持将Pytorch标准的TorchScript格式的模型部署成在线服务。本文为您介绍如何部署及调用Pytorch模型服务。

Pytorch Processor版本说明

Pytorch支持多个版本,包括GPU和CPU版本,各个版本对应的Processor名称如下表所示。

Processor名称

Pytorch版本

是否支持GPU版本

pytorch_cpu_1.6

Pytorch 1.6

pytorch_cpu_1.7

Pytorch 1.7

pytorch_cpu_1.9

Pytorch 1.9

pytorch_cpu_1.10

Pytorch 1.10

pytorch_gpu_1.6

Pytorch 1.6

pytorch_gpu_1.7

Pytorch 1.7

pytorch_gpu_1.9

Pytorch 1.9

pytorch_gpu_1.10

Pytorch 1.10

步骤一:部署服务

使用eascmd客户端部署Pytorch模型服务时,您需要指定配置参数processor的取值为上述支持的Pytorch的Processor名称,服务配置文件示例如下。

{

  "name": "pytorch_resnet_example",
  "model_path": "http://examplebucket.oss-cn-shanghai.aliyuncs.com/models/resnet18.pt",
  "processor": "pytorch_cpu_1.6",
    "metadata": {
    "cpu": 1,
    "instance": 1,
    "memory": 1000
  }
}

关于如何使用客户端工具部署服务,详情请参见服务部署:EASCMD&DSW

您也可以通过控制台部署Pytorch模型服务,详情请参见服务部署:控制台

步骤二:调用服务

Pytorch服务输入输出格式为ProtoBuf,不是纯文本,而在线调式目前仅支持纯文本的输入输出数据,因此无法使用控制台的在线调试功能。

EAS提供了不同版本的SDK,对请求和响应数据进行了封装,且SDK内部包含了关于直连和容错相关的机制,推荐使用SDK来构建和发送请求。具体推理请求示例如下。

#!/usr/bin/env python

from eas_prediction import PredictClient
from eas_prediction import TorchRequest

if __name__ == '__main__':
    client = PredictClient('http://182848887922****.cn-shanghai.pai-eas.aliyuncs.com', 'pytorch_gpu_wl')
    client.init()

    req = TorchRequest()
    req.add_feed(0, [1, 3, 224, 224], TorchRequest.DT_FLOAT, [1] * 150528)
    # req.add_fetch(0)
    for x in range(0, 10):
        resp = client.predict(req)
        print(resp.get_tensor_shape(0))

关于代码中的参数配置说明及调用方法,详情请参见Python SDK使用说明

后续您也可以自行构建服务请求,详情请参见请求格式

请求格式

Pytorch Processor输入输出为ProtoBuf格式。当您使用SDK来发送请求时,SDK对请求进行了封装,您只需根据SDK提供的函数来构建请求即可。如果您希望自行构建服务请求,则可以参考如下pb定义来生成相关的代码,详情请参见TensorFlow服务请求构造

syntax = "proto3";

package pytorch.eas;
option cc_enable_arenas = true;

enum ArrayDataType {
  // Not a legal value for DataType. Used to indicate a DataType field
  // has not been set
  DT_INVALID = 0;

  // Data types that all computation devices are expected to be
  // capable to support
  DT_FLOAT = 1;
  DT_DOUBLE = 2;
  DT_INT32 = 3;
  DT_UINT8 = 4;
  DT_INT16 = 5;
  DT_INT8 = 6;
  DT_STRING = 7;
  DT_COMPLEX64 = 8;  // Single-precision complex
  DT_INT64 = 9;
  DT_BOOL = 10;
  DT_QINT8 = 11;     // Quantized int8
  DT_QUINT8 = 12;    // Quantized uint8
  DT_QINT32 = 13;    // Quantized int32
  DT_BFLOAT16 = 14;  // Float32 truncated to 16 bits.  Only for cast ops
  DT_QINT16 = 15;    // Quantized int16
  DT_QUINT16 = 16;   // Quantized uint16
  DT_UINT16 = 17;
  DT_COMPLEX128 = 18;  // Double-precision complex
  DT_HALF = 19;
  DT_RESOURCE = 20;
  DT_VARIANT = 21;  // Arbitrary C++ data types
}

// Dimensions of an array
message ArrayShape {
  repeated int64 dim = 1 [packed = true];
}

// Protocol buffer representing an array
message ArrayProto {
  // Data Type
  ArrayDataType dtype = 1;

  // Shape of the array.
  ArrayShape array_shape = 2;

  // DT_FLOAT
  repeated float float_val = 3 [packed = true];

  // DT_DOUBLE
  repeated double double_val = 4 [packed = true];

  // DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
  repeated int32 int_val = 5 [packed = true];

  // DT_STRING
  repeated bytes string_val = 6;

  // DT_INT64.
  repeated int64 int64_val = 7 [packed = true];

}


message PredictRequest {

  // Input tensors.
  repeated ArrayProto inputs = 1;

  // Output filter.
  repeated int32 output_filter = 2;
}

// Response for PredictRequest on successful run.
message PredictResponse {
  // Output tensors.
  repeated ArrayProto outputs = 1;
}
  • 本页导读 (1)
文档反馈