Skip to content

Commit 215ba63

Browse files
Merge pull request #83 from graphcore-research/emb_lr_notebook
add emb lr analysis notebook
2 parents b0696ec + 81164f5 commit 215ba63

File tree

1 file changed

+364
-0
lines changed

1 file changed

+364
-0
lines changed

analysis/emb_lr_analysis.ipynb

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Analysis of the effect of the embedding LR update on the subsequent matmul\n",
8+
"\n",
9+
"I wanted to write this out in a notebook to make sure I understood the way in which the embedding update effects the subsequent matmul.\n",
10+
"\n",
11+
"No revelations unfortunately - it still seems as though our rule can't be justified this way (it is \"unnatural\"!). Under the \"no-alignment\" assumption the standard embedding LR breaks, but unfortunately our fix does nothing to help. Oh well."
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 1,
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"import torch\n",
21+
"from torch import randn\n",
22+
"from typing import Iterable"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 2,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"def rms(*xs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:\n",
32+
" if len(xs) == 1:\n",
33+
" return xs[0].pow(2).mean().sqrt()\n",
34+
" return tuple(rms(x) for x in xs)"
35+
]
36+
},
37+
{
38+
"cell_type": "markdown",
39+
"metadata": {},
40+
"source": [
41+
"## Setup\n",
42+
"\n",
43+
"Toggle `full_alignment` and `umup_lr_rule` to see the effect. mup scaling is used by default."
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": 3,
49+
"metadata": {},
50+
"outputs": [],
51+
"source": [
52+
"d = 2**11\n",
53+
"full_alignment = True\n",
54+
"umup_lr_rule = False\n",
55+
"\n",
56+
"w_lr = d ** -(1 if full_alignment else 0.5)\n",
57+
"e_lr = d ** -(0.5 if umup_lr_rule else 0)"
58+
]
59+
},
60+
{
61+
"cell_type": "markdown",
62+
"metadata": {},
63+
"source": [
64+
"## Model & update\n",
65+
"\n",
66+
"Everything can be described in terms of these three tensors (a single embedding vector, weight matrix and a gradient vector). Note that I assume the gradient is unit-scale, and then just use the adam LR rules but under and SGD-like update (I appreciate this is a bit odd, but it's simple and the maths should work out)"
67+
]
68+
},
69+
{
70+
"cell_type": "code",
71+
"execution_count": 4,
72+
"metadata": {},
73+
"outputs": [
74+
{
75+
"data": {
76+
"text/plain": [
77+
"(tensor(0.9984), tensor(0.0221), tensor(0.9882))"
78+
]
79+
},
80+
"execution_count": 4,
81+
"metadata": {},
82+
"output_type": "execute_result"
83+
}
84+
],
85+
"source": [
86+
"e1 = randn(d, 1)\n",
87+
"W1 = randn(d + 1, d) * d**-0.5\n",
88+
"g = randn(d + 1, 1)\n",
89+
"rms(\n",
90+
" e1, W1, g\n",
91+
") # all \"well-scaled\", except the weight which is 1/sqrt(d) (this isn't unit scaling!)"
92+
]
93+
},
94+
{
95+
"cell_type": "markdown",
96+
"metadata": {},
97+
"source": [
98+
"Then we just run:"
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": 5,
104+
"metadata": {},
105+
"outputs": [
106+
{
107+
"data": {
108+
"text/plain": [
109+
"tensor(0.9953)"
110+
]
111+
},
112+
"execution_count": 5,
113+
"metadata": {},
114+
"output_type": "execute_result"
115+
}
116+
],
117+
"source": [
118+
"x1 = W1 @ e1\n",
119+
"rms(x1) # well-scaled"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": 6,
125+
"metadata": {},
126+
"outputs": [
127+
{
128+
"data": {
129+
"text/plain": [
130+
"((tensor(0.9977), tensor(0.0005)), 0.00048828125)"
131+
]
132+
},
133+
"execution_count": 6,
134+
"metadata": {},
135+
"output_type": "execute_result"
136+
}
137+
],
138+
"source": [
139+
"u_e = W1.T @ g * e_lr\n",
140+
"u_W = g @ e1.T * w_lr\n",
141+
"(\n",
142+
" rms(u_e, u_W),\n",
143+
" 1 / d,\n",
144+
") # the weight update is under-scaled (to be expected I think), though as a rank-1 matrix it has a much higher (O(1)) spectral norm! This means its effect doesn't \"go to zero\" in inf. width, though the rms does."
145+
]
146+
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": 7,
150+
"metadata": {},
151+
"outputs": [
152+
{
153+
"data": {
154+
"text/plain": [
155+
"(tensor(0.9998), tensor(0.0221))"
156+
]
157+
},
158+
"execution_count": 7,
159+
"metadata": {},
160+
"output_type": "execute_result"
161+
}
162+
],
163+
"source": [
164+
"e2 = e1 + u_e\n",
165+
"e2_std = e2.std()\n",
166+
"e2 /= e2_std # Why is `/ e2.std()` allowed/justified? Normally we'd have a much smaller weight update (scaled down by small LR constant), and then the original weight would be decayed a bit, keeping this at about rms=1. This re-scaling does something similar, though allows us to see the effect of the weight update scaling more clearly.\n",
167+
"W2 = W1 + u_W\n",
168+
"rms(\n",
169+
" e2, W2\n",
170+
") # Update is well-scaled. Weight has barely changed from its 1/sqrt(d) starting point"
171+
]
172+
},
173+
{
174+
"cell_type": "code",
175+
"execution_count": 8,
176+
"metadata": {},
177+
"outputs": [
178+
{
179+
"data": {
180+
"text/plain": [
181+
"tensor(1.7412)"
182+
]
183+
},
184+
"execution_count": 8,
185+
"metadata": {},
186+
"output_type": "execute_result"
187+
}
188+
],
189+
"source": [
190+
"x2 = W2 @ e2\n",
191+
"rms(x2) # ~well-scaled. Certainly doesn't scale with a significant power of d"
192+
]
193+
},
194+
{
195+
"cell_type": "markdown",
196+
"metadata": {},
197+
"source": [
198+
"## Analysis\n",
199+
"\n",
200+
"Now we break this down into its constituent terms.\n",
201+
"\n",
202+
"First checking that they combine to the original"
203+
]
204+
},
205+
{
206+
"cell_type": "code",
207+
"execution_count": 9,
208+
"metadata": {},
209+
"outputs": [
210+
{
211+
"data": {
212+
"text/plain": [
213+
"True"
214+
]
215+
},
216+
"execution_count": 9,
217+
"metadata": {},
218+
"output_type": "execute_result"
219+
}
220+
],
221+
"source": [
222+
"torch.allclose(x2, (W1 + u_W) @ (e1 + u_e * e_lr) / e2_std, atol=1e-6)\n",
223+
"torch.allclose(x2, (W1 + g @ e1.T * w_lr) @ (e1 + W1.T @ g * e_lr) / e2_std, atol=1e-6)"
224+
]
225+
},
226+
{
227+
"cell_type": "code",
228+
"execution_count": 10,
229+
"metadata": {},
230+
"outputs": [
231+
{
232+
"data": {
233+
"text/plain": [
234+
"True"
235+
]
236+
},
237+
"execution_count": 10,
238+
"metadata": {},
239+
"output_type": "execute_result"
240+
}
241+
],
242+
"source": [
243+
"# t1 = W1 @ e1 (== x1)\n",
244+
"t2 = W1 @ W1.T @ g * e_lr\n",
245+
"t3 = g @ e1.T * w_lr @ e1\n",
246+
"t4 = g @ e1.T * w_lr @ W1.T @ g * e_lr\n",
247+
"torch.allclose(x2, (x1 + t2 + t3 + t4) / e2_std, atol=1e-5)"
248+
]
249+
},
250+
{
251+
"cell_type": "markdown",
252+
"metadata": {},
253+
"source": [
254+
"### Weight @ emb_update (t2)\n",
255+
"\n",
256+
"This is well-scaled under the original emb lr rule, but not under our lr rule - which isn't a great sign for our approach"
257+
]
258+
},
259+
{
260+
"cell_type": "code",
261+
"execution_count": 11,
262+
"metadata": {},
263+
"outputs": [
264+
{
265+
"name": "stdout",
266+
"output_type": "stream",
267+
"text": [
268+
"rms(W1, g), e_lr=((tensor(0.0221), tensor(0.9882)), 1)\n",
269+
"rms(W1 @ W1.T)=tensor(0.0312)\n",
270+
"rms(W1.T @ g)=tensor(0.9977)\n",
271+
"rms(W1 @ W1.T @ g * e_lr / e2_std)=tensor(0.9857)\n"
272+
]
273+
}
274+
],
275+
"source": [
276+
"print(f\"{rms(W1, g), e_lr=}\")\n",
277+
"print(f\"{rms(W1 @ W1.T)=}\")\n",
278+
"print(f\"{rms(W1.T @ g)=}\")\n",
279+
"print(f\"{rms(W1 @ W1.T @ g * e_lr / e2_std)=}\")"
280+
]
281+
},
282+
{
283+
"cell_type": "markdown",
284+
"metadata": {},
285+
"source": [
286+
"### Weight_update @ emb (t3)\n",
287+
"\n",
288+
"This is well-scaled under the original emb lr rule and our rule"
289+
]
290+
},
291+
{
292+
"cell_type": "code",
293+
"execution_count": 12,
294+
"metadata": {},
295+
"outputs": [
296+
{
297+
"name": "stdout",
298+
"output_type": "stream",
299+
"text": [
300+
"rms(g, e1)=(tensor(0.9882), tensor(0.9984))\n",
301+
"rms(g @ e1.T)=tensor(0.9866)\n",
302+
"rms(e1.T @ e1 * w_lr)=tensor(0.9968)\n",
303+
"rms(g @ e1.T * w_lr @ e1)=tensor(0.9850)\n"
304+
]
305+
}
306+
],
307+
"source": [
308+
"print(f\"{rms(g, e1)=}\")\n",
309+
"print(f\"{rms(g @ e1.T)=}\")\n",
310+
"print(f\"{rms(e1.T @ e1 * w_lr)=}\")\n",
311+
"print(f\"{rms(g @ e1.T * w_lr @ e1)=}\")"
312+
]
313+
},
314+
{
315+
"cell_type": "markdown",
316+
"metadata": {},
317+
"source": [
318+
"### Weight_update @ emb_update (t4)\n",
319+
"\n",
320+
"This vanishes with width under the original emb lr and our rule. Probably a good thing?"
321+
]
322+
},
323+
{
324+
"cell_type": "code",
325+
"execution_count": 13,
326+
"metadata": {},
327+
"outputs": [
328+
{
329+
"name": "stdout",
330+
"output_type": "stream",
331+
"text": [
332+
"rms(g @ e1.T @ W1.T @ g)=tensor(46.5558)\n",
333+
"rms(g @ e1.T * w_lr @ W1.T @ g * e_lr)=tensor(0.0227)\n"
334+
]
335+
}
336+
],
337+
"source": [
338+
"print(f\"{rms(g @ e1.T @ W1.T @ g)=}\")\n",
339+
"print(f\"{rms(g @ e1.T * w_lr @ W1.T @ g * e_lr)=}\")"
340+
]
341+
}
342+
],
343+
"metadata": {
344+
"kernelspec": {
345+
"display_name": ".venv",
346+
"language": "python",
347+
"name": "python3"
348+
},
349+
"language_info": {
350+
"codemirror_mode": {
351+
"name": "ipython",
352+
"version": 3
353+
},
354+
"file_extension": ".py",
355+
"mimetype": "text/x-python",
356+
"name": "python",
357+
"nbconvert_exporter": "python",
358+
"pygments_lexer": "ipython3",
359+
"version": "3.11.9"
360+
}
361+
},
362+
"nbformat": 4,
363+
"nbformat_minor": 2
364+
}

0 commit comments

Comments
 (0)