Skip to content

Commit

Permalink
format check
Browse files Browse the repository at this point in the history
  • Loading branch information
gaudyb committed Dec 27, 2024
1 parent 71a411d commit 9119ef3
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 44 deletions.
5 changes: 4 additions & 1 deletion graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ async def drift_search(
return response, context_data
case list():
return response, context_data



@validate_call(config={"arbitrary_types_allowed": True})
async def basic_search(
config: GraphRagConfig,
Expand Down Expand Up @@ -472,6 +473,7 @@ async def basic_search(
context_data = _reformat_context_data(result.context_data) # type: ignore
return response, context_data


@validate_call(config={"arbitrary_types_allowed": True})
async def basic_search_streaming(
config: GraphRagConfig,
Expand Down Expand Up @@ -525,6 +527,7 @@ async def basic_search_streaming(
else:
yield stream_chunk


def _get_embedding_store(
config_args: dict,
embedding_name: str,
Expand Down
2 changes: 2 additions & 0 deletions graphrag/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def run_drift_search(
# TODO: Map/Reduce Drift Search answer to a single response
return response, context_data


def run_basic_search(
config_filepath: Path | None,
data_dir: Path | None,
Expand Down Expand Up @@ -318,6 +319,7 @@ async def run_streaming_search():
# External users should use the API directly to get the response and context data.
return response, context_data


def _resolve_output_files(
config: GraphRagConfig,
output_list: list[str],
Expand Down
2 changes: 1 addition & 1 deletion graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def hydrate_parallelization_params(
local_search=local_search_model,
global_search=global_search_model,
drift_search=drift_search_model,
basic_search=basic_search_model
basic_search=basic_search_model,
)


Expand Down
2 changes: 1 addition & 1 deletion graphrag/prompts/query/basic_search_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@
{response_type}
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
"""
"""
3 changes: 2 additions & 1 deletion graphrag/query/context_builder/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def build_context(
) -> tuple[pd.DataFrame, dict[str, int]]:
"""Build the context for the primer search actions."""


class BasicContextBuilder(ABC):
"""Base class for basic-search context builders."""

Expand All @@ -71,4 +72,4 @@ def build_context(
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> ContextBuilderResult:
"""Build the context for the basic search mode."""
"""Build the context for the basic search mode."""
51 changes: 24 additions & 27 deletions graphrag/query/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def get_drift_search_engine(
token_encoder=token_encoder,
)


def get_basic_search_engine(
text_units: list[TextUnit],
text_unit_embeddings: BaseVectorStore,
Expand All @@ -210,31 +211,27 @@ def get_basic_search_engine(
ls_config = config.basic_search

return BasicSearch(
llm=llm,
system_prompt=system_prompt,
context_builder=BasicSearchContext(
text_embedder=text_embedder,
text_unit_embeddings=text_unit_embeddings,
text_units=text_units,
token_encoder=token_encoder,
),
llm=llm,
system_prompt=system_prompt,
context_builder=BasicSearchContext(
text_embedder=text_embedder,
text_unit_embeddings=text_unit_embeddings,
text_units=text_units,
token_encoder=token_encoder,
llm_params={
"max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500)
"temperature": ls_config.temperature,
"top_p": ls_config.top_p,
"n": ls_config.n,
},
context_builder_params={
"text_unit_prop": ls_config.text_unit_prop,
"conversation_history_max_turns": ls_config.conversation_history_max_turns,
"conversation_history_user_turns_only": True,
"return_candidate_context": False,
"embedding_vectorstore_key": "id",
"max_tokens": ls_config.max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)
},
)




),
token_encoder=token_encoder,
llm_params={
"max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500)
"temperature": ls_config.temperature,
"top_p": ls_config.top_p,
"n": ls_config.n,
},
context_builder_params={
"text_unit_prop": ls_config.text_unit_prop,
"conversation_history_max_turns": ls_config.conversation_history_max_turns,
"conversation_history_user_turns_only": True,
"return_candidate_context": False,
"embedding_vectorstore_key": "id",
"max_tokens": ls_config.max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)
},
)
8 changes: 7 additions & 1 deletion graphrag/query/structured_search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ class SearchResult:
output_tokens_categories: dict[str, int] | None = None


T = TypeVar("T", GlobalContextBuilder, LocalContextBuilder, DRIFTContextBuilder, BasicContextBuilder)
T = TypeVar(
"T",
GlobalContextBuilder,
LocalContextBuilder,
DRIFTContextBuilder,
BasicContextBuilder,
)


class BaseSearch(ABC, Generic[T]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def build_context(
for r in search_results
]
# make a delimited table for the context; this imitates graphrag context building
table = ["id|text"] + [f'{s["id"]}|{s["text"]}' for s in sources]
table = ["id|text"] + [f"{s['id']}|{s['text']}" for s in sources]

return ContextBuilderResult(
context_chunks="\n\n".join(table),
Expand Down
24 changes: 13 additions & 11 deletions graphrag/query/structured_search/basic_search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Implementation of a generic RAG algorithm (vector search on raw text chunks)
"""


class BasicSearch(BaseSearch[BasicContextBuilder]):
"""Search orchestration for basic search mode."""

Expand Down Expand Up @@ -79,7 +80,8 @@ async def asearch(
log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
try:
search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks, response_type=self.response_type
context_data=context_result.context_chunks,
response_type=self.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
Expand All @@ -104,21 +106,21 @@ async def asearch(
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=sum(output_tokens.values())
output_tokens=sum(output_tokens.values()),
)

except Exception:
log.exception("Exception in _asearch")
return SearchResult(
response="",
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)
response="",
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)

def search(
self,
query: str,
Expand Down

0 comments on commit 9119ef3

Please sign in to comment.