Skip to content

Pre/beta - Unit Tests #969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions tests/test_base_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest

from scrapegraphai.nodes.base_node import BaseNode


class DummyNode(BaseNode):
"""Dummy node for testing BaseNode methods."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def execute(self, state: dict) -> dict:
"""Simple execute implementation that returns the state unchanged."""
return state


# A constant representing a dummy state for testing input keys
TEST_STATE = {"a": 1, "b": 2, "c": 3}


class TestBaseNode:
"""Test suite for BaseNode functionality."""

def setup_method(self):
"""Setup DummyNode instance for tests."""
self.node = DummyNode(
node_name="TestNode",
node_type="node",
input="a",
output=["x"],
min_input_len=1,
)

def test_execute_returns_state(self):
"""Test if execute method returns state unchanged."""
state = {"a": 10}
updated = self.node.execute(state)
assert updated == state

def test_invalid_node_type(self):
"""Test that an invalid node_type raises ValueError."""
with pytest.raises(ValueError):
DummyNode(
node_name="InvalidNode", node_type="invalid", input="a", output=["x"]
)

def test_update_config_without_overwrite(self):
"""Test update_config does not overwrite existing attributes when overwrite is False."""
original_input = self.node.input
self.node.update_config({"input": "new_input"})
assert self.node.input == original_input

def test_update_config_with_overwrite(self):
"""Test update_config updates attributes when overwrite is True."""
self.node.update_config({"input": "new_input_value"}, overwrite=True)
assert self.node.input == "new_input_value"

@pytest.mark.parametrize(
"expression, expected",
[
("a", ["a"]),
("a|b", ["a"]),
("a&b", ["a", "b"]),
(
"(a&b)|c",
["a", "b"],
), # Since a and b are valid, returns the first matching OR segment.
(
"a&(b|c)",
["a", "b"],
), # Evaluation returns the first matching AND condition.
],
)
def test_get_input_keys_valid(self, expression, expected):
"""Test get_input_keys returns correct keys for valid expressions."""
self.node.input = expression
result = self.node.get_input_keys(TEST_STATE)
# Check that both sets are equal, ignoring order.
assert set(result) == set(expected)

@pytest.mark.parametrize(
"expression",
[
"", # empty expression should raise an error
"a||b", # consecutive operator ||
"a&&b", # consecutive operator &&
"a b", # adjacent keys without operator should be caught by regex
"(a&b", # missing a closing parenthesis
"a&b)", # extra closing parenthesis
"&a", # invalid start operator
"a|", # invalid end operator
"a&|b", # invalid operator order
],
)
def test_get_input_keys_invalid(self, expression):
"""Test get_input_keys raises ValueError for invalid expressions."""
self.node.input = expression
with pytest.raises(ValueError):
self.node.get_input_keys(TEST_STATE)

def test_validate_input_keys_insufficient_keys(self):
"""Test that _validate_input_keys raises an error if the returned input keys are insufficient."""
self.node.min_input_len = 2
# Use an expression that returns only one key
self.node.input = "a"
with pytest.raises(ValueError):
self.node.get_input_keys(TEST_STATE)

def test_nested_parentheses(self):
"""Test get_input_keys correctly parses nested parentheses in expressions."""
# Expression with nested parentheses; expected to yield keys "a" and "b"
self.node.input = "((a)&(b|c))"
result = self.node.get_input_keys(TEST_STATE)
assert set(result) == {"a", "b"}

def test_execute_integration_with_state(self):
"""Integration test: Pass a non-trivial state to execute and ensure output matches."""
state = {"a": 100, "b": 200, "c": 300}
result = self.node.execute(state)
assert result == state

# End of tests
18 changes: 11 additions & 7 deletions tests/test_chromium.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,13 +1987,17 @@ async def test_ascrape_playwright_scroll_invalid_type(monkeypatch):
)


@pytest.mark.asyncio
async def test_alazy_load_non_iterable_urls():
"""Test that alazy_load raises TypeError when urls is not an iterable (e.g., integer)."""
with pytest.raises(TypeError):
# Passing an integer as urls should cause a TypeError during iteration.
loader = ChromiumLoader(123, backend="playwright")
[doc async for doc in loader.alazy_load()]
def test_lazy_load_non_iterable_urls():
"""Test that lazy_load treats a non‐iterable urls value as a single URL and returns one Document."""
loader = ChromiumLoader(456, backend="playwright")
docs = list(loader.lazy_load())
from langchain_core.documents import Document

assert len(docs) == 1, (
"Expected one Document when a single URL (non-iterable) is provided"
)
assert isinstance(docs[0], Document)
assert docs[0].metadata["source"] == 456


def test_lazy_load_non_iterable_urls():
Expand Down
156 changes: 156 additions & 0 deletions tests/test_concat_answers_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import pytest

from scrapegraphai.nodes.concat_answers_node import ConcatAnswersNode


class DummyBaseNode:
"""Dummy class to simulate BaseNode's get_input_keys method for testing."""

def get_input_keys(self, state: dict):
# For testing, assume self.input is a single key or a comma-separated list of keys.
if "," in self.input:
return [key.strip() for key in self.input.split(",")]
else:
return [self.input]


# Monkey-patch ConcatAnswersNode to use DummyBaseNode's get_input_keys method.
ConcatAnswersNode.get_input_keys = DummyBaseNode.get_input_keys


class TestConcatAnswersNode:
"""Test suite for the ConcatAnswersNode functionality."""

def test_execute_multiple_answers(self):
"""Test execute with multiple answers concatenated into a merged dictionary."""
node = ConcatAnswersNode(
input="answers", output=["result"], node_config={"verbose": True}
)
state = {"answers": ["Answer one", "Answer two", "Answer three"]}
updated_state = node.execute(state)
expected = {
"products": {
"item_1": "Answer one",
"item_2": "Answer two",
"item_3": "Answer three",
}
}
assert updated_state["result"] == expected

def test_execute_single_answer(self):
"""Test execute with a single answer returns the answer directly."""
node = ConcatAnswersNode(input="answers", output=["result"])
state = {"answers": ["Only answer"]}
updated_state = node.execute(state)
assert updated_state["result"] == "Only answer"

def test_execute_missing_input_key_raises_keyerror(self):
"""Test execute raises KeyError when the required input key is missing in the state."""
node = ConcatAnswersNode(input="missing_key", output=["result"])
state = {"some_other_key": "data"}
with pytest.raises(KeyError):
node.execute(state)

def test_merge_dict_private_method(self):
"""Test the _merge_dict private method to ensure correct merge of a list of answers."""
node = ConcatAnswersNode(input="answers", output=["result"])
data = ["A", "B"]
merged = node._merge_dict(data)
expected = {"products": {"item_1": "A", "item_2": "B"}}
assert merged == expected

def test_verbose_flag(self):
"""Test that node initialization with verbose flag does not interfere with execute."""
node = ConcatAnswersNode(
input="answers", output=["result"], node_config={"verbose": True}
)
state = {"answers": ["Verbose answer"]}
updated_state = node.execute(state)
# When only one answer is provided, the answer should be returned directly.
assert updated_state["result"] == "Verbose answer"

def test_merge_dict_empty(self):
"""Test _merge_dict with an empty list returns an empty products dictionary."""
node = ConcatAnswersNode(input="answers", output=["result"])
merged = node._merge_dict([])
expected = {"products": {}}
assert merged == expected

def test_execute_empty_answers(self):
"""Test execute raises an IndexError when the 'answers' list is empty."""
node = ConcatAnswersNode(input="answers", output=["result"])
state = {"answers": []}
with pytest.raises(IndexError):
node.execute(state)

def test_execute_comma_separated_input(self):
"""Test execute with comma-separated input keys returns correct result using first key."""
node = ConcatAnswersNode(input="answers, extra", output=["result"])
state = {"answers": ["First answer", "Second answer"], "extra": "dummy"}
updated_state = node.execute(state)
# Since "answers" list has length > 1, expected merged dictionary.
expected = {"products": {"item_1": "First answer", "item_2": "Second answer"}}
assert updated_state["result"] == expected

def test_verbose_logging(self):
"""Test that verbose mode triggers logging of the execution start message."""
node = ConcatAnswersNode(
input="answers", output=["result"], node_config={"verbose": True}
)
# Setup a dummy logger to capture log messages.
logged_messages = []
node.logger = type(
"DummyLogger", (), {"info": lambda self, msg: logged_messages.append(msg)}
)()
state = {"answers": ["Only answer"]}
node.execute(state)
# Check that one of the logged messages includes 'Executing ConcatAnswers'
assert any("Executing ConcatAnswers" in message for message in logged_messages)

def test_execute_tuple_input(self):
"""Test execute with tuple input for answers returns a merged dictionary."""
node = ConcatAnswersNode(input="answers", output=["result"])
state = {"answers": ("first", "second")}
updated_state = node.execute(state)
expected = {"products": {"item_1": "first", "item_2": "second"}}
assert updated_state["result"] == expected

def test_execute_string_input(self):
"""Test execute with a string input for answers returns a merged dictionary by iterating each character."""
node = ConcatAnswersNode(input="answers", output=["result"])
state = {"answers": "hello"}
updated_state = node.execute(state)
expected = {
"products": {
"item_1": "h",
"item_2": "e",
"item_3": "l",
"item_4": "l",
"item_5": "o",
}
}
assert updated_state["result"] == expected

def test_execute_non_iterable_input_raises_error(self):
"""Test execute with a non-iterable input for answers raises a TypeError."""
node = ConcatAnswersNode(input="answers", output=["result"])
state = {"answers": 123}
with pytest.raises(TypeError):
node.execute(state)

def test_execute_dict_input(self):
"""Test execute with dict input for answers. Since iterating over a dict yields its keys,
the merged dictionary should consist of the dict keys."""
node = ConcatAnswersNode(input="answers", output=["result"])
state = {"answers": {"k1": "Answer one", "k2": "Answer two"}}
updated_state = node.execute(state)
expected = {"products": {"item_1": "k1", "item_2": "k2"}}
assert updated_state["result"] == expected

def test_execute_generator_input(self):
"""Test execute with a generator input for answers raises a TypeError because generators do
not support len() calls."""
node = ConcatAnswersNode(input="answers", output=["result"])
state = {"answers": (x for x in ["A", "B"])}
with pytest.raises(TypeError):
node.execute(state)
16 changes: 8 additions & 8 deletions tests/test_omni_search_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class TestOmniSearchGraph:
def test_run_with_answer(self):
"""Test that the run() method returns the correct answer when present."""
config = {
"llm": {"model": "dummy-model"},
"llm": {"model": "openai/dummy-model"},
"max_results": 3,
"search_engine": "dummy-engine",
}
prompt = "Test prompt?"
graph_instance = OmniSearchGraph(prompt, config)
# Set required attribute manually
graph_instance.llm_model = {"model": "dummy-model"}
graph_instance.llm_model = {"model": "dummy/dummy-model"}
# Inject a DummyGraph that returns a final state containing an "answer"
dummy_final_state = {"answer": "expected answer"}
graph_instance.graph = DummyGraph(dummy_final_state)
Expand All @@ -42,13 +42,13 @@ def test_run_with_answer(self):
def test_run_without_answer(self):
"""Test that the run() method returns the default message when no answer is found."""
config = {
"llm": {"model": "dummy-model"},
"llm": {"model": "openai/dummy-model"},
"max_results": 3,
"search_engine": "dummy-engine",
}
prompt = "Test prompt without answer?"
graph_instance = OmniSearchGraph(prompt, config)
graph_instance.llm_model = {"model": "dummy-model"}
graph_instance.llm_model = {"model": "dummy/dummy-model"}
# Inject a DummyGraph that returns an empty final state
dummy_final_state = {}
graph_instance.graph = DummyGraph(dummy_final_state)
Expand All @@ -58,14 +58,14 @@ def test_run_without_answer(self):
def test_create_graph_structure(self):
"""Test that the _create_graph() method returns a graph with the expected structure."""
config = {
"llm": {"model": "dummy-model"},
"llm": {"model": "openai/dummy-model"},
"max_results": 4,
"search_engine": "dummy-engine",
}
prompt = "Structure test prompt"
# Using a dummy schema for testing
graph_instance = OmniSearchGraph(prompt, config, schema=DummySchema)
graph_instance.llm_model = {"model": "dummy-model"}
graph_instance.llm_model = {"model": "dummy/dummy-model"}
constructed_graph = graph_instance._create_graph()
# Ensure constructed_graph has essential attributes
assert hasattr(constructed_graph, "nodes")
Expand All @@ -81,7 +81,7 @@ def test_create_graph_structure(self):
def test_config_deepcopy(self):
"""Test that the config passed to OmniSearchGraph is deep copied properly."""
config = {
"llm": {"model": "dummy-model"},
"llm": {"model": "openai/dummy-model"},
"max_results": 2,
"search_engine": "dummy-engine",
}
Expand All @@ -96,7 +96,7 @@ def test_config_deepcopy(self):
def test_schema_deepcopy(self):
"""Test that the schema is deep copied correctly so external changes do not affect it."""
config = {
"llm": {"model": "dummy-model"},
"llm": {"model": "openai/dummy-model"},
"max_results": 2,
"search_engine": "dummy-engine",
}
Expand Down
Loading