dev.log

A place for peace of mind

Gazing in the Latent Space with Sparse Autoencoders

Gazing

You can access fluxlens to explore the features of Sparse Autoencoders.

Preface

Over the years, ML researchers have burned countless hours training large neural networks. However, its inner workings remain opaque and incomprehensible to humans. There have been many research efforts in the past [1] [2] introspect what happens inside neural networks. Most recently, Sparse Autoencoders have shown to be a promising tool in analysing Large Language Model’s behavior.

This work explores how Sparse Autoencoders (SAE) can be applied to diffusion models and vision encoders. We demonstrate that SAEs can be used to decompose model’s activation space into interpretable, semantic components. Furthermore, to the best of our knowledge, we are the first to show that SAEs trained directly on flow matching models (flux). This work is accompanied by fluxlens, a minimal visualisation tool I’ve built to explore SAE features and multimodal datasets.

This work began early October in collaboration with krea.ai who’ve dragged me out of GPU poverty line. The initial exploration and visualisation tool was built on top of krea’s internal assets. However, to facilitate SAE open research efforts, I’ve retrained SAEs on public multimodal dataset (CC3M). The code for SAE training will be open sourced in the coming weeks.

I’d also like to acknowledge that this work was directly influenced by Antrhopic’s SAE work and Transformers Circuit discourse. A fair amount of SAE implementation was referenced from saprmarks’s dictionary learning codebase and OpenAI’s implementation. Among other resources, I’ve found Linus’s prism , Gytis Daujotas’s and hugofry ‘s post on LessWrong an useful reference for training SAEs and vision SAEs in particular.

Background

Generative models such as GPTs and Stable Diffusion has changed the information landscape in just couple of years. However, it’s not at all clear how these massively large matrix multiplications generate its output. Neurologists have been facing similar challenges on a different subject, the brain. A classical way to study the brain is to provide a wide range of stimulus and see which part of the brain “lights up” using a fMRI.

Similarly, with neural networks, we can feed it a wide range of data and see what part of neurons gets “activated”. In CNNs, this task is a bit easier since one can directly look at the CNN kernels (i.e. filters).

filters

Filters learned from AlexNet - Source

Feeding in bunch of data to the model and collecting statistics on where model gets activated is a viable idea. However, it is often the case that multiple parts of the models fire up given a single image. This behavior makes it inherently hard to dissect what individual components of the model focuses on.

One of the earliest solution to this problem was “DeepDream” flips the problem. Instead of feeding data which fires specific parts of the model. You can optimize the data to maximally activate a specific part of the model.

deepdream

Source: Distil.pub

Early works in interpretability research mainly focused on ImageNet pretrained models and mutlimodal models such as CLIP. But, as large language models became a more attractive specimen, interpretability researchers got naturally interested in applying these techniques to LLMs.

Polysemanticity is a phenomenon where a “single neuron” gets activated by mixture of different concepts (semantics). For instance, a neuron could get excited by a “human face, front of cars, and cat’s back”. Ideally, we’d some kind of method to separate these concepts. This is where Sparse Autoencoders (SAE) come in.

Sparse Autoencoder, also known as dictionary learning is a classical method to learn a dictionary which sparsely encodes the data. In practice, SAEs are trained by a shallow neural network with a single encoder / decoder.

arch

A well-trained SAE has the following features

The first objective is easily achieved with a standard reconstruction loss. The second objective is typically achieved with a ReLU activation to threshold negative values and adding a sparsity penalty.

archx

It’s still unclear at this point why SAEs are useful for decomposing a single feature vector. But, if you think about what SAE is doing in terms of matrix multiplication, it becomes quite clear.

matmul

Multiplying a sparsely activated hidden feature is equivalent to taking a weighted average over a small subset of decoder matrix columns. In other words, one can think of columns of the decoder matrix as a “dictionary”. One parallel I would like to draw here is that MLP blocks and GLU blocks found in large transformer models can be thought as SAE as beautifully put by Grant Sanderson.

Another key advantage of SAE with generative model is that once you identify what each of the SAE features’ semantics, one can simply “boost” this feature to guide the model’s generation as seen with Golden Gate Claude.

steer

Sadly, interpreting SAE feature is still a somewhat manual process. Most common interpretability strategy is:

Now that we’ve laid the foundation for SAEs let’s go into the details.

Training Sparse Autoencoders

When it comes to training SAEs, there are two metrics we need to focus on:

In summary, the best SAEs are those that can accurately reconstruct the input using minimal number of dictionary features. The architecture that I’ve explained in the previous section is a standard SAE implementation used by Anthropic team. Since then, multiple variants of SAEs have emerged. The variants I’ve tested are Gated SAE, JumpReLU SAE, and Topk SAE. In the end, I’ve ended up using Topk SAEs because these variants allow you to explicitly set desired L0 (i.e. kk) value. Furthmore, Topk SAEs come with the benefit that you don’t have to tune spacity loss weight λ\lambda as it is done in standard, Gated, and JumpReLU SAEs.

sweep

After running a sweeps on learning rate, λ\lambda, and kk values across different values, I’ve found that TopK SAEs to be most performant and robust across different configurations.

Training configuration

hugofry has previously trained a SAE on CLIP vision encoder on ImageNet dataset. To expand on their work, I’ve decided to use CC3M dataset with LLaVA recaptioning to train SAEs on CLIP vision/text encoder. For vision encoder, CLS token was sampled from penultimate layer of CLIP vision encoder. For text encoder, text pooled output was used as input to the SAE. To make training faster, I cached the activations onto .safetensors format to avoid recomputing activations.

As for FLUX.1.schnell, I’ve used the same training setup as CLIP SAEs. I’ve sampled activations from the output of following locations at first time step:

For FLUX SAEs, the activations were generated on the fly by using CC3M prompts to generate images. Motivated by Ostris’s finding on skipping FLUX blocks, I thought it’d be interesting to train SAEs on early, middle, and late layers of the model. We only train on activations sampled from first denoising step because FLUX.1.schnell tends to converge on its content, style, and composition extremely quickly. We generate images using standard FLUX.1.schnell sampling pipeline with base resolution 256x256.

Notes on training (Optional)

It should be taken with a grain of salt that I did not implement “all” the tricks that were used in Gated / JumpReLU SAEs. To keep my implementation simple, I intentionally avoided using techniques such as:

However, a few techniques I’ve used were:

For Topk SAEs, I use a learning rate of 3e-4 with the AdamW defaults. I didn’t see noticible boost by setting β0=0\beta_0 = 0. Effective batch size was set to 8192. SAEs were trained for 6 ~ 12 epochs on 8 A100 GPUs. DDP + bf16 mixed precision was used to speed up training. I’ve used k = 4,16,32,64,128 to control for L0.

Visualising SAEs and Multimodal dataset

Visualising dataset is always a joyful activity. I wanted to develop a minimal interface to explore the dataset itself and the features learned by SAEs. Over a weekend, I built fluxlens, a minimal visualisation tool to explore SAE features and multimodal datasets. The points are rendered using nomic’s open source deepscatter engine. I use cuml’s CUDA implementation of UMAP and HDBSCAN to for effiicient clustering and dimensionality reduction.

interface

Quick tour of SAE explorer interface

cc3m

CC3M dataset visualisation

Interesting defects in CC3M dataset

defects

Near duplicate newspaper images

defects

Color palette images

404

404 images

Examples of CLIP SAE feature visualisation

For CLIP vision / text encoders the results were quite promising.

interface

Houses SAE features

interface

Drinks SAE features

interface

Dance SAE features

Examples of FLUX SAE feature visualisation

I’ve found that SAEs trained on the lower stream of FLUX model transformer_blocks.0 were significantly harder to interpret compared to later layers. The later layers such as single_transformer_blocks.37 seemed to have slightly more interpretable features. Below are a few examples of features visualised from FLUX SAEs on single_transformer_blocks.37.

interface

Natural road SAE features

interface

Text SAE features

These features are significantly less interpretable than CLIP features. I hypothesize that this is due to the fact that

Preliminary Steering results

After extracting these SAE features, I tried steering the model using the identified features. I tried steering the generation using both CLIP SAE features and FLUX SAE features, but the results were not as promising as I had hoped.

steering

Steering results using FLUX SAE features

This result was a bit unexpected since Gytis had previously shown that steering works well with CLIP SAEs on Kandinsky 2.2 models. You can try to increase steering strength even more, but I found that it starts to break the generation or make the images extremely bright. However, I was overall happy with the visualisation results and CLIP SAEs that I’ve trained. As for the reason why steering didn’t work as well, it could be several reasons:

Possible reasons for steering failure

CLIP has relatively minimal impact on generation content.

I’ve abalated effects of CLIP on 6~8 different images by setting CLIP features to zero. At least on first glance, the generated images did not seem to change much. Of course, this observation is on a handful of images, so it might not be conclusive.

FLUX is a flow matching model.

At the end of the day, FLUX is a flow matching model, not an autoregressive model which predicts the image directly. Gytis’s result serves as a counterpoint to this hypothesis, but Kandinsky 2.2’s training receipe is quite different from FLUX. It might be more interesting to apply SAE steering techniques to image generation models such as HART, MAR, and Lumina.

I’m very dumb and my steering implementation is flawed (quite likely).

I might have made a mistake in my steering implementation. In that case, I would love to hear feedback from the community. If you have any insights and would like to dive into the code, details, or anything else, please reach out! I’m available via twitter / email / … etc.

Future directions

Of course, there are a few unturned stones that I would like to explore in the future:

One experimental idea I have is to simply take a few hundred images with a distinct style, sample activations from FLUX, and simply run a SVD/PCA and use the resulting components as steering vectors. Obvisouly, there’s the task of adding extra UI componetns to fluxlens to allow users to steer generations on the sight. Ideally, I should try to make it as feature-complete as possible like Neuronpedia.