DropoutNet 模型的训练和部署
冷启动 DropoutNet 算法的介绍请参考这篇文章:《 冷启动推荐模型DropoutNet深度解析与改进 》。
准备离线训练样本
使用模板生成sql代码,构建离线训练样本。
模板配置:
{
"cold_start_recall": {
"model_name": "cold_start",
"model_type": "dropoutnet",
"label": {
"name": "is_click",
"selection": "max(if(event=\"click\", 1, 0))",
"type": "CLASSIFICATION"
},
"train_days": 14
}
}训练 DropoutNet 模型
使用 Pai 命令训练模型
pai -name easy_rec_ext
-Dcmd='train'
-Dconfig='oss://${bucket}/EasyRec/sv_dropout_net/sv_dropoutnet.config'
-Dtrain_tables='odps://${project}/tables/dwd_samples_for_dropoutnet/dt=${bizdate}'
-Deval_tables='odps://${project}/tables/dwd_sv_cold_start_samples/dt=${bizdate}'
-Dboundary_table='odps://${project}/tables/cold_start_feature_binning'
-Dmodel_dir='oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}'
-Dedit_config_json='{"train_config.fine_tune_checkpoint":"oss://${bucket}/EasyRec/sv_dropout_net/${yesterday}/"}'
-Dbuckets='oss://${bucket}/'
-Darn='acs:ram::XXXXXXXXXXX:role/aliyunodpspaidefaultrole'
-DossHost='oss-cn-beijing-internal.aliyuncs.com'
-Dcluster='{
\"ps\": {
\"count\" : 1,
\"cpu\" : 800
},
\"worker\" : {
\"count\" : 9,
\"cpu\" : 800
}
}';拆分模型为User Embedding子模型和Item Embedding子模型
pai -name tensorflow1120_cpu_ext
-Dscript='oss://${bucket}/EasyRec/sv_dropout_net/split_model_pai_v2.py'
-Dbuckets='oss://${bucket}/'
-Darn='acs:ram::XXXXXXXXXXXX:role/aliyunodpspaidefaultrole'
-DossHost='oss-cn-beijing-internal.aliyuncs.com'
-DuserDefinedParameters='--model_dir=oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}/export/final --user_model_dir=oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}/export/user --item_model_dir=oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}/export/item';部署模型服务
部署脚本
bizdate=$1
cat << EOF > eas_config.json
{
"name": "sv_dropoutnet",
"metadata": {
"cpu": 2,
"instance": 1,
"memory": 6000
},
"processor": "tensorflow_cpu",
"model_path": "oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}/export/item/"
}
EOF
# 创建服务
/home/admin/usertools/tools/eascmd \
-i ${accessId} \
-k ${accessKey} \
-e pai-eas.cn-beijing.aliyuncs.com create eas_config.json
# 更新服务
#/home/admin/usertools/tools/eascmd \
# -i ${accessId} \
# -k ${accessKey} \
# -e pai-eas.cn-beijing.aliyuncs.com \
# modify sv_dropoutnet -s eas_config.json
# 查看服务
echo "-------------------查看服务-------------------"
/home/admin/usertools/tools/eascmd \
-i ${accessId} \
-k ${accessKey} \
-e pai-eas.cn-beijing.aliyuncs.com desc sv_dropoutnet计算实时特征
1. Flink 接入 item table 的 binlog
Flink里创建Hologres item 表的 binlog;创建新品item的视图 View。
create TEMPORARY table item_table_binlog (
hg_binlog_lsn BIGINT,
hg_binlog_event_type BIGINT,
hg_binlog_timestamp_us BIGINT,
itemId bigint,
...
createTime TIMESTAMP,
ets AS TO_TIMESTAMP(FROM_UNIXTIME(hg_binlog_timestamp_us / 1000000)),
WATERMARK FOR ets AS ets - INTERVAL '5' MINUTE
) with (
'connector'='hologres',
'endpoint' = 'hgpostcn-cn-XXXXXX-cn-beijing-vpc.hologres.aliyuncs.com:80',
'username' = '${username}',
'password' = '${passowrod}',
'dbname' = '${dbname}',
'tablename' = 'item_table',
'binlog' = 'true',
'binlogMaxRetryTimes' = '10',
'binlogRetryIntervalMs' = '500',
'binlogBatchReadSize' = '256',
'startTime' = '2022-03-03 00:00:00'
);
CREATE TEMPORARY VIEW if NOT EXISTS new_item_view
AS
SELECT itemId, ..., createTime,
PROCTIME() AS proc_time, ets
FROM smart_video_binlog
WHERE hg_binlog_event_type IN (5, 7) --INSERT=5, AFTER_UPDATE=7
AND createTime >= CURRENT_TIMESTAMP - INTERVAL '24' HOUR
;2. Join item 特征
先创建好Holo的Item Feature Table,然后在Flink上创建临时表,作为Sink的目标表。
create TEMPORARY table item_cold_start_feature (
itemId bigint,
...
update_time bigint
) with (
'connector'='hologres',
'dbname'='${dbname}',
'tablename'='sv_rec.sv_cold_start_feature',
'username'='${user_name}',
'password'='${password}',
'endpoint'='hgpostcn-cn-xxxxxxxxxx-cn-beijing-vpc.hologres.aliyuncs.com:80',
'mutatetype'='insertorupdate'
);INSERT INTO item_cold_start_feature
SELECT
v.itemId,
v.userId AS author,
s.primaryId AS primary_type,
v.title,
TIMESTAMPDIFF(DAY, v.createTime, LOCALTIMESTAMP) AS pub_days,
v.duration,
v.sourceType as source_type,
v.inTimeOrNot as intimeornot,
v.is_prop,
COALESCE(s.gradeScore, v.gradeScore) AS grade_score,
v.width,
v.height,
v.firstPublishSongOrNot AS is_first_publish_song,
COALESCE(v.topic_id, '') as topic_id,
t.cate_name1,
t.cate_name2,
t.video_tags,
au.author_gender,
au.author_level,
au.author_is_member,
au.author_city,
au.author_type,
au.author_fans_num,
au.author_visitor_num,
au.author_billboard_num,
au.author_av_ct,
au.author_sv_ct,
au.author_play_ct,
au.author_play_avg_ct,
au.author_like_ct,
au.author_download_ct,
au.family_hot_ranking,
au.author_diamond_ct,
au.author_flower_ct,
CAST(STR_TO_MAP(au.author_sv_type_play_ct_1, ',', ':')[CAST(s.primaryId as VARCHAR)] AS bigint) AS author_sv_type_play_ct_1,
CAST(STR_TO_MAP(au.author_sv_type_play_ct_7, ',', ':')[CAST(s.primaryId as VARCHAR)] AS bigint) AS author_sv_type_play_ct_7,
CAST(STR_TO_MAP(au.author_sv_type_play_ct_15, ',', ':')[CAST(s.primaryId as VARCHAR)] AS bigint) AS author_sv_type_play_ct_15,
au.author_play_ct_1,
au.author_play_ct_7,
au.author_play_ct_15,
au.author_like_ct_1,
au.author_like_ct_7,
au.author_like_ct_15,
au.author_comment_ct_1,
au.author_comment_ct_7,
au.author_comment_ct_15,
au.author_share_ct_1,
au.author_share_ct_7,
au.author_share_ct_15,
au.author_tags,
TIMESTAMPDIFF(DAY, au.author_last_live_time, LOCALTIMESTAMP) AS author_last_live_days,
UNIX_TIMESTAMP() as update_time,
t.name_embedding,
t.tag_embedding
FROM new_item_view AS v
LEFT JOIN author_feature FOR SYSTEM_TIME AS OF v.proc_time as au
ON v.userId = au.author_id
LEFT JOIN smart_video_sign FOR SYSTEM_TIME AS OF v.proc_time as s
ON v.smartVideoId = s.svid
LEFT JOIN video_name_tag_embedding FOR SYSTEM_TIME AS OF v.proc_time as t
ON v.smartVideoId = t.svid
;3. 生成新 Item Embedding
创建 item embedding 的Hologres table 和 flink 临时表,作为Sink的目标表。
create TEMPORARY table item_dropoutnet_embedding (
itemId bigint,
embedding ARRAY<FLOAT>,
update_time bigint
) with (
'connector'='hologres',
'dbname'='${dbname}',
'tablename'='sv_rec.sv_dropoutnet_embedding',
'username'='${username}',
'password'='${password}',
'endpoint'='hgpostcn-cn-xxxxxxxxxxxx-cn-beijing-vpc.hologres.aliyuncs.com:80',
'mutatetype'='insertorreplace',
'field_delimiter'=','
);开发一个调用DropoutNet模型EAS服务的Udf, 在flink sql中调用udf,实时生成item embedding,存入Hologres供线上使用。
INSERT INTO item_dropoutnet_embedding
SELECT
f.svid,
InvokeEasUdf(
'sv_dropoutnet',
'${endpoint}',
'${token}',
f.primary_type,
f.title,
f.pub_days,
f.duration,
f.source_type,
f.intimeornot,
f.is_prop,
f.grade_score,
f.width,
f.height,
f.is_first_publish_song,
f.topic_id,
COALESCE(t.cate_name1, f.cate_name1),
COALESCE(t.cate_name2, f.cate_name2),
COALESCE(t.video_tags, f.video_tags),
f.author_gender,
f.author_level,
f.author_is_member,
f.author_city,
f.author_type,
f.author_fans_num,
f.author_visitor_num,
f.author_billboard_num,
f.author_av_ct,
f.author_sv_ct,
f.author_play_ct,
f.author_play_avg_ct,
f.author_like_ct,
f.author_download_ct,
f.family_hot_ranking,
f.author_diamond_ct,
f.author_flower_ct,
f.author_sv_type_play_ct_1,
f.author_sv_type_play_ct_7,
f.author_sv_type_play_ct_15,
f.author_play_ct_1,
f.author_play_ct_7,
f.author_play_ct_15,
f.author_like_ct_1,
f.author_like_ct_7,
f.author_like_ct_15,
f.author_comment_ct_1,
f.author_comment_ct_7,
f.author_comment_ct_15,
f.author_share_ct_1,
f.author_share_ct_7,
f.author_share_ct_15,
f.author_tags,
f.author_last_live_days,
COALESCE(t.name_embedding, f.name_embedding),
COALESCE(t.tag_embedding, f.tag_embedding)
) as embedding,
UNIX_TIMESTAMP() as update_time
FROM video_name_tag_embedding_hi as t
JOIN sv_cold_start_feature FOR SYSTEM_TIME AS OF t.proc_time as f
ON t.svid = f.svid and t.hg_binlog_event_type IN (5, 7);调用EAS服务的Flink UDF代码示例:
package com.alibaba.pairec.udf;
import com.aliyun.openservices.eas.predict.http.HttpConfig;
import com.aliyun.openservices.eas.predict.http.PredictClient;
import com.aliyun.openservices.eas.predict.request.TFDataType;
import com.aliyun.openservices.eas.predict.request.TFRequest;
import com.aliyun.openservices.eas.predict.response.TFResponse;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.log4j.Logger;
import java.util.*;
import java.util.stream.Collectors;
public class InvokeEasUdf extends ScalarFunction {
private volatile static PredictClient client;
private static final Logger logger = Logger.getLogger(InvokeEasUdf.class);
public static PredictClient getClient(String modelName, String endpoint, String token) {
if (null == client) {
synchronized (InvokeEasUdf.class) {
if (null == client) {
client = new PredictClient(new HttpConfig());
client.setToken(token);
client.setEndpoint(endpoint);
client.setModelName(modelName);
client.setIsCompressed(false);
}
}
}
return client;
}
public static TFRequest buildPredictRequest(
Long primary_type,
String title,
Long pub_days,
Double duration,
Long source_type,
Long intimeornot,
Long is_prop,
Long grade_score,
Long width,
Long height,
Long is_first_publish_song,
String topic_id,
String cate_name1,
String cate_name2,
String video_tags,
Long author_gender,
Long author_level,
Long author_is_member,
String author_city,
String author_type,
Long author_fans_num,
Long author_visitor_num,
Long author_billboard_num,
Long author_av_ct,
Long author_sv_ct,
Long author_play_ct,
Long author_play_avg_ct,
Long author_like_ct,
Long author_download_ct,
Long family_hot_ranking,
Long author_diamond_ct,
Long author_flower_ct,
Long author_sv_type_play_ct_1,
Long author_sv_type_play_ct_7,
Long author_sv_type_play_ct_15,
Long author_play_ct_1,
Long author_play_ct_7,
Long author_play_ct_15,
Long author_like_ct_1,
Long author_like_ct_7,
Long author_like_ct_15,
Long author_comment_ct_1,
Long author_comment_ct_7,
Long author_comment_ct_15,
Long author_share_ct_1,
Long author_share_ct_7,
Long author_share_ct_15,
String author_tags,
Long author_last_live_days,
String name_embedding,
String tag_embedding
) {
TFRequest request = new TFRequest();
request.setSignatureName("serving_default");
request.addFeed("author_av_ct",
TFDataType.DT_INT64,
new long[]{1},
new long[]{author_av_ct == null ? 0 : author_av_ct});
request.addFeed("author_billboard_num", TFDataType.DT_INT64, new long[]{1}, new long[]{author_billboard_num == null ? 0 : author_billboard_num});
request.addFeed("author_city", TFDataType.DT_STRING, new long[]{1}, new String[]{author_city == null ? "" : author_city});
request.addFeed("author_comment_ct_1", TFDataType.DT_INT64, new long[]{1}, new long[]{author_comment_ct_1 == null ? 0 : author_comment_ct_1});
request.addFeed("author_comment_ct_7", TFDataType.DT_INT64, new long[]{1}, new long[]{author_comment_ct_7 == null ? 0 : author_comment_ct_7});
request.addFeed("author_comment_ct_15", TFDataType.DT_INT64, new long[]{1}, new long[]{author_comment_ct_15 == null ? 0 : author_comment_ct_15});
request.addFeed("author_diamond_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_diamond_ct == null ? 0 : author_diamond_ct});
request.addFeed("author_download_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_download_ct == null ? 0 : author_download_ct});
request.addFeed("author_fans_num", TFDataType.DT_INT64, new long[]{1}, new long[]{author_fans_num == null ? 0 : author_fans_num});
request.addFeed("author_flower_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_flower_ct == null ? 0 : author_flower_ct});
request.addFeed("author_gender", TFDataType.DT_INT64, new long[]{1}, new long[]{author_gender == null ? 0 : author_gender});
request.addFeed("author_is_member", TFDataType.DT_INT64, new long[]{1}, new long[]{author_is_member == null ? 0 : author_is_member});
request.addFeed("author_last_live_days", TFDataType.DT_INT64, new long[]{1}, new long[]{author_last_live_days == null ? 0 : author_last_live_days});
request.addFeed("author_level", TFDataType.DT_INT64, new long[]{1}, new long[]{author_level == null ? 0 : author_level});
request.addFeed("author_like_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_like_ct == null ? 0 : author_like_ct});
request.addFeed("author_like_ct_1", TFDataType.DT_INT64, new long[]{1}, new long[]{author_like_ct_1 == null ? 0 : author_like_ct_1});
request.addFeed("author_like_ct_15", TFDataType.DT_INT64, new long[]{1}, new long[]{author_like_ct_15 == null ? 0 : author_like_ct_15});
request.addFeed("author_like_ct_7", TFDataType.DT_INT64, new long[]{1}, new long[]{author_like_ct_7 == null ? 0 : author_like_ct_7});
request.addFeed("author_play_avg_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_play_avg_ct == null ? 0 : author_play_avg_ct});
request.addFeed("author_play_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_play_ct == null ? 0 : author_play_ct});
request.addFeed("author_play_ct_1", TFDataType.DT_INT64, new long[]{1}, new long[]{author_play_ct_1 == null ? 0 : author_play_ct_1});
request.addFeed("author_play_ct_15", TFDataType.DT_INT64, new long[]{1}, new long[]{author_play_ct_15 == null ? 0 : author_play_ct_15});
request.addFeed("author_play_ct_7", TFDataType.DT_INT64, new long[]{1}, new long[]{author_play_ct_7 == null ? 0 : author_play_ct_7});
request.addFeed("author_share_ct_1", TFDataType.DT_INT64, new long[]{1}, new long[]{author_share_ct_1 == null ? 0 : author_share_ct_1});
request.addFeed("author_share_ct_15", TFDataType.DT_INT64, new long[]{1}, new long[]{author_share_ct_15 == null ? 0 : author_share_ct_15});
request.addFeed("author_share_ct_7", TFDataType.DT_INT64, new long[]{1}, new long[]{author_share_ct_7 == null ? 0 : author_share_ct_7});
request.addFeed("author_sv_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_sv_ct == null ? 0 : author_sv_ct});
request.addFeed("author_sv_type_play_ct_1", TFDataType.DT_INT64, new long[]{1}, new long[]{author_sv_type_play_ct_1 == null ? 0 : author_sv_type_play_ct_1});
request.addFeed("author_sv_type_play_ct_15", TFDataType.DT_INT64, new long[]{1}, new long[]{author_sv_type_play_ct_15 == null ? 0 : author_sv_type_play_ct_15});
request.addFeed("author_sv_type_play_ct_7", TFDataType.DT_INT64, new long[]{1}, new long[]{author_sv_type_play_ct_7 == null ? 0 : author_sv_type_play_ct_7});
request.addFeed("author_tags", TFDataType.DT_STRING, new long[]{1}, new String[]{author_tags == null ? "" : author_tags});
request.addFeed("author_type", TFDataType.DT_STRING, new long[]{1}, new String[]{author_type == null ? "" : author_type});
request.addFeed("author_visitor_num", TFDataType.DT_INT64, new long[]{1}, new long[]{author_visitor_num == null ? 0 : author_visitor_num});
request.addFeed("cate_name1", TFDataType.DT_STRING, new long[]{1}, new String[]{cate_name1 == null ? "" : cate_name1});
request.addFeed("cate_name2", TFDataType.DT_STRING, new long[]{1}, new String[]{cate_name2 == null ? "" : cate_name2});
request.addFeed("duration", TFDataType.DT_DOUBLE, new long[]{1}, new double[]{duration == null ? 0 : duration});
request.addFeed("family_hot_ranking", TFDataType.DT_INT64, new long[]{1}, new long[]{family_hot_ranking == null ? 0 : family_hot_ranking});
request.addFeed("grade_score", TFDataType.DT_INT64, new long[]{1}, new long[]{grade_score == null ? 0 : grade_score});
request.addFeed("height", TFDataType.DT_INT64, new long[]{1}, new long[]{height == null ? 0 : height});
request.addFeed("intimeornot", TFDataType.DT_INT64, new long[]{1}, new long[]{intimeornot == null ? 0 : intimeornot});
request.addFeed("is_first_publish_song", TFDataType.DT_INT64, new long[]{1}, new long[]{is_first_publish_song == null ? 0 : is_first_publish_song});
request.addFeed("is_prop", TFDataType.DT_INT64, new long[]{1}, new long[]{is_prop == null ? 0 : is_prop});
request.addFeed("primary_type", TFDataType.DT_INT64, new long[]{1}, new long[]{primary_type == null ? 0 : primary_type});
request.addFeed("pub_days", TFDataType.DT_INT64, new long[]{1}, new long[]{pub_days == null ? 0 : pub_days});
request.addFeed("source_type", TFDataType.DT_INT64, new long[]{1}, new long[]{source_type == null ? 0 : source_type});
request.addFeed("title", TFDataType.DT_STRING, new long[]{1}, new String[]{title == null ? "" : title});
request.addFeed("topic_id", TFDataType.DT_STRING, new long[]{1}, new String[]{topic_id == null ? "" : topic_id});
request.addFeed("video_tags", TFDataType.DT_STRING, new long[]{1}, new String[]{video_tags == null ? "" : video_tags});
request.addFeed("width", TFDataType.DT_INT64, new long[]{1}, new long[]{width == null ? 0 : width});
request.addFeed("name_embedding", TFDataType.DT_STRING, new long[]{1}, new String[]{name_embedding == null ? "" : name_embedding});
request.addFeed("tag_embedding", TFDataType.DT_STRING, new long[]{1}, new String[]{tag_embedding == null ? "" : tag_embedding});
request.addFetch("item_emb");
return request;
}
protected void finalize() {
if (null != client) {
client.shutdown();
}
}
public List<Float> eval(String modelName, String endpoint, String token,
Long primary_type,
String title,
Long pub_days,
Double duration,
Long source_type,
Long intimeornot,
Long is_prop,
Long grade_score,
Long width,
Long height,
Long is_first_publish_song,
String topic_id,
String cate_name1,
String cate_name2,
String video_tags,
Long author_gender,
Long author_level,
Long author_is_member,
String author_city,
String author_type,
Long author_fans_num,
Long author_visitor_num,
Long author_billboard_num,
Long author_av_ct,
Long author_sv_ct,
Long author_play_ct,
Long author_play_avg_ct,
Long author_like_ct,
Long author_download_ct,
Long family_hot_ranking,
Long author_diamond_ct,
Long author_flower_ct,
Long author_sv_type_play_ct_1,
Long author_sv_type_play_ct_7,
Long author_sv_type_play_ct_15,
Long author_play_ct_1,
Long author_play_ct_7,
Long author_play_ct_15,
Long author_like_ct_1,
Long author_like_ct_7,
Long author_like_ct_15,
Long author_comment_ct_1,
Long author_comment_ct_7,
Long author_comment_ct_15,
Long author_share_ct_1,
Long author_share_ct_7,
Long author_share_ct_15,
String author_tags,
Long author_last_live_days,
String name_embedding,
String tag_embedding
) {
PredictClient predictor = getClient(modelName, endpoint, token);
TFRequest request = buildPredictRequest(
primary_type,
title,
pub_days,
duration,
source_type,
intimeornot,
is_prop,
grade_score,
width,
height,
is_first_publish_song,
topic_id,
cate_name1,
cate_name2,
video_tags,
author_gender,
author_level,
author_is_member,
author_city,
author_type,
author_fans_num,
author_visitor_num,
author_billboard_num,
author_av_ct,
author_sv_ct,
author_play_ct,
author_play_avg_ct,
author_like_ct,
author_download_ct,
family_hot_ranking,
author_diamond_ct,
author_flower_ct,
author_sv_type_play_ct_1,
author_sv_type_play_ct_7,
author_sv_type_play_ct_15,
author_play_ct_1,
author_play_ct_7,
author_play_ct_15,
author_like_ct_1,
author_like_ct_7,
author_like_ct_15,
author_comment_ct_1,
author_comment_ct_7,
author_comment_ct_15,
author_share_ct_1,
author_share_ct_7,
author_share_ct_15,
author_tags,
author_last_live_days,
name_embedding,
tag_embedding
);
TFResponse response;
try {
response = predictor.predict(request);
List<String> result = response.getStringVals("item_emb");
String embedding = result.get(0);
String[] emb = embedding.split(",");
return Arrays.stream(emb).map(Float::valueOf).collect(Collectors.toList());
} catch (Exception e) {
logger.error("call eas failed." + e.getMessage());
return Collections.EMPTY_LIST;
}
}
public static void main(String[] args) {
InvokeEasUdf udf = new InvokeEasUdf();
List<Float> emb = udf.eval("sv_dropoutnet",
"1103287870424018.cn-beijing.pai-eas.aliyuncs.com",
"NDg4OGIwZGU2MjAzNzljMGZkNjI2ZWUxZWEzZjM4ZGYyNmU2ZWVmZA==",
90L,
"#2021\u001D演\u001D技\u001D大\u001D赏\u001D",
0L, 72800.0, 4L, 0L, 0L, 5L,
576L, 1024L, 1L, "97388",
"音乐", "歌曲", "美女\u001D歌曲\u001D音乐",
0L, 6L, 1L, "", "", 0L, 3L,
0L, 0L, 0L, 2L, 6L, 2L,
0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 1L, 1L, 1L,
0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, "", 0L, "", ""
);
System.out.println(emb);
}
}需要添加如下Maven依赖:
<dependencies>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-clients_2.12</artifactId>
<version>${flink.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-java</artifactId>
<version>${flink.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-streaming-java_2.12</artifactId>
<version>${flink.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-common</artifactId>
<version>${flink.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table</artifactId>
<version>${flink.version}</version>
<type>pom</type>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
<version>1.2.17</version>
</dependency>
<dependency>
<groupId>com.aliyun.openservices.eas</groupId>
<artifactId>eas-sdk</artifactId>
<version>2.0.3</version>
</dependency>
</dependencies>准备用户Embedding向量
离线计算好用户特征,并用前面步骤拆分出来的用户子模型生成用户Embedding向量。
pai -name easy_rec_ext
-Dcmd='predict'
-Dconfig='oss://${bucket}/EasyRec/sv_dropout_net/sv_dropoutnet.config'
-Doutput_table='odps://${project}/tables/dropoutnet_user_embedding/dt=${bizdate}'
-Dinput_table='odps://${project}/tables/dropoutnet_user_features/dt=${bizdate}'
-Dsaved_model_dir='oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}/export/final'
-Dreserved_cols="userid"
-Doutput_cols="user_emb string"
-Dmodel_outputs="user_emb"
-Dbuckets='oss://${bucket}/'
-Darn='acs:ram::XXXXXXXXXX:role/aliyunodpspaidefaultrole'
-DossHost='oss-cn-beijing-internal.aliyuncs.com'
-Dcluster='{
\"worker\" : {
\"count\" : 8,
\"cpu\" : 600
}
}';最终,用户Embedding向量需要导入到Hologres。
检索Top N个Item 作为召回结果
在推荐服务中使用向量检索引擎(hologres)查询与用户Embedding向量距离最近的Top N个Item。
func (r *ItemColdStartRecall) GetRetrieveSql(userEmb string) (string, []interface{}) {
sb := sqlbuilder.PostgreSQL.NewSelectBuilder()
vecIndex := sb.Args.Add(userEmb)
dotProduct := fmt.Sprintf("pm_approx_inner_product_distance(%s,%s)", r.VectorEmbeddingField, vecIndex)
sb.Select(r.VectorKeyField, sb.As(dotProduct, "distance"))
sb.From(r.VectorTable)
sb.OrderBy("distance").Desc()
sb.Limit(r.recallCount)
return sb.Build()
}