Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,47 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return None


# Replaces Split operators with all constant inputs by a list of Constant
# operators
@register("Split")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better, more general, solution would be to fix the restriction in the general constant-folding (this else branch) ... so the underlying evaluator evaluates the Split op ... and we just need to bind each output value to the corresponding constant value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you share what needs to be done to enable that part?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iksnagreb as suggested here, we will instead fix handling for multi-output nodes. Please feel free to contribute or provide your inputs, thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. What needs to be done is create an appropriate Replacement object which wraps the list of new nodes (a list of Constant nodes like in current PR), and a list of values (which should be more or less what is returned by the reference implementation).

I realize that this is lacking even in the current PR down below ... if we don't specify the list of new values (to replace the list of old values produced by the constant-folded node), we are not going to connect these values to the consumers of the old values.

So, if we added a test-case where the outputs of Split are used subsequently (or if we test the graph outputs in the tests below actually refer to the correct split const values), we will likely find that this PR doesn't work. It stems from a limitation in the core constant-folding logic that needs to be fixed.

def split(node: ir.Node, op, _):
# Replace single output split by Identity(x)
if len(node.outputs) == 1:
return op.Identity(node.inputs[0])

# Skip non-constant inputs
if (x := ir.convenience.get_const_tensor(node.inputs[0])) is None:
return None

_split = None

# Option A: Sizes per split
if len(node.inputs) == 2:
# Skip non-constant splits
if (_split := ir.convenience.get_const_tensor(node.inputs[1])) is None:
return None
# Numpy expects splits as starting indices for each section
_split = np.cumsum(_split.numpy()[:-1])

# Option B: Number of (even) splits
if (num_outputs := node.attributes.get("num_outputs")) is not None:
# Numpy accepts single integer of (even) splits as well
_split = num_outputs.as_int()

# Hm, something must be terribly wrong...
if _split is None:
return None

# Default split axis is 0, according to ONNX operators reference:
# https://onnx.ai/onnx/operators/onnx__Split.html
if (axis := node.attributes.get("axis")) is None:
axis = ir.Attr("axis", ir.AttributeType.INT, 0)

# Split constant tensor and wrap a list of Constant operators
splits = np.array_split(x.numpy(), _split, axis.as_int())
return [op.Constant(value=ir.tensor(x)) for x in splits]


@register("Concat")
def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Concat node with a single input by Identity"""
Expand Down
92 changes: 92 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,98 @@ def test_dropout_identity_mask(self, dropout_node: str):
ops = [node.op_type for node in nodes]
self.assertEqual(ops, ["Identity", "Shape", "ConstantOfShape"])

def test_split_identity_num_outputs(self):
model = """
<ir_version: 8, opset_import: [ "" : 18]>
agraph (float[N] x) => (float[N] z)
{
z = Split <axis=-1, num_outputs=1> (x)
}
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 1)
self.assertEqual(len(optimized.graph[-1].outputs), 1)
self.assertEqual(optimized.graph[-1].op_type, "Identity")

def test_split_identity_splits(self):
model = """
<ir_version: 8, opset_import: [ "" : 18]>
agraph (float[N] x, float[1] split) => (float[N] z)
{
z = Split <axis=-1> (x, split)
}
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 1)
self.assertEqual(len(optimized.graph[-1].outputs), 1)
self.assertEqual(optimized.graph[-1].op_type, "Identity")


def test_split_constant_num_outputs_even(self):
model = """
<ir_version: 8, opset_import: [ "" : 18]>
agraph () => (float[N] z1, float[N] z2)
{
x = Constant <value: tensor = float[6] const {0,1,2,3,4,5}> ()
z1, z2 = Split <axis=-1, num_outputs=2> (x)
}
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(len(optimized.graph[-2].outputs), 1)
self.assertEqual(len(optimized.graph[-1].outputs), 1)
self.assertEqual(optimized.graph[-2].outputs[0].shape, [3])
self.assertEqual(optimized.graph[-1].outputs[0].shape, [3])
self.assertEqual(optimized.graph[-2].op_type, "Constant")
self.assertEqual(optimized.graph[-1].op_type, "Constant")

def test_split_constant_num_outputs_odd(self):
model = """
<ir_version: 8, opset_import: [ "" : 18]>
agraph () => (float[N] z1, float[M] z2)
{
x = Constant <value: tensor = float[7] const {0,1,2,3,4,5,6}> ()
z1, z2 = Split <axis=-1, num_outputs=2> (x)
}
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(len(optimized.graph[-2].outputs), 1)
self.assertEqual(len(optimized.graph[-1].outputs), 1)
self.assertEqual(optimized.graph[-2].outputs[0].shape, [4])
self.assertEqual(optimized.graph[-1].outputs[0].shape, [3])
self.assertEqual(optimized.graph[-2].op_type, "Constant")
self.assertEqual(optimized.graph[-1].op_type, "Constant")

def test_split_constant_splits(self):
model = """
<ir_version: 8, opset_import: [ "" : 18]>
agraph () => (float[N] z1, float[M] z2, float[L] z3, float[K] z4)
{
x = Constant <value: tensor = float[7] const {0,1,2,3,4,5,6}> ()
split = Constant <value_ints = [2, 3, 1, 1]> ()
z1, z2, z3, z4 = Split <axis=-1> (x, split)
}
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 4)
self.assertEqual(len(optimized.graph[-3].outputs), 1)
self.assertEqual(len(optimized.graph[-2].outputs), 1)
self.assertEqual(len(optimized.graph[-1].outputs), 1)
self.assertEqual(optimized.graph[-4].outputs[0].shape, [2])
self.assertEqual(optimized.graph[-3].outputs[0].shape, [3])
self.assertEqual(optimized.graph[-2].outputs[0].shape, [1])
self.assertEqual(optimized.graph[-1].outputs[0].shape, [1])
self.assertEqual(optimized.graph[-4].op_type, "Constant")
self.assertEqual(optimized.graph[-3].op_type, "Constant")
self.assertEqual(optimized.graph[-2].op_type, "Constant")
self.assertEqual(optimized.graph[-1].op_type, "Constant")

def test_concat_identity(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
Expand Down
Loading