本文中含有需要您注意的重要提示信息,忽略该信息可能对您的业务造成影响,请务必仔细阅读。
在大规模分布式异步训练中,您可以使用WorkQueue进行弹性数据切分,以缓解长尾效应,从而降低模型训练所需的时间。本文介绍WorkQueue的调用格式、参数及其提供的方法。同时,以文件数据源和MaxCompute表数据源为例,介绍实现数据切分的经典示例。
公共云GPU服务器即将过保下线,您可以继续提交CPU版本的TensorFlow任务。如需使用GPU进行模型训练,请前往DLC提交任务,具体操作请参见创建训练任务。
背景信息
在大规模分布式异步训练中,如果每个Worker读取相同数量的样本,则慢节点的训练时长会远大于其他节点,造成长尾效应。并且随着训练规模扩大,长尾效应会越来越严重,导致训练的整体数据吞吐降低,进而增加训练时间。
为解决该问题,PAI提供了pai.data.WorkQueue
类,支持对多种数据源进行弹性数据切分,让慢节点获取较少的训练数据,快节点获取更多的训练数据,以缓解长尾效应,从而降低模型训练所需的时间。
版本配套关系
Python版本:Python 2.7
PAI-TensorFlow版本:PAI-TensorFlow 1.12
pai.data.WorkQueue
功能
工作项队列类,用于统一管理所有Worker上的工作项。每个Worker的当前剩余工作项被消费完后,会从同一个WorkQueue获得新的工作项,并将其作为数据源进行训练,从而使得训练快的Worker获得更多的工作项进行训练,以减少长尾效应。
格式
class pai.data.WorkQueue(works, num_epochs=1, shuffle=True, seed=None, prefix=None, num_slices=None, name='work_queue')
参数
参数名
描述
类型
是否必选
默认值
works
文件名或表名列表。
LIST of STRING
是
无
num_epochs
读取全部数据的次数。
INT
否
1
shuffle
是否每个Epoch都随机重洗数据,取值如下:
True:每个Epoch都随机重洗数据。
False:不进行数据重洗。
BOOL
否
True
seed
重洗数据的随机种子。取值为None时,表示系统自动选取随机种子。
INT
否
None
prefix
工作项(文件名或表名)的前缀。取值为None时,表示无前缀。
STRING
否
None
num_slices
工作项的总数量。集群越不稳定,需要将工作项总数量配置的越大,通常为Worker数量的10倍以上。取值为None时,表示不分片。
INT
否
None
num_clients
工作队列支持的最大工作抢占并发数。
INT
否
1
name
工作队列的名称。
STRING
否
work_queue
返回值
返回WorkQueue对象,您可以使用该对象调用
pai.data.WorkQueue
类提供的方法。
pai.data.WorkQueue提供的方法
pai.data.WorkQueue
类提供以下方法:
take
功能
从全局工作队列获取一个工作项,并下载至本地。
格式
WorkQueue.take()
参数
无
返回值
返回值类型为
tensorflow.Tensor
。
input_dataset
功能
返回一个Dataset,其每个元素为一个工作项。
格式
WorkQueue.input_dataset()
参数
无
返回值
返回值类型为
tensorflow.data.Dataset
。
input_producer
功能
返回全局工作队列在本地的代理队列,为Reader类Op使用。
格式
WorkQueue.input_producer()
参数
无
返回值
返回值类型为
tensorflow.FIFOQueue
。
add_summary
功能
在Tensorboard中显示WorkQueue的资源水位信息。
格式
WorkQueue.add_summary()
参数
无
返回值
无
典型示例
pai.data.WorkQueue
类支持对多种数据源进行弹性数据切分,以下分别以文件数据源和MaxCompute表数据源为例,介绍如何使用pai.data.WorkQueue
类实现弹性数据切分(仅提供核心代码片段):
文件数据源
import pai # ... # path1、path2及path3表示需要读取的文件列表。 # shuffle取值为True,表示每个Epoch都随机化打散文件路径。 work_queue = pai.data.WorkQueue([path1, path2, path3], shuffle=True) # 让WorkQueue支持TensorBoard。 work_queue.add_summary() # 创建文件读取器。 reader = tf.TextLineReader() # 从文件列表中读取2条记录。 keys, values = reader.read_up_to(work_queue.input_producer(), num_records=2) with tf.train.MonitoredTrainingSession() as sess: sess.run(...)
MaxCompute表数据源
TableRecordDataset数据源
import pai #... # odps_path1、odps_path2及odps_path3表示需要读取的MaxCompute表列表。 # shuffle取值为True,表示每个Epoch都随机化打散表路径。 # num_slices为工作项总数量。 # FLAGS.num_workers为训练中的Worker数量。 work_queue = pai.data.WorkQueue([odps_path1, odps_path2, odps_path3],shuffle=True, num_slices=FLAGS.num_workers * 10) # 创建文件名Dataset。 filenames_dataset = work_queue.input_dataset() # 将dataset作为文件名传入TableRecordDataset。 dataset = tf.data.TableRecordDataset(filenames_dataset, record_defaults=...)
关于
tf.data.TableRecordDataset
接口的调用,请参见TableRecordDataset。TableRecordReader数据源
import pai # ... # odps_path1、odps_path2及odps_path3表示需要读取的MaxCompute表列表。 # shuffle取值为True,表示每个Epoch都随机化打散表路径。 # num_slices为工作项总数量。 # FLAGS.num_workers为训练中的Worker数量。 work_queue = pai.data.WorkQueue( [odps_path1, odps_path2, odps_path3], shuffle=True, num_slices=FLAGS.num_workers * 10) # 创建表读取器。 reader = tf.TableRecordReader() # 从表中读取2条记录。 keys, values = reader.read_up_to(work_queue.input_producer(), num_records=2)