Skip to content

Commit 5af0053

Browse files
committed
Initial ROCm build working (missing .cpp->.cu copies)
1 parent 96f4ed7 commit 5af0053

File tree

8 files changed

+315
-113
lines changed

8 files changed

+315
-113
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ __pycache__/
1111
# Distribution / packaging
1212
bin/
1313
build/
14+
cmake-build-*/
1415
develop-eggs/
1516
dist/
1617
eggs/

CMakeLists.txt

Lines changed: 192 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
3737
# Likely should also be in sync with the vLLM version.
3838
#
3939
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
40+
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1")
4041

4142
find_python_constrained_versions(${PYTHON_SUPPORTED_VERSIONS})
4243

@@ -91,7 +92,19 @@ if (NOT HIP_FOUND AND CUDA_FOUND)
9192
"${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
9293
message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
9394
elseif (HIP_FOUND)
94-
message(FATAL_ERROR "ROCm build is not currently supported for vllm-flash-attn.")
95+
set(VLLM_GPU_LANG "HIP")
96+
97+
# Importing torch recognizes and sets up some HIP/ROCm configuration but does
98+
# not let cmake recognize .hip files. In order to get cmake to understand the
99+
# .hip extension automatically, HIP must be enabled explicitly.
100+
enable_language(HIP)
101+
102+
# ROCm 5.X and 6.X
103+
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
104+
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
105+
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
106+
"expected for ROCm build, saw ${Torch_VERSION} instead.")
107+
endif ()
95108
else ()
96109
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
97110
endif ()
@@ -110,129 +123,212 @@ if (NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
110123
list(APPEND VLLM_FA_GPU_FLAGS "--threads=${NVCC_THREADS}")
111124
endif ()
112125

126+
# Replace instead of appending, nvcc doesn't like duplicate -O flags.
127+
string(REPLACE "-O2" "-O3" CMAKE_${VLLM_GPU_LANG}_FLAGS_RELWITHDEBINFO "${CMAKE_${VLLM_GPU_LANG}_FLAGS_RELWITHDEBINFO}")
113128

114-
# Other flags
115-
list(APPEND VLLM_FA_GPU_FLAGS --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math)
116-
117-
# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
118-
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
119-
# driver API. This causes problems when linking with earlier versions of CUDA.
120-
# Setting this variable sidesteps the issue by calling the driver directly.
121-
list(APPEND VLLM_FA_GPU_FLAGS -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
129+
if (VLLM_GPU_LANG STREQUAL "CUDA")
130+
# Other flags
131+
list(APPEND VLLM_FA_GPU_FLAGS --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math)
122132

123-
# Replace instead of appending, nvcc doesn't like duplicate -O flags.
124-
string(REPLACE "-O2" "-O3" CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}")
133+
# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
134+
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
135+
# driver API. This causes problems when linking with earlier versions of CUDA.
136+
# Setting this variable sidesteps the issue by calling the driver directly.
137+
list(APPEND VLLM_FA_GPU_FLAGS -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
125138

126-
#
127-
# _C extension
128-
#
139+
#
140+
# _C extension
141+
#
129142

130-
if (FA2_ENABLED)
131-
file(GLOB FA2_GEN_SRCS "csrc/flash_attn/src/flash_fwd_*.cu")
143+
if (FA2_ENABLED)
144+
file(GLOB FA2_GEN_SRCS "csrc/flash_attn/src/flash_fwd_*.cu")
132145

133-
# For CUDA we set the architectures on a per file basis
134-
if (VLLM_GPU_LANG STREQUAL "CUDA")
146+
# For CUDA we set the architectures on a per file basis
135147
cuda_archs_loose_intersection(FA2_ARCHS "8.0;9.0" "${CUDA_ARCHS}")
136148
message(STATUS "FA2_ARCHS: ${FA2_ARCHS}")
137149

138150
set_gencode_flags_for_srcs(
139-
SRCS "${FA2_GEN_SRCS}"
140-
CUDA_ARCHS "${FA2_ARCHS}")
141-
endif()
142-
143-
define_gpu_extension_target(
144-
_vllm_fa2_C
145-
DESTINATION vllm_flash_attn
146-
LANGUAGE ${VLLM_GPU_LANG}
147-
SOURCES
148-
csrc/flash_attn/flash_api.cpp
149-
csrc/flash_attn/flash_api_sparse.cpp
150-
csrc/flash_attn/flash_api_torch_lib.cpp
151-
${FA2_GEN_SRCS}
152-
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
153-
USE_SABI 3
154-
WITH_SOABI)
155-
156-
target_include_directories(_vllm_fa2_C PRIVATE
157-
csrc/flash_attn
158-
csrc/flash_attn/src
159-
csrc/common
160-
csrc/cutlass/include)
161-
162-
# custom definitions
163-
target_compile_definitions(_vllm_fa2_C PRIVATE
164-
FLASHATTENTION_DISABLE_BACKWARD
165-
FLASHATTENTION_DISABLE_DROPOUT
166-
# FLASHATTENTION_DISABLE_ALIBI
167-
# FLASHATTENTION_DISABLE_SOFTCAP
168-
FLASHATTENTION_DISABLE_UNEVEN_K
169-
# FLASHATTENTION_DISABLE_LOCAL
170-
FLASHATTENTION_DISABLE_PYBIND
171-
)
172-
endif ()
151+
SRCS "${FA2_GEN_SRCS}"
152+
CUDA_ARCHS "${FA2_ARCHS}")
153+
154+
define_gpu_extension_target(
155+
_vllm_fa2_C
156+
DESTINATION vllm_flash_attn
157+
LANGUAGE ${VLLM_GPU_LANG}
158+
SOURCES
159+
csrc/flash_attn/flash_api.cpp
160+
csrc/flash_attn/flash_api_sparse.cpp
161+
csrc/flash_attn/flash_api_torch_lib.cpp
162+
${FA2_GEN_SRCS}
163+
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
164+
USE_SABI 3
165+
WITH_SOABI)
166+
167+
target_include_directories(_vllm_fa2_C PRIVATE
168+
csrc/flash_attn
169+
csrc/flash_attn/src
170+
csrc/common
171+
csrc/cutlass/include)
172+
173+
# custom definitions
174+
target_compile_definitions(_vllm_fa2_C PRIVATE
175+
FLASHATTENTION_DISABLE_BACKWARD
176+
FLASHATTENTION_DISABLE_DROPOUT
177+
# FLASHATTENTION_DISABLE_ALIBI
178+
# FLASHATTENTION_DISABLE_SOFTCAP
179+
FLASHATTENTION_DISABLE_UNEVEN_K
180+
# FLASHATTENTION_DISABLE_LOCAL
181+
FLASHATTENTION_DISABLE_PYBIND
182+
)
183+
endif ()
173184

174185
# FA3 requires CUDA 12.0 or later
175186
if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
176187
# BF16 source files
177-
file(GLOB FA3_BF16_GEN_SRCS
188+
file(GLOB FA3_BF16_GEN_SRCS
178189
"hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
179-
file(GLOB FA3_BF16_GEN_SRCS_
190+
file(GLOB FA3_BF16_GEN_SRCS_
180191
"hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu")
181192
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
182193
# FP16 source files
183-
file(GLOB FA3_FP16_GEN_SRCS
194+
file(GLOB FA3_FP16_GEN_SRCS
184195
"hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
185196
file(GLOB FA3_FP16_GEN_SRCS_
186197
"hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu")
187198
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
188199

189-
# TODO add fp8 source files when FP8 is enabled
190-
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS})
200+
# TODO add fp8 source files when FP8 is enabled
201+
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS})
191202

192-
# For CUDA we set the architectures on a per file basis
193-
if (VLLM_GPU_LANG STREQUAL "CUDA")
203+
# For CUDA we set the architectures on a per file basis
194204
cuda_archs_loose_intersection(FA3_ARCHS "8.0;9.0a" "${CUDA_ARCHS}")
195205
message(STATUS "FA3_ARCHS: ${FA3_ARCHS}")
196206

197207
set_gencode_flags_for_srcs(
198-
SRCS "${FA3_GEN_SRCS}"
199-
CUDA_ARCHS "${FA3_ARCHS}")
208+
SRCS "${FA3_GEN_SRCS}"
209+
CUDA_ARCHS "${FA3_ARCHS}")
200210
set_gencode_flags_for_srcs(
201-
SRCS "hopper/flash_fwd_combine.cu"
202-
CUDA_ARCHS "${FA3_ARCHS}")
211+
SRCS "hopper/flash_fwd_combine.cu"
212+
CUDA_ARCHS "${FA3_ARCHS}")
213+
214+
215+
define_gpu_extension_target(
216+
_vllm_fa3_C
217+
DESTINATION vllm_flash_attn
218+
LANGUAGE ${VLLM_GPU_LANG}
219+
SOURCES
220+
hopper/flash_fwd_combine.cu
221+
hopper/flash_api.cpp
222+
hopper/flash_api_torch_lib.cpp
223+
${FA3_GEN_SRCS}
224+
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
225+
ARCHITECTURES ${VLLM_FA_GPU_ARCHES}
226+
USE_SABI 3
227+
WITH_SOABI)
228+
229+
target_include_directories(_vllm_fa3_C PRIVATE
230+
hopper
231+
csrc/common
232+
csrc/cutlass/include)
233+
234+
235+
# custom definitions
236+
target_compile_definitions(_vllm_fa3_C PRIVATE
237+
FLASHATTENTION_DISABLE_BACKWARD
238+
FLASHATTENTION_DISABLE_DROPOUT
239+
# FLASHATTENTION_DISABLE_ALIBI
240+
# FLASHATTENTION_DISABLE_SOFTCAP
241+
FLASHATTENTION_DISABLE_UNEVEN_K
242+
# FLASHATTENTION_DISABLE_LOCAL
243+
FLASHATTENTION_DISABLE_PYBIND
244+
FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
245+
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
246+
)
247+
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
248+
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
203249
endif()
204-
250+
elseif (VLLM_GPU_LANG STREQUAL "HIP")
251+
# CLang on ROCm
252+
# --offload-compress required to keep size under 2GB (fails with errs)
253+
list(APPEND VLLM_FA_GPU_FLAGS -ffast-math -fgpu-flush-denormals-to-zero --offload-compress)
254+
255+
# CK fails to compile below O2 as inlining is needed for certain inline assembly
256+
string(REGEX REPLACE "-O(g|0)?" "-O2" CMAKE_HIP_FLAGS_DEBUG "${CMAKE_HIP_FLAGS_DEBUG}")
257+
258+
# Generate FA from CK example kernels
259+
# Generate at configure time so we can glob
260+
set(FA_GENERATED_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/gen)
261+
set(CK_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/csrc/composable_kernel/example/ck_tile/01_fmha/generate.py)
262+
file(MAKE_DIRECTORY ${FA_GENERATED_OUTDIR})
263+
# TODO(luka) only run if required
264+
foreach (KERNEL IN ITEMS "fwd" "fwd_appendkv" "fwd_splitkv" "bwd")
265+
execute_process(
266+
COMMAND
267+
"${Python_EXECUTABLE}" "${CK_GEN_SCRIPT}" "-d" "${KERNEL}" "--output_dir" "${FA_GENERATED_OUTDIR}" "--receipt" "2"
268+
RESULT_VARIABLE PYTHON_ERROR_CODE
269+
ERROR_VARIABLE PYTHON_STDERR
270+
OUTPUT_VARIABLE PYTHON_OUT
271+
)
272+
if (NOT PYTHON_ERROR_CODE EQUAL 0)
273+
message(FATAL_ERROR "Cannot generate Python sources with error: ${PYTHON_ERROR_CODE}\n
274+
stdout:${PYTHON_OUT}\n
275+
stderr:${PYTHON_STDERR}")
276+
endif ()
277+
endforeach ()
278+
279+
file(GLOB FA3_GEN_SRCS "${FA_GENERATED_OUTDIR}/fmha_*wd*.cpp")
280+
# Copy cpp files to hip because running hipify on them is a no-op as they only contain instantiations
281+
foreach(FILE ${FA3_GEN_SRCS})
282+
string(REGEX REPLACE "\.cpp$" ".hip" FILE_HIP ${FILE})
283+
file(COPY_FILE ${FILE} ${FILE_HIP})
284+
list(APPEND FA3_GEN_SRCS_CU ${FILE_HIP})
285+
endforeach ()
286+
287+
# TODO: copy cpp->cu for correct hipification
288+
# - try copying into gen/ or maybe even directly into build-tree (make sure that it's where hipify would copy it)
205289
define_gpu_extension_target(
206-
_vllm_fa3_C
207-
DESTINATION vllm_flash_attn
208-
LANGUAGE ${VLLM_GPU_LANG}
209-
SOURCES
210-
hopper/flash_fwd_combine.cu
211-
hopper/flash_api.cpp
212-
hopper/flash_api_torch_lib.cpp
213-
${FA3_GEN_SRCS}
214-
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
215-
ARCHITECTURES ${VLLM_FA_GPU_ARCHES}
216-
USE_SABI 3
217-
WITH_SOABI)
218-
219-
target_include_directories(_vllm_fa3_C PRIVATE
220-
hopper
221-
csrc/common
222-
csrc/cutlass/include)
223-
224-
# custom definitions
225-
target_compile_definitions(_vllm_fa3_C PRIVATE
226-
FLASHATTENTION_DISABLE_BACKWARD
227-
FLASHATTENTION_DISABLE_DROPOUT
228-
# FLASHATTENTION_DISABLE_ALIBI
229-
# FLASHATTENTION_DISABLE_SOFTCAP
230-
FLASHATTENTION_DISABLE_UNEVEN_K
231-
# FLASHATTENTION_DISABLE_LOCAL
232-
FLASHATTENTION_DISABLE_PYBIND
233-
FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
234-
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
290+
_vllm_fa2_C
291+
DESTINATION vllm_flash_attn
292+
LANGUAGE ${VLLM_GPU_LANG}
293+
SOURCES
294+
# csrc/flash_attn_ck/flash_api.cu # only contains declarations & PyBind
295+
csrc/flash_attn_ck/flash_api_torch_lib.cpp
296+
csrc/flash_attn_ck/flash_common.cu
297+
csrc/flash_attn_ck/mha_bwd.cu
298+
csrc/flash_attn_ck/mha_fwd_kvcache.cu
299+
csrc/flash_attn_ck/mha_fwd.cu
300+
csrc/flash_attn_ck/mha_varlen_bwd.cu
301+
csrc/flash_attn_ck/mha_varlen_fwd.cu
302+
${FA3_GEN_SRCS_CU}
303+
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
304+
USE_SABI 3
305+
WITH_SOABI
306+
# CPP_AS_HIP
235307
)
236-
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
237-
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
238-
endif ()
308+
309+
target_include_directories(_vllm_fa2_C PRIVATE
310+
csrc/common
311+
csrc/composable_kernel/include
312+
csrc/composable_kernel/library/include
313+
csrc/composable_kernel/example/ck_tile/01_fmha
314+
)
315+
316+
target_compile_definitions(_vllm_fa2_C PRIVATE
317+
CK_TILE_FMHA_FWD_FAST_EXP2=1
318+
CK_ENABLE_BF16
319+
CK_ENABLE_BF8
320+
CK_ENABLE_FP16
321+
CK_ENABLE_FP32
322+
CK_ENABLE_FP64
323+
CK_ENABLE_FP8
324+
CK_ENABLE_INT8
325+
CK_USE_XDL
326+
USE_PROF_API=1
327+
# FLASHATTENTION_DISABLE_BACKWARD
328+
__HIP_PLATFORM_HCC__=1
329+
FLASHATTENTION_DISABLE_PYBIND
330+
)
331+
332+
# Data section exceeds 2GB, compress HIP binaries
333+
target_link_options(_vllm_fa2_C PRIVATE "--offload-compress")
334+
endif ()

cmake/utils.cmake

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
6161
# Split into C++ and non-C++ (i.e. CUDA) sources.
6262
#
6363
set(SRCS ${ORIG_SRCS})
64-
set(CXX_SRCS ${ORIG_SRCS})
65-
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
66-
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
64+
set(EXCLUDED_SRCS ${ORIG_SRCS})
65+
set(EXCLUDE_REGEX "\.(cc|cpp|hip)$")
66+
list(FILTER SRCS EXCLUDE REGEX ${EXCLUDE_REGEX})
67+
list(FILTER EXCLUDED_SRCS INCLUDE REGEX ${EXCLUDE_REGEX})
68+
message(DEBUG "Excluded source files: ${EXCLUDED_SRCS}")
6769

6870
#
6971
# Generate ROCm/HIP source file names from CUDA file names.
@@ -78,15 +80,16 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
7880
endforeach()
7981

8082
set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc)
83+
set(CSRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/csrc)
8184
add_custom_target(
8285
hipify${NAME}
83-
COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
86+
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p "${CSRC_DIR}" -o "${CSRC_BUILD_DIR}" ${SRCS}
8487
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
8588
BYPRODUCTS ${HIP_SRCS}
8689
COMMENT "Running hipify on ${NAME} extension source files.")
8790

8891
# Swap out original extension sources with hipified sources.
89-
list(APPEND HIP_SRCS ${CXX_SRCS})
92+
list(APPEND HIP_SRCS ${EXCLUDED_SRCS})
9093
set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
9194
endfunction()
9295

csrc/flash_attn_ck/flash_api.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size
111111
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
112112
int num_splits);
113113

114+
#ifndef FLASHATTENTION_DISABLE_PYBIND
115+
116+
#include <torch/python.h>
117+
114118
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
115119
{
116120
m.doc() = "FlashAttention";
@@ -120,3 +124,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
120124
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
121125
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
122126
}
127+
#endif

0 commit comments

Comments
 (0)