|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# copied from https://gitlab-master.nvidia.com/cuda-python/cuda-python-tile-compiler/-/raw/main/test/kernels/attention.py?ref_type=heads |
| 5 | + |
| 6 | +import math |
| 7 | + |
| 8 | +import cuda.tile as ct |
| 9 | +import numpy as np |
| 10 | +from cuda.tile.numeric_semantics import RoundingMode as RMd |
| 11 | + |
| 12 | +INV_LOG_2 = 1.0 / math.log(2) |
| 13 | + |
| 14 | + |
| 15 | +@ct.kernel(occupancy=2) |
| 16 | +def fmha_kernel( |
| 17 | + Q, |
| 18 | + K, |
| 19 | + V, |
| 20 | + Out, |
| 21 | + qk_scale: float, |
| 22 | + input_pos: int, |
| 23 | + TILE_D: ct.Constant[int], # TILE_D = hidden_size |
| 24 | + H: ct.Constant[int], |
| 25 | + TILE_M: ct.Constant[int], |
| 26 | + TILE_N: ct.Constant[int], |
| 27 | + QUERY_GROUP_SIZE: ct.Constant[int], |
| 28 | + CAUSAL: ct.Constant[bool], |
| 29 | + EVEN_K: ct.Constant[bool], |
| 30 | +): |
| 31 | + bid_x = ct.bid(0) # int |
| 32 | + bid_y = ct.bid(1) # int |
| 33 | + batch_idx = bid_y // H # int |
| 34 | + head_idx = bid_y % H # int |
| 35 | + off_kv_h = head_idx // QUERY_GROUP_SIZE # int |
| 36 | + qk_scale = qk_scale * INV_LOG_2 |
| 37 | + # init offset |
| 38 | + offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32) # [TILE_M] |
| 39 | + offs_m += input_pos |
| 40 | + offs_m = offs_m[:, None] # [TILE_M, 1] |
| 41 | + offs_n_tile = ct.arange(TILE_N, dtype=np.int32) # [TILE_N] |
| 42 | + offs_n_tile = offs_n_tile[None, :] # [1, TILE_N] |
| 43 | + |
| 44 | + # initialize m, l, acc |
| 45 | + m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32) |
| 46 | + l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32) |
| 47 | + acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) |
| 48 | + # load q |
| 49 | + q = ct.load( |
| 50 | + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D) |
| 51 | + ).reshape( |
| 52 | + (TILE_M, TILE_D) |
| 53 | + ) # [TILE_M, TILE_D] |
| 54 | + |
| 55 | + # loop over k, v and update accumulator |
| 56 | + m_end = input_pos + (bid_x + 1) * TILE_M # int |
| 57 | + k_seqlen = K.shape[2] # int |
| 58 | + if CAUSAL: |
| 59 | + # when kv pos could exceed q pos |
| 60 | + mask_start = (input_pos + bid_x * TILE_M) // TILE_N |
| 61 | + # when kv pos could exceed k_seqlen |
| 62 | + mask_start = min(mask_start, k_seqlen // TILE_N) |
| 63 | + Tc = ct.cdivi(min(m_end, k_seqlen), TILE_N) |
| 64 | + else: |
| 65 | + Tc = ct.cdivi(k_seqlen, TILE_N) |
| 66 | + mask_start = k_seqlen // TILE_N |
| 67 | + |
| 68 | + for j in range(0, Tc): |
| 69 | + # -- compute qk ---- |
| 70 | + k = ct.load( |
| 71 | + K, |
| 72 | + index=(batch_idx, off_kv_h, 0, j), |
| 73 | + shape=(1, 1, TILE_D, TILE_N), |
| 74 | + order=(0, 1, 3, 2), |
| 75 | + latency=2, |
| 76 | + ) |
| 77 | + k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] |
| 78 | + qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32) |
| 79 | + qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] |
| 80 | + if (CAUSAL or not EVEN_K) and j >= mask_start: |
| 81 | + offs_n = j * TILE_N + offs_n_tile |
| 82 | + mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool) |
| 83 | + # out of bound mask |
| 84 | + if not EVEN_K: |
| 85 | + mask = mask & (offs_n < k_seqlen) |
| 86 | + # causal mask |
| 87 | + if CAUSAL: |
| 88 | + mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] |
| 89 | + mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] |
| 90 | + qk += mask |
| 91 | + # Moving qk_scale multiplication after reduce_max is to improve performance. |
| 92 | + m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale) |
| 93 | + qk = qk * qk_scale - m_ij # [TILE_M, TILE_N] |
| 94 | + |
| 95 | + # attention weights |
| 96 | + p = ct.exp2(qk, flush_to_zero=True) # [TILE_M, TILE_N] |
| 97 | + l_ij = ct.sum(p, axis=-1, keepdims=True) # [TILE_M, 1] |
| 98 | + alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # [TILE_M, 1] |
| 99 | + # update m_i and l_i |
| 100 | + l_i = l_i * alpha + l_ij # [TILE_M, 1] |
| 101 | + # scale acc |
| 102 | + acc = acc * alpha # [TILE_M, TILE_N] |
| 103 | + # compute pv |
| 104 | + v = ct.load( |
| 105 | + V, |
| 106 | + index=(batch_idx, off_kv_h, j, 0), |
| 107 | + shape=(1, 1, TILE_N, TILE_D), |
| 108 | + latency=4, |
| 109 | + ).reshape( |
| 110 | + (TILE_N, TILE_D) |
| 111 | + ) # [TILE_N, TILE_D] |
| 112 | + p = p.astype(Q.dtype) |
| 113 | + acc = ct.mma(p, v, acc) # [TILE_M, TILE_N] |
| 114 | + m_i = m_ij # [TILE_M, 1] |
| 115 | + |
| 116 | + acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) |
| 117 | + acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype) |
| 118 | + ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc) |
0 commit comments