|
1 | 1 | import enum |
2 | 2 | import json |
| 3 | +import os |
3 | 4 | import pickle |
| 5 | +import sys |
4 | 6 | import typing |
5 | 7 | import unittest |
6 | 8 | from datetime import date, datetime, time, timedelta |
|
17 | 19 | from csp.impl.types.typing_utils import FastList |
18 | 20 | from csp.typing import Numpy1DArray |
19 | 21 |
|
| 22 | +USE_PYDANTIC = os.environ.get("CSP_PYDANTIC", True) |
20 | 23 |
|
21 | 24 | class MyEnum(csp.Enum): |
22 | 25 | A = 1 |
@@ -3903,6 +3906,297 @@ class DataPoint(csp.Struct): |
3903 | 3906 | self.assertNotIn("_last_updated", json_data) |
3904 | 3907 | self.assertNotIn("_source", json_data["data"]) |
3905 | 3908 |
|
| 3909 | + def test_literal_types_validation(self): |
| 3910 | + """Test that Literal type annotations correctly validate input values in CSP Structs""" |
| 3911 | + # Define a simple class with various Literal types |
| 3912 | + class StructWithLiterals(csp.Struct): |
| 3913 | + # String literals |
| 3914 | + color: Literal["red", "green", "blue"] |
| 3915 | + # Integer literals |
| 3916 | + size: Literal[1, 2, 3] |
| 3917 | + # Mixed type literals |
| 3918 | + status: Literal["on", "off", 0, 1, True, False] |
| 3919 | + # Optional literal with default |
| 3920 | + mode: Optional[Literal["fast", "slow"]] = "fast" |
| 3921 | + |
| 3922 | + # Test valid assignments |
| 3923 | + s1 = StructWithLiterals(color="red", size=2, status="on") |
| 3924 | + self.assertEqual(s1.color, "red") |
| 3925 | + self.assertEqual(s1.size, 2) |
| 3926 | + self.assertEqual(s1.status, "on") |
| 3927 | + self.assertEqual(s1.mode, "fast") # Default value |
| 3928 | + |
| 3929 | + # Test another valid instance with different values |
| 3930 | + s2 = StructWithLiterals(color="blue", size=1, status=True, mode="slow") |
| 3931 | + self.assertEqual(s2.color, "blue") |
| 3932 | + self.assertEqual(s2.size, 1) |
| 3933 | + self.assertEqual(s2.status, True) |
| 3934 | + self.assertEqual(s2.mode, "slow") |
| 3935 | + |
| 3936 | + # Test direct assignment |
| 3937 | + s3 = StructWithLiterals(color="green", size=3, status=0) |
| 3938 | + s3.color = "blue" |
| 3939 | + s3.size = 2 |
| 3940 | + s3.status = False |
| 3941 | + self.assertEqual(s3.color, "blue") |
| 3942 | + self.assertEqual(s3.size, 2) |
| 3943 | + self.assertEqual(s3.status, False) |
| 3944 | + |
| 3945 | + # This should fail! But old csp type checking doesnt catch |
| 3946 | + StructWithLiterals(color="yellow", size=1, status="on") # Invalid color |
| 3947 | + |
| 3948 | + # This should fail! But old csp type checking doesnt catch |
| 3949 | + StructWithLiterals(color="red", size=4, status="on") # Invalid size |
| 3950 | + |
| 3951 | + # This should fail! But old csp type checking doesnt catch |
| 3952 | + StructWithLiterals(color="red", size=1, status="standby") # Invalid status |
| 3953 | + |
| 3954 | + # Test with Pydantic validation |
| 3955 | + if USE_PYDANTIC: |
| 3956 | + # Test valid values |
| 3957 | + result = TypeAdapter(StructWithLiterals).validate_python({ |
| 3958 | + "color": "green", "size": 3, "status": 0 |
| 3959 | + }) |
| 3960 | + self.assertEqual(result.color, "green") |
| 3961 | + self.assertEqual(result.size, 3) |
| 3962 | + self.assertEqual(result.status, 0) |
| 3963 | + |
| 3964 | + # Test invalid color with Pydantic validation |
| 3965 | + with self.assertRaises(ValidationError) as exc_info: |
| 3966 | + TypeAdapter(StructWithLiterals).validate_python({ |
| 3967 | + "color": "yellow", "size": 1, "status": "on" |
| 3968 | + }) |
| 3969 | + self.assertIn('1 validation error for', str(exc_info.exception)) |
| 3970 | + self.assertIn('color', str(exc_info.exception)) |
| 3971 | + |
| 3972 | + # Test invalid size with Pydantic validation |
| 3973 | + with self.assertRaises(ValidationError) as exc_info: |
| 3974 | + TypeAdapter(StructWithLiterals).validate_python({ |
| 3975 | + "color": "red", "size": 4, "status": "on" |
| 3976 | + }) |
| 3977 | + self.assertIn('1 validation error for', str(exc_info.exception)) |
| 3978 | + self.assertIn('size', str(exc_info.exception)) |
| 3979 | + |
| 3980 | + # Test invalid status with Pydantic validation |
| 3981 | + with self.assertRaises(ValidationError) as exc_info: |
| 3982 | + TypeAdapter(StructWithLiterals).validate_python({ |
| 3983 | + "color": "red", "size": 1, "status": "standby" |
| 3984 | + }) |
| 3985 | + self.assertIn('1 validation error for', str(exc_info.exception)) |
| 3986 | + self.assertIn('status', str(exc_info.exception)) |
| 3987 | + |
| 3988 | + # Test invalid mode with Pydantic validation |
| 3989 | + with self.assertRaises(ValidationError) as exc_info: |
| 3990 | + TypeAdapter(StructWithLiterals).validate_python({ |
| 3991 | + "color": "red", "size": 1, "status": "on", "mode": "medium" |
| 3992 | + }) |
| 3993 | + self.assertIn('1 validation error for', str(exc_info.exception)) |
| 3994 | + self.assertIn('mode', str(exc_info.exception)) |
| 3995 | + # Test serialization and deserialization preserves literal values |
| 3996 | + result = TypeAdapter(StructWithLiterals).validate_python({ |
| 3997 | + "color": "green", "size": 3, "status": 0 |
| 3998 | + }) |
| 3999 | + json_data = TypeAdapter(StructWithLiterals).dump_json(result) |
| 4000 | + restored = TypeAdapter(StructWithLiterals).validate_json(json_data) |
| 4001 | + self.assertEqual(restored.color, "green") |
| 4002 | + self.assertEqual(restored.size, 3) |
| 4003 | + self.assertEqual(restored.status, 0) |
| 4004 | + |
| 4005 | + def test_literal_in_complex_structures(self): |
| 4006 | + """Test Literal type annotations in more complex structures with nesting and containers""" |
| 4007 | + # Define a class using Literal in collection types and nested structs |
| 4008 | + class Configuration(csp.Struct): |
| 4009 | + mode: Literal["debug", "production", "test"] |
| 4010 | + |
| 4011 | + class ItemType(csp.Enum): |
| 4012 | + WEAPON = 1 |
| 4013 | + ARMOR = 2 |
| 4014 | + POTION = 3 |
| 4015 | + |
| 4016 | + class Item(csp.Struct): |
| 4017 | + name: str |
| 4018 | + type: ItemType |
| 4019 | + rarity: Literal["common", "uncommon", "rare", "epic", "legendary"] |
| 4020 | + |
| 4021 | + class Character(csp.Struct): |
| 4022 | + name: str |
| 4023 | + # Literal in list |
| 4024 | + classes: List[Literal["warrior", "mage", "rogue"]] |
| 4025 | + # Literal in dictionary values |
| 4026 | + attributes: Dict[str, Literal[1, 2, 3, 4, 5]] |
| 4027 | + # Nested struct with literal |
| 4028 | + config: Configuration |
| 4029 | + # List of nested structs with literals |
| 4030 | + inventory: List[Item] |
| 4031 | + |
| 4032 | + # Create valid instance with various literal usages |
| 4033 | + character = Character( |
| 4034 | + name="Test Character", |
| 4035 | + classes=["warrior", "mage"], |
| 4036 | + attributes={"strength": 5, "intelligence": 3, "dexterity": 4}, |
| 4037 | + config=Configuration(mode="debug"), |
| 4038 | + inventory=[ |
| 4039 | + Item(name="Sword", type=ItemType.WEAPON, rarity="common"), |
| 4040 | + Item(name="Health Potion", type=ItemType.POTION, rarity="rare") |
| 4041 | + ] |
| 4042 | + ) |
| 4043 | + |
| 4044 | + # Test data is correctly set |
| 4045 | + self.assertEqual(character.name, "Test Character") |
| 4046 | + self.assertEqual(character.classes, ["warrior", "mage"]) |
| 4047 | + self.assertEqual(character.attributes, {"strength": 5, "intelligence": 3, "dexterity": 4}) |
| 4048 | + self.assertEqual(character.config.mode, "debug") |
| 4049 | + self.assertEqual(len(character.inventory), 2) |
| 4050 | + self.assertEqual(character.inventory[0].rarity, "common") |
| 4051 | + self.assertEqual(character.inventory[1].rarity, "rare") |
| 4052 | + |
| 4053 | + # This should fail! But default csp struct type checking doesnt catch |
| 4054 | + Configuration(mode="invalid") |
| 4055 | + |
| 4056 | + # This should fail! But default csp struct type checking doesnt catch |
| 4057 | + Item(name="Bad Item", type=ItemType.ARMOR, rarity="unknown") |
| 4058 | + |
| 4059 | + # This should fail! But we dont check on mutation |
| 4060 | + character.classes.append("paladin") # Invalid class |
| 4061 | + |
| 4062 | + # This should fail! But we dont check on mutation |
| 4063 | + character.attributes["wisdom"] = 6 # Value out of range |
| 4064 | + |
| 4065 | + if USE_PYDANTIC: |
| 4066 | + # Test valid nested data |
| 4067 | + data = { |
| 4068 | + "name": "Pydantic Character", |
| 4069 | + "classes": ["rogue", "warrior"], |
| 4070 | + "attributes": {"strength": 2, "wisdom": 4}, |
| 4071 | + "config": {"mode": "production"}, |
| 4072 | + "inventory": [ |
| 4073 | + {"name": "Shield", "type": ItemType.ARMOR, "rarity": "uncommon"} |
| 4074 | + ] |
| 4075 | + } |
| 4076 | + result = TypeAdapter(Character).validate_python(data) |
| 4077 | + self.assertEqual(result.name, "Pydantic Character") |
| 4078 | + self.assertEqual(result.classes, ["rogue", "warrior"]) |
| 4079 | + self.assertEqual(result.config.mode, "production") |
| 4080 | + self.assertEqual(result.inventory[0].rarity, "uncommon") |
| 4081 | + |
| 4082 | + # Test invalid literal in nested structure |
| 4083 | + invalid_data = data.copy() |
| 4084 | + invalid_data["config"] = {"mode": "invalid_mode"} |
| 4085 | + with self.assertRaises(ValidationError) as exc_info: |
| 4086 | + TypeAdapter(Character).validate_python(invalid_data) |
| 4087 | + |
| 4088 | + # Test serialization/deserialization round trip |
| 4089 | + round_trip = TypeAdapter(Character).validate_python( |
| 4090 | + TypeAdapter(Character).dump_python(result) |
| 4091 | + ) |
| 4092 | + self.assertEqual(round_trip.name, result.name) |
| 4093 | + self.assertEqual(round_trip.classes, result.classes) |
| 4094 | + self.assertEqual(round_trip.config.mode, result.config.mode) |
| 4095 | + self.assertEqual(round_trip.inventory[0].rarity, result.inventory[0].rarity) |
| 4096 | + |
| 4097 | + def test_pipe_operator_types(self): |
| 4098 | + """Test using the pipe operator for union types in Python 3.10+""" |
| 4099 | + if sys.version_info >= (3, 10): # Only run on Python 3.10+ |
| 4100 | + # Define a class using various pipe operator combinations |
| 4101 | + class PipeTypesConfig(csp.Struct): |
| 4102 | + # Basic primitive types with pipe |
| 4103 | + id_field: str | int |
| 4104 | + # Pipe with None (similar to Optional) |
| 4105 | + description: str | None = None |
| 4106 | + # Multiple types with pipe |
| 4107 | + value: str | int | float | bool |
| 4108 | + # Container with pipe |
| 4109 | + tags: List[str] | Dict[str, str] | None = None |
| 4110 | + # Pipe with literal for comparison |
| 4111 | + status: Literal["active", "inactive"] | None = "active" |
| 4112 | + |
| 4113 | + # Test with string ID |
| 4114 | + p1 = PipeTypesConfig(id_field="abc123", value="test_value") |
| 4115 | + self.assertEqual(p1.id_field, "abc123") |
| 4116 | + self.assertIsNone(p1.description) |
| 4117 | + self.assertEqual(p1.value, "test_value") |
| 4118 | + self.assertIsNone(p1.tags) |
| 4119 | + self.assertEqual(p1.status, "active") |
| 4120 | + |
| 4121 | + # Test with integer ID |
| 4122 | + p2 = PipeTypesConfig(id_field=42, value=3.14, description="A config") |
| 4123 | + self.assertEqual(p2.id_field, 42) |
| 4124 | + self.assertEqual(p2.description, "A config") |
| 4125 | + self.assertEqual(p2.value, 3.14) |
| 4126 | + |
| 4127 | + # Test with boolean value and list tags |
| 4128 | + p3 = PipeTypesConfig(id_field=99, value=True, tags=["tag1", "tag2"]) |
| 4129 | + self.assertEqual(p3.id_field, 99) |
| 4130 | + self.assertEqual(p3.value, True) |
| 4131 | + self.assertEqual(p3.tags, ["tag1", "tag2"]) |
| 4132 | + |
| 4133 | + # Test with dict tags |
| 4134 | + p4 = PipeTypesConfig(id_field="xyz", value=42, tags={"key1": "val1", "key2": "val2"}) |
| 4135 | + self.assertEqual(p4.id_field, "xyz") |
| 4136 | + self.assertEqual(p4.value, 42) |
| 4137 | + self.assertEqual(p4.tags, {"key1": "val1", "key2": "val2"}) |
| 4138 | + |
| 4139 | + # Test direct assignment |
| 4140 | + p5 = PipeTypesConfig(id_field="test", value=1) |
| 4141 | + p5.id_field = 100 |
| 4142 | + p5.value = False |
| 4143 | + p5.tags = ["new", "tags"] |
| 4144 | + p5.description = "Updated" |
| 4145 | + self.assertEqual(p5.id_field, 100) |
| 4146 | + self.assertEqual(p5.value, False) |
| 4147 | + self.assertEqual(p5.tags, ["new", "tags"]) |
| 4148 | + self.assertEqual(p5.description, "Updated") |
| 4149 | + |
| 4150 | + # Test Pydantic validation if available |
| 4151 | + if USE_PYDANTIC: |
| 4152 | + # Test all valid types |
| 4153 | + valid_cases = [ |
| 4154 | + {"id_field": "string_id", "value": "string_value"}, |
| 4155 | + {"id_field": 42, "value": 123}, |
| 4156 | + {"id_field": "mixed", "value": 3.14}, |
| 4157 | + {"id_field": 999, "value": True}, |
| 4158 | + {"id_field": "with_desc", "value": 1, "description": "Description"}, |
| 4159 | + {"id_field": "with_tags", "value": 1, "tags": ["a", "b", "c"]}, |
| 4160 | + {"id_field": "with_dict", "value": 1, "tags": {"a": "A", "b": "B"}} |
| 4161 | + ] |
| 4162 | + |
| 4163 | + for case in valid_cases: |
| 4164 | + result = TypeAdapter(PipeTypesConfig).validate_python(case) |
| 4165 | + self.assertEqual(result.id_field, case["id_field"]) |
| 4166 | + self.assertEqual(result.value, case["value"]) |
| 4167 | + |
| 4168 | + # Test invalid values |
| 4169 | + invalid_cases = [ |
| 4170 | + {"id_field": 3.14, "value": 1}, # Float for id_field |
| 4171 | + {"id_field": None, "value": 1}, # None for required id_field |
| 4172 | + {"id_field": "test", "value": {}}, # Dict for value |
| 4173 | + {"id_field": "test", "value": None}, # None for required value |
| 4174 | + {"id_field": "test", "value": 1, "status": "unknown"} # Invalid literal |
| 4175 | + ] |
| 4176 | + |
| 4177 | + for case in invalid_cases: |
| 4178 | + with self.assertRaises(ValidationError): |
| 4179 | + TypeAdapter(PipeTypesConfig).validate_python(case) |
| 4180 | + |
| 4181 | + # Test serialization/deserialization |
| 4182 | + original = PipeTypesConfig( |
| 4183 | + id_field="test_id", |
| 4184 | + value=42, |
| 4185 | + description="Test description", |
| 4186 | + tags=["tag1", "tag2"], |
| 4187 | + status="inactive" |
| 4188 | + ) |
| 4189 | + |
| 4190 | + # Convert to JSON and back |
| 4191 | + json_data = TypeAdapter(PipeTypesConfig).dump_json(original) |
| 4192 | + restored = TypeAdapter(PipeTypesConfig).validate_json(json_data) |
| 4193 | + |
| 4194 | + # Verify data integrity |
| 4195 | + self.assertEqual(restored.id_field, original.id_field) |
| 4196 | + self.assertEqual(restored.value, original.value) |
| 4197 | + self.assertEqual(restored.description, original.description) |
| 4198 | + self.assertEqual(restored.tags, original.tags) |
| 4199 | + self.assertEqual(restored.status, original.status) |
3906 | 4200 |
|
3907 | 4201 | if __name__ == "__main__": |
3908 | 4202 | unittest.main() |
0 commit comments