Skip to content

Commit 72e1bd2

Browse files
authored
Merge pull request #72 from graphcore-research/remove-constraint-on-container-modules
Remove the constraint argument in the container modules uu.{MLP, MHSA}
2 parents 77e74cf + 1500b13 commit 72e1bd2

File tree

4 files changed

+22
-41
lines changed

4 files changed

+22
-41
lines changed

unit_scaling/_modules.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
format_docstring,
2020
inherit_docstring,
2121
mult_docstring,
22-
variadic_constraint_docstring,
2322
)
2423
from .parameter import MupType, Parameter, has_parameter_data
2524

@@ -428,47 +427,40 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
428427
)
429428

430429

431-
@format_docstring(binary_constraint_docstring)
432430
class MLP(nn.Module):
433431
"""A **unit-scaled** implementation of an MLP layer using SwiGLU.
434432
435433
Args:
436434
hidden_size (int): the hidden dimension size of the input.
437435
expansion_factor (int): the factor by which the MLP's intermediate size
438436
increases relative to `hidden_size`.
439-
{0}
440437
"""
441438

442-
def __init__(
443-
self,
444-
hidden_size: int,
445-
expansion_factor: int = 4,
446-
constraint: Optional[str] = "to_output_scale",
447-
) -> None:
439+
def __init__(self, hidden_size: int, expansion_factor: int = 4) -> None:
448440
super().__init__()
449441
intermediate_size = hidden_size * expansion_factor
450-
self.linear_1 = Linear(hidden_size, intermediate_size, constraint=constraint)
451-
self.linear_gate = Linear(hidden_size, intermediate_size, constraint=constraint)
452-
self.linear_2 = Linear(intermediate_size, hidden_size, constraint=constraint)
442+
# Note: constraint=None is safe here, because we know that the forward and
443+
# backward constraints are mirrored between {linear_1, linear_gate} and
444+
# linear_2.
445+
self.linear_1 = Linear(hidden_size, intermediate_size, constraint=None)
446+
self.linear_gate = Linear(hidden_size, intermediate_size, constraint=None)
447+
self.linear_2 = Linear(intermediate_size, hidden_size, constraint=None)
453448

454449
def forward(self, input: Tensor) -> Tensor:
455450
z = U.silu_glu(self.linear_1(input), self.linear_gate(input))
456451
return self.linear_2(z) # type:ignore[no-any-return]
457452

458453

459-
@format_docstring(mult_docstring(), variadic_constraint_docstring)
454+
@format_docstring(mult_docstring())
460455
class MHSA(nn.Module):
461456
"""A **unit-scaled** implementation of a multi-head self-attention layer.
462457
463-
Warning: using `constraint=None` here will likely give incorrect gradients.
464-
465458
Args:
466459
hidden_size (int): the hidden dimension size of the input.
467460
heads (int): the number of attention heads.
468461
is_causal (bool): causal masking (for non-padded sequences).
469462
dropout_p (float, optional): the probability of the post-softmax dropout.
470463
{0}
471-
{1}
472464
"""
473465

474466
def __init__(
@@ -478,16 +470,14 @@ def __init__(
478470
is_causal: bool,
479471
dropout_p: float = 0.0,
480472
mult: float = 1.0,
481-
constraint: Optional[str] = "to_output_scale",
482473
) -> None:
483474
super().__init__()
484475
self.heads = heads
485476
self.dropout_p = dropout_p
486477
self.is_causal = is_causal
487478
self.mult = mult
488-
self.linear_qkv = Linear(hidden_size, 3 * hidden_size, constraint=constraint)
489-
self.linear_o = Linear(hidden_size, hidden_size, constraint=constraint)
490-
self.constraint = constraint
479+
self.linear_qkv = Linear(hidden_size, 3 * hidden_size)
480+
self.linear_o = Linear(hidden_size, hidden_size)
491481

492482
def forward(self, input: Tensor) -> Tensor:
493483
q_k_v = self.linear_qkv(input)
@@ -499,7 +489,6 @@ def forward(self, input: Tensor) -> Tensor:
499489
return self.linear_o(qkv) # type: ignore
500490

501491

502-
@format_docstring(variadic_constraint_docstring)
503492
class TransformerLayer(nn.Module):
504493
"""A **unit-scaled** implementation of a PreNorm
505494
(see https://arxiv.org/abs/2002.04745) transformer layer.
@@ -516,7 +505,6 @@ class TransformerLayer(nn.Module):
516505
is_causal (bool): causal masking (for non-padded sequences).
517506
dropout_p (float, optional): the probability of residual and post-softmax
518507
dropout.
519-
{0}
520508
"""
521509

522510
def __init__(
@@ -527,22 +515,15 @@ def __init__(
527515
mlp_tau: float,
528516
is_causal: bool,
529517
dropout_p: float = 0.0,
530-
constraint: Optional[str] = "to_output_scale",
531518
) -> None:
532519
super().__init__()
533520
self.dropout_p = dropout_p
534521
self.mhsa_tau = mhsa_tau
535522
self.mlp_tau = mlp_tau
536523
self.mhsa_norm = RMSNorm(hidden_size)
537-
self.mhsa = MHSA(
538-
hidden_size,
539-
heads,
540-
is_causal=is_causal,
541-
dropout_p=dropout_p,
542-
constraint=constraint,
543-
)
524+
self.mhsa = MHSA(hidden_size, heads, is_causal=is_causal, dropout_p=dropout_p)
544525
self.mlp_norm = RMSNorm(hidden_size)
545-
self.mlp = MLP(hidden_size, constraint=constraint)
526+
self.mlp = MLP(hidden_size)
546527

547528
def forward(self, input: Tensor) -> Tensor:
548529
input, skip = U.residual_split(input, tau=self.mhsa_tau)
@@ -627,16 +608,13 @@ def __init__(
627608
)
628609

629610

630-
@format_docstring(variadic_constraint_docstring)
631611
class TransformerDecoder(nn.Sequential): # pragma: no cover
632612
"""A **unit-scaled** implementation of a decoder-type transformer.
633613
634614
Note: this class is currently just for demonstrating scaling and lacks key
635615
functionality (for example masking, positional embeddings, usage for
636616
inference).
637617
638-
Warning: using `constraint=None` here will likely give incorrect gradients.
639-
640618
Args:
641619
hidden_size (int): the hidden dimension size of the input.
642620
vocab_size (int): the number of tokens in the vocabulary.
@@ -648,7 +626,6 @@ class TransformerDecoder(nn.Sequential): # pragma: no cover
648626
controlling residual weights in the transformer trunk; see
649627
:func:`unit_scaling.core.functional.transformer_residual_scaling_rule`
650628
(default).
651-
{0}
652629
"""
653630

654631
def __init__(
@@ -659,7 +636,6 @@ def __init__(
659636
heads: int,
660637
dropout_p: float = 0.0,
661638
residual_scaling: ResidualScalingFn = transformer_residual_scaling_rule(),
662-
constraint: Optional[str] = "to_output_scale",
663639
) -> None:
664640
super().__init__()
665641
self.embedding = Embedding(vocab_size, hidden_size)
@@ -670,7 +646,6 @@ def __init__(
670646
is_causal=True,
671647
dropout_p=dropout_p,
672648
residual_scaling=residual_scaling,
673-
constraint=constraint,
674649
)
675650
self.final_norm = RMSNorm(hidden_size)
676651
self.projection = LinearReadout(hidden_size, vocab_size)

unit_scaling/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
22

3-
__version__ = "0.2"
3+
__version__ = "0.3"

unit_scaling/tests/test_modules.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,12 @@ def test_mlp() -> None:
201201

202202
assert float(output.std()) == pytest.approx(1, abs=0.2)
203203

204+
assert_unit_scaled(
205+
model.linear_1.weight.grad,
206+
model.linear_gate.weight.grad,
207+
model.linear_2.weight.grad,
208+
)
209+
204210
assert_not_unit_scaled(
205211
model.linear_1.weight, model.linear_gate.weight, model.linear_2.weight
206212
)

unit_scaling/tests/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def test_analyse_mlp() -> None:
2929
def forward(self, input : Tensor) -> Tensor:
3030
input_1 = input; (-> 1.0, <- 1.44)
3131
linear_1_weight = self.linear_1.weight; (-> 1.0, <- 0.503)
32-
linear = U.linear(input_1, linear_1_weight, None, 'to_output_scale'); (-> 1.0, <- 0.502)
32+
linear = U.linear(input_1, linear_1_weight, None, None); (-> 1.0, <- 0.502)
3333
linear_gate_weight = self.linear_gate.weight; (-> 1.0, <- 0.519)
34-
linear_1 = U.linear(input_1, linear_gate_weight, None, 'to_output_scale'); (-> 1.0, <- 0.518)
34+
linear_1 = U.linear(input_1, linear_gate_weight, None, None); (-> 1.0, <- 0.518)
3535
silu_glu = U.silu_glu(linear, linear_1); (-> 1.0, <- 0.5)
3636
linear_2_weight = self.linear_2.weight; (-> 1.0, <- 1.0)
37-
linear_2 = U.linear(silu_glu, linear_2_weight, None, 'to_output_scale'); (-> 1.0, <- 1.0)
37+
linear_2 = U.linear(silu_glu, linear_2_weight, None, None); (-> 1.0, <- 1.0)
3838
return linear_2
3939
""".strip() # noqa: E501
4040

0 commit comments

Comments
 (0)