@@ -310,6 +310,137 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor
310310
311311// -----
312312
313+ // CHECK-LABEL: func.func @test_lrn_default
314+ func.func @test_lrn_default (%arg0: !torch.vtensor <[20 ,10 ,3 ,50 ],f32 >) -> !torch.vtensor <[20 ,10 ,3 ,50 ],f32 > attributes {torch.onnx_meta.opset_version = 17 : si64 } {
315+ // CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
316+ // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
317+ // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
318+ // CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 9.9999997473787516E-5
319+ // CHECK-DAG: %[[BETA:.*]] = torch.constant.float 7.500000e-01
320+ // CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 1.000000e+00
321+ // CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0
322+
323+ // CHECK-DAG: %[[I20:.*]] = torch.constant.int 20
324+ // CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
325+ // CHECK-DAG: %[[I10:.*]] = torch.constant.int 10
326+ // CHECK-DAG: %[[I3:.+]] = torch.constant.int 3
327+ // CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1
328+ // CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I20]], %[[I1]], %[[I10]], %[[I3]], %[[IMINUS1]]
329+
330+ // CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]
331+
332+ // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
333+ // CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
334+ // CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
335+ // CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
336+ // CHECK-DAG: %[[I1_2:.*]] = torch.constant.int 1
337+ // CHECK-DAG: %[[I1_3:.*]] = torch.constant.int 1
338+ // CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I1_2]], %[[I1_3]]
339+
340+ // CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]
341+
342+ // CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
343+ // CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
344+ // CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
345+ // CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I3_2]], %[[I1_4]], %[[I1_5]]
346+
347+ // CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
348+ // CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
349+ // CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
350+ // CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]
351+
352+ // CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
353+ // CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
354+ // CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
355+ // CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]
356+
357+ // CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
358+ // CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]
359+
360+ // CHECK-DAG: %[[I20_2:.*]] = torch.constant.int 20
361+ // CHECK-DAG: %[[I10_2:.*]] = torch.constant.int 10
362+ // CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
363+ // CHECK-DAG: %[[I50_2:.+]] = torch.constant.int 50
364+ // CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I20_2]], %[[I10_2]], %[[I3_2]], %[[I50_2]]
365+
366+ // CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
367+ // CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
368+ // CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
369+ // CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
370+ // CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
371+ // CHECK: return %[[OUTPUT]]
372+ %0 = torch.operator " onnx.LRN" (%arg0 ) {torch.onnx.size = 3 : si64 } : (!torch.vtensor <[20 ,10 ,3 ,50 ],f32 >) -> !torch.vtensor <[20 ,10 ,3 ,50 ],f32 >
373+ return %0 : !torch.vtensor <[20 ,10 ,3 ,50 ],f32 >
374+ }
375+
376+ // -----
377+
378+ // CHECK-LABEL: func.func @test_lrn_with_optionals
379+ func.func @test_lrn_with_optionals (%arg0: !torch.vtensor <[13 ,19 ,100 ,200 ],f32 >) -> !torch.vtensor <[13 ,19 ,100 ,200 ],f32 > attributes {torch.onnx_meta.opset_version = 17 : si64 } {
380+ // CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
381+ // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
382+ // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
383+ // CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 0.0020000000949949026
384+ // CHECK-DAG: %[[BETA:.*]] = torch.constant.float 0.64999997615814209
385+ // CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 3.000000e+00
386+ // CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0
387+
388+ // CHECK-DAG: %[[I13:.*]] = torch.constant.int 13
389+ // CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
390+ // CHECK-DAG: %[[I19:.*]] = torch.constant.int 19
391+ // CHECK-DAG: %[[I100:.+]] = torch.constant.int 100
392+ // CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1
393+ // CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I13]], %[[I1]], %[[I19]], %[[I100]], %[[IMINUS1]]
394+
395+ // CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]
396+
397+ // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
398+ // CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
399+ // CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
400+ // CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
401+ // CHECK-DAG: %[[I2:.*]] = torch.constant.int 2
402+ // CHECK-DAG: %[[I2_2:.*]] = torch.constant.int 2
403+ // CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I2]], %[[I2_2]]
404+
405+ // CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]
406+
407+ // CHECK-DAG: %[[I5:.+]] = torch.constant.int 5
408+ // CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
409+ // CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
410+ // CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I5]], %[[I1_4]], %[[I1_5]]
411+
412+ // CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
413+ // CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
414+ // CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
415+ // CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]
416+
417+ // CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
418+ // CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
419+ // CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
420+ // CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]
421+
422+ // CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
423+ // CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]
424+
425+ // CHECK-DAG: %[[I13_2:.*]] = torch.constant.int 13
426+ // CHECK-DAG: %[[I19_2:.*]] = torch.constant.int 19
427+ // CHECK-DAG: %[[I100_2:.+]] = torch.constant.int 100
428+ // CHECK-DAG: %[[I200_2:.+]] = torch.constant.int 200
429+ // CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I13_2]], %[[I19_2]], %[[I100_2]], %[[I200_2]]
430+
431+ // CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
432+ // CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
433+ // CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
434+ // CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
435+ // CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
436+ // CHECK: return %[[OUTPUT]]
437+ %none = torch.constant.none
438+ %0 = torch.operator " onnx.LRN" (%arg0 ) {torch.onnx.alpha = 2.000000e-03 : f32 , torch.onnx.beta = 6.500000e-01 : f32 , torch.onnx.bias = 3.000000e+00 : f32 , torch.onnx.size = 5 : si64 } : (!torch.vtensor <[13 ,19 ,100 ,200 ],f32 >) -> !torch.vtensor <[13 ,19 ,100 ,200 ],f32 >
439+ return %0 : !torch.vtensor <[13 ,19 ,100 ,200 ],f32 >
440+ }
441+
442+ // -----
443+
313444// CHECK-LABEL: @test_matmul_2d
314445func.func @test_matmul_2d (%arg0: !torch.vtensor <[3 ,4 ],f32 >, %arg1: !torch.vtensor <[4 ,3 ],f32 >) -> !torch.vtensor <[3 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
315446 // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>
0 commit comments