欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 美景 > 图解大模型分布式训练:张量并行Megatron-LM方法

图解大模型分布式训练:张量并行Megatron-LM方法

2025/2/23 17:14:09 来源:https://blog.csdn.net/Antai_ZHU/article/details/144409187  浏览:    关键词:图解大模型分布式训练:张量并行Megatron-LM方法

AI因你而升温,记得加个星标哦!

随着大模型参数量的爆炸性增长,其所需内存也呈爆炸性增长,最现实的问题就是单块显卡装不下模型,所以我们需要进行分布式训练。

演进路线

  • 数据并行Data Parallelism:一台机器可以装下模型,所以将同一个模型同时部署在多台机器,用多份数据分开训练
  • 流水线并行Pipeline Parallelism:一台机器装不下模型,但模型的一层或多层一台设备装得下,所以同一个模型按层拆分到不同机器进行训练
  • 张量并行Tensor Parallelism:模型的一层都装不下了,所以同一个模型层内拆分开训练

不了解分布式训练的同学建议先阅读这几篇文章:

MapReduce:大模型分布式训练必备知识

图解大模型分布式训练:数据并行

图解大模型分布式训练:流水线并行

图解大模型分布式训练:张量并行Megatron-LM方法

图解大模型分布式训练:ZeRO系列方法

Megatron-LM

Megatron-LM是Nvidia提出的一种Tensor Parallelism(TP)方式,它的核心思想是将模型的层进行纵向或横向分割训练,Megatron-LM的TP主要针对基于Transformer中的Self-Attention和MLP进行拆分并行。

将一个矩阵按行拆分和按列拆分的神经网络正常传播与反向传播图:

MLP层张量并行

在MLP层中,先对A采用“列切割”,然后对B采用“行切割” :

  • f 的 forward 计算:把输入X拷贝到两块GPU上,每块GPU即可独立做forward计算
  • g 的 forward 计算:每块GPU上的forward的计算完毕,取得Z1和Z2后,GPU间做一次AllReduce,相加结果产生Z
  • g 的 backward 计算:只需要把 ∂ L ∂ Z \frac{\partial L}{\partial Z} ZL拷贝到两块GPU上,两块GPU就能各自独立做梯度计算。
  • f 的 backward 计算:当前层的梯度计算完毕,需要传递到下一层继续做梯度计算时,我们需要求得 ∂ L ∂ X \frac{\partial L}{\partial X} XL,则此时两块GPU做一次AllReduce,把各自的梯度相加即可。

那为什么我们对A采用列切割,对B采用行切割呢?这样设计的原因是,我们尽量保证各GPU上的计算相互独立,减少通讯量。对A来说,需要做一次GELU的计算,而GELU函数是非线形的,它的性质如下:

也就意味着,如果对A采用行切割,我们必须在做GELU前,做一次AllReduce,这样就会产生额外通讯量。但是如果对A采用列切割,那每块GPU就可以继续独立计算了。一旦确认好A做列切割,那么也就相应定好B需要做行切割了。

MHA层张量并行

在 MHA 层,对三个参数矩阵Q,K,V按照“列切割” ,对线性层B按照“行切割” ,切割的方式和 MLP 层基本一致,其forward与backward原理也一致。在实际应用中,并不一定按照一个head占用一块GPU来切割权重,我们也可以一个多个head占用一块GPU,这依然不会改变单块GPU上独立计算的目的。所以实际设计时,我们尽量保证head总数能被GPU个数整除。

上图为MLP与MHA块放置在一起,一个Transformer层的张量模型并行流程图。可以看到,一个Transformer层的正向和反向传播中总共有 4 个All-Reduce通信操作。

实际应用

到这里为止,我们基本把张量并行的计算架构说完了。在实际应用中,对Transformer类的模型,采用最经典方法是张量并行 + 数据并行,并在数据并行中引入ZeRO做显存优化。

其中,node表示一台机器,一般我们在同一台机器的GPU间做张量并行,在不同的机器上做数据并行。

在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词