Skip to content

Commit 0860c41

Browse files
frederik-hstellaraccident
authored andcommitted
Implement aten.reflection_pad2d lowering to linalg
1 parent aee1fca commit 0860c41

File tree

8 files changed

+553
-0
lines changed

8 files changed

+553
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7893,6 +7893,30 @@ def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [
78937893
}];
78947894
}
78957895

7896+
def Torch_AtenReflectionPad2dOp : Torch_Op<"aten.reflection_pad2d", [
7897+
AllowsTypeRefinement,
7898+
HasValueSemantics,
7899+
ReadOnly
7900+
]> {
7901+
let summary = "Generated op for `aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)`";
7902+
let arguments = (ins
7903+
AnyTorchTensorType:$self,
7904+
AnyTorchListOfTorchIntType:$padding
7905+
);
7906+
let results = (outs
7907+
AnyTorchTensorType:$result
7908+
);
7909+
let hasCustomAssemblyFormat = 1;
7910+
let extraClassDefinition = [{
7911+
ParseResult AtenReflectionPad2dOp::parse(OpAsmParser &parser, OperationState &result) {
7912+
return parseDefaultTorchOp(parser, result, 2, 1);
7913+
}
7914+
void AtenReflectionPad2dOp::print(OpAsmPrinter &printer) {
7915+
printDefaultTorchOp(printer, *this, 2, 1);
7916+
}
7917+
}];
7918+
}
7919+
78967920
def Torch_AtenPadOp : Torch_Op<"aten.pad", [
78977921
AllowsTypeRefinement,
78987922
HasValueSemantics,

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,294 @@ class ConvertAtenReflectionPad1dOp
244244
};
245245
}
246246

247+
namespace {
248+
249+
// Lower the aten.reflection.pad_2d operator into a sequence of
250+
// tensor.extract_slice, linalg.generic, and tensor_insert_slice
251+
// operations.
252+
253+
// To understand the lowering, consider this pytorch example:
254+
//
255+
// >>> t = torch.tensor([[[1.0,2,3],[4,5,6], [7,8,9]]])
256+
// >>> t
257+
// tensor([[[1., 2., 3.],
258+
// [4., 5., 6.],
259+
// [7., 8., 9.]]])
260+
// >>> torch.ops.aten.reflection_pad2d(t, [1,2,1,2])
261+
// tensor([[[5., 4., 5., 6., 5., 4.],
262+
// [2., 1., 2., 3., 2., 1.],
263+
// [5., 4., 5., 6., 5., 4.],
264+
// [8., 7., 8., 9., 8., 7.],
265+
// [5., 4., 5., 6., 5., 4.],
266+
// [2., 1., 2., 3., 2., 1.]]])
267+
//
268+
// The result can be subdivided into "tiles" corresponding to either
269+
// the input tensor (in the center) or slices of the input tensor
270+
// whose width and height is determined by the padding sizes and which
271+
// are reflected through the side of the central input tensor that
272+
// they touch.
273+
// In the example above, the tiles are:
274+
// top left: [[5]]
275+
// top center: [[4,5,6]]
276+
// top right: [[5,4]]
277+
// center left [[2,1],[5,4],[8,7]]
278+
// center: copy of the input tensor
279+
// center right: [[2,1],[5,4],[8,7]]
280+
// bottom left: [[5,4],[2,1]]
281+
// center bottom: [[2,3,2]]
282+
// center right: [[2,1]]
283+
//
284+
// The lowering uses a tensor.extract_slice operation to create each tile,
285+
// a linalg.generic for the reflection, and a tensor.insert_slice to
286+
// insert the tile in the resulting tensor.
287+
class ConvertAtenReflectionPad2dOp
288+
: public OpConversionPattern<AtenReflectionPad2dOp> {
289+
public:
290+
using OpConversionPattern::OpConversionPattern;
291+
LogicalResult
292+
matchAndRewrite(AtenReflectionPad2dOp op, OpAdaptor adaptor,
293+
ConversionPatternRewriter &rewriter) const override {
294+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
295+
return failure();
296+
297+
SmallVector<int64_t> padInts;
298+
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
299+
return rewriter.notifyMatchFailure(
300+
op, "only support constant int pad ranges");
301+
302+
Location loc = op.getLoc();
303+
// Some generic helper functions for creating arithmetic operations.
304+
auto createAdd = [&](Value x, Value y) {
305+
return rewriter.create<arith::AddIOp>(loc, x, y);
306+
};
307+
308+
auto createAdds = [&](std::initializer_list<Value> values) {
309+
assert(values.size() >= 2);
310+
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
311+
createAdd);
312+
};
313+
314+
auto createSub = [&](Value x, Value y) {
315+
return rewriter.create<arith::SubIOp>(loc, x, y);
316+
};
317+
318+
auto createSubs = [&](std::initializer_list<Value> values) {
319+
assert(values.size() >= 2);
320+
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
321+
createSub);
322+
};
323+
324+
// Enums for specifying the coordinates of a tile. An "h" prefix
325+
// is used to stand for "horizontal" and "v" for "vertical"
326+
// throughout.
327+
enum PadHLoc { LEFT = 0, RIGHT = 1, HCENTER = 2 };
328+
enum PadVLoc { TOP = 0, BOTTOM = 1, VCENTER = 2 };
329+
330+
// Helper functions for obtaining information about the operator's
331+
// padding arguments.
332+
auto getHPadArgument = [&](PadHLoc l) {
333+
assert(l < HCENTER);
334+
return padInts[l];
335+
};
336+
337+
auto getVPadArgument = [&](PadVLoc l) {
338+
assert(l < VCENTER);
339+
return padInts[2 + l];
340+
};
341+
342+
auto shouldCreateTile = [&](PadVLoc v, PadHLoc h) {
343+
if (!(h == HCENTER || getHPadArgument(h) > 0))
344+
return false;
345+
if (!(v == VCENTER || getVPadArgument(v) > 0))
346+
return false;
347+
348+
return true;
349+
};
350+
351+
Value input = adaptor.getSelf();
352+
MLIRContext *context = rewriter.getContext();
353+
auto inputType = llvm::cast<RankedTensorType>(input.getType());
354+
auto outputType = llvm::cast<RankedTensorType>(
355+
getTypeConverter()->convertType(op->getResult(0).getType()));
356+
unsigned numDims = inputType.getRank();
357+
358+
assert(numDims >= 2 && "Not enough input dimensions");
359+
360+
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
361+
int64_t hDim = numDims - 1;
362+
int64_t vDim = numDims - 2;
363+
Value hDimSize = inputShape[hDim];
364+
Value vDimSize = inputShape[vDim];
365+
366+
assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] &&
367+
"Left padding too large");
368+
assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] &&
369+
"Right padding too large");
370+
assert(getVPadArgument(TOP) < inputType.getShape()[vDim] &&
371+
"Top padding too large");
372+
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
373+
"Bottom padding too large");
374+
375+
Type indexType = rewriter.getIndexType();
376+
Value zero = getConstant(rewriter, loc, 0, indexType);
377+
Value one = getConstant(rewriter, loc, 1, indexType);
378+
379+
Value tileWidth[3];
380+
tileWidth[HCENTER] = hDimSize;
381+
for (auto h : {LEFT, RIGHT})
382+
tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType);
383+
384+
Value tileHeight[3];
385+
tileHeight[VCENTER] = vDimSize;
386+
for (auto v : {TOP, BOTTOM})
387+
tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType);
388+
389+
// Helper to reflect/reverse the i-th dimension of an affine map
390+
// without symbols. This only works if applied on a tensor
391+
// for which the corresponding dimension has a statically
392+
// known size which is good enough since we only apply
393+
// it to reflect the padding slices.
394+
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
395+
int64_t size) {
396+
AffineExpr d = map.getResult(i);
397+
return map.replace(d, size - d - 1, numDims, 0);
398+
};
399+
400+
// Create output shape and tensor
401+
SmallVector<Value> resultShape{inputShape};
402+
resultShape[vDim] =
403+
createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]});
404+
resultShape[hDim] =
405+
createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]});
406+
407+
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape,
408+
inputType.getElementType());
409+
410+
// Construction of the tiles
411+
412+
// Example: central left tile
413+
//
414+
// Let m the width of the left padding as returned by getHPadargument(LEFT)
415+
// and n the size of the input tensor's "horizontal" dimension, i.e.
416+
// hDimSize. Assume that the subtensor of the input tensor in the relevant
417+
// (i.e. last two) dimensions is:
418+
//
419+
// x_1,1 x_1,2 ... x_1,m
420+
// x_2,1 x_2,2 ... x_2,m
421+
// .
422+
// .
423+
// .
424+
// x_n,1 x_n,2 ... x_n,m
425+
//
426+
// The padding tile consists of the columns 2, ..., m + 1
427+
// of the input in reverse order. The first column gets
428+
// skipped because this is the column through which the
429+
// reflection happens.
430+
//
431+
// x_1,m x_1,m-1 ... x_1,2
432+
// x_2,m x_1,m-1 ... x_2,2
433+
// .
434+
// .
435+
// .
436+
// x_n,m x_n,m-1 ... x_n,2
437+
//
438+
// The tile will be inserted to the left of the copy of the input tensor
439+
// in the output tensor, i.e. with horizontal offset 0.
440+
// The top padding determines the vertical offset.
441+
442+
// Tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through
443+
// two sides, i.e. their columns and rows must be reversed.
444+
445+
// Setup information about the tiles
446+
447+
// Compute the offsets for extracting the slice from the
448+
// input. We need to skip the row or column through which
449+
// the tile should be reflected, if any (none for the center tile).
450+
Value extractHOffset[3];
451+
extractHOffset[LEFT] = one;
452+
extractHOffset[HCENTER] = zero;
453+
extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one});
454+
455+
Value extractVOffset[3];
456+
extractVOffset[TOP] = one;
457+
extractVOffset[VCENTER] = zero;
458+
extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one});
459+
460+
// Compute the horizontal and vertical offsets for inserting
461+
// the tiles in the resultTensor.
462+
Value insertHOffset[3];
463+
insertHOffset[LEFT] = zero;
464+
insertHOffset[HCENTER] = tileWidth[LEFT];
465+
insertHOffset[RIGHT] = createAdd(hDimSize, tileWidth[LEFT]);
466+
467+
Value insertVOffset[3];
468+
insertVOffset[TOP] = zero;
469+
insertVOffset[VCENTER] = tileHeight[TOP];
470+
insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]);
471+
472+
auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; };
473+
auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; };
474+
475+
SmallVector<utils::IteratorType> iteratorTypes{
476+
numDims, utils::IteratorType::parallel};
477+
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
478+
SmallVector<Value> allOneStrides(numDims, one);
479+
480+
auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) {
481+
// Create the tile by extracting a slice from the input tenor.
482+
SmallVector<Value> extractShape{inputShape};
483+
extractShape[hDim] = tileWidth[horizontalPos];
484+
extractShape[vDim] = tileHeight[verticalPos];
485+
486+
SmallVector<Value> extractOffsets(numDims, zero);
487+
extractOffsets[hDim] = extractHOffset[horizontalPos];
488+
extractOffsets[vDim] = extractVOffset[verticalPos];
489+
490+
Value tile = rewriter.create<tensor::ExtractSliceOp>(
491+
loc, input, extractOffsets, extractShape, allOneStrides);
492+
493+
// Reverse the tile along the horizontal, vertical, or both
494+
// dimensions.
495+
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
496+
if (shouldHReflect(horizontalPos)) {
497+
inputMap =
498+
reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos));
499+
}
500+
if (shouldVReflect(verticalPos)) {
501+
inputMap =
502+
reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos));
503+
}
504+
505+
tile = rewriter
506+
.create<linalg::GenericOp>(
507+
loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
508+
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
509+
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
510+
b.create<linalg::YieldOp>(nestedLoc, args[0]);
511+
})
512+
.getResult(0);
513+
514+
// Insert the tile in the resultTensor.
515+
SmallVector<Value> insertOffsets(numDims, zero);
516+
insertOffsets[hDim] = insertHOffset[horizontalPos];
517+
insertOffsets[vDim] = insertVOffset[verticalPos];
518+
519+
resultTensor = rewriter.create<tensor::InsertSliceOp>(
520+
loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
521+
};
522+
523+
for (auto v : {TOP, BOTTOM, VCENTER})
524+
for (auto h : {LEFT, RIGHT, HCENTER})
525+
if (shouldCreateTile(v, h))
526+
createTile(v, h);
527+
528+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, resultTensor);
529+
530+
return success();
531+
}
532+
};
533+
} // namespace
534+
247535
namespace {
248536
class ConvertAtenFlattenUsingIntsOp
249537
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
@@ -1552,6 +1840,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
15521840
MLIRContext *context = patterns.getContext();
15531841
target.addIllegalOp<AtenReflectionPad1dOp>();
15541842
patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context);
1843+
target.addIllegalOp<AtenReflectionPad2dOp>();
1844+
patterns.add<ConvertAtenReflectionPad2dOp>(typeConverter, context);
15551845
target.addIllegalOp<AtenFlattenUsingIntsOp>();
15561846
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
15571847
target.addIllegalOp<AtenViewOp>();

0 commit comments

Comments
 (0)