Part of a series on |
Machine learning and data mining |
---|
In artificial neural networks, attention is a technique that is meant to mimic cognitive attention. The effect enhances some parts of the input data while diminishing other parts — the motivation being that the network should devote more focus to the small, but important, parts of the data. Learning which part of the data is more important than another depends on the context, and this is trained by gradient descent.
Attention-like mechanisms were introduced in the 1990s under names like multiplicative modules, sigma pi units, and hyper-networks.^{[1]} Its flexibility comes from its role as "soft weights" that can change during runtime, in contrast to standard weights that must remain fixed at runtime. Uses of attention include memory in neural Turing machines, reasoning tasks in differentiable neural computers,^{[2]} language processing in transformers, and LSTMs, and multi-sensory data processing (sound, images, video, and text) in perceivers. ^{[3]}^{[4]}^{[5]}^{[6]} Listed in the Variants section below are the many schemes to implement the soft-weight mechanisms.
Given a sequence of tokens labeled by the index , a neural network computes a soft weight for each with the property that is non-negative and . Each is assigned a value vector which is computed from the word embedding of the th token. The weighted average is the output of the attention mechanism.
The query-key mechanism computes the soft weights. From the word embedding of each token, it computes its corresponding query vector and key vector . The weights are obtained by taking the softmax function of the dot product where represents the current token and represents the token that's being attended to.
In some architectures, there are multiple "heads" of attention (termed 'multi-head attention'), each operating independently with their own queries, keys, and values.
To build a machine that translates English to French, one takes the basic Encoder-Decoder and grafts an attention unit to it (diagram below). In the simplest case, the attention unit consists of dot products of the recurrent encoder states and does not need training. In practice, the attention unit consists of 3 fully-connected neural network layers called query-key-value that need to be trained. See the Variants section below.
Label | Description |
---|---|
100 | Max. sentence length |
300 | Embedding size (word dimension) |
500 | Length of hidden vector |
9k, 10k | Dictionary size of input & output languages respectively. |
x, Y | 9k and 10k 1-hot dictionary vectors. x → x implemented as a lookup table rather than vector multiplication. Y is the 1-hot maximizer of the linear Decoder layer D; that is, it takes the argmax of D's linear layer output. |
x | 300-long word embedding vector. The vectors are usually pre-calculated from other projects such as GloVe or Word2Vec. |
h | 500-long encoder hidden vector. At each point in time, this vector summarizes all the preceding words before it. The final h can be viewed as a "sentence" vector, or a thought vector as Hinton calls it. |
s | 500-long decoder hidden state vector. |
E | 500 neuron RNN encoder. 500 outputs. Input count is 800–300 from source embedding + 500 from recurrent connections. The encoder feeds directly into the decoder only to initialize it, but not thereafter; hence, that direct connection is shown very faintly. |
D | 2-layer decoder. The recurrent layer has 500 neurons and the fully-connected linear layer has 10k neurons (the size of the target vocabulary).^{[7]} The linear layer alone has 5 million (500 × 10k) weights -- ~10 times more weights than the recurrent layer. |
score | 100-long alignment score |
w | 100-long vector attention weight. These are "soft" weights which changes during the forward pass, in contrast to "hard" neuronal weights that change during the learning phase. |
A | Attention module — this can be a dot product of recurrent states, or the query-key-value fully-connected layers. The output is a 100-long vector w. |
H | 500×100. 100 hidden vectors h concatenated into a matrix |
c | 500-long context vector = H * w. c is a linear combination of h vectors weighted by w. |
Viewed as a matrix, the attention weights show how the network adjusts its focus according to context.
I | love | you | |
je | 0.94 | 0.02 | 0.04 |
t' | 0.11 | 0.01 | 0.88 |
aime | 0.03 | 0.95 | 0.02 |
This view of the attention weights addresses the "explainability" problem that neural networks are criticized for. Networks that perform verbatim translation without regard to word order would have a diagonally dominant matrix if they were analyzable in these terms. The off-diagonal dominance shows that the attention mechanism is more nuanced. On the first pass through the decoder, 94% of the attention weight is on the first English word "I", so the network offers the word "je". On the second pass of the decoder, 88% of the attention weight is on the third English word "you", so it offers "t'". On the last pass, 95% of the attention weight is on the second English word "love", so it offers "aime".
There are many variants of attention that implements soft weights, including (a) Bahdanau Attention,^{[8]} also referred to as additive attention, and (b) Luong Attention ^{[9]} which is known as multiplicative attention, built on top of additive attention, and (c) self-attention introduced in transformers. For convolutional neural networks, the attention mechanisms can also be distinguished by the dimension on which they operate, namely: spatial attention,^{[10]} channel attention,^{[11]} or combinations of both.^{[12]}^{[13]}
These variants recombine the encoder-side inputs to redistribute those effects to each target output. Often, a correlation-style matrix of dot products provides the re-weighting coefficients (see legend).
1. encoder-decoder dot product | 2. encoder-decoder QKV | 3. encoder-only dot product | 4. encoder-only QKV | 5. Pytorch tutorial |
---|---|---|---|---|
Label | Description |
---|---|
Variables X, H, S, T | Upper case variables represent the entire sentence, and not just the current word. For example, H is a matrix of the encoder hidden state—one word per column. |
S, T | S, decoder hidden state; T, target word embedding. In the Pytorch Tutorial variant training phase, T alternates between 2 sources depending on the level of teacher forcing used. T could be the embedding of the network's output word; i.e. embedding(argmax(FC output)). Alternatively with teacher forcing, T could be the embedding of the known correct word which can occur with a constant forcing probability, say 1/2. |
X, H | H, encoder hidden state; X, input word embeddings. |
W | Attention coefficients |
Qw, Kw, Vw, FC | Weight matrices for query, key, vector respectively. FC is a fully-connected weight matrix. |
⊕, ⊗ | ⊕, vector concatenation; ⊗, matrix multiplication. |
corr | Column-wise softmax(matrix of all combinations of dot products). The dot products are x_{i}* x_{j} in variant #3, h_{i}* s_{j} in variant 1, and column_{i }( Kw* H )* column_{ j }( Qw* S ) in variant 2, and column_{ i }(Kw* X)* column_{ j }(Qw* X) in variant 4. Variant 5 uses a fully-connected layer to determine the coefficients. If the variant is QKV, then the dot products are normalized by the sqrt(d) where d is the height of the QKV matrices. |