Skip to content

Commit dc58ac5

Browse files
authored
Merge pull request #64 from graphcore-research/migrate-almost-scaled-blog
Migrate blog post 'almost scaled dot product attention' to the graphcore blog
2 parents fa7bde6 + b9ea20d commit dc58ac5

File tree

6 files changed

+2
-98
lines changed

6 files changed

+2
-98
lines changed

docs/blog.rst

Lines changed: 0 additions & 14 deletions
This file was deleted.

docs/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,4 @@ instructions in our :doc:`developer guide <development>`.
4949
User guide <user_guide>
5050
Developer guide <development>
5151
Limitations <limitations>
52-
Blog <blog>
5352
API reference <api_reference>
Lines changed: 2 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,5 @@
11
# Almost-scaled dot-product attention
22

3-
TL;DR: _Scaled dot product attention isn't properly scaled, and that's a good thing!_
3+
**This post [has moved](https://graphcore-research.github.io/posts/almost_scaled/)**.
44

5-
Notebook: _[almost-scaled dot-product attention](https://github.com/graphcore-research/unit-scaling/tree/main/analysis/almost_scaled_dot_product_attention/almost_scaled_dot_product_attention.ipynb)_
6-
7-
---
8-
9-
Transformers seem to be all you need, but we don't fully understand why they work so well. While working on [unit scaling](https://arxiv.org/abs/2303.11257), we noticed something surprising about attention, the heart of the transformer architecture, and how the outputs are scaled.
10-
11-
Many deep learning modules are designed and initialised to roughly preserve variance in the forward and/or backward (gradient) passes. This is a useful property as the behaviour of many modules depends on the scale of their inputs (e.g. saturating nonlinearities). Dot product attention explicitly includes a <span style="color: #008000">scaling factor</span> for this to ensure the variance going into the softmax is stable:
12-
13-
```{math}
14-
A^{\prime} &= Q K^T \cdot \color{green}{d_{head}^{-1/2}}
15-
16-
Z &= \mathrm{Softmax}(A^{\prime})\, V
17-
```
18-
19-
But this is _insufficient for the attention operation as a whole_. We have derived a <span style="color: #fc4349">post-scaling factor</span> for attention to correct this:
20-
21-
```{math}
22-
Z = \mathrm{Softmax}(A^{\prime})\, V \color{red}{\,\cdot\, (d_{seq}/e)^{1/2}}
23-
```
24-
25-
Where {math}`d_{seq}` is the sequence length. For example, this gives the following scaling behaviour:
26-
27-
```{figure} img/attention_scaling.png
28-
---
29-
width: 30em
30-
align: center
31-
alt: "attention scaling: regular attention is underscaled to sigma=0.1 when d_seq=256, but scaled to sigma=1.0 when using a sqrt(d_seq/e) multiplier"
32-
---
33-
```
34-
<p/>
35-
36-
In this post, we'll look at the variance-scaling behaviour of attention, and explain this scaling factor, before seeing that it makes training dynamics _worse_, not better. The post is a condensed summary of our [almost-scaled dot-product attention notebook](https://github.com/graphcore-research/unit-scaling/tree/main/analysis/almost_scaled_dot_product_attention/almost_scaled_dot_product_attention.ipynb).
37-
38-
## Where does {math}`(d_{seq}/e)^{1/2}` come from?
39-
40-
Attention contains the expression {math}`Z=\mathrm{Softmax}(A^{\prime})V`. If we modify this slightly to introduce a temperature {math}`t`, {math}`Z=\mathrm{Softmax}(A^{\prime}/t)V`, we can think about three cases (assuming {math}`V \sim N(0, 1)`):
41-
42-
- {math}`t\to \infty`, the scale of {math}`Z` is {math}`d_{seq}^{-1/2}` — the softmax output is flat with all values {math}`= d_{seq}^{-1}`, followed by a sum over {math}`d_{seq}` uncorrelated values which scales up by {math}`d_{seq}^{1/2}`
43-
- {math}`t\to 0`, the scale of {math}`Z` is {math}`1` and the output is a single unit spike — attention selects a single element of {math}`V`
44-
- {math}`t \gt 1/2`, the scale of {math}`Z` is {math}`(e^{t^{-2}}/d_{seq})^{1/2}` and with some assumptions, the output follows a log-normal distribution — we explain this further in the [companion notebook](https://github.com/graphcore-research/unit-scaling/tree/main/analysis/almost_scaled_dot_product_attention/almost_scaled_dot_product_attention.ipynb)
45-
46-
```{figure} img/softmax_temperature.png
47-
---
48-
align: center
49-
width: 30em
50-
alt: "effect of softmax temperature, flat when temperature is infinite, a spike when temperature is zero and a bumpy corve when temperature is one"
51-
---
52-
```
53-
<p/>
54-
55-
We find that the log-normal scaling rule works well for temperature near 1, so propose multiplying by the inverse, i.e. scale attention output by {math}`(d_{seq}/e)^{1/2}`.
56-
57-
## Does it work? ...No!
58-
59-
We tested this change, introducing "fully scaled attention" in a full transformer model—a small autoregressive character language model trained on Shakespeare. This is what we saw from a learning rate sweep:
60-
61-
```{figure} img/scaled_attention_lr_sweep.png
62-
---
63-
align: center
64-
width: 25em
65-
alt: "learning rate sweep for baseline (standard attention) and fully scaled attention. Fully scaled attention behaves worse than the baseline (final training loss 1.2 for baseline, 1.4 for fully scaled)"
66-
---
67-
```
68-
<p/>
69-
70-
This is most unfortunate. It seems that under-scaled tensors coming out of the attention block are important and helpful for transformer training dynamics. It isn't just tiny Shakespare models—we've also seen this effect when training BERT. We don't yet have an explanation for this difference, but find it intriguing that such a (presumed) accident of under-scaling turns out to be helpful for training dynamics!
71-
72-
Unit scaling has a solution for this, allowing unit-scaled tensors while retaining the original training dynamics. The bad training behaviour must come from scale-dependent operations, in particular when attention's residual output is added to the skip connection. So, we found that we can reproduce the same dynamics as the original model by applying a relative weight to the residual vs skip connections.
73-
74-
## Conclusion
75-
76-
It is helpful to think through the scales of tensors in deep learning models. Indeed, careful reasoning about scale is the core principle underpinning unit scaling (which also considers the scale of gradients, not just activations).
77-
78-
In the above example, we saw how to "fix" attention's scaling behaviour, multiplying the outputs by {math}`(d_{seq}/e)^{1/2}`, so that the outputs are unit-variance. However we also saw that this change can make training dynamics worse, not better. Why this happens is, as far as we know, an open question.
79-
80-
If you're interested to find out more, check out our [accompanying notebook](https://github.com/graphcore-research/unit-scaling/tree/main/analysis/almost_scaled_dot_product_attention/almost_scaled_dot_product_attention.ipynb) and [unit scaling](https://arxiv.org/abs/2303.11257) paper.
81-
82-
---
83-
84-
With thanks to Charlie Blake for help & feedback.
85-
86-
— Douglas Orr ([[email protected]](mailto:[email protected])), October 2023
5+
Note that the approach and equations described in this post are legacy and do not reflect the current implementation of u-μP. Please see the code for a definitive reference.
-90.8 KB
Binary file not shown.
-25.9 KB
Binary file not shown.
-23.3 KB
Binary file not shown.

0 commit comments

Comments
 (0)