diff --git a/.gitignore b/.gitignore index 7453751d..5bb52dcf 100644 --- a/.gitignore +++ b/.gitignore @@ -192,3 +192,6 @@ cython_debug/ .DS_Store dev.ipynb + +# CodeBeaver reports and artifacts +.codebeaver diff --git a/examples/ScrapegraphAI_cookbook.ipynb b/examples/ScrapegraphAI_cookbook.ipynb index 3ef7eb1e..d8f1151e 100644 --- a/examples/ScrapegraphAI_cookbook.ipynb +++ b/examples/ScrapegraphAI_cookbook.ipynb @@ -1,915 +1,914 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9_CQrFgOj78b" - }, - "outputs": [], - "source": [ - "%%capture\n", - "!pip install scrapegraphai\n", - "!apt install chromium-chromedriver\n", - "!pip install nest_asyncio\n", - "!pip install playwright\n", - "!playwright install" - ] - }, - { - "cell_type": "code", - "source": [ - "import nest_asyncio\n", - "nest_asyncio.apply()" - ], - "metadata": { - "id": "tb33AcRHywFb" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "00a84YVhhxJr" - }, - "outputs": [], - "source": [ - "# correct APIKEY\n", - "OPENAI_API_KEY = \"YOUR API KEY\"" - ] - }, - { - "cell_type": "markdown", - "source": [ - "For more examples visit [the examples folder](https://github.com/ScrapeGraphAI/Scrapegraph-ai/tree/main/examples)" - ], - "metadata": { - "id": "vGDjka17pqqg" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Mrujgp-nlp12" - }, - "source": [ - "# SmartScraperGraph\n", - "**SmartScraperGraph** is a class representing one of the default scraping pipelines. It uses a direct graph implementation where each node has its own function, from retrieving html from a website to extracting relevant information based on your query and generate a coherent answer." - ] - }, - { - "cell_type": "markdown", - "source": [ - "![Screenshot 2024-09-19 alle 17.04.56.png]()" - ], - "metadata": { - "id": "M-dmSB0_zHCQ" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uqYBNOM2YZD9" - }, - "source": [ - "## Using OpenAI models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ogiF4g5Z-bzG" - }, - "outputs": [], - "source": [ - "from scrapegraphai.graphs import SmartScraperGraph" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7ZzONlJ6-oe_" - }, - "source": [ - "Define the configuration for the graph" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "MPZgrZ12-eRc" - }, - "outputs": [], - "source": [ - "graph_config = {\n", - " \"llm\": {\n", - " \"api_key\": OPENAI_API_KEY,\n", - " \"model\": \"openai/gpt-4o-mini\",\n", - " \"temperature\":0,\n", - " },\n", - " \"verbose\":True,\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DjDt_10r-q8P" - }, - "source": [ - "Create the SmartScraperGraph instance and run it" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aV4VTnx9-h_d" - }, - "outputs": [], - "source": [ - "smart_scraper_graph = SmartScraperGraph(\n", - " prompt=\"List me all the projects with their descriptions.\",\n", - " # also accepts a string with the already downloaded HTML code\n", - " source=\"https://perinim.github.io/projects/\",\n", - " config=graph_config\n", - ")" - ] - }, - { - "cell_type": "code", - "source": [ - "graph_config = {\n", - " \"llm\": {\n", - " \"api_key\": OPENAI_API_KEY,\n", - " \"model\": \"openai/gpt-4o-mini\",\n", - " },\n", - " \"verbose\": True,\n", - " \"headless\": True,\n", - "}\n", - "\n", - "# ************************************************\n", - "# Create the SmartScraperGraph instance and run it\n", - "# ************************************************\n", - "\n", - "smart_scraper_graph = SmartScraperGraph(\n", - " prompt=\"List me all the projects with their description\",\n", - " source=\"https://perinim.github.io/projects/\",\n", - " config=graph_config\n", - ")" - ], - "metadata": { - "id": "E3pyGQZLTiZ8" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Zty23idsAtwU", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "419dd75f-18c6-44d2-da82-ca8967d17e0f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "--- Executing Fetch Node ---\n", - "--- (Fetching HTML from: https://perinim.github.io/projects/) ---\n", - "--- Executing ParseNode Node ---\n", - "--- Executing GenerateAnswer Node ---\n" - ] - } - ], - "source": [ - "result = smart_scraper_graph.run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rnGhLGCuAqRU", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "062aeab2-3e96-4fec-d04a-b9acae142f40" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{\n", - " \"projects\": [\n", - " {\n", - " \"name\": \"Rotary Pendulum RL\",\n", - " \"description\": \"Open Source project aimed at controlling a real life rotary pendulum using RL algorithms\"\n", - " },\n", - " {\n", - " \"name\": \"DQN Implementation from scratch\",\n", - " \"description\": \"Developed a Deep Q-Network algorithm to train a simple and double pendulum\"\n", - " },\n", - " {\n", - " \"name\": \"Multi Agents HAED\",\n", - " \"description\": \"University project which focuses on simulating a multi-agent system to perform environment mapping. Agents, equipped with sensors, explore and record their surroundings, considering uncertainties in their readings.\"\n", - " },\n", - " {\n", - " \"name\": \"Wireless ESC for Modular Drones\",\n", - " \"description\": \"Modular drone architecture proposal and proof of concept. The project received maximum grade.\"\n", - " }\n", - " ]\n", - "}\n" - ] - } - ], - "source": [ - "import json\n", - "\n", - "output = json.dumps(result, indent=2)\n", - "\n", - "line_list = output.split(\"\\n\") # Sort of line replacing \"\\n\" with a new line\n", - "\n", - "for line in line_list:\n", - " print(line)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5poLHYLVa-6E" - }, - "source": [ - "# Search graph\n", - "This graph **transforms** the user prompt in a **internet search query**, fetch the relevant URLs, and start the scraping process. Similar to the **SmartScraperGraph** but with the addition of the **SearchInternetNode** node." - ] - }, - { - "cell_type": "markdown", - "source": [ - "![image.png]()" - ], - "metadata": { - "id": "NRIoaXSzzP8M" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RIvbQjyhbHhW" - }, - "outputs": [], - "source": [ - "from scrapegraphai.graphs import SearchGraph\n", - "\n", - "# Define the configuration for the graph\n", - "graph_config = {\n", - " \"llm\": {\n", - " \"api_key\": OPENAI_API_KEY,\n", - " \"model\": \"openai/gpt-4o-mini\",\n", - " \"temperature\": 0,\n", - " },\n", - "}\n", - "\n", - "# Create the SearchGraph instance\n", - "search_graph = SearchGraph(\n", - " prompt=\"List me all the European countries. Look in wikipedia.\",\n", - " config=graph_config\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XnVtc7SzCkUY" - }, - "outputs": [], - "source": [ - "result = search_graph.run()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3LPAh-yQCqkY" - }, - "source": [ - "Prettify the result and display the JSON" - ] - }, - { - "cell_type": "code", - "source": [ - "import json\n", - "\n", - "output = json.dumps(result, indent=2)\n", - "\n", - "line_list = output.split(\"\\n\") # Sort of line replacing \"\\n\" with a new line\n", - "\n", - "for line in line_list:\n", - " print(line)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "xgnWDLTjzHwv", - "outputId": "f0c8ebf4-5ba5-4330-dbd8-1c9fdd93eaeb" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{\n", - " \"European_countries\": [\n", - " \"Albania\",\n", - " \"Andorra\",\n", - " \"Armenia\",\n", - " \"Austria\",\n", - " \"Azerbaijan\",\n", - " \"Belarus\",\n", - " \"Belgium\",\n", - " \"Bosnia and Herzegovina\",\n", - " \"Bulgaria\",\n", - " \"Croatia\",\n", - " \"Cyprus\",\n", - " \"Czech Republic\",\n", - " \"Denmark\",\n", - " \"Estonia\",\n", - " \"Finland\",\n", - " \"France\",\n", - " \"Georgia\",\n", - " \"Germany\",\n", - " \"Greece\",\n", - " \"Hungary\",\n", - " \"Iceland\",\n", - " \"Ireland\",\n", - " \"Italy\",\n", - " \"Jersey\",\n", - " \"Isle of Man\",\n", - " \"Kazakhstan\",\n", - " \"Latvia\",\n", - " \"Liechtenstein\",\n", - " \"Lithuania\",\n", - " \"Luxembourg\",\n", - " \"Malta\",\n", - " \"Moldova\",\n", - " \"Monaco\",\n", - " \"Montenegro\",\n", - " \"Netherlands\",\n", - " \"North Macedonia\",\n", - " \"Norway\",\n", - " \"Poland\",\n", - " \"Portugal\",\n", - " \"Romania\",\n", - " \"Russia\",\n", - " \"San Marino\",\n", - " \"Serbia\",\n", - " \"Slovakia\",\n", - " \"Slovenia\",\n", - " \"Spain\",\n", - " \"Sweden\",\n", - " \"Switzerland\",\n", - " \"Turkey\",\n", - " \"Ukraine\",\n", - " \"United Kingdom\",\n", - " \"Vatican City\",\n", - " \"Kosovo\",\n", - " \"Gibraltar\",\n", - " \"Faroe Islands\",\n", - " \"Guernsey\",\n", - " \"Jersey\"\n", - " ],\n", - " \"sources\": [\n", - " \"https://simple.wikipedia.org/wiki/List_of_European_countries\",\n", - " \"https://en.wikipedia.org/wiki/List_of_European_countries_by_population\",\n", - " \"https://en.wikipedia.org/wiki/Member_state_of_the_European_Union\"\n", - " ]\n", - "}\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N5IMdKHvlXFY" - }, - "source": [ - "# SpeechGraph\n", - "**SpeechGraph** is a class representing one of the default scraping pipelines that generate the answer together with an audio file. Similar to the **SmartScraperGraph** but with the addition of the **TextToSpeechNode** node.\n" - ] - }, - { - "cell_type": "markdown", - "source": [ - "![image.png]()" - ], - "metadata": { - "id": "pqJsEVgizs-M" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "W9KhWlT3lXFd" - }, - "outputs": [], - "source": [ - "from scrapegraphai.graphs import SpeechGraph\n", - "\n", - "# Define the configuration for the graph\n", - "graph_config = {\n", - " \"llm\": {\n", - " \"api_key\": OPENAI_API_KEY,\n", - " \"model\": \"gpt-3.5-turbo\",\n", - " },\n", - " \"tts_model\": {\n", - " \"api_key\": OPENAI_API_KEY,\n", - " \"model\": \"tts-1\",\n", - " \"voice\": \"alloy\"\n", - " },\n", - " \"output_path\": \"website_summary.mp3\",\n", - "}\n", - "\n", - "# Create the SpeechGraph instance\n", - "speech_graph = SpeechGraph(\n", - " prompt=\"Create a summary of the website\",\n", - " source=\"https://perinim.github.io/projects/\",\n", - " config=graph_config,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "nVolb3paEczD", - "outputId": "d7d316a0-7580-4a6c-8f20-7e1cb1fc3f07" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--- Executing Fetch Node ---\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Fetching pages: 100%|##########| 1/1 [00:00<00:00, 17.07it/s]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--- Executing Parse Node ---\n", - "--- Executing RAG Node ---\n", - "--- (updated chunks metadata) ---\n", - "--- (tokens compressed and vector stored) ---\n", - "--- Executing GenerateAnswer Node ---\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Processing chunks: 100%|██████████| 1/1 [00:00<00:00, 339.78it/s]\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--- Executing TextToSpeech Node ---\n", - "Audio saved to website_summary.mp3\n" - ] - } - ], - "source": [ - "result = speech_graph.run()\n", - "answer = result.get(\"answer\", \"No answer found\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "znt2EOKZE3z2" - }, - "source": [ - "Prettify the result and display the JSON" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "QqY0TbwbEp-O", - "outputId": "c2b1127d-0c49-4121-922e-39da65c329ee" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{\n", - " \"summary\": {\n", - " \"title\": \"Projects | \",\n", - " \"projects\": [\n", - " {\n", - " \"title\": \"Rotary Pendulum RL\",\n", - " \"description\": \"Open Source project aimed at controlling a real life rotary pendulum using RL algorithms\"\n", - " },\n", - " {\n", - " \"title\": \"DQN Implementation from scratch\",\n", - " \"description\": \"Developed a Deep Q-Network algorithm to train a simple and double pendulum\"\n", - " },\n", - " {\n", - " \"title\": \"Multi Agents HAED\",\n", - " \"description\": \"University project which focuses on simulating a multi-agent system to perform environment mapping. Agents, equipped with sensors, explore and record their surroundings, considering uncertainties in their readings.\"\n", - " },\n", - " {\n", - " \"title\": \"Wireless ESC for Modular Drones\",\n", - " \"description\": \"Modular drone architecture proposal and proof of concept. The project received maximum grade.\"\n", - " }\n", - " ]\n", - " }\n", - "}\n" - ] - } - ], - "source": [ - "import json\n", - "\n", - "output = json.dumps(answer, indent=2)\n", - "\n", - "line_list = output.split(\"\\n\") # Sort of line replacing \"\\n\" with a new line\n", - "\n", - "for line in line_list:\n", - " print(line)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 75 - }, - "id": "lfJ_jVwklXFd", - "outputId": "dc4ad491-4422-4edb-91ae-35775b23168a" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "" - ], - "text/html": [ - "\n", - " \n", - " " - ] - }, - "metadata": {} - } - ], - "source": [ - "from IPython.display import Audio\n", - "wn = Audio(\"website_summary.mp3\", autoplay=True)\n", - "display(wn)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p9kC0x4NuLTx" - }, - "source": [ - "# Build a Custom Graph\n", - "It is possible to **build your own scraping pipeline** by using the default nodes and place them as you wish, without using pre-defined graphs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Pr6DIqt2uLUI" - }, - "source": [ - "You can create **custom graphs** based on your necessities, using standard nodes provided by the library.\n", - "\n", - "The list of the existing nodes can be found through the *nodes_metadata* json construct.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "-o29vDSIvG4t", - "outputId": "be469b65-ba01-437a-e217-ed1c4f3ad264" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "dict_keys(['SearchInternetNode', 'FetchNode', 'GetProbableTagsNode', 'ParseNode', 'RAGNode', 'GenerateAnswerNode', 'ConditionalNode', 'ImageToTextNode', 'TextToSpeechNode'])" - ] - }, - "metadata": {}, - "execution_count": 17 - } - ], - "source": [ - "# check available nodes\n", - "from scrapegraphai.helpers import nodes_metadata\n", - "\n", - "nodes_metadata.keys()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "829wW5E6vrjJ", - "outputId": "58203025-64ce-4107-f6d3-3b3cfa5537d5" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "{'description': 'Converts image content to text by \\n extracting visual information and interpreting it.',\n", - " 'type': 'node',\n", - " 'args': {'image_data': 'Data of the image to be processed.'},\n", - " 'returns': \"Updated state with the textual description of the image under 'image_text' key.\"}" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# to get more information about a node\n", - "nodes_metadata['ImageToTextNode']" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3pnNFDckwWy7" - }, - "source": [ - "To create a custom graph we must:\n", - "\n", - "1. **Istantiate the nodes** you want to use\n", - "2. Create the graph using **BaseGraph** class, which must have a **list of nodes**, tuples representing the **edges** of the graph, an **entry_point**\n", - "3. Run it using the **execute** method\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "eQLZJyg4uLUJ" - }, - "outputs": [], - "source": [ - "from langchain_openai import OpenAIEmbeddings\n", - "from scrapegraphai.models import OpenAI\n", - "from scrapegraphai.graphs import BaseGraph\n", - "from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode\n", - "\n", - "# Define the configuration for the graph\n", - "graph_config = {\n", - " \"llm\": {\n", - " \"api_key\": OPENAI_API_KEY,\n", - " \"model\": \"openai/gpt-4o\",\n", - " \"temperature\": 0,\n", - " \"streaming\": True\n", - " },\n", - "}\n", - "\n", - "llm_model = OpenAI(graph_config[\"llm\"])\n", - "embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)\n", - "\n", - "# define the nodes for the graph\n", - "fetch_node = FetchNode(\n", - " input=\"url | local_dir\",\n", - " output=[\"doc\", \"link_urls\", \"img_urls\"],\n", - " node_config={\n", - " \"verbose\": True,\n", - " \"headless\": True,\n", - " }\n", - ")\n", - "parse_node = ParseNode(\n", - " input=\"doc\",\n", - " output=[\"parsed_doc\"],\n", - " node_config={\n", - " \"chunk_size\": 4096,\n", - " \"verbose\": True,\n", - " }\n", - ")\n", - "rag_node = RAGNode(\n", - " input=\"user_prompt & (parsed_doc | doc)\",\n", - " output=[\"relevant_chunks\"],\n", - " node_config={\n", - " \"llm_model\": llm_model,\n", - " \"embedder_model\": embedder,\n", - " \"verbose\": True,\n", - " }\n", - ")\n", - "generate_answer_node = GenerateAnswerNode(\n", - " input=\"user_prompt & (relevant_chunks | parsed_doc | doc)\",\n", - " output=[\"answer\"],\n", - " node_config={\n", - " \"llm_model\": llm_model,\n", - " \"verbose\": True,\n", - " }\n", - ")\n", - "\n", - "# create the graph by defining the nodes and their connections\n", - "graph = BaseGraph(\n", - " nodes=[\n", - " fetch_node,\n", - " parse_node,\n", - " rag_node,\n", - " generate_answer_node,\n", - " ],\n", - " edges=[\n", - " (fetch_node, parse_node),\n", - " (parse_node, rag_node),\n", - " (rag_node, generate_answer_node)\n", - " ],\n", - " entry_point=fetch_node\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "5FYKF9H1Fvb8", - "outputId": "666d51fe-5e2f-4398-a3b0-bb820960a0d1" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--- Executing Fetch Node ---\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Fetching pages: 100%|##########| 1/1 [00:00<00:00, 28.65it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--- Executing Parse Node ---\n", - "--- Executing RAG Node ---\n", - "--- (updated chunks metadata) ---\n", - "--- (tokens compressed and vector stored) ---\n", - "--- Executing GenerateAnswer Node ---\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Processing chunks: 100%|██████████| 1/1 [00:00<00:00, 911.01it/s]\n" - ] - } + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9_CQrFgOj78b" + }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install scrapegraphai\n", + "!apt install chromium-chromedriver\n", + "!pip install nest_asyncio\n", + "!pip install playwright\n", + "!playwright install" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tb33AcRHywFb" + }, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "00a84YVhhxJr" + }, + "outputs": [], + "source": [ + "# correct APIKEY\n", + "OPENAI_API_KEY = \"YOUR API KEY\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vGDjka17pqqg" + }, + "source": [ + "For more examples visit [the examples folder](https://github.com/ScrapeGraphAI/Scrapegraph-ai/tree/main/examples)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Mrujgp-nlp12" + }, + "source": [ + "# SmartScraperGraph\n", + "**SmartScraperGraph** is a class representing one of the default scraping pipelines. It uses a direct graph implementation where each node has its own function, from retrieving html from a website to extracting relevant information based on your query and generate a coherent answer." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M-dmSB0_zHCQ" + }, + "source": [ + "![Screenshot 2024-09-19 alle 17.04.56.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uqYBNOM2YZD9" + }, + "source": [ + "## Using OpenAI models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ogiF4g5Z-bzG" + }, + "outputs": [], + "source": [ + "from scrapegraphai.graphs import SmartScraperGraph" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7ZzONlJ6-oe_" + }, + "source": [ + "Define the configuration for the graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MPZgrZ12-eRc" + }, + "outputs": [], + "source": [ + "graph_config = {\n", + " \"llm\": {\n", + " \"api_key\": OPENAI_API_KEY,\n", + " \"model\": \"openai/gpt-4o-mini\",\n", + " \"temperature\": 0,\n", + " },\n", + " \"verbose\": True,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DjDt_10r-q8P" + }, + "source": [ + "Create the SmartScraperGraph instance and run it" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aV4VTnx9-h_d" + }, + "outputs": [], + "source": [ + "smart_scraper_graph = SmartScraperGraph(\n", + " prompt=\"List me all the projects with their descriptions.\",\n", + " # also accepts a string with the already downloaded HTML code\n", + " source=\"https://perinim.github.io/projects/\",\n", + " config=graph_config,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "E3pyGQZLTiZ8" + }, + "outputs": [], + "source": [ + "graph_config = {\n", + " \"llm\": {\n", + " \"api_key\": OPENAI_API_KEY,\n", + " \"model\": \"openai/gpt-4o-mini\",\n", + " },\n", + " \"verbose\": True,\n", + " \"headless\": True,\n", + "}\n", + "\n", + "# ************************************************\n", + "# Create the SmartScraperGraph instance and run it\n", + "# ************************************************\n", + "\n", + "smart_scraper_graph = SmartScraperGraph(\n", + " prompt=\"List me all the projects with their description\",\n", + " source=\"https://perinim.github.io/projects/\",\n", + " config=graph_config,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Zty23idsAtwU", + "outputId": "419dd75f-18c6-44d2-da82-ca8967d17e0f" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "--- Executing Fetch Node ---\n", + "--- (Fetching HTML from: https://perinim.github.io/projects/) ---\n", + "--- Executing ParseNode Node ---\n", + "--- Executing GenerateAnswer Node ---\n" + ] + } + ], + "source": [ + "result = smart_scraper_graph.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rnGhLGCuAqRU", + "outputId": "062aeab2-3e96-4fec-d04a-b9acae142f40" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"projects\": [\n", + " {\n", + " \"name\": \"Rotary Pendulum RL\",\n", + " \"description\": \"Open Source project aimed at controlling a real life rotary pendulum using RL algorithms\"\n", + " },\n", + " {\n", + " \"name\": \"DQN Implementation from scratch\",\n", + " \"description\": \"Developed a Deep Q-Network algorithm to train a simple and double pendulum\"\n", + " },\n", + " {\n", + " \"name\": \"Multi Agents HAED\",\n", + " \"description\": \"University project which focuses on simulating a multi-agent system to perform environment mapping. Agents, equipped with sensors, explore and record their surroundings, considering uncertainties in their readings.\"\n", + " },\n", + " {\n", + " \"name\": \"Wireless ESC for Modular Drones\",\n", + " \"description\": \"Modular drone architecture proposal and proof of concept. The project received maximum grade.\"\n", + " }\n", + " ]\n", + "}\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "output = json.dumps(result, indent=2)\n", + "\n", + "line_list = output.split(\"\\n\") # Sort of line replacing \"\\n\" with a new line\n", + "\n", + "for line in line_list:\n", + " print(line)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5poLHYLVa-6E" + }, + "source": [ + "# Search graph\n", + "This graph **transforms** the user prompt in a **internet search query**, fetch the relevant URLs, and start the scraping process. Similar to the **SmartScraperGraph** but with the addition of the **SearchInternetNode** node." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NRIoaXSzzP8M" + }, + "source": [ + "![image.png]()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RIvbQjyhbHhW" + }, + "outputs": [], + "source": [ + "from scrapegraphai.graphs import SearchGraph\n", + "\n", + "# Define the configuration for the graph\n", + "graph_config = {\n", + " \"llm\": {\n", + " \"api_key\": OPENAI_API_KEY,\n", + " \"model\": \"openai/gpt-4o-mini\",\n", + " \"temperature\": 0,\n", + " },\n", + "}\n", + "\n", + "# Create the SearchGraph instance\n", + "search_graph = SearchGraph(\n", + " prompt=\"List me all the European countries. Look in wikipedia.\", config=graph_config\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XnVtc7SzCkUY" + }, + "outputs": [], + "source": [ + "result = search_graph.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3LPAh-yQCqkY" + }, + "source": [ + "Prettify the result and display the JSON" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xgnWDLTjzHwv", + "outputId": "f0c8ebf4-5ba5-4330-dbd8-1c9fdd93eaeb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"European_countries\": [\n", + " \"Albania\",\n", + " \"Andorra\",\n", + " \"Armenia\",\n", + " \"Austria\",\n", + " \"Azerbaijan\",\n", + " \"Belarus\",\n", + " \"Belgium\",\n", + " \"Bosnia and Herzegovina\",\n", + " \"Bulgaria\",\n", + " \"Croatia\",\n", + " \"Cyprus\",\n", + " \"Czech Republic\",\n", + " \"Denmark\",\n", + " \"Estonia\",\n", + " \"Finland\",\n", + " \"France\",\n", + " \"Georgia\",\n", + " \"Germany\",\n", + " \"Greece\",\n", + " \"Hungary\",\n", + " \"Iceland\",\n", + " \"Ireland\",\n", + " \"Italy\",\n", + " \"Jersey\",\n", + " \"Isle of Man\",\n", + " \"Kazakhstan\",\n", + " \"Latvia\",\n", + " \"Liechtenstein\",\n", + " \"Lithuania\",\n", + " \"Luxembourg\",\n", + " \"Malta\",\n", + " \"Moldova\",\n", + " \"Monaco\",\n", + " \"Montenegro\",\n", + " \"Netherlands\",\n", + " \"North Macedonia\",\n", + " \"Norway\",\n", + " \"Poland\",\n", + " \"Portugal\",\n", + " \"Romania\",\n", + " \"Russia\",\n", + " \"San Marino\",\n", + " \"Serbia\",\n", + " \"Slovakia\",\n", + " \"Slovenia\",\n", + " \"Spain\",\n", + " \"Sweden\",\n", + " \"Switzerland\",\n", + " \"Turkey\",\n", + " \"Ukraine\",\n", + " \"United Kingdom\",\n", + " \"Vatican City\",\n", + " \"Kosovo\",\n", + " \"Gibraltar\",\n", + " \"Faroe Islands\",\n", + " \"Guernsey\",\n", + " \"Jersey\"\n", + " ],\n", + " \"sources\": [\n", + " \"https://simple.wikipedia.org/wiki/List_of_European_countries\",\n", + " \"https://en.wikipedia.org/wiki/List_of_European_countries_by_population\",\n", + " \"https://en.wikipedia.org/wiki/Member_state_of_the_European_Union\"\n", + " ]\n", + "}\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "output = json.dumps(result, indent=2)\n", + "\n", + "line_list = output.split(\"\\n\") # Sort of line replacing \"\\n\" with a new line\n", + "\n", + "for line in line_list:\n", + " print(line)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N5IMdKHvlXFY" + }, + "source": [ + "# SpeechGraph\n", + "**SpeechGraph** is a class representing one of the default scraping pipelines that generate the answer together with an audio file. Similar to the **SmartScraperGraph** but with the addition of the **TextToSpeechNode** node.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pqJsEVgizs-M" + }, + "source": [ + "![image.png]()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "W9KhWlT3lXFd" + }, + "outputs": [], + "source": [ + "from scrapegraphai.graphs import SpeechGraph\n", + "\n", + "# Define the configuration for the graph\n", + "graph_config = {\n", + " \"llm\": {\n", + " \"api_key\": OPENAI_API_KEY,\n", + " \"model\": \"gpt-3.5-turbo\",\n", + " },\n", + " \"tts_model\": {\"api_key\": OPENAI_API_KEY, \"model\": \"tts-1\", \"voice\": \"alloy\"},\n", + " \"output_path\": \"website_summary.mp3\",\n", + "}\n", + "\n", + "# Create the SpeechGraph instance\n", + "speech_graph = SpeechGraph(\n", + " prompt=\"Create a summary of the website\",\n", + " source=\"https://perinim.github.io/projects/\",\n", + " config=graph_config,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nVolb3paEczD", + "outputId": "d7d316a0-7580-4a6c-8f20-7e1cb1fc3f07" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Executing Fetch Node ---\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fetching pages: 100%|##########| 1/1 [00:00<00:00, 17.07it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Executing Parse Node ---\n", + "--- Executing RAG Node ---\n", + "--- (updated chunks metadata) ---\n", + "--- (tokens compressed and vector stored) ---\n", + "--- Executing GenerateAnswer Node ---\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing chunks: 100%|██████████| 1/1 [00:00<00:00, 339.78it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Executing TextToSpeech Node ---\n", + "Audio saved to website_summary.mp3\n" + ] + } + ], + "source": [ + "result = speech_graph.run()\n", + "answer = result.get(\"answer\", \"No answer found\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "znt2EOKZE3z2" + }, + "source": [ + "Prettify the result and display the JSON" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QqY0TbwbEp-O", + "outputId": "c2b1127d-0c49-4121-922e-39da65c329ee" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"summary\": {\n", + " \"title\": \"Projects | \",\n", + " \"projects\": [\n", + " {\n", + " \"title\": \"Rotary Pendulum RL\",\n", + " \"description\": \"Open Source project aimed at controlling a real life rotary pendulum using RL algorithms\"\n", + " },\n", + " {\n", + " \"title\": \"DQN Implementation from scratch\",\n", + " \"description\": \"Developed a Deep Q-Network algorithm to train a simple and double pendulum\"\n", + " },\n", + " {\n", + " \"title\": \"Multi Agents HAED\",\n", + " \"description\": \"University project which focuses on simulating a multi-agent system to perform environment mapping. Agents, equipped with sensors, explore and record their surroundings, considering uncertainties in their readings.\"\n", + " },\n", + " {\n", + " \"title\": \"Wireless ESC for Modular Drones\",\n", + " \"description\": \"Modular drone architecture proposal and proof of concept. The project received maximum grade.\"\n", + " }\n", + " ]\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "output = json.dumps(answer, indent=2)\n", + "\n", + "line_list = output.split(\"\\n\") # Sort of line replacing \"\\n\" with a new line\n", + "\n", + "for line in line_list:\n", + " print(line)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "id": "lfJ_jVwklXFd", + "outputId": "dc4ad491-4422-4edb-91ae-35775b23168a" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " ], - "source": [ - "# execute the graph\n", - "result, execution_info = graph.execute({\n", - " \"user_prompt\": \"List me the projects with their description\",\n", - " \"url\": \"https://perinim.github.io/projects/\"\n", - "})\n", - "\n", - "# get the answer from the result\n", - "result = result.get(\"answer\", \"No answer found.\")" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import Audio\n", + "\n", + "wn = Audio(\"website_summary.mp3\", autoplay=True)\n", + "display(wn)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p9kC0x4NuLTx" + }, + "source": [ + "# Build a Custom Graph\n", + "It is possible to **build your own scraping pipeline** by using the default nodes and place them as you wish, without using pre-defined graphs." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pr6DIqt2uLUI" + }, + "source": [ + "You can create **custom graphs** based on your necessities, using standard nodes provided by the library.\n", + "\n", + "The list of the existing nodes can be found through the *nodes_metadata* json construct.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "-o29vDSIvG4t", + "outputId": "be469b65-ba01-437a-e217-ed1c4f3ad264" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "JEP8_zZ9GHW2" - }, - "source": [ - "Prettify the result and display the JSON" + "data": { + "text/plain": [ + "dict_keys(['SearchInternetNode', 'FetchNode', 'GetProbableTagsNode', 'ParseNode', 'RAGNode', 'GenerateAnswerNode', 'ConditionalNode', 'ImageToTextNode', 'TextToSpeechNode'])" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "nx9qGaxvFmfT", - "outputId": "fb327a6a-0dfa-417b-8dbb-505bebc96fe8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{\n", - " \"projects\": [\n", - " {\n", - " \"title\": \"Rotary Pendulum RL\",\n", - " \"description\": \"Open Source project aimed at controlling a real life rotary pendulum using RL algorithms\"\n", - " },\n", - " {\n", - " \"title\": \"DQN Implementation from scratch\",\n", - " \"description\": \"Developed a Deep Q-Network algorithm to train a simple and double pendulum\"\n", - " },\n", - " {\n", - " \"title\": \"Multi Agents HAED\",\n", - " \"description\": \"University project which focuses on simulating a multi-agent system to perform environment mapping. Agents, equipped with sensors, explore and record their surroundings, considering uncertainties in their readings.\"\n", - " },\n", - " {\n", - " \"title\": \"Wireless ESC for Modular Drones\",\n", - " \"description\": \"Modular drone architecture proposal and proof of concept. The project received maximum grade.\"\n", - " }\n", - " ]\n", - "}\n" - ] - } - ], - "source": [ - "import json\n", - "\n", - "output = json.dumps(result, indent=2)\n", - "\n", - "line_list = output.split(\"\\n\") # Sort of line replacing \"\\n\" with a new line\n", - "\n", - "for line in line_list:\n", - " print(line)" + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# check available nodes\n", + "from scrapegraphai.helpers import nodes_metadata\n", + "\n", + "nodes_metadata.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "829wW5E6vrjJ", + "outputId": "58203025-64ce-4107-f6d3-3b3cfa5537d5" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'description': 'Converts image content to text by \\n extracting visual information and interpreting it.',\n", + " 'type': 'node',\n", + " 'args': {'image_data': 'Data of the image to be processed.'},\n", + " 'returns': \"Updated state with the textual description of the image under 'image_text' key.\"}" ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" } - ], - "metadata": { + ], + "source": [ + "# to get more information about a node\n", + "nodes_metadata[\"ImageToTextNode\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3pnNFDckwWy7" + }, + "source": [ + "To create a custom graph we must:\n", + "\n", + "1. **Istantiate the nodes** you want to use\n", + "2. Create the graph using **BaseGraph** class, which must have a **list of nodes**, tuples representing the **edges** of the graph, an **entry_point**\n", + "3. Run it using the **execute** method\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eQLZJyg4uLUJ" + }, + "outputs": [], + "source": [ + "from langchain_openai import OpenAIEmbeddings\n", + "from scrapegraphai.models import OpenAI\n", + "from scrapegraphai.graphs import BaseGraph\n", + "from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode\n", + "\n", + "# Define the configuration for the graph\n", + "graph_config = {\n", + " \"llm\": {\n", + " \"api_key\": OPENAI_API_KEY,\n", + " \"model\": \"openai/gpt-4o\",\n", + " \"temperature\": 0,\n", + " \"streaming\": True,\n", + " },\n", + "}\n", + "\n", + "llm_model = OpenAI(graph_config[\"llm\"])\n", + "embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)\n", + "\n", + "# define the nodes for the graph\n", + "fetch_node = FetchNode(\n", + " input=\"url | local_dir\",\n", + " output=[\"doc\", \"link_urls\", \"img_urls\"],\n", + " node_config={\n", + " \"verbose\": True,\n", + " \"headless\": True,\n", + " },\n", + ")\n", + "parse_node = ParseNode(\n", + " input=\"doc\",\n", + " output=[\"parsed_doc\"],\n", + " node_config={\n", + " \"chunk_size\": 4096,\n", + " \"verbose\": True,\n", + " },\n", + ")\n", + "rag_node = RAGNode(\n", + " input=\"user_prompt & (parsed_doc | doc)\",\n", + " output=[\"relevant_chunks\"],\n", + " node_config={\n", + " \"llm_model\": llm_model,\n", + " \"embedder_model\": embedder,\n", + " \"verbose\": True,\n", + " },\n", + ")\n", + "generate_answer_node = GenerateAnswerNode(\n", + " input=\"user_prompt & (relevant_chunks | parsed_doc | doc)\",\n", + " output=[\"answer\"],\n", + " node_config={\n", + " \"llm_model\": llm_model,\n", + " \"verbose\": True,\n", + " },\n", + ")\n", + "\n", + "# create the graph by defining the nodes and their connections\n", + "graph = BaseGraph(\n", + " nodes=[\n", + " fetch_node,\n", + " parse_node,\n", + " rag_node,\n", + " generate_answer_node,\n", + " ],\n", + " edges=[\n", + " (fetch_node, parse_node),\n", + " (parse_node, rag_node),\n", + " (rag_node, generate_answer_node),\n", + " ],\n", + " entry_point=fetch_node,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { "colab": { - "collapsed_sections": [ - "N5IMdKHvlXFY" - ], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" + "base_uri": "https://localhost:8080/" + }, + "id": "5FYKF9H1Fvb8", + "outputId": "666d51fe-5e2f-4398-a3b0-bb820960a0d1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Executing Fetch Node ---\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fetching pages: 100%|##########| 1/1 [00:00<00:00, 28.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Executing Parse Node ---\n", + "--- Executing RAG Node ---\n", + "--- (updated chunks metadata) ---\n", + "--- (tokens compressed and vector stored) ---\n", + "--- Executing GenerateAnswer Node ---\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing chunks: 100%|██████████| 1/1 [00:00<00:00, 911.01it/s]\n" + ] } + ], + "source": [ + "# execute the graph\n", + "result, execution_info = graph.execute(\n", + " {\n", + " \"user_prompt\": \"List me the projects with their description\",\n", + " \"url\": \"https://perinim.github.io/projects/\",\n", + " }\n", + ")\n", + "\n", + "# get the answer from the result\n", + "result = result.get(\"answer\", \"No answer found.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JEP8_zZ9GHW2" + }, + "source": [ + "Prettify the result and display the JSON" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nx9qGaxvFmfT", + "outputId": "fb327a6a-0dfa-417b-8dbb-505bebc96fe8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"projects\": [\n", + " {\n", + " \"title\": \"Rotary Pendulum RL\",\n", + " \"description\": \"Open Source project aimed at controlling a real life rotary pendulum using RL algorithms\"\n", + " },\n", + " {\n", + " \"title\": \"DQN Implementation from scratch\",\n", + " \"description\": \"Developed a Deep Q-Network algorithm to train a simple and double pendulum\"\n", + " },\n", + " {\n", + " \"title\": \"Multi Agents HAED\",\n", + " \"description\": \"University project which focuses on simulating a multi-agent system to perform environment mapping. Agents, equipped with sensors, explore and record their surroundings, considering uncertainties in their readings.\"\n", + " },\n", + " {\n", + " \"title\": \"Wireless ESC for Modular Drones\",\n", + " \"description\": \"Modular drone architecture proposal and proof of concept. The project received maximum grade.\"\n", + " }\n", + " ]\n", + "}\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "output = json.dumps(result, indent=2)\n", + "\n", + "line_list = output.split(\"\\n\") # Sort of line replacing \"\\n\" with a new line\n", + "\n", + "for line in line_list:\n", + " print(line)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "N5IMdKHvlXFY" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 0 + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/code_generator_graph/ollama/code_generator_graph_ollama.py b/examples/code_generator_graph/ollama/code_generator_graph_ollama.py index 339bb03c..0fdcf04c 100644 --- a/examples/code_generator_graph/ollama/code_generator_graph_ollama.py +++ b/examples/code_generator_graph/ollama/code_generator_graph_ollama.py @@ -2,7 +2,6 @@ Basic example of scraping pipeline using Code Generator with schema """ -import json from typing import List from dotenv import load_dotenv diff --git a/examples/custom_graph/ollama/custom_graph_ollama.py b/examples/custom_graph/ollama/custom_graph_ollama.py index f7aebd3d..4574ee6b 100644 --- a/examples/custom_graph/ollama/custom_graph_ollama.py +++ b/examples/custom_graph/ollama/custom_graph_ollama.py @@ -2,8 +2,6 @@ Example of custom graph using existing nodes """ -import os - from langchain_openai import ChatOpenAI, OpenAIEmbeddings from scrapegraphai.graphs import BaseGraph @@ -11,7 +9,6 @@ FetchNode, GenerateAnswerNode, ParseNode, - RAGNode, RobotsNode, ) diff --git a/examples/extras/chromium_selenium.py b/examples/extras/chromium_selenium.py index 811ebc2a..e8354e7e 100644 --- a/examples/extras/chromium_selenium.py +++ b/examples/extras/chromium_selenium.py @@ -9,7 +9,6 @@ ChromiumLoader, ) from scrapegraphai.graphs import SmartScraperGraph -from scrapegraphai.utils import prettify_exec_info # Load environment variables for API keys load_dotenv() diff --git a/examples/extras/no_cut.py b/examples/extras/no_cut.py index c638df84..044a4f7a 100644 --- a/examples/extras/no_cut.py +++ b/examples/extras/no_cut.py @@ -3,7 +3,6 @@ """ import json -import os from scrapegraphai.graphs import SmartScraperGraph from scrapegraphai.utils import prettify_exec_info diff --git a/examples/extras/serch_graph_scehma.py b/examples/extras/serch_graph_scehma.py index 0ad66d4e..02b76db6 100644 --- a/examples/extras/serch_graph_scehma.py +++ b/examples/extras/serch_graph_scehma.py @@ -40,7 +40,7 @@ class Ceos(BaseModel): # ************************************************ search_graph = SearchGraph( - prompt=f"Who is the ceo of Appke?", + prompt="Who is the ceo of Appke?", schema=Ceos, config=graph_config, ) diff --git a/examples/script_generator_graph/ollama/script_multi_generator_ollama.py b/examples/script_generator_graph/ollama/script_multi_generator_ollama.py index a8a53f1f..b4af2c9d 100644 --- a/examples/script_generator_graph/ollama/script_multi_generator_ollama.py +++ b/examples/script_generator_graph/ollama/script_multi_generator_ollama.py @@ -2,8 +2,6 @@ Basic example of scraping pipeline using ScriptCreatorGraph """ -import os - from dotenv import load_dotenv from scrapegraphai.graphs import ScriptCreatorMultiGraph diff --git a/scrapegraphai/builders/graph_builder.py b/scrapegraphai/builders/graph_builder.py index c44ea72a..1179ebe7 100644 --- a/scrapegraphai/builders/graph_builder.py +++ b/scrapegraphai/builders/graph_builder.py @@ -113,9 +113,7 @@ def _create_extraction_chain(self): {nodes_description} Based on the user's input: "{input}", identify the essential nodes required for the task and suggest a graph configuration that outlines the flow between the chosen nodes. - """.format( - nodes_description=self.nodes_description, input="{input}" - ) + """.format(nodes_description=self.nodes_description, input="{input}") extraction_prompt = ChatPromptTemplate.from_template( create_graph_prompt_template ) diff --git a/scrapegraphai/docloaders/scrape_do.py b/scrapegraphai/docloaders/scrape_do.py index be37e3f7..4c3adbb3 100644 --- a/scrapegraphai/docloaders/scrape_do.py +++ b/scrapegraphai/docloaders/scrape_do.py @@ -2,10 +2,10 @@ Scrape_do module """ +import os import urllib.parse import requests -import os import urllib3 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 3908545d..c42e4daf 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -177,7 +177,7 @@ def _create_llm(self, llm_config: dict) -> object: ] if len(possible_providers) <= 0: raise ValueError( - f"""Provider {llm_params['model_provider']} is not supported. + f"""Provider {llm_params["model_provider"]} is not supported. If possible, try to use a model instance instead.""" ) llm_params["model_provider"] = possible_providers[0] @@ -190,7 +190,7 @@ def _create_llm(self, llm_config: dict) -> object: if llm_params["model_provider"] not in known_providers: raise ValueError( - f"""Provider {llm_params['model_provider']} is not supported. + f"""Provider {llm_params["model_provider"]} is not supported. If possible, try to use a model instance instead.""" ) @@ -201,7 +201,7 @@ def _create_llm(self, llm_config: dict) -> object: ] except KeyError: print( - f"""Max input tokens for model {llm_params['model_provider']}/{llm_params['model']} not found, + f"""Max input tokens for model {llm_params["model_provider"]}/{llm_params["model"]} not found, please specify the model_tokens parameter in the llm section of the graph configuration. Using default token size: 8192""" ) diff --git a/scrapegraphai/graphs/csv_scraper_multi_graph.py b/scrapegraphai/graphs/csv_scraper_multi_graph.py index 5e3a398d..fbad1203 100644 --- a/scrapegraphai/graphs/csv_scraper_multi_graph.py +++ b/scrapegraphai/graphs/csv_scraper_multi_graph.py @@ -49,7 +49,6 @@ def __init__( config: dict, schema: Optional[Type[BaseModel]] = None, ): - self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/json_scraper_multi_graph.py b/scrapegraphai/graphs/json_scraper_multi_graph.py index 6623718b..fb49845b 100644 --- a/scrapegraphai/graphs/json_scraper_multi_graph.py +++ b/scrapegraphai/graphs/json_scraper_multi_graph.py @@ -49,7 +49,6 @@ def __init__( config: dict, schema: Optional[Type[BaseModel]] = None, ): - self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/omni_search_graph.py b/scrapegraphai/graphs/omni_search_graph.py index c30033ab..0fc26697 100644 --- a/scrapegraphai/graphs/omni_search_graph.py +++ b/scrapegraphai/graphs/omni_search_graph.py @@ -44,7 +44,6 @@ class OmniSearchGraph(AbstractGraph): def __init__( self, prompt: str, config: dict, schema: Optional[Type[BaseModel]] = None ): - self.max_results = config.get("max_results", 3) self.copy_config = safe_deepcopy(config) diff --git a/scrapegraphai/graphs/script_creator_multi_graph.py b/scrapegraphai/graphs/script_creator_multi_graph.py index a8e01729..843f8ad9 100644 --- a/scrapegraphai/graphs/script_creator_multi_graph.py +++ b/scrapegraphai/graphs/script_creator_multi_graph.py @@ -48,7 +48,6 @@ def __init__( config: dict, schema: Optional[Type[BaseModel]] = None, ): - self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py b/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py index ebd7b936..8454a860 100644 --- a/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py +++ b/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py @@ -53,7 +53,6 @@ def __init__( config: dict, schema: Optional[Type[BaseModel]] = None, ): - self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/smart_scraper_multi_graph.py b/scrapegraphai/graphs/smart_scraper_multi_graph.py index 226aae00..f017ec09 100644 --- a/scrapegraphai/graphs/smart_scraper_multi_graph.py +++ b/scrapegraphai/graphs/smart_scraper_multi_graph.py @@ -55,7 +55,6 @@ def __init__( config: dict, schema: Optional[Type[BaseModel]] = None, ): - self.max_results = config.get("max_results", 3) self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/smart_scraper_multi_lite_graph.py b/scrapegraphai/graphs/smart_scraper_multi_lite_graph.py index 849c85c8..8ef3211a 100644 --- a/scrapegraphai/graphs/smart_scraper_multi_lite_graph.py +++ b/scrapegraphai/graphs/smart_scraper_multi_lite_graph.py @@ -55,7 +55,6 @@ def __init__( config: dict, schema: Optional[Type[BaseModel]] = None, ): - self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/xml_scraper_multi_graph.py b/scrapegraphai/graphs/xml_scraper_multi_graph.py index 2a3848c9..480781a8 100644 --- a/scrapegraphai/graphs/xml_scraper_multi_graph.py +++ b/scrapegraphai/graphs/xml_scraper_multi_graph.py @@ -49,7 +49,6 @@ def __init__( config: dict, schema: Optional[Type[BaseModel]] = None, ): - self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/models/openai_tts.py b/scrapegraphai/models/openai_tts.py index 714050fb..7503fe1b 100644 --- a/scrapegraphai/models/openai_tts.py +++ b/scrapegraphai/models/openai_tts.py @@ -19,7 +19,6 @@ class OpenAITextToSpeech: """ def __init__(self, tts_config: dict): - self.client = OpenAI( api_key=tts_config.get("api_key"), base_url=tts_config.get("base_url", None) ) diff --git a/scrapegraphai/nodes/base_node.py b/scrapegraphai/nodes/base_node.py index 45ee82d3..179865eb 100644 --- a/scrapegraphai/nodes/base_node.py +++ b/scrapegraphai/nodes/base_node.py @@ -54,7 +54,6 @@ def __init__( min_input_len: int = 1, node_config: Optional[dict] = None, ): - self.node_name = node_name self.input = input self.output = output @@ -197,7 +196,6 @@ def evaluate_simple_expression(exp: str) -> List[str]: """Evaluate an expression without parentheses.""" for or_segment in exp.split("|"): - and_segment = or_segment.split("&") if all(elem.strip() in state for elem in and_segment): return [ @@ -226,7 +224,7 @@ def evaluate_expression(expression: str) -> List[str]: raise ValueError( f"""No state keys matched the expression. Expression was {expression}. - State contains keys: {', '.join(state.keys())}""" + State contains keys: {", ".join(state.keys())}""" ) final_result = [] diff --git a/scrapegraphai/nodes/concat_answers_node.py b/scrapegraphai/nodes/concat_answers_node.py index c1b271c0..11cba4d6 100644 --- a/scrapegraphai/nodes/concat_answers_node.py +++ b/scrapegraphai/nodes/concat_answers_node.py @@ -36,8 +36,7 @@ def __init__( ) def _merge_dict(self, items): - - return {"products": {f"item_{i+1}": item for i, item in enumerate(items)}} + return {"products": {f"item_{i + 1}": item for i, item in enumerate(items)}} def execute(self, state: dict) -> dict: """ diff --git a/scrapegraphai/nodes/description_node.py b/scrapegraphai/nodes/description_node.py index 21917c84..90102ceb 100644 --- a/scrapegraphai/nodes/description_node.py +++ b/scrapegraphai/nodes/description_node.py @@ -58,7 +58,7 @@ def execute(self, state: dict) -> dict: template=DESCRIPTION_NODE_PROMPT, partial_variables={"content": chunk.get("document")}, ) - chain_name = f"chunk{i+1}" + chain_name = f"chunk{i + 1}" chains_dict[chain_name] = prompt | self.llm_model async_runner = RunnableParallel(**chains_dict) diff --git a/scrapegraphai/nodes/generate_answer_csv_node.py b/scrapegraphai/nodes/generate_answer_csv_node.py index c5790479..cd24fc21 100644 --- a/scrapegraphai/nodes/generate_answer_csv_node.py +++ b/scrapegraphai/nodes/generate_answer_csv_node.py @@ -96,7 +96,6 @@ def execute(self, state): doc = input_data[1] if self.node_config.get("schema", None) is not None: - if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( schema=self.node_config["schema"] @@ -151,7 +150,7 @@ def execute(self, state): }, ) - chain_name = f"chunk{i+1}" + chain_name = f"chunk{i + 1}" chains_dict[chain_name] = prompt | self.llm_model | output_parser async_runner = RunnableParallel(**chains_dict) diff --git a/scrapegraphai/nodes/generate_answer_from_image_node.py b/scrapegraphai/nodes/generate_answer_from_image_node.py index 1ef653f3..808804fd 100644 --- a/scrapegraphai/nodes/generate_answer_from_image_node.py +++ b/scrapegraphai/nodes/generate_answer_from_image_node.py @@ -85,7 +85,7 @@ async def execute_async(self, state: dict) -> dict: raise ValueError( f"""The model provided is not supported. Supported models are: - {', '.join(supported_models)}.""" + {", ".join(supported_models)}.""" ) api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "") diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 5e267a4d..db4467be 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -221,7 +221,7 @@ def execute(self, state: dict) -> dict: "format_instructions": format_instructions, }, ) - chain_name = f"chunk{i+1}" + chain_name = f"chunk{i + 1}" chains_dict[chain_name] = prompt | self.llm_model if output_parser: chains_dict[chain_name] = chains_dict[chain_name] | output_parser diff --git a/scrapegraphai/nodes/generate_answer_node_k_level.py b/scrapegraphai/nodes/generate_answer_node_k_level.py index daef9d02..27106c88 100644 --- a/scrapegraphai/nodes/generate_answer_node_k_level.py +++ b/scrapegraphai/nodes/generate_answer_node_k_level.py @@ -155,7 +155,7 @@ def execute(self, state: dict) -> dict: "chunk_id": i + 1, }, ) - chain_name = f"chunk{i+1}" + chain_name = f"chunk{i + 1}" chains_dict[chain_name] = prompt | self.llm_model async_runner = RunnableParallel(**chains_dict) diff --git a/scrapegraphai/nodes/generate_answer_omni_node.py b/scrapegraphai/nodes/generate_answer_omni_node.py index ba5bbc6b..3e608bfb 100644 --- a/scrapegraphai/nodes/generate_answer_omni_node.py +++ b/scrapegraphai/nodes/generate_answer_omni_node.py @@ -89,7 +89,6 @@ def execute(self, state: dict) -> dict: imag_desc = input_data[2] if self.node_config.get("schema", None) is not None: - if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( schema=self.node_config["schema"] @@ -151,7 +150,7 @@ def execute(self, state: dict) -> dict: }, ) - chain_name = f"chunk{i+1}" + chain_name = f"chunk{i + 1}" chains_dict[chain_name] = prompt | self.llm_model | output_parser async_runner = RunnableParallel(**chains_dict) diff --git a/scrapegraphai/nodes/merge_answers_node.py b/scrapegraphai/nodes/merge_answers_node.py index b867b3e0..18e9fcc8 100644 --- a/scrapegraphai/nodes/merge_answers_node.py +++ b/scrapegraphai/nodes/merge_answers_node.py @@ -82,10 +82,9 @@ def execute(self, state: dict) -> dict: answers_str = "" for i, answer in enumerate(answers): - answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n" + answers_str += f"CONTENT WEBSITE {i + 1}: {answer}\n" if self.node_config.get("schema", None) is not None: - if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( schema=self.node_config["schema"] diff --git a/scrapegraphai/nodes/merge_generated_scripts_node.py b/scrapegraphai/nodes/merge_generated_scripts_node.py index 5ccac699..2b4a2217 100644 --- a/scrapegraphai/nodes/merge_generated_scripts_node.py +++ b/scrapegraphai/nodes/merge_generated_scripts_node.py @@ -64,7 +64,7 @@ def execute(self, state: dict) -> dict: scripts_str = "" for i, script in enumerate(scripts): scripts_str += "-----------------------------------\n" - scripts_str += f"SCRIPT URL {i+1}\n" + scripts_str += f"SCRIPT URL {i + 1}\n" scripts_str += "-----------------------------------\n" scripts_str += script diff --git a/scrapegraphai/nodes/parse_node.py b/scrapegraphai/nodes/parse_node.py index cb61a643..1c409da2 100644 --- a/scrapegraphai/nodes/parse_node.py +++ b/scrapegraphai/nodes/parse_node.py @@ -122,7 +122,7 @@ def execute(self, state: dict) -> dict: state.update({self.output[0]: chunks}) state.update({"parsed_doc": chunks}) state.update({"content": chunks}) - + if self.parse_urls: state.update({self.output[1]: link_urls}) state.update({self.output[2]: img_urls}) diff --git a/scrapegraphai/nodes/search_link_node.py b/scrapegraphai/nodes/search_link_node.py index 614b4878..6ae5d01b 100644 --- a/scrapegraphai/nodes/search_link_node.py +++ b/scrapegraphai/nodes/search_link_node.py @@ -122,7 +122,6 @@ def execute(self, state: dict) -> dict: ) ): try: - links = re.findall(r'https?://[^\s"<>\]]+', str(chunk.page_content)) if not self.filter_links: diff --git a/scrapegraphai/utils/__init__.py b/scrapegraphai/utils/__init__.py index 0190d691..df9118c1 100644 --- a/scrapegraphai/utils/__init__.py +++ b/scrapegraphai/utils/__init__.py @@ -1,5 +1,5 @@ """ - __init__.py file for utils folder +__init__.py file for utils folder """ from .cleanup_code import extract_code diff --git a/scrapegraphai/utils/cleanup_html.py b/scrapegraphai/utils/cleanup_html.py index 6da03a90..8a40fd88 100644 --- a/scrapegraphai/utils/cleanup_html.py +++ b/scrapegraphai/utils/cleanup_html.py @@ -2,8 +2,8 @@ Module for minimizing the code """ -import re import json +import re from urllib.parse import urljoin from bs4 import BeautifulSoup, Comment @@ -12,32 +12,36 @@ def extract_from_script_tags(soup): script_content = [] - + for script in soup.find_all("script"): content = script.string if content: try: - json_pattern = r'(?:const|let|var)?\s*\w+\s*=\s*({[\s\S]*?});?$' + json_pattern = r"(?:const|let|var)?\s*\w+\s*=\s*({[\s\S]*?});?$" json_matches = re.findall(json_pattern, content) - + for potential_json in json_matches: try: parsed = json.loads(potential_json) if parsed: - script_content.append(f"JSON data from script: {json.dumps(parsed, indent=2)}") + script_content.append( + f"JSON data from script: {json.dumps(parsed, indent=2)}" + ) except json.JSONDecodeError: pass - + if "window." in content or "document." in content: - data_pattern = r'(?:window|document)\.(\w+)\s*=\s*([^;]+);' + data_pattern = r"(?:window|document)\.(\w+)\s*=\s*([^;]+);" data_matches = re.findall(data_pattern, content) - + for var_name, var_value in data_matches: - script_content.append(f"Dynamic data - {var_name}: {var_value.strip()}") + script_content.append( + f"Dynamic data - {var_name}: {var_value.strip()}" + ) except Exception: if len(content) < 1000: script_content.append(f"Script content: {content.strip()}") - + return "\n\n".join(script_content) @@ -66,9 +70,9 @@ def cleanup_html(html_content: str, base_url: str) -> str: title_tag = soup.find("title") title = title_tag.get_text() if title_tag else "" - + script_content = extract_from_script_tags(soup) - + for tag in soup.find_all("style"): tag.extract() diff --git a/scrapegraphai/utils/dict_content_compare.py b/scrapegraphai/utils/dict_content_compare.py index c94d4863..9e5efbbd 100644 --- a/scrapegraphai/utils/dict_content_compare.py +++ b/scrapegraphai/utils/dict_content_compare.py @@ -53,7 +53,9 @@ def normalize_list(lst: List[Any]) -> List[Any]: else ( normalize_list(item) if isinstance(item, list) - else item.lower().strip() if isinstance(item, str) else item + else item.lower().strip() + if isinstance(item, str) + else item ) ) for item in lst diff --git a/scrapegraphai/utils/output_parser.py b/scrapegraphai/utils/output_parser.py index a4cf9f5a..a9d9ba31 100644 --- a/scrapegraphai/utils/output_parser.py +++ b/scrapegraphai/utils/output_parser.py @@ -10,7 +10,7 @@ def get_structured_output_parser( - schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type] + schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type], ) -> Callable: """ Get the correct output parser for the LLM model. @@ -28,7 +28,7 @@ def get_structured_output_parser( def get_pydantic_output_parser( - schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type] + schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type], ) -> JsonOutputParser: """ Get the correct output parser for the LLM model. diff --git a/scrapegraphai/utils/parse_state_keys.py b/scrapegraphai/utils/parse_state_keys.py index 97531487..040f1310 100644 --- a/scrapegraphai/utils/parse_state_keys.py +++ b/scrapegraphai/utils/parse_state_keys.py @@ -56,7 +56,6 @@ def parse_expression(expression, state: dict) -> list: or "&|" in expression or "|&" in expression ): - raise ValueError("Invalid operator usage.") open_parentheses = close_parentheses = 0 diff --git a/scrapegraphai/utils/proxy_rotation.py b/scrapegraphai/utils/proxy_rotation.py index 8e8534e1..0e348377 100644 --- a/scrapegraphai/utils/proxy_rotation.py +++ b/scrapegraphai/utils/proxy_rotation.py @@ -6,11 +6,12 @@ import random import re from typing import List, Optional, Set, TypedDict +from urllib.parse import urlparse import requests from fp.errors import FreeProxyException from fp.fp import FreeProxy -from urllib.parse import urlparse + class ProxyBrokerCriteria(TypedDict, total=False): """ @@ -200,7 +201,9 @@ def parse_or_search_proxy(proxy: Proxy) -> ProxySettings: raise ValueError(f"Invalid proxy server format: {proxy['server']}") # Accept both IP addresses and domain names like 'gate.nodemaven.com' - if is_ipv4_address(server_address) or re.match(r"^[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", server_address): + if is_ipv4_address(server_address) or re.match( + r"^[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", server_address + ): return _parse_proxy(proxy) assert proxy["server"] == "broker", f"Unknown proxy server type: {proxy['server']}" diff --git a/tests/graphs/abstract_graph_test.py b/tests/graphs/abstract_graph_test.py index 873b5a2f..4c9b026e 100644 --- a/tests/graphs/abstract_graph_test.py +++ b/tests/graphs/abstract_graph_test.py @@ -1,17 +1,19 @@ -import pytest +from unittest.mock import Mock, patch +import pytest from langchain_aws import ChatBedrock from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI + from scrapegraphai.graphs import AbstractGraph, BaseGraph from scrapegraphai.models import DeepSeek, OneApi from scrapegraphai.nodes import FetchNode, ParseNode -from unittest.mock import Mock, patch """ Tests for the AbstractGraph. """ + class TestGraph(AbstractGraph): def __init__(self, prompt: str, config: dict): super().__init__(prompt, config) @@ -48,6 +50,7 @@ def run(self) -> str: return self.final_state.get("answer", "No answer found.") + class TestAbstractGraph: @pytest.mark.parametrize( "llm_config, expected_model", @@ -171,7 +174,7 @@ def test_create_llm_with_custom_model_instance(self): "llm": { "model_instance": mock_model, "model_tokens": 1000, - "model": "custom/model" + "model": "custom/model", } } @@ -192,18 +195,27 @@ def test_set_common_params(self): mock_graph.nodes = [mock_node1, mock_node2] # Create a TestGraph instance with the mock graph - with patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_graph', return_value=mock_graph): - graph = TestGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-test"}}) + with patch( + "scrapegraphai.graphs.abstract_graph.AbstractGraph._create_graph", + return_value=mock_graph, + ): + graph = TestGraph( + "Test prompt", + {"llm": {"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-test"}}, + ) # Call set_common_params with test parameters test_params = {"param1": "value1", "param2": "value2"} graph.set_common_params(test_params) # Assert that update_config was called on each node with the correct parameters - + def test_get_state(self): """Test that get_state returns the correct final state with or without a provided key, and raises KeyError for missing keys.""" - graph = TestGraph("dummy", {"llm": {"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-test"}}) + graph = TestGraph( + "dummy", + {"llm": {"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-test"}}, + ) # Set a dummy final state graph.final_state = {"answer": "42", "other": "value"} # Test without a key returns the entire final_state @@ -218,7 +230,10 @@ def test_get_state(self): def test_append_node(self): """Test that append_node correctly delegates to the graph's append_node method.""" - graph = TestGraph("dummy", {"llm": {"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-test"}}) + graph = TestGraph( + "dummy", + {"llm": {"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-test"}}, + ) # Replace the graph object with a mock that has append_node mock_graph = Mock() graph.graph = mock_graph @@ -228,8 +243,11 @@ def test_append_node(self): def test_get_execution_info(self): """Test that get_execution_info returns the execution info stored in the graph.""" - graph = TestGraph("dummy", {"llm": {"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-test"}}) + graph = TestGraph( + "dummy", + {"llm": {"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-test"}}, + ) dummy_info = {"execution": "info", "status": "ok"} graph.execution_info = dummy_info info = graph.get_execution_info() - assert info == dummy_info \ No newline at end of file + assert info == dummy_info diff --git a/tests/graphs/script_generator_test.py b/tests/graphs/script_generator_test.py index d90d2df0..598ccfd6 100644 --- a/tests/graphs/script_generator_test.py +++ b/tests/graphs/script_generator_test.py @@ -5,7 +5,6 @@ import pytest from scrapegraphai.graphs import ScriptCreatorGraph -from scrapegraphai.utils import prettify_exec_info @pytest.fixture @@ -35,6 +34,6 @@ def test_script_creator_graph(graph_config: dict): config=graph_config, ) result = smart_scraper_graph.run() - assert ( - result is not None - ), "ScriptCreatorGraph execution failed to produce a result." + assert result is not None, ( + "ScriptCreatorGraph execution failed to produce a result." + ) diff --git a/tests/graphs/search_link_ollama.py b/tests/graphs/search_link_ollama.py index 9801b2fa..a6c63e1e 100644 --- a/tests/graphs/search_link_ollama.py +++ b/tests/graphs/search_link_ollama.py @@ -1,5 +1,4 @@ from scrapegraphai.graphs import SearchLinkGraph -from scrapegraphai.utils import prettify_exec_info def test_smart_scraper_pipeline(): diff --git a/tests/nodes/robot_node_test.py b/tests/nodes/robot_node_test.py index 692b6062..dad24fbb 100644 --- a/tests/nodes/robot_node_test.py +++ b/tests/nodes/robot_node_test.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock import pytest -from langchain_community.chat_models import ChatOllama from scrapegraphai.nodes import RobotsNode diff --git a/tests/nodes/search_internet_node_test.py b/tests/nodes/search_internet_node_test.py index 67792424..25ad1234 100644 --- a/tests/nodes/search_internet_node_test.py +++ b/tests/nodes/search_internet_node_test.py @@ -6,7 +6,6 @@ class TestSearchInternetNode(unittest.TestCase): - def setUp(self): # Configuration for the graph self.graph_config = { diff --git a/tests/nodes/search_link_node_test.py b/tests/nodes/search_link_node_test.py index 5c64e470..2c630a2f 100644 --- a/tests/nodes/search_link_node_test.py +++ b/tests/nodes/search_link_node_test.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from langchain_community.chat_models import ChatOllama diff --git a/tests/test_chromium.py b/tests/test_chromium.py new file mode 100644 index 00000000..1c56840a --- /dev/null +++ b/tests/test_chromium.py @@ -0,0 +1,866 @@ +import asyncio +import sys +from unittest.mock import AsyncMock, patch + +import aiohttp +import pytest + + +class MockPlaywright: + def __init__(self): + self.chromium = AsyncMock() + self.firefox = AsyncMock() + + +class MockBrowser: + def __init__(self): + self.new_context = AsyncMock() + + +class MockContext: + def __init__(self): + self.new_page = AsyncMock() + + +class MockPage: + def __init__(self): + self.goto = AsyncMock() + self.wait_for_load_state = AsyncMock() + self.content = AsyncMock() + self.evaluate = AsyncMock() + self.mouse = AsyncMock() + self.mouse.wheel = AsyncMock() + + +@pytest.fixture +def mock_playwright(): + with patch("playwright.async_api.async_playwright") as mock: + mock_pw = MockPlaywright() + mock_browser = MockBrowser() + mock_context = MockContext() + mock_page = MockPage() + + mock_pw.chromium.launch.return_value = mock_browser + mock_pw.firefox.launch.return_value = mock_browser + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + mock.return_value.__aenter__.return_value = mock_pw + yield mock_pw, mock_browser, mock_context, mock_page + + +import pytest +from langchain_core.documents import Document + +from scrapegraphai.docloaders.chromium import ChromiumLoader + + +async def dummy_scraper(url): + """A dummy scraping function that returns dummy HTML content for the URL.""" + return f"dummy content for {url}" + + +@pytest.fixture +def loader_with_dummy(monkeypatch): + """Fixture returning a ChromiumLoader instance with dummy scraping methods patched.""" + urls = ["http://example.com", "http://test.com"] + loader = ChromiumLoader(urls, backend="playwright", requires_js_support=False) + monkeypatch.setattr(loader, "ascrape_playwright", dummy_scraper) + monkeypatch.setattr(loader, "ascrape_with_js_support", dummy_scraper) + monkeypatch.setattr(loader, "ascrape_undetected_chromedriver", dummy_scraper) + return loader + + +def test_lazy_load(loader_with_dummy): + """Test that lazy_load yields Document objects with the correct dummy content and metadata.""" + docs = list(loader_with_dummy.lazy_load()) + assert len(docs) == 2 + for doc, url in zip(docs, loader_with_dummy.urls): + assert isinstance(doc, Document) + assert f"dummy content for {url}" in doc.page_content + assert doc.metadata["source"] == url + + +@pytest.mark.asyncio +async def test_alazy_load(loader_with_dummy): + """Test that alazy_load asynchronously yields Document objects with dummy content and proper metadata.""" + docs = [doc async for doc in loader_with_dummy.alazy_load()] + assert len(docs) == 2 + for doc, url in zip(docs, loader_with_dummy.urls): + assert isinstance(doc, Document) + assert f"dummy content for {url}" in doc.page_content + assert doc.metadata["source"] == url + + +@pytest.mark.asyncio +async def test_scrape_method_unsupported_backend(): + """Test that the scrape method raises a ValueError when an unsupported backend is provided.""" + loader = ChromiumLoader(["http://example.com"], backend="unsupported") + with pytest.raises(ValueError): + await loader.scrape("http://example.com") + + +@pytest.mark.asyncio +async def test_scrape_method_selenium(monkeypatch): + """Test that the scrape method works correctly for selenium by returning the dummy selenium content.""" + + async def dummy_selenium(url): + return f"dummy selenium content for {url}" + + urls = ["http://example.com"] + loader = ChromiumLoader(urls, backend="selenium") + loader.browser_name = "chromium" + monkeypatch.setattr(loader, "ascrape_undetected_chromedriver", dummy_selenium) + result = await loader.scrape("http://example.com") + assert "dummy selenium content" in result + + +@pytest.mark.asyncio +async def test_ascrape_playwright_scroll(mock_playwright): + """Test the ascrape_playwright_scroll method with various configurations.""" + mock_pw, mock_browser, mock_context, mock_page = mock_playwright + + url = "http://example.com" + loader = ChromiumLoader([url], backend="playwright") + + # Test with default parameters + mock_page.evaluate.side_effect = [1000, 2000, 2000] # Simulate scrolling + result = await loader.ascrape_playwright_scroll(url) + + assert mock_page.goto.call_count == 1 + assert mock_page.wait_for_load_state.call_count == 1 + assert mock_page.mouse.wheel.call_count > 0 + assert mock_page.content.call_count == 1 + + # Test with custom parameters + mock_page.evaluate.side_effect = [1000, 2000, 3000, 4000, 4000] + result = await loader.ascrape_playwright_scroll( + url, timeout=10, scroll=10000, sleep=1, scroll_to_bottom=True + ) + + assert mock_page.goto.call_count == 2 + assert mock_page.wait_for_load_state.call_count == 2 + assert mock_page.mouse.wheel.call_count > 0 + assert mock_page.content.call_count == 2 + + +@pytest.mark.asyncio +async def test_ascrape_with_js_support(mock_playwright): + """Test the ascrape_with_js_support method with different browser configurations.""" + mock_pw, mock_browser, mock_context, mock_page = mock_playwright + + url = "http://example.com" + loader = ChromiumLoader([url], backend="playwright", requires_js_support=True) + + # Test with Chromium + result = await loader.ascrape_with_js_support(url, browser_name="chromium") + assert mock_pw.chromium.launch.call_count == 1 + assert mock_page.goto.call_count == 1 + assert mock_page.content.call_count == 1 + + # Test with Firefox + result = await loader.ascrape_with_js_support(url, browser_name="firefox") + assert mock_pw.firefox.launch.call_count == 1 + assert mock_page.goto.call_count == 2 + assert mock_page.content.call_count == 2 + + # Test with invalid browser name + with pytest.raises(ValueError): + await loader.ascrape_with_js_support(url, browser_name="invalid") + + +@pytest.mark.asyncio +async def test_scrape_method_playwright(mock_playwright): + """Test the scrape method with playwright backend.""" + mock_pw, mock_browser, mock_context, mock_page = mock_playwright + + url = "http://example.com" + loader = ChromiumLoader([url], backend="playwright") + + mock_page.content.return_value = "Playwright content" + result = await loader.scrape(url) + + assert "Playwright content" in result + assert mock_pw.chromium.launch.call_count == 1 + assert mock_page.goto.call_count == 1 + assert mock_page.wait_for_load_state.call_count == 1 + assert mock_page.content.call_count == 1 + + +@pytest.mark.asyncio +async def test_scrape_method_retry_logic(mock_playwright): + """Test the retry logic in the scrape method.""" + mock_pw, mock_browser, mock_context, mock_page = mock_playwright + + url = "http://example.com" + loader = ChromiumLoader([url], backend="playwright", retry_limit=3) + + # Simulate two failures and then a success + mock_page.goto.side_effect = [asyncio.TimeoutError(), aiohttp.ClientError(), None] + mock_page.content.return_value = "Success after retries" + + result = await loader.scrape(url) + + assert "Success after retries" in result + assert mock_page.goto.call_count == 3 + assert mock_page.content.call_count == 1 + + # Test failure after all retries + mock_page.goto.side_effect = asyncio.TimeoutError() + + with pytest.raises(RuntimeError): + await loader.scrape(url) + + assert mock_page.goto.call_count == 6 # 3 more attempts + + +@pytest.mark.asyncio +async def test_ascrape_playwright_scroll_invalid_params(): + """Test that ascrape_playwright_scroll raises ValueError for invalid scroll parameters.""" + loader = ChromiumLoader(["http://example.com"], backend="playwright") + with pytest.raises( + ValueError, + match="If set, timeout value for scrolling scraper must be greater than 0.", + ): + await loader.ascrape_playwright_scroll("http://example.com", timeout=0) + with pytest.raises( + ValueError, match="Sleep for scrolling scraper value must be greater than 0." + ): + await loader.ascrape_playwright_scroll("http://example.com", sleep=0) + with pytest.raises( + ValueError, + match="Scroll value for scrolling scraper must be greater than or equal to 5000.", + ): + await loader.ascrape_playwright_scroll("http://example.com", scroll=4000) + + +@pytest.mark.asyncio +async def test_ascrape_with_js_support_retry_failure(monkeypatch): + """Test that ascrape_with_js_support retries and ultimately fails when page.goto always times out.""" + loader = ChromiumLoader( + ["http://example.com"], + backend="playwright", + requires_js_support=True, + retry_limit=2, + timeout=1, + ) + + # Create dummy classes to simulate failure in page.goto + class DummyPage: + async def goto(self, url, wait_until): + raise asyncio.TimeoutError("Forced timeout") + + async def wait_for_load_state(self, state): + return + + async def content(self): + return "Dummy" + + class DummyContext: + async def new_page(self): + return DummyPage() + + class DummyBrowser: + async def new_context(self, **kwargs): + return DummyContext() + + async def close(self): + return + + class DummyPW: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return + + class chromium: + @staticmethod + async def launch(headless, proxy, **kwargs): + return DummyBrowser() + + class firefox: + @staticmethod + async def launch(headless, proxy, **kwargs): + return DummyBrowser() + + # Patch the async_playwright to return our dummy + monkeypatch.setattr("playwright.async_api.async_playwright", lambda: DummyPW()) + + with pytest.raises(RuntimeError, match="Failed to scrape after"): + await loader.ascrape_with_js_support("http://example.com") + + +@pytest.mark.asyncio +async def test_ascrape_undetected_chromedriver_success(monkeypatch): + """Test that ascrape_undetected_chromedriver successfully returns content using the selenium backend.""" + # Create a dummy undetected_chromedriver module with a dummy Chrome driver. + import types + + dummy_module = types.ModuleType("undetected_chromedriver") + + class DummyDriver: + def __init__(self, options): + self.options = options + self.page_source = "selenium content" + + def quit(self): + pass + + dummy_module.Chrome = lambda options: DummyDriver(options) + monkeypatch.setitem(sys.modules, "undetected_chromedriver", dummy_module) + + urls = ["http://example.com"] + loader = ChromiumLoader(urls, backend="selenium", retry_limit=1, timeout=5) + loader.browser_name = "chromium" + result = await loader.ascrape_undetected_chromedriver("http://example.com") + assert "selenium content" in result + + +@pytest.mark.asyncio +async def test_lazy_load_exception(loader_with_dummy, monkeypatch): + """Test that lazy_load propagates exception if the scraping function fails.""" + + async def dummy_failure(url): + raise Exception("Dummy scraping error") + + # Patch the scraping method to always raise an exception + loader_with_dummy.backend = "playwright" + monkeypatch.setattr(loader_with_dummy, "ascrape_playwright", dummy_failure) + with pytest.raises(Exception, match="Dummy scraping error"): + list(loader_with_dummy.lazy_load()) + + +@pytest.mark.asyncio +async def test_ascrape_undetected_chromedriver_unsupported_browser(monkeypatch): + """Test ascrape_undetected_chromedriver raises an error when an unsupported browser is provided.""" + import types + + dummy_module = types.ModuleType("undetected_chromedriver") + # Provide a dummy Chrome; this will not be used for an unsupported browser. + dummy_module.Chrome = lambda options: None + monkeypatch.setitem(sys.modules, "undetected_chromedriver", dummy_module) + + loader = ChromiumLoader( + ["http://example.com"], backend="selenium", retry_limit=1, timeout=1 + ) + loader.browser_name = "opera" # Unsupported browser. + with pytest.raises(UnboundLocalError): + await loader.ascrape_undetected_chromedriver("http://example.com") + + +@pytest.mark.asyncio +async def test_alazy_load_partial_failure(monkeypatch): + """Test that alazy_load propagates an exception if one of the scraping tasks fails.""" + urls = ["http://example.com", "http://fail.com"] + loader = ChromiumLoader(urls, backend="playwright") + + async def partial_scraper(url): + if "fail" in url: + raise Exception("Scraping failed for " + url) + return f"Content for {url}" + + monkeypatch.setattr(loader, "ascrape_playwright", partial_scraper) + + with pytest.raises(Exception, match="Scraping failed for http://fail.com"): + [doc async for doc in loader.alazy_load()] + + +@pytest.mark.asyncio +async def test_ascrape_playwright_retry_failure(monkeypatch): + """Test that ascrape_playwright retries scraping and raises RuntimeError after all attempts fail.""" + + # Dummy classes to simulate persistent failure in page.goto for ascrape_playwright + class DummyPage: + async def goto(self, url, wait_until): + raise asyncio.TimeoutError("Forced timeout in goto") + + async def wait_for_load_state(self, state): + return + + async def content(self): + return "This should not be returned" + + class DummyContext: + async def new_page(self): + return DummyPage() + + class DummyBrowser: + async def new_context(self, **kwargs): + return DummyContext() + + async def close(self): + return + + class DummyPW: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return + + class chromium: + @staticmethod + async def launch(headless, proxy, **kwargs): + return DummyBrowser() + + class firefox: + @staticmethod + async def launch(headless, proxy, **kwargs): + return DummyBrowser() + + monkeypatch.setattr("playwright.async_api.async_playwright", lambda: DummyPW()) + + loader = ChromiumLoader( + ["http://example.com"], backend="playwright", retry_limit=2, timeout=1 + ) + with pytest.raises(RuntimeError, match="Failed to scrape after 2 attempts"): + await loader.ascrape_playwright("http://example.com") + + +@pytest.mark.asyncio +async def test_init_overrides(): + """Test that ChromiumLoader picks up and overrides attributes using kwargs.""" + urls = ["http://example.com"] + loader = ChromiumLoader( + urls, + backend="playwright", + headless=False, + proxy={"http": "http://proxy"}, + load_state="load", + requires_js_support=True, + storage_state="state", + browser_name="firefox", + retry_limit=5, + timeout=120, + extra="value", + ) + # Check that attributes are correctly set + assert loader.headless is False + assert loader.proxy == {"http": "http://proxy"} + assert loader.load_state == "load" + assert loader.requires_js_support is True + assert loader.storage_state == "state" + assert loader.browser_name == "firefox" + assert loader.retry_limit == 5 + assert loader.timeout == 120 + # Check that extra kwargs go into browser_config + assert loader.browser_config.get("extra") == "value" + # Check that the backend remains as provided + assert loader.backend == "playwright" + + +@pytest.mark.asyncio +async def test_lazy_load_with_js_support(monkeypatch): + """Test that lazy_load uses ascrape_with_js_support when requires_js_support is True.""" + urls = ["http://example.com", "http://test.com"] + loader = ChromiumLoader(urls, backend="playwright", requires_js_support=True) + + async def dummy_js(url): + return f"JS content for {url}" + + monkeypatch.setattr(loader, "ascrape_with_js_support", dummy_js) + docs = list(loader.lazy_load()) + assert len(docs) == 2 + for doc, url in zip(docs, urls): + assert isinstance(doc, Document) + assert f"JS content for {url}" in doc.page_content + assert doc.metadata["source"] == url + + +@pytest.mark.asyncio +async def test_no_retry_returns_none(monkeypatch): + """Test that ascrape_playwright returns None if retry_limit is set to 0.""" + urls = ["http://example.com"] + loader = ChromiumLoader(urls, backend="playwright", retry_limit=0) + + # Even if we patch ascrape_playwright, the while loop won't run since retry_limit is 0, so it should return None. + async def dummy(url, browser_name="chromium"): + return f"Content for {url}" + + monkeypatch.setattr(loader, "ascrape_playwright", dummy) + result = await loader.ascrape_playwright("http://example.com") + # With retry_limit=0, the loop never runs and the function returns None. + assert result is None + + +@pytest.mark.asyncio +async def test_alazy_load_empty_urls(): + """Test that alazy_load yields no documents when the urls list is empty.""" + loader = ChromiumLoader([], backend="playwright") + docs = [doc async for doc in loader.alazy_load()] + assert docs == [] + + +def test_lazy_load_empty_urls(): + """Test that lazy_load yields no documents when the urls list is empty.""" + loader = ChromiumLoader([], backend="playwright") + docs = list(loader.lazy_load()) + assert docs == [] + + +@pytest.mark.asyncio +async def test_ascrape_undetected_chromedriver_missing_import(monkeypatch): + """Test that ascrape_undetected_chromedriver raises ImportError when undetected_chromedriver is not installed.""" + # Remove undetected_chromedriver from sys.modules if it exists + if "undetected_chromedriver" in sys.modules: + monkeyatch_key = "undetected_chromedriver" + monkeypatch.delenitem(sys.modules, monkeyatch_key) + loader = ChromiumLoader( + ["http://example.com"], backend="selenium", retry_limit=1, timeout=5 + ) + loader.browser_name = "chromium" + with pytest.raises( + ImportError, match="undetected_chromedriver is required for ChromiumLoader" + ): + await loader.ascrape_undetected_chromedriver("http://example.com") + + +@pytest.mark.asyncio +async def test_ascrape_undetected_chromedriver_quit_called(monkeypatch): + """Test that ascrape_undetected_chromedriver calls driver.quit() on every attempt even when get() fails.""" + # List to collect each DummyDriver instance for later inspection. + driver_instances = [] + attempt_counter = [0] + + class DummyDriver: + def __init__(self, options): + self.options = options + self.quit_called = False + driver_instances.append(self) + + def get(self, url): + # Force a failure on the first attempt then succeed on subsequent attempts. + if attempt_counter[0] < 1: + attempt_counter[0] += 1 + raise aiohttp.ClientError("Forced failure") + # If no failure, simply pass. + + @property + def page_source(self): + return "driver content" + + def quit(self): + self.quit_called = True + + import types + + dummy_module = types.ModuleType("undetected_chromedriver") + dummy_module.Chrome = lambda options: DummyDriver(options) + monkeypatch.setitem(sys.modules, "undetected_chromedriver", dummy_module) + + urls = ["http://example.com"] + loader = ChromiumLoader(urls, backend="selenium", retry_limit=2, timeout=5) + loader.browser_name = "chromium" + result = await loader.ascrape_undetected_chromedriver("http://example.com") + assert "driver content" in result + # Verify that two driver instances were used and that each had its quit() method called. + assert len(driver_instances) == 2 + for driver in driver_instances: + assert driver.quit_called is True + + +@pytest.mark.parametrize("backend", ["playwright", "selenium"]) +def test_dynamic_import_failure(monkeypatch, backend): + """Test that ChromiumLoader raises ImportError when dynamic_import fails.""" + + def fake_dynamic_import(backend, message): + raise ImportError("Test dynamic import error") + + monkeypatch.setattr( + "scrapegraphai.docloaders.chromium.dynamic_import", fake_dynamic_import + ) + with pytest.raises(ImportError, match="Test dynamic import error"): + ChromiumLoader(["http://example.com"], backend=backend) + + +@pytest.mark.asyncio +async def test_ascrape_with_js_support_retry_success(monkeypatch): + """Test that ascrape_with_js_support retries on failure and returns content on a subsequent successful attempt.""" + attempt_count = {"count": 0} + + class DummyPage: + async def goto(self, url, wait_until): + if attempt_count["count"] < 1: + attempt_count["count"] += 1 + raise asyncio.TimeoutError("Forced timeout") + # On second attempt, do nothing (simulate successful navigation) + + async def wait_for_load_state(self, state): + return + + async def content(self): + return "Success on retry" + + class DummyContext: + async def new_page(self): + return DummyPage() + + class DummyBrowser: + async def new_context(self, **kwargs): + return DummyContext() + + async def close(self): + return + + class DummyPW: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return + + class chromium: + @staticmethod + async def launch(headless, proxy, **kwargs): + return DummyBrowser() + + class firefox: + @staticmethod + async def launch(headless, proxy, **kwargs): + return DummyBrowser() + + monkeypatch.setattr("playwright.async_api.async_playwright", lambda: DummyPW()) + + # Create a loader with JS support and a retry_limit of 2 (so one failure is allowed) + loader = ChromiumLoader( + ["http://example.com"], + backend="playwright", + requires_js_support=True, + retry_limit=2, + timeout=1, + ) + result = await loader.ascrape_with_js_support("http://example.com") + assert result == "Success on retry" + + +@pytest.mark.asyncio +async def test_proxy_parsing_in_init(monkeypatch): + """Test that providing a proxy triggers the use of parse_or_search_proxy and sets loader.proxy correctly.""" + dummy_proxy_value = {"dummy": True} + monkeypatch.setattr( + "scrapegraphai.docloaders.chromium.parse_or_search_proxy", + lambda proxy: dummy_proxy_value, + ) + loader = ChromiumLoader( + ["http://example.com"], backend="playwright", proxy="some_proxy_value" + ) + assert loader.proxy == dummy_proxy_value + + +@pytest.mark.asyncio +async def test_scrape_method_selenium_firefox(monkeypatch): + """Test that the scrape method works correctly for selenium with firefox backend.""" + + async def dummy_selenium(url): + return f"dummy selenium firefox content for {url}" + + urls = ["http://example.com"] + loader = ChromiumLoader(urls, backend="selenium") + loader.browser_name = "firefox" + monkeypatch.setattr(loader, "ascrape_undetected_chromedriver", dummy_selenium) + result = await loader.scrape("http://example.com") + assert "dummy selenium firefox content" in result + + +def test_init_with_no_proxy(): + """Test that initializing ChromiumLoader with proxy=None results in loader.proxy being None.""" + urls = ["http://example.com"] + loader = ChromiumLoader(urls, backend="playwright", proxy=None) + assert loader.proxy is None + + +@pytest.mark.asyncio +async def test_ascrape_playwright_negative_retry(monkeypatch): + """Test that ascrape_playwright returns None when retry_limit is negative (loop not executed).""" + + # Set-up a dummy playwright context which should never be used because retry_limit is negative. + class DummyPW: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return + + class chromium: + @staticmethod + async def launch(headless, proxy, **kwargs): + # Should not be called as retry_limit is negative. + raise Exception("Should not launch browser") + + monkeypatch.setattr("playwright.async_api.async_playwright", lambda: DummyPW()) + urls = ["http://example.com"] + loader = ChromiumLoader(urls, backend="playwright", retry_limit=-1) + result = await loader.ascrape_playwright("http://example.com") + assert result is None + + +@pytest.mark.asyncio +async def test_ascrape_with_js_support_negative_retry(monkeypatch): + """Test that ascrape_with_js_support returns None when retry_limit is negative (loop not executed).""" + + class DummyPW: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return + + class chromium: + @staticmethod + async def launch(headless, proxy, **kwargs): + # Should not be called because retry_limit is negative. + raise Exception("Should not launch browser") + + monkeypatch.setattr("playwright.async_api.async_playwright", lambda: DummyPW()) + urls = ["http://example.com"] + loader = ChromiumLoader( + urls, backend="playwright", requires_js_support=True, retry_limit=-1 + ) + try: + result = await loader.ascrape_with_js_support("http://example.com") + except RuntimeError: + result = None + assert result is None + + +@pytest.mark.asyncio +async def test_ascrape_with_js_support_storage_state(monkeypatch): + """Test that ascrape_with_js_support passes the storage_state to the new_context call.""" + + class DummyPage: + async def goto(self, url, wait_until): + return + + async def wait_for_load_state(self, state): + return + + async def content(self): + return "Storage State Tested" + + class DummyContext: + async def new_page(self): + return DummyPage() + + class DummyBrowser: + def __init__(self): + self.last_context_kwargs = None + + async def new_context(self, **kwargs): + self.last_context_kwargs = kwargs + return DummyContext() + + async def close(self): + return + + class DummyPW: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return + + class chromium: + @staticmethod + async def launch(headless, proxy, **kwargs): + dummy_browser = DummyBrowser() + dummy_browser.launch_kwargs = { + "headless": headless, + "proxy": proxy, + **kwargs, + } + return dummy_browser + + class firefox: + @staticmethod + async def launch(headless, proxy, **kwargs): + dummy_browser = DummyBrowser() + dummy_browser.launch_kwargs = { + "headless": headless, + "proxy": proxy, + **kwargs, + } + return dummy_browser + + monkeypatch.setattr("playwright.async_api.async_playwright", lambda: DummyPW()) + storage_state = "dummy_state" + loader = ChromiumLoader( + ["http://example.com"], + backend="playwright", + requires_js_support=True, + storage_state=storage_state, + retry_limit=1, + ) + result = await loader.ascrape_with_js_support("http://example.com") + # To ensure that new_context was called with the correct storage_state, we simulate a launch call + browser = await DummyPW.chromium.launch( + headless=loader.headless, proxy=loader.proxy + ) + await browser.new_context(storage_state=loader.storage_state) + assert browser.last_context_kwargs is not None + assert browser.last_context_kwargs.get("storage_state") == storage_state + assert "Storage State Tested" in result + + +@pytest.mark.asyncio +async def test_ascrape_playwright_browser_config(monkeypatch): + """Test that ascrape_playwright passes extra browser_config kwargs to the browser launch.""" + captured_kwargs = {} + + class DummyPage: + async def goto(self, url, wait_until): + return + + async def wait_for_load_state(self, state): + return + + async def content(self): + return "Config Tested" + + class DummyContext: + async def new_page(self): + return DummyPage() + + class DummyBrowser: + def __init__(self, config): + self.config = config + + async def new_context(self, **kwargs): + self.context_kwargs = kwargs + return DummyContext() + + async def close(self): + return + + class DummyPW: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return + + class chromium: + @staticmethod + async def launch(headless, proxy, **kwargs): + nonlocal captured_kwargs + captured_kwargs = {"headless": headless, "proxy": proxy, **kwargs} + return DummyBrowser(captured_kwargs) + + class firefox: + @staticmethod + async def launch(headless, proxy, **kwargs): + nonlocal captured_kwargs + captured_kwargs = {"headless": headless, "proxy": proxy, **kwargs} + return DummyBrowser(captured_kwargs) + + monkeypatch.setattr("playwright.async_api.async_playwright", lambda: DummyPW()) + extra_kwarg_value = "test_value" + loader = ChromiumLoader( + ["http://example.com"], + backend="playwright", + extra=extra_kwarg_value, + retry_limit=1, + ) + result = await loader.ascrape_playwright("http://example.com") + assert captured_kwargs.get("extra") == extra_kwarg_value + assert "Config Tested" in result diff --git a/tests/test_cleanup_html.py b/tests/test_cleanup_html.py new file mode 100644 index 00000000..28cd86c3 --- /dev/null +++ b/tests/test_cleanup_html.py @@ -0,0 +1,146 @@ +import pytest +from bs4 import BeautifulSoup + +# Import the functions to be tested +from scrapegraphai.utils.cleanup_html import ( + cleanup_html, + extract_from_script_tags, + minify_html, + reduce_html, +) + + +def test_extract_from_script_tags(): + """Test extracting JSON and dynamic data from script tags.""" + html = """ + + + + + + + + + """ + soup = BeautifulSoup(html, "html.parser") + result = extract_from_script_tags(soup) + assert "JSON data from script:" in result + assert '"key": "value"' in result + assert 'Dynamic data - globalVar: "hello"' in result + + +def test_cleanup_html_success(): + """Test cleanup_html with valid HTML containing title, body, links, images, and scripts.""" + html = """ + + + Test Title + + +

Hello World!

+ Link + + + + + """ + base_url = "http://example.com" + title, minimized_body, link_urls, image_urls, script_content = cleanup_html( + html, base_url + ) + assert title == "Test Title" + assert "" in minimized_body and "" in minimized_body + # Check the link is properly joined + assert "http://example.com/page" in link_urls + # Check the image is properly joined + assert "http://example.com/image.jpg" in image_urls + # Check that we got some output from the script extraction + assert "JSON data from script" in script_content + + +def test_cleanup_html_no_body(): + """Test cleanup_html raises ValueError when no tag is present.""" + html = "No Body" + base_url = "http://example.com" + with pytest.raises(ValueError) as excinfo: + cleanup_html(html, base_url) + assert "No HTML body content found" in str(excinfo.value) + + +def test_minify_html(): + """Test minify_html function to remove comments and unnecessary whitespace.""" + raw_html = """ + + + +

Hello World!

+ + + """ + minified = minify_html(raw_html) + # There should be no comment and no unnecessary spaces between tags + assert " +

Some text

+ + + + """ + reduced = reduce_html(raw_html, 1) + # Ensure that unwanted attributes are removed (data-extra and style are gone, class remains) + assert "data-extra" not in reduced + assert "style=" not in reduced + assert 'class="keep"' in reduced + + +def test_reduce_html_reduction_2(): + """Test reduce_html at reduction level 2 (further reducing text content and decomposing style tags).""" + raw_html = """ + + + + + +

Long text with more than twenty characters. Extra content.

+ + + """ + reduced = reduce_html(raw_html, 2) + # For level 2, text should be truncated to the first 20 characters after normalization. + # The original text "Long text with more than twenty characters. Extra content." + # normalized becomes "Long text with more than twenty characters. Extra content." + # and then truncated to: "Long text with more t" (first 20 characters) + assert "Long text with more t" in reduced + # Confirm that style tags contents are completely removed + assert ".unused" not in reduced + + +def test_reduce_html_no_body(): + """Test reduce_html returns specific message when no tag is present.""" + raw_html = "No Body" + reduced = reduce_html(raw_html, 2) + assert reduced == "No tag found in the HTML" diff --git a/tests/test_depth_search_graph.py b/tests/test_depth_search_graph.py index 0197a6b8..1b8a82b2 100644 --- a/tests/test_depth_search_graph.py +++ b/tests/test_depth_search_graph.py @@ -1,8 +1,10 @@ -from unittest.mock import patch, MagicMock -from scrapegraphai.graphs.depth_search_graph import DepthSearchGraph -from scrapegraphai.graphs.abstract_graph import AbstractGraph +from unittest.mock import MagicMock, patch + import pytest +from scrapegraphai.graphs.abstract_graph import AbstractGraph +from scrapegraphai.graphs.depth_search_graph import DepthSearchGraph + class TestDepthSearchGraph: """Test suite for DepthSearchGraph class""" @@ -22,12 +24,14 @@ def test_depth_search_graph_initialization(self, source, expected_input_key): """ prompt = "Test prompt" config = {"llm": {"model": "mock_model"}} - + # Mock both BaseGraph and _create_llm method - with patch("scrapegraphai.graphs.depth_search_graph.BaseGraph"), \ - patch.object(AbstractGraph, '_create_llm', return_value=MagicMock()): + with ( + patch("scrapegraphai.graphs.depth_search_graph.BaseGraph"), + patch.object(AbstractGraph, "_create_llm", return_value=MagicMock()), + ): graph = DepthSearchGraph(prompt, source, config) - + assert graph.prompt == prompt assert graph.source == source assert graph.config == config diff --git a/tests/test_generate_answer_node.py b/tests/test_generate_answer_node.py index db9dbc91..396806ce 100644 --- a/tests/test_generate_answer_node.py +++ b/tests/test_generate_answer_node.py @@ -1,8 +1,6 @@ import json + import pytest -from langchain.prompts import ( - PromptTemplate, -) from langchain_community.chat_models import ( ChatOllama, ) @@ -12,19 +10,18 @@ from requests.exceptions import ( Timeout, ) + from scrapegraphai.nodes.generate_answer_node import ( GenerateAnswerNode, ) class DummyLLM: - def __call__(self, *args, **kwargs): return "dummy response" class DummyLogger: - def info(self, msg): pass diff --git a/tests/test_json_scraper_graph.py b/tests/test_json_scraper_graph.py index 2abcaa85..769fcd64 100644 --- a/tests/test_json_scraper_graph.py +++ b/tests/test_json_scraper_graph.py @@ -1,8 +1,10 @@ -import pytest +from unittest.mock import Mock, patch +import pytest from pydantic import BaseModel, Field + from scrapegraphai.graphs.json_scraper_graph import JSONScraperGraph -from unittest.mock import Mock, patch + class TestJSONScraperGraph: @pytest.fixture @@ -13,10 +15,17 @@ def mock_llm_model(self): def mock_embedder_model(self): return Mock() - @patch('scrapegraphai.graphs.json_scraper_graph.FetchNode') - @patch('scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode') - @patch.object(JSONScraperGraph, '_create_llm') - def test_json_scraper_graph_with_directory(self, mock_create_llm, mock_generate_answer_node, mock_fetch_node, mock_llm_model, mock_embedder_model): + @patch("scrapegraphai.graphs.json_scraper_graph.FetchNode") + @patch("scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode") + @patch.object(JSONScraperGraph, "_create_llm") + def test_json_scraper_graph_with_directory( + self, + mock_create_llm, + mock_generate_answer_node, + mock_fetch_node, + mock_llm_model, + mock_embedder_model, + ): """ Test JSONScraperGraph with a directory of JSON files. This test checks if the graph correctly handles multiple JSON files input @@ -26,15 +35,20 @@ def test_json_scraper_graph_with_directory(self, mock_create_llm, mock_generate_ mock_create_llm.return_value = mock_llm_model # Mock the execute method of BaseGraph - with patch('scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute') as mock_execute: - mock_execute.return_value = ({"answer": "Mocked answer for multiple JSON files"}, {}) + with patch( + "scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute" + ) as mock_execute: + mock_execute.return_value = ( + {"answer": "Mocked answer for multiple JSON files"}, + {}, + ) # Create a JSONScraperGraph instance graph = JSONScraperGraph( prompt="Summarize the data from all JSON files", source="path/to/json/directory", config={"llm": {"model": "test-model", "temperature": 0}}, - schema=BaseModel + schema=BaseModel, ) # Set mocked embedder model @@ -46,10 +60,17 @@ def test_json_scraper_graph_with_directory(self, mock_create_llm, mock_generate_ # Assertions assert result == "Mocked answer for multiple JSON files" assert graph.input_key == "json_dir" - mock_execute.assert_called_once_with({"user_prompt": "Summarize the data from all JSON files", "json_dir": "path/to/json/directory"}) + mock_execute.assert_called_once_with( + { + "user_prompt": "Summarize the data from all JSON files", + "json_dir": "path/to/json/directory", + } + ) mock_fetch_node.assert_called_once() mock_generate_answer_node.assert_called_once() - mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0}) + mock_create_llm.assert_called_once_with( + {"model": "test-model", "temperature": 0} + ) @pytest.fixture def mock_llm_model(self): @@ -59,10 +80,17 @@ def mock_llm_model(self): def mock_embedder_model(self): return Mock() - @patch('scrapegraphai.graphs.json_scraper_graph.FetchNode') - @patch('scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode') - @patch.object(JSONScraperGraph, '_create_llm') - def test_json_scraper_graph_with_single_file(self, mock_create_llm, mock_generate_answer_node, mock_fetch_node, mock_llm_model, mock_embedder_model): + @patch("scrapegraphai.graphs.json_scraper_graph.FetchNode") + @patch("scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode") + @patch.object(JSONScraperGraph, "_create_llm") + def test_json_scraper_graph_with_single_file( + self, + mock_create_llm, + mock_generate_answer_node, + mock_fetch_node, + mock_llm_model, + mock_embedder_model, + ): """ Test JSONScraperGraph with a single JSON file. This test checks if the graph correctly handles a single JSON file input @@ -72,15 +100,20 @@ def test_json_scraper_graph_with_single_file(self, mock_create_llm, mock_generat mock_create_llm.return_value = mock_llm_model # Mock the execute method of BaseGraph - with patch('scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute') as mock_execute: - mock_execute.return_value = ({"answer": "Mocked answer for single JSON file"}, {}) + with patch( + "scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute" + ) as mock_execute: + mock_execute.return_value = ( + {"answer": "Mocked answer for single JSON file"}, + {}, + ) # Create a JSONScraperGraph instance with a single JSON file graph = JSONScraperGraph( prompt="Analyze the data from the JSON file", source="path/to/single/file.json", config={"llm": {"model": "test-model", "temperature": 0}}, - schema=BaseModel + schema=BaseModel, ) # Set mocked embedder model @@ -92,15 +125,29 @@ def test_json_scraper_graph_with_single_file(self, mock_create_llm, mock_generat # Assertions assert result == "Mocked answer for single JSON file" assert graph.input_key == "json" - mock_execute.assert_called_once_with({"user_prompt": "Analyze the data from the JSON file", "json": "path/to/single/file.json"}) + mock_execute.assert_called_once_with( + { + "user_prompt": "Analyze the data from the JSON file", + "json": "path/to/single/file.json", + } + ) mock_fetch_node.assert_called_once() mock_generate_answer_node.assert_called_once() - mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0}) + mock_create_llm.assert_called_once_with( + {"model": "test-model", "temperature": 0} + ) - @patch('scrapegraphai.graphs.json_scraper_graph.FetchNode') - @patch('scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode') - @patch.object(JSONScraperGraph, '_create_llm') - def test_json_scraper_graph_no_answer_found(self, mock_create_llm, mock_generate_answer_node, mock_fetch_node, mock_llm_model, mock_embedder_model): + @patch("scrapegraphai.graphs.json_scraper_graph.FetchNode") + @patch("scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode") + @patch.object(JSONScraperGraph, "_create_llm") + def test_json_scraper_graph_no_answer_found( + self, + mock_create_llm, + mock_generate_answer_node, + mock_fetch_node, + mock_llm_model, + mock_embedder_model, + ): """ Test JSONScraperGraph when no answer is found. This test checks if the graph correctly handles the scenario where no answer is generated, @@ -110,7 +157,9 @@ def test_json_scraper_graph_no_answer_found(self, mock_create_llm, mock_generate mock_create_llm.return_value = mock_llm_model # Mock the execute method of BaseGraph to return an empty answer - with patch('scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute') as mock_execute: + with patch( + "scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute" + ) as mock_execute: mock_execute.return_value = ({}, {}) # Empty state and execution info # Create a JSONScraperGraph instance @@ -118,7 +167,7 @@ def test_json_scraper_graph_no_answer_found(self, mock_create_llm, mock_generate prompt="Query that produces no answer", source="path/to/empty/file.json", config={"llm": {"model": "test-model", "temperature": 0}}, - schema=BaseModel + schema=BaseModel, ) # Set mocked embedder model @@ -130,10 +179,17 @@ def test_json_scraper_graph_no_answer_found(self, mock_create_llm, mock_generate # Assertions assert result == "No answer found." assert graph.input_key == "json" - mock_execute.assert_called_once_with({"user_prompt": "Query that produces no answer", "json": "path/to/empty/file.json"}) + mock_execute.assert_called_once_with( + { + "user_prompt": "Query that produces no answer", + "json": "path/to/empty/file.json", + } + ) mock_fetch_node.assert_called_once() mock_generate_answer_node.assert_called_once() - mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0}) + mock_create_llm.assert_called_once_with( + {"model": "test-model", "temperature": 0} + ) @pytest.fixture def mock_llm_model(self): @@ -143,15 +199,23 @@ def mock_llm_model(self): def mock_embedder_model(self): return Mock() - @patch('scrapegraphai.graphs.json_scraper_graph.FetchNode') - @patch('scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode') - @patch.object(JSONScraperGraph, '_create_llm') - def test_json_scraper_graph_with_custom_schema(self, mock_create_llm, mock_generate_answer_node, mock_fetch_node, mock_llm_model, mock_embedder_model): + @patch("scrapegraphai.graphs.json_scraper_graph.FetchNode") + @patch("scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode") + @patch.object(JSONScraperGraph, "_create_llm") + def test_json_scraper_graph_with_custom_schema( + self, + mock_create_llm, + mock_generate_answer_node, + mock_fetch_node, + mock_llm_model, + mock_embedder_model, + ): """ Test JSONScraperGraph with a custom schema. This test checks if the graph correctly handles a custom schema input and passes it to the GenerateAnswerNode. """ + # Define a custom schema class CustomSchema(BaseModel): name: str = Field(..., description="Name of the attraction") @@ -161,15 +225,20 @@ class CustomSchema(BaseModel): mock_create_llm.return_value = mock_llm_model # Mock the execute method of BaseGraph - with patch('scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute') as mock_execute: - mock_execute.return_value = ({"answer": "Mocked answer with custom schema"}, {}) + with patch( + "scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute" + ) as mock_execute: + mock_execute.return_value = ( + {"answer": "Mocked answer with custom schema"}, + {}, + ) # Create a JSONScraperGraph instance with a custom schema graph = JSONScraperGraph( prompt="List attractions in Chioggia", source="path/to/chioggia.json", config={"llm": {"model": "test-model", "temperature": 0}}, - schema=CustomSchema + schema=CustomSchema, ) # Set mocked embedder model @@ -181,12 +250,19 @@ class CustomSchema(BaseModel): # Assertions assert result == "Mocked answer with custom schema" assert graph.input_key == "json" - mock_execute.assert_called_once_with({"user_prompt": "List attractions in Chioggia", "json": "path/to/chioggia.json"}) + mock_execute.assert_called_once_with( + { + "user_prompt": "List attractions in Chioggia", + "json": "path/to/chioggia.json", + } + ) mock_fetch_node.assert_called_once() mock_generate_answer_node.assert_called_once() # Check if the custom schema was passed to GenerateAnswerNode generate_answer_node_call = mock_generate_answer_node.call_args[1] - assert generate_answer_node_call['node_config']['schema'] == CustomSchema + assert generate_answer_node_call["node_config"]["schema"] == CustomSchema - mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0}) \ No newline at end of file + mock_create_llm.assert_called_once_with( + {"model": "test-model", "temperature": 0} + ) diff --git a/tests/test_models_tokens.py b/tests/test_models_tokens.py index 032f3c15..bfde8df9 100644 --- a/tests/test_models_tokens.py +++ b/tests/test_models_tokens.py @@ -1,13 +1,15 @@ -import pytest from scrapegraphai.helpers.models_tokens import models_tokens + class TestModelsTokens: """Test suite for verifying the models_tokens dictionary content and structure.""" def test_openai_tokens(self): """Test that the 'openai' provider exists and its tokens are valid positive integers.""" openai_models = models_tokens.get("openai") - assert openai_models is not None, "'openai' key should be present in models_tokens" + assert openai_models is not None, ( + "'openai' key should be present in models_tokens" + ) for model, token in openai_models.items(): assert isinstance(model, str), "Model name should be a string" assert isinstance(token, int), "Token limit should be an integer" @@ -28,7 +30,9 @@ def test_google_providers(self): assert google_genai is not None, "'google_genai' key should be present" assert google_vertexai is not None, "'google_vertexai' key should be present" # Check a specific key from google_genai - assert "gemini-pro" in google_genai, "'gemini-pro' should be in google_genai models" + assert "gemini-pro" in google_genai, ( + "'gemini-pro' should be in google_genai models" + ) # Validate token values types for provider in [google_genai, google_vertexai]: for token in provider.values(): @@ -36,7 +40,9 @@ def test_google_providers(self): def test_non_existent_provider(self): """Test that a non-existent provider returns None.""" - assert models_tokens.get("non_existent") is None, "Non-existent provider should return None" + assert models_tokens.get("non_existent") is None, ( + "Non-existent provider should return None" + ) def test_total_model_keys(self): """Test that the total number of models across all providers is above an expected count.""" @@ -53,84 +59,120 @@ def test_non_empty_model_keys(self): """Ensure that model token names are non-empty strings.""" for provider, model_dict in models_tokens.items(): for model in model_dict.keys(): - assert model != "", f"Model name in provider '{provider}' should not be empty." + assert model != "", ( + f"Model name in provider '{provider}' should not be empty." + ) def test_token_limits_range(self): """Test that token limits for all models fall within a plausible range (e.g., 1 to 300000).""" for provider, model_dict in models_tokens.items(): for model, token in model_dict.items(): - assert 1 <= token <= 1100000, f"Token limit for {model} in provider {provider} is out of plausible range." + assert 1 <= token <= 1100000, ( + f"Token limit for {model} in provider {provider} is out of plausible range." + ) + def test_provider_structure(self): """Test that every provider in models_tokens has a dictionary as its value.""" for provider, models in models_tokens.items(): - assert isinstance(models, dict), f"Provider {provider} should map to a dictionary, got {type(models).__name__}" + assert isinstance(models, dict), ( + f"Provider {provider} should map to a dictionary, got {type(models).__name__}" + ) def test_non_empty_provider(self): """Test that each provider dictionary is not empty.""" for provider, models in models_tokens.items(): - assert len(models) > 0, f"Provider {provider} should contain at least one model." + assert len(models) > 0, ( + f"Provider {provider} should contain at least one model." + ) def test_specific_model_token_values(self): """Test specific expected token values for selected models from various providers.""" # Verify a token for a selected model from the 'openai' provider openai = models_tokens.get("openai") - assert openai.get("gpt-3.5-turbo-0125") == 16385, "Expected token limit for gpt-3.5-turbo-0125 in openai to be 16385" + assert openai.get("gpt-3.5-turbo-0125") == 16385, ( + "Expected token limit for gpt-3.5-turbo-0125 in openai to be 16385" + ) # Verify a token for a selected model from the 'azure_openai' provider azure = models_tokens.get("azure_openai") - assert azure.get("gpt-3.5") == 4096, "Expected token limit for gpt-3.5 in azure_openai to be 4096" + assert azure.get("gpt-3.5") == 4096, ( + "Expected token limit for gpt-3.5 in azure_openai to be 4096" + ) # Verify a token for a selected model from the 'anthropic' provider anthropic = models_tokens.get("anthropic") - assert anthropic.get("claude_instant") == 100000, "Expected token limit for claude_instant in anthropic to be 100000" + assert anthropic.get("claude_instant") == 100000, ( + "Expected token limit for claude_instant in anthropic to be 100000" + ) def test_providers_count(self): """Test that the total number of providers is as expected (at least 15).""" - assert len(models_tokens) >= 15, "Expected at least 15 providers in models_tokens" + assert len(models_tokens) >= 15, ( + "Expected at least 15 providers in models_tokens" + ) def test_non_existent_model(self): """Test that a non-existent model within a valid provider returns None.""" openai = models_tokens.get("openai") - assert openai.get("non_existent_model") is None, "Non-existent model should return None from a valid provider." + assert openai.get("non_existent_model") is None, ( + "Non-existent model should return None from a valid provider." + ) + def test_no_whitespace_in_model_names(self): """Test that model names do not contain leading or trailing whitespace.""" for provider, model_dict in models_tokens.items(): for model in model_dict.keys(): # Assert that stripping whitespace does not change the model name - assert model == model.strip(), f"Model name '{model}' in provider '{provider}' contains leading or trailing whitespace." + assert model == model.strip(), ( + f"Model name '{model}' in provider '{provider}' contains leading or trailing whitespace." + ) def test_specific_models_additional(self): """Test specific token values for additional models across various providers.""" # Check some models in the 'ollama' provider ollama = models_tokens.get("ollama") - assert ollama.get("llama2") == 4096, "Expected token limit for 'llama2' in ollama to be 4096" - assert ollama.get("llama2:70b") == 4096, "Expected token limit for 'llama2:70b' in ollama to be 4096" + assert ollama.get("llama2") == 4096, ( + "Expected token limit for 'llama2' in ollama to be 4096" + ) + assert ollama.get("llama2:70b") == 4096, ( + "Expected token limit for 'llama2:70b' in ollama to be 4096" + ) # Check a specific model from the 'mistralai' provider mistralai = models_tokens.get("mistralai") - assert mistralai.get("open-codestral-mamba") == 256000, "Expected token limit for 'open-codestral-mamba' in mistralai to be 256000" + assert mistralai.get("open-codestral-mamba") == 256000, ( + "Expected token limit for 'open-codestral-mamba' in mistralai to be 256000" + ) # Check a specific model from the 'deepseek' provider deepseek = models_tokens.get("deepseek") - assert deepseek.get("deepseek-chat") == 28672, "Expected token limit for 'deepseek-chat' in deepseek to be 28672" + assert deepseek.get("deepseek-chat") == 28672, ( + "Expected token limit for 'deepseek-chat' in deepseek to be 28672" + ) # Check a model from the 'ernie' provider ernie = models_tokens.get("ernie") - assert ernie.get("ernie-bot") == 4096, "Expected token limit for 'ernie-bot' in ernie to be 4096" - + assert ernie.get("ernie-bot") == 4096, ( + "Expected token limit for 'ernie-bot' in ernie to be 4096" + ) + def test_nvidia_specific(self): """Test specific token value for 'meta/codellama-70b' in the nvidia provider.""" nvidia = models_tokens.get("nvidia") assert nvidia is not None, "'nvidia' provider should exist" # Verify token for 'meta/codellama-70b' equals 16384 as defined in the nvidia dictionary - assert nvidia.get("meta/codellama-70b") == 16384, "Expected token limit for 'meta/codellama-70b' in nvidia to be 16384" + assert nvidia.get("meta/codellama-70b") == 16384, ( + "Expected token limit for 'meta/codellama-70b' in nvidia to be 16384" + ) def test_groq_specific(self): """Test specific token value for 'claude-3-haiku-20240307\'' in the groq provider.""" groq = models_tokens.get("groq") assert groq is not None, "'groq' provider should exist" # Note: The model name has an embedded apostrophe at the end in its name. - assert groq.get("claude-3-haiku-20240307'") == 8192, "Expected token limit for 'claude-3-haiku-20240307\\'' in groq to be 8192" + assert groq.get("claude-3-haiku-20240307'") == 8192, ( + "Expected token limit for 'claude-3-haiku-20240307\\'' in groq to be 8192" + ) def test_togetherai_specific(self): """Test specific token value for 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo' in the toghetherai provider.""" @@ -138,11 +180,15 @@ def test_togetherai_specific(self): assert togetherai is not None, "'toghetherai' provider should exist" expected = 128000 model_name = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" - assert togetherai.get(model_name) == expected, f"Expected token limit for '{model_name}' in toghetherai to be {expected}" + assert togetherai.get(model_name) == expected, ( + f"Expected token limit for '{model_name}' in toghetherai to be {expected}" + ) def test_ernie_all_values(self): """Test that all models in the 'ernie' provider have token values exactly 4096.""" ernie = models_tokens.get("ernie") assert ernie is not None, "'ernie' provider should exist" for model, token in ernie.items(): - assert token == 4096, f"Expected token limit for '{model}' in ernie to be 4096, got {token}" \ No newline at end of file + assert token == 4096, ( + f"Expected token limit for '{model}' in ernie to be 4096, got {token}" + ) diff --git a/tests/test_scrape_do.py b/tests/test_scrape_do.py index 9e36da5b..3fa1cd73 100644 --- a/tests/test_scrape_do.py +++ b/tests/test_scrape_do.py @@ -1,6 +1,6 @@ import urllib.parse -import pytest -from unittest.mock import patch, Mock +from unittest.mock import Mock, patch + from scrapegraphai.docloaders.scrape_do import scrape_do_fetch @@ -29,4 +29,3 @@ def test_scrape_do_fetch_without_proxy(): mock_get.assert_called_once_with(expected_url) assert result == expected_response - diff --git a/tests/test_search_graph.py b/tests/test_search_graph.py index 099385da..9ce7e46a 100644 --- a/tests/test_search_graph.py +++ b/tests/test_search_graph.py @@ -1,18 +1,19 @@ +from unittest.mock import MagicMock, patch + import pytest from scrapegraphai.graphs.search_graph import SearchGraph -from unittest.mock import MagicMock, call, patch + class TestSearchGraph: """Test class for SearchGraph""" - @pytest.mark.parametrize("urls", [ - ["https://example.com", "https://test.com"], - [], - ["https://single-url.com"] - ]) - @patch('scrapegraphai.graphs.search_graph.BaseGraph') - @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') + @pytest.mark.parametrize( + "urls", + [["https://example.com", "https://test.com"], [], ["https://single-url.com"]], + ) + @patch("scrapegraphai.graphs.search_graph.BaseGraph") + @patch("scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm") def test_get_considered_urls(self, mock_create_llm, mock_base_graph, urls): """ Test that get_considered_urls returns the correct list of URLs @@ -35,8 +36,8 @@ def test_get_considered_urls(self, mock_create_llm, mock_base_graph, urls): # Assert assert search_graph.get_considered_urls() == urls - @patch('scrapegraphai.graphs.search_graph.BaseGraph') - @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') + @patch("scrapegraphai.graphs.search_graph.BaseGraph") + @patch("scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm") def test_run_no_answer_found(self, mock_create_llm, mock_base_graph): """ Test that the run() method returns "No answer found." when the final state @@ -59,12 +60,19 @@ def test_run_no_answer_found(self, mock_create_llm, mock_base_graph): # Assert assert result == "No answer found." - @patch('scrapegraphai.graphs.search_graph.SearchInternetNode') - @patch('scrapegraphai.graphs.search_graph.GraphIteratorNode') - @patch('scrapegraphai.graphs.search_graph.MergeAnswersNode') - @patch('scrapegraphai.graphs.search_graph.BaseGraph') - @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') - def test_max_results_config(self, mock_create_llm, mock_base_graph, mock_merge_answers, mock_graph_iterator, mock_search_internet): + @patch("scrapegraphai.graphs.search_graph.SearchInternetNode") + @patch("scrapegraphai.graphs.search_graph.GraphIteratorNode") + @patch("scrapegraphai.graphs.search_graph.MergeAnswersNode") + @patch("scrapegraphai.graphs.search_graph.BaseGraph") + @patch("scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm") + def test_max_results_config( + self, + mock_create_llm, + mock_base_graph, + mock_merge_answers, + mock_graph_iterator, + mock_search_internet, + ): """ Test that the max_results parameter from the config is correctly passed to the SearchInternetNode. """ @@ -79,24 +87,28 @@ def test_max_results_config(self, mock_create_llm, mock_base_graph, mock_merge_a # Assert mock_search_internet.assert_called_once() call_args = mock_search_internet.call_args - assert call_args.kwargs['node_config']['max_results'] == max_results - - @patch('scrapegraphai.graphs.search_graph.SearchInternetNode') - @patch('scrapegraphai.graphs.search_graph.GraphIteratorNode') - @patch('scrapegraphai.graphs.search_graph.MergeAnswersNode') - @patch('scrapegraphai.graphs.search_graph.BaseGraph') - @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') - def test_custom_search_engine_config(self, mock_create_llm, mock_base_graph, mock_merge_answers, mock_graph_iterator, mock_search_internet): + assert call_args.kwargs["node_config"]["max_results"] == max_results + + @patch("scrapegraphai.graphs.search_graph.SearchInternetNode") + @patch("scrapegraphai.graphs.search_graph.GraphIteratorNode") + @patch("scrapegraphai.graphs.search_graph.MergeAnswersNode") + @patch("scrapegraphai.graphs.search_graph.BaseGraph") + @patch("scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm") + def test_custom_search_engine_config( + self, + mock_create_llm, + mock_base_graph, + mock_merge_answers, + mock_graph_iterator, + mock_search_internet, + ): """ Test that the custom search_engine parameter from the config is correctly passed to the SearchInternetNode. """ # Arrange prompt = "Test prompt" custom_search_engine = "custom_engine" - config = { - "llm": {"model": "test-model"}, - "search_engine": custom_search_engine - } + config = {"llm": {"model": "test-model"}, "search_engine": custom_search_engine} # Act search_graph = SearchGraph(prompt, config) @@ -104,4 +116,4 @@ def test_custom_search_engine_config(self, mock_create_llm, mock_base_graph, moc # Assert mock_search_internet.assert_called_once() call_args = mock_search_internet.call_args - assert call_args.kwargs['node_config']['search_engine'] == custom_search_engine \ No newline at end of file + assert call_args.kwargs["node_config"]["search_engine"] == custom_search_engine diff --git a/tests/utils/convert_to_md_test.py b/tests/utils/convert_to_md_test.py index f4ea1d4a..d2b64b48 100644 --- a/tests/utils/convert_to_md_test.py +++ b/tests/utils/convert_to_md_test.py @@ -1,5 +1,3 @@ -import pytest - from scrapegraphai.utils.convert_to_md import convert_to_md diff --git a/tests/utils/copy_utils_test.py b/tests/utils/copy_utils_test.py index 607d2c53..c78be3be 100644 --- a/tests/utils/copy_utils_test.py +++ b/tests/utils/copy_utils_test.py @@ -1,5 +1,3 @@ -import copy - import pytest from pydantic.v1 import BaseModel diff --git a/tests/utils/parse_state_keys_test.py b/tests/utils/parse_state_keys_test.py index 4bfaf928..a4617482 100644 --- a/tests/utils/parse_state_keys_test.py +++ b/tests/utils/parse_state_keys_test.py @@ -2,8 +2,6 @@ Parse_state_key test module """ -import pytest - from scrapegraphai.utils.parse_state_keys import parse_expression