-
Notifications
You must be signed in to change notification settings - Fork 51
[example] flex attention #764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,287 @@ | ||||||||||
| """ | ||||||||||
| Flex Attention Example | ||||||||||
| ======================== | ||||||||||
|
|
||||||||||
| This code implements a custom attention kernel using Helion and PyTorch for efficient computation of scaled dot-product attention, | ||||||||||
| with support for both static and dynamic input shapes. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| # %% | ||||||||||
| # Imports | ||||||||||
| # ------- | ||||||||||
| from __future__ import annotations | ||||||||||
|
|
||||||||||
| import math | ||||||||||
| from typing import Any | ||||||||||
| from typing import Callable | ||||||||||
| from typing import cast | ||||||||||
|
|
||||||||||
| import torch | ||||||||||
| from torch.nn.attention.flex_attention import BlockMask | ||||||||||
| from torch.nn.attention.flex_attention import _create_empty_block_mask | ||||||||||
| from torch.nn.attention.flex_attention import _identity | ||||||||||
| from torch.nn.attention.flex_attention import _score_mod_signature | ||||||||||
| from torch.nn.attention.flex_attention import flex_attention | ||||||||||
|
|
||||||||||
| import helion | ||||||||||
| from helion._testing import run_example | ||||||||||
| import helion.language as hl | ||||||||||
|
|
||||||||||
|
|
||||||||||
| # %% | ||||||||||
| # Flex Attention Kernel Implementation | ||||||||||
| # ---------------------------- | ||||||||||
| @helion.kernel(autotune_accuracy_check=False) | ||||||||||
| def helion_flex_attention_kernel( | ||||||||||
| query: torch.Tensor, | ||||||||||
| key: torch.Tensor, | ||||||||||
| value: torch.Tensor, | ||||||||||
| score_mod: Callable, | ||||||||||
| block_mask_kv_num_blocks: torch.Tensor, | ||||||||||
| block_mask_kv_indices: torch.Tensor, | ||||||||||
| block_mask_full_kv_num_blocks: torch.Tensor | None, | ||||||||||
| block_mask_full_kv_indices: torch.Tensor | None, | ||||||||||
| block_mask_mask_mod: Callable, | ||||||||||
| block_mask_m: int, | ||||||||||
| block_mask_n: int, | ||||||||||
| scale: float, | ||||||||||
| enable_gqa: bool = False, | ||||||||||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||||||||||
| B, H, M, D = query.size() | ||||||||||
| D = hl.specialize(D) | ||||||||||
| assert key.size() == value.size() | ||||||||||
| Bk, Hk, N, Dk = key.size() | ||||||||||
| assert Bk == B | ||||||||||
| assert Dk == D | ||||||||||
| if enable_gqa: | ||||||||||
| assert H % Hk == 0 | ||||||||||
| num_groups = H // Hk | ||||||||||
| else: | ||||||||||
| assert Hk == H | ||||||||||
| num_groups = 1 | ||||||||||
| out = torch.empty_like(query) | ||||||||||
| lse = torch.empty((B, H, M), dtype=torch.float32, device=out.device) | ||||||||||
| log_2_e = 1.44269504 | ||||||||||
| block_m = hl.register_block_size(min(256, block_mask_m)) | ||||||||||
| block_n = hl.register_block_size(min(256, block_mask_n)) | ||||||||||
| assert (block_mask_full_kv_indices is None) == ( | ||||||||||
| block_mask_full_kv_num_blocks is None | ||||||||||
| ) | ||||||||||
| for tile_b, tile_h, tile_m in hl.tile([B, H, M], block_size=[1, 1, block_m]): | ||||||||||
| m_i = hl.full([tile_m], float("-inf"), dtype=torch.float32) | ||||||||||
| l_i = torch.full_like(m_i, 1.0) | ||||||||||
| acc = hl.zeros([tile_m, D], dtype=torch.float32) | ||||||||||
| q_i = query[tile_b.begin, tile_h.begin, tile_m, :] | ||||||||||
|
|
||||||||||
| # iterate through full tiles | ||||||||||
| if block_mask_full_kv_indices is not None: | ||||||||||
| sparse_row = tile_m.begin // block_mask_m | ||||||||||
| sparse_num_blocks = block_mask_full_kv_num_blocks[ # pyright: ignore[reportOptionalSubscript] | ||||||||||
| tile_b.begin, tile_h.begin, sparse_row | ||||||||||
| ] | ||||||||||
|
|
||||||||||
| for block_idx in hl.tile(sparse_num_blocks, block_size=1): | ||||||||||
| start_n = block_mask_full_kv_indices[ | ||||||||||
| tile_b.begin, tile_h.begin, sparse_row, block_idx.id | ||||||||||
| ] | ||||||||||
| end_n = start_n + block_mask_n | ||||||||||
| end_N = end_n.new_full([], N) | ||||||||||
| end_n = torch.minimum(end_n, end_N) | ||||||||||
|
|
||||||||||
| # figure out how many tiles there are here | ||||||||||
| for tile_n in hl.tile(start_n, end_n, block_size=block_n): | ||||||||||
| k = key[tile_b.begin, tile_h.begin // num_groups, tile_n, :] | ||||||||||
| bcast_b = (tile_b.begin + hl.arange(tile_b.block_size))[ | ||||||||||
| :, None, None, None | ||||||||||
| ] | ||||||||||
| bcast_h = (tile_h.begin + hl.arange(tile_h.block_size))[ | ||||||||||
| None, :, None, None | ||||||||||
| ] | ||||||||||
| bcast_m = (tile_m.begin + hl.arange(tile_m.block_size))[ | ||||||||||
| None, None, :, None | ||||||||||
| ] | ||||||||||
| bcast_n = (tile_n.begin + hl.arange(tile_n.block_size))[ | ||||||||||
| None, None, None, : | ||||||||||
| ] | ||||||||||
| qk = hl.zeros([tile_m, tile_n], dtype=torch.float32) | ||||||||||
| qk = hl.dot(q_i, k.T, acc=qk) | ||||||||||
| qk = qk * scale | ||||||||||
| bcast_qk = qk[None, None, :, :] | ||||||||||
| score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n) | ||||||||||
|
||||||||||
| score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n) | |
| score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n) | |
| # The following masking code is temporarily disabled for debugging purposes. | |
| # Re-enable if masking is required for your use case. |
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing logic tile_h.begin // num_groups is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's significant code duplication between the two attention computation blocks (lines 92-125 and 144-177). Consider extracting this logic into a helper function to reduce duplication.
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace if True: with a meaningful condition or comment explaining why this branch is always executed. This makes the code structure unclear.
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing logic tile_h.begin // num_groups is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing logic tile_h.begin // num_groups is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's significant code duplication between the two attention computation blocks (lines 92-125 and 144-177). Consider extracting this logic into a helper function to reduce duplication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing logic
tile_h.begin // num_groupsis duplicated across multiple locations. Consider extracting this into a variable for better maintainability.