Skip to content

Commit 7765112

Browse files
committed
Modify DispatchGmmCombine python api and add new test
Signed-off-by: wangqiankun13 <[email protected]>
1 parent 8d007a9 commit 7765112

File tree

3 files changed

+346
-27
lines changed

3 files changed

+346
-27
lines changed

csrc/torch_binding.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -594,11 +594,11 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
594594
const at::Tensor &gmm1_permuted_weight_scale,
595595
const at::Tensor &gmm2_weight,
596596
const at::Tensor &gmm2_weight_scale,
597-
const at::Tensor &expert_smooth_scales_optional,
598-
const at::Tensor &expert_scales_optional,
599-
c10::string_view hcom_ep_name,
600-
int64_t num_ranks,
601-
int64_t rank,
597+
const c10::optional<at::Tensor> &expert_smooth_scales,
598+
const c10::optional<at::Tensor> &expert_scales,
599+
c10::string_view group_ep,
600+
int64_t ep_rank_size,
601+
int64_t ep_rank_id,
602602
int64_t moe_expert_num,
603603
int64_t shared_expert_num,
604604
int64_t shared_expert_rank_num,
@@ -611,11 +611,11 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
611611

612612
at::Tensor output = at::empty({bs, h}, x.options());
613613

614-
bool is_shared_expert = (rank < shared_expert_rank_num);
615-
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (num_ranks - shared_expert_rank_num);
616-
at::Tensor ep_recv_count = at::empty({num_local_experts * num_ranks}, expert_ids.options());
614+
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
615+
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
616+
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options());
617617

618-
vector<char> group_ep_chrs(hcom_ep_name.begin(), hcom_ep_name.end());
618+
vector<char> group_ep_chrs(group_ep.begin(), group_ep.end());
619619
group_ep_chrs.push_back('\0');
620620
char *group_ep_ptr = &group_ep_chrs[0];
621621
EXEC_NPU_CMD(
@@ -628,12 +628,12 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode(
628628
gmm1_permuted_weight_scale,
629629
gmm2_weight,
630630
gmm2_weight_scale,
631-
expert_smooth_scales_optional,
632-
expert_scales_optional,
631+
expert_smooth_scales,
632+
expert_scales,
633633
//input attrs
634634
group_ep_ptr,
635-
num_ranks,
636-
rank,
635+
ep_rank_size,
636+
ep_rank_id,
637637
moe_expert_num,
638638
shared_expert_num,
639639
shared_expert_rank_num,
@@ -719,12 +719,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
719719
"dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
720720
" Tensor gmm1_permuted_weight_scale,"
721721
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
722-
" Tensor expert_smooth_scales_optional, Tensor expert_scales_optional,"
723-
" str hcom_ep_name,"
724-
" int num_ranks, int rank, int moe_expert_num,"
725-
" int shared_expert_num, int shared_expert_rank_num,"
726-
" int quant_mode,"
727-
" int global_bs) -> (Tensor output, Tensor ep_recv_count)"
722+
" Tensor? expert_smooth_scales=None, Tensor? expert_scales=None,"
723+
" str group_ep='',"
724+
" int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0,"
725+
" int shared_expert_num=1, int shared_expert_rank_num=0,"
726+
" int quant_mode=0,"
727+
" int global_bs=0) -> (Tensor output, Tensor ep_recv_count)"
728728
);
729729
ops.impl("dispatch_gmm_combine_decode", torch::kPrivateUse1, &vllm_ascend::dispatch_gmm_combine_decode);
730730
}

csrc/torch_binding_meta.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,11 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
158158
const at::Tensor &gmm1_permuted_weight_scale,
159159
const at::Tensor &gmm2_weight,
160160
const at::Tensor &gmm2_weight_scale,
161-
const at::Tensor &expert_smooth_scales_optional,
162-
const at::Tensor &expert_scales_optional,
163-
c10::string_view hcom_ep_name,
164-
int64_t num_ranks,
165-
int64_t rank,
161+
const c10::optional<at::Tensor> &expert_smooth_scales,
162+
const c10::optional<at::Tensor> &expert_scales,
163+
c10::string_view group_ep,
164+
int64_t ep_rank_size,
165+
int64_t ep_rank_id,
166166
int64_t moe_expert_num,
167167
int64_t shared_expert_num,
168168
int64_t shared_expert_rank_num,
@@ -175,9 +175,9 @@ std::tuple<at::Tensor, at::Tensor> dispatch_gmm_combine_decode_meta(
175175

176176
at::Tensor output = at::empty({bs, h}, x.options().device(at::kMeta));
177177

178-
bool is_shared_expert = (rank < shared_expert_rank_num);
179-
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (num_ranks - shared_expert_rank_num);
180-
at::Tensor ep_recv_count = at::empty({num_local_experts * num_ranks}, expert_ids.options().device(at::kMeta));
178+
bool is_shared_expert = (ep_rank_id < shared_expert_rank_num);
179+
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num);
180+
at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options().device(at::kMeta));
181181

182182
return {output, ep_recv_count};
183183
}

0 commit comments

Comments
 (0)