欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > IT业 > 大模型微调中显存占用和训练时间的影响因素

大模型微调中显存占用和训练时间的影响因素

2025/3/16 12:02:21 来源:https://blog.csdn.net/Louise_Trender/article/details/146283561  浏览:    关键词:大模型微调中显存占用和训练时间的影响因素

BatchSize

显存占用:与batch_size呈线性关系,可理解为 M t o t a l = M f i x e d + B a t c h S i z e ∗ M p e r − s a m p l e M_{total}=M_{fixed}+BatchSize*M_{per-sample} Mtotal=Mfixed+BatchSizeMpersample,其中 M f i x e d M_{fixed} Mfixed指的是模型本身固定占用的显存(由参数数量决定)和优化器状态(也由参数数量决定)

总训练时间:理论上与BatchSize无关(总数不变,单步训练时间增加,总步数减少),但实际中随BatchSize越大,总时间可能减少(硬件并行效率提升),直到显存或硬件并行能力达到瓶颈。

截断长度(输入序列分词后的最大长度,即每条样本被大模型读取的最大长度)

1. 显存占用

在大型语言模型(如 Transformer)中,显存占用主要与模型的激活值(Activations)有关,而激活值的大小受到输入序列长度(即截断长度)的直接影响。以下是逐步分析:

激活值的定义

激活值是指模型在正向传播过程中每一层计算出的中间结果,通常存储在显存中,以便反向传播时计算梯度。对于 Transformer 模型,激活值主要与注意力机制(Self-Attention)和前馈网络(Feed-Forward Network, FFN)的计算相关。

显存占用的组成

显存占用主要包括:

  • 模型参数(权重和偏置):与模型规模(层数、隐藏维度)相关,与截断长度无关。
  • 激活值:与输入序列长度(截断长度 L L L)、批次大小(batch size B B B)、隐藏维度(hidden size H H H)和层数( N N N)成正比。
  • 梯度(训练时):与参数量和激活值大小相关。

对于激活值部分,显存占用主要来源于:

  1. 注意力机制:计算 Q ⋅ K T Q \cdot K^T QKT的注意力分数矩阵,尺寸为 ( B , L , L ) (B, L, L) (B,L,L),每层需要存储。
  2. 中间张量:如 V V V的加权和、前馈层的输出等。
数学表达式

假设: L L L:截断长度(序列长度), B B B:批次大小, H H H:隐藏维度, N N N:模型层数, P P P:浮点数精度(如 FP32 为 4 字节,FP16 为 2 字节)

激活值的显存占用近似为:
显存 激活值 ≈ N ⋅ B ⋅ L ⋅ H ⋅ P + N ⋅ B ⋅ L 2 ⋅ P \text{显存}_{\text{激活值}} \approx N \cdot B \cdot L \cdot H \cdot P + N \cdot B \cdot L^2 \cdot P 显存激活值NBLHP+NBL2P

  • 第一项 N ⋅ B ⋅ L ⋅ H ⋅ P N \cdot B \cdot L \cdot H \cdot P NBLHP:表示每层的线性张量(如 Q , K , V Q, K, V Q,K,V或 FFN 输出)的显存占用。
  • 第二项 N ⋅ B ⋅ L 2 ⋅ P N \cdot B \cdot L^2 \cdot P NBL2P:表示注意力分数矩阵的显存占用(仅在标准注意力机制中显著,若使用优化如 FlashAttention,则可能减少)。

结论:显存占用与截断长度 L L L呈线性( O ( L ) O(L) O(L))到二次方( O ( L 2 ) O(L^2) O(L2))的关系,具体取决于注意力机制的实现方式。


2. 训练时间

训练时间主要与计算量(FLOPs,浮点运算次数)和硬件并行能力有关,而截断长度会影响计算量。

计算量的组成
  1. 注意力机制:每层的计算量与 L 2 L^2 L2相关,因为需要计算 L × L L \times L L×L的注意力矩阵。
  2. 前馈网络:每层的计算量与 L L L线性相关,因为对每个 token 独立计算。

总计算量(FLOPs)近似为:
FLOPs ≈ N ⋅ B ⋅ ( 2 ⋅ L 2 ⋅ H + 4 ⋅ L ⋅ H 2 ) \text{FLOPs} \approx N \cdot B \cdot (2 \cdot L^2 \cdot H + 4 \cdot L \cdot H^2) FLOPsNB(2L2H+4LH2)

  • 2 ⋅ L 2 ⋅ H 2 \cdot L^2 \cdot H 2L2H:注意力机制的矩阵乘法(如 Q ⋅ K T Q \cdot K^T QKT softmax ⋅ V \text{softmax} \cdot V softmaxV),
  • 4 ⋅ L ⋅ H 2 4 \cdot L \cdot H^2 4LH2:前馈网络的计算(假设 FFN 隐藏层维度为 4 H 4H 4H)。
训练时间

训练时间与 FLOPs 成正比,同时受硬件并行能力(如 GPU 的计算核心数)影响。假设每秒浮点运算能力为 F GPU F_{\text{GPU}} FGPU(单位:FLOPs/s),则单次前向+反向传播的训练时间为:
时间 ≈ FLOPs F GPU ≈ N ⋅ B ⋅ ( 2 ⋅ L 2 ⋅ H + 4 ⋅ L ⋅ H 2 ) F GPU \text{时间} \approx \frac{\text{FLOPs}}{F_{\text{GPU}}} \approx \frac{N \cdot B \cdot (2 \cdot L^2 \cdot H + 4 \cdot L \cdot H^2)}{F_{\text{GPU}}} 时间FGPUFLOPsFGPUNB(2L2H+4LH2)

结论:训练时间与截断长度 L L L呈线性( O ( L ) O(L) O(L))到二次方( O ( L 2 ) O(L^2) O(L2))的关系,具体取决于注意力机制的计算占比。


3. 总结

  • 显存占用:与 L L L O ( L ) O(L) O(L) O ( L 2 ) O(L^2) O(L2)关系,取决于是否存储完整的注意力矩阵。
  • 训练时间:与 L L L O ( L ) O(L) O(L) O ( L 2 ) O(L^2) O(L2)关系,注意力机制的二次项通常更显著。

1

假设某模型大小为5GB,推理所需显存也为5GB,普通Lora微调(FP16)所需显存为5GB*2=10GB,8bit的QLora量化为5GB/2=2.5GB,4bit的QLora量化为5GB/4=1.25GB

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词