Skip to content

Commit 8f582ec

Browse files
committed
Add supported arch 8.7 where 8.6 is used
Signed-off-by: Conroy Cheers <[email protected]>
1 parent b336c7a commit 8f582ec

File tree

10 files changed

+22
-21
lines changed

10 files changed

+22
-21
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
2222
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13")
2323

2424
# Supported NVIDIA architectures.
25-
set(CUDA_SUPPORTED_ARCHS "8.0;8.6;8.9;9.0")
25+
set(CUDA_SUPPORTED_ARCHS "8.0;8.6;8.7;8.9;9.0")
2626
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
2727
list(APPEND CUDA_SUPPORTED_ARCHS "10.0" "11.0" "12.0")
2828
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)

csrc/ft_attention/setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def append_nvcc_threads(nvcc_extra_args):
7676
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
7777
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
7878
if bare_metal_version >= Version("11.8"):
79-
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
79+
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7;9.0"
8080
elif bare_metal_version >= Version("11.1"):
81-
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
81+
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7"
8282
elif bare_metal_version == Version("11.0"):
8383
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
8484
else:

csrc/layer_norm/setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def append_nvcc_threads(nvcc_extra_args):
7474
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
7575
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
7676
if bare_metal_version >= Version("11.8"):
77-
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
77+
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7;9.0"
7878
elif bare_metal_version >= Version("11.1"):
79-
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
79+
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7"
8080
elif bare_metal_version == Version("11.0"):
8181
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
8282
else:

csrc/rotary/setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def append_nvcc_threads(nvcc_extra_args):
7474
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
7575
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
7676
if bare_metal_version >= Version("11.8"):
77-
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
77+
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7;9.0"
7878
elif bare_metal_version >= Version("11.1"):
79-
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
79+
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7"
8080
elif bare_metal_version == Version("11.0"):
8181
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
8282
else:

csrc/xentropy/setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def append_nvcc_threads(nvcc_extra_args):
7474
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
7575
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
7676
if bare_metal_version >= Version("11.8"):
77-
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
77+
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7;9.0"
7878
elif bare_metal_version >= Version("11.1"):
79-
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
79+
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;8.7"
8080
elif bare_metal_version == Version("11.0"):
8181
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
8282
else:

hopper/flash_api.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ inline int get_num_splits(Flash_fwd_params const& params) {
467467
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
468468
// Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
469469
// has not been set here. It's OK though because we might just underestimate kBlockN a bit
470-
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
470+
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 87 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
471471
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
472472
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
473473
int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
@@ -654,7 +654,7 @@ mha_fwd_get_scheduler_metadata(
654654

655655
if (params.num_splits_dynamic_ptr) {
656656
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
657-
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
657+
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 87 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
658658
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
659659
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
660660
auto stream = at::cuda::getCurrentCUDAStream().stream();
@@ -1387,7 +1387,7 @@ std::vector<at::Tensor> mha_bwd(
13871387
: 64));
13881388
int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
13891389
int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
1390-
int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
1390+
int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 87 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
13911391
int const kBlockN_sm90 = head_size_rounded <= 128
13921392
? 128
13931393
: (head_size_rounded <= 192 ? 96 : 80);
@@ -1398,7 +1398,7 @@ std::vector<at::Tensor> mha_bwd(
13981398
: (head_size_rounded <= 96 ? 128
13991399
: (head_size_rounded <= 128 ? 96
14001400
: (head_size_rounded <= 192 ? 64 : 64)));
1401-
int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
1401+
int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 87 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
14021402
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
14031403
int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
14041404
int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);

hopper/flash_bwd_launch_template.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
308308
// With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
309309
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
310310
}
311-
} else if constexpr (Arch == 86 || Arch == 89) {
311+
} else if constexpr (Arch == 86 || Arch == 87 || Arch == 89) {
312312
run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
313313
// run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
314314
// run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
@@ -324,7 +324,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
324324
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
325325
if constexpr (Arch >= 90) {
326326
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
327-
} else if constexpr (Arch == 86 || Arch == 89) {
327+
} else if constexpr (Arch == 86 || Arch == 87 || Arch == 89) {
328328
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
329329
} else {
330330
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
@@ -341,7 +341,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
341341
} else {
342342
run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
343343
}
344-
} else if constexpr (Arch == 86 || Arch == 89) {
344+
} else if constexpr (Arch == 86 || Arch == 87 || Arch == 89) {
345345
run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);
346346
} else {
347347
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false>(params, stream);
@@ -354,7 +354,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
354354
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
355355
if constexpr (Arch >= 90) {
356356
run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
357-
} else if constexpr (Arch == 86 || Arch == 89) {
357+
} else if constexpr (Arch == 86 || Arch == 87 || Arch == 89) {
358358
run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);
359359
} else {
360360
run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);
@@ -367,7 +367,7 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
367367
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
368368
if constexpr (Arch >= 90) {
369369
run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
370-
} else if constexpr (Arch == 86 || Arch == 89) {
370+
} else if constexpr (Arch == 86 || Arch == 87 || Arch == 89) {
371371
run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);
372372
// run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
373373
} else {

hopper/flash_fwd_launch_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
3939

4040
// Can't use structured binding since it's not compatible with constexpr
4141
static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg);
42-
static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
42+
static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 87 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
4343
static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
4444
static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
4545
static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);

hopper/static_switch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@
158158
#else
159159
#define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \
160160
[&] { \
161-
if (ARCH == 86 || ARCH == 89) { \
161+
if (ARCH == 86 || ARCH == 87 || ARCH == 89) { \
162162
constexpr static int ARCH_NAME = 86; \
163163
return __VA_ARGS__(); \
164164
} else if (ARCH < 90) { \

vllm_flash_attn/flash_attn_interface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]:
5151
if torch.cuda.get_device_capability(device)[0] < 8 \
5252
or torch.cuda.get_device_capability(device)[0] >= 10 \
5353
or torch.cuda.get_device_capability(device) == (8, 6) \
54+
or torch.cuda.get_device_capability(device) == (8, 7) \
5455
or torch.cuda.get_device_capability(device) == (8, 9):
5556
return False, \
5657
"FA3 is only supported on devices with compute capability >= 8" \
57-
" excluding 8.6 and 8.9 and Blackwell archs (>=10)"
58+
" excluding 8.6, 8.7 and 8.9 and Blackwell archs (>=10)"
5859
return True, None
5960

6061
def _is_fa4_supported(device = None) -> Tuple[bool, Optional[str]]:

0 commit comments

Comments
 (0)