Skip to content

Commit 1ba34fa

Browse files
committed
test: update task_manager tests
1 parent 90fa907 commit 1ba34fa

File tree

2 files changed

+219
-5
lines changed

2 files changed

+219
-5
lines changed

test/agents/experimental/document_agent/test_task_manager.py

Lines changed: 206 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import asyncio
56
from pathlib import Path
7+
from typing import Any
68
from unittest.mock import MagicMock, patch
79

810
import pytest
911

12+
from autogen.agentchat.group.context_variables import ContextVariables
1013
from autogen.agents.experimental.document_agent.task_manager import TASK_MANAGER_SYSTEM_MESSAGE, TaskManagerAgent
1114
from autogen.import_utils import skip_on_missing_imports
1215

@@ -142,7 +145,7 @@ def test_task_manager_agent_cleanup(
142145
agent.__del__()
143146

144147
# Verify shutdown was called
145-
mock_executor.shutdown.assert_called_once_with(wait=True)
148+
mock_executor.shutdown.assert_called_once_with(wait=False)
146149

147150
@pytest.mark.openai
148151
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
@@ -371,3 +374,205 @@ def test_execute_single_query_no_query_engine(self, credentials_gpt_4o_mini: Cre
371374
# Test basic agent functionality
372375
assert agent.query_engine is None
373376
assert agent.parsed_docs_path == tmp_path
377+
378+
379+
@pytest.mark.openai
380+
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
381+
def test_task_manager_agent_init_with_rag_config(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None:
382+
"""Test TaskManagerAgent initialization with rag_config."""
383+
llm_config = credentials_gpt_4o_mini.llm_config
384+
rag_config: dict[str, Any] = {"vector": {}, "graph": {"host": "bolt://localhost"}}
385+
386+
with (
387+
patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine") as mock_ve,
388+
patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"),
389+
):
390+
mock_ve.return_value = MagicMock()
391+
agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path, rag_config=rag_config)
392+
assert agent.rag_config == rag_config
393+
394+
395+
@pytest.mark.openai
396+
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
397+
def test_task_manager_agent_init_with_custom_system_message(
398+
credentials_gpt_4o_mini: Credentials, tmp_path: Path
399+
) -> None:
400+
"""Test TaskManagerAgent initialization with custom system message."""
401+
llm_config = credentials_gpt_4o_mini.llm_config
402+
custom_message = "Custom system message"
403+
404+
with (
405+
patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"),
406+
patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"),
407+
):
408+
agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path, custom_system_message=custom_message)
409+
assert agent.system_message == custom_message
410+
411+
412+
@pytest.mark.openai
413+
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
414+
def test_create_rag_engines_with_graph_config(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None:
415+
"""Test _create_rag_engines with graph configuration."""
416+
llm_config = credentials_gpt_4o_mini.llm_config
417+
rag_config: dict[str, Any] = {"graph": {"host": "bolt://localhost", "port": 7687}}
418+
419+
with (
420+
patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"),
421+
patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"),
422+
patch("autogen.agentchat.contrib.graph_rag.neo4j_graph_query_engine.Neo4jGraphQueryEngine") as mock_neo4j,
423+
):
424+
mock_neo4j.return_value = MagicMock()
425+
agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path, rag_config=rag_config)
426+
assert "graph" in agent.rag_engines
427+
428+
429+
@pytest.mark.openai
430+
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
431+
def test_create_neo4j_engine_import_error(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None:
432+
"""Test _create_neo4j_engine with ImportError."""
433+
llm_config = credentials_gpt_4o_mini.llm_config
434+
rag_config: dict[str, Any] = {"graph": {}}
435+
436+
with (
437+
patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"),
438+
patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"),
439+
patch(
440+
"autogen.agentchat.contrib.graph_rag.neo4j_graph_query_engine.Neo4jGraphQueryEngine",
441+
side_effect=ImportError("No module"),
442+
),
443+
):
444+
agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path, rag_config=rag_config)
445+
assert agent.rag_engines.get("graph") is None
446+
447+
448+
@pytest.mark.openai
449+
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
450+
def test_safe_context_update(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None:
451+
"""Test _safe_context_update method."""
452+
llm_config = credentials_gpt_4o_mini.llm_config
453+
454+
with (
455+
patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"),
456+
patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"),
457+
):
458+
agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path)
459+
context_vars = ContextVariables()
460+
agent._safe_context_update(context_vars, "test_key", "test_value")
461+
assert context_vars["test_key"] == "test_value"
462+
463+
464+
@pytest.mark.openai
465+
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
466+
def test_ingest_documents_empty_list(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None:
467+
"""Test ingest_documents with empty document list."""
468+
llm_config = credentials_gpt_4o_mini.llm_config
469+
470+
with (
471+
patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"),
472+
patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"),
473+
):
474+
agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path)
475+
context_vars = ContextVariables()
476+
477+
# Access the ingest_documents function from the agent's tools
478+
ingest_tool = None
479+
for tool in agent.tools:
480+
if tool.name == "ingest_documents":
481+
ingest_tool = tool
482+
break
483+
484+
assert ingest_tool is not None, "ingest_documents tool not found"
485+
result = asyncio.run(ingest_tool.func([], context_vars))
486+
assert "No documents provided" in str(result.message)
487+
488+
489+
@pytest.mark.openai
490+
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
491+
def test_ingest_documents_invalid_paths(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None:
492+
"""Test ingest_documents with invalid document paths."""
493+
llm_config = credentials_gpt_4o_mini.llm_config
494+
495+
with (
496+
patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"),
497+
patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"),
498+
):
499+
agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path)
500+
context_vars = ContextVariables()
501+
502+
# Access the ingest_documents function from the agent's tools
503+
ingest_tool = None
504+
for tool in agent.tools:
505+
if tool.name == "ingest_documents":
506+
ingest_tool = tool
507+
break
508+
509+
assert ingest_tool is not None, "ingest_documents tool not found"
510+
result = asyncio.run(ingest_tool.func(["", " "], context_vars))
511+
assert "No valid documents found" in str(result.message)
512+
513+
514+
@pytest.mark.openai
515+
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
516+
def test_execute_query_empty_list(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None:
517+
"""Test execute_query with empty query list."""
518+
llm_config = credentials_gpt_4o_mini.llm_config
519+
520+
with (
521+
patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"),
522+
patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"),
523+
):
524+
agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path)
525+
context_vars = ContextVariables()
526+
527+
# Access the execute_query function from the agent's tools
528+
query_tool = None
529+
for tool in agent.tools:
530+
if tool.name == "execute_query":
531+
query_tool = tool
532+
break
533+
534+
assert query_tool is not None, "execute_query tool not found"
535+
result = asyncio.run(query_tool.func([], context_vars))
536+
assert result == "No queries to run"
537+
538+
539+
@pytest.mark.openai
540+
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
541+
def test_execute_query_invalid_queries(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None:
542+
"""Test execute_query with invalid queries."""
543+
llm_config = credentials_gpt_4o_mini.llm_config
544+
545+
with (
546+
patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"),
547+
patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"),
548+
):
549+
agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path)
550+
context_vars = ContextVariables()
551+
552+
# Access the execute_query function from the agent's tools
553+
query_tool = None
554+
for tool in agent.tools:
555+
if tool.name == "execute_query":
556+
query_tool = tool
557+
break
558+
559+
assert query_tool is not None, "execute_query tool not found"
560+
result = asyncio.run(query_tool.func(["", " "], context_vars))
561+
assert result == "No valid queries provided"
562+
563+
564+
@pytest.mark.openai
565+
@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag")
566+
def test_del_with_exception(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None:
567+
"""Test __del__ method with exception during shutdown."""
568+
llm_config = credentials_gpt_4o_mini.llm_config
569+
mock_executor = MagicMock()
570+
mock_executor.shutdown.side_effect = Exception("Shutdown failed")
571+
572+
with (
573+
patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"),
574+
patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor", return_value=mock_executor),
575+
):
576+
agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path)
577+
agent.__del__()
578+
mock_executor.shutdown.assert_called_once_with(wait=False)

test/agents/experimental/document_agent/test_task_manager_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def test_extract_text_from_pdf_url_success(
3535
mock_response.content = b"fake pdf content"
3636
mock_get.return_value = mock_response
3737

38-
mock_temp_dir.return_value.__enter__.return_value = "/tmp/test"
38+
# Mock the temporary directory context manager properly
39+
mock_temp_dir_instance = Mock()
40+
mock_temp_dir_instance.__enter__ = Mock(return_value="/tmp/test")
41+
mock_temp_dir_instance.__exit__ = Mock(return_value=None)
42+
mock_temp_dir.return_value = mock_temp_dir_instance
3943

4044
# Create a proper mock document that supports iteration
4145
mock_doc = Mock()
@@ -56,9 +60,12 @@ def test_extract_text_from_pdf_url_success(
5660
mock_llm_lingua_class.return_value = mock_llm_lingua
5761

5862
# Mock urllib3 to return a URL with scheme
59-
with patch(
60-
"autogen.agents.experimental.document_agent.task_manager_utils.urllib3.util.url.parse_url"
61-
) as mock_parse_url:
63+
with (
64+
patch(
65+
"autogen.agents.experimental.document_agent.task_manager_utils.urllib3.util.url.parse_url"
66+
) as mock_parse_url,
67+
patch("builtins.open", mock_open()) as mock_file,
68+
):
6269
mock_parsed_url = Mock()
6370
mock_parsed_url.scheme = "https"
6471
mock_parse_url.return_value = mock_parsed_url
@@ -71,6 +78,8 @@ def test_extract_text_from_pdf_url_success(
7178
mock_get.assert_called_once_with("https://example.com/test.pdf")
7279
mock_fitz.open.assert_called_once()
7380
mock_compressor.apply_transform.assert_called_once_with([{"content": "Page 1 contentPage 2 content"}])
81+
# Verify file was opened for writing - use Path object instead of string
82+
mock_file.assert_called_with(Path("/tmp/test/temp.pdf"), "wb")
7483

7584
@patch("autogen.agents.experimental.document_agent.task_manager_utils.urllib3.util.url.parse_url")
7685
def test_extract_text_from_pdf_non_url_raises_error(self, mock_parse_url: Any) -> None:

0 commit comments

Comments
 (0)