Lilian Weng’s blog post on VAEs and the original papers!
To train a VQVAE, there are three components we need to train: the encoder, the quantiser, and the decoder. Our objective is to:
Let’s tackle these problems one by one.
To train the codebook, we want the codebook and the encoder to work together to effectively quantise the data. In other words, we want the encoder’s output and the codebook’s vectors to be close to each other. There are two ways to do this: straight-through estimator and gumbel softmax. In this article, I’ll be using the straight-through estimator since it’s eaiser to implement and more widely used.
Since we want the encoder’s output and the codebook’s vectors to be close to each other, we have two options.
Commitment loss is simply the L2 loss between the encoder’s output and the codebook’s vectors. Mathematically, it can be written as:
where is the encoder’s output, is the quantised vector (i.e. the closest vector in the codebook), and denotes the stop_gradient
operator.
Figure - Commitment loss between the encoder’s output and the codebook’s vectors.
It’s important that we use the stop_gradient
operator (.detach()
in PyTorch, .stop_gradient()
in JAX) to the codebook vector since we want to treat it as a constant. This way, the parameters of the encoders are updated to make the encoder’s output close to the codebook’s vectors. One way to think about this is that you’re training the encoder to “commit” to a specific vector in the codebook, instead of switching between different vectors.
Vice versa, the VQ loss is defined similarly, but with stop_gradient
applied to the encoder’s output.
Since the codebook is typcially initialised with , it’s beneficial to introduce VQ loss such that the codebook vectors are pulled closer to the distribution of the encoder’s output. It also makes encoder’s task slightly easier since codebook vectors are drawn closer to the encoder’s output. To optimise these objectives jointly, we simly take the sum of the two losses, yielding the final objective for the codebook:
where is the hyperparameter that controls the tradeoff between the two losses. It’s been shown that works well in practice [1].
Now, that we have an idea of how to train the codebook, let’s move on to training the encoder and decoder. In a typical autoencoder setup, all the components of the network are continous differentiable operations (linear layers, activations, etc.). However, in VQVAE, we clearly have a non-differentiable operation: the quantisation step.
All differentiable functions are continous, so discontinous functions (
max
,random sampling
,quantisation
, …) are non-differentiable. In these cases, we need to define a custom gradient for the non-differentiable operation.
In straight-through estimator, we use the following neat trick:
def straight_through_quantisation(x):
return x + (quantise(x) - x).detach()
The output of straight_through_quantisation
is quantise(x)
. But, we notice that quantise(x) - x
is effectively a “constant” since we apply detach
to it. Hence, when backpropgation is performed, the gradient of quantise(x)
is effectively the gradient of x
. i.e. we have “taped” the gradient of x
to quantise(x)
.
Using this trick, we can allow the encoder and decoder to be trained end-to-end to optimise the reconstruction loss.
In layman’s terms, we want the reconstructed image to be close to the original image.
There are two common losses that are used in VQVAE training: perceptual loss and GAN loss. Perceptual loss is typically implemented by computing the L2 loss between latent features of a pretrained network (e.g. VGG, Resnet, etc.). Perceptual loss is designed to capture how “perceptually similar” the reconstructed image is to the original image.
GAN loss is implemented by training a discriminator to distinguish between the original image and the reconstructed image. GANs [5] were originally developed to generate images, but they have been also adapted into many superresolution [6] and image reconstruction tasks. Ok, but why do we care about perceptual loss and GAN loss?
The limitation of simple L2 loss is that it’s agnostic to aspects such as “sharpness”, “texture”, and “structure” of the image. For instance, under L2 loss, all the surrounding images have the same l2 distance compared to the image in the center.
All the surrounding images have the same l2 distance!
When training any generative model, one pitfall is that the model can get away by generating “average” of all the images in the dataset. For instance, below is an example of a model that fell into this local minima of animefaces dataset.
When your neural network takes the easy way out…
Both GAN loss and perceptual loss are designed to address this issue. Perceptual loss can help capture the “structure” and “texture” of the image. Especially, GAN loss has been found to be effective in generating “sharp” and “realistic” images as shown in super resolution literature [6].
For my setup of training VQVAE, I used the following configuration:
LPIPS
with vgg
backbone) was sufficient to generate sharp imageslr = 3e-4, betas=(0.9, 0.95), weight decay=0.05
RandomCrop(scale=(0.8,1.0))
, RandomHorizontalFlip(p=0.3)
, RandomAdjustSharpness(2,0.3)
, RandomAutocontrast(0.3)
Architectally, I used the following setup:
It’s been found that using a large codebook size with a small codebook dimension and l2 normalised codes improves codebook usage and reconstruction quality [7][8]. The main research question is whether Mamba, Attention, or Convolution backbones are better for VQVAE training.
Apart from a few exceptions [7][9], most works use a CNN backbone for VQVAE training. I used a standard ConvNext [10] backbone design. To match the token count of Mamba and Transformer backbones, I downsample the image to resolution. I also added a transformer layer before/after the quantiser. At each resolution (256, 128, 64, 32), I used two ConvNext blocks. The architecture of ConvNext block is as follows:
class NeXtformer(nn.Module):
def __init__(self, features:int):
super().__init__()
self.depthwise = nn.Conv2d(features, features, 7, padding=3, groups=features)
self.prenorm = RMSNorm(features)
self.postnorm = RMSNorm(features)
self.mlp = SwiGLU(features, bias=False)
def forward(self, x):
x = self.depthwise(self.prenorm(x)) + x
x = self.mlp(self.postnorm(x)) + x
return x
ViT-VQGAN [9] was one of the first works to explore the use of attention backbones for VQVAE training (also introduced l2 normalised codes). I closely replicated the architecture from the paper.
class VQVAE(nn.Module):
def __init__(self, features:int=768, ...):
super().__init__()
self.size = size
self.patch = patch
self.strides = strides
self.ntoken = (size // strides)
self.epe = WPE(features, self.ntoken ** 2)
self.dpe = WPE(features, self.ntoken ** 2)
# patchify
self.input = nn.Conv2d(3, features, patch, ...)
# encoder
transformers = [Transformer(features, ...) for _ in range(depth)]
self.encoder = nn.Sequential(*[
*transformers, RMSNorm(features),
nn.Linear(features, 4 * features),
nn.Tanh(), nn.Linear(4 * features, features)
])
# quantiser
self.quantiser = VectorQuantiser(features, codes, pages)
# decoder
transformers = [Transformer(features, ...) for _ in range(depth)]
self.decoder = nn.Sequential(*[
*transformers, RMSNorm(features),
nn.Linear(features, 4 * features),
nn.Tanh(), nn.Linear(4 * features, features)
])
# unpatchify
self.output = nn.ConvTranspose2d(features, 3, ...)
def forward(self, x):
x = rearrange(self.input(x), 'b c h w -> b (h w) c')
x = self.encoder(self.epe(x))
codes, loss, idxes = self.quantiser(x)
x = self.decoder(self.dpe(codes))
x = rearrange(x, 'b (h w) c -> b c h w', h=self.ntoken, w=self.ntoken)
x = self.output(x)
return x, loss, idxes
For Mamba backbone, I simply used the same architecture as ViT-VQGAN, but with attention layer swapped with Mamba layer. I tried to variants of Mamba: standard Mamba and bidirectional Mamba.
class SSD(nn.Module):
def __init__(self, features:int, heads:int, bias=False):
super().__init__()
self.mamba = Mamba2(features)
self.prenorm = RMSNorm(features)
self.postnorm = RMSNorm(features)
self.mlp = SwiGLU(features)
def forward(self, x):
x = self.mamba(self.prenorm(x)) + x
x = self.mlp(self.postnorm(x)) + x
return x
# simple bidirectional mamba similar to Bidirectional LSTM
class BSSD(nn.Module):
def __init__(self, features:int, heads:int, bias=False):
super().__init__()
self.fwd = Mamba2(features)
self.bwd = Mamba2(features)
self.prenorm = RMSNorm(features)
self.fwdnorm = RMSNorm(features)
self.bwdnorm = RMSNorm(features)
self.fwdmlp = SwiGLU(features)
self.bwdmlp = SwiGLU(features)
def forward(self, x):
# b t d
f = x
b = torch.flip(x, dims=[1])
f = self.fwd(self.prenorm(f)) + f
f = self.fwdmlp(self.fwdnorm(f)) + f
b = self.bwd(self.prenorm(b)) + b
b = self.bwdmlp(self.bwdnorm(b)) + b
# flip back
b = torch.flip(b, dims=[1]) # <- important
return f + b
It’s important to note that you flip back b
. Otherwise, the model would spend most of the time learning how to reverse the input sequence.
For benchmarking, I’ve used models of following sizes:
An initial comparison of the three backbones reveals that CNN is a clear winner despite having the fewest parameters. Another suprising result was that bidirectional Mamba and standard Mamba performed similarly. Hence, it’s not clear whether bidirectionality helps. Works such as VIM [11] have used bidirectional modeling for Mamba backbone, but recent work such as MambaVision [12] simply uses concatenation of SSD features and gated features to achieve better performance.
CNN backbone outperforms Attention and Mamba backbones
One major difference between the CNN backbone and the others is that the compress
(codebook loss) tends to converge very smoothly. Another noticible difference is that CNN backbones have a better initialisation.
Comparison of reconstructions and initialisation for different VQVAE backbones.
Left: Original / Reconstructed image at initialisation. Right: Original / Reconstrcuted image after 30K steps.
CNN’s initialisation is noisy, but resembles the original image. On the other hand, Mamba and Attention backbones mostly output a grid-like noise pattern. These “grid” artifacts persist throughout the training.
Grid artifacts in Mamba and Attention backbones
To fix this issue, I had two ideas in mind:
To overlap the patches, I simply add extra padding and increase kernel’s size on input Conv2d and output ConvTranspose2d layers. For extra convolutional layers, I’ve added a single 3x3 Conv2d layer as a final layer so that it can learn take care of the grid artifacts.
Fixing grid artifacts with overlapping patches and extra convolutional layers
From the experiments, it seems that adding extra convolutional layers and overlapping patches does help out with the performance. However, it’s still nowhere close to the performance of the CNN backbone.
It should be noted that I have not performed no comprehensive hyperparameter search or ablation studies. I’m only a person with a handful of A6000 GPUs and running these experiments as a hobby. Although, I’m interested in applying [14] and Katie et al. [15] for fast hyperparameter search.
As previously mentioned, bidirectionality doesn’t seem to help much. I’m currently running a modified version of MambaVision [13] which uses SSD [16] instead of SSM [3] for better performance. (Their performance is not looking good for your record, story for another day). But, the obvious next step I want to take is to train an autoregressive model on top of CNN VQVAE using Mamba.