直接策略优化

2024-03-10

前段时间看了一下去年NeruIPS的一篇关于RLHF的文章,这篇文章介绍了一种新的算法——直接偏好优化(Direct Preference Optimization, DPO),它用于优化大规模非监督语言模型(Language Models, LMs),使其行为更符合人类的偏好。随着大语言模型技术的不断发展,对模型的对齐也越来越重要。大语言模型虽然能够学习广泛的世界知识和一些推理技能,但由于其训练的非监督性质,精确控制它们的行为非常困难。目前的方法通过收集人类对模型生成内容质量的标签,然后使用基于人类反馈的强化学习(Reinforcement Learning from Human Feedback, RLHF)对其进行微调,以符合这些偏好。RLHF是一个复杂且通常不稳定的过程,它首先拟合一个反映人类偏好的奖励模型,然后使用强化学习对大型非监督LM进行微调,以最大化这个估计的奖励,同时不过分偏离原始模型。可以看到,RLHF的实现需要多个模型的参与,这也使得实现其变得非常困难。为了简化这一过程,文章提出了DPO算法,它通过一种新的奖励模型参数化,可以直接提取相应的最优策略,而不需要复杂的强化学习过程。DPO算法稳定、高效且计算成本较低,它通过简单的分类损失来优化策略,从而符合人类的偏好。

本文主要注重对文章如何推到得到一个新的目标函数进行理论解释。

Formulate the optimization problem

和之前RLHF的方法一样,DPO使用了常用的Bradley-Terry (BT)模型来量化偏好。具体过程如下:

对于一对偏好数据$(x, y_w, y_l)$,其中$y_w$为对于$x$偏好的回答,$y_l$为对于$x$不偏好的回答,模型认为这两个回答背后都有一个潜在的奖励值$r^{\star}(x,y)$,并且偏好是由下面的概率决定的:

$$\begin{aligned} p^{\star}(y_1\succ y_2\mid x)=\frac{\exp\left(r^\star(x,y_1)\right)}{\exp\left(r^{\star}(x,y_1)\right)+\exp\left(r^{\star}(x,y_2)\right)}. \end{aligned}$$

对于一组从上面的概率采样出的偏好数据\(\mathcal{D}=\left\{x^{(i)},y_w^{(i)},y_l^{(i)}\right\}_{i=1}^N\),我们就有如下的负似然损失:

$$\begin{aligned} \mathcal{L}_R(r_\phi,\mathcal{D})=-\mathbb{E}_{(x,y_{w},y_{l})\sim\mathcal{D}}\left[\log\sigma(r_\phi(x,y_w)-r_\phi(x,y_l))\right] \end{aligned}$$

在传统RLHF中,算法将最小化该损失来得到对$r^{\star}$的估计$r_\phi$,并且利用该奖励模型在RL阶段优化下面这个目标函数:

$$\begin{aligned} \max_{\pi_\theta}\mathbb{E}_{x\sim\mathcal{D},y\sim\pi_\theta(y|x)}\begin{bmatrix}r_\phi(x,y)\end{bmatrix}-\beta\mathbb{D}_{\mathbf{KL}}\begin{bmatrix}\pi_\theta(y\mid x)\parallel\pi_{\mathrm{ref}}(y\mid x)\end{bmatrix} \end{aligned}$$

其中第一项是为了最大化模型输出的奖励,后一项是为了不让训练后的模型和原来的模型有太大的偏移。

Solve the optimal solution

传统的RLHF对上述的目标函数使用PPO等算法来逼近最优解,而在DPO中,作者对上述目标函数进行了求解。过程如下:

$$ \begin{aligned} \max_{\pi}\mathbb{E}_{x\sim\mathcal{D},y\sim\pi}& \begin{bmatrix}r(x,y)\end{bmatrix}-\beta\mathbb{D}_{\text{KL}}\begin{bmatrix}\pi(y|x)&\mid\mid\pi_{\text{ref}}(y|x)\end{bmatrix} \\ &=\max_\pi\mathbb{E}_{x\thicksim\mathcal{D}}\mathbb{E}_{y\thicksim\pi(y|x)}\left[r(x,y)-\beta\log\frac{\pi(y|x)}{\pi_{\mathsf{ref}}(y|x)}\right] \\ &=\min_\pi\mathbb{E}_{x\sim\mathcal{D}}\mathbb{E}_{y\sim\pi(y|x)}\left[\log\frac{\pi(y|x)}{\pi_{\mathrm{ref}}(y|x)}-\frac1\beta r(x,y)\right] \\ &=\min_\pi\mathbb{E}_{x\sim\mathcal{D}}\mathbb{E}_{y\sim\pi(y|x)}\left[\log\frac{\pi(y|x)}{\frac1{Z(x)}\pi_{\text{ref}}(y|x)\exp\left(\frac1\beta r(x,y)\right)}-\log Z(x)\right] \end{aligned} $$

其中,

$$\begin{aligned} Z(x)=\sum_y\pi_\text{ref}(y|x)\exp\left(\frac1\beta r(x,y)\right) \end{aligned}$$

当我们进行如下定义后

$$\begin{aligned} \pi^*(y|x)=\frac1{Z(x)}\pi_\text{ref}(y|x)\exp\left(\frac1\beta r(x,y)\right) \end{aligned}$$

便得到了

$$\begin{aligned} \min_\pi\mathbb{E}_{x\sim\mathcal{D}}\left[\mathbb{E}_{y\sim\pi(y|x)}\left[\log\frac{\pi(y|x)}{\pi^*(y|x)}\right]-\log Z(x)\right]=\\\min_\pi\mathbb{E}_{x\sim\mathcal{D}}\left[\mathbb{D}_{\mathrm{KL}}(\pi(y|x)\parallel\pi^*(y|x))-\log Z(x)\right] \end{aligned}$$

可以很显然地看出,为了是整个最小,KL散度需要为0,也就得出了最优解

$$ \pi(y|x)=\pi^*(y|x)=\frac1{Z(x)}\pi_{\text{ref}}(y|x)\exp\left(\frac1\beta r(x,y)\right) $$

得到最优解后,我们也可以发现这个解是非常难以估算的,这也是为什么强化学习在这个地方被引入的原因。

DPO Loss

将上面得到的最优解进行巧妙地变形,得到

$$ r(x,y)=\beta\log\frac{\pi_r(y\mid x)}{\pi_{\mathrm{ref}}(y\mid x)}+\beta\log Z(x). $$

再将得到的这个表达式代入之前所说的奖励偏好模型中,便可以得到

$$ p^*(y_1\succ y_2\mid x)=\frac1{1+\exp\left(\beta\log\frac{\pi^*(y_2\mid x)}{\pi_{\mathrm{ref}}(y_2\mid x)}-\beta\log\frac{\pi^*(y_1\mid x)}{\pi_{\mathrm{ref}}(y_1\mid x)}\right)} $$

可以看到,因为偏好只依赖于两者之差,表示prompt x的Z(x)被消除了。像之前一样,我们便得到了DPO的Loss函数:

$$ \mathcal{L}_{\mathrm{DPO}}(\pi_\theta;\pi_{\mathrm{ref}})=-\mathbb{E}_{(x,y_w,y_l)\thicksim\mathcal{D}}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\mathrm{ref}}(y_w\mid x)}-\beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\mathrm{ref}}(y_l\mid x)}\right)\right] $$

求导,得到

$$ \begin{aligned} &\nabla_\theta\mathcal{L}_{\mathrm{DPO}}(\pi_\theta;\pi_{\mathrm{ref}})=\\&-\beta\mathbb{E}_{(x,y_w,y_l)\thicksim\mathcal{D}}\left[\quad{\sigma(\hat{r}_\theta(x,y_l)-\hat{r}_\theta(x,y_w))}\quad\left[\quad{\nabla_\theta\log\pi(y_w\mid x)}-{\nabla_\theta\log\pi(y_l\mid x)}\right]\right] \end{aligned} $$

其中,

$$ \hat{r}_\theta(x,y)=\beta\log\frac{\pi_\theta(y|x)}{\pi_{\mathrm{ref}}(y|x)} $$

可看出,这个隐含的奖励函数只由策略函数决定。

DPO Outline

DPO算法的一般流程如下:1) 对于每个提示$x$,从参考策略$\pi_{\text{ref}}(\cdot\mid x)$中采样生成回答$y_1$和$y_2$,使用人类偏好进行标注,以构建离线偏好数据集$\mathcal{D}=\left\{x^{(i)},y_w^{(i)},y_l^{(i)}\right\}_{i=1}^N$;2) 对给定的参考策略$\pi_{ref}$、数据集$\mathcal{D}$和期望的$\beta$值,优化语言模型$\pi_\theta$以最小化DPO损失函数。