Skip to content

Commit

Permalink
refactor: update typing to union rather than pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
dhrudevalia committed Jan 2, 2025
1 parent 224be4d commit 1fc1c1c
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/templates/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
]

Arrow = Union[pa.Table, pa.RecordBatch]
ArrowData = Union[pa.Table, Iterable[pa.RecordBatch]]
DataStream = Generator[Arrow, None, None]


Expand Down Expand Up @@ -80,7 +81,7 @@ def fn(x: Any) -> Any:
return [x for y in map(fn, lists) for x in y]


def map_tables(model: Graph, entity_type: str) -> Callable[[DataStream], list[Table | Iterable[RecordBatch]]]:
def map_tables(model: Graph, entity_type: str) -> Callable[[DataStream], List[ArrowData]]:
def map_entity(table: Arrow, mapper: MappingFn):
if not table:
raise Exception("empty iterable of record batches provided")
Expand All @@ -89,7 +90,7 @@ def map_entity(table: Arrow, mapper: MappingFn):
mapped_table = [mapped_table]
return mapped_table

def _map_arrow_tables(table: DataStream) -> list[Table | Iterable[RecordBatch]]:
def _map_arrow_tables(table: DataStream) -> List[ArrowData]:
source_field = "_table"
if entity_type == "NODE":
mapper = node_mapper(model, source_field)
Expand All @@ -105,13 +106,13 @@ def _map_arrow_tables(table: DataStream) -> list[Table | Iterable[RecordBatch]]:
def send_nodes(
client: GdsArrowClient,
model: Optional[Graph] = None,
) -> Callable[[Table | Iterable[RecordBatch]], int]:
) -> Callable[[ArrowData], int]:
"""
Wrap the given client, model, and (optional) source_field in a function that
streams PyArrow data (Table or RecordBatch) to Neo4j as nodes.
"""

def _send_nodes_gds(table: Table | Iterable[RecordBatch]) -> int:
def _send_nodes_gds(table: ArrowData) -> int:
nodes_progress = ProgressTracker()

client.upload_nodes(model.name, table, progress_callback=nodes_progress.callback)
Expand All @@ -123,13 +124,13 @@ def _send_nodes_gds(table: Table | Iterable[RecordBatch]) -> int:
def send_edges(
client: GdsArrowClient,
model: Optional[Graph] = None,
) -> Callable[[list[Table | Iterable[RecordBatch]]], int]:
) -> Callable[[ArrowData], int]:
"""
Wrap the given client, model, and (optional) source_field in a function that
streams PyArrow data (Table or RecordBatch) to Neo4j as relationships.
"""

def _send_edges_gds(table: list[Table | Iterable[RecordBatch]]) -> int:
def _send_edges_gds(table: ArrowData) -> int:
edge_progress = ProgressTracker()

client.upload_relationships(model.name, table, progress_callback=edge_progress.callback)
Expand Down

0 comments on commit 1fc1c1c

Please sign in to comment.