diff --git a/tests/test_base_node.py b/tests/test_base_node.py new file mode 100644 index 00000000..395b2572 --- /dev/null +++ b/tests/test_base_node.py @@ -0,0 +1,122 @@ +import pytest + +from scrapegraphai.nodes.base_node import BaseNode + + +class DummyNode(BaseNode): + """Dummy node for testing BaseNode methods.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def execute(self, state: dict) -> dict: + """Simple execute implementation that returns the state unchanged.""" + return state + + +# A constant representing a dummy state for testing input keys +TEST_STATE = {"a": 1, "b": 2, "c": 3} + + +class TestBaseNode: + """Test suite for BaseNode functionality.""" + + def setup_method(self): + """Setup DummyNode instance for tests.""" + self.node = DummyNode( + node_name="TestNode", + node_type="node", + input="a", + output=["x"], + min_input_len=1, + ) + + def test_execute_returns_state(self): + """Test if execute method returns state unchanged.""" + state = {"a": 10} + updated = self.node.execute(state) + assert updated == state + + def test_invalid_node_type(self): + """Test that an invalid node_type raises ValueError.""" + with pytest.raises(ValueError): + DummyNode( + node_name="InvalidNode", node_type="invalid", input="a", output=["x"] + ) + + def test_update_config_without_overwrite(self): + """Test update_config does not overwrite existing attributes when overwrite is False.""" + original_input = self.node.input + self.node.update_config({"input": "new_input"}) + assert self.node.input == original_input + + def test_update_config_with_overwrite(self): + """Test update_config updates attributes when overwrite is True.""" + self.node.update_config({"input": "new_input_value"}, overwrite=True) + assert self.node.input == "new_input_value" + + @pytest.mark.parametrize( + "expression, expected", + [ + ("a", ["a"]), + ("a|b", ["a"]), + ("a&b", ["a", "b"]), + ( + "(a&b)|c", + ["a", "b"], + ), # Since a and b are valid, returns the first matching OR segment. + ( + "a&(b|c)", + ["a", "b"], + ), # Evaluation returns the first matching AND condition. + ], + ) + def test_get_input_keys_valid(self, expression, expected): + """Test get_input_keys returns correct keys for valid expressions.""" + self.node.input = expression + result = self.node.get_input_keys(TEST_STATE) + # Check that both sets are equal, ignoring order. + assert set(result) == set(expected) + + @pytest.mark.parametrize( + "expression", + [ + "", # empty expression should raise an error + "a||b", # consecutive operator || + "a&&b", # consecutive operator && + "a b", # adjacent keys without operator should be caught by regex + "(a&b", # missing a closing parenthesis + "a&b)", # extra closing parenthesis + "&a", # invalid start operator + "a|", # invalid end operator + "a&|b", # invalid operator order + ], + ) + def test_get_input_keys_invalid(self, expression): + """Test get_input_keys raises ValueError for invalid expressions.""" + self.node.input = expression + with pytest.raises(ValueError): + self.node.get_input_keys(TEST_STATE) + + def test_validate_input_keys_insufficient_keys(self): + """Test that _validate_input_keys raises an error if the returned input keys are insufficient.""" + self.node.min_input_len = 2 + # Use an expression that returns only one key + self.node.input = "a" + with pytest.raises(ValueError): + self.node.get_input_keys(TEST_STATE) + + def test_nested_parentheses(self): + """Test get_input_keys correctly parses nested parentheses in expressions.""" + # Expression with nested parentheses; expected to yield keys "a" and "b" + self.node.input = "((a)&(b|c))" + result = self.node.get_input_keys(TEST_STATE) + assert set(result) == {"a", "b"} + + def test_execute_integration_with_state(self): + """Integration test: Pass a non-trivial state to execute and ensure output matches.""" + state = {"a": 100, "b": 200, "c": 300} + result = self.node.execute(state) + assert result == state + + # End of tests diff --git a/tests/test_chromium.py b/tests/test_chromium.py index 976a3cdd..f42d6c98 100644 --- a/tests/test_chromium.py +++ b/tests/test_chromium.py @@ -1987,13 +1987,17 @@ async def test_ascrape_playwright_scroll_invalid_type(monkeypatch): ) -@pytest.mark.asyncio -async def test_alazy_load_non_iterable_urls(): - """Test that alazy_load raises TypeError when urls is not an iterable (e.g., integer).""" - with pytest.raises(TypeError): - # Passing an integer as urls should cause a TypeError during iteration. - loader = ChromiumLoader(123, backend="playwright") - [doc async for doc in loader.alazy_load()] +def test_lazy_load_non_iterable_urls(): + """Test that lazy_load treats a non‐iterable urls value as a single URL and returns one Document.""" + loader = ChromiumLoader(456, backend="playwright") + docs = list(loader.lazy_load()) + from langchain_core.documents import Document + + assert len(docs) == 1, ( + "Expected one Document when a single URL (non-iterable) is provided" + ) + assert isinstance(docs[0], Document) + assert docs[0].metadata["source"] == 456 def test_lazy_load_non_iterable_urls(): diff --git a/tests/test_concat_answers_node.py b/tests/test_concat_answers_node.py new file mode 100644 index 00000000..44c83680 --- /dev/null +++ b/tests/test_concat_answers_node.py @@ -0,0 +1,156 @@ +import pytest + +from scrapegraphai.nodes.concat_answers_node import ConcatAnswersNode + + +class DummyBaseNode: + """Dummy class to simulate BaseNode's get_input_keys method for testing.""" + + def get_input_keys(self, state: dict): + # For testing, assume self.input is a single key or a comma-separated list of keys. + if "," in self.input: + return [key.strip() for key in self.input.split(",")] + else: + return [self.input] + + +# Monkey-patch ConcatAnswersNode to use DummyBaseNode's get_input_keys method. +ConcatAnswersNode.get_input_keys = DummyBaseNode.get_input_keys + + +class TestConcatAnswersNode: + """Test suite for the ConcatAnswersNode functionality.""" + + def test_execute_multiple_answers(self): + """Test execute with multiple answers concatenated into a merged dictionary.""" + node = ConcatAnswersNode( + input="answers", output=["result"], node_config={"verbose": True} + ) + state = {"answers": ["Answer one", "Answer two", "Answer three"]} + updated_state = node.execute(state) + expected = { + "products": { + "item_1": "Answer one", + "item_2": "Answer two", + "item_3": "Answer three", + } + } + assert updated_state["result"] == expected + + def test_execute_single_answer(self): + """Test execute with a single answer returns the answer directly.""" + node = ConcatAnswersNode(input="answers", output=["result"]) + state = {"answers": ["Only answer"]} + updated_state = node.execute(state) + assert updated_state["result"] == "Only answer" + + def test_execute_missing_input_key_raises_keyerror(self): + """Test execute raises KeyError when the required input key is missing in the state.""" + node = ConcatAnswersNode(input="missing_key", output=["result"]) + state = {"some_other_key": "data"} + with pytest.raises(KeyError): + node.execute(state) + + def test_merge_dict_private_method(self): + """Test the _merge_dict private method to ensure correct merge of a list of answers.""" + node = ConcatAnswersNode(input="answers", output=["result"]) + data = ["A", "B"] + merged = node._merge_dict(data) + expected = {"products": {"item_1": "A", "item_2": "B"}} + assert merged == expected + + def test_verbose_flag(self): + """Test that node initialization with verbose flag does not interfere with execute.""" + node = ConcatAnswersNode( + input="answers", output=["result"], node_config={"verbose": True} + ) + state = {"answers": ["Verbose answer"]} + updated_state = node.execute(state) + # When only one answer is provided, the answer should be returned directly. + assert updated_state["result"] == "Verbose answer" + + def test_merge_dict_empty(self): + """Test _merge_dict with an empty list returns an empty products dictionary.""" + node = ConcatAnswersNode(input="answers", output=["result"]) + merged = node._merge_dict([]) + expected = {"products": {}} + assert merged == expected + + def test_execute_empty_answers(self): + """Test execute raises an IndexError when the 'answers' list is empty.""" + node = ConcatAnswersNode(input="answers", output=["result"]) + state = {"answers": []} + with pytest.raises(IndexError): + node.execute(state) + + def test_execute_comma_separated_input(self): + """Test execute with comma-separated input keys returns correct result using first key.""" + node = ConcatAnswersNode(input="answers, extra", output=["result"]) + state = {"answers": ["First answer", "Second answer"], "extra": "dummy"} + updated_state = node.execute(state) + # Since "answers" list has length > 1, expected merged dictionary. + expected = {"products": {"item_1": "First answer", "item_2": "Second answer"}} + assert updated_state["result"] == expected + + def test_verbose_logging(self): + """Test that verbose mode triggers logging of the execution start message.""" + node = ConcatAnswersNode( + input="answers", output=["result"], node_config={"verbose": True} + ) + # Setup a dummy logger to capture log messages. + logged_messages = [] + node.logger = type( + "DummyLogger", (), {"info": lambda self, msg: logged_messages.append(msg)} + )() + state = {"answers": ["Only answer"]} + node.execute(state) + # Check that one of the logged messages includes 'Executing ConcatAnswers' + assert any("Executing ConcatAnswers" in message for message in logged_messages) + + def test_execute_tuple_input(self): + """Test execute with tuple input for answers returns a merged dictionary.""" + node = ConcatAnswersNode(input="answers", output=["result"]) + state = {"answers": ("first", "second")} + updated_state = node.execute(state) + expected = {"products": {"item_1": "first", "item_2": "second"}} + assert updated_state["result"] == expected + + def test_execute_string_input(self): + """Test execute with a string input for answers returns a merged dictionary by iterating each character.""" + node = ConcatAnswersNode(input="answers", output=["result"]) + state = {"answers": "hello"} + updated_state = node.execute(state) + expected = { + "products": { + "item_1": "h", + "item_2": "e", + "item_3": "l", + "item_4": "l", + "item_5": "o", + } + } + assert updated_state["result"] == expected + + def test_execute_non_iterable_input_raises_error(self): + """Test execute with a non-iterable input for answers raises a TypeError.""" + node = ConcatAnswersNode(input="answers", output=["result"]) + state = {"answers": 123} + with pytest.raises(TypeError): + node.execute(state) + + def test_execute_dict_input(self): + """Test execute with dict input for answers. Since iterating over a dict yields its keys, + the merged dictionary should consist of the dict keys.""" + node = ConcatAnswersNode(input="answers", output=["result"]) + state = {"answers": {"k1": "Answer one", "k2": "Answer two"}} + updated_state = node.execute(state) + expected = {"products": {"item_1": "k1", "item_2": "k2"}} + assert updated_state["result"] == expected + + def test_execute_generator_input(self): + """Test execute with a generator input for answers raises a TypeError because generators do + not support len() calls.""" + node = ConcatAnswersNode(input="answers", output=["result"]) + state = {"answers": (x for x in ["A", "B"])} + with pytest.raises(TypeError): + node.execute(state) diff --git a/tests/test_omni_search_graph.py b/tests/test_omni_search_graph.py index 656421d5..a0841228 100644 --- a/tests/test_omni_search_graph.py +++ b/tests/test_omni_search_graph.py @@ -25,14 +25,14 @@ class TestOmniSearchGraph: def test_run_with_answer(self): """Test that the run() method returns the correct answer when present.""" config = { - "llm": {"model": "dummy-model"}, + "llm": {"model": "openai/dummy-model"}, "max_results": 3, "search_engine": "dummy-engine", } prompt = "Test prompt?" graph_instance = OmniSearchGraph(prompt, config) # Set required attribute manually - graph_instance.llm_model = {"model": "dummy-model"} + graph_instance.llm_model = {"model": "dummy/dummy-model"} # Inject a DummyGraph that returns a final state containing an "answer" dummy_final_state = {"answer": "expected answer"} graph_instance.graph = DummyGraph(dummy_final_state) @@ -42,13 +42,13 @@ def test_run_with_answer(self): def test_run_without_answer(self): """Test that the run() method returns the default message when no answer is found.""" config = { - "llm": {"model": "dummy-model"}, + "llm": {"model": "openai/dummy-model"}, "max_results": 3, "search_engine": "dummy-engine", } prompt = "Test prompt without answer?" graph_instance = OmniSearchGraph(prompt, config) - graph_instance.llm_model = {"model": "dummy-model"} + graph_instance.llm_model = {"model": "dummy/dummy-model"} # Inject a DummyGraph that returns an empty final state dummy_final_state = {} graph_instance.graph = DummyGraph(dummy_final_state) @@ -58,14 +58,14 @@ def test_run_without_answer(self): def test_create_graph_structure(self): """Test that the _create_graph() method returns a graph with the expected structure.""" config = { - "llm": {"model": "dummy-model"}, + "llm": {"model": "openai/dummy-model"}, "max_results": 4, "search_engine": "dummy-engine", } prompt = "Structure test prompt" # Using a dummy schema for testing graph_instance = OmniSearchGraph(prompt, config, schema=DummySchema) - graph_instance.llm_model = {"model": "dummy-model"} + graph_instance.llm_model = {"model": "dummy/dummy-model"} constructed_graph = graph_instance._create_graph() # Ensure constructed_graph has essential attributes assert hasattr(constructed_graph, "nodes") @@ -81,7 +81,7 @@ def test_create_graph_structure(self): def test_config_deepcopy(self): """Test that the config passed to OmniSearchGraph is deep copied properly.""" config = { - "llm": {"model": "dummy-model"}, + "llm": {"model": "openai/dummy-model"}, "max_results": 2, "search_engine": "dummy-engine", } @@ -96,7 +96,7 @@ def test_config_deepcopy(self): def test_schema_deepcopy(self): """Test that the schema is deep copied correctly so external changes do not affect it.""" config = { - "llm": {"model": "dummy-model"}, + "llm": {"model": "openai/dummy-model"}, "max_results": 2, "search_engine": "dummy-engine", } diff --git a/tests/test_openai_tts.py b/tests/test_openai_tts.py new file mode 100644 index 00000000..b896f3b6 --- /dev/null +++ b/tests/test_openai_tts.py @@ -0,0 +1,174 @@ +import pytest + +from scrapegraphai.models.openai_tts import OpenAITextToSpeech + + +class FakeResponse: + """Fake response object to simulate the API call return value.""" + + def __init__(self, text): + # Emulate audio conversion by prepending a fixed prefix to the input text. + self.content = b"converted:" + text.encode() + + +class FakeSpeech: + """Fake speech class with a create method that simulates generating speech.""" + + def create(self, model, voice, input): + # We ignore the model and voice for the fake implementation. + return FakeResponse(input) + + +class FakeAudio: + """Fake audio class that provides a speech attribute.""" + + def __init__(self): + self.speech = FakeSpeech() + + +class FakeClient: + """Fake client to simulate OpenAI's audio API without making actual network calls.""" + + def __init__(self): + self.audio = FakeAudio() + + +@pytest.fixture +def tts_config(): + """Fixture for providing configuration for OpenAITextToSpeech.""" + return {"api_key": "dummy_key", "model": "custom-model", "voice": "custom_voice"} + + +@pytest.fixture +def tts_instance(tts_config): + """Fixture for an OpenAITextToSpeech instance with a fake client to avoid external API calls.""" + tts = OpenAITextToSpeech(tts_config) + # Override the client with our FakeClient. + tts.client = FakeClient() + return tts + + +def test_run_valid_text(tts_instance): + """Test that run method returns the appropriate byte result for a valid text input.""" + input_text = "Hello, OpenAI!" + result = tts_instance.run(input_text) + # The expected response is the fake conversion with the prefix b"converted:". + expected = b"converted:" + input_text.encode() + assert result == expected + + +def test_run_empty_text(tts_instance): + """Test that run method works correctly when provided with an empty string.""" + input_text = "" + result = tts_instance.run(input_text) + expected = b"converted:" # b"converted:" + b"" is just b"converted:" + assert result == expected + + +def test_attributes_set(tts_config): + """Test that the OpenAITextToSpeech instance correctly sets attributes from the configuration.""" + tts = OpenAITextToSpeech(tts_config) + # The model and voice should be set to the values provided in configuration. + assert tts.model == tts_config["model"] + assert tts.voice == tts_config["voice"] + # The client should be an instance of OpenAI; check that it is not None. + assert tts.client is not None + + +def test_default_config_no_model_voice(): + """Test that default values for model and voice are used when they are not provided in the configuration.""" + # Create a configuration without 'model' and 'voice' keys. + config = {"api_key": "dummy_key"} + tts = OpenAITextToSpeech(config) + # Default values should be "tts-1" for model and "alloy" for voice. + assert tts.model == "tts-1" + assert tts.voice == "alloy" + + +def test_run_unicode_text(tts_instance): + """Test that run method correctly handles Unicode characters in the input text.""" + input_text = "こんにちは、世界!" # "Hello, World!" in Japanese. + result = tts_instance.run(input_text) + expected = b"converted:" + input_text.encode() + assert result == expected + + +def test_run_non_string_input(tts_instance): + """Test that run method raises an error when non-string input is provided.""" + with pytest.raises(AttributeError): + # Passing an integer to run should fail when trying to call .encode() on a non-string type. + tts_instance.run(123) + + +def test_run_exception(tts_instance): + """Test that run method propagates exceptions from the client's API call.""" + + def raise_exception(model, voice, input): + raise Exception("API failure") + + tts_instance.client.audio.speech.create = raise_exception + with pytest.raises(Exception, match="API failure"): + tts_instance.run("Any text") + + +def test_run_long_text(tts_instance): + """Test that run method correctly handles long text input.""" + long_text = "a" * 10000 # a string of 10,000 'a' characters + result = tts_instance.run(long_text) + expected = b"converted:" + long_text.encode() + assert result == expected + + +def test_run_whitespace_text(tts_instance): + """Test that run method correctly handles text that is only whitespace.""" + whitespace_text = " \n\t " + result = tts_instance.run(whitespace_text) + expected = b"converted:" + whitespace_text.encode() + assert result == expected + + +def test_constructor_base_url_usage(tts_config, monkeypatch): + """Test that OpenAITextToSpeech passes the base_url value from the configuration to the OpenAI client.""" + + # Define a fake OpenAI class that captures the initialization parameters. + class FakeOpenAI: + def __init__(self, api_key, base_url=None): + self.api_key = api_key + self.base_url = base_url + + def __getattr__(self, name): + # Return a dummy function for any method calls. + return lambda *args, **kwargs: None + + # Ensure the configuration has a base_url. + custom_url = "https://custom.api.openai.com" + tts_config_with_url = tts_config.copy() + tts_config_with_url["base_url"] = custom_url + + # Monkey-patch the OpenAI class in the module to use our FakeOpenAI. + monkeypatch.setattr("scrapegraphai.models.openai_tts.OpenAI", FakeOpenAI) + + # Create an instance of OpenAITextToSpeech and check that the client has the expected base_url. + from scrapegraphai.models.openai_tts import OpenAITextToSpeech + + tts = OpenAITextToSpeech(tts_config_with_url) + assert hasattr(tts.client, "base_url") + assert tts.client.base_url == custom_url + + +def test_run_response_no_content(tts_instance): + """Test that run method raises AttributeError if the response from the API does not contain a 'content' attribute.""" + + # Create a fake function that simulates a response missing the "content" attribute. + def fake_create_no_content(model, voice, input): + # Return an object with no content attribute. + class NoContent: + pass + + return NoContent() + + # Patch the fake client's speech.create method. + tts_instance.client.audio.speech.create = fake_create_no_content + + with pytest.raises(AttributeError): + tts_instance.run("Test text without content attribute") diff --git a/tests/test_smart_scraper_multi_concat_graph.py b/tests/test_smart_scraper_multi_concat_graph.py index e69de29b..4bd413f3 100644 --- a/tests/test_smart_scraper_multi_concat_graph.py +++ b/tests/test_smart_scraper_multi_concat_graph.py @@ -0,0 +1,89 @@ +import pytest + +from scrapegraphai.graphs.base_graph import BaseGraph +from scrapegraphai.graphs.smart_scraper_multi_concat_graph import ( + SmartScraperMultiConcatGraph, +) + + +class DummyGraph: + """A dummy graph to simulate the execute() behavior for testing.""" + + def __init__(self, answer=None): + self.answer = answer + + def execute(self, inputs): + """ + Simulate execution of the graph. + If answer is None, return an empty state to simulate a missing answer. + Otherwise return a state with an answer. + """ + if self.answer is None: + return ({}, {}) + return ({"answer": self.answer}, {}) + + +class TestSmartScraperMultiConcatGraph: + """Tests for SmartScraperMultiConcatGraph.""" + + @pytest.fixture + def graph_instance(self): + """Fixture to create an instance of SmartScraperMultiConcatGraph with default parameters.""" + prompt = "What is test?" + source = ["http://example.com"] + config = {"test_config": True, "llm": {"model": "dummy-model"}} + instance = SmartScraperMultiConcatGraph(prompt, source, config) + # Manually set llm_model for testing purposes + instance.llm_model = {"model": "dummy-model"} + return instance + + def test_run_with_answer(self, graph_instance): + """Test that run() returns the expected answer when provided by the dummy graph.""" + expected_answer = "This is the merged answer." + graph_instance.graph = DummyGraph(answer=expected_answer) + result = graph_instance.run() + assert result == expected_answer + + def test_run_no_answer(self, graph_instance): + """Test that run() returns 'No answer found.' when no answer key is present.""" + graph_instance.graph = DummyGraph(answer=None) # simulate an empty final state + result = graph_instance.run() + assert result == "No answer found." + + def test_create_graph_structure(self, graph_instance): + """Test that _create_graph returns a BaseGraph instance with expected node names and structure.""" + graph = graph_instance._create_graph() + assert isinstance(graph, BaseGraph) + # Verify that the entry point is the GraphIteratorNode and graph name is set correctly + assert graph.entry_point.node_name == "GraphIteratorNode" + assert graph.graph_name == "SmartScraperMultiConcatGraph" + + # Check all expected node names exist in the graph + node_names = [node.node_name for node in graph.nodes] + expected_nodes = [ + "GraphIteratorNode", + "ConditionalNode", + "MergeAnswersNode", + "ConcatNode", + ] + for expected in expected_nodes: + assert expected in node_names + + # Check that ConditionalNode has edges to both MergeAnswersNode and ConcatNode + edges_from_conditional = [ + edge for edge in graph.edges if edge[0].node_name == "ConditionalNode" + ] + targets = [edge[1].node_name for edge in edges_from_conditional] + assert "MergeAnswersNode" in targets + assert "ConcatNode" in targets + + def test_conditional_node_config(self, graph_instance): + """Test that the ConditionalNode is configured with the correct condition and key_name.""" + graph = graph_instance._create_graph() + cond_nodes = [ + node for node in graph.nodes if node.node_name == "ConditionalNode" + ] + assert len(cond_nodes) == 1 + cond_node = cond_nodes[0] + assert cond_node.node_config.get("condition") == "len(results) > 2" + assert cond_node.node_config.get("key_name") == "results" diff --git a/tests/test_smart_scraper_multi_graph.py b/tests/test_smart_scraper_multi_graph.py new file mode 100644 index 00000000..537c0e8b --- /dev/null +++ b/tests/test_smart_scraper_multi_graph.py @@ -0,0 +1,94 @@ +from pydantic import BaseModel + +from scrapegraphai.graphs.base_graph import BaseGraph +from scrapegraphai.graphs.smart_scraper_multi_graph import SmartScraperMultiGraph + + +# Dummy classes to simulate graph execution +class DummyGraph: + """Dummy graph that always returns an answer.""" + + def execute(self, inputs): + return {"answer": "Test answer"}, "dummy execution info" + + +class DummyGraphNoAnswer: + """Dummy graph that returns no answer key.""" + + def execute(self, inputs): + return {"not_answer": "missing"}, "dummy execution info" + + +class DummyGraphRecord: + """Dummy graph that records the input provided to execute().""" + + def __init__(self): + self.last_inputs = None + + def execute(self, inputs): + self.last_inputs = inputs + return {"answer": "Recorded Answer"}, "dummy execution info" + + +class DummySchema(BaseModel): + answer: str + + +# Tests for SmartScraperMultiGraph +def test_run_returns_answer(): + """ + Test that run() returns the answer provided by the dummy graph. + """ + config = {"llm": {"model": "dummy"}} + scraper = SmartScraperMultiGraph( + prompt="Test prompt", source=["http://example.com"], config=config + ) + scraper.graph = DummyGraph() + scraper.llm_model = {"model": "dummy"} + result = scraper.run() + assert result == "Test answer" + + +def test_run_no_answer_found(): + """ + Test that run() returns 'No answer found.' when the dummy graph does not provide an answer. + """ + config = {"llm": {"model": "dummy"}} + scraper = SmartScraperMultiGraph( + prompt="Another test", source=["http://example.org"], config=config + ) + scraper.graph = DummyGraphNoAnswer() + scraper.llm_model = {"model": "dummy"} + result = scraper.run() + assert result == "No answer found." + + +def test_create_graph_structure(): + """ + Test that _create_graph() returns a valid BaseGraph with two nodes and one edge. + """ + config = {"llm": {"model": "dummy"}} + scraper = SmartScraperMultiGraph( + prompt="Graph test", source=["http://example.net"], config=config + ) + scraper.llm_model = {"model": "dummy"} + graph = scraper._create_graph() + assert isinstance(graph, BaseGraph) + assert len(graph.nodes) == 2 + assert len(graph.edges) == 1 + + +def test_run_records_proper_input(): + """ + Test that run() sends the correct input to the graph's execute method. + """ + config = {"llm": {"model": "dummy"}} + scraper = SmartScraperMultiGraph( + prompt="Record input", source=["http://recorder.com"], config=config + ) + dummy_record = DummyGraphRecord() + scraper.graph = dummy_record + scraper.llm_model = {"model": "dummy"} + scraper.run() + expected_inputs = {"user_prompt": "Record input", "urls": ["http://recorder.com"]} + assert dummy_record.last_inputs == expected_inputs diff --git a/tests/test_smart_scraper_multi_lite_graph.py b/tests/test_smart_scraper_multi_lite_graph.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_xml_scraper_multi_graph.py b/tests/test_xml_scraper_multi_graph.py new file mode 100644 index 00000000..79810fc9 --- /dev/null +++ b/tests/test_xml_scraper_multi_graph.py @@ -0,0 +1,77 @@ +from pydantic import BaseModel + +from scrapegraphai.graphs.xml_scraper_multi_graph import XMLScraperMultiGraph + +XMLScraperMultiGraph._create_llm = lambda self, llm_config: None + + +# Define a fake graph class to simulate the behavior of the graph.execute() method. +class FakeGraph: + def __init__(self, final_state): + self.final_state = final_state + + def execute(self, inputs): + # Return final_state and dummy execution_info. + return self.final_state, {"info": "dummy execution info"} + + +class DummySchema(BaseModel): + dummy_field: str + + +class TestXMLScraperMultiGraph: + """Test suite for XMLScraperMultiGraph""" + + def test_run_returns_answer(self): + """Test run method returns the expected answer when provided in final state.""" + prompt = "Test prompt" + source = ["http://example.com"] + config = {"llm": {"model": "openai/test-model"}} + graph = XMLScraperMultiGraph(prompt, source, config) + expected_answer = "Expected Answer" + # Inject fake graph that returns expected answer + graph.graph = FakeGraph({"answer": expected_answer}) + result = graph.run() + assert result == expected_answer + + def test_run_no_answer_found(self): + """Test run method returns default answer when no answer is present in final state.""" + prompt = "Test prompt" + source = ["http://example.com"] + config = {"llm": {"model": "openai/test-model"}} + graph = XMLScraperMultiGraph(prompt, source, config) + # Inject fake graph that returns empty final_state + graph.graph = FakeGraph({}) + result = graph.run() + assert result == "No answer found." + + def test_create_graph_structure(self): + """Test that _create_graph produces a graph structure with expected nodes and edges.""" + prompt = "Test prompt" + source = ["http://example.com"] + config = {"llm": {"model": "openai/test-model"}, "other_config": "value"} + dummy_schema = DummySchema + graph_instance = XMLScraperMultiGraph(prompt, source, config, dummy_schema) + # Create graph structure using _create_graph + created_graph = graph_instance._create_graph() + # Check that the created graph has nodes and edges defined + assert hasattr(created_graph, "nodes") + assert hasattr(created_graph, "edges") + # Check that entry_point is one of the nodes + assert created_graph.entry_point in created_graph.nodes + + def test_config_and_schema_deepcopy(self): + """Test that modifying the original config and schema does not affect the instance copies.""" + prompt = "Test prompt" + source = ["http://example.com"] + original_config = {"llm": {"model": "openai/test-model"}, "list": [1, 2, 3]} + original_schema = DummySchema + graph_instance = XMLScraperMultiGraph( + prompt, source, original_config, original_schema + ) + # Modify original config after initialization + original_config["list"].append(4) + # The instance copy should remain unchanged + assert graph_instance.copy_config["list"] == [1, 2, 3] + # Similarly, for schema, since it's a deepcopy of the reference, it should be equal to original_schema + assert graph_instance.copy_schema == original_schema