2424import torch .nn as nn
2525import torch .nn .functional as F
2626from megatron .core .fusions .fused_layer_norm import FusedLayerNorm
27+ from megatron .core .models .common .embeddings .language_model_embedding import LanguageModelEmbedding
2728from megatron .core .models .gpt import GPTModel
2829from megatron .core .parallel_state import (
2930 get_data_parallel_group ,
@@ -154,22 +155,54 @@ def _setup(self):
154155 )
155156
156157
158+ # Embedding DynamicModule ##########################################################################
157159@DMRegistry .register (
158- {VocabParallelEmbedding : "megatron.core.tensor_parallel.layers.VocabParallelEmbedding" }
160+ {
161+ VocabParallelEmbedding : "megatron.core.tensor_parallel.layers.VocabParallelEmbedding" ,
162+ nn .Embedding : "nn.Embedding" ,
163+ },
159164)
160- class _DynamicVocabParallelEmbedding (DynamicModule ):
161- """A VocabParallelEmbedding layer with dynamic hyperparams."""
165+ class _DynamicEmbedding (DynamicModule ):
166+ """A Embedding layer with dynamic hyperparams."""
162167
163168 def _setup (self ):
164169 self ._register_hparam ("embedding_dim" , TracedHp (list (range (1 , self .embedding_dim + 1 ))))
165170 self ._register_dynamic_attribute ("weight" , self ._get_weight )
166171
167172 @staticmethod
168- def _get_weight (mod : "_DynamicVocabParallelEmbedding " , weight : torch .Tensor ) -> torch .Tensor :
173+ def _get_weight (mod : "_DynamicEmbedding " , weight : torch .Tensor ) -> torch .Tensor :
169174 """Return the weight tensor of the embedding layer."""
170175 return get_sliced_tensor (mod , weight , None , "embedding_dim" )
171176
172177
178+ @DMRegistry .register (
179+ {
180+ LanguageModelEmbedding : "megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding"
181+ }
182+ )
183+ class _DynamicLanguageModelEmbedding (DynamicModule ):
184+ """A LanguageModelEmbedding layer with dynamic hyperparams."""
185+
186+ def _setup (self ):
187+ DMRegistry .convert (self .word_embeddings )
188+ hp_hidden_size = self .word_embeddings .get_hparam ("embedding_dim" )
189+ if hasattr (self , "position_embeddings" ) and self .position_embeddings is not None :
190+ DMRegistry .convert (self .position_embeddings )
191+ self .position_embeddings .embedding_dim = hp_hidden_size
192+ if hasattr (self , "tokentype_embeddings" ) and self .tokentype_embeddings is not None :
193+ DMRegistry .convert (self .tokentype_embeddings )
194+ self .tokentype_embeddings .embedding_dim = hp_hidden_size
195+
196+ def export (self ) -> torch .nn .Module :
197+ self .word_embeddings .export ()
198+ if hasattr (self , "position_embeddings" ) and self .position_embeddings is not None :
199+ self .position_embeddings .export ()
200+ if hasattr (self , "tokentype_embeddings" ) and self .tokentype_embeddings is not None :
201+ self .tokentype_embeddings .export ()
202+ return super ().export ()
203+
204+
205+ # Normalization DynamicModule ######################################################################
173206@DMRegistry .register ({FusedLayerNorm : "megatron.core.fusions.fused_layer_norm.FusedLayerNorm" })
174207class _DynamicFusedLayerNorm (_DynamicLayerNorm ):
175208 """A FusedLayerNorm layer with dynamic hyperparams."""
@@ -211,8 +244,8 @@ def _setup(self):
211244 self .hparam_name = "moe_ffn_hidden_size"
212245 else :
213246 self .hparam_name = "ffn_hidden_size"
214- self . linear_fc1 = DMRegistry .convert (self .linear_fc1 )
215- self . linear_fc2 = DMRegistry .convert (self .linear_fc2 )
247+ DMRegistry .convert (self .linear_fc1 )
248+ DMRegistry .convert (self .linear_fc2 )
216249
217250 ffn_hidden_size = TracedHp (list (range (1 , self .config .ffn_hidden_size + 1 )))
218251 fc1_output_size = (
@@ -537,7 +570,7 @@ def _setup(self):
537570 {"_setup" : lambda self : None },
538571 )
539572
540- self . core_attention = _DynamicDotProductAttention .convert (self .core_attention )
573+ _DynamicDotProductAttention .convert (self .core_attention )
541574 self .core_attention ._register_dynamic_attribute (
542575 "hidden_size_per_partition" ,
543576 lambda mod , val : self .config .kv_channels * self .num_attention_heads_per_partition ,
@@ -559,7 +592,7 @@ def _setup(self):
559592 {"_setup" : lambda self : None },
560593 )
561594
562- self . core_attention = _DynamicTEDotProductAttention .convert (self .core_attention )
595+ _DynamicTEDotProductAttention .convert (self .core_attention )
563596 self .core_attention ._register_dynamic_attribute (
564597 "num_attention_heads" , lambda mod , val : self .num_attention_heads_per_partition
565598 )
@@ -571,10 +604,10 @@ def _setup(self):
571604 )
572605
573606 # Convert the fused qkv and output projection linear layer to dynamic module
574- self . linear_qkv = _DynamicQKVColumnParallelLinear .convert (
607+ _DynamicQKVColumnParallelLinear .convert (
575608 self .linear_qkv , num_heads_per_group , num_query_groups
576609 )
577- self . linear_proj = _DynamicProjRowParallelLinear .convert (
610+ _DynamicProjRowParallelLinear .convert (
578611 self .linear_proj , num_heads_per_group , num_query_groups
579612 )
580613
@@ -670,6 +703,8 @@ def _setup(self):
670703
671704 # Register dynamic attributes
672705 self ._register_dynamic_attribute ("weight" , self ._get_router_weight )
706+ if self .config .add_bias_linear :
707+ self ._register_dynamic_attribute ("bias" , self ._get_slice_by_num_experts )
673708 if self .enable_expert_bias :
674709 self ._register_dynamic_attribute (
675710 "local_tokens_per_expert" , self ._get_slice_by_num_experts
@@ -681,8 +716,8 @@ def _get_router_weight(mod: "_DynamicTopKRouter", weight: torch.Tensor) -> torch
681716 return get_sliced_tensor (mod , weight , "num_experts" , "hidden_size" )
682717
683718 @staticmethod
684- def _get_slice_by_num_experts (mod : "_DynamicTopKRouter" , bias : torch .Tensor ) -> torch .Tensor :
685- return get_sliced_tensor (mod , bias , "num_experts" )
719+ def _get_slice_by_num_experts (mod : "_DynamicTopKRouter" , val : torch .Tensor ) -> torch .Tensor :
720+ return get_sliced_tensor (mod , val , "num_experts" )
686721
687722 def set_hidden_size_hp (self , hidden_size : TracedHp ) -> None :
688723 """Set hidden_size hparam for router weights from global hidden_size hparam."""
@@ -699,10 +734,10 @@ def _setup(self):
699734 self ._register_hparam ("num_local_experts" , num_moe_experts )
700735
701736 # Convert local_experts list and each individual expert MLP to dynamic modules
702- self . local_experts = DynamicModuleList .convert (self .local_experts )
737+ DynamicModuleList .convert (self .local_experts )
703738 self .local_experts .depth = num_moe_experts # Reuse same hparam for depth
704739 for i in range (len (self .local_experts )):
705- self . local_experts [ i ] = DMRegistry .convert (self .local_experts [i ])
740+ DMRegistry .convert (self .local_experts [i ])
706741
707742 # Track forward activations for importance estimation.
708743 # _activations name is needed for get_activations_and_layer_scores to save scores for re-running pruning.
@@ -777,8 +812,8 @@ def _setup(self):
777812 # Convert to dynamic modules
778813 # Reuse _DynamicSequentialMLP's num_moe_experts hparam for _DynamicTopKRouter's hparam so
779814 # importance estimator and depth hparam is retained.
780- self . router = DMRegistry .convert (self .router )
781- self . experts = DMRegistry .convert (self .experts )
815+ DMRegistry .convert (self .router )
816+ DMRegistry .convert (self .experts )
782817 num_moe_experts_hp = self .experts .get_hparam ("num_local_experts" )
783818
784819 # NOTE: Use num_moe_experts hparam name in top-level module to match TransformerConfig's name
@@ -789,7 +824,7 @@ def _setup(self):
789824 )
790825 self .router .num_experts = num_moe_experts_hp
791826 if self .use_shared_expert :
792- self . shared_experts = DMRegistry .convert (self .shared_experts )
827+ DMRegistry .convert (self .shared_experts )
793828
794829 def forward (self , * args , ** kwargs ):
795830 """Forward pass for the MoE layer."""
@@ -898,12 +933,12 @@ def _setup(self):
898933 # Convert the layernorms, self-attention, and mlp/moe layers to dynamic modules
899934 # NOTE: Mamba stack layers have either Attention or MLP, not both unlike GPT models
900935 if isinstance (self .self_attention , SelfAttention ):
901- self . input_layernorm = DMRegistry .convert (self .input_layernorm )
902- self . self_attention = DMRegistry .convert (self .self_attention )
936+ DMRegistry .convert (self .input_layernorm )
937+ DMRegistry .convert (self .self_attention )
903938
904939 if isinstance (self .mlp , (MLP , MoELayer )):
905- self . pre_mlp_layernorm = DMRegistry .convert (self .pre_mlp_layernorm )
906- self . mlp = DMRegistry .convert (self .mlp )
940+ DMRegistry .convert (self .pre_mlp_layernorm )
941+ DMRegistry .convert (self .mlp )
907942
908943 # Register forward hook to collect activations for importance estimation
909944 self ._setup_mixin ()
@@ -1168,23 +1203,23 @@ def _setup(self):
11681203 self ._register_dynamic_attribute ("headdim" , lambda mod , val : self .mamba_head_dim )
11691204
11701205 # Convert to dynamic modules
1171- self . in_proj = DMRegistry .convert (self .in_proj )
1206+ DMRegistry .convert (self .in_proj )
11721207 self .in_proj .output_size = build_concat_hp (
11731208 [d_inner , d_inner , bc , mamba_num_heads ]
11741209 ) # z, x, B, C, dt
11751210
11761211 conv_dim = build_concat_hp ([d_inner , bc ]) # z, B, C
1177- self . conv1d = DMRegistry .convert (self .conv1d )
1212+ DMRegistry .convert (self .conv1d )
11781213 self .conv1d .in_channels = conv_dim
11791214 self .conv1d .out_channels = conv_dim
11801215 ks = self .conv1d .get_hparam ("kernel_size" )
11811216 ks .choices = [ks .original ]
11821217
11831218 if self .rmsnorm :
1184- self . norm = DMRegistry .convert (self .norm )
1219+ DMRegistry .convert (self .norm )
11851220 self .norm .hidden_size = d_inner
11861221
1187- self . out_proj = DMRegistry .convert (self .out_proj )
1222+ DMRegistry .convert (self .out_proj )
11881223 self .out_proj .input_size = d_inner
11891224
11901225 # Register dynamic attributes for Mamba-specific parameters
@@ -1310,8 +1345,8 @@ class _DynamicMambaLayer(DynamicModule, MambaTransformerLayerMixin):
13101345
13111346 def _setup (self ):
13121347 # Convert to dynamic module
1313- self . mixer = DMRegistry .convert (self .mixer )
1314- self . norm = DMRegistry .convert (self .norm )
1348+ DMRegistry .convert (self .mixer )
1349+ DMRegistry .convert (self .norm )
13151350 self ._setup_mixin ()
13161351
13171352 def set_hidden_size_hp (self , hidden_size : TracedHp ) -> None :
@@ -1336,7 +1371,7 @@ def modify(
13361371 ("mamba_head_dim" , mamba_head_dim_divisor ),
13371372 ]:
13381373 hp = self .mixer .get_hparam (hp_name )
1339- choices = {int (make_divisible (c , divisor )) for c in hp .choices } # type: ignore[arg-type]
1374+ choices = {int (make_divisible (c , divisor )) for c in hp .choices }
13401375 hp .choices = list (set (hp .choices ) & choices | {hp .original })
13411376
13421377 def export (self ):
@@ -1376,24 +1411,21 @@ def _setup(self):
13761411 assert self .config .expert_model_parallel_size == 1 , "Expert parallel is not supported."
13771412 assert self .pre_process == is_pipeline_first_stage ()
13781413 assert self .post_process == is_pipeline_last_stage ()
1379- assert self .position_embedding_type in ["rope" , "none" ], (
1380- f"Only rope position embedding is supported, got { self .position_embedding_type } ."
1381- )
13821414
13831415 # Register num_layers hparam for depth pruning
13841416 self ._register_hparam ("num_layers" , TracedHp (list (range (1 , self .config .num_layers + 1 ))))
13851417
13861418 # Convert layers to dynamic modules and set the shared hidden_size hparam
13871419 if is_pipeline_first_stage ():
1388- self . embedding . word_embeddings = DMRegistry .convert (self .embedding . word_embeddings )
1420+ DMRegistry .convert (self .embedding )
13891421 hidden_size = self .embedding .word_embeddings .get_hparam ("embedding_dim" )
13901422 else :
13911423 hidden_size = None
13921424 hidden_size = dist .broadcast (hidden_size , src = 0 )
13931425 self ._register_hparam ("hidden_size" , hidden_size )
13941426
13951427 for i in range (len (self .decoder .layers )):
1396- self . decoder . layers [ i ] = DMRegistry .convert (self .decoder .layers [i ])
1428+ DMRegistry .convert (self .decoder .layers [i ])
13971429 self .decoder .layers [i ].set_hidden_size_hp (hidden_size )
13981430
13991431 # NOTE: GPTModel has final_layernorm, MambaModel has final_norm
@@ -1409,7 +1441,7 @@ def _setup(self):
14091441 DMRegistry .convert (getattr (self .decoder , self .final_norm_attr_name )),
14101442 )
14111443 getattr (self .decoder , self .final_norm_attr_name ).num_features = hidden_size
1412- self . output_layer = DMRegistry .convert (self .output_layer )
1444+ DMRegistry .convert (self .output_layer )
14131445 self .output_layer .input_size = hidden_size
14141446 self .output_layer .get_hparam ("output_size" ).choices = [self .output_layer .output_size ]
14151447
@@ -1548,7 +1580,7 @@ def export(self) -> torch.nn.Module:
15481580 handle .remove ()
15491581 self ._export_drop_layers ()
15501582 if is_pipeline_first_stage ():
1551- self .embedding .word_embeddings . export ()
1583+ self .embedding .export ()
15521584 for layer in self .decoder .layers :
15531585 layer .export ()
15541586 if is_pipeline_last_stage ():
@@ -1587,9 +1619,7 @@ def set_activations_and_layer_scores(
15871619 rank = get_pipeline_model_parallel_rank ()
15881620 pp_size = get_pipeline_model_parallel_world_size ()
15891621 assert len (activations_per_rank ) == pp_size , (
1590- len (activations_per_rank ),
1591- activations_per_rank ,
1592- pp_size ,
1622+ f"Expected same PP size for stored pruning scores ({ len (activations_per_rank )} ) as current ({ pp_size } )!"
15931623 )
15941624 for layer in self .decoder .layers :
15951625 layer ._scores = layer_scores [layer .layer_number ]
@@ -1611,14 +1641,14 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i
16111641 )
16121642
16131643 supported_model_types = tuple (SUPPORTED_MODELS .keys ())
1614- for m in model .modules ():
1644+ for n , m in model .named_modules ():
16151645 if isinstance (m , supported_model_types ):
16161646 model = m
16171647 break
16181648 assert isinstance (model , supported_model_types ), (
16191649 f"Model should have one of { supported_model_types } submodule, got { model } "
16201650 )
1621- print_rank_0 (f"Dropping layers { layers_to_drop } from { type (model )} ." )
1651+ print_rank_0 (f"Dropping layers { layers_to_drop } from { n } ( { type (model )} ) ." )
16221652
16231653 # get the number of layers remaining in each pp rank
16241654 layers_remaining_per_pp = torch .zeros (
0 commit comments