TorchEasyRec Processor

EAS内置的TorchEasyRec Processor支持将TorchEasyRecTorch训练的推荐模型部署为打分服务,并具备集成特征工程的能力。通过联合优化特征工程和Torch模型,Processor能够实现高性能的打分服务。本文为您介绍如何部署及调用TorchEasyRec模型服务。

背景信息

基于TorchEasyRec Processor的推荐引擎的架构图如下所示:

image

其中TorchEasyRec Processor主要包含以下模块:

  • Item Feature Cache:将FeatureStore里面的物品侧特征缓存到内存中,可以减少请求FeatureStore带来的网络开销和压力,同时可以提升推理服务的性能。当物品侧特征包含实时特征时,FeatureStore负责对实时特征进行同步。

  • 特征生成(FeatureGenerator):特征生成模块,通过配置文件定义特征变换的过程,使用一套C++代码保证离线和在线特征处理逻辑的一致性。

  • TorchModel:Torch模型,经过TorchEasyRecTorch训练后导出的ScriptedModel。

使用限制

仅支持使用通用型实例规格族g6、g7g8机型(仅支持Intel系列的CPU),并且支持T4、A10GPU型号,详情请参见通用型(g系列)。如果部署GPU服务,请确保CUDA Driver版本不低于535。

版本列表

TorchEasyRec Processor仍然在迭代中,建议您使用最新的版本部署推理服务,新的版本将提供更多的功能和更高的推理性能。已经发布的版本列表如下:

Processor名称

发布日期

Torch版本

FG版本

新增功能

easyrec-torch-0.1

20240910

2.4

0.2.9

  • 支持Feature Generator(FG)和FeatureStore Item Feature Cache。

  • 支持Torch模型CPUGPU推理。

  • 支持Input_Tile User类特征自动扩展。

  • 支持Faiss向量召回。

  • 支持在normal模式下预热。

easyrec-torch-0.2

20240930

2.4

0.2.9

  • FeatureDB支持复杂类型。

  • 加快FeatureStore数据初始化载入时间。

  • 优化bypass模式下debug_level。

  • 优化H2D。

easyrec-torch-0.3

20241014

2.4

0.2.9

  • FeatureStore支持JSON初始化。

  • proto重定义。

easyrec-torch-0.4

20241028

2.4

0.3.1

  • 修复Feature Generator(FG)复杂类型问题

easyrec-torch-0.5

20241114

2.4

0.3.1

  • 优化离在线一致性逻辑,Debug设置时,无论item是否存在都生成FG之后的特征信息。

easyrec-torch-0.6

20241118

2.4

0.3.6

  • 优化package打包过程,去除冗余头文件。

easyrec-torch-0.7

20241206

2.5

0.3.9

  • sequence primary key支持array。

  • 升级torch版本至2.5。

  • 升级FG版本至0.3.9。

easyrec-torch-0.8

20241225

2.5

0.3.9

  • 升级TensorRT SDK版本至2.5。

  • Torcheasyrec的模型输入支持int64类型。

  • 升级FeatureStore版本,解决Holo查特征问题。

  • 优化debug时运行效率和逻辑。

  • proto中新增item_features,支持从请求侧传入item特征。

步骤一:部署服务

  1. 准备服务配置文件torcheasyrec.json。

    您需要指定Processor种类easyrec-torch-{version},其中 {version} 请参照版本列表进行选择。JSON配置文件内容示例如下:

    使用FG的示例(fg_mode='normal')

    {
      "metadata": {
        "instance": 1,
        "name": "alirec_rank_with_fg",
        "rpc": {
          "enable_jemalloc": 1,
          "max_queue_size": 256,
          "worker_threads": 16
        }
      },
      "cloud": {
            "computing": {
                "instance_type": "ecs.gn6i-c16g1.4xlarge"
            }
      },
      "model_config": {
        "fg_mode": "normal",
        "fg_threads": 8,
        "region": "YOUR_REGION",
        "fs_project": "YOUR_FS_PROJECT",
        "fs_model": "YOUR_FS_MODEL",
        "fs_entity": "item",
        "load_feature_from_offlinestore": true,
        "access_key_id":"YOUR_ACCESS_KEY_ID",
        "access_key_secret":"YOUR_ACCESS_KEY_SECRET"
      },
      "storage": [
        {
          "mount_path": "/home/admin/docker_ml/workspace/model/",
          "oss": {
            "path": "oss://xxx/xxx/export",
            "readOnly": false
          },
          "properties": {
            "resource_type": "code"
          }
        }
      ],
      "processor":"easyrec-torch-0.3"
    }

    不使用FG的示例(fg_mode='bypass')

    {
      "metadata": {
        "instance": 1,
        "name": "alirec_rank_no_fg",
        "rpc": {
          "enable_jemalloc": 1,
          "max_queue_size": 256,
          "worker_threads": 16
        }
      },
      "cloud": {
            "computing": {
                "instance_type": "ecs.gn6i-c16g1.4xlarge"
            }
      },
      "model_config": {
        "fg_mode": "bypass"
      },
      "storage": [
        {
          "mount_path": "/home/admin/docker_ml/workspace/model/",
          "oss": {
            "path": "oss://xxx/xxx/export",
            "readOnly": false
          },
          "properties": {
            "resource_type": "code"
          }
        }
      ],
      "processor":"easyrec-torch-0.3"
    }

    其中关键参数说明如下,其他参数说明,请参见服务模型所有相关参数说明

    参数

    是否必选

    描述

    示例

    processor

    TorchEasyRec Processor。

    "processor":"easyrec-torch-0.3"

    path

    表示服务存储挂载的对象存储OSS路径,用来存放模型文件。

    "path": "oss://examplebucket/xxx/export"

    fg_mode

    用于指定特征工程模式,取值如下:

    • bypass(默认值):不使用FG,仅部署Torch模型。

      • 适用于自定义特征处理的场景。

      • 该模式下不需要配置Processor访问FeatureStore相关参数。

    • normal:使用FG。通常配合TorchEasyRec进行模型训练。

    "fg_mode": "normal"

    fg_threads

    用于单请求执行FG的并发线程数。

    "fg_threads": 15

    outputs

    Torch模型预测的输出变量名称,如probs_ctr。若存在多个则用半角逗号(,)分隔。默认输出所有变量。

    "outputs":"probs_ctr,probs_cvr"

    item_empty_score

    Item ID不存在时,默认的打分情况。默认值为0。

    "item_empty_score": -1

    Processor召回相关参数

    faiss_neigh_num

    FAISS向量召回数量。默认从请求体(Request)中的faiss_neigh_num字段获取;若该字段未提供,则会读取model_config配置中的faiss_neigh_num值,其默认设置为1。

    "faiss_neigh_num":200

    faiss_nprobe

    nprobe参数指定检索过程中检索到的簇的数量,默认值为800。FAISS中的倒排文件索引是将数据划分为多个小的簇(或组),并为每个簇维护一个倒排列表。更大的 nprobe 值通常会导致更高的检索精度,但会增加计算成本和搜索时间;反之则会降低精度但加快速度。

    "faiss_nprobe" : 700

    Processor访问FeatureStore相关参数

    fs_project

    FeatureStore项目名称,使用FeatureStore时需指定该字段。 关于FeatureStore的详细介绍,请参见配置FeatureStore项目

    "fs_project": "fs_demo"

    fs_model

    FeatureStore模型特征名称。

    "fs_model": "fs_rank_v1"

    fs_entity

    FeatureStore实体名称。

    "fs_entity": "item"

    region

    FeatureStore产品所在的地域,例如华北2(北京)配置为cn-beijing。更多地域配置说明,请参见服务接入点

    "region": "cn-beijing"

    access_key_id

    FeatureStore产品的AccessKey ID。

    "access_key_id": "xxxxx"

    access_key_secret

    FeatureStore产品的AccessKey Secret。

    "access_key_secret": "xxxxx"

    load_feature_from_offlinestore

    离线特征是否直接从FeatureStore OfflineStore中获取数据,取值如下:

    • True:是,会从FeatureStore OfflineStore中获取数据。

    • False(默认值):否,会从FeatureStore OnlineStore中获取数据。

    "load_feature_from_offlinestore": True

    featuredb_username

    FeatureDB用户名。

    "featuredb_username":"xxx"

    featuredb_password

    FeatureDB密码。

    "featuredb_passwd":"xxx"

    input_tile:特征自动扩展相关参数

    INPUT_TILE

    支持Feature自动扩展,对于一次请求中值都相同的特征(例如user_id),只需传递一个值即可,这有助于减少请求大小、网络传输时间和计算时间。

    该功能必须在normal模式下使用,需要与TorchEasyRec配合使用,并且在导出时设置相应的环境变量。目前系统默认从TorchEasyRec导出模型目录下的model_acc.json文件中读取INPUT_TILE值,如果该文件不存在,则会读取环境变量里的值。

    开启后:

    • 环境变量值设置为2:User侧特征FG仅计算一次。

    • 环境变量值设置为3:User侧特征FG仅计算一次,系统会将UserItemEmbedding分开计算,并且User侧的Embedding仅计算一次。适用于User侧特征比较多的情况。

    "processor_envs":

    [

    {

    "name": "INPUT_TILE",

    "value": "2"

    }

    ]

    NO_GRAD_GUARD

    推理时禁止梯度计算,会停止跟踪操作,从而不构建计算图。

    说明

    当设置为1时,可能会出现部分模型不兼容的情况。如果在第二次运行推理过程中遇到卡顿问题,可以通过添加环境变量PYTORCH_TENSOREXPR_FALLBACK=2来解决,这样可以跳过编译步骤,同时保留一定的图优化功能。

    "processor_envs":

    [

    {

    "name": "NO_GRAD_GUARD",

    "value": "1"

    }

    ]

  2. 部署TorchEasyRec模型服务。您可以任意选择一种部署方式:

    JSON独立部署(推荐)

    具体操作步骤如下:

    1. 登录PAI控制台,在页面上方选择目标地域,并在右侧选择目标工作空间,然后单击进入EAS

    2. 模型在线服务(EAS)页面,单击部署服务,然后在自定义模型部署区域,单击JSON独立部署

    3. JSON文本编辑框中,填入已准备好的JSON配置文件内容,然后单击部署

    eascmd客户端部署

    1. 下载并认证客户端,以Windows 64版本为例。

    2. 打开终端工具,在JSON文件所在目录,使用以下命令创建服务。更多操作说明,请参见命令使用说明

      eascmdwin64.exe create <service.json>

      其中:<service.json>需要替换为您已创建的JSON文件名称。例如torcheasyrec.json。

步骤二:调用服务

TorchEasyRec模型服务部署完成后,按照以下操作步骤查看服务调用信息:

  1. 登录PAI控制台,在页面上方选择目标地域,并在右侧选择目标工作空间,然后单击进入EAS

  2. 单击目标服务的服务方式列下的调用信息,查看服务的访问地址和Token信息。image

TorchEasyRec模型服务的输入输出格式为Protobuf格式,根据是否使用FG,分为以下两种调用方法:

使用FG(fg_mode='normal')

支持以下两种调用方法:

使用EAS Java SDK

在执行代码前,您需要配置Maven环境,配置详情请参见Java SDK使用说明。请求服务alirec_rank_with_fg的示例代码如下:

package com.aliyun.openservices.eas.predict;

import com.aliyun.openservices.eas.predict.http.Compressor;
import com.aliyun.openservices.eas.predict.http.HttpConfig;
import com.aliyun.openservices.eas.predict.http.PredictClient;
import com.aliyun.openservices.eas.predict.proto.TorchRecPredictProtos;
import com.aliyun.openservices.eas.predict.request.TorchRecRequest;
import com.aliyun.openservices.eas.predict.proto.TorchPredictProtos.ArrayProto;

import java.util.*;


public class TorchRecPredictTest {
    public static PredictClient InitClient() {
        return new PredictClient(new HttpConfig());
    }

    public static TorchRecRequest buildPredictRequest() {
        TorchRecRequest TorchRecRequest = new TorchRecRequest();
        TorchRecRequest.appendItemId("7033");

        TorchRecRequest.addUserFeature("user_id", 33981,"int");

        ArrayList<Double> list = new ArrayList<>();
        list.add(0.24689289764507472);
        list.add(0.005758482924454689);
        list.add(0.6765301324940026);
        list.add(0.18137273055602343);
        TorchRecRequest.addUserFeature("raw_3", list,"List<double>");

        Map<String,Integer> myMap =new LinkedHashMap<>();
        myMap.put("866", 4143);
        myMap.put("1627", 2451);
        TorchRecRequest.addUserFeature("map_1", myMap,"map<string,int>");

        ArrayList<ArrayList<Float>> list2 = new ArrayList<>();
        ArrayList<Float> innerList1 = new ArrayList<>();
        innerList1.add(1.1f);
        innerList1.add(2.2f);
        innerList1.add(3.3f);
        list2.add(innerList1);
        ArrayList<Float> innerList2 = new ArrayList<>();
        innerList2.add(4.4f);
        innerList2.add(5.5f);
        list2.add(innerList2);
        TorchRecRequest.addUserFeature("click", list2,"list<list<float>>");

        TorchRecRequest.addContextFeature("id_2", list,"List<double>");
        TorchRecRequest.addContextFeature("id_2", list,"List<double>");

        System.out.println(TorchRecRequest.request);
        return TorchRecRequest;
    }

    public static void main(String[] args) throws Exception{
        PredictClient client = InitClient();
        client.setToken("tokenGeneratedFromService");
        client.setEndpoint("175805416243****.cn-beijing.pai-eas.aliyuncs.com");
        client.setModelName("alirec_rank_with_fg");
        client.setRequestTimeout(100000);


        testInvoke(client);
        testDebugLevel(client);
        client.shutdown();
    }

    public static void testInvoke(PredictClient client) throws Exception {
        long startTime = System.currentTimeMillis();
        TorchRecPredictProtos.PBResponse response = client.predict(buildPredictRequest());
        for (Map.Entry<String, ArrayProto> entry : response.getMapOutputsMap().entrySet()) {

            System.out.println("Key: " + entry.getKey() + ", Value: " + entry.getValue());
        }
        long endTime = System.currentTimeMillis();
        System.out.println("Spend Time: " + (endTime - startTime) + "ms");

    }

    public static void testDebugLevel(PredictClient client) throws Exception {
        long startTime = System.currentTimeMillis();
        TorchRecRequest request = buildPredictRequest();
        request.setDebugLevel(1);
        TorchRecPredictProtos.PBResponse response = client.predict(request);
        Map<String, String> genFeas = response.getGenerateFeaturesMap();
        for(String itemId: genFeas.keySet()) {
            System.out.println(itemId);
            System.out.println(genFeas.get(itemId));
        }
        long endTime = System.currentTimeMillis();
        System.out.println("Spend Time: " + (endTime - startTime) + "ms");

    }
}

其中:

  • client.setToken("tokenGeneratedFromService"):需要将括号里的配置设置为您的服务Token。例如MmFiMDdlO****wYjhhNjgwZmZjYjBjMTM1YjliZmNkODhjOGVi****

  • client.setEndpoint("175805416243****.cn-beijing.pai-eas.aliyuncs.com"):需要将括号里的配置设置为您的服务Endpoint。例如175805416243****.cn-beijing.pai-eas.aliyuncs.com

  • client.setModelName("alirec_rank_with_fg"):需要将括号里的配置设置为您的服务名称。

使用EAS Python SDK

在执行代码前,请先使用pip install -U eas-prediction --user命令安装或更新eas-prediction库,更多配置详情请参见Python SDK使用说明。示例代码如下:

from eas_prediction import PredictClient
from eas_prediction.torchrec_request import TorchRecRequest


if __name__ == '__main__':
    endpoint = 'http://localhost:6016'

    client = PredictClient(endpoint, '<YOUR_SERVICE_NAME>')
    client.set_token('<your_service_token>')
    client.init()
    torchrec_req = TorchRecRequest()

    torchrec_req.add_user_fea('user_id', 'u001d', "STRING")
    torchrec_req.add_user_fea('age', 12, "INT")
    torchrec_req.add_user_fea('weight', 129.8, "FLOAT")
    torchrec_req.add_item_id('item_0001')
    torchrec_req.add_item_id('item_0002')
    torchrec_req.add_item_id('item_0003')
    torchrec_req.add_user_fea("raw_3", [0.24689289764507472, 0.005758482924454689, 0.6765301324940026, 0.18137273055602343], "list<double>")
    torchrec_req.add_user_fea("raw_4", [0.9965264740966043, 0.659596586238391, 0.16396649403055896, 0.08364986620265635], "list<double>")
    torchrec_req.add_user_fea("map_1", {"0":0.37845234405201145}, "map<int,float>")
    torchrec_req.add_user_fea("map_2", {"866":4143,"1627":2451}, "map<int,int>")
    torchrec_req.add_context_fea("id_2", [866], "list<int>" )
    torchrec_req.add_context_fea("id_2", [7022,1], "list<int>" )
    torchrec_req.add_context_fea("id_2", [7022,1], "list<int>" )
    torchrec_req.add_user_fea("click", [[0.94433516,0.49145547], [0.94433516, 0.49145597]], "list<list<float>>")

    res = client.predict(torchrec_req)
    print(res)

其中关键配置说明如下:

  • endpoint:配置为您的服务访问地址,例如http://175805416243****.cn-beijing.pai-eas.aliyuncs.com/

  • <your_service_name>:替换为您的服务名称。

  • <your_service_token>:替换您的服务Token,例如MmFiMDdlO****wYjhhNjgwZmZjYjBjMTM1YjliZmNkODhjOGVi****

不使用FG(fg_mode='bypass')

使用EAS Java SDK

在执行代码前,您需要配置Maven环境,配置详情请参见Java SDK使用说明。请求服务alirec_rank_no_fg的示例代码如下:

package com.aliyun.openservices.eas.predict;

import java.util.List;
import java.util.Arrays;


import com.aliyun.openservices.eas.predict.http.PredictClient;
import com.aliyun.openservices.eas.predict.http.HttpConfig;
import com.aliyun.openservices.eas.predict.request.TorchDataType;
import com.aliyun.openservices.eas.predict.request.TorchRequest;
import com.aliyun.openservices.eas.predict.response.TorchResponse;

public class Test_Torch {
    public static PredictClient InitClient() {
        return new PredictClient(new HttpConfig());
    }

    public static TorchRequest buildPredictRequest() {
        TorchRequest request = new TorchRequest();
        float[] content = new float[2304000];
        for (int i = 0; i < content.length; i++) {
            content[i] = (float) 0.0;
        }
        long[] content_i = new long[900];
        for (int i = 0; i < content_i.length; i++) {
            content_i[i] = 0;
        }

        long[] a = Arrays.copyOfRange(content_i, 0, 300);
        float[] b = Arrays.copyOfRange(content, 0, 230400);
        request.addFeed(0, TorchDataType.DT_INT64, new long[]{300,3}, content_i);
        request.addFeed(1, TorchDataType.DT_FLOAT, new long[]{300,10,768}, content);
        request.addFeed(2, TorchDataType.DT_FLOAT, new long[]{300,768}, b);
        request.addFeed(3, TorchDataType.DT_INT64, new long[]{300}, a);
        request.addFetch(0);
        request.setDebugLevel(903);
        return request;
    }

    public static void main(String[] args) throws Exception {
        PredictClient client = InitClient();
        client.setToken("tokenGeneratedFromService");
        client.setEndpoint("175805416243****.cn-beijing.pai-eas.aliyuncs.com");
        client.setModelName("alirec_rank_no_fg");
        client.setIsCompressed(false);
        long startTime = System.currentTimeMillis();
        for (int i = 0; i < 10; i++) {
            TorchResponse response = null;
            try {
                response = client.predict(buildPredictRequest());
                List<Float> result = response.getFloatVals(0);
                System.out.print("Predict Result: [");
                for (int j = 0; j < result.size(); j++) {
                    System.out.print(result.get(j).floatValue());
                    if (j != result.size() - 1) {
                        System.out.print(", ");
                    }
                }
                System.out.print("]\n");
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        long endTime = System.currentTimeMillis();
        System.out.println("Spend Time: " + (endTime - startTime) + "ms");
        client.shutdown();
    }
}

其中:

  • client.setToken("tokenGeneratedFromService"):需要将括号里的配置设置为您的服务Token。例如MmFiMDdlO****wYjhhNjgwZmZjYjBjMTM1YjliZmNkODhjOGVi****

  • client.setEndpoint("175805416243****.cn-beijing.pai-eas.aliyuncs.com"):需要将括号里的配置设置为您的服务Endpoint。例如175805416243****.cn-beijing.pai-eas.aliyuncs.com

  • client.setModelName("alirec_rank_no_fg"):需要将括号里的配置设置为您的服务名称。

使用EAS Python SDK

在执行代码前,请先使用pip install -U eas-prediction --user命令安装或更新eas-prediction库,更多配置详情请参见Python SDK使用说明。请求服务alirec_rank_no_fg的示例代码如下:

from eas_prediction import PredictClient
from eas_prediction import TorchRequest

# snappy data
req = TorchRequest(False)

req.add_feed(0, [300, 3], TorchRequest.DT_INT64, [1] * 900)
req.add_feed(1, [300, 10, 768], TorchRequest.DT_FLOAT, [1.0] * 3 * 768000)
req.add_feed(2, [300, 768], TorchRequest.DT_FLOAT, [1.0] * 3 * 76800)
req.add_feed(3, [300], TorchRequest.DT_INT64, [1] * 300)


client = PredictClient('<your_endpoint>', '<your_service_name>')
client.set_token('<your_service_token>')

client.init()

resp = client.predict(req)
print(resp)

其中关键配置说明如下:

  • <your_endpoint>:替换为您的服务访问地址,例如http://175805416243****.cn-beijing.pai-eas.aliyuncs.com/

  • <your_service_name>:替换为您的服务名称。

  • <your_service_token>:替换您的服务Token,例如MmFiMDdlO****wYjhhNjgwZmZjYjBjMTM1YjliZmNkODhjOGVi****

有关访问服务返回的状态码的详细说明,请参见服务状态码说明。您也可以参考请求格式自行构建服务请求。

请求格式

客户端调用服务可以根据.proto文件手动生成预测的请求代码文件。如果您希望自行构建服务请求,则可以参考如下protobuf的定义来生成相应的代码:

pytorch_predict.proto:Torch模型的请求定义

syntax = "proto3";

package pytorch.eas;
option cc_enable_arenas = true;
option java_package = "com.aliyun.openservices.eas.predict.proto";
option java_outer_classname = "TorchPredictProtos";

enum ArrayDataType {
  // Not a legal value for DataType. Used to indicate a DataType field
  // has not been set.
  DT_INVALID = 0;
  
  // Data types that all computation devices are expected to be
  // capable to support.
  DT_FLOAT = 1;
  DT_DOUBLE = 2;
  DT_INT32 = 3;
  DT_UINT8 = 4;
  DT_INT16 = 5;
  DT_INT8 = 6;
  DT_STRING = 7;
  DT_COMPLEX64 = 8;  // Single-precision complex
  DT_INT64 = 9;
  DT_BOOL = 10;
  DT_QINT8 = 11;     // Quantized int8
  DT_QUINT8 = 12;    // Quantized uint8
  DT_QINT32 = 13;    // Quantized int32
  DT_BFLOAT16 = 14;  // Float32 truncated to 16 bits.  Only for cast ops.
  DT_QINT16 = 15;    // Quantized int16
  DT_QUINT16 = 16;   // Quantized uint16
  DT_UINT16 = 17;
  DT_COMPLEX128 = 18;  // Double-precision complex
  DT_HALF = 19;
  DT_RESOURCE = 20;
  DT_VARIANT = 21;  // Arbitrary C++ data types
}

// Dimensions of an array
message ArrayShape {
  repeated int64 dim = 1 [packed = true];
}

// Protocol buffer representing an array
message ArrayProto {
  // Data Type.
  ArrayDataType dtype = 1;

  // Shape of the array.
  ArrayShape array_shape = 2;

  // DT_FLOAT.
  repeated float float_val = 3 [packed = true];

  // DT_DOUBLE.
  repeated double double_val = 4 [packed = true];

  // DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
  repeated int32 int_val = 5 [packed = true];

  // DT_STRING.
  repeated bytes string_val = 6;

  // DT_INT64.
  repeated int64 int64_val = 7 [packed = true];

}


message PredictRequest {

  // Input tensors.
  repeated ArrayProto inputs = 1;

  // Output filter.
  repeated int32 output_filter = 2;

  // Input tensors for rec
  map<string, ArrayProto> map_inputs = 3;

  // debug_level for rec
  int32 debug_level = 100;
}

// Response for PredictRequest on successful run.
message PredictResponse {
  // Output tensors.
  repeated ArrayProto outputs = 1;
  // Output tensors for rec.
  map<string, ArrayProto> map_outputs = 2;
}

torchrec_predict.proto:Torch模型+FG的请求定义

syntax = "proto3";

option go_package = ".;torch_predict_protos";
option java_package = "com.aliyun.openservices.eas.predict.proto";
option java_outer_classname = "TorchRecPredictProtos";
package com.alibaba.pairec.processor;
import "pytorch_predict.proto";

//long->others
message LongStringMap {
  map<int64, string> map_field = 1;
}
message LongIntMap {
  map<int64, int32> map_field = 1;
}
message LongLongMap {
  map<int64, int64> map_field = 1;
}
message LongFloatMap {
  map<int64, float> map_field = 1;
}
message LongDoubleMap {
  map<int64, double> map_field = 1;
}

//string->others
message StringStringMap {
  map<string, string> map_field = 1;
}
message StringIntMap {
  map<string, int32> map_field = 1;
}
message StringLongMap {
  map<string, int64> map_field = 1;
}
message StringFloatMap {
  map<string, float> map_field = 1;
}
message StringDoubleMap {
  map<string, double> map_field = 1;
}

//int32->others
message IntStringMap {
  map<int32, string> map_field = 1;
}
message IntIntMap {
  map<int32, int32> map_field = 1;
}
message IntLongMap {
  map<int32, int64> map_field = 1;
}
message IntFloatMap {
  map<int32, float> map_field = 1;
}
message IntDoubleMap {
  map<int32, double> map_field = 1;
}

// list
message IntList {
  repeated int32 features = 1;
}
message LongList {
  repeated int64 features  = 1;
}

message FloatList {
  repeated float features = 1;
}
message DoubleList {
  repeated double features = 1;
}
message StringList {
  repeated string features = 1;
}

// lists
message IntLists {
  repeated IntList lists = 1;
}
message LongLists {
  repeated LongList lists = 1;
}

message FloatLists {
  repeated FloatList lists = 1;
}
message DoubleLists {
  repeated DoubleList lists = 1;
}
message StringLists {
  repeated StringList lists = 1;
}

message PBFeature {
  oneof value {
    int32 int_feature = 1;
    int64 long_feature = 2;
    string string_feature = 3;
    float float_feature = 4;
    double double_feature=5;

    LongStringMap long_string_map = 6; 
    LongIntMap long_int_map = 7; 
    LongLongMap long_long_map = 8; 
    LongFloatMap long_float_map = 9; 
    LongDoubleMap long_double_map = 10; 
    
    StringStringMap string_string_map = 11; 
    StringIntMap string_int_map = 12; 
    StringLongMap string_long_map = 13; 
    StringFloatMap string_float_map = 14; 
    StringDoubleMap string_double_map = 15; 

    IntStringMap int_string_map = 16; 
    IntIntMap int_int_map = 17; 
    IntLongMap int_long_map = 18; 
    IntFloatMap int_float_map = 19; 
    IntDoubleMap int_double_map = 20; 

    IntList int_list = 21; 
    LongList long_list =22;
    StringList string_list = 23;
    FloatList float_list = 24;
    DoubleList double_list = 25;

    IntLists int_lists = 26;
    LongLists long_lists =27;
    StringLists string_lists = 28;
    FloatLists float_lists = 29;
    DoubleLists double_lists = 30;
    
  }
}

// context features
message ContextFeatures {
  repeated PBFeature features = 1;
}

// PBRequest specifies the request for aggregator
message PBRequest {
  // debug mode
  int32 debug_level = 1;

  // user features, key is user input name
  map<string, PBFeature> user_features = 2;

  // item ids
  repeated string item_ids = 3;

  // context features for each item, key is context input name 
  map<string, ContextFeatures> context_features = 4;

  // number of nearest neighbors(items) to retrieve
  // from faiss
  int32 faiss_neigh_num = 5;

  // item features for each item, key is item input name 
  map<string, ContextFeatures> item_features = 6;
}

// PBResponse specifies the response for aggregator
message PBResponse {
  // torch output tensors
  map<string, pytorch.eas.ArrayProto> map_outputs = 1;

  // fg ouput features
  map<string, string> generate_features = 2;

  // all fg input features
  map<string, string> raw_features = 3;

  // item ids
  repeated string item_ids = 4;

}

debug_level说明如下:

说明

默认情况下无需配置,当您需要进行Debug调试时才需传入。

debug_level

说明

0

服务正常预测。

1

normal模式下,对请求的key做校验,并对FG的输入输出进行形状校验,同时保存输入特征和输出特征,但不进行预测。

2

normal模式下,对请求的key做校验,并对FG的输入输出进行形状校验,保存输入特征和输出特征,及模型输入的Tensor,进行预测。

3

normal模式下,对请求的key做校验,并对FG的输入输出进行形状校验,输出特征,不做预测。

100

normal模式下保存预测请求。

102

normal模式下进行向量召回,对请求的key做校验,对FG的输入输出进行形状校验,保存输入特征和输出特征,以及模型输入的Tensor、User Embedding结果。

903

打印每个阶段的预测时间。

服务状态码说明

访问TorchEasyRec服务时,可能返回的主要状态码说明如下。关于访问EAS服务返回的更多状态码说明,请参见附录:服务状态码说明

状态码

说明

200

服务正常返回。

400

请求输入有问题。

500

预测失败,详细请查看服务日志。