|
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
| 5 | +import asyncio |
5 | 6 | from pathlib import Path |
| 7 | +from typing import Any |
6 | 8 | from unittest.mock import MagicMock, patch |
7 | 9 |
|
8 | 10 | import pytest |
9 | 11 |
|
| 12 | +from autogen.agentchat.group.context_variables import ContextVariables |
10 | 13 | from autogen.agents.experimental.document_agent.task_manager import TASK_MANAGER_SYSTEM_MESSAGE, TaskManagerAgent |
11 | 14 | from autogen.import_utils import skip_on_missing_imports |
12 | 15 |
|
@@ -142,7 +145,7 @@ def test_task_manager_agent_cleanup( |
142 | 145 | agent.__del__() |
143 | 146 |
|
144 | 147 | # Verify shutdown was called |
145 | | - mock_executor.shutdown.assert_called_once_with(wait=True) |
| 148 | + mock_executor.shutdown.assert_called_once_with(wait=False) |
146 | 149 |
|
147 | 150 | @pytest.mark.openai |
148 | 151 | @skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
@@ -371,3 +374,205 @@ def test_execute_single_query_no_query_engine(self, credentials_gpt_4o_mini: Cre |
371 | 374 | # Test basic agent functionality |
372 | 375 | assert agent.query_engine is None |
373 | 376 | assert agent.parsed_docs_path == tmp_path |
| 377 | + |
| 378 | + |
| 379 | +@pytest.mark.openai |
| 380 | +@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
| 381 | +def test_task_manager_agent_init_with_rag_config(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None: |
| 382 | + """Test TaskManagerAgent initialization with rag_config.""" |
| 383 | + llm_config = credentials_gpt_4o_mini.llm_config |
| 384 | + rag_config: dict[str, Any] = {"vector": {}, "graph": {"host": "bolt://localhost"}} |
| 385 | + |
| 386 | + with ( |
| 387 | + patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine") as mock_ve, |
| 388 | + patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"), |
| 389 | + ): |
| 390 | + mock_ve.return_value = MagicMock() |
| 391 | + agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path, rag_config=rag_config) |
| 392 | + assert agent.rag_config == rag_config |
| 393 | + |
| 394 | + |
| 395 | +@pytest.mark.openai |
| 396 | +@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
| 397 | +def test_task_manager_agent_init_with_custom_system_message( |
| 398 | + credentials_gpt_4o_mini: Credentials, tmp_path: Path |
| 399 | +) -> None: |
| 400 | + """Test TaskManagerAgent initialization with custom system message.""" |
| 401 | + llm_config = credentials_gpt_4o_mini.llm_config |
| 402 | + custom_message = "Custom system message" |
| 403 | + |
| 404 | + with ( |
| 405 | + patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"), |
| 406 | + patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"), |
| 407 | + ): |
| 408 | + agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path, custom_system_message=custom_message) |
| 409 | + assert agent.system_message == custom_message |
| 410 | + |
| 411 | + |
| 412 | +@pytest.mark.openai |
| 413 | +@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
| 414 | +def test_create_rag_engines_with_graph_config(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None: |
| 415 | + """Test _create_rag_engines with graph configuration.""" |
| 416 | + llm_config = credentials_gpt_4o_mini.llm_config |
| 417 | + rag_config: dict[str, Any] = {"graph": {"host": "bolt://localhost", "port": 7687}} |
| 418 | + |
| 419 | + with ( |
| 420 | + patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"), |
| 421 | + patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"), |
| 422 | + patch("autogen.agentchat.contrib.graph_rag.neo4j_graph_query_engine.Neo4jGraphQueryEngine") as mock_neo4j, |
| 423 | + ): |
| 424 | + mock_neo4j.return_value = MagicMock() |
| 425 | + agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path, rag_config=rag_config) |
| 426 | + assert "graph" in agent.rag_engines |
| 427 | + |
| 428 | + |
| 429 | +@pytest.mark.openai |
| 430 | +@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
| 431 | +def test_create_neo4j_engine_import_error(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None: |
| 432 | + """Test _create_neo4j_engine with ImportError.""" |
| 433 | + llm_config = credentials_gpt_4o_mini.llm_config |
| 434 | + rag_config: dict[str, Any] = {"graph": {}} |
| 435 | + |
| 436 | + with ( |
| 437 | + patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"), |
| 438 | + patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"), |
| 439 | + patch( |
| 440 | + "autogen.agentchat.contrib.graph_rag.neo4j_graph_query_engine.Neo4jGraphQueryEngine", |
| 441 | + side_effect=ImportError("No module"), |
| 442 | + ), |
| 443 | + ): |
| 444 | + agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path, rag_config=rag_config) |
| 445 | + assert agent.rag_engines.get("graph") is None |
| 446 | + |
| 447 | + |
| 448 | +@pytest.mark.openai |
| 449 | +@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
| 450 | +def test_safe_context_update(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None: |
| 451 | + """Test _safe_context_update method.""" |
| 452 | + llm_config = credentials_gpt_4o_mini.llm_config |
| 453 | + |
| 454 | + with ( |
| 455 | + patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"), |
| 456 | + patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"), |
| 457 | + ): |
| 458 | + agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path) |
| 459 | + context_vars = ContextVariables() |
| 460 | + agent._safe_context_update(context_vars, "test_key", "test_value") |
| 461 | + assert context_vars["test_key"] == "test_value" |
| 462 | + |
| 463 | + |
| 464 | +@pytest.mark.openai |
| 465 | +@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
| 466 | +def test_ingest_documents_empty_list(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None: |
| 467 | + """Test ingest_documents with empty document list.""" |
| 468 | + llm_config = credentials_gpt_4o_mini.llm_config |
| 469 | + |
| 470 | + with ( |
| 471 | + patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"), |
| 472 | + patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"), |
| 473 | + ): |
| 474 | + agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path) |
| 475 | + context_vars = ContextVariables() |
| 476 | + |
| 477 | + # Access the ingest_documents function from the agent's tools |
| 478 | + ingest_tool = None |
| 479 | + for tool in agent.tools: |
| 480 | + if tool.name == "ingest_documents": |
| 481 | + ingest_tool = tool |
| 482 | + break |
| 483 | + |
| 484 | + assert ingest_tool is not None, "ingest_documents tool not found" |
| 485 | + result = asyncio.run(ingest_tool.func([], context_vars)) |
| 486 | + assert "No documents provided" in str(result.message) |
| 487 | + |
| 488 | + |
| 489 | +@pytest.mark.openai |
| 490 | +@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
| 491 | +def test_ingest_documents_invalid_paths(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None: |
| 492 | + """Test ingest_documents with invalid document paths.""" |
| 493 | + llm_config = credentials_gpt_4o_mini.llm_config |
| 494 | + |
| 495 | + with ( |
| 496 | + patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"), |
| 497 | + patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"), |
| 498 | + ): |
| 499 | + agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path) |
| 500 | + context_vars = ContextVariables() |
| 501 | + |
| 502 | + # Access the ingest_documents function from the agent's tools |
| 503 | + ingest_tool = None |
| 504 | + for tool in agent.tools: |
| 505 | + if tool.name == "ingest_documents": |
| 506 | + ingest_tool = tool |
| 507 | + break |
| 508 | + |
| 509 | + assert ingest_tool is not None, "ingest_documents tool not found" |
| 510 | + result = asyncio.run(ingest_tool.func(["", " "], context_vars)) |
| 511 | + assert "No valid documents found" in str(result.message) |
| 512 | + |
| 513 | + |
| 514 | +@pytest.mark.openai |
| 515 | +@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
| 516 | +def test_execute_query_empty_list(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None: |
| 517 | + """Test execute_query with empty query list.""" |
| 518 | + llm_config = credentials_gpt_4o_mini.llm_config |
| 519 | + |
| 520 | + with ( |
| 521 | + patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"), |
| 522 | + patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"), |
| 523 | + ): |
| 524 | + agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path) |
| 525 | + context_vars = ContextVariables() |
| 526 | + |
| 527 | + # Access the execute_query function from the agent's tools |
| 528 | + query_tool = None |
| 529 | + for tool in agent.tools: |
| 530 | + if tool.name == "execute_query": |
| 531 | + query_tool = tool |
| 532 | + break |
| 533 | + |
| 534 | + assert query_tool is not None, "execute_query tool not found" |
| 535 | + result = asyncio.run(query_tool.func([], context_vars)) |
| 536 | + assert result == "No queries to run" |
| 537 | + |
| 538 | + |
| 539 | +@pytest.mark.openai |
| 540 | +@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
| 541 | +def test_execute_query_invalid_queries(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None: |
| 542 | + """Test execute_query with invalid queries.""" |
| 543 | + llm_config = credentials_gpt_4o_mini.llm_config |
| 544 | + |
| 545 | + with ( |
| 546 | + patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"), |
| 547 | + patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor"), |
| 548 | + ): |
| 549 | + agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path) |
| 550 | + context_vars = ContextVariables() |
| 551 | + |
| 552 | + # Access the execute_query function from the agent's tools |
| 553 | + query_tool = None |
| 554 | + for tool in agent.tools: |
| 555 | + if tool.name == "execute_query": |
| 556 | + query_tool = tool |
| 557 | + break |
| 558 | + |
| 559 | + assert query_tool is not None, "execute_query tool not found" |
| 560 | + result = asyncio.run(query_tool.func(["", " "], context_vars)) |
| 561 | + assert result == "No valid queries provided" |
| 562 | + |
| 563 | + |
| 564 | +@pytest.mark.openai |
| 565 | +@skip_on_missing_imports(["selenium", "webdriver_manager"], "rag") |
| 566 | +def test_del_with_exception(credentials_gpt_4o_mini: Credentials, tmp_path: Path) -> None: |
| 567 | + """Test __del__ method with exception during shutdown.""" |
| 568 | + llm_config = credentials_gpt_4o_mini.llm_config |
| 569 | + mock_executor = MagicMock() |
| 570 | + mock_executor.shutdown.side_effect = Exception("Shutdown failed") |
| 571 | + |
| 572 | + with ( |
| 573 | + patch("autogen.agents.experimental.document_agent.task_manager.VectorChromaQueryEngine"), |
| 574 | + patch("autogen.agents.experimental.document_agent.task_manager.ThreadPoolExecutor", return_value=mock_executor), |
| 575 | + ): |
| 576 | + agent = TaskManagerAgent(llm_config=llm_config, parsed_docs_path=tmp_path) |
| 577 | + agent.__del__() |
| 578 | + mock_executor.shutdown.assert_called_once_with(wait=False) |
0 commit comments