19
19
format_docstring ,
20
20
inherit_docstring ,
21
21
mult_docstring ,
22
- variadic_constraint_docstring ,
23
22
)
24
23
from .parameter import MupType , Parameter , has_parameter_data
25
24
@@ -428,47 +427,40 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
428
427
)
429
428
430
429
431
- @format_docstring (binary_constraint_docstring )
432
430
class MLP (nn .Module ):
433
431
"""A **unit-scaled** implementation of an MLP layer using SwiGLU.
434
432
435
433
Args:
436
434
hidden_size (int): the hidden dimension size of the input.
437
435
expansion_factor (int): the factor by which the MLP's intermediate size
438
436
increases relative to `hidden_size`.
439
- {0}
440
437
"""
441
438
442
- def __init__ (
443
- self ,
444
- hidden_size : int ,
445
- expansion_factor : int = 4 ,
446
- constraint : Optional [str ] = "to_output_scale" ,
447
- ) -> None :
439
+ def __init__ (self , hidden_size : int , expansion_factor : int = 4 ) -> None :
448
440
super ().__init__ ()
449
441
intermediate_size = hidden_size * expansion_factor
450
- self .linear_1 = Linear (hidden_size , intermediate_size , constraint = constraint )
451
- self .linear_gate = Linear (hidden_size , intermediate_size , constraint = constraint )
452
- self .linear_2 = Linear (intermediate_size , hidden_size , constraint = constraint )
442
+ # Note: constraint=None is safe here, because we know that the forward and
443
+ # backward constraints are mirrored between {linear_1, linear_gate} and
444
+ # linear_2.
445
+ self .linear_1 = Linear (hidden_size , intermediate_size , constraint = None )
446
+ self .linear_gate = Linear (hidden_size , intermediate_size , constraint = None )
447
+ self .linear_2 = Linear (intermediate_size , hidden_size , constraint = None )
453
448
454
449
def forward (self , input : Tensor ) -> Tensor :
455
450
z = U .silu_glu (self .linear_1 (input ), self .linear_gate (input ))
456
451
return self .linear_2 (z ) # type:ignore[no-any-return]
457
452
458
453
459
- @format_docstring (mult_docstring (), variadic_constraint_docstring )
454
+ @format_docstring (mult_docstring ())
460
455
class MHSA (nn .Module ):
461
456
"""A **unit-scaled** implementation of a multi-head self-attention layer.
462
457
463
- Warning: using `constraint=None` here will likely give incorrect gradients.
464
-
465
458
Args:
466
459
hidden_size (int): the hidden dimension size of the input.
467
460
heads (int): the number of attention heads.
468
461
is_causal (bool): causal masking (for non-padded sequences).
469
462
dropout_p (float, optional): the probability of the post-softmax dropout.
470
463
{0}
471
- {1}
472
464
"""
473
465
474
466
def __init__ (
@@ -478,16 +470,14 @@ def __init__(
478
470
is_causal : bool ,
479
471
dropout_p : float = 0.0 ,
480
472
mult : float = 1.0 ,
481
- constraint : Optional [str ] = "to_output_scale" ,
482
473
) -> None :
483
474
super ().__init__ ()
484
475
self .heads = heads
485
476
self .dropout_p = dropout_p
486
477
self .is_causal = is_causal
487
478
self .mult = mult
488
- self .linear_qkv = Linear (hidden_size , 3 * hidden_size , constraint = constraint )
489
- self .linear_o = Linear (hidden_size , hidden_size , constraint = constraint )
490
- self .constraint = constraint
479
+ self .linear_qkv = Linear (hidden_size , 3 * hidden_size )
480
+ self .linear_o = Linear (hidden_size , hidden_size )
491
481
492
482
def forward (self , input : Tensor ) -> Tensor :
493
483
q_k_v = self .linear_qkv (input )
@@ -499,7 +489,6 @@ def forward(self, input: Tensor) -> Tensor:
499
489
return self .linear_o (qkv ) # type: ignore
500
490
501
491
502
- @format_docstring (variadic_constraint_docstring )
503
492
class TransformerLayer (nn .Module ):
504
493
"""A **unit-scaled** implementation of a PreNorm
505
494
(see https://arxiv.org/abs/2002.04745) transformer layer.
@@ -516,7 +505,6 @@ class TransformerLayer(nn.Module):
516
505
is_causal (bool): causal masking (for non-padded sequences).
517
506
dropout_p (float, optional): the probability of residual and post-softmax
518
507
dropout.
519
- {0}
520
508
"""
521
509
522
510
def __init__ (
@@ -527,22 +515,15 @@ def __init__(
527
515
mlp_tau : float ,
528
516
is_causal : bool ,
529
517
dropout_p : float = 0.0 ,
530
- constraint : Optional [str ] = "to_output_scale" ,
531
518
) -> None :
532
519
super ().__init__ ()
533
520
self .dropout_p = dropout_p
534
521
self .mhsa_tau = mhsa_tau
535
522
self .mlp_tau = mlp_tau
536
523
self .mhsa_norm = RMSNorm (hidden_size )
537
- self .mhsa = MHSA (
538
- hidden_size ,
539
- heads ,
540
- is_causal = is_causal ,
541
- dropout_p = dropout_p ,
542
- constraint = constraint ,
543
- )
524
+ self .mhsa = MHSA (hidden_size , heads , is_causal = is_causal , dropout_p = dropout_p )
544
525
self .mlp_norm = RMSNorm (hidden_size )
545
- self .mlp = MLP (hidden_size , constraint = constraint )
526
+ self .mlp = MLP (hidden_size )
546
527
547
528
def forward (self , input : Tensor ) -> Tensor :
548
529
input , skip = U .residual_split (input , tau = self .mhsa_tau )
@@ -627,16 +608,13 @@ def __init__(
627
608
)
628
609
629
610
630
- @format_docstring (variadic_constraint_docstring )
631
611
class TransformerDecoder (nn .Sequential ): # pragma: no cover
632
612
"""A **unit-scaled** implementation of a decoder-type transformer.
633
613
634
614
Note: this class is currently just for demonstrating scaling and lacks key
635
615
functionality (for example masking, positional embeddings, usage for
636
616
inference).
637
617
638
- Warning: using `constraint=None` here will likely give incorrect gradients.
639
-
640
618
Args:
641
619
hidden_size (int): the hidden dimension size of the input.
642
620
vocab_size (int): the number of tokens in the vocabulary.
@@ -648,7 +626,6 @@ class TransformerDecoder(nn.Sequential): # pragma: no cover
648
626
controlling residual weights in the transformer trunk; see
649
627
:func:`unit_scaling.core.functional.transformer_residual_scaling_rule`
650
628
(default).
651
- {0}
652
629
"""
653
630
654
631
def __init__ (
@@ -659,7 +636,6 @@ def __init__(
659
636
heads : int ,
660
637
dropout_p : float = 0.0 ,
661
638
residual_scaling : ResidualScalingFn = transformer_residual_scaling_rule (),
662
- constraint : Optional [str ] = "to_output_scale" ,
663
639
) -> None :
664
640
super ().__init__ ()
665
641
self .embedding = Embedding (vocab_size , hidden_size )
@@ -670,7 +646,6 @@ def __init__(
670
646
is_causal = True ,
671
647
dropout_p = dropout_p ,
672
648
residual_scaling = residual_scaling ,
673
- constraint = constraint ,
674
649
)
675
650
self .final_norm = RMSNorm (hidden_size )
676
651
self .projection = LinearReadout (hidden_size , vocab_size )
0 commit comments