Skip to content

Commit bb712ab

Browse files
committed
fixbug
1 parent 6deb62d commit bb712ab

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

csrc/torch_binding.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -527,12 +527,12 @@ std::tuple<at::Tensor, at::Tensor> fused_deep_moe(const at::Tensor &x, const at:
527527
const at::Tensor &gmm1_permuted_weight,
528528
const at::Tensor &gmm1_permuted_weight_scale,
529529
const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale,
530+
const at::Tensor &expert_smooth_scales_optional,
530531
const at::Tensor &expert_scales_optional,
531-
c10::optional<c10::string_view> hcom_ep_name,
532-
int64_t num_ranks, int64_t rank,
532+
c10::string_view hcom_ep_name,
533+
int64_t num_ranks, int64_t rank, int64_t moe_expert_num,
533534
int64_t shared_expert_num, int64_t shared_expert_rank_num,
534-
int64_t num_experts, int64_t global_bs,
535-
int64_t quant_mode)
535+
int64_t quant_mode, int64_t global_bs)
536536
{
537537
auto x_shape = x.sizes();
538538
auto experts_shape = expert_ids.sizes();
@@ -542,15 +542,18 @@ std::tuple<at::Tensor, at::Tensor> fused_deep_moe(const at::Tensor &x, const at:
542542
at::Tensor output = at::empty({bs, h}, x.options());
543543

544544
bool is_shared_expert = (rank < shared_expert_rank_num);
545-
int64_t num_local_experts = is_shared_expert ? 1 : num_experts / (num_ranks - shared_expert_rank_num);
545+
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (num_ranks - shared_expert_rank_num);
546546
at::Tensor ep_recv_count = at::empty({num_local_experts * num_ranks}, expert_ids.options());
547547

548+
vector<char> group_ep_chrs(hcom_ep_name.begin(), hcom_ep_name.end());
549+
group_ep_chrs.push_back('\0');
550+
char *group_ep_ptr = &group_ep_chrs[0];
548551
EXEC_NPU_CMD(aclnnFusedDeepMoe,
549552
// input
550553
x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight,
551-
gmm2_weight_scale, static_cast<const std::nullptr_t &>(nullptr), expert_scales_optional,
554+
gmm2_weight_scale, expert_smooth_scales_optional, expert_scales_optional,
552555
//attr
553-
hcom_ep_name, num_ranks, rank, num_experts, shared_expert_num, shared_expert_rank_num, quant_mode,
556+
group_ep_ptr, num_ranks, rank, moe_expert_num, shared_expert_num, shared_expert_rank_num, quant_mode,
554557
global_bs,
555558
// output
556559
output, ep_recv_count);
@@ -619,12 +622,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
619622
"fused_deep_moe(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight,"
620623
" Tensor gmm1_permuted_weight_scale,"
621624
" Tensor gmm2_weight, Tensor gmm2_weight_scale,"
622-
" Tensor expert_scales_optional,"
623-
" str? hcom_ep_name,"
624-
" int num_ranks, int rank,"
625+
" Tensor expert_smooth_scales_optional, Tensor expert_scales_optional,"
626+
" str hcom_ep_name,"
627+
" int num_ranks, int rank, int moe_expert_num,"
625628
" int shared_expert_num, int shared_expert_rank_num,"
626-
" int num_experts, int global_bs,"
627-
" int quant_mode) -> (Tensor output, Tensor ep_recv_count)"
629+
" int quant_mode,"
630+
" int global_bs) -> (Tensor output, Tensor ep_recv_count)"
628631
);
629632

630633
ops.impl("fused_deep_moe", torch::kPrivateUse1, &vllm_ascend::fused_deep_moe);

csrc/torch_binding_meta.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ std::tuple<at::Tensor, at::Tensor> fused_deep_moe_meta(const at::Tensor &x, cons
7373
const at::Tensor &gmm1_permuted_weight,
7474
const at::Tensor &gmm1_permuted_weight_scale,
7575
const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale,
76+
const at::Tensor &expert_smooth_scales_optional,
7677
const at::Tensor &expert_scales_optional,
77-
c10::optional<c10::string_view> hcom_ep_name,
78-
int64_t num_ranks, int64_t rank,
78+
c10::string_view hcom_ep_name,
79+
int64_t num_ranks, int64_t rank, int64_t moe_expert_num,
7980
int64_t shared_expert_num, int64_t shared_expert_rank_num,
80-
int64_t num_experts, int64_t global_bs,
81-
int64_t quant_mode)
81+
int64_t quant_mode, int64_t global_bs)
8282
{
8383
auto x_shape = x.sizes();
8484
auto experts_shape = expert_ids.sizes();
@@ -88,7 +88,7 @@ std::tuple<at::Tensor, at::Tensor> fused_deep_moe_meta(const at::Tensor &x, cons
8888
at::Tensor output = at::empty({bs, h}, x.options().device(at::kMeta));
8989

9090
bool is_shared_expert = (rank < shared_expert_rank_num);
91-
int64_t num_local_experts = is_shared_expert ? 1 : num_experts / (num_ranks - shared_expert_rank_num);
91+
int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (num_ranks - shared_expert_rank_num);
9292
at::Tensor ep_recv_count = at::empty({num_local_experts * num_ranks}, expert_ids.options().device(at::kMeta));
9393

9494
return {output, ep_recv_count};

0 commit comments

Comments
 (0)