Skip to content

Commit d3b6066

Browse files
author
jax authors
committed
Merge pull request #22820 from Rifur13:mha-faster
PiperOrigin-RevId: 660461104
2 parents 32131d0 + 181d17e commit d3b6066

File tree

2 files changed

+164
-112
lines changed

2 files changed

+164
-112
lines changed

jax/experimental/pallas/ops/gpu/attention.py

Lines changed: 162 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,6 @@ def body(start_k, carry):
7070
curr_k_slice = pl.dslice(start_k * block_k, block_k)
7171

7272
k = pl.load(k_ref, (curr_k_slice, slice(None)))
73-
kv_segment_ids = (
74-
None
75-
if segment_ids_ref is None
76-
else pl.load(segment_ids_ref, (curr_k_slice,))
77-
)
7873
qk = pl.dot(q, k.T) # [block_q, block_k]
7974
if sm_scale != 1.:
8075
qk *= sm_scale # [block_q, block_k]
@@ -87,6 +82,7 @@ def body(start_k, carry):
8782
if causal or segment_ids_ref is not None:
8883
mask = None
8984
if segment_ids_ref is not None:
85+
kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,))
9086
mask = segment_mask(q_segment_ids, kv_segment_ids)
9187
if causal:
9288
span_q = start_q * block_q + jnp.arange(block_q)
@@ -354,6 +350,9 @@ def _preprocess_backward(out, do, l, block_q: int,
354350
return do_scaled, delta
355351

356352

353+
# This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence
354+
# length.
355+
# Inspired by the triton tutorial: https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py
357356
def mha_backward_kernel(
358357
# Inputs
359358
q_ref,
@@ -365,92 +364,148 @@ def mha_backward_kernel(
365364
l_ref,
366365
m_ref,
367366
delta_ref,
368-
_,
369367
# Outputs
370368
dq_ref,
371369
dk_ref,
372370
dv_ref,
373371
*,
374372
sm_scale: float,
375373
causal: bool,
376-
block_q: int,
374+
block_q1: int,
375+
block_k1: int,
376+
block_q2: int,
377+
block_k2: int,
377378
block_d: int,
378-
block_k: int,
379379
):
380380
del out_ref, l_ref # Not needed
381381
seq_len = q_ref.shape[0]
382382

383-
def outer_loop(start_k, _):
384-
385-
dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
386-
dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
387-
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
388-
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
389-
span_k = start_k * block_k + jnp.arange(block_k)
390-
kv_segment_ids = (
391-
None
392-
if segment_ids_ref is None
393-
else pl.load(segment_ids_ref, (pl.ds(start_k * block_k, block_k),))
394-
)
395-
396-
def inner_loop(start_q, carry):
397-
dv, dk = carry
398-
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
399-
qk = pl.dot(q, k.T)
400-
qk = qk.astype(q_ref.dtype)
401-
qk = qk.astype(jnp.float32)
402-
if sm_scale != 1.0:
403-
qk *= sm_scale
404-
405-
q_segment_ids = (
406-
None
407-
if segment_ids_ref is None
408-
else pl.load(segment_ids_ref, (pl.ds(start_q * block_q, block_q),))
409-
)
410-
411-
if causal or segment_ids_ref is not None:
412-
mask = None
413-
if segment_ids_ref is not None:
414-
mask = segment_mask(q_segment_ids, kv_segment_ids)
415-
416-
if causal:
417-
span_q = start_q * block_q + jnp.arange(block_q)
418-
causal_mask = span_q[:, None] >= span_k[None, :]
419-
mask = (
420-
causal_mask
421-
if mask is None
422-
else jnp.logical_and(mask, causal_mask)
423-
)
424-
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
425-
426-
m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
427-
p = jnp.exp(qk - m[:, None])
428-
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
429-
dv = dv + pl.dot(p.astype(do.dtype).T, do)
430-
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
431-
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
432-
dp = dp + pl.dot(do, v.T)
433-
ds = p * dp
434-
if sm_scale != 1.0:
435-
ds = ds * sm_scale
436-
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
437-
dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q),
438-
slice(None)), eviction_policy="evict_last")
439-
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
440-
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
441-
slice(None)), dq, eviction_policy="evict_last")
442-
return dv, dk
443-
if causal:
444-
lower_bound = lax.div(start_k * block_k, block_q)
445-
else:
446-
lower_bound = 0
447-
dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop,
448-
(dv, dk))
449-
pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
450-
slice(None)), dv.astype(dv_ref.dtype))
451-
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
452-
slice(None)), dk.astype(dk_ref.dtype))
453-
lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None)
383+
# Scan #1: dK and dV
384+
# 1. Load a block of K and V of size (block_k1, head_dim) in SMEM.
385+
# 2. Iterate through Q in chunks of (block_q1, head_dim) to accumulate
386+
# dK and dV.
387+
start_k = pl.program_id(2)
388+
curr_k_slice = pl.dslice(start_k * block_k1, block_k1)
389+
390+
dv = jnp.zeros([block_k1, block_d], dtype=jnp.float32)
391+
dk = jnp.zeros([block_k1, block_d], dtype=jnp.float32)
392+
393+
v = pl.load(v_ref, (curr_k_slice, slice(None)))
394+
k = pl.load(k_ref, (curr_k_slice, slice(None)))
395+
span_k = start_k * block_k1 + jnp.arange(block_k1)
396+
kv_segment_ids = (
397+
None
398+
if segment_ids_ref is None
399+
else pl.load(segment_ids_ref, (curr_k_slice,))
400+
)
401+
402+
def inner_loop_dkdv(start_q, carry):
403+
dv, dk = carry
404+
curr_q_slice = pl.dslice(start_q * block_q1, block_q1)
405+
406+
q = pl.load(q_ref, (curr_q_slice, slice(None)))
407+
qk = pl.dot(q, k.T)
408+
if sm_scale != 1.0:
409+
qk *= sm_scale
410+
411+
if causal or segment_ids_ref is not None:
412+
mask = None
413+
if segment_ids_ref is not None:
414+
q_segment_ids = pl.load(segment_ids_ref, (curr_q_slice,))
415+
mask = segment_mask(q_segment_ids, kv_segment_ids)
416+
417+
if causal:
418+
span_q = start_q * block_q1 + jnp.arange(block_q1)
419+
causal_mask = span_q[:, None] >= span_k[None, :]
420+
mask = (
421+
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
422+
)
423+
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
424+
425+
m = pl.load(m_ref, (curr_q_slice,))
426+
di = pl.load(delta_ref, (curr_q_slice,))
427+
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)))
428+
429+
p = jnp.exp(qk - m[:, None])
430+
dv = dv + pl.dot(p.astype(do.dtype).T, do)
431+
dp = jnp.zeros((block_q1, block_k1), dtype=jnp.float32) - di[:, None]
432+
dp = dp + pl.dot(do, v.T)
433+
ds = p * dp
434+
if sm_scale != 1.0:
435+
ds = ds * sm_scale
436+
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
437+
438+
return dv, dk
439+
440+
lower_bound = lax.div(start_k * block_k1, block_q1) if causal else 0
441+
dv, dk = lax.fori_loop(
442+
lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk)
443+
)
444+
pl.store(dv_ref, (curr_k_slice, slice(None)), dv.astype(dv_ref.dtype))
445+
pl.store(dk_ref, (curr_k_slice, slice(None)), dk.astype(dk_ref.dtype))
446+
447+
del dv, dk
448+
449+
# Scan #2: dQ
450+
# 1. Load a block of Q of size (block_q2, head_dim) in SMEM.
451+
# 2. Iterate through K and V in chunks of (block_k2, head_dim) to
452+
# accumulate dQ.
453+
start_q = pl.program_id(2)
454+
curr_q_slice = pl.ds(start_q * block_q2, block_q2)
455+
span_q = start_q * block_q2 + jnp.arange(block_q2)
456+
dq = jnp.zeros([block_q2, block_d], dtype=jnp.float32)
457+
458+
q = pl.load(q_ref, (curr_q_slice, slice(None)))
459+
q_segment_ids = (
460+
None
461+
if segment_ids_ref is None
462+
else pl.load(segment_ids_ref, (curr_q_slice,))
463+
)
464+
m = pl.load(m_ref, (curr_q_slice,))
465+
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)))
466+
di = pl.load(delta_ref, (curr_q_slice,))
467+
468+
def inner_loop_dq(start_k, dq):
469+
curr_k_slice = pl.dslice(start_k * block_k2, block_k2)
470+
k = pl.load(k_ref, (curr_k_slice, slice(None)))
471+
v = pl.load(v_ref, (curr_k_slice, slice(None)))
472+
473+
qk = pl.dot(q, k.T)
474+
if sm_scale != 1.0:
475+
qk *= sm_scale
476+
477+
if causal or segment_ids_ref is not None:
478+
mask = None
479+
if segment_ids_ref is not None:
480+
kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,))
481+
mask = segment_mask(q_segment_ids, kv_segment_ids)
482+
483+
if causal:
484+
span_k = start_k * block_k2 + jnp.arange(block_k2)
485+
causal_mask = span_q[:, None] >= span_k[None, :]
486+
mask = (
487+
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
488+
)
489+
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
490+
491+
p = jnp.exp(qk - m[:, None])
492+
dp = jnp.zeros((block_q2, block_k2), dtype=jnp.float32) - di[:, None]
493+
dp = dp + pl.dot(do, v.T)
494+
ds = p * dp
495+
if sm_scale != 1.0:
496+
ds = ds * sm_scale
497+
498+
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
499+
500+
return dq
501+
502+
if causal:
503+
upper_bound = lax.div((start_q + 1) * block_q2, block_k2)
504+
else:
505+
upper_bound = pl.cdiv(seq_len, block_k2)
506+
507+
dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq))
508+
pl.store(dq_ref, (curr_q_slice, slice(None)), dq.astype(dq_ref.dtype))
454509

455510

456511
def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
@@ -473,75 +528,72 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
473528
block_q = min(block_q, seq_len)
474529
block_k = min(block_k, seq_len)
475530
do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret)
476-
# We accumulate into dq so we need to initialize it to zeros.
477-
dq = jnp.zeros(q.shape, jnp.float32)
478531
out_shapes = [
479-
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
480-
jax.ShapeDtypeStruct(k.shape, k.dtype),
481-
jax.ShapeDtypeStruct(v.shape, v.dtype),
532+
jax.ShapeDtypeStruct(q.shape, q.dtype),
533+
jax.ShapeDtypeStruct(k.shape, k.dtype),
534+
jax.ShapeDtypeStruct(v.shape, v.dtype),
482535
]
483536

484537
in_specs = [
485538
pl.BlockSpec(
486-
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
487-
),
488-
pl.BlockSpec(
489-
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
539+
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
490540
),
491541
pl.BlockSpec(
492-
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
542+
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
493543
),
494544
pl.BlockSpec(
495-
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
545+
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
496546
),
497547
pl.BlockSpec(
498-
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
548+
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
499549
),
500-
pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)),
501-
pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)),
502-
pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)),
503550
pl.BlockSpec(
504-
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
551+
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
505552
),
553+
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
554+
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
555+
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
506556
]
507557
if segment_ids is None:
508558
in_specs.insert(3, None) # type: ignore[arg-type]
509-
input_output_aliases = {8: 0}
510559
else:
511-
in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda j, k: (j, 0)))
512-
input_output_aliases = {9: 0}
513-
grid = (batch_size, num_heads)
514-
# TODO(sharadmv): figure out why num_warps=8 doesn't work!
560+
in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda i, j, _: (i, 0)))
561+
562+
grid = (batch_size, num_heads, pl.cdiv(seq_len, block_k))
515563
num_warps = 8
516564
dq, dk, dv = pl.pallas_call(
517565
functools.partial(
518566
mha_backward_kernel,
519-
block_q=block_q,
520-
block_d=head_dim,
521-
block_k=block_k,
522567
sm_scale=sm_scale,
523568
causal=causal,
569+
block_q1=block_q,
570+
block_k1=block_k,
571+
block_q2=block_q,
572+
block_k2=block_k,
573+
block_d=head_dim,
524574
),
525-
grid=grid,
526575
out_shape=out_shapes,
527576
in_specs=in_specs,
577+
grid=grid,
528578
out_specs=[
529579
pl.BlockSpec(
530-
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
580+
(None, seq_len, None, head_dim),
581+
lambda i, j, _: (i, 0, j, 0), # dq
531582
),
532583
pl.BlockSpec(
533-
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
584+
(None, seq_len, None, head_dim),
585+
lambda i, j, _: (i, 0, j, 0), # dk
534586
),
535587
pl.BlockSpec(
536-
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
588+
(None, seq_len, None, head_dim),
589+
lambda i, j, _: (i, 0, j, 0), # dv
537590
),
538591
],
539592
name="mha_backward",
540593
debug=debug,
541594
interpret=interpret,
542-
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)),
543-
input_output_aliases=input_output_aliases,
544-
)(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq)
595+
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=2)),
596+
)(q, k, v, segment_ids, out, do_scaled, l, m, delta)
545597
else:
546598
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
547599
return dq.astype(q.dtype), dk, dv, None

tests/pallas/gpu_ops_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ def impl(q, k, v):
252252
(1, 384, 1, 32, False, False),
253253
(2, 384, 2, 32, False, True),
254254
(2, 384, 2, 32, False, False),
255-
# TODO(b/283035396): (1, 384, 1, 32, True, True),
256-
# TODO(b/283035396): (2, 384, 2, 32, True, True),
255+
(1, 384, 1, 32, True, True),
256+
(2, 384, 2, 32, True, True),
257257
]
258258
]
259259
)

0 commit comments

Comments
 (0)