模型上传与在线推理
本文介绍如何将自定义模型上传至PolarDB for AI,以及如何使用模型完成在线推理任务。
背景信息
PolarDB for AI虽然内置了常用的机器学习模型,但是在客户的实际业务场景中,模型经过算法调优后,模型结构可能会根据业务效果进行调整,内置的机器学习模型可能无法满足用户的实际业务需求。因此,PolarDB for AI推出了上传模型和模型在线推理功能,并支持将模型运行在机密容器中,进一步保障了客户模型中数据的安全。
操作步骤
离线训练模型。
以LightGBM算法为例,离线训练脚本如下:
# coding: utf-8 from pathlib import Path import pandas as pd from sklearn.metrics import mean_squared_error # import lightgbm as lgb import joblib def train_model(): print('Loading data...') # load or create your dataset df_train = pd.read_csv('regression.train', header=None, sep='\t') df_test = pd.read_csv('regression.test', header=None, sep='\t') y_train = df_train[0] y_test = df_test[0] X_train = df_train.drop(0, axis=1) X_test = df_test.drop(0, axis=1) # create dataset for lightgbm lgb_train = lgb.Dataset(X_train, y_train) lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) # specify your configurations as a dict params = { 'boosting_type': 'gbdt', 'objective': 'regression', 'metric': {'l2', 'l1'}, 'num_leaves': 31, 'learning_rate': 0.05, 'feature_fraction': 0.9, 'bagging_fraction': 0.8, 'bagging_freq': 5, 'verbose': 0 } print('Starting training...') # train gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=lgb_eval, callbacks=[lgb.early_stopping(stopping_rounds=5)]) print('Saving model...') # save model to file # gbm.save_model('model.txt') joblib.dump(gbm, 'lgb.pkl') print('Starting predicting...') # predict y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration) # eval rmse_test = mean_squared_error(y_test, y_pred) ** 0.5 print(f'The RMSE of prediction is: {rmse_test}')
其中,模型文件导出为
pkl
格式,并且调用predict
方法返回推理结果,同时还需要提供模型运行依赖的python文件。文件
requirements.txt
内容示例如下:lightgbm==3.3.3
上传模型。
执行以下命令,将模型上传至PolarDB for AI。
/*polar4ai*/UPLOAD MODEL my_model WITH (model_location='https://xxxx.oss-cn-hangzhou.aliyuncs.com/xxxx/model.pkl?Expires=xxxx&OSSAccessKeyId=xxxx&Signature=xxxx', req_location='https://xxxx.oss-cn-hangzhou.aliyuncs.com/xxxx/requirements.txt?Expires=xxxx&OSSAccessKeyId=xxxx&Signature=xxxx')
其中,
model_location
和req_location
分别表示模型文件地址以及模型运行时所依赖的文件地址,您可以将这两个文件提前准备好,并上传至自己的私有oss bucket,再通过以上命令上传至PolarDB for AI平台。返回结果如下:
Query OK, 0 rows affected (0.29 sec)
执行以下命令,查看模型状态。
/*polar4ai*/ SHOW my_model;
返回结果如下:
+-------------+-----------------------------------------------------------------------------------------------------------------------------+ | modelStatus | modelPath | +-------------+-----------------------------------------------------------------------------------------------------------------------------+ | saved | http://db4ai-collie-cn-hangzhou.oss-cn-hangzhou.aliyuncs.com/xxxxx.pkl?OSSAccessKeyId=xxxxxx&Expires=xxxx&Signature=xxxxxx | +-------------+-----------------------------------------------------------------------------------------------------------------------------+ 1 row in set (0.23 sec)
modelStatus
显示为saved
,表示模型上传成功。
部署模型。
执行以下命令,将模型部署在PolarDB for AI上。
/*polar4ai*/ DEPLOY MODEL my_model;
返回结果如下:
Query OK, 0 rows affected (0.29 sec)
执行以下命令,查看模型状态。
/*polar4ai*/ SHOW MODEL my_model;
返回结果如下:
+-------------+-----------------------------------------------------------------------------------------------------------------------------+ | modelStatus | modelPath | +-------------+-----------------------------------------------------------------------------------------------------------------------------+ | serving | http://db4ai-collie-cn-hangzhou.oss-cn-hangzhou.aliyuncs.com/xxxxx.pkl?OSSAccessKeyId=xxxxxx&Expires=xxxx&Signature=xxxxxx | +-------------+-----------------------------------------------------------------------------------------------------------------------------+ 1 row in set (0.23 sec)
modelStatus
显示为serving
,表示模型部署成功。
模型在线推理。
执行以下命令,执行模型在线推理任务。
/*polar4ai*/ SELECT Y FROM PREDICT(MODEL my_model, SELECT * FROM db4ai.regression_test LIMIT 10) WITH (x_cols = 'x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16,x17,x18,x19,x20,x21,x22,x23,x24,x25,x26,x27,x28', y_cols='');
返回结果如下:
+------+---------------------+ | Y | predicted_results | +------+---------------------+ | 1.0 | 0.6262147669037363 | | 0.0 | 0.5082804008241021 | | 0.0 | 0.37533158372209957 | | 1.0 | 0.461974928099089 | | 0.0 | 0.3777339456553666 | | 0.0 | 0.35045096227525735 | | 0.0 | 0.4178165504012342 | | 1.0 | 0.40869795422774036 | | 1.0 | 0.6826481286570045 | | 0.0 | 0.47021259543154736 | +------+---------------------+ 10 rows in set (0.95 sec)