全部产品
云市场

TensorFlow使用指南

更新时间:2020-01-16 10:03:43

目录

OSS上传数据说明

使用深度学习处理数据时,数据先存储到 OSS 的 Bucket 中。第一步要创建OSS Bucket。 GPU的计算集群区域和OSS所在区域需要相同。这样在数据传输时就可以使用阿里云经典网络,算法运行时不需要收取流量费用。Bucket 创建好之后,可以在OSS管理控制台 创建文件夹、组织数据目录、上传数据。如何创建OSS Bucket请参考:OSS开通说明文档

OSS 支持多种方式上传数据, API 或 SDK 详细见:https://help.aliyun.com/document_detail/31848.html?spm=5176.doc31848.6.580.a6es2a

OSS 还提供了大量的常用工具用来帮助用户更加高效的使用 OSS。工具列表请参见: https://help.aliyun.com/document_detail/44075.html?spm=5176.doc32184.6.1012.XlMMUx

建议您使用 ossutil 或 osscmd 这两个命令行工具,通过命令的方式来上传、下载文件,还支持断点续传。

注意:在使用工具时需要配置 AccessID 和 AccessKey,登录后,可以在Access Key 管理控制台创建或查看。

读OSSBucket

用户在机器学习平台中使用“读OSS Bucket”组件时,需要授予一个名称为“AliyunODPSPAIDefaultRole” 的系统默认角色给数加的服务账号,当且仅当该角色被正确授权后,机器学习平台的算法才能正确地读、写OSS bucket。

注意:由于机器学习平台运行在MaxCompute框架之上,与MaxCompute共用服务账号。在授权时,默认的角色授予给MaxCompute服务账号。

RAM 授权可以使机器学习平台获得OSS的访问权限。在设置菜单完成对OSS读写权限的授权,详情见RAM授权。

RAM 授权

  1. 进入机器学习控制台,单击左侧菜单栏的设置,选择基本设置
  2. OSS访问授权中勾选授权机器学习读取我的OSS中的数据
  3. 进入如下界面,单击点击前往RAM进行授权,进入RAM入口,如下图所示。

  4. 进入如下界面,单击同意授权

    注意:如果您想查看“AliyunODPSPAIDefaultRole”的相关详细策略信息,可以登录RAM控制台来查看。 默认角色“AliyunODPSPAIDefaultRole”包含的权限信息如下。

    权限名称(Action)权限说明
    oss:PutObject上传文件或文件夹对象
    oss:GetObject获取文件或文件夹对象
    oss:ListObjects查询文件列表信息
    oss:DeleteObjects删除对象
  5. 返回机器学习界面,单击刷新。RAM信息会自动录入组件中,如下图所示。

  1. 使用深度学习框架。将读OSSBucket组件与相应的深度学习组件连接,用来获得OSS的读写权限。

TensorFlow读取OSS数据方法说明

低效的IO方式

本地执行TensorFlow代码和分布式云端执行TensorFlow的区别:

  • 本地读取数据:Server端直接从Client端获得graph进行计算。
  • 云端服务:Server在获得graph之后还需要将计算下发到各个Worker处理(具体原理可以参考视频教程-Tensorflow高级篇)。

本文档通过读取一个简单的CSV文件为例,帮您快速了解如何使用TensorFlow高效地读取数据。CSV文件如下:

  1. 1,1,1,1,1
  2. 2,2,2,2,2
  3. 3,3,3,3,3

容易产生问题的几个地方:

  • 不建议使用python本地读取文件的方式

    机器学习平台支持python的自带IO方式,但是需要将数据源和代码打包上传。这种读取方式是将数据写入内存之后再计算,效率比较低,不建议使用。示例代码如下。

    1. import csv
    2. csv_reader=csv.reader(open('csvtest.csv'))
    3. for row in csv_reader:
    4. print(row)
  • 不建议使用第三方库读取文件的方式

    通过第三方库(比如TFLearn、Panda)的一些数据IO的方式读取数据,是通过封装python的读取方式实现的,所以在机器学习平台使用时也会造成效率低下的问题。

  • 不建议使用 preload 读取文件的方式

    很多用户在使用机器学习服务的时候,发现 GPU 并没有比本地的 CPU 速度快的明显,主要问题可能就出在数据IO这块。
    preload 方式是先把数据全部都读到内存中,然后再通过 session 计算,比如feed的读取方式。这样要先进行数据读取,再计算,不同步造成性能浪费。同时因为内存限制也无法支持大数据量的计算。
    例如:假设硬盘中有一个图片数据集 0001.jpg,0002.jpg,0003.jpg,…… ,我们只需要把它们读取到内存中,然后提供给 GPU 或 CPU 计算就可以了。但并没有那么简单。事实上,我们必须把数据先读入后才能进行计算,假设读入用时0.1s,计算用时0.9s,那么就意味着每过1s,GPU都会有0.1s无事可做,这就大大降低了运算的效率。

高效的IO方式

高效的 TensorFlow 读取方式是将数据读取转换成 OP,通过 session run 的方式拉去数据。读取线程源源不断地将文件系统中的图片读入到一个内存的队列中,而负责计算的是另一个线程,计算需要数据时,直接从内存队列中取就可以了。这样就可以解决GPU因为IO而空闲的问题。

如下代码解释了如何在机器学习平台通过OP的方式读取数据。

  1. import argparse
  2. import tensorflow as tf
  3. import os
  4. FLAGS=None
  5. def main(_):
  6. dirname = os.path.join(FLAGS.buckets, "csvtest.csv")
  7. reader=tf.TextLineReader()
  8. filename_queue=tf.train.string_input_producer([dirname])
  9. key,value=reader.read(filename_queue)
  10. record_defaults=[[''],[''],[''],[''],['']]
  11. d1, d2, d3, d4, d5= tf.decode_csv(value, record_defaults, ',')
  12. init=tf.initialize_all_variables()
  13. with tf.Session() as sess:
  14. sess.run(init)
  15. coord = tf.train.Coordinator()
  16. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  17. for i in range(4):
  18. print(sess.run(d2))
  19. coord.request_stop()
  20. coord.join(threads)
  21. if __name__ == '__main__':
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument('--buckets', type=str, default='',
  24. help='input data path')
  25. parser.add_argument('--checkpointDir', type=str, default='',
  26. help='output model path')
  27. FLAGS, _ = parser.parse_known_args()
  28. tf.app.run(main=main)
  • dirname:OSS文件路径,可以是数组,方便下一阶段 shuffle。
  • reader:TF内置各种reader API,可以根据需求选用。
  • tf.train.string_input_producer:将文件生成队列。
  • tf.decode_csv:是一个splite功能的OP,可以得到每一行的特定参数。
  • 通过OP获取数据,在session中需要tf.train.Coordinator()和tf.train.start_queue_runners(sess=sess,coord=coord)。

在代码中,输入得是3行5个字段:

  1. 1,1,1,1,1
  2. 2,2,2,2,2
  3. 3,3,3,3,3

循环输出4次,打印出第2个字段。结果如下图所示。

输出结果也证明了数据结构是成队列。

其它

  • 机器学习平台 Notebook 功能上线,支持在线修改代码并且内置各种深度学习框架,欢迎使用
  • 本文参考了互联网上《十图详解TensorFlow数据读取机制(附代码)》一文,关于图片的读取方式也可以参考这篇文章,感谢原作者。

https://help.aliyun.com/document_detail/52239.html?spm=5176.doc50656.6.562.dmyNaF)

Tensorflow读取MaxCompute表数据方法说明

PAI-Studio在支持OSS数据源的基础上,增加了对MaxCompute表的数据支持。用户可以直接使用PAI-Studio的Tensorflow组件读写MaxCompute数据,本教程将提供完整数据和代码供大家测试。

在开始实验之前,请先确保已开通了OSS存储服务用来存放训练代码。主账号如何开通和授权OSS,请参考:https://help.aliyun.com/document_detail/49571.html#h2-oss-4。

详细流程

为了方便用户快速上手,本文档将以训练iris数据集为例,介绍如何跑通实验。

1.读数据表组件

为了方便大家,我们提供了一份公共读的数据供大家测试,只要拖出读数据表组件,输入:

  1. pai_online_project.iris_data

即可获取数据,

数据格式如图:

2.Tensorflow组件说明

3个输入桩从左到右分别是OSS输入、MaxCompute输入、模型输入。2个输出桩分别是模型输出、MaxCompute输出。如果输入是一个MaxCompute表,输出也是一个MaxCompute表,需要按下图方法连接。

读写MaxCompute表需要配置数据源、代码文件、输出模型路径、建表等操作。

  • Python代码文件:需要把执行代码放到OSS路径下(注意OSS需要与当前项目在同一区域),本文提供的代码可以在下方连接下载(代码需要按照下方代码说明文案调整):http://docs-aliyun.cn-hangzhou.oss.aliyun-inc.com/assets/attach/129749/cn_zh/1565333220966/iristest.py?spm=a2c4g.11186623.2.10.50c46b36PlNwcq&file=iristest.py
  • Checkpoint输出目录/模型输入目录:选择自己的OSS路径用来存放模型
  • MaxCompute输出表:写MaxCompute表要求输出表是已经存在的表,并且输出的表名需要跟代码中的输出表名一致。在本案例中需要填写“iris_output”(不需要填写odps://xxx,仅需填写表名称)
  • 建表SQL语句:如果代码中的输出表并不存在,可以通过这个输入框输入建表语句自动建表。本案例中建表语句“create table iris_output(f1 DOUBLE,f2 DOUBLE,f3 DOUBLE,f4 DOUBLE,f5 STRING);”

组件PAI命令

  1. PAI -name tensorflow180_ext -project algo_public -Doutputs="odps://${当前项目名}/tables/${输出表名}" -DossHost="${OSS的host}" -Dtables="odps://${当前项目名}/tables/${输入表名}" -DgpuRequired="${GPU卡数}" -Darn="${OSS访问RoleARN}" -Dscript="${执行的代码文件}";

上述命令中的${}需要替换成用户真实数据

3.代码说明

  1. import tensorflow as tf
  2. tf.app.flags.DEFINE_string("tables", "", "tables info")
  3. FLAGS = tf.app.flags.FLAGS
  4. print("tables:" + FLAGS.tables)
  5. tables = [FLAGS.tables]
  6. filename_queue = tf.train.string_input_producer(tables, num_epochs=1)
  7. reader = tf.TableRecordReader()
  8. key, value = reader.read(filename_queue)
  9. record_defaults = [[1.0], [1.0], [1.0], [1.0], ["Iris-virginica"]]
  10. col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults = record_defaults)
  11. # line 9 and 10 can be written like below for short. It will be helpful when too many columns exist.
  12. # record_defaults = [[1.0]] * 4 + [["Iris-virginica"]]
  13. # value_list = tf.decode_csv(value, record_defaults = record_defaults)
  14. writer = tf.TableRecordWriter("odps://{MaxCompute项目名}/tables/iris_output")
  15. write_to_table = writer.write([0, 1, 2, 3, 4], [col1, col2, col3, col4, col5])
  16. # line 16 can be written like below for short. It will be helpful when too many columns exist.
  17. # write_to_table = writer.write(range(5), value_list)
  18. close_table = writer.close()
  19. init = tf.global_variables_initializer()
  20. with tf.Session() as sess:
  21. sess.run(init)
  22. sess.run(tf.local_variables_initializer())
  23. coord = tf.train.Coordinator()
  24. threads = tf.train.start_queue_runners(coord=coord)
  25. try:
  26. step = 0
  27. while not coord.should_stop():
  28. step += 1
  29. sess.run(write_to_table)
  30. except tf.errors.OutOfRangeError:
  31. print('%d records copied' % step)
  32. finally:
  33. sess.run(close_table)
  34. coord.request_stop()
  35. coord.join(threads)

读数据表

  1. tables = [FLAGS.tables]
  2. filename_queue = tf.train.string_input_producer(tables, num_epochs=1)
  3. reader = tf.TableRecordReader()
  4. key, value = reader.read(filename_queue)
  5. record_defaults = [[1.0], [1.0], [1.0], [1.0], ["Iris-virginica"]]

其中FLAGS.tables是前端配置的输入表名的传参变量,对应组件的MaxCompute输入桩:

在数据量较大时,尽量采用批量操作,减少op的执行次数。

写数据表

  1. writer = tf.TableRecordWriter("odps://{MaxCompute项目名}/tables/iris_output")
  2. write_to_table = writer.write([0, 1, 2, 3, 4], [col1, col2, col3, col4, col5])
  • TableRecordWriter中的格式为”odps://当前项目名/tables/输出表名”
  • 读写分区表写法:”odps://当前项目名/tables/输出表名/pt=1”

TensorFlow多机多卡使用说明

机器学习平台目前已经上线了支持多机、多卡、多PS Server的TensorFlow服务,目前只支持华北2 Region。华北2 Region因为支持多机多卡功能,适用于大规模数据的训练,相关服务需要收费,有需要的机构可以联系我们。

原理说明

  • Parameter Server 节点:用来存储 TensorFlow 计算过程中的参数。配置多个 PS节点,计算参数将会被自动切片并存储在不同的PS节点中,从而减小 Worker 和 PS 节点通信过程中的带宽限制的影响。
  • Worker 节点:“多机多卡”中的“机”,GPU卡的载体。
  • Task 节点:“多机多卡”中的“卡”,在机器学习中指的是 GPU卡,在 TensorFlow 训练过程中,通过数据切片将数据分布在不同的 Task 节点进行模型参数的训练。

使用说明

多机、多卡、多PS功能会以服务化的方式提供,用户无需关心底层计算资源的调度和运维,只需通过机器学习平台前端的简单配置,即可快速搭建整个分布式计算网络。具体的使用方式如下。

  1. 前端配置

    1. mnist_cluster.tar.gz文件下载并上传到 OSS。

    2. 配置深度学习的OSS读取权限。

    3. 拖拽任意版本TensorFlow组件并按照下图连接,并设置对应的代码数据源(Python代码文件设置 mnist_cluster.tar.gz 路径,Python 主文件填入 mnist_cluster.py)。

    4. 单击执行调优进行参数配置。

    5. 通过以上配置可以快速建立起如下图所示的多机多卡多PS计算网络结构,其中PS表示Parameter Server服务,WORKER 表示计算节点机器,TASK 表示具体执行计算的GPU卡。

  2. 代码端设置

    传统的TensorFlow多机多卡作业需要在代码端输入每个计算节点的对应端口信息,如下图所示。

    当计算节点数量增多时,这种端口信息的配置会非常复杂。机器学习平台优化了计算节点配置信息的功能,只需要以下两行代码即可自动在代码端获取计算节点信息。

    1. ps_hosts = FLAGS.ps_hosts.split(",")#框架层提供ps_hosts的端口
    2. worker_hosts = FLAGS.worker_hosts.split(",")#框架层提供worker_hosts的端口
  3. 运行日志查看

    1. 右键单击 TensorFlow 组件,查看日志。可以看到资源的分配情况,分配两个PS,两个WORKER

    2. 点击蓝色链接,可以在logview中查看对应每个worker的运行状态。

代码下载

https://help.aliyun.com/document_detail/64146.html

TensorFlow超参支持

Tensorflow 超参配置:在画布右侧的组件参数设置页面,配置文件超参及用户自定义参数中设置超参文件,文件为.txt格式,如下所示。

  1. batch_size=10
  2. learning_rate=0.01

在代码中可以通过如下方法引用超参。

  1. import tensorflow as tf
  2. tf.app.flags.DEFINE_string("learning_rate", "", "learning_rate")
  3. tf.app.flags.DEFINE_string("batch_size", "", "batch size")
  4. FAGS = tf.app.flags.FLAGS
  5. print("learning rate:" + FAGS.learning_rate)
  6. print("batch size:" + FAGS.batch_size)

TensorFlow支持的第三方库

TensorFlow1.0.0版本第三方库

  1. appdirs (1.4.3)
  2. backports-abc (0.5)
  3. backports.shutil-get-terminal-size (1.0.0)
  4. backports.ssl-match-hostname (3.5.0.1)
  5. bleach (2.0.0)
  6. boto (2.48.0)
  7. bz2file (0.98)
  8. certifi (2017.7.27.1)
  9. chardet (3.0.4)
  10. configparser (3.5.0)
  11. cycler (0.10.0)
  12. decorator (4.1.2)
  13. docutils (0.14)
  14. easygui (0.98.1)
  15. entrypoints (0.2.3)
  16. enum34 (1.1.6)
  17. funcsigs (1.0.2)
  18. functools32 (3.2.3.post2)
  19. gensim (2.3.0)
  20. h5py (2.7.0)
  21. html5lib (0.999999999)
  22. idna (2.6)
  23. iniparse (0.4)
  24. ipykernel (4.6.1)
  25. ipython (5.4.1)
  26. ipython-genutils (0.2.0)
  27. ipywidgets (7.0.0)
  28. Jinja2 (2.9.6)
  29. jsonschema (2.6.0)
  30. jupyter (1.0.0)
  31. jupyter-client (5.1.0)
  32. jupyter-console (5.1.0)
  33. jupyter-core (4.3.0)
  34. Keras (2.0.6)
  35. kitchen (1.1.1)
  36. langtable (0.0.31)
  37. MarkupSafe (1.0)
  38. matplotlib (2.0.2)
  39. mistune (0.7.4)
  40. mock (2.0.0)
  41. nbconvert (5.2.1)
  42. nbformat (4.4.0)
  43. networkx (1.11)
  44. nose (1.3.7)
  45. notebook (5.0.0)
  46. numpy (1.13.1)
  47. olefile (0.44)
  48. pandas (0.20.3)
  49. pandocfilters (1.4.2)
  50. pathlib2 (2.3.0)
  51. pbr (3.1.1)
  52. pexpect (4.2.1)
  53. pickleshare (0.7.4)
  54. Pillow (4.2.1)
  55. pip (9.0.1)
  56. prompt-toolkit (1.0.15)
  57. protobuf (3.1.0)
  58. ptyprocess (0.5.2)
  59. pycrypto (2.6.1)
  60. pycurl (7.19.0)
  61. Pygments (2.2.0)
  62. pygobject (3.14.0)
  63. pygpgme (0.3)
  64. pyliblzma (0.5.3)
  65. pyparsing (2.2.0)
  66. python-dateutil (2.6.1)
  67. pytz (2017.2)
  68. PyWavelets (0.5.2)
  69. pyxattr (0.5.1)
  70. PyYAML (3.12)
  71. pyzmq (16.0.2)
  72. qtconsole (4.3.1)
  73. requests (2.18.4)
  74. scandir (1.5)
  75. scikit-image (0.13.0)
  76. scikit-learn (0.19.0)
  77. scikit-sound (0.1.8)
  78. scikit-stack (3.0)
  79. scikit-surprise (1.0.3)
  80. scikit-tensor (0.1)
  81. scikit-video (0.1.2)
  82. scipy (0.19.1)
  83. setuptools (36.2.7)
  84. simplegeneric (0.8.1)
  85. singledispatch (3.4.0.3)
  86. six (1.10.0)
  87. slip (0.4.0)
  88. slip.dbus (0.4.0)
  89. smart-open (1.5.3)
  90. subprocess32 (3.2.7)
  91. tensorflow (1.0.0)
  92. terminado (0.6)
  93. testpath (0.3.1)
  94. tflearn (0.3.2)
  95. Theano (0.9.0)
  96. torch (0.1.12.post2)
  97. tornado (4.5.1)
  98. traitlets (4.3.2)
  99. urlgrabber (3.10)
  100. urllib3 (1.22)
  101. wcwidth (0.1.7)
  102. webencodings (0.5.1)
  103. wheel (0.29.0)
  104. widgetsnbextension (3.0.0)
  105. yum-langpacks (0.4.2)
  106. yum-metadata-parser (1.1.4)
  107. opencv-python (3.3.0.10)

TensorFlow1.1.0版本第三方库

  1. appdirs (1.4.3)
  2. backports-abc (0.5)
  3. backports.shutil-get-terminal-size (1.0.0)
  4. backports.ssl-match-hostname (3.5.0.1)
  5. bleach (2.0.0)
  6. boto (2.48.0)
  7. bz2file (0.98)
  8. certifi (2017.7.27.1)
  9. chardet (3.0.4)
  10. configparser (3.5.0)
  11. cycler (0.10.0)
  12. decorator (4.1.2)
  13. docutils (0.14)
  14. easygui (0.98.1)
  15. entrypoints (0.2.3)
  16. enum34 (1.1.6)
  17. funcsigs (1.0.2)
  18. functools32 (3.2.3.post2)
  19. gensim (2.3.0)
  20. h5py (2.7.1)
  21. html5lib (0.999999999)
  22. idna (2.6)
  23. iniparse (0.4)
  24. ipykernel (4.6.1)
  25. ipython (5.4.1)
  26. ipython-genutils (0.2.0)
  27. ipywidgets (7.0.0)
  28. Jinja2 (2.9.6)
  29. jsonschema (2.6.0)
  30. jupyter (1.0.0)
  31. jupyter-client (5.1.0)
  32. jupyter-console (5.2.0)
  33. jupyter-core (4.3.0)
  34. jupyter-tensorboard (0.1.1)
  35. Keras (2.0.8)
  36. kitchen (1.1.1)
  37. langtable (0.0.31)
  38. MarkupSafe (1.0)
  39. matplotlib (2.0.2)
  40. mistune (0.7.4)
  41. mock (2.0.0)
  42. nbconvert (5.3.0)
  43. nbformat (4.4.0)
  44. networkx (1.11)
  45. nose (1.3.7)
  46. notebook (4.4.1)
  47. numpy (1.13.1)
  48. olefile (0.44)
  49. pandas (0.20.3)
  50. pandocfilters (1.4.2)
  51. pathlib2 (2.3.0)
  52. pbr (3.1.1)
  53. pexpect (4.2.1)
  54. pickleshare (0.7.4)
  55. Pillow (4.2.1)
  56. pip (9.0.1)
  57. prompt-toolkit (1.0.15)
  58. protobuf (3.1.0)
  59. ptyprocess (0.5.2)
  60. pycrypto (2.6.1)
  61. pycurl (7.19.0)
  62. Pygments (2.2.0)
  63. pygobject (3.14.0)
  64. pygpgme (0.3)
  65. pyliblzma (0.5.3)
  66. pyparsing (2.2.0)
  67. python-dateutil (2.6.1)
  68. pytz (2017.2)
  69. PyWavelets (0.5.2)
  70. pyxattr (0.5.1)
  71. PyYAML (3.12)
  72. pyzmq (16.0.2)
  73. qtconsole (4.3.1)
  74. requests (2.18.4)
  75. scandir (1.5)
  76. scikit-image (0.13.0)
  77. scikit-learn (0.19.0)
  78. scikit-sound (0.1.8)
  79. scikit-stack (3.0)
  80. scikit-surprise (1.0.3)
  81. scikit-tensor (0.1)
  82. scikit-video (0.1.2)
  83. scipy (0.19.1)
  84. setuptools (36.4.0)
  85. simplegeneric (0.8.1)
  86. singledispatch (3.4.0.3)
  87. six (1.10.0)
  88. slip (0.4.0)
  89. slip.dbus (0.4.0)
  90. smart-open (1.5.3)
  91. subprocess32 (3.2.7)
  92. tensorflow (1.1.0)
  93. terminado (0.6)
  94. testpath (0.3.1)
  95. tflearn (0.3.2)
  96. Theano (0.9.0)
  97. torch (0.1.12.post2)
  98. tornado (4.5.2)
  99. traitlets (4.3.2)
  100. urlgrabber (3.10)
  101. urllib3 (1.22)
  102. wcwidth (0.1.7)
  103. webencodings (0.5.1)
  104. Werkzeug (0.12.2)
  105. wheel (0.29.0)
  106. widgetsnbextension (3.0.2)
  107. yum-langpacks (0.4.2)
  108. yum-metadata-parser (1.1.4)
  109. opencv-python (3.3.0.10)

TensorFlow1.2.1版本第三方库

  1. appdirs (1.4.3)
  2. backports-abc (0.5)
  3. backports.shutil-get-terminal-size (1.0.0)
  4. backports.ssl-match-hostname (3.5.0.1)
  5. backports.weakref (1.0rc1)
  6. bleach (1.5.0)
  7. boto (2.48.0)
  8. bz2file (0.98)
  9. certifi (2017.7.27.1)
  10. chardet (3.0.4)
  11. configparser (3.5.0)
  12. cycler (0.10.0)
  13. decorator (4.1.2)
  14. docutils (0.14)
  15. easygui (0.98.1)
  16. entrypoints (0.2.3)
  17. enum34 (1.1.6)
  18. funcsigs (1.0.2)
  19. functools32 (3.2.3.post2)
  20. gensim (2.3.0)
  21. h5py (2.7.1)
  22. html5lib (0.9999999)
  23. idna (2.6)
  24. iniparse (0.4)
  25. ipykernel (4.6.1)
  26. ipython (5.4.1)
  27. ipython-genutils (0.2.0)
  28. ipywidgets (7.0.0)
  29. Jinja2 (2.9.6)
  30. jsonschema (2.6.0)
  31. jupyter (1.0.0)
  32. jupyter-client (5.1.0)
  33. jupyter-console (5.2.0)
  34. jupyter-core (4.3.0)
  35. jupyter-tensorboard (0.1.1)
  36. Keras (2.0.8)
  37. kitchen (1.1.1)
  38. langtable (0.0.31)
  39. Markdown (2.6.9)
  40. MarkupSafe (1.0)
  41. matplotlib (2.0.2)
  42. mistune (0.7.4)
  43. mock (2.0.0)
  44. nbconvert (5.3.0)
  45. nbformat (4.4.0)
  46. networkx (1.11)
  47. nose (1.3.7)
  48. notebook (4.4.1)
  49. numpy (1.13.1)
  50. olefile (0.44)
  51. pandas (0.20.3)
  52. pandocfilters (1.4.2)
  53. pathlib2 (2.3.0)
  54. pbr (3.1.1)
  55. pexpect (4.2.1)
  56. pickleshare (0.7.4)
  57. Pillow (4.2.1)
  58. pip (9.0.1)
  59. prompt-toolkit (1.0.15)
  60. protobuf (3.1.0)
  61. ptyprocess (0.5.2)
  62. pycrypto (2.6.1)
  63. pycurl (7.19.0)
  64. Pygments (2.2.0)
  65. pygobject (3.14.0)
  66. pygpgme (0.3)
  67. pyliblzma (0.5.3)
  68. pyparsing (2.2.0)
  69. python-dateutil (2.6.1)
  70. pytz (2017.2)
  71. PyWavelets (0.5.2)
  72. pyxattr (0.5.1)
  73. PyYAML (3.12)
  74. pyzmq (16.0.2)
  75. qtconsole (4.3.1)
  76. requests (2.18.4)
  77. scandir (1.5)
  78. scikit-image (0.13.0)
  79. scikit-learn (0.19.0)
  80. scikit-sound (0.1.8)
  81. scikit-stack (3.0)
  82. scikit-surprise (1.0.3)
  83. scikit-tensor (0.1)
  84. scikit-video (0.1.2)
  85. scipy (0.19.1)
  86. setuptools (36.4.0)
  87. simplegeneric (0.8.1)
  88. singledispatch (3.4.0.3)
  89. six (1.10.0)
  90. slip (0.4.0)
  91. slip.dbus (0.4.0)
  92. smart-open (1.5.3)
  93. subprocess32 (3.2.7)
  94. tensorflow (1.2.1)
  95. terminado (0.6)
  96. testpath (0.3.1)
  97. tflearn (0.3.2)
  98. Theano (0.9.0)
  99. torch (0.1.12.post2)
  100. tornado (4.5.2)
  101. traitlets (4.3.2)
  102. urlgrabber (3.10)
  103. urllib3 (1.22)
  104. wcwidth (0.1.7)
  105. webencodings (0.5.1)
  106. Werkzeug (0.12.2)
  107. wheel (0.29.0)
  108. widgetsnbextension (3.0.2)
  109. yum-langpacks (0.4.2)
  110. yum-metadata-parser (1.1.4)
  111. opencv-python (3.3.0.10)

二部图GraphSage嵌入算法

本算法同时支持同构图和二部图的Graph Node Representation Learning算法。节点编码可带属性。二部图编码时,支持u-i-u和u-i-i类型的编码方式,其中u-i-i需要提供额外的i2i相似度表。

本算法集采样和学习为一体,不需要外部系统进行采样。支持十亿点,千亿边分布式大规模数据集的运行。

1. 算法背景

本算法同时支持同构图和二部图的Graph Node Representation Learning算法。节点编码可带属性。二部图编码时,支持u-i-u和u-i-i类型的编码方式,其中u-i-i需要提供额外的i2i相似度表。

注:

  1. 目前暂时仅支持无监督训练任务。
  2. 可以任意指定编码层数,0层时算法退化为1st order Line算法的更新逻辑
  3. 不同模式train/eval/save_emb通过设置对应的mode参数来实现
  4. 输入表的schema见2.2部分

2. 使用方法

2.1 参数列表
参数key名称 参数描述 取值范围 是否必选,默认值/行为
uiTableName user与item连边的表名 表名字符串,project.tableName[/partdesc=xx] 可选,默认””
iiTableName item与item连边的表名。在同构图(homo_graph=true)设定下,就是同构图表名 表名字符串,project.tableName[/partdesc=xx] 必选
uFeatTable user的feature表名 表名字符串,project.tableName[/partdesc=xx] 可选,同构图时可以不设置
iFeatTable item的feature表名 表名字符串,project.tableName[/partdesc=xx] 必选
uiTestTable 测试边表 表名字符串,project.tableName[/partdesc=xx] 可选, mode=eval时会使用
epochs 迭代轮数 int 必选,默认1
u_discrete_feat_desc user属性编码配置表, json格式, 记录feature表里第index个特征的总数n及需要embedding的维数d,则’{“index”:[“feature_name”, n, d]}’。注意,以’idx’结尾的feature_name表示,大小为n的连续id的Lookup Table。而其他的则表示特征会经过string2hash去Lookup一个bucket数量为n的Lookup Table json 可选
i_discrete_feat_desc item属性编码配置表,同上 json 可选
host oss host string 必选
arn oss arn string 必选
validate_iter 验证集batch数 int 可选,默认200
encoding_schema item邻居的编码方式 u-i-i 或者 u-i-u 可选,默认u-i-u
dim 中间层embedding dimension int 可选,默认256
final_dim 最后一层embedding dimension int 可选,默认64,设置为-1则不需要最后一层projection
u_depth user encoding深度 int 可选,默认1
i_depth item encoding深度 int 可选,默认2
u_neighs_num user编码层数信息 string 可选,默认20
i_neighs_num item编码层数信息 string 可选,默认20
learning_rate 学习率 string 可选,默认0.0001,注意随着worker数目不同适当调节学习率
checkpointDir 模型输出的oss路径 string 可选,默认oss://checkpoint
cluster_ps_count #ps server int 可选,默认1
cluster_ps_memory 单个ps server内存, 单位为m int 可选,默认10000
cluster_worker_count #worker int 可选,默认1
cluster_worker_memory 单个worker内存,单位为m int 可选,默认8000
user_features_num user特征总数 int 可选,默认1
item_features_num item特征总数 int 可选,默认1
user_count user总数,当不指定u_discrete_feat_desc时,默认采用从0开始的连续id embedding,table大小为user_coun int 可选,默认-1
user_count item总数,当不指定i_discrete_feat_desc时,默认采用从0开始的连续id embedding,table大小为user_coun int 可选,默认-1
mode 运行模式 train/eval/save_emb 可选,默认train
outputs_u user emb输出表 表名字符串,project.tableName[/partdesc=xx] 可选
outputs_i item emb输出表 表名字符串,project.tableName[/partdesc=xx] 可选
homo_graph 是否同构图 boolean 可选,默认false
learning_algo adam或sgd string 可选,默认adam
2.2 输入输出表的schema:

输入边表(测试表同):

  1. create table ui_edge (
  2. src_id bigint comment 'source节点id',
  3. src_type string comment '目前可忽略',
  4. dst_id bigint comment 'destination节点id',
  5. dst_type string comment '目前可忽略',
  6. val string comment '边权重'
  7. );

src_type, dst_type对于该算法可以随意指定。val是一个值,会作为带权采样的依据。

输入点特征表:

  1. create table user_feature (
  2. id bigint comment '节点id',
  3. feature string comment '节点属性'
  4. );

features以分号隔开,例如’103:食品:309.0’

输出表:

  1. create table user_emb (
  2. id bigint comment '节点id',
  3. emb string comment 'embedding列'
  4. );

emb会以’,’分隔

2.3 Examples:
2.3.1 同构图/二部图的训练样例

同构图

  1. pai -name graphsage_ss_ext
  2. -project algo_public_dev
  3. -Dhomo_graph=True
  4. -Dmode='train'
  5. -DiiTableName='graph_embedding_dev.arxiv_train_edge'
  6. -DiFeatTable='graph_embedding_dev.arxiv_feature'
  7. -Dhost='cn-zhangjiakou.oss-internal.aliyun-inc.com'
  8. -Darn='acs:ram::xxxxx:role/xx'
  9. -DcheckpointDir='oss://xxxx/xxxxx/'
  10. -Dcluster_ps_count=2
  11. -Dcluster_worker_count=2
  12. -Dcluster_worker_memory=16000
  13. -Dencoding_schema='u-i-i'
  14. -Duser_features_num=1
  15. -Ditem_features_num=1
  16. -Di_discrete_feat_desc='{"0":["item_idx",5242,64]}'
  17. -Du_neighs_num='15,10'
  18. -Di_neighs_num='15,10'
  19. -Depochs=20
  20. -Ditem_count=5242
  21. -Dlearning_rate='0.001'

二部图

  1. pai -name graphsage_ss_ext
  2. -project algo_public_dev
  3. -Dmode='train'
  4. -DuiTableName='graph_embedding_dev.ui_train_rand_split'
  5. -DiiTableName='graph_embedding_dev.iu_train_rand_split'
  6. -DuFeatTable='graph_embedding_dev.user_feature'
  7. -DiFeatTable='graph_embedding_dev.item_feature'
  8. -Dhost='cn-zhangjiakou.oss-internal.aliyun-inc.com'
  9. -Darn='acs:ram::xxxxx:role/xxxx'
  10. -DcheckpointDir='oss://xxxx/xxxx/'
  11. -Dcluster_ps_count=2
  12. -Dcluster_worker_count=2
  13. -Dcluster_worker_memory=16000
  14. -Di_discrete_feat_desc='{"0":["cate",1000,32],"1":["item_idx",23033,64]}'
  15. -Du_discrete_feat_desc='{"0":["user_idx",39387,1]}'
  16. -Dencoding_schema='u-i-u'
  17. -Duser_features_num=1
  18. -Ditem_features_num=2
  19. -Du_neighs_num='20'
  20. -Di_neighs_num='20'
  21. -Depochs=20
  22. -Duser_count=39387
  23. -Ditem_count=23033
  24. -Dlearning_rate='0.001'