Skip to content

Commit

Permalink
Update json serializer so that we automatically short-circuit circula…
Browse files Browse the repository at this point in the history
…r references and thus can serialize more of the AST
  • Loading branch information
shangyian committed Aug 9, 2023
1 parent 96e9bf7 commit b7eff1c
Show file tree
Hide file tree
Showing 6 changed files with 325 additions and 84 deletions.
9 changes: 9 additions & 0 deletions datajunction-server/datajunction_server/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,15 @@ def has_available_materialization(self, build_criteria: BuildCriteria) -> bool:
)
)

def __json_encode__(self):
"""
JSON encoder for node revision
"""
return {
"name": self.name,
"type": self.type,
}


class ImmutableNodeFields(BaseSQLModel):
"""
Expand Down
36 changes: 30 additions & 6 deletions datajunction-server/datajunction_server/sql/parsing/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ class Node(ABC):
def json_ignore_keys(self):
return ["parent", "parent_key", "_is_compiled"]

def __json_encode__(self):
return {
key: self.__dict__[key]
for key in self.__dict__
if key not in self.json_ignore_keys
}

def __post_init__(self):
self.add_self_as_parent()

Expand Down Expand Up @@ -628,6 +635,10 @@ def identifier(self, quotes: bool = True) -> str:
f"{namespace}{quote_style}{self.name}{quote_style}" # pylint: disable=C0301
)

@property
def json_ignore_keys(self):
return ["names", "parent", "parent_key"]


TNamed = TypeVar("TNamed", bound="Named") # pylint: disable=C0103

Expand Down Expand Up @@ -711,9 +722,7 @@ class Column(Aliasable, Named, Expression):

@property
def json_ignore_keys(self):
if set(self._expression.columns).intersection(self.columns):
return ["parent", "parent_key", "_is_compiled", "_expression", "columns"]
return ["parent", "parent_key", "_is_compiled", "columns"]
return ["parent", "parent_key", "columns"]

@property
def type(self):
Expand Down Expand Up @@ -1000,10 +1009,11 @@ def json_ignore_keys(self):
return [
"parent",
"parent_key",
"_is_compiled",
# "_is_compiled",
"_columns",
"column_list",
# "column_list",
"_ref_columns",
"columns",
]

@property
Expand Down Expand Up @@ -1250,6 +1260,11 @@ class BinaryOpKind(DJEnum):
Minus = "-"
Modulo = "%"

def __json_encode__(self):
return {
"value": self.value,
}


@dataclass(eq=False)
class BinaryOp(Operation):
Expand Down Expand Up @@ -2026,7 +2041,16 @@ class FunctionTable(FunctionTableExpression):

@property
def json_ignore_keys(self):
return ["parent", "parent_key", "_is_compiled", "_table"]
return [
"parent",
"parent_key",
"_is_compiled",
"_table",
"_columns",
"column_list",
"_ref_columns",
"columns",
]

def __str__(self) -> str:
alias = f" {self.alias}" if self.alias else ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@
"""
from json import JSONEncoder

from sqlmodel import select

from datajunction_server.models import Node
from datajunction_server.sql.parsing import ast


def remove_circular_refs(obj, _seen: set = None):
"""
Short-circuits circular references in AST nodes
"""
if _seen is None:
_seen = set()
if id(obj) in _seen:
return None
_seen.add(id(obj))
if issubclass(obj.__class__, ast.Node):
serializable_keys = [
key for key in obj.__dict__.keys() if key not in obj.json_ignore_keys
]
for key in serializable_keys:
setattr(obj, key, remove_circular_refs(getattr(obj, key), _seen))
_seen.remove(id(obj))
return obj


class ASTEncoder(JSONEncoder):
"""
Expand All @@ -12,26 +36,50 @@ class ASTEncoder(JSONEncoder):
"""

def __init__(self, *args, **kwargs):
kwargs["check_circular"] = False # no need to check anymore
kwargs["check_circular"] = False
self.markers = set()
super().__init__(*args, **kwargs)
self._processed = set()

def default(self, o):
if id(o) in self._processed:
return None
self._processed.add(id(o))

if o.__class__.__name__ == "NodeRevision":
return {
"__class__": o.__class__.__name__,
"name": o.name,
"type": o.type,
}

o = remove_circular_refs(o)
json_dict = {
k: o.__dict__[k]
for k in o.__dict__
if hasattr(o, "json_ignore_keys") and k not in o.json_ignore_keys
"__class__": o.__class__.__name__,
}
json_dict["__class__"] = o.__class__.__name__
if hasattr(o, "__json_encode__"):
json_dict = {**json_dict, **o.__json_encode__()}
return json_dict


def ast_decoder(session, json_dict):
"""Decodes json dict"""
class_name = json_dict["__class__"]
if not class_name or not hasattr(ast, class_name):
return None
clazz = getattr(ast, class_name)
if class_name == "NodeRevision":
instance = (
session.exec(select(Node).where(Node.name == json_dict["name"]))
.one()
.current
)
else:
instance = clazz(
**{
k: v
for k, v in json_dict.items()
if k not in {"__class__", "_type", "laterals", "_is_compiled"}
},
)
for key, value in json_dict.items():
if key not in {"__class__", "_is_compiled"}:
try:
setattr(instance, key, value)
except AttributeError:
pass

if class_name == "Table":
instance._columns = [ # pylint: disable=protected-access
ast.Column(ast.Name(col.name), _table=instance, _type=col.type)
for col in instance._dj_node.columns # pylint: disable=protected-access
]
return instance
5 changes: 5 additions & 0 deletions datajunction-server/datajunction_server/sql/parsing/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def __str__(self):
def __deepcopy__(self, memo):
return self

def __json_encode__(self):
return {
"__class__": self.__class__.__name__,
}

@classmethod
def __get_validators__(cls) -> Generator[AnyCallable, None, None]:
"""
Expand Down
Loading

0 comments on commit b7eff1c

Please sign in to comment.