Alibaba Cloud PAI provides sample models for typical scenarios to simplify TorchAcc integration for training acceleration. This topic explains how to use TorchAcc with Stable Diffusion to accelerate distributed training.
Staging environment configuration
To configure the staging environment, see Configure a staging environment.
This example uses the V100M16 GPU type in a DSW environment. For example, the node specification is ecs.gn6v-c8g1.16xlarge-64c256gNVIDIA V100 * 8.
Accelerating distributed training with TorchAcc
Using the DSW environment as an example:
-
Go to the DSW instance page to download and extract the test code and script files.
-
On the Data Science Workshop (DSW) page, find your DSW instance and click Open in the Actions column.
-
On the Launcher page of the Notebook tab, in the Quick start section, click Python3 under Notebook.
-
Run the following command to download and extract the test code and script files.
!wget http://odps-release.cn-hangzhou.oss.aliyun-inc.com/torchacc/accbench/gallery/stable_diffusion.tar.gz && tar -zxvf stable_diffusion.tar.gz
-
-
Go to the
stable-diffusiondirectory and double-clickstable_diffusion.ipynbto open it.Run the commands for each step sequentially in the file.
-
Run the following command to download a mock dataset similar to ImageNet-1k and install the third-party packages required for the Stable Diffusion model.
!bash prepare.sh -
To verify the performance gains from TorchAcc, run distributed training for the Stable Diffusion model using both the standard method (baseline) and the TorchAcc-integrated method.
Note-
When you test different GPU types, such as V100 or A10, adjust the batch_size to fit the VRAM of the GPU.
-
The number of GPUs per machine varies across different machine instances. If a machine has N GPUs, set nproc_per_node to start single-GPU or multi-GPU tasks. The value of nproc_per_node must be an integer from 1 to N.
-
Pytorch Eager single-card (baseline training)
!#!/bin/bash !set -ex !python launch_single_task.py --batch_size=4 --nproc_per_node=1 -
Pytorch Eager 8-card (baseline training)
!#!/bin/bash !set -ex !python launch_single_task.py --batch_size=4 --nproc_per_node=8 -
TorchAcc single-card (PAI-OPT)
!#!/bin/bash !set -ex !python launch_single_task.py --batch_size=4 --nproc_per_node=1 --compiler-opt -
TorchAcc 8-card (PAI-OPT)
!#!/bin/bash !set -ex !python launch_single_task.py --batch_size=4 --nproc_per_node=8 --compiler-opt
The optimization configurations for baseline training and TorchAcc training are as follows:
-
baseline: Torch112+DDP+AMPO1
-
PAI-Opt: Torch112+TorchAcc+AMPO1
-
-
Run the following command to view the performance data.
import os from plot import plot, traverse from parser import parse_file #import seaborn as sns if __name__ == '__main__': path = "output" file_names = {} traverse(path, file_names) for model, tags in file_names.items(): for tag, suffixes in tags.items(): title = model + "_" + tag label = [] api_data = [] for suffix, o_suffixes in suffixes.items(): label.append(suffix) for output_suffix, node_ranks in o_suffixes.items(): assert "0" in node_ranks assert "log" in node_ranks["0"] parse_data = parse_file(node_ranks["0"]["log"]) api_data.append(parse_data) plot(title, label, api_data)For the V100M16 GPU model, the limited video memory requires a small batch_size, which prevents significant acceleration. However, in production workloads, tests on A10 GPUs show that TorchAcc delivers a speedup of over 40% on both single-card and multi-card setups. For more details about the code implementation, see Implementation Details.
Implementation details
When using Stable Diffusion with the pytorch-lightning==1.8.6 third-party package, you can integrate TorchAcc by importing pl_hooks.py and logger.py from the stable-diffusion directory.
Import TorchAcc API
In the main function, add the following code to the import section. For more information, see lines 35-45 in the main.py file:
from logger import create_logger, enable_torchacc_compiler, enable_torchacc_kernel, log_params, log_metrics
+if enable_torchacc_compiler():
+ from torchacc.torch_xla.amp import GradScaler
+ import torchacc.torch_xla.distributed.xla_backend
+ import torchacc.torch_xla.core.xla_model as xm
+ import torchacc.torch_xla.distributed.parallel_loader as ploader
+ dist.get_rank = xm.get_ordinal
+ dist.get_world_size = xm.xrt_world_size
+ device = xm.xla_device()
+ xm.set_replication(device, [device])
+else:
from torch.cuda.amp import GradScaler
Enable Pytorch-lightning hook
Use the enable_pl_hooks function from the pl_hooks.py file to integrate TorchAcc. For more information, see line 588 in the main.py file:
from pl_hooks import enable_pl_hooks
+if enable_torchacc_compiler():
+ from torchacc.torch_xla.amp import syncfree
+ torch.optim.Adam = syncfree.Adam
+ torch.optim.AdamW = syncfree.AdamW
+ torch.optim.SGD = syncfree.SGD
+if opt.use_pl_logger:
+ os.environ["USE_PL_LOGGER"] = "1"
+if opt.log_freq is not None:
+ os.environ["LOG_FREQ"] = str(opt.log_freq)
+enable_pl_hooks() # call hook to accelerate