@@ -308,7 +308,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
308308 // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
309309 run_mha_bwd_dispatch<Arch, T, 128 , 128 , 64 , Is_causal, Is_local, Has_softcap, 2 , 2 , true , false , false , 2 , 1 , 2 , 2 , false >(params, stream);
310310 }
311- } else if constexpr (Arch == 86 || Arch == 89 ) {
311+ } else if constexpr (Arch == 86 || Arch == 87 || Arch == 89 ) {
312312 run_mha_bwd_dispatch<Arch, T, 64 , 128 , 64 , Is_causal, Is_local, Has_softcap, 2 , 2 , false , false , false , 2 , 2 , 4 , 2 , true >(params, stream);
313313 // run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
314314 // run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
@@ -324,7 +324,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
324324 CAUSAL_LOCAL_SWITCH (params.is_causal , params.is_local , Is_causal, Is_local, [&] {
325325 if constexpr (Arch >= 90 ) {
326326 run_mha_bwd_dispatch<Arch, T, 64 , 128 , 96 , Is_causal, Is_local, Has_softcap, 2 , 2 , true , false , false , 2 , 1 , 2 , 1 , true >(params, stream);
327- } else if constexpr (Arch == 86 || Arch == 89 ) {
327+ } else if constexpr (Arch == 86 || Arch == 87 || Arch == 89 ) {
328328 run_mha_bwd_dispatch<Arch, T, 64 , 128 , 96 , Is_causal, Is_local, Has_softcap, 1 , 2 , false , false , false , 2 , 2 , 4 , 2 , true >(params, stream);
329329 } else {
330330 run_mha_bwd_dispatch<Arch, T, 64 , 128 , 96 , Is_causal, Is_local, Has_softcap, 2 , 2 , false , false , false , 2 , 2 , 4 , 2 , false >(params, stream);
@@ -341,7 +341,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
341341 } else {
342342 run_mha_bwd_dispatch<Arch, T, 80 , 128 , 128 , Is_causal, Is_local, Has_softcap, 2 , 2 , true , false , true , 2 , 1 , 2 , 1 , false >(params, stream);
343343 }
344- } else if constexpr (Arch == 86 || Arch == 89 ) {
344+ } else if constexpr (Arch == 86 || Arch == 87 || Arch == 89 ) {
345345 run_mha_bwd_dispatch<Arch, T, 64 , 96 , 128 , Is_causal, Is_local, Has_softcap, 1 , 2 , false , false , false , 2 , 2 , 2 , 2 , true >(params, stream);
346346 } else {
347347 run_mha_bwd_dispatch<Arch, T, 64 , 128 , 128 , Is_causal, Is_local, Has_softcap, 2 , 2 , false , false , false , 2 , 2 , 2 , 2 , false >(params, stream);
@@ -354,7 +354,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
354354 CAUSAL_LOCAL_SWITCH (params.is_causal , params.is_local , Is_causal, Is_local, [&] {
355355 if constexpr (Arch >= 90 ) {
356356 run_mha_bwd_dispatch<Arch, T, 64 , 96 , 192 , Is_causal, Is_local, Has_softcap, 1 , 1 , false , true , false , 3 , 1 , 1 , 1 , false >(params, stream);
357- } else if constexpr (Arch == 86 || Arch == 89 ) {
357+ } else if constexpr (Arch == 86 || Arch == 87 || Arch == 89 ) {
358358 run_mha_bwd_dispatch<Arch, T, 64 , 64 , 192 , Is_causal, Is_local, Has_softcap, 1 , 1 , false , false , false , 2 , 2 , 2 , 2 , true >(params, stream);
359359 } else {
360360 run_mha_bwd_dispatch<Arch, T, 64 , 80 , 192 , Is_causal, Is_local, Has_softcap, 1 , 2 , false , true , false , 2 , 4 , 2 , 2 , false >(params, stream);
@@ -367,7 +367,7 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
367367 CAUSAL_LOCAL_SWITCH (params.is_causal , params.is_local , Is_causal, Is_local, [&] {
368368 if constexpr (Arch >= 90 ) {
369369 run_mha_bwd_dispatch<Arch, T, 64 , 80 , 256 , Is_causal, Is_local, Has_softcap, 1 , 1 , false , true , true , 2 , 1 , 1 , 1 , false >(params, stream);
370- } else if constexpr (Arch == 86 || Arch == 89 ) {
370+ } else if constexpr (Arch == 86 || Arch == 87 || Arch == 89 ) {
371371 run_mha_bwd_dispatch<Arch, T, 32 , 64 , 256 , Is_causal, Is_local, Has_softcap, 1 , 1 , false , false , false , 2 , 2 , 2 , 1 , true >(params, stream);
372372 // run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
373373 } else {
0 commit comments