Skip to content

Commit

Permalink
feat: added pre-commit and tox integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PeriniM committed Oct 7, 2024
1 parent cede480 commit 25dff8b
Show file tree
Hide file tree
Showing 31 changed files with 425 additions and 178 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
OPENAI_API_KEY=sk-example-1234567890abcdef1234567890abcdef
ANTHROPIC_API_KEY=sk-example-1234567890abcdef1234567890abcdef
FIREWORKS_API_KEY=sk-example-1234567890abcdef1234567890abcdef
FIREWORKS_API_KEY=sk-example-1234567890abcdef1234567890abcdef
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache/
cover/

# Translations
Expand Down Expand Up @@ -147,4 +148,4 @@ cython_debug/
# macOS
.DS_Store

dev.ipynb
dev.ipynb
22 changes: 22 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
repos:
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.6.9
hooks:
- id: ruff

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
4 changes: 2 additions & 2 deletions brickllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .configs import GraphConfig
from .schemas import ElemListSchema, RelationshipsSchema, TTLSchema
from .states import State
from .configs import GraphConfig

__all__ = [
"ElemListSchema",
"RelationshipsSchema",
"TTLSchema",
"State",
"GraphConfig",
]
]
2 changes: 1 addition & 1 deletion brickllm/compiled_graphs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .graphs import BrickSchemaGraph

# compiled graph
brickschema_graph = BrickSchemaGraph()._compiled_graph()
brickschema_graph = BrickSchemaGraph()._compiled_graph()
6 changes: 4 additions & 2 deletions brickllm/configs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import TypedDict, Literal, Union
from typing import Literal, TypedDict, Union

from langchain.chat_models.base import BaseChatModel


# Define the config
class GraphConfig(TypedDict):
model: Union[Literal["anthropic", "openai", "fireworks"], BaseChatModel]
model: Union[Literal["anthropic", "openai", "fireworks"], BaseChatModel]
1 change: 0 additions & 1 deletion brickllm/edges/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .validate_condition import validate_condition

__all__ = ["validate_condition"]

6 changes: 4 additions & 2 deletions brickllm/edges/validate_condition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Literal

# from langgraph.graph import END


def validate_condition(state) -> Literal["schema_to_ttl", "__end__"]:
"""
Validate the condition for the next node to visit.
Expand All @@ -11,11 +13,11 @@ def validate_condition(state) -> Literal["schema_to_ttl", "__end__"]:
Returns:
Literal["schema_to_ttl", "__end__"]: The next node to visit.
"""

is_valid = state.get("is_valid")
max_iter = state.get("validation_max_iter")

if max_iter > 0 and not is_valid:
return "schema_to_ttl"

return "__end__"
return "__end__"
2 changes: 1 addition & 1 deletion brickllm/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__all__ = [
"BrickSchemaGraph",
]
]
67 changes: 41 additions & 26 deletions brickllm/graphs/brickschema_graph.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from typing import Union
from PIL import Image
import os
from typing import Union

from langchain.chat_models.base import BaseChatModel
from langgraph.graph import START, END, StateGraph
from .. import State, GraphConfig
from ..nodes import (
get_elements, get_elem_children, get_relationships,
schema_to_ttl, validate_schema, get_sensors
)
from langgraph.graph import END, START, StateGraph
from PIL import Image

from .. import GraphConfig, State
from ..edges import validate_condition
from ..helpers.llm_models import _get_model
from ..nodes import (
get_elem_children,
get_elements,
get_relationships,
get_sensors,
schema_to_ttl,
validate_schema,
)


class BrickSchemaGraph:
Expand All @@ -21,13 +26,13 @@ def __init__(self, model: Union[str, BaseChatModel] = "openai"):
Args:
model (Union[str, BaseChatModel]): The model type as a string or an instance of BaseChatModel.
"""

# Define a new graph
self.workflow = StateGraph(State, config_schema=GraphConfig)

# Store the model
self.model = _get_model(model)

# Build graph by adding nodes
self.workflow.add_node("get_elements", get_elements)
self.workflow.add_node("get_elem_children", get_elem_children)
Expand All @@ -45,13 +50,13 @@ def __init__(self, model: Union[str, BaseChatModel] = "openai"):
self.workflow.add_conditional_edges("validate_schema", validate_condition)
self.workflow.add_edge("get_relationships", "get_sensors")
self.workflow.add_edge("get_sensors", END)

# Compile graph
try:
self.graph = self.workflow.compile()
except Exception as e:
raise ValueError(f"Failed to compile the graph: {e}")

# Update the config with the model
self.config = {"configurable": {"thread_id": "1", "llm_model": self.model}}

Expand All @@ -61,18 +66,22 @@ def __init__(self, model: Union[str, BaseChatModel] = "openai"):
def _compiled_graph(self):
"""Check if the graph is compiled and return the compiled graph."""
if self.graph is None:
raise ValueError("Graph is not compiled yet. Please compile the graph first.")
raise ValueError(
"Graph is not compiled yet. Please compile the graph first."
)
return self.graph

def display(self, filename="graph.png") -> None:
"""Display the compiled graph as an image.
Args:
filename (str): The filename to save the graph image.
"""
if self.graph is None:
raise ValueError("Graph is not compiled yet. Please compile the graph first.")

raise ValueError(
"Graph is not compiled yet. Please compile the graph first."
)

# Save the image to the specified file
self.graph.get_graph().draw_mermaid_png(output_file_path=filename)

Expand All @@ -81,7 +90,9 @@ def display(self, filename="graph.png") -> None:
with Image.open(filename) as img:
img.show()
else:
raise FileNotFoundError(f"Failed to generate the graph image file: {filename}")
raise FileNotFoundError(
f"Failed to generate the graph image file: {filename}"
)

def run(self, prompt, stream=False):
"""Run the graph with the given user prompt.
Expand All @@ -95,9 +106,11 @@ def run(self, prompt, stream=False):
if stream:
events = []
# Stream the content of the graph state at each node
for event in self.graph.stream(input_data, self.config, stream_mode="values"):
for event in self.graph.stream(
input_data, self.config, stream_mode="values"
):
events.append(event)

# Store the last event as the result
self.result = events[-1]
return events
Expand All @@ -112,18 +125,20 @@ def get_state_snapshots(self) -> list:
all_states = []
for state in self.graph.get_state_history(self.config):
all_states.append(state)

return all_states

def save_ttl_output(self, output_file="brick_schema_output.ttl"):
"""Save the TTL output to a file."""
if self.result is None:
raise ValueError("No result found. Please run the graph first.")
ttl_output = self.result.get('ttl_output', None)

ttl_output = self.result.get("ttl_output", None)

if ttl_output:
with open(output_file, 'w') as f:
with open(output_file, "w") as f:
f.write(ttl_output)
else:
raise ValueError("No TTL output found in the result. Please run the graph with a valid prompt.")
raise ValueError(
"No TTL output found in the result. Please run the graph with a valid prompt."
)
4 changes: 2 additions & 2 deletions brickllm/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .llm_models import _get_model
from .prompts import (
get_elem_instructions,
get_elem_children_instructions,
get_elem_instructions,
get_relationships_instructions,
schema_to_ttl_instructions,
ttl_example,
Expand All @@ -14,4 +14,4 @@
"get_relationships_instructions",
"schema_to_ttl_instructions",
"ttl_example",
]
]
16 changes: 9 additions & 7 deletions brickllm/helpers/llm_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Union
from dotenv import load_dotenv

from dotenv import load_dotenv
from langchain.chat_models.base import BaseChatModel
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from langchain_fireworks import ChatFireworks
from langchain.chat_models.base import BaseChatModel
from langchain_openai import ChatOpenAI


def _get_model(model: Union[str, BaseChatModel]):
Expand All @@ -17,10 +17,10 @@ def _get_model(model: Union[str, BaseChatModel]):
Returns:
BaseChatModel: The LLM model instance.
"""

if isinstance(model, BaseChatModel):
return model

# Load environment variables
load_dotenv()

Expand All @@ -29,6 +29,8 @@ def _get_model(model: Union[str, BaseChatModel]):
elif model == "anthropic":
return ChatAnthropic(temperature=0, model="claude-3-sonnet-20240229")
elif model == "fireworks":
return ChatFireworks(temperature=0, model="accounts/fireworks/models/llama-v3p1-70b-instruct")
return ChatFireworks(
temperature=0, model="accounts/fireworks/models/llama-v3p1-70b-instruct"
)
else:
raise ValueError(f"Unsupported model type: {model}")
raise ValueError(f"Unsupported model type: {model}")
2 changes: 1 addition & 1 deletion brickllm/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
"get_sensors",
"schema_to_ttl",
"validate_schema",
]
]
23 changes: 15 additions & 8 deletions brickllm/nodes/get_elem_children.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from langchain_core.messages import HumanMessage, SystemMessage

from .. import State, ElemListSchema
from .. import ElemListSchema, State
from ..helpers import get_elem_children_instructions
from ..utils import get_children_hierarchy, create_hierarchical_dict, filter_elements
from ..utils import create_hierarchical_dict, filter_elements, get_children_hierarchy


def get_elem_children(state: State, config):
Expand All @@ -24,7 +24,9 @@ def get_elem_children(state: State, config):
category_dict = {}
for category in categories:
children_list = get_children_hierarchy(category, flatten=True)
children_string = "\n".join([f"{parent} -> {child}" for parent, child in children_list])
children_string = "\n".join(
[f"{parent} -> {child}" for parent, child in children_list]
)
category_dict[category] = children_string

# Get the model name from the config
Expand All @@ -36,11 +38,16 @@ def get_elem_children(state: State, config):
identified_children = []
for category in categories:
# if the category is not "\n", then add the category to the prompt
if category_dict[category] != '':
# System message
system_message = get_elem_children_instructions.format(prompt=user_prompt, elements_list=category_dict[category])
if category_dict[category] != "":
# System message
system_message = get_elem_children_instructions.format(
prompt=user_prompt, elements_list=category_dict[category]
)
# Generate question
elements = structured_llm.invoke([SystemMessage(content=system_message)]+[HumanMessage(content="Find the elements.")])
elements = structured_llm.invoke(
[SystemMessage(content=system_message)]
+ [HumanMessage(content="Find the elements.")]
)
identified_children.extend(elements.elem_list)
else:
identified_children.append(category)
Expand All @@ -52,4 +59,4 @@ def get_elem_children(state: State, config):
# create hierarchical dictionary
hierarchical_dict = create_hierarchical_dict(filtered_children, properties=True)

return {"elem_hierarchy": hierarchical_dict}
return {"elem_hierarchy": hierarchical_dict}
Loading

0 comments on commit 25dff8b

Please sign in to comment.