This note contains four papers for "historical perspective"... which would usually mean "no longer directly relevant", although I'm not sure that's really what the author means.
You might be looking for the author's "Understanding Large Language Models" post [1] instead.
Misspelling "Attention is All Your Need" twice in one paragraph makes for a rough start to the linked post.
and https://news.ycombinator.com/item?id=23649542 gives some context to the "For instance, in 1991, which is about two-and-a-half decades before the original transformer paper above ("Attention Is All You Need")"
> Misspelling "Attention is All Your Need" twice in one paragraph makes for a rough start to the linked post.
100%! LOL. I was traveling and typing this on a mobile device. Must have been some weird autocorrect/autocomplete. Strange. And I didn't even notice. Thanks!
As someone who started reading ML papers 10 years ago, I find transformers pretty simple compared to most architectures. For example Google's Inception was a complex mesh of various sized convolutions, other models have an internal algorithm like Non Maximum Suppression for object detection, or Connectionist Temporal Classification for OCR, GANs use complicated probability theory for the loss function. Even LSTM is more complicated.
If anything, we have been abandoning exotic neural nets in favour of a single architecture and that one is pretty simple, just linear layers (vector-matrix product), key-value products (matrix-matrix product), softmax (fancy normalisation), weighted averaging (a sum of products) and skip connections (an addition). Maybe it's become hard for me to see what is complicated about it, I'd be curious to know what part is difficult. Is it the embeddings, masking, multiple heads, gradient descent, ...? Embeddings have been famous for 10 years, ever since the king - man + woman = queen paper. You don't need to be able to derive the gradients for the network by hand to understand it.
In short, a transformer is mixing information between tokens in a sequence and computing updates. The mixing part is the "self attention" or "cross attention". The updating part is the feed-forward sublayer. It has skip connections (adds the input to the output) in order to keep training stable.
The things that still confuse me about transformers is:
1. Why do we _add_ the positional embedding to the semantic embedding? It seems like it means certain semantic directions are irreversibly with certain positions.
2. I don't understand why the attention head (which I can implement and follow the math of) is described as "key query value lookup". Specifically, the Q and K matrices aren't structurally distinct – the projections into them will learn different weights, but one doesn't start out biased key-ward and the other query-ward.
The first one: transformers are "permutation invariant" by nature, so if you permute the input and apply the opposite permutation to the output you get the exact same thing. The transformer itself has no positional information. RNNs by comparison have positional information by design, they go token by token, but the transformer is parallel and all tokens are just independent "channels". So what can be done? You put positional embeddings in it - either by adding them to the tokens (concatenation was also ok, but less efficient) or by inserting relative distance biases in the attention matrix. It's a fix to make it understand time. It's still puzzling this works, because mixing text tokens with position tokens seems to cause a conflict, but it doesn't in practice. The model will learn to use the embedding vector for both, maybe specialising a part for semantics and another for position.
The second question. Neural nets find a way to differentiate the keys from queries by simply doing gradient descent. If we tell the model it should generate a specific token here, then it needs to fix the keys and queries to make it happen. The architecture is pretty dumb, the secret is the training data - everything the transformer learns comes from the training set. We should think about the training data when we marvel at what transformers can do. The architecture doesn't tell us why they work so well.
With regards to the "It's still puzzling this works" wrt positional encoding, I have developed an intuition (that may be very wrong ;-). If you take the fourier transform of a linear or sawtooth function (akin to the the progress of time), I think you get something that resembles the positional encoding in the original transformer. EDIT: fixed typo
This is a good intuition. At times it reminds me of old school hand rolled feature engineering used in time series modelling: assuming that the signal is made up of a stationary component and a sine wave. Though haven't managed to mathematically figure out if the two are equivalent.
> The architecture is pretty dumb, the secret is the training data
If this were true, we could throw the same training data at any other "dumb" architecture and it would learn language at least as well/fast as transformers do. But we don't see that happening, so the architecture must be smartly designed for this purpose.
Actually there are alternatives by the hundreds, with similar results. Reformer, Linformer, Performer, Longformer... none is better than vanilla overall, they all have an edge in some use case.
And then we have MLP-mixer which just doesn't do "attention" at all, MLP is all you need. A good solution for edge models.
Other dumb architectures don't parallelize as well. Other architectures that parallelize at similar levels (RNN-RWKV, H3, S4, etc.) do perform well at similar parameter counts and data sizes.
Regarding the positional encoding, why not include a scalar in the range (0..1) with every token where the scalar encodes the position of the token? This adds a small amount of complexity to the network, but it could aid comprehensibility which to me seems preferable if you're still doing research on these networks.
I'm still not clear on the second question. If lalaithion's original statement "the Q and K matrices aren't structurally distinct" is true, then once the neural network is trained, how can we look at the two matrices and confidently say that one is the query matrix instead of it being the key matrix (or vice versa)? To put it another way: is the distinction between query and key roles "real" or is it just an analogy for humans?
I am not an expert, but I think that they are structurally identical only in decoder only transformers like GPT. The original transformers were used for translation, and so the encoder-decoder layers use Q from the decoder layer and K from the encoder layer. The attention is all you need paper has an explanation:
> In "encoder-decoder attention" layers, the queries come from the previous decoder layer, and the memory keys and values come from the output of the encoder. This allows every position in the decoder to attend over all positions in the input sequence. This mimics the typical encoder-decoder attention mechanisms in sequence-to-sequence models such as...
Would this not imply that if I encrypt the input and then decrypt the output I would get the correct result (i.e. what I would have gotten if I used the plaintext input)?
I recently had the same questions and here is how I understand it:
1. You could concatenate the positional embedding and the semantic embedding and that way isolate them from each other. But if that separation is necessary, the model can learn the separation itself as well (it can make positional embeddings and semantic embeddings orthogonal to each other), so using addition is strictly more general.
2. My sense is that you could merge the Q and K matrices and everything would work mostly the same, but with multi-headed attention this will typically result in a larger matrix than the combined sizes of Q and K. It's basically a more efficient matrix factorization.
Curious to see if I got this right and if there is more to it.
One advantage of summing is that the lower frequency terms hardly change for a small text, so effectively there is more capacity for embeddings with short texts, while still encoding order in long texts.
1. High dimensional embedding space is way more vast than you'd think, so adding two vectors together doesn't really destroy information in the same way as addition does in low dimensional cartesian space - the semantic and position information remains separable.
2. I find the QKV nomenclature unintuitive too. Cross attention explains it a bit, where the Q and K come from different places. For self attention they are the same, but the terminology stuck.
1. It works, the direct alternative (concatenation) allocates a smaller dimension to the initial embedding, and also added positional embeddings are no longer commonly used in newer Transformers. Schemes like RoPE and ALiBi are more common.
2. I'm not 100% sure I understand your question. The Ks correspond to the Vs, and so is used to compute the weighted sum over Vs. This is easiest to understand when you think of an encoder-decoder model (Qs come from the decoder, KVs come from the encoder), or decoding in a decoder (there is 1Q and multiple KVs)
One aspect of the specific positional embedding used there is that it explicitly encodes a signal that the very first attention layer can directly use for both relative and absolute position - i.e. that there can be a trivial set of wights for "pay attention only to the token two tokens to the right from the target token" and also a trivial set of weights saying "pay attention to the first token in the sequence" and also "pay attention to this aspect of all the words weighed by the distance from the target token". As by default the transformer architecture is effectively position-blind, having the flexibility to learn all these different types of relations is important; and many possible simple, clean, efficient position encodings make it easy to represent some relations but very difficult for others, perhaps theoretically possible but needing extra layers and/or hard to learn by gradient descent.
To answer (2): You are token i. In order to see how much of a token j's value v_j you update yourself with, you compare your query q_i with token j's key k_j. This gives you the asymmetry between queries and keys.
This is even more apparent in a cross-attention setting, where one stream of tokens will have only queries associated with it and the other will have only keys/values.
Agreed, compared to other architectures, transformers are actually quite straight-forward. The complicated part comes more from training it in distributed setups, making the data loading and tensor parallelism work due to the large size etc. Like the vanilla architecture is simple, but the practical implementation for large-scale training can be a bit complicated.
In CNN the layers seem to learn geometric primitive and deeper layers seem to learn more complex geometric patterns loosely speaking.
In transformer what do query key matrices learn? How are their weights somehow working to extract context no matter which word appears in which position?
The transformer doesn't have the nice pyramid shape of CNNs, but it still needs multiple layers. There have been papers showing non-trivial interactions between successive layers, forming more complex circuits.
The Q an K matrices learn how to relate tokens. Each of the heads will learn to extract a different relation. For example, one will link to the next token, another will link pronouns to their references, another would be matching brackets, etc. Check out the cute diagrams here:
Agree. So attention is like hierarchy of graphs where nodes are tokens and edges are attention scores per head.
Now what's trippy is this node has position data. So node feature and position it appears is used to create a operator that projects a sequence to a semantic space.
This seems to work for any modality of data.. so there is some thing about order in which data appears that seems to be linked to semantics and for me hints about some deep causal structure being latent in LLM
Simple is good. Especially in machine learning where a bug usually means that it kinda works, but not as well as it could. Also, when an off-the-shelf algorithm half works, it's good to be able to add you own tweaks to it, and again, this requires simplicity.
For a complicated architecture to succeed, it's going to need to reliably achieve state of the art performance on everything without requiring any adjustment or tweaks.
Even when someone understands the architecture very well, there's still great utility in having a full graph representation of the full NN architecture. This could be for teaching students, or for doing analyses on the structure of the full network, etc.
That would be wonderful and I have been trying to do this. However, unfortunately some 'assumptions'/shortcuts have to be made. For example, the attention matrix is not known without input, so if just the structure of the network (weighted by the weights) is wanted, you have to put in some value 'p' ('1', '-1', w/e) to these edges. Also skip connections have to be dealt with explicitly instead of just adding them to a block diagonal matrix as one would with an MLP.
I am very interested if someone has a good solution that already has done these things though.
Those are nice diagrams. Yes, well actually I'm interested in anything that's between the abstracted form in the paper and the fully expanded form where you see all neurons.
Another cool diagram. But one thing it misses is that you can't follow the arrows from input to output. For example, you might be tempted to think that the Keys or Queries are inputs to the neural net since there is no arrow going into them.
I wonder if for example a function is an example of a transformer. So the phrase "argument one is cat" and argument two is dog and operation is join so the result is the word catdog is operated by the transformer as the function concat(cat,dog). Here the query is the function and the keys are the argument for the function and the value is a function from word to words.
They can intelligently parse the unstructured input into a structured internal form, apply a transform, and then format the result back into unstructured text. Even the transform itself can be an argument.
The actual title "Why the Original Transformer Figure Is Wrong, and Some Other Interesting Historical Tidbits About LLMs" is way more representative of what this post is about...
As to the figure being wrong, it's kind of a nit-pick:
"While the original transformer figure above (from Attention Is All Your Need, https://arxiv.org/abs/1706.03762) is a helpful summary of the original encoder-decoder architecture, there is a slight discrepancy in this figure.
For instance, it places the layer normalization between the residual blocks, which doesn't match the official (updated) code implementation accompanying the original transformer paper. The variant shown in the Attention Is All Your Need figure is known as Post-LN Transformer."
So weird, I posted it with almost the original title (only slightly abbreviated to make it fit: "Why the Original Transformer Figure Is Wrong, and Some Interesting Tidbits About LLMs".
Not sure what happened there. Someone must have changed it! So weird! And I agree that the current title is a bit awkward and less representative.