File tree 1 file changed +17
-4
lines changed
1 file changed +17
-4
lines changed Original file line number Diff line number Diff line change @@ -111,11 +111,24 @@ def assert_numpy_allclose(
111
111
)
112
112
113
113
if not (is_allclose ):
114
- diff = np .abs (
115
- np .asarray (output_parameters [i ]) - np .asarray (output_references [i ])
116
- )
114
+ output_parameters_i_arr = np .asarray (output_parameters [i ])
115
+ output_references_i_arr = np .asarray (output_references [i ])
116
+
117
+ diff = np .abs (output_parameters_i_arr - output_references_i_arr )
117
118
abs_diff = np .sum (diff )
118
- rel_diff = np .sum (diff / np .abs (output_references [i ]))
119
+ rel_diff_dividend = np .max (
120
+ np .vstack (
121
+ (
122
+ np .abs (output_parameters_i_arr ),
123
+ np .abs (output_references_i_arr ),
124
+ )
125
+ ),
126
+ axis = 0 ,
127
+ )
128
+ # when both are zero the diff is also zero, so we set it to 1
129
+ # so no division by zero error is raised
130
+ rel_diff_dividend [rel_diff_dividend == 0.0 ] = 1.0
131
+ rel_diff = np .sum (diff / rel_diff_dividend )
119
132
120
133
message = (
121
134
f"Output is not close to reference absolute difference "
You can’t perform that action at this time.
0 commit comments