Understanding CausalAttention — the thing that makes GPT-style models actually work

February 21, 2026

Understanding CausalAttention

Before we get into any code — let's get the intuition right.

Imagine you're reading the sentence:

"The cat sat on the mat"

When you hit the word "sat", you didn't process it in isolation. You subconsciously pulled in "The cat" to understand what sat. You blended context from earlier words to give "sat" meaning.

That's exactly what attention does. For every word in a sequence, it figures out which other words are relevant — and how much — then mixes their information together.

CausalAttention adds one constraint on top of that: you can only look backwards. When processing "sat", you can see "The" and "cat" — but not "on" or "mat". This is what makes it causal — no peeking at the future.

That constraint is what lets a language model generate text one token at a time.


The class at a glance

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys    = self.W_key(x)
        queries = self.W_query(x)
        values  = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec

Six moving parts. Let's take them one by one.


Part 1 — the three projection matrices

self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

Each token embedding comes in with shape (d_in,). These three linear layers project it into three different (d_out,) vectors.

Think of it like this — every token plays three roles simultaneously:

RoleMatrixWhat it represents
ShopperW_query (Q)What am I looking for?
LabelW_key (K)What do I advertise to others?
ProductW_value (V)What do I actually give if selected?

Token i uses its query to scan the keys of every other token. High dot product = strong match. Once we know the match scores, we use the values to do the actual mixing.

The query and key are used purely for routing — deciding who pays attention to whom. The value is the content. They're entirely separate learned projections, which is what gives the model the flexibility to route and retrieve independently.

Concrete example. Take "sat" in our sentence. Its query might be tuned to look for subject tokens. It scans the keys of "The", "cat" — and "cat" scores high because its key fires for animate subject patterns. So "sat" ends up borrowing heavily from "cat"'s value. The output representation of "sat" is now contextualised — it knows who sat.


Part 2 — the causal mask

self.register_buffer(
    'mask',
    torch.triu(torch.ones(context_length, context_length), diagonal=1)
)

torch.triu(..., diagonal=1) creates an upper-triangular matrix of 1s. For a sequence of 5 tokens it looks like this:

      The  cat  sat  on  mat
The  [ 0    1    1    1    1 ]
cat  [ 0    0    1    1    1 ]
sat  [ 0    0    0    1    1 ]
on   [ 0    0    0    0    1 ]
mat  [ 0    0    0    0    0 ]

1 means blocked. 0 means allowed.

Row = the token doing the attending. Column = the token being attended to.

  • "The" can only look at itself.
  • "cat" can look at "The" and itself.
  • "sat" can look at "The", "cat", and itself.
  • "mat" can see the whole sentence — but only because it's last.

This enforces the causal constraint: no token can see the future. The positions where j > i are masked out.

register_buffer saves the mask as part of the module — it moves to GPU with the model — but it's not a learned parameter. It never gets gradients. It's just a constant constraint we apply at every forward pass.


Part 3 — computing attention scores

queries = self.W_query(x)   # shape: (B, T, d_out)
keys    = self.W_key(x)     # shape: (B, T, d_out)

attn_scores = queries @ keys.transpose(1, 2)   # shape: (B, T, T)

We project x into queries and keys, then do a batched matrix multiply.

The result is a (T, T) grid per batch item. Entry [i, j] answers: how much does token i's query align with token j's key?

Let's keep the numbers simple. Say we have 4 tokens and d_out = 2. The raw score grid might look like:

        tok0   tok1   tok2   tok3
tok0  [ 0.9    0.2    0.7    0.1 ]
tok1  [ 0.4    0.8    0.3    0.6 ]
tok2  [ 0.1    0.9    0.8    0.2 ]
tok3  [ 0.5    0.3    0.6    0.7 ]

Before masking, every token is comparing itself to every other token — including the future ones. Those upper-right scores get killed next.


Part 4 — applying the mask

attn_scores.masked_fill_(
    self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
)

masked_fill_ replaces every position where mask == 1 with −∞.

After masking, our score grid becomes:

        tok0   tok1   tok2   tok3
tok0  [ 0.9    -inf   -inf   -inf ]
tok1  [ 0.4    0.8    -inf   -inf ]
tok2  [ 0.1    0.9    0.8    -inf ]
tok3  [ 0.5    0.3    0.6    0.7  ]

Why −∞ specifically? Because the next step is softmax. And softmax(−∞) = 0 — exactly. No floating point fuzz, no tiny residuals. Those positions contribute nothing to the output. The future is completely invisible.

The [:num_tokens, :num_tokens] slice handles sequences shorter than context_length — we only apply as much mask as we need.


Part 5 — scaled softmax

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

Two things happening here.

Scaling by √d_out

Dot products grow proportionally to dimension size. If d_out = 512, the raw dot products can get very large, pushing softmax into its saturation region where gradients are near-zero. Dividing by √d_out keeps them in a healthy range.

Think of it like this: if you add up 512 random numbers, the variance grows with dimension. Dividing by √512 normalises that back down.

Softmax over dim=-1

Softmax turns each row into a probability distribution — all values between 0 and 1, summing to exactly 1. The −∞ entries become exactly 0.

After this step, our masked example becomes something like:

        tok0   tok1   tok2   tok3
tok0  [ 1.00   0.00   0.00   0.00 ]
tok1  [ 0.35   0.65   0.00   0.00 ]
tok2  [ 0.08   0.56   0.36   0.00 ]
tok3  [ 0.22   0.14   0.34   0.30 ]

These are the attention weights. Each row is a mixing recipe — how much of each token's value to blend into the output.


Part 6 — dropout + weighted sum of values

attn_weights = self.dropout(attn_weights)
context_vec  = attn_weights @ values

Dropout randomly zeros out some attention weights during training. This stops the model from becoming over-reliant on a specific attention pattern. At inference (after model.eval()), dropout is a no-op.

The weighted sum is where the actual information mixing happens:

context_vec = attn_weights @ values   # shape: (B, T, d_out)

For each token position i, the output is a weighted average of all allowed value vectors:

context[i] = w[i,0]*v[0] + w[i,1]*v[1] + ... + w[i,i]*v[i]

For "sat" (token index 2), if the weights were [0.08, 0.56, 0.36], the output would be:

context["sat"] = 0.08 * value["The"]
               + 0.56 * value["cat"]
               + 0.36 * value["sat"]

The output for "sat" is now a blend — heavily influenced by "cat" — even though the input was just "sat". That blending is the whole point. The token has been contextualised.


Putting it all together

Here's what happens to a single batch of tokens in one forward pass:

x                         (B, T, d_in)
  ↓ W_query, W_key, W_value
queries, keys, values     (B, T, d_out) each
  ↓ queries @ keysᵀ
attn_scores               (B, T, T)   — raw similarity grid
  ↓ masked_fill_(-inf)
attn_scores               (B, T, T)   — future positions zeroed
  ↓ softmax(/ √d)
attn_weights              (B, T, T)   — probability distributions
  ↓ dropout
attn_weights              (B, T, T)   — some weights zeroed (training only)
  ↓ @ values
context_vec               (B, T, d_out) — contextualised token representations

Input: raw token embeddings, one per position. Output: contextualised token embeddings, where each position has borrowed from its past.

The model learns what to look for (Q and K weights) and what to contribute (V weights) entirely from data — through backprop — with no manual rules about which words relate to which.


The one thing to remember

If you forget everything else, hold onto this:

Every output token is a weighted average of past value vectors, where the weights are learned similarities between queries and keys — and future positions are permanently set to zero weight.

That's CausalAttention. Everything else is implementation detail.


Next up: we'll stack multiple of these into Multi-Head Attention — where we run several attention mechanisms in parallel, each learning to route on a different type of relationship.

Related posts