@@ -153,11 +153,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
153153    return  rewriter.notifyMatchFailure (op,
154154                                       " Unable to extract the scalar constant"  );
155155
156+   int64_t  numElem = 1 ;
157+   for  (int64_t  dim : dshape)
158+     numElem *= dim;
159+ 
156160  if  (isa<mlir::FloatType>(dtype)) {
157-     tosaTensor = tosa::getConstTensor<float >(rewriter, op,
158-                                              (isFloat ? doubleValue : intValue),
159-                                              dshape, dtype)
160-                      .value ();
161+     tosaTensor =
162+         tosa::getConstTensor<float >(
163+             rewriter, op,
164+             SmallVector<float >(numElem, (isFloat ? doubleValue : intValue)),
165+             dshape, dtype)
166+             .value ();
161167  } else  if  (auto  intType = dyn_cast<mlir::IntegerType>(dtype)) {
162168    auto  w = intType.getWidth ();
163169    if  (w != 1  && w != 32  && w != 64 )
@@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
173179      }
174180      bool  d = isFloat ? static_cast <bool >(doubleValue)
175181                       : static_cast <bool >(intValue);
176-       tosaTensor =
177-           tosa::getConstTensor<bool >(rewriter, op, {d}, dshape).value ();
182+       tosaTensor = tosa::getConstTensor<bool >(
183+                        rewriter, op, SmallVector<bool >(numElem, d), dshape)
184+                        .value ();
178185    } else  if  (w == 32 ) {
179186      if  (!isInValidRange<int32_t >(isFloat, doubleValue, isInt, intValue)) {
180187        return  rewriter.notifyMatchFailure (
@@ -183,17 +190,19 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
183190      }
184191      int32_t  d = isFloat ? static_cast <int32_t >(doubleValue)
185192                          : static_cast <int32_t >(intValue);
186-       tosaTensor =
187-           tosa::getConstTensor<int32_t >(rewriter, op, {d}, dshape).value ();
193+       tosaTensor = tosa::getConstTensor<int32_t >(
194+                        rewriter, op, SmallVector<int32_t >(numElem, d), dshape)
195+                        .value ();
188196    } else  if  (w == 64 ) {
189197      if  (!isInValidRange<int64_t >(isFloat, doubleValue, isInt, intValue)) {
190198        return  rewriter.notifyMatchFailure (
191199            op, " Supplied value of scalar constant exceeds limits " 
192200                " of destination type"  );
193201      }
194202      int64_t  d = (isFloat ? static_cast <int64_t >(doubleValue) : intValue);
195-       tosaTensor =
196-           tosa::getConstTensor<int64_t >(rewriter, op, {d}, dshape).value ();
203+       tosaTensor = tosa::getConstTensor<int64_t >(
204+                        rewriter, op, SmallVector<int64_t >(numElem, d), dshape)
205+                        .value ();
197206    }
198207  } else  {
199208    return  rewriter.notifyMatchFailure (op, " Usupported element type"  );
@@ -5320,7 +5329,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
53205329};
53215330
53225331template  <typename  AtenOpT>
5323- class  ConvertAtenFillScalarOp  : public  OpConversionPattern <AtenOpT> {
5332+ class  ConvertAtenFillOp  : public  OpConversionPattern <AtenOpT> {
53245333public: 
53255334  using  OpConversionPattern<AtenOpT>::OpConversionPattern;
53265335  using  OpAdaptor = typename  AtenOpT::Adaptor;
@@ -5336,18 +5345,48 @@ class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
53365345          op, " Only Tensor types with static shapes are currently supported"  );
53375346
53385347    Type outElemTy = outType.getElementType ();
5339-     if  (!outElemTy.isIntOrFloat ()) { 
5348+     if  (!outElemTy.isIntOrFloat ())
53405349      return  rewriter.notifyMatchFailure (
53415350          op, " Only floating-point or integer datatype legalization supported"  );
5351+ 
5352+     Value fillValueTargetTensor;
5353+     if  constexpr  (std::is_same<AtenOpT, AtenFillTensorOp>()) {
5354+       //  Reshape value tensor to have same rank and shape as input
5355+       auto  inputRank =
5356+           cast<RankedTensorType>(adaptor.getSelf ().getType ()).getRank ();
5357+ 
5358+       auto  fillValue = adaptor.getValue ();
5359+       auto  fillValueType = dyn_cast<TensorType>(fillValue.getType ());
5360+       if  (!fillValueType)
5361+         return  rewriter.notifyMatchFailure (op, " Fill value is not a tensor"  );
5362+       auto  fillValueElemTy = fillValueType.getElementType ();
5363+ 
5364+       SmallVector<int64_t > fillValueMatchedInputRankShape (inputRank, 1 );
5365+ 
5366+       auto  fillValueMatchedInputRankType = RankedTensorType::get (
5367+           makeShapeTorchCompatible (fillValueMatchedInputRankShape),
5368+           fillValueElemTy);
5369+ 
5370+       auto  fillValueMatchedInputRankTensor = rewriter.create <tosa::ReshapeOp>(
5371+           op->getLoc (), fillValueMatchedInputRankType, fillValue,
5372+           rewriter.getDenseI64ArrayAttr (fillValueMatchedInputRankShape));
5373+ 
5374+       fillValueTargetTensor = rewriter.create <tosa::TileOp>(
5375+           op->getLoc (),
5376+           RankedTensorType::get (makeShapeTorchCompatible (outType.getShape ()),
5377+                                 fillValueElemTy),
5378+           fillValueMatchedInputRankTensor.getResult (),
5379+           makeShapeTorchCompatible (outType.getShape ()));
5380+     } else  {
5381+       if  (failed (torchScalarToTosaTensor (
5382+               rewriter, op, op.getValue (), fillValueTargetTensor, outElemTy,
5383+               makeShapeTorchCompatible (outType.getShape ()))))
5384+         return  rewriter.notifyMatchFailure (
5385+             op, " Fill value must be a scalar constant"  );
53425386    }
5343-     Value constOp;
5344-     if  (failed (torchScalarToTosaTensor (
5345-             rewriter, op, op.getValue (), constOp, outElemTy,
5346-             makeShapeTorchCompatible (outType.getShape ()))))
5347-       return  rewriter.notifyMatchFailure (
5348-           op, " Supplied value must be a Scalar constant"  );
53495387
5350-     rewriter.replaceOpWithNewOp <tosa::CastOp>(op, outType, constOp);
5388+     rewriter.replaceOpWithNewOp <tosa::CastOp>(op, outType,
5389+                                               fillValueTargetTensor);
53515390
53525391    return  success ();
53535392  }
@@ -5869,6 +5908,127 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
58695908  return  success ();
58705909}
58715910
5911+ //  Legalization for aten.flip
5912+ template  <>
5913+ LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
5914+     AtenFlipOp op, OpAdaptor adaptor,
5915+     ConversionPatternRewriter &rewriter) const  {
5916+ 
5917+   auto  self = adaptor.getSelf ();
5918+ 
5919+   auto  selfTy = dyn_cast<RankedTensorType>(self.getType ());
5920+   if  (!selfTy)
5921+     return  rewriter.notifyMatchFailure (
5922+         op, " Only ranked tensor types are currently supported"  );
5923+ 
5924+   SmallVector<int64_t > dims;
5925+   if  (!matchPattern (adaptor.getDims (), m_TorchListOfConstantInts (dims)))
5926+     return  rewriter.notifyMatchFailure (
5927+         op, " Only constant dims are currently supported"  );
5928+ 
5929+   auto  selfRank = selfTy.getRank ();
5930+ 
5931+   auto  resultTy = getTypeConverter ()->convertType (op.getType ());
5932+   Value result = self;
5933+ 
5934+   for  (auto  &dim : dims) {
5935+     dim = toPositiveDim (dim, selfRank);
5936+     if  (!isValidDim (dim, selfRank))
5937+       return  rewriter.notifyMatchFailure (op, " Not all dims are valid"  );
5938+ 
5939+     result = rewriter.create <tosa::ReverseOp>(op->getLoc (), resultTy, result,
5940+                                               static_cast <int32_t >(dim));
5941+   }
5942+ 
5943+   rewriter.replaceOp (op, result);
5944+   return  success ();
5945+ }
5946+ 
5947+ //  Legalization for aten.round:
5948+ //  Rounds elements of input to the nearest integer.
5949+ //  Implements "round half to even" to break ties when a number is equidistant
5950+ //  from two integers.
5951+ template  <>
5952+ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
5953+     AtenRoundOp op, OpAdaptor adaptor,
5954+     ConversionPatternRewriter &rewriter) const  {
5955+   //  To round to the nearest integer, we will consider the fractional part of
5956+   //  the input element (= input element - integer part of element). If the
5957+   //  fractional part is smaller than 0.5, round the number down. If the
5958+   //  fractional part is 0.5, apply "round half to even" rule. If the fractional
5959+   //  part is greater than 0.5, round up.
5960+   // 
5961+   //  if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
5962+   //    res = floor(input)
5963+   //  else:
5964+   //    res = ceil(input)
5965+ 
5966+   auto  self = adaptor.getSelf ();
5967+ 
5968+   auto  selfTy = dyn_cast<TensorType>(self.getType ());
5969+   if  (!selfTy)
5970+     return  rewriter.notifyMatchFailure (op, " Only tensor types supported"  );
5971+ 
5972+   auto  resultTy =
5973+       cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType ()));
5974+ 
5975+   auto  boolTy =
5976+       RankedTensorType::get (resultTy.getShape (), rewriter.getIntegerType (1 ));
5977+ 
5978+   auto  resultElemTy = resultTy.getElementType ();
5979+ 
5980+   auto  oneHalf =
5981+       tosa::getConstTensor<float >(rewriter, op, 0.5 , {}, resultElemTy).value ();
5982+ 
5983+   auto  two =
5984+       tosa::getConstTensor<float >(rewriter, op, 2 , {}, resultElemTy).value ();
5985+ 
5986+   auto  floorInput =
5987+       rewriter.create <tosa::FloorOp>(op->getLoc (), resultTy, self);
5988+ 
5989+   //  input - floor(input)
5990+   auto  fractionalPart = rewriter.create <tosa::SubOp>(
5991+       op->getLoc (), resultTy, self, floorInput.getResult ());
5992+ 
5993+   auto  ceilInput = rewriter.create <tosa::CeilOp>(op->getLoc (), resultTy, self);
5994+ 
5995+   auto  floorInputDivByTwo = rewriter.create <tosa::MulOp>(
5996+       op->getLoc (), resultTy, floorInput.getResult (), oneHalf, /* shift=*/ 0 );
5997+ 
5998+   auto  floorDivResult = rewriter.create <tosa::FloorOp>(
5999+       op->getLoc (), resultTy, floorInputDivByTwo.getResult ());
6000+ 
6001+   //  (floor(input) // 2) * 2
6002+   auto  evenComparison = rewriter.create <tosa::MulOp>(
6003+       op->getLoc (), resultTy, floorDivResult.getResult (), two, /* shift=*/ 0 );
6004+ 
6005+   //  floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
6006+   auto  floorInputEven = rewriter.create <tosa::EqualOp>(
6007+       op->getLoc (), boolTy, floorInput.getResult (), evenComparison.getResult ());
6008+ 
6009+   auto  fracEqualOneHalf = rewriter.create <tosa::EqualOp>(
6010+       op->getLoc (), boolTy, fractionalPart.getResult (), oneHalf);
6011+ 
6012+   auto  fracLtOneHalf = rewriter.create <tosa::GreaterOp>(
6013+       op->getLoc (), boolTy, oneHalf, fractionalPart.getResult ());
6014+ 
6015+   //  (frac == 0.5) && (floor(input) % 2 == 0)
6016+   auto  fracEqualOneHalfCond = rewriter.create <tosa::LogicalAndOp>(
6017+       op->getLoc (), boolTy, fracEqualOneHalf.getResult (),
6018+       floorInputEven.getResult ());
6019+ 
6020+   //  (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
6021+   auto  floorResultCond = rewriter.create <tosa::LogicalOrOp>(
6022+       op->getLoc (), boolTy, fracLtOneHalf.getResult (),
6023+       fracEqualOneHalfCond.getResult ());
6024+ 
6025+   rewriter.replaceOpWithNewOp <tosa::SelectOp>(
6026+       op, resultTy, floorResultCond.getResult (), floorInput.getResult (),
6027+       ceilInput.getResult ());
6028+ 
6029+   return  success ();
6030+ }
6031+ 
58726032//  Template to create supporting diagonal mask tensor for aten.diagonal
58736033template  <typename  T>
58746034Value createDiagonalMask (PatternRewriter &rewriter, Operation *op,
@@ -6052,6 +6212,7 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
60526212
60536213  return  success ();
60546214}
6215+ 
60556216} //  namespace
60566217
60576218//  -----------------------------------------------------------------------------
@@ -6283,11 +6444,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
62836444    INSERT_CONSTANT_FILL_PATTERN (AtenZerosOp, 0 );
62846445#undef  INSERT_CONSTANT_FILL_PATTERN
62856446
6286- #define  INSERT_FILL_SCALAR_PATTERN (AtenOp )                                     \
6447+ #define  INSERT_FILL_PATTERN (AtenOp )                                             \
62876448  target.addIllegalOp <AtenOp>();                                               \
6288-   patterns.add <ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context);
6289-     INSERT_FILL_SCALAR_PATTERN (AtenFill_ScalarOp);
6290- #undef  INSERT_FILL_SCALAR_PATTERN
6449+   patterns.add <ConvertAtenFillOp<AtenOp>>(typeConverter, context);
6450+     INSERT_FILL_PATTERN (AtenFill_ScalarOp);
6451+     INSERT_FILL_PATTERN (AtenFillScalarOp);
6452+     INSERT_FILL_PATTERN (AtenFillTensorOp);
6453+ #undef  INSERT_FILL_PATTERN
62916454
62926455#define  INSERT_MASKED_FILL_PATTERN (AtenOp )                                     \
62936456  target.addIllegalOp <AtenOp>();                                               \
@@ -6359,6 +6522,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
63596522    INSERT_ATENOP_PATTERN (AtenTrilOp);
63606523    INSERT_ATENOP_PATTERN (AtenDiagonalOp);
63616524    INSERT_ATENOP_PATTERN (AtenIndexSelectOp);
6525+     INSERT_ATENOP_PATTERN (AtenFlipOp);
6526+     INSERT_ATENOP_PATTERN (AtenRoundOp);
63626527#undef  INSERT_ATENOP_PATTERN
63636528
63646529#define  INSERT_CLONE_ATENOP_PATTERN (AtenOp )                                    \
0 commit comments