Skip to content

Commit 86e7054

Browse files
committed
Fix test
Signed-off-by: Justin Chu <[email protected]>
1 parent 644f67d commit 86e7054

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,36 @@ def test_pass_with_clear_metadata_and_docstring(self):
3535
metadata_props={"mul_key": "mul_value"},
3636
doc_string="This is a Mul node",
3737
)
38-
func_inputs = [
39-
ir.Value(
40-
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
41-
),
42-
ir.Value(
43-
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
44-
),
45-
]
4638
function = ir.Function(
4739
graph=ir.Graph(
4840
name="my_function",
49-
inputs=func_inputs,
50-
outputs=mul_node.outputs,
51-
nodes=[add_node, mul_node],
41+
inputs=[
42+
input_a := ir.Value(
43+
name="input_a",
44+
type=ir.TensorType(ir.DataType.FLOAT),
45+
shape=ir.Shape((2, 3)),
46+
),
47+
input_b := ir.Value(
48+
name="input_b",
49+
type=ir.TensorType(ir.DataType.FLOAT),
50+
shape=ir.Shape((2, 3)),
51+
),
52+
],
53+
nodes=[
54+
add_node_func := ir.node(
55+
"Add",
56+
inputs=[input_a, input_b],
57+
metadata_props={"add_key": "add_value"},
58+
doc_string="This is an Add node",
59+
),
60+
mul_node_func := ir.node(
61+
"Mul",
62+
inputs=[add_node_func.o(), input_b],
63+
metadata_props={"mul_key": "mul_value"},
64+
doc_string="This is a Mul node",
65+
),
66+
],
67+
outputs=mul_node_func.outputs,
5268
opset_imports={"": 20},
5369
doc_string="This is a function docstring",
5470
metadata_props={"function_key": "function_value"},
@@ -59,8 +75,8 @@ def test_pass_with_clear_metadata_and_docstring(self):
5975
)
6076
func_node = ir.node(
6177
"my_function",
62-
inputs=[add_node.o(), inputs[1]],
63-
domain = "my_domain",
78+
inputs=[inputs[0], mul_node.o()],
79+
domain="my_domain",
6480
metadata_props={"mul_key": "mul_value"},
6581
doc_string="This is a Mul node",
6682
)
@@ -77,7 +93,7 @@ def test_pass_with_clear_metadata_and_docstring(self):
7793
)
7894
sub_node = ir.node(
7995
"Sub",
80-
inputs=[function.o(), const_node.o()],
96+
inputs=[func_node.o(), const_node.o()],
8197
num_outputs=1,
8298
metadata_props={"sub_key": "sub_value"},
8399
doc_string="This is a Sub node",

0 commit comments

Comments
 (0)