@@ -161,20 +161,77 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
161161 using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
162162 using OpAdaptor = typename AtenOpT::Adaptor;
163163
164+ unsigned getBitWidth (Type type) const {
165+ if (auto complexTy = dyn_cast<ComplexType>(type))
166+ return 2 * getBitWidth (complexTy.getElementType ());
167+ return type.getIntOrFloatBitWidth ();
168+ }
169+
164170 LogicalResult
165171 matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
166172 ConversionPatternRewriter &rewriter) const override {
167173 auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf ().getType ());
168174 if (!rankType)
169- return op.emitError (" Only ranked tensor types are currently supported" );
175+ return op.emitError (" Only ranked tensor types are currently supported." );
176+ auto loc = op.getLoc ();
177+
178+ // support AtenViewDtypeOp
179+ if (isa<AtenViewDtypeOp>(op)) {
180+ auto self = adaptor.getSelf ();
181+ auto baseResultTy = dyn_cast<BaseTensorType>(op.getType ());
182+
183+ // infer the result shape
184+ auto operandElt = rankType.getElementType ();
185+ auto targetElt = baseResultTy.getDtype ();
186+ auto operandEltBitWidth = getBitWidth (operandElt);
187+ auto targetEltBitWidth = getBitWidth (targetElt);
188+ auto operandSizes = rankType.getShape ();
189+ SmallVector<int64_t > castShape (operandSizes);
190+ if (operandEltBitWidth > targetEltBitWidth) {
191+ int64_t last_size = operandEltBitWidth / targetEltBitWidth;
192+ castShape.push_back (last_size);
193+ } else if (operandEltBitWidth < targetEltBitWidth) {
194+ int64_t last_size = targetEltBitWidth / operandEltBitWidth;
195+ if (!ShapedType::isDynamic (castShape.back ()) and
196+ last_size != castShape.back ()) {
197+ return rewriter.notifyMatchFailure (
198+ op, " The last dim size is not equal to targetEltBitWidth / "
199+ " operandEltBitWidth." );
200+ } else {
201+ castShape.pop_back ();
202+ }
203+ }
204+
205+ auto resultType =
206+ OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
207+ baseResultTy);
208+ if (!dyn_cast<ShapedType>(resultType).hasStaticShape ()) {
209+ return rewriter.notifyMatchFailure (
210+ op, " Currently only support static output shape." );
211+ }
212+
213+ auto castType =
214+ baseResultTy.getWithSizesAndDtype (castShape, baseResultTy.getDtype ());
215+ auto cast = rewriter.create <stablehlo::BitcastConvertOp>(
216+ loc,
217+ OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
218+ castType),
219+ self);
220+
221+ auto reshape =
222+ rewriter.create <stablehlo::ReshapeOp>(loc, resultType, cast);
223+
224+ rewriter.replaceOp (op, reshape);
225+
226+ return success ();
227+ }
170228
171229 // collect Value of dims
172230 SmallVector<Value, 4 > dimSizes;
173231 if (!getAtenViewOpSizes (op, adaptor, rewriter, dimSizes)) {
174232 return op.emitError (" Dims size must be a list of Scalar" );
175233 }
176234
177- auto loc = op.getLoc ();
178235 if (dimSizes.size () == 0 || rankType.getRank () == 0 ) {
179236 rewriter.replaceOpWithNewOp <stablehlo::ReshapeOp>(
180237 op,
@@ -236,6 +293,13 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
236293 SmallVector<Value, 4 > &dimSizes) const ;
237294};
238295
296+ template <>
297+ bool ConvertAtenViewOp<AtenViewDtypeOp>::getAtenViewOpSizes(
298+ AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
299+ SmallVector<Value, 4 > &dimSizes) const {
300+ return false ;
301+ }
302+
239303template <>
240304bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
241305 AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
@@ -496,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
496560#define INSERT_VIEW_OP_PATTERN (AtenOp ) \
497561 target.addIllegalOp <AtenOp>(); \
498562 patterns.add <ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
563+ INSERT_VIEW_OP_PATTERN (AtenViewDtypeOp);
499564 INSERT_VIEW_OP_PATTERN (AtenViewOp);
500565 INSERT_VIEW_OP_PATTERN (AtenReshapeOp);
501566#undef INSERT_VIEW_OP_PATTERN
0 commit comments