Skip to content

Commit eb4e59e

Browse files
authored
[Torch] support binary_cross_entropy_with_logits decomposition (#3741)
1 parent a33d123 commit eb4e59e

File tree

6 files changed

+154
-0
lines changed

6 files changed

+154
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9224,6 +9224,33 @@ def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy
92249224
}];
92259225
}
92269226

9227+
def Torch_AtenBinaryCrossEntropyWithLogitsOp : Torch_Op<"aten.binary_cross_entropy_with_logits", [
9228+
AllowsTypeRefinement,
9229+
HasValueSemantics,
9230+
ReadOnly
9231+
]> {
9232+
let summary = "Generated op for `aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)`";
9233+
let arguments = (ins
9234+
AnyTorchTensorType:$self,
9235+
AnyTorchTensorType:$target,
9236+
AnyTorchOptionalTensorType:$weight,
9237+
AnyTorchOptionalTensorType:$pos_weight,
9238+
Torch_IntType:$reduction
9239+
);
9240+
let results = (outs
9241+
AnyTorchOptionalTensorType:$result
9242+
);
9243+
let hasCustomAssemblyFormat = 1;
9244+
let extraClassDefinition = [{
9245+
ParseResult AtenBinaryCrossEntropyWithLogitsOp::parse(OpAsmParser &parser, OperationState &result) {
9246+
return parseDefaultTorchOp(parser, result, 5, 1);
9247+
}
9248+
void AtenBinaryCrossEntropyWithLogitsOp::print(OpAsmPrinter &printer) {
9249+
printDefaultTorchOp(printer, *this, 5, 1);
9250+
}
9251+
}];
9252+
}
9253+
92279254
def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [
92289255
AllowsTypeRefinement,
92299256
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10289,6 +10289,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1028910289
" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n"
1029010290
" return %0 : !torch.list<int>\n"
1029110291
" }\n"
10292+
" func.func @\"__torch_mlir_shape_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.int) -> !torch.list<int> {\n"
10293+
" %int0 = torch.constant.int 0\n"
10294+
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
10295+
" %1 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
10296+
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
10297+
" %3 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
10298+
" torch.prim.If.yield %3 : !torch.list<int>\n"
10299+
" } else {\n"
10300+
" torch.prim.If.yield %0 : !torch.list<int>\n"
10301+
" }\n"
10302+
" return %2 : !torch.list<int>\n"
10303+
" }\n"
1029210304
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
1029310305
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
1029410306
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
@@ -14634,6 +14646,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1463414646
" }\n"
1463514647
" return %4 : !torch.int\n"
1463614648
" }\n"
14649+
" func.func @\"__torch_mlir_dtype_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.int) -> !torch.int {\n"
14650+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
14651+
" return %0#1 : !torch.int\n"
14652+
" }\n"
1463714653
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
1463814654
" %none = torch.constant.none\n"
1463914655
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8799,6 +8799,77 @@ class DecomposeAtenCrossEntropyLossOp
87998799
};
88008800
} // namespace
88018801

8802+
namespace {
8803+
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
8804+
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
8805+
using OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp>::OpRewritePattern;
8806+
LogicalResult matchAndRewrite(AtenBinaryCrossEntropyWithLogitsOp op,
8807+
PatternRewriter &rewriter) const override {
8808+
Location loc = op.getLoc();
8809+
auto self = op.getSelf();
8810+
auto target = op.getTarget();
8811+
auto posWeight = op.getPosWeight();
8812+
auto weight = op.getWeight();
8813+
auto reduction = op.getReduction();
8814+
8815+
Value loss;
8816+
auto one =
8817+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
8818+
auto _one =
8819+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
8820+
8821+
auto _target =
8822+
rewriter.create<AtenMulScalarOp>(loc, target.getType(), target, _one);
8823+
auto _target_1 = rewriter.create<AtenAddScalarOp>(loc, _target.getType(),
8824+
_target, one, one);
8825+
Value mm =
8826+
rewriter.create<AtenMulTensorOp>(loc, self.getType(), _target_1, self);
8827+
Value logSigm =
8828+
rewriter.create<AtenLogSigmoidOp>(loc, self.getType(), self);
8829+
8830+
if (!isa<Torch::NoneType>(posWeight.getType())) {
8831+
auto logWeight = rewriter.create<AtenAddScalarOp>(
8832+
loc, posWeight.getType(),
8833+
rewriter.create<AtenSubScalarOp>(loc, posWeight.getType(), posWeight,
8834+
one, one),
8835+
one, one);
8836+
loss = rewriter.create<AtenSubTensorOp>(
8837+
loc, mm.getType(), mm,
8838+
rewriter.create<AtenMulTensorOp>(loc, logWeight.getType(), logWeight,
8839+
logSigm),
8840+
one);
8841+
} else {
8842+
loss =
8843+
rewriter.create<AtenSubTensorOp>(loc, mm.getType(), mm, logSigm, one);
8844+
}
8845+
8846+
if (!isa<Torch::NoneType>(weight.getType())) {
8847+
loss =
8848+
rewriter.create<AtenMulTensorOp>(loc, loss.getType(), loss, weight);
8849+
}
8850+
8851+
// apply loss reduction.
8852+
int64_t reductionInt;
8853+
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
8854+
return rewriter.notifyMatchFailure(op, "no reduction type is appointed!");
8855+
}
8856+
8857+
auto none = rewriter.create<ConstantNoneOp>(loc);
8858+
Value res;
8859+
if (reductionInt == 1) {
8860+
res = rewriter.create<AtenMeanOp>(loc, op.getType(), loss, none);
8861+
} else if (reductionInt == 2) {
8862+
res = rewriter.create<AtenSumOp>(loc, op.getType(), loss, none);
8863+
} else {
8864+
res = loss;
8865+
}
8866+
8867+
rewriter.replaceOp(op, res);
8868+
return success();
8869+
}
8870+
};
8871+
} // namespace
8872+
88028873
namespace {
88038874
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
88048875
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
@@ -9936,6 +10007,8 @@ class DecomposeComplexOpsPass
993610007
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
993710008
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
993810009
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
10010+
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
10011+
patterns);
993910012
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
994010013
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
994110014
addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,6 +1993,14 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int =
19931993
def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]:
19941994
return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing)
19951995

1996+
def aten〇binary_cross_entropy_with_logits〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, pos_weight: Optional[List[int]] = None, reduction: int = 1) -> List[int]:
1997+
scalar_shape: List[int] = []
1998+
if reduction == 0:
1999+
result_shape = upstream_shape_functions._copy(self)
2000+
else:
2001+
result_shape = scalar_shape
2002+
return result_shape
2003+
19962004
@check_shape_function([
19972005
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
19982006
])
@@ -4958,6 +4966,10 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
49584966
return dtype
49594967
return aten〇std〡dtype(self_rank_dtype)
49604968

4969+
def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int:
4970+
self_rank, self_dtype = self_rank_dtype
4971+
return self_dtype
4972+
49614973
@check_dtype_function(
49624974
_check_tensors_with_the_same_dtype(
49634975
tensor_shapes=[(3,3)],

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,9 @@ def emit_with_mutating_variants(key, **kwargs):
743743
emit(
744744
"aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)"
745745
)
746+
emit(
747+
"aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)"
748+
)
746749
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
747750
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
748751
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,6 +2294,29 @@ def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils):
22942294
module.forward(tu.rand(8, 2), tu.randint(8, high=2))
22952295

22962296

2297+
class BinaryCrossEntropyWithLogitsStaticModule(torch.nn.Module):
2298+
def __init__(self):
2299+
super().__init__()
2300+
2301+
@export
2302+
@annotate_args(
2303+
[
2304+
None,
2305+
([8, 2], torch.float32, True),
2306+
([8, 2], torch.float32, True),
2307+
]
2308+
)
2309+
def forward(self, input, target):
2310+
return torch.ops.aten.binary_cross_entropy_with_logits(
2311+
input, target, reduction=0
2312+
)
2313+
2314+
2315+
@register_test_case(module_factory=lambda: BinaryCrossEntropyWithLogitsStaticModule())
2316+
def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils):
2317+
module.forward(tu.rand(8, 2), tu.rand(8, 2))
2318+
2319+
22972320
# ==============================================================================
22982321

22992322

0 commit comments

Comments
 (0)