Skip to content

Commit

Permalink
Simplify flow config (#1554)
Browse files Browse the repository at this point in the history
* Flatten compute_communities config

* Remove cluster strategy type

* Flatten create_base_text_units config

* Move cluster seed to config default, leave as None in functions

* Remove "prechunked" logic

* Remove hard-coded encoding model

* Remove unused variables

* Strongly type embed_config

* Simplify layout_graph config

* Semver

* Fix integration test

* Fix config unit tests: ignore new config defaults

* Remove pipeline integ test
  • Loading branch information
natoverse authored Dec 28, 2024
1 parent e6de713 commit a2647da
Show file tree
Hide file tree
Showing 44 changed files with 285 additions and 626 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241224192900934104.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Simplify and streamline internal config."
}
17 changes: 14 additions & 3 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from graphrag.config.input_models.graphrag_config_input import GraphRagConfigInput
from graphrag.config.input_models.llm_config_input import LLMConfigInput
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
from graphrag.config.models.community_reports_config import CommunityReportsConfig
Expand Down Expand Up @@ -318,13 +318,16 @@ def hydrate_parallelization_params(
reader.envvar_prefix(Section.node2vec),
reader.use(values.get("embed_graph")),
):
use_lcc = reader.bool("use_lcc")
embed_graph_model = EmbedGraphConfig(
enabled=reader.bool(Fragment.enabled) or defs.NODE2VEC_ENABLED,
dimensions=reader.int("dimensions") or defs.NODE2VEC_DIMENSIONS,
num_walks=reader.int("num_walks") or defs.NODE2VEC_NUM_WALKS,
walk_length=reader.int("walk_length") or defs.NODE2VEC_WALK_LENGTH,
window_size=reader.int("window_size") or defs.NODE2VEC_WINDOW_SIZE,
iterations=reader.int("iterations") or defs.NODE2VEC_ITERATIONS,
random_seed=reader.int("random_seed") or defs.NODE2VEC_RANDOM_SEED,
use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC,
)
with reader.envvar_prefix(Section.input), reader.use(values.get("input")):
input_type = reader.str("type")
Expand Down Expand Up @@ -412,12 +415,15 @@ def hydrate_parallelization_params(
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)

strategy = reader.str("strategy")
chunks_model = ChunkingConfig(
size=reader.int("size") or defs.CHUNK_SIZE,
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
group_by_columns=group_by_columns,
encoding_model=encoding_model,
strategy=ChunkStrategyType(strategy)
if strategy
else ChunkStrategyType.tokens,
)
with (
reader.envvar_prefix(Section.snapshot),
Expand Down Expand Up @@ -522,8 +528,13 @@ def hydrate_parallelization_params(
)

with reader.use(values.get("cluster_graph")):
use_lcc = reader.bool("use_lcc")
seed = reader.int("seed")
cluster_graph_model = ClusterGraphConfig(
max_cluster_size=reader.int("max_cluster_size") or defs.MAX_CLUSTER_SIZE
max_cluster_size=reader.int("max_cluster_size")
or defs.MAX_CLUSTER_SIZE,
use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC,
seed=seed if seed is not None else defs.CLUSTER_GRAPH_SEED,
)

with (
Expand Down
3 changes: 3 additions & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
CLAIM_MAX_GLEANINGS = 1
CLAIM_EXTRACTION_ENABLED = False
MAX_CLUSTER_SIZE = 10
USE_LCC = True
CLUSTER_GRAPH_SEED = 0xDEADBEEF
COMMUNITY_REPORT_MAX_LENGTH = 2000
COMMUNITY_REPORT_MAX_INPUT_LENGTH = 8000
ENTITY_EXTRACTION_ENTITY_TYPES = ["organization", "person", "geo", "event"]
Expand All @@ -74,6 +76,7 @@
PARALLELIZATION_STAGGER = 0.3
PARALLELIZATION_NUM_THREADS = 50
NODE2VEC_ENABLED = False
NODE2VEC_DIMENSIONS = 1536
NODE2VEC_NUM_WALKS = 10
NODE2VEC_WALK_LENGTH = 40
NODE2VEC_WINDOW_SIZE = 2
Expand Down
4 changes: 2 additions & 2 deletions graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
### LLM settings ###
## There are a number of settings to tune the threading and token limits for LLM calls - check the docs.
encoding_model: cl100k_base # this needs to be matched to your model!
encoding_model: {defs.ENCODING_MODEL} # this needs to be matched to your model!
llm:
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
Expand Down Expand Up @@ -111,7 +111,7 @@
enabled: false # if true, will generate node2vec embeddings for nodes
umap:
enabled: false # if true, will generate UMAP embeddings for nodes
enabled: false # if true, will generate UMAP embeddings for nodes (embed_graph must also be enabled)
snapshots:
graphml: false
Expand Down
34 changes: 17 additions & 17 deletions graphrag/config/models/chunking_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,24 @@

"""Parameterization settings for the default configuration."""

from enum import Enum

from pydantic import BaseModel, Field

import graphrag.config.defaults as defs


class ChunkStrategyType(str, Enum):
"""ChunkStrategy class definition."""

tokens = "tokens"
sentence = "sentence"

def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'


class ChunkingConfig(BaseModel):
"""Configuration section for chunking."""

Expand All @@ -19,22 +32,9 @@ class ChunkingConfig(BaseModel):
description="The chunk by columns to use.",
default=defs.CHUNK_GROUP_BY_COLUMNS,
)
strategy: dict | None = Field(
description="The chunk strategy to use, overriding the default tokenization strategy",
default=None,
strategy: ChunkStrategyType = Field(
description="The chunking strategy to use.", default=ChunkStrategyType.tokens
)
encoding_model: str | None = Field(
default=None, description="The encoding model to use."
encoding_model: str = Field(
description="The encoding model to use.", default=defs.ENCODING_MODEL
)

def resolved_strategy(self, encoding_model: str | None) -> dict:
"""Get the resolved chunking strategy."""
from graphrag.index.operations.chunk_text import ChunkStrategyType

return self.strategy or {
"type": ChunkStrategyType.tokens,
"chunk_size": self.size,
"chunk_overlap": self.overlap,
"group_by_columns": self.group_by_columns,
"encoding_name": encoding_model or self.encoding_model,
}
18 changes: 7 additions & 11 deletions graphrag/config/models/cluster_graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,11 @@ class ClusterGraphConfig(BaseModel):
max_cluster_size: int = Field(
description="The maximum cluster size to use.", default=defs.MAX_CLUSTER_SIZE
)
strategy: dict | None = Field(
description="The cluster strategy to use.", default=None
use_lcc: bool = Field(
description="Whether to use the largest connected component.",
default=defs.USE_LCC,
)
seed: int | None = Field(
description="The seed to use for the clustering.",
default=defs.CLUSTER_GRAPH_SEED,
)

def resolved_strategy(self) -> dict:
"""Get the resolved cluster strategy."""
from graphrag.index.operations.cluster_graph import GraphCommunityStrategyType

return self.strategy or {
"type": GraphCommunityStrategyType.leiden,
"max_cluster_size": self.max_cluster_size,
}
23 changes: 6 additions & 17 deletions graphrag/config/models/embed_graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class EmbedGraphConfig(BaseModel):
description="A flag indicating whether to enable node2vec.",
default=defs.NODE2VEC_ENABLED,
)
dimensions: int = Field(
description="The node2vec vector dimensions.", default=defs.NODE2VEC_DIMENSIONS
)
num_walks: int = Field(
description="The node2vec number of walks.", default=defs.NODE2VEC_NUM_WALKS
)
Expand All @@ -30,21 +33,7 @@ class EmbedGraphConfig(BaseModel):
random_seed: int = Field(
description="The node2vec random seed.", default=defs.NODE2VEC_RANDOM_SEED
)
strategy: dict | None = Field(
description="The graph embedding strategy override.", default=None
use_lcc: bool = Field(
description="Whether to use the largest connected component.",
default=defs.USE_LCC,
)

def resolved_strategy(self) -> dict:
"""Get the resolved node2vec strategy."""
from graphrag.index.operations.embed_graph.typing import (
EmbedGraphStrategyType,
)

return self.strategy or {
"type": EmbedGraphStrategyType.node2vec,
"num_walks": self.num_walks,
"walk_length": self.walk_length,
"window_size": self.window_size,
"iterations": self.iterations,
"random_seed": self.iterations,
}
2 changes: 0 additions & 2 deletions graphrag/config/models/entity_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,5 @@ def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
if self.prompt
else None,
"max_gleanings": self.max_gleanings,
# It's prechunked in create_base_text_units
"encoding_name": encoding_model or self.encoding_model,
"prechunked": True,
}
16 changes: 4 additions & 12 deletions graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,8 @@ def _text_unit_workflows(
PipelineWorkflowReference(
name=create_base_text_units,
config={
"chunks": settings.chunks,
"snapshot_transient": settings.snapshots.transient,
"chunk_by": settings.chunks.group_by_columns,
"text_chunk": {
"strategy": settings.chunks.resolved_strategy(
settings.encoding_model
)
},
},
),
PipelineWorkflowReference(
Expand Down Expand Up @@ -243,9 +238,7 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
PipelineWorkflowReference(
name=compute_communities,
config={
"cluster_graph": {
"strategy": settings.cluster_graph.resolved_strategy()
},
"cluster_graph": settings.cluster_graph,
"snapshot_transient": settings.snapshots.transient,
},
),
Expand All @@ -260,9 +253,8 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
PipelineWorkflowReference(
name=create_final_nodes,
config={
"layout_graph_enabled": settings.umap.enabled,
"embed_graph_enabled": settings.embed_graph.enabled,
"embed_graph": {"strategy": settings.embed_graph.resolved_strategy()},
"layout_enabled": settings.umap.enabled,
"embed_graph": settings.embed_graph,
},
),
]
Expand Down
10 changes: 6 additions & 4 deletions graphrag/index/flows/compute_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""All the steps to create the base entity graph."""

from typing import Any

import pandas as pd

from graphrag.index.operations.cluster_graph import cluster_graph
Expand All @@ -13,14 +11,18 @@

def compute_communities(
base_relationship_edges: pd.DataFrame,
clustering_strategy: dict[str, Any],
max_cluster_size: int,
use_lcc: bool,
seed: int | None = None,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
graph = create_graph(base_relationship_edges)

communities = cluster_graph(
graph,
strategy=clustering_strategy,
max_cluster_size,
use_lcc,
seed=seed,
)

base_communities = pd.DataFrame(
Expand Down
40 changes: 25 additions & 15 deletions graphrag/index/flows/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@
aggregate_operation_mapping,
)

from graphrag.index.operations.chunk_text import chunk_text
from graphrag.config.models.chunking_config import ChunkStrategyType
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
from graphrag.index.utils.hashing import gen_sha512_hash


def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None,
group_by_columns: list[str],
size: int,
overlap: int,
encoding_model: str,
strategy: ChunkStrategyType,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
Expand All @@ -35,7 +39,7 @@ def create_base_text_units(

aggregated = _aggregate_df(
sort,
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
groupby=[*group_by_columns] if len(group_by_columns) > 0 else None,
aggregations=[
{
"column": "text_with_ids",
Expand All @@ -47,30 +51,36 @@ def create_base_text_units(

callbacks.progress(Progress(percent=1))

chunked = chunk_text(
aggregated["chunks"] = chunk_text(
aggregated,
column="texts",
to="chunks",
size=size,
overlap=overlap,
encoding_model=encoding_model,
strategy=strategy,
callbacks=callbacks,
strategy=chunk_strategy,
)

chunked = cast("pd.DataFrame", chunked[[*chunk_by_columns, "chunks"]])
chunked = chunked.explode("chunks")
chunked.rename(
aggregated = cast("pd.DataFrame", aggregated[[*group_by_columns, "chunks"]])
aggregated = aggregated.explode("chunks")
aggregated.rename(
columns={
"chunks": "chunk",
},
inplace=True,
)
chunked["id"] = chunked.apply(lambda row: gen_sha512_hash(row, ["chunk"]), axis=1)
chunked[["document_ids", "chunk", "n_tokens"]] = pd.DataFrame(
chunked["chunk"].tolist(), index=chunked.index
aggregated["id"] = aggregated.apply(
lambda row: gen_sha512_hash(row, ["chunk"]), axis=1
)
aggregated[["document_ids", "chunk", "n_tokens"]] = pd.DataFrame(
aggregated["chunk"].tolist(), index=aggregated.index
)
# rename for downstream consumption
chunked.rename(columns={"chunk": "text"}, inplace=True)
aggregated.rename(columns={"chunk": "text"}, inplace=True)

return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))
return cast(
"pd.DataFrame", aggregated[aggregated["text"].notna()].reset_index(drop=True)
)


# TODO: would be nice to inline this completely in the main method with pandas
Expand Down
Loading

0 comments on commit a2647da

Please sign in to comment.