1 A 3 O R N

A Note on Per-Token Linear Transformations in Tranformers

Created: 2025-07-12
Wordcount: 0.8k

For the last few weeks, I've been experimenting with a per-token linear projections added to the residual stream of an LLM. They consistently seem to give a reasonably large decrease in loss, despite their absurd simplicity.

All we do is add one more element to the standard Transformer block:

from torch import nn

class TransformerBlock(nn.Module):

    def __init__(self, config, index):
        super().__init__()
        self.attention = SelfAttention(config, index)
        self.ff = MLP(config, index)

        # New!
        self.extra = PerTokenTrans(config, index)

    def forward(self, x, token_indices):

        x_att = self.attention(norm(x))
        x = x + x_att

        # New!
        x = x + self.extra(x, token_indices)

        x_ff = self.ff(norm(x))
        x = x + x_ff

        return x

The per-token-transform here does three things:

  1. Projects the residual stream down to a much smaller (~square root of hidden dim) size through a simple linear layer.
  2. Multiplies the stream with a token-index-specific matrix, within that tiny subspace.
  3. Projects up back to the full hidden dimension. Following NanoGPT, I initialize this matrix to zero.

This is a ton of parameters (~4x more), but very few floating-point-operations, and also relatively little memory. It dependably increases performance up to the largest size I've managed to test, and across a variety of Transformer widths / heights. Here's a typical improvement, whch looks about the same at every sale I've found.

Token-Linear Transform

Note that this chart is typical in the sense that it starts off even or a bit behind, in terms of wall-clock time, before pulling ahead.

Note also that this is a slightly pessimistic chart -- I've found that this change may let you bump up the learning rate higher than you otherwise might, but this is of learning-rate-matched runs.

You can frame this change in two different ways, I think:

  • You can see it as a simplified per-token Hash MoE. if you want. We route each token to a different "expert", it's just a very small expert with no discontinuity.
  • You can see it as a way to have an extra per-token embedding, and thus save on compute, in the same way that the NanoGPT speedrun has extra embeddings per-token put into the value section of attention.

This seems to be the best-working version of a family of possible alterations.

Here's how I stumbled on this:

First, BlinkDL described a way to to improve the MLP in a Transformer on Twitter.

Basically, you have an embedding vector for each token for each layer, with size [hidden dimension], where each element is initalized to 1. Then, after each MLP, you elementwise multiply the output of each MLP by this embedding before adding back to the residual. At the start of training, this does nothing because it's just elementwise multiplication by 1; but soon we learn some kind of per-token modulator for the output of the MLP. This consistently

Second, well -- in this schema, each token's embedding only alters one number with the elementwise multiplication. I naturally wondered if you could increase performance by allowing an actual linear transform.

So I reshaped the vector to be a matrix with the same number of elements, then after each MLP projected down to this smaller dimension, multiplied by the transform, and cast up. This did indeed work much better, instantly.

  1. But then -- why have rescaling in the MLP at all? Why not add it to the residual dimension instead? This also seems to improve performance a bit. And so it reached its current form.

I've made a lot of local modifications to this (non-linearity in subspace, various different initialization schemes, grouped convolutions), but it seems hard to improve from here. It's a local maxima as far as changed that come readily-to-mind are concerned.

I'm not sure how promising this is.

This is leaning a little hard on tokens as meaningful units of semantic information, and this might not be the future. I'm also working with a small-ish vocabulary of only 30k. It's conceivable this would work less well with a larger vocabulary. And I have no idea how it would interact with a real MoE.

But on the other hand, this has worked really well over a bunch of different hyperparameters. Seems relatively promising.