Skip to content

Commit ccaac85

Browse files
renxidaXida Renfrederik-h
authored
implement aten.conv1d, aten.conv3d, and aten.conv_tbc (#2757)
convolution with [time,batch,channel] ordering, as opposed to the default [batch, channel, time]. Currently implementing by transposing the input and output, but may need to get its own implementation in the future because this is supposed to be an op that gives a speedup. This is used by fairseq (facebookresearch/fairseq#172). (in case you were wondering like me, this is different from transposed convolution. Transposed convolution has fractional strides). --------- Co-authored-by: Xida Ren <[email protected]> Co-authored-by: Frederik Harwath <[email protected]>
1 parent 77ae563 commit ccaac85

File tree

9 files changed

+626
-19
lines changed

9 files changed

+626
-19
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5494,6 +5494,35 @@ def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [
54945494
}];
54955495
}
54965496

5497+
def Torch_AtenConv3dOp : Torch_Op<"aten.conv3d", [
5498+
AllowsTypeRefinement,
5499+
HasValueSemantics,
5500+
ReadOnly
5501+
]> {
5502+
let summary = "Generated op for `aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)`";
5503+
let arguments = (ins
5504+
AnyTorchTensorType:$input,
5505+
AnyTorchTensorType:$weight,
5506+
AnyTorchOptionalTensorType:$bias,
5507+
AnyTorchListOfTorchIntType:$stride,
5508+
AnyTorchListOfTorchIntType:$padding,
5509+
AnyTorchListOfTorchIntType:$dilation,
5510+
Torch_IntType:$groups
5511+
);
5512+
let results = (outs
5513+
AnyTorchTensorType:$result
5514+
);
5515+
let hasCustomAssemblyFormat = 1;
5516+
let extraClassDefinition = [{
5517+
ParseResult AtenConv3dOp::parse(OpAsmParser &parser, OperationState &result) {
5518+
return parseDefaultTorchOp(parser, result, 7, 1);
5519+
}
5520+
void AtenConv3dOp::print(OpAsmPrinter &printer) {
5521+
printDefaultTorchOp(printer, *this, 7, 1);
5522+
}
5523+
}];
5524+
}
5525+
54975526
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
54985527
AllowsTypeRefinement,
54995528
HasValueSemantics,
@@ -5523,6 +5552,35 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
55235552
}];
55245553
}
55255554

5555+
def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [
5556+
AllowsTypeRefinement,
5557+
HasValueSemantics,
5558+
ReadOnly
5559+
]> {
5560+
let summary = "Generated op for `aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)`";
5561+
let arguments = (ins
5562+
AnyTorchTensorType:$input,
5563+
AnyTorchTensorType:$weight,
5564+
AnyTorchOptionalTensorType:$bias,
5565+
AnyTorchListOfTorchIntType:$stride,
5566+
AnyTorchListOfTorchIntType:$padding,
5567+
AnyTorchListOfTorchIntType:$dilation,
5568+
Torch_IntType:$groups
5569+
);
5570+
let results = (outs
5571+
AnyTorchTensorType:$result
5572+
);
5573+
let hasCustomAssemblyFormat = 1;
5574+
let extraClassDefinition = [{
5575+
ParseResult AtenConv1dOp::parse(OpAsmParser &parser, OperationState &result) {
5576+
return parseDefaultTorchOp(parser, result, 7, 1);
5577+
}
5578+
void AtenConv1dOp::print(OpAsmPrinter &printer) {
5579+
printDefaultTorchOp(printer, *this, 7, 1);
5580+
}
5581+
}];
5582+
}
5583+
55265584
def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [
55275585
AllowsTypeRefinement,
55285586
HasValueSemantics,
@@ -5613,6 +5671,61 @@ def Torch_AtenConvTranspose3dInputOp : Torch_Op<"aten.conv_transpose3d.input", [
56135671
}];
56145672
}
56155673

5674+
def Torch_AtenConvTbcOp : Torch_Op<"aten.conv_tbc", [
5675+
AllowsTypeRefinement,
5676+
HasValueSemantics,
5677+
ReadOnly
5678+
]> {
5679+
let summary = "Generated op for `aten::conv_tbc : (Tensor, Tensor, Tensor, int) -> (Tensor)`";
5680+
let arguments = (ins
5681+
AnyTorchTensorType:$self,
5682+
AnyTorchTensorType:$weight,
5683+
AnyTorchTensorType:$bias,
5684+
Torch_IntType:$pad
5685+
);
5686+
let results = (outs
5687+
AnyTorchTensorType:$result
5688+
);
5689+
let hasCustomAssemblyFormat = 1;
5690+
let extraClassDefinition = [{
5691+
ParseResult AtenConvTbcOp::parse(OpAsmParser &parser, OperationState &result) {
5692+
return parseDefaultTorchOp(parser, result, 4, 1);
5693+
}
5694+
void AtenConvTbcOp::print(OpAsmPrinter &printer) {
5695+
printDefaultTorchOp(printer, *this, 4, 1);
5696+
}
5697+
}];
5698+
}
5699+
5700+
def Torch_AtenConvTbcBackwardOp : Torch_Op<"aten.conv_tbc_backward", [
5701+
AllowsTypeRefinement,
5702+
HasValueSemantics,
5703+
ReadOnly
5704+
]> {
5705+
let summary = "Generated op for `aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)`";
5706+
let arguments = (ins
5707+
AnyTorchTensorType:$self,
5708+
AnyTorchTensorType:$input,
5709+
AnyTorchTensorType:$weight,
5710+
AnyTorchTensorType:$bias,
5711+
Torch_IntType:$pad
5712+
);
5713+
let results = (outs
5714+
AnyTorchTensorType:$result0,
5715+
AnyTorchTensorType:$result1,
5716+
AnyTorchTensorType:$result2
5717+
);
5718+
let hasCustomAssemblyFormat = 1;
5719+
let extraClassDefinition = [{
5720+
ParseResult AtenConvTbcBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
5721+
return parseDefaultTorchOp(parser, result, 5, 3);
5722+
}
5723+
void AtenConvTbcBackwardOp::print(OpAsmPrinter &printer) {
5724+
printDefaultTorchOp(printer, *this, 5, 3);
5725+
}
5726+
}];
5727+
}
5728+
56165729
def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
56175730
AllowsTypeRefinement,
56185731
HasValueSemantics,

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -566,9 +566,9 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
566566
return op.emitError("unimplemented: non-floating point type");
567567
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
568568
size_t numSpacialDims = inRank - 2;
569-
if (numSpacialDims != 2)
569+
if (numSpacialDims < 1 || numSpacialDims > 3)
570570
return rewriter.notifyMatchFailure(
571-
op, "unimplemented: only 2D convolution currently supported");
571+
op, "unimplemented: only 1d-3d convolution currently supported");
572572

573573
Type intType = IntegerType::get(context, 64);
574574
auto castIndexToInt = [&](Value v) {
@@ -796,15 +796,50 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
796796
weightSliceSizes.append(weightDims);
797797

798798
Value conv;
799+
// the code so far is able to respect all numSpacialDims
800+
// the code below this point is numSpacialDims specific and groupSize specific
801+
// TODO: factor out the above code into a helper function, and then separate convolution into:
802+
// - grouped 1d-3d
803+
// - ungrouped 1d-3d
799804
if (groupSize == 1) {
800-
// TODO: add 1D and 3D case
801-
conv =
802-
rewriter
803-
.create<linalg::Conv2DNchwFchwOp>(
804-
loc, outputTensor.getType(), ValueRange{paddedInput, weight},
805-
outputTensor, stridesAttr, dilationAttr)
806-
.getResult(0);
805+
// TODO: 3D case
806+
switch (numSpacialDims) {
807+
case 1:
808+
conv = rewriter
809+
.create<linalg::Conv1DNcwFcwOp>(
810+
loc, outputTensor.getType(),
811+
ValueRange{paddedInput, weight}, outputTensor,
812+
stridesAttr, dilationAttr)
813+
.getResult(0);
814+
break;
815+
case 2:
816+
conv =
817+
rewriter
818+
.create<linalg::Conv2DNchwFchwOp>(
819+
loc, outputTensor.getType(), ValueRange{paddedInput, weight},
820+
outputTensor, stridesAttr, dilationAttr)
821+
.getResult(0);
822+
break;
823+
case 3:
824+
conv =
825+
rewriter
826+
.create<linalg::Conv3DNcdhwFcdhwOp>(
827+
loc, outputTensor.getType(), ValueRange{paddedInput, weight},
828+
outputTensor, stridesAttr, dilationAttr)
829+
.getResult(0);
830+
break;
831+
default:
832+
return rewriter.notifyMatchFailure(
833+
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
834+
};
835+
Type newResultType = getTypeConverter()->convertType(op.getType());
836+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
837+
return success();
807838
} else {
839+
if(numSpacialDims != 2)
840+
return rewriter.notifyMatchFailure(
841+
op, "unimplemented: only 2D grouped convolution supported");
842+
808843
// Special depthwise case
809844
auto inShape = makeShapeTorchCompatible(
810845
input.getType().cast<RankedTensorType>().getShape());
@@ -824,11 +859,11 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
824859
loc, collapsedType, weight, collapsedDims);
825860

826861
conv = rewriter
827-
.create<linalg::DepthwiseConv2DNchwChwOp>(
828-
loc, outputTensor.getType(),
829-
ValueRange{paddedInput, collapsedWeight}, outputTensor,
830-
stridesAttr, dilationAttr)
831-
.getResult(0);
862+
.create<linalg::DepthwiseConv2DNchwChwOp>(
863+
loc, outputTensor.getType(),
864+
ValueRange{paddedInput, collapsedWeight}, outputTensor,
865+
stridesAttr, dilationAttr)
866+
.getResult(0);
832867

833868
Type newResultType = getTypeConverter()->convertType(op.getType());
834869
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
@@ -902,11 +937,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
902937
conv = rewriter.create<tensor::CollapseShapeOp>(
903938
loc, outputTensor.getType(), conv,
904939
expandOutputTensor.getReassociationIndices());
940+
Type newResultType = getTypeConverter()->convertType(op.getType());
941+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
942+
return success();
905943
}
906-
907-
Type newResultType = getTypeConverter()->convertType(op.getType());
908-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
909-
return success();
910944
}
911945
};
912946
} // namespace

0 commit comments

Comments
 (0)