阿里云PAI为您提供了部分典型场景下的示例模型,便于您便捷地接入TorchAcc进行训练加速。本文为您介绍如何在Swin Transformer分布式训练中接入TorchAcc并实现训练加速。
测试环境配置
测试环境配置方法,请参见配置测试环境。
本案例以DSW环境V100M16卡型为例,例如:节点规格选择ecs.gn6v-c8g1.16xlarge-64c256gNVIDIA V100 * 8
。
接入TorchAcc加速Swin Transformer分布式训练
以DSW环境为例:
进入DSW实例页面下载并解压测试代码及脚本文件。
在交互式建模(DSW)页面,单击DSW实例操作列下的打开。
在Notebook页签的Launcher页面,单击快速开始区域Notebook下的Python3。
执行以下命令下载并解压测试代码及脚本文件。
!wget http://odps-release.cn-hangzhou.oss.aliyun-inc.com/torchacc/accbench/gallery/swin_transformer.tar.gz && tar -zxvf swin_transformer.tar.gz
进入
Swin-Transformer
目录,双击打开swin_transformer.ipynb
文件。后续,您可以直接在该文件中运行下述步骤中的命令,当成功运行结束一个步骤命令后,再顺次运行下个步骤的命令。
执行以下命令下载类似Imagenet-1k的mock数据集并安装Swin Transformer模型依赖的第三方包。
!bash prepare.sh
分别使用普通训练方法(baseline)和接入TorchAcc进行Swin Transformer模型分布式训练,来验证TorchAcc的性能提升效果。
普通训练方法和接入TorchAcc训练方法的优化配置如下:
baseline:Torch112+DDP+AMPO1
PAI-Opt:Torch112+TorchAcc+AMPO1
说明在测试不同GPU卡型(例如V100、A10等)时,可以通过调整batch_size来适配不同卡型的显存大小。
在测试不同机器实例时,由于单机GPU卡数不同(假设为N),因此可以通过设置nproc_per_node来启动单卡或多卡的任务,其中:1<=nproc_per_node<=N。
Pytorch Eager单卡(baseline训练)
!#!/bin/bash !set -ex !python launch_single_task.py --amp_level=O1 --batch_size=32 --nproc_per_node=1
Pytorch Eager八卡(baseline训练)
!#!/bin/bash !set -ex !python launch_single_task.py --amp_level=O1 --batch_size=32 --nproc_per_node=8
TorchAcc单卡(PAI-OPT)
!#!/bin/bash !set -ex !python launch_single_task.py --nproc_per_node=1 --amp_level=O2 --kernel-opt --batch_size=32 --nproc_per_node=1
TorchAcc八卡(PAI-OPT)
!#!/bin/bash !set -ex !python launch_single_task.py --nproc_per_node=1 --amp_level=O2 --kernel-opt --batch_size=32 --nproc_per_node=8
执行以下命令,获取性能数据结果。
import os from plot import plot, traverse from parser import parse_file # import seaborn as sns if __name__ == '__main__': path = "output" file_names = {} traverse(path, file_names) for model, tags in file_names.items(): for tag, suffixes in tags.items(): title = model + "_" + tag label = [] api_data = [] for suffix, o_suffixes in suffixes.items(): label.append(suffix) for output_suffix, node_ranks in o_suffixes.items(): assert "0" in node_ranks assert "log" in node_ranks["0"] parse_data = parse_file(node_ranks["0"]["log"]) api_data.append(parse_data) plot(title, label, api_data)
生成如下图所示结果。
实验结果表明,使用TorchAcc进行Swin Transformer分布式训练可以明显提升性能。接入TorchAcc更详细的代码实现原理,请参见代码实现原理。
代码实现原理
将上述的Swin Transformer模型接入TorchAcc框架进行分布式训练加速的代码配置,请参考已下载的代码文件Swin-Transformer/main.py
。
Import TorchAcc API
在main
函数import处添加以下代码:
def enable_torchacc_compiler():
return os.getenv('USE_TORCHACC') is not None
如果打开TorchAcc,则会在main.py文件import处添加以下代码:
from logger import create_logger, enable_torchacc_compiler, enable_torchacc_kernel, log_params, log_metrics
+if enable_torchacc_compiler():
+ import torchacc.torch_xla.core.xla_model as xm
+ import torchacc.torch_xla.distributed.parallel_loader as pl
+ import torchacc.torch_xla.test.test_utils as test_utils
+ import torchacc.torch_xla.utils.utils as xu
+ from torchacc.torch_xla.amp import autocast, GradScaler
+ dist.get_rank = xm.get_ordinal
+ dist.get_world_size = xm.xrt_world_size
+ scaler = GradScaler()
+ device = xm.xla_device()
else:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
分布式初始化
在调用dist.init_process_group
函数时,将backend参数设置为xla:
dist.init_process_group(backend="xla", init_method="env://")
set_replication+封装dataloader+model placement+optimizer
在模型和dataloader定义完成之后,获取xla_device并调用set_replication函数,以封装dataloader并设置模型的设备位置。
+if enable_torchacc_compiler():
+ xm.set_replication(device, [device])
+ model.to(device)
+ data_loader_train = pl.MpDeviceLoader(data_loader_train, device)
+ data_loader_val = pl.MpDeviceLoader(data_loader_val, device)
+ model_without_ddp = model
+ optimizer = build_optimizer(config, model)
+else:
model.cuda()
optimizer = build_optimizer(config, model)
if config.AMP_OPT_LEVEL == "O2":
loss_scale = float(config.AMP_LOSS_SCALE) if config.AMP_LOSS_SCALE != "dynamic" else "dynamic"
model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL, loss_scale=loss_scale)
local_rank = int(os.environ["LOCAL_RANK"]) if 'LOCAL_RANK' in os.environ else config.LOCAL_RANK
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=False)
model_without_ddp = model.module
在Swin-Transformer/data/build.py
中,如果dataset使用了mixup_fn,则TorchAcc场景下需要替换成collate_mixedup function,如果没有使用mixup_fn,则可以忽略。
# setup mixup / cutmix
mixup_fn = None
collate_mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
# 使用TorchAcc时使用collate_mixedup_fn
+ if config.AUG.COLLATE_MIXUP:
+ collate_mixup_fn = CollateMixup(
+ mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
+ prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
+ label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES
+ )
+ else:
mixup_fn = Mixup(
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES
)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
collate_fn=collate_mixup_fn, # TorchAcc enabled
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False
)
梯度allreduce通信
如果启用了AMP开关,需要在loss backward后对梯度进行allreduce,并在backward和apply计算阶段修改代码。具体请参考main.py文件的273-324行代码。
if config.TRAIN.ACCUMULATION_STEPS > 1:
loss = loss / config.TRAIN.ACCUMULATION_STEPS
if config.AMP_OPT_LEVEL == "O2":
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(amp.master_params(optimizer))
else:
loss.backward()
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(model.parameters())
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step_update(epoch * num_steps + idx)
else:
optimizer.zero_grad()
if config.AMP_OPT_LEVEL != "O0":
if config.AMP_OPT_LEVEL == "O2":
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(amp.master_params(optimizer))
optimizer.step()
else:
scaler.scale(loss).backward()
+ if not enable_torchacc_compiler():
if config.TRAIN.CLIP_GRAD:
scaler.unscale_(optimizer)
+ else:
+ gradients = xm._fetch_gradients(optimizer)
+ xm.all_reduce('sum', gradients, scale=1.0/xm.xrt_world_size())
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(model.parameters())
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(model.parameters())
optimizer.step()
lr_scheduler.step_update(epoch * num_steps + idx)
Training Loop封装
更新代码逻辑:
从dataloader取出样本(数据)作为后面训练的输入,具体请参考main.py文件的262-264行代码。
+if not enable_torchacc_compiler(): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True)
如果开启了AMP功能,由于TorchAcc暂时只能使用AMP的AutoCast功能,因此需要在training loop中添加autocast_context_manager代码,具体请参考main.py文件的269-270行代码。
with autocast_context_manager(config): outputs = model(samples)
其中
autocast_context_manager
函数的实现可以参考main.py文件的79-87行代码。def autocast_context_manager(config): if config.AMP_OPT_LEVEL == "O2": if enable_torchacc_compiler(): ctx_manager = autocast() else: ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() else: ctx_manager = torch.cuda.amp.autocast(enabled=config.AMP_ENABLE) return ctx_manager