@@ -527,12 +527,12 @@ std::tuple<at::Tensor, at::Tensor> fused_deep_moe(const at::Tensor &x, const at:
527527 const at::Tensor &gmm1_permuted_weight,
528528 const at::Tensor &gmm1_permuted_weight_scale,
529529 const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale,
530+ const at::Tensor &expert_smooth_scales_optional,
530531 const at::Tensor &expert_scales_optional,
531- c10::optional<c10:: string_view> hcom_ep_name,
532- int64_t num_ranks, int64_t rank,
532+ c10::string_view hcom_ep_name,
533+ int64_t num_ranks, int64_t rank, int64_t moe_expert_num,
533534 int64_t shared_expert_num, int64_t shared_expert_rank_num,
534- int64_t num_experts, int64_t global_bs,
535- int64_t quant_mode)
535+ int64_t quant_mode, int64_t global_bs)
536536{
537537 auto x_shape = x.sizes ();
538538 auto experts_shape = expert_ids.sizes ();
@@ -542,15 +542,18 @@ std::tuple<at::Tensor, at::Tensor> fused_deep_moe(const at::Tensor &x, const at:
542542 at::Tensor output = at::empty ({bs, h}, x.options ());
543543
544544 bool is_shared_expert = (rank < shared_expert_rank_num);
545- int64_t num_local_experts = is_shared_expert ? 1 : num_experts / (num_ranks - shared_expert_rank_num);
545+ int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (num_ranks - shared_expert_rank_num);
546546 at::Tensor ep_recv_count = at::empty ({num_local_experts * num_ranks}, expert_ids.options ());
547547
548+ vector<char > group_ep_chrs (hcom_ep_name.begin (), hcom_ep_name.end ());
549+ group_ep_chrs.push_back (' \0 ' );
550+ char *group_ep_ptr = &group_ep_chrs[0 ];
548551 EXEC_NPU_CMD (aclnnFusedDeepMoe,
549552 // input
550553 x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight,
551- gmm2_weight_scale, static_cast < const std:: nullptr_t &>( nullptr ) , expert_scales_optional,
554+ gmm2_weight_scale, expert_smooth_scales_optional , expert_scales_optional,
552555 // attr
553- hcom_ep_name , num_ranks, rank, num_experts , shared_expert_num, shared_expert_rank_num, quant_mode,
556+ group_ep_ptr , num_ranks, rank, moe_expert_num , shared_expert_num, shared_expert_rank_num, quant_mode,
554557 global_bs,
555558 // output
556559 output, ep_recv_count);
@@ -619,12 +622,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
619622 " fused_deep_moe(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
620623 " Tensor gmm1_permuted_weight_scale,"
621624 " Tensor gmm2_weight, Tensor gmm2_weight_scale,"
622- " Tensor expert_scales_optional,"
623- " str? hcom_ep_name,"
624- " int num_ranks, int rank,"
625+ " Tensor expert_smooth_scales_optional, Tensor expert_scales_optional,"
626+ " str hcom_ep_name,"
627+ " int num_ranks, int rank, int moe_expert_num, "
625628 " int shared_expert_num, int shared_expert_rank_num,"
626- " int num_experts, int global_bs ,"
627- " int quant_mode ) -> (Tensor output, Tensor ep_recv_count)"
629+ " int quant_mode ,"
630+ " int global_bs ) -> (Tensor output, Tensor ep_recv_count)"
628631 );
629632
630633 ops.impl (" fused_deep_moe" , torch::kPrivateUse1 , &vllm_ascend::fused_deep_moe);
0 commit comments