[CLS] Token 在 ViT(Vision Transformer)中的作用与实现
1. 什么是 [CLS] Token?
[CLS](classification token)是Transformer模型中一个可学习的嵌入向量,最初在 BERT(Bidirectional Encoder Representations from Transformers)中提出,用于文本分类任务。在 ViT(Vision Transformer)中,[CLS] Token 也被用来汇总图像全局特征,最终用于分类或其他任务。
2. [CLS] Token 在 ViT 中的作用
在 ViT 中,输入图像被划分为多个固定大小的 Patch(如 16×1616 \times 16 像素),然后每个 Patch 被投影到一个 DD 维的特征向量空间,形成 Token 序列。为了让模型学习到全局信息,ViT 在输入序列的最前面添加一个特殊的 [CLS]
Token:
其中:
-
:一个可学习的 DD 维向量,与 Patch Tokens 具有相同的维度。
-
:每个 Patch 经过线性投影后的特征向量。
在 Transformer 计算过程中,所有 Token 之间会进行自注意力计算,这样 [CLS] Token 就能聚合所有 Patch Token 的信息。最终,ViT 只使用 [CLS] Token 的最终表示来进行图像分类。
3. [CLS] Token 计算流程
ViT 处理 [CLS] Token 的主要流程如下:
-
输入图像划分 Patch
-
设输入图像大小为
(如 224×224)。
-
设 Patch 大小为
(如 16×16)。
-
图像被划分成
个 Patch。
-
-
线性投影到特征空间
-
每个 Patch 被展平为 P2P^2 维向量,然后通过一个可训练的线性变换投影到 DD 维:
-
这里
是可训练的权重矩阵。
-
-
添加 [CLS] Token
-
初始化一个可学习的 [CLS] 向量:
-
并将其拼接到 Patch Tokens 的前面,形成输入序列:
-
-
位置编码(Positional Encoding)
-
由于 Transformer 不能感知位置信息,ViT 为每个 Token 添加可学习的位置编码:
-
其中
是可训练的位置编码矩阵。
-
-
Transformer 计算
-
经过多层 Transformer(注意力计算 + FFN),输出 [CLS] Token 的最终表示:
-
其中 LL 是 Transformer 层数。
-
-
使用 [CLS] Token 进行分类
-
最后,[CLS] Token 的表示会输入到一个 MLP(多层感知机)分类器:
-
其中
是 Transformer 最后一个 Block 输出的 [CLS] Token 表示。
-
4. 为什么 [CLS] Token 能代表全局信息?
Transformer 通过自注意力机制(Self-Attention)让 [CLS] Token 逐步聚合所有 Patch Token 的信息:
-
在每一层 Transformer 中,[CLS] Token 作为 Query,与所有 Token(包括自己)进行注意力计算。
-
由于 [CLS] Token 参与了多层 Transformer 的注意力计算,它会学习到整个图像的全局信息。
-
最终,经过多层 Transformer 计算后,[CLS] Token 变成了图像的全局表示,可用于分类任务。
5. ViT 中 [CLS] Token 的优化与变种
虽然 [CLS] Token 在 ViT 中效果不错,但有些研究提出了改进方案:
-
GAP(Global Average Pooling)替代 [CLS] Token
-
研究发现,直接对所有 Patch Token 做平均池化(Global Average Pooling, GAP)可以代替 [CLS] Token,减少参数并提高分类精度:
-
这种方法能提高模型的稳定性,同时减少 [CLS] Token 可能带来的信息损失。
-
-
Distilled Token(在 DeiT 模型中)
-
在 DeiT(Data-efficient Image Transformer)中,除了标准的 [CLS] Token,还引入了一个蒸馏 Token(Distillation Token),用于知识蒸馏:
-
这个蒸馏 Token 用于模仿 CNN 预训练教师模型的行为,提高训练效率。
-
-
Hybrid ViT(使用 CNN 提取特征)
-
研究发现,ViT 仅靠 [CLS] Token 可能会忽略局部细节信息。因此,一些变种如 Hybrid ViT 结合了 CNN 进行特征提取,再输入 Transformer,提高了细粒度信息的捕获能力。
-
6. 总结
-
[CLS] Token 是 ViT 中用于分类的全局表征向量,其通过 Transformer 自注意力机制聚合整个图像的信息。
-
计算流程:
-
输入图像划分 Patch 并投影到高维特征空间。
-
在序列开头添加 [CLS] Token,作为全局特征的代理。
-
添加位置编码 以维持位置信息。
-
Transformer 计算,[CLS] Token 逐步聚合全局信息。
-
最终 [CLS] Token 通过 MLP 进行分类。
-
-
优化方法:
-
使用 GAP 代替 [CLS] Token 提高稳定性。
-
采用 Distilled Token 进行知识蒸馏(如 DeiT)。
-
结合 CNN 提取局部特征(Hybrid ViT)。
-
[CLS] Token 在 ViT 中是关键组件,决定了模型的分类性能,但研究也表明它并非唯一的最佳选择。