Part of a series on |
Machine learning and data mining |
---|
![]() |
Machine learning-based attention is a mechanism mimicking cognitive attention. It calculates "soft" weights for each word, more precisely for its embedding, in the context window. It can do it either in parallel (such as in transformers) or sequentially (such as recursive neural networks). "Soft" weights can change during each runtime, in contrast to "hard" weights, which are (pre-)trained and fine-tuned and remain frozen afterwards. Multiple attention heads are used in transformer-based large language models.
Predecessors of the mechanism were used in recursive neural networks which, however, calculated "soft" weights sequentially and, at each step, considered the current word and other words within the context window. They were known as multiplicative modules, sigma pi units,[1] and hyper-networks.[2] They have been used in LSTMs, and multi-sensory data processing (sound, images, video, and text) in perceivers, fast weight controllers's memory,[3] reasoning tasks in differentiable neural computers, and neural Turing machines[4][5][6][7][8]
Correlating the different parts within a sentence or a picture can help capture its structure and meaning. In the sentence "see that girl run" the attention weights, originating from the word "that", are being calculated by the Q and K sub-networks of a single "attention head" in the illustration below. As a result the most soft weight (or attention) is given to the word "girl".
The sentence is split into three paths (left), which merge at the end as the context vector (right). The word embedding size is 300 and the neuron count is 100 in each sub-network of the attention head.
- The capital letter X denotes a matrix sized 4 * 300, consisting of the embeddings of all four words.
- The small underlined letter x denotes the embedding vector (sized 300) of the word "that".
- The attention head include three (vertically arranged in the illustration) sub-networks, each having 100 neurons with a weight matrix sized 300 x 100.
- The asterix within parenthesis "(*)" denotes the softmax( qKT / sqrt(100) ), i.e. not yet multiplied by the matrix V.
- Rescaling by sqrt(100) prevents a high variance in qKT that would allow a single word to excessively dominate the softmax resulting in attention to only one word, as a discrete hard max would do.
Notation: the commonly written row-wise softmax formula here assumes that vectors are rows, which contradicts the standard math notation of column vectors. More correctly, we should take the transpose of the context vector and use the column-wise softmax, resulting in the more correct form
The query vector is compared (via dot product) with each word in the keys. This helps the model discover the most relevant word for the query word. In this case "girl" was determined to be the most relevant word for "that". The result (size 4 in this case) is run through the softmax function, producing a vector of size 4 with probabilities summing to 1. Multiplying this against the value matrix effectively amplifies the signal for the most important words in the sentence and diminishes the signal for less important words.[9]
The structure of the input data is captured in the Qw and Kw weights, and the Vw weights express that structure in terms of more meaningful features for the task being trained for. For this reason, the attention head components are called Query (Q), Key (K), and Value (V)—a loose and possibly misleading analogy with relational database systems.
Note that the context vector for "that" does not rely on context vectors for the other words; therefore the context vectors of all words can be calculated using the whole matrix X, which includes all the word embeddings, instead of a single word's embedding vector x in the formula above, thus parallelizing the calculations. Now, the softmax can be interpreted as a matrix softmax acting on separate rows. This is a huge advantage over recurrent networks which must operate sequentially.
To build a machine that translates English to French, an attention unit is grafted to the basic Encoder-Decoder (diagram below). In the simplest case, the attention unit consists of dot products of the recurrent encoder states and does not need training. In practice, the attention unit consists of 3 trained, fully-connected neural network layers called query, key, and value.
Label | Description |
---|---|
100 | Max. sentence length |
300 | Embedding size (word dimension) |
500 | Length of hidden vector |
9k, 10k | Dictionary size of input & output languages respectively. |
x, Y | 9k and 10k 1-hot dictionary vectors. x → x implemented as a lookup table rather than vector multiplication. Y is the 1-hot maximizer of the linear Decoder layer D; that is, it takes the argmax of D's linear layer output. |
x | 300-long word embedding vector. The vectors are usually pre-calculated from other projects such as GloVe or Word2Vec. |
h | 500-long encoder hidden vector. At each point in time, this vector summarizes all the preceding words before it. The final h can be viewed as a "sentence" vector, or a thought vector as Hinton calls it. |
s | 500-long decoder hidden state vector. |
E | 500 neuron RNN encoder. 500 outputs. Input count is 800–300 from source embedding + 500 from recurrent connections. The encoder feeds directly into the decoder only to initialize it, but not thereafter; hence, that direct connection is shown very faintly. |
D | 2-layer decoder. The recurrent layer has 500 neurons and the fully-connected linear layer has 10k neurons (the size of the target vocabulary).[10] The linear layer alone has 5 million (500 × 10k) weights – ~10 times more weights than the recurrent layer. |
score | 100-long alignment score |
w | 100-long vector attention weight. These are "soft" weights which changes during the forward pass, in contrast to "hard" neuronal weights that change during the learning phase. |
A | Attention module – this can be a dot product of recurrent states, or the query-key-value fully-connected layers. The output is a 100-long vector w. |
H | 500×100. 100 hidden vectors h concatenated into a matrix |
c | 500-long context vector = H * w. c is a linear combination of h vectors weighted by w. |
Viewed as a matrix, the attention weights show how the network adjusts its focus according to context.
I | love | you | |
je | 0.94 | 0.02 | 0.04 |
t' | 0.11 | 0.01 | 0.88 |
aime | 0.03 | 0.95 | 0.02 |
This view of the attention weights addresses the neural network "explainability" problem. Networks that perform verbatim translation without regard to word order would show the highest scores along the (dominant) diagonal of the matrix. The off-diagonal dominance shows that the attention mechanism is more nuanced. On the first pass through the decoder, 94% of the attention weight is on the first English word "I", so the network offers the word "je". On the second pass of the decoder, 88% of the attention weight is on the third English word "you", so it offers "t'". On the last pass, 95% of the attention weight is on the second English word "love", so it offers "aime".
Many variants of attention implement soft weights, such as
For convolutional neural networks, attention mechanisms can be distinguished by the dimension on which they operate, namely: spatial attention,[17] channel attention,[18] or combinations.[19][20]
These variants recombine the encoder-side inputs to redistribute those effects to each target output. Often, a correlation-style matrix of dot products provides the re-weighting coefficients.
1. encoder-decoder dot product | 2. encoder-decoder QKV | 3. encoder-only dot product | 4. encoder-only QKV | 5. Pytorch tutorial |
---|---|---|---|---|
![]() |
![]() |
![]() |
![]() |
![]() |
Label | Description |
---|---|
Variables X, H, S, T | Upper case variables represent the entire sentence, and not just the current word. For example, H is a matrix of the encoder hidden state—one word per column. |
S, T | S, decoder hidden state; T, target word embedding. In the Pytorch Tutorial variant training phase, T alternates between 2 sources depending on the level of teacher forcing used. T could be the embedding of the network's output word; i.e. embedding(argmax(FC output)). Alternatively with teacher forcing, T could be the embedding of the known correct word which can occur with a constant forcing probability, say 1/2. |
X, H | H, encoder hidden state; X, input word embeddings. |
W | Attention coefficients |
Qw, Kw, Vw, FC | Weight matrices for query, key, vector respectively. FC is a fully-connected weight matrix. |
⊕, ⊗ | ⊕, vector concatenation; ⊗, matrix multiplication. |
corr | Column-wise softmax(matrix of all combinations of dot products). The dot products are xi* xj in variant #3, hi* sj in variant 1, and columni ( Kw* H )* column j ( Qw* S ) in variant 2, and column i (Kw* X)* column j (Qw* X) in variant 4. Variant 5 uses a fully-connected layer to determine the coefficients. If the variant is QKV, then the dot products are normalized by the sqrt(d) where d is the height of the QKV matrices. |