From bb6f4f1472de9d1d5b31ca4ea8b422609cbc1911 Mon Sep 17 00:00:00 2001 From: Patrick Yoho Date: Mon, 26 Jan 2026 13:07:49 -0600 Subject: [PATCH 1/5] Improve some string_utils functions and add tests --- components/lif/string_utils/core.py | 26 ++++- test/components/lif/string_utils/test_core.py | 95 ++++++++++++++++++- 2 files changed, 113 insertions(+), 8 deletions(-) diff --git a/components/lif/string_utils/core.py b/components/lif/string_utils/core.py index 6884657..b56effb 100644 --- a/components/lif/string_utils/core.py +++ b/components/lif/string_utils/core.py @@ -19,6 +19,7 @@ def safe_identifier(name: str) -> str: s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", name) s2 = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1) safe = re.sub(r"\W|^(?=\d)", "_", s2) + safe = re.sub(r"_+", "_", safe) # Collapse consecutive underscores return safe.lower() @@ -31,7 +32,24 @@ def to_pascal_case(*parts: str) -> str: Returns: str: PascalCase string. """ - return "".join("".join(word.capitalize() for word in part.split("_")) for part in parts if part) + result = [] + for part in parts: + if not part: + continue + # First, insert separators at case boundaries (camelCase -> camel_Case) + s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", part.strip()) + s2 = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1) + # Now split on all separators (underscores, hyphens, spaces) + words = re.split(r"[_\-\s]+", s2) + for word in words: + if not word: + continue + # If word is all uppercase (acronym), keep it; otherwise capitalize + if word.isupper(): + result.append(word) + else: + result.append(word.capitalize()) + return "".join(result) def to_snake_case(name: str) -> str: @@ -41,12 +59,10 @@ def to_snake_case(name: str) -> str: def to_camel_case(s: str) -> str: - """Converts snake_case to lowerCamelCase.""" + """Convert string to camelCase.""" + s = re.sub(r"([_\-\s]+)([a-zA-Z])", lambda m: m.group(2).upper(), s) if not s: return s - if "_" in s: - parts = s.lower().split("_") - return parts[0] + "".join(word.capitalize() for word in parts[1:]) return s[0].lower() + s[1:] diff --git a/test/components/lif/string_utils/test_core.py b/test/components/lif/string_utils/test_core.py index da55029..aeb2210 100644 --- a/test/components/lif/string_utils/test_core.py +++ b/test/components/lif/string_utils/test_core.py @@ -1,5 +1,94 @@ -from lif.string_utils import core +from datetime import date, datetime +from lif.string_utils import ( + safe_identifier, + to_pascal_case, + to_snake_case, + to_camel_case, + camelcase_path, + dict_keys_to_snake, + dict_keys_to_camel, + convert_dates_to_strings, + to_value_enum_name, +) -def test_sample(): - assert core is not None + +class TestSafeIdentifier: + def test_basic(self): + assert safe_identifier("First Name") == "first_name" + assert safe_identifier("first-name") == "first_name" + assert safe_identifier("first$name") == "first_name" + + def test_leading_digit(self): + assert safe_identifier("123abc") == "_123abc" + + def test_camel_pascal(self): + assert safe_identifier("CamelCase") == "camel_case" + assert safe_identifier("camelCaseABC") == "camel_case_abc" + + +class TestToPascalCase: + def test_single_part(self): + assert to_pascal_case("hello world") == "HelloWorld" + assert to_pascal_case("hello-world") == "HelloWorld" + assert to_pascal_case("hello_world") == "HelloWorld" + + def test_multiple_parts(self): + assert to_pascal_case("hello", "world") == "HelloWorld" + assert to_pascal_case("HTTP", "status", "200") == "HTTPStatus200" + + def test_mixed_case(self): + assert to_pascal_case("camelCase") == "CamelCase" + assert to_pascal_case("PascalCase") == "PascalCase" + + +class TestToSnakeCase: + def test_basic(self): + assert to_snake_case("CamelCase") == "camel_case" + assert to_snake_case("camelCase") == "camel_case" + + def test_with_acronyms(self): + assert to_snake_case("HTTPServerID") == "http_server_id" + + +class TestToCamelCase: + def test_basic(self): + assert to_camel_case("hello_world") == "helloWorld" + assert to_camel_case("Hello World") == "helloWorld" + assert to_camel_case("hello-world") == "helloWorld" + + def test_empty(self): + assert to_camel_case("") == "" + + +class TestCamelcasePath: + def test_path(self): + assert camelcase_path("a.b_c.d-e f") == "a.bC.dEF" + + +class TestDictKeyTransforms: + def test_to_snake(self): + data = {"FirstName": "Alice", "Address": {"zipCode": 12345}, "items": [{"itemID": 1}]} + out = dict_keys_to_snake(data) + assert out == {"first_name": "Alice", "address": {"zip_code": 12345}, "items": [{"item_id": 1}]} + + def test_to_camel(self): + data = {"first_name": "Bob", "address": {"zip_code": 12345}, "items": [{"item_id": 1}]} + out = dict_keys_to_camel(data) + assert out == {"firstName": "Bob", "address": {"zipCode": 12345}, "items": [{"itemId": 1}]} + + +class TestConvertDatesToStrings: + def test_nested(self): + d = date(2020, 1, 2) + dt = datetime(2020, 1, 2, 3, 4, 5) + obj = {"when": d, "arr": [dt, {"n": 1}]} + out = convert_dates_to_strings(obj) + assert out == {"when": d.isoformat(), "arr": [dt.isoformat(), {"n": 1}]} + + +class TestToValueEnumName: + def test_basic(self): + assert to_value_enum_name("in progress") == "IN_PROGRESS" + assert to_value_enum_name("done!") == "DONE_" + assert to_value_enum_name("123start") == "_123START" From ed6ffe1ed4af98ae4916c507300594aa8f6e496b Mon Sep 17 00:00:00 2001 From: Patrick Yoho Date: Mon, 26 Jan 2026 13:08:42 -0600 Subject: [PATCH 2/5] Add 'schema' polylith component and tests --- components/lif/datatypes/schema.py | 11 + components/lif/schema/__init__.py | 3 + components/lif/schema/core.py | 144 ++++++++ test/components/lif/schema/__init__.py | 0 test/components/lif/schema/test_core.py | 424 ++++++++++++++++++++++++ 5 files changed, 582 insertions(+) create mode 100644 components/lif/datatypes/schema.py create mode 100644 components/lif/schema/__init__.py create mode 100644 components/lif/schema/core.py create mode 100644 test/components/lif/schema/__init__.py create mode 100644 test/components/lif/schema/test_core.py diff --git a/components/lif/datatypes/schema.py b/components/lif/datatypes/schema.py new file mode 100644 index 0000000..26a9994 --- /dev/null +++ b/components/lif/datatypes/schema.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from typing import Any, Dict + +@dataclass +class SchemaField: + """Represents a single schema field, including its path and attributes.""" + + json_path: str + description: str + attributes: Dict[str, Any] + py_field_name: str = "" diff --git a/components/lif/schema/__init__.py b/components/lif/schema/__init__.py new file mode 100644 index 0000000..3138d3f --- /dev/null +++ b/components/lif/schema/__init__.py @@ -0,0 +1,3 @@ +from lif.schema import core + +__all__ = ["core"] diff --git a/components/lif/schema/core.py b/components/lif/schema/core.py new file mode 100644 index 0000000..3d80c39 --- /dev/null +++ b/components/lif/schema/core.py @@ -0,0 +1,144 @@ +from pathlib import Path +from typing import Any, List, Optional, Union + +import jsonref + +from lif.datatypes.schema import SchemaField +from lif.logging import get_logger +from lif.string_utils.core import camelcase_path, to_camel_case + + +logger = get_logger(__name__) + + +ATTRIBUTE_KEYS = [ + "x-queryable", + "x-mutable", + "DataType", + "Required", + "Array", + "UniqueName", + "enum", + "type", +] + +# ===== SCHEMA FIELD EXTRACTION ===== + + +def extract_nodes(obj: Any, path_prefix: str = "") -> List[SchemaField]: + """ + Recursively extract schema fields from an OpenAPI/JSON Schema object. + + Returns: + List[SchemaField]: Flat list of SchemaField objects. + """ + nodes = [] + + def is_array(node: dict) -> bool: + """Return True if node is an array.""" + return node.get("type") == "array" or "items" in node + + def get_description(node: dict) -> str: + """Get description from node, prefer lower-case.""" + return node.get("Description", "") or node.get("description", "") + + def extract_attributes(node: dict) -> dict: + """Extract core attributes from node.""" + attributes = { + to_camel_case(k): node.get(k) for k in ATTRIBUTE_KEYS if k in node + } + if "Array" in node: + attributes["array"] = node["Array"] + else: + attributes["array"] = "Yes" if is_array(node) else "No" + attributes["type"] = node.get("type", node.get("DataType", None)) + return attributes + + if isinstance(obj, dict): + key = camelcase_path(path_prefix.rstrip(".")) + + branch = ( + "properties" in obj + and isinstance(obj["properties"], dict) + and obj["properties"] + ) or ("items" in obj and isinstance(obj["items"], dict)) + attributes = extract_attributes(obj) + attributes["branch"] = bool(branch) + attributes["leaf"] = not attributes["branch"] + nodes.append( + SchemaField( + json_path=key, + description=get_description(obj), + attributes=attributes, + ) + ) + + # Recurse children + if "properties" in obj and isinstance(obj["properties"], dict): + for prop, val in obj["properties"].items(): + new_prefix = f"{path_prefix}.{prop}" if path_prefix else prop + nodes.extend(extract_nodes(val, new_prefix)) + + if "items" in obj: + items = obj["items"] + if isinstance(items, dict): + nodes.extend(extract_nodes(items, path_prefix)) + elif isinstance(items, list): # tuple validation + for sub_item in items: + nodes.extend(extract_nodes(sub_item, path_prefix)) + + return nodes + + +# ===== ROOT SCHEMA RESOLUTION ===== + + +def resolve_openapi_root(doc: dict, root: str): + """Return the schema node for a given root in the OpenAPI spec.""" + candidates = [] + if "components" in doc and "schemas" in doc["components"]: + schemas = doc["components"]["schemas"] + if root in schemas: + return schemas[root], root + candidates.extend(schemas.keys()) + if "definitions" in doc: + definitions = doc["definitions"] + if root in definitions: + return definitions[root], root + candidates.extend(definitions.keys()) + raise ValueError(f"Root schema '{root}' not found. Available: {sorted(candidates)}") + + +# ===== FILE LOADING ===== + + +def load_schema_nodes( + openapi: Union[str, Path, dict], + root: Optional[str] = None, +) -> List[SchemaField]: + """ + Load and extract schema fields from an OpenAPI JSON file, pathlib.Path, or dictionary. + + Args: + openapi (str | Path | dict): Either a file path (str or Path) to the OpenAPI JSON file, + or a dictionary representing the OpenAPI schema. + root (str, optional): Root key in the schema to resolve. + + Returns: + List[SchemaField]: Extracted SchemaField objects. + """ + if isinstance(openapi, (str, Path)): + with open(openapi, "r") as f: + doc = jsonref.load(f) + elif isinstance(openapi, dict): + # Replace $ref references in dict input + doc = jsonref.JsonRef.replace_refs(openapi) + else: + raise TypeError("openapi must be a str, Path, or dict") + + node = doc + path_prefix = "" + if root: + node, path_prefix = resolve_openapi_root(doc, root) + + return extract_nodes(node, path_prefix) diff --git a/test/components/lif/schema/__init__.py b/test/components/lif/schema/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/components/lif/schema/test_core.py b/test/components/lif/schema/test_core.py new file mode 100644 index 0000000..aded94a --- /dev/null +++ b/test/components/lif/schema/test_core.py @@ -0,0 +1,424 @@ +import pytest +from pathlib import Path +from unittest.mock import mock_open, patch + +from lif.schema import core + + +class TestExtractNodes: + """Tests for the extract_nodes function.""" + + def test_extract_simple_string_field(self): + """Test extracting a simple string field.""" + obj = {"type": "string", "description": "A simple string field", "x-queryable": True} + + nodes = core.extract_nodes(obj, "test.field") + + assert len(nodes) == 1 + node = nodes[0] + assert node.json_path == "test.field" + assert node.description == "A simple string field" + assert node.attributes["xQueryable"] is True + assert node.attributes["type"] == "string" + assert node.attributes["array"] == "No" + assert node.attributes["leaf"] is True + assert node.attributes["branch"] is False + + def test_extract_array_field(self): + """Test extracting an array field.""" + obj = {"type": "array", "items": {"type": "string"}, "description": "An array of strings"} + + nodes = core.extract_nodes(obj, "test.array") + + assert len(nodes) == 2 + # First node is the array container + array_node = nodes[0] + assert array_node.json_path == "test.array" + assert array_node.description == "An array of strings" + assert array_node.attributes["type"] == "array" + assert array_node.attributes["array"] == "Yes" + assert array_node.attributes["branch"] is True + + # Second node is the items + items_node = nodes[1] + assert items_node.json_path == "test.array" + assert items_node.attributes["type"] == "string" + + def test_extract_object_with_properties(self): + """Test extracting an object with properties.""" + obj = { + "type": "object", + "description": "An object with properties", + "properties": { + "name": {"type": "string", "description": "Name field", "x-mutable": False}, + "age": {"type": "integer", "description": "Age field"}, + }, + } + + nodes = core.extract_nodes(obj, "person") + + assert len(nodes) == 3 + + # First node is the object container + obj_node = nodes[0] + assert obj_node.json_path == "person" + assert obj_node.description == "An object with properties" + assert obj_node.attributes["type"] == "object" + assert obj_node.attributes["branch"] is True + + # Second and third nodes are the properties + name_node = next(n for n in nodes if n.json_path == "person.name") + assert name_node.description == "Name field" + assert name_node.attributes["type"] == "string" + assert name_node.attributes["xMutable"] is False + + age_node = next(n for n in nodes if n.json_path == "person.age") + assert age_node.description == "Age field" + assert age_node.attributes["type"] == "integer" + + def test_extract_with_custom_attributes(self): + """Test extracting fields with custom LIF attributes.""" + obj = { + "type": "string", + "Description": "Custom description", # uppercase D + "DataType": "xsd:string", + "Required": "Yes", + "Array": "No", + "UniqueName": "Person.Name.firstName", + "enum": ["option1", "option2"], + } + + nodes = core.extract_nodes(obj, "name") + + assert len(nodes) == 1 + node = nodes[0] + assert node.description == "Custom description" + assert node.attributes["dataType"] == "xsd:string" + assert node.attributes["required"] == "Yes" + assert node.attributes["array"] == "No" + assert node.attributes["uniqueName"] == "Person.Name.firstName" + assert node.attributes["enum"] == ["option1", "option2"] + + def test_extract_nested_structure(self): + """Test extracting a deeply nested structure.""" + obj = { + "type": "object", + "properties": { + "contact": { + "type": "array", + "items": { + "type": "object", + "properties": { + "email": {"type": "string", "description": "Email address", "x-queryable": True} + }, + }, + } + }, + } + + nodes = core.extract_nodes(obj, "person") + + # Should have: person, person.contact, person.contact (items), person.contact.email + assert len(nodes) == 4 + + email_node = next(n for n in nodes if n.json_path == "person.contact.email") + assert email_node.description == "Email address" + assert email_node.attributes["xQueryable"] is True + + def test_extract_with_empty_path_prefix(self): + """Test extracting with empty path prefix.""" + obj = {"type": "string", "description": "Root field"} + + nodes = core.extract_nodes(obj, "") + + assert len(nodes) == 1 + assert nodes[0].json_path == "" + + def test_extract_non_dict_returns_empty(self): + """Test that non-dict objects return empty list.""" + nodes = core.extract_nodes("not a dict", "path") + assert nodes == [] + + nodes = core.extract_nodes(123, "path") + assert nodes == [] + + nodes = core.extract_nodes(None, "path") + assert nodes == [] + + +class TestResolveOpenApiRoot: + """Tests for the resolve_openapi_root function.""" + + def test_resolve_from_components_schemas(self): + """Test resolving root from components.schemas.""" + doc = {"components": {"schemas": {"Person": {"type": "object"}, "Organization": {"type": "object"}}}} + + schema, root = core.resolve_openapi_root(doc, "Person") + + assert root == "Person" + assert schema == {"type": "object"} + + def test_resolve_from_definitions(self): + """Test resolving root from definitions (older OpenAPI/JSON Schema).""" + doc = {"definitions": {"User": {"type": "object"}, "Product": {"type": "object"}}} + + schema, root = core.resolve_openapi_root(doc, "User") + + assert root == "User" + assert schema == {"type": "object"} + + def test_resolve_components_takes_precedence(self): + """Test that components.schemas takes precedence over definitions.""" + doc = { + "components": {"schemas": {"Person": {"type": "object", "description": "from components"}}}, + "definitions": {"Person": {"type": "object", "description": "from definitions"}}, + } + + schema, root = core.resolve_openapi_root(doc, "Person") + + assert schema["description"] == "from components" + + def test_resolve_nonexistent_root_raises_error(self): + """Test that resolving a non-existent root raises ValueError.""" + doc = {"components": {"schemas": {"Person": {"type": "object"}}}} + + with pytest.raises(ValueError) as exc_info: + core.resolve_openapi_root(doc, "NonExistent") + + assert "Root schema 'NonExistent' not found" in str(exc_info.value) + assert "Person" in str(exc_info.value) + + def test_resolve_empty_doc_raises_error(self): + """Test that resolving from empty doc raises ValueError.""" + doc = {} + + with pytest.raises(ValueError) as exc_info: + core.resolve_openapi_root(doc, "Person") + + assert "Root schema 'Person' not found" in str(exc_info.value) + + +class TestLoadSchemaNodes: + """Tests for the load_schema_nodes function.""" + + def test_load_from_dict(self): + """Test loading schema nodes from a dictionary.""" + schema_dict = {"type": "object", "properties": {"name": {"type": "string", "description": "Person name"}}} + + nodes = core.load_schema_nodes(schema_dict) + + assert len(nodes) == 2 + assert any(n.json_path == "" for n in nodes) + assert any(n.json_path == "name" for n in nodes) + + def test_load_from_dict_with_root(self): + """Test loading schema nodes from dict with specific root.""" + schema_dict = { + "components": {"schemas": {"Person": {"type": "object", "properties": {"name": {"type": "string"}}}}} + } + + nodes = core.load_schema_nodes(schema_dict, root="Person") + + assert len(nodes) == 2 + # The root should be resolved and path should start with the camelCase root name + assert any(n.json_path == "person" for n in nodes) # Root node + assert any(n.json_path == "person.name" for n in nodes) # Property node + + @patch("builtins.open", new_callable=mock_open) + @patch("jsonref.load") + def test_load_from_file_path_string(self, mock_jsonref_load, mock_file_open): + """Test loading schema nodes from file path as string.""" + schema_data = {"type": "object", "properties": {"id": {"type": "string"}}} + mock_jsonref_load.return_value = schema_data + + nodes = core.load_schema_nodes("/path/to/schema.json") + + mock_file_open.assert_called_once_with("/path/to/schema.json", "r") + mock_jsonref_load.assert_called_once() + assert len(nodes) == 2 + + @patch("builtins.open", new_callable=mock_open) + @patch("jsonref.load") + def test_load_from_pathlib_path(self, mock_jsonref_load, mock_file_open): + """Test loading schema nodes from pathlib.Path.""" + schema_data = {"type": "string"} + mock_jsonref_load.return_value = schema_data + + path = Path("/path/to/schema.json") + nodes = core.load_schema_nodes(path) + + mock_file_open.assert_called_once_with(path, "r") + assert len(nodes) == 1 + + def test_load_invalid_type_raises_error(self): + """Test that invalid input type raises TypeError.""" + with pytest.raises(TypeError) as exc_info: + core.load_schema_nodes(123) # type: ignore + + assert "openapi must be a str, Path, or dict" in str(exc_info.value) + + @patch("lif.schema.core.jsonref") + def test_load_dict_calls_replace_refs(self, mock_jsonref): + """Test that loading from dict calls jsonref.replace_refs.""" + schema_dict = {"type": "string"} + mock_jsonref.JsonRef.replace_refs.return_value = schema_dict + + core.load_schema_nodes(schema_dict) + + mock_jsonref.JsonRef.replace_refs.assert_called_once_with(schema_dict) + + +class TestAttributeKeys: + """Test that ATTRIBUTE_KEYS contains expected values.""" + + def test_attribute_keys_content(self): + """Test that ATTRIBUTE_KEYS contains the expected keys.""" + expected_keys = ["x-queryable", "x-mutable", "DataType", "Required", "Array", "UniqueName", "enum", "type"] + + assert core.ATTRIBUTE_KEYS == expected_keys + + +class TestHelperFunctions: + """Test helper functions within extract_nodes.""" + + def test_is_array_detection(self): + """Test array detection logic.""" + # Test with direct array access since is_array is nested + obj_with_type_array = {"type": "array"} + nodes = core.extract_nodes(obj_with_type_array, "test") + assert nodes[0].attributes["array"] == "Yes" + + obj_with_items = {"items": {"type": "string"}} + nodes = core.extract_nodes(obj_with_items, "test") + assert nodes[0].attributes["array"] == "Yes" + + obj_without_array = {"type": "string"} + nodes = core.extract_nodes(obj_without_array, "test") + assert nodes[0].attributes["array"] == "No" + + def test_description_preference(self): + """Test that uppercase 'Description' is preferred over 'description'.""" + obj_with_both = { + "type": "string", + "Description": "Uppercase description", + "description": "Lowercase description", + } + + nodes = core.extract_nodes(obj_with_both, "test") + assert nodes[0].description == "Uppercase description" + + obj_with_lowercase_only = {"type": "string", "description": "Lowercase only"} + + nodes = core.extract_nodes(obj_with_lowercase_only, "test") + assert nodes[0].description == "Lowercase only" + + +class TestIntegrationWithTestSchema: + """Integration tests using the test schema.""" + + def test_extract_from_simple_person_schema(self): + """Test extracting from a simple person schema.""" + person_schema = { + "type": "object", + "properties": { + "Identifier": { + "type": "array", + "properties": { + "identifier": {"type": "string", "description": "A unique identifier", "x-queryable": True}, + "identifierType": {"type": "string", "x-queryable": True}, + }, + }, + "Name": { + "type": "array", + "properties": { + "firstName": {"type": "string", "x-mutable": False}, + "lastName": {"type": "string", "x-mutable": False}, + }, + }, + }, + } + + nodes = core.extract_nodes(person_schema, "Person") + + # Should extract: Person, Identifier, identifier, identifierType, Name, firstName, lastName + assert len(nodes) == 7 + + # Check specific nodes + person_node = next(n for n in nodes if n.json_path == "person") + assert person_node.attributes["type"] == "object" + assert person_node.attributes["branch"] is True + + id_node = next(n for n in nodes if n.json_path == "person.identifier.identifier") + assert id_node.attributes["xQueryable"] is True + assert id_node.description == "A unique identifier" + + first_name_node = next(n for n in nodes if n.json_path == "person.name.firstName") + assert first_name_node.attributes["xMutable"] is False + + +class TestWithRealTestSchema: + """Test with the actual test schema file.""" + + def test_load_test_schema_file(self): + """Test loading the actual test_openapi_schema.json file.""" + test_schema_path = Path(__file__).parent.parent.parent.parent / "data" / "test_openapi_schema.json" + + nodes = core.load_schema_nodes(test_schema_path, root="Person") + + # Should have nodes for Person and its properties + assert len(nodes) > 0 + + # Check that we have the expected main properties + paths = [n.json_path for n in nodes] + assert "person" in paths # Root Person object + assert any("identifier" in path for path in paths) # Identifier array + assert any("name" in path for path in paths) # Name array + assert any("proficiency" in path for path in paths) # Proficiency array + assert any("contact" in path for path in paths) # Contact array + + # Check that x-queryable and x-mutable attributes are preserved + queryable_nodes = [n for n in nodes if n.attributes.get("xQueryable")] + mutable_nodes = [n for n in nodes if "xMutable" in n.attributes] + + assert len(queryable_nodes) > 0 # Should have some queryable fields + assert len(mutable_nodes) > 0 # Should have some mutable fields + + +# Additional edge case tests +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_extract_with_tuple_validation_items(self): + """Test extracting with items as list (tuple validation).""" + obj = {"type": "array", "items": [{"type": "string"}, {"type": "number"}]} + + nodes = core.extract_nodes(obj, "tuple_array") + + # Should have the array node plus nodes for each item type + assert len(nodes) >= 1 + assert nodes[0].attributes["array"] == "Yes" + + def test_extract_with_missing_attributes(self): + """Test extraction when some attributes are missing.""" + obj = { + # No type, no description + "x-queryable": True + } + + nodes = core.extract_nodes(obj, "test") + + assert len(nodes) == 1 + node = nodes[0] + assert node.description == "" + assert node.attributes["xQueryable"] is True + assert node.attributes["type"] is None + + def test_camelcase_path_conversion(self): + """Test that paths are properly converted to camelCase.""" + obj = {"type": "string"} + + nodes = core.extract_nodes(obj, "some-complex_path.with-dashes") + + # The camelcase_path function should be called on the path + assert len(nodes) == 1 + # The actual conversion is handled by camelcase_path function in string_utils From 22b35cf1b71d3326e15d28185706eb4e0a2c9ea8 Mon Sep 17 00:00:00 2001 From: Patrick Yoho Date: Mon, 26 Jan 2026 13:09:39 -0600 Subject: [PATCH 3/5] Add 'dynamic_models' polylith component and tests --- components/lif/dynamic_models/__init__.py | 3 + components/lif/dynamic_models/core.py | 406 ++++++++ .../components/lif/dynamic_models/__init__.py | 0 .../lif/dynamic_models/test_core.py | 980 ++++++++++++++++++ test/data/test_openapi_schema.json | 71 ++ 5 files changed, 1460 insertions(+) create mode 100644 components/lif/dynamic_models/__init__.py create mode 100644 components/lif/dynamic_models/core.py create mode 100644 test/components/lif/dynamic_models/__init__.py create mode 100644 test/components/lif/dynamic_models/test_core.py create mode 100644 test/data/test_openapi_schema.json diff --git a/components/lif/dynamic_models/__init__.py b/components/lif/dynamic_models/__init__.py new file mode 100644 index 0000000..18ec27f --- /dev/null +++ b/components/lif/dynamic_models/__init__.py @@ -0,0 +1,3 @@ +from lif.dynamic_models import core + +__all__ = ["core"] diff --git a/components/lif/dynamic_models/core.py b/components/lif/dynamic_models/core.py new file mode 100644 index 0000000..4257832 --- /dev/null +++ b/components/lif/dynamic_models/core.py @@ -0,0 +1,406 @@ +""" +Reusable module to build nested Pydantic models dynamically +from a list of schema fields (endpoints in the schema tree). + +This module supports building Pydantic models at runtime based on a schema definition, +allowing flexible data validation for various use cases (e.g., query filters, mutations, or full models). +All fields in generated models are Optional and default to None. + +Functions: + build_dynamic_model: Creates nested Pydantic models from schema fields. + build_dynamic_models: Loads schema fields and builds all model variants (filters, mutations, full model). +""" + +import logging +import os +import re +from datetime import date, datetime +from enum import Enum +from typing import Annotated, Any, Dict, List, Optional, Tuple, Type, TypeVar, cast + +from pydantic import BaseModel, ConfigDict, Field + +from lif.datatypes.schema import SchemaField +from lif.schema.core import load_schema_nodes +from lif.string_utils.core import to_pascal_case + + +# ===== Environment and Global Config ===== + +ROOT_NODE: str | None = os.getenv("ROOT_NODE") +OPENAPI_SCHEMA_FILE: str | None = os.getenv("OPENAPI_SCHEMA_FILE") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +T = TypeVar("T") +ModelDict = Dict[str, Type[BaseModel]] + +#: Map XML-Schema datatypes ➜ native Python types +DATATYPE_MAP: dict[str, type[Any]] = { + "xsd:string": str, + "xsd:decimal": float, + "xsd:integer": int, + "xsd:boolean": bool, + "xsd:date": date, + "xsd:dateTime": datetime, + "xsd:datetime": datetime, + "xsd:anyURI": str, +} + +#: Singleton cache to prevent duplicate Enum definitions +_ENUM_CLASS_CACHE: Dict[str, Type[Enum]] = {} + +_TRUTHY = {"yes", "true", "1"} + + +# ===== Helpers ===== + + +def _is_yes(value: Any) -> bool: + """Check if value is a truthy 'yes' string.""" + return str(value).strip().lower() in _TRUTHY + + +def _to_enum_member(label: str) -> str: + """Convert a string label into a valid Enum member name.""" + key = re.sub(r"\W|^(?=\d)", "_", label.upper()) + return key + + +def make_enum(name: str, values: List[Any]) -> Type[Enum]: + """Return a cached Enum type for the given values.""" + cache_key = f"{name}_{'_'.join(map(str, sorted(values)))}" + if cache_key in _ENUM_CLASS_CACHE: # short-circuit hit + return _ENUM_CLASS_CACHE[cache_key] + + # Ensure value type is compatible with typing stubs and tooling + members: Dict[str, object] = {_to_enum_member(v): v for v in values} + # Use functional API with explicit module for better pickling/introspection + enum_cls = cast(Type[Enum], Enum(name, members, module=__name__)) + # enum_cls = type(name, (Enum,), {"__module__": __name__, **members}) # alternative + _ENUM_CLASS_CACHE[cache_key] = enum_cls + return enum_cls + + +# ===== Core builders ===== + + +def build_dynamic_model( + schema_fields: List[SchemaField], + *, + attribute_flag: str | None = "xQueryable", + model_doc: str = "Data model", + allow_extra: bool = False, + model_suffix: str = "", + all_optional: bool = True, +) -> ModelDict: + """Create nested Pydantic models for the given schema fields. + + Args: + schema_fields (List[SchemaField]): Leaf-level schema nodes. + attribute_flag (str | None): Only include fields whose attributes[flag] is truthy. + Use None to include all fields. + model_doc (str): Base docstring used for generated classes. + allow_extra (bool): If True, allows extra properties (additionalProperties: true). + model_suffix (str): Suffix for model class names (e.g., "Type", "Filter", "Mutation"). + all_optional (bool): If True, all fields will be Optional and default to None. + + Returns: + ModelDict: A mapping { model_name: Pydantic class }. + """ + + # Build a *tree* structure and a quick lookup for leaf nodes. + tree: dict[str, Any] = {} + leaf_by_path: Dict[Tuple[str, ...], SchemaField] = {} + + for sf in schema_fields: + if attribute_flag and not sf.attributes.get(attribute_flag, False): + continue + + parts = sf.json_path.split(".") + node = tree + for i, part in enumerate(parts): + node = node.setdefault(part, {}) + if i == len(parts) - 1: + leaf_by_path[tuple(parts)] = sf + + if not tree: + return {} + + if len(tree) != 1: + raise ValueError(f"All {attribute_flag or 'selected'} json_path values must share a common root") + + # ===== Internal utilities ===== + + def _field_type(sf: SchemaField, name: str) -> Any: + """Translate a SchemaField to a typing type. + + Args: + sf (SchemaField): The schema field. + name (str): Name for enum classes. + + Returns: + Any: The Python type or Enum for the field. + """ + if "enum" in sf.attributes: + base: Any = make_enum(name.capitalize(), sf.attributes["enum"]) + else: + base = DATATYPE_MAP.get(sf.attributes.get("dataType", "xsd:string"), str) + + if _is_yes(sf.attributes.get("array", "No")): + base = List[base] + + if all_optional: + return Optional[base] + return base + + def _wrap_root(root_name: str, inner: Type[BaseModel], is_array: bool) -> Type[BaseModel]: + """Create a top-level Pydantic wrapper model with the specified root name. + + Args: + root_name (str): Root field name. + inner (Type[BaseModel]): The inner model class. + is_array (bool): If True, wraps the model in a List[]. + + Returns: + Type[BaseModel]: The generated wrapper model. + """ + field_type: Any = List[inner] if is_array else inner + annotations = {root_name: field_type} + doc = f"Top-level wrapper with `{root_name}` field." if root_name else "Top-level wrapper." + class_name = f"{root_name.capitalize()}{model_suffix}" + namespace = { + "__annotations__": annotations, + "__doc__": doc, + "model_config": ConfigDict(strict=False, extra="allow" if allow_extra else "forbid"), + } + # Only set default if all_optional + if all_optional: + namespace[root_name] = None + return type(class_name, (BaseModel,), namespace) + + # ===== Recursive Model Builder ===== + + models: ModelDict = {} + + def strip_root(parts): + """Remove the root node from the path.""" + if parts and parts[0].lower() == root_name.lower(): + return parts[1:] + return parts + + def _build_model(name: str, subtree: dict[str, Any], path: Tuple[str, ...]) -> Type[BaseModel] | None: + """Recursively build nested Pydantic models. + + Args: + name (str): Model class name. + subtree (dict[str, Any]): Subtree of the schema. + path (Tuple[str, ...]): Current path in the tree. + + Returns: + Type[BaseModel] | None: The constructed model or None if no fields. + """ + sf = leaf_by_path.get(path) + stripped = strip_root(path) + if stripped: + class_name = to_pascal_case("".join(x for x in stripped)) + else: + class_name = to_pascal_case(root_name) + if model_suffix and not class_name.endswith(model_suffix): + class_name = f"{class_name}{model_suffix}" + # Guarantee non-empty unique_name for root + unique_name = (sf.attributes.get("uniqueName") if sf else None) or ".".join(stripped) or class_name + + annotations: Dict[str, Any] = {} + defaults: Dict[str, Any] = {} + + for key, child in subtree.items(): + child_path = path + (key,) + leaf_sf = leaf_by_path.get(child_path) + + if not child: # leaf + if leaf_sf: + if all_optional: + annotations[key] = Annotated[ + Optional[_field_type(leaf_sf, to_pascal_case(key))], Field(description=leaf_sf.description) + ] + defaults[key] = None + else: + annotations[key] = Annotated[ + _field_type(leaf_sf, to_pascal_case(key)), Field(description=leaf_sf.description) + ] + # No default: required + else: # branch ➜ nested model + is_array = _is_yes(leaf_sf.attributes.get("array", "No")) if leaf_sf else False + child_model = _build_model(key.capitalize(), child, child_path) + if child_model: + if all_optional: + annotations[key] = Optional[List[child_model]] if is_array else Optional[child_model] + defaults[key] = None + else: + annotations[key] = List[child_model] if is_array else child_model + # No default: required + + if not annotations: + return None + + desc = f"{model_doc} for `{class_name}`." + namespace = { + "__annotations__": annotations, + "__doc__": desc, + "__module__": __name__, + # Use supported ConfigDict keys; attach metadata via json_schema_extra + "model_config": ConfigDict( + # TODO (from before integration into this repo): Make sure the change from this works: title=class_name, description=desc, strict=False, extra="allow" if allow_extra else "forbid" + title=class_name, + strict=False, + extra="allow" if allow_extra else "forbid", + json_schema_extra={"description": desc}, + ), + } + # Only set defaults for all_optional + if all_optional: + namespace.update(defaults) + cls = type(class_name, (BaseModel,), namespace) + models[unique_name] = cls + return cls + + # ===== Build Root + Wrapper ===== + + # TODO (from before integration into this repo): This forces the wrapper structure. It should use OpenAPI schema + + root_name = next(iter(tree)) + + inner_model = _build_model(root_name.capitalize(), tree[root_name], (root_name,)) + if inner_model is None: # pragma: no cover – safeguard + return {} + + # Just force as array: + wrapper_model = _wrap_root(root_name, inner_model, True) + + models[root_name] = inner_model + models[f"{root_name}_wrapper"] = wrapper_model + return models + + +# ===== External Entrypoints ===== + + +def get_schema_fields() -> List[SchemaField]: + """ + Load and return the list of schema fields from the configured schema source. + + Returns: + List[SchemaField]: The schema fields for the root node. + """ + # Read environment at call time to allow tests/runtime overrides + openapi_file = os.getenv("OPENAPI_SCHEMA_FILE", OPENAPI_SCHEMA_FILE) + root_node = os.getenv("ROOT_NODE", ROOT_NODE) + if openapi_file is None: + raise ValueError("OPENAPI_SCHEMA_FILE environment variable is not set") + return load_schema_nodes(openapi_file, root_node) + + +def build_filter_models(fields: List[SchemaField], *, allow_extra: bool = True, all_optional: bool = True) -> ModelDict: + """ + Build filter models from schema fields. + + Args: + fields (List[SchemaField]): Schema fields. + allow_extra (bool, optional): Allow extra properties in models. Default is True. + all_optional (bool, optional): Make all fields Optional and default to None. Default is True. + + Returns: + ModelDict: A mapping { model_name: Pydantic class } for filter models. + """ + return build_dynamic_model( + fields, + attribute_flag="xQueryable", + model_doc="Filter data model", + model_suffix="Filter", + allow_extra=allow_extra, + all_optional=all_optional, + ) + + +def build_mutation_models( + fields: List[SchemaField], *, allow_extra: bool = False, all_optional: bool = True +) -> ModelDict: + """ + Build mutation models from schema fields. + + Args: + fields (List[SchemaField]): Schema fields. + allow_extra (bool, optional): Allow extra properties in models. Default is False. + all_optional (bool, optional): Make all fields Optional and default to None. Default is True. + + Returns: + ModelDict: A mapping { model_name: Pydantic class } for mutation models. + """ + return build_dynamic_model( + fields, + attribute_flag="xMutable", + model_doc="Mutation data model", + model_suffix="Mutation", + allow_extra=allow_extra, + all_optional=all_optional, + ) + + +def build_full_models(fields: List[SchemaField], *, allow_extra: bool = False, all_optional: bool = False) -> ModelDict: + """ + Build full (strict) models from schema fields. + + Args: + fields (List[SchemaField]): Schema fields. + allow_extra (bool, optional): Allow extra properties in models. Default is False. + all_optional (bool, optional): Make all fields Optional and default to None. Default is False (fields required). + + Returns: + ModelDict: A mapping { model_name: Pydantic class } for full models. + """ + return build_dynamic_model( + fields, + attribute_flag=None, + model_doc="Full data model", + model_suffix="Type", + allow_extra=allow_extra, + all_optional=all_optional, + ) + + +def build_all_models( + *, + filter_allow_extra: bool = True, + filter_all_optional: bool = True, + mutation_allow_extra: bool = False, + mutation_all_optional: bool = True, + full_allow_extra: bool = False, + full_all_optional: bool = False, +) -> tuple[List[SchemaField], ModelDict, ModelDict, ModelDict]: + """ + Build all three model sets (filter, mutation, full) in one go, optionally customizing allow_extra and all_optional for each. + + Keyword Args: + filter_allow_extra (bool): Allow extra properties in filter models. Default True. + filter_all_optional (bool): Make all filter model fields Optional. Default True. + mutation_allow_extra (bool): Allow extra properties in mutation models. Default False. + mutation_all_optional (bool): Make all mutation model fields Optional. Default True. + full_allow_extra (bool): Allow extra properties in full models. Default False. + full_all_optional (bool): Make all full model fields Optional. Default False (fields required). + + Returns: + tuple: + - List[SchemaField]: Schema fields. + - ModelDict: Filter models. + - ModelDict: Mutation models. + - ModelDict: Full models. + """ + fields = get_schema_fields() + return ( + fields, + build_filter_models(fields, allow_extra=filter_allow_extra, all_optional=filter_all_optional), + build_mutation_models(fields, allow_extra=mutation_allow_extra, all_optional=mutation_all_optional), + build_full_models(fields, allow_extra=full_allow_extra, all_optional=full_all_optional), + ) diff --git a/test/components/lif/dynamic_models/__init__.py b/test/components/lif/dynamic_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/components/lif/dynamic_models/test_core.py b/test/components/lif/dynamic_models/test_core.py new file mode 100644 index 0000000..bed0144 --- /dev/null +++ b/test/components/lif/dynamic_models/test_core.py @@ -0,0 +1,980 @@ +""" +Comprehensive unit tests for the dynamic_models.core module. + +Tests the dynamic Pydantic model building functionality including: +- Building nested models from schema fields +- Filter, mutation, and full model variants +- Enum handling +- Field type mapping +- Error conditions +""" + +import os +import pytest +from enum import Enum +from pathlib import Path +from unittest.mock import patch + +from pydantic import BaseModel, ValidationError + +from lif.dynamic_models import core +from lif.datatypes.schema import SchemaField + + +PATH_TO_TEST_SCHEMA = Path(__file__).parent.parent.parent.parent / "data" / "test_openapi_schema.json" + + +class TestHelperFunctions: + """Test helper functions in the core module.""" + + def test_is_yes(self): + """Test the _is_yes helper function.""" + assert core._is_yes("yes") is True + assert core._is_yes("YES") is True + assert core._is_yes("true") is True + assert core._is_yes("TRUE") is True + assert core._is_yes("1") is True + assert core._is_yes(" yes ") is True + + assert core._is_yes("no") is False + assert core._is_yes("false") is False + assert core._is_yes("0") is False + assert core._is_yes("") is False + assert core._is_yes("maybe") is False + + def test_to_enum_member(self): + """Test the _to_enum_member helper function.""" + assert core._to_enum_member("Valid Option") == "VALID_OPTION" + assert core._to_enum_member("123invalid") == "_123INVALID" + assert core._to_enum_member("special-chars!@#") == "SPECIAL_CHARS___" + assert core._to_enum_member("") == "" + + def test_make_enum(self): + """Test enum creation and caching.""" + # Create enum + enum_cls = core.make_enum("TestEnum", ["option1", "option2", "option3"]) + assert issubclass(enum_cls, Enum) + + # Test enum values + assert hasattr(enum_cls, "OPTION1") + assert hasattr(enum_cls, "OPTION2") + assert hasattr(enum_cls, "OPTION3") + + # Test enum value content + option1_member = getattr(enum_cls, "OPTION1") + assert option1_member.value == "option1" + + # Test caching - same values should return same class + enum_cls2 = core.make_enum("TestEnum", ["option1", "option2", "option3"]) + assert enum_cls is enum_cls2 + + # Different values should return different class + enum_cls3 = core.make_enum("TestEnum", ["option1", "option2", "option4"]) + assert enum_cls is not enum_cls3 + + +class TestFieldTypeMapping: + """Test field type mapping through model creation and validation.""" + + def test_string_field_in_model(self): + """Test string field type mapping and validation.""" + fields = [ + SchemaField( + json_path="test.field", + description="Test field", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ) + ] + models = core.build_dynamic_model(fields) + assert "test" in models + + test_model = models["test"] + + # Test valid string assignment + instance = test_model(field="hello") + assert getattr(instance, "field") == "hello" + + # Test None assignment (optional) + instance_none = test_model(field=None) + assert getattr(instance_none, "field") is None + + # Test string conversion from valid types + instance_converted = test_model(field="123") + assert getattr(instance_converted, "field") == "123" + + # Test that invalid types raise ValidationError + with pytest.raises(ValidationError): + test_model(field=123) # Integer not allowed for strict string + + def test_integer_field_in_model(self): + """Test integer field type mapping and validation.""" + fields = [ + SchemaField( + json_path="test.field", + description="Test field", + attributes={"xQueryable": True, "dataType": "xsd:integer", "array": "No"}, + ) + ] + models = core.build_dynamic_model(fields) + assert "test" in models + + test_model = models["test"] + + # Test valid integer assignment + instance = test_model(field=42) + assert getattr(instance, "field") == 42 + + # Test string to integer conversion (if supported) + try: + instance_converted = test_model(field="100") + assert getattr(instance_converted, "field") == 100 + except ValidationError: + # If strict mode doesn't allow string conversion, that's also valid + pass + + # Test invalid conversion should raise ValidationError + with pytest.raises(ValidationError): + test_model(field="not_a_number") + + def test_boolean_field_in_model(self): + """Test boolean field type mapping and validation.""" + fields = [ + SchemaField( + json_path="test.field", + description="Test field", + attributes={"xQueryable": True, "dataType": "xsd:boolean", "array": "No"}, + ) + ] + models = core.build_dynamic_model(fields) + assert "test" in models + + test_model = models["test"] + + # Test valid boolean assignment + instance_true = test_model(field=True) + assert getattr(instance_true, "field") is True + + instance_false = test_model(field=False) + assert getattr(instance_false, "field") is False + + # Test truthy/falsy conversion + instance_truthy = test_model(field=1) + assert getattr(instance_truthy, "field") is True + + instance_falsy = test_model(field=0) + assert getattr(instance_falsy, "field") is False + + def test_date_field_in_model(self): + """Test date field type mapping and validation.""" + from datetime import date + + fields = [ + SchemaField( + json_path="test.field", + description="Test field", + attributes={"xQueryable": True, "dataType": "xsd:date", "array": "No"}, + ) + ] + models = core.build_dynamic_model(fields) + assert "test" in models + + test_model = models["test"] + + # Test valid date assignment + test_date = date(2023, 12, 25) + instance = test_model(field=test_date) + assert getattr(instance, "field") == test_date + + # Test string date parsing (if supported by Pydantic) + try: + instance_str = test_model(field="2023-12-25") + parsed_date = getattr(instance_str, "field") + assert isinstance(parsed_date, date) + except ValidationError: + # If strict parsing is not enabled, that's also valid behavior + pass + + def test_datetime_field_in_model(self): + """Test datetime field type mapping and validation.""" + from datetime import datetime + + fields = [ + SchemaField( + json_path="test.field", + description="Test field", + attributes={"xQueryable": True, "dataType": "xsd:dateTime", "array": "No"}, + ) + ] + models = core.build_dynamic_model(fields) + assert "test" in models + + test_model = models["test"] + + # Test valid datetime assignment + test_datetime = datetime(2023, 12, 25, 14, 30, 0) + instance = test_model(field=test_datetime) + assert getattr(instance, "field") == test_datetime + + # Test ISO string parsing (if supported) + try: + instance_str = test_model(field="2023-12-25T14:30:00") + parsed_datetime = getattr(instance_str, "field") + assert isinstance(parsed_datetime, datetime) + except ValidationError: + # If strict parsing is not enabled, that's also valid behavior + pass + + def test_enum_field_in_model(self): + """Test enum field type mapping and validation.""" + fields = [ + SchemaField( + json_path="test.field", + description="Test field", + attributes={"xQueryable": True, "enum": ["option1", "option2", "option3"], "array": "No"}, + ) + ] + models = core.build_dynamic_model(fields) + assert "test" in models + + test_model = models["test"] + + # Test valid enum value assignment + instance = test_model(field="option1") + enum_value = getattr(instance, "field") + assert enum_value.value == "option1" + + # Test all enum options + for option in ["option1", "option2", "option3"]: + instance = test_model(field=option) + assert getattr(instance, "field").value == option + + # Test invalid enum value + with pytest.raises(ValidationError): + test_model(field="invalid_option") + + def test_array_field_in_model(self): + """Test array field type mapping and validation.""" + fields = [ + SchemaField( + json_path="test.field", + description="Test field", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "Yes"}, + ) + ] + models = core.build_dynamic_model(fields) + assert "test" in models + + test_model = models["test"] + + # Test valid list assignment + instance = test_model(field=["item1", "item2", "item3"]) + field_value = getattr(instance, "field") + assert field_value == ["item1", "item2", "item3"] + + # Test empty list + instance_empty = test_model(field=[]) + assert getattr(instance_empty, "field") == [] + + # Test None assignment (optional) + instance_none = test_model(field=None) + assert getattr(instance_none, "field") is None + + def test_field_descriptions_preserved(self): + """Test that field descriptions are preserved in the model.""" + fields = [ + SchemaField( + json_path="test.name", + description="A person's full name", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + SchemaField( + json_path="test.age", + description="Age in years", + attributes={"xQueryable": True, "dataType": "xsd:integer", "array": "No"}, + ), + ] + models = core.build_dynamic_model(fields) + assert "test" in models + + test_model = models["test"] + + # Check that the model has the expected fields + instance = test_model() + assert hasattr(instance, "name") + assert hasattr(instance, "age") + + # Verify field information is accessible through model schema + schema = test_model.model_json_schema() + if "properties" in schema: + if "name" in schema["properties"]: + assert "description" in schema["properties"]["name"] + if "age" in schema["properties"]: + assert "description" in schema["properties"]["age"] + + +class TestBuildDynamicModel: + """Test the main build_dynamic_model function.""" + + def test_empty_fields(self): + """Test with empty schema fields.""" + result = core.build_dynamic_model([]) + assert result == {} + + def test_no_matching_fields(self): + """Test with fields that don't match the attribute flag.""" + fields = [ + SchemaField( + json_path="person.name", + description="Person name", + attributes={"xMutable": True}, # No xQueryable + ) + ] + result = core.build_dynamic_model(fields, attribute_flag="xQueryable") + assert result == {} + + def test_multiple_roots_error(self): + """Test error when fields have different root paths.""" + fields = [ + SchemaField(json_path="person.name", description="Person name", attributes={"xQueryable": True}), + SchemaField( + json_path="organization.name", description="Organization name", attributes={"xQueryable": True} + ), + ] + with pytest.raises(ValueError, match="must share a common root"): + core.build_dynamic_model(fields) + + def test_simple_model_creation(self): + """Test creating a simple model with basic fields.""" + fields = [ + SchemaField( + json_path="person.name", + description="Person name", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + SchemaField( + json_path="person.age", + description="Person age", + attributes={"xQueryable": True, "dataType": "xsd:integer", "array": "No"}, + ), + ] + + models = core.build_dynamic_model(fields) + + assert "person" in models + assert "person_wrapper" in models + + # Test inner model + person_model = models["person"] + assert issubclass(person_model, BaseModel) + + # Create instance + instance = person_model(name="John", age=30) + assert getattr(instance, "name") == "John" + assert getattr(instance, "age") == 30 + + # Test with None values (optional fields) + instance2 = person_model() + assert getattr(instance2, "name") is None + assert getattr(instance2, "age") is None + + def test_nested_model_creation(self): + """Test creating nested models.""" + fields = [ + SchemaField( + json_path="person.name.firstName", + description="First name", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + SchemaField( + json_path="person.name.lastName", + description="Last name", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + SchemaField( + json_path="person.age", + description="Person age", + attributes={"xQueryable": True, "dataType": "xsd:integer", "array": "No"}, + ), + ] + + models = core.build_dynamic_model(fields) + + assert "person" in models + # Check that nested models were created (key may vary based on implementation) + assert len(models) > 1 # Should have more than just the root model + + # Test nested structure + person_model = models["person"] + instance = person_model() + + # Should have name and age fields + assert hasattr(instance, "name") + assert hasattr(instance, "age") + + def test_array_fields(self): + """Test handling of array fields.""" + fields = [ + SchemaField( + json_path="person.identifier", + description="Identifiers", + attributes={"xQueryable": True, "array": "Yes", "branch": True}, + ), + SchemaField( + json_path="person.identifier.value", + description="Identifier value", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + ] + + models = core.build_dynamic_model(fields) + + # Should create models + assert len(models) > 0 + + def test_enum_fields(self): + """Test handling of enum fields.""" + fields = [ + SchemaField( + json_path="person.gender", + description="Gender", + attributes={"xQueryable": True, "enum": ["Male", "Female", "Other"], "array": "No"}, + ) + ] + + models = core.build_dynamic_model(fields) + + assert "person" in models + person_model = models["person"] + + # Create instance with enum value + instance = person_model(gender="Male") + # Enum fields return enum members, so we need to check the value + gender_value = getattr(instance, "gender") + assert gender_value.value == "Male" + + def test_model_suffix(self): + """Test model suffix application.""" + fields = [ + SchemaField( + json_path="person.name", + description="Person name", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ) + ] + + models = core.build_dynamic_model(fields, model_suffix="Filter") + + # Model names should include suffix + person_model = models["person"] + assert person_model.__name__.endswith("Filter") + + def test_all_optional_false(self): + """Test with all_optional=False.""" + fields = [ + SchemaField( + json_path="person.name", + description="Person name", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ) + ] + + models = core.build_dynamic_model(fields, all_optional=False) + + person_model = models["person"] + + # Should require fields when all_optional=False + with pytest.raises(ValidationError): + person_model() # Missing required field + + def test_allow_extra_true(self): + """Test with allow_extra=True.""" + fields = [ + SchemaField( + json_path="person.name", + description="Person name", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ) + ] + + models = core.build_dynamic_model(fields, allow_extra=True) + + person_model = models["person"] + + # Should allow extra fields + instance = person_model(name="John", extra_field="value") + assert getattr(instance, "extra_field") == "value" + + def test_attribute_flag_none(self): + """Test with attribute_flag=None (include all fields).""" + fields = [ + SchemaField( + json_path="person.name", + description="Person name", + attributes={"dataType": "xsd:string", "array": "No"}, # No xQueryable + ) + ] + + models = core.build_dynamic_model(fields, attribute_flag=None) + + assert "person" in models + person_model = models["person"] + assert hasattr(person_model(), "name") + + +class TestBuilderFunctions: + """Test the convenience builder functions.""" + + @patch("lif.dynamic_models.core.get_schema_fields") + def test_build_filter_models(self, mock_get_schema_fields): + """Test build_filter_models function.""" + mock_fields = [ + # SchemaField( + # json_path="person.name", + # description="Person name", + # attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + # ) + # SchemaField( + # json_path='person', + # description='', + # attributes={'xMutable': False, 'type': 'object', 'array': 'No', 'branch': True, 'leaf': False}, + # py_field_name='' + # ) + SchemaField( + json_path="person.identifier.identifier", + description='A number and/or alphanumeric code used to uniquely identify the entity. Use "missing at will", "ad-hoc" and "not applicable" for missing data to avoid skewed outcomes.', + attributes={ + "xQueryable": True, + "xMutable": False, + "dataType": "xsd:string", + "required": "Yes", + "array": "No", + "uniqueName": "Common.Identifier.identifier", + "type": "xsd:string", + "branch": False, + "leaf": True, + }, + py_field_name="", + ) + ] + mock_get_schema_fields.return_value = mock_fields + + models = core.build_filter_models(mock_fields) + + print(models) + + assert len(models) > 0 + # Should have Filter suffix + for model_cls in models.values(): + if hasattr(model_cls, "__name__"): + assert "Filter" in model_cls.__name__ + + @patch("lif.dynamic_models.core.get_schema_fields") + def test_build_mutation_models(self, mock_get_schema_fields): + """Test build_mutation_models function.""" + mock_fields = [ + SchemaField( + json_path="person.name", + description="Person name", + attributes={"xMutable": True, "dataType": "xsd:string", "array": "No"}, + ) + ] + mock_get_schema_fields.return_value = mock_fields + + models = core.build_mutation_models(mock_fields) + + assert len(models) > 0 + # Should have Mutation suffix + for model_cls in models.values(): + if hasattr(model_cls, "__name__"): + assert "Mutation" in model_cls.__name__ + + @patch("lif.dynamic_models.core.get_schema_fields") + def test_build_full_models(self, mock_get_schema_fields): + """Test build_full_models function.""" + mock_fields = [ + SchemaField( + json_path="person.name", description="Person name", attributes={"dataType": "xsd:string", "array": "No"} + ) + ] + mock_get_schema_fields.return_value = mock_fields + + models = core.build_full_models(mock_fields) + + assert len(models) > 0 + # Should have Type suffix + for model_cls in models.values(): + if hasattr(model_cls, "__name__"): + assert "Type" in model_cls.__name__ + + +class TestGetSchemaFields: + """Test the get_schema_fields function.""" + + @patch.dict(os.environ, {"OPENAPI_SCHEMA_FILE": str(PATH_TO_TEST_SCHEMA), "ROOT_NODE": "Person"}) + @patch("lif.dynamic_models.core.load_schema_nodes") + def test_get_schema_fields_with_env_vars(self, mock_load_schema_nodes): + """Test get_schema_fields with environment variables.""" + mock_fields = [SchemaField("person.name", "Name", {})] + mock_load_schema_nodes.return_value = mock_fields + + result = core.get_schema_fields() + + # Verify the function was called with correct arguments + mock_load_schema_nodes.assert_called_once() + call_args = mock_load_schema_nodes.call_args[0] + assert str(call_args[0]) == str(PATH_TO_TEST_SCHEMA) + assert call_args[1] == "Person" + assert result == mock_fields + + def test_get_schema_fields_no_env_var(self): + """Test get_schema_fields without OPENAPI_SCHEMA_FILE.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="OPENAPI_SCHEMA_FILE environment variable is not set"): + core.get_schema_fields() + + +class TestBuildAllModels: + """Test the build_all_models function.""" + + @patch("lif.dynamic_models.core.get_schema_fields") + def test_build_all_models(self, mock_get_schema_fields): + """Test build_all_models function.""" + mock_fields = [ + SchemaField( + json_path="person.name", + description="Person name", + attributes={"xQueryable": True, "xMutable": True, "dataType": "xsd:string", "array": "No"}, + ) + ] + mock_get_schema_fields.return_value = mock_fields + + fields, filter_models, mutation_models, full_models = core.build_all_models() + + assert fields == mock_fields + assert len(filter_models) > 0 + assert len(mutation_models) > 0 + assert len(full_models) > 0 + + @patch("lif.dynamic_models.core.get_schema_fields") + def test_build_all_models_custom_options(self, mock_get_schema_fields): + """Test build_all_models with custom options.""" + mock_fields = [ + SchemaField( + json_path="person.name", + description="Person name", + attributes={"xQueryable": True, "xMutable": True, "dataType": "xsd:string", "array": "No"}, + ) + ] + mock_get_schema_fields.return_value = mock_fields + + fields, filter_models, mutation_models, full_models = core.build_all_models( + filter_allow_extra=False, + filter_all_optional=False, + mutation_allow_extra=True, + mutation_all_optional=False, + full_allow_extra=True, + full_all_optional=True, + ) + + assert fields == mock_fields + assert len(filter_models) > 0 + assert len(mutation_models) > 0 + assert len(full_models) > 0 + + +class TestRealWorldIntegration: + """Integration tests with the actual test schema file.""" + + def test_with_test_schema_file(self): + """Test with the actual test_openapi_schema.json file.""" + with patch.dict(os.environ, {"OPENAPI_SCHEMA_FILE": str(PATH_TO_TEST_SCHEMA), "ROOT_NODE": "Person"}): + fields = core.get_schema_fields() + + # Should have loaded fields from the test schema + assert len(fields) > 0 + + # Test building models + filter_models = core.build_filter_models(fields) + assert len(filter_models) > 0 + + # Test that we can create instances + if "person" in filter_models: + person_model = filter_models["person"] + instance = person_model() + assert instance is not None + + def test_end_to_end_model_creation(self): + """End-to-end test of model creation and usage.""" + # Create test fields manually to simulate real usage + fields = [ + SchemaField( + json_path="person.identifier.identifier", + description="A number and/or alphanumeric code used to uniquely identify the entity", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + SchemaField( + json_path="person.identifier.identifierType", + description="The types of sources of identifiers used to uniquely identify the entity", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + SchemaField( + json_path="person.name.firstName", + description="The first name of a person or individual", + attributes={"xMutable": False, "dataType": "xsd:string", "array": "No"}, + ), + SchemaField( + json_path="person.name.lastName", + description="The last name of a person or individual", + attributes={"xMutable": False, "dataType": "xsd:string", "array": "No"}, + ), + ] + + # Build filter models (only xQueryable fields) + filter_models = core.build_filter_models(fields) + + # Should have models + assert len(filter_models) > 0 + + # Build full models (all fields) - use all_optional=True for this test + full_models = core.build_full_models(fields, all_optional=True) + + # Should have models + assert len(full_models) > 0 + + # Test that models work correctly + if "person" in full_models: + person_model = full_models["person"] + + # Create instance (should work with all_optional=True) + instance = person_model() + assert instance is not None + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_invalid_enum_values(self): + """Test handling of invalid enum values.""" + fields = [ + SchemaField( + json_path="person.status", + description="Status", + attributes={"xQueryable": True, "enum": ["Active", "Inactive"], "array": "No"}, + ) + ] + + models = core.build_dynamic_model(fields) + person_model = models["person"] + + # Valid enum value should work + instance = person_model(status="Active") + assert getattr(instance, "status").value == "Active" + + # Invalid enum value should raise validation error + with pytest.raises(ValidationError): + person_model(status="Invalid") + + def test_complex_nested_structure(self): + """Test deeply nested field structures.""" + fields = [ + SchemaField( + json_path="person.contact.address.street", + description="Street address", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + SchemaField( + json_path="person.contact.email.address", + description="Email address", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + ] + + models = core.build_dynamic_model(fields) + + # Should create nested models + assert "person" in models + assert len(models) > 1 + + # Test that nested structure works + person_model = models["person"] + instance = person_model() + + # Should have contact field + assert hasattr(instance, "contact") + + def test_special_characters_in_enum(self): + """Test enum with special characters.""" + fields = [ + SchemaField( + json_path="person.type", + description="Person type", + attributes={"xQueryable": True, "enum": ["Type-A", "Type B", "Type@C"], "array": "No"}, + ) + ] + + models = core.build_dynamic_model(fields) + person_model = models["person"] + + # Should handle special characters in enum values + instance = person_model(type="Type-A") + assert getattr(instance, "type").value == "Type-A" + + # Test all special character variants + for type_val in ["Type-A", "Type B", "Type@C"]: + instance = person_model(type=type_val) + assert getattr(instance, "type").value == type_val + + def test_empty_description(self): + """Test fields with empty descriptions.""" + fields = [ + SchemaField( + json_path="person.field", + description="", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ) + ] + + models = core.build_dynamic_model(fields) + assert len(models) > 0 + + # Should still create working model + person_model = models["person"] + instance = person_model(field="test") + assert getattr(instance, "field") == "test" + + def test_missing_attributes(self): + """Test fields with minimal attributes.""" + fields = [ + SchemaField( + json_path="person.field", + description="Basic field", + attributes={"xQueryable": True}, # Minimal attributes + ) + ] + + models = core.build_dynamic_model(fields) + assert len(models) > 0 + + # Should create model with default string type + person_model = models["person"] + instance = person_model(field="test") + assert getattr(instance, "field") == "test" + + def test_array_with_nested_objects(self): + """Test arrays containing nested objects.""" + fields = [ + SchemaField( + json_path="person.addresses", + description="List of addresses", + attributes={"xQueryable": True, "array": "Yes", "branch": True}, + ), + SchemaField( + json_path="person.addresses.street", + description="Street address", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + SchemaField( + json_path="person.addresses.city", + description="City", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ), + ] + + models = core.build_dynamic_model(fields) + assert len(models) > 0 + + # Should create model with nested array structure + person_model = models["person"] + instance = person_model() + assert hasattr(instance, "addresses") + + def test_very_long_field_names(self): + """Test handling of very long field names.""" + long_name = "a" * 100 # 100 character field name + fields = [ + SchemaField( + json_path=f"person.{long_name}", + description="Field with very long name", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ) + ] + + models = core.build_dynamic_model(fields) + person_model = models["person"] + + # Should handle long field names gracefully + instance = person_model(**{long_name: "test_value"}) + assert getattr(instance, long_name) == "test_value" + + def test_unicode_in_descriptions(self): + """Test handling of unicode characters in descriptions.""" + fields = [ + SchemaField( + json_path="person.name", + description="Имя пользователя (User name in Cyrillic) 用户名 (Chinese)", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ) + ] + + models = core.build_dynamic_model(fields) + assert len(models) > 0 + + person_model = models["person"] + instance = person_model(name="Test") + assert getattr(instance, "name") == "Test" + + +class TestPerformance: + """Test performance aspects of model generation.""" + + def test_large_schema_handling(self): + """Test model generation with a large number of fields.""" + import time + + # Generate 100 fields + fields = [] + for i in range(100): + fields.append( + SchemaField( + json_path=f"person.field_{i}", + description=f"Test field number {i}", + attributes={"xQueryable": True, "dataType": "xsd:string", "array": "No"}, + ) + ) + + start_time = time.time() + models = core.build_dynamic_model(fields) + end_time = time.time() + + # Should complete in reasonable time (less than 5 seconds) + assert (end_time - start_time) < 5.0 + assert len(models) > 0 + + # Test that the resulting model works + person_model = models["person"] + test_data = {f"field_{i}": f"value_{i}" for i in range(10)} # Test first 10 fields + instance = person_model(**test_data) + + for i in range(10): + assert getattr(instance, f"field_{i}") == f"value_{i}" + + def test_enum_caching_efficiency(self): + """Test that enum caching works efficiently.""" + # Create the same enum multiple times + enum_values = ["A", "B", "C"] + + enum1 = core.make_enum("TestEnum", enum_values) + enum2 = core.make_enum("TestEnum", enum_values) + enum3 = core.make_enum("TestEnum", enum_values) + + # Should return the same class (cached) + assert enum1 is enum2 is enum3 + + # Different values should create different enums + enum4 = core.make_enum("TestEnum", ["A", "B", "D"]) + assert enum1 is not enum4 + + +def test_sample(): + """Legacy test to maintain compatibility.""" + assert core is not None diff --git a/test/data/test_openapi_schema.json b/test/data/test_openapi_schema.json new file mode 100644 index 0000000..249c9c8 --- /dev/null +++ b/test/data/test_openapi_schema.json @@ -0,0 +1,71 @@ +{ + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0"}, + "paths": {}, + "components": { + "schemas": { + "Person": { + "type": "object", + "properties": { + "Identifier": { + "type": "array", + "properties": { + "identifier": { + "type": "string", + "description": "A number and/or alphanumeric code used to uniquely identify the entity", + "x-queryable": true + }, + "identifierType": { + "type": "string", + "description": "The types of sources of identifiers used to uniquely identify the entity", + "x-queryable": true + } + } + }, + "Name": { + "type": "array", + "properties": { + "firstName": { + "type": "string", + "description": "The first name of a person or individual", + "x-mutable": false + }, + "lastName": { + "type": "string", + "description": "The last name of a person or individual", + "x-mutable": false + } + } + }, + "Proficiency": { + "type": "array", + "properties": { + "name": { + "type": "string", + "description": "Name of the proficiency" + }, + "description": { + "type": "string", + "description": "Description of the proficiency" + } + } + }, + "Contact": { + "type": "array", + "properties": { + "Email": { + "type": "array", + "properties": { + "emailAddress": { + "type": "string", + "description": "The electronic mail address of an individual or person" + } + } + } + } + } + } + } + } + } +} From 3b4bf36399bd9f06af361d68d880ce27ba1f6ce8 Mon Sep 17 00:00:00 2001 From: Patrick Yoho Date: Mon, 26 Jan 2026 13:10:35 -0600 Subject: [PATCH 4/5] Register new polylith components in pyproject.toml --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index bef3a42..45bdf63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,8 @@ dev-dependencies = ["polylith-cli"] "components/lif/translator" = "lif/translator" "components/lif/mdr_utils" = "lif/mdr_utils" "components/lif/mdr_dto" = "lif/mdr_dto" +"components/lif/schema" = "lif/schema" +"components/lif/dynamic_models" = "lif/dynamic_models" [tool.ruff] line-length = 120 From 26dc6621bc5adc63b464095875f45024aa9e0b76 Mon Sep 17 00:00:00 2001 From: Patrick Yoho Date: Wed, 28 Jan 2026 10:44:17 -0600 Subject: [PATCH 5/5] WIP - attempt to decompose a&m work into logical components. Adds graphql, openapi_schema, and utils polylith components. --- bases/lif/api_graphql/core.py | 76 ++-- components/lif/graphql/__init__.py | 3 + components/lif/graphql/core.py | 108 ++++++ components/lif/graphql/schema_factory.py | 341 +++++++++++++++++ components/lif/graphql/type_registry.py | 125 +++++++ components/lif/graphql/utils.py | 87 +++++ components/lif/openapi_schema/__init__.py | 3 + components/lif/openapi_schema/core.py | 71 ++++ components/lif/utils/__init__.py | 3 + components/lif/utils/core.py | 84 +++++ components/lif/utils/strings.py | 123 +++++++ components/lif/utils/validation.py | 66 ++++ projects/lif_graphql_api/pyproject.toml | 6 + pyproject.toml | 3 + test/components/lif/graphql/__init__.py | 0 test/components/lif/graphql/test_core.py | 5 + test/components/lif/graphql/test_schema.py | 48 +++ .../lif/graphql/test_type_registry.py | 346 ++++++++++++++++++ test/components/lif/graphql/test_utils.py | 9 + .../components/lif/openapi_schema/__init__.py | 0 .../lif/openapi_schema/test_core.py | 53 +++ test/components/lif/utils/__init__.py | 0 test/components/lif/utils/test_core.py | 46 +++ test/components/lif/utils/test_strings.py | 94 +++++ test/components/lif/utils/test_validation.py | 117 ++++++ 25 files changed, 1788 insertions(+), 29 deletions(-) create mode 100644 components/lif/graphql/__init__.py create mode 100644 components/lif/graphql/core.py create mode 100644 components/lif/graphql/schema_factory.py create mode 100644 components/lif/graphql/type_registry.py create mode 100644 components/lif/graphql/utils.py create mode 100644 components/lif/openapi_schema/__init__.py create mode 100644 components/lif/openapi_schema/core.py create mode 100644 components/lif/utils/__init__.py create mode 100644 components/lif/utils/core.py create mode 100644 components/lif/utils/strings.py create mode 100644 components/lif/utils/validation.py create mode 100644 test/components/lif/graphql/__init__.py create mode 100644 test/components/lif/graphql/test_core.py create mode 100644 test/components/lif/graphql/test_schema.py create mode 100644 test/components/lif/graphql/test_type_registry.py create mode 100644 test/components/lif/graphql/test_utils.py create mode 100644 test/components/lif/openapi_schema/__init__.py create mode 100644 test/components/lif/openapi_schema/test_core.py create mode 100644 test/components/lif/utils/__init__.py create mode 100644 test/components/lif/utils/test_core.py create mode 100644 test/components/lif/utils/test_strings.py create mode 100644 test/components/lif/utils/test_validation.py diff --git a/bases/lif/api_graphql/core.py b/bases/lif/api_graphql/core.py index 9ab499d..a505c0b 100644 --- a/bases/lif/api_graphql/core.py +++ b/bases/lif/api_graphql/core.py @@ -1,48 +1,66 @@ """ -ASGI application generator for OpenAPI-to-GraphQL. +ASGI application generator for LIF GraphQL. -Converts OpenAPI schema definitions to a Strawberry GraphQL API dynamically. -Generates Python types, input filters, enums, and root query objects from OpenAPI JSON schemas. +This base wires environment configuration, constructs the HTTP backend, +builds the GraphQL schema via the `lif.graphql.schema_factory`, and mounts +the GraphQL endpoint using Strawberry's FastAPI router. """ +from __future__ import annotations + import os -from contextlib import asynccontextmanager -from typing import AsyncGenerator +from pathlib import Path + from fastapi import FastAPI from strawberry.fastapi import GraphQLRouter -from lif.logging import get_logger -from lif.mdr_client import get_openapi_lif_data_model -from lif.openapi_to_graphql.core import generate_graphql_schema +from lif.graphql.core import HttpBackend +from lif.graphql.schema_factory import build_schema +from lif.logging.core import get_logger +from lif.openapi_schema.core import get_schema_fields +from lif.utils.core import get_required_env_var +from lif.utils.validation import is_truthy logger = get_logger(__name__) +# Environment variable validation at import time +LIF_QUERY_PLANNER_URL = get_required_env_var("LIF_QUERY_PLANNER_URL") + -LIF_QUERY_PLANNER_URL = os.getenv("LIF_QUERY_PLANNER_URL", "http://localhost:8002") -LIF_GRAPHQL_ROOT_TYPE_NAME = os.getenv("LIF_GRAPHQL_ROOT_TYPE_NAME", "Person") +def create_app() -> FastAPI: + # Ensure process cwd is the project root so Rich/Strawberry can compute relative paths + try: + project_root = Path(__file__).resolve().parents[4] # lif-main + os.chdir(project_root) + logger.debug(f"Set working directory to project root: {project_root}") + except Exception: + logger.debug("Could not change working directory to project root", exc_info=True) -logger.info(f"LIF_QUERY_PLANNER_URL: {LIF_QUERY_PLANNER_URL}") -logger.info(f"LIF_GRAPHQL_ROOT_TYPE_NAME: {LIF_GRAPHQL_ROOT_TYPE_NAME}") -logger.info(f"LIF_MDR_API_URL: {os.getenv('LIF_MDR_API_URL')}") + root_type = os.getenv("LIF_GRAPHQL_ROOT_TYPE_NAME", "Person") + # TODO: The graphql api should only contact the query planner and not the query cache directly + # HTTP-only backend configuration -async def fetch_dynamic_graphql_schema(openapi: dict): - return await generate_graphql_schema( - openapi=openapi, - root_type_name=LIF_GRAPHQL_ROOT_TYPE_NAME, - query_planner_query_url=LIF_QUERY_PLANNER_URL.rstrip("/") + "/query", - query_planner_update_url=LIF_QUERY_PLANNER_URL.rstrip("/") + "/update", - ) + # Back-compat fallback from planner URL if provided + base_url = LIF_QUERY_PLANNER_URL.rstrip("/") + query_url = f"{base_url}/query" + update_url = f"{base_url}/update" + logger.info(f"GraphQL root type: {root_type}") + logger.info(f"Query URL: {query_url}") + logger.info(f"Update URL: {update_url}") -@asynccontextmanager -async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - openapi = await get_openapi_lif_data_model() - schema = await fetch_dynamic_graphql_schema(openapi=openapi) - logger.info("GraphQL schema successfully created") - app.include_router(GraphQLRouter(schema, prefix="/graphql")) - logger.info("GraphQL router successfully created and included in FastAPI app") - yield + backend = HttpBackend(query_url=query_url, update_url=update_url) + fields = get_schema_fields() + schema = build_schema(schema_fields=fields, root_node=root_type, backend=backend) + # Optional schema artifact dumping for tooling + if is_truthy(os.getenv("LIF_GRAPHQL_DUMP_SCHEMA")): + out_dir = Path(__file__).parent / "_artifacts" + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "schema.graphql").write_text(schema.as_str(), encoding="utf-8") + logger.info(f"Wrote schema to {out_dir / 'schema.graphql'}") -app = FastAPI(lifespan=lifespan) + app = FastAPI() + app.include_router(GraphQLRouter(schema), prefix="/graphql") + return app diff --git a/components/lif/graphql/__init__.py b/components/lif/graphql/__init__.py new file mode 100644 index 0000000..4b8c0c4 --- /dev/null +++ b/components/lif/graphql/__init__.py @@ -0,0 +1,3 @@ +from lif.graphql import core + +__all__ = ["core"] diff --git a/components/lif/graphql/core.py b/components/lif/graphql/core.py new file mode 100644 index 0000000..f8e297b --- /dev/null +++ b/components/lif/graphql/core.py @@ -0,0 +1,108 @@ +""" +GraphQL backend and resolver helpers. + +This module defines a minimal Backend protocol and an HTTP implementation +used by the GraphQL schema factory. It handles payload shapes expected by +the LIF query cache service and returns plain dicts for persons. +""" + +# TODO: Figure out if we actually want the annotations directive for this module +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Protocol + +import httpx + +from lif.logging.core import get_logger +from lif.graphql.utils import get_fragments_from_info, get_selected_field_paths, serialize_for_json + +logger = get_logger(__name__) + + +# TODO: Figure out if this backend protocol belongs in this component +class Backend(Protocol): + async def query(self, filter_dict: Optional[dict], selected_fields: List[str]) -> List[dict]: ... + + async def update( + self, filter_dict: Optional[dict], input_dict: Optional[dict], selected_fields: List[str] + ) -> List[dict]: ... + + +class HttpBackend: + """HTTP client for the Query Planner endpoints. + + Expects two endpoints: + - query_url: POST {"filter": {"person": |None}, "selected_fields": [..]} + - update_url: POST {"updatePerson": {"filter": {"person": ...}, "input": {...}, "selected_fields": [..]}} + """ + + def __init__(self, query_url: Optional[str], update_url: Optional[str]) -> None: + self.query_url = query_url + self.update_url = update_url + + async def _post_json(self, url: Optional[str], payload: dict) -> httpx.Response: + if not url: + raise RuntimeError("Backend URL is not configured") + async with httpx.AsyncClient() as client: + logger.debug("POST %s payload=%s", url, payload) + resp = await client.post(url, json=payload) + resp.raise_for_status() + return resp + + # async def query(self, ) + + # TODO: Potentially remove + async def query(self, filter_dict: Optional[dict], selected_fields: List[str]) -> List[dict]: + wrapped_filter = {"person": filter_dict} if filter_dict is not None else None + payload = {"filter": wrapped_filter, "selected_fields": selected_fields} + resp = await self._post_json(self.query_url, payload) + raw = resp.json() + logger.debug("Query response: %s", raw) + persons: List[dict] = [] + for item in raw or []: + if isinstance(item, dict): + val = item.get("person") + if isinstance(val, list): + persons.extend(val) + return persons + + # TODO: Potentially remove + async def update( + self, filter_dict: Optional[dict], input_dict: Optional[dict], selected_fields: List[str] + ) -> List[dict]: + wrapped_filter = {"person": filter_dict} if filter_dict is not None else None + payload = {"updatePerson": {"filter": wrapped_filter, "input": input_dict, "selected_fields": selected_fields}} + resp = await self._post_json(self.update_url, payload) + data = resp.json() or {} + logger.debug("Update response: %s", data) + persons = data.get("person", []) + if persons and not isinstance(persons, list): + persons = [persons] + return persons or [] + + +def extract_selected_fields(info: Any) -> List[str]: + """Compute dotted JSON paths actually requested by the client.""" + fragments = get_fragments_from_info(info) + return get_selected_field_paths(info.field_nodes, fragments, [info.field_name]) + + +def pydantic_inputs_to_dict( + filter_input: Optional[Any] = None, mutation_input: Optional[Any] = None +) -> tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: + """Serialize Strawberry-Pydantic inputs into JSON-safe dicts.""" + + def to_dict(obj: Any) -> Optional[Dict[str, Any]]: + if obj is None: + return None + if hasattr(obj, "to_pydantic"): + return serialize_for_json(obj.to_pydantic()) + if hasattr(obj, "model_dump"): + return serialize_for_json(obj) + if isinstance(obj, dict): + return serialize_for_json(obj) + return serialize_for_json(obj) + + filter_dict = to_dict(filter_input) + input_dict = to_dict(mutation_input) + return filter_dict, input_dict diff --git a/components/lif/graphql/schema_factory.py b/components/lif/graphql/schema_factory.py new file mode 100644 index 0000000..d439f90 --- /dev/null +++ b/components/lif/graphql/schema_factory.py @@ -0,0 +1,341 @@ +""" +Schema factory for building the Strawberry GraphQL schema from dynamic models. + +This module is side-effect free: given a root node name and a backend, +it builds all Pydantic models, Strawberry types/inputs, and returns a Schema. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Dict, List, Optional, Any, Protocol +import os + +import strawberry +from strawberry.experimental.pydantic import input as pyd_input +from strawberry.experimental.pydantic import type as pyd_type +from strawberry.types import Info +from strawberry.experimental.pydantic.exceptions import UnregisteredTypeException +from strawberry.scalars import JSON + +from lif.datatypes.schema import SchemaField +from lif.dynamic_models.core import build_filter_models, build_full_models, build_mutation_models +from lif.openapi_schema.core import get_schema_fields +from lif.graphql.core import Backend, extract_selected_fields, pydantic_inputs_to_dict +from lif.graphql.utils import to_pascal_case_from_str, unique_type_name +from lif.utils.validation import to_bool + + + + + +class StrawberryType(Protocol): + """Mixin for Strawberry types from Pydantic models.""" + + @classmethod + def from_pydantic(cls, data): ... + + +RootType: StrawberryType | None = None +FilterInput: StrawberryType | None = None +MutationInput: StrawberryType | None = None + + +def build_schema(*, schema_fields: List[SchemaField], root_node: str, backend: Backend) -> strawberry.Schema: + """Build a Strawberry GraphQL schema for the given root node and backend. + + Args: + schema_fields (List[SchemaField]): List of schema fields from OpenAPI/JSON Schema. + root_node (str): The root node name (e.g., "Person"). + backend (Backend): The backend implementation for queries and mutations. + Returns: + strawberry.Schema: The constructed GraphQL schema. + """ + # Build dynamic Pydantic models + filter_models = build_filter_models(schema_fields, allow_extra=False, all_optional=False) + mutation_models = build_mutation_models(schema_fields, allow_extra=False, all_optional=True) + full_models = build_full_models(schema_fields, allow_extra=False, all_optional=True) + + root_name = to_pascal_case_from_str(root_node) + + # Build dynamic Strawberry types and inputs + strawberry_types: Dict[str, type[StrawberryType]] = {} + strawberry_inputs: Dict[str, type[StrawberryType]] = {} + + # ===== Output (Query) Types ===== + for name, model in full_models.items(): + if "wrapper" in name.lower(): + continue # Skip wrapper types + tname = to_pascal_case_from_str(name) + if tname not in strawberry_types: + strawberry_types[tname] = pyd_type(model=model, all_fields=True)( + type(tname, (), {}) + ) + + # ===== Filter (Input) Types ===== + for name, model in filter_models.items(): + if "wrapper" in name.lower(): + continue # Skip wrapper types + iname = unique_type_name(name, "FilterInput", root_node) + iname = to_pascal_case_from_str(iname) + if iname not in strawberry_inputs: + strawberry_inputs[iname] = pyd_input(model=model, all_fields=True)( + type(iname, (), {}) + ) + + # ===== Mutation (Input) Types ===== + for name, model in mutation_models.items(): + if "wrapper" in name.lower(): + continue # Skip wrapper types + iname = unique_type_name(name, "MutationInput", root_node) + iname = to_pascal_case_from_str(iname) + if iname not in strawberry_inputs: + strawberry_inputs[iname] = pyd_input(model=model, all_fields=True)( + type(iname, (), {}) + ) + + + global RootType, FilterInput, MutationInput + RootType = strawberry_types[root_name] + FilterInput = strawberry_inputs[ + to_pascal_case_from_str(unique_type_name(root_name, "FilterInput", root_node)) + ] + MutationInput = strawberry_inputs[ + to_pascal_case_from_str(unique_type_name(root_name, "MutationInput", root_node)) + ] + + # ===== Strawberry Query & Mutation Roots ===== + @strawberry.type + class Query: + @strawberry.field + async def persons( + self, + info: Info, + filter: FilterInput # type: ignore[type-arg] + ) -> List[RootType]: # type: ignore[valid-type] + selected = extract_selected_fields(info) + filter_dict, _ = pydantic_inputs_to_dict(filter_input=filter) + persons = await backend.query(filter_dict, selected) + return [RootType.from_pydantic(full_models[root_node](**p)) for p in persons] + + @strawberry.field + async def person( + self, + info: Info, + filter: FilterInput # type: ignore[type-arg] + ) -> Optional[RootType]: # type: ignore[valid-type] + selected = extract_selected_fields(info) + filter_dict, _ = pydantic_inputs_to_dict(filter_input=filter) + persons = await backend.query(filter_dict, selected) + return RootType.from_pydantic(full_models[root_node](**persons[0])) if persons else None + + + # ===== Strawberry Schema ===== + + schema = strawberry.Schema( + query=Query, + # mutation=Mutation, + types=[*strawberry_types.values(), *strawberry_inputs.values()], + ) + + return schema + +def build_schema_old(*, root_node: str, backend: Backend) -> strawberry.Schema: + # Allow a minimal mode for tests or environments where Pydantic-to-Strawberry registration is undesirable + minimal_mode = to_bool(os.getenv("LIF_GRAPHQL_DISABLE_PYDANTIC_TYPES", "false")) + + # Build dynamic Pydantic models + fields = get_schema_fields() + filter_models = build_filter_models(fields, allow_extra=False, all_optional=False) + mutation_models = build_mutation_models(fields, allow_extra=False, all_optional=True) + full_models = build_full_models(fields, allow_extra=False, all_optional=True) + + root_name = to_pascal_case_from_str(root_node) + lc_root = root_node[:1].lower() + root_node[1:] + # Prefer the lower-camel-case root key inserted by the dynamic model builder (e.g., "person") + FullModel = full_models.get(lc_root) + if FullModel is None: + # Fallback to the suffixed key (e.g., "PersonType") if present + FullModel = full_models.get(f"{root_name}Type") + if FullModel is None: + raise KeyError(f"Cannot locate FullModel for root '{root_node}' in dynamic models") + + # print(f"### filter models: {filter_models}") + # for k, _ in filter_models.items(): + # print(f"Filter model: {k}") + + FilterModel = filter_models[f"{root_name}Filter"] + MutationModel = mutation_models[f"{root_name}Mutation"] + WrapperModel = full_models[f"{root_name.lower()}_wrapper"] + + # Generate Strawberry output types + # strawberry_types: Dict[str, type] = {} + # strawberry_inputs: Dict[str, type] = {} + StrawberryTypes: Dict[str, type[StrawberryType]] = {} + StrawberryInputs: Dict[str, type[StrawberryType]] = {} + + if not minimal_mode: + # ===== Output (Query) Types ===== + + for name, model in full_models.items(): + if "wrapper" in name.lower(): + continue # Skip wrapper types + tname = to_pascal_case_from_str(name) + if tname not in StrawberryTypes: + StrawberryTypes[tname] = pyd_type(model=model, all_fields=True)( + type(tname, (), {}) + ) + + # ===== Filter (Input) Types ===== + + for name, model in filter_models.items(): + if "wrapper" in name.lower(): + continue # Skip wrapper types + iname = unique_type_name(name, "FilterInput", root_node) + iname = to_pascal_case_from_str(iname) + if iname not in StrawberryInputs: + StrawberryInputs[iname] = pyd_input(model=model, all_fields=True)( + type(iname, (), {}) + ) + + # ===== Mutation (Input) Types ===== + + for name, model in mutation_models.items(): + if "wrapper" in name.lower(): + continue # Skip wrapper types + iname = unique_type_name(name, "MutationInput", root_node) + iname = to_pascal_case_from_str(iname) + if iname not in StrawberryInputs: + StrawberryInputs[iname] = pyd_input(model=model, all_fields=True)( + type(iname, (), {}) + ) + + RootType = StrawberryTypes[root_name] + FilterInput = StrawberryInputs[ + to_pascal_case_from_str(unique_type_name(root_name, "FilterInput", root_node)) + ] + MutationInput = StrawberryInputs[ + to_pascal_case_from_str(unique_type_name(root_name, "MutationInput", root_node)) + ] + + # if not minimal_mode: + # # Register full models ensuring nested dependencies are available first + # pending = [(name, model) for name, model in full_models.items() if "wrapper" not in name.lower()] + # attempts = 0 + # max_attempts = len(pending) * 2 + 1 + # while pending and attempts < max_attempts: + # attempts += 1 + # next_round = [] + # for name, model in pending: + # tname = to_pascal_case_from_str(name) + # if tname in strawberry_types: + # continue + # try: + # strawberry_types[tname] = pyd_type(model=model, all_fields=True)(type(tname, (), {})) + # except UnregisteredTypeException: + # next_round.append((name, model)) + # pending = next_round + # if pending: + # # If still pending, raise the first error contextually + # raise UnregisteredTypeException(pending[0][1]) + + # # Filter inputs + # for name, model in filter_models.items(): + # if "wrapper" in name.lower(): + # continue + # iname = to_pascal_case_from_str(unique_type_name(name, "FilterInput", root_node)) + # if iname not in strawberry_inputs: + # strawberry_inputs[iname] = pyd_input(model=model, all_fields=True)(type(iname, (), {})) + + # # Mutation inputs + # for name, model in mutation_models.items(): + # if "wrapper" in name.lower(): + # continue + # iname = to_pascal_case_from_str(unique_type_name(name, "MutationInput", root_node)) + # if iname not in strawberry_inputs: + # strawberry_inputs[iname] = pyd_input(model=model, all_fields=True)(type(iname, (), {})) + + # if not minimal_mode: + # RootType = strawberry_types[root_name] + + # Resolvers + @strawberry.type + class QueryJSON: + @strawberry.field + async def persons(self, info: Info, filter: JSON) -> List[JSON]: # type: ignore[valid-type] + selected = extract_selected_fields(info) + filter_dict, _ = pydantic_inputs_to_dict(filter_input=filter) + persons = await backend.query(filter_dict, selected) + return persons + + @strawberry.field + async def person(self, info: Info, filter: JSON) -> Optional[JSON]: # type: ignore[valid-type] + selected = extract_selected_fields(info) + filter_dict, _ = pydantic_inputs_to_dict(filter_input=filter) + persons = await backend.query(filter_dict, selected) + return persons[0] if persons else None + + @strawberry.type + class QueryTyped: + @strawberry.field + async def persons( + self, + info: Info, + filter: FilterInput # type: ignore[type-arg] + ) -> List[RootType]: # type: ignore[valid-type] + selected = extract_selected_fields(info) + filter_dict, _ = pydantic_inputs_to_dict(filter_input=filter) + persons = await backend.query(filter_dict, selected) + return [RootType.from_pydantic(FullModel(**p)) for p in persons] + + @strawberry.field + async def person( + self, + info: Info, + filter: FilterInput # type: ignore[type-arg] + ) -> Optional[RootType]: # type: ignore[valid-type] + selected = extract_selected_fields(info) + filter_dict, _ = pydantic_inputs_to_dict(filter_input=filter) + persons = await backend.query(filter_dict, selected) + return RootType.from_pydantic(FullModel(**persons[0])) if persons else None + + Query = QueryJSON if minimal_mode else QueryTyped + + @strawberry.type + class MutationJSON: + @strawberry.mutation + async def update_person( + self, info: Info, filter: JSON, input_: JSON) -> List[JSON]: # type: ignore[valid-type] + selected = extract_selected_fields(info) + filter_dict, input_dict = pydantic_inputs_to_dict(filter_input=filter, mutation_input=input_) + persons = await backend.update(filter_dict, input_dict, selected) + return persons + + @strawberry.type + class MutationTyped: + @strawberry.mutation + async def update_person( + self, + info: Info, + filter: FilterInput, + input_: MutationInput, + ) -> List[RootType]: # runtime type is (FilterInput, MutationInput) -> List[RootType] + selected = extract_selected_fields(info) + filter_dict, input_dict = pydantic_inputs_to_dict(filter_input=filter, mutation_input=input_) + persons = await backend.update(filter_dict, input_dict, selected) + return [RootType.from_pydantic(FullModel(**p)) for p in persons] + + Mutation = MutationJSON if minimal_mode else MutationTyped + + # TODO: remove this debug output + # if not minimal_mode: + # out_dir = Path(__file__).parent / "_artifacts" + # out_dir.mkdir(parents=True, exist_ok=True) + # (out_dir / "schema.graphql").write_text(Query.as_str(), encoding="utf-8") + + print(f"### Query: {Query}") + + types_arg = [*StrawberryTypes.values(), *StrawberryInputs.values()] if not minimal_mode else [] + # types_arg = [*strawberry_types.values(), *strawberry_inputs.values()] if not minimal_mode else [] + schema = strawberry.Schema(query=Query, mutation=Mutation, types=types_arg) + return schema diff --git a/components/lif/graphql/type_registry.py b/components/lif/graphql/type_registry.py new file mode 100644 index 0000000..db2b026 --- /dev/null +++ b/components/lif/graphql/type_registry.py @@ -0,0 +1,125 @@ +""" +Centralized type registry for GraphQL schema generation. +Manages dynamic creation and caching of Pydantic and Strawberry types. +""" + +from typing import Dict, Type, Optional +# import strawberry +from strawberry.experimental.pydantic import input as pyd_input +from strawberry.experimental.pydantic import type as pyd_type + +from lif.dynamic_models.core import build_filter_models, build_full_models, build_mutation_models +from lif.openapi_schema.core import get_schema_fields +from lif.graphql.utils import to_pascal_case_from_str, unique_type_name + +class TypeRegistry: + """Singleton registry for managing dynamically generated types.""" + + _instance: Optional['TypeRegistry'] = None + _initialized: bool = False + + def __new__(cls) -> 'TypeRegistry': + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not self._initialized: + self.strawberry_types: Dict[str, Type] = {} + self.strawberry_inputs: Dict[str, Type] = {} + self.pydantic_models: Dict[str, Type] = {} + self._initialized = True + + def initialize_types(self, root_node: str, minimal_mode: bool = False): + """Initialize all dynamic types for the given root node.""" + if self.strawberry_types: # Already initialized + return + + # Build dynamic Pydantic models + fields = get_schema_fields() + filter_models = build_filter_models(fields, allow_extra=False, all_optional=False) + mutation_models = build_mutation_models(fields, allow_extra=False, all_optional=True) + full_models = build_full_models(fields, allow_extra=False, all_optional=True) + + # Store Pydantic models + self.pydantic_models.update({ + 'filter_models': filter_models, + 'mutation_models': mutation_models, + 'full_models': full_models + }) + + if not minimal_mode: + # Generate Strawberry types + self._create_strawberry_types(full_models, filter_models, mutation_models, root_node) + + def _create_strawberry_types(self, full_models, filter_models, mutation_models, root_node): + """Create Strawberry types from Pydantic models.""" + # Output types + for name, model in full_models.items(): + if "wrapper" in name.lower(): + continue + tname = to_pascal_case_from_str(name) + if tname not in self.strawberry_types: + self.strawberry_types[tname] = pyd_type(model=model, all_fields=True)( + type(tname, (), {}) + ) + + # Filter input types + for name, model in filter_models.items(): + if "wrapper" in name.lower(): + continue + iname = unique_type_name(name, "FilterInput", root_node) + iname = to_pascal_case_from_str(iname) + if iname not in self.strawberry_inputs: + self.strawberry_inputs[iname] = pyd_input(model=model, all_fields=True)( + type(iname, (), {}) + ) + + # Mutation input types + for name, model in mutation_models.items(): + if "wrapper" in name.lower(): + continue + iname = unique_type_name(name, "MutationInput", root_node) + iname = to_pascal_case_from_str(iname) + if iname not in self.strawberry_inputs: + self.strawberry_inputs[iname] = pyd_input(model=model, all_fields=True)( + type(iname, (), {}) + ) + + def get_models_for_root(self, root_node: str): + """Get the main models for a root node.""" + root_name = to_pascal_case_from_str(root_node) + lc_root = root_node[:1].lower() + root_node[1:] + + full_models = self.pydantic_models['full_models'] + filter_models = self.pydantic_models['filter_models'] + mutation_models = self.pydantic_models['mutation_models'] + + # Get FullModel + full_model = full_models.get(lc_root) or full_models.get(f"{root_name}Type") + if full_model is None: + raise KeyError(f"Cannot locate FullModel for root '{root_node}' in dynamic models") + + return { + 'FullModel': full_model, + 'FilterModel': filter_models[f"{root_name}Filter"], + 'MutationModel': mutation_models[f"{root_name}Mutation"], + 'WrapperModel': full_models[f"{root_name.lower()}_wrapper"] + } + + def get_strawberry_types_for_root(self, root_node: str): + """Get Strawberry types for a root node.""" + root_name = to_pascal_case_from_str(root_node) + + return { + 'RootType': self.strawberry_types[root_name], + 'FilterInput': self.strawberry_inputs[ + to_pascal_case_from_str(unique_type_name(root_name, "FilterInput", root_node)) + ], + 'MutationInput': self.strawberry_inputs[ + to_pascal_case_from_str(unique_type_name(root_name, "MutationInput", root_node)) + ] + } + +# Global registry instance +type_registry = TypeRegistry() \ No newline at end of file diff --git a/components/lif/graphql/utils.py b/components/lif/graphql/utils.py new file mode 100644 index 0000000..fd9513d --- /dev/null +++ b/components/lif/graphql/utils.py @@ -0,0 +1,87 @@ +import dataclasses +import datetime as _dt +import enum +import re +from typing import Any, Dict, List, Optional, Union + +import strawberry + + + +# TODO: Replace with lif.string_utils function +def to_pascal_case_from_str(s: str) -> str: + """Convert a string to PascalCase.""" + parts = re.findall( + r"[A-Za-z][a-z]*|[A-Z]+(?![a-z])", s.replace("-", " ").replace("_", " ") + ) + return "".join(p.capitalize() for p in parts if p) + + +# TODO: Replace with lif.string_utils function +def to_pascal_case(parts): + """Join list of strings as PascalCase.""" + return "".join(p.capitalize() for p in parts if p) + + +# TODO: Update docstring +def unique_type_name(name: str, suffix: str, root_node: str) -> str: + """Make a unique PascalCase type name with suffix.""" + parts = re.split(r"\W+", name) + root = to_pascal_case_from_str(root_node) + if parts and to_pascal_case_from_str(parts[0]) == root: + parts = parts[1:] + if not parts: + parts = [root] + return f"{to_pascal_case(parts)}{suffix[:1].upper() + suffix[1:]}" + + +def _iso(obj: Union[_dt.date, _dt.datetime]) -> str: + """Return date/datetime as ISO-8601 string.""" + return obj.isoformat() + + +def serialize_for_json(obj: Any) -> Any: + """Convert obj into JSON-serialisable structures.""" + if hasattr(obj, "model_dump"): + return serialize_for_json(obj.model_dump(exclude_none=True)) + if dataclasses.is_dataclass(obj): + return serialize_for_json(dataclasses.asdict(obj)) + if isinstance(obj, enum.Enum): + return obj.value + if isinstance(obj, (_dt.datetime, _dt.date)): + return _iso(obj) + if isinstance(obj, (list, tuple)): + return [serialize_for_json(x) for x in obj] + if isinstance(obj, dict): + return {str(k): serialize_for_json(v) for k, v in obj.items()} + return obj + + +def get_fragments_from_info(info: strawberry.types.Info) -> Dict[str, Any]: + """Return fragments map from GraphQL info object.""" + return getattr(info, "fragments", {}) + + +def get_selected_field_paths( + field_nodes: List[Any], + fragments: Dict[str, Any], + prefix: Optional[Union[List[str], str]] = None, +) -> List[str]: + """Get dotted JSON paths actually requested in the selection.""" + prefix = [] if prefix is None else ([prefix] if isinstance(prefix, str) else prefix) + paths: set[str] = set() + + for node in field_nodes: + sel_set = getattr(node, "selection_set", None) + if not sel_set: + continue + for selection in sel_set.selections: + match selection.kind: + case "field": + path = prefix + [selection.name.value] + paths.add(".".join(path)) + paths.update(get_selected_field_paths([selection], fragments, path)) + case "fragment_spread": + frag = fragments[selection.name.value] + paths.update(get_selected_field_paths([frag], fragments, prefix)) + return list(paths) \ No newline at end of file diff --git a/components/lif/openapi_schema/__init__.py b/components/lif/openapi_schema/__init__.py new file mode 100644 index 0000000..ae06857 --- /dev/null +++ b/components/lif/openapi_schema/__init__.py @@ -0,0 +1,3 @@ +from lif.openapi_schema import core + +__all__ = ["core"] diff --git a/components/lif/openapi_schema/core.py b/components/lif/openapi_schema/core.py new file mode 100644 index 0000000..3496fcd --- /dev/null +++ b/components/lif/openapi_schema/core.py @@ -0,0 +1,71 @@ +"""OpenAPI schema provider. + +This component is responsible for sourcing the OpenAPI document that defines the +dynamic model fields. It integrates with the MDR client as the primary source, +and falls back to a local file (primarily for development and tests). + +Public API: +- get_schema_fields() -> List[SchemaField] + +Configuration (envvars): +- LIF_OPENAPI_SCHEMA_PATH: Optional file path fallback to read JSON. +- LIF_OPENAPI_ROOT: Optional root schema name (e.g., "Person"). +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import List + +from lif.datatypes.schema import SchemaField +from lif.logging.core import get_logger +from lif.schema.core import load_schema_nodes +import lif.mdr_client.core as mdr_core + +logger = get_logger(__name__) + + +def _load_from_file(path: str | Path) -> dict: + p = Path(path) + if not p.exists(): + raise FileNotFoundError(f"OpenAPI schema file not found: {p}") + with p.open("r", encoding="utf-8") as f: + try: + return json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in schema file {p}: {e}") + + +def get_schema_fields() -> List[SchemaField]: + """Return SchemaField list sourced from MDR client or file based on env configuration. + + Resolution order: + 1) Attempt to load OpenAPI from MDR via get_openapi_lif_data_model() + 2) Else if LIF_OPENAPI_SCHEMA_PATH is set, load from file + """ + # Root resolution (optional); can be provided from env + root = os.getenv("LIF_OPENAPI_ROOT") or os.getenv("ROOT_NODE") or "Person" + + print(f"Root node for schema fields: {root}") + + # Preferred: MDR client provider + try: + logger.info("Loading OpenAPI schema via MDR client") + doc = mdr_core.get_openapi_lif_data_model() + return load_schema_nodes(doc, root) + except Exception as e: # noqa: BLE001 - surface clean fallback path + logger.warning("MDR client schema load failed: %s. Falling back...", e) + + # Fallback: file path provider + path = os.getenv("LIF_OPENAPI_SCHEMA_PATH") + if path: + logger.info("Loading OpenAPI schema from file: %s", path) + doc = _load_from_file(path) + return load_schema_nodes(doc, root) + + raise RuntimeError( + "No OpenAPI schema source configured. Provide MDR client implementation, " + "or set LIF_OPENAPI_SCHEMA_PATH." + ) diff --git a/components/lif/utils/__init__.py b/components/lif/utils/__init__.py new file mode 100644 index 0000000..9192b16 --- /dev/null +++ b/components/lif/utils/__init__.py @@ -0,0 +1,3 @@ +from lif.utils import core + +__all__ = ["core"] diff --git a/components/lif/utils/core.py b/components/lif/utils/core.py new file mode 100644 index 0000000..d84e4f3 --- /dev/null +++ b/components/lif/utils/core.py @@ -0,0 +1,84 @@ +"""Core utilities for the LIF system.""" + +import os +import sys +from typing import Dict, List + +from lif.exceptions.core import MissingEnvironmentVariableException + + +def check_required_env_vars( + required_vars: List[str], + raise_exception: bool = True, + logger=None +) -> Dict[str, str]: + """Check that all required environment variables are set. + + Args: + required_vars: List of environment variable names to check + raise_exception: If True, raise MissingEnvironmentVariableException. + If False, log critical error and exit with sys.exit(1) + logger: Optional logger instance for error reporting + + Returns: + Dict[str, str]: Dictionary mapping env var names to their values + + Raises: + MissingEnvironmentVariableException: If raise_exception=True and vars are missing + + Examples: + >>> env_vars = check_required_env_vars(["DATABASE_URL", "API_KEY"]) + >>> database_url = env_vars["DATABASE_URL"] + + >>> # For server applications that should exit on missing config + >>> check_required_env_vars(["CONFIG_FILE"], raise_exception=False, logger=logger) + """ + missing = [var for var in required_vars if not os.getenv(var)] + + if missing: + error_msg = f"Missing required environment variables: {', '.join(missing)}" + + if raise_exception: + # For libraries and components that should raise exceptions + if len(missing) == 1: + raise MissingEnvironmentVariableException(missing[0]) + else: + # For multiple missing vars, use a generic message + from lif.exceptions.core import LIFException + raise LIFException(error_msg) + else: + # For standalone applications that should exit + if logger: + logger.critical(error_msg) + else: + print(f"CRITICAL: {error_msg}", file=sys.stderr) + sys.exit(1) + + # Return the validated environment variables + result = {} + for var in required_vars: + value = os.getenv(var) + if value is not None: # We know this is true since we checked above + result[var] = value + return result + + +def get_required_env_var(var_name: str) -> str: + """Get a single required environment variable. + + Args: + var_name: Name of the environment variable + + Returns: + str: The environment variable value + + Raises: + MissingEnvironmentVariableException: If the variable is not set + + Examples: + >>> database_url = get_required_env_var("DATABASE_URL") + """ + value = os.getenv(var_name) + if not value: + raise MissingEnvironmentVariableException(var_name) + return value \ No newline at end of file diff --git a/components/lif/utils/strings.py b/components/lif/utils/strings.py new file mode 100644 index 0000000..c7ceb75 --- /dev/null +++ b/components/lif/utils/strings.py @@ -0,0 +1,123 @@ +import re +from typing import Any +from datetime import date, datetime + +from lif.logging import get_logger + +logger = get_logger(__name__) + + +def safe_identifier(name: str) -> str: + """Convert any string to a safe Python identifier (snake_case, no special chars). + + Args: + name (str): The input name string. + + Returns: + str: A valid Python identifier. + """ + # Replace any non-word characters with spaces to isolate tokens + cleaned = re.sub(r"[^0-9A-Za-z]+", " ", name).strip() + + tokens: list[str] = [] + for chunk in cleaned.split(): + # If the chunk contains camel-case boundaries or acronym patterns, split accordingly + if re.search(r"[a-z][A-Z]", chunk) or re.search(r"[A-Z]{2,}[a-z]", chunk): + # Use standard camelCase -> snake_case splitting for such chunks + s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", chunk) + s2 = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1) + tokens.append(s2.lower()) + else: + tokens.append(chunk.lower()) + + # Join with single underscores + result = "_".join(t for t in tokens if t) + # If result starts with a digit, prefix underscore + if result and result[0].isdigit(): + result = f"_{result}" + return result + + +def to_pascal_case(*parts: str) -> str: + """Convert strings or parts to PascalCase. + + Accepts any number of string parts. Each part can contain spaces, dashes, + underscores, or mixed case; all will be tokenized and joined as PascalCase. + + Args: + *parts (str): Parts to be converted. + + Returns: + str: PascalCase string. + """ + tokens: list[str] = [] + for p in parts: + if not p: + continue + s = p.replace("-", " ").replace("_", " ") + # Capture standard words, all-caps acronyms, and numbers + tokens += re.findall(r"[A-Za-z][a-z]*|[A-Z]+(?![a-z])|\d+", s) + return "".join(t[:1].upper() + t[1:] for t in tokens) + + +def to_snake_case(name: str) -> str: + """Converts CamelCase or PascalCase to snake_case.""" + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def to_camel_case(s: str) -> str: + """Convert string to lower camelCase.""" + s = re.sub(r"([_\-\s]+)([a-zA-Z])", lambda m: m.group(2).upper(), s) + if not s: + return s + return s[0].lower() + s[1:] + + +def camelcase_path(path: str) -> str: + """Convert a dotted path string to camelCase segments. + + Args: + path: The dot-separated path. + + Returns: + The camelCase path. + """ + return ".".join([to_camel_case(p) for p in path.split(".")]) + + +def dict_keys_to_snake(obj: Any) -> Any: + """Recursively converts dict keys to snake_case.""" + if isinstance(obj, list): + return [dict_keys_to_snake(item) for item in obj] + if isinstance(obj, dict): + return {to_snake_case(k): dict_keys_to_snake(v) for k, v in obj.items()} + return obj + + +def dict_keys_to_camel(obj: Any) -> Any: + """Recursively converts dict keys to camelCase.""" + if isinstance(obj, list): + return [dict_keys_to_camel(item) for item in obj] + if isinstance(obj, dict): + return {to_camel_case(k): dict_keys_to_camel(v) for k, v in obj.items()} + return obj + + +def convert_dates_to_strings(obj: Any) -> Any: + """Recursively converts dict date and datetime to strings.""" + if isinstance(obj, dict): + return {k: convert_dates_to_strings(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_dates_to_strings(item) for item in obj] + elif isinstance(obj, (date, datetime)): + return obj.isoformat() + else: + return obj + + +def to_value_enum_name(label: str) -> str: + """Convert a string label to a valid Python Enum name.""" + key = str(label).upper() + key = re.sub(r"\W|^(?=\d)", "_", key) + return key diff --git a/components/lif/utils/validation.py b/components/lif/utils/validation.py new file mode 100644 index 0000000..c7d15a4 --- /dev/null +++ b/components/lif/utils/validation.py @@ -0,0 +1,66 @@ +"""Data validation utilities for the LIF system.""" + +from typing import Any, Set + +# Truthy value constants +TRUTHY_VALUES: Set[str] = {"1", "true", "yes", "on", "y"} +FALSY_VALUES: Set[str] = {"0", "false", "no", "off", "n"} + + +def is_truthy(value: Any) -> bool: + """Check if a value represents a truthy state. + + Args: + value: The value to evaluate (string, bool, int, etc.) + + Returns: + bool: True if the value is considered truthy + + Examples: + >>> is_truthy("yes") + True + >>> is_truthy("1") + True + >>> is_truthy("false") + False + >>> is_truthy(None) + False + """ + if value is None: + return False + if isinstance(value, bool): + return value + return str(value).strip().lower() in TRUTHY_VALUES + + +def is_falsy(value: Any) -> bool: + """Check if a value represents a falsy state. + + Args: + value: The value to evaluate + + Returns: + bool: True if the value is considered falsy + """ + if value is None: + return True + if isinstance(value, bool): + return not value + return str(value).strip().lower() in FALSY_VALUES + + +def to_bool(value: Any, default: bool = False) -> bool: + """Convert a value to boolean with explicit truthy/falsy evaluation. + + Args: + value: The value to convert + default: Default value if the input is ambiguous + + Returns: + bool: The boolean representation + """ + if is_truthy(value): + return True + if is_falsy(value): + return False + return default diff --git a/projects/lif_graphql_api/pyproject.toml b/projects/lif_graphql_api/pyproject.toml index a93aeb3..fe321f3 100644 --- a/projects/lif_graphql_api/pyproject.toml +++ b/projects/lif_graphql_api/pyproject.toml @@ -29,3 +29,9 @@ packages = ["lif"] "../../components/lif/string_utils" = "lif/string_utils" "../../components/lif/logging" = "lif/logging" "../../components/lif/exceptions" = "lif/exceptions" +"../../components/lif/utils" = "lif/utils" +"../../components/lif/datatypes" = "lif/datatypes" +"../../components/lif/openapi_schema" = "lif/openapi_schema" +"../../components/lif/dynamic_models" = "lif/dynamic_models" +"../../components/lif/schema" = "lif/schema" +"../../components/lif/graphql" = "lif/graphql" diff --git a/pyproject.toml b/pyproject.toml index 45bdf63..f9c2001 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,7 +112,10 @@ dev-dependencies = ["polylith-cli"] "components/lif/mdr_utils" = "lif/mdr_utils" "components/lif/mdr_dto" = "lif/mdr_dto" "components/lif/schema" = "lif/schema" +"components/lif/utils" = "lif/utils" "components/lif/dynamic_models" = "lif/dynamic_models" +"components/lif/openapi_schema" = "lif/openapi_schema" +"components/lif/graphql" = "lif/graphql" [tool.ruff] line-length = 120 diff --git a/test/components/lif/graphql/__init__.py b/test/components/lif/graphql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/components/lif/graphql/test_core.py b/test/components/lif/graphql/test_core.py new file mode 100644 index 0000000..d6f6e06 --- /dev/null +++ b/test/components/lif/graphql/test_core.py @@ -0,0 +1,5 @@ +from lif.graphql import core + + +def test_sample(): + assert core is not None diff --git a/test/components/lif/graphql/test_schema.py b/test/components/lif/graphql/test_schema.py new file mode 100644 index 0000000..759397f --- /dev/null +++ b/test/components/lif/graphql/test_schema.py @@ -0,0 +1,48 @@ +import asyncio + +from lif.graphql.schema_factory import build_schema + + +class FakeBackend: + async def query(self, filter_dict, selected_fields): + return [] + + async def update(self, filter_dict, input_dict, selected_fields): + return [] + + +def test_build_schema_smoke(monkeypatch): + # Configure dynamic models to read our small test schema + monkeypatch.setenv("OPENAPI_SCHEMA_FILE", "test_openapi_schema.json") + monkeypatch.setenv("ROOT_NODE", "Person") + + schema = build_schema(root_node="Person", backend=FakeBackend()) + s = schema.as_str() + assert "type Query" in s + assert "type Mutation" in s + + +def test_execute_query_smoke(monkeypatch): + monkeypatch.setenv("OPENAPI_SCHEMA_FILE", "test_openapi_schema.json") + monkeypatch.setenv("ROOT_NODE", "Person") + + class CapturingBackend(FakeBackend): + def __init__(self): + self.seen = [] + + async def query(self, filter_dict, selected_fields): + self.seen.append((filter_dict, tuple(selected_fields))) + return [] + + backend = CapturingBackend() + schema = build_schema(root_node="Person", backend=backend) + + query = """ + query Q($f: PersonFilterInput!) { + persons(filter: $f) { person { identifier name } } + } + """ + variables = {"f": {"person": {"identifier": "123"}}} + result = asyncio.run(schema.execute(query, variable_values=variables)) + assert result.errors is None + assert backend.seen, "backend was not called" diff --git a/test/components/lif/graphql/test_type_registry.py b/test/components/lif/graphql/test_type_registry.py new file mode 100644 index 0000000..ec3e7ef --- /dev/null +++ b/test/components/lif/graphql/test_type_registry.py @@ -0,0 +1,346 @@ +""" +Unit tests for the TypeRegistry singleton class. +""" + +import pytest +from unittest.mock import Mock, patch + +from lif.graphql.type_registry import TypeRegistry, type_registry + + +class TestTypeRegistry: + """Test the TypeRegistry singleton class.""" + + def setup_method(self): + """Reset the singleton state before each test.""" + # Reset the singleton instance for clean tests + TypeRegistry._instance = None + TypeRegistry._initialized = False + # Also reset the global instance + type_registry.__dict__.clear() + type_registry._initialized = False + + def test_singleton_behavior(self): + """Test that TypeRegistry behaves as a singleton.""" + # Create first instance + registry1 = TypeRegistry() + + # Create second instance + registry2 = TypeRegistry() + + # They should be the same object + assert registry1 is registry2 + assert id(registry1) == id(registry2) + + def test_global_instance_is_singleton(self): + """Test that the global type_registry instance follows singleton pattern.""" + # Create a new instance + new_registry = TypeRegistry() + + # The global instance should be the same + assert new_registry is type_registry + + def test_initialization_only_happens_once(self): + """Test that __init__ only initializes attributes once.""" + registry1 = TypeRegistry() + + # Add some data to test persistence + registry1.strawberry_types["test"] = Mock() + + # Create another instance + registry2 = TypeRegistry() + + # Data should persist (same instance) + assert "test" in registry2.strawberry_types + assert registry2.strawberry_types["test"] == registry1.strawberry_types["test"] + + @patch("lif.graphql.type_registry.get_schema_fields") + @patch("lif.graphql.type_registry.build_filter_models") + @patch("lif.graphql.type_registry.build_mutation_models") + @patch("lif.graphql.type_registry.build_full_models") + def test_initialize_types_called_once(self, mock_full, mock_mutation, mock_filter, mock_schema): + """Test that initialize_types only builds models once.""" + # Setup mocks + mock_schema.return_value = {"field1": "value1"} + mock_filter.return_value = {"PersonFilter": Mock()} + mock_mutation.return_value = {"PersonMutation": Mock()} + mock_full.return_value = {"person": Mock(), "person_wrapper": Mock()} + + registry = TypeRegistry() + + # First call should build models + registry.initialize_types("Person", minimal_mode=True) + + # Verify mocks were called + assert mock_schema.called + assert mock_filter.called + assert mock_mutation.called + assert mock_full.called + + # Reset mock call counts + mock_schema.reset_mock() + mock_filter.reset_mock() + mock_mutation.reset_mock() + mock_full.reset_mock() + + # Second call should not rebuild models + registry.initialize_types("Person", minimal_mode=True) + + # Verify mocks were NOT called again + assert not mock_schema.called + assert not mock_filter.called + assert not mock_mutation.called + assert not mock_full.called + + @patch("lif.graphql.type_registry.get_schema_fields") + @patch("lif.graphql.type_registry.build_filter_models") + @patch("lif.graphql.type_registry.build_mutation_models") + @patch("lif.graphql.type_registry.build_full_models") + def test_initialize_types_stores_pydantic_models(self, mock_full, mock_mutation, mock_filter, mock_schema): + """Test that initialize_types stores Pydantic models correctly.""" + # Setup mocks + mock_schema.return_value = {"field1": "value1"} + filter_models = {"PersonFilter": Mock()} + mutation_models = {"PersonMutation": Mock()} + full_models = {"person": Mock(), "person_wrapper": Mock()} + + mock_filter.return_value = filter_models + mock_mutation.return_value = mutation_models + mock_full.return_value = full_models + + registry = TypeRegistry() + registry.initialize_types("Person", minimal_mode=True) + + # Check that models are stored + assert registry.pydantic_models["filter_models"] == filter_models + assert registry.pydantic_models["mutation_models"] == mutation_models + assert registry.pydantic_models["full_models"] == full_models + + @patch("lif.graphql.type_registry.get_schema_fields") + @patch("lif.graphql.type_registry.build_filter_models") + @patch("lif.graphql.type_registry.build_mutation_models") + @patch("lif.graphql.type_registry.build_full_models") + @patch("lif.graphql.type_registry.pyd_type") + @patch("lif.graphql.type_registry.pyd_input") + def test_initialize_types_creates_strawberry_types( + self, mock_pyd_input, mock_pyd_type, mock_full, mock_mutation, mock_filter, mock_schema + ): + """Test that initialize_types creates Strawberry types when not in minimal mode.""" + # Setup mocks + mock_schema.return_value = {"field1": "value1"} + + # Create mock models + mock_person_model = Mock() + mock_filter_model = Mock() + mock_mutation_model = Mock() + + mock_filter.return_value = {"PersonFilter": mock_filter_model} + mock_mutation.return_value = {"PersonMutation": mock_mutation_model} + mock_full.return_value = {"person": mock_person_model, "person_wrapper": Mock()} + + # Setup strawberry type creation mocks + mock_strawberry_type = Mock() + mock_strawberry_input = Mock() + + mock_pyd_type.return_value = lambda cls: mock_strawberry_type + mock_pyd_input.return_value = lambda cls: mock_strawberry_input + + registry = TypeRegistry() + registry.initialize_types("Person", minimal_mode=False) + + # Verify strawberry types were created + assert "Person" in registry.strawberry_types + assert registry.strawberry_types["Person"] == mock_strawberry_type + + # Verify inputs were created (exact names depend on unique_type_name implementation) + assert len(registry.strawberry_inputs) > 0 + + def test_get_models_for_root_success(self): + """Test successful retrieval of models for a root node.""" + registry = TypeRegistry() + + # Setup mock models + mock_full_model = Mock() + mock_filter_model = Mock() + mock_mutation_model = Mock() + mock_wrapper_model = Mock() + + # Mock the pydantic_models structure + with patch.object( + registry, + "pydantic_models", + { + "full_models": {"person": mock_full_model, "person_wrapper": mock_wrapper_model}, + "filter_models": {"PersonFilter": mock_filter_model}, + "mutation_models": {"PersonMutation": mock_mutation_model}, + }, + ): + result = registry.get_models_for_root("Person") + + assert result["FullModel"] == mock_full_model + assert result["FilterModel"] == mock_filter_model + assert result["MutationModel"] == mock_mutation_model + assert result["WrapperModel"] == mock_wrapper_model + + def test_get_models_for_root_fallback_to_type_suffix(self): + """Test fallback to Type suffix when lower-camel-case key doesn't exist.""" + registry = TypeRegistry() + + mock_full_model = Mock() + mock_filter_model = Mock() + mock_mutation_model = Mock() + mock_wrapper_model = Mock() + + with patch.object( + registry, + "pydantic_models", + { + "full_models": { + "PersonType": mock_full_model, # Only Type suffix exists + "person_wrapper": mock_wrapper_model, + }, + "filter_models": {"PersonFilter": mock_filter_model}, + "mutation_models": {"PersonMutation": mock_mutation_model}, + }, + ): + result = registry.get_models_for_root("Person") + + assert result["FullModel"] == mock_full_model + + def test_get_models_for_root_missing_model_raises_error(self): + """Test that missing FullModel raises KeyError.""" + registry = TypeRegistry() + + with patch.object( + registry, + "pydantic_models", + { + "full_models": {"other_model": Mock()}, + "filter_models": {"PersonFilter": Mock()}, + "mutation_models": {"PersonMutation": Mock()}, + }, + ): + with pytest.raises(KeyError, match="Cannot locate FullModel for root 'Person'"): + registry.get_models_for_root("Person") + + @patch("lif.graphql.type_registry.unique_type_name") + @patch("lif.graphql.type_registry.to_pascal_case_from_str") + def test_get_strawberry_types_for_root(self, mock_pascal, mock_unique): + """Test retrieval of Strawberry types for a root node.""" + # Setup mocks + mock_pascal.side_effect = lambda x: x # Identity function for simplicity + mock_unique.return_value = "PersonFilterInput" + + registry = TypeRegistry() + + mock_root_type = Mock() + mock_filter_input = Mock() + mock_mutation_input = Mock() + + registry.strawberry_types = {"Person": mock_root_type} + registry.strawberry_inputs = { + "PersonFilterInput": mock_filter_input, + "PersonMutationInput": mock_mutation_input, + } + + result = registry.get_strawberry_types_for_root("Person") + + assert result["RootType"] == mock_root_type + assert result["FilterInput"] == mock_filter_input + + def test_empty_initialization(self): + """Test that a new registry starts with empty collections.""" + registry = TypeRegistry() + + assert registry.strawberry_types == {} + assert registry.strawberry_inputs == {} + assert registry.pydantic_models == {} + + @patch("lif.graphql.type_registry.get_schema_fields") + @patch("lif.graphql.type_registry.build_filter_models") + @patch("lif.graphql.type_registry.build_mutation_models") + @patch("lif.graphql.type_registry.build_full_models") + def test_minimal_mode_skips_strawberry_creation(self, mock_full, mock_mutation, mock_filter, mock_schema): + """Test that minimal mode skips Strawberry type creation.""" + # Setup mocks + mock_schema.return_value = {"field1": "value1"} + mock_filter.return_value = {"PersonFilter": Mock()} + mock_mutation.return_value = {"PersonMutation": Mock()} + mock_full.return_value = {"person": Mock(), "person_wrapper": Mock()} + + registry = TypeRegistry() + registry.initialize_types("Person", minimal_mode=True) + + # Should have Pydantic models but no Strawberry types + assert registry.pydantic_models + assert not registry.strawberry_types + assert not registry.strawberry_inputs + + def test_wrapper_types_are_skipped_in_strawberry_creation(self): + """Test that wrapper types are skipped when creating Strawberry types.""" + registry = TypeRegistry() + + # Mock the internal method call + full_models = { + "person": Mock(), + "person_wrapper": Mock(), + "organization": Mock(), + "organization_wrapper": Mock(), + } + filter_models = {"PersonFilter": Mock()} + mutation_models = {"PersonMutation": Mock()} + + with patch.object(registry, "_create_strawberry_types") as mock_create: + with patch.object( + registry, + "pydantic_models", + {"full_models": full_models, "filter_models": filter_models, "mutation_models": mutation_models}, + ): + # Manually call the method to test wrapper skipping logic + registry._create_strawberry_types(full_models, filter_models, mutation_models, "Person") + + # The method should have been called + assert mock_create.called + + +class TestGlobalTypeRegistryInstance: + """Test the global type_registry instance.""" + + def setup_method(self): + """Reset the singleton state before each test.""" + TypeRegistry._instance = None + TypeRegistry._initialized = False + type_registry.__dict__.clear() + type_registry._initialized = False + + def test_global_instance_exists(self): + """Test that the global type_registry instance exists.""" + assert type_registry is not None + assert isinstance(type_registry, TypeRegistry) + + def test_global_instance_is_singleton(self): + """Test that multiple references to type_registry are the same object.""" + from lif.graphql.type_registry import type_registry as imported_registry + + assert type_registry is imported_registry + + @patch("lif.graphql.type_registry.get_schema_fields") + @patch("lif.graphql.type_registry.build_filter_models") + @patch("lif.graphql.type_registry.build_mutation_models") + @patch("lif.graphql.type_registry.build_full_models") + def test_global_instance_functionality(self, mock_full, mock_mutation, mock_filter, mock_schema): + """Test that the global instance works correctly.""" + # Setup mocks + mock_schema.return_value = {"field1": "value1"} + mock_filter.return_value = {"PersonFilter": Mock()} + mock_mutation.return_value = {"PersonMutation": Mock()} + mock_full.return_value = {"person": Mock(), "person_wrapper": Mock()} + + # Use the global instance + type_registry.initialize_types("Person", minimal_mode=True) + + # Should have stored the models + assert type_registry.pydantic_models + assert "filter_models" in type_registry.pydantic_models + assert "mutation_models" in type_registry.pydantic_models + assert "full_models" in type_registry.pydantic_models diff --git a/test/components/lif/graphql/test_utils.py b/test/components/lif/graphql/test_utils.py new file mode 100644 index 0000000..a209d41 --- /dev/null +++ b/test/components/lif/graphql/test_utils.py @@ -0,0 +1,9 @@ +from lif.graphql import utils + + +def test_to_pascal_case_from_str(): + assert utils.to_pascal_case_from_str("person_name") == "PersonName" + + +def test_unique_type_name(): + assert utils.unique_type_name("Person", "FilterInput", "Person").endswith("FilterInput") diff --git a/test/components/lif/openapi_schema/__init__.py b/test/components/lif/openapi_schema/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/components/lif/openapi_schema/test_core.py b/test/components/lif/openapi_schema/test_core.py new file mode 100644 index 0000000..9f22ef2 --- /dev/null +++ b/test/components/lif/openapi_schema/test_core.py @@ -0,0 +1,53 @@ +import os +import json +from unittest.mock import patch + +import pytest + +from lif.openapi_schema.core import get_schema_fields + + +def test_get_schema_fields_from_file(tmp_path): + # Create a temporary OpenAPI file + doc = { + "openapi": "3.0.0", + "components": {"schemas": {"Person": {"type": "object", "properties": {"Name": {"type": "string"}}}}}, + } + p = tmp_path / "openapi.json" + p.write_text(json.dumps(doc), encoding="utf-8") + + with patch.dict(os.environ, {"LIF_OPENAPI_SCHEMA_PATH": str(p), "LIF_OPENAPI_ROOT": "Person"}, clear=True): + fields = get_schema_fields() + assert any(f.json_path.lower().startswith("person") for f in fields) + + +def test_get_schema_fields_from_mdr(monkeypatch): + # Patch MDR client to return a small OpenAPI doc + def fake_get_openapi_lif_data_model(): + return { + "openapi": "3.0.0", + "components": {"schemas": {"Person": {"type": "object", "properties": {"Name": {"type": "string"}}}}}, + } + + with patch("lif.mdr_client.core.get_openapi_lif_data_model", fake_get_openapi_lif_data_model): + with patch.dict(os.environ, {"LIF_OPENAPI_ROOT": "Person"}, clear=True): + fields = get_schema_fields() + assert any(f.json_path.lower().startswith("person") for f in fields) + + +def test_get_schema_fields_legacy(tmp_path): + # Uses repo test_data path convention + with patch.dict(os.environ, {"OPENAPI_SCHEMA_FILE": "test_openapi_schema.json", "ROOT_NODE": "Person"}, clear=True): + fields = get_schema_fields() + assert any(f.json_path.lower().startswith("person") for f in fields) + + +def test_no_configuration_raises(): + # Patch MDR to raise, and no fallbacks set + def fake_get_openapi_lif_data_model(): + raise RuntimeError("mdr unavailable") + + with patch("lif.mdr_client.core.get_openapi_lif_data_model", fake_get_openapi_lif_data_model): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(RuntimeError): + get_schema_fields() diff --git a/test/components/lif/utils/__init__.py b/test/components/lif/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/components/lif/utils/test_core.py b/test/components/lif/utils/test_core.py new file mode 100644 index 0000000..bb7f7fd --- /dev/null +++ b/test/components/lif/utils/test_core.py @@ -0,0 +1,46 @@ +import os +import pytest +from unittest.mock import patch + +from lif.exceptions.core import MissingEnvironmentVariableException +from lif.utils.core import check_required_env_vars, get_required_env_var + + +class TestEnvironmentUtilities: + """Test environment variable utilities in core module.""" + + def test_get_required_env_var_success(self): + """Test getting a required environment variable that exists.""" + with patch.dict(os.environ, {"TEST_VAR": "test_value"}): + result = get_required_env_var("TEST_VAR") + assert result == "test_value" + + def test_get_required_env_var_missing(self): + """Test getting a required environment variable that doesn't exist.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingEnvironmentVariableException) as exc_info: + get_required_env_var("MISSING_VAR") + assert "MISSING_VAR" in str(exc_info.value) + + def test_check_required_env_vars_success(self): + """Test checking multiple required environment variables that exist.""" + test_env = {"VAR1": "value1", "VAR2": "value2", "VAR3": "value3"} + with patch.dict(os.environ, test_env): + result = check_required_env_vars(["VAR1", "VAR2", "VAR3"]) + assert result == test_env + + def test_check_required_env_vars_missing_single(self): + """Test checking required environment variables with one missing.""" + with patch.dict(os.environ, {"VAR1": "value1"}, clear=True): + with pytest.raises(MissingEnvironmentVariableException) as exc_info: + check_required_env_vars(["VAR1", "MISSING_VAR"]) + assert "MISSING_VAR" in str(exc_info.value) + + def test_check_required_env_vars_missing_multiple(self): + """Test checking required environment variables with multiple missing.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(Exception) as exc_info: + check_required_env_vars(["MISSING1", "MISSING2"]) + assert "Missing required environment variables" in str(exc_info.value) + assert "MISSING1" in str(exc_info.value) + assert "MISSING2" in str(exc_info.value) diff --git a/test/components/lif/utils/test_strings.py b/test/components/lif/utils/test_strings.py new file mode 100644 index 0000000..aeb2210 --- /dev/null +++ b/test/components/lif/utils/test_strings.py @@ -0,0 +1,94 @@ +from datetime import date, datetime + +from lif.string_utils import ( + safe_identifier, + to_pascal_case, + to_snake_case, + to_camel_case, + camelcase_path, + dict_keys_to_snake, + dict_keys_to_camel, + convert_dates_to_strings, + to_value_enum_name, +) + + +class TestSafeIdentifier: + def test_basic(self): + assert safe_identifier("First Name") == "first_name" + assert safe_identifier("first-name") == "first_name" + assert safe_identifier("first$name") == "first_name" + + def test_leading_digit(self): + assert safe_identifier("123abc") == "_123abc" + + def test_camel_pascal(self): + assert safe_identifier("CamelCase") == "camel_case" + assert safe_identifier("camelCaseABC") == "camel_case_abc" + + +class TestToPascalCase: + def test_single_part(self): + assert to_pascal_case("hello world") == "HelloWorld" + assert to_pascal_case("hello-world") == "HelloWorld" + assert to_pascal_case("hello_world") == "HelloWorld" + + def test_multiple_parts(self): + assert to_pascal_case("hello", "world") == "HelloWorld" + assert to_pascal_case("HTTP", "status", "200") == "HTTPStatus200" + + def test_mixed_case(self): + assert to_pascal_case("camelCase") == "CamelCase" + assert to_pascal_case("PascalCase") == "PascalCase" + + +class TestToSnakeCase: + def test_basic(self): + assert to_snake_case("CamelCase") == "camel_case" + assert to_snake_case("camelCase") == "camel_case" + + def test_with_acronyms(self): + assert to_snake_case("HTTPServerID") == "http_server_id" + + +class TestToCamelCase: + def test_basic(self): + assert to_camel_case("hello_world") == "helloWorld" + assert to_camel_case("Hello World") == "helloWorld" + assert to_camel_case("hello-world") == "helloWorld" + + def test_empty(self): + assert to_camel_case("") == "" + + +class TestCamelcasePath: + def test_path(self): + assert camelcase_path("a.b_c.d-e f") == "a.bC.dEF" + + +class TestDictKeyTransforms: + def test_to_snake(self): + data = {"FirstName": "Alice", "Address": {"zipCode": 12345}, "items": [{"itemID": 1}]} + out = dict_keys_to_snake(data) + assert out == {"first_name": "Alice", "address": {"zip_code": 12345}, "items": [{"item_id": 1}]} + + def test_to_camel(self): + data = {"first_name": "Bob", "address": {"zip_code": 12345}, "items": [{"item_id": 1}]} + out = dict_keys_to_camel(data) + assert out == {"firstName": "Bob", "address": {"zipCode": 12345}, "items": [{"itemId": 1}]} + + +class TestConvertDatesToStrings: + def test_nested(self): + d = date(2020, 1, 2) + dt = datetime(2020, 1, 2, 3, 4, 5) + obj = {"when": d, "arr": [dt, {"n": 1}]} + out = convert_dates_to_strings(obj) + assert out == {"when": d.isoformat(), "arr": [dt.isoformat(), {"n": 1}]} + + +class TestToValueEnumName: + def test_basic(self): + assert to_value_enum_name("in progress") == "IN_PROGRESS" + assert to_value_enum_name("done!") == "DONE_" + assert to_value_enum_name("123start") == "_123START" diff --git a/test/components/lif/utils/test_validation.py b/test/components/lif/utils/test_validation.py new file mode 100644 index 0000000..519b776 --- /dev/null +++ b/test/components/lif/utils/test_validation.py @@ -0,0 +1,117 @@ +import os +import pytest +from unittest.mock import patch + +from lif.exceptions.core import MissingEnvironmentVariableException +from lif.utils.core import ( + check_required_env_vars, + get_required_env_var, +) +from lif.utils.validation import ( + is_truthy, + is_falsy, + to_bool, +) + + +class TestEnvironmentValidation: + """Test environment variable validation functions.""" + + def test_get_required_env_var_success(self): + """Test getting a required environment variable that exists.""" + with patch.dict(os.environ, {"TEST_VAR": "test_value"}): + result = get_required_env_var("TEST_VAR") + assert result == "test_value" + + def test_get_required_env_var_missing(self): + """Test getting a required environment variable that doesn't exist.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingEnvironmentVariableException) as exc_info: + get_required_env_var("MISSING_VAR") + assert "MISSING_VAR" in str(exc_info.value) + + def test_check_required_env_vars_success(self): + """Test checking multiple required environment variables that exist.""" + test_env = { + "VAR1": "value1", + "VAR2": "value2", + "VAR3": "value3" + } + with patch.dict(os.environ, test_env): + result = check_required_env_vars(["VAR1", "VAR2", "VAR3"]) + assert result == test_env + + def test_check_required_env_vars_missing_single(self): + """Test checking required environment variables with one missing.""" + with patch.dict(os.environ, {"VAR1": "value1"}, clear=True): + with pytest.raises(MissingEnvironmentVariableException) as exc_info: + check_required_env_vars(["VAR1", "MISSING_VAR"]) + assert "MISSING_VAR" in str(exc_info.value) + + def test_check_required_env_vars_missing_multiple(self): + """Test checking required environment variables with multiple missing.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(Exception) as exc_info: + check_required_env_vars(["MISSING1", "MISSING2"]) + assert "Missing required environment variables" in str(exc_info.value) + assert "MISSING1" in str(exc_info.value) + assert "MISSING2" in str(exc_info.value) + + +class TestTruthyFalsy: + """Test truthy/falsy validation functions.""" + + def test_is_truthy(self): + """Test is_truthy function with various inputs.""" + # Truthy values + assert is_truthy("true") is True + assert is_truthy("TRUE") is True + assert is_truthy("yes") is True + assert is_truthy("1") is True + assert is_truthy("on") is True + assert is_truthy("y") is True + assert is_truthy(True) is True + + # Falsy values + assert is_truthy("false") is False + assert is_truthy("no") is False + assert is_truthy("0") is False + assert is_truthy("off") is False + assert is_truthy(False) is False + assert is_truthy(None) is False + assert is_truthy("") is False + + def test_is_falsy(self): + """Test is_falsy function with various inputs.""" + # Falsy values + assert is_falsy("false") is True + assert is_falsy("FALSE") is True + assert is_falsy("no") is True + assert is_falsy("0") is True + assert is_falsy("off") is True + assert is_falsy("n") is True + assert is_falsy(False) is True + assert is_falsy(None) is True + + # Truthy values + assert is_falsy("true") is False + assert is_falsy("yes") is False + assert is_falsy("1") is False + assert is_falsy(True) is False + + def test_to_bool(self): + """Test to_bool function with various inputs.""" + # Truthy values + assert to_bool("true") is True + assert to_bool("yes") is True + assert to_bool("1") is True + + # Falsy values + assert to_bool("false") is False + assert to_bool("no") is False + assert to_bool("0") is False + + # Ambiguous values use default + assert to_bool("maybe") is False # default is False + assert to_bool("maybe", default=True) is True + assert to_bool("random", default=False) is False