0%

论文阅读 [粗读]-Improved Denoising Diffusion Probabilistic Models

这篇论文探索了 DDPM 对于 NLL 指标效果不好的原因,并且从实际训练的角度给出了很多可行的改进。

这篇论文是 21 年的 NIPS,作者来自 OpenAI,其实就是后面 GLIDE 的作者。我理解大概是 OpenAI 看到了 DDPM 的论文,然后用” 财大气粗 “的方式来了一波复现和改进。这篇论文其实更偏向于分析性文章。另外,这篇论文的方法的代码和复现性很好。 > 发现一个讲 IDDPM 代码的是视频

本文大概探索了这个几个问题:

  • DDPM 的训练 object
  • βt 的选取
  • 训练的采样方法
  • DDPM 模型的可扩展性

Denoising Diffusion Probabilistic Models

这里重新梳理了一遍 DDPM,我就说几个新颖的地方

Definition

作者写出了 LVLB 的形式

可以看出中间的项都是 KL 散度,可以通过前后两个高斯分布的均值和方差直接计算出来:

  • 正常的 DDPM 模型 pθ 模型的均值由模型得出,而方差是确定性的
  • q(xt1|xt,x0) 的均值方差均可以被代数表示

Training in Practice

这一部分,作者说到了三种模型目标:

  • 生成 μθ(xt,t),预测 xt1 的均值
  • 预测 x0, 采样时根据公式 10 线性拟合出 xt1 进行去噪
  • 预测加噪时的噪声 ϵ, 去噪根据下式和公式 10 拟合出 xt1

x0=1αt(xtβt1αtϵ)

DDPM 在训练过程中使用最后一种,表现最好

作者把 LVLB 进行了一波 reweight,得出下式作为训练 loss 进行训练 Lsimple=Et,x0,ϵ||ϵϵθ(xt,t)||2

同时,作者说到在 predict 阶段,使用下面两种方差效果很接近

  • σt2=βt
  • σt2=β~t

Improving the Log-likelihood

这里作者提到 Log-likelihood 这个指标在图像生成也很重要,但 DDPM 在这个指标的效果不好 (虽然 FID 不错)。

learning Σθ(xt,t)

作者思考为什么 DDPM 的两个方差效果接近

  • 图一可以发现两种方差只在 t 很小的时候有差距,别的时候基本一致

  • 图二可以发现 t 小的时候对 NLL 的影响是最大的

综合上面观点,作者觉得 NLL 效果差的重要原因就是对方差的估计有偏差 (上面两种 σ 都是估计的形式)。因此作者想要找到更好的方差,作者希望通过网络拟合它,看做两种 σ 的拟合 Σθ(xt,t)=exp(vlogβt+(1v)logβ~t) 其中 v 是可学习的参数。作者在这里没有对 v 的范围做限制,但实际训练还是控制在 0-1 内。说明 DDPM 作者选取的两个估计还是很靠谱的

想要学习这个 v,就只能把 loss 直接表示成 LVLB, 才能反向传播到 v,因此作者选择了最终的 loss 函数 Lhybird=Lsimple+λ·LVLB 其中 λ=0.001

Improving the Noise Schedule

这一部分,作者探索了 β 的选取

DDPM 使用 linear 来选取,作者采用新的 cosine 的公式选取 αt=f(t)f(0),f(t)=cos(t/T+s1+s·π2)2βt=min(0.999,1αtαt1) 这中选取方式比 linear 更稳定。同时作者指出,类似 cosine 的这种形状的,最后效果都差不多

Reducing Gradient Noise

第三个改进是让训练更稳定

作者发现直接优化 LVLB 的训练非常不稳定,比起优化 Lsimple。作者发现这是由于 LVLB 参数对梯度更分散。

进一步分析,发现是在训练时平均的选取 t 导致训练额外的噪声,实际上的选取应该是: LVLB=Etp(t)[Ltpt],whereptE[Lt2],pt=1 对于每个 t,在训练过程中都可以维护 Lt,只要训练开始一段时间以后。

  • 先平均选取 t,直到每个 t 都有 10 个 Lt 的历史数据为止
  • 接下来,按照前 10 个历史数据的平均作为目前 Lt 情况的估计,进而带入公式来采样下一个 t 的选取

最终的训练变得很稳定,如上图的绿线

Improving Sampling Speed

这一部分提到了可以加速采样

这里主要对比的是 DDIM,方法也差不多:采样一个 1,2,...T 的长度为 K 子序列进行 sample

可以发现,DDIM 的没有 IDDPM 好:

  • DDIM 没有做 IDDPM 里这些改进。
  • 作者提到没法做改进,这是因为 DDIM 的公式里的方差项没了

## Scaling Model Size

这一部分是分析 DDPM 模型在加大 size 以后会不会变得更好,因此作者直接加参数

  • FID 得分和模型的参数 (计算量) 基本呈线性关系。满足 power law,因此扩大规模是个好选择
  • NLL 得分不满足线性关系,扩大规模效果不好,可能原因有:
    • 很快过拟合
    • loss 不稳定
    • 作者没有直接优化降低 NLL 用的 LVLB,而是一个 Lhybrid

我的思考

  • 总体而言,这篇文章更像是分析性文章,复习笔记。主要方法还是 DDPM。
  • OpenAI 研究员解决问题的眼光真的很高明呀,尤其是关于改变 t 选取和分析 scale 能力这一部分。
Powered By Valine
v1.5.2