今天讲讲强化学习里的经典算法 PPO,也是现在 Gym 库里默认的强化学习算法,最后再讲讲 RLHF 中的 PPO 算法是怎么算的。参考
Proximal Policy Optimization Algorithms
Trust Region Policy Optimization
A (Long) peek into Reinforcement Learning
这是 2017 年的论文,作者团队是 openAI。讲 PPO,必须先说明一大堆前置知识。我先简单说说前置知识,不保证正确。
前置知识
强化学习领域就是让智能体 agent 和环境 e 一直交互,最终强化智能体。
强化学习的一次交互为多个 state,action 链式连接。
强化学习优化的对象叫做 return。环境对于每个 action 给出的反馈叫做 reward。return 是所有未来 state 的 reward 指数衰减
- return 是个未来函数,只有跑完后面才知道前面的 reward。我们只能用模型模拟或者求期望
- 在很多情况下,环境也不能给出密集的 reward,很有可能只有交互结束才能给一个 reward (比如下棋),这个领域叫 sparse reward
是时间衰减函数,说明未来的收益要衰减。这是因为未来有不确定性,所以未来的回报没有现在的重要。至于为什么用指数衰减,这是未了后面的数学推导更简单
强化学习分为两大类方法,基于动作的 policy method 和基于评估的 value method
value method
这是指用一个模型来拟合 value。要知道 value,要先知道 Q (s,a)
Q (s,a) 是衡量模型在状态 s 做出动作 a 到底好不好,值就等于未来状态 return。$Q\pi(s_t,a) = return{t+1}$
Q (s,a) 包含动作 a,不好。把 a 积分掉 (或者离散情况下就是
另外,
经典的 Q-learning 用模型去学习 Q (s,a),怎么学?通过贝尔曼方程,根据 Q 的定义
- $Q^
s_t a_t Q^(s_t,a_t)$ 了
这衍生出了 TD-error (temporal difference): 等式右边的 V 估计更准确,因为右边计算用到了真实的回报
用一个模型拟合
在状态 $st
\epsilon-\text {greedy} \epsilon Q\theta (s_t,a_t) a_t$ 采取动作- 执行 $at
s{t+1} R_t$ - 对于 $Q\theta (s_t,a_t)
R_t + \gamma \max{a\in A} Q\theta(s{t+1},a)$。用做 target,然后用 mse loss 优化就行
Q-learning 里的最经典的算法叫做 deep-Q Network,大概做了几个改进:
- 经验重放:就是说实际在环境里跑交互太慢了,我可以提前跑,跑完把状态转移对 $(st,a_t,s{t+1},R_t)$ 存下来,然后再多次更新、多次利用
- periodically update target,就是我算 loss 的时候那个 target 里的
不用现在的参数,而是用一段时间之前的参数,每隔一段时间同步一次。这是为了减小 loss 尖峰带来的影响,让训练更稳定
policy method
前面提到 Q-learning 等方法学习 Q,然后采取动作是用
优化的目标当然就是让好的动作概率更大,坏的动作概率小
经典的 REINFORCE 算法就是一大堆数学推导,最后使得
很优雅,用了 log 换元。但是上面的式子要求期望,还是有两个
上面的 reinforce 算法有个问题:上面的梯度受到环境影响大,方差大不稳定,要找个东西减小方差,这叫做 baseline 方法。
一个思路就是用上面提到的 advantage
actor-critic
上面说我们蒙特卡洛模拟,然后反着求 Q,这个 Q 的估计显然是不准确的,有解决办法吗?我们搞 ai 就是 working is all you need:我们用另一个模型拟合
经典的 actor-critic 是用另一个模型来拟合
这种双模型的方法是现在 RL 的最常用的方法。简单区分几种方法,来自知乎:
actor 就好比是你,critic 就好比你妈。你做一件事情,比如抓蜜蜂,结果被蜇疼了,下次你再抓蜜蜂的概率就减小了,这个就是 policy gradient。你刚手伸出去要去抓蜜蜂,你妈就说,别抓,十有八九会被蜇疼。你听了后停止了抓蜜蜂,并且下次抓蜜蜂的概率减小了,这个就是 actor-critic。你每次看见蜜蜂的时候都问你妈,抓蜜蜂好还是不抓蜜蜂好?你妈说不抓蜜蜂好,通常你听你妈的话就不抓蜜蜂了,偶尔心情不好的时候(以
的概率)还要去抓蜜蜂,这个就是 Q-learning。那么妈妈是怎么知道抓蜜蜂会疼的?当然她也是抓过蜜蜂的(Q-value update)
PPO
下面正式开始 PPO 讲解。PPO 基于更早的一个叫 TRPO (trust region) 的算法。可以理解成一个 actor-critic 方法,大概做了几个改进。
Trust region 是一篇数学很多的论文,大概讲的事情是
上面的估计是很不稳定的,要想办法用另一个让它变稳定:用一个叫做的方法,可以用另一个分布来估计原始的分布。
这样这个分布的采用使用
我们希望两个分布不要差太远,这是因为上式可以一阶等价于
这就是论文里说的代理训练目标
PPO 把这个分式重新命名了一下
把 TRPO 中的 loss 称为
然后提出了两个变体
这两个变体都能使得训练更稳定。其中变体二的超参数
如果 如果
除此之外,模型的更新就和 actor-critic 方法没什么别的区别了。
它使用了经验重放,交互用的链可以更新多次 (4 epoch) 再同步参数
他的更新用的不是 Q, 而是 Advantage
,这是为了减少方差,稳定训练。
- 其中
是需要训练的第二个模型。然后训练流程是先用 跑完一整轮,得到 再倒过来算出所有的 ,再把 拿过来训练 - V 的值同样通过 TD-error 的 mse-loss 进行更新
注意计算要用到那个慢更新的优化,用的是一段时间以前的参数算的 V
RLHF 中的 PPO
上面讲完了 PPO,那 RLHF 中的 PPO 又是怎么算的呢?这一部分我是根据代码阅读的结果来讲的,实际上可能每边都有每边的实现,我讲其中一种。
RLHF 算法分 3 个大块:SFT,reward model training, RLHF。我这里讲最后一个部分,也就是说我们已经有了一个 reward model 可以给任何一个 (query, response) 对输出一个
what is model
首先,PPO 明显是需要两个模型 policy model, value model。怎么实现呢?
- policy model 其实就是语言模型自己,我们把一次 response 生成的每个 token 生成视为一个状态转移,然后 action 就是对应的 token t。
- value model 这里和语言模型是共享参数的,我们在最后语言模型的 hidden state 层后面加一个 model_dim -> 1 的映射(正常是 model_dim->vocab_size->softmax 映射到词表),所以所有的状态下都有一个 float 的输出,把这个当做
- 至于两个模型要不要共享参数,是个实验问题。我只能说,如果不共享参数的话,首先需要存两份,然后每次交互都要跑两次前向,反向传播也是,这个在尤其是大模型场景下,对算力的消耗是多很多的。
我们需要记载一个最开始的模型的参数,这个开始是指原始的 LLM 语言模型的参数
advantage
然后就是 reward 的计算
- 对于正常的 token,除了最后一个 token 以外,reward 就是引入模型分布和原始分布的 KL 散度 $KL (\pi\theta(r|q), \pi{\theta_{start}}(r|q))$
- 对于最后一个 token,额外引入最终 response 的 reward,就是 reward 和 KL 散度的和。
KL 散度怎么算?其实就是
有了所有未知的 reward 以后,就可以用 PPO 的方法倒过来算出所有位置的 advantage $At
value model loss
接下来计算 Value 的 loss,使用
policy model loss
接下来计算 policy 的梯度。首先是 openAI 定义的那个
然后要注意!这里的 L 本来按 policy gradient 算法是要最大化的,所以在 AI 框架实现中最后要取个符号作为 loss (或者就直接用负的,然后换成
最终得到最终的优化目标
训练框架
- 1. 初始化模型
,用一个预训练 LLM 来初始化。同时初始化 value model, 其实就是加个 linear 层。 - 2. 用现在的
跑一堆 response 数据。query 就是采样自你的 query 数据集 - 3. 把 repsonse 送进 reward model 打分。注意 reward model 要锁参,然后这里记得 detach。
- 4. 存下来 2,3 出来的 (query,response,reward) 对,记录现在的
为 - 5. 对 4 的数据做几个 epoch 的更新 (比如 4 个),每个 epoch
- 按上述方法算出
- 不管是不是分离的 policy、value model,反正进行梯度更新
- 按上述方法算出
- 6. 删除刚才的跑出来的数据集 4,节省空间。回到 2
v1.5.2