Skip to content

Commit

Permalink
docs: added all docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
PeriniM committed Oct 7, 2024
1 parent 9b06758 commit cede480
Show file tree
Hide file tree
Showing 14 changed files with 270 additions and 14 deletions.
1 change: 1 addition & 0 deletions brickllm/compiled_graphs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .graphs import BrickSchemaGraph

# compiled graph
brickschema_graph = BrickSchemaGraph()._compiled_graph()
10 changes: 9 additions & 1 deletion brickllm/edges/validate_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
# from langgraph.graph import END

def validate_condition(state) -> Literal["schema_to_ttl", "__end__"]:
"""
Validate the condition for the next node to visit.
# Often, we will use state to decide on the next node to visit
Args:
state (State): The current state containing the validation result.
Returns:
Literal["schema_to_ttl", "__end__"]: The next node to visit.
"""

is_valid = state.get("is_valid")
max_iter = state.get("validation_max_iter")

Expand Down
8 changes: 7 additions & 1 deletion brickllm/graphs/brickschema_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
from ..edges import validate_condition
from ..helpers.llm_models import _get_model


class BrickSchemaGraph:
def __init__(self, model: Union[str, BaseChatModel] = "openai"):
"""Initialize the StateGraph object and build the graph."""
"""
Initialize the StateGraph object and build the graph.
Args:
model (Union[str, BaseChatModel]): The model type as a string or an instance of BaseChatModel.
"""

# Define a new graph
self.workflow = StateGraph(State, config_schema=GraphConfig)
Expand Down
10 changes: 10 additions & 0 deletions brickllm/helpers/llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@


def _get_model(model: Union[str, BaseChatModel]):
"""
Get the LLM model based on the provided model type.
Args:
model (Union[str, BaseChatModel]): The model type as a string or an instance of BaseChatModel.
Returns:
BaseChatModel: The LLM model instance.
"""

if isinstance(model, BaseChatModel):
return model

Expand Down
4 changes: 4 additions & 0 deletions brickllm/helpers/prompts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Module containing the prompts used for the LLM models
"""

get_elem_instructions = """
You are a BrickSchema ontology expert and you are provided with a user prompt which describes a building or facility.\n
You are provided with a list of common elements that can be used to describe a building or facility.\n
Expand Down
10 changes: 10 additions & 0 deletions brickllm/nodes/get_elem_children.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@


def get_elem_children(state: State, config):
"""
Identify child elements for each category in the element list using a language model.
Args:
state (State): The current state containing the user prompt and element list.
config (dict): Configuration dictionary containing the language model.
Returns:
dict: A dictionary containing the hierarchical structure of identified elements.
"""
print("---Get Elem Children Node---")

user_prompt = state["user_prompt"]
Expand Down
11 changes: 11 additions & 0 deletions brickllm/nodes/get_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@


def get_elements(state: State, config):
"""
Process the user prompt to identify elements within specified categories
using a language model.
Args:
state (State): The current state containing the user prompt.
config (dict): Configuration dictionary containing the language model.
Returns:
dict: A dictionary containing the list of identified elements.
"""
print("---Get Elements Node---")

user_prompt = state["user_prompt"]
Expand Down
11 changes: 10 additions & 1 deletion brickllm/nodes/get_relationships.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import json
from collections import defaultdict

Expand All @@ -10,6 +9,16 @@


def get_relationships(state: State, config):
"""
Determine relationships between building components using a language model.
Args:
state (State): The current state containing the user prompt and element hierarchy.
config (dict): Configuration dictionary containing the language model.
Returns:
dict: A dictionary containing the grouped sensor paths.
"""
print("---Get Relationships Node---")

user_prompt = state["user_prompt"]
Expand Down
9 changes: 9 additions & 0 deletions brickllm/nodes/get_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@


def get_sensors(state: State):
"""
Retrieve sensor information for the building structure.
Args:
state (State): The current state.
Returns:
dict: A dictionary containing sensor UUIDs mapped to their locations.
"""
print("---Get Sensor Node---")

uuid_dict = {
Expand Down
10 changes: 10 additions & 0 deletions brickllm/nodes/schema_to_ttl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@


def schema_to_ttl(state: State, config):
"""
Generate a TTL (Turtle) script from the building description and component hierarchy.
Args:
state (State): The current state containing the user prompt, sensors, and element hierarchy.
config (dict): Configuration dictionary containing the language model.
Returns:
dict: A dictionary containing the generated TTL output.
"""
print("---Schema To TTL Node---")

user_prompt = state["user_prompt"]
Expand Down
9 changes: 9 additions & 0 deletions brickllm/nodes/validate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@


def validate_schema(state: State):
"""
Validate the generated TTL output against the BrickSchema.
Args:
state (State): The current state containing the TTL output and validation parameters.
Returns:
dict: A dictionary containing the validation status and report.
"""
print("---Validate Schema Node---")

ttl_output = state.get("ttl_output", None)
Expand Down
2 changes: 1 addition & 1 deletion brickllm/states.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, List, Dict
from typing_extensions import TypedDict

# graph state
# state for BrickSchemaGraph class
class State(TypedDict):
user_prompt: str
elem_list: List[str]
Expand Down
118 changes: 113 additions & 5 deletions brickllm/utils/get_hierarchy_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@

# Function to recursively find parents
def find_parents(current_data, target, parents=None):
"""
Recursively find the parent nodes of a target node in a hierarchical data structure.
Args:
current_data (dict): The current level of the hierarchy to search.
target (str): The target node to find parents for.
parents (list, optional): Accumulated list of parent nodes. Defaults to None.
Returns:
tuple: A tuple containing a boolean indicating if the target was found and a list of parent nodes.
"""
if parents is None:
parents = []
for key, value in current_data.items():
Expand All @@ -28,6 +39,16 @@ def find_parents(current_data, target, parents=None):

# Function to get the children of a node
def get_children(current_data, target):
"""
Get the children of a target node in a hierarchical data structure.
Args:
current_data (dict): The current level of the hierarchy to search.
target (str): The target node to find children for.
Returns:
list: A list of child nodes.
"""
if target in current_data:
return list(current_data[target].keys())
for key, value in current_data.items():
Expand All @@ -39,6 +60,17 @@ def get_children(current_data, target):

# Function to flatten the hierarchy
def flatten_hierarchy(current_data, parent=None, result=None):
"""
Flatten a hierarchical data structure into a list of parent-child tuples.
Args:
current_data (dict): The current level of the hierarchy to flatten.
parent (str, optional): The parent node. Defaults to None.
result (list, optional): Accumulated list of parent-child tuples. Defaults to None.
Returns:
list: A list of tuples representing parent-child relationships.
"""
if result is None:
result = []
for key, value in current_data.items():
Expand All @@ -50,6 +82,15 @@ def flatten_hierarchy(current_data, parent=None, result=None):

# Main function to get hierarchy info
def get_hierarchical_info(key):
"""
Get the hierarchical information of a node, including its parents and children.
Args:
key (str): The target node to get information for.
Returns:
tuple: A tuple containing a list of parent nodes and a list of child nodes.
"""
# Get parents
found, parents = find_parents(data, key)
# Get children
Expand All @@ -58,6 +99,16 @@ def get_hierarchical_info(key):

# Function to recursively get all children and subchildren
def get_all_subchildren(current_data, target):
"""
Recursively get all children and subchildren of a target node.
Args:
current_data (dict): The current level of the hierarchy to search.
target (str): The target node to find children for.
Returns:
dict: A dictionary representing the subtree of the target node.
"""
if target in current_data:
return current_data[target]
for key, value in current_data.items():
Expand All @@ -69,12 +120,31 @@ def get_all_subchildren(current_data, target):

# Main function to get hierarchy dictionary
def get_children_hierarchy(key, flatten=False):
"""
Get the hierarchy of children for a target node, optionally flattening the result.
Args:
key (str): The target node to get children for.
flatten (bool, optional): Whether to flatten the hierarchy. Defaults to False.
Returns:
dict or list: A dictionary representing the hierarchy or a list of parent-child tuples if flattened.
"""
if flatten:
return flatten_hierarchy(get_all_subchildren(data, key))
return get_all_subchildren(data, key)

# Function to filter elements based on the given conditions
def filter_elements(elements):
"""
Filter elements based on their hierarchical relationships.
Args:
elements (list): A list of elements to filter.
Returns:
list: A list of filtered elements.
"""
elements_info = {element: get_hierarchical_info(element) for element in elements}
filtered_elements = []

Expand All @@ -91,6 +161,16 @@ def filter_elements(elements):
return filtered_elements

def create_hierarchical_dict(elements, properties=False):
"""
Create a hierarchical dictionary from a list of elements, optionally including properties.
Args:
elements (list): A list of elements to include in the hierarchy.
properties (bool, optional): Whether to include properties in the hierarchy. Defaults to False.
Returns:
dict: A dictionary representing the hierarchical structure.
"""
hierarchy = {}

for category in elements:
Expand Down Expand Up @@ -119,6 +199,16 @@ def create_hierarchical_dict(elements, properties=False):
return hierarchy

def find_sensor_paths(tree, path=None):
"""
Find paths to sensor nodes in a hierarchical tree structure.
Args:
tree (dict): The hierarchical tree structure.
path (list, optional): Accumulated path to the current node. Defaults to None.
Returns:
list: A list of dictionaries containing sensor names and their paths.
"""
if path is None:
path = []

Expand All @@ -136,6 +226,15 @@ def find_sensor_paths(tree, path=None):
return sensor_paths

def build_hierarchy(relationships):
"""
Build a hierarchical tree structure from a list of parent-child relationships.
Args:
relationships (list): A list of tuples representing parent-child relationships.
Returns:
dict: A dictionary representing the hierarchical tree structure.
"""
# Helper function to recursively build the tree structure
def build_tree(node, tree_dict):
return {'name': node, 'children': [build_tree(child, tree_dict) for child in tree_dict[node]]} if tree_dict[node] else {'name': node}
Expand All @@ -157,8 +256,17 @@ def build_tree(node, tree_dict):
return hierarchy

def extract_ttl_content(input_string: str) -> str:
# Use regex to match content between ```python and ```
match = re.search(r"```code\s*(.*?)\s*```", input_string, re.DOTALL)
if match:
return match.group(1).strip()
return ""
"""
Extract content between code block markers in a string.
Args:
input_string (str): The input string containing code blocks.
Returns:
str: The extracted content between the code block markers.
"""
# Use regex to match content between ```python and ```
match = re.search(r"```code\s*(.*?)\s*```", input_string, re.DOTALL)
if match:
return match.group(1).strip()
return ""
Loading

0 comments on commit cede480

Please sign in to comment.