-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathsearch_link_node.py
158 lines (129 loc) · 5.69 KB
/
search_link_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
SearchLinkNode Module
"""
import re
from typing import List, Optional
from urllib.parse import parse_qs, urlparse
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from tqdm import tqdm
from ..helpers import default_filters
from ..prompts import TEMPLATE_RELEVANT_LINKS
from .base_node import BaseNode
class SearchLinkNode(BaseNode):
"""
A node that can filter out the relevant links in the webpage content for the user prompt.
Node expects the already scrapped links on the webpage and hence it is expected
that this node be used after the FetchNode.
Attributes:
llm_model: An instance of the language model client used for generating answers.
verbose (bool): A flag indicating whether to show print statements during execution.
Args:
input (str): Boolean expression defining the input keys needed from the state.
output (List[str]): List of output keys to be updated in the state.
node_config (dict): Additional configuration for the node.
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""
def __init__(
self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "SearchLinks",
):
super().__init__(node_name, "node", input, output, 1, node_config)
if node_config.get("filter_links", False) or "filter_config" in node_config:
provided_filter_config = node_config.get("filter_config", {})
self.filter_config = {
**default_filters.filter_dict,
**provided_filter_config,
}
self.filter_links = True
else:
self.filter_config = None
self.filter_links = False
self.verbose = node_config.get("verbose", False)
self.seen_links = set()
def _is_same_domain(self, url, domain):
if not self.filter_links or not self.filter_config.get(
"diff_domain_filter", True
):
return True
parsed_url = urlparse(url)
parsed_domain = urlparse(domain)
return parsed_url.netloc == parsed_domain.netloc
def _is_image_url(self, url):
if not self.filter_links:
return False
image_extensions = self.filter_config.get("img_exts", [])
return any(url.lower().endswith(ext) for ext in image_extensions)
def _is_language_url(self, url):
if not self.filter_links:
return False
lang_indicators = self.filter_config.get("lang_indicators", [])
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
return any(
indicator in parsed_url.path.lower() or indicator in query_params
for indicator in lang_indicators
)
def _is_potentially_irrelevant(self, url):
if not self.filter_links:
return False
irrelevant_keywords = self.filter_config.get("irrelevant_keywords", [])
return any(keyword in url.lower() for keyword in irrelevant_keywords)
def execute(self, state: dict) -> dict:
"""
Filter out relevant links from the webpage that are relavant to prompt.
Out of the filtered links, also ensure that all links are navigable.
Args:
state (dict): The current state of the graph. The input keys will be used to fetch the
correct data types from the state.
Returns:
dict: The updated state with the output key containing the list of links.
Raises:
KeyError: If the input keys are not found in the state, indicating that the
necessary information for generating the answer is missing.
"""
self.logger.info(f"--- Executing {self.node_name} Node ---")
parsed_content_chunks = state.get("doc")
source_url = state.get("url") or state.get("local_dir")
output_parser = JsonOutputParser()
relevant_links = []
for i, chunk in enumerate(
tqdm(
parsed_content_chunks,
desc="Processing chunks",
disable=not self.verbose,
)
):
try:
links = re.findall(r'https?://[^\s"<>\]]+', str(chunk.page_content))
if not self.filter_links:
links = list(set(links))
relevant_links += links
self.seen_links.update(relevant_links)
else:
filtered_links = [
link
for link in links
if self._is_same_domain(link, source_url)
and not self._is_image_url(link)
and not self._is_language_url(link)
and not self._is_potentially_irrelevant(link)
and link not in self.seen_links
]
filtered_links = list(set(filtered_links))
relevant_links += filtered_links
self.seen_links.update(relevant_links)
except Exception as e:
self.logger.error(f"Error extracting links: {e}. Falling back to LLM.")
merge_prompt = PromptTemplate(
template=TEMPLATE_RELEVANT_LINKS,
input_variables=["content", "user_prompt"],
)
merge_chain = merge_prompt | self.llm_model | output_parser
answer = merge_chain.invoke({"content": chunk.page_content})
relevant_links += answer
state.update({self.output[0]: relevant_links})
return state