Amazon SageMaker新模型并行库进一步加速PyTorch FSDP工作负载
关键要点
加速性能:Amazon SageMaker模型并行库现在能够将PyTorch FSDP工作负载的训练速度提高多达20。强化用户体验:与开源PyTorch完全对接,降低使用门槛。功能扩展:新增张量并行性,支持在大型集群上训练超过千亿参数的模型。优化技术:通过混合分片、SMDDP 和激活离线等技术,加快模型训练,提高效率。在过去一年中,随着Llama 2、Falcon和Mistral等流行大型语言模型LLM的发布,LLM的训练迅速走红。现在,客户正在预训练和微调参数从10亿到超过1750亿的LLM,以优化各行业从医疗到金融和市场营销的应用性能。
在如此大规模的模型训练中,性能优化成为了一个重大挑战。高效的LLM可能需要数TB的训练数据,以及数千甚至数百万小时的加速计算时间,以实现目标精度。因此,客户依赖并行技术将庞大的工作负载分配到数千个加速器设备上。然而,这些并行技术难以使用:不同的技术和库仅与特定的工作负载兼容,或仅限于特定模型架构,训练性能对配置极为敏感,而目前的技术发展迅速。因此,机器学习从业者往往需要花费数周的时间来准备将LLM工作负载扩展到大型GPU集群。
在本篇文章中,我们将详细介绍Amazon SageMaker模型并行SMP库的新特性,这些特性可以简化大型模型训练过程,帮助您更快地训练LLM。具体来说,我们将介绍SMP库的新简化用户体验,建立在开源PyTorch完全分片数据并行FSDPAPI之上,扩展张量并行功能,支持训练包含数百亿参数的模型,并通过优化技术使模型训练时间和成本降低多达20。
有关SageMaker模型并行库的更多信息,请参考SageMaker模型并行主义库文档。您还可以查看我们的示例笔记本以快速上手。
简化和加速大型模型训练的新特性
本文讨论SageMaker模型并行库版本20中的最新特性。这些特性提升了库的可用性,扩展了功能,加快了训练速度。在接下来的部分中,我们将总结这些新功能,并讨论如何利用该库加速您的大型模型训练。
云梯加速器试用将SMP与开源PyTorch对接
自2020年推出以来,SMP库已在SageMaker计算实例上实现高性能的大规模训练。此次SMP库的最新版本简化了用户体验,将其API与开源PyTorch对齐。
PyTorch提供完全分片数据并行FSDP作为支持大规模训练工作负载的主要方式。如以下代码片段所示,SMP更新后的API与PyTorch的分片数据并行API相似。您只需运行 import torchsagemaker 即可替代 torch 使用。
python
trainingscriptpy
import torchsagemaker as tsmtsminit()
设置一个PyTorch模型
model =
使用PyTorch FSDP模块包装PyTorch模型
model = FSDP( model )
optimizer =
通过这些API更新,您可以在不重塑现有PyTorch FSDP训练脚本的情况下,实现SageMaker和SMP库的性能优势。这种范式还允许您在本地和SageMaker上使用相同的代码基础,简化了在多个环境中进行训练的用户体验。
有关如何使用现有的PyTorch FSDP训练脚本启用SMP的更多信息,请参考开始使用SMP。
集成张量并行性以支持在大型集群上训练
本次SMP更新还扩展了PyTorch FSDP的功能,加入了张量并行技术。单独使用分片数据并行性时,集群规模扩大可能导致收敛问题。这是因为在数据并行范围内分片参数、梯度和优化器状态同时也会增加全局批量大小;在大型集群上,这种全局批量大小可能会超过模型收敛的阈值。因此,您需要结合额外的并行技术,确保在扩展集群时不增加全局批量大小。
为了解决这一问题,SMP v20引入了将分片数据并行性与张量并行性结合的能力。张量并行性允许集群规模扩大而不更改全局批量大小或影响模型的收敛性。通过这一功能,您可以安全地增加训练吞吐量,配备256个节点或更多的集群。
当前,PyTorch FSDP的张量并行性仅在SMP v2中可用。SMP v2允许您通过少量代码更改启用此技术,并在大型集群上实现稳定训练。SMP v2与Transformer Engine集成,以实现张量并行性,并使之与PyTorch FSDP API兼容。您可以同时启用PyTorch FSDP和SMP的张量并行性,而无需对PyTorch模型或FSDP配置做任何更改。以下代码片段展示了如何在训练脚本中设置SMP配置字典,并添加SMP初始化模块 torchsagemakerinit() ,将配置字典传递给后台以启动训练任务。
SMP配置如下:
json{ tensorparalleldegree 8 tensorparallelseed 0}
在您的训练脚本中,请使用如下代码:
pythonimport torchsagemaker as tsmtsminit()
from transformers import AutoModelForCausalLMmodel = AutoModelForCausalLMfromconfig()model = tsmtransform(model)
有关如何在SMP中使用张量并行性的更多信息,请参阅我们文档中的张量并行性部分。
使用先进的特性将模型训练加速高达20
除了在拥有数百个实例的集群上启用分布式训练外,SMP还提供可加速模型训练的优化技术,提升速度高达20。在这一部分中,我们将重点介绍其中的一些优化。要了解更多,请参阅我们文档中的核心功能部分。
混合分片
分片数据并行性是一种节省内存的分布式训练技术,它将模型状态模型参数、梯度和优化器状态分散到多个设备上。这种较小的内存占用使您可以将更大的模型放入集群中或增加批量大小。然而,分片数据并行性也增加了训练作业的通信需求,因为在训练过程中分片的模型工件需要频繁地从不同设备收集。因此,分片的程度是一个重要配置,它在内存消耗和通信开销之间进行权衡。
默认情况下,PyTorch FSDP将在集群中所有加速设备上分片模型工件。根据训练作业的不同,这种分片方法可能会增加通信开销,产生瓶颈。为此,SMP库在PyTorch FSDP之上提供了可配置的混合分片数据并行性。这一功能允许您设置适合训练工作负载的分片程度。只需在配置JSON对象中指定分片程度,并将其包含在您的SMP训练脚本中。
SMP配置如下:
json{ hybridsharddegree 16 }

要了解混合分片的优势,请参阅AWS的巨型模型训练近线性扩展。有关如何使用您现有的FSDP训练脚本实现混合分片的更多信息,请参见我们文档中的混合分片部分。
使用优化AWS基础设施的SMDDP集体通信操作
您可以使用SMP库与SageMaker分布式数据并行SMDDP库一起,加速您的分布式训练工作负载。SMDDP包括一个为SageMaker p4d和p4de加速实例优化的 AllGather 集体通信操作。在分布式训练中,集体通信操作用于在GPU工作节点之间同步信息。AllGather 是在分片数据并行性中通常使用的核心集体通信操作之一,在进行前向和后向计算步骤之前,将层参数聚集到一起。对于那些被通信瓶颈限制的训练任务,更快的集体操作可以降低训练时间和成本,而不会对收敛性产生任何副作用。
要使用SMDDP库,您只需在训练脚本中添加两行代码:
pythonimport torchdistributed as dist
使用SMDDP初始化
import smdistributeddataparalleltorchtorchsmddpdistinitprocessgroup(backend=smddp) # 替换nccl
使用SMP初始化
import torchsagemaker as tsmtsminit()
除了SMP,SMDDP还支持开源PyTorch FSDP和DeepSpeed。要了解更多有关SMDDP库的信息,请参阅使用SageMaker分布式数据并行库运行分布式训练。
激活离线处理
通常,在模型训练的前向过程中,计算activations并在反向传播完成之前将其保留在GPU内存中。这些存储的激活在训练期间可能会消耗大量GPU内存。激活离线处理是一种技术,它将这些张量在前向传递后移动到CPU内存,并在需要时再将其获取到GPU。这种方法可以显著减少训练时的GPU内存使用。
尽管PyTorch支持激活离线处理,但其实现效率不高,并可能导致GPU在反向传播期间因从CPU获取激活时闲置。这可能会造成显著的性能下降。
SMP v2提供了一种优化的激活离线处理算法,可以提高训练性能。SMP的实现提前将激活进行预取,减少了GPU的闲置时间。
因为SMP是建立在PyTorch API之上的,所以启用优化的激活离线处理只需进行少量代码更改。只需添加相关配置smactivationoffloading和activationloadinghorizon参数,并将其包含在您的训练脚本中。
SMP配置如下:
json{ activationloadinghorizon 2 smactivationoffloading true}
在训练脚本中,使用如下代码:
pythonimport torchsagemaker as tsmtsminit()
用于激活离线处理的PyTorch原生模块
from torchdistributedalgorithmscheckpointcheckpointwrapper import ( applyactivationcheckpointing offloadwrapper)
model = FSDP()
激活离线处理需要激活检查点。
applyactivationcheckpointing( model checkfn=checkpointtformerlayerspolicy)
model = offloadwrapper(model)
要了解有关激活离线处理的开源PyTorch检查点工具的更多信息,请查看PyTorch GitHub存储库中的checkpointwrapperpy脚本,以及PyTorch博客文章激活检查点中的内容。有关SMP优化的激活离线处理的更多信息,请参阅我们文档中的激活离线处理部分。
除了混合分片、SMDDP和激活离线处理之外,SMP还提供其他优化,可以加速您的大型模型训练工作负载。这包括优化激活检查点、延迟参数初始化等。要了解更多内容,请参阅我们文档中的核心功能部分。
结论
随着数据集、模型尺寸和训练集群的不断增长,效率高的分布式训练对于及时且经济地交付模型和产品变得越来越关键。SageMaker模型并行库的最新版本通过减少代码变动,和PyTorch FSDP APIs对齐,支持通过张量并行性在大型集群加载训练,并通过优化技加快训练时间高达20。
要开始使用SMP v2,请参阅我们的文档以及我们的示例笔记本。
作者简介
罗伯特范杜森是Amazon SageMaker的高级产品经理,负责深度学习训练的框架、编译器和优化技术。
路易斯昆特拉是AWS SageMaker模型并行库的软件开发经理。在闲暇时,他常常骑着他的哈雷摩托在旧金山湾区巡游。
高塔姆库马尔是AWS AI深度学习的软件工程师,热衷于构建AI工具和系统。闲暇时,他喜欢骑自行车和阅读书籍。
拉胡尔赫尔戈尔是亚马逊云服务中分布式深度学习的高级软件开发工程师。
标签:分布式训练、生成式AI