Skip to content

Commit e581e93

Browse files
committed
Add support for .cu->.cpp copy
Signed-off-by: Luka Govedič <[email protected]>
1 parent 5af0053 commit e581e93

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

CMakeLists.txt

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -284,21 +284,39 @@ elseif (VLLM_GPU_LANG STREQUAL "HIP")
284284
list(APPEND FA3_GEN_SRCS_CU ${FILE_HIP})
285285
endforeach ()
286286

287-
# TODO: copy cpp->cu for correct hipification
288-
# - try copying into gen/ or maybe even directly into build-tree (make sure that it's where hipify would copy it)
287+
# These files are "converted" to .cu before being passed to torch.build_extension on upstream.
288+
# We need to do the same so that hipify treats them correctly. We copy the files in the source tree like upstream.
289+
set(VLLM_FA2_CPP_CU_SRCS
290+
# csrc/flash_attn_ck/flash_api.cpp # only contains declarations & PyBind
291+
csrc/flash_attn_ck/flash_common.cpp
292+
csrc/flash_attn_ck/mha_bwd.cpp
293+
csrc/flash_attn_ck/mha_fwd_kvcache.cpp
294+
csrc/flash_attn_ck/mha_fwd.cpp
295+
csrc/flash_attn_ck/mha_varlen_bwd.cpp
296+
csrc/flash_attn_ck/mha_varlen_fwd.cpp
297+
)
298+
299+
foreach(CPP_FILE ${VLLM_FA2_CPP_CU_SRCS})
300+
string(REGEX REPLACE "\.cpp$" ".cu" CU_FILE ${CPP_FILE})
301+
set(CU_FILE_ABS ${CMAKE_CURRENT_SOURCE_DIR}/${CU_FILE})
302+
set(CPP_FILE_ABS ${CMAKE_CURRENT_SOURCE_DIR}/${CPP_FILE})
303+
add_custom_command(
304+
OUTPUT ${CU_FILE_ABS}
305+
COMMAND ${CMAKE_COMMAND} -E copy ${CPP_FILE_ABS} ${CU_FILE_ABS}
306+
DEPENDS ${CPP_FILE_ABS}
307+
COMMENT "Copying ${CPP_FILE} to ${CU_FILE_ABS}"
308+
)
309+
list(APPEND VLLM_FA2_CU_SRCS ${CU_FILE}) # relative to source dir
310+
endforeach ()
311+
312+
# This target automatically depends on the copy by depending on copied files
289313
define_gpu_extension_target(
290314
_vllm_fa2_C
291315
DESTINATION vllm_flash_attn
292316
LANGUAGE ${VLLM_GPU_LANG}
293317
SOURCES
294-
# csrc/flash_attn_ck/flash_api.cu # only contains declarations & PyBind
295318
csrc/flash_attn_ck/flash_api_torch_lib.cpp
296-
csrc/flash_attn_ck/flash_common.cu
297-
csrc/flash_attn_ck/mha_bwd.cu
298-
csrc/flash_attn_ck/mha_fwd_kvcache.cu
299-
csrc/flash_attn_ck/mha_fwd.cu
300-
csrc/flash_attn_ck/mha_varlen_bwd.cu
301-
csrc/flash_attn_ck/mha_varlen_fwd.cu
319+
${VLLM_FA2_CU_SRCS}
302320
${FA3_GEN_SRCS_CU}
303321
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
304322
USE_SABI 3

0 commit comments

Comments
 (0)