Skip to content

Commit e45faae

Browse files
authored
Merge branch 'main' into feat/csv-row-mode
2 parents fc7f91f + 443101e commit e45faae

15 files changed

+686
-24
lines changed

haystack/components/agents/agent.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def _initialize_fresh_execution(
255255
messages: list[ChatMessage],
256256
streaming_callback: Optional[StreamingCallbackT],
257257
requires_async: bool,
258+
system_prompt: Optional[str] = None,
258259
**kwargs,
259260
) -> _ExecutionContext:
260261
"""
@@ -263,10 +264,12 @@ def _initialize_fresh_execution(
263264
:param messages: List of ChatMessage objects to start the agent with.
264265
:param streaming_callback: Optional callback for streaming responses.
265266
:param requires_async: Whether the agent run requires asynchronous execution.
267+
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
266268
:param kwargs: Additional data to pass to the State used by the Agent.
267269
"""
268-
if self.system_prompt is not None:
269-
messages = [ChatMessage.from_system(self.system_prompt)] + messages
270+
system_prompt = system_prompt or self.system_prompt
271+
if system_prompt is not None:
272+
messages = [ChatMessage.from_system(system_prompt)] + messages
270273

271274
if all(m.is_from(ChatRole.SYSTEM) for m in messages):
272275
logger.warning("All messages provided to the Agent component are system messages. This is not recommended.")
@@ -443,6 +446,7 @@ def run(
443446
*,
444447
break_point: Optional[AgentBreakpoint] = None,
445448
snapshot: Optional[AgentSnapshot] = None,
449+
system_prompt: Optional[str] = None,
446450
**kwargs: Any,
447451
) -> dict[str, Any]:
448452
"""
@@ -455,6 +459,7 @@ def run(
455459
for "tool_invoker".
456460
:param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
457461
the relevant information to restart the Agent execution from where it left off.
462+
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
458463
:param kwargs: Additional data to pass to the State schema used by the Agent.
459464
The keys must match the schema defined in the Agent's `state_schema`.
460465
:returns:
@@ -482,7 +487,11 @@ def run(
482487
)
483488
else:
484489
exe_context = self._initialize_fresh_execution(
485-
messages=messages, streaming_callback=streaming_callback, requires_async=False, **kwargs
490+
messages=messages,
491+
streaming_callback=streaming_callback,
492+
requires_async=False,
493+
system_prompt=system_prompt,
494+
**kwargs,
486495
)
487496

488497
with self._create_agent_span() as span:
@@ -558,6 +567,7 @@ async def run_async(
558567
*,
559568
break_point: Optional[AgentBreakpoint] = None,
560569
snapshot: Optional[AgentSnapshot] = None,
570+
system_prompt: Optional[str] = None,
561571
**kwargs: Any,
562572
) -> dict[str, Any]:
563573
"""
@@ -574,6 +584,7 @@ async def run_async(
574584
for "tool_invoker".
575585
:param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
576586
the relevant information to restart the Agent execution from where it left off.
587+
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
577588
:param kwargs: Additional data to pass to the State schema used by the Agent.
578589
The keys must match the schema defined in the Agent's `state_schema`.
579590
:returns:
@@ -601,7 +612,11 @@ async def run_async(
601612
)
602613
else:
603614
exe_context = self._initialize_fresh_execution(
604-
messages=messages, streaming_callback=streaming_callback, requires_async=False, **kwargs
615+
messages=messages,
616+
streaming_callback=streaming_callback,
617+
requires_async=False,
618+
system_prompt=system_prompt,
619+
**kwargs,
605620
)
606621

607622
with self._create_agent_span() as span:

haystack/components/rankers/sentence_transformers_diversity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def __init__( # noqa: PLR0913 # pylint: disable=too-many-positional-arguments
184184
self.meta_fields_to_embed = meta_fields_to_embed or []
185185
self.embedding_separator = embedding_separator
186186
self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
187-
self.lambda_threshold = lambda_threshold or 0.5
187+
self.lambda_threshold = lambda_threshold
188188
self._check_lambda_threshold(self.lambda_threshold, self.strategy)
189189
self.model_kwargs = model_kwargs
190190
self.tokenizer_kwargs = tokenizer_kwargs

haystack/core/component/component.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def copy_class_namespace(namespace):
605605

606606
return new_cls
607607

608-
# Call signature when the the decorator is usead without parens (@component).
608+
# Call signature when the decorator is used without parens (@component).
609609
@overload
610610
def __call__(self, cls: type[T]) -> type[T]: ...
611611

haystack/core/pipeline/breakpoint.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import replace
88
from datetime import datetime
99
from pathlib import Path
10-
from typing import Any, Optional, Union
10+
from typing import TYPE_CHECKING, Any, Optional, Union
1111

1212
from networkx import MultiDiGraph
1313

@@ -22,9 +22,12 @@
2222
PipelineState,
2323
ToolBreakpoint,
2424
)
25-
from haystack.tools import Tool, Toolset
2625
from haystack.utils.base_serialization import _serialize_value_with_schema
2726

27+
if TYPE_CHECKING:
28+
from haystack.tools.tool import Tool
29+
from haystack.tools.toolset import Toolset
30+
2831
logger = logging.getLogger(__name__)
2932

3033

@@ -323,7 +326,9 @@ def _create_agent_snapshot(
323326
)
324327

325328

326-
def _validate_tool_breakpoint_is_valid(agent_breakpoint: AgentBreakpoint, tools: Union[list[Tool], Toolset]) -> None:
329+
def _validate_tool_breakpoint_is_valid(
330+
agent_breakpoint: AgentBreakpoint, tools: Union[list["Tool"], "Toolset"]
331+
) -> None:
327332
"""
328333
Validates the AgentBreakpoint passed to the agent.
329334

haystack/core/pipeline/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from haystack import logging
1212
from haystack.core.component import Component
13-
from haystack.tools import Tool, Toolset
1413

1514
logger = logging.getLogger(__name__)
1615

@@ -29,6 +28,10 @@ def _deepcopy_with_exceptions(obj: Any) -> Any:
2928
:returns:
3029
A deep-copied version of the object, or the original object if deepcopying fails.
3130
"""
31+
# Import here to avoid circular imports
32+
from haystack.tools.tool import Tool
33+
from haystack.tools.toolset import Toolset
34+
3235
if isinstance(obj, (list, tuple, set)):
3336
return type(obj)(_deepcopy_with_exceptions(v) for v in obj)
3437

haystack/core/super_component/super_component.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,21 @@ def __init__(
4646
:param pipeline: The pipeline instance or async pipeline instance to be wrapped
4747
:param input_mapping: A dictionary mapping component input names to pipeline input socket paths.
4848
If not provided, a default input mapping will be created based on all pipeline inputs.
49+
Example:
50+
```python
51+
input_mapping={
52+
"query": ["retriever.query", "prompt_builder.query"],
53+
}
54+
```
4955
:param output_mapping: A dictionary mapping pipeline output socket paths to component output names.
5056
If not provided, a default output mapping will be created based on all pipeline outputs.
57+
Example:
58+
```python
59+
output_mapping={
60+
"retriever.documents": "documents",
61+
"generator.replies": "replies",
62+
}
63+
```
5164
:raises InvalidMappingError: Raised if any mapping is invalid or type conflicts occur
5265
:raises ValueError: Raised if no pipeline is provided
5366
"""
@@ -177,10 +190,14 @@ def _validate_input_mapping(
177190
for path in pipeline_input_paths:
178191
comp_name, socket_name = self._split_component_path(path)
179192
if comp_name not in pipeline_inputs:
180-
raise InvalidMappingValueError(f"Component '{comp_name}' not found in pipeline inputs.")
193+
raise InvalidMappingValueError(
194+
f"Component '{comp_name}' not found in pipeline inputs.\n"
195+
f"Available components: {list(pipeline_inputs.keys())}"
196+
)
181197
if socket_name not in pipeline_inputs[comp_name]:
182198
raise InvalidMappingValueError(
183-
f"Input socket '{socket_name}' not found in component '{comp_name}'."
199+
f"Input socket '{socket_name}' not found in component '{comp_name}'.\n"
200+
f"Available inputs for '{comp_name}': {list(pipeline_inputs[comp_name].keys())}"
184201
)
185202

186203
def _resolve_input_types_from_mapping(

haystack/tools/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44

55
# NOTE: we do not use LazyImporter here because it creates conflicts between the tool module and the tool decorator
66

7-
# ruff: noqa: I001 (ignore import order as we need to import Tool before ComponentTool)
8-
from .from_function import create_tool_from_function, tool
9-
from .tool import Tool, _check_duplicate_tool_names
10-
from .component_tool import ComponentTool
11-
from .serde_utils import deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
12-
from .toolset import Toolset
7+
# ruff: noqa: I001 (ignore import order as we need to import Tool before ComponentTool and PipelineTool)
8+
from haystack.tools.from_function import create_tool_from_function, tool
9+
from haystack.tools.tool import Tool, _check_duplicate_tool_names
10+
from haystack.tools.toolset import Toolset
11+
from haystack.tools.component_tool import ComponentTool
12+
from haystack.tools.pipeline_tool import PipelineTool
13+
from haystack.tools.serde_utils import deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
1314

1415
__all__ = [
1516
"_check_duplicate_tool_names",
1617
"ComponentTool",
1718
"create_tool_from_function",
1819
"deserialize_tools_or_toolset_inplace",
20+
"PipelineTool",
1921
"serialize_tools_or_toolset",
2022
"Tool",
2123
"Toolset",

haystack/tools/component_tool.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,9 @@ def to_dict(self) -> dict[str, Any]:
222222
"description": self.description,
223223
"parameters": self._unresolved_parameters,
224224
"inputs_from_state": self.inputs_from_state,
225-
# This is soft-copied as to not modify the attributes in place
226-
"outputs_to_state": self.outputs_to_state.copy() if self.outputs_to_state else None,
225+
"outputs_to_state": _serialize_outputs_to_state(self.outputs_to_state) if self.outputs_to_state else None,
227226
}
228227

229-
if self.outputs_to_state is not None:
230-
serialized["outputs_to_state"] = _serialize_outputs_to_state(self.outputs_to_state)
231-
232228
if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
233229
# This is soft-copied as to not modify the attributes in place
234230
serialized["outputs_to_string"] = self.outputs_to_string.copy()
@@ -239,7 +235,7 @@ def to_dict(self) -> dict[str, Any]:
239235
return {"type": generate_qualified_class_name(type(self)), "data": serialized}
240236

241237
@classmethod
242-
def from_dict(cls, data: dict[str, Any]) -> "Tool":
238+
def from_dict(cls, data: dict[str, Any]) -> "ComponentTool":
243239
"""
244240
Deserializes the ComponentTool from a dictionary.
245241
"""

0 commit comments

Comments
 (0)