Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions lib/Conversion/TorchToStablehlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,20 +161,77 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor;

unsigned getBitWidth(Type type) const {
if (auto complexTy = dyn_cast<ComplexType>(type))
return 2 * getBitWidth(complexTy.getElementType());
return type.getIntOrFloatBitWidth();
}

LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
if (!rankType)
return op.emitError("Only ranked tensor types are currently supported");
return op.emitError("Only ranked tensor types are currently supported.");
auto loc = op.getLoc();

// support AtenViewDtypeOp
if (isa<AtenViewDtypeOp>(op)) {
auto self = adaptor.getSelf();
auto baseResultTy = dyn_cast<BaseTensorType>(op.getType());

// infer the result shape
auto operandElt = rankType.getElementType();
auto targetElt = baseResultTy.getDtype();
auto operandEltBitWidth = getBitWidth(operandElt);
auto targetEltBitWidth = getBitWidth(targetElt);
auto operandSizes = rankType.getShape();
SmallVector<int64_t> castShape(operandSizes);
if (operandEltBitWidth > targetEltBitWidth) {
int64_t last_size = operandEltBitWidth / targetEltBitWidth;
castShape.push_back(last_size);
} else if (operandEltBitWidth < targetEltBitWidth) {
int64_t last_size = targetEltBitWidth / operandEltBitWidth;
if (!ShapedType::isDynamic(castShape.back()) and
last_size != castShape.back()) {
return rewriter.notifyMatchFailure(
op, "The last dim size is not equal to targetEltBitWidth / "
"operandEltBitWidth.");
} else {
castShape.pop_back();
}
}

auto resultType =
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
baseResultTy);
if (!dyn_cast<ShapedType>(resultType).hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "Currently only support static output shape.");
}

auto castType =
baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype());
auto cast = rewriter.create<stablehlo::BitcastConvertOp>(
loc,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
castType),
self);

auto reshape =
rewriter.create<stablehlo::ReshapeOp>(loc, resultType, cast);

rewriter.replaceOp(op, reshape);

return success();
}

// collect Value of dims
SmallVector<Value, 4> dimSizes;
if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) {
return op.emitError("Dims size must be a list of Scalar");
}

auto loc = op.getLoc();
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op,
Expand Down Expand Up @@ -236,6 +293,13 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
SmallVector<Value, 4> &dimSizes) const;
};

template <>
bool ConvertAtenViewOp<AtenViewDtypeOp>::getAtenViewOpSizes(
AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &dimSizes) const {
return false;
}

template <>
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -496,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
INSERT_VIEW_OP_PATTERN(AtenViewDtypeOp);
INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
#undef INSERT_VIEW_OP_PATTERN
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@
"UpSampleNearest2dDynamicFactor_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic",
"ViewDtypeStaticModule_basic",
"WeightNormInterfaceModule_basic",
# Error: `aten.as_strided` op is not supported
"ChunkListUnpackDynamic_Module_basic",
Expand Down Expand Up @@ -3167,6 +3168,7 @@
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Dynamic_basic",
"ViewDtypeStaticModule_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down
24 changes: 24 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,30 @@ def ReshapeDynamicModule_basic(module, tu: TestUtils):
# ==============================================================================


class ViewDtypeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([12, 1], torch.float32, True),
]
)
def forward(self, a):
res = a.view(torch.int8)
return res


@register_test_case(module_factory=lambda: ViewDtypeStaticModule())
def ViewDtypeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(12, 1))


# ==============================================================================


class ReshapeAliasCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down