文档

自定义模型接入TorchAcc

更新时间:

阿里云PAI为您提供了部分典型场景下的示例模型,便于您便捷地接入TorchAcc进行训练加速,同时也支持将自行开发的模型接入TorchAcc进行加速,本文为您介绍如何在自定义模型中接入TorchAcc以提高分布式训练速度和效率。

背景信息

TorchAcc的优化方式分为以下两类,您可以根据实际需求选择合适的优化方式,以提高模型训练速度和效率。

  • 编译优化

    TorchAcc支持将PyTorch动态图转换为静态图,并进行计算图优化和编译,以提高模型训练速度和效率。TorchAcc会将计算图转换为高效的计算图,并使用JIT编译器将其编译为更高效的代码。这样可以避免PyTorch动态图计算过程中的一些性能损失,并提高模型训练速度和效率。

  • 定制优化

    当模型包含Dynamic Shape、Custom算子、Dynamic ControlFlow等特性时,暂时无法应用全局编译优化进行分布式训练加速。针对此类场景,TorchAcc提供了定制优化:

    • IO优化

    • 计算(Kernel)优化

    • 显存优化

TorchAcc编译优化

接入分布式训练

接入TorchAcc的Compiler进行分布式训练,具体操作步骤如下:

  1. 固定随机种子。

    通过固定随机种子保证每个Worker权重的初始化保持一致,用于代替权重broadcast的效果。

    torch.manual_seed(SEED_NUMBER)
    替换为:
    xm.set_rng_state(SEED_NUMBER)
  2. 在获取xla_device后,调用set_replication、封装dataloader并设置model device placement。

    device = xm.xla_device()
    xm.set_replication(device, [device])
    
    # Wrapper dataloader
    data_loader_train = pl.MpDeviceLoader(data_loader_train, device)
    data_loader_val = pl.MpDeviceLoader(data_loader_val, device)
    
    # Dispatch device to model
    model.to(device)
  3. 分布式初始化。

    dist.init_process_group的backend参数配置为'xla'

    dist.init_process_group(backend='xla', init_method='env://')
  4. 梯度allreduce通信。

    在loss backward后对梯度进行allreduce操作:

    gradients=xm._fetch_gradients(optimizer)
    xm.all_reduce('sum', gradients, scale=1.0/xm.xrt_world_size())
    重要

    如果使用混合精度AMP训练,且手动调用了scaler.unscale_,一定要在scaler.unscale_之前调用xm.all_reduce,以确保基于all_reduce之后的梯度进行溢出检测。

  5. 使用xlarun拉起任务。

    xlarun --nproc_per_node=8 YOUR_MODEL.py
    说明

    多机情况使用方法与torchrun相同。

接入混合精度

通过混合精度训练可以加速模型训练速度,在单卡训练或分布式训练的基础上按照以下步骤完成AMP逻辑的实现。在上一章节基础上接入混合精度进行TorchAcc编译优化的具体操作步骤如下。

  1. 按照pytorch原生功能实现AMP。

    TorchAcc混合精度与Pytorch原生混合精度使用方法基本一致,请先参照以下文档实现Pytorch原生的AMP功能。

  2. 替换GradScaler。

    torch.cuda.amp.GradScaler替换为torchacc.torch_xla.amp.GradScaler

    from torchacc.torch_xla.amp import GradScaler
  3. 替换optimizer。

    使用原生PyTorch optimizer性能会稍差,可将torch.optim的optimizer替换为syncfree optimizer来进一步提升训练速度。

    from torchacc.torch_xla.amp import syncfree
    
    adam_optimizer = syncfree.Adam()
    adamw_optimizer = syncfree.AdamW()
    sgd_optimizer = syncfree.SGD()

    目前syncfree optimizer只提供了以上三类optimizer的实现,其它类型optimizer可继续使用PyTorch原生optimizer即可。

接入案例

以Bert-base模型为例,代码示例如下:

import argparse
import os
import time
import torch
import torch.distributed as dist

from datasets import load_from_disk
from datetime import datetime as dt
from time import gmtime, strftime
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding

# Pytorch1.12 default set False.
torch.backends.cuda.matmul.allow_tf32=True

parser = argparse.ArgumentParser()
parser.add_argument("--amp-level", choices=["O1"], default="O1", help="amp level.")
parser.add_argument("--profile", action="store_true")
parser.add_argument("--profile_folder", type=str, default="./profile_folder")
parser.add_argument("--dataset_path", type=str, default="./sst_data/train")
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--max_seq_length", type=int, default=512)
parser.add_argument("--break_step_for_profiling", type=int, default=20)
parser.add_argument("--model_name", type=str, default="bert-base-cased")
parser.add_argument("--local_rank", type=int, default="-1")
parser.add_argument("--log-interval", type=int, default="10")
parser.add_argument('--max-steps', type=int, default=200, help='total training epochs.')
args = parser.parse_args()
print("Job running args: ", args)
args.local_rank = os.getenv("LOCAL_RANK", 0)


+def enable_torchacc_compiler():
+    return os.getenv('TORCHACC_COMPILER_OPT') is not None


def print_rank_0(message):
    """If distributed is initialized, print only on rank 0."""
    if torch.distributed.get_rank() == 0:
        print(message, flush=True)


def print_test_update(epoch, step, batch_size, loss, time_elapsed, samples_per_step, peak_mem):
  # Getting the current date and time
  dt = strftime("%a, %d %b %Y %H:%M:%S", gmtime())
  print_rank_0(train_format_string.format(dt, epoch, step, batch_size, loss, time_elapsed, samples_per_step, peak_mem))


def log_metrics(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem):
    batch_time = f"{batch_time:.3f}"
    samples_per_step = f"{samples_per_step:.3f}"
    peak_mem = f"{peak_mem:.3f}"
+    if enable_torchacc_compiler():
+        import torchacc.torch_xla.core.xla_model as xm
+        xm.add_step_closure(
+            print_test_update, args=(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem), run_async=True)
+    else:
        print_test_update(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem)


+if enable_torchacc_compiler():
+  import torchacc.torch_xla.core.xla_model as xm
+  import torchacc.torch_xla.distributed.parallel_loader as pl
+  import torchacc.torch_xla.distributed.xla_backend
+  from torchacc.torch_xla.amp import autocast, GradScaler, syncfree
+  xm.set_rng_state(101)
+  dist.init_process_group(backend="xla", init_method="env://")
+else:
  from torch.cuda.amp import autocast, GradScaler
  dist.init_process_group(backend="nccl", init_method="env://")

dist.barrier()
args.world_size = dist.get_world_size()
args.rank = dist.get_rank()
print("world size:", args.world_size, " rank:", args.rank, " local rank:", args.local_rank)


def get_autocast_and_scaler():
+  if enable_torchacc_compiler():
+    return autocast, GradScaler()

  return autocast, GradScaler()


def loop_with_amp(model, inputs, optimizer, autocast, scaler):
  with autocast():
    outputs = model(**inputs)
    loss = outputs["loss"]

  scaler.scale(loss).backward()
+  if enable_torchacc_compiler():
+    gradients = xm._fetch_gradients(optimizer)
+    xm.all_reduce('sum', gradients, scale=1.0/xm.xrt_world_size())
  scaler.step(optimizer)
  scaler.update()

  return loss, optimizer


def loop_without_amp(model, inputs, optimizer):
  outputs = model(**inputs)
  loss = outputs["loss"]
  loss.backward()
+  if enable_torchacc_compiler():
+    xm.optimizer_step(optimizer)
+  else:
    optimizer.step()
  return loss, optimizer


def full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler, profiler=None):
  model.train()

  iteration_time = time.time()
  num_steps = int(len(train_device_loader.dataset) / args.batch_size)
  for step, inputs in enumerate(train_device_loader):
    if step > args.max_steps:
      break
+    if not enable_torchacc_compiler():
      inputs.to(device)

    optimizer.zero_grad()

    if args.amp_level == "O1":
      loss, optimizer = loop_with_amp(model, inputs, optimizer, autocast, scaler)
    else:
      loss, optimizer = loop_without_amp(model, inputs, optimizer)

    if args.profile and profiler:
      profiler.step()

    if step % args.log_interval == 0:
      time_elapsed = (time.time() - iteration_time) / args.log_interval
      iteration_time = time.time()
      samples_per_step = float(args.batch_size / time_elapsed) * args.world_size
      peak_mem = torch.cuda.memory_allocated()/1024.0/1024.0/1024.0
      log_metrics(epoch, step, args.batch_size, loss, time_elapsed, samples_per_step, peak_mem)


def train_bert():
  model = AutoModelForSequenceClassification.from_pretrained(args.model_name, cache_dir="./model")
  tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  tokenizer.model_max_length = args.max_seq_length

  training_dataset = load_from_disk(args.dataset_path)
  collator = DataCollatorWithPadding(tokenizer)
  training_dataset = training_dataset.remove_columns(['text'])
  train_device_loader = torch.utils.data.DataLoader(
      training_dataset, batch_size=args.batch_size, collate_fn=collator, shuffle=True, num_workers=4)

+  if enable_torchacc_compiler():
+    device = xm.xla_device()
+    xm.set_replication(device, [device])
+    train_device_loader = pl.MpDeviceLoader(train_device_loader, device)
+    model = model.to(device)
+  else:
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    model = model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(model)


+  if enable_torchacc_compiler() and args.amp_level == "O1":
+    optimizer = syncfree.Adam(model.parameters(), lr=1e-3)
+  else:
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

  autocast, scaler = None, None
  if args.amp_level == "O1":
    autocast, scaler = get_autocast_and_scaler()

  if args.profile:
    with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=2, warmup=2, active=20),
        on_trace_ready=torch.profiler.tensorboard_trace_handler(args.profile_folder)) as prof:
      for epoch in range(args.num_epochs):
        full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler, profiler=prof)
  else:
    for epoch in range(args.num_epochs):
      full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler)


if __name__ == "__main__":
  train_bert()

TorchAcc定制优化

IO优化

Data Prefetcher

支持预先读取训练数据,且提供preprocess_fn参数支持数据预处理。

+ from torchacc.runtime.io.prefetcher import Prefetcher

data_loader = build_data_loader()
model = build_model()
optimizer = build_optimizer()

# define preprocess function
preprocess_fn = None

+ prefetcher = Prefetcher(data_loader, preprocess_fn)

for iter, samples in enumerate(prefetcher):
    loss = model(samples)
    loss.backward()

    # Prefetch to CPU first. Call after backward and before update.
    # At this point we are waiting for kernels launched by cuda graph
    #  to finish, so CPU is idle. Take advantage of this by loading next
    #  input batch before calling step.
+    prefetcher.prefetch_CPU()

    optimizer.step()
    
    # Prefetch to GPU. Call after optimizer step.
+	prefetcher.prefetch_GPU()

Pack Dataset

语言数据集都存在变长的情况,例如文本句子、语音等。为了提高计算效率,利用样本的长短不一致的问题,将几个样本打包到一起,组成一个固定shape的batch,减少padding的0值占比和batch data的动态性,从而提高EPOCH的(分布式)训练效率。

pin memory

在dataloader定义时增加pin_memory参数,并适量增加num_workers。image..png

计算优化

Kernel Fusion优化

支持以下几种优化方式:

  • FusedLayerNorm

    # LayerNorm的等价替换kernel
    from torchacc.runtime import hooks
    # add before import torch
    hooks.enable_fused_layer_norm()
  • FusedAdam

    # Adam/AdamW的等价替换kernel
    from torchacc.runtime import hooks
    # add before import torch
    hooks.enable_fused_adam()
  • QuickGelu

    # 用QuickGelu替换nn.GELU
    from torchacc.runtime.nn.quick_gelu import QuickGelu
  • fused_bias_dropout_add

    # from torchacc.runtime.nn import dropout_add_fused_train, 
    #将Dropout和element-wise的bias add等操作fuse起来
    if self.training:
        # train mode
        with torch.enable_grad():
            x = dropout_add_fused_train(x, to_add, drop_rate)
    else:
        # inference mode
        x = dropout_add_fused(x, to_add, drop_rate)
  • WindowProcess

    # WindowProcess优化kernel 融合了SwinTransformer中关于shift window及window划分的操作,包括 - window cyclic shift和window partition - window merge和reverse cyclic shift。
    from torchacc.runtime.nn.window_process import WindowProcess
    
    if not fused:
        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
    else:
        x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
    
    
    
    from torchacc.runtime.nn.window_process import WindowProcessReverse
    
    if not fused:
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    else:
        x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
  • FusedSwinFmha

    # 融合了SwinTransformer中MHA的qk_result + relative_position_bias + mask + softmax部分
    from torchacc.runtime.nn.fmha import FusedSwinFmha
    
    FusedSwinFmha.apply(attn, relative_pos_bias, attn_mask, batch_size, window_num,
                  num_head, window_len)
  • nms/nms_normal/soft_nms/batched_soft_nms

    # 融合了nms/nms_normal/soft_nms/batched_soft_nms等四类算子cuda kernel实现。
    
    from torchacc.runtime.nn.nms import nms, nms_normal
    from torchacc.runtime.nn.nms import soft_nms, batched_soft_nms
    

Parallelized Kernel优化

DCN/DCNv2:

# 对dcn_v2_cuda后向进行了并行计算优化。
from torchacc.runtime.nn.dcn_v2 import DCN, DCNv2

self.conv = DCN(chi, cho, kernel_size, stride, padding, dilation, deformable_groups)

Multi-stream Kernel优化

利用多个stream来并发计算函数的一组输入,计算逻辑同mmdet.core.multi_apply函数。

from torchacc.runtime.utils.misc import multi_apply_multi_stream
from mmdet.core import multi_apply

def test_func(t1, t2, t3):
  t1 = t1 * 2.0
  t2 = t2 + 2.0
  t3 = t3 - 2.0
  return (t1, t2, t3)

cuda = torch.device('cuda')
t1 = torch.empty((100, 1000), device=cuda).normal_(0.0, 1.0)
t2 = torch.empty((100, 1000), device=cuda).normal_(0.0, 2.0)
t3 = torch.empty((100, 1000), device=cuda).normal_(0.0, 3.0)

if enable_torchacc:
    result = multi_apply_multi_stream(test_func, 2, t1, t2, t3)
else:
    result = multi_apply(test_func, t1, t2, t3)

显存优化

Gradient Checkpointing

import torchacc

model = torchacc.auto_checkpoint(model)