@@ -594,11 +594,11 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
594594 const at::Tensor &gmm1_permuted_weight_scale,
595595 const at::Tensor &gmm2_weight,
596596 const at::Tensor &gmm2_weight_scale,
597- const at::Tensor &expert_smooth_scales_optional ,
598- const at::Tensor &expert_scales_optional ,
599- c10::string_view hcom_ep_name ,
600- int64_t num_ranks ,
601- int64_t rank ,
597+ const c10::optional< at::Tensor> &expert_smooth_scales ,
598+ const c10::optional< at::Tensor> &expert_scales ,
599+ c10::string_view group_ep ,
600+ int64_t ep_rank_size ,
601+ int64_t ep_rank_id ,
602602 int64_t moe_expert_num,
603603 int64_t shared_expert_num,
604604 int64_t shared_expert_rank_num,
@@ -611,11 +611,11 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
611611
612612 at::Tensor output = at::empty ({bs, h}, x.options ());
613613
614- bool is_shared_expert = (rank < shared_expert_rank_num);
615- int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (num_ranks - shared_expert_rank_num);
616- at::Tensor ep_recv_count = at::empty ({num_local_experts * num_ranks }, expert_ids.options ());
614+ bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
615+ int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
616+ at::Tensor ep_recv_count = at::empty ({num_local_experts * ep_rank_size }, expert_ids.options ());
617617
618- vector<char > group_ep_chrs (hcom_ep_name .begin (), hcom_ep_name .end ());
618+ vector<char > group_ep_chrs (group_ep .begin (), group_ep .end ());
619619 group_ep_chrs.push_back (' \0 ' );
620620 char *group_ep_ptr = &group_ep_chrs[0 ];
621621 EXEC_NPU_CMD (
@@ -628,12 +628,12 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
628628 gmm1_permuted_weight_scale,
629629 gmm2_weight,
630630 gmm2_weight_scale,
631- expert_smooth_scales_optional ,
632- expert_scales_optional ,
631+ expert_smooth_scales ,
632+ expert_scales ,
633633 // input attrs
634634 group_ep_ptr,
635- num_ranks ,
636- rank ,
635+ ep_rank_size ,
636+ ep_rank_id ,
637637 moe_expert_num,
638638 shared_expert_num,
639639 shared_expert_rank_num,
@@ -719,12 +719,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
719719 " dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
720720 " Tensor gmm1_permuted_weight_scale,"
721721 " Tensor gmm2_weight, Tensor gmm2_weight_scale,"
722- " Tensor expert_smooth_scales_optional , Tensor expert_scales_optional ,"
723- " str hcom_ep_name ,"
724- " int num_ranks , int rank , int moe_expert_num,"
725- " int shared_expert_num, int shared_expert_rank_num,"
726- " int quant_mode,"
727- " int global_bs) -> (Tensor output, Tensor ep_recv_count)"
722+ " Tensor? expert_smooth_scales=None , Tensor? expert_scales=None ,"
723+ " str group_ep='' ,"
724+ " int ep_rank_size=0 , int ep_rank_id=0 , int moe_expert_num=0 ,"
725+ " int shared_expert_num=1 , int shared_expert_rank_num=0 ,"
726+ " int quant_mode=0 ,"
727+ " int global_bs=0 ) -> (Tensor output, Tensor ep_recv_count)"
728728 );
729729 ops.impl (" dispatch_gmm_combine_decode" , torch::kPrivateUse1 , &vllm_ascend::dispatch_gmm_combine_decode);
730730}
0 commit comments