Skip to content

Commit 075a68b

Browse files
committed
intial checkin for cutile kernel support
1 parent 9bba6d7 commit 075a68b

File tree

10 files changed

+814
-19
lines changed

10 files changed

+814
-19
lines changed

tests/py/dynamo/automatic_plugin/cutile/__init__.py

Whitespace-only changes.
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
# copied from https://gitlab-master.nvidia.com/cuda-python/cuda-python-tile-compiler/-/raw/main/test/kernels/attention.py?ref_type=heads
5+
6+
import math
7+
8+
import cuda.tile as ct
9+
import numpy as np
10+
from cuda.tile.numeric_semantics import RoundingMode as RMd
11+
12+
INV_LOG_2 = 1.0 / math.log(2)
13+
14+
15+
@ct.kernel(occupancy=2)
16+
def fmha_kernel(
17+
Q,
18+
K,
19+
V,
20+
Out,
21+
qk_scale: float,
22+
input_pos: int,
23+
TILE_D: ct.Constant[int], # TILE_D = hidden_size
24+
H: ct.Constant[int],
25+
TILE_M: ct.Constant[int],
26+
TILE_N: ct.Constant[int],
27+
QUERY_GROUP_SIZE: ct.Constant[int],
28+
CAUSAL: ct.Constant[bool],
29+
EVEN_K: ct.Constant[bool],
30+
):
31+
bid_x = ct.bid(0) # int
32+
bid_y = ct.bid(1) # int
33+
batch_idx = bid_y // H # int
34+
head_idx = bid_y % H # int
35+
off_kv_h = head_idx // QUERY_GROUP_SIZE # int
36+
qk_scale = qk_scale * INV_LOG_2
37+
# init offset
38+
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32) # [TILE_M]
39+
offs_m += input_pos
40+
offs_m = offs_m[:, None] # [TILE_M, 1]
41+
offs_n_tile = ct.arange(TILE_N, dtype=np.int32) # [TILE_N]
42+
offs_n_tile = offs_n_tile[None, :] # [1, TILE_N]
43+
44+
# initialize m, l, acc
45+
m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32)
46+
l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32)
47+
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)
48+
# load q
49+
q = ct.load(
50+
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
51+
).reshape(
52+
(TILE_M, TILE_D)
53+
) # [TILE_M, TILE_D]
54+
55+
# loop over k, v and update accumulator
56+
m_end = input_pos + (bid_x + 1) * TILE_M # int
57+
k_seqlen = K.shape[2] # int
58+
if CAUSAL:
59+
# when kv pos could exceed q pos
60+
mask_start = (input_pos + bid_x * TILE_M) // TILE_N
61+
# when kv pos could exceed k_seqlen
62+
mask_start = min(mask_start, k_seqlen // TILE_N)
63+
Tc = ct.cdivi(min(m_end, k_seqlen), TILE_N)
64+
else:
65+
Tc = ct.cdivi(k_seqlen, TILE_N)
66+
mask_start = k_seqlen // TILE_N
67+
68+
for j in range(0, Tc):
69+
# -- compute qk ----
70+
k = ct.load(
71+
K,
72+
index=(batch_idx, off_kv_h, 0, j),
73+
shape=(1, 1, TILE_D, TILE_N),
74+
order=(0, 1, 3, 2),
75+
latency=2,
76+
)
77+
k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N]
78+
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32)
79+
qk = ct.mma(q, k, qk) # [TILE_M, TILE_N]
80+
if (CAUSAL or not EVEN_K) and j >= mask_start:
81+
offs_n = j * TILE_N + offs_n_tile
82+
mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool)
83+
# out of bound mask
84+
if not EVEN_K:
85+
mask = mask & (offs_n < k_seqlen)
86+
# causal mask
87+
if CAUSAL:
88+
mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N]
89+
mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N]
90+
qk += mask
91+
# Moving qk_scale multiplication after reduce_max is to improve performance.
92+
m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale)
93+
qk = qk * qk_scale - m_ij # [TILE_M, TILE_N]
94+
95+
# attention weights
96+
p = ct.exp2(qk, flush_to_zero=True) # [TILE_M, TILE_N]
97+
l_ij = ct.sum(p, axis=-1, keepdims=True) # [TILE_M, 1]
98+
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # [TILE_M, 1]
99+
# update m_i and l_i
100+
l_i = l_i * alpha + l_ij # [TILE_M, 1]
101+
# scale acc
102+
acc = acc * alpha # [TILE_M, TILE_N]
103+
# compute pv
104+
v = ct.load(
105+
V,
106+
index=(batch_idx, off_kv_h, j, 0),
107+
shape=(1, 1, TILE_N, TILE_D),
108+
latency=4,
109+
).reshape(
110+
(TILE_N, TILE_D)
111+
) # [TILE_N, TILE_D]
112+
p = p.astype(Q.dtype)
113+
acc = ct.mma(p, v, acc) # [TILE_M, TILE_N]
114+
m_i = m_ij # [TILE_M, 1]
115+
116+
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
117+
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
118+
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
# copied from https://gitlab-master.nvidia.com/cuda-python/cuda-python-tile-compiler/-/blob/main/test/kernels/matmul.py?ref_type=heads
6+
7+
import cuda.tile as ct
8+
from cuda.tile.by_target import ByTarget
9+
10+
ConstInt = ct.Constant[int]
11+
12+
13+
def swizzle_2d(M, N, tm, tn, GROUP_SIZE_M):
14+
bid = ct.bid(0)
15+
num_bid_m = ct.cdivi(M, tm)
16+
num_bid_n = ct.cdivi(N, tn)
17+
num_bid_in_group = GROUP_SIZE_M * num_bid_n
18+
group_id = bid // num_bid_in_group
19+
first_bid_m = group_id * GROUP_SIZE_M
20+
group_size_m = min(num_bid_m - first_bid_m, GROUP_SIZE_M)
21+
bid_m = first_bid_m + (bid % group_size_m)
22+
bid_n = (bid % num_bid_in_group) // group_size_m
23+
return bid_m, bid_n
24+
25+
26+
@ct.kernel(num_ctas=ByTarget(sm_100=2))
27+
def matmul_kernel(A, B, C, tm: ConstInt, tn: ConstInt, tk: ConstInt):
28+
GROUP_SIZE_M = 8
29+
M = A.shape[0]
30+
N = B.shape[1]
31+
bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M)
32+
33+
num_tiles = ct.dim(A, axis=1, shape=(tm, tk))
34+
sum = ct.full((tm, tn), 0, dtype=ct.float32)
35+
zero_pad = ct.PaddingValue.ZERO
36+
37+
# Convert fp32 to tf32 to use tensorcore
38+
dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype
39+
40+
for k in range(num_tiles):
41+
a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_value=zero_pad).astype(
42+
dtype
43+
)
44+
b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_value=zero_pad).astype(
45+
dtype
46+
)
47+
sum = ct.mma(a, b, sum)
48+
49+
sum = ct.astype(sum, C.dtype)
50+
ct.store(C, index=(bidx, bidy), tile=sum)
51+
52+
53+
@ct.kernel
54+
def matmul_split_k_kernel(
55+
A, B, C, LOCKS, COUNTS, tm: ConstInt, tn: ConstInt, tk: ConstInt, SPLIT_K: ConstInt
56+
):
57+
GROUP_SIZE_M = 8
58+
M = A.shape[0]
59+
N = B.shape[1]
60+
bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M)
61+
bidz = ct.bid(1)
62+
63+
num_tiles = ct.dim(A, axis=1, shape=(tm, tk))
64+
sum = ct.full((tm, tn), 0, dtype=ct.float32)
65+
zero_pad = ct.PaddingValue.ZERO
66+
67+
# Convert fp32 to tf32 to use tensorcore
68+
dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype
69+
70+
for k in range(bidz, num_tiles, SPLIT_K):
71+
a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_value=zero_pad).astype(
72+
dtype
73+
)
74+
b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_value=zero_pad).astype(
75+
dtype
76+
)
77+
sum = ct.mma(a, b, sum)
78+
79+
sum = ct.astype(sum, C.dtype)
80+
lock_offset = ct.bid(0)
81+
count_offset = lock_offset
82+
while (
83+
ct.atomic_cas(LOCKS, lock_offset, 0, 1, memory_order=ct.MemoryOrder.ACQUIRE)
84+
== 1
85+
):
86+
pass
87+
count = ct.load_offset(COUNTS, count_offset)
88+
if count == 0:
89+
ct.store(C, index=(bidx, bidy), tile=sum)
90+
else:
91+
curr = ct.load(C, index=(bidx, bidy), shape=(tm, tn))
92+
ct.store(C, index=(bidx, bidy), tile=(curr + sum))
93+
ct.store_offset(COUNTS, count_offset, (count + 1) % SPLIT_K)
94+
ct.atomic_xchg(LOCKS, lock_offset, 0, memory_order=ct.MemoryOrder.RELEASE)

0 commit comments

Comments
 (0)