본문 바로가기

머신러닝

Variational Autoencoder

Variational autoencoder(VAE)는 generative model(생성모델)의 하나로써 autoencoder와 비슷하게 latent space로 data를 매핑하지만, 여기서는 기존의 data를 그대로 복원하는 것이 목적이 아니라 새로운 data를 생성하는 것이다. 그런데 아무 data를 생성하는 것이 아니라, input data와 유사한 데이터를 생성하고 싶다. 유사하다는 것의 의미는, input data의 분포 안에 있을 법한 데이터라는 의미이다. 그런데 이 data가 latent space로 encode될 수 있고 다시 latent space에서 진짜 data로 decode될 수 있으므로, latent vector $z$의 분포에서 $z$를 뽑은 다음 이를 decode하여 data로 복원시키는 방식으로 data를 생성한다. 

 

Latent vector의 분포를 이야기했지만, input data의 분포 안에 있을 법한 데이터를 생성하는 것이므로 결론적으로는 input data($x$)들이 있는 분포를 알아내고 싶은 것이다. 그런데 고차원인 $x$의 분포를 바로 추론하는 건 매우 어렵기 때문에, 저차원인 latent vector $z$의 분포($p(z)$)를 Gaussian같이 쉬운 확률분포로 가정하고 이로부터 $x$의 분포를 추론하고자 한다. 또 이렇게 하면 우리가 $z$를 약간 조절하여 원하는 특성을 갖는 data를 생성해낼 수도 있을 것이다. $x$가 분포를 갖고 있으므로 $z$로부터 복원되는 $x$도 확률적으로 정해질 것이고, 이 확률분포를 $p(x|z)$라 하자. 그런데 $z$의 분포는 Gaussian으로 가정했지만 고차원 data $x$와 latent vector $z$를 이어주는 $p(x|z)$는 굉장히 복잡할 것이므로 이 확률분포를 neural network로 근사할 것이고, parameter $\theta$를 사용하여 $p_\theta(x|z)$라 하자. 그러면 $x$의 확률분포 $p(x)$를 다음과 같이 쓸 수 있다.

$$p_\theta(x)=\int p(z)p_\theta(x|z)dz$$

하지만 모든 $z$에 대한 적분은 불가능하기 때문에, 위 식을 계산할 수가 없다. 한편, $x$로부터 추출되는 $z$ 역시 확률적으로 정해질 것이고 이 확률분포는 Bayes' theorem에 의해 다음과 같이 표현될 수 있다.

$$p_\theta(z|x)=\frac{p_\theta(x|z)p(z)}{p_\theta(x)}$$

그런데 $p_\theta(x)$이 계산 불가능하기 때문에 이것 역시 계산할 수 없다. 따라서 이 확률분포 역시 neural network로 근사할 것이고, 이를 다른 parameter $\phi$를 사용하여 $q_\phi(z|x)$라 하자. 결론적으로 $x$에서 $z$로 이어주는 encoder를 $q_\phi(z|x)$, 반대로 $z$에서 $x$로 이어주는 decoder를 $q_\phi(z|x)$라는 확률분포로 표현하는 것이다. 사실 이들 역시 Gaussian과 같이 쉬운 확률분포로 가정하고 있다. 그 분포를 Gaussian으로 가정한다면, 실제로 neural network에서 출력되는 값은 $z|x$의 평균과 분산($\mu_{z|x}$, $\sigma_{z|x}$) 및 $x|z$의 평균 및 분산($\mu_{x|z}$, $\sigma_{x|z}$)이 된다. 왜냐하면 Gaussian을 정의하는 parameter가 평균과 분산이기 때문이다. 

 

그러면 encoder와 decoder의 확률분포를 둘 다 Gaussian으로 가정했을 때, VAE가 어떻게 작동하는지 알아보자. VAE에서는 encoder에 어떤 data $x$를 넣으면 그 $x$로부터 추출된 $z$ 자체를 출력하는 것이 아니라, 그 $z$가 어떤 값을 가질 것 같은지(분포)에 대한 정보($\mu_{z|x}$, $\sigma_{z|x}$)를 출력한다. 그러면 그 분포에서 임의로 한 $z$를 뽑아낸다. 그러면 이 $z$는 아무 $z$가 아니라, input으로 들어온 $x$로부터 추출되었을 것 같은 $z$인 것이다. 그 $z$를 다시 decoder에 넣으면 역시 $z$로부터 복원되는 하나의 $x$ 자체를 출력하는 것이 아니라, 그 $x$가 어떤 값을 가질 것 같은지(분포)에 대한 정보($\mu_{x|z}$, $\sigma_{x|z}$)를 출력한다. 그러면 이 분포로부터 $x$를 하나 뽑았을 때 그 $x$가 input으로 넣은 $x$와 유사하게 되도록 VAE가 학습된다.

 

Decoder의 확률분포를 Gaussian이 아니라 (generalized) Bernoulli distribution으로 할 수도 있다. 이런 경우에는 평균과 분산을 출력하지 않고 일반적인 classification을 할 때처럼 network를 구성하면 된다. 이는 나중에 설명할 loss term을 보면 좀더 명확해진다.

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim):
        super(VAE, self).__init__()
        # encoder
        self.z_mean = nn.Sequential(nn.Linear(self.input_dim, hidden_dim),
                                    nn.ReLU(),
                                    nn.Linear(hidden_dim, self.latent_dim))
        self.z_log_var = nn.Sequential(nn.Linear(self.input_dim, hidden_dim),
                                       nn.ReLU(), 
                                       nn.Linear(hidden_dim, self.latent_dim))
        
        # decoder
        self.x_mean = nn.Sequential(nn.Linear(self.latent_dim, hidden_dim), 
                                    nn.ReLU(), 
                                    nn.Linear(hidden_dim, self.output_dim))
        self.x_log_var = nn.Sequential(nn.Linear(self.latent_dim, hidden_dim), 
                                       nn.ReLU(), 
                                       nn.Linear(hidden_dim, self.output_dim))
        
        '''
        Bernoulli decoder
        self.x = nn.Sequential(nn.Linear(self.latent_dim, hidden_dim), 
                               nn.ReLU(), 
                               nn.Linear(hidden_dim, self.output_dim))
        '''
        
    def encoder(self, x):
        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)

        return z_mean, z_log_var
       
    def reparametrization(self, z_mean, z_log_var):
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        
        return z_mean + std * eps
        
    def decoder(self, z):
        x_mean = self.x_mean(z)
        x_log_var = self.x_log_var(z)
        
        return x_mean, x_log_var
        
    '''
    Bernoulli decoder   
    def decoder(self, z):
        x_hat = self.x(z)
        
        return x_hat
    '''
        
    def forward(self, x):
        z_mean, z_log_var = self.encoder(x)
        z = self.reparameterize(z_mean, z_log_var)
        x_mean, x_log_var = self.decoder(z)
        
        return x_mean, x_log_var, z_mean, z_log_var
        
    '''
    Bernoulli decoder
    def foward(self, x):
        z_mean, z_log_var = self.encoder(x)
        z = self.reparameterize(z_mean, z_log_var)
        x_hat = self.decoder(z)
        
        return x_hat, z_mean, z_log_var
    '''

그런데 $z$ 자체의 확률분포 $p(z)$ 역시 앞서 Gaussian으로 가정했다. 이 Gaussian은 그냥 평균이 0, 분산이 1인 Gaussian으로 정할 것이다. 따라서 모든 input $x$로부터 얻을 수 있는 모든 $z$들의 전체 평균은 0, 분산은 1이 되게끔 만들 것이다. 이렇게 하기 위해 encoder가 출력하는 $\mu_{z|x}$ $\sigma_{z|x}$가 각각 0과 1에 가까워지도록 VAE가 학습된다. 

 

Model을 이해하는 좋은 방법 중 하나는 loss function이 어떻게 정의되는지 보는 것이다. 이 경우에는 확률분포를 추정하고 있으므로 maximum likelyhood estimation의 관점에서 loss를 정의하자. VAE의 목적은 data를 생성하기 위해 input data의 분포를 최대한 정확히 추정하는 것이 목적이므로, maximum likelihood estimation의 원리에 따라 다음 likelihood를 최대화한다.

$$\begin{align*}
\log p_\theta(x)&=\int q_\phi(z|x)\log p_\theta(x)dz \\
&=\text{E}_{z\sim q_\phi(z|x)}[\log p_\theta(x)] \\
&=\text{E}_z\left[\log \frac{p_\theta(x|z)p(z)}{p_\theta(z|x)}\right] \\
&=\text{E}_z\left[\log \frac{p_\theta(x|z)p(z)}{p_\theta(z|x)} \frac{q_\phi(z|x)}{q_\phi(z|x)}\right] \\
&=\text{E}_z\left[\log p_\theta(x|z)\right]-\text{E}_z\left[\log\frac{q_\phi(z|x)}{p_\theta(z)}\right]+\text{E}_z\left [\log \frac{q_\phi(z|x)}{p_\theta(z|x)}\right] \\
&=\text{E}_z\left[\log p_\theta(x|z)\right]-D_{KL}(q_\phi(z|x)\parallel p_\theta(z))+D_{KL}(q_\phi(z|x)\parallel p_\theta(z|x))
\end{align*}$$

먼저 $p_\theta(x)$는 $z$에 의존하지 않기 때문에 첫 줄과 같이 쓸 수 있다. 이를 Bayes' theorem을 이용하여 전개하고, 식을 정리하면 마지막 줄과 같은 식을 얻을 수 있다. 이때 마지막 항은 정확히 계산할 수는 없지만, KL divergence는 항상 0 이상의 값을 가지므로 다음과 같은 부등식이 성립한다.

$$\log p_\theta(x)\geq \text{E}_z\left[\log p_\theta(x|z)\right]-D_{KL}(q_\phi(z|x)\parallel p(z))=\mathcal{L}(x, \theta, \phi)$$

따라서 $\log p_\theta(x)$는 위와 같이 lower bound $\mathcal{L}(x, \theta, \phi)$가 존재함을 알 수 있고, 이 lower bound를 최대화하면 된다. 이것을 코드로 구현할 때는 $-\mathcal{L}$을 loss로 두고 이를 최소화한다. 

 

$\mathcal{L}(x, \theta, \phi)$을 최대화하는 것의 의미를 살펴보자. 첫 항, 즉 $z$가 주어졌을 때 원본 input $x$가 복원될 likelihood는 최대화해야 하고, 그 의미는 input data와 비슷한 data를 생성하도록 하는 것이다. 반면 두 번째 항은 최소화해야 하는데, KL divergence는 두 분포의 차이를 계산하는 것이므로 이것을 최소화한다는 것은 $q_\phi(z|x)$를 미리 가정한 $p_(z)$, 즉 평균이 0이고 분산이 1인 Gaussian과 비슷하게끔 만든다는 의미이다. 첫 항을 reconstruction term, 두 번째 항을 regularization term이라고도 한다. 

 

Gaussian decoder의 경우 reconstruction term은 다음과 같이 표현된다.

$$\log p_\theta (x|z)=\sum_{j=1}^D \log \left[\frac{1}{\sqrt{2\pi \sigma_j^2}}\exp\left(-\frac{(x_j-\mu_j)^2}{2\sigma_j^2}\right)\right]\rightarrow-\sum_{j=1}^D\left[\frac{1}{2}\log(\sigma_j^2)+\frac{(x_j-\mu_j)^2}{2\sigma_j^2}\right]$$

이때 $D$는 $x$의 차원이다. 만약 분산 $\sigma$를 학습되는 parameter가 아니라 고정된 값(예를 들어 1)으로 두면, 이 term은 mean squared loss(MSE)와 동일하다는 것을 알 수 있다. 사실 MSE를 최소화하는 것은 관측된 $y$값이 $f(x)$를 평균으로 갖는 어떤 Gaussian에서 나온 것으로 가정하고 그 $y$의 likelihood를 최대화하는 것과 같다.  

 

(Generalized) Bernoulli decoder의 경우 reconstruction term은 cross entropy와 같다.

$$\log p_\theta (x|z)=\sum_{j=1}^D x_j\log p_j\text{ where } p_j=\text{softmax}(\hat{x_j})$$

 

$p(z)=N(0,I)$이고 $q_\theta(z|x)=N(\mu,\sigma^2)$일 때 KL divergence term을 구하면 다음과 같다.

$$\frac{1}{2}\sum_{j=1}^J (\mu_j^2+\sigma_j^2-\ln(\sigma_j^2)-1)$$

이때 $J$는 latent vector $z$의 차원이다. 

# Generalized Bernoulli decoderr
rec_loss = nn.CrossEntropyLoss(x_hat, x, reduction='sum')

# Gaussian decoder
rec_loss = 0.5 * x_log_var + (x - x_mean)**2 / (2*torch.exp(x_log_var)).sum()

# KL divergence
KLD = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())

loss = rec_loss + KLD

참고:

https://www.youtube.com/watch?v=5WoItGTWV54

https://velog.io/@tobigs16gm/VAEVariational-Auto-Encoder

 

VAE(Variational Auto Encoder)

Variational Auto Encoder: 15기 박진수

velog.io

'머신러닝' 카테고리의 다른 글

GPT-4를 이용한 재료 및 화학 분야 적용 연구  (0) 2023.12.05
Gaussian Process  (1) 2022.12.30
Graph Neural Network  (0) 2022.04.26
Feature Selection  (0) 2022.04.24
Principal Component Analysis (PCA)  (0) 2022.04.24