本文介绍如何使用Pai-Megatron-Patch优化PyTorch版Transformer模型训练。
背景信息
以下所有实验结果均在阿里云ECS服务器上进行,实验使用的ECS配置信息如下。实验环境为阿里云 ECS 实例(规格 ecs.gn6e-c12g1.12xlarge,48核CPU、368 GiB内存、4×NVIDIA V100 GPU,操作系统 Ubuntu 18.04 64位,镜像ID ubuntu_18_04_x64_20G_alibase_20211227.vhd,带宽 100Mbps 峰值)。运行 nvidia-smi 命令确认驱动版本 440.64.00、CUDA 10.2 已就绪,4 块 Tesla V100-SXM2 GPU 均处于空闲状态。
| NVIDIA-SMI 440.64.00 Driver Version: 440.64.00 CUDA Version: 10.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 Tesla V100-SXM2... On | 00000000:00:07.0 Off | 0 |
| N/A 32C P0 41W / 300W | 0MiB / 32510MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
| 1 Tesla V100-SXM2... On | 00000000:00:08.0 Off | 0 |
| N/A 31C P0 41W / 300W | 0MiB / 32510MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
| 2 Tesla V100-SXM2... On | 00000000:00:09.0 Off | 0 |
| N/A 30C P0 39W / 300W | 0MiB / 32510MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
| 3 Tesla V100-SXM2... On | 00000000:00:0A.0 Off | 0 |
| N/A 31C P0 40W / 300W | 0MiB / 32510MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
混合精度
实验环境:英文huggingface bert预训练
-
num-layers 12
-
hidden-size 768
-
num-attention-heads 12
-
num-params 110106428
-
local-rank 4
-
seq-length 512
-
micro-batch-size 16
-
global-batch-size 64
|
方案 |
吞吐 (samples/s) |
Peak Memory (MB) |
|
单精度训练 |
103.07 +/- 1.03 |
17025 |
|
混合精度训练 |
178.15 +/- 2.10 |
12698 |
分布式显存优化:模型状态切分
实验环境:英文megatron gpt预训练
-
num-layers 24
-
hidden-size 2048
-
num-attention-heads 32
-
num-params 1313722368(13亿)
-
local-rank 4
-
seq-length 1024
-
micro-batch-size 1
-
global-batch-size 4
使用Pytorch原生的分布式数据并行会导致出现OOM,导致OOM的关键原因是模型无法放在32G的显卡上,因为Adam优化器的状态参数就消耗16G显存。
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/megatron/model/language_model.py", line 351, in forward
encoder_output = self.encoder(encoder_input,
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/megatron/model/transformer.py", line 703, in forward
hidden_states = layer(hidden_states,
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/megatron/model/transformer.py", line 441, in forward
self.self_attention(layernorm_output,
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/megatron/model/transformer.py", line 264, in forward
matmul_result = torch.baddbmm(
RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 1; 31.75 GiB total capacity; 28.56 GiB already allocated; 84.00 MiB free; 30.19 GiB reserved in total by PyTorch)
|
方案 |
吞吐 (samples/s) |
Peak Memory (MB) |
|
无加速技术 |
OOM |
OOM |
|
混合精度训练 |
9.57 +/- 0.26 |
25061 |
|
混合精度训练 + oss模型状态切分 |
6.02 +/- 0.06 |
22077 |
|
混合精度训练 + oss/sdp模型状态切分 |
7.01 +/- 0.07 |
17113 |
|
混合精度训练 + fsdp模型状态切分 |
NA |
NA |
|
混合精度训练 + Zero-1 |
12.88 +/- 0.10 |
15709 |
|
混合精度训练 + Zero-2 |
10.27 +/- 0.08 |
15693 |
|
混合精度训练 + Zero-3 |
NA |
NA |
3D混合并行
实验环境:英文megatron gpt预训练
-
num-layers 24
-
hidden-size 2048
-
num-attention-heads 32
-
num-params 1313722368(13亿)
-
local-rank 4
-
seq-length 1024
-
micro-batch-size 1
-
global-batch-size 4
开启混合精度训练下:
|
算子拆分 |
流水并行 |
吞吐 (samples/s) |
Peak Memory (MB) |
|
1 |
1 |
9.63 +/- 0.29 |
25061 |
|
2 |
1 |
7.59 +/- 0.14 |
11300 |
|
4 |
1 |
6.16 +/- 0.06 |
5673 |
|
1 |
2 |
8.46 +/- 0.17 |
12375 |
|
1 |
4 |
8.03 +/- 0.12 |
8141 |
|
2 |
2 |
7.37 +/- 0.11 |
6211 |
|
4 |
4 |
6.24 +/- 0.08 |
5673 |
ORT计算图优化
实验环境:英文huggingface bert 微调
-
num-layers 12
-
hidden-size 768
-
num-attention-heads 12
-
num-params 110106428
-
local-rank 4
-
seq-length 512
-
micro-batch-size 16
-
global-batch-size 64
提升15.6%:
|
方案 |
吞吐 (samples/s) |
Peak Memory (MB) |
|
单精度训练 |
479.15 +/- 1.67 |
2112 |
|
混合精度训练 |
589.66 +/- 4.79 |
2127 |
|
ORT计算图优化 |
554.24 +/- 1.98 |
2430 |
|
ORT+混合精度 |
614.70 +/- 8.69 |
2289 |