diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index cf4e2b4f07f0..4688ffc7808a 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1282,29 +1282,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createRemainderPayload(b, loc, converter, payloadArgs, remTensor, operands); } - if (auto fmod = dyn_cast(op)) { - Type newResultType = - cast(converter->convertType(fmod.getType())) - .getElementType(); - - Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); - Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); - Value result; - - if (isa(newResultType)) { - Value n = b.create(loc, self, other); - n = b.create(loc, n); - Value n_y = b.create(loc, n, other); - result = b.create(loc, self, n_y); - } else if (isa(newResultType)) { - Value n = b.create(loc, self, other); - Value n_y = b.create(loc, n, other); - result = b.create(loc, self, n_y); - } else { - fmod.emitError("Unsupported type encountered for AtenFmodTensorOp."); - } - return result; - } if (auto reciprocal = dyn_cast(op)) { Type dtype = cast(converter->convertType(reciprocal.getType())) @@ -1612,23 +1589,23 @@ class ConvertElementwiseOp : public ConversionPattern { AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, - AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp, - AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, - AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, - AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, - AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp, - Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, - AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, - AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, - AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, - AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, - AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, - AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, - AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp, - AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, + AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, + Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp, + AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, + AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, + AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, + AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, + AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, + AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, + AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, + AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, + AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, + AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -3385,10 +3362,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp, - AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, - AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>(); + AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenQuantizePerTensorOp, AtenIscloseOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context);