From 075a68ba3d899d534b580d46109d8df527023852 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 9 Dec 2025 17:59:45 -0800 Subject: [PATCH] intial checkin for cutile kernel support --- .../automatic_plugin/cutile/__init__.py | 0 .../automatic_plugin/cutile/attention.py | 118 +++++++++ .../dynamo/automatic_plugin/cutile/matmul.py | 94 ++++++++ .../automatic_plugin/test_cutile_attention.py | 228 ++++++++++++++++++ .../automatic_plugin/test_cutile_matmul.py | 114 +++++++++ tools/llm/README.md | 1 + tools/llm/run_llm.py | 7 +- tools/llm/torchtrt_ext/attention.py | 118 +++++++++ tools/llm/torchtrt_ext/register_sdpa.py | 152 ++++++++++-- tools/llm/utils.py | 1 - 10 files changed, 814 insertions(+), 19 deletions(-) create mode 100644 tests/py/dynamo/automatic_plugin/cutile/__init__.py create mode 100644 tests/py/dynamo/automatic_plugin/cutile/attention.py create mode 100644 tests/py/dynamo/automatic_plugin/cutile/matmul.py create mode 100644 tests/py/dynamo/automatic_plugin/test_cutile_attention.py create mode 100644 tests/py/dynamo/automatic_plugin/test_cutile_matmul.py create mode 100644 tools/llm/torchtrt_ext/attention.py diff --git a/tests/py/dynamo/automatic_plugin/cutile/__init__.py b/tests/py/dynamo/automatic_plugin/cutile/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/py/dynamo/automatic_plugin/cutile/attention.py b/tests/py/dynamo/automatic_plugin/cutile/attention.py new file mode 100644 index 0000000000..819abf28a4 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/cutile/attention.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# copied from https://gitlab-master.nvidia.com/cuda-python/cuda-python-tile-compiler/-/raw/main/test/kernels/attention.py?ref_type=heads + +import math + +import cuda.tile as ct +import numpy as np +from cuda.tile.numeric_semantics import RoundingMode as RMd + +INV_LOG_2 = 1.0 / math.log(2) + + +@ct.kernel(occupancy=2) +def fmha_kernel( + Q, + K, + V, + Out, + qk_scale: float, + input_pos: int, + TILE_D: ct.Constant[int], # TILE_D = hidden_size + H: ct.Constant[int], + TILE_M: ct.Constant[int], + TILE_N: ct.Constant[int], + QUERY_GROUP_SIZE: ct.Constant[int], + CAUSAL: ct.Constant[bool], + EVEN_K: ct.Constant[bool], +): + bid_x = ct.bid(0) # int + bid_y = ct.bid(1) # int + batch_idx = bid_y // H # int + head_idx = bid_y % H # int + off_kv_h = head_idx // QUERY_GROUP_SIZE # int + qk_scale = qk_scale * INV_LOG_2 + # init offset + offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32) # [TILE_M] + offs_m += input_pos + offs_m = offs_m[:, None] # [TILE_M, 1] + offs_n_tile = ct.arange(TILE_N, dtype=np.int32) # [TILE_N] + offs_n_tile = offs_n_tile[None, :] # [1, TILE_N] + + # initialize m, l, acc + m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32) + l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32) + acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) + # load q + q = ct.load( + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D) + ).reshape( + (TILE_M, TILE_D) + ) # [TILE_M, TILE_D] + + # loop over k, v and update accumulator + m_end = input_pos + (bid_x + 1) * TILE_M # int + k_seqlen = K.shape[2] # int + if CAUSAL: + # when kv pos could exceed q pos + mask_start = (input_pos + bid_x * TILE_M) // TILE_N + # when kv pos could exceed k_seqlen + mask_start = min(mask_start, k_seqlen // TILE_N) + Tc = ct.cdivi(min(m_end, k_seqlen), TILE_N) + else: + Tc = ct.cdivi(k_seqlen, TILE_N) + mask_start = k_seqlen // TILE_N + + for j in range(0, Tc): + # -- compute qk ---- + k = ct.load( + K, + index=(batch_idx, off_kv_h, 0, j), + shape=(1, 1, TILE_D, TILE_N), + order=(0, 1, 3, 2), + latency=2, + ) + k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32) + qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] + if (CAUSAL or not EVEN_K) and j >= mask_start: + offs_n = j * TILE_N + offs_n_tile + mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool) + # out of bound mask + if not EVEN_K: + mask = mask & (offs_n < k_seqlen) + # causal mask + if CAUSAL: + mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] + mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] + qk += mask + # Moving qk_scale multiplication after reduce_max is to improve performance. + m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale) + qk = qk * qk_scale - m_ij # [TILE_M, TILE_N] + + # attention weights + p = ct.exp2(qk, flush_to_zero=True) # [TILE_M, TILE_N] + l_ij = ct.sum(p, axis=-1, keepdims=True) # [TILE_M, 1] + alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # [TILE_M, 1] + # update m_i and l_i + l_i = l_i * alpha + l_ij # [TILE_M, 1] + # scale acc + acc = acc * alpha # [TILE_M, TILE_N] + # compute pv + v = ct.load( + V, + index=(batch_idx, off_kv_h, j, 0), + shape=(1, 1, TILE_N, TILE_D), + latency=4, + ).reshape( + (TILE_N, TILE_D) + ) # [TILE_N, TILE_D] + p = p.astype(Q.dtype) + acc = ct.mma(p, v, acc) # [TILE_M, TILE_N] + m_i = m_ij # [TILE_M, 1] + + acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) + acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype) + ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc) diff --git a/tests/py/dynamo/automatic_plugin/cutile/matmul.py b/tests/py/dynamo/automatic_plugin/cutile/matmul.py new file mode 100644 index 0000000000..29d42cfa77 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/cutile/matmul.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +# copied from https://gitlab-master.nvidia.com/cuda-python/cuda-python-tile-compiler/-/blob/main/test/kernels/matmul.py?ref_type=heads + +import cuda.tile as ct +from cuda.tile.by_target import ByTarget + +ConstInt = ct.Constant[int] + + +def swizzle_2d(M, N, tm, tn, GROUP_SIZE_M): + bid = ct.bid(0) + num_bid_m = ct.cdivi(M, tm) + num_bid_n = ct.cdivi(N, tn) + num_bid_in_group = GROUP_SIZE_M * num_bid_n + group_id = bid // num_bid_in_group + first_bid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_bid_m - first_bid_m, GROUP_SIZE_M) + bid_m = first_bid_m + (bid % group_size_m) + bid_n = (bid % num_bid_in_group) // group_size_m + return bid_m, bid_n + + +@ct.kernel(num_ctas=ByTarget(sm_100=2)) +def matmul_kernel(A, B, C, tm: ConstInt, tn: ConstInt, tk: ConstInt): + GROUP_SIZE_M = 8 + M = A.shape[0] + N = B.shape[1] + bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M) + + num_tiles = ct.dim(A, axis=1, shape=(tm, tk)) + sum = ct.full((tm, tn), 0, dtype=ct.float32) + zero_pad = ct.PaddingValue.ZERO + + # Convert fp32 to tf32 to use tensorcore + dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype + + for k in range(num_tiles): + a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_value=zero_pad).astype( + dtype + ) + b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_value=zero_pad).astype( + dtype + ) + sum = ct.mma(a, b, sum) + + sum = ct.astype(sum, C.dtype) + ct.store(C, index=(bidx, bidy), tile=sum) + + +@ct.kernel +def matmul_split_k_kernel( + A, B, C, LOCKS, COUNTS, tm: ConstInt, tn: ConstInt, tk: ConstInt, SPLIT_K: ConstInt +): + GROUP_SIZE_M = 8 + M = A.shape[0] + N = B.shape[1] + bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M) + bidz = ct.bid(1) + + num_tiles = ct.dim(A, axis=1, shape=(tm, tk)) + sum = ct.full((tm, tn), 0, dtype=ct.float32) + zero_pad = ct.PaddingValue.ZERO + + # Convert fp32 to tf32 to use tensorcore + dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype + + for k in range(bidz, num_tiles, SPLIT_K): + a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_value=zero_pad).astype( + dtype + ) + b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_value=zero_pad).astype( + dtype + ) + sum = ct.mma(a, b, sum) + + sum = ct.astype(sum, C.dtype) + lock_offset = ct.bid(0) + count_offset = lock_offset + while ( + ct.atomic_cas(LOCKS, lock_offset, 0, 1, memory_order=ct.MemoryOrder.ACQUIRE) + == 1 + ): + pass + count = ct.load_offset(COUNTS, count_offset) + if count == 0: + ct.store(C, index=(bidx, bidy), tile=sum) + else: + curr = ct.load(C, index=(bidx, bidy), shape=(tm, tn)) + ct.store(C, index=(bidx, bidy), tile=(curr + sum)) + ct.store_offset(COUNTS, count_offset, (count + 1) % SPLIT_K) + ct.atomic_xchg(LOCKS, lock_offset, 0, memory_order=ct.MemoryOrder.RELEASE) diff --git a/tests/py/dynamo/automatic_plugin/test_cutile_attention.py b/tests/py/dynamo/automatic_plugin/test_cutile_attention.py new file mode 100644 index 0000000000..d7dfdf5479 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_cutile_attention.py @@ -0,0 +1,228 @@ +import importlib.util +import math +import platform +import unittest + +import torch +import torch.nn as nn +import torch_tensorrt +from torch_tensorrt import Input + +if not importlib.util.find_spec("cuda.tile"): + print("cuda.tile is not installed, skipping cuTile tests") +else: + import cuda.tile as ct + + from .cutile.attention import fmha_kernel + from .cutile.matmul import matmul_kernel + + def register_cutile_flash_attention(): + + @torch.library.custom_op("cutile::flash_attention", mutates_args=()) # type: ignore[misc] + def cutile_flash_attention( + Q: torch.Tensor, # (batch_size, q_heads, q_len, hidden_size) + K: torch.Tensor, # (batch_size, k_heads, k_len, hidden_size) + V: torch.Tensor, # (batch_size, k_heads, k_len, hidden_size) + is_causal: bool = False, + tile_size_m: int = 8, + tile_size_n: int = 16, + ) -> torch.Tensor: + TILE_M, TILE_N = tile_size_m, tile_size_n + batch_size, q_heads, q_len, hidden_size = Q.shape + _, k_heads, k_len, _ = K.shape + query_group_size = q_heads // k_heads + qk_scale = 1 / math.sqrt(hidden_size) + O = torch.zeros_like(Q) + EVEN_K = (k_len % TILE_N) == 0 + grid = (math.ceil(q_len / TILE_M), batch_size * q_heads, 1) + input_pos = 0 # TODO: figure out how to use the input_pos, for now do not use, set to 0 + ct.launch( + torch.cuda.current_stream(), + grid, + fmha_kernel, + ( + Q, + K, + V, + O, + qk_scale, + input_pos, + hidden_size, + q_heads, + TILE_M, + TILE_N, + query_group_size, + is_causal, + EVEN_K, + ), + ) + return O + + @torch.library.register_fake("cutile::flash_attention") + def _cutile_flash_attention_fake( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + is_causal: bool = False, + tile_size_m: int = 8, + tile_size_n: int = 16, + ) -> torch.Tensor: + return Q + + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "cutile::flash_attention", supports_dynamic_shapes=True + ) + + class cutile_flash_attention(nn.Module): + def forward(self, Q, K, V): + return torch.ops.cutile.flash_attention.default(Q, K, V, True) + + class torch_flash_attention(nn.Module): + def forward(self, Q, K, V): + return torch.nn.functional.scaled_dot_product_attention( + Q, + K, + V, + attn_mask=None, + dropout_p=0.0, + is_causal=True, + ) + + @unittest.skipIf( + torch.cuda.get_device_capability() < (10, 0), + "cuTile requires compute capability 10.0 or later", + ) + @unittest.skipIf( + not importlib.util.find_spec("cuda.tile"), + "cuda.tile is required to run this test", + ) + @unittest.skipIf( + platform.system() != "Linux", + "cuTile is only supported on Linux for now", + ) + @unittest.skipIf( + torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, + "TensorRT RTX does not support plugins which is required for cuTile", + ) + class TestCutile: + register_cutile_flash_attention() + + def test_cutile_flash_attention(self): + data_type = torch.float32 + inputs = ( + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + ) + enable_cutile, enable_trt_native = True, True + if enable_cutile: + with torch.no_grad(): + cutile_mod = cutile_flash_attention() + cutile_mod_ep = torch.export.export(cutile_mod, inputs) + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir="debuglogs_cutile_attention", + capture_fx_graph_after=["constant_fold"], + save_engine_profile=True, + profile_format="trex", + engine_builder_monitor=False, + ): + trt_cutile_mod = torch_tensorrt.dynamo.compile( + cutile_mod_ep, + inputs, + min_block_size=1, + ) + outputs_cutile = trt_cutile_mod(*inputs) + + if enable_trt_native: + with torch.no_grad(): + torch_mod = torch_flash_attention() + torch_mod_ep = torch.export.export(torch_mod, inputs) + trt_torch_mod = torch_tensorrt.dynamo.compile( + torch_mod_ep, + inputs, + min_block_size=1, + ) + outputs_trt = trt_torch_mod(*inputs) + + if enable_cutile and enable_trt_native: + assert torch.allclose(outputs_cutile, outputs_trt, atol=5e-3, rtol=1e-2) + + def test_cutile_flash_attention_dynamic_shape(self): + data_type = torch.float32 + input_specs = [ + Input( + min_shape=(32, 8, 1, 64), + opt_shape=(32, 8, 128, 64), + max_shape=(32, 8, 256, 64), + dtype=data_type, + ), + Input( + min_shape=(32, 8, 1, 64), + opt_shape=(32, 8, 128, 64), + max_shape=(32, 8, 256, 64), + dtype=data_type, + ), + Input( + min_shape=(32, 8, 1, 64), + opt_shape=(32, 8, 128, 64), + max_shape=(32, 8, 256, 64), + dtype=data_type, + ), + ] + + compile_inputs = ( + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + ) + inference_inputs = ( + torch.randn((32, 8, 256, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 256, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 256, 64), device="cuda", dtype=data_type).cuda(), + ) + + enable_cutile, enable_trt_native = True, True + q_len_dim = torch.export.Dim("q_len", min=1, max=256) + dynamic_shapes = { + "Q": {2: q_len_dim}, + "K": {2: q_len_dim}, + "V": {2: q_len_dim}, + } + if enable_cutile: + with torch.no_grad(): + cutile_mod = cutile_flash_attention() + cutile_mod_ep = torch.export.export( + cutile_mod, + compile_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + trt_cutile_mod = torch_tensorrt.dynamo.compile( + cutile_mod_ep, + input_specs, + min_block_size=1, + enable_precisions={data_type}, + ) + outputs_cutile = trt_cutile_mod(*inference_inputs) + + if enable_trt_native: + with torch.no_grad(): + torch_mod = torch_flash_attention() + torch_mod_ep = torch.export.export( + torch_mod, + compile_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + trt_torch_mod = torch_tensorrt.dynamo.compile( + torch_mod_ep, + input_specs, + min_block_size=1, + enable_precisions={data_type}, + ) + outputs_trt = trt_torch_mod(*inference_inputs) + + if enable_cutile and enable_trt_native: + assert torch.allclose(outputs_cutile, outputs_trt, atol=5e-3, rtol=1e-2) diff --git a/tests/py/dynamo/automatic_plugin/test_cutile_matmul.py b/tests/py/dynamo/automatic_plugin/test_cutile_matmul.py new file mode 100644 index 0000000000..28bf7fb1f8 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_cutile_matmul.py @@ -0,0 +1,114 @@ +import importlib.util +import platform +import unittest +from math import ceil + +import torch +import torch.nn as nn +import torch_tensorrt +from parameterized import parameterized + +if not importlib.util.find_spec("cuda.tile"): + print("cuda.tile is not installed, skipping cuTile tests") +else: + import cuda.tile as ct + + from .cutile.matmul import matmul_kernel + + def register_cutile_matmul(): + @torch.library.custom_op("cutile::matmul", mutates_args=()) # type: ignore[misc] + def cutile_matmul( + A: torch.Tensor, + B: torch.Tensor, + tile_size_m: int = 256, + tile_size_n: int = 256, + tile_size_k: int = 64, + ) -> torch.Tensor: + C = torch.empty(A.shape[0], B.shape[1], device=A.device, dtype=A.dtype) + tm, tn, tk = tile_size_m, tile_size_n, tile_size_k + m, n, _ = A.shape[0], B.shape[1], A.shape[1] + grid = (ceil(m / tm) * ceil(n / tn), 1, 1) + ct.launch( + torch.cuda.current_stream(), grid, matmul_kernel, (A, B, C, tm, tn, tk) + ) + return C + + @torch.library.register_fake("cutile::matmul") + def _( + A: torch.Tensor, + B: torch.Tensor, + tile_size_m: int = 256, + tile_size_n: int = 256, + tile_size_k: int = 64, + ) -> torch.Tensor: + return torch.empty(A.shape[0], B.shape[1], device=A.device, dtype=A.dtype) + + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "cutile::matmul", supports_dynamic_shapes=True + ) + + @unittest.skipIf( + torch.cuda.get_device_capability() < (10, 0), + "cuTile requires compute capability 10.0 or later", + ) + @unittest.skipIf( + not importlib.util.find_spec("cuda.tile"), + "cuda.tile is required to run this test", + ) + @unittest.skipIf( + platform.system() != "Linux", + "cuTile is only supported on Linux for now", + ) + @unittest.skipIf( + torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, + "TensorRT RTX does not support plugins which is required for cuTile", + ) + class TestMatmul: + register_cutile_matmul() + + @parameterized.expand( + [ + ((64, 64), (64, 128), torch.float16), + ((256, 256), (256, 16), torch.float16), + ] + ) + def test_matmul(self, a_shape, b_shape, data_type): + class cutile_matmul(nn.Module): + def forward(self, a, b): + return torch.ops.cutile.matmul.default(a, b) + + class torch_matmul(nn.Module): + def forward(self, a, b): + return torch.matmul(a, b) + + inputs = ( + torch.randn(a_shape, device="cuda", dtype=data_type), + torch.randn(b_shape, device="cuda", dtype=data_type), + ) + enable_cutile, enable_trt_native = True, True + if enable_cutile: + with torch.no_grad(): + cutile_mod = cutile_matmul() + cutile_mod_ep = torch.export.export(cutile_mod, inputs) + trt_cutile_mod = torch_tensorrt.dynamo.compile( + cutile_mod_ep, + inputs, + min_block_size=1, + ) + outputs_cutile = trt_cutile_mod(*inputs) + + if enable_trt_native: + with torch.no_grad(): + torch_mod = torch_matmul() + torch_mod_ep = torch.export.export(torch_mod, inputs) + + trt_torch_mod = torch_tensorrt.dynamo.compile( + torch_mod_ep, + inputs, + min_block_size=1, + ) + outputs_trt = trt_torch_mod(*inputs) + print(f"outputs_trt: {outputs_trt.shape}") + + if enable_trt_native and enable_cutile: + assert torch.allclose(outputs_cutile, outputs_trt, atol=1e-4, rtol=1e-4) diff --git a/tools/llm/README.md b/tools/llm/README.md index 05a1e3cc60..d06816e9a3 100644 --- a/tools/llm/README.md +++ b/tools/llm/README.md @@ -58,6 +58,7 @@ python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --num_tokens 128 --c - `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching). - `--benchmark`: Enable benchmarking mode. - `--enable_pytorch_run`: Also run and compare PyTorch baseline. +- `--enable_cutile_attention`: use cutile attention kernel registered as QDP of TensorRT. ### Caching Strategies diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 1531c30622..0f8ef8cb75 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -59,7 +59,7 @@ def get_model(args): .cuda() ) # register SDPA variant for the model - register_sdpa.enable_sdpa_converter(args.model, model.config) + register_sdpa.enable_sdpa_converter(args, model.config) if args.precision == "FP16": model = model.to(torch.float16) @@ -236,6 +236,11 @@ def measure_perf(trt_model, input_signature, backend_name): arg_parser.add_argument( "--benchmark", action="store_true", help="Enable benchmark (default: False)" ) + arg_parser.add_argument( + "--enable_cutile_attention", + action="store_true", + help="Enable cutile attention (default: False)", + ) args = arg_parser.parse_args() with torch.inference_mode(): diff --git a/tools/llm/torchtrt_ext/attention.py b/tools/llm/torchtrt_ext/attention.py new file mode 100644 index 0000000000..819abf28a4 --- /dev/null +++ b/tools/llm/torchtrt_ext/attention.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# copied from https://gitlab-master.nvidia.com/cuda-python/cuda-python-tile-compiler/-/raw/main/test/kernels/attention.py?ref_type=heads + +import math + +import cuda.tile as ct +import numpy as np +from cuda.tile.numeric_semantics import RoundingMode as RMd + +INV_LOG_2 = 1.0 / math.log(2) + + +@ct.kernel(occupancy=2) +def fmha_kernel( + Q, + K, + V, + Out, + qk_scale: float, + input_pos: int, + TILE_D: ct.Constant[int], # TILE_D = hidden_size + H: ct.Constant[int], + TILE_M: ct.Constant[int], + TILE_N: ct.Constant[int], + QUERY_GROUP_SIZE: ct.Constant[int], + CAUSAL: ct.Constant[bool], + EVEN_K: ct.Constant[bool], +): + bid_x = ct.bid(0) # int + bid_y = ct.bid(1) # int + batch_idx = bid_y // H # int + head_idx = bid_y % H # int + off_kv_h = head_idx // QUERY_GROUP_SIZE # int + qk_scale = qk_scale * INV_LOG_2 + # init offset + offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32) # [TILE_M] + offs_m += input_pos + offs_m = offs_m[:, None] # [TILE_M, 1] + offs_n_tile = ct.arange(TILE_N, dtype=np.int32) # [TILE_N] + offs_n_tile = offs_n_tile[None, :] # [1, TILE_N] + + # initialize m, l, acc + m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32) + l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32) + acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) + # load q + q = ct.load( + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D) + ).reshape( + (TILE_M, TILE_D) + ) # [TILE_M, TILE_D] + + # loop over k, v and update accumulator + m_end = input_pos + (bid_x + 1) * TILE_M # int + k_seqlen = K.shape[2] # int + if CAUSAL: + # when kv pos could exceed q pos + mask_start = (input_pos + bid_x * TILE_M) // TILE_N + # when kv pos could exceed k_seqlen + mask_start = min(mask_start, k_seqlen // TILE_N) + Tc = ct.cdivi(min(m_end, k_seqlen), TILE_N) + else: + Tc = ct.cdivi(k_seqlen, TILE_N) + mask_start = k_seqlen // TILE_N + + for j in range(0, Tc): + # -- compute qk ---- + k = ct.load( + K, + index=(batch_idx, off_kv_h, 0, j), + shape=(1, 1, TILE_D, TILE_N), + order=(0, 1, 3, 2), + latency=2, + ) + k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32) + qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] + if (CAUSAL or not EVEN_K) and j >= mask_start: + offs_n = j * TILE_N + offs_n_tile + mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool) + # out of bound mask + if not EVEN_K: + mask = mask & (offs_n < k_seqlen) + # causal mask + if CAUSAL: + mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] + mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] + qk += mask + # Moving qk_scale multiplication after reduce_max is to improve performance. + m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale) + qk = qk * qk_scale - m_ij # [TILE_M, TILE_N] + + # attention weights + p = ct.exp2(qk, flush_to_zero=True) # [TILE_M, TILE_N] + l_ij = ct.sum(p, axis=-1, keepdims=True) # [TILE_M, 1] + alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # [TILE_M, 1] + # update m_i and l_i + l_i = l_i * alpha + l_ij # [TILE_M, 1] + # scale acc + acc = acc * alpha # [TILE_M, TILE_N] + # compute pv + v = ct.load( + V, + index=(batch_idx, off_kv_h, j, 0), + shape=(1, 1, TILE_N, TILE_D), + latency=4, + ).reshape( + (TILE_N, TILE_D) + ) # [TILE_N, TILE_D] + p = p.astype(Q.dtype) + acc = ct.mma(p, v, acc) # [TILE_M, TILE_N] + m_i = m_ij # [TILE_M, 1] + + acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) + acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype) + ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index c86ee6f3a4..8e8d916a62 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -1,3 +1,4 @@ +import argparse import copy import logging import operator @@ -58,6 +59,7 @@ def _process_sdpa_node( settings: CompilationSettings, sliding_window_size: Optional[int] = None, use_gqa: bool = False, + is_cutile_attention: bool = False, ) -> torch.fx.GraphModule: """ Helper function to process SDPA nodes with common logic. @@ -73,7 +75,7 @@ def _process_sdpa_node( settings: TensorRT compilation settings sliding_window_size: Optional sliding window size for models with sliding attention use_gqa: Whether the model uses Grouped Query Attention - + is_cutile_attention: Whether the attention is cutile attention Returns: The modified graph module with SDPA nodes replaced @@ -134,20 +136,29 @@ def _process_sdpa_node( f"SDPA converter configuration: attn_mask={attn_mask}, dropout_p={dropout_p}, " f"is_causal={is_causal}, sliding_window_size={sliding_window_size}, use_gqa={use_gqa}" ) - - modified_input_args = ( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - ) + if is_cutile_attention: + modified_input_args = ( + query, + key, + value, + is_causal, + ) + call_function_target = torch.ops.cutile.flash_attention.default + else: + modified_input_args = ( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + ) + call_function_target = torch.nn.functional.scaled_dot_product_attention # Create a new node with torch.nn.functional.scaled_dot_product_attention with gm.graph.inserting_after(node): new_node = gm.graph.call_function( - torch.nn.functional.scaled_dot_product_attention, + call_function_target, args=modified_input_args, kwargs={ "scale": node.kwargs.get("scale", None), @@ -282,7 +293,111 @@ def default_sdpa_pass( return gm -def enable_sdpa_converter(model_name: str, model_config: Any) -> None: +def register_cutile_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + """ + Register cutile SDPA pass for models without specific implementations. + + This function creates and registers a default SDPA replacement pass that can be used + for any model type. It provides basic SDPA replacement functionality without + model-specific optimizations. + + Args: + index: Position in the lowering pass list where this pass should be inserted + model_config: The model configuration object (optional, for consistency) + + Example: + # Register default pass at index 0 + register_cutile_sdpa_pass(index=0) + + # Or with model config for consistency + config = AutoConfig.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") + register_cutile_sdpa_pass(index=0, model_config=config) + + Note: + This is a fallback pass that should be used when no model-specific + SDPA pass is available or when you want generic SDPA replacement behavior. + """ + import math + + import cuda.tile as ct + + from .attention import fmha_kernel + + @torch.library.custom_op("cutile::flash_attention", mutates_args=()) # type: ignore[misc] + def cutile_flash_attention( + Q: torch.Tensor, # (batch_size, q_heads, q_len, hidden_size) + K: torch.Tensor, # (batch_size, k_heads, k_len, hidden_size) + V: torch.Tensor, # (batch_size, k_heads, k_len, hidden_size) + is_causal: bool = False, + tile_size_m: int = 128, + tile_size_n: int = 256, + ) -> torch.Tensor: + TILE_M, TILE_N = tile_size_m, tile_size_n + batch_size, q_heads, q_len, hidden_size = Q.shape + _, k_heads, k_len, _ = K.shape + query_group_size = q_heads // k_heads + qk_scale = 1 / math.sqrt(hidden_size) + O = torch.zeros_like(Q) + EVEN_K = (k_len % TILE_N) == 0 + grid = (math.ceil(q_len / TILE_M), batch_size * q_heads, 1) + input_pos = ( + 0 # TODO: figure out how to use the input_pos, for now do not use, set to 0 + ) + ct.launch( + torch.cuda.current_stream(), + grid, + fmha_kernel, + ( + Q, + K, + V, + O, + qk_scale, + input_pos, + hidden_size, + q_heads, + TILE_M, + TILE_N, + query_group_size, + is_causal, + EVEN_K, + ), + ) + return O + + @torch.library.register_fake("cutile::flash_attention") + def _( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + is_causal: bool = False, + tile_size_m: int = 128, + tile_size_n: int = 256, + ) -> torch.Tensor: + return Q + + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "cutile::flash_attention", supports_dynamic_shapes=True + ) + + @_aten_lowering_pass(index=index, model_config=model_config) + def cutile_sdpa_pass( + gm: torch.fx.GraphModule, + settings: CompilationSettings, + ) -> torch.fx.GraphModule: + """cutile SDPA pass for models without specific implementations.""" + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + # Process the node with default logic + gm = _process_sdpa_node(gm, node, settings) + + clean_up_graph_after_modifications(gm) + logger.debug("Applied cutile SDPA replacement") + return gm + + +def enable_sdpa_converter(args: argparse.Namespace, model_config: Any) -> None: """ Enables the custom SDPA converter for a given model. @@ -300,15 +415,17 @@ def enable_sdpa_converter(model_name: str, model_config: Any) -> None: like sliding window attention. """ _remove_decompositions() - - pass_registrator = _SDPA_MAPPING.get(model_name) - + if args.enable_cutile_attention: + pass_registrator = register_cutile_sdpa_pass + logger.info(f"Registering cutile SDPA lowering pass for model: {args.model}") + else: + pass_registrator = _SDPA_MAPPING.get(args.model) + logger.info(f"Registering specific SDPA lowering pass for model: {args.model}") if pass_registrator: - logger.info(f"Registering specific SDPA lowering pass for model: {model_name}") pass_registrator(model_config=model_config) else: logger.info( - f"No specific SDPA lowering pass for model '{model_name}'. " + f"No specific SDPA lowering pass for model '{args.model}'. " "Using default SDPA pass." ) _SDPA_MAPPING["default"](model_config=model_config) @@ -317,5 +434,6 @@ def enable_sdpa_converter(model_name: str, model_config: Any) -> None: # Global registry for SDPA passes _SDPA_MAPPING: Dict[str, Callable] = { "google/gemma-3-1b-it": register_gemma3_sdpa_pass, + # "Qwen/Qwen2.5-0.5B-Instruct": register_cutile_sdpa_pass, "default": register_default_sdpa_pass, } diff --git a/tools/llm/utils.py b/tools/llm/utils.py index b9e3506f4b..2b44e99cd0 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -194,7 +194,6 @@ def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_tok num_tokens_generated = 0 kv_cache = get_zeroed_dynamic_cache_inputs(model) last_position_id = position_ids[-1, -1].item() - breakpoint() while num_tokens_generated < num_output_tokens: is_generate = False if input_seq.shape[1] > 1 else True position_ids = (