Skip to content

Commit 58e3afc

Browse files
committed
split ablation on main
1 parent 57b4e68 commit 58e3afc

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

hopper/flash_api.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,9 @@ mha_fwd_get_scheduler_metadata(
625625
params.pack_gqa |= params.num_splits > 1;
626626
// printf("Num splits (metadata) = %d.\n", params.num_splits);
627627

628+
// ABLATION: set num splits to 1. Should cripple perf
629+
params.num_splits = 1;
630+
628631
bool is_varlen = true;
629632

630633
// Otherwise the kernel will be launched from cuda:0 device
@@ -986,6 +989,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
986989
params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
987990
// Always enable PackGQA for Split
988991
params.pack_gqa |= (params.num_splits > 1);
992+
// ABLATION: set num splits to 1. Should cripple perf
993+
params.num_splits = 1;
989994

990995
// This needs to be set after get_num_splits
991996
at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic

0 commit comments

Comments
 (0)