diff --git a/internlm/model/moe/dropless_layer.py b/internlm/model/moe/dropless_layer.py index 031c23065..0414bf677 100644 --- a/internlm/model/moe/dropless_layer.py +++ b/internlm/model/moe/dropless_layer.py @@ -229,9 +229,7 @@ def __init__( if self.token_dispatch_policy == "alltoall": self.token_permutation_func = self.token_permutation_by_alltoall self.token_unpermutation_func = self.token_unpermutation_by_alltoall - self.enable_fused_permute = ( - GEMM_INSTALLED and enable_fused_permute and not drop_and_pad and capacity_factor is None - ) + self.enable_fused_permute = GEMM_INSTALLED and enable_fused_permute and not drop_and_pad self.input_splits = None self.output_splits = None self.num_global_tokens_per_local_expert_cpu = None @@ -342,7 +340,7 @@ def topk_softmax_with_capacity(self, gates): exceed_mask = torch.gather(drop_mask, 1, indices) # shape: [num_token, topk] final_expert_weights = expert_weights * torch.logical_not(exceed_mask) - final_indices = indices.clone().masked_fill_(exceed_mask, torch.iinfo(torch.long).max) + final_indices = indices.clone().masked_fill_(exceed_mask, torch.iinfo(torch.int32).max) tokens_per_expert_before_capacity = topk_mask.sum(dim=0)