Flow models I: Diffusion models

March 2023

These series of notes focus on diffusion-based generative models, like the celebrated Denoising Diffusion Probabilistic Models; they contain the material I regularly present as lectures in some working groups for mathematicians or math graduate students, so the style is tailored for this audience. In particular, everything is fitted into the continuous-time framework (which is not exactly how it is done in practice).

A special attention is given to the differences between ODE sampling and SDE sampling. The analysis of the time evolution of the densities ptp_t is done using only Fokker-Planck Equations or Transport Equations.

The problem

Let pp be a probability density on Rd\mathbb{R}^d. The goal of generative modelling is twofold: given samples x1,,xnx^1, \dotsc, x^n from pp, we want to estimate pp and generate new samples from pp.

Many methods were designed for tackling these challenges: Energy-Based Models, Normalizing Flows and the famous Neural ODEs, vanilla Score-Matching, GANs. Each has its limitations: for example, EBMs are challenging to train, NFs lack expressivity and SM fails to capture multimodal distributions. Diffusion models, and their successors flow-based models, offer sufficient flexibility to partially overcome these limitations. Their ability to be guided towards conditional generations using Classifier-Free Guidance is also a major advantage, and will be reviewed in the third note of the series.

Stochastic interpolation

Diffusion models fall into the framework of stochastic interpolants. The idea is to continuously transform the density pp into another easy-to-sample density π\pi (often called the target), while also transforming the samples xix^i from pp into samples from π\pi; and then, to reverse the process: that is, to generate a sample from π\pi, and to inverse the transformation to get a new sample from pp. In other words, we seek a path (pt: t[0,T])(p_t: t\in [0,T]) with p0=pp_0=p and pT=qp_T=q, such that generating samples xtptx_t \sim p_t is doable.

The success of diffusion models came from the realization that some stochastic processes, such as Ornstein-Uhlenbeck processes that connect p0p_0 with a distribution pTp_T very close to pure noise N(0,I)\mathscr{N}(0,I), can be reversed when the score function logpt\nabla \log p_t is available at each time tt. Although unknown, this score can efficiently be learnt using statistical procedures called score matching.

Original formulation: Gaussian noising process and its inversion

Time-Reversal of diffusions

Let (t,x)ft(x)(t,x)\to f_t(x) and twtt\to w_t be two smooth functions. Consider the stochastic differential equation

dXt=ft(Xt)dt+2wt2dBt,X0p\begin{aligned}& dX_t = f_t(X_t)dt + \sqrt{2w_t ^2}dB_t, \\ & X_0 \sim p\end{aligned}

where dBtdB_t denotes integration with respect to a Brownian motion. Under mild conditions on ff, an almost-surely continuous stochastic process satisfying this SDE exists. Let ptp_t be the probability density of XtX_t; it is a non-trivial fact that this process can be reversed in time. More precisely, the following SDE is exactly the time-reversal of (1):

dYt=(ft(Yt)+2wt2logpt(Yt))dt+2wt2dBtYTpT.\begin{aligned} & dY_t = -\left( f_t(Y_t)+ 2w_t^2 \nabla \log p_t(Y_t) \right)dt + \sqrt{2w_t^2}dB_t \\ & Y_T \sim p_T. \end{aligned}
Theorem (Anderson, 1982). The law of the stochastic process (Yt)t[0,T](Y_t)_{t \in [0,T]} is the same as the law of (XTt)t[0,T](X_{T-t})_{t \in [0,T]}.

While this theorem is not easy to prove, we will later check using the Fokker-Planck equation that the marginals pTtp_{T-t} of the process XTtX_{T-t} are indeed the same as the marginals qTtq_{T-t} of the process YtY_t.

Sampling paths from the reverse SDE (2) needs access to logpt\nabla \log p_t. The key point of diffusion-like methods is that this quantity can be estimated.

Working out the Ornstein-Uhlenbeck process

For simple functions ff, the process (1) has an explicit representation. Here we focus on the case where ft(x)=μtxf_t(x) = -\mu_t x for some function μ\mu called drift coefficient, that is

dXt=μtXt+2wt2dBt. dX_t = -\mu_t X_t + \sqrt{2w_t^2}dB_t.

Define αt=eAt\alpha_t = e^{-A_t} where At=0tμsdsA_t =\int_0^t \mu_s ds. Then, the solution of (3) is given by the following stochastic process: Xt=αtX0+20teAsAtwsdBs. X_t = \alpha_t X_0 + \sqrt{2}\int_0^t e^{A_s-A_t} w_s dB_s. In particular, noting

σˉt2=20te2stμuduws2ds\bar{\sigma}_t^2 = 2\int_0^t e^{-2\int_s^t \mu_u du}w_s^2 ds

we have

Xt=lawαtX0+σˉtε X_t \stackrel{\mathrm{law}}{=} \alpha_t X_0 + \bar{\sigma}_t \varepsilon

where εN(0,1)\varepsilon \sim \mathscr{N}(0,1).

In the pure Orstein-Uhlenbeck case where wt=ww_t = w and μt=1\mu_t = 1, we get αt=et\alpha_t = e^{-t} and Xt=etX0+N(0,1e2t)X_t = e^{-t}X_0 + \mathscr{N}(0,1 - e^{-2t}).

Proof of (4). We set F(x,t)=xeAtF(x,t) = xe^{A_t} and Yt=F(Xt,t)=XteAtY_t = F(X_t, t) = X_t e^{A_t}; it turns out that YtY_t satisfies a nicer SDE. Since Δxf=0\Delta_x f = 0, tf(x,t)=xeAtμt\partial_t f(x,t) = xe^{A_t}\mu_t and xf(x,t)=eAt\nabla_x f(x,t) = e^{A_t}, Itô's formula says that dYt=tF(t,Xt)dt+xF(t,Xt)dXt+2wt22ΔxF(t,Xt)dt=XteAtμtdt+eAtdXt=2wt2e2AtdBt.\begin{aligned}dY_t &= \partial_tF(t,X_t)dt + \nabla_x F(t,X_t)dX_t + \frac{2w_t^2}{2}\Delta_x F(t,X_t)dt \\ &= X_te^{A_t}\mu_tdt + e^{A_t} dX_t \\ &= \sqrt{2w_t^2 e^{2A_t}}dB_t. \end{aligned} Consequently, Yt=Y0+0t2ws2e2AsdBsY_t = Y_0 + \int_0^t \sqrt{2w_s^2e^{2A_s}}dB_s and the result holds when we multiply everything by eAte^{-A_t}.

The second term in (4) reduces to a Wiener Integral; it is a centered Gaussian with variance 20te2(AsAt)ws2ds2\int_0^t e^{2(A_s-A_t)}w_s^2 ds, hence Xt=laweAtX0+N(0,20te2As2Atws2ds). X_t \stackrel{\mathrm{law}}{=} e^{-A_t}X_0 + \mathscr{N}\left(0, 2\int_0^t e^{2A_s - 2A_t}w_s^2 ds\right).

A consequence of the preceding result is that when the variance σˉt2\bar{\sigma}_t^2 is big compared to αt\alpha_t, then the distribution of XtX_t is well-approximated by N(0,σˉt2)\mathscr{N}(0,\bar{\sigma}_t^2). Indeed, for μt=wt=1\mu_t = w_t = 1, we have αT=eT\alpha_T = e^{-T} and σˉT=1e2T1\bar{\sigma}_T = \sqrt{1 - e^{-2T}} \approx 1 if TT is sufficiently large, like T>10T>10.

The Fokker-Planck point of view

It has recently been recognized that the Ornstein-Uhlenbeck representation of ptp_t as in (1), as well as the stochastic process (2) that has the same marginals as ptp_t, are not necessarily unique or special. Instead, what matters are two key features: (i) ptp_t provides a path connecting pp and pTN(0,I)p_T\approx N(0,I), and (ii) its marginals are easy to sample. There are other processes besides (1) that have ptp_t as their marginals, and that can also be reversed. The crucial point is that ptp_t is a solution of the Fokker-Planck equation:

tpt(x)=Δ(wt2pt(x))(ft(x)pt(x)). \partial_t p_t(x) = \Delta (w_t^2 p_t(x)) - \nabla \cdot (f_t(x)p_t(x)).

Just to settle the notations once and for all: \nabla is the gradient, and for a function ρ:RdRd\rho : \mathbb{R}^d \to \mathbb{R}^d, ρ(x)\nabla \cdot \rho(x) stands for the divergence, that is i=1dxiρ(x1,,xd)\sum_{i=1}^d \partial_{x_i} \rho(x_1, \dotsc, x_d), and =Δ=i=1dxi2\nabla \cdot \nabla = \Delta = \sum_{i=1}^d \partial^2_{x_i} is the Laplacian.

Proof (informal). For a compactly supported smooth test function φ\varphi, we have tE[φ(Xt)]=φ(x)tpt(x)dx\partial_t \mathbb{E}[\varphi(X_t)] = \int \varphi(x)\partial_t p_t(x)dx. On the other hand, this quantity is also equal to E[dφ(Xt)]\mathbb{E}[d\varphi(X_t)]. Itô's formula says that dφ(Xt)=φ(Xt)dXt+12Δφ(Xt)dtd\varphi(X_t) = \nabla \varphi(X_t) \cdot dX_t + \frac{1}{2}\Delta \varphi(X_t)dt, which is also equal to φ(Xt)ft(Xt)+Mt+12Δφ(Xt)dt\nabla \varphi(X_t)f_t(X_t) + M_t + \frac{1}{2}\Delta \varphi(X_t)dt, where MtM_t is a Brownian martingale started at 0, whose exapectation is thus 0. Gathering everything, we see that E[dφ(Xt)]\mathbb{E}[d\varphi(X_t)] is also equal to E[φ(Xt)ft(Xt)dt+12Δφ(Xt)dt]\mathbb{E}[\nabla \varphi(X_t) \cdot f_t(X_t)dt + \frac{1}{2}\Delta \varphi(X_t)dt], that is

φ(x)ft(x)pt(x)dx+12Δφ(x)wt2pt(x)dx. \int \nabla\varphi(x)\cdot f_t(x) p_t(x)dx + \frac{1}{2}\int\Delta \varphi(x)w_t^2 p_t(x) dx.

One integration by parts on the first integral, and two on the second, lead to the expression

φ(x)[(pt(x)ft(x))+wt22Δpt(x)]dx. \int \varphi(x) \left[-\nabla \cdot (p_t(x)f_t(x)) + \frac{w_t^2}{2}\Delta p_t(x)\right]dx.

Comparing this with the first expression for tE[φ(Xt)]\partial_t \mathbb{E}[\varphi(X_t)] gives the result.

Importantly, equation (9) can be recast as a transport equation: with a velocity field defined as

vt(x)=wt2logpt(x)ft(x),v_t(x) = w_t^2 \nabla \log p_t(x) - f_t(x),

the equation (9) is equivalent to

tpt(x)=(vt(x)pt(x)). \partial_t p_t(x) = \nabla \cdot (v_t(x)p_t(x)).
Proof. vt(x)pt(x)=(wt2logpt(x))pt(x)ft(x)pt(x)=wt2pt(x)ft(x)pt(x)\nabla \cdot v_t(x)p_t(x) = \nabla\cdot (w_t^2\nabla \log p_t(x))p_t(x) - \nabla\cdot f_t(x)p_t(x)= w_t^2 \nabla\cdot \nabla p_t(x) - \nabla\cdot f_t(x)p_t(x), and since =Δ\nabla\cdot\nabla = \Delta, this is equal to wt2Δpt(x)ft(x)pt(x)=tpt(x).w_t^2 \Delta p_t(x) - \nabla \cdot f_t(x)p_t(x) = \partial_t p_t(x).

An associated ODE

Equations like (13) are called transport equations or continuity equations. They come from simple ODEs; that is, there is a deterministic process with the same marginals as (1).

Let x(t)x(t) be the solution of the differential equation with random initial condition x(t)=vt(x(t))x(0)=X0.x'(t) = -v_t(x(t))\qquad \qquad x(0) =X_0. Then the probability density of x(t)x(t) satisfies (13), hence it is equal to ptp_t.
Proof. Let ptp_t be the probability density of x(t)x(t) and let φ\varphi be any smooth, compactly supported test function. Then, E[φ(x(t))]=pt(x)φ(x)dx\mathbb{E}[\varphi(x(t))] = \int p_t(x)\varphi(x)dx, so by derivation under the integral, tpt(x)φ(x)dx=tE[φ(x(t))]=E[φ(x(t))x(t)]=φ(x)vt(x)pt(x)dx=φ(x)(vt(x)pt(x))dx\begin{aligned}\int \partial_t p_t(x)\varphi(x)dx = \partial_t \mathbb{E}[\varphi(x(t))]&= \mathbb{E}[\nabla\varphi(x(t))x'(t)]\\ &= -\int \nabla \varphi(x)v_t(x)p_t(x)dx = \int \varphi(x) \nabla \cdot (v_t(x)p_t(x))dx \end{aligned} where the last line uses the multidimensional integration by parts formula.

Up to now, we proved that there are two continuous random processes having the same marginal probability density at time tt: a smooth one provided by x(t)x(t), the solution of the ODE, and a continuous but not differentiable one, XtX_t, provided by the solution of the SDE.

Time-reversal of Transport Equations and Fokker-Planck equations

We now have various processes x(t),Xtx(t), X_t starting at a density p0p_0 and evolving towards a density pTπ=N(0,I)p_T \approx \pi = \mathscr{N}(0,I). Can these processes be reversed in time? The answer is yes for both of them. We'll start by reversing their associated equations. From now on, we will note ptbp^{\mathrm{b}}_t the time-reversal of ptp_t, that is:

ptb(x)=pTt(x).p^{\mathrm{b}}_t(x) = p_{T-t}(x).

The density ptbp^{\mathrm{b}}_t solves the backward Transport Equation: ptb(x)=vtb(x)ptb(x)\partial p^{\mathrm{b}}_t(x)= \nabla \cdot v^{\mathrm{b}}_t(x) p^{\mathrm{b}}_t(x) where

vtb(x)=vt(x)=wTt2logpt(x)μTtx.v^{\mathrm{b}}_t(x) = -v_t(x) = -w_{T-t}^2 \nabla \log p_t(x) - \mu_{T-t} x.

The density ptbp^{\mathrm{b}}_t also solves the backward Fokker-Planck Equation: ptb(x)=wTt2Δptb(x)utb(x)ptb(x)\partial p^{\mathrm{b}}_t(x) =w_{T-t}^2 \Delta p^{\mathrm{b}}_t(x) - \nabla \cdot u_t^{\mathrm{b}}(x)p^{\mathrm{b}}_t(x) where

utb(x)=2wTt2logptb(x)+μTtx.u^{\mathrm{b}}_t(x) = 2w_{T-t}^2 \nabla \log p^{\mathrm{b}}_t(x) + \mu_{T-t} x.
Proof. Noting p˙t(x)\dot{p}_t(x) the time derivative of tpt(x)t\mapsto p_t(x) at time tt, we immediately see that tptb(x)=p˙Tt(x)\partial_t p^{\mathrm{b}}_t(x) = -\dot{p}_{T-t}(x) and the rest is a mere verification.

Of course, these two equations are the same, but they represent the time-evolution of the density of two different random processes. As explained before, the Transport version (17) represents the time-evolution of the density of the ODE system

y(t)=vtb(y(t))y(0)pT\begin{aligned}& y'(t) = -v^{\mathrm{b}}_t(y(t)) \\ & y(0) \sim p_T\end{aligned}

while the Fokker-Planck version (19) represents the time-evolution of the SDE system

dYt=utb(Yt)dt+2wTt2dBtY0pT.\begin{aligned}&dY_t = u^{\mathrm{b}}_t(Y_t)dt + \sqrt{2w_{T-t}^2}dB_t \\ & Y_0 \sim p_T.\end{aligned}

Both of these two processes can be sampled using a range of ODE and SDE solvers, the simplest of which being the Euler scheme and the Euler-Maruyama scheme. However, this requires access to the functions vtbv^{\mathrm{b}}_t and wtbw^{\mathrm{b}}_t, which in turn depend on the unknown score logpt\nabla \log p_t. Fortunately, logpt\nabla \log p_t can efficiently be estimated due to two factors.

  1. First: we have samples from ptp_t. Remember that our only information about pp is a collection x1,,xnx^1, \dotsc, x^n of samples. Thanks to the representation (8), we can represent xti=αtxi+σˉtξix^i_t = \alpha_t x^i + \bar{\sigma}_t \xi^i with ξiN(0,I)\xi^i \sim \mathscr{N}(0,I) are samples from ptp_t. They are extremely easy to access, since we only need to generate iid standard Gaussian variables ξi\xi^i.

  2. Second: score matching. If pp is a probability density and xix^i are samples from pp, estimating logp\nabla \log p (called score) has been thoroughly examined and is fairly doable, a technique known as score matching.

Methods for learning the score

Learning the score xlnp(x)\nabla_x \ln p(x) of a probability density pp is a well-known problem in statistics, and is somehow orthogonal to the world of generative flow models. I gathered the main ideas in the next note on flow matching and Tweedie-s formula. In short, it turns out that training a neural network s(t,x)s(t,x) to denoise XtX_t (that is, to remove the added noise εt\varepsilon_t, where Xt=αtX0+εtX_t = \alpha_t X_0 + \varepsilon_t) with the loss E[s(t,Xt)εt2]\mathbb{E}[|s(t,X_t) - \varepsilon_t|^2] directly leads to an estimator of the score,

lnpt(x)s(t,x)σˉt2.\nabla \ln p_t(x) \approx -\frac{s(t,x)}{\bar{\sigma}_t^2}.

Choice of architecture

In practice, for image generations, the go-to choice for the architecture of sθs_\theta was first chosen to be a U-net, a special kind of convolutional neural networks with a downsampling phase, an upsampling phase, and skip-connections in between. After 2023 it seemed that everyone switched to pure-transformers models, following the landmark DiT paper from Peebles and Xie.

Sampling

Once the algorithm has converged to θ\theta, we get sθ(t,x)s_\theta(t,x) which is a proxy for logpt(x)\nabla \log p_t(x) (we absorbed the constant σˉt2\bar{\sigma}_t^2 into the definition of ss). Now, we simply plug this expression in the functions vtbv^{\mathrm{b}}_t if we want to solve the ODE (21) or wtbw^{\mathrm{b}}_t if we want to solve the SDE (22).

where v^tb(x)=wTt2sθ(Tt,x)μTtx\hat{v}^{\mathrm{b}}_t(x) = -w_{T-t}^2 s_\theta(T-t,x) - \mu_{T-t} x.

We must stress a subtle fact. Equations (9) and (13), or their backward counterparts, are exactly the same equation accounting for ptp_t. But since now we replaced logpt\nabla \log p_t by its approximation sθs_\theta, this is no longer the case for our two samplers: their probability densities are not the same. In fact, let us note qtode,qtsdeq^{\mathrm{ode}}_t,q^{\mathrm{sde}}_t the densities of y(t)y(t) and YtY_{t}; the first one solves a Transport Equation, the second one a Fokker-Planck equation, and these two equations are different.

Backward Equations for the samplers tqtode(x)=v^tb(x)qtode(x)q0ode=π\partial_t q^{\mathrm{ode}}_t(x) = \nabla \cdot \hat{v}^{\mathrm{b}}_t(x)q^{\mathrm{ode}}_t(x)\qquad \qquad q_0^{\mathrm{ode}} = \pi tqtsde(x)=[wTt2logqtsde(x)u^tb(x)]qtsde(x)q0sde=π\partial_t q^{\mathrm{sde}}_t(x) = \nabla \cdot [w_{T-t}^2\nabla \log q^{\mathrm{sde}}_t(x) - \hat{u}^{\mathrm{b}}_t(x)]q^{\mathrm{sde}}_t(x) \qquad \qquad q_0^{\mathrm{sde}} = \pi

Importantly, the velocity wTt2logqtsde(x)u^tb(x)w_{T-t}^2\nabla \log q^{\mathrm{sde}}_t(x) - \hat{u}^{\mathrm{b}}_t(x) is in general not equal to the velocity v^tb(x)\hat{v}^{\mathrm{b}}_t(x). They would be equal only in the case sθ(t,x)=logpt(x)s_\theta(t,x) = \nabla \log p_t(x).

Proof. Since y(t)y(t) is an ODE, it directly satisfies the transport equation with velocity v^tb\hat{v}^{\mathrm{b}}_t. Since YtY_t is an SDE, it satisfies the Fokker-Planck equation associated with the drift u^tb\hat{u}^{\mathrm{b}}_t, which in turn can be transformed in the transport equation shown above.

Design choices for μt\mu_t and wtw_t

We recall that the SDE equation is given by

dXt=μtXtdt+2wt2dBt. dX_t = -\mu_t X_t dt + \sqrt{2w_t^2}dB_t.

We showed that the solution of this equation at time tt has the same distribution as αtXt+σˉtε\alpha_t X_t + \bar{\sigma}_t \varepsilon where εN(0,1)\varepsilon \sim \mathscr{N}(0,1). Here, the αt,σˉt\alpha_t, \bar{\sigma}_t are related to μt,wt\mu_t, w_t by

αt=exp{0tμsds} \alpha_t = \exp\left\lbrace -\int_0^t \mu_s ds \right\rbrace σˉt2=20te2stμuduws2ds.\bar{\sigma}_t^2 = 2\int_0^t e^{-2\int_s^t \mu_u du}w_s^2 ds.

Considerable work has been done (mostly experimentally) to find good functions μt,wt\mu_t,w_t. Some choices seem to stand out.

Variance Exploding path

The VE path takes μt=0\mu_t = 0 (that is, no drift) and wtw_t a continuous, increasing function over [0,1)[0,1), such that σ0=0\sigma_0 = 0 and σ1=+\sigma_1 = +\infty; typically, wt=(1t)1w_t = (1-t)^{-1}. This gives parameters

αt=1,σˉt2=20tws2ds=20t(1s)2ds=2t1t.\alpha_t = 1, \qquad \bar{\sigma}_t^2 = 2\int_0^t w_s^2 ds = 2\int_0^t (1-s)^{-2}ds = 2\frac{t}{1-t}.

Variance-Preserving path

The VP takes wt=μtw_t = \sqrt{\mu_t}. In this case we see that in this case, αt=e0tμsds\alpha_t = e^{-\int_0^t \mu_s ds} and

σˉt2=20te2stμuduμsds=1e20tμsds. \bar{\sigma}_t^2 = 2\int_0^t e^{-2\int_s^t \mu_u du}\mu_s ds = 1 - e^{-2\int_0^t \mu_s ds}.

The name « variance preserving » comes from the fact that the element-wise variance of Xt=αtX0+σˉtεX_t = \alpha_t X_0 + \bar{\sigma}_t \varepsilon is exactly αt2+σˉt2\alpha_t^2 + \bar{\sigma}_t^2, which in this case is equal to 1 (we supposed without loss of generality that X0X_0 had been standardized to have element-wise variance 1).

The pure Ornstein-Uhlenbeck path

The OU path takes wt=μt=1w_t = \mu_t = 1, so that in this case we've already seen that

αt=et,σˉt2=1e2t. \alpha_t = e^{-t}, \qquad \bar{\sigma}_t^2 = 1 - e^{-2t}.

This is not used in practice and is more for theoretical purposes.

Toward Flow Matching

The design choice for a diffusion reduces to the drift and diffusion coefficients μt,wt\mu_t, w_t. These choices restrict the actual variances αt,σˉt\alpha_t, \bar{\sigma}_t. But there might be a way to directly choose αt,σˉt\alpha_t, \bar{\sigma}_t and specify paths having distribution αtX0+σˉtε\alpha_t X_0 + \bar{\sigma}_t \varepsilon. Typically, we would like to choose

αt=cos(πt/2),σˉt2=sin(πt/2). \alpha_t = \cos(\pi t / 2), \qquad \bar{\sigma}_t^2 = \sin(\pi t /2).

Of course we could find μt,wt\mu_t, w_t to solve these equations, but this is weird. This decoupling is done through stochastic interpolation and will be reviewed in the third note of the series.

A variational bound for the SDE sampler

Let s: [0,T]×RdRds : [0,T]\times \mathbb{R}^d \to \mathbb{R}^d be a smooth function, meant as a proxy for logpt\nabla \log p_t. Our goal is to quantify the difference between the sampled densities qtode,qtsdeq^{\mathrm{ode}}_t, q^{\mathrm{sde}}_t and ptb=pTtp^{\mathrm{b}}_t=p_{T-t}. It turns out that controlling the Fisher divergence E[logpt(X)s(t,X)2]\mathbb{E}[|\nabla \log p_t(X) - s(t,X)|^2] results in a bound for kl(pqTsde)\mathrm{kl}(p \mid q_T^{\mathrm{sde}}), but not for kl(pqTode)\mathrm{kl}(p \mid q_T^{\mathrm{ode}}).

Small recap on notations

The true density is ptb=pTtp^{\mathrm{b}}_t = p_{T-t}, it satisfies the backward equation (17):

tptb(x)=vtb(x)ptb(x)vtb(x)=wTt2logptb(x)μTtx. \partial_t p^{\mathrm{b}}_t(x) = \nabla \cdot v^{\mathrm{b}}_t(x)p^{\mathrm{b}}_t(x)\qquad \qquad v^{\mathrm{b}}_t(x) = -w_{T-t}^2\nabla \log p^{\mathrm{b}}_t(x) - \mu_{T-t}x.

The density of the generative process is qtsdeq^{\mathrm{sde}}_t, but we'll simply note qtq_t. It satisfies the backward equation (25)

tqt(x)=ut(x)qt(x)\partial_t q_t(x) = \nabla\cdot u_t(x)q_t(x)

where

ut(x)=wTt2logqt(x)2wTt2s(t,x)μTtx. u_t(x) = w_{T-t}^2\nabla \log q_t(x) - 2w_{T-t}^2s(t,x) - \mu_{T-t}x.

The original distribution we want to sample is p=p0=pTbp = p_0 = p^{\mathrm{b}}_T, and the output distribution of our SDE sampler is qTsde=qTq^{\mathrm{sde}}_T = q_T. Finally, the distribution pT=p0bp_T = p_0^{\mathrm{b}} is approximated with π\pi (in practice, N(0,I)\mathscr{N}(0,I)).

The KL divergence between densities ρ1,ρ2\rho_1, \rho_2 is

kl(ρ1ρ2)=ρ2(x)log(ρ2(x)/ρ1(x))dx. \mathrm{kl}(\rho_1 \mid \rho_2) = \int \rho_2(x)\log(\rho_2(x)/ \rho_1(x))dx.

A variational lower-bound

This theorem restricts to the case where the weights w(t)w(t) are constant, and for simplicity, they are set to 1.

Variational lower-bound for score-based diffusion models with SDE sampler

kl(pqTsde)kl(pTπ)+0Twt2E[logpt(Xt)s(t,Xt)2]dt. \mathrm{kl}(p \mid q_T^{\mathrm{sde}}) \leqslant \mathrm{kl}(p_T \mid \pi) +\int_0^T w^2_{t} \mathbb{E}[ |\nabla \log p_t(X_t) - s(t,X_t)\vert^2 ] dt.

The original proof can be found in this paper and uses the Girsanov theorem applied to the SDE representations (1)-(2) of the forward/backward process. This is utterly complicated and is too dependent on the SDE representation. The proof presented below only needs the Fokker-Planck equation and is done directly at the level of probability densities.

The following lemma is interesting on its own since it gives an exact expression for the KL divergence between transport equations.

ddtkl(ptbqt)=wTt2ptb(x)log(ptb(x)qt(x))(ut(x)vtb(x))dx\frac{d}{dt}\mathrm{kl}(p^{\mathrm{b}}_t \mid q_t) = w^2_{T-t} \int p^{\mathrm{b}}_t(x) \nabla \log\left(\frac{p^{\mathrm{b}}_t(x)}{q_t(x)}\right) \cdot \left(u_t(x)- v^{\mathrm{b}}_t(x) \right)dx

In our case with the specific shape assumed by ut,vtbu_t, v^{\mathrm{b}}_t, we get the following bound:

ddtkl(ptqt)wTt2ptb(x)s(t,x)logptb(x)2dx\begin{aligned}\frac{d}{dt}\mathrm{kl}(p_t \mid q_t) \leqslant w^2_{T-t}\int p^{\mathrm{b}}_t(x) |s(t,x) - \nabla \log p^{\mathrm{b}}_t(x) |^2 dx \end{aligned}

The proofs of (37)-(38)-(39) are only based on elementary manipulations of time-evolution equations.

Proof of (38).

A small differentiation shows that ddtkl(ptbqt) \frac{d}{dt}\mathrm{kl}(p^{\mathrm{b}}_t \mid q_t) is equal to

(vtb(x)ptb(x))log(ptb(x)/qt(x))dx+ptb(x)(vtb(x)ptb(x))ptb(x)dxptb(x)(ut(x)qt(x))qt(x)dx.\int \nabla \cdot (v^{\mathrm{b}}_t(x)p^{\mathrm{b}}_t(x))\log(p^{\mathrm{b}}_t(x)/q_t(x))dx + \int p^{\mathrm{b}}_t(x)\frac{\nabla \cdot (v^{\mathrm{b}}_t(x)p^{\mathrm{b}}_t(x))}{p^{\mathrm{b}}_t(x)}dx - \int p^{\mathrm{b}}_t(x)\frac{\nabla \cdot (u_t(x)q_t(x))}{q_t(x)} dx.

By an integration by parts, the first term is also equal to ptb(x)vtb(x)log(ptb(x)/qt(x))dx-\int p^{\mathrm{b}}_t(x)v^{\mathrm{b}}_t(x)\cdot \nabla \log(p^{\mathrm{b}}_t(x)/q_t(x))dx. For the second term, it is clearly zero. Finally, for the last one, ptb(x)(ut(x)qt(x))qt(x)dx=(ptb(x)/qt(x))ut(x)qt(x)dx=log(ptb(x)/qt(x))ut(x)ptb(x)dx.\begin{aligned} - \int p^{\mathrm{b}}_t(x)\frac{\nabla \cdot (u_t(x)q_t(x))}{q_t(x)} dx &= \int \nabla (p^{\mathrm{b}}_t(x)/q_t(x)) \cdot u_t(x)q_t(x)dx \\ &= \int \nabla \log(p^{\mathrm{b}}_t(x)/q_t(x))\cdot u_t(x)p^{\mathrm{b}}_t(x)dx. \end{aligned}

Proof of (39). We recall that

ut(x)=wTt2logqt(x)2wTt2s(t,x)μTtxu_t(x) = w^2_{T-t}\nabla \log q_t(x) - 2w^2_{T-t}s(t,x) - \mu_{T-t}x

and

vtb(x)=wTt2logptb(x)μTtx,v^{\mathrm{b}}_t(x) = -w^2_{T-t}\nabla \log p^{\mathrm{b}}_t(x) - \mu_{T-t}x,

so that utvtb=wTt2logqt2wTt2s+wTt2logptb=wTt2(logqtlogptb+2(logptbs)).\begin{aligned} u_t - v^{\mathrm{b}}_t &= w^2_{T-t}\nabla \log q_t - 2w^2_{T-t}s + w^2_{T-t}\nabla \log p^{\mathrm{b}}_t\\ &= w^2_{T-t} \left( \nabla \log q_t - \nabla \log p^{\mathrm{b}}_t + 2 (\nabla \log p^{\mathrm{b}}_t - s) \right).\end{aligned} We momentarily note a=logptb(x)a = \nabla \log p^{\mathrm{b}}_t(x) and b=logqt(x)b = \nabla \log q_t(x) and s=s(t,x)s=s(t,x). Then, (38) shows that ddtkl(ptbqt)=σTt2ptb(x)(ab)((ba)+2(sa))dx=wTt2pt(x)ab2dx+2wTt2pt(x)(ab)(sa)dx.\begin{aligned} \frac{d}{dt}\mathrm{kl}(p^{\mathrm{b}}_t \mid q_t) &= \sigma^2_{T-t}\int p^{\mathrm{b}}_t(x)(a - b)\cdot ((b-a) + 2(s - a))dx\\ &= - w^2_{T-t}\int p_t(x)|a-b|^2 dx + 2 w^2_{T-t}\int p_t(x)(a-b)\cdot (s-a)dx. \end{aligned} We now use the classical inequality 2(xy)x2+y22(x\cdot y) \leqslant |x|^2 + |y|^2; we get

ddtkl(ptbqt)wTt2ptb(x)s(t,x)logptb(x)2dx. \frac{d}{dt}\mathrm{kl}(p^{\mathrm{b}}_t \mid q_t) \leqslant w^2_{T-t} \int p^{\mathrm{b}}_t(x)|s(t,x) - \nabla\log p^{\mathrm{b}}_t(x)|^2dx.

Proof of (37).

Now, we simply write kl(pTbqTsde)kl(p0bq0sde)=0Tddtkl(ptbqt)dt\begin{aligned} \mathrm{kl}(p^{\mathrm{b}}_T \mid q^{\mathrm{sde}}_T) - \mathrm{kl}(p^{\mathrm{b}}_0 \mid q_0^{\mathrm{sde}}) &= \int_0^T \frac{d}{dt}\mathrm{kl}(p^{\mathrm{b}}_t \mid q_t) dt \end{aligned} and plug (39) inside the RHS. Here q0=πq_0 = \pi and pTb=pp^{\mathrm{b}}_T= p, hence the result.

What about the ODE ?

It turns out that the ODE solver, whose density is qtodeq^{\mathrm{ode}}_t, does not have such a nice upper bound. In fact, since qtodeq^{\mathrm{ode}}_t solves a Transport Equation, we can still use (38) but with utu_t replaced with v^tb\hat{v}^{\mathrm{b}}_t, and integrate in tt just as in (47). We have

v^tb(x)vtb(x)=logptb(x)s(t,x)=logptb(x)logqt(x)+logqt(x)s(t,x).\begin{aligned}\hat{v}^{\mathrm{b}}_t(x) - v^{\mathrm{b}}_t(x) &= \nabla \log p^{\mathrm{b}}_t(x)-s(t,x) \\ &= \nabla \log p^{\mathrm{b}}_t(x) - \nabla\log q_t(x) + \nabla\log q_t(x) - s(t,x). \end{aligned}

Using the Cauchy-Schwarz inequality, we could obtain the following upper bound.

kl(pqTode)kl(pTπ)0Tpt(x)logpt(x)logqt(x)2+pt(x)logqt(x)s(t,x)2dxdt0TE[logpt(Xt)logqt(Xt)2+logqt(Xt)s(t,Xt)2]dt.\begin{aligned} \mathrm{kl}(p \mid q_T^{\mathrm{ode}}) - \mathrm{kl}(p_T \mid \pi) &\leqslant \int_0^T \int p_t(x)\left|\nabla\log p_t(x) - \nabla\log q_t(x)\right|^2 + p_t(x)\left|\nabla \log q_t(x) - s(t,x)\right|^2 dx dt\\ &\leqslant \int_0^T \mathbb{E}\left[|\nabla\log p_t(X_t) - \nabla\log q_t(X_t)|^2 + |\nabla\log q_t(X_t) - s(t,X_t)|^2\right]dt. \end{aligned}

There is a significant difference between the score matching objective function and the SDE version. Minimizing the former does not minimize the upper bound, whereas the latter does. This disparity is due to the Fisher divergence, which does not provide control over the KL divergence between the solutions of two transport equations. However, it does regulate the KL divergence between the solutions of the associated Fokker-Planck equations, thanks to the presence of a diffusive term. This could be one of the reasons for the lower performance of ODE solvers that was observed by early experimenters in the field. However, more recent works (see the references just below) seemed to challenge this idea. With different dynamics than the Ornstein-Uhlenbeck one, deterministic sampling techniques like ODEs seem now to outperform the stochastic one. A complete understanding of these phenomena is not available yet; the outstanding paper on stochastic interpolants proposes a remarkable framework towards this task (and inspired most of the analysis in this note).

References

On diffusion models

The original paper on diffusion models

DDPM (seminal paper for image generation)

Diffusion beat GANs (pushing diffusions well beyond the SOTA)

Variational perspective on Diffusions or arxiv (the analytical SDE approach)

Maximum likelihood training of Diffusions (proofs of the variational lower-bound)

Probability flow for FP, containing the proof of the variational lower-bound for the FP equation.