Skip to content

Commit ce727b5

Browse files
committed
Add onnx op LRN lowering
This commit adds support for lowering Onnx LRN op to aten.
1 parent d59d0b6 commit ce727b5

File tree

5 files changed

+250
-2
lines changed

5 files changed

+250
-2
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace mlir::torch::onnx_c {
3434

3535
Value createConstantIntList(OpBinder binder,
3636
ConversionPatternRewriter &rewriter,
37-
SmallVector<int64_t> cstInput);
37+
ArrayRef<int64_t> cstInput);
3838

3939
Type getQTorchTypeFromTorchIntType(Type ty);
4040

include/torch-mlir/Dialect/Torch/Utils/Utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ namespace mlir {
1818
namespace torch {
1919
namespace Torch {
2020

21+
class PrimListConstructOp;
22+
2123
int64_t toPositiveDim(int64_t dim, int64_t inputRank);
2224
bool isValidDim(int64_t dim, int64_t inputRank);
2325
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,6 +1445,121 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
14451445
binder.op, resultType, operand, constAlpha);
14461446
return success();
14471447
});
1448+
patterns.onOp(
1449+
"LRN", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1450+
Torch::ValueTensorType resultType;
1451+
Value operand;
1452+
int64_t size;
1453+
float alpha, beta, bias;
1454+
if (binder.tensorOperand(operand) ||
1455+
binder.tensorResultType(resultType) ||
1456+
binder.s64IntegerAttr(size, "size", 2) ||
1457+
binder.f32FloatAttr(alpha, "alpha", 0.0001f) ||
1458+
binder.f32FloatAttr(beta, "beta", 0.75f) ||
1459+
binder.f32FloatAttr(bias, "bias", 1.0f))
1460+
return failure();
1461+
Type dtype = resultType.getOptionalDtype();
1462+
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
1463+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
1464+
rewriter.getF64FloatAttr(alpha));
1465+
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
1466+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
1467+
rewriter.getF64FloatAttr(beta));
1468+
Value constBias = rewriter.create<Torch::ConstantFloatOp>(
1469+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
1470+
rewriter.getF64FloatAttr(bias));
1471+
// Please refer to the operator description
1472+
// for more info on the lowering
1473+
// https://onnx.ai/onnx/operators/onnx__LRN.html
1474+
1475+
// squared = operand^2
1476+
Location loc = binder.getLoc();
1477+
Torch::ValueTensorType inTy =
1478+
cast<Torch::ValueTensorType>(operand.getType());
1479+
Value sqOperand = rewriter.create<Torch::AtenMulTensorOp>(
1480+
loc, inTy, operand, operand);
1481+
// view it as n x 1 x c x d0 x d..
1482+
if (!inTy.hasSizes()) {
1483+
return rewriter.notifyMatchFailure(binder.op,
1484+
"Expected input to have sizes");
1485+
}
1486+
ArrayRef<int64_t> inTyShape = inTy.getSizes();
1487+
if (inTyShape.size() < 3) {
1488+
return rewriter.notifyMatchFailure(
1489+
binder.op, "Unsupported: the input dimensions should be >= 3");
1490+
}
1491+
if (inTyShape[1] == Torch::kUnknownSize) {
1492+
return rewriter.notifyMatchFailure(
1493+
binder.op, "Unsupported: the second dimension size must be "
1494+
"statically known");
1495+
}
1496+
SmallVector<int64_t, 5> viewShapeInt{inTyShape[0], 1, inTyShape[1],
1497+
inTyShape[2], Torch::kUnknownSize};
1498+
Torch::ValueTensorType reshapeType =
1499+
rewriter.getType<Torch::ValueTensorType>(viewShapeInt, dtype);
1500+
Value viewShapeListVal =
1501+
createConstantIntList(binder, rewriter, viewShapeInt);
1502+
auto view = rewriter.create<Torch::AtenViewOp>(
1503+
loc, reshapeType, sqOperand, viewShapeListVal);
1504+
// padding
1505+
int64_t highPad = (size - 1) / 2;
1506+
int64_t lowPad = (size - 1) - highPad;
1507+
SmallVector<int64_t> paddingInt{0, 0, 0, 0, lowPad, highPad};
1508+
auto constPadVal = rewriter.create<Torch::ConstantFloatOp>(
1509+
loc, rewriter.getType<Torch::FloatType>(),
1510+
rewriter.getF64FloatAttr(0.0));
1511+
Value paddingListVal =
1512+
createConstantIntList(binder, rewriter, paddingInt);
1513+
SmallVector<int64_t, 5> paddedShapeInt = viewShapeInt;
1514+
paddedShapeInt[2] += size - 1;
1515+
Torch::ValueTensorType paddedType =
1516+
rewriter.getType<Torch::ValueTensorType>(paddedShapeInt, dtype);
1517+
auto padded = rewriter.create<Torch::AtenConstantPadNdOp>(
1518+
loc, paddedType, view, paddingListVal, constPadVal);
1519+
// avg_pool3d
1520+
SmallVector<int64_t, 3> kernelSize{size, 1, 1};
1521+
Value kernelSizeList =
1522+
createConstantIntList(binder, rewriter, kernelSize);
1523+
SmallVector<int64_t, 3> strides{1, 1, 1};
1524+
Value stridesList = createConstantIntList(binder, rewriter, strides);
1525+
SmallVector<int64_t, 3> padding{0, 0, 0};
1526+
Value paddingList = createConstantIntList(binder, rewriter, padding);
1527+
auto cstCeilMode =
1528+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
1529+
auto cstCountIncludeMode =
1530+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
1531+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
1532+
// Output of pooling is same reshape(view) type because
1533+
// of the padding done on the dimensions being pooled.
1534+
auto pool = rewriter.create<Torch::AtenAvgPool3dOp>(
1535+
loc, reshapeType, padded, kernelSizeList, stridesList, paddingList,
1536+
cstCeilMode, cstCountIncludeMode, /*divisor_override=*/cstNone);
1537+
// squeeze
1538+
auto one = rewriter.create<Torch::ConstantIntOp>(
1539+
loc, rewriter.getI64IntegerAttr(1));
1540+
SmallVector<int64_t, 5> squeezeShapeInt{
1541+
viewShapeInt[0], viewShapeInt[2], viewShapeInt[3], viewShapeInt[4]};
1542+
Torch::ValueTensorType squeezeType =
1543+
rewriter.getType<Torch::ValueTensorType>(squeezeShapeInt, dtype);
1544+
auto squeeze = rewriter.create<Torch::AtenSqueezeDimOp>(
1545+
loc, squeezeType, pool, one);
1546+
// view as input Type
1547+
Value intTyShapeList =
1548+
createConstantIntList(binder, rewriter, inTyShape);
1549+
auto viewAsInput = rewriter.create<Torch::AtenViewOp>(
1550+
loc, inTy, squeeze, intTyShapeList);
1551+
// mul + add + pow + div
1552+
auto mul = rewriter.create<Torch::AtenMulScalarOp>(
1553+
loc, resultType, viewAsInput, constAlpha);
1554+
auto add = rewriter.create<Torch::AtenAddScalarOp>(loc, resultType, mul,
1555+
constBias, one);
1556+
auto pow = rewriter.create<Torch::AtenPowTensorScalarOp>(
1557+
loc, resultType, add, constBeta);
1558+
1559+
rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
1560+
binder.op, resultType, operand, pow);
1561+
return success();
1562+
});
14481563
patterns.onOp(
14491564
"Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
14501565
Torch::ValueTensorType resultType;

lib/Conversion/TorchOnnxToTorch/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using namespace mlir::torch::onnx_c;
1616

1717
Value mlir::torch::onnx_c::createConstantIntList(
1818
OpBinder binder, ConversionPatternRewriter &rewriter,
19-
SmallVector<int64_t> cstInput) {
19+
ArrayRef<int64_t> cstInput) {
2020
SmallVector<Value> cstValue;
2121
for (int64_t i : cstInput) {
2222
cstValue.push_back(rewriter.create<Torch::ConstantIntOp>(

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,137 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor
310310

311311
// -----
312312

313+
// CHECK-LABEL: func.func @test_lrn_default
314+
func.func @test_lrn_default(%arg0: !torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} {
315+
// CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
316+
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
317+
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
318+
// CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 9.9999997473787516E-5
319+
// CHECK-DAG: %[[BETA:.*]] = torch.constant.float 7.500000e-01
320+
// CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 1.000000e+00
321+
// CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0
322+
323+
// CHECK-DAG: %[[I20:.*]] = torch.constant.int 20
324+
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
325+
// CHECK-DAG: %[[I10:.*]] = torch.constant.int 10
326+
// CHECK-DAG: %[[I3:.+]] = torch.constant.int 3
327+
// CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1
328+
// CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I20]], %[[I1]], %[[I10]], %[[I3]], %[[IMINUS1]]
329+
330+
// CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]
331+
332+
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
333+
// CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
334+
// CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
335+
// CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
336+
// CHECK-DAG: %[[I1_2:.*]] = torch.constant.int 1
337+
// CHECK-DAG: %[[I1_3:.*]] = torch.constant.int 1
338+
// CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I1_2]], %[[I1_3]]
339+
340+
// CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]
341+
342+
// CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
343+
// CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
344+
// CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
345+
// CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I3_2]], %[[I1_4]], %[[I1_5]]
346+
347+
// CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
348+
// CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
349+
// CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
350+
// CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]
351+
352+
// CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
353+
// CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
354+
// CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
355+
// CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]
356+
357+
// CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
358+
// CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]
359+
360+
// CHECK-DAG: %[[I20_2:.*]] = torch.constant.int 20
361+
// CHECK-DAG: %[[I10_2:.*]] = torch.constant.int 10
362+
// CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
363+
// CHECK-DAG: %[[I50_2:.+]] = torch.constant.int 50
364+
// CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I20_2]], %[[I10_2]], %[[I3_2]], %[[I50_2]]
365+
366+
// CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
367+
// CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
368+
// CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
369+
// CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
370+
// CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
371+
// CHECK: return %[[OUTPUT]]
372+
%0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.size = 3 : si64} : (!torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32>
373+
return %0 : !torch.vtensor<[20,10,3,50],f32>
374+
}
375+
376+
// -----
377+
378+
// CHECK-LABEL: func.func @test_lrn_with_optionals
379+
func.func @test_lrn_with_optionals(%arg0: !torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} {
380+
// CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
381+
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
382+
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
383+
// CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 0.0020000000949949026
384+
// CHECK-DAG: %[[BETA:.*]] = torch.constant.float 0.64999997615814209
385+
// CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 3.000000e+00
386+
// CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0
387+
388+
// CHECK-DAG: %[[I13:.*]] = torch.constant.int 13
389+
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
390+
// CHECK-DAG: %[[I19:.*]] = torch.constant.int 19
391+
// CHECK-DAG: %[[I100:.+]] = torch.constant.int 100
392+
// CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1
393+
// CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I13]], %[[I1]], %[[I19]], %[[I100]], %[[IMINUS1]]
394+
395+
// CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]
396+
397+
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
398+
// CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
399+
// CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
400+
// CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
401+
// CHECK-DAG: %[[I2:.*]] = torch.constant.int 2
402+
// CHECK-DAG: %[[I2_2:.*]] = torch.constant.int 2
403+
// CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I2]], %[[I2_2]]
404+
405+
// CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]
406+
407+
// CHECK-DAG: %[[I5:.+]] = torch.constant.int 5
408+
// CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
409+
// CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
410+
// CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I5]], %[[I1_4]], %[[I1_5]]
411+
412+
// CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
413+
// CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
414+
// CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
415+
// CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]
416+
417+
// CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
418+
// CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
419+
// CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
420+
// CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]
421+
422+
// CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
423+
// CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]
424+
425+
// CHECK-DAG: %[[I13_2:.*]] = torch.constant.int 13
426+
// CHECK-DAG: %[[I19_2:.*]] = torch.constant.int 19
427+
// CHECK-DAG: %[[I100_2:.+]] = torch.constant.int 100
428+
// CHECK-DAG: %[[I200_2:.+]] = torch.constant.int 200
429+
// CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I13_2]], %[[I19_2]], %[[I100_2]], %[[I200_2]]
430+
431+
// CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
432+
// CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
433+
// CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
434+
// CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
435+
// CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
436+
// CHECK: return %[[OUTPUT]]
437+
%none = torch.constant.none
438+
%0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.alpha = 2.000000e-03 : f32, torch.onnx.beta = 6.500000e-01 : f32, torch.onnx.bias = 3.000000e+00 : f32, torch.onnx.size = 5 : si64} : (!torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32>
439+
return %0 : !torch.vtensor<[13,19,100,200],f32>
440+
}
441+
442+
// -----
443+
313444
// CHECK-LABEL: @test_matmul_2d
314445
func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
315446
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>

0 commit comments

Comments
 (0)