@@ -410,6 +410,49 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
410
410
}
411
411
return failure ();
412
412
});
413
+ patterns.onOp (" LayerNormalization" , 17 ,
414
+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
415
+ Torch::ValueTensorType Y_type;
416
+ Torch::ValueTensorType Mean_type;
417
+ Torch::ValueTensorType InvStdDev_type;
418
+ Value X;
419
+ Value Scale;
420
+ Value B;
421
+ int64_t axis;
422
+ float epsilon;
423
+ int64_t stash_type;
424
+ if (binder.tensorOperandAtIndex (X, 0 ) ||
425
+ binder.tensorOperandAtIndex (Scale, 1 ) ||
426
+ binder.tensorOperandAtIndex (B, 2 ) ||
427
+ binder.tensorResultTypeAtIndex (Y_type, 0 ) ||
428
+ binder.tensorResultTypeAtIndex (Mean_type, 1 ) ||
429
+ binder.tensorResultTypeAtIndex (InvStdDev_type, 2 ) ||
430
+ binder.s64IntegerAttr (axis, " axis" , -1 ) ||
431
+ binder.f32FloatAttr (epsilon, " epsilon" , 0.00001 ) ||
432
+ binder.s64IntegerAttr (stash_type, " stash_type" , 1 ))
433
+ return failure ();
434
+ Value constEpsilon = rewriter.create <Torch::ConstantFloatOp>(
435
+ binder.getLoc (), rewriter.getType <Torch::FloatType>(),
436
+ rewriter.getF64FloatAttr (epsilon));
437
+ unsigned rank = 1 ;
438
+ if (std::optional<unsigned > maybeRank = Torch::getTensorRank (X))
439
+ rank = *maybeRank;
440
+ SmallVector<Value> normalized;
441
+ axis = Torch::toPositiveDim (axis, rank);
442
+ auto X_type = X.getType ().cast <Torch::ValueTensorType>();
443
+ ArrayRef<int64_t > X_shape = X_type.getSizes ();
444
+ for (int64_t n = axis; n < rank ; n++) {
445
+ normalized.push_back (rewriter.create <Torch::ConstantIntOp>(
446
+ binder.getLoc (), rewriter.getI64IntegerAttr (X_shape[n])));
447
+ }
448
+ Value normalized_shape = rewriter.create <Torch::PrimListConstructOp>(
449
+ binder.getLoc (),
450
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
451
+ normalized);
452
+ rewriter.replaceOpWithNewOp <Torch::AtenNativeLayerNormOp>(
453
+ binder.op , Y_type, Mean_type, InvStdDev_type, X, normalized_shape, Scale, B, constEpsilon);
454
+ return success ();
455
+ });
413
456
patterns.onOp (" LeakyRelu" , 1 ,
414
457
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
415
458
Torch::ValueTensorType resultType;
0 commit comments