Skip to content

Commit 09bbd27

Browse files
authored
Remove usages of ir.Input in test (#2591)
It was deprecated Signed-off-by: Justin Chu <[email protected]>
1 parent a1db753 commit 09bbd27

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ def clone_model(self, model: ir.Model) -> ir.Model:
1818

1919
def test_conv_affine_fusion(self):
2020
tape = ir.tape.Tape()
21-
x = ir.Input(
22-
"x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT)
23-
)
21+
x = ir.val("x", dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32]))
2422
w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w"))
2523
b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b"))
2624
scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale"))
@@ -31,10 +29,10 @@ def test_conv_affine_fusion(self):
3129
z = tape.op(
3230
"Add",
3331
[mul_out, offset],
34-
output=ir.Input(
32+
output=ir.val(
3533
"z",
34+
dtype=ir.DataType.FLOAT,
3635
shape=ir.Shape([1, 3, 32, 32]),
37-
type=ir.TensorType(ir.DataType.FLOAT),
3836
),
3937
)
4038

@@ -65,9 +63,7 @@ def test_conv_affine_fusion(self):
6563

6664
def test_affine_conv_fusion_without_pad(self):
6765
tape = ir.tape.Tape()
68-
x = ir.Input(
69-
"x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT)
70-
)
66+
x = ir.val("x", dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32]))
7167
w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w"))
7268
b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b"))
7369
scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale"))
@@ -77,10 +73,10 @@ def test_affine_conv_fusion_without_pad(self):
7773
z = tape.op(
7874
"Add",
7975
[mul_out, offset],
80-
output=ir.Input(
76+
output=ir.val(
8177
"z",
78+
dtype=ir.DataType.FLOAT,
8279
shape=ir.Shape([1, 3, 32, 32]),
83-
type=ir.TensorType(ir.DataType.FLOAT),
8480
),
8581
)
8682
conv_out = tape.op("Conv", [z, w, b], attributes={"pads": [0, 0, 0, 0]})

0 commit comments

Comments
 (0)