引言
在在线广告点击率(CTR)预估、推荐系统等应用中,如何从高维稀疏的类别特征中提取有效信息成为关键问题。本文将详细由伦敦大学学院的研究人员Zhang et al. (2016)
提出的介绍一种融合因子分解机(Factorisation Machine, FM)和深度神经网络(Deep Neural Networks, DNN)的模型——FNN(Factorisation-machine supported Neural Networks)。这种模型利用 FM 在底层的有效特征表示,再通过多层非线性变换进一步捕捉数据内部的复杂模式,从而提升预估性能。
1. FNN 模型结构概述
FNN 结合了 Factorization Machine(FM) 和 神经网络(DNN) :
-
底层是 FM:利用 FM预训练得到的各特征隐向量初始化 Embedding 层的参数,而不是随机初始化。
-
多层神经网络:
- 第一层(Dense Real Layer,z) : 由 FM 提供的向量
z
作为输入。 - 隐藏层(Hidden Layers, l1 & l2) : 采用 tanh 激活函数(经验上比 ReLU 更好)。
- 输出层(CTR 预测) : 使用 sigmoid 激活函数,将输出映射到 (0,1) 之间,表示点击率的概率。
- 第一层(Dense Real Layer,z) : 由 FM 提供的向量
具体来说,FNN 模型由以下部分组成:
-
输入层
- 输入特征为经过 field-wise one-hot 编码的稀疏二值向量 x x x。例如,对于字段“城市(city)”,每个可能取值(如 London、New York 等)对应一个神经元,其中只有一个取值为 1,其余均为 0。
-
FM 底层(特征表示层)
- 对于每个字段 i i i,构造参数向量 z i z_i zi:
z i = W i ′ ⋅ x [ start i : end i ] = ( w i , v 1 i , v 2 i , … , v K i ) z_i = W'_i \cdot x[\text{start}_i:\text{end}_i] = \left(w_i, v^i_1, v^i_2, \dots, v^i_K\right) zi=Wi′⋅x[starti:endi]=(wi,v1i,v2i,…,vKi) - FM 模型通过如下公式预训练参数,得到初始特征表示:
y F M ( x ) = sigmoid ( w 0 + ∑ i = 1 N w i x i + ∑ i = 1 N ∑ j = i + 1 N ⟨ v i , v j ⟩ x i x j ) y_{FM}(x) = \operatorname{sigmoid}\!\left(w_0 + \sum_{i=1}^{N} w_i x_i + \sum_{i=1}^{N} \sum_{j=i+1}^{N} \langle v^i, v^j \rangle x_i x_j\right) yFM(x)=sigmoid(w0+i=1∑Nwixi+i=1∑Nj=i+1∑N⟨vi,vj⟩xixj) - 这种预训练方式既能够捕捉特征间的二阶交互,又减少了计算复杂度,因为只需更新与正输入单元相连的权重。
- 对于每个字段 i i i,构造参数向量 z i z_i zi:
-
隐藏层
- FNN 设计了多层隐藏层,每一层都采用非线性激活函数(例如 tanh):
- 第一隐藏层:
l 1 = tanh ( W 1 z + b 1 ) l_1 = \tanh\!\left(W_1 z + b_1\right) l1=tanh(W1z+b1) - 第二隐藏层:
l 2 = tanh ( W 2 l 1 + b 2 ) l_2 = \tanh\!\left(W_2 l_1 + b_2\right) l2=tanh(W2l1+b2)
- 第一隐藏层:
- 隐藏层的作用是逐步将 FM 层获得的局部信息整合并抽象成更高层次的全局特征。
- FNN 设计了多层隐藏层,每一层都采用非线性激活函数(例如 tanh):
-
输出层
- 输出层将隐藏层的输出映射为一个预测值 y ^ ∈ ( 0 , 1 ) \hat{y} \in (0, 1) y^∈(0,1),表示点击概率:
y ^ = sigmoid ( W 3 l 2 + b 3 ) \hat{y} = \operatorname{sigmoid}\!\left(W_3 l_2 + b_3\right) y^=sigmoid(W3l2+b3)
- 输出层将隐藏层的输出映射为一个预测值 y ^ ∈ ( 0 , 1 ) \hat{y} \in (0, 1) y^∈(0,1),表示点击概率:
整体网络结构的分层设计,使得模型既能利用 FM 进行局部特征表示,又能通过 DNN 捕捉数据的全局模式。
2. 为什么EMbedding层训练缓慢
这里引入王喆老师提过的例子
1. 参数数量巨大,优化难度高
Embedding 层本质上是一个查找表,将离散的类别(如单词、用户 ID、物品 ID)映射到一个连续的低维向量空间。假设在一个推荐系统中,输入维度为 10,000,每个用户的 embedding 维度为 128,那么总参数量就是:
10000 × 128 = 1 , 280 , 000 10000 \times 128 = 1,280,000 10000×128=1,280,000
- 但在某些任务中,Embedding 层的权重占比可能超过 90% ,以刚刚例子,embedding层是128维,如果加上4层128维全连接,和10维的输出层,那么其余所有层的参数就为:
( 128 ∗ 128 ) ∗ 3 + 128 ∗ 10 = 50 , 432 (128 * 128)* 3 + 128 * 10 = 50,432 (128∗128)∗3+128∗10=50,432
那么Embedding层的占比就为1,280,000/ (1,280,000 + 50,432) ≈ 96.2%
- 另外,在 NLP 任务中,BERT 模型的 Transformer 层有大量参数,但在大规模词表,Embedding 层的参数仍然占据主导地位。如此庞大的参数量导致优化器在更新参数时计算量巨大,梯度更新的步伐变慢,从而影响收敛速度。此外,参数过多还会导致存储开销大,增加计算成本。
2. 输入向量稀疏,梯度更新受限
在实际任务中,Embedding 层的输入通常是稀疏的索引,这意味着每次仅有部分 embedding 向量被激活并进行更新。例如,在 NLP 任务中,一个句子通常只包含几十个单词,但词表可能有上百万个单词,导致大部分 embedding 向量在每个 batch 里都不会被更新。
例如,在一个文本分类任务中,假设词表大小为 50,000,但一个输入句子只有 10 个单词,每次训练时,仅有这 10 个单词对应的 embedding 向量会被更新,而其他 49,990 个单词的 embedding 仍保持不变。这种稀疏更新使得整个 embedding 矩阵需要更长的训练时间才能达到较优的表示,导致收敛缓慢。
所以利用 FM 预训练得到的各特征隐向量初始化 Embedding 层的参数,而不是随机初始化是十分有必要的
。
- 一来是FM训练得到的隐向量已经包含了特征交互信息,提高了初始化的质量。
- 二来随机初始化的 Embedding 需要较长的时间去探索合适的参数空间,而FM预训练提供一个较优的起点使得模型训练时参数调整的幅度更小,能更快接近最优解。
3. 模型数学描述
下面给出 FNN 的详细数学公式说明。
3.1 输出层计算
输出层采用 logistic 激活函数,其公式为:
y ^ = sigmoid ( W 3 l 2 + b 3 ) , sigmoid ( x ) = 1 1 + e − x \hat{y} = \operatorname{sigmoid}\!\left(W_3 l_2 + b_3\right), \quad \operatorname{sigmoid}(x) = \frac{1}{1 + e^{-x}} y^=sigmoid(W3l2+b3),sigmoid(x)=1+e−x1
其中:
- W 3 ∈ R 1 × L W_3 \in \mathbb{R}^{1 \times L} W3∈R1×L
- b 3 ∈ R b_3 \in \mathbb{R} b3∈R
- l 2 ∈ R L l_2 \in \mathbb{R}^{L} l2∈RL
3.2 隐藏层计算
第二隐藏层:
l 2 = tanh ( W 2 l 1 + b 2 ) , tanh ( x ) = 1 − e − 2 x 1 + e − 2 x l_2 = \tanh\!\left(W_2 l_1 + b_2\right), \quad \tanh(x) = \frac{1 - e^{-2x}}{1 + e^{-2x}} l2=tanh(W2l1+b2),tanh(x)=1+e−2x1−e−2x
其中:
- W 2 ∈ R L × M W_2 \in \mathbb{R}^{L \times M} W2∈RL×M
- b 2 ∈ R L b_2 \in \mathbb{R}^{L} b2∈RL
- l 1 ∈ R M l_1 \in \mathbb{R}^{M} l1∈RM
第一隐藏层:
l 1 = tanh ( W 1 z + b 1 ) l_1 = \tanh\!\left(W_1 z + b_1\right) l1=tanh(W1z+b1)
其中:
- W 1 ∈ R M × J W_1 \in \mathbb{R}^{M \times J} W1∈RM×J
- b 1 ∈ R M b_1 \in \mathbb{R}^{M} b1∈RM
- z ∈ R J z \in \mathbb{R}^{J} z∈RJ
3.3 FM 底层特征表示
输入向量 z z z 包含全局偏置项和各字段的向量表示:
z = ( w 0 , z 1 , z 2 , … , z n ) z = \left(w_0, z_1, z_2, \dots, z_n\right) z=(w0,z1,z2,…,zn)
对于第 i i i 个字段,其表示为:
z i = W i ′ ⋅ x [ start i : end i ] = ( w i , v 1 i , v 2 i , … , v K i ) z_i = W'_i \cdot x[\text{start}_i:\text{end}_i] = \left(w_i, v^i_1, v^i_2, \dots, v^i_K\right) zi=Wi′⋅x[starti:endi]=(wi,v1i,v2i,…,vKi)
这里, W i ′ W'_i Wi′ 的初始化依赖于因子分解机训练结果。具体来说,FM 模型的训练目标为:
y F M ( x ) = sigmoid ( w 0 + ∑ i = 1 N w i x i + ∑ i = 1 N ∑ j = i + 1 N ⟨ v i , v j ⟩ x i x j ) y_{FM}(x) = \operatorname{sigmoid}\!\left(w_0 + \sum_{i=1}^{N} w_i x_i + \sum_{i=1}^{N} \sum_{j=i+1}^{N} \langle v^i, v^j \rangle x_i x_j\right) yFM(x)=sigmoid(w0+i=1∑Nwixi+i=1∑Nj=i+1∑N⟨vi,vj⟩xixj)
3.4 模型训练
FNN 的训练分为两个阶段:
-
预训练阶段
- FM 层预训练:采用随机梯度下降(SGD)优化 FM 的参数(仅更新正输入对应的权重),得到良好的初始特征表示。
- 隐藏层预训练:使用基于受限玻尔兹曼机(RBM)的逐层预训练(利用对比散度算法),以保存输入数据的重要信息。
-
监督微调阶段
- 在预训练完成后,通过反向传播算法以交叉熵损失函数对整个网络进行监督微调:
L ( y , y ^ ) = − y log y ^ − ( 1 − y ) log ( 1 − y ^ ) L(y, \hat{y}) = -y \log \hat{y} - (1 - y) \log (1 - \hat{y}) L(y,y^)=−ylogy^−(1−y)log(1−y^) - 这种训练策略能够有效利用 FM 提供的先验信息,使得 DNN 在数据歧义较大的情况下(如广告点击行为)依然能够保持稳定的参数更新。
- 在预训练完成后,通过反向传播算法以交叉熵损失函数对整个网络进行监督微调:
4. FNN 的优势与启示
4.1 结构优势
-
局部连接与全局捕捉:
底层 FM 对每个字段进行局部连接,有效捕捉了稀疏输入的局部交互;而上层的全连接隐藏层能够整合全局信息,捕捉更深层次的数据模式。 -
降维与表示学习:
通过 FM 层将高维稀疏数据映射到低维潜在空间,不仅缓解了特征稀疏性问题,同时为深度网络提供了较好的初始化。
4.2 训练策略
-
层级预训练:
利用 RBM 进行逐层预训练能够较好地保存输入信息,使得后续的监督微调更为高效。 -
针对正样本更新:
在 FM 训练过程中仅更新与正输入单元相连的权重,大大降低了计算复杂度。
4.3 理论与实践启示
-
融合思想的应用:
FNN 的设计灵感部分来自于卷积神经网络(CNN),即利用局部连接结构来捕捉数据的局部相关性。这种跨领域的融合为解决高维稀疏问题提供了新思路。 -
从先验到后验:
在数据存在高度歧义的场景下(如广告点击预测),利用 FM 获得稳定的先验信息,再由 DNN 进行后续微调,使得模型能够更稳健地捕捉数据中的细微差异。
5. 总结
FNN 模型通过在深度神经网络的底层引入因子分解机,有效解决了类别特征高维稀疏的问题,并利用多层非线性变换对特征进行抽象和整合。通过预训练与监督微调相结合的策略,FNN 在 CTR 预估等任务中表现出优越的性能。
这种模型不仅在理论上为我们提供了深度学习与传统因子分解方法融合的思路,同时在实践中也展现了在复杂大规模数据场景下的应用潜力。未来的研究可以进一步探索如何在其他领域(如推荐系统、风险评估等)中利用这种混合建模策略,发挥各自模型的优势。
Reference
- Weinan Zhang, Tianming Du, and Jun Wang. 2016. Deep Learning over Multi-field Categorical Data – A Case Study on User Response Prediction. In Proceedings of the 25th International Conference on World Wide Web (WWW '16) . 471–480.
- 王喆 《深度学习推荐系统》