Skip to content

Commit 902f3ce

Browse files
committed
Make literal and union work with raw from_dict csp, add option to go via pydantic
Signed-off-by: Nijat Khanbabayev <[email protected]>
1 parent e71f1c2 commit 902f3ce

File tree

5 files changed

+221
-326
lines changed

5 files changed

+221
-326
lines changed

csp/impl/struct.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def __new__(cls, name, bases, dct):
3535
# Lists need to be normalized too as potentially we need to add a boolean flag to use FastList
3636
if v == FastList:
3737
raise TypeError(f"{v} annotation is not supported without args")
38-
if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v):
38+
if (
39+
CspTypingUtils.is_generic_container(v)
40+
or CspTypingUtils.is_union_type(v)
41+
or CspTypingUtils.is_literal_type(v)
42+
):
3943
actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v)
4044
if CspTypingUtils.is_generic_container(actual_type):
4145
raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]")
@@ -147,6 +151,17 @@ def serializer(val, handler):
147151

148152

149153
class Struct(_csptypesimpl.PyStruct, metaclass=StructMeta):
154+
@classmethod
155+
def type_adapter(cls):
156+
# Late import to avoid autogen issues
157+
from pydantic import TypeAdapter
158+
159+
internal_type_adapter = getattr(cls, "_pydantic_type_adapter", None)
160+
if internal_type_adapter:
161+
return internal_type_adapter
162+
cls._pydantic_type_adapter = TypeAdapter(cls)
163+
return cls._pydantic_type_adapter
164+
150165
@classmethod
151166
def metadata(cls, typed=False):
152167
if typed:
@@ -191,7 +206,8 @@ def _obj_from_python(cls, json, obj_type):
191206
if CspTypingUtils.is_generic_container(obj_type):
192207
if CspTypingUtils.get_origin(obj_type) in (typing.List, typing.Set, typing.Tuple, FastList):
193208
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
194-
(expected_item_type,) = obj_type.__args__
209+
# We only take the first item, so like for a Tuple, we would ignore arguments after
210+
expected_item_type = obj_type.__args__[0]
195211
return_type = list if isinstance(return_type, list) else return_type
196212
return return_type(cls._obj_from_python(v, expected_item_type) for v in json)
197213
elif CspTypingUtils.get_origin(obj_type) is typing.Dict:
@@ -206,6 +222,13 @@ def _obj_from_python(cls, json, obj_type):
206222
return json
207223
else:
208224
raise NotImplementedError(f"Can not deserialize {obj_type} from json")
225+
elif CspTypingUtils.is_union_type(obj_type):
226+
return json ## no checks, just let it through
227+
elif CspTypingUtils.is_literal_type(obj_type):
228+
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
229+
if isinstance(json, return_type):
230+
return json
231+
raise ValueError(f"Expected type {return_type} received {json.__class__}")
209232
elif issubclass(obj_type, Struct):
210233
if not isinstance(json, dict):
211234
raise TypeError("Representation of struct as json is expected to be of dict type")
@@ -223,7 +246,9 @@ def _obj_from_python(cls, json, obj_type):
223246
return obj_type(json)
224247

225248
@classmethod
226-
def from_dict(cls, json: dict):
249+
def from_dict(cls, json: dict, use_pydantic: bool = False):
250+
if use_pydantic:
251+
return cls.type_adapter().validate_python(json)
227252
return cls._obj_from_python(json, cls)
228253

229254
def to_dict_depr(self):

csp/impl/types/container_type_normalizer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@ def normalized_type_to_actual_python_type(cls, typ, level=0):
8181
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True]
8282
if origin is typing.List and level == 0:
8383
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)]
84-
if origin is typing.Literal:
85-
# Import here to prevent circular import
86-
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
87-
88-
args = typing.get_args(typ)
89-
typ = type(args[0])
90-
for arg in args[1:]:
91-
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
92-
if typ:
93-
return typ
94-
else:
95-
return object
9684
return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ)
9785
elif CspTypingUtils.is_union_type(typ):
9886
return object
87+
elif CspTypingUtils.is_literal_type(typ):
88+
# Import here to prevent circular import
89+
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
90+
91+
args = typing.get_args(typ)
92+
typ = type(args[0])
93+
for arg in args[1:]:
94+
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
95+
if typ:
96+
return typ
97+
else:
98+
return object
9999
else:
100100
return typ
101101

csp/impl/types/typing_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import typing
66

77
import numpy
8+
from pydantic_core import core_schema
89

910
import csp.typing
1011

@@ -15,6 +16,22 @@ class FastList(typing.List, typing.Generic[T]): # Need to inherit from Generic[
1516
def __init__(self):
1617
raise NotImplementedError("Can not init FastList class")
1718

19+
@classmethod
20+
def __get_pydantic_core_schema__(cls, source_type, handler):
21+
args = typing.get_args(source_type)
22+
if args:
23+
inner_type = args[0]
24+
list_schema = handler.generate_schema(typing.List[inner_type])
25+
else:
26+
list_schema = handler.generate_schema(typing.List)
27+
28+
def create_instance(raw_data, validator):
29+
if isinstance(raw_data, FastList):
30+
return raw_data
31+
return validator(raw_data) # just return a list
32+
33+
return core_schema.no_info_wrap_validator_function(function=create_instance, schema=list_schema)
34+
1835

1936
class CspTypingUtils39:
2037
_ORIGIN_COMPAT_MAP = {list: typing.List, set: typing.Set, dict: typing.Dict, tuple: typing.Tuple}
@@ -23,7 +40,7 @@ class CspTypingUtils39:
2340

2441
@classmethod
2542
def is_generic_container(cls, typ):
26-
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union
43+
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ not in (typing.Union, typing.Literal)
2744

2845
@classmethod
2946
def is_type_spec(cls, val):
@@ -56,6 +73,10 @@ def is_numpy_nd_array_type(cls, typ):
5673
def is_union_type(cls, typ):
5774
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Union
5875

76+
@classmethod
77+
def is_literal_type(cls, typ):
78+
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Literal
79+
5980
@classmethod
6081
def is_forward_ref(cls, typ):
6182
return isinstance(typ, typing.ForwardRef)

0 commit comments

Comments
 (0)