阿里云PAI为您提供了部分典型场景下的示例模型,便于您便捷地接入TorchAcc进行训练加速,同时也支持将自行开发的模型接入TorchAcc进行加速,本文为您介绍如何在自定义模型中接入TorchAcc以提高分布式训练速度和效率。
背景信息
TorchAcc的优化方式分为以下两类,您可以根据实际需求选择合适的优化方式,以提高模型训练速度和效率。
编译优化
TorchAcc支持将PyTorch动态图转换为静态图,并进行计算图优化和编译,以提高模型训练速度和效率。TorchAcc会将计算图转换为高效的计算图,并使用JIT编译器将其编译为更高效的代码。这样可以避免PyTorch动态图计算过程中的一些性能损失,并提高模型训练速度和效率。
定制优化
当模型包含Dynamic Shape、Custom算子、Dynamic ControlFlow等特性时,暂时无法应用全局编译优化进行分布式训练加速。针对此类场景,TorchAcc提供了定制优化:
IO优化
计算(Kernel)优化
显存优化
TorchAcc编译优化
接入分布式训练
接入TorchAcc的Compiler进行分布式训练,具体操作步骤如下:
固定随机种子。
通过固定随机种子保证每个Worker权重的初始化保持一致,用于代替权重broadcast的效果。
torch.manual_seed(SEED_NUMBER) 替换为: xm.set_rng_state(SEED_NUMBER)
在获取xla_device后,调用set_replication、封装dataloader并设置model device placement。
device = xm.xla_device() xm.set_replication(device, [device]) # Wrapper dataloader data_loader_train = pl.MpDeviceLoader(data_loader_train, device) data_loader_val = pl.MpDeviceLoader(data_loader_val, device) # Dispatch device to model model.to(device)
分布式初始化。
将dist.init_process_group的backend参数配置为'xla':
dist.init_process_group(backend='xla', init_method='env://')
梯度allreduce通信。
在loss backward后对梯度进行allreduce操作:
gradients=xm._fetch_gradients(optimizer) xm.all_reduce('sum', gradients, scale=1.0/xm.xrt_world_size())
重要如果使用混合精度AMP训练,且手动调用了scaler.unscale_,一定要在scaler.unscale_之前调用xm.all_reduce,以确保基于all_reduce之后的梯度进行溢出检测。
使用xlarun拉起任务。
xlarun --nproc_per_node=8 YOUR_MODEL.py
说明多机情况使用方法与torchrun相同。
接入混合精度
通过混合精度训练可以加速模型训练速度,在单卡训练或分布式训练的基础上按照以下步骤完成AMP逻辑的实现。在上一章节基础上接入混合精度进行TorchAcc编译优化的具体操作步骤如下。
按照pytorch原生功能实现AMP。
TorchAcc混合精度与Pytorch原生混合精度使用方法基本一致,请先参照以下文档实现Pytorch原生的AMP功能。
替换GradScaler。
将torch.cuda.amp.GradScaler替换为torchacc.torch_xla.amp.GradScaler:
from torchacc.torch_xla.amp import GradScaler
替换optimizer。
使用原生PyTorch optimizer性能会稍差,可将torch.optim的optimizer替换为syncfree optimizer来进一步提升训练速度。
from torchacc.torch_xla.amp import syncfree adam_optimizer = syncfree.Adam() adamw_optimizer = syncfree.AdamW() sgd_optimizer = syncfree.SGD()
目前syncfree optimizer只提供了以上三类optimizer的实现,其它类型optimizer可继续使用PyTorch原生optimizer即可。
接入案例
以Bert-base模型为例,代码示例如下:
import argparse
import os
import time
import torch
import torch.distributed as dist
from datasets import load_from_disk
from datetime import datetime as dt
from time import gmtime, strftime
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding
# Pytorch1.12 default set False.
torch.backends.cuda.matmul.allow_tf32=True
parser = argparse.ArgumentParser()
parser.add_argument("--amp-level", choices=["O1"], default="O1", help="amp level.")
parser.add_argument("--profile", action="store_true")
parser.add_argument("--profile_folder", type=str, default="./profile_folder")
parser.add_argument("--dataset_path", type=str, default="./sst_data/train")
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--max_seq_length", type=int, default=512)
parser.add_argument("--break_step_for_profiling", type=int, default=20)
parser.add_argument("--model_name", type=str, default="bert-base-cased")
parser.add_argument("--local_rank", type=int, default="-1")
parser.add_argument("--log-interval", type=int, default="10")
parser.add_argument('--max-steps', type=int, default=200, help='total training epochs.')
args = parser.parse_args()
print("Job running args: ", args)
args.local_rank = os.getenv("LOCAL_RANK", 0)
+def enable_torchacc_compiler():
+ return os.getenv('TORCHACC_COMPILER_OPT') is not None
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.get_rank() == 0:
print(message, flush=True)
def print_test_update(epoch, step, batch_size, loss, time_elapsed, samples_per_step, peak_mem):
# Getting the current date and time
dt = strftime("%a, %d %b %Y %H:%M:%S", gmtime())
print_rank_0(train_format_string.format(dt, epoch, step, batch_size, loss, time_elapsed, samples_per_step, peak_mem))
def log_metrics(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem):
batch_time = f"{batch_time:.3f}"
samples_per_step = f"{samples_per_step:.3f}"
peak_mem = f"{peak_mem:.3f}"
+ if enable_torchacc_compiler():
+ import torchacc.torch_xla.core.xla_model as xm
+ xm.add_step_closure(
+ print_test_update, args=(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem), run_async=True)
+ else:
print_test_update(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem)
+if enable_torchacc_compiler():
+ import torchacc.torch_xla.core.xla_model as xm
+ import torchacc.torch_xla.distributed.parallel_loader as pl
+ import torchacc.torch_xla.distributed.xla_backend
+ from torchacc.torch_xla.amp import autocast, GradScaler, syncfree
+ xm.set_rng_state(101)
+ dist.init_process_group(backend="xla", init_method="env://")
+else:
from torch.cuda.amp import autocast, GradScaler
dist.init_process_group(backend="nccl", init_method="env://")
dist.barrier()
args.world_size = dist.get_world_size()
args.rank = dist.get_rank()
print("world size:", args.world_size, " rank:", args.rank, " local rank:", args.local_rank)
def get_autocast_and_scaler():
+ if enable_torchacc_compiler():
+ return autocast, GradScaler()
return autocast, GradScaler()
def loop_with_amp(model, inputs, optimizer, autocast, scaler):
with autocast():
outputs = model(**inputs)
loss = outputs["loss"]
scaler.scale(loss).backward()
+ if enable_torchacc_compiler():
+ gradients = xm._fetch_gradients(optimizer)
+ xm.all_reduce('sum', gradients, scale=1.0/xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
return loss, optimizer
def loop_without_amp(model, inputs, optimizer):
outputs = model(**inputs)
loss = outputs["loss"]
loss.backward()
+ if enable_torchacc_compiler():
+ xm.optimizer_step(optimizer)
+ else:
optimizer.step()
return loss, optimizer
def full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler, profiler=None):
model.train()
iteration_time = time.time()
num_steps = int(len(train_device_loader.dataset) / args.batch_size)
for step, inputs in enumerate(train_device_loader):
if step > args.max_steps:
break
+ if not enable_torchacc_compiler():
inputs.to(device)
optimizer.zero_grad()
if args.amp_level == "O1":
loss, optimizer = loop_with_amp(model, inputs, optimizer, autocast, scaler)
else:
loss, optimizer = loop_without_amp(model, inputs, optimizer)
if args.profile and profiler:
profiler.step()
if step % args.log_interval == 0:
time_elapsed = (time.time() - iteration_time) / args.log_interval
iteration_time = time.time()
samples_per_step = float(args.batch_size / time_elapsed) * args.world_size
peak_mem = torch.cuda.memory_allocated()/1024.0/1024.0/1024.0
log_metrics(epoch, step, args.batch_size, loss, time_elapsed, samples_per_step, peak_mem)
def train_bert():
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, cache_dir="./model")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.model_max_length = args.max_seq_length
training_dataset = load_from_disk(args.dataset_path)
collator = DataCollatorWithPadding(tokenizer)
training_dataset = training_dataset.remove_columns(['text'])
train_device_loader = torch.utils.data.DataLoader(
training_dataset, batch_size=args.batch_size, collate_fn=collator, shuffle=True, num_workers=4)
+ if enable_torchacc_compiler():
+ device = xm.xla_device()
+ xm.set_replication(device, [device])
+ train_device_loader = pl.MpDeviceLoader(train_device_loader, device)
+ model = model.to(device)
+ else:
device = torch.device(f"cuda:{args.local_rank}")
torch.cuda.set_device(device)
model = model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
+ if enable_torchacc_compiler() and args.amp_level == "O1":
+ optimizer = syncfree.Adam(model.parameters(), lr=1e-3)
+ else:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
autocast, scaler = None, None
if args.amp_level == "O1":
autocast, scaler = get_autocast_and_scaler()
if args.profile:
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=2, warmup=2, active=20),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.profile_folder)) as prof:
for epoch in range(args.num_epochs):
full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler, profiler=prof)
else:
for epoch in range(args.num_epochs):
full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler)
if __name__ == "__main__":
train_bert()
TorchAcc定制优化
IO优化
Data Prefetcher
支持预先读取训练数据,且提供preprocess_fn参数支持数据预处理。
+ from torchacc.runtime.io.prefetcher import Prefetcher
data_loader = build_data_loader()
model = build_model()
optimizer = build_optimizer()
# define preprocess function
preprocess_fn = None
+ prefetcher = Prefetcher(data_loader, preprocess_fn)
for iter, samples in enumerate(prefetcher):
loss = model(samples)
loss.backward()
# Prefetch to CPU first. Call after backward and before update.
# At this point we are waiting for kernels launched by cuda graph
# to finish, so CPU is idle. Take advantage of this by loading next
# input batch before calling step.
+ prefetcher.prefetch_CPU()
optimizer.step()
# Prefetch to GPU. Call after optimizer step.
+ prefetcher.prefetch_GPU()
Pack Dataset
语言数据集都存在变长的情况,例如文本句子、语音等。为了提高计算效率,利用样本的长短不一致的问题,将几个样本打包到一起,组成一个固定shape的batch,减少padding的0值占比和batch data的动态性,从而提高EPOCH的(分布式)训练效率。
pin memory
在dataloader定义时增加pin_memory参数,并适量增加num_workers。
计算优化
Kernel Fusion优化
支持以下几种优化方式:
FusedLayerNorm
# LayerNorm的等价替换kernel from torchacc.runtime import hooks # add before import torch hooks.enable_fused_layer_norm()
FusedAdam
# Adam/AdamW的等价替换kernel from torchacc.runtime import hooks # add before import torch hooks.enable_fused_adam()
QuickGelu
# 用QuickGelu替换nn.GELU from torchacc.runtime.nn.quick_gelu import QuickGelu
fused_bias_dropout_add
# from torchacc.runtime.nn import dropout_add_fused_train, #将Dropout和element-wise的bias add等操作fuse起来 if self.training: # train mode with torch.enable_grad(): x = dropout_add_fused_train(x, to_add, drop_rate) else: # inference mode x = dropout_add_fused(x, to_add, drop_rate)
WindowProcess
# WindowProcess优化kernel 融合了SwinTransformer中关于shift window及window划分的操作,包括 - window cyclic shift和window partition - window merge和reverse cyclic shift。 from torchacc.runtime.nn.window_process import WindowProcess if not fused: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C else: x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size) from torchacc.runtime.nn.window_process import WindowProcessReverse if not fused: shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
FusedSwinFmha
# 融合了SwinTransformer中MHA的qk_result + relative_position_bias + mask + softmax部分 from torchacc.runtime.nn.fmha import FusedSwinFmha FusedSwinFmha.apply(attn, relative_pos_bias, attn_mask, batch_size, window_num, num_head, window_len)
nms/nms_normal/soft_nms/batched_soft_nms
# 融合了nms/nms_normal/soft_nms/batched_soft_nms等四类算子cuda kernel实现。 from torchacc.runtime.nn.nms import nms, nms_normal from torchacc.runtime.nn.nms import soft_nms, batched_soft_nms
Parallelized Kernel优化
DCN/DCNv2:
# 对dcn_v2_cuda后向进行了并行计算优化。
from torchacc.runtime.nn.dcn_v2 import DCN, DCNv2
self.conv = DCN(chi, cho, kernel_size, stride, padding, dilation, deformable_groups)
Multi-stream Kernel优化
利用多个stream来并发计算函数的一组输入,计算逻辑同mmdet.core.multi_apply函数。
from torchacc.runtime.utils.misc import multi_apply_multi_stream
from mmdet.core import multi_apply
def test_func(t1, t2, t3):
t1 = t1 * 2.0
t2 = t2 + 2.0
t3 = t3 - 2.0
return (t1, t2, t3)
cuda = torch.device('cuda')
t1 = torch.empty((100, 1000), device=cuda).normal_(0.0, 1.0)
t2 = torch.empty((100, 1000), device=cuda).normal_(0.0, 2.0)
t3 = torch.empty((100, 1000), device=cuda).normal_(0.0, 3.0)
if enable_torchacc:
result = multi_apply_multi_stream(test_func, 2, t1, t2, t3)
else:
result = multi_apply(test_func, t1, t2, t3)
显存优化
Gradient Checkpointing
import torchacc
model = torchacc.auto_checkpoint(model)