Auto-Encoding Variational Bayes 리뷰
VAE의 핵심을 잘 짚은 설명이 부족하다고 느껴 직접 리뷰를 남기기로 결심했습니다. 많은 사람들이 model design에 중점을 두는 것 같아서 training strategy나 loss에 대한 설명이 많이 부족하다고 느꼈습니다. 이에 동기부여가 되어 나중에 헷갈릴때 돌아볼 겸 리뷰를 작성합니다.
주의!
: 아래 키워드, 논문, 지식은 있다는 가정 하에 작성되었습니다. 이들은 본 포스팅을 이해하기 위한 필수 요소입니다.
Prerequisite
- Latent Variable
- Posterior, Prior, Likelihood
- Marginal Likelihood
Intruduction
생성이 어려운 이유는 분류처럼 특징을 뽑아 데이터를 압축하는 것이 아닌 같은 크기의 비슷한 값을 뽑아야 하기 때문이라 생각합니다. 심지어 출력값도 카테고리처럼 정해진 값이 아닌 자연스러우면서도 비슷한 특징을 공유해야 함. 다시 말해 생성 문제의 정답은 무한하다고 말할 수 있죠. 그래서 생성 모델
에 관한 연구는 정답 자체를 찾는 것보다는 정답의 분포를 찾으려는 시도
가 주를 이루는 것 같습니다. 본 논문은 학습 데이터를 샘플링하여 모분포를 추정하는 생성 모델을 디자인하는 방법을 소개합니다. 데이터의 모분포를 학습해버리면 그 분포에서 뽑아낸 값은 생성 결과값으로 사용할 수 있겠죠? 이러한 방법을 소개한 참신한 고전적인(2014년이면 생성형에선 고전인듯.) 아이디어를 소개합니다 !
Method
Problem scenario
먼저 분포를 학습할 데이터 $X={x^{(i)}}^{N}_{i=1}$가 존재한다고 가정합니다.
Introduction에서 설명했듯, 우리의 목표는 X의 분포를 학습하는 것입니다. 그런데 논문에서 주장하는 문제가 크게 두 가지 있습니다.
- $X$의 분포를 계산하기가 너무 어렵습니다. 이미지 튜토리얼 문제인 숫자 분류 MNIST 데이터셋 마저도 (28 x 28 = 576) 사이즈인데, 이것의 분포를 계산한다는건 상당히 난이도가 있습니다.
- 데이터의 크기 $N$이 너무 클때 이를 한번에 최적화하기 어렵습니다. 이 문제는 사실 모든 딥러닝 공통된 문제긴 하죠. 리뷰를 쓰는 지금에서야 batch를 나눠 확률적 경사 하강법(SGD) 기반 옵티마이저를 사용하는게 기본이니까 뭐 무시해도 될 문제인듯.
이러한 이유로 제 생각엔 $X$의 분포 계산을 효율적으로 하는 것이 가장 중요한 문제입니다. 본 논문에서는 $X$의 계산량을 줄이기 위해 아주 멋진 아이디어를 제시합니다. $X$의 모분포를 직접 추정하는 대신, $X$의 정보를 담고 있는 훨씬 작은 차원의 분포 $z \sim P(z\mid x)$와 $z$로부터 생성 결과 $x’$을 추정할 수 있는 또다른 분포 $x’ \sim P(x\mid z)$를 추정합니다. 자, 이러면 $X$의 모분포를 추정하는 문제를 아예 다른 관점으로 바라볼 수 있겠습니다. x가 이미 관찰된 값이고 $z$가 latent variable이므로 $x$에 대한 $z$의 posterior($P(z\mid x)$), 그리고 $z$에 대한 $x$의 likelihood($P(x\mid z)$)를 추정하는 문제로 변환할 수 있겠네요! 또한 $p(z)$는 prior로 정의할 수 있겠네요. 이 문제에서는 $p(z)$가 표준 정규 분포라는 가설을 세웁니다. 굳이 분포를 표준 정규 분포로 잡은건 아마 계산이 편리하기 때문이라 생각해요.
The Variational Bound
이 섹션에서는 Variational Inference라는 수리통계 이론을 기반으로 한 loss function을 정의하는 법에 대해 다룹니다. 먼저, 생성 모델의 목표는 성공적으로 생성할 확률을 높이는 것입니다. 모델이 x를 입력으로 받아 x와 비슷한 값을 생성할 확률을 $P_\theta(x)$라 정의할 때, 이 값을 maximize 하는거죠. 그런데 위에서 언급했듯, $P_\theta(x)$ 이 값 자체를 근사하기란 너무 어렵습니다. 대신 $z$라는 latent variable을 도입하기로 결심했죠. 이때, $P_\theta(x)$은 $\log$ 함수를 붙이면 다음과 같이 전개가 가능합니다.
\[\log p_{\theta} (\mathbf{x}^{(i)}) = D_{KL} (q_{\phi} (\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta} (\mathbf{z} | \mathbf{x}^{(i)})) - D_{KL}(q_{\phi}(z|x^{(i)}) || p_{\theta}(z)) + \mathbb{E}_{q_{\phi}(z|x^{(i)})} [\log p_{\theta}(x^{(i)}|z)]\]이때 $P_\phi(a)$는 $a$를 추정하는 모델의 확률입니다. 본 수식은 수리통계 혹은 정보이론 관점으로 다루는 유명한 수식인듯 한데 증명은 아래 참고.
Proof.
$$ \begin{align*} \log p_\theta(\mathbf{x}) &= \int q_\phi(\mathbf{z}|\mathbf{x}) \log p_\theta(\mathbf{x}) \, d\mathbf{z} \\ &= \int q_\phi(\mathbf{z}|\mathbf{x}) \log \frac{p_\theta(\mathbf{x}|\mathbf{z})p(\mathbf{z})}{p_\theta(\mathbf{z}|\mathbf{x})} \, d\mathbf{z} \\ &= \int q_\phi(\mathbf{z}|\mathbf{x}) \log \frac{p_\theta(\mathbf{x}|\mathbf{z})p(\mathbf{z}) q_\phi(\mathbf{z}|\mathbf{x})}{p_\theta(\mathbf{z}|\mathbf{x}) q_\phi(\mathbf{z}|\mathbf{x})} \, d\mathbf{z} \\ &= \int q_\phi(\mathbf{z}|\mathbf{x}) \log \frac{p_\phi(\mathbf{z}|\mathbf{x})}{p_\theta(\mathbf{z}|\mathbf{x})} \, d\mathbf{z} + \int q_\phi(\mathbf{z}|\mathbf{x}) \log \frac{p_\theta(z)}{p_\phi(\mathbf{z}|\mathbf{x})} \, d\mathbf{z} + \int q_\phi(\mathbf{z}|\mathbf{x}) \log p_\theta(\mathbf{x}|\mathbf{z}) \, d\mathbf{z} \\ &= D_{KL}\big(q_\phi(\mathbf{z}|\mathbf{x}) \,\|\, p_\theta(\mathbf{z}|\mathbf{x})\big) - D_{KL}\big(q_\phi(\mathbf{z}|\mathbf{x}) \,\|\, p(\mathbf{z})\big) + \mathbb{E}_{q_{\phi}(z|x^{(i)})} [\log p_{\theta}(x^{(i)}|z)] \end{align*} $$자 이제, 계산이 어려운 $P_\theta(x)$를 latent variable을 통해 세 개의 term으로 줄였습니다.
$q_\phi(\mathbf{z}\mid\mathbf{x})$는 $x$가 주어지면 $z$를 추론할 확률로 해석이 되고 $p_\theta(\mathbf{x}\mid\mathbf{z})$는 $z$가 주어지면 $x$를 생성할 확률로 해석이 됩니다. 그런데 조금 수상한 term이 하나 있습니다.
$p_\theta(\mathbf{z}\mid\mathbf{x})$ 이 term은 생성을 위한 parameter인 $\theta$로 $z$를 추정할 확률을 나타냅니다. 즉, x에 대한 z의 posterior죠. 역분포를 추정하기 너무 어렵기 때문에 본 term은 계산하기 굉장히 어렵습니다. 그렇기 때문에 여기서 참신한 전략을 적용합니다.
\[D_{KL}\big(q_\phi(\mathbf{z}|\mathbf{x}) \,\|\, p_\theta(\mathbf{z}|\mathbf{x})\big) - D_{KL}\big(q_\phi(\mathbf{z}|\mathbf{x}) \,\|\, p(\mathbf{z})\big) + \mathbb{E}_{q_{\phi}(z\mid x^{(i)})} [\log p_{\theta}(x^{(i)}|z)] \\ \geq - D_{KL}\big(q_\phi(\mathbf{z}|\mathbf{x}) \,\|\, p(\mathbf{z})\big) + \mathbb{E}_{q_{\phi}(z|x^{(i)})} [\log p_{\theta}(x^{(i)}|z)]\]KL divergence는 항상 0보다 같거나 크기 때문에, 위와 같은 부등식을 세울 수 있습니다. $P_\theta(x)$를 maximize하기 어렵기 때문에 계산이 어려운 KL divergence 항은 고려하지 않고, Lower Boundary를 만들어서 이 값을 최적화합니다. Lower Boundary 부분은 ELBO입니다. 여기까지가 Variational Inference에 대한 내용이고, VAE는 이를 loss function으로 사용하는 신경망 모델입니다.
\[\mathcal{L}(\theta, \phi; x^{(i)}) = -D_{KL}(q_{\phi}(z|x^{(i)})||p_{\theta}(z)) + \mathbb{E}_{q_{\phi}(z|x^{(i)})} [ \log p_{\theta}(x^{(i)}|z) ]\]The SGVB estimator and AEVB algorithm
이 섹션은 loss를 학습하는 한가지 전략을 기술합니다. 우리는 위에서 x의 분포를 학습하기 위해 z를 이용한다는 아이디어를 적용중이었죠. 신경망은 이를 따라 모델링합니다. 위의 Loss Function을 학습하기 위해서는 사실 분포 자체를 학습해야 합니다. 그런데, 당연하게도 데이터는 분포 자체를 나타내진 않고, 분포 위 하나의 점을 나타내죠. 이때 우리는 이 데이터를 통해 표본분포를 학습합니다.
The reparameterization trick
대부분 리뷰글은 이 부분을 강조하면서 설명하던데 저는 이 부분이 크게 중요하다고 느끼진 않았습니다. 데이터 $x$가 입력되면 사용할 latent variable $z$는 deterministic이 아닌 stochastic합니다. 그 이유는 기존 학습 목적이 분포를 학습하는 것이었고, $x$의 분포를 학습하기 어려우니, $z$의 분포를 학습하자는게 메인 아이디어이기 때문입니다. 그래서 z를 자체를 학습하기보다는 z의 분포를 학습합니다. 이때, 위에서 z는 표준 정규 분포라는 prior가 있었죠? 그래서 $q_{\phi}(z\mid x^{(i)})$ 각 벡터 원소의 평균과 표준 편차를 학습합니다. 이후 실제 생성에 사용할 $z$는 본 분포로부터 sampling합니다.
Example: Variational Auto-Encoder
이 부분부터는 VAE 신경망 디자인에 대해 다룹니다. 이전 두 섹션에서 다룬 점은 loss function은 데이터로 표본분포를 학습한다는 점, Encoder는 $z$의 분포를 학습하기 위한 목적이므로 출력이 정규 분포라는 점을 강조합니다. 이를 기반으로 신경망을 디자인할 수 있습니다. $q_{\phi}(z\mid x^{(i)})$는 Encoder Layer로 사용하고 $x$(의 분포)가 입력되면 $z$(의 분포)를 출력합니다. $p_{\theta}(x^{(i)}\mid z)$는 Decoder Layer로 사용하고 $z$가 입력되면 $x$를 출력합니다. 이 사이에는 z를 샘플링하는 과정을 포함합니다.
우리는 loss function을 표본분포로 근사하는데, 이때, loss function된 두 term을 상세히 살펴보겠습니다.
\[\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) \approx \frac{1}{2} \sum_{j=1}^{J} \left(1 + \log((\sigma_{j}^{(i)})^2) - (\mu_{j}^{(i)})^2 - (\sigma_{j}^{(i)})^2\right) + \frac{1}{L} \sum_{l=1}^{L} \log p_{\theta}(\mathbf{x}^{(i)} | \mathbf{z}^{(i,l)}) \\ \text{where} \quad \mathbf{z}^{(i,l)} = \boldsymbol{\mu}^{(i)} + \boldsymbol{\sigma}^{(i)} \odot \boldsymbol{\epsilon}^{(l)} \quad \text{and} \quad \boldsymbol{\epsilon}^{(l)} \sim \mathcal{N}(0, \mathbf{I})\]오른쪽 항의 첫번째 term은 KL Divergence를 구한 term입니다. $P(z)$는 prior로 인해 표준 정규 분포($\mathcal{N}(0,1)$)로 가정할 수 있고, $q_{\phi}(z\mid x^{(i)})$에서 출력한 평균, 표준편차를 통해 loss 값을 구해 backpropagation이 가능하죠.
오른쪽 항의 두번째 term은 sampling한 $z$ 하에서 $x$를 성공적으로 복원할 수 있는지를 기준으로 삼아 각 벡터별로 MSE를 계산합니다.
이렇게, 신경망을 성공적으로 디자인할 수 있습니다.