阿里云PAI为您提供了部分典型场景下的示例模型,便于您便捷地接入TorchAcc进行训练加速。本文为您介绍如何在Stable Diffusion分布式训练中接入TorchAcc并实现训练加速。
测试环境配置
测试环境配置方法,请参见配置测试环境。
本案例以DSW环境V100M16卡型为例,例如:节点规格选择ecs.gn6v-c8g1.16xlarge-64c256gNVIDIA V100 * 8
。
接入TorchAcc加速Stable Diffusion分布式训练
以DSW环境为例:
进入DSW实例页面下载并解压测试代码及脚本文件。
在交互式建模(DSW)页面,单击DSW实例操作列下的打开。
在Notebook页签的Launcher页面,单击快速开始区域Notebook下的Python3。
执行以下命令下载并解压测试代码及脚本文件。
!wget http://odps-release.cn-hangzhou.oss.aliyun-inc.com/torchacc/accbench/gallery/stable_diffusion.tar.gz && tar -zxvf stable_diffusion.tar.gz
进入
stable-diffusion
目录,双击打开stable_diffusion.ipynb
文件。后续,您可以直接在该文件中运行下述步骤中的命令,当成功运行结束一个步骤命令后,再顺次运行下个步骤的命令。
执行以下命令下载类Imagenet-1k的mock数据集并安装Stable Diffusion模型依赖的第三方包。
!bash prepare.sh
分别使用普通训练方法(baseline)和接入TorchAcc进行Stable Diffusion模型分布式训练,来验证TorchAcc的性能提升效果。
说明在测试不同GPU卡型(例如V100、A10等)时,可以通过调整batch_size来适配不同卡型的显存大小。
在测试不同机器实例时,由于单机GPU卡数不同(假设为N),因此可以通过设置nproc_per_node来启动单卡或多卡的任务,其中:1<=nproc_per_node<=N。
Pytorch Eager单卡(baseline训练)
!#!/bin/bash !set -ex !python launch_single_task.py --batch_size=4 --nproc_per_node=1
Pytorch Eager八卡(baseline训练)
!#!/bin/bash !set -ex !python launch_single_task.py --batch_size=4 --nproc_per_node=8
TorchAcc单卡(PAI-OPT)
!#!/bin/bash !set -ex !python launch_single_task.py --batch_size=4 --nproc_per_node=1 --compiler-opt
TorchAcc八卡(PAI-OPT)
!#!/bin/bash !set -ex !python launch_single_task.py --batch_size=4 --nproc_per_node=8 --compiler-opt
其中:普通训练方法和接入TorchAcc训练方法的优化配置如下:
baseline:Torch112+DDP+AMPO1
PAI-Opt:Torch112+TorchAcc+AMPO1
执行以下命令,获取性能数据结果。
import os from plot import plot, traverse from parser import parse_file #import seaborn as sns if __name__ == '__main__': path = "output" file_names = {} traverse(path, file_names) for model, tags in file_names.items(): for tag, suffixes in tags.items(): title = model + "_" + tag label = [] api_data = [] for suffix, o_suffixes in suffixes.items(): label.append(suffix) for output_suffix, node_ranks in o_suffixes.items(): assert "0" in node_ranks assert "log" in node_ranks["0"] parse_data = parse_file(node_ranks["0"]["log"]) api_data.append(parse_data) plot(title, label, api_data)
对于V100M16卡型,由于显存有限,batch_size设置的值比较小,无法获得较大程度的加速效果。但在实际场景中,经过在A10上的测试验证,使用TorchAcc在单卡和多卡上均能够获得40%以上的提速效果。关于接入TorchAcc更详细的代码实现原理,请参见代码实现原理。
代码实现原理
基于StableDiffusion使用三方包pytorch-lighting==1.8.6版本时,可以直接导入stable-diffusion
目录下的pl_hooks.py和logger.py完成TorchAcc接入。
Import TorchAcc API
在main
函数import处添加以下代码,具体请参考main.py文件中35-45行代码:
from logger import create_logger, enable_torchacc_compiler, enable_torchacc_kernel, log_params, log_metrics
+if enable_torchacc_compiler():
+ from torchacc.torch_xla.amp import GradScaler
+ import torchacc.torch_xla.distributed.xla_backend
+ import torchacc.torch_xla.core.xla_model as xm
+ import torchacc.torch_xla.distributed.parallel_loader as ploader
+ dist.get_rank = xm.get_ordinal
+ dist.get_world_size = xm.xrt_world_size
+ device = xm.xla_device()
+ xm.set_replication(device, [device])
+else:
from torch.cuda.amp import GradScaler
Enable Pytorch-lightning hook
使用pl_hooks.py文件的enable_pl_hooks.py完成TorchAcc接入,具体请参考main.py文件中588行代码:
from pl_hooks import enable_pl_hooks
+if enable_torchacc_compiler():
+ from torchacc.torch_xla.amp import syncfree
+ torch.optim.Adam = syncfree.Adam
+ torch.optim.AdamW = syncfree.AdamW
+ torch.optim.SGD = syncfree.SGD
+if opt.use_pl_logger:
+ os.environ["USE_PL_LOGGER"] = "1"
+if opt.log_freq is not None:
+ os.environ["LOG_FREQ"] = str(opt.log_freq)
+enable_pl_hooks() # call hook of acclerate