Skip to content

Commit cf6662c

Browse files
committed
async_to_sync wrapper
1 parent a2aea05 commit cf6662c

File tree

4 files changed

+18
-14
lines changed

4 files changed

+18
-14
lines changed

src/neo4j_graphrag/experimental/components/resolver.py

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from neo4j_graphrag.experimental.components.types import ResolutionStats
2121
from neo4j_graphrag.experimental.pipeline import Component
22+
from neo4j_graphrag.utils import async_to_sync
2223

2324

2425
class EntityResolver(Component, abc.ABC):
@@ -136,3 +137,5 @@ async def run(self) -> ResolutionStats:
136137
number_of_nodes_to_resolve=number_of_nodes_to_resolve,
137138
number_of_created_nodes=number_of_created_nodes,
138139
)
140+
141+
run_sync = async_to_sync(run)

src/neo4j_graphrag/experimental/pipeline/component.py

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pydantic import BaseModel
2222

2323
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
24+
from neo4j_graphrag.utils import async_to_sync
2425

2526

2627
class DataModel(BaseModel):
@@ -63,6 +64,8 @@ def __new__(
6364
}
6465
for f, field in return_model.model_fields.items()
6566
}
67+
# create sync method:
68+
attrs["run_sync"] = async_to_sync(run_method)
6669
return type.__new__(meta, name, bases, attrs)
6770

6871

src/neo4j_graphrag/experimental/pipeline/pipeline.py

+6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from timeit import default_timer
2525
from typing import Any, AsyncGenerator, Optional
2626

27+
from neo4j_graphrag.utils import async_to_sync
28+
2729
try:
2830
import pygraphviz as pgv
2931
except ImportError:
@@ -107,6 +109,8 @@ async def run(self, inputs: dict[str, Any]) -> RunResult | None:
107109
logger.debug(f"TASK RESULT {self.name=} {res=}")
108110
return res
109111

112+
run_sync = async_to_sync(run)
113+
110114

111115
class Orchestrator:
112116
"""Orchestrate a pipeline.
@@ -638,3 +642,5 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
638642
run_id=orchestrator.run_id,
639643
result=await self.final_results.get(orchestrator.run_id),
640644
)
645+
646+
run_sync = async_to_sync(run)

src/neo4j_graphrag/utils.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
from functools import wraps
1718
from typing import Optional
1819
import asyncio
1920
import concurrent.futures
@@ -33,17 +34,8 @@ def run_sync(function, *args, **kwargs):
3334
return return_value
3435

3536

36-
if __name__ == "__main__":
37-
async def async_run(char: str, repeat: int = 2) -> str:
38-
await asyncio.sleep(5)
39-
return char * repeat
40-
41-
async def async_run_multiple(char, n=10):
42-
return await asyncio.gather(*[
43-
async_run(char)
44-
for _ in range(n)
45-
])
46-
47-
print(
48-
run_sync(async_run_multiple, "abc")
49-
)
37+
def async_to_sync(func):
38+
@wraps(func)
39+
def wrapper(*args, **kwargs):
40+
return run_sync(func, *args, **kwargs)
41+
return wrapper

0 commit comments

Comments
 (0)