From 557b45436532490b36d766808bcf0f77929a229f Mon Sep 17 00:00:00 2001 From: khushmanvar Date: Wed, 15 Oct 2025 00:31:54 +0530 Subject: [PATCH] feat(retrievers): enhance MergerRetriever to support mixed retriever types - Updated the `retrievers` attribute in `MergerRetriever` to accept both `BaseRetriever` and `RetrieverOutputLike` types. - Modified the `_get_relevant_documents` method to handle invocation for both types of retrievers. - Added comprehensive unit tests for `MergerRetriever` to validate functionality with mixed retriever types, including async support and integration with history-aware retrievers. Fixes #33184 --- .../retrievers/merger_retriever.py | 25 ++-- .../retrievers/test_merger_retriever.py | 128 ++++++++++++++++++ 2 files changed, 144 insertions(+), 9 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py diff --git a/libs/langchain/langchain_classic/retrievers/merger_retriever.py b/libs/langchain/langchain_classic/retrievers/merger_retriever.py index f556b7d1785cf..70bf1e27e0302 100644 --- a/libs/langchain/langchain_classic/retrievers/merger_retriever.py +++ b/libs/langchain/langchain_classic/retrievers/merger_retriever.py @@ -5,13 +5,13 @@ CallbackManagerForRetrieverRun, ) from langchain_core.documents import Document -from langchain_core.retrievers import BaseRetriever +from langchain_core.retrievers import BaseRetriever, RetrieverOutputLike class MergerRetriever(BaseRetriever): """Retriever that merges the results of multiple retrievers.""" - retrievers: list[BaseRetriever] + retrievers: list[BaseRetriever | RetrieverOutputLike] """A list of retrievers to merge.""" def _get_relevant_documents( @@ -65,13 +65,20 @@ def merge_documents( A list of merged documents. """ # Get the results of all retrievers. - retriever_docs = [ - retriever.invoke( - query, - config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")}, - ) - for i, retriever in enumerate(self.retrievers) - ] + retriever_docs = [] + for i, retriever in enumerate(self.retrievers): + if isinstance(retriever, BaseRetriever): + docs = retriever.invoke( + query, + config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")}, + ) + else: + # Handle RetrieverOutputLike (Runnable) + docs = retriever.invoke( + query, + config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")}, + ) + retriever_docs.append(docs) # Merge the results of the retrievers. merged_documents = [] diff --git a/libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py b/libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py new file mode 100644 index 0000000000000..c6734b99f3888 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/test_merger_retriever.py @@ -0,0 +1,128 @@ +"""Tests for MergerRetriever.""" + +from langchain_core.documents import Document +from langchain_core.language_models import FakeListLLM +from langchain_core.prompts import PromptTemplate +from langchain_core.retrievers import BaseRetriever + +from langchain_classic.chains import create_history_aware_retriever +from langchain_classic.retrievers import MergerRetriever +from tests.unit_tests.retrievers.parrot_retriever import FakeParrotRetriever + + +class MockRetriever(BaseRetriever): + """Mock retriever for testing.""" + + def __init__(self, docs: list[Document]): + super().__init__() + self.docs = docs + + def _get_relevant_documents( + self, query: str, *, run_manager + ) -> list[Document]: + return self.docs + + +def test_merger_retriever_with_base_retrievers() -> None: + """Test MergerRetriever with BaseRetriever objects.""" + docs1 = [Document(page_content="doc1"), Document(page_content="doc2")] + docs2 = [Document(page_content="doc3"), Document(page_content="doc4")] + + retriever1 = MockRetriever(docs1) + retriever2 = MockRetriever(docs2) + + merger = MergerRetriever(retrievers=[retriever1, retriever2]) + + result = merger.invoke("test query") + + # Should merge documents from both retrievers + assert len(result) == 4 + assert result[0].page_content == "doc1" + assert result[1].page_content == "doc3" + assert result[2].page_content == "doc2" + assert result[3].page_content == "doc4" + + +def test_merger_retriever_with_history_aware_retriever() -> None: + """Test MergerRetriever with create_history_aware_retriever (RetrieverOutputLike).""" + # Create a simple retriever + docs = [Document(page_content="test document")] + base_retriever = MockRetriever(docs) + + # Create a history aware retriever + llm = FakeListLLM(responses=["rephrased query"]) + prompt = PromptTemplate.from_template("Rephrase: {input}") + history_aware_retriever = create_history_aware_retriever( + llm, base_retriever, prompt + ) + + # Create another simple retriever + docs2 = [Document(page_content="another document")] + retriever2 = MockRetriever(docs2) + + # Create MergerRetriever with both types + merger = MergerRetriever(retrievers=[history_aware_retriever, retriever2]) + + # This should work without ValidationError + result = merger.invoke("test query") + + # Should have documents from both retrievers + assert len(result) == 2 + assert any(doc.page_content == "test document" for doc in result) + assert any(doc.page_content == "another document" for doc in result) + + +def test_merger_retriever_mixed_types() -> None: + """Test MergerRetriever with mixed BaseRetriever and RetrieverOutputLike types.""" + # Create base retrievers + docs1 = [Document(page_content="base retriever doc")] + base_retriever1 = MockRetriever(docs1) + + docs2 = [Document(page_content="another base retriever doc")] + base_retriever2 = MockRetriever(docs2) + + # Create history aware retriever + llm = FakeListLLM(responses=["rephrased"]) + prompt = PromptTemplate.from_template("Rephrase: {input}") + history_aware_retriever = create_history_aware_retriever( + llm, base_retriever1, prompt + ) + + # Create MergerRetriever with mixed types + merger = MergerRetriever(retrievers=[base_retriever2, history_aware_retriever]) + + # This should work without ValidationError + result = merger.invoke("test query") + + # Should have documents from both retrievers + assert len(result) == 2 + assert any(doc.page_content == "another base retriever doc" for doc in result) + assert any(doc.page_content == "base retriever doc" for doc in result) + + +async def test_merger_retriever_async() -> None: + """Test MergerRetriever async functionality with mixed types.""" + # Create base retrievers + docs1 = [Document(page_content="async doc 1")] + base_retriever1 = MockRetriever(docs1) + + docs2 = [Document(page_content="async doc 2")] + base_retriever2 = MockRetriever(docs2) + + # Create history aware retriever + llm = FakeListLLM(responses=["async rephrased"]) + prompt = PromptTemplate.from_template("Async rephrase: {input}") + history_aware_retriever = create_history_aware_retriever( + llm, base_retriever1, prompt + ) + + # Create MergerRetriever with mixed types + merger = MergerRetriever(retrievers=[base_retriever2, history_aware_retriever]) + + # Test async invoke + result = await merger.ainvoke("async test query") + + # Should have documents from both retrievers + assert len(result) == 2 + assert any(doc.page_content == "async doc 2" for doc in result) + assert any(doc.page_content == "async doc 1" for doc in result)