欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 能源 > KL散度的三种估计k1 k2 k3

KL散度的三种估计k1 k2 k3

2025/4/19 1:57:29 来源:https://blog.csdn.net/weixin_44966641/article/details/147128624  浏览:    关键词:KL散度的三种估计k1 k2 k3

KL散度的三种估计k1 k2 k3

原文是 John Schulman 的博客:http://joschu.net/blog/kl-approx.html

John Schulman 在多处代码实现中采用的对 KL 散度 KL [ p , q ] \text{KL}[p,q] KL[p,q] 的估计是 1 2 ( log ⁡ p ( x ) − log ⁡ q ( x ) ) 2 \frac{1}{2}(\log p(x)-\log q(x))^2 21(logp(x)logq(x))2,并非常规的 log ⁡ q ( x ) p ( x ) \log\frac{q(x)}{p(x)} logp(x)q(x)。本文将介绍这种对 KL 散度的估计形式为什么更好(虽然是有偏的),以及如何进一步追求无偏且低方差的 KL 散度估计。

k1 估计

KL 散度定义为:
KL [ p , q ] = ∑ x q ( x ) log ⁡ q ( x ) p ( x ) = E x ∼ q [ log ⁡ q ( x ) p ( x ) ] \text{KL}[p,q]=\sum_xq(x)\log\frac{q(x)}{p(x)}=\mathbb{E}_{x\sim q}\left[\log\frac{q(x)}{p(x)}\right] \notag \\ KL[p,q]=xq(x)logp(x)q(x)=Exq[logp(x)q(x)]
由于计算复杂度高或分布本身没有闭式解等原因,KL 散度一般是无法解析求解的,我们需要从 q q q 分布中采样样本 x 1 , x 2 , ⋯ ∼ q x_1,x_2,\dots\sim q x1,x2,q,然后用蒙特卡洛方法对 KL 散度进行估计。我们对估计的 KL 散度有两个要求,第一最好是无偏的,即估计值的期望与真实值相等,第二方差要尽可能小。

最常规的对 KL 散度的估计,是用:
k 1 = log ⁡ q ( x ) p ( x ) = − log ⁡ r k1=\log\frac{q(x)}{p(x)}=-\log r \notag \\ k1=logp(x)q(x)=logr
这里记 r = p ( x ) q ( x ) r=\frac{p(x)}{q(x)} r=q(x)p(x),我们将这个估计记作 k 1 k1 k1 k 1 k1 k1 显然是无偏的,但是,它其中有个 log ⁡ \log log 函数,当 q ( x ) p ( x ) \frac{q(x)}{p(x)} p(x)q(x) 时,它的值是负的,而我们知道 KL 散度一定是正的,所以说它的方差很大。因此 k 1 k1 k1 这个估计不满足要求 2。

f 散度与k2 估计

我们考虑第二种估计:
k 2 = 1 2 ( log ⁡ p ( x ) q ( x ) ) 2 = 1 2 ( log ⁡ r ) 2 k2=\frac{1}{2}\left(\log\frac{p(x)}{q(x)}\right)^2=\frac{1}{2}(\log r)^2 \notag \\ k2=21(logq(x)p(x))2=21(logr)2

这个估计看起来很不错,它能表征出 p , q p,q p,q 之间的差异,而且它是恒正的,方差比 k 1 k1 k1 要小。

但是这个形式是哪里来的呢?实际上,这是通过更一般的 f 散度近似来的。f 散度可以看作是 KL 散度的一种推广,其定义为:
D f ( p , q ) = E x ∼ q [ f ( p ( x ) q ( x ) ) ] D_f(p,q)=\mathbb{E}_{x\sim q}\left[f(\frac{p(x)}{q(x)})\right] \notag \\ Df(p,q)=Exq[f(q(x)p(x))]
其中函数 f ( ⋅ ) f(\cdot) f() 需要是凸函数。可以看到,KL 散度,实际就是取了 f ( x ) = − log ⁡ ( x ) f(x)=-\log(x) f(x)=log(x) 的 f 散度。而我们刚刚介绍的 k 2 k2 k2,则相当于是取了 f ( x ) = 1 2 ( log ⁡ x ) 2 f(x)=\frac{1}{2}(\log x)^2 f(x)=21(logx)2,其期望也是一种 f 散度。

有这样一个事实:当两个概率分布 p p p q q q 非常接近时,所有( f f f 可微的)f 散度在二阶近似上都会表现得非常相似,这当然也包括 KL 散度。所以,我们可以选择一个其他的 f f f 函数构建 f 散度,来近似 KL 散度,只要保证 p , q p,q p,q 比较接近时的表现相似。而要保证这一点,只需要考察二者的 f ′ ′ ( 1 ) f''(1) f′′(1),很明显, 二者的 f ′ ′ ( 1 ) f''(1) f′′(1) 都为 1。

所以,我们可以将 KL 散度近似为 f ( x ) = 1 2 ( log ⁡ x ) 2 f(x)=\frac{1}{2}(\log x)^2 f(x)=21(logx)2 下的 f 散度。虽然这会带来一些偏差(后面实验显示,增加的偏差其实很小),但降低了估计值的方差。

From Kimi:

p p p q q q 接近时,我们可以将 q ( x ) p ( x ) \frac{q(x)}{p(x)} p(x)q(x) 近似为 1 + ϵ 1+\epsilon 1+ϵ,其中 ϵ \epsilon ϵ 是一个小的偏差。使用泰勒展开,我们有:
f ( 1 + ϵ ) ≈ f ( 1 ) + f ′ ( 1 ) ϵ + 1 2 f ′ ′ ( 1 ) ϵ 2 f(1+\epsilon)\approx f(1)+f'(1)\epsilon+\frac{1}{2}f''(1)\epsilon^2 \notag \\ f(1+ϵ)f(1)+f(1)ϵ+21f′′(1)ϵ2

由于 f ( 1 ) = 0 f(1)=0 f(1)=0(根据 f-散度的定义),并且 E x ∼ q [ ϵ ] = 0 \mathbb{E}_{x\sim q}[\epsilon]=0 Exq[ϵ]=0(因为 p 和 q 接近),所以 f-散度的二次近似主要取决于 f ′ ′ ( 1 ) f''(1) f′′(1)

control variate与k3 估计

更进一步,我们能不能既要无偏,又要低方差呢?想要降低方差,通用的办法是控制变量(control variate),即选用无偏的 k 1 k1 k1,但是再加上一些期望为 0 ,但是与 k 1 k1 k1 负相关的项,从而在保证无偏的同时,降低方差。很巧的是,在这里 r − 1 = p ( x ) q ( x ) − 1 r-1=\frac{p(x)}{q(x)}-1 r1=q(x)p(x)1 就是一个期望为零的项(推导如下)。
E q [ r − 1 ] = E q [ p ( x ) q ( x ) − 1 ] = ∫ [ p ( x ) q ( x ) − 1 ] q ( x ) d x = ∫ p ( x ) d x − ∫ q ( x ) d x = 1 − 1 = 0 \begin{aligned} \mathbb{E}_q[r-1]&=\mathbb{E}_q\left[\frac{p(x)}{q(x)}-1\right] \\ &=\int\left[\frac{p(x)}{q(x)}-1\right]q(x)dx \\ &=\int p(x)dx-\int q(x)dx \\ &=1-1=0 \end{aligned} \notag \\ Eq[r1]=Eq[q(x)p(x)1]=[q(x)p(x)1]q(x)dx=p(x)dxq(x)dx=11=0
所以,对于任意的 λ \lambda λ − log ⁡ r + λ ( r − 1 ) -\log r+\lambda(r-1) logr+λ(r1) 都是一个无偏估计。这样我们就可以选择一个 λ \lambda λ,使得该式的方差最小。但是由于该式依赖于 p p p q q q,所以没法直接解析求解。我们就直接考虑一个简单的选择,取 λ = 1 \lambda=1 λ=1,由于有 log ⁡ ( x ) ≤ x − 1 \log(x)\le x-1 log(x)x1,所以该式能保证是正的,已经能够尽量减小方差了。所以,我们有对 KL 散度的第三种估计:
k 3 = ( r − 1 ) − log ⁡ r k3=(r-1)-\log r \notag \\ k3=(r1)logr
我们可以将这个思想扩展到任意的 f 散度估计上,就比如 KL [ p , q ] \text{KL}[p,q] KL[p,q](注意 p , q p,q p,q 反过来了),对它的无偏低方差估计,就可以取 r log ⁡ r − ( r − 1 ) r\log r-(r-1) rlogr(r1)

实验

我们进行一个简单的实验来对比这三种 KL 散度的估计。假设 q = N ( 0 , 1 ) q=N(0,1) q=N(0,1) p = N ( 0.1 , 1 ) p=N(0.1,1) p=N(0.1,1),它们真实的 KL 散度为 0.005。

import torch.distributions as dis
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
truekl = dis.kl_divergence(p, q)
print("true", truekl)
logr = p.log_prob(x) - q.log_prob(x)
k1 = -logr
k2 = logr ** 2 / 2
k3 = (logr.exp() - 1) - logr
for k in (k1, k2, k3):print((k.mean() - truekl) / truekl, k.std() / truekl
bias/truestdev/true
k1020
k20.0021.42
k301.42

可以看到,k2 虽然不是无偏的,但是偏差非常小,只有 0.2%,因为此时的 p , q p,q p,q 两分布很接近。

我们再将 p p p 改为 N ( 1 , 1 ) N(1,1) N(1,1),此时 KL 散度的真实值为 0.5。

bias/truestdev/true
k102
k20.251.73
k301.7

此时,k2 的偏差就非常大了,因为此时的 p , q p,q p,q 两分布已经不是那么接近了。而 k3 甚至比 k2 的方差还要低,所有看起来 k3 是一个全面更优的估计。

OpenRLHF 中实现的 k1k2k3 估计方法:link。

总结

本文中我们首先介绍了 KL 散度最常用的估计 k1,但是发现它方差非常大,然后我们介绍 f 散度并设计了对 KL 散度近似的 k2 估计,k2 降低了方差但是是有偏的。为了得到无偏且低方差的估计,我们又考虑通过 control variate 构造了 k3 估计,达到了比较理想的对 KL 散度的估计。在 RL (for LLM) 中,k2、k3 都有被选用,我们需要根据实际场景分析和实验来决定选用哪种估计(比如 k2 估计要求两分布是比较接近的,才能有降低的偏差)。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com