KL散度的三种估计k1 k2 k3
原文是 John Schulman 的博客:http://joschu.net/blog/kl-approx.html
John Schulman 在多处代码实现中采用的对 KL 散度 KL [ p , q ] \text{KL}[p,q] KL[p,q] 的估计是 1 2 ( log p ( x ) − log q ( x ) ) 2 \frac{1}{2}(\log p(x)-\log q(x))^2 21(logp(x)−logq(x))2,并非常规的 log q ( x ) p ( x ) \log\frac{q(x)}{p(x)} logp(x)q(x)。本文将介绍这种对 KL 散度的估计形式为什么更好(虽然是有偏的),以及如何进一步追求无偏且低方差的 KL 散度估计。
k1 估计
KL 散度定义为:
KL [ p , q ] = ∑ x q ( x ) log q ( x ) p ( x ) = E x ∼ q [ log q ( x ) p ( x ) ] \text{KL}[p,q]=\sum_xq(x)\log\frac{q(x)}{p(x)}=\mathbb{E}_{x\sim q}\left[\log\frac{q(x)}{p(x)}\right] \notag \\ KL[p,q]=x∑q(x)logp(x)q(x)=Ex∼q[logp(x)q(x)]
由于计算复杂度高或分布本身没有闭式解等原因,KL 散度一般是无法解析求解的,我们需要从 q q q 分布中采样样本 x 1 , x 2 , ⋯ ∼ q x_1,x_2,\dots\sim q x1,x2,⋯∼q,然后用蒙特卡洛方法对 KL 散度进行估计。我们对估计的 KL 散度有两个要求,第一最好是无偏的,即估计值的期望与真实值相等,第二方差要尽可能小。
最常规的对 KL 散度的估计,是用:
k 1 = log q ( x ) p ( x ) = − log r k1=\log\frac{q(x)}{p(x)}=-\log r \notag \\ k1=logp(x)q(x)=−logr
这里记 r = p ( x ) q ( x ) r=\frac{p(x)}{q(x)} r=q(x)p(x),我们将这个估计记作 k 1 k1 k1。 k 1 k1 k1 显然是无偏的,但是,它其中有个 log \log log 函数,当 q ( x ) p ( x ) \frac{q(x)}{p(x)} p(x)q(x) 时,它的值是负的,而我们知道 KL 散度一定是正的,所以说它的方差很大。因此 k 1 k1 k1 这个估计不满足要求 2。
f 散度与k2 估计
我们考虑第二种估计:
k 2 = 1 2 ( log p ( x ) q ( x ) ) 2 = 1 2 ( log r ) 2 k2=\frac{1}{2}\left(\log\frac{p(x)}{q(x)}\right)^2=\frac{1}{2}(\log r)^2 \notag \\ k2=21(logq(x)p(x))2=21(logr)2
这个估计看起来很不错,它能表征出 p , q p,q p,q 之间的差异,而且它是恒正的,方差比 k 1 k1 k1 要小。
但是这个形式是哪里来的呢?实际上,这是通过更一般的 f 散度近似来的。f 散度可以看作是 KL 散度的一种推广,其定义为:
D f ( p , q ) = E x ∼ q [ f ( p ( x ) q ( x ) ) ] D_f(p,q)=\mathbb{E}_{x\sim q}\left[f(\frac{p(x)}{q(x)})\right] \notag \\ Df(p,q)=Ex∼q[f(q(x)p(x))]
其中函数 f ( ⋅ ) f(\cdot) f(⋅) 需要是凸函数。可以看到,KL 散度,实际就是取了 f ( x ) = − log ( x ) f(x)=-\log(x) f(x)=−log(x) 的 f 散度。而我们刚刚介绍的 k 2 k2 k2,则相当于是取了 f ( x ) = 1 2 ( log x ) 2 f(x)=\frac{1}{2}(\log x)^2 f(x)=21(logx)2,其期望也是一种 f 散度。
有这样一个事实:当两个概率分布 p p p 和 q q q 非常接近时,所有( f f f 可微的)f 散度在二阶近似上都会表现得非常相似,这当然也包括 KL 散度。所以,我们可以选择一个其他的 f f f 函数构建 f 散度,来近似 KL 散度,只要保证 p , q p,q p,q 比较接近时的表现相似。而要保证这一点,只需要考察二者的 f ′ ′ ( 1 ) f''(1) f′′(1),很明显, 二者的 f ′ ′ ( 1 ) f''(1) f′′(1) 都为 1。
所以,我们可以将 KL 散度近似为 f ( x ) = 1 2 ( log x ) 2 f(x)=\frac{1}{2}(\log x)^2 f(x)=21(logx)2 下的 f 散度。虽然这会带来一些偏差(后面实验显示,增加的偏差其实很小),但降低了估计值的方差。
From Kimi:
当 p p p 和 q q q 接近时,我们可以将 q ( x ) p ( x ) \frac{q(x)}{p(x)} p(x)q(x) 近似为 1 + ϵ 1+\epsilon 1+ϵ,其中 ϵ \epsilon ϵ 是一个小的偏差。使用泰勒展开,我们有:
f ( 1 + ϵ ) ≈ f ( 1 ) + f ′ ( 1 ) ϵ + 1 2 f ′ ′ ( 1 ) ϵ 2 f(1+\epsilon)\approx f(1)+f'(1)\epsilon+\frac{1}{2}f''(1)\epsilon^2 \notag \\ f(1+ϵ)≈f(1)+f′(1)ϵ+21f′′(1)ϵ2由于 f ( 1 ) = 0 f(1)=0 f(1)=0(根据 f-散度的定义),并且 E x ∼ q [ ϵ ] = 0 \mathbb{E}_{x\sim q}[\epsilon]=0 Ex∼q[ϵ]=0(因为 p 和 q 接近),所以 f-散度的二次近似主要取决于 f ′ ′ ( 1 ) f''(1) f′′(1)。
control variate与k3 估计
更进一步,我们能不能既要无偏,又要低方差呢?想要降低方差,通用的办法是控制变量(control variate),即选用无偏的 k 1 k1 k1,但是再加上一些期望为 0 ,但是与 k 1 k1 k1 负相关的项,从而在保证无偏的同时,降低方差。很巧的是,在这里 r − 1 = p ( x ) q ( x ) − 1 r-1=\frac{p(x)}{q(x)}-1 r−1=q(x)p(x)−1 就是一个期望为零的项(推导如下)。
E q [ r − 1 ] = E q [ p ( x ) q ( x ) − 1 ] = ∫ [ p ( x ) q ( x ) − 1 ] q ( x ) d x = ∫ p ( x ) d x − ∫ q ( x ) d x = 1 − 1 = 0 \begin{aligned} \mathbb{E}_q[r-1]&=\mathbb{E}_q\left[\frac{p(x)}{q(x)}-1\right] \\ &=\int\left[\frac{p(x)}{q(x)}-1\right]q(x)dx \\ &=\int p(x)dx-\int q(x)dx \\ &=1-1=0 \end{aligned} \notag \\ Eq[r−1]=Eq[q(x)p(x)−1]=∫[q(x)p(x)−1]q(x)dx=∫p(x)dx−∫q(x)dx=1−1=0
所以,对于任意的 λ \lambda λ, − log r + λ ( r − 1 ) -\log r+\lambda(r-1) −logr+λ(r−1) 都是一个无偏估计。这样我们就可以选择一个 λ \lambda λ,使得该式的方差最小。但是由于该式依赖于 p p p 和 q q q,所以没法直接解析求解。我们就直接考虑一个简单的选择,取 λ = 1 \lambda=1 λ=1,由于有 log ( x ) ≤ x − 1 \log(x)\le x-1 log(x)≤x−1,所以该式能保证是正的,已经能够尽量减小方差了。所以,我们有对 KL 散度的第三种估计:
k 3 = ( r − 1 ) − log r k3=(r-1)-\log r \notag \\ k3=(r−1)−logr
我们可以将这个思想扩展到任意的 f 散度估计上,就比如 KL [ p , q ] \text{KL}[p,q] KL[p,q](注意 p , q p,q p,q 反过来了),对它的无偏低方差估计,就可以取 r log r − ( r − 1 ) r\log r-(r-1) rlogr−(r−1)。
实验
我们进行一个简单的实验来对比这三种 KL 散度的估计。假设 q = N ( 0 , 1 ) q=N(0,1) q=N(0,1), p = N ( 0.1 , 1 ) p=N(0.1,1) p=N(0.1,1),它们真实的 KL 散度为 0.005。
import torch.distributions as dis
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
truekl = dis.kl_divergence(p, q)
print("true", truekl)
logr = p.log_prob(x) - q.log_prob(x)
k1 = -logr
k2 = logr ** 2 / 2
k3 = (logr.exp() - 1) - logr
for k in (k1, k2, k3):print((k.mean() - truekl) / truekl, k.std() / truekl
bias/true | stdev/true | |
---|---|---|
k1 | 0 | 20 |
k2 | 0.002 | 1.42 |
k3 | 0 | 1.42 |
可以看到,k2 虽然不是无偏的,但是偏差非常小,只有 0.2%,因为此时的 p , q p,q p,q 两分布很接近。
我们再将 p p p 改为 N ( 1 , 1 ) N(1,1) N(1,1),此时 KL 散度的真实值为 0.5。
bias/true | stdev/true | |
---|---|---|
k1 | 0 | 2 |
k2 | 0.25 | 1.73 |
k3 | 0 | 1.7 |
此时,k2 的偏差就非常大了,因为此时的 p , q p,q p,q 两分布已经不是那么接近了。而 k3 甚至比 k2 的方差还要低,所有看起来 k3 是一个全面更优的估计。
OpenRLHF 中实现的 k1k2k3 估计方法:link。
总结
本文中我们首先介绍了 KL 散度最常用的估计 k1,但是发现它方差非常大,然后我们介绍 f 散度并设计了对 KL 散度近似的 k2 估计,k2 降低了方差但是是有偏的。为了得到无偏且低方差的估计,我们又考虑通过 control variate 构造了 k3 估计,达到了比较理想的对 KL 散度的估计。在 RL (for LLM) 中,k2、k3 都有被选用,我们需要根据实际场景分析和实验来决定选用哪种估计(比如 k2 估计要求两分布是比较接近的,才能有降低的偏差)。