@@ -70,11 +70,6 @@ def body(start_k, carry):
7070 curr_k_slice = pl .dslice (start_k * block_k , block_k )
7171
7272 k = pl .load (k_ref , (curr_k_slice , slice (None )))
73- kv_segment_ids = (
74- None
75- if segment_ids_ref is None
76- else pl .load (segment_ids_ref , (curr_k_slice ,))
77- )
7873 qk = pl .dot (q , k .T ) # [block_q, block_k]
7974 if sm_scale != 1. :
8075 qk *= sm_scale # [block_q, block_k]
@@ -87,6 +82,7 @@ def body(start_k, carry):
8782 if causal or segment_ids_ref is not None :
8883 mask = None
8984 if segment_ids_ref is not None :
85+ kv_segment_ids = pl .load (segment_ids_ref , (curr_k_slice ,))
9086 mask = segment_mask (q_segment_ids , kv_segment_ids )
9187 if causal :
9288 span_q = start_q * block_q + jnp .arange (block_q )
@@ -354,6 +350,9 @@ def _preprocess_backward(out, do, l, block_q: int,
354350 return do_scaled , delta
355351
356352
353+ # This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence
354+ # length.
355+ # Inspired by the triton tutorial: https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py
357356def mha_backward_kernel (
358357 # Inputs
359358 q_ref ,
@@ -365,92 +364,148 @@ def mha_backward_kernel(
365364 l_ref ,
366365 m_ref ,
367366 delta_ref ,
368- _ ,
369367 # Outputs
370368 dq_ref ,
371369 dk_ref ,
372370 dv_ref ,
373371 * ,
374372 sm_scale : float ,
375373 causal : bool ,
376- block_q : int ,
374+ block_q1 : int ,
375+ block_k1 : int ,
376+ block_q2 : int ,
377+ block_k2 : int ,
377378 block_d : int ,
378- block_k : int ,
379379):
380380 del out_ref , l_ref # Not needed
381381 seq_len = q_ref .shape [0 ]
382382
383- def outer_loop (start_k , _ ):
384-
385- dv = jnp .zeros ([block_k , block_d ], dtype = jnp .float32 )
386- dk = jnp .zeros ([block_k , block_d ], dtype = jnp .float32 )
387- k = pl .load (k_ref , (pl .ds (start_k * block_k , block_k ), slice (None )))
388- v = pl .load (v_ref , (pl .ds (start_k * block_k , block_k ), slice (None )))
389- span_k = start_k * block_k + jnp .arange (block_k )
390- kv_segment_ids = (
391- None
392- if segment_ids_ref is None
393- else pl .load (segment_ids_ref , (pl .ds (start_k * block_k , block_k ),))
394- )
395-
396- def inner_loop (start_q , carry ):
397- dv , dk = carry
398- q = pl .load (q_ref , (pl .ds (start_q * block_q , block_q ), slice (None )))
399- qk = pl .dot (q , k .T )
400- qk = qk .astype (q_ref .dtype )
401- qk = qk .astype (jnp .float32 )
402- if sm_scale != 1.0 :
403- qk *= sm_scale
404-
405- q_segment_ids = (
406- None
407- if segment_ids_ref is None
408- else pl .load (segment_ids_ref , (pl .ds (start_q * block_q , block_q ),))
409- )
410-
411- if causal or segment_ids_ref is not None :
412- mask = None
413- if segment_ids_ref is not None :
414- mask = segment_mask (q_segment_ids , kv_segment_ids )
415-
416- if causal :
417- span_q = start_q * block_q + jnp .arange (block_q )
418- causal_mask = span_q [:, None ] >= span_k [None , :]
419- mask = (
420- causal_mask
421- if mask is None
422- else jnp .logical_and (mask , causal_mask )
423- )
424- qk = jnp .where (mask , qk , DEFAULT_MASK_VALUE )
425-
426- m = pl .load (m_ref , (pl .ds (start_q * block_q , block_q ),))
427- p = jnp .exp (qk - m [:, None ])
428- do = pl .load (do_scaled_ref , (pl .ds (start_q * block_q , block_q ), slice (None )))
429- dv = dv + pl .dot (p .astype (do .dtype ).T , do )
430- di = pl .load (delta_ref , (pl .ds (start_q * block_q , block_q ),))
431- dp = jnp .zeros ((block_q , block_k ), dtype = jnp .float32 ) - di [:, None ]
432- dp = dp + pl .dot (do , v .T )
433- ds = p * dp
434- if sm_scale != 1.0 :
435- ds = ds * sm_scale
436- dk = dk + pl .dot (ds .astype (q_ref .dtype ).T , q )
437- dq = pl .load (dq_ref , (pl .ds (start_q * block_q , block_q ),
438- slice (None )), eviction_policy = "evict_last" )
439- dq = dq + pl .dot (ds .astype (k .dtype ), k ).astype (dq .dtype )
440- pl .store (dq_ref , (pl .ds (start_q * block_q , block_q ),
441- slice (None )), dq , eviction_policy = "evict_last" )
442- return dv , dk
443- if causal :
444- lower_bound = lax .div (start_k * block_k , block_q )
445- else :
446- lower_bound = 0
447- dv , dk = lax .fori_loop (lower_bound , pl .cdiv (seq_len , block_q ), inner_loop ,
448- (dv , dk ))
449- pl .store (dv_ref , (pl .ds (start_k * block_k , block_k ),
450- slice (None )), dv .astype (dv_ref .dtype ))
451- pl .store (dk_ref , (pl .ds (start_k * block_k , block_k ),
452- slice (None )), dk .astype (dk_ref .dtype ))
453- lax .fori_loop (0 , pl .cdiv (seq_len , block_k ), outer_loop , None )
383+ # Scan #1: dK and dV
384+ # 1. Load a block of K and V of size (block_k1, head_dim) in SMEM.
385+ # 2. Iterate through Q in chunks of (block_q1, head_dim) to accumulate
386+ # dK and dV.
387+ start_k = pl .program_id (2 )
388+ curr_k_slice = pl .dslice (start_k * block_k1 , block_k1 )
389+
390+ dv = jnp .zeros ([block_k1 , block_d ], dtype = jnp .float32 )
391+ dk = jnp .zeros ([block_k1 , block_d ], dtype = jnp .float32 )
392+
393+ v = pl .load (v_ref , (curr_k_slice , slice (None )))
394+ k = pl .load (k_ref , (curr_k_slice , slice (None )))
395+ span_k = start_k * block_k1 + jnp .arange (block_k1 )
396+ kv_segment_ids = (
397+ None
398+ if segment_ids_ref is None
399+ else pl .load (segment_ids_ref , (curr_k_slice ,))
400+ )
401+
402+ def inner_loop_dkdv (start_q , carry ):
403+ dv , dk = carry
404+ curr_q_slice = pl .dslice (start_q * block_q1 , block_q1 )
405+
406+ q = pl .load (q_ref , (curr_q_slice , slice (None )))
407+ qk = pl .dot (q , k .T )
408+ if sm_scale != 1.0 :
409+ qk *= sm_scale
410+
411+ if causal or segment_ids_ref is not None :
412+ mask = None
413+ if segment_ids_ref is not None :
414+ q_segment_ids = pl .load (segment_ids_ref , (curr_q_slice ,))
415+ mask = segment_mask (q_segment_ids , kv_segment_ids )
416+
417+ if causal :
418+ span_q = start_q * block_q1 + jnp .arange (block_q1 )
419+ causal_mask = span_q [:, None ] >= span_k [None , :]
420+ mask = (
421+ causal_mask if mask is None else jnp .logical_and (mask , causal_mask )
422+ )
423+ qk = jnp .where (mask , qk , DEFAULT_MASK_VALUE )
424+
425+ m = pl .load (m_ref , (curr_q_slice ,))
426+ di = pl .load (delta_ref , (curr_q_slice ,))
427+ do = pl .load (do_scaled_ref , (curr_q_slice , slice (None )))
428+
429+ p = jnp .exp (qk - m [:, None ])
430+ dv = dv + pl .dot (p .astype (do .dtype ).T , do )
431+ dp = jnp .zeros ((block_q1 , block_k1 ), dtype = jnp .float32 ) - di [:, None ]
432+ dp = dp + pl .dot (do , v .T )
433+ ds = p * dp
434+ if sm_scale != 1.0 :
435+ ds = ds * sm_scale
436+ dk = dk + pl .dot (ds .astype (q_ref .dtype ).T , q )
437+
438+ return dv , dk
439+
440+ lower_bound = lax .div (start_k * block_k1 , block_q1 ) if causal else 0
441+ dv , dk = lax .fori_loop (
442+ lower_bound , pl .cdiv (seq_len , block_q1 ), inner_loop_dkdv , (dv , dk )
443+ )
444+ pl .store (dv_ref , (curr_k_slice , slice (None )), dv .astype (dv_ref .dtype ))
445+ pl .store (dk_ref , (curr_k_slice , slice (None )), dk .astype (dk_ref .dtype ))
446+
447+ del dv , dk
448+
449+ # Scan #2: dQ
450+ # 1. Load a block of Q of size (block_q2, head_dim) in SMEM.
451+ # 2. Iterate through K and V in chunks of (block_k2, head_dim) to
452+ # accumulate dQ.
453+ start_q = pl .program_id (2 )
454+ curr_q_slice = pl .ds (start_q * block_q2 , block_q2 )
455+ span_q = start_q * block_q2 + jnp .arange (block_q2 )
456+ dq = jnp .zeros ([block_q2 , block_d ], dtype = jnp .float32 )
457+
458+ q = pl .load (q_ref , (curr_q_slice , slice (None )))
459+ q_segment_ids = (
460+ None
461+ if segment_ids_ref is None
462+ else pl .load (segment_ids_ref , (curr_q_slice ,))
463+ )
464+ m = pl .load (m_ref , (curr_q_slice ,))
465+ do = pl .load (do_scaled_ref , (curr_q_slice , slice (None )))
466+ di = pl .load (delta_ref , (curr_q_slice ,))
467+
468+ def inner_loop_dq (start_k , dq ):
469+ curr_k_slice = pl .dslice (start_k * block_k2 , block_k2 )
470+ k = pl .load (k_ref , (curr_k_slice , slice (None )))
471+ v = pl .load (v_ref , (curr_k_slice , slice (None )))
472+
473+ qk = pl .dot (q , k .T )
474+ if sm_scale != 1.0 :
475+ qk *= sm_scale
476+
477+ if causal or segment_ids_ref is not None :
478+ mask = None
479+ if segment_ids_ref is not None :
480+ kv_segment_ids = pl .load (segment_ids_ref , (curr_k_slice ,))
481+ mask = segment_mask (q_segment_ids , kv_segment_ids )
482+
483+ if causal :
484+ span_k = start_k * block_k2 + jnp .arange (block_k2 )
485+ causal_mask = span_q [:, None ] >= span_k [None , :]
486+ mask = (
487+ causal_mask if mask is None else jnp .logical_and (mask , causal_mask )
488+ )
489+ qk = jnp .where (mask , qk , DEFAULT_MASK_VALUE )
490+
491+ p = jnp .exp (qk - m [:, None ])
492+ dp = jnp .zeros ((block_q2 , block_k2 ), dtype = jnp .float32 ) - di [:, None ]
493+ dp = dp + pl .dot (do , v .T )
494+ ds = p * dp
495+ if sm_scale != 1.0 :
496+ ds = ds * sm_scale
497+
498+ dq = dq + pl .dot (ds .astype (k .dtype ), k ).astype (dq .dtype )
499+
500+ return dq
501+
502+ if causal :
503+ upper_bound = lax .div ((start_q + 1 ) * block_q2 , block_k2 )
504+ else :
505+ upper_bound = pl .cdiv (seq_len , block_k2 )
506+
507+ dq = lax .fori_loop (0 , upper_bound , inner_loop_dq , (dq ))
508+ pl .store (dq_ref , (curr_q_slice , slice (None )), dq .astype (dq_ref .dtype ))
454509
455510
456511def _mha_backward (sm_scale : float , causal : bool , block_q : int , block_k : int ,
@@ -473,75 +528,72 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
473528 block_q = min (block_q , seq_len )
474529 block_k = min (block_k , seq_len )
475530 do_scaled , delta = _preprocess_backward (out , do , l , block_q , debug , interpret )
476- # We accumulate into dq so we need to initialize it to zeros.
477- dq = jnp .zeros (q .shape , jnp .float32 )
478531 out_shapes = [
479- jax .ShapeDtypeStruct (dq .shape , dq .dtype ),
480- jax .ShapeDtypeStruct (k .shape , k .dtype ),
481- jax .ShapeDtypeStruct (v .shape , v .dtype ),
532+ jax .ShapeDtypeStruct (q .shape , q .dtype ),
533+ jax .ShapeDtypeStruct (k .shape , k .dtype ),
534+ jax .ShapeDtypeStruct (v .shape , v .dtype ),
482535 ]
483536
484537 in_specs = [
485538 pl .BlockSpec (
486- (None , seq_len , None , head_dim ), lambda j , k : (j , 0 , k , 0 )
487- ),
488- pl .BlockSpec (
489- (None , seq_len , None , head_dim ), lambda j , k : (j , 0 , k , 0 )
539+ (None , seq_len , None , head_dim ), lambda i , j , _ : (i , 0 , j , 0 )
490540 ),
491541 pl .BlockSpec (
492- (None , seq_len , None , head_dim ), lambda j , k : (j , 0 , k , 0 )
542+ (None , seq_len , None , head_dim ), lambda i , j , _ : (i , 0 , j , 0 )
493543 ),
494544 pl .BlockSpec (
495- (None , seq_len , None , head_dim ), lambda j , k : (j , 0 , k , 0 )
545+ (None , seq_len , None , head_dim ), lambda i , j , _ : (i , 0 , j , 0 )
496546 ),
497547 pl .BlockSpec (
498- (None , seq_len , None , head_dim ), lambda j , k : (j , 0 , k , 0 )
548+ (None , seq_len , None , head_dim ), lambda i , j , _ : (i , 0 , j , 0 )
499549 ),
500- pl .BlockSpec ((None , None , seq_len ), lambda j , k : (j , k , 0 )),
501- pl .BlockSpec ((None , None , seq_len ), lambda j , k : (j , k , 0 )),
502- pl .BlockSpec ((None , None , seq_len ), lambda j , k : (j , k , 0 )),
503550 pl .BlockSpec (
504- (None , seq_len , None , head_dim ), lambda j , k : (j , 0 , k , 0 )
551+ (None , seq_len , None , head_dim ), lambda i , j , _ : (i , 0 , j , 0 )
505552 ),
553+ pl .BlockSpec ((None , None , seq_len ), lambda i , j , _ : (i , j , 0 )),
554+ pl .BlockSpec ((None , None , seq_len ), lambda i , j , _ : (i , j , 0 )),
555+ pl .BlockSpec ((None , None , seq_len ), lambda i , j , _ : (i , j , 0 )),
506556 ]
507557 if segment_ids is None :
508558 in_specs .insert (3 , None ) # type: ignore[arg-type]
509- input_output_aliases = {8 : 0 }
510559 else :
511- in_specs .insert (3 , pl .BlockSpec ((None , seq_len ), lambda j , k : (j , 0 )))
512- input_output_aliases = {9 : 0 }
513- grid = (batch_size , num_heads )
514- # TODO(sharadmv): figure out why num_warps=8 doesn't work!
560+ in_specs .insert (3 , pl .BlockSpec ((None , seq_len ), lambda i , j , _ : (i , 0 )))
561+
562+ grid = (batch_size , num_heads , pl .cdiv (seq_len , block_k ))
515563 num_warps = 8
516564 dq , dk , dv = pl .pallas_call (
517565 functools .partial (
518566 mha_backward_kernel ,
519- block_q = block_q ,
520- block_d = head_dim ,
521- block_k = block_k ,
522567 sm_scale = sm_scale ,
523568 causal = causal ,
569+ block_q1 = block_q ,
570+ block_k1 = block_k ,
571+ block_q2 = block_q ,
572+ block_k2 = block_k ,
573+ block_d = head_dim ,
524574 ),
525- grid = grid ,
526575 out_shape = out_shapes ,
527576 in_specs = in_specs ,
577+ grid = grid ,
528578 out_specs = [
529579 pl .BlockSpec (
530- (None , seq_len , None , head_dim ), lambda j , k : (j , 0 , k , 0 )
580+ (None , seq_len , None , head_dim ),
581+ lambda i , j , _ : (i , 0 , j , 0 ), # dq
531582 ),
532583 pl .BlockSpec (
533- (None , seq_len , None , head_dim ), lambda j , k : (j , 0 , k , 0 )
584+ (None , seq_len , None , head_dim ),
585+ lambda i , j , _ : (i , 0 , j , 0 ), # dk
534586 ),
535587 pl .BlockSpec (
536- (None , seq_len , None , head_dim ), lambda j , k : (j , 0 , k , 0 )
588+ (None , seq_len , None , head_dim ),
589+ lambda i , j , _ : (i , 0 , j , 0 ), # dv
537590 ),
538591 ],
539592 name = "mha_backward" ,
540593 debug = debug ,
541594 interpret = interpret ,
542- compiler_params = dict (triton = dict (num_warps = num_warps , num_stages = 1 )),
543- input_output_aliases = input_output_aliases ,
544- )(q , k , v , segment_ids , out , do_scaled , l , m , delta , dq )
595+ compiler_params = dict (triton = dict (num_warps = num_warps , num_stages = 2 )),
596+ )(q , k , v , segment_ids , out , do_scaled , l , m , delta )
545597 else :
546598 raise ValueError (f"Invalid backward pass implementation: { backward_pass_impl } " )
547599 return dq .astype (q .dtype ), dk , dv , None
0 commit comments