Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions libs/langchain/langchain_classic/retrievers/merger_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.retrievers import BaseRetriever, RetrieverOutputLike


class MergerRetriever(BaseRetriever):
"""Retriever that merges the results of multiple retrievers."""

retrievers: list[BaseRetriever]
retrievers: list[BaseRetriever | RetrieverOutputLike]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would RetrieverLike work? (Since it has as input a str?)

RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]

"""A list of retrievers to merge."""

def _get_relevant_documents(
Expand Down Expand Up @@ -65,13 +65,20 @@ def merge_documents(
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = [
retriever.invoke(
query,
config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")},
)
for i, retriever in enumerate(self.retrievers)
]
retriever_docs = []
for i, retriever in enumerate(self.retrievers):
if isinstance(retriever, BaseRetriever):
docs = retriever.invoke(
query,
config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")},
)
else:
# Handle RetrieverOutputLike (Runnable)
docs = retriever.invoke(
query,
config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")},
)
retriever_docs.append(docs)

# Merge the results of the retrievers.
merged_documents = []
Expand Down
128 changes: 128 additions & 0 deletions libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Tests for MergerRetriever."""

from langchain_core.documents import Document
from langchain_core.language_models import FakeListLLM
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever

from langchain_classic.chains import create_history_aware_retriever
from langchain_classic.retrievers import MergerRetriever
from tests.unit_tests.retrievers.parrot_retriever import FakeParrotRetriever

Check failure on line 10 in libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.10) / Python 3.10

Ruff (F401)

tests/unit_tests/retrievers/test_merger_retriever.py:10:58: F401 `tests.unit_tests.retrievers.parrot_retriever.FakeParrotRetriever` imported but unused

Check failure on line 10 in libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.13) / Python 3.13

Ruff (F401)

tests/unit_tests/retrievers/test_merger_retriever.py:10:58: F401 `tests.unit_tests.retrievers.parrot_retriever.FakeParrotRetriever` imported but unused


class MockRetriever(BaseRetriever):
"""Mock retriever for testing."""

def __init__(self, docs: list[Document]):
super().__init__()
self.docs = docs

def _get_relevant_documents(
self, query: str, *, run_manager

Check failure on line 21 in libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.10) / Python 3.10

Ruff (ARG002)

tests/unit_tests/retrievers/test_merger_retriever.py:21:30: ARG002 Unused method argument: `run_manager`

Check failure on line 21 in libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.10) / Python 3.10

Ruff (ANN001)

tests/unit_tests/retrievers/test_merger_retriever.py:21:30: ANN001 Missing type annotation for function argument `run_manager`

Check failure on line 21 in libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.10) / Python 3.10

Ruff (ARG002)

tests/unit_tests/retrievers/test_merger_retriever.py:21:15: ARG002 Unused method argument: `query`

Check failure on line 21 in libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.13) / Python 3.13

Ruff (ARG002)

tests/unit_tests/retrievers/test_merger_retriever.py:21:30: ARG002 Unused method argument: `run_manager`

Check failure on line 21 in libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.13) / Python 3.13

Ruff (ANN001)

tests/unit_tests/retrievers/test_merger_retriever.py:21:30: ANN001 Missing type annotation for function argument `run_manager`

Check failure on line 21 in libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.13) / Python 3.13

Ruff (ARG002)

tests/unit_tests/retrievers/test_merger_retriever.py:21:15: ARG002 Unused method argument: `query`
) -> list[Document]:
return self.docs


def test_merger_retriever_with_base_retrievers() -> None:
"""Test MergerRetriever with BaseRetriever objects."""
docs1 = [Document(page_content="doc1"), Document(page_content="doc2")]
docs2 = [Document(page_content="doc3"), Document(page_content="doc4")]

retriever1 = MockRetriever(docs1)
retriever2 = MockRetriever(docs2)

merger = MergerRetriever(retrievers=[retriever1, retriever2])

result = merger.invoke("test query")

# Should merge documents from both retrievers
assert len(result) == 4
assert result[0].page_content == "doc1"
assert result[1].page_content == "doc3"
assert result[2].page_content == "doc2"
assert result[3].page_content == "doc4"


def test_merger_retriever_with_history_aware_retriever() -> None:
"""Test MergerRetriever with create_history_aware_retriever (RetrieverOutputLike)."""

Check failure on line 47 in libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.10) / Python 3.10

Ruff (E501)

tests/unit_tests/retrievers/test_merger_retriever.py:47:89: E501 Line too long (89 > 88)

Check failure on line 47 in libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.13) / Python 3.13

Ruff (E501)

tests/unit_tests/retrievers/test_merger_retriever.py:47:89: E501 Line too long (89 > 88)
# Create a simple retriever
docs = [Document(page_content="test document")]
base_retriever = MockRetriever(docs)

# Create a history aware retriever
llm = FakeListLLM(responses=["rephrased query"])
prompt = PromptTemplate.from_template("Rephrase: {input}")
history_aware_retriever = create_history_aware_retriever(
llm, base_retriever, prompt
)

# Create another simple retriever
docs2 = [Document(page_content="another document")]
retriever2 = MockRetriever(docs2)

# Create MergerRetriever with both types
merger = MergerRetriever(retrievers=[history_aware_retriever, retriever2])

# This should work without ValidationError
result = merger.invoke("test query")

# Should have documents from both retrievers
assert len(result) == 2
assert any(doc.page_content == "test document" for doc in result)
assert any(doc.page_content == "another document" for doc in result)


def test_merger_retriever_mixed_types() -> None:
"""Test MergerRetriever with mixed BaseRetriever and RetrieverOutputLike types."""
# Create base retrievers
docs1 = [Document(page_content="base retriever doc")]
base_retriever1 = MockRetriever(docs1)

docs2 = [Document(page_content="another base retriever doc")]
base_retriever2 = MockRetriever(docs2)

# Create history aware retriever
llm = FakeListLLM(responses=["rephrased"])
prompt = PromptTemplate.from_template("Rephrase: {input}")
history_aware_retriever = create_history_aware_retriever(
llm, base_retriever1, prompt
)

# Create MergerRetriever with mixed types
merger = MergerRetriever(retrievers=[base_retriever2, history_aware_retriever])

# This should work without ValidationError
result = merger.invoke("test query")

# Should have documents from both retrievers
assert len(result) == 2
assert any(doc.page_content == "another base retriever doc" for doc in result)
assert any(doc.page_content == "base retriever doc" for doc in result)


async def test_merger_retriever_async() -> None:
"""Test MergerRetriever async functionality with mixed types."""
# Create base retrievers
docs1 = [Document(page_content="async doc 1")]
base_retriever1 = MockRetriever(docs1)

docs2 = [Document(page_content="async doc 2")]
base_retriever2 = MockRetriever(docs2)

# Create history aware retriever
llm = FakeListLLM(responses=["async rephrased"])
prompt = PromptTemplate.from_template("Async rephrase: {input}")
history_aware_retriever = create_history_aware_retriever(
llm, base_retriever1, prompt
)

# Create MergerRetriever with mixed types
merger = MergerRetriever(retrievers=[base_retriever2, history_aware_retriever])

# Test async invoke
result = await merger.ainvoke("async test query")

# Should have documents from both retrievers
assert len(result) == 2
assert any(doc.page_content == "async doc 2" for doc in result)
assert any(doc.page_content == "async doc 1" for doc in result)
Loading