From 9a5f69a5ef95ff8c0167498939b899e3df9e4fb1 Mon Sep 17 00:00:00 2001 From: David Wurtz Date: Tue, 26 Nov 2024 05:21:07 -0800 Subject: [PATCH] add tests --- .../_internal/param_manipulation_test.py | 138 +++++++++++++ onnxscript/_legacy_ir/visitor_test.py | 84 ++++++++ onnxscript/backend/onnx_backend_test.py | 7 + onnxscript/evaluator_test.py | 76 +++++++ .../torch_lib/deduce_type_constraints_test.py | 60 ++++++ onnxscript/ir/_display_test.py | 74 +++++++ onnxscript/ir/_enums_test.py | 20 ++ onnxscript/ir/_external_data_test.py | 92 +++++++++ onnxscript/ir/_linked_list_test.py | 70 +++++++ onnxscript/ir/_schemas_test.py | 59 ++++++ onnxscript/ir/_type_casting_test.py | 22 ++ onnxscript/ir/serde_test.py | 61 ++++++ onnxscript/ir/tensor_adapters_test.py | 6 + onnxscript/ir/traversal_test.py | 38 ++++ onnxscript/optimizer/_inliner_test.py | 60 ++++++ .../_legacy/_simple_function_folding_test.py | 28 +++ .../rewriter/broadcast_to_matmul_test.py | 33 +++ onnxscript/rewriter/collapse_slices_test.py | 148 ++++++++++++++ onnxscript/rewriter/generic_pattern_test.py | 32 +++ onnxscript/rewriter/llama_rule_sets_test.py | 30 +++ .../bfloat16_utils/bfloat16_converter_test.py | 24 +++ .../fused_matmul_rule_sets_test.py | 16 ++ .../rewriter/onnxruntime/softmax_test.py | 7 + .../transformers/biassplitgelu_test.py | 110 ++++++++++ .../onnxruntime/transformers/fastgelu_test.py | 29 +++ .../transformers/layernorm_test.py | 13 ++ .../transformers/multihead_attention_test.py | 117 +++++++++++ onnxscript/tensor_test.py | 189 ++++++++++++++++++ .../tools/benchmark/benchmark_helpers_test.py | 135 +++++++++++++ .../tools/benchmark/export_model_test.py | 26 +++ onnxscript/tools/memory_peak_test.py | 24 +++ .../tools/transformers_models/llama_test.py | 72 +++++++ .../tools/transformers_models/mistral_test.py | 60 ++++++ .../tools/transformers_models/phi3_test.py | 99 +++++++++ .../tools/transformers_models/phi_test.py | 47 +++++ onnxscript/type_annotation_test.py | 30 +++ onnxscript/values_test.py | 23 +++ .../_version_converter_test.py | 14 ++ opgen/pygen_test.py | 82 ++++++++ tests/eager_mode_test.py | 56 ++++++ tests/external_tensor_test.py | 48 +++++ .../torch_lib/quantization_test.py | 63 ++++++ tests/ir/graph_view_test.py | 97 +++++++++ tests/onnx_types_test.py | 71 +++++++ tests/operator_test.py | 57 ++++++ 45 files changed, 2647 insertions(+) diff --git a/onnxscript/_internal/param_manipulation_test.py b/onnxscript/_internal/param_manipulation_test.py index 7b67e4380..19e04bc74 100644 --- a/onnxscript/_internal/param_manipulation_test.py +++ b/onnxscript/_internal/param_manipulation_test.py @@ -184,6 +184,144 @@ def test_it_raises_on_insufficient_args( allow_extra_kwargs=allow_extra_kwargs, ) + def test_tag_arguments_with_extra_kwargs_not_allowed(self): + param_schemas = ( + values.ParamSchema(name="a", type=INT64, is_input=True), + values.ParamSchema(name="b", type=int, is_input=False), + ) + + args = (TEST_INPUT,) + kwargs = {"b": 42, "extra": 100} + + with self.assertRaises(TypeError): + _, _ = param_manipulation.tag_arguments_with_param_schemas( + param_schemas, args, kwargs, allow_extra_kwargs=False + ) + + + def test_turn_to_kwargs_with_variadic_inputs(self): + param_schemas = ( + values.ParamSchema(name="a", type=INT64, is_input=True, is_variadic_input=True), + values.ParamSchema(name="b", type=int, is_input=False), + ) + + inputs = [TEST_INPUT, TEST_INPUT, TEST_INPUT] + attributes = {"b": 42} + + expected_attributes = { + "a": [TEST_INPUT, TEST_INPUT, TEST_INPUT], + "b": 42, + } + + result = param_manipulation.turn_to_kwargs_to_avoid_ordering( + param_schemas, inputs, attributes + ) + + self.assertEqual(result, expected_attributes) + + + def test_tag_arguments_with_variadic_inputs(self): + param_schemas = ( + values.ParamSchema(name="a", type=INT64, is_input=True, is_variadic_input=True), + values.ParamSchema(name="b", type=int, is_input=False), + ) + + args = (TEST_INPUT, TEST_INPUT, TEST_INPUT) + kwargs = {"b": 42} + + expected_tagged_args = [(TEST_INPUT, param_schemas[0]), (TEST_INPUT, param_schemas[0]), (TEST_INPUT, param_schemas[0])] + expected_tagged_kwargs = {"b": (42, param_schemas[1])} + + tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_param_schemas( + param_schemas, args, kwargs + ) + + self.assertEqual(tagged_args, expected_tagged_args) + self.assertEqual(tagged_kwargs, expected_tagged_kwargs) + + + def test_turn_to_kwargs_to_avoid_ordering(self): + param_schemas = ( + values.ParamSchema(name="a", type=INT64, is_input=True), + values.ParamSchema(name="b", type=int, is_input=True), + values.ParamSchema(name="c", type=float, is_input=False, default=0.0), + ) + + inputs = [TEST_INPUT, 42] + attributes = {"c": 0.0} + + expected_attributes = { + "a": TEST_INPUT, + "b": 42, + "c": 0.0, + } + + result = param_manipulation.turn_to_kwargs_to_avoid_ordering( + param_schemas, inputs, attributes + ) + + self.assertEqual(result, expected_attributes) + + + def test_tag_arguments_with_param_schemas(self): + param_schemas = ( + values.ParamSchema(name="a", type=INT64, is_input=True), + values.ParamSchema(name="b", type=int, is_input=False, default=100), + values.ParamSchema(name="c", type=float, is_input=False, default=0.0), + ) + + args = (TEST_INPUT,) + kwargs = {"b": 42} + + expected_tagged_args = [(TEST_INPUT, param_schemas[0])] + expected_tagged_kwargs = { + "b": (42, param_schemas[1]), + "c": (0.0, param_schemas[2]), + } + + tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_param_schemas( + param_schemas, args, kwargs + ) + + self.assertEqual(tagged_args, expected_tagged_args) + self.assertEqual(tagged_kwargs, expected_tagged_kwargs) + + + def test_required_input_not_provided(self): + param_schemas = ( + values.ParamSchema(name="a", type=INT64, is_input=True, required=True), + values.ParamSchema(name="b", type=int, is_input=False, default=100), + ) + + args = () + kwargs = {"b": 42} + + with self.assertRaises(TypeError): + _, _ = param_manipulation.tag_arguments_with_param_schemas( + param_schemas, args, kwargs + ) + + + def test_variadic_inputs(self): + param_schemas = ( + values.ParamSchema(name="a", type=INT64, is_input=True, is_variadic_input=True), + values.ParamSchema(name="b", type=int, is_input=False), + ) + + args = (TEST_INPUT, TEST_INPUT, TEST_INPUT) + kwargs = {"b": 42} + + expected_inputs = [TEST_INPUT, TEST_INPUT, TEST_INPUT] + expected_attributes = collections.OrderedDict([("b", 42)]) + + inputs, attributes = param_manipulation.separate_input_attributes_from_arguments( + param_schemas, args, kwargs + ) + + self.assertEqual(inputs, expected_inputs) + self.assertEqual(attributes, expected_attributes) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/_legacy_ir/visitor_test.py b/onnxscript/_legacy_ir/visitor_test.py index 7c0ebc05d..1add6de80 100644 --- a/onnxscript/_legacy_ir/visitor_test.py +++ b/onnxscript/_legacy_ir/visitor_test.py @@ -35,6 +35,90 @@ def test_function_optional_input_is_recorded_by_shape_env(self): model_visitor.function_shape_env.lookup(model.functions[0], "optional_z") ) + def test_proto_visitor_enter_exit_function_scope(self): + function_proto = onnx.FunctionProto() + visitor_instance = visitor.ProtoVisitor() + visitor_instance.enter_function_scope(function_proto) + self.assertIsNotNone(visitor_instance.scopes.current_scope().current_function_scope()) + visitor_instance.exit_function_scope(function_proto) + self.assertIsNone(visitor_instance.scopes.current_scope().current_function_scope()) + + + def test_proto_visitor_missing_input_types(self): + node_proto = onnx.helper.make_node( + 'Add', + inputs=['A', 'B'], + outputs=['C'] + ) + visitor_instance = visitor.ProtoVisitor(do_shape_inference=True) + visitor_instance.scopes.enter_graph_scope(onnx.GraphProto()) + visitor_instance.bind('A', visitor.ir.Value(name='A', type=onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [1]))) + visitor_instance.bind('B', visitor.ir.Value(name='B', type=onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [1]))) + visitor_instance.process_node(node_proto) + output_value = visitor_instance.lookup('C') + self.assertIsNone(output_value.type) + + + def test_save_to_value_info_with_overload(self): + shape_env = visitor.FunctionShapeEnv() + value_info = onnx.helper.make_tensor_value_info('custom::function::overload/x', onnx.TensorProto.FLOAT, [1]) + with self.assertRaises(NotImplementedError): + shape_env.save_to_value_info(visitor.ir.Value(name='x', type=value_info.type), 'custom', 'function', 'overload') + + + def test_save_to_model_proto_with_function_id_and_value_info(self): + model_proto = onnx.ModelProto() + model_proto.graph.value_info.extend([ + onnx.helper.make_tensor_value_info('custom::function/x', onnx.TensorProto.FLOAT, [1]) + ]) + shape_env = visitor.FunctionShapeEnv() + shape_env.load_from_model_proto(model_proto) + shape_env.save_to_model_proto(model_proto) + self.assertEqual(len(model_proto.graph.value_info), 2) + + + def test_subscope_bind_and_lookup_ref_attribute(self): + graph_proto = onnx.GraphProto() + subscope = visitor.SubScope(graph_proto) + attr_proto = onnx.AttributeProto() + attr_proto.name = "attr1" + subscope.bind_ref_attribute("attr1", attr_proto) + self.assertEqual(subscope.lookup_ref_attribute("attr1"), attr_proto) + + + def test_scope_bind_empty_name(self): + scope = visitor.Scope() + scope.enter_sub_scope(onnx.GraphProto()) + with self.assertRaises(ValueError): + scope.bind("", visitor.ir.Value(name="value")) + + + def test_load_from_value_info_with_function_id(self): + value_info = onnx.helper.make_tensor_value_info('custom::function/x', onnx.TensorProto.FLOAT, [1]) + shape_env = visitor.FunctionShapeEnv() + shape_env.load_from_value_info(value_info) + self.assertEqual(len(shape_env._function_values), 1) + + + def test_scope_bind_none_value(self): + scope = visitor.Scope() + scope.enter_sub_scope(onnx.GraphProto()) + with self.assertRaises(ValueError): + scope.bind("test_name", None) + + + def test_load_from_value_info_with_none_function_id(self): + value_info = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [1]) + shape_env = visitor.FunctionShapeEnv() + shape_env.load_from_value_info(value_info) + self.assertEqual(len(shape_env._function_values), 0) + + + def test_override_inferred_value_type_with_none_values(self): + result = visitor._override_inferred_value_type_with_symbolic_value_type(None, None) + self.assertIsNone(result) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/backend/onnx_backend_test.py b/onnxscript/backend/onnx_backend_test.py index efd9d823d..02a3a9bb5 100644 --- a/onnxscript/backend/onnx_backend_test.py +++ b/onnxscript/backend/onnx_backend_test.py @@ -8,6 +8,7 @@ from onnxscript.backend import onnx_backend +import numpy as np def load_function(obj): return ort.InferenceSession(obj.SerializeToString(), providers=("CPUExecutionProvider",)) @@ -41,6 +42,12 @@ def test_enumerate_onnx_tests_run_one(self): done += 1 self.assertEqual(done, 1) + def test_assert_almost_equal_string_with_floats(self): + expected = np.array([1.0, 2.0, 3.0]) + value = np.array([1.0, 2.0, 3.0]) + onnx_backend.assert_almost_equal_string(expected, value) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/evaluator_test.py b/onnxscript/evaluator_test.py index d42b1bab7..a7a70b02f 100644 --- a/onnxscript/evaluator_test.py +++ b/onnxscript/evaluator_test.py @@ -8,6 +8,8 @@ from onnxscript.onnx_opset import opset17 as op from onnxscript.onnx_types import FLOAT +from onnxscript import tensor +import onnx class EvaluatorTest(unittest.TestCase): def test_evaluator(self): @@ -62,5 +64,79 @@ def test_function(x, y: float = 1.0): _ = test_function(x, unknown=42) # pylint: disable=unexpected-keyword-arg + + def test_adapt_to_eager_mode_list_of_numpy_arrays(self): + inputs = [np.array([1, 2]), np.array([3, 4])] + expected = [tensor.Tensor(np.array([1, 2])), tensor.Tensor(np.array([3, 4]))] + result, has_array = evaluator._adapt_to_eager_mode(inputs) + for res, exp in zip(result, expected): + np.testing.assert_array_equal(res.value, exp.value) + self.assertTrue(has_array) + + + def test_compute_num_outputs_scan(self): + schema = onnx.defs.get_schema("Scan", 9) + args = [np.array([1, 2, 3, 4])] + kwargs = {'body': onnx.helper.make_graph([], "body", [], [onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1])])} + expected_outputs = 1 + result = evaluator.compute_num_outputs(schema, args, kwargs) + self.assertEqual(result, expected_outputs) + + + def test_compute_num_outputs_variable_outputs(self): + schema = onnx.defs.get_schema("Split", 13) + args = [np.array([1, 2, 3, 4]), np.array([2, 2])] + kwargs = {} + expected_outputs = 2 + result = evaluator.compute_num_outputs(schema, args, kwargs) + self.assertEqual(result, expected_outputs) + + + def test_adapt_to_user_mode_single_numpy_array(self): + input_array = np.array([1, 2, 3]) + expected = np.array([1, 2, 3]) + result = evaluator._adapt_to_user_mode(input_array) + np.testing.assert_array_equal(result, expected) + + + def test_adapt_to_eager_mode_single_none(self): + input_none = None + expected = None + result, has_array = evaluator._adapt_to_eager_mode(input_none) + self.assertEqual(result, expected) + self.assertFalse(has_array) + + + def test_adapt_to_eager_mode_single_scalar(self): + input_scalar = 5 + expected = tensor.Tensor(np.array(input_scalar, dtype=np.int64)) + result, has_array = evaluator._adapt_to_eager_mode(input_scalar) + self.assertEqual(result, expected) + self.assertFalse(has_array) + + + def test_adapt_to_user_mode_tuple_of_tensors(self): + input_tensors = (tensor.Tensor(np.array([1, 2, 3])), tensor.Tensor(np.array([4, 5, 6]))) + expected = (np.array([1, 2, 3]), np.array([4, 5, 6])) + result = evaluator._adapt_to_user_mode(input_tensors) + np.testing.assert_array_equal(result[0], expected[0]) + np.testing.assert_array_equal(result[1], expected[1]) + + + def test_unwrap_tensors_in_kwargs_mixed(self): + kwargs = {'a': tensor.Tensor(np.array([1, 2, 3])), 'b': np.array([4, 5, 6])} + expected = {'a': np.array([1, 2, 3]), 'b': np.array([4, 5, 6])} + result = evaluator._unwrap_tensors_in_kwargs(kwargs) + np.testing.assert_array_equal(result['a'], expected['a']) + np.testing.assert_array_equal(result['b'], expected['b']) + + + def test_compute_num_outputs_split_no_num_outputs(self): + schema = onnx.defs.get_schema("Split", 13) + args = [np.array([1, 2, 3, 4])] + kwargs = {} + with self.assertRaises(evaluator.EagerModeError): + evaluator.compute_num_outputs(schema, args, kwargs) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index a8d15c242..35699798b 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -58,5 +58,65 @@ def test_deduce_type_constraints_does_not_crash_for_onnx_function( logger.info(signature_type_constraint) + def test_type_constraint_repr_unordered(self): + value1 = deduce_type_constraints.Value("value1") + value2 = deduce_type_constraints.Value("value2") + type_constraint = deduce_type_constraints.TypeConstraint("test_constraint", {"tensor(float)", "tensor(int64)"}) + type_constraint.bind_value(value1) + type_constraint.bind_value(value2) + expected_repr = "TypeConstraint(name=test_constraint, type_strs={'tensor(float)', 'tensor(int64)'}, values=['value1', 'value2'])" + self.assertIn(repr(type_constraint), [ + "TypeConstraint(name=test_constraint, type_strs={'tensor(float)', 'tensor(int64)'}, values=['value1', 'value2'])", + "TypeConstraint(name=test_constraint, type_strs={'tensor(int64)', 'tensor(float)'}, values=['value1', 'value2'])" + ]) + + + def test_deduce_raises_not_implemented_error_for_loop_and_scan(self): + class MockOnnxFunction: + def to_function_proto(self): + class MockFunctionProto: + opset_import = [type('MockOpset', (object,), {'version': 1})] + node = [type('MockNode', (object,), {'op_type': 'Loop', 'domain': 'onnx'})] + return MockFunctionProto() + + def param_schemas(self): + return [] + + onnx_function = MockOnnxFunction() + deducer = deduce_type_constraints.TypeConstraintDeducer(onnx_function) + + with self.assertRaises(NotImplementedError): + deducer.deduce() + + + def test_onnx_function_type_constraints_repr(self): + input_constraints = {"input1": deduce_type_constraints.TypeConstraint("T0", {"tensor(float)"})} + output_constraints = {"output1": deduce_type_constraints.TypeConstraint("T1", {"tensor(int64)"})} + intermediate_constraints = {"intermediate1": deduce_type_constraints.TypeConstraint("T2", {"tensor(int32)"})} + constraints = deduce_type_constraints.OnnxFunctionTypeConstraints(input_constraints, output_constraints, intermediate_constraints) + expected_repr = ( + "Type Constraints:\n" + " Inputs: \n" + " input1: T0\n" + " Outputs: \n" + " output1: T1\n" + " Type Constraints: \n" + " T0: {'tensor(float)'}\n" + " T1: {'tensor(int64)'}\n" + " Intermediate Values: \n" + " intermediate1: T2\n" + " Intermediate Type Constraints: \n" + " T2: {'tensor(int32)'}" + ) + self.assertEqual(repr(constraints), expected_repr) + + + def test_value_merge_type_constraint_no_constraints(self): + value1 = deduce_type_constraints.Value("value1") + value2 = deduce_type_constraints.Value("value2") + with self.assertRaises(ValueError): + value1.merge_type_constraint(value2) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/ir/_display_test.py b/onnxscript/ir/_display_test.py index ee745b484..579ec7036 100644 --- a/onnxscript/ir/_display_test.py +++ b/onnxscript/ir/_display_test.py @@ -10,6 +10,7 @@ import onnxscript.ir as ir +import ml_dtypes class DisplayTest(unittest.TestCase): def test_tensor_display_does_not_raise_on_nan_values(self): array_with_nan = np.array([np.inf, -np.inf, np.nan, 5, -10], dtype=np.float32) @@ -18,5 +19,78 @@ def test_tensor_display_does_not_raise_on_nan_values(self): tensor.display() + def test_graph_register_initializer(self): + graph = ir.Graph(inputs=[], outputs=[], nodes=[]) + tensor_value = ir.Tensor(np.array([1, 2, 3], dtype=np.int32), dtype=ir.DataType.INT32) + initializer = ir.Value(name="init_tensor", const_value=tensor_value) + + # Register initializer successfully + graph.register_initializer(initializer) + self.assertIn("init_tensor", graph.initializers) + + # Attempt to register an initializer without a name + unnamed_initializer = ir.Value(const_value=tensor_value) + with self.assertRaises(ValueError): + graph.register_initializer(unnamed_initializer) + + # Attempt to register an initializer that is produced by a node + node = ir.Node(domain="", op_type="Add", inputs=[], outputs=[initializer]) + with self.assertRaises(ValueError): + graph.register_initializer(initializer) + + + def test_graph_remove_node_still_in_use(self): + graph = ir.Graph(inputs=[], outputs=[], nodes=[]) + node1 = ir.Node(domain="", op_type="Add", inputs=[], outputs=[ir.Value()]) + node2 = ir.Node(domain="", op_type="Add", inputs=[node1.outputs[0]], outputs=[ir.Value()]) + graph.append(node1) + graph.append(node2) + with self.assertRaises(ValueError): + graph.remove(node1, safe=True) + + + def test_shape_initialization_with_symbolic_dim(self): + symbolic_dim = ir.SymbolicDim("N") + shape = ir.Shape([symbolic_dim, 3, 5]) + self.assertEqual(shape.dims, (symbolic_dim, 3, 5)) + + + def test_external_tensor_with_relative_path(self): + tensor = ir.ExternalTensor( + location="relative/path/to/data", + offset=0, + length=100, + dtype=ir.DataType.FLOAT, + shape=ir.Shape([10, 10]), + name="test_tensor" + ) + self.assertEqual(tensor.location, "relative/path/to/data") + + + def test_graph_append_node_from_another_graph(self): + graph1 = ir.Graph(inputs=[], outputs=[], nodes=[]) + graph2 = ir.Graph(inputs=[], outputs=[], nodes=[]) + node = ir.Node(domain="", op_type="Add", inputs=[], outputs=[]) + graph1.append(node) + with self.assertRaises(ValueError): + graph2.append(node) + + + def test_tensor_initialization_with_mismatched_dtype(self): + array_int32 = np.array([1, 2, 3], dtype=np.int32) + with self.assertRaises(TypeError): + ir.Tensor(array_int32, dtype=ir.DataType.FLOAT) + + + def test_tensor_initialization_with_non_standard_dtypes(self): + array_bfloat16 = np.array([1.0, 2.0, 3.0], dtype=ml_dtypes.bfloat16) + tensor_bfloat16 = ir.Tensor(array_bfloat16, dtype=ir.DataType.BFLOAT16) + self.assertEqual(tensor_bfloat16.dtype, ir.DataType.BFLOAT16) + + array_float8 = np.array([1.0, 2.0, 3.0], dtype=ml_dtypes.float8_e4m3fn) + tensor_float8 = ir.Tensor(array_float8, dtype=ir.DataType.FLOAT8E4M3FN) + self.assertEqual(tensor_float8.dtype, ir.DataType.FLOAT8E4M3FN) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/ir/_enums_test.py b/onnxscript/ir/_enums_test.py index 0721aaa99..d80bd19e0 100644 --- a/onnxscript/ir/_enums_test.py +++ b/onnxscript/ir/_enums_test.py @@ -72,6 +72,26 @@ def test_enums_are_the_same_as_spec(self): self.assertEqual(_enums.AttributeType.TYPE_PROTOS, onnx.AttributeProto.TYPE_PROTOS) self.assertEqual(_enums.AttributeType.UNDEFINED, onnx.AttributeProto.UNDEFINED) + def test_from_numpy_unsupported_dtype(self): + with self.assertRaises(TypeError): + _enums.DataType.from_numpy(np.dtype('datetime64')) + + + def test_numpy_unsupported_type(self): + with self.assertRaises(TypeError): + _enums.DataType.UNDEFINED.numpy() + + + def test_attribute_type_str(self): + self.assertEqual(str(_enums.AttributeType.FLOAT), "FLOAT") + self.assertEqual(str(_enums.AttributeType.INT), "INT") + + + def test_attribute_type_repr(self): + self.assertEqual(repr(_enums.AttributeType.FLOAT), "FLOAT") + self.assertEqual(repr(_enums.AttributeType.INT), "INT") + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/ir/_external_data_test.py b/onnxscript/ir/_external_data_test.py index afcf32b20..99425a76a 100644 --- a/onnxscript/ir/_external_data_test.py +++ b/onnxscript/ir/_external_data_test.py @@ -557,5 +557,97 @@ def test_external_data_sorted(self): self.assertEqual(tensor_data, expected_tensor_order[i]) + def test_all_tensors_excludes_attributes(self): + attr_tensor_1 = ir.Tensor( + np.array([1.0], dtype=np.float32), + dtype=ir.DataType.FLOAT, + shape=ir.Shape([1]), + name="test_attr_tensor_1" + ) + attr_tensor_2 = ir.Tensor( + np.array([2.0], dtype=np.float32), + dtype=ir.DataType.FLOAT, + shape=ir.Shape([1]), + name="test_attr_tensor_2" + ) + node = ir.Node( + domain="", + op_type="TestOp", + inputs=[], + attributes=[ + ir.Attr(name="attr_tensor", type=ir.AttributeType.TENSOR, value=attr_tensor_1), + ir.Attr(name="attr_tensors", type=ir.AttributeType.TENSORS, value=[attr_tensor_2]) + ], + num_outputs=1 + ) + graph = ir.Graph( + inputs=[], + outputs=[], + initializers={}, + nodes=[node], + name="test_graph" + ) + tensors = list(_external_data._all_tensors(graph, include_attributes=False)) + self.assertEqual(len(tensors), 0) + + + def test_all_tensors_with_various_attributes(self): + attr_tensor_1 = ir.Tensor( + np.array([1.0], dtype=np.float32), + dtype=ir.DataType.FLOAT, + shape=ir.Shape([1]), + name="test_attr_tensor_1" + ) + attr_tensor_2 = ir.Tensor( + np.array([2.0], dtype=np.float32), + dtype=ir.DataType.FLOAT, + shape=ir.Shape([1]), + name="test_attr_tensor_2" + ) + node = ir.Node( + domain="", + op_type="TestOp", + inputs=[], + attributes=[ + ir.Attr(name="attr_tensor", type=ir.AttributeType.TENSOR, value=attr_tensor_1), + ir.Attr(name="attr_tensors", type=ir.AttributeType.TENSORS, value=[attr_tensor_2]) + ], + num_outputs=1 + ) + graph = ir.Graph( + inputs=[], + outputs=[], + initializers={}, + nodes=[node], + name="test_graph" + ) + tensors = list(_external_data._all_tensors(graph, include_attributes=True)) + self.assertEqual(len(tensors), 2) + self.assertEqual(tensors[0].name, "test_attr_tensor_1") + self.assertEqual(tensors[1].name, "test_attr_tensor_2") + + + def test_save_external_data_padding(self): + tensor_data = np.random.rand(1, 42).astype(np.float32) + tensor = ir.Tensor( + tensor_data, + dtype=ir.DataType.FLOAT, + shape=ir.Shape(tensor_data.shape), + name="test_tensor", + ) + external_data_info = [ + (tensor, _external_data._ExternalDataInfo("test_tensor", 100, tensor.nbytes)) + ] + file_path = os.path.join(self.base_path, "test_external_data.bin") + + _external_data._save_external_data(external_data_info, file_path) + + with open(file_path, "rb") as f: + f.seek(0, os.SEEK_END) + file_size = f.tell() + + self.assertEqual(file_size, 100 + tensor.nbytes) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/ir/_linked_list_test.py b/onnxscript/ir/_linked_list_test.py index 00f03e71e..dc56c1cda 100644 --- a/onnxscript/ir/_linked_list_test.py +++ b/onnxscript/ir/_linked_list_test.py @@ -374,5 +374,75 @@ def test_insert_after_supports_taking_elements_from_another_doubly_linked_list( self.assertEqual([elem.value for elem in other_linked_list], [42]) + def test_doubly_linked_set_repr(self): + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + self.assertEqual(repr(linked_list), "DoublyLinkedSet([_TestElement(0), _TestElement(1), _TestElement(2)])") + + + def test_insert_before_value_not_in_list(self): + linked_list = _linked_list.DoublyLinkedSet() + elem = _TestElement(0) + linked_list.append(elem) + with self.assertRaises(ValueError): + linked_list.insert_before(_TestElement(1), [_TestElement(2)]) + + + def test_insert_one_after_none_value(self): + linked_list = _linked_list.DoublyLinkedSet() + elem = _TestElement(0) + linked_list.append(elem) + with self.assertRaises(TypeError): + linked_list._insert_one_after(linked_list._root.next, None) + + + def test_erase_already_erased_linkbox(self): + linked_list = _linked_list.DoublyLinkedSet() + elem = _TestElement(0) + box = _linked_list._LinkBox(linked_list, elem) + box.erase() + with self.assertRaises(ValueError): + box.erase() + + + def test_insert_after_value_not_in_list(self): + linked_list = _linked_list.DoublyLinkedSet() + elem = _TestElement(0) + linked_list.append(elem) + with self.assertRaises(ValueError): + linked_list.insert_after(_TestElement(1), [_TestElement(2)]) + + + def test_insert_one_after_box_not_in_list(self): + linked_list = _linked_list.DoublyLinkedSet() + other_linked_list = _linked_list.DoublyLinkedSet() + elem = _TestElement(0) + other_elem = _TestElement(1) + linked_list.append(elem) + other_linked_list.append(other_elem) + with self.assertRaises(ValueError): + linked_list._insert_one_after(other_linked_list._root.next, _TestElement(2)) + + + def test_iterator_element_not_in_list(self): + linked_list = _linked_list.DoublyLinkedSet() + other_linked_list = _linked_list.DoublyLinkedSet() + elem = _TestElement(0) + linked_list.append(elem) + other_elem = _TestElement(1) + other_linked_list.append(other_elem) + linked_list._root.next = other_linked_list._root.next # Forcefully link an element from another list + with self.assertRaises(RuntimeError): + list(linked_list) + + + def test_linkbox_repr_erased(self): + linked_list = _linked_list.DoublyLinkedSet() + elem = _TestElement(0) + box = _linked_list._LinkBox(linked_list, elem) + box.erase() + self.assertEqual(repr(box), "_LinkBox(None, erased=True, prev=None, next=None)") + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/ir/_schemas_test.py b/onnxscript/ir/_schemas_test.py index c134bd7a6..78742494d 100644 --- a/onnxscript/ir/_schemas_test.py +++ b/onnxscript/ir/_schemas_test.py @@ -12,6 +12,8 @@ from onnxscript import FLOAT, INT64, ir from onnxscript.ir import _schemas +import unittest.mock +import onnx _TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) _TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) _TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) @@ -171,6 +173,63 @@ def test_pytype_to_ir_type(self, _, pytype: Any, expected: set[ir.TypeProtocol]) def test_get_type_constraint_name(self, _: str, pytype: Any, expected: str | None): self.assertEqual(_schemas._get_type_constraint_name(pytype), expected) # pylint: disable=protected-access + def test_convert_formal_parameter_plain_type(self): + mock_param = unittest.mock.Mock() + mock_param.type_str = "tensor(float)" + mock_param.name = "param" + mock_param.option = onnx.defs.OpSchema.FormalParameterOption.Single + type_constraints = {} + parameter = _schemas._convert_formal_parameter(mock_param, type_constraints) + self.assertEqual(parameter.name, "param") + self.assertTrue(ir.TensorType(ir.DataType.FLOAT) in parameter.type_constraint.allowed_types) + + + def test_type_constraint_param_str_single_allowed_type(self): + type_constraint = _schemas.TypeConstraintParam( + name="TFloat", + allowed_types={ir.TensorType(ir.DataType.FLOAT)} + ) + expected_str = "TFloat=FLOAT" + self.assertEqual(str(type_constraint), expected_str) + + + def test_get_type_from_str_unknown_type_part(self): + with self.assertRaises(ValueError) as context: + _schemas._get_type_from_str("unknown(float)") + self.assertIn("Unknown type part: 'unknown'", str(context.exception)) + + + def test_parameter_has_default(self): + type_constraint = _schemas.TypeConstraintParam.any_tensor("T") + param_with_default = _schemas.Parameter( + name="param1", type_constraint=type_constraint, required=True, variadic=False, default=5 + ) + param_without_default = _schemas.Parameter( + name="param2", type_constraint=type_constraint, required=True, variadic=False + ) + self.assertTrue(param_with_default.has_default()) + self.assertFalse(param_without_default.has_default()) + + + def test_type_constraint_param_any_value(self): + param = _schemas.TypeConstraintParam.any_value("TAny") + expected_types = _schemas._ALL_VALUE_TYPES + self.assertEqual(param.name, "TAny") + self.assertEqual(param.allowed_types, expected_types) + + + def test_type_constraint_param_any_tensor(self): + param = _schemas.TypeConstraintParam.any_tensor("TFloat") + expected_types = {ir.TensorType(dtype) for dtype in ir.DataType} + self.assertEqual(param.name, "TFloat") + self.assertEqual(param.allowed_types, expected_types) + + + def test_empty_class_representation(self): + empty_instance = _schemas._Empty() + self.assertEqual(repr(empty_instance), "_EMPTY_DEFAULT") + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/ir/_type_casting_test.py b/onnxscript/ir/_type_casting_test.py index abe4923ee..99b42f7d3 100644 --- a/onnxscript/ir/_type_casting_test.py +++ b/onnxscript/ir/_type_casting_test.py @@ -6,6 +6,7 @@ import parameterized from onnxscript.ir import _type_casting +import ml_dtypes class TypeCastingTest(unittest.TestCase): @@ -45,6 +46,27 @@ def test_pack_int4_returns_flatten_array(self, _: str, dtype): actual = _type_casting.pack_int4(array) np.testing.assert_array_equal(actual, expected) + def test_unpack_uint4_with_padding(self): + packed_data = np.array([0x21, 0x43, 0x65], dtype=np.uint8) + expected = np.array([1, 2, 3, 4, 5], dtype=np.uint8) + actual = _type_casting._unpack_uint4_as_uint8(packed_data, (5,)) + np.testing.assert_array_equal(actual, expected) + + + def test_unpack_float4e2m1(self): + packed_data = np.array([0x12, 0x34, 0x56, 0x78], dtype=np.uint8) + expected_shape = (2, 4) + actual = _type_casting.unpack_float4e2m1(packed_data, expected_shape) + self.assertEqual(actual.shape, expected_shape) + + + def test_unpack_uint4(self): + packed_data = np.array([0x21, 0x43, 0x65, 0x87], dtype=np.uint8) + expected = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=ml_dtypes.uint4) + actual = _type_casting.unpack_uint4(packed_data, (2, 4)) + np.testing.assert_array_equal(actual, expected) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py index f46756055..53bccf21e 100644 --- a/onnxscript/ir/serde_test.py +++ b/onnxscript/ir/serde_test.py @@ -289,6 +289,67 @@ def test_deserialize_graph_handles_unsorted_graph(self): self.assertEqual(deserialized_graph[0].op_type, "Op_1") self.assertEqual(deserialized_graph[1].op_type, "Op_0") + def test_tensor_proto_tensor_string(self): + tensor_proto = onnx.helper.make_tensor( + "test_tensor", onnx.TensorProto.STRING, [3], [b"one", b"two", b"three"] + ) + tensor = serde.TensorProtoTensor(tensor_proto) + expected_array = np.array([b"one", b"two", b"three"], dtype=object) + np.testing.assert_array_equal(tensor.numpy(), expected_array) + + + def test_serialize_function_without_default_attribute(self): + function = ir.Function( + domain="test_domain", + name="test_function", + graph=ir.Graph([], [], nodes=[]), + attributes=[ir.Attr("attr_without_default", ir.AttributeType.FLOAT, None)] + ) + function_proto = serde.serialize_function(function) + self.assertIn("attr_without_default", function_proto.attribute) + + + def test_deserialize_graph_with_empty_input(self): + node_proto = onnx.helper.make_node( + "OpType", + inputs=["", "input_1"], + outputs=["output_1"], + name="test_node" + ) + graph_proto = onnx.helper.make_graph( + nodes=[node_proto], + name="test_graph", + inputs=[], + outputs=[] + ) + deserialized_graph = serde.deserialize_graph(graph_proto) + self.assertEqual(deserialized_graph[0].inputs[0], None) + self.assertEqual(deserialized_graph[0].inputs[1].name, "input_1") + + + def test_tensor_proto_tensor_external_data(self): + tensor_proto = onnx.helper.make_tensor( + "test_tensor", onnx.TensorProto.FLOAT, [1, 9], [-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0] + ) + tensor_proto.data_location = onnx.TensorProto.EXTERNAL + tensor = serde.TensorProtoTensor(tensor_proto) + with self.assertRaises(ValueError): + tensor.numpy() + + + def test_to_proto_unsupported_type(self): + class UnsupportedIRObject: + pass + + with self.assertRaises(NotImplementedError): + serde.to_proto(UnsupportedIRObject()) + + + def test_from_proto_unsupported_type(self): + with self.assertRaises(NotImplementedError): + serde.from_proto(onnx.OperatorSetIdProto()) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/ir/tensor_adapters_test.py b/onnxscript/ir/tensor_adapters_test.py index 34034ac51..e373f4cb9 100644 --- a/onnxscript/ir/tensor_adapters_test.py +++ b/onnxscript/ir/tensor_adapters_test.py @@ -80,5 +80,11 @@ def test_tobytes(self, dtype: torch.dtype): self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes()) + def test_array_conversion_with_dtype(self): + tensor = tensor_adapters.TorchTensor(torch.tensor([1.0], dtype=torch.float32)) + np_array = tensor.__array__(dtype=np.float64) + self.assertEqual(np_array.dtype, np.float64) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/ir/traversal_test.py b/onnxscript/ir/traversal_test.py index 5ed4d3147..6a8d426b6 100644 --- a/onnxscript/ir/traversal_test.py +++ b/onnxscript/ir/traversal_test.py @@ -77,5 +77,43 @@ def test_recursive_graph_iterator_recursive_controls_recursive_behavior( self.assertEqual(tuple(node.op_type for node in nodes), expected) + def test_recursive_graph_iterator_with_multiple_graphs(self): + graph_with_multiple_graphs = ir.Graph( + [], + [], + nodes=[ + ir.Node( + "", + "Loop", + [], + attributes=[ + ir.AttrGraphs( + "body_branches", + [ + ir.Graph( + [], + [], + nodes=[ir.Node("", "Node7", []), ir.Node("", "Node8", [])], + name="body_graph_1", + ), + ir.Graph( + [], + [], + nodes=[ir.Node("", "Node9", []), ir.Node("", "Node10", [])], + name="body_graph_2", + ), + ], + ), + ], + ), + ], + name="main_graph_with_multiple_graphs", + ) + iterator = traversal.RecursiveGraphIterator(graph_with_multiple_graphs) + nodes = list(iterator) + expected = ("Loop", "Node7", "Node8", "Node9", "Node10") + self.assertEqual(tuple(node.op_type for node in nodes), expected) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/optimizer/_inliner_test.py b/onnxscript/optimizer/_inliner_test.py index e7e3bbadc..52c45cfe8 100644 --- a/onnxscript/optimizer/_inliner_test.py +++ b/onnxscript/optimizer/_inliner_test.py @@ -206,6 +206,66 @@ def test_attr_parameter_with_default_value(self): """ self._check(input_model, expected_model) + def test_opset_mismatch_error(self): + input_model = """ + + agraph (float[N] X) => (float[N] Y) + { + Y = local.foo (X) + } + + + foo (x) => (y) { + y = Selu (x) + } + """ + model_proto = parser.parse_model(input_model) + model_ir = ir.serde.deserialize_model(model_proto) + with self.assertRaises(ValueError) as context: + inline(model_ir) + self.assertIn("Opset mismatch", str(context.exception)) + + + def test_input_mismatch_error(self): + input_model = """ + + agraph (float[N] X, float[N] Z) => (float[N] Y) + { + Y = local.foo (X, Z) + } + + + foo (x) => (y) { + y = Selu (x) + } + """ + model_proto = parser.parse_model(input_model) + model_ir = ir.serde.deserialize_model(model_proto) + with self.assertRaises(ValueError) as context: + inline(model_ir) + self.assertIn("Input mismatch", str(context.exception)) + + + def test_graph_attribute_parameter_error(self): + input_model = """ + + agraph (float[N] X) => (float[N] Y) + { + Y = local.foo (X) + } + + + foo (x) => (y) { + y = Selu (x) + } + """ + model_proto = parser.parse_model(input_model) + model_ir = ir.serde.deserialize_model(model_proto) + with self.assertRaises(ValueError) as context: + inline(model_ir) + self.assertIn("Inliner does not support graph attribute parameters to functions", str(context.exception)) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/optimizer/_legacy/_simple_function_folding_test.py b/onnxscript/optimizer/_legacy/_simple_function_folding_test.py index aa0af61a0..c6334f3c0 100644 --- a/onnxscript/optimizer/_legacy/_simple_function_folding_test.py +++ b/onnxscript/optimizer/_legacy/_simple_function_folding_test.py @@ -217,5 +217,33 @@ def test_fold_function_with_unused_output(self): self.assertEqual(len(model.functions), 1) + def test_inline_function_with_fewer_inputs(self): + model = onnx.parser.parse_model( + """ + < + ir_version: 8, + opset_import: ["this" : 1, "" : 18] + > + func ( x, y) => ( return_val) { + tmp = this.foldable (x) + return_val = Add (tmp, y) + } + < + domain: "this", + opset_import: ["" : 18] + > + foldable (x, y) => (return_val) + { + return_val = Identity (x) + } + """ + ) + + _simple_function_folding.inline_simple_functions(model) + model = _remove_unused_function.remove_unused_functions(model) + + self.assertEqual(len(model.functions), 0) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/broadcast_to_matmul_test.py index 49c97d2c7..be82ea756 100644 --- a/onnxscript/rewriter/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/broadcast_to_matmul_test.py @@ -10,6 +10,7 @@ from onnxscript import ir from onnxscript.rewriter import broadcast_to_matmul +import numpy as np def _infer_shapes(model: ir.Model) -> ir.Model: @@ -393,5 +394,37 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): self.assertEqual(len(model.graph), 3) + def test_input_b_shape_symbolic(self): + input_a = ir.Value(shape=ir.Shape([2, 3])) + input_b = ir.Value(shape=ir.Shape([3, ir.SymbolicDim("m")])) + shape_c = ir.Value(const_value=ir.Tensor(np.array([2, 4]))) + result = broadcast_to_matmul.check_if_not_need_reshape(None, input_a, input_b, shape_c) + self.assertFalse(result) + + + def test_input_a_shape_symbolic(self): + input_a = ir.Value(shape=ir.Shape([ir.SymbolicDim("n"), 3])) + input_b = ir.Value(shape=ir.Shape([3, 4])) + shape_c = ir.Value(const_value=ir.Tensor(np.array([2, 4]))) + result = broadcast_to_matmul.check_if_not_need_reshape(None, input_a, input_b, shape_c) + self.assertFalse(result) + + + def test_shape_c_tensor_multi_dimensional(self): + input_a = ir.Value(shape=ir.Shape([2, 3])) + input_b = ir.Value(shape=ir.Shape([3, 4])) + shape_c = ir.Value(const_value=ir.Tensor(np.array([[1, 2], [3, 4]]))) + result = broadcast_to_matmul.check_if_not_need_reshape(None, input_a, input_b, shape_c) + self.assertFalse(result) + + + def test_shape_c_tensor_none(self): + input_a = ir.Value(shape=ir.Shape([2, 3])) + input_b = ir.Value(shape=ir.Shape([3, 4])) + shape_c = ir.Value(const_value=None) + result = broadcast_to_matmul.check_if_not_need_reshape(None, input_a, input_b, shape_c) + self.assertFalse(result) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/collapse_slices_test.py b/onnxscript/rewriter/collapse_slices_test.py index 6a11bd202..2f75a37b3 100644 --- a/onnxscript/rewriter/collapse_slices_test.py +++ b/onnxscript/rewriter/collapse_slices_test.py @@ -114,3 +114,151 @@ def test_scatternd_is_redundant_when_it_is_updating_the_whole_input_in_order(sel input = np.random.rand(112, 16, 512).astype(np.float32) testing.assert_numerically_equal(original_model_proto, model, (input, input)) + + def test_check_if_redundant_scatternd_logs_when_indices_not_referring_to_whole_data(self): + data = ir.Value(shape=ir.Shape([112, 16, 512])) + updates = ir.Value(shape=ir.Shape([112, 16, 512])) + indices = ir.Value(const_value=ir.Tensor(np.array([[0], [1], [2]]))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_scatternd(None, data, indices, updates) + self.assertFalse(result) + self.assertIn("The 'indices' is not referring to the whole data.", log.output[0]) + + + def test_check_if_redundant_scatternd_logs_when_updates_shape_not_statically_known(self): + data = ir.Value(shape=ir.Shape([112, 16, 512])) + updates = ir.Value(shape=None) + indices = ir.Value(const_value=ir.Tensor(np.arange(112).reshape(112, 1).astype(np.int64))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_scatternd(None, data, indices, updates) + self.assertFalse(result) + self.assertIn("The value 'updates' shape is not statically known.", log.output[0]) + + + def test_check_if_redundant_scatternd_logs_when_data_shape_not_statically_known(self): + data = ir.Value(shape=None) + updates = ir.Value(shape=ir.Shape([112, 16, 512])) + indices = ir.Value(const_value=ir.Tensor(np.arange(112).reshape(112, 1).astype(np.int64))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_scatternd(None, data, indices, updates) + self.assertFalse(result) + self.assertIn("The value 'data' shape is not statically known.", log.output[0]) + + + def test_check_if_redundant_slice_logs_when_start_not_zero(self): + data = ir.Value(shape=ir.Shape([512, 16, 112])) + starts = ir.Value(const_value=ir.Tensor(np.array([1]))) + ends = ir.Value(const_value=ir.Tensor(np.array([9999]))) + axes = ir.Value(const_value=ir.Tensor(np.array([0]))) + steps = ir.Value(const_value=ir.Tensor(np.array([1]))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_slice(None, data, starts, ends, axes, steps) + self.assertFalse(result) + self.assertIn("The value 'start' is not 0.", log.output[0]) + + + def test_check_if_redundant_slice_logs_when_step_not_one(self): + data = ir.Value(shape=ir.Shape([512, 16, 112])) + starts = ir.Value(const_value=ir.Tensor(np.array([0]))) + ends = ir.Value(const_value=ir.Tensor(np.array([9999]))) + axes = ir.Value(const_value=ir.Tensor(np.array([0]))) + steps = ir.Value(const_value=ir.Tensor(np.array([2]))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_slice(None, data, starts, ends, axes, steps) + self.assertFalse(result) + self.assertIn("The value 'step' is not 1.", log.output[0]) + + + def test_check_if_redundant_slice_logs_when_step_not_scalar(self): + data = ir.Value(shape=ir.Shape([512, 16, 112])) + starts = ir.Value(const_value=ir.Tensor(np.array([0]))) + ends = ir.Value(const_value=ir.Tensor(np.array([9999]))) + axes = ir.Value(const_value=ir.Tensor(np.array([0]))) + steps = ir.Value(const_value=ir.Tensor(np.array([1, 2]))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_slice(None, data, starts, ends, axes, steps) + self.assertFalse(result) + self.assertIn("The value 'step' is not a scalar.", log.output[0]) + + + def test_check_if_redundant_slice_logs_when_axis_not_scalar(self): + data = ir.Value(shape=ir.Shape([512, 16, 112])) + starts = ir.Value(const_value=ir.Tensor(np.array([0]))) + ends = ir.Value(const_value=ir.Tensor(np.array([9999]))) + axes = ir.Value(const_value=ir.Tensor(np.array([0, 1]))) + steps = ir.Value(const_value=ir.Tensor(np.array([1]))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_slice(None, data, starts, ends, axes, steps) + self.assertFalse(result) + self.assertIn("The value 'axis' is not a scalar.", log.output[0]) + + + def test_check_if_redundant_slice_logs_when_end_not_scalar(self): + data = ir.Value(shape=ir.Shape([512, 16, 112])) + starts = ir.Value(const_value=ir.Tensor(np.array([0]))) + ends = ir.Value(const_value=ir.Tensor(np.array([9999, 10000]))) + axes = ir.Value(const_value=ir.Tensor(np.array([0]))) + steps = ir.Value(const_value=ir.Tensor(np.array([1]))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_slice(None, data, starts, ends, axes, steps) + self.assertFalse(result) + self.assertIn("The value 'end' is not a scalar.", log.output[0]) + + + def test_check_if_redundant_scatternd_logs_when_shapes_different(self): + data = ir.Value(shape=ir.Shape([112, 16, 512])) + updates = ir.Value(shape=ir.Shape([112, 16, 256])) + indices = ir.Value(const_value=ir.Tensor(np.arange(112).reshape(112, 1).astype(np.int64))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_scatternd(None, data, indices, updates) + self.assertFalse(result) + self.assertIn("The shape of 'data' and 'updates' are different.", log.output[0]) + + + def test_check_if_redundant_slice_logs_when_end_less_than_shape(self): + data = ir.Value(shape=ir.Shape([512, 16, 112])) + starts = ir.Value(const_value=ir.Tensor(np.array([0]))) + ends = ir.Value(const_value=ir.Tensor(np.array([100]))) + axes = ir.Value(const_value=ir.Tensor(np.array([0]))) + steps = ir.Value(const_value=ir.Tensor(np.array([1]))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_slice(None, data, starts, ends, axes, steps) + self.assertFalse(result) + self.assertIn("The value 'end' is less than the shape of the specified axis.", log.output[0]) + + + def test_check_if_redundant_slice_logs_when_start_not_scalar(self): + data = ir.Value(shape=ir.Shape([512, 16, 112])) + starts = ir.Value(const_value=ir.Tensor(np.array([0, 1]))) + ends = ir.Value(const_value=ir.Tensor(np.array([9999]))) + axes = ir.Value(const_value=ir.Tensor(np.array([0]))) + steps = ir.Value(const_value=ir.Tensor(np.array([1]))) + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_slice(None, data, starts, ends, axes, steps) + self.assertFalse(result) + self.assertIn("The value 'start' is not a scalar.", log.output[0]) + + + def test_check_if_redundant_slice_logs_when_values_not_statically_known(self): + data = ir.Value(shape=ir.Shape([512, 16, 112])) + starts = ir.Value() + ends = ir.Value() + axes = ir.Value() + steps = ir.Value() + + with self.assertLogs(collapse_slices.logger, level='INFO') as log: + result = collapse_slices._check_if_redundant_slice(None, data, starts, ends, axes, steps) + self.assertFalse(result) + self.assertIn("The value 'start', 'end', 'axis', 'step' is not statically known.", log.output[0]) + diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index dadaf5e8b..1ce096dd1 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -603,5 +603,37 @@ def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): self.assertEqual(list(att.ints), [2, 0, 1]) + + def test_enumerate_matches_with_node(self): + def match_pattern(op, x): + return op.Add(x, x) + + def apply_pattern(op, x, **_): + return op.Add(x, x) + + pattern = generic_pattern.orp._to_graph_pattern(match_pattern) + matcher = generic_pattern.GenericPatternMatcher(pattern) + + model = onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Add", ["x", "x"], ["y"]), + ], + "dummy", + [onnx.helper.make_tensor_value_info("x", FLOAT, [None, None])], + [onnx.helper.make_tensor_value_info("y", FLOAT, [None, None])], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ir_version=9, + ) + onnx.checker.check_model(model) + + model = onnx.shape_inference.infer_shapes(model) + ir_model = ir.serde.deserialize_model(model) + + matches = list(matcher.enumerate_matches(ir_model, ir_model.graph, ir_model.graph[0])) + + self.assertEqual(len(matches), 1) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/llama_rule_sets_test.py b/onnxscript/rewriter/llama_rule_sets_test.py index 2415130c7..b18a2882e 100644 --- a/onnxscript/rewriter/llama_rule_sets_test.py +++ b/onnxscript/rewriter/llama_rule_sets_test.py @@ -430,5 +430,35 @@ def test_llama_p0_rule_set_slice_split(self): self._check_model(model_proto, rewritten_model) + def test_reshape_reshape_non_positive_values(self): + model = _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Reshape", ["X", "shape_"], ["Xu"]), + onnx.helper.make_node("Reshape", ["Xu", "shape"], ["Y"]), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])], + [onnx.helper.make_tensor_value_info("Y", FLOAT, [5, 4, 3])], + [ + onnx.numpy_helper.from_array( + np.array([4, 5, 3], dtype=np.int64), name="shape_" + ), + onnx.numpy_helper.from_array( + np.array([-5, 4, 3], dtype=np.int64), name="shape" + ), + ], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ) + rule_set = llama_rule_sets.llama_p0_rule_set() + model_proto = ir.serde.serialize_model(model) + rule_set.apply_to_model(model) + rewritten_model = ir.serde.serialize_model(model) + + self.assertNotEqual(["Reshape"], [n.op_type for n in model.graph]) + self._check_model(model_proto, rewritten_model) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py index b9666fba3..efa2b7bac 100644 --- a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py @@ -95,5 +95,29 @@ def test_bfloat16_converted_model_runtime(self): ) + def test_insert_cast_nodes_logs_warning_for_output_without_producer_or_index(self): + output_without_producer = ir.Value(name="output_without_producer") + output_without_producer.dtype = ir.DataType.BFLOAT16 + output_without_producer._producer = None # Ensure no producer + output_without_producer._index = None # Ensure no index + with self.assertLogs(bfloat16_converter.logger, level='WARNING') as log: + bfloat16_converter._insert_cast_nodes_for_bfloat16_to_float16_to_outputs(output_without_producer) + self.assertTrue(any("has no producer or index" in message for message in log.output)) + + + def test_convert_outputs_no_conversion_for_non_bfloat16(self): + output_value = ir.Input(name="output_value", shape=ir.Shape([2, 3, 4])) + output_value.dtype = ir.DataType.FLOAT + bfloat16_converter._convert_outputs_from_bfloat16_to_float16(output_value) + self.assertEqual(output_value.dtype, ir.DataType.FLOAT) + + + def test_convert_inputs_no_conversion_for_non_bfloat16(self): + input_value = ir.Input(name="input_value", shape=ir.Shape([2, 3, 4])) + input_value.dtype = ir.DataType.FLOAT + bfloat16_converter._convert_inputs_from_bfloat16_to_float16(input_value) + self.assertEqual(input_value.dtype, ir.DataType.FLOAT) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py index a7d170e69..76df52f14 100644 --- a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py +++ b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py @@ -359,5 +359,21 @@ def test_should_not_match(self): self._check_model(model_proto, rewritten_model, atol=1e-6) + def test_fused_matmul_div2_check_none_const_value(self): + class MockConst: + const_value = None + + result = fused_matmul_rule_sets.FusedMatMulDiv2.check(None, None, None, MockConst()) + self.assertFalse(result) + + + def test_fused_matmul_div1_check_none_const_value(self): + class MockConst: + const_value = None + + result = fused_matmul_rule_sets.FusedMatMulDiv1.check(None, None, None, MockConst()) + self.assertFalse(result) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/onnxruntime/softmax_test.py b/onnxscript/rewriter/onnxruntime/softmax_test.py index f2aa37c1f..adf92569d 100644 --- a/onnxscript/rewriter/onnxruntime/softmax_test.py +++ b/onnxscript/rewriter/onnxruntime/softmax_test.py @@ -89,6 +89,13 @@ def test_softmax_upcast_to_fp32_is_not_removed_when_final_output_is_not_fp16( len([node.op_type for node in model.graph if node.op_type == "Cast"]), 2 ) + def test_check_if_fp16_input_logs_warning_when_input_is_none(self): + with self.assertLogs('onnxscript.rewriter.onnxruntime.softmax', level='WARNING') as log: + result = softmax.check_if_fp16_input(None, None) + self.assertFalse(result) + self.assertIn("Cannot perform softmax upcast removal", log.output[0]) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu_test.py b/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu_test.py index 0812ae3d3..85b59af5f 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu_test.py +++ b/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu_test.py @@ -8,6 +8,9 @@ from tests.common import testutils +import unittest.mock +import onnx +import onnx class BiasSplitGeluParityTest(unittest.TestCase): def setUp(self): @@ -20,5 +23,112 @@ def test_geglu_stable_diffusion_unet(self): ) + def test_validate_method_calls_to_function_proto(self): + class MockFunction: + def to_function_proto(self): + return "FunctionProtoCalled" + + test_base = testutils.TestBase() + result = test_base.validate(MockFunction()) + self.assertEqual(result, "FunctionProtoCalled") + + + @unittest.mock.patch('onnx.load') + @unittest.mock.patch('tests.common.testutils.evaluation_utils.load_test_data') + @unittest.mock.patch('tests.common.testutils.optimizer.optimize') + @unittest.mock.patch('tests.common.testutils.ort_rewriter.rewrite') + @unittest.mock.patch('onnxruntime.InferenceSession') + def test_onnxruntime_rewrite_output_mismatch(self, mock_inference_session, mock_rewrite, mock_optimize, mock_load_test_data, mock_onnx_load): + model = onnx.helper.make_model( + onnx.helper.make_graph( + nodes=[ + onnx.helper.make_node("Add", ["X", "Y"], ["Z"], domain=""), + ], + name="test_graph", + inputs=[ + onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, [1]), + ], + ) + ) + mock_onnx_load.return_value = model + mock_load_test_data.return_value = ({}, [np.array([1.0])]) + mock_optimize.return_value = model + mock_rewrite.return_value = model + mock_session_instance = mock_inference_session.return_value + mock_session_instance.run.return_value = [np.array([2.0])] + + with self.assertRaises(AssertionError): + testutils.test_onnxruntime_rewrite( + "mock_model", 1, {("", "Add", "")} + ) + + + @unittest.mock.patch('torch.cuda.is_available', return_value=True) + @unittest.mock.patch('onnxruntime.get_device', return_value='GPU') + def test_skip_if_no_cuda_available(self, mock_cuda, mock_device): + @testutils.skip_if_no_cuda("Test reason") + def dummy_test(self): + return "Test Passed" + + result = dummy_test(self) + self.assertEqual(result, "Test Passed") + + + @unittest.mock.patch('onnx.load') + @unittest.mock.patch('tests.common.testutils.evaluation_utils.load_test_data') + @unittest.mock.patch('tests.common.testutils.optimizer.optimize') + @unittest.mock.patch('tests.common.testutils.ort_rewriter.rewrite') + def test_onnxruntime_rewrite_missing_optype(self, mock_rewrite, mock_optimize, mock_load_test_data, mock_onnx_load): + model = onnx.helper.make_model( + onnx.helper.make_graph( + nodes=[ + onnx.helper.make_node("Add", ["X", "Y"], ["Z"], domain=""), + ], + name="test_graph", + inputs=[ + onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, [1]), + ], + ) + ) + mock_onnx_load.return_value = model + mock_load_test_data.return_value = ({}, []) + mock_optimize.return_value = model + mock_rewrite.return_value = model + + with self.assertRaises(AssertionError): + testutils.test_onnxruntime_rewrite( + "mock_model", 1, {("com.microsoft", "NonExistentOp", "")} + ) + + + def test_op_type_analysis_visitor(self): + model = onnx.helper.make_model( + onnx.helper.make_graph( + nodes=[ + onnx.helper.make_node("Add", ["X", "Y"], ["Z"], domain=""), + ], + name="test_graph", + inputs=[ + onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, [1]), + ], + ) + ) + visitor = testutils.OpTypeAnalysisVisitor() + visitor.visit_model(model) + self.assertIn(("", "Add", ""), visitor.op_types) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py b/onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py index e6de540b8..455f42efb 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py +++ b/onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py @@ -19,5 +19,34 @@ def test_gelu_phi_1_5(self): ) + def test_validate_method_calls_to_function_proto(self): + class MockFunction: + def to_function_proto(self): + return "FunctionProtoCalled" + + test_base = testutils.TestBase() + result = test_base.validate(MockFunction()) + self.assertEqual(result, "FunctionProtoCalled") + + + def test_output_shape_mismatch(self): + with self.assertRaises(AssertionError): + testutils.test_onnxruntime_rewrite( + "gelu_phi_1_5", 1, {("com.microsoft", "FastGelu", "")}, atol=0 + ) + + + def test_assertion_error_for_missing_optypes(self): + with self.assertRaises(AssertionError): + testutils.test_onnxruntime_rewrite( + "gelu_phi_1_5", 1, {("com.microsoft", "NonExistentOpType", "")} + ) + + + @testutils.skip_if_no_cuda("Testing skip if no CUDA.") + def test_skip_if_no_cuda(self): + self.fail("This test should be skipped if CUDA is not available.") + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py b/onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py index c47c77ee7..c74ecdde1 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py +++ b/onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py @@ -19,5 +19,18 @@ def test_ln_llama2(self): ) + def test_onnxruntime_rewrite_missing_optype(self): + with self.assertRaises(AssertionError): + testutils.test_onnxruntime_rewrite( + "ln_llama2", 1, {("fake_domain", "FakeOpType", "")} + ) + + + @testutils.skip_if_no_cuda("CUDA is required for this test.") + def test_skip_if_no_cuda(self): + # This test will be skipped if CUDA is not available + self.assertTrue(torch.cuda.is_available() and onnxruntime.get_device() == "GPU") + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py index f752a00a7..df26c0413 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py +++ b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py @@ -7,6 +7,9 @@ import numpy as np from tests.common import testutils +import unittest.mock +import onnx +import onnx class MHAParityTest(unittest.TestCase): @@ -83,5 +86,119 @@ def test_attn_stable_diffusion_unet_without_encoder_hidden_states(self): ) + @unittest.mock.patch('onnxruntime.InferenceSession') + @unittest.mock.patch('onnx.load') + @unittest.mock.patch('tests.common.testutils.evaluation_utils.load_test_data', return_value=({"input": np.array([1.0])}, [np.array([1.0, 2.0])])) + def test_onnxruntime_rewrite_output_shape_mismatch(self, mock_load_test_data, mock_load, mock_inference_session): + model = onnx.helper.make_model( + onnx.helper.make_graph( + nodes=[ + onnx.helper.make_node("Add", ["X", "Y"], ["Z"], domain=""), + ], + name="test_graph", + inputs=[ + onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, [1]), + ], + ) + ) + mock_load.return_value = model + mock_session = mock_inference_session.return_value + mock_session.run.return_value = [np.array([1.0])] + + with self.assertRaises(AssertionError): + testutils.test_onnxruntime_rewrite( + "dummy_model", 1, {("", "Add", "")} + ) + + + def test_validate_method(self): + class MockFunction: + def to_function_proto(self): + return "function_proto" + + test_base = testutils.TestBase() + result = test_base.validate(MockFunction()) + self.assertEqual(result, "function_proto") + + + @unittest.mock.patch('onnxruntime.InferenceSession') + @unittest.mock.patch('onnx.load') + @unittest.mock.patch('tests.common.testutils.evaluation_utils.load_test_data', return_value=({}, [])) + def test_onnxruntime_rewrite_success(self, mock_load_test_data, mock_load, mock_inference_session): + model = onnx.helper.make_model( + onnx.helper.make_graph( + nodes=[ + onnx.helper.make_node("Add", ["X", "Y"], ["Z"], domain=""), + onnx.helper.make_node("Relu", ["Z"], ["W"], domain=""), + ], + name="test_graph", + inputs=[ + onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("W", onnx.TensorProto.FLOAT, [1]), + ], + ) + ) + mock_load.return_value = model + testutils.test_onnxruntime_rewrite( + "dummy_model", 1, {("", "Add", ""), ("", "Relu", "")} + ) + + + @unittest.mock.patch('onnxruntime.InferenceSession') + @unittest.mock.patch('onnx.load') + @unittest.mock.patch('tests.common.testutils.evaluation_utils.load_test_data', return_value=({}, [])) + def test_onnxruntime_rewrite_missing_optypes(self, mock_load_test_data, mock_load, mock_inference_session): + model = onnx.helper.make_model( + onnx.helper.make_graph( + nodes=[ + onnx.helper.make_node("Add", ["X", "Y"], ["Z"], domain=""), + ], + name="test_graph", + inputs=[ + onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, [1]), + ], + ) + ) + mock_load.return_value = model + with self.assertRaises(AssertionError): + testutils.test_onnxruntime_rewrite( + "dummy_model", 1, {("", "Relu", "")} + ) + + + def test_op_type_analysis_visitor(self): + model = onnx.helper.make_model( + onnx.helper.make_graph( + nodes=[ + onnx.helper.make_node("Add", ["X", "Y"], ["Z"], domain=""), + onnx.helper.make_node("Relu", ["Z"], ["W"], domain=""), + ], + name="test_graph", + inputs=[ + onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("W", onnx.TensorProto.FLOAT, [1]), + ], + ) + ) + visitor = testutils.OpTypeAnalysisVisitor() + visitor.visit_model(model) + expected_op_types = {("", "Add", ""), ("", "Relu", "")} + self.assertEqual(visitor.op_types, expected_op_types) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/tensor_test.py b/onnxscript/tensor_test.py index afe490e8d..477d91fcb 100644 --- a/onnxscript/tensor_test.py +++ b/onnxscript/tensor_test.py @@ -148,5 +148,194 @@ def test_getitem_gather(self): self._check_values_and_shape(y, [0, 9], (2, 1)) + def test_tensor_reverse_matmul(self): + x = tensor.Tensor(np.array([[1, 2], [3, 4]])) + y = tensor.Tensor(np.array([[5, 6], [7, 8]])) + result = y @ x + expected = np.array([[23, 34], [31, 46]]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_equality(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = tensor.Tensor(np.array([1, 2, 3])) + result = x == y + expected = np.array([True, True, True]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_reverse_sub_with_tensor(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = tensor.Tensor(np.array([10, 20, 30])) + result = y - x + expected = np.array([9, 18, 27]) + np.testing.assert_array_equal(result.value, expected) + + + def test_getitem_negative_step_slice(self): + data = np.array(range(12), dtype=np.int32).reshape(4, 3) + x = tensor.Tensor(data) + y = x[::-1] + expected = np.array([[9, 10, 11], [6, 7, 8], [3, 4, 5], [0, 1, 2]]) + np.testing.assert_array_equal(y.value, expected) + + + def test_tensor_greater_than_or_equal(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = tensor.Tensor(np.array([2, 2, 2])) + result = x >= y + expected = np.array([False, True, True]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_less_than_or_equal(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = tensor.Tensor(np.array([2, 2, 4])) + result = x <= y + expected = np.array([True, True, True]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_reverse_sub(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = 10 + result = y - x + expected = np.array([9, 8, 7]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_greater_than(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = tensor.Tensor(np.array([0, 2, 2])) + result = x > y + expected = np.array([True, False, True]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_less_than(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = tensor.Tensor(np.array([2, 2, 4])) + result = x < y + expected = np.array([True, False, True]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_reverse_mul_with_tensor(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = tensor.Tensor(np.array([2, 3, 4])) + result = y * x + expected = np.array([2, 6, 12]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_truediv(self): + x = tensor.Tensor(np.array([4.0, 9.0, 16.0])) + y = tensor.Tensor(np.array([2.0, 3.0, 4.0])) + result = x / y + expected = np.array([2.0, 3.0, 4.0]) + np.testing.assert_array_almost_equal(result.value, expected) + + + def test_tensor_pow(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = tensor.Tensor(np.array([2, 3, 4])) + result = x ** y + expected = np.array([1, 8, 81]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_or(self): + x = tensor.Tensor(np.array([True, False, True])) + y = tensor.Tensor(np.array([True, True, False])) + result = x | y + expected = np.array([True, True, True]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_and(self): + x = tensor.Tensor(np.array([True, False, True])) + y = tensor.Tensor(np.array([True, True, False])) + result = x & y + expected = np.array([True, False, False]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_reverse_mul(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = 2 + result = y * x + expected = np.array([2, 4, 6]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_reverse_add(self): + x = tensor.Tensor(np.array([1, 2, 3])) + y = 10 + result = y + x + expected = np.array([11, 12, 13]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_mod_integer(self): + x = tensor.Tensor(np.array([5, 6], dtype=np.int32)) + y = tensor.Tensor(np.array([2, 2], dtype=np.int32)) + result = x % y + expected = np.array([1, 0], dtype=np.int32) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_negation(self): + x = tensor.Tensor(np.array([1, -2, 3])) + result = -x + expected = np.array([-1, 2, -3]) + np.testing.assert_array_equal(result.value, expected) + + + def test_tensor_not_equal_scalar(self): + x = tensor.Tensor(np.array(1)) + y = tensor.Tensor(np.array(2)) + result = x != y + self.assertTrue(result.value) + + + def test_getitem_unexpected_index_type(self): + data = np.array([1, 2, 3]) + x = tensor.Tensor(data) + with self.assertRaises(TypeError): + _ = x["invalid"] + + + def test_getitem_index_exceeds_rank(self): + data = np.array([[1, 2, 3], [4, 5, 6]]) + x = tensor.Tensor(data) + with self.assertRaises(ValueError): + _ = x[0, 0, 0] + + + def test_tensor_mod_floating_point(self): + x = tensor.Tensor(np.array([5.5, 6.5], dtype=np.float32)) + y = tensor.Tensor(np.array([2.0, 2.0], dtype=np.float32)) + result = x % y + expected = np.array([1.5, 0.5], dtype=np.float32) + np.testing.assert_array_almost_equal(result.value, expected) + + + def test_getitem_opset_version_error(self): + data = np.array([1, 2, 3]) + x = tensor.Tensor(data, opset=type('Opset', (object,), {'version': 12})()) + with self.assertRaises(RuntimeError): + _ = x[0] + + + def test_tensor_repr(self): + x = tensor.Tensor(np.array([1, 2, 3])) + self.assertEqual(repr(x), "Tensor(array([1, 2, 3]))") + + + def test_tensor_initialization_with_invalid_type(self): + with self.assertRaises(TypeError): + tensor.Tensor([1, 2, 3]) # Not a numpy array + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/tools/benchmark/benchmark_helpers_test.py b/onnxscript/tools/benchmark/benchmark_helpers_test.py index ec88ffd9e..488bd11c4 100644 --- a/onnxscript/tools/benchmark/benchmark_helpers_test.py +++ b/onnxscript/tools/benchmark/benchmark_helpers_test.py @@ -4,6 +4,11 @@ import onnxscript.tools.benchmark.benchmark_helpers as bh +import torch +import onnx +import unittest.mock +import numpy as np +import sys class BenchmarkHelperTest(unittest.TestCase): def test_make_configs(self): @@ -48,6 +53,136 @@ def test_make_configs(self): ] self.assertEqual(expected, configs) + def test_run_benchmark_high_verbose(self): + script_name = "example_script" + configs = [{"arg1": "value1"}] + with unittest.mock.patch('subprocess.Popen') as mock_popen: + process_mock = unittest.mock.Mock() + attrs = {'communicate.return_value': (b"output", b"")} + process_mock.configure_mock(**attrs) + mock_popen.return_value = process_mock + result = bh.run_benchmark(script_name, configs, verbose=10, stop_if_exception=False) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertIn("ERROR", result[0]) + self.assertIn("OUTPUT", result[0]) + + + def test_common_export_valid_exporter_with_optimization(self): + import torch + import onnx + model = torch.nn.Linear(2, 2) + inputs = (torch.randn(1, 2),) + onnx_model = bh.common_export(model, inputs, exporter="script", optimization="optimize") + self.assertIsInstance(onnx_model, onnx.ModelProto) + + + def test_run_benchmark_with_onnxruntime_error_no_exception(self): + script_name = "example_script" + configs = [{"arg1": "value1"}] + with unittest.mock.patch('subprocess.Popen') as mock_popen: + process_mock = unittest.mock.Mock() + attrs = {'communicate.return_value': (b"", b"ONNXRuntimeError")} + process_mock.configure_mock(**attrs) + mock_popen.return_value = process_mock + result = bh.run_benchmark(script_name, configs, stop_if_exception=False) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertIn("ERROR", result[0]) + self.assertIn("OUTPUT", result[0]) + + + def test_optimize_model_proto_unknown_step(self): + import onnx + model_proto = onnx.ModelProto() + with self.assertRaises(AssertionError): + bh.optimize_model_proto(model_proto, optimization="unknown_step") + + + def test_apply_rule_sets_valid_rule_set(self): + import onnx + model_proto = onnx.ModelProto() + rule_sets = ["llama0"] + result = bh.apply_rule_sets(model_proto, rule_sets) + self.assertIsInstance(result, onnx.ModelProto) + + + def test_run_benchmark_no_metrics_no_exception(self): + script_name = "example_script" + configs = [{"arg1": "value1"}] + result = bh.run_benchmark(script_name, configs, stop_if_exception=False) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertIn("ERROR", result[0]) + self.assertIn("OUTPUT", result[0]) + + + def test_common_export_unknown_exporter(self): + model = None + inputs = [] + with self.assertRaises(ValueError): + bh.common_export(model, inputs, exporter="unknown") + + + def test_measure_discrepancies_shape_mismatch(self): + expected = [(np.array([1, 2, 3]),)] + outputs = [(np.array([1, 2]),)] + with self.assertRaises(AssertionError): + bh.measure_discrepancies(expected, outputs) + + + def test_run_benchmark_with_onnxruntime_error(self): + script_name = "example_script" + configs = [{"arg1": "value1"}] + with self.assertRaises(RuntimeError): + bh.run_benchmark(script_name, configs, stop_if_exception=True) + + + def test_make_prefix(self): + script_name = "example_script.py" + index = 3 + expected = "example_script_dort_c3_" + result = bh._make_prefix(script_name, index) + self.assertEqual(result, expected) + + + def test_extract_metrics_with_no_metrics(self): + text = "No metrics here" + expected = {} + result = bh._extract_metrics(text) + self.assertEqual(result, expected) + + + def test_get_machine_with_cuda(self): + import torch + torch.cuda.is_available = lambda: True + torch.cuda.get_device_capability = lambda x: (7, 5) + torch.cuda.get_device_name = lambda x: "Mock CUDA Device" + + result = bh.get_machine() + self.assertIn("has_cuda", result) + self.assertTrue(result["has_cuda"]) + self.assertEqual(result["capability"], (7, 5)) + self.assertEqual(result["device_name"], "Mock CUDA Device") + + + def test_get_parsed_args_with_custom_args(self): + name = "test_script" + new_args = ["--n_trees", "20", "--learning_rate", "0.05"] + kwargs = {"n_trees": (10, "number of trees to train"), "learning_rate": (0.01, "learning rate")} + expected = {"n_trees": 20, "learning_rate": 0.05} + result = bh.get_parsed_args(name, new_args=new_args, **kwargs) + self.assertEqual(result, expected) + + + def test_cmd_line_with_kwargs(self): + script_name = "example_script" + kwargs = {"arg1": "value1", "arg2": "value2"} + expected = [sys.executable, "-m", "example_script", "--arg1", "value1", "--arg2", "value2"] + result = bh._cmd_line(script_name, **kwargs) + self.assertEqual(result, expected) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/tools/benchmark/export_model_test.py b/onnxscript/tools/benchmark/export_model_test.py index 55698be67..4795ef2f7 100644 --- a/onnxscript/tools/benchmark/export_model_test.py +++ b/onnxscript/tools/benchmark/export_model_test.py @@ -201,5 +201,31 @@ def test_export_model_phi3_cpu_dynamo_llama0(self): self.assertIn(":repeat_time,", out) + @unittest.skipIf(not has_transformers(), reason="transformers missing") + def test_export_model_with_dynamic_shapes(self): + args = [ + "--verbose", + "1", + "--config", + "medium", + "--dtype", + "float32", + "--device", + "cpu", + "--exporter", + "eager", + "--model", + "phi", + "--dynamic", + "1", + ] + f = io.StringIO() + with contextlib.redirect_stdout(f): + onnxscript.tools.benchmark.export_model.main(args) + + out = f.getvalue() + self.assertIn("dynamic_shapes=", out) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/tools/memory_peak_test.py b/onnxscript/tools/memory_peak_test.py index 71bbc75c8..8fd4d8e7c 100644 --- a/onnxscript/tools/memory_peak_test.py +++ b/onnxscript/tools/memory_peak_test.py @@ -10,6 +10,7 @@ import onnxscript.tools.memory_peak +import multiprocessing class TestMemoryPeak(unittest.TestCase): @unittest.skipIf(sys.platform == "win32", reason="other test are failing") @@ -53,5 +54,28 @@ def test_spy_cuda(self): self.assertLessEqual(pres["gpus"][0].begin, pres["gpus"][0].max_peak) + def test_flatten_with_gpus_precise(self): + cpu_monitor = onnxscript.tools.memory_peak.Monitor() + gpu_monitor = onnxscript.tools.memory_peak.Monitor() + cpu_monitor.update(1000) + gpu_monitor.update(2000) + stats = {'cpu': [cpu_monitor], 'gpus': [gpu_monitor]} + flat_stats = onnxscript.tools.memory_peak.flatten(stats) + self.assertIn('gpu0_peak', flat_stats) + self.assertAlmostEqual(flat_stats['gpu0_peak'], 0.0019073486328125) + + + def test_memory_spy_start_handshake_failure(self): + original_pipe = multiprocessing.Pipe + def mock_pipe(): + parent_conn, child_conn = original_pipe() + parent_conn.recv = lambda: -1 # Simulate handshake failure + return parent_conn, child_conn + multiprocessing.Pipe = mock_pipe + with self.assertRaises(RuntimeError): + onnxscript.tools.memory_peak.MemorySpy(os.getpid()) + multiprocessing.Pipe = original_pipe + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/tools/transformers_models/llama_test.py b/onnxscript/tools/transformers_models/llama_test.py index ea4844476..dba679d47 100644 --- a/onnxscript/tools/transformers_models/llama_test.py +++ b/onnxscript/tools/transformers_models/llama_test.py @@ -138,6 +138,78 @@ def test_llama_dort_static(self): gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1.0e-5, rtol=1e-5) + def test_llama_model_from_config_small(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.llama.get_llama_model_from_config(config="small") + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + try: + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + except torch._export.verifier.SpecViolationError as e: + if "Node.meta _enter_autocast is missing val field." in str(e): + raise unittest.SkipTest(str(e)) + raise + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + + def test_llama_model_from_config_medium(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.llama.get_llama_model_from_config(config="medium") + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + try: + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + except torch._export.verifier.SpecViolationError as e: + if "Node.meta _enter_autocast is missing val field." in str(e): + raise unittest.SkipTest(str(e)) + raise + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + + def test_llama_model_from_config_unexpected_config(self): + with self.assertRaises(ValueError) as context: + onnxscript.tools.transformers_models.llama.get_llama_model_from_config(config="unexpected") + self.assertIn("Unexpected configuration", str(context.exception)) + + + def test_llama_model_without_mask(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.llama.get_llama_model(with_mask=False) + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + try: + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + except torch._export.verifier.SpecViolationError as e: + if "Node.meta _enter_autocast is missing val field." in str(e): + raise unittest.SkipTest(str(e)) + raise + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/tools/transformers_models/mistral_test.py b/onnxscript/tools/transformers_models/mistral_test.py index 7498b9a15..a60207aa2 100644 --- a/onnxscript/tools/transformers_models/mistral_test.py +++ b/onnxscript/tools/transformers_models/mistral_test.py @@ -140,6 +140,66 @@ def test_mistral_dort_static(self): gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + def test_get_mistral_model_from_config_medium(self): + model, input_tensors_many, dynamic_shapes = onnxscript.tools.transformers_models.mistral.get_mistral_model_from_config(config="medium") + input_tensors = input_tensors_many[0] + self.assertIsInstance(model, torch.nn.Module) + self.assertIsInstance(input_tensors, tuple) + self.assertIsInstance(dynamic_shapes, dict) + + + def test_get_mistral_model_without_mask_output(self): + model, input_tensors_many, _ = onnxscript.tools.transformers_models.mistral.get_mistral_model(with_mask=False) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + self.assertIsInstance(expected, tuple) + self.assertGreater(len(expected), 0) + + + def test_get_mistral_model_from_config_invalid_config(self): + with self.assertRaises(ValueError) as context: + onnxscript.tools.transformers_models.mistral.get_mistral_model_from_config(config="invalid") + self.assertIn("Unexpected configuration", str(context.exception)) + + + def test_get_mistral_model_without_mask(self): + model, input_tensors_many, dynamic_shapes = onnxscript.tools.transformers_models.mistral.get_mistral_model(with_mask=False) + input_tensors = input_tensors_many[0] + self.assertEqual(len(input_tensors), 1) # Only input_ids should be present + self.assertIsInstance(model, torch.nn.Module) + self.assertIsInstance(dynamic_shapes, dict) + + + def test_prepare_config_and_inputs_with_labels(self): + batch_size = 2 + seq_length = 3 + vocab_size = 10 + type_sequence_label_size = 2 + num_labels = 3 + num_choices = 4 + _, _, _, sequence_labels, token_labels, choice_labels = onnxscript.tools.transformers_models.mistral._prepare_config_and_inputs( + batch_size, seq_length, vocab_size, use_labels=True, type_sequence_label_size=type_sequence_label_size, num_labels=num_labels, num_choices=num_choices + ) + self.assertIsNotNone(sequence_labels) + self.assertEqual(sequence_labels.shape, (batch_size,)) + self.assertIsNotNone(token_labels) + self.assertEqual(token_labels.shape, (batch_size, seq_length)) + self.assertIsNotNone(choice_labels) + self.assertEqual(choice_labels.shape, (batch_size,)) + + + def test_prepare_config_and_inputs_with_token_type_ids(self): + batch_size = 2 + seq_length = 3 + vocab_size = 10 + type_vocab_size = 5 + input_ids, token_type_ids, _, _, _, _ = onnxscript.tools.transformers_models.mistral._prepare_config_and_inputs( + batch_size, seq_length, vocab_size, use_token_type_ids=True, type_vocab_size=type_vocab_size + ) + self.assertIsNotNone(token_type_ids) + self.assertEqual(token_type_ids.shape, (batch_size, seq_length)) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/tools/transformers_models/phi3_test.py b/onnxscript/tools/transformers_models/phi3_test.py index d9adcfd86..51bc58dc1 100644 --- a/onnxscript/tools/transformers_models/phi3_test.py +++ b/onnxscript/tools/transformers_models/phi3_test.py @@ -141,6 +141,105 @@ def test_phi3_dort_static(self): gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + def test_get_phi3_model_from_config_large(self): + model, input_tensors_many, dynamic_shapes = ( + onnxscript.tools.transformers_models.phi3.get_phi3_model_from_config(config="large") + ) + self.assertIsNotNone(model) + self.assertIsInstance(input_tensors_many, list) + self.assertIsInstance(dynamic_shapes, dict) + + + def test_get_phi3_model_from_config_medium(self): + model, input_tensors_many, dynamic_shapes = ( + onnxscript.tools.transformers_models.phi3.get_phi3_model_from_config(config="medium") + ) + self.assertIsNotNone(model) + self.assertIsInstance(input_tensors_many, list) + self.assertIsInstance(dynamic_shapes, dict) + + + def test_get_phi3_model_no_mask(self): + model, input_tensors_many, dynamic_shapes = ( + onnxscript.tools.transformers_models.phi3.get_phi3_model(with_mask=False) + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + self.assertIsNotNone(expected) + + + def test_get_phi3_model_unexpected_config(self): + with self.assertRaises(ValueError) as context: + onnxscript.tools.transformers_models.phi3.get_phi3_model_from_config(config="unexpected") + self.assertIn("Unexpected configuration", str(context.exception)) + + + def test_has_phi3_import_error(self): + import sys + import builtins + original_import = builtins.__import__ + + def mocked_import(name, *args): + if name == "transformers": + raise ImportError("No module named 'transformers'") + return original_import(name, *args) + + builtins.__import__ = mocked_import + try: + result = onnxscript.tools.transformers_models.phi3.has_phi3() + self.assertFalse(result) + finally: + builtins.__import__ = original_import + + + def test_prepare_config_and_inputs_num_choices_zero(self): + with self.assertRaises(AssertionError) as context: + onnxscript.tools.transformers_models.phi3._prepare_config_and_inputs( + batch_size=1, + seq_length=1, + vocab_size=10, + num_choices=0, + use_labels=True + ) + self.assertIn("num_choices is null", str(context.exception)) + + + def test_prepare_config_and_inputs_num_labels_zero(self): + with self.assertRaises(AssertionError) as context: + onnxscript.tools.transformers_models.phi3._prepare_config_and_inputs( + batch_size=1, + seq_length=1, + vocab_size=10, + num_labels=0, + use_labels=True + ) + self.assertIn("num_labels is null", str(context.exception)) + + + def test_prepare_config_and_inputs_type_sequence_label_size_zero(self): + with self.assertRaises(AssertionError) as context: + onnxscript.tools.transformers_models.phi3._prepare_config_and_inputs( + batch_size=1, + seq_length=1, + vocab_size=10, + type_sequence_label_size=0, + use_labels=True + ) + self.assertIn("type_sequence_label_size is null", str(context.exception)) + + + def test_prepare_config_and_inputs_type_vocab_size_zero(self): + with self.assertRaises(AssertionError) as context: + onnxscript.tools.transformers_models.phi3._prepare_config_and_inputs( + batch_size=1, + seq_length=1, + vocab_size=10, + type_vocab_size=0, + use_token_type_ids=True + ) + self.assertIn("type_vocab_size is null", str(context.exception)) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/tools/transformers_models/phi_test.py b/onnxscript/tools/transformers_models/phi_test.py index e835d8b1d..e2c8c03e1 100644 --- a/onnxscript/tools/transformers_models/phi_test.py +++ b/onnxscript/tools/transformers_models/phi_test.py @@ -106,5 +106,52 @@ def test_phi_dort_static(self): torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + def test_get_phi_model_from_config_unexpected_config(self): + with self.assertRaises(ValueError) as context: + onnxscript.tools.transformers_models.phi.get_phi_model_from_config(config="unexpected") + self.assertIn("Unexpected configuration", str(context.exception)) + + + def test_get_phi_model_no_mask(self): + model, input_tensors_many, dynamic_shapes = onnxscript.tools.transformers_models.phi.get_phi_model(with_mask=False) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + self.assertIsNotNone(expected) + self.assertEqual(len(input_tensors), 1) + + + def test_prepare_config_and_inputs_with_token_type_ids_and_labels(self): + batch_size = 2 + seq_length = 3 + vocab_size = 10 + type_sequence_label_size = 2 + type_vocab_size = 5 + num_labels = 3 + num_choices = 4 + use_input_mask = True + use_token_type_ids = True + use_labels = True + + result = onnxscript.tools.transformers_models.phi._prepare_config_and_inputs( + batch_size, + seq_length, + vocab_size, + type_sequence_label_size, + type_vocab_size, + num_labels, + num_choices, + use_input_mask, + use_token_type_ids, + use_labels, + ) + + input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = result + + self.assertIsNotNone(token_type_ids) + self.assertIsNotNone(sequence_labels) + self.assertIsNotNone(token_labels) + self.assertIsNotNone(choice_labels) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 4104eb51d..ff4452d68 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -241,5 +241,35 @@ def test_get_type_constraint_name(self, _: str, pytype: Any, expected: Optional[ self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected) + def test_pytype_to_attrtype_unsupported(self): + self.assertIsNone(type_annotation.pytype_to_attrtype(dict)) + + + def test_pytype_to_attrtype_optional_and_list(self): + import onnx + self.assertEqual(type_annotation.pytype_to_attrtype(Optional[int]), onnx.AttributeProto.INT) + self.assertEqual(type_annotation.pytype_to_attrtype(Optional[float]), onnx.AttributeProto.FLOAT) + self.assertEqual(type_annotation.pytype_to_attrtype(List[int]), onnx.AttributeProto.INTS) + self.assertEqual(type_annotation.pytype_to_attrtype(List[float]), onnx.AttributeProto.FLOATS) + + + def test_is_value_type_unsupported_annotation(self): + with self.assertRaises(ValueError): + type_annotation.is_value_type(dict) + + + def test_base_type_is_bool(self): + self.assertTrue(type_annotation.base_type_is_bool(bool)) + self.assertTrue(type_annotation.base_type_is_bool(Optional[bool])) + self.assertTrue(type_annotation.base_type_is_bool(Sequence[bool])) + self.assertFalse(type_annotation.base_type_is_bool(int)) + self.assertFalse(type_annotation.base_type_is_bool(Optional[int])) + + + def test_onnx_attr_type_to_onnxscript_repr_unsupported_type(self): + with self.assertRaises(ValueError): + type_annotation.onnx_attr_type_to_onnxscript_repr(999) # Assuming 999 is unsupported + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/values_test.py b/onnxscript/values_test.py index c33e62333..3c7fd0470 100644 --- a/onnxscript/values_test.py +++ b/onnxscript/values_test.py @@ -103,6 +103,29 @@ def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "def self.assertEqual(annotations["attr2"], float) self.assertEqual(annotations["attr3"], str) + def test_onnx_function_to_model_proto_with_required_attributes(self): + opset = values.Opset("test", 1) + + @onnxscript.script(default_opset=opset) + def function(input1, attr1: int): + return input1 + attr1 + + with self.assertRaises(ValueError, msg="A function with required attributes cannot be exported as a model."): + function.to_model_proto() + + + def test_prepare_inputs_trims_none(self): + opset = values.Opset("test", 1) + inputs = [1, 2, None, None] + prepared_inputs = opset._prepare_inputs(None, *inputs) + self.assertEqual(prepared_inputs, [1, 2], "Opset._prepare_inputs should trim 'None' values from the end of the inputs list.") + + + def test_param_schema_is_attribute(self): + param_schema = values.ParamSchema(name="attr", is_input=False) + self.assertTrue(param_schema.is_attribute, "ParamSchema should identify non-input parameters as attributes.") + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index 472ffe2e5..5b38e2662 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -328,5 +328,19 @@ def test_version_convert_compatible(self): version_converter.convert_version(model, target_version=target_version) + + def test_get_str_attribute_invalid_type(self): + node = ir.Node("", "TestOp", []) + node.attributes["test_attr"] = "not_an_attr_instance" + result = version_converter._version_converter._get_str_attribute(node, "test_attr") + self.assertIsNone(result) + + + def test_get_int_attribute_invalid_type(self): + node = ir.Node("", "TestOp", []) + node.attributes["test_attr"] = "not_an_attr_instance" + result = version_converter._version_converter._get_int_attribute(node, "test_attr") + self.assertIsNone(result) + if __name__ == "__main__": unittest.main() diff --git a/opgen/pygen_test.py b/opgen/pygen_test.py index 831052026..7a633f18a 100644 --- a/opgen/pygen_test.py +++ b/opgen/pygen_test.py @@ -57,6 +57,88 @@ def raise_(): """, ) + def test_node_replace_with_self(self): + class TestNode(cg.Node): + def accept(self, visitor: cg.Visitor): + pass + node = TestNode() + try: + node.replace(node) + except ValueError: + self.fail("replace() raised ValueError unexpectedly!") + + + def test_node_replace_root_node(self): + class TestNode(cg.Node): + def accept(self, visitor: cg.Visitor): + pass + node = TestNode() + with self.assertRaises(ValueError): + node.replace(TestNode()) + + + def test_node_get_ancestors_and_self(self): + class TestNode(cg.Node): + def accept(self, visitor: cg.Visitor): + pass + root = TestNode() + child = TestNode() + root.append_child(child, cg.Role("child")) + ancestors = list(child.get_ancestors(and_self=True)) + self.assertEqual(ancestors, [child, root]) + + + def test_node_replace_with_none(self): + class TestNode(cg.Node): + def accept(self, visitor: cg.Visitor): + pass + parent = TestNode() + child = TestNode() + parent.append_child(child, cg.Role("child")) + child.replace(None) + self.assertIsNone(child.parent) + self.assertIsNone(parent.first_child) + + + def test_node_predicate_type_matching(self): + class TestNode(cg.Node): + def accept(self, visitor: cg.Visitor): + pass + predicate = cg.NodePredicate(type_=TestNode) + node = TestNode() + self.assertTrue(predicate.matches(node)) + + + def test_node_remove(self): + class TestNode(cg.Node): + def accept(self, visitor: cg.Visitor): + pass + parent = TestNode() + child = TestNode() + parent.append_child(child, cg.Role("child")) + removed_node = child.remove() + self.assertIsNone(removed_node.parent) + self.assertIsNone(parent.first_child) + + + def test_node_replace_with_new_node(self): + class TestNode(cg.Node): + def accept(self, visitor: cg.Visitor): + pass + parent = TestNode() + old_node = TestNode() + new_node = TestNode() + parent.append_child(old_node, cg.Role("child")) + old_node.replace(new_node) + self.assertIsNone(old_node.parent) + self.assertEqual(new_node.parent, parent) + + + def test_single_or_none_multiple_elements(self): + with self.assertRaises(StopIteration): + cg.single_or_none([1, 2]) + + if __name__ == "__main__": unittest.main() diff --git a/tests/eager_mode_test.py b/tests/eager_mode_test.py index 566169f22..6390d0317 100644 --- a/tests/eager_mode_test.py +++ b/tests/eager_mode_test.py @@ -114,6 +114,62 @@ def test_function_attribute_by_positional_args(self): def test_function_input_and_attribute_by_kwargs_out_of_order(self): self.assertEqual(add_with_alpha(alpha=3.0, other=2.0, this=1.0), 7.0) + def test_adapt_to_user_mode_mixed_tuple(self): + tensor_tuple = (onnxscript.tensor.Tensor(np.array([1, 2, 3])), np.array([4, 5, 6])) + result = onnxscript.evaluator._adapt_to_user_mode(tensor_tuple) + self.assertIsInstance(result, tuple) + for item in result: + self.assertIsInstance(item, np.ndarray) + + + def test_base_evaluator_adapt_attributes_with_callable(self): + class DummySchema: + attributes = {'attr': None} + + def dummy_function(): + pass + + class DummyEvaluator(onnxscript.evaluator.BaseEvaluator): + def _eval(self, schema, inputs, attributes, closure): + pass + + evaluator = DummyEvaluator() + with self.assertRaises(TypeError): + evaluator.adapt_attributes(DummySchema(), {'attr': dummy_function}) + + + def test_adapt_to_user_mode_tuple_of_tensors(self): + tensor_tuple = (onnxscript.tensor.Tensor(np.array([1, 2, 3])), onnxscript.tensor.Tensor(np.array([4, 5, 6]))) + result = onnxscript.evaluator._adapt_to_user_mode(tensor_tuple) + self.assertIsInstance(result, tuple) + for item in result: + self.assertIsInstance(item, np.ndarray) + + + def test_adapt_to_eager_mode_nested_list_of_integers(self): + nested_list = [[1, 2], [3, 4]] + result, has_array = onnxscript.evaluator._adapt_to_eager_mode(nested_list) + self.assertFalse(has_array) + for sublist in result: + for item in sublist: + self.assertIsInstance(item, onnxscript.tensor.Tensor) + + + def test_adapt_to_user_mode_list_of_tensors(self): + tensor_list = [onnxscript.tensor.Tensor(np.array([1, 2, 3])), onnxscript.tensor.Tensor(np.array([4, 5, 6]))] + result = onnxscript.evaluator._adapt_to_user_mode(tensor_list) + self.assertIsInstance(result, list) + for item in result: + self.assertIsInstance(item, np.ndarray) + + + def test_unwrap_tensors_in_kwargs(self): + kwargs = {'a': onnxscript.tensor.Tensor(np.array([1, 2])), 'b': 3} + result = onnxscript.evaluator._unwrap_tensors_in_kwargs(kwargs) + self.assertIsInstance(result['a'], np.ndarray) + self.assertEqual(result['b'], 3) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/external_tensor_test.py b/tests/external_tensor_test.py index f12e5720c..d1875bc0b 100644 --- a/tests/external_tensor_test.py +++ b/tests/external_tensor_test.py @@ -43,5 +43,53 @@ def TestFun(X: FLOAT[1024]) -> FLOAT[1024]: self.assertIn("external_tensor('bias', 1, [10], 'bias', length=40)", pymodel) + def test_onnx_type_to_onnxscript_repr_tensor_with_dim_params(self): + from onnxscript.onnx_types import onnx_type_to_onnxscript_repr + from onnx import TypeProto, TensorShapeProto + onnx_type = TypeProto() + onnx_type.tensor_type.elem_type = onnx.TensorProto.FLOAT + dim1 = onnx_type.tensor_type.shape.dim.add() + dim1.dim_param = 'dim1' + dim2 = onnx_type.tensor_type.shape.dim.add() + dim2.dim_param = 'dim2' + result = onnx_type_to_onnxscript_repr(onnx_type) + self.assertEqual(result, "FLOAT['dim1','dim2']") + + + def test_onnx_type_to_onnxscript_repr_tensor_unknown_rank(self): + from onnxscript.onnx_types import onnx_type_to_onnxscript_repr + from onnx import TypeProto + onnx_type = TypeProto() + onnx_type.tensor_type.elem_type = onnx.TensorProto.FLOAT + result = onnx_type_to_onnxscript_repr(onnx_type) + self.assertEqual(result, "FLOAT[...]") + + + def test_onnx_type_to_onnxscript_repr_not_implemented(self): + from onnxscript.onnx_types import onnx_type_to_onnxscript_repr + from onnx import TypeProto + unsupported_type = TypeProto() + with self.assertRaises(NotImplementedError): + onnx_type_to_onnxscript_repr(unsupported_type) + + + def test_class_getitem_shape_already_specified(self): + from onnxscript.onnx_types import FLOAT + with self.assertRaises(ValueError): + FLOAT[None][None] + + + def test_tensor_type_instantiation(self): + with self.assertRaises(NotImplementedError): + from onnxscript.onnx_types import TensorType + TensorType() + + + def test_check_dim_invalid_type(self): + with self.assertRaises(TypeError): + from onnxscript.onnx_types import _check_dim + _check_dim(3.14) # Invalid type, should raise TypeError + + if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/quantization_test.py b/tests/function_libs/torch_lib/quantization_test.py index 7ec04ee77..bd74a80a7 100644 --- a/tests/function_libs/torch_lib/quantization_test.py +++ b/tests/function_libs/torch_lib/quantization_test.py @@ -14,6 +14,8 @@ from onnxscript._internal import version_utils +import warnings +import unittest.mock class QuantizedModelExportTest(unittest.TestCase): @unittest.skipIf( @@ -49,6 +51,67 @@ def forward(self, x): program = torch.onnx.dynamo_export(pt2e_torch_model, *example_inputs) onnx.checker.check_model(program.model_proto, full_check=True) + def test_is_onnxruntime_training_with_push_back_batch(self): + with unittest.mock.patch('onnxruntime.training', create=True): + mock_ortvaluevector = type('OrtValueVector', (object,), {'push_back_batch': True})() + with unittest.mock.patch('onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector', mock_ortvaluevector): + result = version_utils.is_onnxruntime_training() + self.assertTrue(result) + + + def test_ignore_warnings_suppresses_warning(self): + @version_utils.ignore_warnings(UserWarning) + def dummy_function(self): + warnings.warn("This is a user warning", UserWarning) + return True + + result = dummy_function(self) + self.assertTrue(result) + + + def test_has_transformers_not_installed(self): + with unittest.mock.patch.dict('sys.modules', {'transformers': None}): + result = version_utils.has_transformers() + self.assertFalse(result) + + + def test_has_transformers_installed(self): + with unittest.mock.patch('sys.modules', {'transformers': unittest.mock.Mock()}): + result = version_utils.has_transformers() + self.assertTrue(result) + + + def test_numpy_older_than_true(self): + with unittest.mock.patch('numpy.__version__', '1.18.0'): + result = version_utils.numpy_older_than("1.19.0") + self.assertTrue(result) + + + def test_onnxruntime_older_than_true(self): + with unittest.mock.patch('onnxruntime.__version__', '1.8.0'): + result = version_utils.onnxruntime_older_than("1.9.0") + self.assertTrue(result) + + + def test_ignore_warnings_raises_assertion(self): + with self.assertRaises(AssertionError): + @version_utils.ignore_warnings(None) + def dummy_function(): + pass + + + def test_is_onnxruntime_training_no_training(self): + with unittest.mock.patch.dict('sys.modules', {'onnxruntime.training': None}): + result = version_utils.is_onnxruntime_training() + self.assertFalse(result) + + + def test_transformers_older_than_no_transformers(self): + with unittest.mock.patch.dict('sys.modules', {'transformers': None}): + result = version_utils.transformers_older_than("4.0") + self.assertIsNone(result) + + if __name__ == "__main__": unittest.main() diff --git a/tests/ir/graph_view_test.py b/tests/ir/graph_view_test.py index 83a51cdaa..095e77eca 100644 --- a/tests/ir/graph_view_test.py +++ b/tests/ir/graph_view_test.py @@ -39,5 +39,102 @@ def test_it_can_be_serialized_as_graph_proto(self): # It should succeed + def test_deserialize_string_tensor_with_external_data_location(self): + tensor_proto = onnx.TensorProto() + tensor_proto.data_type = onnx.TensorProto.STRING + tensor_proto.data_location = onnx.TensorProto.EXTERNAL + tensor_proto.string_data.extend([b"external_data"]) + tensor_proto.dims.extend([1]) + + tensor = ir.serde.deserialize_tensor(tensor_proto) + + self.assertIsInstance(tensor, ir.ExternalTensor) + self.assertEqual(tensor.shape.dims, (1,)) + + + def test_serialize_string_tensor(self): + string_tensor = ir.StringTensor([b"foo", b"bar"], shape=ir.Shape([2])) + + tensor_proto = ir.serde.serialize_tensor(string_tensor) + + self.assertEqual(tensor_proto.data_type, onnx.TensorProto.STRING) + self.assertEqual(tensor_proto.string_data, [b"foo", b"bar"]) + + + def test_deserialize_string_tensor(self): + tensor_proto = onnx.TensorProto() + tensor_proto.data_type = onnx.TensorProto.STRING + tensor_proto.string_data.extend([b"hello", b"world"]) + tensor_proto.dims.extend([2]) + + tensor = ir.serde.deserialize_tensor(tensor_proto) + + self.assertIsInstance(tensor, ir.StringTensor) + self.assertEqual(tensor.numpy().tolist(), [b"hello", b"world"]) + + + def test_deserialize_unsorted_graph(self): + graph_proto = onnx.GraphProto() + node_proto_1 = onnx.NodeProto() + node_proto_1.name = "node1" + node_proto_1.op_type = "Add" + node_proto_1.input.extend(["input1", "input2"]) + node_proto_1.output.extend(["output1"]) + + node_proto_2 = onnx.NodeProto() + node_proto_2.name = "node2" + node_proto_2.op_type = "Mul" + node_proto_2.input.extend(["output1", "input3"]) + node_proto_2.output.extend(["output2"]) + + graph_proto.node.extend([node_proto_2, node_proto_1]) + + graph = ir.serde.deserialize_graph(graph_proto) + self.assertEqual(len(graph), 2) + self.assertEqual(graph[0].name, "node2") + self.assertEqual(graph[1].name, "node1") + + + def test_serialize_function_no_inputs(self): + function_proto = onnx.FunctionProto() + function_proto.name = "test_function" + function_proto.domain = "test_domain" + function_proto.opset_import.add(domain="", version=13) + + model_proto = onnx.ModelProto() + model_proto.ir_version = 7 + model_proto.functions.extend([function_proto]) + + model = ir.serde.deserialize_model(model_proto) + serialized_model = ir.serde.serialize_model(model) + + self.assertEqual(len(serialized_model.functions), 1) + self.assertEqual(serialized_model.functions[0].name, "test_function") + + + def test_serialize_sparse_tensor_not_implemented_error(self): + attribute_proto = onnx.AttributeProto() + with self.assertRaises(NotImplementedError): + ir.serde._fill_in_value_for_attribute(attribute_proto, ir.AttributeType.SPARSE_TENSOR, None) + + + def test_to_proto_not_implemented_error(self): + class UnsupportedIRObject: + pass + + unsupported_ir_object = UnsupportedIRObject() + with self.assertRaises(NotImplementedError): + ir.serde.to_proto(unsupported_ir_object) + + + def test_from_proto_not_implemented_error(self): + class UnsupportedProto: + pass + + unsupported_proto = UnsupportedProto() + with self.assertRaises(NotImplementedError): + ir.serde.from_proto(unsupported_proto) + + if __name__ == "__main__": unittest.main() diff --git a/tests/onnx_types_test.py b/tests/onnx_types_test.py index 1f7a98cc1..f68edd2c7 100644 --- a/tests/onnx_types_test.py +++ b/tests/onnx_types_test.py @@ -13,6 +13,9 @@ from parameterized import parameterized +import onnx +from onnxscript.onnx_types import onnx_type_to_onnxscript_repr +from onnxscript.onnx_types import _check_dim from onnxscript.onnx_types import DOUBLE, FLOAT, TensorType, tensor_type_registry @@ -73,5 +76,73 @@ def test_shapes_are_not_same_type(self, a: TensorType, b: TensorType): self.assertIsNot(a, b) + def test_to_type_proto_unsupported_onnx_type(self): + class MockTypeProto: + def HasField(self, field): + return False + + type_proto = MockTypeProto() + with self.assertRaises(NotImplementedError): + onnx_type_to_onnxscript_repr(type_proto) + + + def test_onnx_type_to_onnxscript_repr(self): + # Mocking an ONNX TypeProto + class MockDim: + def __init__(self, dim_value=None, dim_param=None): + self.dim_value = dim_value + self.dim_param = dim_param + + def HasField(self, field): + if field == "dim_value": + return self.dim_value is not None + if field == "dim_param": + return self.dim_param is not None + return False + + class MockShape: + def __init__(self, dims): + self.dim = dims + + class MockTensorType: + def __init__(self, elem_type, shape=None): + self.elem_type = elem_type + self.shape = shape + + def HasField(self, field): + if field == "shape": + return self.shape is not None + return False + + class MockTypeProto: + def __init__(self, tensor_type): + self.tensor_type = tensor_type + + def HasField(self, field): + return field == "tensor_type" + + # Test cases + tensor_type = MockTensorType(onnx.TensorProto.FLOAT, MockShape([MockDim(10), MockDim(dim_param='N')])) + type_proto = MockTypeProto(tensor_type) + self.assertEqual(onnx_type_to_onnxscript_repr(type_proto), "FLOAT[10,'N']") + + tensor_type = MockTensorType(onnx.TensorProto.INT32, MockShape([])) + type_proto = MockTypeProto(tensor_type) + self.assertEqual(onnx_type_to_onnxscript_repr(type_proto), "INT32") + + tensor_type = MockTensorType(onnx.TensorProto.BOOL) + type_proto = MockTypeProto(tensor_type) + self.assertEqual(onnx_type_to_onnxscript_repr(type_proto), "BOOL[...]") + + + def test_check_dim_invalid_type(self): + with self.assertRaises(TypeError): + _check_dim(3.14) + with self.assertRaises(TypeError): + _check_dim([1, 2, 3]) + with self.assertRaises(TypeError): + _check_dim({'dim': 1}) + + if __name__ == "__main__": unittest.main() diff --git a/tests/operator_test.py b/tests/operator_test.py index 8ff193ce4..c115783ce 100644 --- a/tests/operator_test.py +++ b/tests/operator_test.py @@ -40,6 +40,63 @@ def implicit_plus1(A: FLOAT["N"]) -> FLOAT["N"]: # noqa: F821 onnxscript.testing.assert_isomorphic_function(explicit_plus1, implicit_plus1) + def test_pow_function(self): + @script(default_opset=op) + def pow_function(X, Y): + return op.Pow(X, Y) + + X = op.Constant(value=onnx.helper.make_tensor("X", onnx.TensorProto.FLOAT, [3], [2.0, 3.0, 4.0])) + Y = op.Constant(value=onnx.helper.make_tensor("Y", onnx.TensorProto.FLOAT, [3], [1.0, 2.0, 3.0])) + result = pow_function(X, Y) + self.assertIsNotNone(result) + + + def test_optional_with_input(self): + @script(default_opset=op) + def optional_with_input(input): + return op.Optional(input) + + input = op.Constant(value=onnx.helper.make_tensor("input", onnx.TensorProto.FLOAT, [1], [1.0])) + result = optional_with_input(input) + self.assertIsNotNone(result) + + + def test_cast_like_float_to_float(self): + @script(default_opset=op) + def cast_like(input, target_type): + return op.CastLike(input, target_type) + + input = op.Constant(value=onnx.helper.make_tensor("input", onnx.TensorProto.FLOAT, [3], [1.0, 2.0, 3.0])) + target_type = op.Constant(value=onnx.helper.make_tensor("target_type", onnx.TensorProto.FLOAT, [3], [0.0, 0.0, 0.0])) + result = cast_like(input, target_type) + self.assertIsNotNone(result) + + + def test_bernoulli_with_seed_reproducibility(self): + @script(default_opset=op) + def bernoulli_with_seed(input): + return op.Bernoulli(input, seed=42.0) + + input = op.Constant(value=onnx.helper.make_tensor("input", onnx.TensorProto.FLOAT, [3], [0.5, 0.5, 0.5])) + result = bernoulli_with_seed(input) + self.assertIsNotNone(result) + + + def test_batch_normalization_inference_simple(self): + @script(default_opset=op) + def batch_norm_inference(X, scale, B, input_mean, input_var): + return op.BatchNormalization(X, scale, B, input_mean, input_var) + + X = op.Constant(value=onnx.helper.make_tensor("X", onnx.TensorProto.FLOAT, [1, 2, 2, 2], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])) + scale = op.Constant(value=onnx.helper.make_tensor("scale", onnx.TensorProto.FLOAT, [2], [1.0, 1.0])) + B = op.Constant(value=onnx.helper.make_tensor("B", onnx.TensorProto.FLOAT, [2], [0.0, 0.0])) + input_mean = op.Constant(value=onnx.helper.make_tensor("input_mean", onnx.TensorProto.FLOAT, [2], [0.0, 0.0])) + input_var = op.Constant(value=onnx.helper.make_tensor("input_var", onnx.TensorProto.FLOAT, [2], [1.0, 1.0])) + + result = batch_norm_inference(X, scale, B, input_mean, input_var) + self.assertIsNotNone(result) + + if __name__ == "__main__": unittest.main()