Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 48 additions & 16 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,21 @@ endif ()
if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
# BF16 source files
file(GLOB FA3_BF16_GEN_SRCS
"hopper/instantiations/flash_fwd_hdim64_bf16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_bf16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_bf16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_bf16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_bf16*_sm90.cu")
"hopper/instantiations/flash_fwd_hdim64_bf16_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_bf16_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_bf16_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_bf16_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_bf16_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim64_bf16_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_bf16_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_bf16_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_bf16_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_bf16_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim64_bf16_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_bf16_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_bf16_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_bf16_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_bf16_paged*_sm90.cu")
# Add these for hdim diff cases
file(GLOB FA3_BF16_GEN_SRCS_
# "hopper/instantiations/flash_fwd_hdim64_256_bf16*_sm90.cu"
Expand All @@ -195,11 +205,22 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)

# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"hopper/instantiations/flash_fwd_hdim64_fp16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_fp16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_fp16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_fp16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_fp16*_sm90.cu")
"hopper/instantiations/flash_fwd_hdim64_fp16_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_fp16_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_fp16_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_fp16_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_fp16_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim64_fp16_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_fp16_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_fp16_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_fp16_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_fp16_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim64_fp16_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_fp16_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_fp16_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_fp16_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_fp16_paged*_sm90.cu"
)
# Add these for hdim diff cases
file(GLOB FA3_FP16_GEN_SRCS_
# "hopper/instantiations/flash_fwd_hdim64_256_fp16*_sm90.cu"
Expand All @@ -212,11 +233,21 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)

# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
"hopper/instantiations/flash_fwd_hdim64_e4m3*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_e4m3*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_e4m3*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_e4m3*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_e4m3*_sm90.cu")
"hopper/instantiations/flash_fwd_hdim64_e4m3_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_e4m3_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_e4m3_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_e4m3_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_e4m3_*packgqa_sm90.cu"
"hopper/instantiations/flash_fwd_hdim64_e4m3_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_e4m3_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_e4m3_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_e4m3_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_e4m3_split*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim64_e4m3_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_e4m3_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_e4m3_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_e4m3_paged*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_e4m3_paged*_sm90.cu")
# Add these for hdim diff cases (192 only)
file(GLOB FA3_FP8_GEN_SRCS_
"hopper/instantiations/flash_fwd_hdim192_128_e4m3*_sm90.cu")
Expand Down Expand Up @@ -265,11 +296,12 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
FLASHATTENTION_DISABLE_SOFTCAP
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
FLASHATTENTION_PACKGQA_ONLY # Custom flag to save on binary size
FLASHATTENTION_DISABLE_CLUSTER # disabled for varlen in any case
# FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_HDIMDIFF64
Expand Down
7 changes: 7 additions & 0 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,16 @@ struct Flash_fwd_params : public Qkv_params {

int * __restrict__ tile_count_semaphore;
// int * __restrict__ num_m_blocks_ptr;
int * __restrict__ prepare_seqlen_q_ptr;
// int * __restrict__ num_n_blocks_ptr;
int * __restrict__ num_splits_dynamic_ptr;
int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual
int * __restrict__ num_nheads_in_l2_ptr;
bool skip_scheduler_metadata_computation;
bool varlen_sort_batches;
int tile_count_semaphore_offset;
bool head_swizzle;
bool prepare_varlen_pdl;

int arch;
int num_sm;
Expand Down
Loading