@@ -206,36 +206,72 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !to
206206 return %1 : !torch.vtensor <[3 ,3 ],f32 >
207207}
208208
209- // CHECK-LABEL: func.func @torch.aten.flex_attention
210- func.func @torch.aten.flex_attention (%arg0: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, %arg1: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, %arg2: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >) -> (!torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >) {
209+
210+ //===----------------------------------------------------------------------===//
211+ // FlexAttention variant tests
212+ //===----------------------------------------------------------------------===//
213+
214+ func.func private @sdpa_score0 (%arg0: !torch.vtensor <[],f32 >, %arg1: !torch.vtensor <[],si32 >, %arg2: !torch.vtensor <[],si32 >, %arg3: !torch.vtensor <[],si32 >, %arg4: !torch.vtensor <[],si32 >) -> !torch.vtensor <[],f32 > {
215+ %5 = torch.aten.tanh %arg0 : !torch.vtensor <[],f32 > -> !torch.vtensor <[],f32 >
216+ return %5 : !torch.vtensor <[],f32 >
217+ }
218+
219+ func.func private @sdpa_mask0 (%arg0: !torch.vtensor <[],si32 >, %arg1: !torch.vtensor <[],si32 >, %arg2: !torch.vtensor <[],si32 >, %arg3: !torch.vtensor <[],si32 >) -> !torch.vtensor <[],i1 > {
220+ %0 = torch.aten.ge.Tensor %arg2 , %arg3 : !torch.vtensor <[],si32 >, !torch.vtensor <[],si32 > -> !torch.vtensor <[],i1 >
221+ return %0 : !torch.vtensor <[],i1 >
222+ }
223+
224+ // CHECK-LABEL: func.func @torch.hop_flex_attention
225+ func.func @torch.hop_flex_attention (%arg0: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, %arg1: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, %arg2: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >) -> (!torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >) {
211226 %float1.0 = torch.constant.float 1.000000e+00
212227 %false_0 = torch.constant.bool false
213228 // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
214229 // CHECK: %[[FALSE:.*]] = torch.constant.bool false
215- // CHECK: torch.aten.flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
230+ // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
216231 // CHECK-SAME: {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0}
217232 // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
218233 // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
219- %output , %logsumexp , %maxscore = torch.aten.flex_attention %arg0 , %arg1 , %arg2 , %float1.0 , %false_0 , %false_0 {mask_mod_fn = @sdpa_mask0 , score_mod_fn = @sdpa_score0 } : !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.float , !torch.bool , !torch.bool -> !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >
234+ %output , %logsumexp , %maxscore = torch.hop_flex_attention %arg0 , %arg1 , %arg2 , %float1.0 , %false_0 , %false_0 {mask_mod_fn = @sdpa_mask0 , score_mod_fn = @sdpa_score0 } : !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.float , !torch.bool , !torch.bool -> !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >
220235 return %output , %logsumexp , %maxscore : !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >
221236}
222237
223- func.func private @sdpa_score0 (%arg0: !torch.vtensor <[],f32 >, %arg1: !torch.vtensor <[],si32 >, %arg2: !torch.vtensor <[],si32 >, %arg3: !torch.vtensor <[],si32 >, %arg4: !torch.vtensor <[],si32 >) -> !torch.vtensor <[],f32 > {
224- %int1 = torch.constant.int 1
225- %0 = torch.aten.sub.Tensor %arg3 , %arg4 , %int1 : !torch.vtensor <[],si32 >, !torch.vtensor <[],si32 >, !torch.int -> !torch.vtensor <[],si32 >
226- %float1.000000e -01 = torch.constant.float 1.000000e-01
227- %1 = torch.aten.mul.Scalar %arg2 , %float1.000000e -01 : !torch.vtensor <[],si32 >, !torch.float -> !torch.vtensor <[],f32 >
228- %float1.000000e -02 = torch.constant.float 1.000000e-02
229- %2 = torch.aten.mul.Scalar %0 , %float1.000000e -02 : !torch.vtensor <[],si32 >, !torch.float -> !torch.vtensor <[],f32 >
230- %int1_0 = torch.constant.int 1
231- %3 = torch.aten.add.Tensor %arg0 , %2 , %int1_0 : !torch.vtensor <[],f32 >, !torch.vtensor <[],f32 >, !torch.int -> !torch.vtensor <[],f32 >
232- %int1_1 = torch.constant.int 1
233- %4 = torch.aten.add.Tensor %3 , %1 , %int1_1 : !torch.vtensor <[],f32 >, !torch.vtensor <[],f32 >, !torch.int -> !torch.vtensor <[],f32 >
234- %5 = torch.aten.tanh %4 : !torch.vtensor <[],f32 > -> !torch.vtensor <[],f32 >
235- return %5 : !torch.vtensor <[],f32 >
238+ // CHECK-LABEL: func.func @torch.hop_flex_attention_nomask
239+ func.func @torch.hop_flex_attention_nomask (%arg0: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, %arg1: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, %arg2: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >) -> (!torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >) {
240+ %float1.0 = torch.constant.float 1.000000e+00
241+ %false_0 = torch.constant.bool false
242+ // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
243+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
244+ // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
245+ // CHECK-SAME: {score_mod_fn = @sdpa_score0}
246+ // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
247+ // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
248+ %output , %logsumexp , %maxscore = torch.hop_flex_attention %arg0 , %arg1 , %arg2 , %float1.0 , %false_0 , %false_0 {score_mod_fn = @sdpa_score0 } : !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.float , !torch.bool , !torch.bool -> !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >
249+ return %output , %logsumexp , %maxscore : !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >
236250}
237251
238- func.func private @sdpa_mask0 (%arg0: !torch.vtensor <[],si32 >, %arg1: !torch.vtensor <[],si32 >, %arg2: !torch.vtensor <[],si32 >, %arg3: !torch.vtensor <[],si32 >) -> !torch.vtensor <[],i1 > {
239- %0 = torch.aten.ge.Tensor %arg2 , %arg3 : !torch.vtensor <[],si32 >, !torch.vtensor <[],si32 > -> !torch.vtensor <[],i1 >
240- return %0 : !torch.vtensor <[],i1 >
252+ // CHECK-LABEL: func.func @torch.hop_flex_attention_noscore
253+ func.func @torch.hop_flex_attention_noscore (%arg0: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, %arg1: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, %arg2: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >) -> (!torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >) {
254+ %float1.0 = torch.constant.float 1.000000e+00
255+ %false_0 = torch.constant.bool false
256+ // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
257+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
258+ // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
259+ // CHECK-SAME: {mask_mod_fn = @sdpa_mask0}
260+ // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
261+ // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
262+ %output , %logsumexp , %maxscore = torch.hop_flex_attention %arg0 , %arg1 , %arg2 , %float1.0 , %false_0 , %false_0 {mask_mod_fn = @sdpa_mask0 } : !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.float , !torch.bool , !torch.bool -> !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >
263+ return %output , %logsumexp , %maxscore : !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >
264+ }
265+
266+ // CHECK-LABEL: func.func @torch.hop_flex_attention_noscore_nomask
267+ func.func @torch.hop_flex_attention_noscore_nomask (%arg0: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, %arg1: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, %arg2: !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >) -> (!torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >) {
268+ %float1.0 = torch.constant.float 1.000000e+00
269+ %false_0 = torch.constant.bool false
270+ // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
271+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
272+ // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
273+ // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
274+ // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
275+ %output , %logsumexp , %maxscore = torch.hop_flex_attention %arg0 , %arg1 , %arg2 , %float1.0 , %false_0 , %false_0 : !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.float , !torch.bool , !torch.bool -> !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >
276+ return %output , %logsumexp , %maxscore : !torch.vtensor <[2 ,4 ,8 ,16 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >, !torch.vtensor <[2 ,4 ,8 ],f32 >
241277}
0 commit comments