Skip to content

Commit 238610e

Browse files
committed
Update test
1 parent 78e668c commit 238610e

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

src/onnx_ir/serde_test.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,29 +80,53 @@ def test_from_to_onnx_text_with_initializers(self):
8080
opset_import: ["" : 17]
8181
>
8282
agraph (float[1] input_x, float[2] input_y) => (float[2] result) {
83-
[node_name] result = Add (input_x, input_y)
83+
[node_1] add = Add (input_x, input_y)
84+
[node_2] result = Add (add, initializer_z)
8485
}"""
8586
self.maxDiff = None
8687
array = np.array([1.0, 2.0], dtype=np.float32)
88+
init_array = np.array([3.0, 4.0], dtype=np.float32)
8789
model = serde.from_onnx_text(
88-
model_text, initializers=[ir.tensor(array, name="input_y")]
90+
model_text,
91+
initializers=[
92+
ir.tensor(init_array, name="initializer_z"),
93+
ir.tensor(array, name="input_y"),
94+
],
8995
)
9096
np.testing.assert_array_equal(model.graph.inputs[1].const_value.numpy(), array)
97+
np.testing.assert_array_equal(
98+
model.graph.initializers["initializer_z"].const_value.numpy(), init_array
99+
)
91100
expected_text = """\
92101
<
93102
ir_version: 10,
94103
opset_import: ["" : 17]
95104
>
96105
agraph (float[1] input_x, float[2] input_y) => (float[2] result)
97-
<float[2] input_y = {1,2}>
106+
<float[2] initializer_z = {3,4}, float[2] input_y = {1,2}>
98107
{
99-
[node_name] result = Add (input_x, input_y)
108+
[node_1] add = Add (input_x, input_y)
109+
[node_2] result = Add (add, initializer_z)
100110
}"""
101111
onnx_text_roundtrip = serde.to_onnx_text(model)
102112
stripped_lines = [line.rstrip() for line in onnx_text_roundtrip.splitlines()]
103113
result = "\n".join(stripped_lines)
104114
self.assertEqual(result, expected_text)
105115

116+
def test_to_onnx_text_excluding_initializers(self):
117+
model_text = """\
118+
<
119+
ir_version: 10,
120+
opset_import: ["" : 17]
121+
>
122+
agraph (float[1] input_x, float[2] input_y) => (float[2] result) {
123+
[node_name] result = Add (input_x, input_y)
124+
}"""
125+
self.maxDiff = None
126+
array = np.array([1.0, 2.0], dtype=np.float32)
127+
model = serde.from_onnx_text(
128+
model_text, initializers=[ir.tensor(array, name="input_y")]
129+
)
106130
onnx_text_without_initializers = serde.to_onnx_text(model, exclude_initializers=True)
107131
expected_text_without_initializers = """\
108132
<

0 commit comments

Comments
 (0)