通过whale.auto_parallel接口,您可以快速实现模型分布式训练。本文介绍该接口的调用格式、参数说明及调用示例。

背景信息

Whale可以通过资源Cluster划分和模型Scope划分实现模型分布式训练。该实现方式需要理解资源和模型的配置组合,以达到高效地分布式训练性能。为了进一步降低使用成本及简化操作,Whale提供了whale.auto_parallel接口,可以通过一行代码的简易模式自动进行并行化操作。

接口说明

  • 格式
    auto_parallel(modes)
  • 功能
    通过一行代码的简易模式自动进行并行化操作。例如:
    • whale.auto_parallel(whale.replica):自动对整个模型进行数据并行。
    • whale.auto_parallel(whale.split):自动对整个模型进行算子拆分。
    • whale.auto_parallel(whale.pipeline):自动对整个模型进行模型拆分,再进行流水并行。
    • whale.auto_parallel([whale.pipeline, whale.replica]):Whale通过Split和Replica两种模式进行自动并行,其中的参数modes可以自由组合多种whale.scopes原语。
    • whale.auto_parallel():全自动并行化训练模式,Whale自动推断生成分布式策略。
    说明 当前仅支持whale.auto_parallel(whale.replica)自动数据并行模式。
  • 参数

    modes:并行化模式,SCOPES类型。例如数据并行、模型并行、流水并行及组合的并行方式。当前仅支持自动数据并行,即modes=whale.replica

  • 返回值

  • 示例
    import whale as wh之后,添加wh.auto_parallel(wh.replica)即可自动配置Cluster,并实现数据并行。此处仅给出主体代码,完整代码请参见auto_data_parallel.py
    import whale as wh
    
    wh.auto_parallel(wh.replica)
    
    # Construct your model here.
    model_definition()