0%

论文阅读 [粗读]- 强化学习和 RLHF 中的 PPO 算法

今天讲讲强化学习里的经典算法 PPO,也是现在 Gym 库里默认的强化学习算法,最后再讲讲 RLHF 中的 PPO 算法是怎么算的。参考

Proximal Policy Optimization Algorithms

Trust Region Policy Optimization

A (Long) peek into Reinforcement Learning

这是 2017 年的论文,作者团队是 openAI。讲 PPO,必须先说明一大堆前置知识。我先简单说说前置知识,不保证正确。

前置知识

强化学习领域就是让智能体 agent 和环境 e 一直交互,最终强化智能体。

强化学习的一次交互为多个 state,action 链式连接。

强化学习优化的对象叫做 return。环境对于每个 action 给出的反馈叫做 reward。return 是所有未来 state 的 reward 指数衰减

returni=t=iTγtiRewardt
  • return 是个未来函数,只有跑完后面才知道前面的 reward。我们只能用模型模拟或者求期望
  • 在很多情况下,环境也不能给出密集的 reward,很有可能只有交互结束才能给一个 reward (比如下棋),这个领域叫 sparse reward
  • γ 是时间衰减函数,说明未来的收益要衰减。这是因为未来有不确定性,所以未来的回报没有现在的重要。至于为什么用指数衰减,这是未了后面的数学推导更简单

强化学习分为两大类方法,基于动作的 policy method 和基于评估的 value method

value method

这是指用一个模型来拟合 value。要知道 value,要先知道 Q (s,a)

Q (s,a) 是衡量模型在状态 s 做出动作 a 到底好不好,值就等于未来状态 return。$Q\pi(s_t,a) = return{t+1}$

Q (s,a) 包含动作 a,不好。把 a 积分掉 (或者离散情况下就是 ) 以后得到 $V\pi(S) = \sum{a} \pi(a|s)Q(s,a)$

另外,Adavantagea=Q(s,a)V(s) 专门表示采取动作 a 带来了多少额外收益

经典的 Q-learning 用模型去学习 Q (s,a),怎么学?通过贝尔曼方程,根据 Q 的定义

V(St)=Rt+1+γV(St+1)...V(ST1)=RTQ(st,at)=Rt+γmaxaAQ(st+1,a)
  • $Q^Qs_ta_tQ^(s_t,a_t)$ 了

这衍生出了 TD-error (temporal difference): 等式右边的 V 估计更准确,因为右边计算用到了真实的回报 Rt+1。通过这个归纳偏置就可以学习 Q 了:

  • 用一个模型拟合 Qθ(st,at)

  • 在状态 $st\epsilon-\text {greedy}\epsilonQ\theta (s_t,a_t)a_t$ 采取动作

  • 执行 $ats{t+1}R_t$
  • 对于 $Q\theta (s_t,a_t)R_t + \gamma \max{a\in A} Q\theta(s{t+1},a)$。用做 target,然后用 mse loss 优化就行

Q-learning 里的最经典的算法叫做 deep-Q Network,大概做了几个改进:

  • 经验重放:就是说实际在环境里跑交互太慢了,我可以提前跑,跑完把状态转移对 $(st,a_t,s{t+1},R_t)$ 存下来,然后再多次更新、多次利用
  • periodically update target,就是我算 loss 的时候那个 target 里的 Q 不用现在的参数,而是用一段时间之前的参数,每隔一段时间同步一次。这是为了减小 loss 尖峰带来的影响,让训练更稳定

policy method

前面提到 Q-learning 等方法学习 Q,然后采取动作是用 ϵgreedy 解码。这里的思路是说我不学 Q,我学习一个 agent π(a|st) 输入一个状态,我返回在这个状态下做所有动作的概率分布。

优化的目标当然就是让好的动作概率更大,坏的动作概率小

J(θ)=sS(aAπ(a|s)Qπ(s,a))

经典的 REINFORCE 算法就是一大堆数学推导,最后使得

ΔJ(θ)=Eπθ[Δlnπ(a|s)Qπ(s,a)]

ΔJ(θ)θ 的函数,可以看做 loss,用梯度上升来做参数更新

很优雅,用了 log 换元。但是上面的式子要求期望,还是有两个 ,显然没法直接算。一个简单的方法就是用蒙特卡洛近似,就是在环境里跑完一大堆状态转移,然后倒着算回来所有的 Qπ(s,a)(这里视为这个是数字,和 θ 无关,是一个近似)。再拿上面的式子就能训练 π

上面的 reinforce 算法有个问题:上面的梯度受到环境影响大,方差大不稳定,要找个东西减小方差,这叫做 baseline 方法。

一个思路就是用上面提到的 advantage=Q(s,a)V(s) 代替 Q (s,a)。

actor-critic

上面说我们蒙特卡洛模拟,然后反着求 Q,这个 Q 的估计显然是不准确的,有解决办法吗?我们搞 ai 就是 working is all you need:我们用另一个模型拟合 Vθ(S)。这种两个模型的方法就是 actor-critic。一个演员采取行动,一个裁判进行打分。

经典的 actor-critic 是用另一个模型来拟合 Q(s,a)。Q 的参数更新就用上面讲到的 Q-learning 里的 temporal difference (TD-error),然后 π 的参数更新用模型输出的 Q 来算 J(θ)。注意这里要把 Q 的输出的梯度 detach 掉,就是不要反向传播到 Q 的参数里。

这种双模型的方法是现在 RL 的最常用的方法。简单区分几种方法,来自知乎:

actor 就好比是你,critic 就好比你妈。你做一件事情,比如抓蜜蜂,结果被蜇疼了,下次你再抓蜜蜂的概率就减小了,这个就是 policy gradient。你刚手伸出去要去抓蜜蜂,你妈就说,别抓,十有八九会被蜇疼。你听了后停止了抓蜜蜂,并且下次抓蜜蜂的概率减小了,这个就是 actor-critic。你每次看见蜜蜂的时候都问你妈,抓蜜蜂好还是不抓蜜蜂好?你妈说不抓蜜蜂好,通常你听你妈的话就不抓蜜蜂了,偶尔心情不好的时候(以 ϵ 的概率)还要去抓蜜蜂,这个就是 Q-learning。那么妈妈是怎么知道抓蜜蜂会疼的?当然她也是抓过蜜蜂的(Q-value update)

PPO

下面正式开始 PPO 讲解。PPO 基于更早的一个叫 TRPO (trust region) 的算法。可以理解成一个 actor-critic 方法,大概做了几个改进。

Trust region 是一篇数学很多的论文,大概讲的事情是

J(θ)=sS(aAπ(a|s)At)

上面的估计是很不稳定的,要想办法用另一个让它变稳定:用一个叫做的方法,可以用另一个分布来估计原始的分布。

J(θ)Eπθold[πθ(at|st)πθold(at|st)At]

这样这个分布的采用使用 θold 就在数学上成立了。然后训练时我们每隔一段时间更新一次 θ,在这段时间内用之前跑出来的状态转移对进行训练就行。极大地提高了训练效率。

我们希望两个分布不要差太远,这是因为上式可以一阶等价于 J(θ),因此需要分布很接近才行。因此可以再给 loss 加一个 KL 散度惩罚。

Et[πθ(at|st)πθold(at|st)AtKL(πθ(·|st),πθold(·|st))]

这就是论文里说的代理训练目标

PPO 把这个分式重新命名了一下

rt(θ)=πθ(at|st)πθold(at|st)

把 TRPO 中的 loss 称为 LCPI=E[rt(θ)At]

然后提出了两个变体

LCLIP=E[min(rt(θ)At,clip(rt(θ,1ϵ,1+ϵ)At)]LKLPEN=Et[rt(θ)AtβKL(πθ(·|st),πθold(·|st))]

这两个变体都能使得训练更稳定。其中变体二的超参数 β 是自适应的:

  • ββ2 如果 KL/1.5>KL
  • β2β 如果 KL1.5<KL

除此之外,模型的更新就和 actor-critic 方法没什么别的区别了。

  • 它使用了经验重放,交互用的链可以更新多次 (4 epoch) 再同步参数 θold

  • 他的更新用的不是 Q, 而是 Advantage At,这是为了减少方差,稳定训练。

At=δt+(λγ)δt+1+...+(λγ)Tt+1δT1δt=rt+γV(st+1)V(st)δt=rt+γV(st+1)V(st)
  • 其中 V(s) 是需要训练的第二个模型。然后训练流程是先用 π 跑完一整轮,得到 V(s1),V(sT) 再倒过来算出所有的 At,再把 At 拿过来训练 π
  • V 的值同样通过 TD-error 的 mse-loss 进行更新 Vtarget 注意计算要用到那个慢更新的优化,用的是一段时间以前的参数算的 V

RLHF 中的 PPO

上面讲完了 PPO,那 RLHF 中的 PPO 又是怎么算的呢?这一部分我是根据代码阅读的结果来讲的,实际上可能每边都有每边的实现,我讲其中一种。

RLHF 算法分 3 个大块:SFT,reward model training, RLHF。我这里讲最后一个部分,也就是说我们已经有了一个 reward model 可以给任何一个 (query, response) 对输出一个 r[1,1] 的讲理,越高说明越好越符合人的期望输出。这个就是假的” 环境”,可以给反馈

what is model

首先,PPO 明显是需要两个模型 policy model, value model。怎么实现呢?

  • policy model 其实就是语言模型自己,我们把一次 response 生成的每个 token 生成视为一个状态转移,然后 action 就是对应的 token t。
  • value model 这里和语言模型是共享参数的,我们在最后语言模型的 hidden state 层后面加一个 model_dim -> 1 的映射(正常是 model_dim->vocab_size->softmax 映射到词表),所以所有的状态下都有一个 float 的输出,把这个当做 V(st)
  • 至于两个模型要不要共享参数,是个实验问题。我只能说,如果不共享参数的话,首先需要存两份,然后每次交互都要跑两次前向,反向传播也是,这个在尤其是大模型场景下,对算力的消耗是多很多的。

我们需要记载一个最开始的模型的参数,这个开始是指原始的 LLM 语言模型的参数 πstart,这是为了让 RLHF 阶段的更新不要太多,不要丢失语言模型原本的语言、推理能力

advantage

然后就是 reward 的计算

  • 对于正常的 token,除了最后一个 token 以外,reward 就是引入模型分布和原始分布的 KL 散度 $KL (\pi\theta(r|q), \pi{\theta_{start}}(r|q))$
  • 对于最后一个 token,额外引入最终 response 的 reward,就是 reward 和 KL 散度的和。

KL 散度怎么算?其实就是 logP(|s) 的差值。其实就是你的模型的输出 logit [seqlength,vocabsize]。过完 log-softmax 层,然后把你实际取的 token 的对应位置的值取出来,得到一个 logprobs [seqlength,float]。把现在的参数 θ 和之前参数 θstart 的对应 logprobs 都取出来,然后直接做减法就行,得出来 seq-length 长度的序列,就是对应每个位置的 KL 散度 reward

有了所有未知的 reward 以后,就可以用 PPO 的方法倒过来算出所有位置的 advantage $AtA_tValueoldV{old}(s_t)\theta$ 的输出,这是为了训练更稳定

At=δt+(λγ)δt+1+...+(λγ)Tt+1δT1δt=rt+γVold(st+1)Vold(st)

value model loss

接下来计算 Value 的 loss,使用 lossV=||V(st)A(st)||2 作为 loss,这里可以进行一下 clip 使得更新不要太大

lossVclip=||clip(V(st),Voldϵ1,Vold+ϵ1)A(st)||2lossV=lossV+lossVclip2

policy model loss

接下来计算 policy 的梯度。首先是 openAI 定义的那个 rt(θ)。因为是除法,相当于 lo 概率的减法再 e 指数。直接用之前算 KL 时那个 log-softmax 的输出减法再指数就行,得出来也是 seq-length 长度的链。

LCLIP=E[min(rt(θ)At,clip(rt(θ,1ϵ2,1+ϵ2)At)]losspg=LCLIP

然后要注意!这里的 L 本来按 policy gradient 算法是要最大化的,所以在 AI 框架实现中最后要取个符号作为 loss (或者就直接用负的,然后换成 max)

最终得到最终的优化目标

loss=losspg+α·lossV

训练框架

  • 1. 初始化模型 θ,用一个预训练 LLM 来初始化。同时初始化 value model, 其实就是加个 linear 层。
  • 2. 用现在的 θ 跑一堆 response 数据。query 就是采样自你的 query 数据集
  • 3. 把 repsonse 送进 reward model 打分。注意 reward model 要锁参,然后这里记得 detach。
  • 4. 存下来 2,3 出来的 (query,response,reward) 对,记录现在的 θθold
  • 5. 对 4 的数据做几个 epoch 的更新 (比如 4 个),每个 epoch
    • 按上述方法算出 losspg,lossV
    • 不管是不是分离的 policy、value model,反正进行梯度更新
  • 6. 删除刚才的跑出来的数据集 4,节省空间。回到 2
Powered By Valine
v1.5.2