Skip to content

Commit 235f6a0

Browse files
committed
Make literal and union work with raw from_dict csp, add option to go via pydantic
Signed-off-by: Nijat Khanbabayev <[email protected]>
1 parent e71f1c2 commit 235f6a0

File tree

5 files changed

+219
-326
lines changed

5 files changed

+219
-326
lines changed

csp/impl/struct.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import ruamel.yaml
66
from deprecated import deprecated
7+
from pydantic import TypeAdapter
78

89
import csp
910
from csp.impl.__csptypesimpl import _csptypesimpl
@@ -35,7 +36,11 @@ def __new__(cls, name, bases, dct):
3536
# Lists need to be normalized too as potentially we need to add a boolean flag to use FastList
3637
if v == FastList:
3738
raise TypeError(f"{v} annotation is not supported without args")
38-
if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v):
39+
if (
40+
CspTypingUtils.is_generic_container(v)
41+
or CspTypingUtils.is_union_type(v)
42+
or CspTypingUtils.is_literal_type(v)
43+
):
3944
actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v)
4045
if CspTypingUtils.is_generic_container(actual_type):
4146
raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]")
@@ -147,6 +152,14 @@ def serializer(val, handler):
147152

148153

149154
class Struct(_csptypesimpl.PyStruct, metaclass=StructMeta):
155+
@classmethod
156+
def type_adapter(cls) -> TypeAdapter:
157+
internal_type_adapter = getattr(cls, "_pydantic_type_adapter", None)
158+
if internal_type_adapter:
159+
return internal_type_adapter
160+
cls._pydantic_type_adapter = TypeAdapter(cls)
161+
return cls._pydantic_type_adapter
162+
150163
@classmethod
151164
def metadata(cls, typed=False):
152165
if typed:
@@ -191,7 +204,8 @@ def _obj_from_python(cls, json, obj_type):
191204
if CspTypingUtils.is_generic_container(obj_type):
192205
if CspTypingUtils.get_origin(obj_type) in (typing.List, typing.Set, typing.Tuple, FastList):
193206
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
194-
(expected_item_type,) = obj_type.__args__
207+
# We only take the first item, so like for a Tuple, we would ignore arguments after
208+
expected_item_type = obj_type.__args__[0]
195209
return_type = list if isinstance(return_type, list) else return_type
196210
return return_type(cls._obj_from_python(v, expected_item_type) for v in json)
197211
elif CspTypingUtils.get_origin(obj_type) is typing.Dict:
@@ -206,6 +220,13 @@ def _obj_from_python(cls, json, obj_type):
206220
return json
207221
else:
208222
raise NotImplementedError(f"Can not deserialize {obj_type} from json")
223+
elif CspTypingUtils.is_union_type(obj_type):
224+
return json ## no checks, just let it through
225+
elif CspTypingUtils.is_literal_type(obj_type):
226+
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
227+
if isinstance(json, return_type):
228+
return json
229+
raise ValueError(f"Expected type {return_type} received {json.__class__}")
209230
elif issubclass(obj_type, Struct):
210231
if not isinstance(json, dict):
211232
raise TypeError("Representation of struct as json is expected to be of dict type")
@@ -223,7 +244,9 @@ def _obj_from_python(cls, json, obj_type):
223244
return obj_type(json)
224245

225246
@classmethod
226-
def from_dict(cls, json: dict):
247+
def from_dict(cls, json: dict, use_pydantic: bool = False):
248+
if use_pydantic:
249+
return cls.type_adapter().validate_python(json)
227250
return cls._obj_from_python(json, cls)
228251

229252
def to_dict_depr(self):

csp/impl/types/container_type_normalizer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@ def normalized_type_to_actual_python_type(cls, typ, level=0):
8181
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True]
8282
if origin is typing.List and level == 0:
8383
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)]
84-
if origin is typing.Literal:
85-
# Import here to prevent circular import
86-
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
87-
88-
args = typing.get_args(typ)
89-
typ = type(args[0])
90-
for arg in args[1:]:
91-
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
92-
if typ:
93-
return typ
94-
else:
95-
return object
9684
return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ)
9785
elif CspTypingUtils.is_union_type(typ):
9886
return object
87+
elif CspTypingUtils.is_literal_type(typ):
88+
# Import here to prevent circular import
89+
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
90+
91+
args = typing.get_args(typ)
92+
typ = type(args[0])
93+
for arg in args[1:]:
94+
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
95+
if typ:
96+
return typ
97+
else:
98+
return object
9999
else:
100100
return typ
101101

csp/impl/types/typing_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import typing
66

77
import numpy
8+
from pydantic_core import core_schema
89

910
import csp.typing
1011

@@ -15,6 +16,22 @@ class FastList(typing.List, typing.Generic[T]): # Need to inherit from Generic[
1516
def __init__(self):
1617
raise NotImplementedError("Can not init FastList class")
1718

19+
@classmethod
20+
def __get_pydantic_core_schema__(cls, source_type, handler):
21+
args = typing.get_args(source_type)
22+
if args:
23+
inner_type = args[0]
24+
list_schema = handler.generate_schema(typing.List[inner_type])
25+
else:
26+
list_schema = handler.generate_schema(typing.List)
27+
28+
def create_instance(raw_data, validator):
29+
if isinstance(raw_data, FastList):
30+
return raw_data
31+
return validator(raw_data) # just return a list
32+
33+
return core_schema.no_info_wrap_validator_function(function=create_instance, schema=list_schema)
34+
1835

1936
class CspTypingUtils39:
2037
_ORIGIN_COMPAT_MAP = {list: typing.List, set: typing.Set, dict: typing.Dict, tuple: typing.Tuple}
@@ -23,7 +40,7 @@ class CspTypingUtils39:
2340

2441
@classmethod
2542
def is_generic_container(cls, typ):
26-
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union
43+
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ not in (typing.Union, typing.Literal)
2744

2845
@classmethod
2946
def is_type_spec(cls, val):
@@ -56,6 +73,10 @@ def is_numpy_nd_array_type(cls, typ):
5673
def is_union_type(cls, typ):
5774
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Union
5875

76+
@classmethod
77+
def is_literal_type(cls, typ):
78+
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Literal
79+
5980
@classmethod
6081
def is_forward_ref(cls, typ):
6182
return isinstance(typ, typing.ForwardRef)

0 commit comments

Comments
 (0)