Skip to content

Commit 52ecff8

Browse files
yyp0qingyunqu
authored andcommitted
[stablehlo] support aten.view.dtype lowering (#3778)
1 parent 52f5450 commit 52ecff8

File tree

3 files changed

+97
-2
lines changed

3 files changed

+97
-2
lines changed

lib/Conversion/TorchToStablehlo/ViewLike.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,20 +161,77 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
161161
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
162162
using OpAdaptor = typename AtenOpT::Adaptor;
163163

164+
unsigned getBitWidth(Type type) const {
165+
if (auto complexTy = dyn_cast<ComplexType>(type))
166+
return 2 * getBitWidth(complexTy.getElementType());
167+
return type.getIntOrFloatBitWidth();
168+
}
169+
164170
LogicalResult
165171
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
166172
ConversionPatternRewriter &rewriter) const override {
167173
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
168174
if (!rankType)
169-
return op.emitError("Only ranked tensor types are currently supported");
175+
return op.emitError("Only ranked tensor types are currently supported.");
176+
auto loc = op.getLoc();
177+
178+
// support AtenViewDtypeOp
179+
if (isa<AtenViewDtypeOp>(op)) {
180+
auto self = adaptor.getSelf();
181+
auto baseResultTy = dyn_cast<BaseTensorType>(op.getType());
182+
183+
// infer the result shape
184+
auto operandElt = rankType.getElementType();
185+
auto targetElt = baseResultTy.getDtype();
186+
auto operandEltBitWidth = getBitWidth(operandElt);
187+
auto targetEltBitWidth = getBitWidth(targetElt);
188+
auto operandSizes = rankType.getShape();
189+
SmallVector<int64_t> castShape(operandSizes);
190+
if (operandEltBitWidth > targetEltBitWidth) {
191+
int64_t last_size = operandEltBitWidth / targetEltBitWidth;
192+
castShape.push_back(last_size);
193+
} else if (operandEltBitWidth < targetEltBitWidth) {
194+
int64_t last_size = targetEltBitWidth / operandEltBitWidth;
195+
if (!ShapedType::isDynamic(castShape.back()) and
196+
last_size != castShape.back()) {
197+
return rewriter.notifyMatchFailure(
198+
op, "The last dim size is not equal to targetEltBitWidth / "
199+
"operandEltBitWidth.");
200+
} else {
201+
castShape.pop_back();
202+
}
203+
}
204+
205+
auto resultType =
206+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
207+
baseResultTy);
208+
if (!dyn_cast<ShapedType>(resultType).hasStaticShape()) {
209+
return rewriter.notifyMatchFailure(
210+
op, "Currently only support static output shape.");
211+
}
212+
213+
auto castType =
214+
baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype());
215+
auto cast = rewriter.create<stablehlo::BitcastConvertOp>(
216+
loc,
217+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
218+
castType),
219+
self);
220+
221+
auto reshape =
222+
rewriter.create<stablehlo::ReshapeOp>(loc, resultType, cast);
223+
224+
rewriter.replaceOp(op, reshape);
225+
226+
return success();
227+
}
170228

171229
// collect Value of dims
172230
SmallVector<Value, 4> dimSizes;
173231
if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) {
174232
return op.emitError("Dims size must be a list of Scalar");
175233
}
176234

177-
auto loc = op.getLoc();
178235
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
179236
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
180237
op,
@@ -236,6 +293,13 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
236293
SmallVector<Value, 4> &dimSizes) const;
237294
};
238295

296+
template <>
297+
bool ConvertAtenViewOp<AtenViewDtypeOp>::getAtenViewOpSizes(
298+
AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
299+
SmallVector<Value, 4> &dimSizes) const {
300+
return false;
301+
}
302+
239303
template <>
240304
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
241305
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
@@ -496,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
496560
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
497561
target.addIllegalOp<AtenOp>(); \
498562
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
563+
INSERT_VIEW_OP_PATTERN(AtenViewDtypeOp);
499564
INSERT_VIEW_OP_PATTERN(AtenViewOp);
500565
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
501566
#undef INSERT_VIEW_OP_PATTERN

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@
498498
"UpSampleNearest2dDynamicFactor_basic",
499499
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
500500
"ViewSizeFromOtherTensor_basic",
501+
"ViewDtypeStaticModule_basic",
501502
"WeightNormInterfaceModule_basic",
502503
}
503504

@@ -2849,6 +2850,11 @@
28492850
"ReduceMaxAlongDimUnsignedInt_basic",
28502851
"ReduceMinAlongDimUnsignedInt_basic",
28512852
"UnfoldModule_basic",
2853+
"Unfold_Module_Rank_4",
2854+
"Unfold_Module_Rank_Zero_basic",
2855+
"Unfold_Module_Rank_Zero_Size_Zero_basic",
2856+
"Unfold_Module_Dynamic_basic",
2857+
"ViewDtypeStaticModule_basic",
28522858
}
28532859

28542860
if torch_version_for_comparison() < version.parse("2.3.0.dev"):

projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,30 @@ def ReshapeDynamicModule_basic(module, tu: TestUtils):
11741174
# ==============================================================================
11751175

11761176

1177+
class ViewDtypeStaticModule(torch.nn.Module):
1178+
def __init__(self):
1179+
super().__init__()
1180+
1181+
@export
1182+
@annotate_args(
1183+
[
1184+
None,
1185+
([12, 1], torch.float32, True),
1186+
]
1187+
)
1188+
def forward(self, a):
1189+
res = a.view(torch.int8)
1190+
return res
1191+
1192+
1193+
@register_test_case(module_factory=lambda: ViewDtypeStaticModule())
1194+
def ViewDtypeStaticModule_basic(module, tu: TestUtils):
1195+
module.forward(tu.rand(12, 1))
1196+
1197+
1198+
# ==============================================================================
1199+
1200+
11771201
class ReshapeAliasCollapseModule(torch.nn.Module):
11781202
def __init__(self):
11791203
super().__init__()

0 commit comments

Comments
 (0)