Trong bài viết này, mình sẽ giới thiệu về mô hình diffusion, một mô hình sinh với sự đột phá gần đây, cùng với mô hình score matching đã vượt qua GAN trong việc sinh dữ liệu. Hai mô hình này có thể xem như trường hợp đặc biệt của phương trình vi phân ngẫu nhiên, và được tổng quát thành mô hình dạng SDE, đưa ra một góc nhìn mới cũng như việc kết hợp hai loại mô hình này. Mô hình diffusion cũng như mô hình dạng SDE khi sinh dữ liệu không điều kiện thậm chí còn cho kết quả tốt hơn GAN khi sinh dữ liệu với nhãn cho trước.
Do nội dung khá dài nên phần cài đặt mình sẽ để sang bài khác nếu có thời gian, các bạn có thể xem trước notebook tutorial của tác giả tại đây. Một số chứng minh chi tiết mình sẽ để ở cuối để tránh đi xa khỏi nội dung chính, các bạn quan tâm có thể đọc thêm.
Mô hình diffusion
Ý tưởng của phương pháp này là biến đổi phân bố dữ liệu thành một phân bố có thể lấy mẫu được. Việc sinh dữ liệu sẽ bắt đầu từ phân bố này, sau đó biến đổi ngược về phân bố ban đầu. Mô hình cần học ở đây sẽ là phép biến đổi ngược đó. Quá trình biến đổi này được mô tả bằng một chuỗi các phân bố, cụ thể hơn chúng ta sẽ sử dụng quá trình ngẫu nhiên để mô tả chuỗi này.
Định nghĩa: Quá trình ngẫu nhiên là một họ các biến ngẫu nhiên {Xt}t∈T{X_t}_{tin T} từ cùng một không gian xác suất sang cùng một không gian trạng thái. Ở đây tập chỉ số TT có thứ tự, ví dụ T=R+T=mathbb{R}^+ hoặc T=Z+T=mathbb{Z}^+.
Quá trình ngẫu nhiên được gọi là quá trình Markov nếu nó thỏa mãn tính chất Markov. Một cách trực quan, xác suất của trạng thái tại tương lai khi biết trạng thái hiện tại không phụ thuộc vào quá khứ. Đối với chuỗi Markov, tính chất này có thể được viết thành
P(Xn+m=i∣X1,…,Xn)=P(Xn+m=i∣Xn)mathbb{P}(X_{n+m}=i|X_1,dots,X_n)=mathbb{P}(X_{n+m}=i|X_n)
Quá trình thuận
Để cho đơn giản, xác suất chuyển từ thời điểm tt sang thời điểm ss sẽ được kí hiệu là q(xs∣xt)q(x_s|x_t).Từ tính chất Markov, xác suất liên hợp được phân tích thành
q(x0…xt)=q(x0)∏i=1Tq(xi∣xi−1)q(x_0dots x_t)=q(x_0)prod_{i=1}^T q(x_i|x_{i-1})
Xác suất chuyển q(xt∣xt−1)q(x_t|x_{t-1}) được mô hình bởi N(xt;1−βtxt−1,βtI)mathcal{N}(x_t;sqrt{1-beta_t} x_{t-1}, beta_tI).
Xác suất khi biết trạng thái x0x_0 cũng là phân bố Gaussian, đạt được nhờ tính chất Markov. Đặt αt=1−βt, αtˉ=∏i=1tαialpha_t = 1-beta_t,,bar{alpha_t}=prod_{i=1}^t alpha_i, ta có q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)q(x_t|x_0) = mathcal{N}(x_t; sqrt{bar{alpha_t}}x_0, (1-bar{alpha_t})I).
Phân bố tại trạng thái xTx_T được xem như prior, sao cho có thể lấy mẫu được. Nhờ vào tính chất của xác suất điều kiện trên, với βtbeta_t phù hợp, q(xT)≈N(0,I)q(x_T)approx mathcal{N}(0, I).
Quá trình nghịch
Lúc này, quá trình sẽ bắt đầu từ phân bố p(xT)p(x_T) tại xTx_T, biến đổi ngược lại để quay về phân bố gốc của dữ liệu p(x0)p(x_0). Quá trình này có thể xem như một chuỗi Markov với chiều ngược lại, do đó xác suất liên hợp được phân tích thành
p(x0…xT)=p(xT)∏i=1Tp(xi−1∣xi)p(x_0dots x_T) = p(x_T)prod_{i=1}^{T}p(x_{i-1}|x_i)
Mục tiêu lúc này là tìm xác suất chuyển p(xt−1∣xt)p(x_{t-1}|x_t) của chuỗi Markov này. Ta sẽ mô hình xác suất này bởi phân bố Gaussian, có dạng N(xt−1;μθ(xt,t),Σθ(xt,t))mathcal{N}(x_{t-1};mu_{theta}(x_t,t), Sigma_{theta}(x_t, t)).
Huấn luyện
Mục tiêu của quá trình huấn luyện là cực đại likelihood của phân bố dữ liệu của mô hình sinh
p(x0)=∫p(x0…xT)dx1…xT=∫p(x0…xT)q(x1…xT∣x0)q(x1…xT∣x0)dx1…xT=∫p(xT)∏i=1Tp(xi−1∣xi)q(xi∣xi−1)dQ(x1…xT∣x0)begin{aligned}
p(x_0)&=int p(x_0dots x_T)dx_1dots x_T\
&=int frac{p(x_0dots x_T)}{q(x_1dots x_T|x_0)}q(x_1dots x_T|x_0)dx_1dots x_T\
&= int p(x_T)prod_{i=1}^T frac{p(x_{i-1}|x_i)}{q(x_i|x_{i-1})} dQ(x_1dots x_T|x_0)
end{aligned}
Áp dụng bất đẳng thức Jensen ta có
logp(x0)≥∫log(p(xT)∏i=1Tp(xi−1∣xi)q(xi∣xi−1))dQ(x1…xT∣x0)begin{aligned}
log p(x_0) &geqint log(p(x_T)prod_{i=1}^T frac{p(x_{i-1}|x_i)}{q(x_i|x_{i-1})}) dQ(x_1dots x_T|x_0)
end{aligned}
với t>1t>1, ta có thể tính posterior như sau
q(xt∣xt−1)=q(xt∣xt−1,x0)tıˊnh chaˆˊt Markov=q(xt−1∣xt,x0)q(xt∣x0)q(xt−1∣x0)begin{aligned}
q(x_t|x_{t-1}) &= q(x_t|x_{t-1}, x_0)quad text{tính chất Markov}\
&=frac{q(x_{t-1}|x_t, x_0)q(x_t|x_0)}{q(x_{t-1}|x_0)}
end{aligned}
Chặn dưới của log likelihood trở thành
L(x0)=Eq(x1…xT∣x0)[logp(XT)+∑t=2Tlogp(Xt−1∣Xt)q(Xt∣Xt−1)+logp(x0∣X1)+logq(X1∣x0)]=Eq(x1…xT∣x0)[logp(XT)q(XT∣x0)+∑t=2Tlogp(Xt−1∣XT)q(Xt−1∣Xt,x0)+logp(x0∣X1)+logq(X1∣x0)]begin{aligned}
L(x_0)&=mathbb{E}_{q(x_1dots x_{T}|x_0)}[log p(X_T)+sum_{t=2}^Tlogfrac{p(X_{t-1}|X_t)}{q(X_t|X_{t-1})} + log p(x_0|X_1) + log q(X_1|x_0)]\
&=mathbb{E}_{q(x_1dots x_{T}|x_0)}[logfrac{p(X_T)}{q(X_T|x_0)}+sum_{t=2}^Tlog frac{p(X_{t-1}|X_T)}{q(X_{t-1}|X_t, x_0)} + log p(x_0|X_1) + log q(X_1|x_0)]\
end{aligned}
Thành phần logq(x1∣x0)log q(x_1|x_0) là xác suất chuyển của quá trình thuận, do đó không có tham số và có thể loại bỏ trong quá trình huấn luyện. Mục tiêu của chúng ta là cực đại chặn dưới của log likelihood, tương đương với việc cực tiểu hàm mục tiêu sau
L=Eq[KL(q(xT∣X0)∣∣p(xT))+∑t=2TKL(q(xt−1∣Xt,X0)∣∣p(xt−1∣Xt))−logp(X0∣X1)]L=mathbb{E}_q[KL(q(x_T|X_0)||p(x_T)) +sum_{t=2}^TKL(q(x_{t-1}|X_t, X_0)||p(x_{t-1}|X_t))-log p(X_0|X_1)]
với kì vọng được lấy theo q(x0…xT)q(x_0dots x_T).
Các xác suất ở trên đều là phân bố Gaussian, do đó khoảng cách KL có thể tính từ kì vọng và phương sai. Đối với posterior, q(xt−1∣xt,x0)q(x_{t-1}|x_t, x_0) sẽ là phân bố Gaussian N(xt−1;αˉt−1βt1−αˉtx0+αt(1−αˉt−1)1−αˉtxt,β~tI)mathcal{N}(x_{t-1}; frac{sqrt{baralpha_{t-1}}beta_t}{1-baralpha_t}x_0 +frac{sqrt{alpha_t}(1-baralpha_{t-1})}{1-baralpha_{t}}x_t, tildebeta_tI), với β~t=1−αˉt−11−αˉtβttilde beta_t=frac{1-baralpha_{t-1}}{1-baralpha_t}beta_t.
Mô hình denoise diffusion
Để cho đơn giản, Σθ(xt,t)Sigma_{theta}(x_t, t) sẽ được đặt là σt2Isigma_t^2I, với σtsigma_t được chọn trước, do đó không tham gia vào quá trình huấn luyện. Tác giả đưa ra hai lựa chọn σt2=βtsigma_t^2=beta_t và σt2=β~tsigma_t^2=tilde beta_t, tương đương với việc entropy H(q(xt−1∣xt))H(q(x_{t-1}|x_t)) lớn nhất và nhỏ nhất, qua thực nghiệm hai cách chọn này cho kết quả tương đương.
Kí hiệu μ~t(xt,x0)tildemu_t(x_t, x_0) là kì vọng của q(xt−1∣xt,x0)q(x_{t-1}|x_t, x_0), với khoảng cách KL giữa hai phân bố Gaussian ta có
Lt−1=Eq[KL(q(xt−1∣Xt,X0)∣∣p(xt−1∣Xt))]=Ex0,xt[12σt2∣∣μ~t(Xt,X0)−μθ(Xt,t)∣∣2]+Cbegin{aligned}
L_{t-1} &= mathbb{E}_q[KL(q(x_{t-1}|X_t, X_0)||p(x_{t-1}|X_t))]\
&=mathbb{E}_{x_0, x_t}[frac{1}{2sigma_t^2}||tildemu_t(X_t,X_0)-mu_{theta}(X_t,t)||^2] + C
end{aligned}
Hàm μθ(xt,t)mu_{theta}(x_t,t) dự đoán kì vọng μ~(xt,x0)=αˉt−1βt1−αˉtx0+αt(1−αˉt−1)1−αˉtxttildemu(x_t,x_0)=frac{sqrt{baralpha_{t-1}}beta_t}{1-baralpha_t}x_0 +frac{sqrt{alpha_t}(1-baralpha_{t-1})}{1-baralpha_{t}}x_t của q(xt−1∣xt,x0)q(x_{t-1}|x_t, x_0) khi biết xtx_t và tt. Điều này tương đương với việc dự đoán x0x_0 khi biết xtx_t. Tuy nhiên, từ thực nghiệm, tác giả thấy việc tham số như vậy không đưa ra kết quả tốt. Từ xác suất chuyển của quá trình thuận, chúng ta có xt(x0,ϵ)=αtˉx0+1−αtˉϵx_t(x_0,epsilon) = sqrt{bar{alpha_t}}x_0+ sqrt{1-bar{alpha_t}}epsilon, trong đó ϵ∼N(0,I)epsilonsimmathcal{N}(0,I). Nói cách khác, trong quá trình thuận, x0x_0 có thể được tham số bởi xt(x0,ϵ)x_t(x_0,epsilon) và một biến ngẫu nhiên độc lập ϵepsilon thông qua x0=1αˉt(xt−1−αˉtϵ)x_0=frac{1}{sqrt{bar alpha_t}}(x_t-sqrt{1-baralpha_t}epsilon). Như vậy, thay vì đoán x0x_0 khi biết xtx_t, chúng ta có thể xây dựng mô hình ϵθ(xt,t)epsilon_{theta}(x_t,t) đoán nhiễu ϵepsilon khi biết xtx_t (đây là lí do cho từ denoise trong tên gọi).
Từ cách tham số này, chúng ta có thể thay vào μ~(xt,x0)tildemu(x_t,x_0) để được
μ~(xt,x0)=αt(1−αˉt−1)1−αˉtxt+αˉt−1βt1−αˉt(1αˉt(xt−1−αˉtϵ))=1αt(xt−βt1−αˉtϵ)begin{aligned}
tildemu(x_t,x_0)&=frac{sqrt{alpha_t}(1-baralpha_{t-1})}{1-baralpha_{t}}x_t+frac{sqrt{baralpha_{t-1}}beta_t}{1-baralpha_t}(frac{1}{sqrt{bar alpha_t}}(x_t-sqrt{1-baralpha_t}epsilon))\
&=frac{1}{sqrt{alpha_t}}(x_t-frac{beta_t}{sqrt{1-baralpha_t}}epsilon)
end{aligned}
Tương tự như vậy, μθ(xt,t)mu_{theta}(x_t,t) lúc này sẽ được tham số như sau
μθ(xt,t)=1αt(xt−βt1−αˉtϵθ(xt,t))mu_{theta}(x_t,t)=frac{1}{sqrt{alpha_t}}(x_t-frac{beta_t}{sqrt{1-baralpha_t}}epsilon_{theta}(x_t,t))
Nhắc lại, trong quá trình huấn luyện (quá trình thuận), xtx_t có thể tính từ x0x_0 thông qua xt(x0,ϵ)=αtˉx0+1−αtˉϵx_t(x_0,epsilon) = sqrt{bar{alpha_t}}x_0+ sqrt{1-bar{alpha_t}}epsilon. Lúc này, hàm mục tiêu sẽ trở thành
Lt−1=Ex0,ϵ[βt22σt2αt(1−αtˉ)∣∣ϵ−ϵθ(αtˉx0+1−αtˉϵ,t)∣∣2]L_{t-1}=mathbb{E}_{x_0,epsilon}[frac{beta_t^2}{2sigma_t^2alpha_t(1-bar{alpha_t})}||epsilon-epsilon_{theta}(sqrt{bar{alpha_t}}x_0+ sqrt{1-bar{alpha_t}}epsilon,t)||^2]
và hàm mục tiêu cho tại toàn bộ vị trí sẽ là L=Et[Lt−1]L=mathbb{E}_t[L_{t-1}] với tt tuân theo phân bố đều U{1,T}mathcal{U}{1,T}.
Để cho đơn giản, chúng ta có thể tối ưu với phiên bản không trọng số của hàm mục tiêu bên trên
L=Ex0,ϵ,t[∣∣ϵ−ϵθ(αtˉx0+1−αtˉϵ,t)∣∣2]L=mathbb{E}_{x_0,epsilon,t}[||epsilon-epsilon_{theta}(sqrt{bar{alpha_t}}x_0+ sqrt{1-bar{alpha_t}}epsilon,t)||^2]
Lấy mẫu
Thay vì mô hình trực tiếp kì vọng của p(xt−1∣xt)p(x_{t-1}|x_t), chúng ta đã mô hình nhiễu ϵθ(xt,t)epsilon_{theta}(x_t,t). Do đó, ở bước lấy mẫu, giả sử đã biết xtx_t, chúng ta sẽ tính lại kì vọng này qua công thức
μ~(xt,t)=1αt(xt−βt1−αˉtϵθ(xt,t))tilde mu(x_t,t)=frac{1}{sqrt{alpha_t}}(x_t-frac{beta_t}{sqrt{1-baralpha_t}}epsilon_{theta}(x_t,t))
Lúc này xt−1x_{t-1} sẽ được tính bởi
xt−1=μ~(xt,t)+σtz, z∼N(0,I)x_{t-1}=tildemu(x_t,t)+sigma_t z,,zsimmathcal{N}(0,I)
Bắt đầu từ xT∼N(0,I)x_Tsim mathcal{N}(0,I), chúng ta thực hiện tuần tự TT bước đến khi tìm được x0x_0.
Mô hình SDE tổng quát
Liên hệ giữa mô hình diffusion và score matching
Hàm mục tiêu của mô hình denoise diffusion có thể xem như denoise score matching. Với phân bố q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)q(x_t|x_0) = mathcal{N}(x_t; sqrt{bar{alpha_t}}x_0, (1-bar{alpha_t})I), score ∇logq(xt∣x0)nablalog q(x_t|x_0) của phân bố này sẽ là αˉtx0−xt1−αˉtfrac{sqrt{baralpha_t}x_0-x_t}{1-baralpha_t}. Chú ý αˉtx0−xt1−αˉt∼N(0,I)frac{sqrt{baralpha_t}x_0-x_t}{sqrt{1-baralpha_t}}simmathcal{N}(0,I), nếu ta thay biến ngẫu nhiên này cho ϵepsilon trong thành phần Lt−1L_{t-1} của hàm mục tiêu không trọng số trong mô hình denoise diffusion, ta có
Lt−1=Ex0,xt[∣∣1−αˉt∇logq(xt∣x0)−ϵθ(xt,t)∣∣2]=(1−αˉt)Ex0,xt[∣∣∇logq(xt∣x0)−sθ(xt,t)∣∣2]begin{aligned}
L_{t-1}&=mathbb{E}_{x_0,x_t}[||sqrt{1-baralpha_t}nablalog q(x_t|x_0)-epsilon_{theta}(x_t,t)||^2]\
&=(1-baralpha_t)mathbb{E}_{x_0,x_t}[||nablalog q(x_t|x_0)-s_{theta}(x_t,t)||^2]
end{aligned}
với sθ(xt,t)=−ϵθ(xt,t)1−αˉts_{theta}(x_t,t)=-frac{epsilon_{theta}(x_t,t)}{sqrt{1-baralpha_t}}. Lúc này, hàm mục tiêu sẽ là
L=∑t=1T(1−αˉt)Ex0,xt[∣∣∇logq(xt∣x0)−sθ(xt,t)∣∣2]L=sum_{t=1}^T(1-baralpha_t)mathbb{E}_{x_0,x_t}[||nablalog q(x_t|x_0)-s_{theta}(x_t,t)||^2]
Đây chính là hàm mục tiêu của NCSN với trọng số (1−αˉt)(1-baralpha_t) khi sử dụng denoise score matching. Tương tự như NCSN, trọng số (1−αˉt)(1-baralpha_t) có tính chất (1−αˉt)∝1/E[∣∣∇logq(xt∣x0)∣∣2](1-baralpha_t)propto1/mathbb{E}[||nablalog q(x_t|x_0)||^2]. Cách nhìn này cho thấy sự liên hệ giữa phương pháp score matching và mô hình diffusion, đó là thay đổi phân bố dữ liệu bằng một họ các nhiễu, và học mô hình khử nhiễu lần lượt. Từ đây, ta có thể tổng quát cả hai phương pháp này, bằng cách mô hình họ các nhiễu bởi quá trình ngẫu nhiên liên tục, biểu diễn bởi một phương trình vi phân ngẫu nhiên (SDE).
Mô hình với SDE
Cụ thể hơn, với phân bố dữ liệu p0p_0 ban đầu, ta mong muốn biến đổi nó thành một phân bố đơn giản pTp_T, theo nghĩa có thể lấy mẫu một cách dễ dàng, ví dụ như N(0,I)mathcal{N}(0,I) trong mô hình diffusion. Nói cách khác, ta cần một quá trình ngẫu nhiên Xt{X_t} với t∈[0,T]tin[0,T] sao cho p(x0)=p0,p(xT)=pTp(x_0)=p_0, p(x_T)=p_T. Quá trình ngẫu nhiên này có thể mô tả bởi phương trình vi phân ngẫu nhiên Itô (từ bây giờ khi nhắc đến SDE, chúng ta sẽ hiểu đó là Itô SDE)
dxt=f(x,t)dt+g(t)dwdx_t=f(x,t)dt + g(t)dw
trong đó f(x,t):Rd×R+↦Rdf(x,t):mathbb{R}^dtimesmathbb{R}^+mapsto mathbb{R}^d, g(t):R+↦Rg(t):mathbb{R}^+mapstomathbb{R}, dwdw kí hiệu một cách hình thức vi phân của chuyển động Brown. Một cách trực quan, dw=N(0,Δt)dw=mathcal{N}(0,Delta t) với Δt→0Delta tto 0. Để cho đơn giản, chúng ta chỉ xét g(t)g(t) có dạng trên, tuy nhiên tất cả kết quả bên dưới đều có thể mở rộng cho hàm g(t)g(t) trả về ma trận.
SDE của mô hình diffusion
Nhắc lại quá trình thuận của mô hình diffusion có thể được mô tả bởi quá trình ngẫu nhiên {xt}t=0T{x_t}_{t=0}^T. Giả sử chúng ta dùng σt2=βtsigma_t^2=beta_t, chuỗi Markov có dạng
xt=1−βtxt−1+βtzt−1, z∼N(0,1)x_t=sqrt{1-beta_t}x_{t-1}+sqrt{beta_t}z_{t-1},, zsimmathcal{N}(0,1)
Quá trình ngẫu nhiên này có thể xem như rời rạc của một quá trình ngẫu nhiên liên tục, chúng ta sẽ tìm quá trình này bằng cách cho T→∞Ttoinfty. Đặt βˉt=Tβtbarbeta_t=Tbeta_t, chuỗi này sẽ tiến về một hàm β(t):[0,1]↦Rbeta(t):[0,1]mapstomathbb{R}, β(tT)=βˉtbeta(frac{t}{T})=barbeta_t. Tương tự quá trình ngẫu nhiên của xix_i và ziz_i cũng tiến tới quá trình ngẫu nhiên liên tục x(tT)=xt,z(tT)=ztx(frac{t}{T})=x_t, z(frac{t}{T})=z_t. Đặt Δt=t/TDelta t=t/T, dùng khai triển Taylor bậc 1, phương trình trên có thể viết lại thành
x(t+Δt)=1−β(t+Δt)Δtx(t)+β(t+Δt)Δtz(t)≈x(t)−12β(t)Δtx(t)+β(t)Δtz(t)begin{aligned}
x(t+Delta t)&=sqrt{1-beta(t+Delta t)Delta t}x(t)+sqrt{beta(t+Delta t)Delta t}z(t)\
&approx x(t)-frac{1}{2}beta(t)Delta tx(t)+sqrt{beta(t)Delta t}z(t)
end{aligned}
Khi Δt→0Delta tto 0, phương trình này hội tụ tới SDE
dxt=−12β(t)xtdt+β(t)dwdx_t=-frac{1}{2}beta(t)x_tdt+sqrt{beta(t)}dw
SDE của mô hình NCSN
Nhắc lại, mô hình NCSN thêm lần lượt nhiễu với phương sai {σt}t=1N{sigma_t}_{t=1}^N vào phân bố dữ liệu. Quá trình này có thể viết lại thành
xt=xt−1−σt2−σt−12z,z∼N(0,I)x_{t}=x_{t-1}-sqrt{sigma_{t}^2-sigma_{t-1}^2}z,qquad zsimmathcal{N}(0,I)
với σ0=0sigma_0=0.
Lập luận tương tự như trên, ta có thể tính giới hạn khi N→∞Ntoinfty
x(t+Δt)=x(t)+σ2(t+Δt)−σ2(t)z(t)≈dσ2(t)dtΔtz(t)x(t+Delta t)=x(t)+sqrt{sigma^2(t+Delta t)-sigma^2(t)}z(t)approxsqrt{frac{d sigma^2(t)}{dt}Delta t}z(t)
sử dụng khai triển Taylor bậc 1 của σ2(t)sigma^2(t). Khi Δt→0Delta tto 0, chuỗi xtx_t hội tụ tới quá trình ngẫu nhiên mô tả bởi
dxt=dσ2(t)dtdwdx_t=sqrt{frac{d sigma^2(t)}{dt}}dw
Lấy mẫu
Việc lấy mẫu tương đương với đảo chiều thời gian của quá trình ngẫu nhiên. Quá trình nghịch này được mô tả bởi SDE sau
dxt=(f(x,t)−g(t)2∇xtlogpt(xt))dt+g(t)dwˉdx_t=(f(x,t)-g(t)^2nabla_{x_t}log p_t(x_t))dt + g(t)dbar w
ở đây wˉbar w là chuyển động Brown theo chiều ngược lại, từ TT về 00.
Nếu biết được score của pt(x)p_t(x), chúng ta có thể mô phỏng lại quá trình ngược này. Bắt đầu từ xT∼pTx_Tsim p_T, từ phương trình trên, chúng ta sẽ biến đổi xTx_T thành x0x_0 tuân theo phân bố p0p_0 của dữ liệu. Như vậy, mục tiêu của chúng ta là xây dựng mô hình sθ(x(t),t)s_{theta}(x(t),t) xấp xỉ ∇xtlogpt(xt))nabla_{x_t}log p_t(x_t)).
Giải SDE
Quá trình lấy mẫu được thực hiện bằng cách giải phương trình SDE nghịch. Tương tự như khi rời rạc hóa quá trình thuận, chúng ta có thể giải bằng cách rời rạc hóa quá trình nghịch
xt=xt+1−ft+1(xt+1)+gt+12sθ(xi+1,i+1)+gt+1z, z∼N(0,I)x_t=x_{t+1}-f_{t+1}(x_{t+1})+g_{t+1}^2s_{theta}(x_{i+1},i+1)+g_{t+1}z,,zsimmathcal{N}(0,I)
Quay lại với cách cập nhật của mô hình denoise diffusion, giả sử ta dùng σt2=βtsigma_t^2=beta_t
xt−1=1αt(xt−βt1−αˉtϵθ(xt,t))+βtzx_{t-1} = frac{1}{sqrt{alpha_t}}(x_t-frac{beta_t}{sqrt{1-baralpha_t}}epsilon_{theta}(x_t,t)) + sqrt{beta_t}z
Đặt sθ(xt,t)=−ϵθ(xt,t)1−αˉts_{theta}(x_t,t)=-frac{epsilon_{theta}(x_t,t)}{sqrt{1-baralpha_t}}, ta có thể biến đổi như sau
xt−1=11−βt(xt+βts(xt,t))+βtz≈(1+12βt)(xt+βts(xt,t))+βtzkhai triển Taylor=(1+12βt)xt+βts(xt,t)+12βt2s(xt,t)+βtz≈xt+12βtxt+βts(xt,t)+βtzbegin{aligned}
x_{t-1}&=frac{1}{sqrt{1-beta_t}}(x_t+beta_ts(x_t,t))+sqrt{beta_t}z\
&approx (1+frac{1}{2}beta_t)(x_t+beta_ts(x_t,t)) +sqrt{beta_t}zqquad text{khai triển Taylor}\
&=(1+frac{1}{2}beta_t)x_t+ beta_ts(x_t,t)+frac{1}{2}beta_t^2s(x_t,t)+sqrt{beta_t}z\
&approx x_t+frac{1}{2}beta_tx_t+ beta_ts(x_t,t)+sqrt{beta_t}z
end{aligned}
Quá trình nghịch của SDE ứng với mô hình diffusion là
dxt=(−12β(t)xt−βt∇xlogpt(xt))dt+β(t)dwdx_t=(-frac{1}{2}beta(t)x_t-beta_tnabla_xlog p_t(x_t))dt+sqrt{beta(t)}dw
Ta có thể thấy thuật toán lấy mẫu của mô hình denoise diffusion gần giống với việc giải quá trình nghịch thông qua rời rạc hóa.
Lấy mẫu với Predictor-Corrector
Ở phần trước, ta đã biết quá trình lấy mẫu có thể thực hiện bằng việc giải phương trình SDE nghịch, và thuật toán lấy mẫu của mô hình diffusion thuộc loại này. Mặt khác, ta đang mô hình score của pt(xt)p_t(x_t), do đó ta cũng có thể lấy mẫu với (annealed) Langevin dynamics.
Để có thể sinh dữ liệu tốt hơn, chúng ta có thể kết hợp hai phương pháp này. Lấy mẫu thông qua giải SDE sẽ được xem như thuật toán chính, gọi là Predictor. Ở bước thứ ii trong Predictor, sau khi cập nhật xT−ix_{T-i} qua xT−i+1x_{T-i+1}, chúng ta sẽ thực hiện Langevin dynamics MM lần với s(xT−i,T−i)s(x_{T-i},T-i)
xT−i=xT−i+ϵis(xT−i,T−i)+2ϵiz, z∼N(0,I)x_{T-i}=x_{T-i}+epsilon_i s(x_{T-i},T-i) + sqrt{2epsilon_i}z,,zsimmathcal{N}(0,I)
Từ góc nhìn này, cách sinh dữ liệu của NCSN có thể xem như Predictor là hàm đồng nhất, Corrector là Langevin dynamics, cách sinh dữ liệu của mô hình denoise diffusion có thể xem như Predictor là giải quá trình nghịch, Corrector là hàm đồng nhất.
Huấn luyện
Tương tự như hàm mục tiêu của NCSN cũng như mô hình denoise diffusion, hàm mục tiêu của mô hình SDE sẽ có dạng score matching trên tất cả mức độ nhiễu. Điểm khác biệt là biến thời gian tt lúc này là biến ngẫu nhiên liên tục tuân theo phân bố đều U[0,1]mathcal{U}[0,1]
L=Et[λ(t)Ex0,xt[∣∣∇logp(xt∣x0)−sθ(xt,t)∣∣2]]L=mathbb{E}_t[lambda(t)mathbb{E}_{x_0,x_t}[||nablalog p(x_t|x_0)-s_{theta}(x_t,t)||^2]]
Ở đây λ(t)lambda(t) là hàm trọng số, có thể chọn giống như NCSN và mô hình denoise diffusion là λ(t)∝1/E[∣∣∇logq(xt∣x0)∣∣2]lambda(t)propto1/mathbb{E}[||nablalog q(x_t|x_0)||^2].
Việc tính hàm mất mát yêu cầu score của phân bố chuyển trong quá trình thuận. Đối với trường hợp SDE tổng quát, ta cần giải phương trình Kolmogorov tiến để tìm phân bố này. Khi f(x,t)=a(t)x+b(t)f(x,t)=a(t)x+b(t), phân bố chuyển là phân bố Gaussian, do đó chỉ cần biết kì vọng và phương sai để tính score. Kì vọng mtm_t và ma trận hiệp phương sai PtP_t sẽ thỏa mãn phương trình vi phân sau
dmtdt=a(t)mt+b(t)frac{dm_t}{dt}=a(t)m_t+b(t)
dPtdt=2a(t)Pt+g(t)2frac{dP_t}{dt}=2a(t)P_t+g(t)^2
Để tránh việc phải tính phân bố chuyển, chúng ta có thể dùng phương pháp score matching khác, ví dụ như sliced score matching, với hàm mục tiêu
L=Et[λ(t)Ex0ExtEv[12∣∣sθ(xt,t)∣∣2+v⊺Js(.,t)(xt)v]]L=mathbb{E}_t[lambda(t)mathbb{E}_{x_0}mathbb{E}_{x_t}mathbb{E}_v[frac{1}{2}||s_{theta}(x_t,t)||^2+v^intercal J_{s(.,t)}(x_t)v]]
với Js(.,t)(xt)J_{s(.,t)}(x_t) là ma trận Jacobian của s(xt,t)s(x_t,t), v⊺Js(.,t)(xt)vv^intercal J_{s(.,t)}(x_t)v tính bởi v⊺∇(v⊺s(xt,t))v^intercalnabla(v^intercal s(x_t,t)).
Kết luận
Trong bài này, mình đã giới thiệu mô hình diffusion và mô hình dạng SDE tổng quát mà trong đó score matching và mô hình diffusion là trường hợp đặc biệt. Cách tiếp cận này hiện đã cho kết quả tốt nhất hiện tại cho mô hình sinh.
Tuy nhiên cách tiếp cận này có các nhược điểm sau: Các trạng thái có cùng số chiều, do đó việc mô hình quá trình nghịch cần đảm bảo điều đó chứ không thể thay đổi số chiều dữ liệu. Việc lấy mẫu tốn khá nhiều thời gian, do cần phải đi từng bước để giải phương trình SDE nghịch, chưa tính đến việc kết hợp với Corrector trong quá trình lấy mẫu.
Tham khảo
- Deep Unsupervised Learning using
Nonequilibrium Thermodynamics
- Denoising Diffusion Probabilistic Models
- SCORE-BASED GENERATIVE MODELING THROUGH
STOCHASTIC DIFFERENTIAL EQUATIONS
- Blog của tác giả
- Công thức posterior của quá trình thuận DIFFWAVE: A VERSATILE DIFFUSION MODEL FOR
AUDIO SYNTHESIS - Tham khảo thêm về SDE
Một số định nghĩa và chứng minh chi tiết
Công thức các phân bố trong quá trình thuận của mô hình diffusion
Tính chất: Với q(xt∣xt−1)=N(xt;αtxt−1,(1−αt)I)q(x_t|x_{t-1})=mathcal{N}(x_t;sqrt{alpha_t}x_{t-1}, (1-alpha_t)I), ta có q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)q(x_t|x_0) = mathcal{N}(x_t; sqrt{bar{alpha_t}}x_0, (1-bar{alpha_t})I),
trong đó αˉt=∏i=1tαibaralpha_t=prod_{i=1}^talpha_i.
Chứng minh:
Quá trình Markov thỏa mãn tính chất sau
Mệnh đề: Với t1>t2>t3t_1>t_2>t_3, xác suất chuyển thỏa mãn phương trình Chapman-Kolmogorov
pt3t1(xt1∣xt3)=∫pt3t2(xt2∣xt3)pt2t1(xt1∣xt2)dxt2p_{t_3t_1}(x_{t_1}|x_{t_3})=int p_{t_3t_2}(x_{t_2}|x_{t_3})p_{t_2t_1}(x_{t_1}|x_{t_2})dx_{t_2}
Tính chất trên có thể chứng minh dễ dàng bằng tính chất Markov.
Chúng ta chỉ cần chứng minh cho t=2t=2, các trường hợp còn lại có thể suy ra theo quy nạp. Hơn nữa, ma trận hiệp phương sai có dạng βtIbeta_tI, do đó ta chỉ cần chứng minh cho trường hợp x∈Rxinmathbb{R}.
Từ phương trình Chapman-Kolmogorov, ta có
q(x2∣x0)=∫q(x2∣x1)q(x1∣x0)dx1=12(1−α1)(1−α2)π∫exp(−(x2−α2×1)22(1−α2))exp(−(x1−α1×0)22(1−α1))dx1=12(1−α1)(1−α2)π∫exp(−12((x2−α1α2×0)21−α1α2+(1−α1α2)(x1−α2(1−α1)x2+α1(1−α2)x01−α1α2)2(1−α1)(1−α2)))dx1=12π(1−α1α2)exp(−12(x2−α1α2×0)21−α1α2).begin{aligned}
q(x_2|x_0) &= int q(x_2|x_1)q(x_1|x_0)dx_1\
&=frac{1}{2sqrt{(1-alpha_1)(1-alpha_2)}pi}int exp(-frac{(x_2-sqrtalpha_2x_1)^2}{2(1-alpha_2)})exp(-frac{(x_1-sqrtalpha_1x_0)^2}{2(1-alpha_1)})dx_1\
&=frac{1}{2sqrt{(1-alpha_1)(1-alpha_2)}pi}intexp(-frac{1}{2}(frac{(x_2-sqrt{alpha_1alpha_2}x_0)^2}{1-alpha_1alpha_2} +frac{(1-alpha_1alpha_2)(x_1-frac{sqrtalpha_2(1-alpha_1)x_2+sqrtalpha_1(1-alpha_2)x_0}{1-alpha_1alpha_2})^2}{(1-alpha_1)(1-alpha_2)}))dx_1\
&=frac{1}{sqrt{2pi(1-alpha_1alpha_2)}}exp(-frac{1}{2}frac{(x_2-sqrt{alpha_1alpha_2}x_0)^2}{1-alpha_1alpha_2}).
end{aligned}
□square
Tính chất:q(xt−1∣xt,x0)=N(xt−1;αˉt−1βt1−αˉtx0+αt(1−αˉt−1)1−αˉtxt,β~tI)q(x_{t-1}|x_t, x_0) =mathcal{N}(x_{t-1}; frac{sqrt{baralpha_{t-1}}beta_t}{1-baralpha_t}x_0 +frac{sqrt{alpha_t}(1-baralpha_{t-1})}{1-baralpha_{t}}x_t, tildebeta_tI), với β~t=1−αˉt−11−αˉtβttilde beta_t=frac{1-baralpha_{t-1}}{1-baralpha_t}beta_t
Chứng minh: Tương tự như trên, chúng ta cũng chỉ cần chứng minh cho trường hợp x∈Rxinmathbb{R}.
q(xt−1∣xt,x0)=q(xt∣xt−1)q(xt∣x0)q(xt−1∣x0)=(2πβt)−1/2(2π(1−αˉt−1))−1/2(2π(1−αˉt))1/2exp(−∣∣xt−αtxt−1∣∣22βt−∣∣xt−1−αˉt−1×0∣∣22(1−αˉt−1)+∣∣xt−αˉtx0∣∣22(1−αˉt))=(2πβ~t)−1/2exp(−1β~t∣∣xt−1−αˉt−1βt1−αˉtx0+αt(1−αˉt−1)1−αˉtxt∣∣2).begin{aligned}
q(x_{t-1}|x_t, x_0)&=frac{q(x_t|x_{t-1})q(x_t|x_0)}{q(x_{t-1}|x_0)}\
&=(2pibeta_t)^{-1/2}(2pi(1-baralpha_{t-1}))^{-1/2}(2pi(1-baralpha_t))^{1/2}\
&quadexpleft(-frac{||x_t-sqrt{alpha_t}x_{t-1}||^2}{2beta_t}-frac{||x_{t-1}-sqrt{baralpha_{t-1}}x_0||^2}{2(1-baralpha_{t-1})}+frac{||x_t-sqrt{baralpha_t}x_0||^2}{2(1-baralpha_t)}right)\
&=(2pitildebeta_t)^{-1/2}expleft(-frac{1}{tildebeta_t}||x_{t-1}-frac{sqrt{baralpha_{t-1}}beta_t}{1-baralpha_t}x_0 +frac{sqrt{alpha_t}(1-baralpha_{t-1})}{1-baralpha_{t}}x_t||^2right).
end{aligned}
□square
Một số tính chất của phương trình vi phân ngẫu nhiên
Phương trình vi phân ngẫu nhiên Itô với điều kiện đầu x0=xx_0=x
dxt=f(xt,t)dt+g(t)dwdx_t=f(x_t,t)dt+g(t)dw
là biểu diễn hình thức của phương trình tích phân sau
xt=x+∫0tf(xt,t)dt+∫0tg(t)dwx_t=x+int_0^tf(x_t,t)dt+int_0^tg(t)dw
Tích phân đầu tiên là tích phân Riemann-Stieltjes thông thường. Tuy nhiên ta không thể tính tích phân thứ hai như vậy, do chuyển động Brown không thỏa mãn tính chất bounded variation. Thay vào đó, ta sẽ sử dụng tích phân Itô để tính đại lượng này
∫0tg(t)dw=limK→∞∑k=0K−1g(tk)(wtk+1−wtk)int_0^tg(t)dw=lim_{Kto infty}sum_{k=0}^{K-1}g(t_k)(w_{t_{k+1}}-w_{t_k})
với tk=kΔt,t=KΔtt_k=kDelta t, t=KDelta t.
Từ đây chúng ta có tính chất sau
∑k=0K−1E[g(tk)(wtk+1−wtk)]=∑k=0K−1E[g(tk)]E[wtk+1−wtk]=0sum_{k=0}^{K-1}mathbb{E}[g(t_k)(w_{t_{k+1}}-w_{t_k})]=sum_{k=0}^{K-1}mathbb{E}[g(t_k)]mathbb{E}[w_{t_{k+1}}-w_{t_k}]=0
theo định nghĩa của chuyển động Brown, do đó
E[∫0tg(t)dw]=0mathbb{E}[int_0^tg(t)dw]=0
Với quá trình ngẫu nhiên xtx_t và một hàm tất định u(x,t):Rd×R+↦Ru(x,t):mathbb{R}^dtimes mathbb{R}^+mapstomathbb{R}, chúng ta cũng không thể tính đạo hàm toàn phân du(xt,t)dtfrac{du(x_t,t)}{dt} bằng chain rule như thông thường, thay vào đó chúng ta sẽ dùng công thức Itô
du(xt,t)=∂u(xt,t)∂tdt+∇u(xt,t)⊺dxt+12(dxt)⊺Hxu(xt,t)dxtdu(x_t,t)=frac{partial u(x_t,t)}{partial t}dt+nabla u(x_t,t)^intercal dx_t+frac{1}{2}(dx_t)^intercal H_xu(x_t,t)dx_t
trong đó Hxu(xt,t)H_xu(x_t,t) là ma trận Hessian của uu.
Chứng minh SDE nghịch
Từ công thức Itô, chúng ta có hai phương trình quan trọng. Đó là phương trình Kolmogorov tiến
∂p(xt,t)∂t=−∑i∂fi(xt,t)p(xt,t)∂xi+12∑i,j∂2g(t)2p(xt,t)∂xixjfrac{partial p(x_t,t)}{partial t}=-sum_ifrac{partial f^i(x_t,t)p(x_t,t)}{partial x^i}+frac{1}{2}sum_{i,j}frac{partial^2 g(t)^2p(x_t,t)}{partial x^ix^j}
và phương trình Kolmogorov lùi
−∂pts(xs∣xt)∂t=∑ifi(xt,t)∂pts(xs∣xt)∂x+∑i,jg(t)22∂2pts(xs∣xt)∂x2-frac{partial p_{ts}(x_s|x_t)}{partial t}=sum_if^i(x_t,t)frac{partial p_{ts}(x_s|x_t)}{partial x}+sum_{i,j}frac{g(t)^2}{2}frac{partial^2 p_{ts}(x_s|x_t)}{partial x^2}
với t<st<s, pts(xs∣xt)p_{ts}(x_s|x_t) là xác suất chuyển từ trạng thái xtx_t tại tt sang trạng thái xsx_s tại ss, fi,xif_i,x_i là chỉ số thứ ii của f,xf,x.
Phương trình SDE nghịch có thể suy ngược từ phương trình Kolmogorov như sau: Với xác suất liên hợp xs,xtx_s, x_t, ta có
p(xs,xt)=pts(xs∣xt)p(xt)p(x_s,x_t)=p_{ts}(x_s|x_t)p(x_t)
∂p(xs,xt)∂t=pts(xs∣xt)∂p(xt,t)∂t+p(xt)∂pts(xs∣xt)∂tfrac{partial p(x_s,x_t)}{partial t}= p_{ts}(x_s|x_t)frac{partial p(x_t,t)}{partial t}+ p(x_t)frac{partial p_{ts}(x_s|x_t)}{partial t}
Thay phương trình Kolmogorov vào phương trình trên, ta được
−∂p(xs,xt)∂t=∑i∂fˉi(xt,t)p(xs,xt)∂xi+12∑i,j∂2g(t)2p(xs,xt)∂xixj-frac{partial p(x_s,x_t)}{partial t}=sum_ifrac{partial bar f^i(x_t,t)p(x_s,x_t)}{partial x^i}+frac{1}{2}sum_{i,j}frac{partial^2 g(t)^2p(x_s,x_t)}{partial x^ix^j}
trong đó
fˉ(xt,t)=f(x,t)−g(t)21p(xt)∇pt(xt)=f(x,t)−g(t)2∇logpt(xt)bar f(x_t,t)= f(x,t)-g(t)^2frac{1}{p(x_t)}nabla p_t(x_t)=f(x,t)-g(t)^2nablalog p_t(x_t)
Tích phân cả hai vế cho xsx_s, ta được phương trình Kolmogorov tiến của quá trình nghịch, ứng với SDE
dxt=fˉ(xt,t)dt+g(t)dwˉdx_t=bar f(x_t,t)dt+g(t)dbar w
Kì vọng và phương sai của SDE tuyến tính
Với hàm uu bất kì, từ công thức Itô
du(xt,t)=∂u(xt,t)∂tdt+∇u(xt,t)⊺dxt+12(dxt)⊺Hxu(xt,t)dxt=(∂u(xt,t)∂t+∇u(xt,t)⊺f(xt,t)+12g(t)2∑i,j∂2u∂xi∂xj)dt+∇u g(t)dwbegin{aligned}
du(x_t,t)&=frac{partial u(x_t,t)}{partial t}dt+nabla u(x_t,t)^intercal dx_t+frac{1}{2}(dx_t)^intercal H_xu(x_t,t)dx_t\
&=left(frac{partial u(x_t,t)}{partial t}+nabla u(x_t,t)^intercal f(x_t,t)+frac{1}{2}g(t)^2sum_{i,j}frac{partial^2u}{partial x^ipartial x^j}right)dt + nabla u ,g(t)dw
end{aligned}
Lấy kì vọng hai vế và dùng tính chất kì vọng của tích phân Itô bằng 0, ta có
dE[u]dt=E[∂u∂t]+E[∇u⊺f(xt,t)]+12g(t)2E[∑i,j∂2u∂xi∂xj]frac{dmathbb{E}[u]}{dt}=mathbb{E}[frac{partial u}{partial t}]+mathbb{E}[nabla u^intercal f(x_t,t)]+frac{1}{2}g(t)^2mathbb{E}[sum_{i,j}frac{partial^2u}{partial x^ipartial x^j}]
Thay u=xiu=x^i, ta tính được kì vọng của xix^i, từ đó suy ra kì vọng của xx. Với ma trận hiệp phương sai, ta thay u=xixj−mi(t)mj(t)u=x^ix^j-m^i(t)m^j(t), với m(t)m(t) là kì vọng của xtx_t.
Nguồn: viblo.asia