@@ -284,21 +284,39 @@ elseif (VLLM_GPU_LANG STREQUAL "HIP")
284284 list (APPEND FA3_GEN_SRCS_CU ${FILE_HIP} )
285285 endforeach ()
286286
287- # TODO: copy cpp->cu for correct hipification
288- # - try copying into gen/ or maybe even directly into build-tree (make sure that it's where hipify would copy it)
287+ # These files are "converted" to .cu before being passed to torch.build_extension on upstream.
288+ # We need to do the same so that hipify treats them correctly. We copy the files in the source tree like upstream.
289+ set (VLLM_FA2_CPP_CU_SRCS
290+ # csrc/flash_attn_ck/flash_api.cpp # only contains declarations & PyBind
291+ csrc/flash_attn_ck/flash_common.cpp
292+ csrc/flash_attn_ck/mha_bwd.cpp
293+ csrc/flash_attn_ck/mha_fwd_kvcache.cpp
294+ csrc/flash_attn_ck/mha_fwd.cpp
295+ csrc/flash_attn_ck/mha_varlen_bwd.cpp
296+ csrc/flash_attn_ck/mha_varlen_fwd.cpp
297+ )
298+
299+ foreach (CPP_FILE ${VLLM_FA2_CPP_CU_SRCS} )
300+ string (REGEX REPLACE "\. cpp$" ".cu" CU_FILE ${CPP_FILE} )
301+ set (CU_FILE_ABS ${CMAKE_CURRENT_SOURCE_DIR} /${CU_FILE} )
302+ set (CPP_FILE_ABS ${CMAKE_CURRENT_SOURCE_DIR} /${CPP_FILE} )
303+ add_custom_command (
304+ OUTPUT ${CU_FILE_ABS}
305+ COMMAND ${CMAKE_COMMAND} -E copy ${CPP_FILE_ABS} ${CU_FILE_ABS}
306+ DEPENDS ${CPP_FILE_ABS}
307+ COMMENT "Copying ${CPP_FILE} to ${CU_FILE_ABS} "
308+ )
309+ list (APPEND VLLM_FA2_CU_SRCS ${CU_FILE} ) # relative to source dir
310+ endforeach ()
311+
312+ # This target automatically depends on the copy by depending on copied files
289313 define_gpu_extension_target(
290314 _vllm_fa2_C
291315 DESTINATION vllm_flash_attn
292316 LANGUAGE ${VLLM_GPU_LANG}
293317 SOURCES
294- # csrc/flash_attn_ck/flash_api.cu # only contains declarations & PyBind
295318 csrc/flash_attn_ck/flash_api_torch_lib.cpp
296- csrc/flash_attn_ck/flash_common.cu
297- csrc/flash_attn_ck/mha_bwd.cu
298- csrc/flash_attn_ck/mha_fwd_kvcache.cu
299- csrc/flash_attn_ck/mha_fwd.cu
300- csrc/flash_attn_ck/mha_varlen_bwd.cu
301- csrc/flash_attn_ck/mha_varlen_fwd.cu
319+ ${VLLM_FA2_CU_SRCS}
302320 ${FA3_GEN_SRCS_CU}
303321 COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
304322 USE_SABI 3
0 commit comments