@@ -244,6 +244,294 @@ class ConvertAtenReflectionPad1dOp
244
244
};
245
245
}
246
246
247
+ namespace {
248
+
249
+ // Lower the aten.reflection.pad_2d operator into a sequence of
250
+ // tensor.extract_slice, linalg.generic, and tensor_insert_slice
251
+ // operations.
252
+
253
+ // To understand the lowering, consider this pytorch example:
254
+ //
255
+ // >>> t = torch.tensor([[[1.0,2,3],[4,5,6], [7,8,9]]])
256
+ // >>> t
257
+ // tensor([[[1., 2., 3.],
258
+ // [4., 5., 6.],
259
+ // [7., 8., 9.]]])
260
+ // >>> torch.ops.aten.reflection_pad2d(t, [1,2,1,2])
261
+ // tensor([[[5., 4., 5., 6., 5., 4.],
262
+ // [2., 1., 2., 3., 2., 1.],
263
+ // [5., 4., 5., 6., 5., 4.],
264
+ // [8., 7., 8., 9., 8., 7.],
265
+ // [5., 4., 5., 6., 5., 4.],
266
+ // [2., 1., 2., 3., 2., 1.]]])
267
+ //
268
+ // The result can be subdivided into "tiles" corresponding to either
269
+ // the input tensor (in the center) or slices of the input tensor
270
+ // whose width and height is determined by the padding sizes and which
271
+ // are reflected through the side of the central input tensor that
272
+ // they touch.
273
+ // In the example above, the tiles are:
274
+ // top left: [[5]]
275
+ // top center: [[4,5,6]]
276
+ // top right: [[5,4]]
277
+ // center left [[2,1],[5,4],[8,7]]
278
+ // center: copy of the input tensor
279
+ // center right: [[2,1],[5,4],[8,7]]
280
+ // bottom left: [[5,4],[2,1]]
281
+ // center bottom: [[2,3,2]]
282
+ // center right: [[2,1]]
283
+ //
284
+ // The lowering uses a tensor.extract_slice operation to create each tile,
285
+ // a linalg.generic for the reflection, and a tensor.insert_slice to
286
+ // insert the tile in the resulting tensor.
287
+ class ConvertAtenReflectionPad2dOp
288
+ : public OpConversionPattern<AtenReflectionPad2dOp> {
289
+ public:
290
+ using OpConversionPattern::OpConversionPattern;
291
+ LogicalResult
292
+ matchAndRewrite (AtenReflectionPad2dOp op, OpAdaptor adaptor,
293
+ ConversionPatternRewriter &rewriter) const override {
294
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
295
+ return failure ();
296
+
297
+ SmallVector<int64_t > padInts;
298
+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (padInts)))
299
+ return rewriter.notifyMatchFailure (
300
+ op, " only support constant int pad ranges" );
301
+
302
+ Location loc = op.getLoc ();
303
+ // Some generic helper functions for creating arithmetic operations.
304
+ auto createAdd = [&](Value x, Value y) {
305
+ return rewriter.create <arith::AddIOp>(loc, x, y);
306
+ };
307
+
308
+ auto createAdds = [&](std::initializer_list<Value> values) {
309
+ assert (values.size () >= 2 );
310
+ return std::accumulate (values.begin () + 1 , values.end (), data (values)[0 ],
311
+ createAdd);
312
+ };
313
+
314
+ auto createSub = [&](Value x, Value y) {
315
+ return rewriter.create <arith::SubIOp>(loc, x, y);
316
+ };
317
+
318
+ auto createSubs = [&](std::initializer_list<Value> values) {
319
+ assert (values.size () >= 2 );
320
+ return std::accumulate (values.begin () + 1 , values.end (), data (values)[0 ],
321
+ createSub);
322
+ };
323
+
324
+ // Enums for specifying the coordinates of a tile. An "h" prefix
325
+ // is used to stand for "horizontal" and "v" for "vertical"
326
+ // throughout.
327
+ enum PadHLoc { LEFT = 0 , RIGHT = 1 , HCENTER = 2 };
328
+ enum PadVLoc { TOP = 0 , BOTTOM = 1 , VCENTER = 2 };
329
+
330
+ // Helper functions for obtaining information about the operator's
331
+ // padding arguments.
332
+ auto getHPadArgument = [&](PadHLoc l) {
333
+ assert (l < HCENTER);
334
+ return padInts[l];
335
+ };
336
+
337
+ auto getVPadArgument = [&](PadVLoc l) {
338
+ assert (l < VCENTER);
339
+ return padInts[2 + l];
340
+ };
341
+
342
+ auto shouldCreateTile = [&](PadVLoc v, PadHLoc h) {
343
+ if (!(h == HCENTER || getHPadArgument (h) > 0 ))
344
+ return false ;
345
+ if (!(v == VCENTER || getVPadArgument (v) > 0 ))
346
+ return false ;
347
+
348
+ return true ;
349
+ };
350
+
351
+ Value input = adaptor.getSelf ();
352
+ MLIRContext *context = rewriter.getContext ();
353
+ auto inputType = llvm::cast<RankedTensorType>(input.getType ());
354
+ auto outputType = llvm::cast<RankedTensorType>(
355
+ getTypeConverter ()->convertType (op->getResult (0 ).getType ()));
356
+ unsigned numDims = inputType.getRank ();
357
+
358
+ assert (numDims >= 2 && " Not enough input dimensions" );
359
+
360
+ SmallVector<Value> inputShape = getTensorSizes (rewriter, loc, input);
361
+ int64_t hDim = numDims - 1 ;
362
+ int64_t vDim = numDims - 2 ;
363
+ Value hDimSize = inputShape[hDim];
364
+ Value vDimSize = inputShape[vDim];
365
+
366
+ assert (getHPadArgument (LEFT) < inputType.getShape ()[hDim] &&
367
+ " Left padding too large" );
368
+ assert (getHPadArgument (RIGHT) < inputType.getShape ()[hDim] &&
369
+ " Right padding too large" );
370
+ assert (getVPadArgument (TOP) < inputType.getShape ()[vDim] &&
371
+ " Top padding too large" );
372
+ assert (getVPadArgument (BOTTOM) < inputType.getShape ()[vDim] &&
373
+ " Bottom padding too large" );
374
+
375
+ Type indexType = rewriter.getIndexType ();
376
+ Value zero = getConstant (rewriter, loc, 0 , indexType);
377
+ Value one = getConstant (rewriter, loc, 1 , indexType);
378
+
379
+ Value tileWidth[3 ];
380
+ tileWidth[HCENTER] = hDimSize;
381
+ for (auto h : {LEFT, RIGHT})
382
+ tileWidth[h] = getConstant (rewriter, loc, getHPadArgument (h), indexType);
383
+
384
+ Value tileHeight[3 ];
385
+ tileHeight[VCENTER] = vDimSize;
386
+ for (auto v : {TOP, BOTTOM})
387
+ tileHeight[v] = getConstant (rewriter, loc, getVPadArgument (v), indexType);
388
+
389
+ // Helper to reflect/reverse the i-th dimension of an affine map
390
+ // without symbols. This only works if applied on a tensor
391
+ // for which the corresponding dimension has a statically
392
+ // known size which is good enough since we only apply
393
+ // it to reflect the padding slices.
394
+ auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
395
+ int64_t size) {
396
+ AffineExpr d = map.getResult (i);
397
+ return map.replace (d, size - d - 1 , numDims, 0 );
398
+ };
399
+
400
+ // Create output shape and tensor
401
+ SmallVector<Value> resultShape{inputShape};
402
+ resultShape[vDim] =
403
+ createAdds ({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]});
404
+ resultShape[hDim] =
405
+ createAdds ({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]});
406
+
407
+ Value resultTensor = createZeroInitTensor (rewriter, loc, resultShape,
408
+ inputType.getElementType ());
409
+
410
+ // Construction of the tiles
411
+
412
+ // Example: central left tile
413
+ //
414
+ // Let m the width of the left padding as returned by getHPadargument(LEFT)
415
+ // and n the size of the input tensor's "horizontal" dimension, i.e.
416
+ // hDimSize. Assume that the subtensor of the input tensor in the relevant
417
+ // (i.e. last two) dimensions is:
418
+ //
419
+ // x_1,1 x_1,2 ... x_1,m
420
+ // x_2,1 x_2,2 ... x_2,m
421
+ // .
422
+ // .
423
+ // .
424
+ // x_n,1 x_n,2 ... x_n,m
425
+ //
426
+ // The padding tile consists of the columns 2, ..., m + 1
427
+ // of the input in reverse order. The first column gets
428
+ // skipped because this is the column through which the
429
+ // reflection happens.
430
+ //
431
+ // x_1,m x_1,m-1 ... x_1,2
432
+ // x_2,m x_1,m-1 ... x_2,2
433
+ // .
434
+ // .
435
+ // .
436
+ // x_n,m x_n,m-1 ... x_n,2
437
+ //
438
+ // The tile will be inserted to the left of the copy of the input tensor
439
+ // in the output tensor, i.e. with horizontal offset 0.
440
+ // The top padding determines the vertical offset.
441
+
442
+ // Tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through
443
+ // two sides, i.e. their columns and rows must be reversed.
444
+
445
+ // Setup information about the tiles
446
+
447
+ // Compute the offsets for extracting the slice from the
448
+ // input. We need to skip the row or column through which
449
+ // the tile should be reflected, if any (none for the center tile).
450
+ Value extractHOffset[3 ];
451
+ extractHOffset[LEFT] = one;
452
+ extractHOffset[HCENTER] = zero;
453
+ extractHOffset[RIGHT] = createSubs ({hDimSize, tileWidth[RIGHT], one});
454
+
455
+ Value extractVOffset[3 ];
456
+ extractVOffset[TOP] = one;
457
+ extractVOffset[VCENTER] = zero;
458
+ extractVOffset[BOTTOM] = createSubs ({vDimSize, tileHeight[BOTTOM], one});
459
+
460
+ // Compute the horizontal and vertical offsets for inserting
461
+ // the tiles in the resultTensor.
462
+ Value insertHOffset[3 ];
463
+ insertHOffset[LEFT] = zero;
464
+ insertHOffset[HCENTER] = tileWidth[LEFT];
465
+ insertHOffset[RIGHT] = createAdd (hDimSize, tileWidth[LEFT]);
466
+
467
+ Value insertVOffset[3 ];
468
+ insertVOffset[TOP] = zero;
469
+ insertVOffset[VCENTER] = tileHeight[TOP];
470
+ insertVOffset[BOTTOM] = createAdd (vDimSize, tileHeight[TOP]);
471
+
472
+ auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; };
473
+ auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; };
474
+
475
+ SmallVector<utils::IteratorType> iteratorTypes{
476
+ numDims, utils::IteratorType::parallel};
477
+ auto idMap = AffineMap::getMultiDimIdentityMap (numDims, context);
478
+ SmallVector<Value> allOneStrides (numDims, one);
479
+
480
+ auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) {
481
+ // Create the tile by extracting a slice from the input tenor.
482
+ SmallVector<Value> extractShape{inputShape};
483
+ extractShape[hDim] = tileWidth[horizontalPos];
484
+ extractShape[vDim] = tileHeight[verticalPos];
485
+
486
+ SmallVector<Value> extractOffsets (numDims, zero);
487
+ extractOffsets[hDim] = extractHOffset[horizontalPos];
488
+ extractOffsets[vDim] = extractVOffset[verticalPos];
489
+
490
+ Value tile = rewriter.create <tensor::ExtractSliceOp>(
491
+ loc, input, extractOffsets, extractShape, allOneStrides);
492
+
493
+ // Reverse the tile along the horizontal, vertical, or both
494
+ // dimensions.
495
+ auto inputMap = AffineMap::getMultiDimIdentityMap (numDims, context);
496
+ if (shouldHReflect (horizontalPos)) {
497
+ inputMap =
498
+ reflectDim (inputMap, numDims, hDim, getHPadArgument (horizontalPos));
499
+ }
500
+ if (shouldVReflect (verticalPos)) {
501
+ inputMap =
502
+ reflectDim (inputMap, numDims, vDim, getVPadArgument (verticalPos));
503
+ }
504
+
505
+ tile = rewriter
506
+ .create <linalg::GenericOp>(
507
+ loc, llvm::cast<RankedTensorType>(tile.getType ()), tile,
508
+ tile, ArrayRef ({inputMap, idMap}), iteratorTypes,
509
+ [](OpBuilder &b, Location nestedLoc, ValueRange args) {
510
+ b.create <linalg::YieldOp>(nestedLoc, args[0 ]);
511
+ })
512
+ .getResult (0 );
513
+
514
+ // Insert the tile in the resultTensor.
515
+ SmallVector<Value> insertOffsets (numDims, zero);
516
+ insertOffsets[hDim] = insertHOffset[horizontalPos];
517
+ insertOffsets[vDim] = insertVOffset[verticalPos];
518
+
519
+ resultTensor = rewriter.create <tensor::InsertSliceOp>(
520
+ loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
521
+ };
522
+
523
+ for (auto v : {TOP, BOTTOM, VCENTER})
524
+ for (auto h : {LEFT, RIGHT, HCENTER})
525
+ if (shouldCreateTile (v, h))
526
+ createTile (v, h);
527
+
528
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, outputType, resultTensor);
529
+
530
+ return success ();
531
+ }
532
+ };
533
+ } // namespace
534
+
247
535
namespace {
248
536
class ConvertAtenFlattenUsingIntsOp
249
537
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
@@ -1552,6 +1840,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
1552
1840
MLIRContext *context = patterns.getContext ();
1553
1841
target.addIllegalOp <AtenReflectionPad1dOp>();
1554
1842
patterns.add <ConvertAtenReflectionPad1dOp>(typeConverter, context);
1843
+ target.addIllegalOp <AtenReflectionPad2dOp>();
1844
+ patterns.add <ConvertAtenReflectionPad2dOp>(typeConverter, context);
1555
1845
target.addIllegalOp <AtenFlattenUsingIntsOp>();
1556
1846
patterns.add <ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
1557
1847
target.addIllegalOp <AtenViewOp>();
0 commit comments