Skip to content

Commit a7d7a58

Browse files
authored
fix bugs with triton3.4.0 (#3946)
* fix bugs with triton3.4.0 * random seed * bug fixing
1 parent 735f3d6 commit a7d7a58

File tree

10 files changed

+44
-133
lines changed

10 files changed

+44
-133
lines changed

lmdeploy/pytorch/check_env/triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from .base import BaseChecker
55

6-
MAX_TRITON_VERSION = '3.3.1'
6+
MAX_TRITON_VERSION = '3.4.0'
77
MIN_TRITON_VERSION = '3.0.0'
88

99

lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def apply_rotary_pos_emb_qk_kernel(
4444
feat_mask = feat_offset_l < half_size
4545
feat_offset_l = feat_offset_l % half_size
4646
feat_offset_h = half_size + feat_offset_l
47-
seq_mask = pos_mask[:, None] and feat_mask[None, :]
47+
seq_mask = pos_mask[:, None] & feat_mask[None, :]
4848
cs_offset_l = pos_offset[:, None] * feat_size + feat_offset_l[None, :]
4949
cs_offset_h = pos_offset[:, None] * feat_size + feat_offset_h[None, :]
5050
q_elem_type = Q.dtype.element_ty

lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def fused_moe_blocked_f8_kernel(
125125
k_start = BLOCK_SIZE_K
126126
offs_ksa = k_start // group_ak
127127
offs_ksb = k_start // group_bk
128-
a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=mask_sid and k_start < K, other=1.0)
128+
a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=mask_sid & (k_start < K), other=1.0)
129129
b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)
130130
acc_scale1 = tl.maximum(a_scale * b_scale, 1e-12)
131131
acc_ratio = acc_scale0 / acc_scale1
@@ -136,7 +136,7 @@ def fused_moe_blocked_f8_kernel(
136136
k_start = (k + 2) * BLOCK_SIZE_K
137137
offs_ksa = k_start // group_ak
138138
offs_ksb = k_start // group_bk
139-
a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=mask_sid and k_start < K, other=1.0)
139+
a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=mask_sid & (k_start < K), other=1.0)
140140
b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)
141141

142142
# load ab

lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,23 @@ def quant_fp8_tma(A: Tensor, group_size: int, dtype: torch.dtype = torch.float8_
132132
return _quant_fp8_launcher(A, group_size, out, scales)
133133

134134

135+
def _gemm_fp8_tma_pre_hook(nargs):
136+
BLOCK_M = nargs['BLOCK_M']
137+
BLOCK_N = nargs['BLOCK_N']
138+
BLOCK_K = nargs['BLOCK_K']
139+
nargs['desc_a'].block_shape = (BLOCK_M, BLOCK_K)
140+
nargs['desc_b'].block_shape = (BLOCK_N, BLOCK_K)
141+
142+
135143
@triton.autotune(configs=[
136144
triton.Config({
137145
'BLOCK_M': 128,
138146
'BLOCK_N': 128,
139-
}, num_stages=3, num_warps=8),
147+
}, num_stages=3, num_warps=8, pre_hook=_gemm_fp8_tma_pre_hook),
140148
triton.Config({
141149
'BLOCK_M': 128,
142150
'BLOCK_N': 64,
143-
}, num_stages=3, num_warps=4)
151+
}, num_stages=3, num_warps=4, pre_hook=_gemm_fp8_tma_pre_hook)
144152
],
145153
key=['N', 'K'])
146154
@triton.jit
@@ -162,7 +170,6 @@ def _gemm_fp8_tma_kernel(
162170
stride_bsn: tl.constexpr,
163171
stride_cm,
164172
stride_cn: tl.constexpr,
165-
dtype: tl.constexpr,
166173
BLOCK_M: tl.constexpr,
167174
BLOCK_N: tl.constexpr,
168175
BLOCK_K: tl.constexpr,
@@ -200,8 +207,8 @@ def _gemm_fp8_tma_kernel(
200207
b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)
201208

202209
# load ab
203-
a = tl._experimental_descriptor_load(desc_a, [off_m, off_k], [BLOCK_M, BLOCK_K], dtype)
204-
b = tl._experimental_descriptor_load(desc_b, [off_n, off_k], [BLOCK_N, BLOCK_K], dtype).T
210+
a = desc_a.load([off_m, off_k])
211+
b = desc_b.load([off_n, off_k]).T
205212

206213
# mma
207214
accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None])
@@ -348,42 +355,18 @@ def grid(META):
348355

349356
# run_tma = False
350357
if run_tma:
351-
from .utils import TmaAutoTuneHelper
358+
from .utils import TensorDescriptor
352359

353-
desc_helper = TmaAutoTuneHelper()
354-
desc_helper.init_tma_descriptor('desc_a')
355-
desc_helper.init_tma_descriptor('desc_b')
356-
357-
desc_a = desc_helper.get_tma_descriptor_kernel_param('desc_a')
358-
desc_b = desc_helper.get_tma_descriptor_kernel_param('desc_b')
360+
dummy_block = (1, 1)
361+
desc_a = TensorDescriptor.from_tensor(A, block_shape=dummy_block)
362+
desc_b = TensorDescriptor.from_tensor(B.T, block_shape=dummy_block)
359363

360364
def _grid_tma(META):
361365
"""Grid tma."""
362366
BLOCK_M = META['BLOCK_M']
363367
BLOCK_N = META['BLOCK_N']
364-
desc_helper.fill_2d_tma_descriptor('desc_a',
365-
A.data_ptr(),
366-
dim1=M,
367-
dim0=K,
368-
block_dim1=BLOCK_M,
369-
block_dim0=BLOCK_K,
370-
element_size=A.element_size())
371-
desc_helper.fill_2d_tma_descriptor('desc_b',
372-
B.data_ptr(),
373-
dim1=N,
374-
dim0=K,
375-
block_dim1=BLOCK_N,
376-
block_dim0=BLOCK_K,
377-
element_size=B.element_size())
378368
return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), )
379369

380-
if A.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz):
381-
dtype = tl.float8e4nv
382-
elif A.dtype in (torch.float8_e5m2, torch.float8_e5m2fnuz):
383-
dtype = tl.float8e5
384-
else:
385-
raise RuntimeError(f'Not supported dtype: {A.dtype}')
386-
387370
_gemm_fp8_tma_kernel[_grid_tma](
388371
desc_a,
389372
A_scale,
@@ -402,13 +385,8 @@ def _grid_tma(META):
402385
stride_bsn=B_scale.stride(1),
403386
stride_cm=C.stride(0),
404387
stride_cn=C.stride(1),
405-
dtype=dtype,
406-
# BLOCK_M=BLOCK_M,
407-
# BLOCK_N=BLOCK_N,
408388
BLOCK_K=BLOCK_K,
409389
GROUP_M=8,
410-
# num_warps=num_warps,
411-
# num_stages=num_stages,
412390
)
413391
else:
414392
_gemm_fp8_kernel[grid](

lmdeploy/pytorch/kernels/cuda/flashattention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start
7979
qk = qk * tl_log2(math.e)
8080
qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :])
8181
if window_size > 0:
82-
qk_mask = qk_mask and ((start_n + offs_n[None, :]) >= kv_min_loc[:, None])
82+
qk_mask = qk_mask & ((start_n + offs_n[None, :]) >= kv_min_loc[:, None])
8383
qk = tl.where(
8484
qk_mask,
8585
qk,
@@ -218,7 +218,7 @@ def _flash_prefill_fwd_kernel(
218218
offs_dk = tl.multiple_of(tl.max_contiguous(offs_dk % head_dim_k, BLOCK_DK), BLOCK_DK)
219219
off_q = ((q_start_loc + offs_m[:, None]) * stride_qs + head_id * stride_qh + offs_dk[None, :] * stride_qd)
220220
q_ptrs = q_ptr + off_q
221-
q = tl.load(q_ptrs, mask=(offs_m[:, None] < q_seqlen and mask_dk[None, :]))
221+
q = tl.load(q_ptrs, mask=((offs_m[:, None] < q_seqlen) & mask_dk[None, :]))
222222

223223
k_ptrs = tl.make_block_ptr(
224224
base=k_ptr + kv_start_loc * stride_ks + kv_head_id * stride_kh,
@@ -252,7 +252,7 @@ def _flash_prefill_fwd_kernel(
252252
offs_dk1 = tl.multiple_of(tl.max_contiguous(offs_dk1 % head_dim_k, BLOCK_DK1), BLOCK_DK1)
253253
offs_q1 = ((q_start_loc + offs_m[:, None]) * stride_qs + head_id * stride_qh + offs_dk1[None, :] * stride_qd)
254254
q1_ptrs = q_ptr + offs_q1
255-
q1 = tl.load(q1_ptrs, mask=(offs_m[:, None] < q_seqlen and mask_dk1[None, :]))
255+
q1 = tl.load(q1_ptrs, mask=((offs_m[:, None] < q_seqlen) & mask_dk1[None, :]))
256256
k1_ptrs = tl.make_block_ptr(
257257
base=k_ptr + kv_start_loc * stride_ks + kv_head_id * stride_kh,
258258
shape=(head_dim_k, kv_seqlen),

lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ def _flatten_kv_cache(
7474
offs_dv[None, :] * stride_vod)
7575

7676
kc = tl.load(kc_ptrs)
77-
tl.store(ko_ptrs, kc, mask=mask_bs[:, None] and mask_dk[None, :])
77+
tl.store(ko_ptrs, kc, mask=mask_bs[:, None] & mask_dk[None, :])
7878
if HEAD_DIM_V > 0:
7979
vc = tl.load(vc_ptrs)
80-
tl.store(vo_ptrs, vc, mask=mask_bs[:, None] and mask_dv[None, :])
80+
tl.store(vo_ptrs, vc, mask=mask_bs[:, None] & mask_dv[None, :])
8181

8282

8383
@triton.jit
@@ -181,15 +181,15 @@ def _flatten_kv_cache_quant(
181181
kz = tl.load(ksz_ptrs + stride_kszd)
182182
ksz = ks * kz
183183
kq = (kc * ks[:, None] - ksz[:, None]).to(ko_ptr.dtype.element_ty)
184-
tl.store(ko_ptrs, kq, mask=mask_bs[:, None] and mask_dok[None, :])
184+
tl.store(ko_ptrs, kq, mask=mask_bs[:, None] & mask_dok[None, :])
185185
vc = tl.load(vc_ptrs)
186186
if quant_policy == 4:
187187
vc = _dequant_int4(vc, HEAD_DIM_V, BLOCK_DV)
188188
vs = tl.load(vsz_ptrs)
189189
vz = tl.load(vsz_ptrs + stride_vszd)
190190
vsz = vs * vz
191191
vq = (vc * vs[:, None] - vsz[:, None]).to(vo_ptr.dtype.element_ty)
192-
tl.store(vo_ptrs, vq, mask=mask_bs[:, None] and mask_dov[None, :])
192+
tl.store(vo_ptrs, vq, mask=mask_bs[:, None] & mask_dov[None, :])
193193

194194

195195
def flatten_kv_cache(k_caches: Tensor,

lmdeploy/pytorch/kernels/cuda/pagedattention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def _fwd_grouped_split_kernel(
170170
if start_n + BLOCK_N > history_len or window_size > 0:
171171
qk_mask = history_len >= (start_n + offs_n)
172172
if window_size > 0:
173-
qk_mask = qk_mask and ((start_n + offs_n) >= kv_min_loc)
173+
qk_mask = qk_mask & ((start_n + offs_n) >= kv_min_loc)
174174
qk = tl.where(
175175
qk_mask[None, :],
176176
qk,
@@ -388,7 +388,7 @@ def _fwd_grouped_split_quant_kernel(
388388
if start_n + BLOCK_N > history_len or window_size > 0:
389389
qk_mask = history_len >= (start_n + offs_n)
390390
if window_size > 0:
391-
qk_mask = qk_mask and ((start_n + offs_n) >= kv_min_loc)
391+
qk_mask = qk_mask & ((start_n + offs_n) >= kv_min_loc)
392392
qk = tl.where(
393393
qk_mask[None, :],
394394
qk,

lmdeploy/pytorch/kernels/cuda/utils.py

Lines changed: 6 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
(12, 0): 24,
3030
}
3131

32+
TRITON_VERSION = version.parse(triton.__version__)
33+
3234

3335
@functools.lru_cache
3436
def get_device_props(device=None):
@@ -57,86 +59,9 @@ def supports_tma():
5759
if not ret:
5860
return False
5961

60-
TRITON_VERSION = version.parse(triton.__version__)
61-
VALID_VERSION = version.parse('3.2.0')
62-
return TRITON_VERSION >= VALID_VERSION
63-
64-
65-
# Copy from:
66-
# https://github.com/triton-lang/triton/blob/main/python/triton/tools/experimental_descriptor.py
67-
class TmaDescKernelParam:
68-
TMA_DESC_SIZE = 128
69-
70-
def __init__(self):
71-
self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device='cpu')
72-
73-
def fill_(self, ptr, dims, block_dims, element_size):
74-
assert len(dims) == len(block_dims)
75-
assert 1 <= len(dims) <= 2
76-
assert self.desc.data_ptr() % 64 == 0
77-
78-
if len(dims) == 1:
79-
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size,
80-
self.desc.data_ptr())
81-
else:
82-
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0],
83-
block_dims[1], element_size, self.desc.data_ptr())
84-
85-
# Return a CUtensorMap* pointer in host memory
86-
def tma_desc_cpu_ptr(self):
87-
return self.desc.data_ptr()
88-
89-
90-
# Copy from:
91-
# https://github.com/triton-lang/triton/blob/main/python/triton/tools/experimental_descriptor.py
92-
def create_1d_tma_descriptor_custom(ptr, dim, block_dim, element_size):
93-
desc = TmaDescKernelParam()
94-
desc.fill_(ptr, [dim], [block_dim], element_size)
95-
return desc
96-
97-
98-
# Copy from:
99-
# https://github.com/triton-lang/triton/blob/main/python/triton/tools/experimental_descriptor.py
100-
def create_2d_tma_descriptor_custom(ptr, dim1, dim0, block_dim1, block_dim0, element_size):
101-
desc = TmaDescKernelParam()
102-
desc.fill_(ptr, [dim1, dim0], [block_dim1, block_dim0], element_size)
103-
return desc
104-
105-
106-
try:
107-
from triton.tools.experimental_descriptor import create_1d_tma_descriptor, create_2d_tma_descriptor # noqa
108-
except BaseException:
109-
create_1d_tma_descriptor = create_1d_tma_descriptor_custom
110-
create_2d_tma_descriptor = create_2d_tma_descriptor_custom
111-
112-
113-
class TmaAutoTuneHelper:
114-
115-
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
116-
class KernelParamWrapper:
117-
118-
def __init__(self, desc):
119-
self.desc = desc
120-
121-
def tma_desc_cpu_ptr(self):
122-
return self.desc.data_ptr()
123-
124-
TMA_SIZE = 128
125-
126-
def __init__(self):
127-
self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor)
128-
self.descriptors = {}
129-
130-
# Call this method outside of the lambda function for grid size
131-
def init_tma_descriptor(self, name):
132-
self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device='cpu', dtype=torch.int8)
62+
VALID_VERSION = version.parse('3.4.0')
63+
return TRITON_VERSION == VALID_VERSION
13364

134-
# Call this method inside the lambda function for grid size
135-
def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size):
136-
desc_x = self.descriptors[name]
137-
assert desc_x.data_ptr() % 64 == 0
138-
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
13965

140-
def get_tma_descriptor_kernel_param(self, name):
141-
assert self.descriptors[name] is not None
142-
return self.KernelParamWrapper(self.descriptors[name])
66+
if supports_tma():
67+
from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F401

lmdeploy/pytorch/messages.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,18 @@ def from_gen_config(self, gen_config: GenerationConfig):
114114
logprobs = gen_config.logprobs
115115
if logprobs is None:
116116
logprobs = -1
117+
118+
random_seed = gen_config.random_seed
119+
if random_seed is None:
120+
import random
121+
random_seed = random.getrandbits(64)
117122
return SamplingParam(top_p=top_p,
118123
top_k=top_k,
119124
min_p=min_p,
120125
temperature=temperature,
121126
repetition_penalty=repetition_penalty,
122127
ignore_eos=gen_config.ignore_eos,
123-
random_seed=gen_config.random_seed,
128+
random_seed=random_seed,
124129
stop_words=stop_words,
125130
bad_words=bad_words,
126131
response_format=response_format,

tests/pytorch/kernel/test_apply_rotary.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,8 @@ def test_apply_rotary(self, q_states, k_states, cos, sin, gt):
9494
if q_states.dtype == torch.float16:
9595
rtol = 1e-5
9696
atol = 1e-3
97+
elif q_states.dtype == torch.bfloat16:
98+
rtol = 1e-5
99+
atol = 1e-2
97100
torch.testing.assert_close(q_embed, q_gt, rtol=rtol, atol=atol)
98101
torch.testing.assert_close(k_embed, k_gt, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)