diff --git a/datajunction-server/datajunction_server/sql/parsing/ast.py b/datajunction-server/datajunction_server/sql/parsing/ast.py index f8dbd13a6..44be61f1b 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast.py @@ -211,11 +211,41 @@ def get_furthest_parent( return curr_parent curr_parent = curr_parent.parent - def flatten(self) -> Iterator["Node"]: + def flatten( + self, + obfuscated: bool = False, + named: bool = False, + ) -> Iterator["Node"]: """ Flatten the sub-ast of the node as an iterator """ - return self.filter(lambda _: True) + seen = set() + + def _flatten(args): + parent_key = None + if named: + parent_key, node = args + else: + node = args + if id(node) not in seen: + yield (node if not named else (parent_key, node)) + seen.add(id(node)) + for child in chain( + *[ + _flatten(child) + for child in node.fields( + nodes_only=True, + flat=True, + named=named, + nones=False, + obfuscated=obfuscated, + ) + if id(child) not in seen + ] + ): + yield child + + return _flatten((self.parent_key, self) if named else self) # pylint: disable=R0913 def fields( @@ -471,13 +501,13 @@ def is_compiled(self) -> bool: return self._is_compiled -class DJEnum(Enum): +class DJEnum(str, Enum): """ A DJ AST enum """ - def __repr__(self) -> str: - return str(self) + def __str__(self): + return self.value @dataclass(eq=False) @@ -1190,9 +1220,6 @@ class UnaryOpKind(DJEnum): Exists = "EXISTS" Not = "NOT" - def __str__(self): - return self.value - @dataclass(eq=False) class UnaryOp(Operation): @@ -1317,7 +1344,7 @@ def __str__(self) -> str: right = self.right.copy().use_alias_as_name() if isinstance(self.left, Column) and self.left.alias: left = self.left.copy().use_alias_as_name() - ret = f"{left} {self.op.value} {right}" + ret = f"{left} {self.op} {right}" if self.parenthesized: return f"({ret})" @@ -2235,12 +2262,6 @@ def add_aliases_to_unnamed_columns(self) -> None: projection.append(expression) self.projection = projection - -class Select(SelectExpression): - """ - A single select statement type - """ - def __str__(self) -> str: parts = ["SELECT "] if self.quantifier: @@ -2272,6 +2293,12 @@ def __str__(self) -> str: return f"{select}{as_}{self.alias}" return select + +class Select(SelectExpression): + """ + A single select statement type + """ + @property def type(self) -> ColumnType: if len(self.projection) != 1: @@ -2416,3 +2443,256 @@ def build( # pylint: disable=R0913,C0415 self.select.projection, key=lambda x: str(x.alias_or_name), )[:] + + +################################### +###SERIALIZATION/DESERIALIZATION### +################################### +def get_node_key(node: Node) -> int: + """ + Returns the unique identifier of a node. + """ + return id(node) + + +def serialize_value( + value: Any, + serialization: Dict[int, Tuple[str, Dict[str, Any]]], + visited_nodes: Set[int], +) -> Any: + """ + Serializes a value to a dictionary representation. + """ + if isinstance(value, Node): + node_key = get_node_key(value) + if node_key in visited_nodes: + return {"kind": "node", "value": node_key} + visited_nodes.add(node_key) + _serialize_ast(value, serialization, visited_nodes) + return {"kind": "node", "value": node_key} + if isinstance(value, ColumnType): + return {"kind": "type", "value": str(value)} + if isinstance(value, DJEnum): + return {"kind": "primitive", "value": value.value} + if type(value) in PRIMITIVES: + return {"kind": "primitive", "value": value} + if isinstance(value, list): + return { + "kind": "list", + "value": [ + serialize_value(item, serialization, visited_nodes) for item in value + ], + } + if isinstance(value, tuple): + return { + "kind": "tuple", + "value": [ + serialize_value(item, serialization, visited_nodes) for item in value + ], + } + if isinstance(value, set): + return { + "kind": "set", + "value": [ + serialize_value(item, serialization, visited_nodes) for item in value + ], + } + + +def _serialize_ast( + node: Node, + serialization: Dict[int, Tuple[str, Dict[str, Any]]], + visited_nodes: Set[int], +): + """ + Recursively serializes an AST node and its children. + """ + node_key = get_node_key(node) + if node_key in serialization: + return + cls_name = type(node).__name__ + + data = {} + for key, value in node.__dict__.items(): + data[key] = serialize_value(value, serialization, visited_nodes) + serialization[node_key] = (cls_name, data) + + +def serialize_ast(node: Node) -> Dict[int, Tuple[str, Dict[str, Any]]]: + """ + Serializes an AST node and returns its serialization. + """ + ret = {} + visited_nodes = set() + _serialize_ast(node, ret, visited_nodes) + return ret + + +def deserialize_value( + parent_id: int, + parent_key: str, + value: Any, + serialization: Dict[int, Tuple[str, Dict[str, Any]]], + visited: Set[int], + lazies: List["LazyNode"], +) -> Any: + """ + Deserializes a value from its dictionary representation. + """ + if not value: + return + if value["kind"] == "node": + node_key = value["value"] + return ( + _deserialize_ast( + parent_id, + parent_key, + node_key, + serialization, + visited, + lazies, + ) + or serialization[node_key] + ) + elif value["kind"] == "type": + from datajunction_server.sql.parsing.backends.antlr4 import parse + + return parse(f"select CAST(x as {value['value']})").select.projection[0].type # type: ignore + elif value["kind"] == "primitive": + return value["value"] + elif value["kind"] == "list": + return [ + deserialize_value( + parent_id, + parent_key, + item, + serialization, + visited, + lazies, + ) + for item in value["value"] + ] + elif value["kind"] == "tuple": + return tuple( + deserialize_value( + parent_id, + parent_key, + item, + serialization, + visited, + lazies, + ) + for item in value["value"] + ) + elif value["kind"] == "set": + return { + deserialize_value( + parent_id, + parent_key, + item, + serialization, + visited, + lazies, + ) + for item in value["value"] + } + raise TypeError(f"Cannot deserialize value `{value}`.") + + +@dataclass +class LazyNode(Node): + """ + Type used during deserialization of an AST + in place of nodes that have yet to finish + deserializing. + """ + + parent_id: Optional[int] = None + parent_key: Optional[str] = None + key: Optional[int] = None + refs: Optional[Dict[int, Tuple[str, Dict[str, Any]]]] = None + + def finalize(self) -> Node: + """ + Replaces the lazy node with the fully deserialized node. + """ + node = self.refs[self.key] # type: ignore + self.parent = self.refs[self.parent_id] # type: ignore + node = node.copy() + self.swap(node) + return node + + def __str__(self): + raise NotImplementedError() + + +def _deserialize_ast( + parent_id: Optional[int], + parent_key: Optional[str], + node_id: int, + serialization: Dict[int, Tuple[str, Dict[str, Any]]], + visited: Set[int], + lazies: List[LazyNode], +) -> Node: + """ + Recursively deserializes an AST node and its children. + """ + value = serialization[node_id] + if isinstance(value, Node): # node already deserialized + return value + elif node_id in visited: # circular references + # create a lazy node to be swapped once deserialization is complete + lazy = LazyNode(parent_id, parent_key, node_id, serialization) + lazies.append(lazy) + return lazy + else: # node not deserialized yet + cls_name, data = value + visited.add(node_id) + cls = globals().get(cls_name) + # get the fields we can feed the class init + init_fields = {field.name for field in fields(cls) if field.init == True} + attrs = [] + kwargs = {} + for key, value in data.items(): + if key in init_fields: + deserialized_value = deserialize_value( + node_id, + key, + value, + serialization, + visited, + lazies, + ) + kwargs[key] = deserialized_value + else: + attrs.append(key) + ret = cls(**kwargs) # type: ignore + # set the rest of the attributes from data + for key in attrs: + deserialized_value = deserialize_value( + node_id, + key, + data[key], + serialization, + visited, + lazies, + ) + setattr(ret, key, deserialized_value) + + serialization[node_id] = ret + + return ret + + +def deserialize_ast( + node_id: int, + serialization: Dict[int, Tuple[str, Dict[str, Any]]], +) -> Node: + """ + Deserializes an AST from its serialization and returns the root node. + """ + lazies = [] + ret = _deserialize_ast(None, None, node_id, serialization, set(), lazies) + for lazy in lazies: + lazy.finalize() + return ret diff --git a/datajunction-server/tests/sql/parsing/test_ast.py b/datajunction-server/tests/sql/parsing/test_ast.py index d5839c59d..d8fcea843 100644 --- a/datajunction-server/tests/sql/parsing/test_ast.py +++ b/datajunction-server/tests/sql/parsing/test_ast.py @@ -2,11 +2,20 @@ testing ast Nodes and their methods """ +import json + from fastapi.testclient import TestClient from sqlmodel import Session from datajunction_server.errors import DJException +from datajunction_server.models import NodeRevision +from datajunction_server.models.node import Node from datajunction_server.sql.parsing import ast, types +from datajunction_server.sql.parsing.ast import ( + CompileContext, + deserialize_ast, + serialize_ast, +) from datajunction_server.sql.parsing.backends.antlr4 import parse @@ -892,3 +901,75 @@ def test_ast_compile_lateral_view_explode8(session: Session): quote_style="", namespace=None, ) + + +def test_serde_uncompiled(): + """ + tests that serialization-deserialization preserves the query + """ + tree = parse( + """ + SELECT + customer_id, + COUNT(DISTINCT order_id) AS order_count, + SUM(order_total) AS total_sales + FROM + orders + WHERE + order_date >= '2023-01-01' AND order_date < '2024-01-01' + GROUP BY + customer_id + HAVING + order_count >= 3 + ORDER BY + total_sales DESC + LIMIT 10 + """, + ) + serialized = serialize_ast(tree) + deserialized = deserialize_ast(id(tree), serialized) + assert tree.compare(deserialized) + + +def test_serde_jsonable(): + """ + tests that serialized ast is json serializable + """ + tree = parse( + """ + SELECT + customer_id, + COUNT(DISTINCT order_id) AS order_count, + SUM(order_total) AS total_sales + FROM + orders + WHERE + order_date >= '2023-01-01' AND order_date < '2024-01-01' + GROUP BY + customer_id + HAVING + order_count >= 3 + ORDER BY + total_sales DESC + LIMIT 10 + """, + ) + json.dumps(serialize_ast(tree)) + + +def test_compile_node_serde(construction_session: Session): + """ + Test compiling a node + """ + node_a = Node(name="A", current_version="1") + node_a_rev = NodeRevision( + node=node_a, + version="1", + query="SELECT country FROM basic.transform.country_agg", + ) + tree = parse(node_a_rev.query) + ctx = CompileContext(session=construction_session, exception=DJException()) + tree.compile(ctx) + serialized = serialize_ast(tree) + deserialized = deserialize_ast(id(tree), serialized) + assert tree.compare(deserialized)