前言
前段时间接触到了联邦学习(Federated Learning, FL)。涉猎了几年多目标优化的我,惊奇地发现横向联邦学习里面也有用多目标优化来做的。于是有感而发,特此写一篇博客记录记录,如有机会可以和大家多多交流。遇到不专业的地方,欢迎大家来指正!
参考文章:FedMDFG: Federated Learning with Multi-Gradient Descent and Fair Guidance (AAAI-2023)
联邦学习背景
横向联邦学习它的基本原理很通俗易懂,就是想象一下这样的场景:
-
有m个用户,原本每个用户都只用自己的数据来各顾各地进行本地的模型训练,这就叫做Individual Learning;
-
而如果把所有用户的数据都上传到一处地方,然后在那儿用收集起来的所有数据进行模型训练,就叫做Centralized Learning;
-
而现在用户不愿意分享数据了,但又怕只做本地模型训练得到的模型泛化性不好,想共同合作训练一个模型,这个时候,它们想到,把本地训练的结果/梯度上传到一个地方(称之为server),然后server收集到本地模型/梯度后,进行聚合运算,得到一个新的模型(称之为全局模型),并将其发给用户;用户基于收到这个新的全局模型来再进行本地训练,完后再回传给server进行聚合,如此往复。这样一来,用户就不需要share数据了,同时又能从这样的联合训练过程中获得一个泛化性更好的模型。
上述三种模型训练方式的对比如下图:
总结来说,generic FL的流程如下:
-
在第 t 轮通信时,Server把global model分发给各个clients。
-
Clients收到模型后,进行本地训练,然后上传模型更新的gradients(或local updates)给server。
-
Server基于收到的gradients(或local updates)更新global model。
-
t=t+1 ,然后回到第1步,并循环若干轮直到达到终止条件。
公平联邦学习
因为联邦学习是一个涉及多个用户共同参与的过程,因此合作的公平性尤为重要。FL的公平性有很多种,其中最直观的就是performance fairness。简单来说,当一个模型在某些用户的数据上精度很高,但在其他用户上则表现很差,则这个模型是不公平的了,例如下图:
为了提升模型的performance fairness,我们的目标是去让模型在各个clients上的精度更平均一些。那怎么衡量公平性呢?我们很容易想到用模型在各个clients上的accuracy的标准差。但标准差不是scale-free的。为此,这里我们引入余弦相似度来衡量公平性:
导致模型不公平的直接原因
为了提升FL模型的公平性,前人提出了很多复杂的方法,例如降低模型的本地更新与全局更新的冲突;在模型聚合的时候提高效果较弱的用户的权重,等等。但究其根本,很容易想到,导致模型不公平有两个直接原因:
-
使用了一个会加剧不公平的更新方向来更新模型;
-
使用了一个不恰当的学习率来更新模型。
具体来说,如果用一个错误的更新方向,它对于某些用户而言不是梯度下降的,因此就会直接导致模型不公平。此外,哪怕更新方向能促进公平,但由于步长(学习率)太大了,也会破坏模型的公平性。
为此,FedMDFG这篇文章从多目标优化的角度来给出了提升模型公平性的方法。在我之前的多目标优化研究里面,通常是有若干个目标函数,然后并不是用加权聚合的方法把它转成单目标,而是用多目标优化算法,使得解逐渐趋于帕累托最优。帕累托最优的概念可以参考此链接:多目标优化。
而对于联邦学习而言,如果我们把模型参数ω看作是决策变量,每个用户的training loss(记为L)看作优化目标,那么就可以构建这样一个多目标优化问题:
由于这是一个连续变量的优化问题,我们可以用Multiple Gradient Descent Algorithm (MGDA)来求解出一个common descent direction来更新模型,因为该direction是common descent的,所以它能同时让模型在各个用户上的效果变好。
但“不让模型在某些用户上的效果变差”只是FL公平性的一个必要条件,只做到这一点,并不能使模型变得公平。举个简单的例子:假设有两个用户,它们的loss值分别是5和6,用common descent direction来更新模型后,两者的loss值变为4和1,显然,虽然两者的loss都下降了,但模型变得更加不公平了。
文章画了一个直观的图来展现什么样的更新方向才能促进模型公平性:
从图中可看出,三个用户的梯度g1, g2, g3相互冲突,中间黄色+灰色锥的表示所有满足common descent direction的向量组成的区域,而黄色锥中的更新方向既满足common descent,又能够避免上面所说的导致模型不公平的问题。可以直观地看到,前面的几种公平联邦学习算法,并不能确保更新方向是common descent的。因此它们不但会破坏模型的公平性,甚至可能还会出现收敛问题,尤其是在non-IID的场景下。
为了计算出一个公平的更新方向(即落在黄色区域内),文章在FL的多目标优化的formulation基础上,进一步引入了一个新的优化目标:
其中h是这样一个向量:
它是在clients的loss组成的向量L(ω)上作一个垂直的向量,并进行归一化所得。如下图所示:
同时优化这个目标,就能让L(ω)在偏离公平引导向量p=(1,1,...,1)的时候,将它拉回去,促进了模型的公平性。
因此,文章的优化目标变成:
至此,我们就可以用MGDA来求解上述多目标优化问题:
-
首先创建一个矩阵Q,它由上述的m+1个目标函数的梯度拼接而成。
-
然后求解下述dual problem得到λ:
-
最后模型更新方向 就等于:
根据对偶理论,上述优化问题是下面这个优化问题的dual problem。
因此,上面求解得到的 满足:
-
如果 已经是Pareto critical,则 ;
-
如果 还不是是Pareto critical,则:
因此, 不单对于各用户的loss而言是common descent的,它对于新增的目标而言也是common descent的,因此它能够促进模型公平性。
步长线性搜索
文章还给出了一套适用于FL的线性搜索步长的策略,很容易实施。简单来说,就是server在计算出公平更新方向 并发给clients后,clients用这个 以及一系列从大到小的步长,分别进行local training,并评估结果的loss,发给server。server在收集所有的loss反馈后,得到一个最大的、并且相较于旧模型在精度和公平性上变得更好的步长,来作为最终正式更新模型的参数。值得注意的是,由于这个额外的通信过程中,只需要来回传标量值,因此它并不会增加很多额外的通信开销。
我觉得这是很有意义的一件事情,步长线性搜索策略能让FL不再依赖调节learning rate这个超参数。并且它是目前而言绝无仅有的可行的方法。我分析原因大概是这样:想要确保模型能收敛的同时,用一个更好、收敛性更快的步长,则最好要保证模型的更新方向是common descent的。但此前的FL算法并不能保障梯度下降,因此,如果贸然地加大更新步长(learning rate),则无法保障模型收敛。而FedMDFG因为能计算出一个common descent且公平的梯度下降方向,所以可以用line search搜索出恰当的步长来更新模型。反之,如果用的是传统的FL算法,比如FedAvg,它们不能确保更新方向是common descent的,如果强行用line search,则需要结合更新方向,设计一套更加复杂的步长搜索策略,才能保障收敛性。
实验
文章还给出了算法的收敛性严格证明,并且在多个场景下进行了实验。具体的实验设置这里就不赘述了,这里贴其中一个实验图,可以直观看到FedMDFG显著地提升了算法的收敛速度,以及收敛效果。并且在公平性上也明显好很多。文章还给出了复现的代码。
后记
将多目标优化与联邦学习结合,确实是一个很令人信服的方法。它能够具有理论保障地改善联邦学习的公平性,使得联邦学习在non-IID的场景下表现更佳。并且引入步长的线性搜索策略,能让联邦学习更具备落地实用性。
我目前已成功将公平联邦学习算法应用到智能电网的负荷预测和非侵入式负荷监测中,取得了令人满意的效果。后续我会继续关注这一块。希望这篇文章在帮助自己记录学习点滴之余,也能帮助大家!