DDPM 数学推导
- 变分下界 (ELBO):将最大化对数似然转化为最小化变分下界,明确我们的目标是匹配逆向分布。
- 后验分布 $q(x_{t-1} \mid x_t, x_0)$:推导扩散过程逆向步骤的“真值”,为模型提供学习的靶子。
- 损失函数:结合前两步,将复杂的分布匹配转化为简单的噪声预测 (MSE)。
1. 变分下界 (Variational Lower Bound / ELBO)
我们的终极目标是训练一个模型来生成数据,即最大化 $p_\theta(x_0)$。但由于直接积分计算边缘概率极其困难,我们必须退而求其次,寻找一个可优化的下界(ELBO)。
1.1 最大似然与 Jensen 不等式
根据边缘概率定义:
\[p_\theta(x_0) = \int p_\theta(x_{0:T}) dx_{1:T}\]其中 $x_{1:T}$ 是潜变量。引入近似后验分布 $q(x_{1:T} \mid x_0)$:
\[\begin{aligned} \log p_\theta(x_0) &= \log \int p_\theta(x_{0:T}) dx_{1:T} \\\\ &= \log \int \frac{p_\theta(x_{0:T})}{q(x_{1:T} \mid x_0)} q(x_{1:T} \mid x_0) dx_{1:T} \\\\ &= \log \mathbb{E}_{q(x_{1:T} \mid x_0)} \left[ \frac{p_\theta(x_{0:T})}{q(x_{1:T} \mid x_0)} \right] \end{aligned}\]根据 Jensen 不等式 $f(\mathbb{E}[x]) \ge \mathbb{E}[f(x)]$(此处 $f$ 为 $\log$ 函数,是凹函数,故 $\log \mathbb{E} \ge \mathbb{E} \log$):
\[\log p_\theta(x_0) \ge \mathbb{E}_{q(x_{1:T} \mid x_0)} \left[ \log \frac{p_\theta(x_{0:T})}{q(x_{1:T} \mid x_0)} \right] = -L_{\text{VLB}}\]1.2 展开 ELBO (详细推导)
利用马尔可夫性质展开联合概率:
- 扩散过程 (Forward): $q(x_{1:T} \mid x_0) = \prod_{t=1}^T q(x_t \mid x_{t-1})$
- 逆向过程 (Reverse): $p_\theta(x_{0:T}) = p(x_T) \prod_{t=1}^T p_\theta(x_{t-1} \mid x_t)$
代入 ELBO 公式:
\[\begin{aligned} L_{\text{VLB}} &= \mathbb{E}_q \left[ - \log \frac{p(x_T) \prod_{t=1}^T p_\theta(x_{t-1} \mid x_t)}{\prod_{t=1}^T q(x_t \mid x_{t-1})} \right] \\\\ &= \mathbb{E}_q \left[ - \log p(x_T) - \sum_{t=1}^T \log \frac{p_\theta(x_{t-1} \mid x_t)}{q(x_t \mid x_{t-1})} \right] \end{aligned}\]为了利用 $x_0$ 的信息减少方差,我们需要将分母中的 $q(x_t \mid x_{t-1})$ 替换掉。 利用贝叶斯公式:
\[q(x_t \mid x_{t-1}, x_0) = \frac{q(x_{t-1} \mid x_t, x_0) q(x_t \mid x_0)}{q(x_{t-1} \mid x_0)}\]把这一项代入求和公式中:
\[\begin{aligned} L_{\text{VLB}} &= \mathbb{E}_q \left[ - \log p(x_T) - \sum_{t=1}^T \log \frac{p_\theta(x_{t-1} \mid x_t)}{\frac{q(x_{t-1} \mid x_t, x_0) q(x_t \mid x_0)}{q(x_{t-1} \mid x_0)}} \right] \\\\ &= \mathbb{E}_q \left[ - \log p(x_T) - \sum_{t=1}^T \left( \log \frac{p_\theta(x_{t-1} \mid x_t)}{q(x_{t-1} \mid x_t, x_0)} + \log \frac{q(x_{t-1} \mid x_0)}{q(x_t \mid x_0)} \right) \right] \end{aligned}\]这里我们将对数拆分为两部分:一部分是模型预测与真实后验的比值,另一部分是纯粹关于 $q$ 的比值(这一项会发生裂项相消)。
步骤 A: 裂项相消 (Telescoping Sum)
单独观察第二部分 $\sum_{t=1}^T \log \frac{q(x_{t-1} \mid x_0)}{q(x_t \mid x_0)}$ 的展开:
\[\begin{aligned} \sum_{t=1}^T \log \frac{q(x_{t-1} \mid x_0)}{q(x_t \mid x_0)} &= \underbrace{\log \frac{q(x_0 \mid x_0)}{q(x_1 \mid x_0)}}_{t=1} + \underbrace{\log \frac{q(x_1 \mid x_0)}{q(x_2 \mid x_0)}}_{t=2} + \dots + \underbrace{\log \frac{q(x_{T-1} \mid x_0)}{q(x_T \mid x_0)}}_{t=T} \\\\ &= \log q(x_0 \mid x_0) - \log q(x_T \mid x_0) \end{aligned}\]由于 $q(x_0 \mid x_0)$ 表示“已知 $x_0$ 时 $x_0$ 的概率”,这是一个确定事件,概率为 1,故 $\log q(x_0 \mid x_0) = 0$。 所以,整个求和项简化为: $- \log q(x_T \mid x_0)$。
步骤 B: 重新组合各项
将裂项相消的结果代回原式,并注意负号:
\[\begin{aligned} L_{\text{VLB}} &= \mathbb{E}_q \left[ - \log p(x_T) - \sum_{t=1}^T \log \frac{p_\theta(x_{t-1} \mid x_t)}{q(x_{t-1} \mid x_t, x_0)} - \left( - \log q(x_T \mid x_0) \right) \right] \\\\ &= \mathbb{E}_q \left[ \underbrace{\log \frac{q(x_T \mid x_0)}{p(x_T)}}_{\text{Term } L_T} + \sum_{t=1}^T \log \frac{q(x_{t-1} \mid x_t, x_0)}{p_\theta(x_{t-1} \mid x_t)} \right] \end{aligned}\]我们将 $t=1$ 的情况从求和中分离出来,因为 $L_0$ (重构) 与 $L_{t-1}$ (去噪) 的物理意义不同:
\[\text{Sum Term} = \sum_{t=2}^T \log \frac{q(x_{t-1} \mid x_t, x_0)}{p_\theta(x_{t-1} \mid x_t)} + \underbrace{\log \frac{q(x_0 \mid x_1, x_0)}{p_\theta(x_0 \mid x_1)}}_{t=1}\]利用 KL 散度定义 $D_{KL}(q \mid\mid p) = \mathbb{E}_q [\log \frac{q}{p}]$,最终整理得:
\[L_{\text{VLB}} = \underbrace{D_{KL}(q(x_T \mid x_0) \mid\mid p(x_T))}_{L_T: \text{Constant}} + \sum_{t=2}^T \underbrace{D_{KL}(q(x_{t-1} \mid x_t, x_0) \mid\mid p_\theta(x_{t-1} \mid x_t))}_{L_{t-1}: \text{Denoising Process}} \underbrace{- \log p_\theta(x_0 \mid x_1)}_{L_0: \text{Reconstruction}}\]- $L_T$: 常数项,因为 $q(x_T \mid x_0)$ 固定且 $p(x_T)$ 是纯高斯噪声,训练中可忽略。
- $L_{t-1}$: 核心项,衡量网络预测的分布 $p_\theta$ 与真实后验 $q$ 的差异。这是扩散模型去噪的关键。
- $L_0$: 最后一步的重构误差。
本章小结:我们发现优化的核心在于 $L_{t-1}$,即让模型 $p_\theta(x_{t-1} \mid x_t)$ 去尽可能接近真实的后验分布 $q(x_{t-1} \mid x_t, x_0)$。 下一步:要做到这一点,我们首先需要知道这个“真实目标” $q(x_{t-1} \mid x_t, x_0)$ 到底长什么样?这就引出了第二部分的推导。
2. 真实后验分布 $q(x_{t-1} \mid x_t, x_0)$ 的推导
上一章确定了目标是匹配 $q(x_{t-1} \mid x_t, x_0)$。幸运的是,由于前向扩散过程全是高斯分布,我们可以通过贝叶斯公式解析地求出这个分布的数学形式。
2.1 定义与记号
- 加噪公式: $x_t = \sqrt{1-\beta_t} x_{t-1} + \sqrt{\beta_t} \epsilon$,令 $\alpha_t = 1 - \beta_t$,$\bar{\alpha}t = \prod{i=1}^t \alpha_i$。
- 单步分布: $q(x_t \mid x_{t-1}) = \mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, (1-\alpha_t)I)$
- 跳步分布: $q(x_t \mid x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t}x_0, (1-\bar{\alpha}_t)I)$
2.2 贝叶斯公式展开
\[q(x_{t-1} \mid x_t, x_0) = \frac{q(x_t \mid x_{t-1}, x_0) q(x_{t-1} \mid x_0)}{q(x_t \mid x_0)}\]由于都是高斯分布,其乘积和商仍然是高斯分布。我们只关注指数部分的二次项。
\[q(x_{t-1} \mid x_t, x_0) \propto \exp \left( -\frac{1}{2} \left[ \frac{(x_t - \sqrt{\alpha_t}x_{t-1})^2}{\beta_t} + \frac{(x_{t-1} - \sqrt{\bar{\alpha}_{t-1}}x_0)^2}{1-\bar{\alpha}_{t-1}} - \frac{(x_t - \sqrt{\bar{\alpha}_t}x_0)^2}{1-\bar{\alpha}_t} \right] \right)\]2.3 待定系数法 (配方 Completing the Square)
我们需要整理出关于 $x_{t-1}$ 的二次型:$-\frac{1}{2\tilde{\beta}t} (x{t-1} - \tilde{\mu}t)^2$。 展开指数内部关于 $x{t-1}$ 的项:
\[\text{Exp} = -\frac{1}{2} \left( \underbrace{\left[ \frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}} \right]}_{A} x_{t-1}^2 - 2 \underbrace{\left[ \frac{\sqrt{\alpha_t}}{\beta_t}x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}}x_0 \right]}_{B} x_{t-1} + C(x_t, x_0) \right)\]第一步:计算方差 $\tilde{\beta}_t$ (对应 $A^{-1}$)
\[\frac{1}{\tilde{\beta}_t} = \frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}} = \frac{1 - \bar{\alpha}_t}{\beta_t(1-\bar{\alpha}_{t-1})}\]所以方差为:
\[\tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t\]第二步:计算均值 $\tilde{\mu}_t$ (对应 $B/A = B \cdot \tilde{\beta}_t$)
\[\tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t} x_0\]至此,我们得到了真实后验分布的参数:
\[q(x_{t-1} \mid x_t, x_0) = \mathcal{N}(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t I)\]本章小结:我们成功算出了真实目标 $q$ 的均值 $\tilde{\mu}t$。这是一个确定的公式,依赖于 $x_t$ 和 $x_0$。 下一步:现在我们有了“靶子”($\tilde{\mu}_t$),我们可以构建神经网络 $p\theta$ 来瞄准这个靶子,并推导具体的 Loss 函数。 —
3. 模型预测与 Loss 简化
既然我们要让模型预测的分布 $p_\theta$ 接近真实分布 $q$,且两者都是高斯分布,那么 KL 散度最小化就等价于均值的 MSE 最小化。我们需要把上一章算出来的 $\tilde{\mu}_t$ 代入 Loss 中。
3.1 均值的重参数化 (Reparameterization)
在训练时,网络输入是 $x_t$,我们希望网络能预测出 $\tilde{\mu}_t$。但直接预测均值并不直观。 利用前向公式反解 $x_0$:$x_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t}\epsilon}{\sqrt{\bar{\alpha}_t}}$。 将此 $x_0$ 代入刚才推导的 $\tilde{\mu}_t$ 公式中,经过代数化简可得:
\[\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon \right)\]结论:要预测分布的均值 $\tilde{\mu}_t$,本质上只需要预测当前时刻添加的噪声 $\epsilon$。
3.2 损失函数推导
我们的优化目标是最小化 KL 散度:
\[L_{t-1} = D_{KL}(q(x_{t-1} \mid x_t, x_0) \mid\mid p_\theta(x_{t-1} \mid x_t)) \propto || \tilde{\mu}_t - \mu_\theta ||^2\]我们将网络设计为预测噪声 $\epsilon_\theta(x_t, t)$,则网络输出的均值 $\mu_\theta$ 也应该具有相同的形式:
\[\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right)\]将真值 $\tilde{\mu}t$ 和预测值 $\mu\theta$ 代入 MSE 公式:
\[\begin{aligned} L_{t-1} &\propto \mathbb{E}_{x_0, \epsilon} \left[ \left\| \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon \right) - \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) \right\|^2 \right] \\\\ &= \mathbb{E}_{x_0, \epsilon} \left[ \frac{\beta_t^2}{\alpha_t(1-\bar{\alpha}_t)} || \epsilon - \epsilon_\theta(x_t, t) ||^2 \right] \end{aligned}\]3.3 最终的简化 Loss ($L_{\text{simple}}$)
DDPM 论文发现,去掉前面的权重系数 $\frac{\beta_t^2}{\alpha_t(1-\bar{\alpha}_t)}$,直接优化单纯的 MSE 效果更好。这意味着我们只需要训练一个网络,让它根据输入的带噪图像 $x_t$,预测出其中包含的噪声 $\epsilon$。
\[L_{\text{simple}} = \mathbb{E}_{t \sim [1, T], x_0, \epsilon \sim \mathcal{N}(0, I)} \left[ || \epsilon - \epsilon_\theta(\sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon, t) ||^2 \right]\]
Leave a comment