From 6b0b77a61e8b98db6527b196ae60ff562166cd96 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Fri, 22 Aug 2025 12:15:57 +0200 Subject: [PATCH 1/6] Add constant folding of Split with all constant inputs Signed-off-by: Christoph Berganski --- onnxscript/optimizer/_constant_folding.py | 41 +++++++++ .../optimizer/_constant_folding_test.py | 92 +++++++++++++++++++ 2 files changed, 133 insertions(+) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 55fb8759d4..3c1f2c935f 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -547,6 +547,47 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None +# Replaces Split operators with all constant inputs by a list of Constant +# operators +@register("Split") +def split(node: ir.Node, op, _): + # Replace single output split by Identity(x) + if len(node.outputs) == 1: + return op.Identity(node.inputs[0]) + + # Skip non-constant inputs + if (x := ir.convenience.get_const_tensor(node.inputs[0])) is None: + return None + + _split = None + + # Option A: Sizes per split + if len(node.inputs) == 2: + # Skip non-constant splits + if (_split := ir.convenience.get_const_tensor(node.inputs[1])) is None: + return None + # Numpy expects splits as starting indices for each section + _split = np.cumsum(_split.numpy()[:-1]) + + # Option B: Number of (even) splits + if (num_outputs := node.attributes.get("num_outputs")) is not None: + # Numpy accepts single integer of (even) splits as well + _split = num_outputs.as_int() + + # Hm, something must be terribly wrong... + if _split is None: + return None + + # Default split axis is 0, according to ONNX operators reference: + # https://onnx.ai/onnx/operators/onnx__Split.html + if (axis := node.attributes.get("axis")) is None: + axis = ir.Attr("axis", ir.AttributeType.INT, 0) + + # Split constant tensor and wrap a list of Constant operators + splits = np.array_split(x.numpy(), _split, axis.as_int()) + return [op.Constant(value=ir.tensor(x)) for x in splits] + + @register("Concat") def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Replace a Concat node with a single input by Identity""" diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index e58ee0ba19..93f2bfb0b0 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -403,6 +403,98 @@ def test_dropout_identity_mask(self, dropout_node: str): ops = [node.op_type for node in nodes] self.assertEqual(ops, ["Identity", "Shape", "ConstantOfShape"]) + def test_split_identity_num_outputs(self): + model = """ + + agraph (float[N] x) => (float[N] z) + { + z = Split (x) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(len(optimized.graph[-1].outputs), 1) + self.assertEqual(optimized.graph[-1].op_type, "Identity") + + def test_split_identity_splits(self): + model = """ + + agraph (float[N] x, float[1] split) => (float[N] z) + { + z = Split (x, split) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(len(optimized.graph[-1].outputs), 1) + self.assertEqual(optimized.graph[-1].op_type, "Identity") + + + def test_split_constant_num_outputs_even(self): + model = """ + + agraph () => (float[N] z1, float[N] z2) + { + x = Constant () + z1, z2 = Split (x) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(len(optimized.graph[-2].outputs), 1) + self.assertEqual(len(optimized.graph[-1].outputs), 1) + self.assertEqual(optimized.graph[-2].outputs[0].shape, [3]) + self.assertEqual(optimized.graph[-1].outputs[0].shape, [3]) + self.assertEqual(optimized.graph[-2].op_type, "Constant") + self.assertEqual(optimized.graph[-1].op_type, "Constant") + + def test_split_constant_num_outputs_odd(self): + model = """ + + agraph () => (float[N] z1, float[M] z2) + { + x = Constant () + z1, z2 = Split (x) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(len(optimized.graph[-2].outputs), 1) + self.assertEqual(len(optimized.graph[-1].outputs), 1) + self.assertEqual(optimized.graph[-2].outputs[0].shape, [4]) + self.assertEqual(optimized.graph[-1].outputs[0].shape, [3]) + self.assertEqual(optimized.graph[-2].op_type, "Constant") + self.assertEqual(optimized.graph[-1].op_type, "Constant") + + def test_split_constant_splits(self): + model = """ + + agraph () => (float[N] z1, float[M] z2, float[L] z3, float[K] z4) + { + x = Constant () + split = Constant () + z1, z2, z3, z4 = Split (x, split) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 4) + self.assertEqual(len(optimized.graph[-3].outputs), 1) + self.assertEqual(len(optimized.graph[-2].outputs), 1) + self.assertEqual(len(optimized.graph[-1].outputs), 1) + self.assertEqual(optimized.graph[-4].outputs[0].shape, [2]) + self.assertEqual(optimized.graph[-3].outputs[0].shape, [3]) + self.assertEqual(optimized.graph[-2].outputs[0].shape, [1]) + self.assertEqual(optimized.graph[-1].outputs[0].shape, [1]) + self.assertEqual(optimized.graph[-4].op_type, "Constant") + self.assertEqual(optimized.graph[-3].op_type, "Constant") + self.assertEqual(optimized.graph[-2].op_type, "Constant") + self.assertEqual(optimized.graph[-1].op_type, "Constant") + def test_concat_identity(self): model = """ From 83ef99dc4ebc032db33f36189a87d3bab334e731 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Fri, 22 Aug 2025 17:18:54 +0200 Subject: [PATCH 2/6] Add a more descriptive comment for rejecting Split configuration Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/optimizer/_constant_folding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 3c1f2c935f..21603aca64 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -574,7 +574,7 @@ def split(node: ir.Node, op, _): # Numpy accepts single integer of (even) splits as well _split = num_outputs.as_int() - # Hm, something must be terribly wrong... + # Unable to determine split configuration, skip optimization if _split is None: return None From 096d51e203e747a5e97558078d430d7103249b05 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Fri, 22 Aug 2025 17:19:58 +0200 Subject: [PATCH 3/6] Add proper docstring for Split constant folding Co-authored-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 21603aca64..13f49ae4ab 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -547,10 +547,9 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None -# Replaces Split operators with all constant inputs by a list of Constant -# operators @register("Split") def split(node: ir.Node, op, _): + """Replaces Split operators with all constant inputs by a list of Constant operators.""" # Replace single output split by Identity(x) if len(node.outputs) == 1: return op.Identity(node.inputs[0]) From 607a17e8f2763a180fe43e2f9dce5faee96eb145 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Fri, 22 Aug 2025 17:40:44 +0200 Subject: [PATCH 4/6] Make split constant folding options mutually exclusive Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/optimizer/_constant_folding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 13f49ae4ab..253957f0c0 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -569,7 +569,7 @@ def split(node: ir.Node, op, _): _split = np.cumsum(_split.numpy()[:-1]) # Option B: Number of (even) splits - if (num_outputs := node.attributes.get("num_outputs")) is not None: + elif (num_outputs := node.attributes.get("num_outputs")) is not None: # Numpy accepts single integer of (even) splits as well _split = num_outputs.as_int() From 956742cc347ac94ad1486b582165ebbda87e765b Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Thu, 28 Aug 2025 11:45:21 +0200 Subject: [PATCH 5/6] No _ prefix for local variable split Co-authored-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 253957f0c0..2af3f40994 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -558,7 +558,7 @@ def split(node: ir.Node, op, _): if (x := ir.convenience.get_const_tensor(node.inputs[0])) is None: return None - _split = None + split_ = None # Option A: Sizes per split if len(node.inputs) == 2: From 89d57a034c9485713e4f3bc2b0b85bf37268847f Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Thu, 28 Aug 2025 15:44:15 +0200 Subject: [PATCH 6/6] Rename other instances of _split to split_ Signed-off-by: Christoph Berganski --- onnxscript/optimizer/_constant_folding.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 2af3f40994..375bca02f5 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -563,18 +563,18 @@ def split(node: ir.Node, op, _): # Option A: Sizes per split if len(node.inputs) == 2: # Skip non-constant splits - if (_split := ir.convenience.get_const_tensor(node.inputs[1])) is None: + if (split_ := ir.convenience.get_const_tensor(node.inputs[1])) is None: return None # Numpy expects splits as starting indices for each section - _split = np.cumsum(_split.numpy()[:-1]) + split_ = np.cumsum(split_.numpy()[:-1]) # Option B: Number of (even) splits elif (num_outputs := node.attributes.get("num_outputs")) is not None: # Numpy accepts single integer of (even) splits as well - _split = num_outputs.as_int() + split_ = num_outputs.as_int() # Unable to determine split configuration, skip optimization - if _split is None: + if split_ is None: return None # Default split axis is 0, according to ONNX operators reference: @@ -583,7 +583,7 @@ def split(node: ir.Node, op, _): axis = ir.Attr("axis", ir.AttributeType.INT, 0) # Split constant tensor and wrap a list of Constant operators - splits = np.array_split(x.numpy(), _split, axis.as_int()) + splits = np.array_split(x.numpy(), split_, axis.as_int()) return [op.Constant(value=ir.tensor(x)) for x in splits]