推荐使用PAI-EAS提供的官方SDK进行服务调用,从而有效减少编写调用逻辑的时间并提高调用稳定性。本文介绍官方Java SDK接口详情。同时,以字符串输入输出和TensorFlow输入输出为例,提供了使用Java SDK进行服务调用的完整程序示例。

添加依赖包

使用Java编写客户端代码时,需要使用Maven管理项目。因此,您必须在pom.xml文件中添加客户端所需的依赖包eas-sdk。依赖包eas-sdk的最新版本为2.0.3,pom.xml文件中的具体代码如下。
<dependency>
  <groupId>com.aliyun.openservices.eas</groupId>
  <artifactId>eas-sdk</artifactId>
  <version>2.0.3</version>
</dependency>

接口列表

接口 描述
PredictClient PredictClient(HttpConfig httpConfig)
  • 功能:PredictClient类构造器。
  • 参数:httpConfig表示HttpConfig类的实例对象。
void setToken(String token)
  • 功能:设置HTTP请求的Token参数。
  • 参数:token表示访问服务时需要使用的鉴权Token。
void setModelName(String modelName)
  • 功能:设置请求的在线预测服务的模型名称。
  • 参数:modelName表示所设置的模型名称。
void setEndpoint(String endpoint)
  • 功能:设置请求服务的Host和Port,格式为"host:port"
  • 参数:endpoint表示接收消息的终端地址。
void setDirectEndpoint(String endpoint)
  • 功能:设置VPC高速直连通道访问服务的Endpoint,例如,pai-eas-vpc.cn-shanghai.aliyuncs.com
  • 参数:endpont表示设置的访问服务地址。
void setRetryCount(boolean int retryCount)
  • 功能:设置失败重试次数。
  • 参数:retryCount表示失败的重试次数。
void setContentType(String contentType)
  • 功能:设置HTTP Client的Content类型,默认为"application/octet-stream"。
  • 参数:contentType表示发送数据流的类型。
void createChildClient(String token, String endpoint, String modelName)
  • 功能:创建子Client对象,共用父Client对象的线程池。该接口用于多线程预测。
  • 参数:
    • token:服务的鉴权Token。
    • endpoint:服务的Endpoint。
    • modelName:服务的名称。
void predict(TFRequest runRequest)
  • 功能:向在线预测服务提交一个TensorFlow请求。
  • 参数:runRequest表示TensorFlow请求的实例对象。
void predict(String requestContent)
  • 功能:向在线预测服务提交一个字符串请求。
  • 参数:requestContent表示字符串格式的请求内容。
void predict(byte[] requestContent)
  • 功能:向在线预测服务提交一个Byte数组请求。
  • 参数:requestContent表示Byte类型的请求内容。
HttpConfig void setIoThreadNum(int ioThreadNum)
  • 功能:设置HTTP请求的IO线程数,默认值为2。
  • 参数:ioThreadNum表示发送HTTP请求的IO线程数。
void setReadTimeout(int readTimeout)
  • 功能:设置Socket的读取超时时间,默认值为5000,表示5s。
  • 参数:readTimeout表示请求的读取超时时间。
void setConnectTimeout(int connectTimeout)
  • 功能:设置连接超时时间,默认值为5000,表示5s。
  • 参数:connectTimeout表示请求的连接超时时间。
void setMaxConnectionCount(int maxConnectionCount)
  • 功能:设置最大连接数,默认值为1000。
  • 参数:maxConnectionCount表示客户端连接池的最大连接数。
void setMaxConnectionPerRoute(int maxConnectionPerRoute)
  • 功能:设置每个路由的最大默认连接数,默认值为1000。
  • 参数:maxConnectionPerRoute表示每个路由上的默认最大连接数。
void setKeepAlive(boolean keepAlive)
  • 功能:设置HTTP服务的keep-alive
  • 参数:keepAlive表示是否开启连接的keep-alive机制,默认为true
int getErrorCode() 返回最近一次调用的状态码。
string getErrorMessage() 返回最近一次调用的状态信息。
TFRequest void setSignatureName(String value)
  • 功能:请求的在线服务的模型为TensorFlow的SavedModel格式时,设置请求模型的signatureDef的名称。
  • 参数:请求模型的signatureDef的名称。
void addFetch(String value)
  • 功能:请求TensorFlow的在线服务模型时,设置模型输出Tensor的别名。
  • 参数:value表示TensorFlow服务输出Tensor的别名。
void addFeed(String inputName, TFDataType dataType, long[]shape, ?[]content)
  • 功能:请求TensorFlow的在线预测服务模型时,设置需要输入的Tensor。
  • 参数:
    • inputName:表示输入Tensor的别名。
    • dataType:表示输入Tensor的DataType。
    • shape:表示输入Tensor的TensorShape。
    • content:表示输入Tensor的内容,采用一维数组表示。

      如果输入Tensor的DataType为DT_FLOAT、DT_COMPLEX64、DT_BFLOAT16或DT_HALF,则content中的元素类型为FLOAT。其中DataType为DT_COMPLEX64时,content中相邻两个FLOAT元素依次表示复数的实部和虚部。

      如果输入Tensor的DataType为DT_DOUBLE或DT_COMPLEX128,则content中的元素类型为DOUBLE。其中DataType为DT_COMPLEX128时,content中相邻两个DOUBLE元素依次表示复数的实部和虚部。

      如果输入Tensor的DataType为DT_INT32、DT_UINT8、DT_INT16、DT_INT8、DT_QINT8、DT_QUINT8、DT_QINT32、DT_QINT16、DT_QUINT16或DT_UINT16,content中的元素类型为INT。

      如果输入Tensor的DataType为DT_INT64,则content中的元素类型为LONG。

      如果输入Tensor的DataType为DT_STRING,则content中的元素类型为STRING。

      如果输入Tensor的DataType为DT_BOOL,则content中的元素类型为BOOLEAN。

TFResponse List<Long> getTensorShape(String outputName)
  • 功能:获得指定别名的输出Tensor的TensorShape。
  • 参数:outputName表示待获取TensorShape的模型输出的名称。
  • 返回值:表示TensorShape的一维数组。
List<Float> getFloatVals(String outputName)
  • 功能:如果输出Tensor的DataType为DT_FLOAT、DT_COMPLEX64、DT_BFLOAT16或DT_HALF,则可以调用该接口获取指定输出Tensor的data
  • 参数:outputName表示待获取FLOAT类型返回数据的模型输出的名称。
  • 返回值:模型输出的TensorData展开成的一维数组。
List<Double> getDoubleVals(String outputName)
  • 功能:如果输出Tensor的DataType为DT_DOUBLE或DT_COMPLEX128,则调用该接口获取指定输出Tensor的data
  • 参数:outputname表示待获取DOUBLE类型返回数据的模型输出的名称。
  • 返回值:模型输出的TensorData展开成的一维数组。
List<Integer> getIntVals(String outputName)
  • 功能:如果输出Tensor的DataType为DT_INT32、DT_UINT8、DT_INT16、DT_INT8、DT_QINT8、DT_QUINT8、DT_QINT32、DT_QINT16、DT_QUINT16或DT_UINT16,则调用该接口获取指定输出Tensor的data
  • 参数:outputname表示待获取INT类型返回数据的模型输出的名称。
  • 返回值:模型输出的TensorData展开成的一维数组。
List<String> getStringVals(String outputName)
  • 功能:如果输出Tensor的DataType为DT_STRING,则调用该接口获取指定输出Tensor的data
  • 参数:outputName表示待获取STRING类型返回数据的模型输出的名称。
  • 返回值:模型输出的TensorData展开成的一维数组。
List<Long> getInt64Vals(String outputName)
  • 功能:如果输出Tensor的DataType为DT_INT64,则调用该接口获取指定输出Tensor的data
  • 参数:outputName表示待获取的INT64类型返回数据的模型输出的名称。
  • 返回值:模型输出的TensorData展开成的一维数组。
List<Boolean> getBoolVals(String outputName)
  • 功能:如果输出Tensor的DataType为DT_BOOL,则调用该接口获取指定输出Tensor的data
  • 参数:outputName表示待获取BOOL类型的返回数据的模型输出的名称。
  • 返回值:模型输出的TensorData展开成的一维数组。

程序示例

字符串输入输出示例

对于使用自定义Processor部署服务的用户而言,通常采用字符串进行服务调用(例如,PMML模型服务的调用),具体的Demo程序如下。
import com.aliyun.openservices.eas.predict.http.PredictClient;
import com.aliyun.openservices.eas.predict.http.HttpConfig;

public class Test_String {
    public static void main(String[] args) throws Exception{
    // 启动并初始化客户端。client对象需要共享,不能每个请求都创建一个client对象。
        PredictClient client = new PredictClient(new HttpConfig());
        client.setToken("YWFlMDYyZDNmNTc3M2I3MzMwYmY0MmYwM2Y2MTYxMTY4NzBkNzdj****");                         
        // 如果需要使用网络直连功能,则使用setDirectEndpoint方法。
        // 例如,client.setDirectEndpoint("pai-eas-vpc.cn-shanghai.aliyuncs.com");
        // 网络直连功能需要在PAI-EAS控制台开通,提供用于访问PAI-EAS服务的源vswitch。开通后可以绕过网关以软负载的方式直接访问服务的实例,以实现更好的稳定性和性能。
        // 注意:通过普通网关访问时,需要使用以用户uid开头的Endpoint,在PAI-EAS控制台服务的调用信息中可以查到该信息。通过直连访问时,需要使用如上的pai-eas-vpc.{region_id}.aliyuncs.com的域名。
        client.setEndpoint("182848887922****.vpc.cn-shanghai.pai-eas.aliyuncs.com");
        client.setModelName("scorecard_pmml_example");

        //定义输入字符串。
        String request = "[{\"money_credit\": 3000000}, {\"money_credit\": 10000}]";
        System.out.println(request);

        //通过PAI-EAS返回字符串。
        try {
            String response = client.predict(request);
            System.out.println(response);
        } catch(Exception e) {
            e.printStackTrace();
        }       
        return;
    }
}
如上述程序所示,使用Java SDK调用服务的流程如下:
  1. 通过PredictClient接口创建客户端服务对象。如果在程序中需要使用多个服务,则创建多个Client对象。
  2. 为PredictClient对象配置Token、Endpoint及ModelName。
  3. 构造STRING类型的request作为输入,通过client.predict发送HTTP请求,系统返回response

TensorFlow输入输出示例

使用TensorFlow的用户,需要将TFRequest和TFResponse分别作为输入和输出数据格式,具体Demo示例如下。
import java.util.List;
import com.aliyun.openservices.eas.predict.http.PredictClient;
import com.aliyun.openservices.eas.predict.http.HttpConfig;
import com.aliyun.openservices.eas.predict.request.TFDataType;
import com.aliyun.openservices.eas.predict.request.TFRequest;
import com.aliyun.openservices.eas.predict.response.TFResponse;

public class Test_TF {
    public static TFRequest buildPredictRequest() {
        TFRequest request = new TFRequest();
        request.setSignatureName("predict_images");
        float[] content = new float[784];
        for (int i = 0; i < content.length; i++)
            content[i] = (float)0.0;
        request.addFeed("images", TFDataType.DT_FLOAT, new long[]{1, 784}, content);
        request.addFetch("scores");
        return request;
    }

    public static void main(String[] args) throws Exception{
        PredictClient client = new PredictClient(new HttpConfig());
        
        // 如果使用网络直连功能,则调用setDirectEndpoint方法。
        // 例如,client.setDirectEndpoint("pai-eas-vpc.cn-shanghai.aliyuncs.com");
        // 网络直连功能需要在PAI-EAS控制台开通,提供用于访问PAI-EAS服务的源vswitch。开通后可以绕过网关以软负载的方式直接访问服务的实例,以实现更好的稳定性和性能。
        // 注意:通过普通网关访问时,需要使用以用户uid开头的Endpoint,在PAI-EAS控制台服务的调用信息中可以查到该信息。通过直连访问时,需要使用如上的pai-eas-vpc.{region_id}.aliyuncs.com的域名。
        client.setEndpoint("1828488879222746.vpc.cn-shanghai.pai-eas.aliyuncs.com");
        client.setModelName("mnist_saved_model_example");
        client.setToken("YTg2ZjE0ZjM4ZmE3OTc0NzYxZDMyNmYzMTJjZTQ1YmU0N2FjMTAy****");
        long startTime = System.currentTimeMillis();
        for (int i = 0; i < 1000; i++) {
            try {
                TFResponse response = client.predict(buildPredictRequest());
                List<Float> result = response.getFloatVals("scores");
                System.out.print("Predict Result: [");
                for (int j = 0; j < result.size(); j++) {
                    System.out.print(result.get(j).floatValue());
                    if (j != result.size() -1)
                        System.out.print(", ");
                }
                System.out.print("]\n");
            } catch(Exception e) {
                e.printStackTrace();
            }
        }
        long endTime = System.currentTimeMillis();
        System.out.println("Spend Time: " + (endTime - startTime) + "ms");       
    }
}
如上述程序所示,使用Java SDK调用TensorFlow服务的流程如下:
  1. 通过PredictClient接口创建客户端服务对象。如果在程序中需要使用多个服务,则创建多个Client对象。
  2. 为PredictClient对象配置Token、Endpoint及ModelName。
  3. 使用TFRequest类封装输入数据,使用TFResponse类封装输出数据。