Keywords: Transformers, attention mechanism, sequence transduction, natural language processing, neural networks
Overview: This article provides a deep dive into transformers, starting from basic concepts like one-hot encoding and matrix multiplication, and gradually building up to the full transformer architecture. It explains the core mechanisms, such as attention, embeddings, and positional encoding, while also addressing practical considerations like computational efficiency and gradient stability. The author aims to provide a clear mental model of how transformers work, enabling readers to understand the latest advancements in natural language processing. The article emphasizes the importance of attention, skip connections, and layer normalization in achieving good performance.
Section-by-section reading:
- One-hot encoding: Converts words into numerical representations for computation. One-hot encoding represents each word as a vector with a single ‘1’ and the rest ‘0’s. This allows for mathematical operations on words.
- Dot product: A mathematical operation used to measure similarity between vectors. The dot product of two one-hot vectors is 1 if they are the same and 0 if they are different. It can be used to compare a one-hot encoded word against a vector of values.
- Matrix multiplication: A way to combine two-dimensional arrays using dot products. Matrix multiplication can act as a lookup table, pulling out specific rows of a matrix based on a one-hot vector. This is a core mechanism in how transformers work.
- First order sequence model: Represents sequences using a transition model (Markov chain). It shows the likelihood of the next word based on the current word. Markov chains can be expressed in matrix form.
- Second order sequence model: Predicts the next word based on the two most recent words. This provides more context and reduces uncertainty compared to a first-order model. The transition matrix has N^2 rows, where N is the vocabulary size.
- Second order sequence model with skips: Considers combinations of the most recent word with each of the words that came before. This allows capturing long-range dependencies. Each value in the matrix represents a vote, and votes are summed to determine next word predictions.
- Masking: Improves prediction by weeding out uninformative feature votes. A mask is a vector that forces unhelpful features to zero. Selective masking is the “attention” mechanism in transformers.
- Attention as matrix multiplication: Expresses the mask lookup as a matrix multiplication (QK^T). Q represents the feature of interest, and K represents the collection of masks. This is a differentiable lookup table.
- Second order sequence model as matrix multiplications: Constructs transition matrices using matrix multiplications. A single-layer fully connected neural network creates word pair features. ReLU nonlinearity is applied to clean up the results.
- Sequence completion: Generates long sequences using a prompt and greedy selection. The decoder takes a forward pass and predicts probability distributions of words. The word with the highest probability is chosen as the next word.
- Embeddings: Projects words from their one-hot representation to a lower-dimensional space. This reduces the number of parameters needed. A good embedding groups words with similar meanings together.
- Positional encoding: Adds position information to the embedded representation of words. This is done by adding a circular wiggle to the position of the word in the embedding space. The wiggle is repeated in all pairs of dimensions with different angular frequencies.
- De-embeddings: Converts embedded words back to words from the original vocabulary. This is done with a projection from one space to another (matrix multiplication). The de-embedding results are converted to a probability distribution using softmax.
- Softmax: Converts de-embedding results to a probability distribution. It thins the field near the top and is differentiable. This allows using it with backpropagation to train the transformer.
- Multi-head attention: Runs several different instances of attention (heads) at once. This allows the transformer to consider several previous words simultaneously. Each head has its own transforms (Wv, Wq, Wk).
- Single head attention revisited: Attention becomes a relationship between word groups rather than individual words. The Mask block enforces the constraint that we can’t look into the future. It avoids introducing any weird artifacts from imaginary future words.
- Skip connection: Adds a copy of the input to the output of a set of calculations. This helps keep the gradient smooth and preserves the original input sequence. Skip connections are now a standard feature in neural network architectures.
- Layer normalization: Shifts the values of the matrix to have a mean of zero and scales them to have a standard deviation of one. It maintains a consistent distribution of signal values throughout many-layered neural networks. Normalization encourages convergence of parameter values and usually results in much better performance.
- Multiple layers: Having multiple attention layers provides redundancy. If one layer fails, another can take up the slack. Layers become workers on the assembly line, where each does what it can, but doesn’t worry about catching every piece, because the next worker will catch the ones they miss.
- Decoder stack: The decoder alone is useful for sequence completion tasks. Generative models like GPT use the decoder half of a transformer.
- Encoder stack: The encoder stack produces a semantic representation of the sequence. It is a useful signal for communicating intent and meaning to the decoder stack.
- Cross-attention: Connects the encoder and decoder stacks. The key and value matrices are based on the output of the final encoder layer. The query matrix is calculated from the results of the previous decoder layer.
- Tokenizing: Breaks the unbroken stream of text into a sequence of distinct chunks, and provides a concise code for each one.
- Byte pair encoding: Infers which long sequences of characters to learn from the data, as opposed to dumbly representing all possible sequences.
- Audio input: Raw audio is turned into a a sequence of vectors, where each element represents the change of audio activity in a particular frequency range.
Related Tools:
- None mentioned in the article.
References:
- Attention is All You Need: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
- The Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
- The Illustrated Transformer: https://jalammar.github.io/illustrated-transformer/
- Explanation of how transformers work: https://www.youtube.com/watch?v=rBCqOTEfxvg
- Illustrations in Google Slides: https://docs.google.com/presentation/d/1Po-GY7X-mXmPKHr8Vh29S4tFPv23TjeY-jq-yShlivM/edit?usp=sharing
Original Article Link: https://www.brandonrohrer.com/transformers