文档

WorkQueue

更新时间:
重要

本文中含有需要您注意的重要提示信息,忽略该信息可能对您的业务造成影响,请务必仔细阅读。

在大规模分布式异步训练中,您可以使用WorkQueue进行弹性数据切分,以缓解长尾效应,从而降低模型训练所需的时间。本文介绍WorkQueue的调用格式、参数及其提供的方法。同时,以文件数据源和MaxCompute表数据源为例,介绍实现数据切分的经典示例。

警告

公共云GPU服务器即将过保下线,您可以继续提交CPU版本的TensorFlow任务。如需使用GPU进行模型训练,请前往DLC提交任务,具体操作请参见创建训练任务

背景信息

在大规模分布式异步训练中,如果每个Worker读取相同数量的样本,则慢节点的训练时长会远大于其他节点,造成长尾效应。并且随着训练规模扩大,长尾效应会越来越严重,导致训练的整体数据吞吐降低,进而增加训练时间。

为解决该问题,PAI提供了pai.data.WorkQueue类,支持对多种数据源进行弹性数据切分,让慢节点获取较少的训练数据,快节点获取更多的训练数据,以缓解长尾效应,从而降低模型训练所需的时间。

版本配套关系

  • Python版本:Python 2.7

  • PAI-TensorFlow版本:PAI-TensorFlow 1.12

pai.data.WorkQueue

  • 功能

    工作项队列类,用于统一管理所有Worker上的工作项。每个Worker的当前剩余工作项被消费完后,会从同一个WorkQueue获得新的工作项,并将其作为数据源进行训练,从而使得训练快的Worker获得更多的工作项进行训练,以减少长尾效应。

  • 格式

    class pai.data.WorkQueue(works, num_epochs=1, shuffle=True, seed=None, 
                                prefix=None, num_slices=None, name='work_queue')
  • 参数

    参数名

    描述

    类型

    是否必选

    默认值

    works

    文件名或表名列表。

    LIST of STRING

    num_epochs

    读取全部数据的次数。

    INT

    1

    shuffle

    是否每个Epoch都随机重洗数据,取值如下:

    • True:每个Epoch都随机重洗数据。

    • False:不进行数据重洗。

    BOOL

    True

    seed

    重洗数据的随机种子。取值为None时,表示系统自动选取随机种子。

    INT

    None

    prefix

    工作项(文件名或表名)的前缀。取值为None时,表示无前缀。

    STRING

    None

    num_slices

    工作项的总数量。集群越不稳定,需要将工作项总数量配置的越大,通常为Worker数量的10倍以上。取值为None时,表示不分片。

    INT

    None

    num_clients

    工作队列支持的最大工作抢占并发数。

    INT

    1

    name

    工作队列的名称。

    STRING

    work_queue

  • 返回值

    返回WorkQueue对象,您可以使用该对象调用pai.data.WorkQueue类提供的方法。

pai.data.WorkQueue提供的方法

pai.data.WorkQueue类提供以下方法:

  • take

    • 功能

      从全局工作队列获取一个工作项,并下载至本地。

    • 格式

      WorkQueue.take()
    • 参数

    • 返回值

      返回值类型为tensorflow.Tensor

  • input_dataset

    • 功能

      返回一个Dataset,其每个元素为一个工作项。

    • 格式

      WorkQueue.input_dataset()
    • 参数

    • 返回值

      返回值类型为tensorflow.data.Dataset

  • input_producer

    • 功能

      返回全局工作队列在本地的代理队列,为Reader类Op使用。

    • 格式

      WorkQueue.input_producer()
    • 参数

    • 返回值

      返回值类型为tensorflow.FIFOQueue

  • add_summary

    • 功能

      在Tensorboard中显示WorkQueue的资源水位信息。

    • 格式

      WorkQueue.add_summary()
    • 参数

    • 返回值

典型示例

pai.data.WorkQueue类支持对多种数据源进行弹性数据切分,以下分别以文件数据源和MaxCompute表数据源为例,介绍如何使用pai.data.WorkQueue类实现弹性数据切分(仅提供核心代码片段):

  • 文件数据源

    import pai
    # ...
    # path1、path2及path3表示需要读取的文件列表。
    # shuffle取值为True,表示每个Epoch都随机化打散文件路径。
    work_queue = pai.data.WorkQueue([path1, path2, path3], shuffle=True)
    
    # 让WorkQueue支持TensorBoard。
    work_queue.add_summary()
    
    # 创建文件读取器。
    reader = tf.TextLineReader()
    # 从文件列表中读取2条记录。
    keys, values = reader.read_up_to(work_queue.input_producer(), num_records=2)
    with tf.train.MonitoredTrainingSession() as sess:
      sess.run(...)
  • MaxCompute表数据源

    • TableRecordDataset数据源

      import pai
      #...
      # odps_path1、odps_path2及odps_path3表示需要读取的MaxCompute表列表。
      # shuffle取值为True,表示每个Epoch都随机化打散表路径。
      # num_slices为工作项总数量。
      # FLAGS.num_workers为训练中的Worker数量。
      work_queue = pai.data.WorkQueue([odps_path1, odps_path2, odps_path3],shuffle=True,
                          num_slices=FLAGS.num_workers * 10)
      # 创建文件名Dataset。
      filenames_dataset = work_queue.input_dataset()
      
      # 将dataset作为文件名传入TableRecordDataset。
      dataset = tf.data.TableRecordDataset(filenames_dataset, record_defaults=...)

      关于tf.data.TableRecordDataset接口的调用,请参见TableRecordDataset

    • TableRecordReader数据源

      import pai
      # ...
      # odps_path1、odps_path2及odps_path3表示需要读取的MaxCompute表列表。
      # shuffle取值为True,表示每个Epoch都随机化打散表路径。
      # num_slices为工作项总数量。
      # FLAGS.num_workers为训练中的Worker数量。
      work_queue = pai.data.WorkQueue(
        [odps_path1, odps_path2, odps_path3], shuffle=True, num_slices=FLAGS.num_workers * 10)
      
      # 创建表读取器。
      reader = tf.TableRecordReader()
      
      # 从表中读取2条记录。
      keys, values = reader.read_up_to(work_queue.input_producer(), num_records=2)
  • 本页导读 (1)
文档反馈