全部产品
云市场

推荐召回算法:二部图GraphSAGE

更新时间:2020-05-21 11:17:27

背景

图神经网络是目前深度学习热点的发展方向,PAI团队在前不久开源了graph-learn框架( https://github.com/alibaba/graph-learn ),提供了大量常见的图学习算法。二部图GraphSAGE是经典的图神经网络算法GraphSAGE 在二部图场景下的扩展,被用于淘宝内部推荐的召回场景。目前该算法已经在PAI-Studio正式上线。

在二部图场景下,对于user, item的交互数据,user和item可以看成图中的点,user-item之间的关系(点击,购买等)可以当作图中的边,因此整体的user, item的关系可以抽象为一个二部图。对于user和item,采样邻居时分别按照user-item-user-item.., item-user-item-user…的meta-path进行采样。

参数设置说明

输入为3个表,从左到右分别是user&item的行为表,user特征表,item特征表。

行为表说明

  • user:bigint型,user的ID
  • item:bigint型,item的ID
  • weight:double型,表明行为,比如1是购买,2是收藏等

User表说明

  • user:bigint型,user的ID
  • feature:String型,user的特征,每个特征以冒号分割,0元特征不可以省略。feature必须是float类型,会按照连续特征进行处理

Item表说明

  • item:bigint型,item的ID
  • feature:String型,item的特征,每个特征以冒号分割,0元特征不可以省略。feature必须是float类型,会按照连续特征进行处理

输出结果

输出结果为两个表,分别是User向量表和Item向量表,可以用于推荐的召回场景

PAI命令

PAI命令

  1. pai -name bipartite_graphsage_ext
  2. -project algo_public
  3. -Dps_count=2
  4. -Dps_memory=20000
  5. -Dworker_count=2
  6. -Dworker_memory=20000
  7. -Dui_edge_table='u2i_edge'
  8. -Du_node_table='u2i_node'
  9. -Di_node_table='u2i_node_1'
  10. -Du_emb_table='u_emb'
  11. -Di_emb_table='i_emb'
  12. -Depoch=2
  13. -Dbatch_size=512
  14. -Dlearning_rate='0.001'
  15. -Ddrop_out='0.5'
  16. -Dhidden_dim=128
  17. -Doutput_dim=128
  18. -Du_features_num=1
  19. -Di_features_num=1
  20. -Du_neighs_num='[10,5]'
  21. -Di_neighs_num='[10]'
  22. -Dneg_num=5
  23. -Dagg_type='gcn'

参数

参数名 类型 解释
worker_count int TF worker(graph-learn client)数
ps_count int TF ps(graph-learn server)数,推荐worker_count = ps_count
worker_mem int worker内存
ps_mem int ps内存
ui_edge_table 输入的ui边表
u_node_table 输入的u的属性表
i_node_table 输入的i的属性表
u_emb_table 输出的u的embedding表
i_emb_table 输出的i的embedding表
epoch int 训练的epcoh数
batch_size int 训练的batch size
learning_rate float 训练的学习率
drop_out float drop out率
hidden_dim int 隐层的维数
output_dim int 最后输出embedding的维数
u_features_num int u的总特征数
i_features_num int i的总特征数
u_neighs_num string, 内容是list例如’[10,2]’ u的每一跳的邻居数。’[10, 2]’表示第一跳采样10个邻居,第二跳采样2个邻居。
i_neighs_num string, 内容是list例如’[10]’ i的每一跳的邻居数
neg_num int(可选),默认5 负采样数目(一条正样本对应的负样本数目)
agg_type string(可选),默认’gcn’ 聚合类型,’gcn’, ‘sum’ or ‘mean’