@@ -37,6 +37,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
3737# Likely should also be in sync with the vLLM version.
3838#
3939set (TORCH_SUPPORTED_VERSION_CUDA "2.4.0" )
40+ set (TORCH_SUPPORTED_VERSION_ROCM "2.5.1" )
4041
4142find_python_constrained_versions(${PYTHON_SUPPORTED_VERSIONS} )
4243
@@ -91,7 +92,19 @@ if (NOT HIP_FOUND AND CUDA_FOUND)
9192 "${CUDA_SUPPORTED_ARCHS} " "${CUDA_ARCHS} " )
9293 message (STATUS "CUDA supported target architectures: ${CUDA_ARCHS} " )
9394elseif (HIP_FOUND)
94- message (FATAL_ERROR "ROCm build is not currently supported for vllm-flash-attn." )
95+ set (VLLM_GPU_LANG "HIP" )
96+
97+ # Importing torch recognizes and sets up some HIP/ROCm configuration but does
98+ # not let cmake recognize .hip files. In order to get cmake to understand the
99+ # .hip extension automatically, HIP must be enabled explicitly.
100+ enable_language (HIP)
101+
102+ # ROCm 5.X and 6.X
103+ if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
104+ NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM} )
105+ message (WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
106+ "expected for ROCm build, saw ${Torch_VERSION} instead." )
107+ endif ()
95108else ()
96109 message (FATAL_ERROR "Can't find CUDA or HIP installation." )
97110endif ()
@@ -110,129 +123,212 @@ if (NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
110123 list (APPEND VLLM_FA_GPU_FLAGS "--threads=${NVCC_THREADS} " )
111124endif ()
112125
126+ # Replace instead of appending, nvcc doesn't like duplicate -O flags.
127+ string (REPLACE "-O2" "-O3" CMAKE_${VLLM_GPU_LANG} _FLAGS_RELWITHDEBINFO "${CMAKE_${VLLM_GPU_LANG} _FLAGS_RELWITHDEBINFO}" )
113128
114- # Other flags
115- list (APPEND VLLM_FA_GPU_FLAGS --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math)
116-
117- # If CUTLASS is compiled on NVCC >= 12.5, it by default uses
118- # cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
119- # driver API. This causes problems when linking with earlier versions of CUDA.
120- # Setting this variable sidesteps the issue by calling the driver directly.
121- list (APPEND VLLM_FA_GPU_FLAGS -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
129+ if (VLLM_GPU_LANG STREQUAL "CUDA" )
130+ # Other flags
131+ list (APPEND VLLM_FA_GPU_FLAGS --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math)
122132
123- # Replace instead of appending, nvcc doesn't like duplicate -O flags.
124- string (REPLACE "-O2" "-O3" CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} " )
133+ # If CUTLASS is compiled on NVCC >= 12.5, it by default uses
134+ # cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
135+ # driver API. This causes problems when linking with earlier versions of CUDA.
136+ # Setting this variable sidesteps the issue by calling the driver directly.
137+ list (APPEND VLLM_FA_GPU_FLAGS -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
125138
126- #
127- # _C extension
128- #
139+ #
140+ # _C extension
141+ #
129142
130- if (FA2_ENABLED)
131- file (GLOB FA2_GEN_SRCS "csrc/flash_attn/src/flash_fwd_*.cu" )
143+ if (FA2_ENABLED)
144+ file (GLOB FA2_GEN_SRCS "csrc/flash_attn/src/flash_fwd_*.cu" )
132145
133- # For CUDA we set the architectures on a per file basis
134- if (VLLM_GPU_LANG STREQUAL "CUDA" )
146+ # For CUDA we set the architectures on a per file basis
135147 cuda_archs_loose_intersection(FA2_ARCHS "8.0;9.0" "${CUDA_ARCHS} " )
136148 message (STATUS "FA2_ARCHS: ${FA2_ARCHS} " )
137149
138150 set_gencode_flags_for_srcs(
139- SRCS "${FA2_GEN_SRCS} "
140- CUDA_ARCHS "${FA2_ARCHS} " )
141- endif ()
142-
143- define_gpu_extension_target(
144- _vllm_fa2_C
145- DESTINATION vllm_flash_attn
146- LANGUAGE ${VLLM_GPU_LANG}
147- SOURCES
148- csrc/flash_attn/flash_api.cpp
149- csrc/flash_attn/flash_api_sparse.cpp
150- csrc/flash_attn/flash_api_torch_lib.cpp
151- ${FA2_GEN_SRCS}
152- COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
153- USE_SABI 3
154- WITH_SOABI)
155-
156- target_include_directories (_vllm_fa2_C PRIVATE
157- csrc/flash_attn
158- csrc/flash_attn/src
159- csrc/common
160- csrc/cutlass/include )
161-
162- # custom definitions
163- target_compile_definitions (_vllm_fa2_C PRIVATE
164- FLASHATTENTION_DISABLE_BACKWARD
165- FLASHATTENTION_DISABLE_DROPOUT
166- # FLASHATTENTION_DISABLE_ALIBI
167- # FLASHATTENTION_DISABLE_SOFTCAP
168- FLASHATTENTION_DISABLE_UNEVEN_K
169- # FLASHATTENTION_DISABLE_LOCAL
170- FLASHATTENTION_DISABLE_PYBIND
171- )
172- endif ()
151+ SRCS "${FA2_GEN_SRCS} "
152+ CUDA_ARCHS "${FA2_ARCHS} " )
153+
154+ define_gpu_extension_target(
155+ _vllm_fa2_C
156+ DESTINATION vllm_flash_attn
157+ LANGUAGE ${VLLM_GPU_LANG}
158+ SOURCES
159+ csrc/flash_attn/flash_api.cpp
160+ csrc/flash_attn/flash_api_sparse.cpp
161+ csrc/flash_attn/flash_api_torch_lib.cpp
162+ ${FA2_GEN_SRCS}
163+ COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
164+ USE_SABI 3
165+ WITH_SOABI)
166+
167+ target_include_directories (_vllm_fa2_C PRIVATE
168+ csrc/flash_attn
169+ csrc/flash_attn/src
170+ csrc/common
171+ csrc/cutlass/include )
172+
173+ # custom definitions
174+ target_compile_definitions (_vllm_fa2_C PRIVATE
175+ FLASHATTENTION_DISABLE_BACKWARD
176+ FLASHATTENTION_DISABLE_DROPOUT
177+ # FLASHATTENTION_DISABLE_ALIBI
178+ # FLASHATTENTION_DISABLE_SOFTCAP
179+ FLASHATTENTION_DISABLE_UNEVEN_K
180+ # FLASHATTENTION_DISABLE_LOCAL
181+ FLASHATTENTION_DISABLE_PYBIND
182+ )
183+ endif ()
173184
174185# FA3 requires CUDA 12.0 or later
175186if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
176187 # BF16 source files
177- file (GLOB FA3_BF16_GEN_SRCS
188+ file (GLOB FA3_BF16_GEN_SRCS
178189 "hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu" )
179- file (GLOB FA3_BF16_GEN_SRCS_
190+ file (GLOB FA3_BF16_GEN_SRCS_
180191 "hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu" )
181192 list (APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_} )
182193 # FP16 source files
183- file (GLOB FA3_FP16_GEN_SRCS
194+ file (GLOB FA3_FP16_GEN_SRCS
184195 "hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu" )
185196 file (GLOB FA3_FP16_GEN_SRCS_
186197 "hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu" )
187198 list (APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_} )
188199
189- # TODO add fp8 source files when FP8 is enabled
190- set (FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} )
200+ # TODO add fp8 source files when FP8 is enabled
201+ set (FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} )
191202
192- # For CUDA we set the architectures on a per file basis
193- if (VLLM_GPU_LANG STREQUAL "CUDA" )
203+ # For CUDA we set the architectures on a per file basis
194204 cuda_archs_loose_intersection(FA3_ARCHS "8.0;9.0a" "${CUDA_ARCHS} " )
195205 message (STATUS "FA3_ARCHS: ${FA3_ARCHS} " )
196206
197207 set_gencode_flags_for_srcs(
198- SRCS "${FA3_GEN_SRCS} "
199- CUDA_ARCHS "${FA3_ARCHS} " )
208+ SRCS "${FA3_GEN_SRCS} "
209+ CUDA_ARCHS "${FA3_ARCHS} " )
200210 set_gencode_flags_for_srcs(
201- SRCS "hopper/flash_fwd_combine.cu"
202- CUDA_ARCHS "${FA3_ARCHS} " )
211+ SRCS "hopper/flash_fwd_combine.cu"
212+ CUDA_ARCHS "${FA3_ARCHS} " )
213+
214+
215+ define_gpu_extension_target(
216+ _vllm_fa3_C
217+ DESTINATION vllm_flash_attn
218+ LANGUAGE ${VLLM_GPU_LANG}
219+ SOURCES
220+ hopper/flash_fwd_combine.cu
221+ hopper/flash_api.cpp
222+ hopper/flash_api_torch_lib.cpp
223+ ${FA3_GEN_SRCS}
224+ COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
225+ ARCHITECTURES ${VLLM_FA_GPU_ARCHES}
226+ USE_SABI 3
227+ WITH_SOABI)
228+
229+ target_include_directories (_vllm_fa3_C PRIVATE
230+ hopper
231+ csrc/common
232+ csrc/cutlass/include )
233+
234+
235+ # custom definitions
236+ target_compile_definitions (_vllm_fa3_C PRIVATE
237+ FLASHATTENTION_DISABLE_BACKWARD
238+ FLASHATTENTION_DISABLE_DROPOUT
239+ # FLASHATTENTION_DISABLE_ALIBI
240+ # FLASHATTENTION_DISABLE_SOFTCAP
241+ FLASHATTENTION_DISABLE_UNEVEN_K
242+ # FLASHATTENTION_DISABLE_LOCAL
243+ FLASHATTENTION_DISABLE_PYBIND
244+ FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
245+ FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
246+ )
247+ elseif (${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
248+ message (STATUS "FA3 is disabled because CUDA version is not 12.0 or later." )
203249 endif ()
204-
250+ elseif (VLLM_GPU_LANG STREQUAL "HIP" )
251+ # CLang on ROCm
252+ # --offload-compress required to keep size under 2GB (fails with errs)
253+ list (APPEND VLLM_FA_GPU_FLAGS -ffast-math -fgpu-flush-denormals-to-zero --offload-compress)
254+
255+ # CK fails to compile below O2 as inlining is needed for certain inline assembly
256+ string (REGEX REPLACE "-O(g|0)?" "-O2" CMAKE_HIP_FLAGS_DEBUG "${CMAKE_HIP_FLAGS_DEBUG} " )
257+
258+ # Generate FA from CK example kernels
259+ # Generate at configure time so we can glob
260+ set (FA_GENERATED_OUTDIR ${CMAKE_CURRENT_BINARY_DIR} /gen)
261+ set (CK_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR} /csrc/composable_kernel/example/ck_tile/01_fmha/generate.py)
262+ file (MAKE_DIRECTORY ${FA_GENERATED_OUTDIR} )
263+ # TODO(luka) only run if required
264+ foreach (KERNEL IN ITEMS "fwd" "fwd_appendkv" "fwd_splitkv" "bwd" )
265+ execute_process (
266+ COMMAND
267+ "${Python_EXECUTABLE} " "${CK_GEN_SCRIPT} " "-d" "${KERNEL} " "--output_dir" "${FA_GENERATED_OUTDIR} " "--receipt" "2"
268+ RESULT_VARIABLE PYTHON_ERROR_CODE
269+ ERROR_VARIABLE PYTHON_STDERR
270+ OUTPUT_VARIABLE PYTHON_OUT
271+ )
272+ if (NOT PYTHON_ERROR_CODE EQUAL 0)
273+ message (FATAL_ERROR "Cannot generate Python sources with error: ${PYTHON_ERROR_CODE} \n
274+ stdout:${PYTHON_OUT} \n
275+ stderr:${PYTHON_STDERR} " )
276+ endif ()
277+ endforeach ()
278+
279+ file (GLOB FA3_GEN_SRCS "${FA_GENERATED_OUTDIR} /fmha_*wd*.cpp" )
280+ # Copy cpp files to hip because running hipify on them is a no-op as they only contain instantiations
281+ foreach (FILE ${FA3_GEN_SRCS} )
282+ string (REGEX REPLACE "\. cpp$" ".hip" FILE_HIP ${FILE} )
283+ file (COPY_FILE ${FILE} ${FILE_HIP} )
284+ list (APPEND FA3_GEN_SRCS_CU ${FILE_HIP} )
285+ endforeach ()
286+
287+ # TODO: copy cpp->cu for correct hipification
288+ # - try copying into gen/ or maybe even directly into build-tree (make sure that it's where hipify would copy it)
205289 define_gpu_extension_target(
206- _vllm_fa3_C
207- DESTINATION vllm_flash_attn
208- LANGUAGE ${VLLM_GPU_LANG}
209- SOURCES
210- hopper/flash_fwd_combine.cu
211- hopper/flash_api.cpp
212- hopper/flash_api_torch_lib.cpp
213- ${FA3_GEN_SRCS}
214- COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
215- ARCHITECTURES ${VLLM_FA_GPU_ARCHES}
216- USE_SABI 3
217- WITH_SOABI)
218-
219- target_include_directories (_vllm_fa3_C PRIVATE
220- hopper
221- csrc/common
222- csrc/cutlass/include )
223-
224- # custom definitions
225- target_compile_definitions (_vllm_fa3_C PRIVATE
226- FLASHATTENTION_DISABLE_BACKWARD
227- FLASHATTENTION_DISABLE_DROPOUT
228- # FLASHATTENTION_DISABLE_ALIBI
229- # FLASHATTENTION_DISABLE_SOFTCAP
230- FLASHATTENTION_DISABLE_UNEVEN_K
231- # FLASHATTENTION_DISABLE_LOCAL
232- FLASHATTENTION_DISABLE_PYBIND
233- FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
234- FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
290+ _vllm_fa2_C
291+ DESTINATION vllm_flash_attn
292+ LANGUAGE ${VLLM_GPU_LANG}
293+ SOURCES
294+ # csrc/flash_attn_ck/flash_api.cu # only contains declarations & PyBind
295+ csrc/flash_attn_ck/flash_api_torch_lib.cpp
296+ csrc/flash_attn_ck/flash_common.cu
297+ csrc/flash_attn_ck/mha_bwd.cu
298+ csrc/flash_attn_ck/mha_fwd_kvcache.cu
299+ csrc/flash_attn_ck/mha_fwd.cu
300+ csrc/flash_attn_ck/mha_varlen_bwd.cu
301+ csrc/flash_attn_ck/mha_varlen_fwd.cu
302+ ${FA3_GEN_SRCS_CU}
303+ COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
304+ USE_SABI 3
305+ WITH_SOABI
306+ # CPP_AS_HIP
235307 )
236- elseif (${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
237- message (STATUS "FA3 is disabled because CUDA version is not 12.0 or later." )
238- endif ()
308+
309+ target_include_directories (_vllm_fa2_C PRIVATE
310+ csrc/common
311+ csrc/composable_kernel/include
312+ csrc/composable_kernel/library/include
313+ csrc/composable_kernel/example/ck_tile/01_fmha
314+ )
315+
316+ target_compile_definitions (_vllm_fa2_C PRIVATE
317+ CK_TILE_FMHA_FWD_FAST_EXP2=1
318+ CK_ENABLE_BF16
319+ CK_ENABLE_BF8
320+ CK_ENABLE_FP16
321+ CK_ENABLE_FP32
322+ CK_ENABLE_FP64
323+ CK_ENABLE_FP8
324+ CK_ENABLE_INT8
325+ CK_USE_XDL
326+ USE_PROF_API=1
327+ # FLASHATTENTION_DISABLE_BACKWARD
328+ __HIP_PLATFORM_HCC__=1
329+ FLASHATTENTION_DISABLE_PYBIND
330+ )
331+
332+ # Data section exceeds 2GB, compress HIP binaries
333+ target_link_options (_vllm_fa2_C PRIVATE "--offload-compress" )
334+ endif ()
0 commit comments