forked from neo4j/neo4j-graphrag-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunner.py
138 lines (117 loc) · 4.86 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline config wrapper (router based on 'template_' key)
and pipeline runner.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import (
Annotated,
Any,
Union,
)
from pydantic import (
BaseModel,
Discriminator,
Field,
Tag,
)
from pydantic.v1.utils import deep_update
from typing_extensions import Self
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader
from neo4j_graphrag.experimental.pipeline.config.pipeline_config import (
AbstractPipelineConfig,
PipelineConfig,
)
from neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder import (
SimpleKGPipelineConfig,
)
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition
from neo4j_graphrag.utils.logging import prettify
logger = logging.getLogger(__name__)
def _get_discriminator_value(model: Any) -> PipelineType:
template_ = None
if "template_" in model:
template_ = model["template_"]
if hasattr(model, "template_"):
template_ = model.template_
return PipelineType(template_) or PipelineType.NONE
class PipelineConfigWrapper(BaseModel):
"""The pipeline config wrapper will parse the right pipeline config based on the `template_` field."""
config: Union[
Annotated[PipelineConfig, Tag(PipelineType.NONE)],
Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)],
] = Field(discriminator=Discriminator(_get_discriminator_value))
def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition:
logger.debug("PIPELINE_CONFIG: start parsing config...")
return self.config.parse(resolved_data)
def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
return self.config.get_run_params(user_input)
class PipelineRunner:
"""Pipeline runner builds a pipeline from different objects and exposes a run method to run pipeline
Pipeline can be built from:
- A PipelineDefinition (`__init__` method)
- A PipelineConfig (`from_config` method)
- A config file (`from_config_file` method)
"""
def __init__(
self,
pipeline_definition: PipelineDefinition,
config: AbstractPipelineConfig | None = None,
do_cleaning: bool = False,
) -> None:
self.config = config
self.pipeline = Pipeline.from_definition(pipeline_definition)
self.run_params = pipeline_definition.get_run_params()
self.do_cleaning = do_cleaning
@classmethod
def from_config(
cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False
) -> Self:
wrapper = PipelineConfigWrapper.model_validate({"config": config})
logger.debug(
f"PIPELINE_RUNNER: instantiating Pipeline from config type: {wrapper.config.template_}"
)
return cls(wrapper.parse(), config=wrapper.config, do_cleaning=do_cleaning)
@classmethod
def from_config_file(cls, file_path: Union[str, Path]) -> Self:
logger.info(f"PIPELINE_RUNNER: reading config file from {file_path}")
if not isinstance(file_path, str):
file_path = str(file_path)
data = ConfigReader().read(file_path)
return cls.from_config(data, do_cleaning=True)
async def run(self, user_input: dict[str, Any]) -> PipelineResult:
# pipeline_conditional_run_params = self.
if self.config:
run_param = deep_update(
self.run_params, self.config.get_run_params(user_input)
)
else:
run_param = deep_update(self.run_params, user_input)
logger.info(
f"PIPELINE_RUNNER: starting pipeline {self.pipeline} with run_params={prettify(run_param)}"
)
result = await self.pipeline.run(data=run_param)
if self.do_cleaning:
await self.close()
return result
async def close(self) -> None:
logger.debug("PIPELINE_RUNNER: cleaning up (closing instantiated drivers...)")
if self.config:
await self.config.close()