Variational Autoencoders (VAEs)
Traditional Auto Encoder maps input data into a smaller / compressed vectors (latent vectors), and restores output from the compressed vectors.
The good thing is only the fundamental information is kept and noises are cleaned out. The compressed vectors represent the nature / underlying structure of the input data.
Variational Auto Encoder (VAEs) are similar, except that the compressed vectors are designed to have (1) mean values, and (2) log of variance, i.e. Normal distributions (or normal-like distributions).
The Encoder compresses inputs into probability distributions, and then restore data back from probability distributions to outputs.
Therefore VAE is a type of generative model that learn to represent data in a compressed latent space, while also enabling the generation of new, similar data. Unlike traditional autoencoders, which map inputs to fixed latent vectors, VAEs learn to map inputs to probability distributions in the latent space.
This probabilistic approach allows VAEs to:
Capture uncertainty in the data.
Generate diverse outputs by sampling from the latent space.
Regularize the latent space to be smooth and continuous.
Applications:
For noisy data, it can denoise it and ouput a smoother version of data.
For anomaly detection, if input doesn't conform with the latent normal distributions, it would struggle to restore the data as output, and thus the reconstruction loss would be high.
Then the reconstruction loss can be used for measuring anomolies.
Key Ideas Behind VAEs
1. Encoding as a Distribution
Instead of encoding an input into a single point, VAEs encode it into a Gaussian distribution:
The encoder outputs two vectors:
mu: the mean of the distribution.
logvar: the logarithm of the variance.
This means each input is represented by a distribution N(μ,σ2) in the latent space.
2. Why Use logvar Instead of var?
It rescales a large variance to a smaller number, and rescales a small variance to a moderate number.
This makes the training easier:
Numerical stability: log-space compresses large or small values.
Ensures positivity: variance must be positive
Simplifies KL divergence: the loss function includes a KL divergence term that is easier to compute using logvar.
To convert logvar to standard deviation std for the decoder part.
σ^2=exp(logvar)
so std or σ = exp(0.5logvar)
3. The Reparameterization Trick
From the compressed latent space forward, it restore output by sampling directly from probability distribution.
z = mu + eps * σ
where eps=randn(0,1) is a number follow N(0,1) normal distribution, so that z follows N(mu, σ2) distribution.
z can not be sampled directly fron N(mu, σ2) using like torch.normal(mean=mu, std=std), because the random / normal function would break backpropagation in training because random function is non-differentiable, i.e. you can't calculate gradient out of a random function.
To fix this, VAEs use the reparameterization trick to caluclate z in a deterministic way:
z = mu + eps * σ
This makes z a deterministic function of mu, logvar, and eps, allowing gradients to flow through the network during training.
4. Loss Function
The VAE loss combines two terms:
Reconstruction loss: measures how well the output matches the input.
KL divergence: regularizes the latent space by measuring how close the learned distribution is to a standard normal distribution.
KL = −0.5∑(1+logvar−μ^2−exp(logvar))
5. Training and Prediction
In the VAE training process, z = mu + eps * σ with the randomness from eps. This helps to learn about the laten probabilistic distributions.
When runing inference / prediction, use z = mu without the randomness, so the forward() and predict() are separate, while other models may use forward for both training and prediction.
6 limits
VAE assumes the data follows normal distribution which may not be true, or mostly not true, so be careful.
To overcome the limitations of the normality assumption, several extensions to VAEs have been proposed:
Uses a Gaussian Mixture Model (GMM), allow multimodal latent distributions, use a series of invertible functions, use categorical or binary latent variables instead of continuous ones, etc.
Example Code:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
# Dummy dataset class
class ExampleDataset(Dataset):
def __init__(self, data_tensor):
self.data = data_tensor
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# VAE model
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
# Encoder, there can be multiple Linear layers here
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU()
)
self.mu_layer = nn.Linear(hidden_dim, latent_dim)
self.logvar_layer = nn.Linear(hidden_dim, latent_dim)
# Decoderm there can be multiple linear layers here
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid() # Use Sigmoid if input is normalized between 0 and 1
)
def encode(self, x):
h = self.encoder(x)
mu = self.mu_layer(h)
logvar = self.logvar_layer(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar
def predict(self, x):
mu, _ = self.encode(x)
recon = self.decode(mu)
return recon
# Loss function
def vae_loss(recon_x, x, mu, logvar):
recon_loss = F.mse_loss(recon_x, x, reduction='sum')
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_div
# Example usage
if __name__ == "__main__":
input_dim = 20
hidden_dim = 64
latent_dim = 10
# Create dummy data
data = torch.randn(1000, input_dim)
dataset = ExampleDataset(data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Initialize model
model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
for epoch in range(10):
total_loss = 0
for batch in dataloader:
optimizer.zero_grad()
recon, mu, logvar = model(batch)
loss = vae_loss(recon, batch, mu, logvar)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.2f}")
# Inference example
test_input = torch.randn(1, input_dim)
reconstructed = model.predict(test_input)
print("Original:", test_input)
print("Reconstructed:", reconstructed)
This code example uses MSE loss + KL loss for reconstruction (recon_loss + kl_div), suitable for continuous data.