@@ -35,20 +35,36 @@ def test_pass_with_clear_metadata_and_docstring(self):
35
35
metadata_props = {"mul_key" : "mul_value" },
36
36
doc_string = "This is a Mul node" ,
37
37
)
38
- func_inputs = [
39
- ir .Value (
40
- name = "input_a" , type = ir .TensorType (ir .DataType .FLOAT ), shape = ir .Shape ((2 , 3 ))
41
- ),
42
- ir .Value (
43
- name = "input_b" , type = ir .TensorType (ir .DataType .FLOAT ), shape = ir .Shape ((2 , 3 ))
44
- ),
45
- ]
46
38
function = ir .Function (
47
39
graph = ir .Graph (
48
40
name = "my_function" ,
49
- inputs = func_inputs ,
50
- outputs = mul_node .outputs ,
51
- nodes = [add_node , mul_node ],
41
+ inputs = [
42
+ input_a := ir .Value (
43
+ name = "input_a" ,
44
+ type = ir .TensorType (ir .DataType .FLOAT ),
45
+ shape = ir .Shape ((2 , 3 )),
46
+ ),
47
+ input_b := ir .Value (
48
+ name = "input_b" ,
49
+ type = ir .TensorType (ir .DataType .FLOAT ),
50
+ shape = ir .Shape ((2 , 3 )),
51
+ ),
52
+ ],
53
+ nodes = [
54
+ add_node_func := ir .node (
55
+ "Add" ,
56
+ inputs = [input_a , input_b ],
57
+ metadata_props = {"add_key" : "add_value" },
58
+ doc_string = "This is an Add node" ,
59
+ ),
60
+ mul_node_func := ir .node (
61
+ "Mul" ,
62
+ inputs = [add_node_func .o (), input_b ],
63
+ metadata_props = {"mul_key" : "mul_value" },
64
+ doc_string = "This is a Mul node" ,
65
+ ),
66
+ ],
67
+ outputs = mul_node_func .outputs ,
52
68
opset_imports = {"" : 20 },
53
69
doc_string = "This is a function docstring" ,
54
70
metadata_props = {"function_key" : "function_value" },
@@ -59,8 +75,8 @@ def test_pass_with_clear_metadata_and_docstring(self):
59
75
)
60
76
func_node = ir .node (
61
77
"my_function" ,
62
- inputs = [add_node .o (), inputs [ 1 ] ],
63
- domain = "my_domain" ,
78
+ inputs = [inputs [ 0 ], mul_node .o ()],
79
+ domain = "my_domain" ,
64
80
metadata_props = {"mul_key" : "mul_value" },
65
81
doc_string = "This is a Mul node" ,
66
82
)
@@ -77,7 +93,7 @@ def test_pass_with_clear_metadata_and_docstring(self):
77
93
)
78
94
sub_node = ir .node (
79
95
"Sub" ,
80
- inputs = [function .o (), const_node .o ()],
96
+ inputs = [func_node .o (), const_node .o ()],
81
97
num_outputs = 1 ,
82
98
metadata_props = {"sub_key" : "sub_value" },
83
99
doc_string = "This is a Sub node" ,
0 commit comments