Mô hình năng lượng (Energy-based models – EBM) và một số cách huấn luyện (2).

Trong bài trước, ta đã biết mô hình năng lượng biểu diễn một phân bố không chuẩn hóa, cụ thể hơn p(x)=exp⁡(−E(x))Zp(x)=frac{exp( -E(x))}{Z} p(x)=Zexp(−E(x))​ Với phân bố p(x)p(x)p(x) như trên, ta sinh dữ liệu bằng phương pháp stochasic gradient Langevin dynamics. Phương pháp này sử dụng gradient tại xxx của log⁡p(x)log p(x)logp(x) để lấy mẫu.

Trong bài trước, ta đã biết mô hình năng lượng biểu diễn một phân bố không chuẩn hóa, cụ thể hơn

p(x)=exp⁡(−E(x))Zp(x)=frac{exp( -E(x))}{Z}

Với phân bố p(x)p(x) như trên, ta sinh dữ liệu bằng phương pháp stochasic gradient Langevin dynamics. Phương pháp này sử dụng gradient tại xx của log⁡p(x)log p(x) để lấy mẫu.

Từ điều này, ta có thể thấy việc học một mô hình năng lượng có thể chuyển thành học ∇xlog⁡p(x)nabla_{x}log p(x), còn được gọi là hàm score, thay vì học tham số θtheta của hàm năng lượng. Quan sát này dẫn tới một lớp các phương pháp học mới, được gọi là score matching.

Mục tiêu của chúng ta là xây dựng một hàm score sθ(x)s_{theta}(x) sao cho xấp xỉ ∇xlog⁡p(x)nabla_{x}log p(x) tốt nhất. Ta sẽ sử dụng Fisher divergence để đo sự khác nhau giữa phân bố p(x)p(x) cần xấp xỉ và phân bố ẩn q(x)q(x) nhận sθ(x)s_{theta}(x) làm hàm score. Giá trị này được tính như sau

F(p,q)=∫Rd∣∣∇log⁡q(x)−∇log⁡p(x)∣∣2p(x)dxF(p, q) = int_{mathbb{R}^d} ||nabla log q(x)-nablalog p(x)||^2p(x)dx

Fisher divergence có thể xem như khoảng cách L2L_2 trung bình giữa hai hàm score.

Tuy nhiên, ta không thể tính trực tiếp giá trị này được, do nó yêu cầu gradient của phân bố p(x)p(x) thật của dữ liệu, trong khi ta chỉ có một lượng mẫu từ phân bố này là dữ liệu huấn luyện.

Phương pháp sliced score matching

Ta có thể thêm ràng buộc của phân bố để biến đổi Fisher divergence về dạng tính được. Khoảng cách này được viết lại như sau

F(p,q)=Ep[∣∣∇log⁡q(X)∣∣2]+Ep[∣∣∇log⁡p(X)∣∣2]−2Ep[∇log⁡p(X)⊺∇log⁡q(X)]begin{aligned}
F(p, q) &= mathbb{E}_p[||nablalog q(X)||^2] + mathbb{E}_p[||nablalog p(X)||^2] – 2mathbb{E}_p[nablalog p(X)^intercalnablalog q(X)]\
end{aligned}

Hạng tử thứ nhất không chứa ∇log⁡p(x)nablalog p(x), hạng tử thứ hai không phụ thuộc vào q(x)q(x), do đó có thể bỏ qua. Ta sẽ biến đổi hạng tử cuối cùng để bỏ đi ∇log⁡p(x)nablalog p(x). Áp dụng chain rule, ta có

Ep[∇log⁡p(X)⊺∇log⁡q(X)]=∑i∫Rdp(x)∂p(x)∂xi1p(x)∂log⁡q(x)∂xidx=∑i∫Rd∂p(x)∂xisi(x)dxbegin{aligned}
mathbb{E}_p[nablalog p(X)^intercalnablalog q(X)] &= sum_iint_{mathbb{R}^d}p(x)frac{partial p(x)}{partial x_i}frac{1}{p(x)}frac{partial log q(x)}{partial x_i}dx \
&= sum_iint_{mathbb{R}^d}frac{partial p(x)}{partial x_i}s_i(x)dx
end{aligned}

với si(x)=∂log⁡q(x)∂xi laˋ chỉ soˆˊ thứ i của s(x)=∇log⁡q(x)text{với }s_i(x)=frac{partial log q(x)}{partial x_i}text{ là chỉ số thứ i của }s(x)=nablalog q(x)

Ở đây ta sẽ dùng tích phân từng phần để loại bỏ đạo hàm riêng của p(x)p(x)

∂p(x)si(x)∂xi=∂p(x)∂xisi(x)+p(x)∂si(x)∂xifrac{partial p(x) s_i(x)}{partial x_i}=frac{partial p(x)}{partial x_i}s_i(x)+p(x)frac{partial s_i(x)}{partial x_i}

Giả sử lim⁡∣∣x∣∣→∞p(x)si(x)=0lim_{||x||toinfty}p(x)s_i(x)=0, ta có

∫R∂p(x)∂xisi(x)dxi+∫Rp(x)∂si(x)∂xidxi=lim⁡xi→∞p(x)si(x)−lim⁡xi→−∞p(x)si(x)=0int_{mathbb{R}}frac{partial p(x)}{partial x_i}s_i(x)dx_i+int_{mathbb{R}}p(x)frac{partial s_i(x)}{partial x_i}dx_i=lim_{x_itoinfty}p(x)s_i(x) – lim_{x_ito-infty}p(x)s_i(x)=0

∫Rd∂p(x)∂xisi(x)dx=∫Rd−1∫R−p(x)∂si(x)∂xidxid(x1…xi−1xi+1…xn)=−∫Rdp(x)∂si(x)∂xidxbegin{aligned}
int_{mathbb{R}^d}frac{partial p(x)}{partial x_i}s_i(x)dx &=int_{mathbb{R}^{d-1}}int_{mathbb{R}}-p(x)frac{partial s_i(x)}{partial x_i}dx_id(x_1…x_{i-1}x_{i+1}…x_n)\
&=-int_{mathbb{R}^d}p(x)frac{partial s_i(x)}{partial x_i}dx
end{aligned}

Tổng hợp lại, Fisher divergence sẽ được viết lại như sau

F(p,q)=Ep[∣∣s(X)∣∣2]+2Ep[∑i∂si(X)∂xi]+c=Ep[∣∣s(X)∣∣2]+2Ep[tr(Js(X))]+cbegin{aligned}
F(p,q)&= mathbb{E}_p[||s(X)||^2]+2mathbb{E}_p[sum_ifrac{partial s_i(X)}{partial x_i}] +c\
&=mathbb{E}_p[||s(X)||^2]+2mathbb{E}_p[tr(J_s(X))] +c
end{aligned}

với JsJ_s là ma trận Jacobian của s(x)s(x).

Công thức này đã không còn ∇log⁡p(x)nabla log p(x), do đó có thể tính toán được. Tuy nhiên điều này cũng chỉ là trên lý thuyết, vì ta sẽ cần phải tính (vết của) ma trận Jacobian, trong khi xx thường có số chiều lớn. Ta có thể xấp xỉ giá trị này bằng cách chiều xuống một vector ngẫu nhiên vv (đây là kĩ thuật Hutchinson). Cụ thể hơn, với vector ngẫu nhiên vv thỏa mãn E[vv⊺]=Imathbb{E}[vv^intercal]=I, ta có

tr(Js)=tr(JsI)=tr(JsE[vv⊺])=E[tr(Jsvv⊺)]=E[v⊺Jsv]tr(J_s)=tr(J_sI)=tr(J_smathbb{E}[vv^intercal])=mathbb{E}[tr(J_svv^intercal)]=mathbb{E}[v^intercal J_sv]

Cách làm này này sẽ giúp tính vết nhanh hơn, cụ thể hơn với vv bất kì, ta có

∇v⊺s(x)=v⊺Js+(∇v)s(x)=v⊺Jsnabla v^intercal s(x)=v^intercal J_s+(nabla v)s(x)=v^intercal J_s

Nếu ta lấy mẫu mm vector vv, ta sẽ cần tính mm lần gradient của v⊺s(x)v^intercal s(x), trong khi với JsJ_s sẽ cần tính gradient dd lần với dd là số chiều của xx. Phương pháp này được gọi là sliced score matching, với hàm mục tiêu lúc này là

L(p,q)=Ep(x)[∣∣s(X)∣∣2]+2Ep(x)Ep(v)[v⊺Jsv]L(p, q)=mathbb{E}_{p(x)}[||s(X)||^2]+2mathbb{E}_{p(x)}mathbb{E}_{p(v)}[v^intercal J_sv]

Phương pháp denoise score matching

Một cách khác để loại bỏ ∇log⁡p(x)nablalog p(x) là cộng thêm nhiễu vào phân bố. Ta đạt được biến ngẫu nhiên mới X~=X+ϵtilde{X}=X+epsilon với ϵepsilon là nhiễu tùy ý, giả sử ϵ∼N(0,σ2)epsilon sim mathcal{N}(0,sigma^2). Phân bố của biến ngẫu nhiên này sẽ là q(x~)=∫q(x~∣x)p(x)dxq(tilde{x})=int q(tilde{x}|x)p(x)dx, trong đó x~∣x∼N(x,σ2)tilde{x}|x simmathcal{N}(x,sigma^2). Với phân bố mới, Fisher divergence được viết lại thành

F(p,q)=Eq(x~)[∣∣sθ(X~)∣∣2]−2Eq(x~)[∇x~log⁡q(X~)⊺sθ(X~)]+c=Eq(x~)[∣∣sθ(X~)∣∣2]−2∫∇x~q(x~)⊺sθ(x~)dx~+c=Eq(x~)[∣∣sθ(X~)∣∣2]−2∫(∫p(x)∇q(x~∣x)dx)⊺sθ(x~)dx~+c=Eq(x~)[∣∣sθ(X~)∣∣2]−2∫∫p(x)q(x~∣x)∇log⁡q(x~∣x)⊺sθ(x~)dxdx~+c=Eq(x~)[∣∣sθ(X~)∣∣2]−2Eq(x~,x)[∇x~log⁡q(X~∣X)⊺sθ(X)]+c=Eq(x~,x)[∣∣sθ(X~)−∇log⁡q(X~∣X)∣∣2]+cbegin{aligned}
F(p,q)&= mathbb{E}_{q(tilde x)}[||s_{theta}(tilde X)||^2] – 2mathbb{E}_{q(tilde x)}[nabla_{tilde x}log q(tilde X)^intercal s_{theta}(tilde X)] +c\
&= mathbb{E}_{q(tilde x)}[||s_{theta}(tilde X)||^2]-2int nabla_{tilde x} q(tilde x)^{intercal}s_{theta}(tilde x)dtilde x +c\
&= mathbb{E}_{q(tilde x)}[||s_{theta}(tilde X)||^2]-2int(int p(x)nabla q(tilde x|x)dx)^intercal s_{theta}(tilde x)dtilde x +c\
&=mathbb{E}_{q(tilde x)}[||s_{theta}(tilde X)||^2]-2intint p(x)q(tilde x|x)nabla log q(tilde x|x)^intercal s_{theta}(tilde x)dx dtilde x+c\
&= mathbb{E}_{q(tilde x)}[||s_{theta}(tilde X)||^2]-2mathbb{E}_{q(tilde x, x)}[nabla_{tilde x}log q(tilde X|X)^intercal s_{theta}(X)]+c\
&=mathbb{E}_{q(tilde x,x)}[||s_{theta}(tilde X)-nabla log q(tilde X|X)||^2]+c
end{aligned}

ở đây cc tượng trưng cho hằng số không phụ thuộc vào s(x)s(x).

Như vậy, ta không cần phải tính ∇log⁡p(x)nablalog p(x) nữa mà chuyển thành tính score của phân bố N(x,σ2)mathcal{N}(x,sigma^2) với công thức là 1σ(x−x~)frac{1}{sigma}(x-tilde x). Tuy nhiên điều này dẫn tới một điểm yếu của phương pháp này. Thứ nhất, ta muốn nhiễu có phương sai không quá lớn, nếu không sẽ làm sai lệch phân bố đi nhiều. Tuy nhiên, khi phương sai của nhiễu nhỏ thì phương sai khi ước lượng sẽ tăng. Cụ thể hơn khi σ→∞sigmatoinfty, s(x~)≈s(x)s(tilde x)approx s(x), đại lượng trên xuất hiện thành phần

(x−x~)2σ2−2(x−x~σ)⊺s(x) frac{(x-tilde x)^2}{sigma^2}-2(frac{x-tilde x}{sigma})^intercal s(x)

có phương sai tiến tới ∞infty khi σ→0sigmato 0. Ta sẽ dùng control variate để giảm phương sai khi ước lượng với hàm

c(x~,x)=(x−x~)2σ2−2(x−x~σ)⊺s(x)−Eq(x~,x)[(x−x~)2σ2−2(x−x~σ)⊺s(x)]=(x−x~)2σ2−2(x−x~σ)⊺s(x)−dσ2begin{aligned}
c(tilde x, x) &= frac{(x-tilde x)^2}{sigma^2}-2(frac{x-tilde x}{sigma})^intercal s(x)-mathbb{E}_{q(tilde x, x)}[frac{(x-tilde x)^2}{sigma^2}-2(frac{x-tilde x}{sigma})^intercal s(x)]\
&=frac{(x-tilde x)^2}{sigma^2}-2(frac{x-tilde x}{sigma})^intercal s(x)-frac{d}{sigma^2}
end{aligned}

trong đó dd là số chiều của x. Với mm mẫu xi,x~ix^i, tilde x^i từ q(x~,x)q(tilde x, x), hàm mục tiêu sẽ được xấp xỉ như sau

1m∑im∣∣sθ(x~i)−xi−x~iσ∣∣2−c(x~i,xi)frac{1}{m}sum_i^m ||s_{theta}(tilde x^i)-frac{x^i-tilde x^i}{sigma}||^2-c(tilde x^i, x^i)

Mối liên hệ giữa MCMC và score matching

Ta đã biết phương pháp MCMC đi tìm phân bố q(x)q(x) có likelihood cao nhất, tương đương với việc tìm phân bố có KL divergence nhỏ nhất với phân bố của dữ liệu p(x)p(x). Trong khi đó score matching đi tìm phân bố có Fisher divergence nhỏ nhất. Do vậy, mối liên hệ giữa hai phương pháp có thể quy về mối liên hệ giữa hai loại khoảng cách này.

Cho biến ngẫu nhiên XXXt=X+tZX_t=X+sqrt tZ với z∼N(0,1)zsimmathcal{N}(0,1), p~,q~tilde p,tilde q là hai luật của XX, p,qp, q là hai luật tương ứng của XtX_t. Giả sử p,qp, q hội tụ về 00 đủ nhanh khi ∣∣x∣∣→∞||x||toinfty, ta có đẳng thức de Bruijn

ddtKL(p∣∣q)=−12F(p,q)frac{d}{dt}KL( p||q)=-frac{1}{2}F( p, q)

Cực tiểu Fisher divergence tương đương với việc tìm phân bố q~tilde q sao cho chênh lệch của KL divergence giữa hai luật trước và sau khi cộng thêm nhiễu là nhỏ nhất. Nói cách khác, score matching đi tìm phân bố có tính ổn định với nhiễu.

Cực tiểu vi phân của KL divergence

Đẳng thức de Bruijin có thể tổng quát như sau: Cho phương trình vi phân ngẫu nhiên

dXt=V(x)dt+βdBtdX_t=V(x)dt+beta dB_t

với p,qp,q là hai luật tại X0X_0, pt,qtp_t, q_t là hai luật tương ứng tại XtX_t. Ta có

ddtKL(pt∣∣qt)=−12F(pt,qt)frac{d}{dt}KL( p_t||q_t)=-frac{1}{2}F( p_t, q_t)

nếu ddtKL(pt∣∣qt)frac{d}{dt}KL( p_t||q_t) tồn tại.
Ta có thể thể thấy đẳng thức ở phần trước là trước là một trường hợp đặc biệt khi dXt=dBtdX_t = dB_t.

Một trường hợp khác là phương trình Langevin như trong bài trước, có phân bố ổn định là pdata(x)p_{data}(x)

ddtKL(pdata∣∣qt)=−12F(pdata,qt)frac{d}{dt}KL( p_{data}||q_t)=-frac{1}{2}F( p_{data}, q_t)

dẫn đến một cách khác để cực tiểu Fisher divergence, đó là cực tiểu tốc độ thay đổi của KL divergence. Do Fisher divergence luôn không âm, đẳng thức này chỉ ra KL divergence luôn giảm, do đó ta có thể dùng một toán tử ϕphi sao cho KL(p∣∣q)≥KL(ϕ(p)∣∣ϕ(q))KL(p||q)geq KL(phi(p)||phi(q)) để mô phỏng toán tử sinh của chu trình ngẫu nhiên. Lúc này, hàm mục tiêu cần cực tiểu sẽ là

KL(p∣∣q)−KL(ϕ(p)∣∣ϕ(q))KL(p||q)- KL(phi(p)||phi(q))

Kết

Trong bài này, chúng ta đã tìm hiểu về họ phương pháp score matching và tính chất của nó. Ở các bài tiếp theo, chúng ta sẽ tiếp tục tìm hiểu về các phương pháp khác để huấn luyện EBM, và các vấn đề khi huấn luyện score-based models.

Tham khảo

Nguồn: viblo.asia

Bài viết liên quan

WebP là gì? Hướng dẫn cách để chuyển hình ảnh jpg, png qua webp

WebP là gì? WebP là một định dạng ảnh hiện đại, được phát triển bởi Google

Điểm khác biệt giữa IPv4 và IPv6 là gì?

IPv4 và IPv6 là hai phiên bản của hệ thống địa chỉ Giao thức Internet (IP). IP l

Check nameservers của tên miền xem website trỏ đúng chưa

Tìm hiểu cách check nameservers của tên miền để xác định tên miền đó đang dùn

Mình đang dùng Google Domains để check tên miền hàng ngày

Từ khi thông báo dịch vụ Google Domains bỏ mác Beta, mình mới để ý và bắt đầ