Trong bài trước, ta đã biết mô hình năng lượng biểu diễn một phân bố không chuẩn hóa, cụ thể hơn
p(x)=exp(−E(x))Zp(x)=frac{exp( -E(x))}{Z}
Với phân bố p(x)p(x) như trên, ta sinh dữ liệu bằng phương pháp stochasic gradient Langevin dynamics. Phương pháp này sử dụng gradient tại xx của logp(x)log p(x) để lấy mẫu.
Từ điều này, ta có thể thấy việc học một mô hình năng lượng có thể chuyển thành học ∇xlogp(x)nabla_{x}log p(x), còn được gọi là hàm score, thay vì học tham số θtheta của hàm năng lượng. Quan sát này dẫn tới một lớp các phương pháp học mới, được gọi là score matching.
Mục tiêu của chúng ta là xây dựng một hàm score sθ(x)s_{theta}(x) sao cho xấp xỉ ∇xlogp(x)nabla_{x}log p(x) tốt nhất. Ta sẽ sử dụng Fisher divergence để đo sự khác nhau giữa phân bố p(x)p(x) cần xấp xỉ và phân bố ẩn q(x)q(x) nhận sθ(x)s_{theta}(x) làm hàm score. Giá trị này được tính như sau
F(p,q)=∫Rd∣∣∇logq(x)−∇logp(x)∣∣2p(x)dxF(p, q) = int_{mathbb{R}^d} ||nabla log q(x)-nablalog p(x)||^2p(x)dx
Fisher divergence có thể xem như khoảng cách L2L_2 trung bình giữa hai hàm score.
Tuy nhiên, ta không thể tính trực tiếp giá trị này được, do nó yêu cầu gradient của phân bố p(x)p(x) thật của dữ liệu, trong khi ta chỉ có một lượng mẫu từ phân bố này là dữ liệu huấn luyện.
Phương pháp sliced score matching
Ta có thể thêm ràng buộc của phân bố để biến đổi Fisher divergence về dạng tính được. Khoảng cách này được viết lại như sau
F(p,q)=Ep[∣∣∇logq(X)∣∣2]+Ep[∣∣∇logp(X)∣∣2]−2Ep[∇logp(X)⊺∇logq(X)]begin{aligned}
F(p, q) &= mathbb{E}_p[||nablalog q(X)||^2] + mathbb{E}_p[||nablalog p(X)||^2] – 2mathbb{E}_p[nablalog p(X)^intercalnablalog q(X)]\
end{aligned}
Hạng tử thứ nhất không chứa ∇logp(x)nablalog p(x), hạng tử thứ hai không phụ thuộc vào q(x)q(x), do đó có thể bỏ qua. Ta sẽ biến đổi hạng tử cuối cùng để bỏ đi ∇logp(x)nablalog p(x). Áp dụng chain rule, ta có
Ep[∇logp(X)⊺∇logq(X)]=∑i∫Rdp(x)∂p(x)∂xi1p(x)∂logq(x)∂xidx=∑i∫Rd∂p(x)∂xisi(x)dxbegin{aligned}
mathbb{E}_p[nablalog p(X)^intercalnablalog q(X)] &= sum_iint_{mathbb{R}^d}p(x)frac{partial p(x)}{partial x_i}frac{1}{p(x)}frac{partial log q(x)}{partial x_i}dx \
&= sum_iint_{mathbb{R}^d}frac{partial p(x)}{partial x_i}s_i(x)dx
end{aligned}
với si(x)=∂logq(x)∂xi laˋ chỉ soˆˊ thứ i của s(x)=∇logq(x)text{với }s_i(x)=frac{partial log q(x)}{partial x_i}text{ là chỉ số thứ i của }s(x)=nablalog q(x)
Ở đây ta sẽ dùng tích phân từng phần để loại bỏ đạo hàm riêng của p(x)p(x)
∂p(x)si(x)∂xi=∂p(x)∂xisi(x)+p(x)∂si(x)∂xifrac{partial p(x) s_i(x)}{partial x_i}=frac{partial p(x)}{partial x_i}s_i(x)+p(x)frac{partial s_i(x)}{partial x_i}
Giả sử lim∣∣x∣∣→∞p(x)si(x)=0lim_{||x||toinfty}p(x)s_i(x)=0, ta có
∫R∂p(x)∂xisi(x)dxi+∫Rp(x)∂si(x)∂xidxi=limxi→∞p(x)si(x)−limxi→−∞p(x)si(x)=0int_{mathbb{R}}frac{partial p(x)}{partial x_i}s_i(x)dx_i+int_{mathbb{R}}p(x)frac{partial s_i(x)}{partial x_i}dx_i=lim_{x_itoinfty}p(x)s_i(x) – lim_{x_ito-infty}p(x)s_i(x)=0
∫Rd∂p(x)∂xisi(x)dx=∫Rd−1∫R−p(x)∂si(x)∂xidxid(x1…xi−1xi+1…xn)=−∫Rdp(x)∂si(x)∂xidxbegin{aligned}
int_{mathbb{R}^d}frac{partial p(x)}{partial x_i}s_i(x)dx &=int_{mathbb{R}^{d-1}}int_{mathbb{R}}-p(x)frac{partial s_i(x)}{partial x_i}dx_id(x_1…x_{i-1}x_{i+1}…x_n)\
&=-int_{mathbb{R}^d}p(x)frac{partial s_i(x)}{partial x_i}dx
end{aligned}
Tổng hợp lại, Fisher divergence sẽ được viết lại như sau
F(p,q)=Ep[∣∣s(X)∣∣2]+2Ep[∑i∂si(X)∂xi]+c=Ep[∣∣s(X)∣∣2]+2Ep[tr(Js(X))]+cbegin{aligned}
F(p,q)&= mathbb{E}_p[||s(X)||^2]+2mathbb{E}_p[sum_ifrac{partial s_i(X)}{partial x_i}] +c\
&=mathbb{E}_p[||s(X)||^2]+2mathbb{E}_p[tr(J_s(X))] +c
end{aligned}
với JsJ_s là ma trận Jacobian của s(x)s(x).
Công thức này đã không còn ∇logp(x)nabla log p(x), do đó có thể tính toán được. Tuy nhiên điều này cũng chỉ là trên lý thuyết, vì ta sẽ cần phải tính (vết của) ma trận Jacobian, trong khi xx thường có số chiều lớn. Ta có thể xấp xỉ giá trị này bằng cách chiều xuống một vector ngẫu nhiên vv (đây là kĩ thuật Hutchinson). Cụ thể hơn, với vector ngẫu nhiên vv thỏa mãn E[vv⊺]=Imathbb{E}[vv^intercal]=I, ta có
tr(Js)=tr(JsI)=tr(JsE[vv⊺])=E[tr(Jsvv⊺)]=E[v⊺Jsv]tr(J_s)=tr(J_sI)=tr(J_smathbb{E}[vv^intercal])=mathbb{E}[tr(J_svv^intercal)]=mathbb{E}[v^intercal J_sv]
Cách làm này này sẽ giúp tính vết nhanh hơn, cụ thể hơn với vv bất kì, ta có
∇v⊺s(x)=v⊺Js+(∇v)s(x)=v⊺Jsnabla v^intercal s(x)=v^intercal J_s+(nabla v)s(x)=v^intercal J_s
Nếu ta lấy mẫu mm vector vv, ta sẽ cần tính mm lần gradient của v⊺s(x)v^intercal s(x), trong khi với JsJ_s sẽ cần tính gradient dd lần với dd là số chiều của xx. Phương pháp này được gọi là sliced score matching, với hàm mục tiêu lúc này là
L(p,q)=Ep(x)[∣∣s(X)∣∣2]+2Ep(x)Ep(v)[v⊺Jsv]L(p, q)=mathbb{E}_{p(x)}[||s(X)||^2]+2mathbb{E}_{p(x)}mathbb{E}_{p(v)}[v^intercal J_sv]
Phương pháp denoise score matching
Một cách khác để loại bỏ ∇logp(x)nablalog p(x) là cộng thêm nhiễu vào phân bố. Ta đạt được biến ngẫu nhiên mới X~=X+ϵtilde{X}=X+epsilon với ϵepsilon là nhiễu tùy ý, giả sử ϵ∼N(0,σ2)epsilon sim mathcal{N}(0,sigma^2). Phân bố của biến ngẫu nhiên này sẽ là q(x~)=∫q(x~∣x)p(x)dxq(tilde{x})=int q(tilde{x}|x)p(x)dx, trong đó x~∣x∼N(x,σ2)tilde{x}|x simmathcal{N}(x,sigma^2). Với phân bố mới, Fisher divergence được viết lại thành
F(p,q)=Eq(x~)[∣∣sθ(X~)∣∣2]−2Eq(x~)[∇x~logq(X~)⊺sθ(X~)]+c=Eq(x~)[∣∣sθ(X~)∣∣2]−2∫∇x~q(x~)⊺sθ(x~)dx~+c=Eq(x~)[∣∣sθ(X~)∣∣2]−2∫(∫p(x)∇q(x~∣x)dx)⊺sθ(x~)dx~+c=Eq(x~)[∣∣sθ(X~)∣∣2]−2∫∫p(x)q(x~∣x)∇logq(x~∣x)⊺sθ(x~)dxdx~+c=Eq(x~)[∣∣sθ(X~)∣∣2]−2Eq(x~,x)[∇x~logq(X~∣X)⊺sθ(X)]+c=Eq(x~,x)[∣∣sθ(X~)−∇logq(X~∣X)∣∣2]+cbegin{aligned}
F(p,q)&= mathbb{E}_{q(tilde x)}[||s_{theta}(tilde X)||^2] – 2mathbb{E}_{q(tilde x)}[nabla_{tilde x}log q(tilde X)^intercal s_{theta}(tilde X)] +c\
&= mathbb{E}_{q(tilde x)}[||s_{theta}(tilde X)||^2]-2int nabla_{tilde x} q(tilde x)^{intercal}s_{theta}(tilde x)dtilde x +c\
&= mathbb{E}_{q(tilde x)}[||s_{theta}(tilde X)||^2]-2int(int p(x)nabla q(tilde x|x)dx)^intercal s_{theta}(tilde x)dtilde x +c\
&=mathbb{E}_{q(tilde x)}[||s_{theta}(tilde X)||^2]-2intint p(x)q(tilde x|x)nabla log q(tilde x|x)^intercal s_{theta}(tilde x)dx dtilde x+c\
&= mathbb{E}_{q(tilde x)}[||s_{theta}(tilde X)||^2]-2mathbb{E}_{q(tilde x, x)}[nabla_{tilde x}log q(tilde X|X)^intercal s_{theta}(X)]+c\
&=mathbb{E}_{q(tilde x,x)}[||s_{theta}(tilde X)-nabla log q(tilde X|X)||^2]+c
end{aligned}
ở đây cc tượng trưng cho hằng số không phụ thuộc vào s(x)s(x).
Như vậy, ta không cần phải tính ∇logp(x)nablalog p(x) nữa mà chuyển thành tính score của phân bố N(x,σ2)mathcal{N}(x,sigma^2) với công thức là 1σ(x−x~)frac{1}{sigma}(x-tilde x). Tuy nhiên điều này dẫn tới một điểm yếu của phương pháp này. Thứ nhất, ta muốn nhiễu có phương sai không quá lớn, nếu không sẽ làm sai lệch phân bố đi nhiều. Tuy nhiên, khi phương sai của nhiễu nhỏ thì phương sai khi ước lượng sẽ tăng. Cụ thể hơn khi σ→∞sigmatoinfty, s(x~)≈s(x)s(tilde x)approx s(x), đại lượng trên xuất hiện thành phần
(x−x~)2σ2−2(x−x~σ)⊺s(x) frac{(x-tilde x)^2}{sigma^2}-2(frac{x-tilde x}{sigma})^intercal s(x)
có phương sai tiến tới ∞infty khi σ→0sigmato 0. Ta sẽ dùng control variate để giảm phương sai khi ước lượng với hàm
c(x~,x)=(x−x~)2σ2−2(x−x~σ)⊺s(x)−Eq(x~,x)[(x−x~)2σ2−2(x−x~σ)⊺s(x)]=(x−x~)2σ2−2(x−x~σ)⊺s(x)−dσ2begin{aligned}
c(tilde x, x) &= frac{(x-tilde x)^2}{sigma^2}-2(frac{x-tilde x}{sigma})^intercal s(x)-mathbb{E}_{q(tilde x, x)}[frac{(x-tilde x)^2}{sigma^2}-2(frac{x-tilde x}{sigma})^intercal s(x)]\
&=frac{(x-tilde x)^2}{sigma^2}-2(frac{x-tilde x}{sigma})^intercal s(x)-frac{d}{sigma^2}
end{aligned}
trong đó dd là số chiều của x. Với mm mẫu xi,x~ix^i, tilde x^i từ q(x~,x)q(tilde x, x), hàm mục tiêu sẽ được xấp xỉ như sau
1m∑im∣∣sθ(x~i)−xi−x~iσ∣∣2−c(x~i,xi)frac{1}{m}sum_i^m ||s_{theta}(tilde x^i)-frac{x^i-tilde x^i}{sigma}||^2-c(tilde x^i, x^i)
Mối liên hệ giữa MCMC và score matching
Ta đã biết phương pháp MCMC đi tìm phân bố q(x)q(x) có likelihood cao nhất, tương đương với việc tìm phân bố có KL divergence nhỏ nhất với phân bố của dữ liệu p(x)p(x). Trong khi đó score matching đi tìm phân bố có Fisher divergence nhỏ nhất. Do vậy, mối liên hệ giữa hai phương pháp có thể quy về mối liên hệ giữa hai loại khoảng cách này.
Cho biến ngẫu nhiên XX và Xt=X+tZX_t=X+sqrt tZ với z∼N(0,1)zsimmathcal{N}(0,1), p~,q~tilde p,tilde q là hai luật của XX, p,qp, q là hai luật tương ứng của XtX_t. Giả sử p,qp, q hội tụ về 00 đủ nhanh khi ∣∣x∣∣→∞||x||toinfty, ta có đẳng thức de Bruijn
ddtKL(p∣∣q)=−12F(p,q)frac{d}{dt}KL( p||q)=-frac{1}{2}F( p, q)
Cực tiểu Fisher divergence tương đương với việc tìm phân bố q~tilde q sao cho chênh lệch của KL divergence giữa hai luật trước và sau khi cộng thêm nhiễu là nhỏ nhất. Nói cách khác, score matching đi tìm phân bố có tính ổn định với nhiễu.
Cực tiểu vi phân của KL divergence
Đẳng thức de Bruijin có thể tổng quát như sau: Cho phương trình vi phân ngẫu nhiên
dXt=V(x)dt+βdBtdX_t=V(x)dt+beta dB_t
với p,qp,q là hai luật tại X0X_0, pt,qtp_t, q_t là hai luật tương ứng tại XtX_t. Ta có
ddtKL(pt∣∣qt)=−12F(pt,qt)frac{d}{dt}KL( p_t||q_t)=-frac{1}{2}F( p_t, q_t)
nếu ddtKL(pt∣∣qt)frac{d}{dt}KL( p_t||q_t) tồn tại.
Ta có thể thể thấy đẳng thức ở phần trước là trước là một trường hợp đặc biệt khi dXt=dBtdX_t = dB_t.
Một trường hợp khác là phương trình Langevin như trong bài trước, có phân bố ổn định là pdata(x)p_{data}(x)
ddtKL(pdata∣∣qt)=−12F(pdata,qt)frac{d}{dt}KL( p_{data}||q_t)=-frac{1}{2}F( p_{data}, q_t)
dẫn đến một cách khác để cực tiểu Fisher divergence, đó là cực tiểu tốc độ thay đổi của KL divergence. Do Fisher divergence luôn không âm, đẳng thức này chỉ ra KL divergence luôn giảm, do đó ta có thể dùng một toán tử ϕphi sao cho KL(p∣∣q)≥KL(ϕ(p)∣∣ϕ(q))KL(p||q)geq KL(phi(p)||phi(q)) để mô phỏng toán tử sinh của chu trình ngẫu nhiên. Lúc này, hàm mục tiêu cần cực tiểu sẽ là
KL(p∣∣q)−KL(ϕ(p)∣∣ϕ(q))KL(p||q)- KL(phi(p)||phi(q))
Kết
Trong bài này, chúng ta đã tìm hiểu về họ phương pháp score matching và tính chất của nó. Ở các bài tiếp theo, chúng ta sẽ tiếp tục tìm hiểu về các phương pháp khác để huấn luyện EBM, và các vấn đề khi huấn luyện score-based models.
Tham khảo
Nguồn: viblo.asia