diff --git a/tests/py/dynamo/automatic_plugin/cutile/__init__.py b/tests/py/dynamo/automatic_plugin/cutile/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/py/dynamo/automatic_plugin/cutile/attention.py b/tests/py/dynamo/automatic_plugin/cutile/attention.py new file mode 100644 index 0000000000..603c728de3 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/cutile/attention.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import math + +import cuda.tile as ct +import numpy as np +from cuda.tile import RoundingMode as RMd + +INV_LOG_2 = 1.0 / math.log(2) + + +# Define type aliases for Constant integers and booleans +ConstInt = ct.Constant[int] +ConstBool = ct.Constant[bool] + + +# --- FMHA Kernel Implementation --- +@ct.kernel(occupancy=2) +def fmha_kernel( + Q, + K, + V, + Out, + qk_scale: float, + input_pos: int, + TILE_D: ConstInt, # TILE_D = hidden_size + H: ConstInt, + TILE_M: ConstInt, + TILE_N: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + CAUSAL: ConstBool, + EVEN_K: ConstBool, +): + """ + cuTile kernel for Fused Multi-Head Attention (FMHA). + Computes attention output for a specific batch item and head, using tiling and online softmax. + """ + # Map block IDs to batch and head indices + bid_x = ct.bid(0) + bid_y = ct.bid(1) + batch_idx = bid_y // H + head_idx = bid_y % H + off_kv_h = head_idx // QUERY_GROUP_SIZE + + # Adjust qk_scale for exp2 + qk_scale = qk_scale * INV_LOG_2 + + # Initialize offsets for current query tile (M-dimension) + offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32) # [TILE_M] + offs_m += input_pos + offs_m = offs_m[:, None] # [TILE_M, 1] + + # Initialize local offsets for key/value tile (N-dimension) + offs_n_tile = ct.arange(TILE_N, dtype=np.int32) # [TILE_N] + offs_n_tile = offs_n_tile[None, :] # [1, TILE_N] + + # Initialize online softmax accumulators in float32 for stability + m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32) + l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32) + acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) + + # Load query tile for this batch, head, and M-chunk + q = ct.load( + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D) + ).reshape( + (TILE_M, TILE_D) + ) # [TILE_M, TILE_D] + + # loop over k, v and update accumulator + m_end = input_pos + (bid_x + 1) * TILE_M + k_seqlen = K.shape[2] + if CAUSAL: + # when kv pos could exceed q pos + mask_start = (input_pos + bid_x * TILE_M) // TILE_N + # when kv pos could exceed k_seqlen + mask_start = min(mask_start, k_seqlen // TILE_N) + Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N) + else: + Tc = ct.cdiv(k_seqlen, TILE_N) + mask_start = k_seqlen // TILE_N + + # Loop over K, V blocks (N-dimension chunks) + for j in range(0, Tc): + # --- Compute QK product --- + k = ct.load( + K, + index=(batch_idx, off_kv_h, 0, j), + shape=(1, 1, TILE_D, TILE_N), + order=(0, 1, 3, 2), + latency=2, + ) + k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32) + qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] + + # --- Apply Causal Masking --- + if (CAUSAL or not EVEN_K) and j >= mask_start: + offs_n = j * TILE_N + offs_n_tile + mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool) + # out of bound mask + if not EVEN_K: + mask = mask & (offs_n < k_seqlen) + # causal mask + if CAUSAL: + mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] + mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] + qk += mask + + # --- Online Softmax Update --- + # Moving qk_scale multiplication after reduce_max is to improve performance. + m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale) + qk = qk * qk_scale - m_ij # [TILE_M, TILE_N] + + # attention weights + p = ct.exp2(qk, flush_to_zero=True) # [TILE_M, TILE_N] + l_ij = ct.sum(p, axis=-1, keepdims=True) # [TILE_M, 1] + alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # [TILE_M, 1] + # update m_i and l_i + l_i = l_i * alpha + l_ij # [TILE_M, 1] + # scale acc + acc = acc * alpha # [TILE_M, TILE_N] + + # --- Compute PV product --- + v = ct.load( + V, + index=(batch_idx, off_kv_h, j, 0), + shape=(1, 1, TILE_N, TILE_D), + latency=4, + ).reshape( + (TILE_N, TILE_D) + ) # [TILE_N, TILE_D] + p = p.astype(Q.dtype) + acc = ct.mma(p, v, acc) # [TILE_M, TILE_N] + m_i = m_ij # [TILE_M, 1] + + # --- Final Normalization and Store --- + acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) + acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype) + ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc) diff --git a/tests/py/dynamo/automatic_plugin/cutile/matmul.py b/tests/py/dynamo/automatic_plugin/cutile/matmul.py new file mode 100644 index 0000000000..f355dd4b5b --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/cutile/matmul.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import cuda.tile as ct + +# Define a type alias for Constant integers. +# This makes kernel signatures cleaner and indicates that these parameters +# are compile-time constants, which cuTile uses for optimization. +ConstInt = ct.Constant[int] + + +def swizzle_2d(M, N, tm, tn, GROUP_SIZE_M): + # Get the global IDs of the current CUDA block (CTA) in a 1D grid. + bid = ct.bid(0) + num_bid_m = ct.cdiv(M, tm) + num_bid_n = ct.cdiv(N, tn) + num_bid_in_group = GROUP_SIZE_M * num_bid_n + group_id = bid // num_bid_in_group + first_bid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_bid_m - first_bid_m, GROUP_SIZE_M) + bid_m = first_bid_m + (bid % group_size_m) + bid_n = (bid % num_bid_in_group) // group_size_m + return bid_m, bid_n + + +@ct.kernel(num_ctas=ct.ByTarget(sm_100=2)) +def matmul_kernel( + A, + B, + C, + tm: ConstInt, # Tile size along M dimension (rows of C) + tn: ConstInt, # Tile size along N dimension (columns of C) + tk: ConstInt, +): # Tile size along K dimension (inner product dimension) + """ + cuTile kernel for performing matrix multiplication C = A @ B. + + This kernel uses a tiled approach, where each CUDA thread block (CTA) + computes a `tm` x `tn` tile of the output matrix C. The computation + involves iterating over the K-dimension in chunks of `tk`. + + Args: + A: Input matrix A (M x K). + B: Input matrix B (K x N). + C: Output matrix C (M x N). + tm (ConstInt): The height of the output tile computed by this block. + Corresponds to rows of A and C. + tn (ConstInt): The width of the output tile computed by this block. + Corresponds to columns of B and C. + tk (ConstInt): The depth of the inner loop (K-dimension) tile size. + Corresponds to columns of A and rows of B. + """ + GROUP_SIZE_M = 8 + M = A.shape[0] + N = B.shape[1] + bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M) + + # Calculate the total number of K-tiles that need to be processed. + # `ct.num_tiles(A, axis=1, shape=(tm, tk))` extracts the K-dimension (axis 1) + # from matrix A's shape, assuming A's shape is conceptually (M_tiles, K_tiles), + # and then implicitly performs ceiling division by `tk` to get the number of K-tiles. + num_tiles_k = ct.num_tiles(A, axis=1, shape=(tm, tk)) + + # Initialize an accumulator for the current output tile (tm x tn). + # It's common practice to use `float32` for accumulation even with `float16` inputs + # to maintain higher precision during the sum-reduction of the matrix multiplication. + accumulator = ct.full((tm, tn), 0, dtype=ct.float32) + zero_pad = ct.PaddingMode.ZERO + + # Convert fp32 to tf32 to use tensorcore + dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype + + # K-dimension loop: Iterate over the K-dimension in chunks of 'tk'. + # In each iteration, a `tm` x `tk` tile from A and a `tk` x `tn` tile from B + # are loaded, multiplied, and accumulated. + for k in range(num_tiles_k): + # Load tile from matrix A. + # The `index=(bidx, k_tile_idx)` specifies which (M-tile, K-tile) to load + # from global memory A. `shape=(tm, tk)` defines the size of this tile. + a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_mode=zero_pad).astype( + dtype + ) + + # Load tile from matrix B. + # The `index=(k_tile_idx, bidy)` specifies which (K-tile, N-tile) to load + # from global memory B. `shape=(tk, tn)` defines the size of this tile. + b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_mode=zero_pad).astype( + dtype + ) + + # Perform Matrix Multiplication for the current tiles. + # `ct.mma` computes the product of the two loaded tiles and accumulates the result. + accumulator = ct.mma(a, b, accumulator) + + # Convert the final accumulated result to the desired output data type (C.dtype). + # This might downcast from float32 to float16 if the output is float16. + accumulator = ct.astype(accumulator, C.dtype) + + # Store the computed tile to the global memory of the output matrix C. + # The `(bidx, bidy)` directly corresponds to the tile's position in the 2D output matrix. + ct.store(C, index=(bidx, bidy), tile=accumulator) + + +@ct.kernel +def matmul_split_k_kernel( + A, B, C, LOCKS, COUNTS, tm: ConstInt, tn: ConstInt, tk: ConstInt, SPLIT_K: ConstInt +): + GROUP_SIZE_M = 8 + M = A.shape[0] + N = B.shape[1] + bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M) + bidz = ct.bid(1) + + num_tiles = ct.num_tiles(A, axis=1, shape=(tm, tk)) + sum = ct.full((tm, tn), 0, dtype=ct.float32) + zero_pad = ct.PaddingMode.ZERO + + # Convert fp32 to tf32 to use tensorcore + dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype + + for k in range(bidz, num_tiles, SPLIT_K): + a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_mode=zero_pad).astype( + dtype + ) + b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_mode=zero_pad).astype( + dtype + ) + sum = ct.mma(a, b, sum) + + sum = ct.astype(sum, C.dtype) + lock_offset = ct.bid(0) + count_offset = lock_offset + while ( + ct.atomic_cas(LOCKS, lock_offset, 0, 1, memory_order=ct.MemoryOrder.ACQUIRE) + == 1 + ): + pass + count = ct.gather(COUNTS, count_offset) + if count == 0: + ct.store(C, index=(bidx, bidy), tile=sum) + else: + curr = ct.load(C, index=(bidx, bidy), shape=(tm, tn)) + ct.store(C, index=(bidx, bidy), tile=(curr + sum)) + ct.scatter(COUNTS, count_offset, (count + 1) % SPLIT_K) + ct.atomic_xchg(LOCKS, lock_offset, 0, memory_order=ct.MemoryOrder.RELEASE) + + +@ct.kernel +def batch_matmul_kernel(A, B, C, tm: ConstInt, tn: ConstInt, tk: ConstInt): + """CuTile kernel for batch matrix multiplication + A has shape (Batch, M, K), B has shape (Batch, K, N) and C has shape (Batch, M, N) + Each thread block computes one (tm x tn) tile for a specific batch item. + The grid is 3D: (Batch_idx, M_tile_idx, N_tile_idx). + """ + pid_batch = ct.bid(0) # Batch dimension + pidx = ct.bid(1) # M dimension + pidy = ct.bid(2) # N dimension + + # Calculate number of K tiles + # A is (Batch, M, K), so K is axis 2 + # Use A.shape[2] for the total K dimension and ct.cdiv for ceiling division + num_k_tiles = ct.cdiv(A.shape[2], tk) + + # Initialize accumulator + accumulator = ct.full((tm, tn), 0.0, dtype=ct.float32) + zero_pad = ct.PaddingMode.ZERO + # K-dimension loop + for k in range(num_k_tiles): + # Load tiles with 3D index and 3D shape + # A is (Batch, M, K), load (1, tm, tk) tile + a = ct.load( + A, index=(pid_batch, pidx, k), shape=(1, tm, tk), padding_mode=zero_pad + ) + a = ct.reshape(a, (tm, tk)) # Reshape to 2D for ct.mma + + # B is (Batch, K, N), load (1, tk, tn) tile + b = ct.load( + B, index=(pid_batch, k, pidy), shape=(1, tk, tn), padding_mode=zero_pad + ) + b = ct.reshape(b, (tk, tn)) # Reshape to 2D for ct.mma + + accumulator = ct.mma(a, b, acc=accumulator) + + # Convert to output dtype and store + result = ct.astype(accumulator, C.dtype) + # Store with 3D index and 3D shape, C is (Batch, M, N) + result_3d = ct.reshape(result, (1, tm, tn)) + ct.store(C, index=(pid_batch, pidx, pidy), tile=result_3d) diff --git a/tests/py/dynamo/automatic_plugin/test_cutile_attention.py b/tests/py/dynamo/automatic_plugin/test_cutile_attention.py new file mode 100644 index 0000000000..d7dfdf5479 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_cutile_attention.py @@ -0,0 +1,228 @@ +import importlib.util +import math +import platform +import unittest + +import torch +import torch.nn as nn +import torch_tensorrt +from torch_tensorrt import Input + +if not importlib.util.find_spec("cuda.tile"): + print("cuda.tile is not installed, skipping cuTile tests") +else: + import cuda.tile as ct + + from .cutile.attention import fmha_kernel + from .cutile.matmul import matmul_kernel + + def register_cutile_flash_attention(): + + @torch.library.custom_op("cutile::flash_attention", mutates_args=()) # type: ignore[misc] + def cutile_flash_attention( + Q: torch.Tensor, # (batch_size, q_heads, q_len, hidden_size) + K: torch.Tensor, # (batch_size, k_heads, k_len, hidden_size) + V: torch.Tensor, # (batch_size, k_heads, k_len, hidden_size) + is_causal: bool = False, + tile_size_m: int = 8, + tile_size_n: int = 16, + ) -> torch.Tensor: + TILE_M, TILE_N = tile_size_m, tile_size_n + batch_size, q_heads, q_len, hidden_size = Q.shape + _, k_heads, k_len, _ = K.shape + query_group_size = q_heads // k_heads + qk_scale = 1 / math.sqrt(hidden_size) + O = torch.zeros_like(Q) + EVEN_K = (k_len % TILE_N) == 0 + grid = (math.ceil(q_len / TILE_M), batch_size * q_heads, 1) + input_pos = 0 # TODO: figure out how to use the input_pos, for now do not use, set to 0 + ct.launch( + torch.cuda.current_stream(), + grid, + fmha_kernel, + ( + Q, + K, + V, + O, + qk_scale, + input_pos, + hidden_size, + q_heads, + TILE_M, + TILE_N, + query_group_size, + is_causal, + EVEN_K, + ), + ) + return O + + @torch.library.register_fake("cutile::flash_attention") + def _cutile_flash_attention_fake( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + is_causal: bool = False, + tile_size_m: int = 8, + tile_size_n: int = 16, + ) -> torch.Tensor: + return Q + + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "cutile::flash_attention", supports_dynamic_shapes=True + ) + + class cutile_flash_attention(nn.Module): + def forward(self, Q, K, V): + return torch.ops.cutile.flash_attention.default(Q, K, V, True) + + class torch_flash_attention(nn.Module): + def forward(self, Q, K, V): + return torch.nn.functional.scaled_dot_product_attention( + Q, + K, + V, + attn_mask=None, + dropout_p=0.0, + is_causal=True, + ) + + @unittest.skipIf( + torch.cuda.get_device_capability() < (10, 0), + "cuTile requires compute capability 10.0 or later", + ) + @unittest.skipIf( + not importlib.util.find_spec("cuda.tile"), + "cuda.tile is required to run this test", + ) + @unittest.skipIf( + platform.system() != "Linux", + "cuTile is only supported on Linux for now", + ) + @unittest.skipIf( + torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, + "TensorRT RTX does not support plugins which is required for cuTile", + ) + class TestCutile: + register_cutile_flash_attention() + + def test_cutile_flash_attention(self): + data_type = torch.float32 + inputs = ( + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + ) + enable_cutile, enable_trt_native = True, True + if enable_cutile: + with torch.no_grad(): + cutile_mod = cutile_flash_attention() + cutile_mod_ep = torch.export.export(cutile_mod, inputs) + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir="debuglogs_cutile_attention", + capture_fx_graph_after=["constant_fold"], + save_engine_profile=True, + profile_format="trex", + engine_builder_monitor=False, + ): + trt_cutile_mod = torch_tensorrt.dynamo.compile( + cutile_mod_ep, + inputs, + min_block_size=1, + ) + outputs_cutile = trt_cutile_mod(*inputs) + + if enable_trt_native: + with torch.no_grad(): + torch_mod = torch_flash_attention() + torch_mod_ep = torch.export.export(torch_mod, inputs) + trt_torch_mod = torch_tensorrt.dynamo.compile( + torch_mod_ep, + inputs, + min_block_size=1, + ) + outputs_trt = trt_torch_mod(*inputs) + + if enable_cutile and enable_trt_native: + assert torch.allclose(outputs_cutile, outputs_trt, atol=5e-3, rtol=1e-2) + + def test_cutile_flash_attention_dynamic_shape(self): + data_type = torch.float32 + input_specs = [ + Input( + min_shape=(32, 8, 1, 64), + opt_shape=(32, 8, 128, 64), + max_shape=(32, 8, 256, 64), + dtype=data_type, + ), + Input( + min_shape=(32, 8, 1, 64), + opt_shape=(32, 8, 128, 64), + max_shape=(32, 8, 256, 64), + dtype=data_type, + ), + Input( + min_shape=(32, 8, 1, 64), + opt_shape=(32, 8, 128, 64), + max_shape=(32, 8, 256, 64), + dtype=data_type, + ), + ] + + compile_inputs = ( + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 128, 64), device="cuda", dtype=data_type).cuda(), + ) + inference_inputs = ( + torch.randn((32, 8, 256, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 256, 64), device="cuda", dtype=data_type).cuda(), + torch.randn((32, 8, 256, 64), device="cuda", dtype=data_type).cuda(), + ) + + enable_cutile, enable_trt_native = True, True + q_len_dim = torch.export.Dim("q_len", min=1, max=256) + dynamic_shapes = { + "Q": {2: q_len_dim}, + "K": {2: q_len_dim}, + "V": {2: q_len_dim}, + } + if enable_cutile: + with torch.no_grad(): + cutile_mod = cutile_flash_attention() + cutile_mod_ep = torch.export.export( + cutile_mod, + compile_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + trt_cutile_mod = torch_tensorrt.dynamo.compile( + cutile_mod_ep, + input_specs, + min_block_size=1, + enable_precisions={data_type}, + ) + outputs_cutile = trt_cutile_mod(*inference_inputs) + + if enable_trt_native: + with torch.no_grad(): + torch_mod = torch_flash_attention() + torch_mod_ep = torch.export.export( + torch_mod, + compile_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + trt_torch_mod = torch_tensorrt.dynamo.compile( + torch_mod_ep, + input_specs, + min_block_size=1, + enable_precisions={data_type}, + ) + outputs_trt = trt_torch_mod(*inference_inputs) + + if enable_cutile and enable_trt_native: + assert torch.allclose(outputs_cutile, outputs_trt, atol=5e-3, rtol=1e-2) diff --git a/tests/py/dynamo/automatic_plugin/test_cutile_matmul.py b/tests/py/dynamo/automatic_plugin/test_cutile_matmul.py new file mode 100644 index 0000000000..28bf7fb1f8 --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_cutile_matmul.py @@ -0,0 +1,114 @@ +import importlib.util +import platform +import unittest +from math import ceil + +import torch +import torch.nn as nn +import torch_tensorrt +from parameterized import parameterized + +if not importlib.util.find_spec("cuda.tile"): + print("cuda.tile is not installed, skipping cuTile tests") +else: + import cuda.tile as ct + + from .cutile.matmul import matmul_kernel + + def register_cutile_matmul(): + @torch.library.custom_op("cutile::matmul", mutates_args=()) # type: ignore[misc] + def cutile_matmul( + A: torch.Tensor, + B: torch.Tensor, + tile_size_m: int = 256, + tile_size_n: int = 256, + tile_size_k: int = 64, + ) -> torch.Tensor: + C = torch.empty(A.shape[0], B.shape[1], device=A.device, dtype=A.dtype) + tm, tn, tk = tile_size_m, tile_size_n, tile_size_k + m, n, _ = A.shape[0], B.shape[1], A.shape[1] + grid = (ceil(m / tm) * ceil(n / tn), 1, 1) + ct.launch( + torch.cuda.current_stream(), grid, matmul_kernel, (A, B, C, tm, tn, tk) + ) + return C + + @torch.library.register_fake("cutile::matmul") + def _( + A: torch.Tensor, + B: torch.Tensor, + tile_size_m: int = 256, + tile_size_n: int = 256, + tile_size_k: int = 64, + ) -> torch.Tensor: + return torch.empty(A.shape[0], B.shape[1], device=A.device, dtype=A.dtype) + + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "cutile::matmul", supports_dynamic_shapes=True + ) + + @unittest.skipIf( + torch.cuda.get_device_capability() < (10, 0), + "cuTile requires compute capability 10.0 or later", + ) + @unittest.skipIf( + not importlib.util.find_spec("cuda.tile"), + "cuda.tile is required to run this test", + ) + @unittest.skipIf( + platform.system() != "Linux", + "cuTile is only supported on Linux for now", + ) + @unittest.skipIf( + torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, + "TensorRT RTX does not support plugins which is required for cuTile", + ) + class TestMatmul: + register_cutile_matmul() + + @parameterized.expand( + [ + ((64, 64), (64, 128), torch.float16), + ((256, 256), (256, 16), torch.float16), + ] + ) + def test_matmul(self, a_shape, b_shape, data_type): + class cutile_matmul(nn.Module): + def forward(self, a, b): + return torch.ops.cutile.matmul.default(a, b) + + class torch_matmul(nn.Module): + def forward(self, a, b): + return torch.matmul(a, b) + + inputs = ( + torch.randn(a_shape, device="cuda", dtype=data_type), + torch.randn(b_shape, device="cuda", dtype=data_type), + ) + enable_cutile, enable_trt_native = True, True + if enable_cutile: + with torch.no_grad(): + cutile_mod = cutile_matmul() + cutile_mod_ep = torch.export.export(cutile_mod, inputs) + trt_cutile_mod = torch_tensorrt.dynamo.compile( + cutile_mod_ep, + inputs, + min_block_size=1, + ) + outputs_cutile = trt_cutile_mod(*inputs) + + if enable_trt_native: + with torch.no_grad(): + torch_mod = torch_matmul() + torch_mod_ep = torch.export.export(torch_mod, inputs) + + trt_torch_mod = torch_tensorrt.dynamo.compile( + torch_mod_ep, + inputs, + min_block_size=1, + ) + outputs_trt = trt_torch_mod(*inputs) + print(f"outputs_trt: {outputs_trt.shape}") + + if enable_trt_native and enable_cutile: + assert torch.allclose(outputs_cutile, outputs_trt, atol=1e-4, rtol=1e-4) diff --git a/tools/llm/README.md b/tools/llm/README.md index 05a1e3cc60..d06816e9a3 100644 --- a/tools/llm/README.md +++ b/tools/llm/README.md @@ -58,6 +58,7 @@ python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --num_tokens 128 --c - `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching). - `--benchmark`: Enable benchmarking mode. - `--enable_pytorch_run`: Also run and compare PyTorch baseline. +- `--enable_cutile_attention`: use cutile attention kernel registered as QDP of TensorRT. ### Caching Strategies diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 1531c30622..0f8ef8cb75 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -59,7 +59,7 @@ def get_model(args): .cuda() ) # register SDPA variant for the model - register_sdpa.enable_sdpa_converter(args.model, model.config) + register_sdpa.enable_sdpa_converter(args, model.config) if args.precision == "FP16": model = model.to(torch.float16) @@ -236,6 +236,11 @@ def measure_perf(trt_model, input_signature, backend_name): arg_parser.add_argument( "--benchmark", action="store_true", help="Enable benchmark (default: False)" ) + arg_parser.add_argument( + "--enable_cutile_attention", + action="store_true", + help="Enable cutile attention (default: False)", + ) args = arg_parser.parse_args() with torch.inference_mode(): diff --git a/tools/llm/torchtrt_ext/attention.py b/tools/llm/torchtrt_ext/attention.py new file mode 100644 index 0000000000..0ea7f65ac7 --- /dev/null +++ b/tools/llm/torchtrt_ext/attention.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import cuda.tile as ct +import numpy as np +import math + +from cuda.tile import RoundingMode as RMd + +INV_LOG_2 = 1.0 / math.log(2) + + +# Define type aliases for Constant integers and booleans +ConstInt = ct.Constant[int] +ConstBool = ct.Constant[bool] + + +# --- FMHA Kernel Implementation --- +@ct.kernel(occupancy=2) +def fmha_kernel(Q, K, V, Out, + qk_scale: float, + input_pos: int, + TILE_D: ConstInt, # TILE_D = hidden_size + H: ConstInt, + TILE_M: ConstInt, + TILE_N: ConstInt, + QUERY_GROUP_SIZE: ConstInt, + CAUSAL: ConstBool, + EVEN_K: ConstBool): + """ + cuTile kernel for Fused Multi-Head Attention (FMHA). + Computes attention output for a specific batch item and head, using tiling and online softmax. + """ + # Map block IDs to batch and head indices + bid_x = ct.bid(0) + bid_y = ct.bid(1) + batch_idx = bid_y // H + head_idx = bid_y % H + off_kv_h = head_idx // QUERY_GROUP_SIZE + + # Adjust qk_scale for exp2 + qk_scale = qk_scale * INV_LOG_2 + + # Initialize offsets for current query tile (M-dimension) + offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32) # [TILE_M] + offs_m += input_pos + offs_m = offs_m[:, None] # [TILE_M, 1] + + # Initialize local offsets for key/value tile (N-dimension) + offs_n_tile = ct.arange(TILE_N, dtype=np.int32) # [TILE_N] + offs_n_tile = offs_n_tile[None, :] # [1, TILE_N] + + # Initialize online softmax accumulators in float32 for stability + m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32) + l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32) + acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) + + # Load query tile for this batch, head, and M-chunk + q = ct.load( + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D) + ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] + + # loop over k, v and update accumulator + m_end = input_pos + (bid_x + 1) * TILE_M + k_seqlen = K.shape[2] + if CAUSAL: + # when kv pos could exceed q pos + mask_start = (input_pos + bid_x * TILE_M) // TILE_N + # when kv pos could exceed k_seqlen + mask_start = min(mask_start, k_seqlen // TILE_N) + Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N) + else: + Tc = ct.cdiv(k_seqlen, TILE_N) + mask_start = k_seqlen // TILE_N + + # Loop over K, V blocks (N-dimension chunks) + for j in range(0, Tc): + # --- Compute QK product --- + k = ct.load( + K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N), + order=(0, 1, 3, 2), + latency=2, + ) + k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) + qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] + + # --- Apply Causal Masking --- + if (CAUSAL or not EVEN_K) and j >= mask_start: + offs_n = j * TILE_N + offs_n_tile + mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool) + # out of bound mask + if not EVEN_K: + mask = mask & (offs_n < k_seqlen) + # causal mask + if CAUSAL: + mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] + mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] + qk += mask + + # --- Online Softmax Update --- + # Moving qk_scale multiplication after reduce_max is to improve performance. + m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale) + qk = qk * qk_scale - m_ij # [TILE_M, TILE_N] + + # attention weights + p = ct.exp2(qk, flush_to_zero=True) # [TILE_M, TILE_N] + l_ij = ct.sum(p, axis=-1, keepdims=True) # [TILE_M, 1] + alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # [TILE_M, 1] + # update m_i and l_i + l_i = l_i * alpha + l_ij # [TILE_M, 1] + # scale acc + acc = acc * alpha # [TILE_M, TILE_N] + + # --- Compute PV product --- + v = ct.load( + V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D), + latency=4, + ).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D] + p = p.astype(Q.dtype) + acc = ct.mma(p, v, acc) # [TILE_M, TILE_N] + m_i = m_ij # [TILE_M, 1] + + # --- Final Normalization and Store --- + acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) + acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype) + ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc) \ No newline at end of file diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index c86ee6f3a4..1191508e3d 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -1,3 +1,4 @@ +import argparse import copy import logging import operator @@ -58,6 +59,7 @@ def _process_sdpa_node( settings: CompilationSettings, sliding_window_size: Optional[int] = None, use_gqa: bool = False, + is_cutile_attention: bool = False, ) -> torch.fx.GraphModule: """ Helper function to process SDPA nodes with common logic. @@ -73,7 +75,7 @@ def _process_sdpa_node( settings: TensorRT compilation settings sliding_window_size: Optional sliding window size for models with sliding attention use_gqa: Whether the model uses Grouped Query Attention - + is_cutile_attention: Whether the attention is cutile attention Returns: The modified graph module with SDPA nodes replaced @@ -134,20 +136,29 @@ def _process_sdpa_node( f"SDPA converter configuration: attn_mask={attn_mask}, dropout_p={dropout_p}, " f"is_causal={is_causal}, sliding_window_size={sliding_window_size}, use_gqa={use_gqa}" ) - - modified_input_args = ( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - ) + if is_cutile_attention: + modified_input_args = ( + query, + key, + value, + is_causal, + ) + call_function_target = torch.ops.cutile.flash_attention.default + else: + modified_input_args = ( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + ) + call_function_target = torch.nn.functional.scaled_dot_product_attention # Create a new node with torch.nn.functional.scaled_dot_product_attention with gm.graph.inserting_after(node): new_node = gm.graph.call_function( - torch.nn.functional.scaled_dot_product_attention, + call_function_target, args=modified_input_args, kwargs={ "scale": node.kwargs.get("scale", None), @@ -282,7 +293,111 @@ def default_sdpa_pass( return gm -def enable_sdpa_converter(model_name: str, model_config: Any) -> None: +def register_cutile_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + """ + Register cutile SDPA pass for models without specific implementations. + + This function creates and registers a default SDPA replacement pass that can be used + for any model type. It provides basic SDPA replacement functionality without + model-specific optimizations. + + Args: + index: Position in the lowering pass list where this pass should be inserted + model_config: The model configuration object (optional, for consistency) + + Example: + # Register default pass at index 0 + register_cutile_sdpa_pass(index=0) + + # Or with model config for consistency + config = AutoConfig.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") + register_cutile_sdpa_pass(index=0, model_config=config) + + Note: + This is a fallback pass that should be used when no model-specific + SDPA pass is available or when you want generic SDPA replacement behavior. + """ + import math + + import cuda.tile as ct + + from .attention import fmha_kernel + + @torch.library.custom_op("cutile::flash_attention", mutates_args=()) # type: ignore[misc] + def cutile_flash_attention( + Q: torch.Tensor, # (batch_size, q_heads, q_len, hidden_size) + K: torch.Tensor, # (batch_size, k_heads, k_len, hidden_size) + V: torch.Tensor, # (batch_size, k_heads, k_len, hidden_size) + is_causal: bool = False, + tile_size_m: int = 128, + tile_size_n: int = 256, + ) -> torch.Tensor: + TILE_M, TILE_N = tile_size_m, tile_size_n + batch_size, q_heads, q_len, hidden_size = Q.shape + _, k_heads, k_len, _ = K.shape + query_group_size = q_heads // k_heads + qk_scale = 1 / math.sqrt(hidden_size) + O = torch.zeros_like(Q) + EVEN_K = (k_len % TILE_N) == 0 + grid = (math.ceil(q_len / TILE_M), batch_size * q_heads, 1) + input_pos = ( + 0 # TODO: figure out how to use the input_pos, for now do not use, set to 0 + ) + ct.launch( + torch.cuda.current_stream(), + grid, + fmha_kernel, + ( + Q, + K, + V, + O, + qk_scale, + input_pos, + hidden_size, + q_heads, + TILE_M, + TILE_N, + query_group_size, + is_causal, + EVEN_K, + ), + ) + return O + + @torch.library.register_fake("cutile::flash_attention") + def _( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + is_causal: bool = False, + tile_size_m: int = 128, + tile_size_n: int = 256, + ) -> torch.Tensor: + return Q + + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "cutile::flash_attention", supports_dynamic_shapes=True + ) + + @_aten_lowering_pass(index=index, model_config=model_config) + def cutile_sdpa_pass( + gm: torch.fx.GraphModule, + settings: CompilationSettings, + ) -> torch.fx.GraphModule: + """cutile SDPA pass for models without specific implementations.""" + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + # Process the node with default logic + gm = _process_sdpa_node(gm, node, settings) + + clean_up_graph_after_modifications(gm) + logger.debug("Applied cutile SDPA replacement") + return gm + + +def enable_sdpa_converter(args: argparse.Namespace, model_config: Any) -> None: """ Enables the custom SDPA converter for a given model. @@ -300,15 +415,17 @@ def enable_sdpa_converter(model_name: str, model_config: Any) -> None: like sliding window attention. """ _remove_decompositions() - - pass_registrator = _SDPA_MAPPING.get(model_name) - + if args.enable_cutile_attention: + pass_registrator = register_cutile_sdpa_pass + logger.info(f"Registering cutile SDPA lowering pass for model: {args.model}") + else: + pass_registrator = _SDPA_MAPPING.get(args.model) + logger.info(f"Registering specific SDPA lowering pass for model: {args.model}") if pass_registrator: - logger.info(f"Registering specific SDPA lowering pass for model: {model_name}") pass_registrator(model_config=model_config) else: logger.info( - f"No specific SDPA lowering pass for model '{model_name}'. " + f"No specific SDPA lowering pass for model '{args.model}'. " "Using default SDPA pass." ) _SDPA_MAPPING["default"](model_config=model_config) diff --git a/tools/llm/utils.py b/tools/llm/utils.py index b9e3506f4b..2b44e99cd0 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -194,7 +194,6 @@ def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_tok num_tokens_generated = 0 kv_cache = get_zeroed_dynamic_cache_inputs(model) last_position_id = position_ids[-1, -1].item() - breakpoint() while num_tokens_generated < num_output_tokens: is_generate = False if input_seq.shape[1] > 1 else True position_ids = (