欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > IT业 > Transformer特辑

Transformer特辑

2025/4/28 2:53:54 来源:https://blog.csdn.net/weixin_38812492/article/details/140265996  浏览:    关键词:Transformer特辑

https://github.com/LongxingTan/Machine-learning-interview

模型结构

transformer 原文插图

  • 基本单元:token_embedding + positional encoding, encoder, token_embedding + positional encoding, decoder
  • encoder: (self-attention, skip-connect, ln), (ffn, skip-connect, ln)
  • decoder: (self-attention, skip-connect, ln), (cross-attention, skip-connect, ln), (ffn, skip-connect, ln)

复杂度

参数与计算量

参考文献5插图

  • 反向传播优化过程:(1)前向计算损失函数,(2)后向计算梯度,(3)优化器更新参数

开始训练一个大模型之前,根据scaling law来估算,有多少数据,需要多少算力,要计算多少时间

  • 深度学习每次前向计算,矩阵乘法就是一次加一次乘,一个parameters,要对应2次浮点计算,所以要乘以2

我们采用文献6中的约定:

  • L: Transfomer 层树
  • H:d_model, 也就是attention hidden_size维度
  • h: 多头注意力有几个attention 头
  • B: batchsize
  • S:序列的长度,比如GPT 2K,LLama2 4K
  • V: 词表里词的数量 vocab

Attention

从模型结构中拿出一个标准单元
attention, skip_connect, ln + ffn, skip_connect, ln

在这里插入图片描述

输入的embedding形状为: [B,S,H]

  • 多头注意力先把Q, K, V都dense层到H维度,[B, S, H] X [H, H] = [B, S, H], 共计算BSH^2次 x 3
  • 计算attention score, softmax(Q* K转置 / sqrt(d_model)),[B, h, S, H’] X [B, h, H’, S] = [B, h, S, S],考虑其中多头,共计算 BHS^2次
  • 与V点积,[B, h, S, S] X [B, h, S, H’] = [B, h, S, H’],共计算 BhS^2 H’ = BHS**2次
  • 经过dense线性层,多头转换回去,[B, h, S, H’] X [H’, H’] = [B, S, H],共计算BSH^2次

以上Attention过程总共计算为 2 * (3BSH2 + BSH2 + BHS2 + BHS2) = 8BSH**2 + 4BHS **2,乘以2是因为神经网络计算一次加法 一次乘法。

FFN

输入embedding形状为:[B, S, H]

  • ffn第一层,[B, S, H] x [H, 4H] = [B, S, 4H], 共计算 4BSH^2
  • ffn第二层,[B, S, 4H] x [4H, H] = [B, S, H], 共计算 4BSH^2

以上FFN总计算为 16 BSH**2

总计算量

前向计算量

  • 一个attention + ffn单元:24BSH**2 + 4 BHS **2
  • L层: L * (单元)
  • 生成:2BSHV

反向求导的时候,Loss算梯度得到新weight然后更新,所以是前向计算的两倍,乘以4

完成每个参数,都过一遍所有Token的情况下,也就是一个epoch,要经过6次浮点运算

对于Llama 65B模型推导

  • 模型参数: 65 * 10 ^9
  • token: 1.4 * 10 ^12

因此需要算力 =6 * (模型参数 * 总token)

实际算力=GPU总数单个GPU算力单个GPU利用率。
实际算力 = 2048 * 312(A100 Tflops)* 10^12() * 0.45

需要算力/实际算力 = 时间(原文21天)

显存

  • 装载模型,假如模型的参数是以FP16来计算的(A100之后BF16的居多,防止计算的时候上溢出)

  • 一个参数被表示16位的浮点数,所以它也就占用2个byte 。

  • 7B的话,静态显存占用量,指模型的所有参数被load到显存里,如果以BF16的话,要占据14个G

  • 训练过程中,除了模型参数本身外,还有梯度和优化器

在一次用AdamW和混合精度训练的Epcho里,每一个模型参数,需要占用:
2byte的模型静态参数权重(以16bit存储)
2byte的模型更新参数权重(以16bit存储)
2byte的梯度(以16bit存储)
2byte的梯度更新(以16bit存储)
4byte的一阶动量优化器更新(以32bit存储)
4byte的二阶方差优化器更新(以32bit存储)

也就是: 一个模型参数需要占用16bytes的内存

更详细可以参考 LLM 参数,显存,Tflops? 训练篇(5) - 周博洋的文章 - 知乎

Tokenizer

  • Byte Pair Encoding(BPE), Byte-level BPE(BBPE),Uniform Language Model(ULM),WordPiece
  • https://github.com/LongxingTan/Machine-learning-interview/blob/main/02_ml/11_nlp.md

Positional encoding/embedding

由于attention模型自身没有衡量位置的能力,因此需要位置编码。至于输入为什么是token_embedding + positional encoding, 可参考为什么 Bert 的三个 Embedding 可以进行相加? - 知乎

transformer论文中使用的是 positional encoding,

  • 位置编码是和token embeding一样的形状,[B, S, H]
  • 位置编码是位于[0, 1]的连续数值

旋转位置编码: RoPE

Attention 推理优化

KV-Cache

  • KV cache主要分成5个方向的优化,即Sparse、Quantization、Allocator、Window、share
  • 关于为什么Q不需要缓存,可参考为什么加速LLM推断有KV Cache而没有Q Cache? - 知乎
  • KC cache计算量,显存分析,可参考KV cache详解 图示,显存,计算量分析,代码 - 莫笑傅立叶的文章 - 知乎

Multi-Query Attention (MQA)

  • MQA 在 encoder 上的提速没有非常明显,但在 decoder 上的提速是很显著

在这里插入图片描述

Grouped Query Attention (GQA)

在这里插入图片描述

Sliding window attention (SWA)

在这里插入图片描述

Flash Attention

  • FlashAttention主要解决Transformer计算速度慢和存储占用高的问题. 将优化重点放在了降低存储访问开销(Memory Access Cost,MAC)上

PagedAttention

在这里插入图片描述

Quantization

Decoding/sampling

Constrained sampling

Speculative decoding 投机采样

  • Accelerating Large Language Model Decoding with Speculative Sampling
  • Fast Inference from Transformers via Speculative Decoding

Decoder-only 推理

  • How continuous batching enables 23x throughput in LLM inference while reducing p50 latency
  • 关于 prefill cache, 可参考原理&图解vLLM Automatic Prefix Cache(RadixAttention): 首Token时延优化 - DefTruth的文章 - 知乎

参考

  • https://jalammar.github.io/illustrated-transformer/
  • http://nlp.seas.harvard.edu/annotated-transformer/
  • https://github.com/Kyubyong/transformer
  • Transformer学习笔记一:Positional Encoding(位置编码) - 猛猿的文章 - 知乎
  • 浅谈后向传递的计算量大约是前向传递的两倍 - 回旋托马斯x的文章 - 知乎
  • LLM 参数,显存,Tflops? 训练篇(1) - 周博洋的文章 - 知乎
  • Llama源码深入解析 - 一个有毅力的吃货的文章 - 知乎
  • 十分钟读懂旋转编码(RoPE) - 绝密伏击的文章 - 知乎
  • 大模型推理性能优化之KV Cache解读 - Young的文章 - 知乎
  • Muti Query Attention 和 Attention with Linear Bias(附源码) - 何枝的文章 - 知乎
  • DistServe速读——Prefill & Decode解耦、模型并行策略&GPU资源分配解耦 - 阿杰的文章 - 知乎
  • https://github.com/alibaba/Megatron-LLaMA
  • 稀疏注意力计算:sliding window attention - Linsight的文章 - 知乎
  • 大模型推理加速:KV Cache Sparsity(稀疏化)方法 - 歪门正道的文章 - 知乎

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词