文章目录
- 概要
- 1. 数值梯度的公式
- 2. 数值梯度计算过程
- 3. 数值梯度的特点
概要
前文已经简单介绍梯度,本文主要介绍大语言模型中使用数值梯度的方法实现 损失值 L L L 对模型权重矩阵的梯度计算,而不是传统的链式法则进行梯度计算。如果想要理解整体计算方式,先明白损失值 L L L的计算方式,通过公式了解其和权重矩阵 W V W_V WV的关系。然后再理解损失值 L L L对权重矩阵 W V W_V WV的梯度计算。
1. 数值梯度的公式
数值梯度通过有限差分法近似计算梯度,对权重矩阵 W V W_V WV 中每个元素的梯度 ∂ L ∂ W V i j \frac{\partial L}{\partial W_{V_{ij}}} ∂WVij∂L:
∇ L W V i j = L p l u s − L c u r r e n t h \nabla L_{W_{V_{ij}}} = \frac{L_{plus}-L_{current}}{h} ∇LWVij=hLplus−Lcurrent
其中,每个参数的含义在下文中有讲解。
2. 数值梯度计算过程
(1) 初始化
- 给定权重矩阵 W V ∈ F m × n W_V \in \mathbb{F}^{m \times n} WV∈Fm×n,与 W V W_V WV大小相同的梯度矩阵 ∇ L W V = zeros ( m , n ) \nabla L_{W_V} = \text{zeros}(m, n) ∇LWV=zeros(m,n)。
- 确定增量 h h h 的值(如 h = 1 0 − 5 h=10^{−5} h=10−5)。
(2) 遍历权重矩阵的每个元素
对于 W V W_V WV中的每个元素 W V i j W_{V_{ij}} WVij:
- 创建一个单位矩阵 E i j E_{ij} Eij,大小与 W V W_V WV相同,且 E i j = 1 E_{ij}=1 Eij=1。
- 计算损失值:
- L p l u s = L ( W v + h ∗ E i j ) L_{plus}=L(W_v+h*E_{ij}) Lplus=L(Wv+h∗Eij):
- 在 W V W_V WV的第 ( i , j ) (i,j) (i,j) 元素增加一个微小值 h h h,得到新的权重矩阵,然后计算损失值 L p l u s L_{plus} Lplus.
- L c u r r e n t = L ( W v ) L_{current}=L(W_v) Lcurrent=L(Wv):
- 使用当前的权重矩阵 W V W_V WV计算损失值 L c u r r e n t L_{current} Lcurrent。
- 使用当前的权重矩阵 W V W_V WV计算损失值 L c u r r e n t L_{current} Lcurrent。
(3) 梯度估算
通过有限差分公式,计算第 ( i , j ) (i,j) (i,j)元素的梯度:
∇ L W V i j = L p l u s − L c u r r e n t h \nabla L_{W_{V_{ij}}} = \frac{L_{plus}-L_{current}}{h} ∇LWVij=hLplus−Lcurrent
这个公式的含义是:通过观察 W V i j W_{V_{ij}} WVij 增加 h h h 后损失函数的变化,我们可以估算出损失函数对该参数的敏感程度(梯度)。
3. 数值梯度的特点
优点:
- 简单直观:无需解析推导梯度公式,直接利用损失函数计算。
- 适合验证解析梯度:可以作为解析梯度的参考标准,用于检测实现是否正确。
缺点:
- 计算效率低:
- 对于权重矩阵 W V ∈ F m × n W_V \in \mathbb{F}^{m \times n} WV∈Fm×n,需要计算 m × n m×n m×n 次损失。
- 如果网络规模较大,数值梯度的计算会非常耗时。
- 数值误差:
- 梯度近似的精度取决于 h h h 的选择。
- h h h 太大会导致误差较大, h h h 太小可能引入浮点数精度问题。