Skip to content

Commit b9cc4f5

Browse files
committed
Fix to allow Literal and Union via | in csp annotations
Signed-off-by: Nijat Khanbabayev <[email protected]>
1 parent aafe7f9 commit b9cc4f5

File tree

4 files changed

+76
-12
lines changed

4 files changed

+76
-12
lines changed

csp/impl/types/pydantic_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
import types
33
import typing
4-
from typing import Any, ForwardRef, Generic, Optional, Type, TypeVar, Union, get_args, get_origin
4+
from typing import Any, ForwardRef, Generic, Literal, Optional, Type, TypeVar, Union, get_args, get_origin
55

66
from pydantic import GetCoreSchemaHandler, ValidationInfo, ValidatorFunctionWrapHandler
77
from pydantic_core import CoreSchema, core_schema
@@ -184,6 +184,8 @@ def adjust_annotations(
184184
return TsType[
185185
adjust_annotations(args[0], top_level=False, in_ts=True, make_optional=False, forced_tvars=forced_tvars)
186186
]
187+
if origin is Literal: # for literals, we stop converting
188+
return Optional[annotation] if make_optional else annotation
187189
else:
188190
try:
189191
if origin is CspTypeVar or origin is CspTypeVarType:

csp/impl/types/type_annotation_normalizer_transformer.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def visit_arg(self, node):
5151
return node
5252

5353
def visit_Subscript(self, node):
54+
# We choose to avoid parsing here
55+
# to maintain current behavior of allowing empty lists in our types
5456
return node
5557

5658
def visit_List(self, node):
@@ -98,17 +100,13 @@ def visit_Call(self, node):
98100
return node
99101

100102
def visit_Constant(self, node):
101-
if not self._cur_arg:
102-
return node
103-
104-
if self._cur_arg:
105-
return ast.Call(
106-
func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()),
107-
args=[node],
108-
keywords=[],
109-
)
110-
else:
103+
if not self._cur_arg or not isinstance(node.value, str):
111104
return node
105+
return ast.Call(
106+
func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()),
107+
args=[node],
108+
keywords=[],
109+
)
112110

113111
def visit_Str(self, node):
114112
return self.visit_Constant(node)

csp/tests/impl/types/test_pydantic_types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22
from inspect import isclass
3-
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_origin
3+
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union, get_args, get_origin
44
from unittest import TestCase
55

66
import csp
@@ -160,3 +160,12 @@ def test_force_tvars(self):
160160
self.assertAnnotationsEqual(
161161
adjust_annotations(CspTypeVarType[T], forced_tvars={"T": float}), Union[Type[float], Type[int]]
162162
)
163+
164+
def test_literal(self):
165+
self.assertAnnotationsEqual(adjust_annotations(Literal["a", "b"]), Literal["a", "b"])
166+
self.assertAnnotationsEqual(
167+
adjust_annotations(Literal["a", "b"], make_optional=True), Optional[Literal["a", "b"]]
168+
)
169+
self.assertAnnotationsEqual(adjust_annotations(Literal[123, "a"]), Literal[123, "a"])
170+
self.assertAnnotationsEqual(adjust_annotations(Literal[123, None]), Literal[123, None])
171+
self.assertAnnotationsEqual(adjust_annotations(ts[Literal[123, None]]), ts[Literal[123, None]])

csp/tests/test_type_checking.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import pickle
33
import re
4+
import sys
45
import typing
56
import unittest
67
from datetime import datetime, time, timedelta
@@ -938,6 +939,60 @@ def test_is_callable(self):
938939
result = CspTypingUtils.is_callable(input_type)
939940
self.assertEqual(result, expected)
940941

942+
def test_literal_with_pipe_operator(self):
943+
"""Test combining Literal types with the pipe operator in Python 3.10+."""
944+
if sys.version_info >= (3, 10) and USE_PYDANTIC: # this doesn't work without pydantic type checking
945+
946+
def run_literal_pipe_test():
947+
from typing import Literal
948+
949+
@csp.node
950+
def node_with_literal_pipe(x: ts[int], choice: Literal["a", "b", "c"] | None | int) -> ts[str]:
951+
if csp.ticked(x):
952+
return str(choice) if choice is not None else "none"
953+
954+
@csp.graph
955+
def graph_with_literal_pipe(choice: Literal["a", "b", "c"] | None | int) -> ts[str]:
956+
return csp.const(str(choice) if choice is not None else "none")
957+
958+
@csp.node
959+
def dummy_node(x: ts["T"]): # to avoid pruning
960+
if csp.ticked(x):
961+
...
962+
963+
def graph():
964+
# These should work - valid literal values or None or int
965+
dummy_node(node_with_literal_pipe(csp.const(10), "a"))
966+
dummy_node(node_with_literal_pipe(csp.const(10), "b"))
967+
dummy_node(node_with_literal_pipe(csp.const(10), "c"))
968+
dummy_node(node_with_literal_pipe(csp.const(10), None))
969+
dummy_node(node_with_literal_pipe(csp.const(10), 12))
970+
971+
graph_with_literal_pipe("a")
972+
graph_with_literal_pipe("b")
973+
graph_with_literal_pipe("c")
974+
graph_with_literal_pipe(None)
975+
graph_with_literal_pipe(12)
976+
977+
msg = "(?s)2 validation errors for node_with_literal_pipe.*choice.*"
978+
with self.assertRaisesRegex(TypeError, msg):
979+
dummy_node(node_with_literal_pipe(csp.const(10), "d"))
980+
981+
csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))
982+
983+
# Test direct graph building
984+
csp.build_graph(graph_with_literal_pipe, "a")
985+
csp.build_graph(graph_with_literal_pipe, None)
986+
csp.build_graph(graph_with_literal_pipe, 12)
987+
988+
# This should fail
989+
msg = "(?s)2 validation errors for graph_with_literal_pipe.*choice.*"
990+
with self.assertRaisesRegex(TypeError, msg):
991+
csp.build_graph(graph_with_literal_pipe, "d")
992+
993+
# Run the test
994+
run_literal_pipe_test()
995+
941996

942997
if __name__ == "__main__":
943998
unittest.main()

0 commit comments

Comments
 (0)