@@ -8612,3 +8612,217 @@ def forward(self, x):
86128612 self .run_compare_torch (x , OuterModel (),
86138613 input_as_shape = False , use_scripting = True ,
86148614 backend = backend , compute_unit = compute_unit )
8615+
8616+ class TestScaledDotProductAttention (TorchBaseTest ):
8617+ """
8618+ Tests for torch.nn.functional.scaled_dot_product_attention op
8619+ (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
8620+ """
8621+
8622+ @pytest .mark .parametrize (
8623+ "compute_unit, backend, rank" ,
8624+ itertools .product (
8625+ compute_units ,
8626+ backends ,
8627+ [2 , 3 , 4 , 5 ],
8628+ ),
8629+ )
8630+ def test_different_input_ranks_no_mask (self , compute_unit , backend , rank ):
8631+ """
8632+ The query/key/value inputs can be any rank 2 or greater.
8633+ """
8634+ batch_size , seq_len , n_heads_1 , n_heads_2 , d = 2 , 10 , 3 , 4 , 7
8635+ if rank == 2 :
8636+ input_shape = (seq_len , d )
8637+ elif rank == 3 :
8638+ input_shape = (batch_size , seq_len , d )
8639+ elif rank == 4 :
8640+ input_shape = (batch_size , n_heads_1 , seq_len , d )
8641+ elif rank == 5 :
8642+ input_shape = (batch_size , n_heads_1 , n_heads_1 , seq_len , d )
8643+ else :
8644+ raise ValueError ("invalid rank" )
8645+
8646+ model = ModuleWrapper (
8647+ function = nn .functional .scaled_dot_product_attention ,
8648+ kwargs = {
8649+ "attn_mask" : None ,
8650+ "dropout_p" : 0.0 ,
8651+ "is_causal" : False ,
8652+ },
8653+ )
8654+
8655+ self .run_compare_torch (
8656+ [input_shape ] * 3 ,
8657+ model ,
8658+ backend = backend ,
8659+ compute_unit = compute_unit ,
8660+ )
8661+
8662+ @pytest .mark .parametrize (
8663+ "compute_unit, backend, seq_lengths, include_heads" ,
8664+ itertools .product (
8665+ compute_units ,
8666+ backends ,
8667+ [(5 , 5 ), (5 , 7 ), (6 , 4 )],
8668+ [False , True ],
8669+ ),
8670+ )
8671+ def test_is_causal_flag (self , compute_unit , backend , seq_lengths , include_heads ):
8672+ source_seq_len , target_seq_len = seq_lengths
8673+ query_shape = (2 , 2 , target_seq_len , 7 ) if include_heads else (2 , target_seq_len , 7 )
8674+ key_shape = (2 , 2 , source_seq_len , 7 ) if include_heads else (2 , source_seq_len , 7 )
8675+ value_shape = key_shape
8676+
8677+ model = ModuleWrapper (
8678+ function = nn .functional .scaled_dot_product_attention ,
8679+ kwargs = {
8680+ "attn_mask" : None ,
8681+ "is_causal" : True ,
8682+ },
8683+ )
8684+ res = self .run_compare_torch (
8685+ [query_shape , key_shape , value_shape ],
8686+ model ,
8687+ backend = backend ,
8688+ compute_unit = compute_unit ,
8689+ )
8690+ # check that "fill" and "band_part" ops, which are needed to compute mask, have been constant folded
8691+ mil_prog = res [1 ]._get_mil_internal ()
8692+ # assert that "lstm" ops are present in the mil program
8693+ assert len (mil_prog .find_ops (op_type = "fill" )) == 0
8694+ assert len (mil_prog .find_ops (op_type = "band_part" )) == 0
8695+
8696+ @pytest .mark .parametrize (
8697+ "compute_unit, backend, seq_lengths, bool_mask" ,
8698+ itertools .product (
8699+ compute_units ,
8700+ backends ,
8701+ [(5 , 5 ), (7 , 5 )],
8702+ [False , True ],
8703+ ),
8704+ )
8705+ def test_attn_mask (self , compute_unit , backend , seq_lengths , bool_mask ):
8706+ source_seq_len , target_seq_len = seq_lengths
8707+ query_shape = (2 , 3 , target_seq_len , 7 )
8708+ key_shape = (2 , 3 , source_seq_len , 7 )
8709+ value_shape = key_shape
8710+ mask_shape = (target_seq_len , source_seq_len )
8711+
8712+ query = generate_input_data (query_shape )
8713+ key = generate_input_data (key_shape )
8714+ value = generate_input_data (value_shape )
8715+ if bool_mask :
8716+ mask = torch .rand (mask_shape ) > 0.5
8717+ mask = mask .bool ()
8718+ else :
8719+ mask = generate_input_data (mask_shape )
8720+
8721+ model = ModuleWrapper (function = nn .functional .scaled_dot_product_attention )
8722+ self .run_compare_torch (
8723+ (query , key , value , mask ),
8724+ model ,
8725+ backend = backend ,
8726+ compute_unit = compute_unit ,
8727+ input_as_shape = False ,
8728+ )
8729+
8730+ @pytest .mark .parametrize (
8731+ "compute_unit, backend, mask_as_input" ,
8732+ itertools .product (
8733+ compute_units ,
8734+ backends ,
8735+ [True , False ],
8736+ ),
8737+ )
8738+ def test_toy_xformer_with_sdpa (self , compute_unit , backend , mask_as_input ):
8739+ embedding_size = 32
8740+ seq_length = 16
8741+ n_heads = 4
8742+ batch_size = 2
8743+ num_blocks = 3
8744+
8745+ class AttentionBlock (nn .Module ):
8746+ def __init__ (self , embed_dim = embedding_size , n_head = n_heads ):
8747+ super ().__init__ ()
8748+ self .query_proj_op = nn .Linear (embed_dim , embed_dim )
8749+ self .key_proj_op = nn .Linear (embed_dim , embed_dim )
8750+ self .value_proj_op = nn .Linear (embed_dim , embed_dim )
8751+ self .out_proj_op = nn .Linear (embed_dim , embed_dim )
8752+ self .n_head = n_head
8753+
8754+ def forward (self , x , mask = None ):
8755+ # in comments below for shapes, using following notation:
8756+ # B: batch_size, S: seq_length, E: embedding_size, h: n_heads
8757+ # x: (B,S,E)
8758+ # mask: (S,S)
8759+ batch_size , seq_len , dim = x .shape
8760+ query_proj = self .query_proj_op (x ) # (B,S,E)
8761+ key_proj = self .key_proj_op (x ) # (B,S,E)
8762+ value_proj = self .value_proj_op (x ) # (B,S,E)
8763+ # reshape to (B, h, S, E/h)
8764+ query_proj = query_proj .reshape (
8765+ batch_size , seq_len , self .n_head , dim // self .n_head
8766+ ).permute (
8767+ 0 , 2 , 1 , 3
8768+ ) # (B, h, S, E/h)
8769+ key_proj = key_proj .reshape (
8770+ batch_size , seq_len , self .n_head , dim // self .n_head
8771+ ).permute (
8772+ 0 , 2 , 1 , 3
8773+ ) # (B, h, S, E/h)
8774+ value_proj = value_proj .reshape (
8775+ batch_size , seq_len , self .n_head , dim // self .n_head
8776+ ).permute (
8777+ 0 , 2 , 1 , 3
8778+ ) # (B, h, S, E/h)
8779+ # now do scaled dot produce attention
8780+ if mask is None :
8781+ out = nn .functional .scaled_dot_product_attention (
8782+ query_proj , key_proj , value_proj , is_causal = True
8783+ ) # (B, h, S, E/h)
8784+ else :
8785+ out = nn .functional .scaled_dot_product_attention (
8786+ query_proj , key_proj , value_proj , mask
8787+ ) # (B, h, S, E/h)
8788+ # reshape back to (B, S, E)
8789+ out = out .permute (0 , 2 , 1 , 3 ).reshape (batch_size , seq_len , dim ) # (B, S, E)
8790+ return self .out_proj_op (out )
8791+
8792+ class MLPBlock (nn .Module ):
8793+ def __init__ (self , embed_dim = embedding_size ):
8794+ super ().__init__ ()
8795+ self .fc1 = nn .Linear (embed_dim , embed_dim )
8796+ self .activation = nn .GELU ()
8797+ self .fc2 = nn .Linear (embed_dim , embed_dim )
8798+
8799+ def forward (self , x ):
8800+ x = self .fc1 (x )
8801+ x = self .activation (x )
8802+ return self .fc2 (x )
8803+
8804+ class ToyTransformer (nn .Module ):
8805+ def __init__ (self , n_blocks = num_blocks , embed_dim = embedding_size ):
8806+ super ().__init__ ()
8807+ self .attn_block = AttentionBlock (embed_dim = embed_dim )
8808+ self .mlp = MLPBlock (embed_dim = embed_dim )
8809+ self .n_blocks = n_blocks
8810+ self .lnorm = nn .LayerNorm (embed_dim )
8811+
8812+ def forward (self , x , mask = None ):
8813+ for i in range (self .n_blocks ):
8814+ x = self .attn_block (x , mask ) + x
8815+ x = self .lnorm (x )
8816+ x = self .mlp (x ) + x
8817+ x = self .lnorm (x )
8818+ return x
8819+
8820+ model = ToyTransformer ()
8821+ self .run_compare_torch (
8822+ [(batch_size , seq_length , embedding_size ), (seq_length , seq_length )]
8823+ if mask_as_input
8824+ else [(batch_size , seq_length , embedding_size )],
8825+ model ,
8826+ backend = backend ,
8827+ compute_unit = compute_unit ,
8828+ )
0 commit comments