-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[CI/Build][AMD] Use ROCM_ATTN instead of FLASH_ATTN test for test_register_kv_caches for ROCm and update test for TRITON_ATTN #29985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Randall Smith <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates the test_register_kv_caches test to correctly handle differences in KV cache shapes between FLASH_ATTN and TRITON_ATTN backends. The changes adjust expected values for tensor sizes, base addresses, and block lengths based on the attention backend being tested. The logic seems correct and addresses the test failure described. I have one suggestion to make the test more robust and readable.
Signed-off-by: Randall Smith <[email protected]>
| "FLASH_ATTN", | ||
| marks=pytest.mark.skipif( | ||
| current_platform.is_rocm(), | ||
| reason="Attention backend FLASH_ATTN is not supported on ROCm", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try if ROCM_AITER_FA or ROCM_ATTN is suitable to replace FLASH_ATTN?
FLASH_ATTN, ROCM_AITER_FA, andROCM_ATTN have the same kvcache layout.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjtanaa So, I was able to do this and it worked! However, it turns out that get_attn_backend was using _cached_get_attn_backend which was returning the backend class from the previous run, so FLASH_ATTN was being retesting during the TRITON_ATTN test run. I mocked get_attn_backend, and now all tests pass and test against the correct backend.
This test skips
FLASH_ATTNfortest_register_kv_cachessince it is not supported on ROCm.This also updates test_register_kv_caches for TRITON_ATTN which was failing with the following error:
This is because FLASH_MLA and TRITON_MLA use different shapes for KV cache, according to
get_kv_cache_shape, in particular,TritonAttentionBackend.get_kv_cache_shapereturns something of the form[num_blocks, 2, H, N, D], which causesTpKVTopologyto setself._is_kv_layout_blocks_firsttoTrueforTRITON_ATTN, butFalseforFLASH_ATTN.I adjusted the expected outputs in the test to reflect the expected differences in outputs.
=================
I found a second problem when using ROCM_ATTN instead of FLASH_ATTN on ROCm. When the test runs with the first backend, the _cached_get_attn_backend function was returning the backend from the previous test run (FLASH_ATTN on upstream CI since it runs before the TRITON_ATTN test) and as a result, simply retesting the previous backend.
So, I mocked the get_attn_backend function to return the backend that we want to test
=================
All tests pass now.