Skip to content

Commit b1e408a

Browse files
committed
Make literal and union work with raw from_dict csp, add option to go via pydantic
Signed-off-by: Nijat Khanbabayev <[email protected]>
1 parent e71f1c2 commit b1e408a

File tree

5 files changed

+223
-326
lines changed

5 files changed

+223
-326
lines changed

csp/impl/struct.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def __new__(cls, name, bases, dct):
3535
# Lists need to be normalized too as potentially we need to add a boolean flag to use FastList
3636
if v == FastList:
3737
raise TypeError(f"{v} annotation is not supported without args")
38-
if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v):
38+
if (
39+
CspTypingUtils.is_generic_container(v)
40+
or CspTypingUtils.is_union_type(v)
41+
or CspTypingUtils.is_literal_type(v)
42+
):
3943
actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v)
4044
if CspTypingUtils.is_generic_container(actual_type):
4145
raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]")
@@ -147,6 +151,17 @@ def serializer(val, handler):
147151

148152

149153
class Struct(_csptypesimpl.PyStruct, metaclass=StructMeta):
154+
@classmethod
155+
def type_adapter(cls):
156+
# Late import to avoid autogen issues
157+
from pydantic import TypeAdapter
158+
159+
internal_type_adapter = getattr(cls, "_pydantic_type_adapter", None)
160+
if internal_type_adapter:
161+
return internal_type_adapter
162+
cls._pydantic_type_adapter = TypeAdapter(cls)
163+
return cls._pydantic_type_adapter
164+
150165
@classmethod
151166
def metadata(cls, typed=False):
152167
if typed:
@@ -191,7 +206,8 @@ def _obj_from_python(cls, json, obj_type):
191206
if CspTypingUtils.is_generic_container(obj_type):
192207
if CspTypingUtils.get_origin(obj_type) in (typing.List, typing.Set, typing.Tuple, FastList):
193208
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
194-
(expected_item_type,) = obj_type.__args__
209+
# We only take the first item, so like for a Tuple, we would ignore arguments after
210+
expected_item_type = obj_type.__args__[0]
195211
return_type = list if isinstance(return_type, list) else return_type
196212
return return_type(cls._obj_from_python(v, expected_item_type) for v in json)
197213
elif CspTypingUtils.get_origin(obj_type) is typing.Dict:
@@ -206,6 +222,13 @@ def _obj_from_python(cls, json, obj_type):
206222
return json
207223
else:
208224
raise NotImplementedError(f"Can not deserialize {obj_type} from json")
225+
elif CspTypingUtils.is_union_type(obj_type):
226+
return json ## no checks, just let it through
227+
elif CspTypingUtils.is_literal_type(obj_type):
228+
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
229+
if isinstance(json, return_type):
230+
return json
231+
raise ValueError(f"Expected type {return_type} received {json.__class__}")
209232
elif issubclass(obj_type, Struct):
210233
if not isinstance(json, dict):
211234
raise TypeError("Representation of struct as json is expected to be of dict type")
@@ -223,7 +246,9 @@ def _obj_from_python(cls, json, obj_type):
223246
return obj_type(json)
224247

225248
@classmethod
226-
def from_dict(cls, json: dict):
249+
def from_dict(cls, json: dict, use_pydantic: bool = False):
250+
if use_pydantic:
251+
return cls.type_adapter().validate_python(json)
227252
return cls._obj_from_python(json, cls)
228253

229254
def to_dict_depr(self):

csp/impl/types/container_type_normalizer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@ def normalized_type_to_actual_python_type(cls, typ, level=0):
8181
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True]
8282
if origin is typing.List and level == 0:
8383
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)]
84-
if origin is typing.Literal:
85-
# Import here to prevent circular import
86-
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
87-
88-
args = typing.get_args(typ)
89-
typ = type(args[0])
90-
for arg in args[1:]:
91-
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
92-
if typ:
93-
return typ
94-
else:
95-
return object
9684
return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ)
9785
elif CspTypingUtils.is_union_type(typ):
9886
return object
87+
elif CspTypingUtils.is_literal_type(typ):
88+
# Import here to prevent circular import
89+
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
90+
91+
args = typing.get_args(typ)
92+
typ = type(args[0])
93+
for arg in args[1:]:
94+
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
95+
if typ:
96+
return typ
97+
else:
98+
return object
9999
else:
100100
return typ
101101

csp/impl/types/typing_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,25 @@ class FastList(typing.List, typing.Generic[T]): # Need to inherit from Generic[
1515
def __init__(self):
1616
raise NotImplementedError("Can not init FastList class")
1717

18+
@classmethod
19+
def __get_pydantic_core_schema__(cls, source_type, handler):
20+
from pydantic_core import core_schema
21+
22+
# Late import to not interfere with autogen
23+
args = typing.get_args(source_type)
24+
if args:
25+
inner_type = args[0]
26+
list_schema = handler.generate_schema(typing.List[inner_type])
27+
else:
28+
list_schema = handler.generate_schema(typing.List)
29+
30+
def create_instance(raw_data, validator):
31+
if isinstance(raw_data, FastList):
32+
return raw_data
33+
return validator(raw_data) # just return a list
34+
35+
return core_schema.no_info_wrap_validator_function(function=create_instance, schema=list_schema)
36+
1837

1938
class CspTypingUtils39:
2039
_ORIGIN_COMPAT_MAP = {list: typing.List, set: typing.Set, dict: typing.Dict, tuple: typing.Tuple}
@@ -23,7 +42,7 @@ class CspTypingUtils39:
2342

2443
@classmethod
2544
def is_generic_container(cls, typ):
26-
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union
45+
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ not in (typing.Union, typing.Literal)
2746

2847
@classmethod
2948
def is_type_spec(cls, val):
@@ -56,6 +75,10 @@ def is_numpy_nd_array_type(cls, typ):
5675
def is_union_type(cls, typ):
5776
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Union
5877

78+
@classmethod
79+
def is_literal_type(cls, typ):
80+
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Literal
81+
5982
@classmethod
6083
def is_forward_ref(cls, typ):
6184
return isinstance(typ, typing.ForwardRef)

0 commit comments

Comments
 (0)