Skip to content

Commit 19f6271

Browse files
committed
fix agent switcher in front-end
1 parent 8b3bcbc commit 19f6271

File tree

2 files changed

+76
-55
lines changed

2 files changed

+76
-55
lines changed

src/mcp_agent/core/enhanced_prompt.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from prompt_toolkit.styles import Style
1515
from rich import print as rich_print
1616

17+
from mcp_agent.core.agent_types import AgentType
1718
from mcp_agent.core.exceptions import PromptExitError
1819

1920
# Get the application version
@@ -86,7 +87,7 @@ def get_completions(self, document, complete_event):
8687
for agent in self.agents:
8788
if agent.lower().startswith(agent_name.lower()):
8889
# Get agent type or default to "Agent"
89-
agent_type = self.agent_types.get(agent, "Agent")
90+
agent_type = self.agent_types.get(agent, AgentType.BASIC).value
9091
yield Completion(
9192
agent,
9293
start_position=-len(agent_name),
@@ -149,7 +150,7 @@ async def get_enhanced_input(
149150
show_stop_hint: bool = False,
150151
multiline: bool = False,
151152
available_agent_names: List[str] = None,
152-
agent_types: dict = None,
153+
agent_types: dict[str, AgentType] = None,
153154
is_human_input: bool = False,
154155
toolbar_color: str = "ansiblue",
155156
) -> str:
@@ -430,18 +431,18 @@ async def get_argument_input(
430431
async def handle_special_commands(command, agent_app=None):
431432
"""
432433
Handle special input commands.
433-
434+
434435
Args:
435436
command: The command to handle, can be string or dictionary
436437
agent_app: Optional agent app reference
437-
438+
438439
Returns:
439440
True if command was handled, False if not, or a dict with action info
440441
"""
441442
# Quick guard for empty or None commands
442443
if not command:
443444
return False
444-
445+
445446
# If command is already a dictionary, it has been pre-processed
446447
# Just return it directly (like when /prompts converts to select_prompt dict)
447448
if isinstance(command, dict):

src/mcp_agent/core/interactive_prompt.py

Lines changed: 70 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from rich.console import Console
2121
from rich.table import Table
2222

23+
from mcp_agent.core.agent_types import AgentType
2324
from mcp_agent.core.enhanced_prompt import (
2425
get_argument_input,
2526
get_enhanced_input,
@@ -36,14 +37,14 @@ class InteractivePrompt:
3637
This is extracted from the original AgentApp implementation to support DirectAgentApp.
3738
"""
3839

39-
def __init__(self, agent_types: Optional[Dict[str, str]] = None) -> None:
40+
def __init__(self, agent_types: Optional[Dict[str, AgentType]] = None) -> None:
4041
"""
4142
Initialize the interactive prompt.
4243
4344
Args:
4445
agent_types: Dictionary mapping agent names to their types for display
4546
"""
46-
self.agent_types = agent_types or {}
47+
self.agent_types: Dict[str, AgentType] = agent_types or {}
4748

4849
async def prompt_loop(
4950
self,
@@ -97,7 +98,7 @@ async def prompt_loop(
9798

9899
# Handle special commands - pass "True" to enable agent switching
99100
command_result = await handle_special_commands(user_input, True)
100-
101+
101102
# Check if we should switch agents
102103
if isinstance(command_result, dict):
103104
if "switch_agent" in command_result:
@@ -113,32 +114,43 @@ async def prompt_loop(
113114
# Use the list_prompts_func directly
114115
await self._list_prompts(list_prompts_func, agent)
115116
continue
116-
elif "select_prompt" in command_result and (list_prompts_func and apply_prompt_func):
117+
elif "select_prompt" in command_result and (
118+
list_prompts_func and apply_prompt_func
119+
):
117120
# Handle prompt selection, using both list_prompts and apply_prompt
118121
prompt_name = command_result.get("prompt_name")
119122
prompt_index = command_result.get("prompt_index")
120-
123+
121124
# If a specific index was provided (from /prompt <number>)
122125
if prompt_index is not None:
123126
# First get a list of all prompts to look up the index
124127
all_prompts = await self._get_all_prompts(list_prompts_func, agent)
125128
if not all_prompts:
126129
rich_print("[yellow]No prompts available[/yellow]")
127130
continue
128-
131+
129132
# Check if the index is valid
130133
if 1 <= prompt_index <= len(all_prompts):
131134
# Get the prompt at the specified index (1-based to 0-based)
132135
selected_prompt = all_prompts[prompt_index - 1]
133136
# Use the already created namespaced_name to ensure consistency
134-
await self._select_prompt(list_prompts_func, apply_prompt_func, agent, selected_prompt["namespaced_name"])
137+
await self._select_prompt(
138+
list_prompts_func,
139+
apply_prompt_func,
140+
agent,
141+
selected_prompt["namespaced_name"],
142+
)
135143
else:
136-
rich_print(f"[red]Invalid prompt number: {prompt_index}. Valid range is 1-{len(all_prompts)}[/red]")
144+
rich_print(
145+
f"[red]Invalid prompt number: {prompt_index}. Valid range is 1-{len(all_prompts)}[/red]"
146+
)
137147
# Show the prompt list for convenience
138148
await self._list_prompts(list_prompts_func, agent)
139149
else:
140150
# Use the name-based selection
141-
await self._select_prompt(list_prompts_func, apply_prompt_func, agent, prompt_name)
151+
await self._select_prompt(
152+
list_prompts_func, apply_prompt_func, agent, prompt_name
153+
)
142154
continue
143155

144156
# Skip further processing if command was handled
@@ -158,11 +170,11 @@ async def prompt_loop(
158170
async def _get_all_prompts(self, list_prompts_func, agent_name):
159171
"""
160172
Get a list of all available prompts.
161-
173+
162174
Args:
163175
list_prompts_func: Function to get available prompts
164176
agent_name: Name of the agent
165-
177+
166178
Returns:
167179
List of prompt info dictionaries, sorted by server and name
168180
"""
@@ -171,52 +183,59 @@ async def _get_all_prompts(self, list_prompts_func, agent_name):
171183
# the agent_name parameter should never be used as a server name
172184
prompt_servers = await list_prompts_func(None)
173185
all_prompts = []
174-
186+
175187
# Process the returned prompt servers
176188
if prompt_servers:
177189
# First collect all prompts
178190
for server_name, prompts_info in prompt_servers.items():
179191
if prompts_info and hasattr(prompts_info, "prompts") and prompts_info.prompts:
180192
for prompt in prompts_info.prompts:
181193
# Use the SEP constant for proper namespacing
182-
all_prompts.append({
183-
"server": server_name,
184-
"name": prompt.name,
185-
"namespaced_name": f"{server_name}{SEP}{prompt.name}",
186-
"description": getattr(prompt, "description", "No description"),
187-
"arg_count": len(getattr(prompt, "arguments", [])),
188-
"arguments": getattr(prompt, "arguments", [])
189-
})
194+
all_prompts.append(
195+
{
196+
"server": server_name,
197+
"name": prompt.name,
198+
"namespaced_name": f"{server_name}{SEP}{prompt.name}",
199+
"description": getattr(prompt, "description", "No description"),
200+
"arg_count": len(getattr(prompt, "arguments", [])),
201+
"arguments": getattr(prompt, "arguments", []),
202+
}
203+
)
190204
elif isinstance(prompts_info, list) and prompts_info:
191205
for prompt in prompts_info:
192206
if isinstance(prompt, dict) and "name" in prompt:
193-
all_prompts.append({
194-
"server": server_name,
195-
"name": prompt["name"],
196-
"namespaced_name": f"{server_name}{SEP}{prompt['name']}",
197-
"description": prompt.get("description", "No description"),
198-
"arg_count": len(prompt.get("arguments", [])),
199-
"arguments": prompt.get("arguments", [])
200-
})
207+
all_prompts.append(
208+
{
209+
"server": server_name,
210+
"name": prompt["name"],
211+
"namespaced_name": f"{server_name}{SEP}{prompt['name']}",
212+
"description": prompt.get("description", "No description"),
213+
"arg_count": len(prompt.get("arguments", [])),
214+
"arguments": prompt.get("arguments", []),
215+
}
216+
)
201217
else:
202-
all_prompts.append({
203-
"server": server_name,
204-
"name": str(prompt),
205-
"namespaced_name": f"{server_name}{SEP}{str(prompt)}",
206-
"description": "No description",
207-
"arg_count": 0,
208-
"arguments": []
209-
})
210-
218+
all_prompts.append(
219+
{
220+
"server": server_name,
221+
"name": str(prompt),
222+
"namespaced_name": f"{server_name}{SEP}{str(prompt)}",
223+
"description": "No description",
224+
"arg_count": 0,
225+
"arguments": [],
226+
}
227+
)
228+
211229
# Sort prompts by server and name for consistent ordering
212230
all_prompts.sort(key=lambda p: (p["server"], p["name"]))
213-
231+
214232
return all_prompts
215-
233+
216234
except Exception as e:
217235
import traceback
218236

219237
from rich import print as rich_print
238+
220239
rich_print(f"[red]Error getting prompts: {e}[/red]")
221240
rich_print(f"[dim]{traceback.format_exc()}[/dim]")
222241
return []
@@ -238,11 +257,11 @@ async def _list_prompts(self, list_prompts_func, agent_name) -> None:
238257
try:
239258
# Directly call the list_prompts function for this agent
240259
rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
241-
260+
242261
# Get all prompts using the helper function - pass None as server name
243262
# to get prompts from all available servers
244263
all_prompts = await self._get_all_prompts(list_prompts_func, None)
245-
264+
246265
if all_prompts:
247266
# Create a table for better display
248267
table = Table(title="Available MCP Prompts")
@@ -251,19 +270,19 @@ async def _list_prompts(self, list_prompts_func, agent_name) -> None:
251270
table.add_column("Prompt Name", style="bright_blue")
252271
table.add_column("Description")
253272
table.add_column("Args", justify="center")
254-
273+
255274
# Add prompts to table
256275
for i, prompt in enumerate(all_prompts):
257276
table.add_row(
258277
str(i + 1),
259278
prompt["server"],
260279
prompt["name"],
261280
prompt["description"],
262-
str(prompt["arg_count"])
281+
str(prompt["arg_count"]),
263282
)
264-
283+
265284
console.print(table)
266-
285+
267286
# Add usage instructions
268287
rich_print("\n[bold]Usage:[/bold]")
269288
rich_print(" • Use [cyan]/prompt <number>[/cyan] to select a prompt by number")
@@ -272,10 +291,13 @@ async def _list_prompts(self, list_prompts_func, agent_name) -> None:
272291
rich_print("[yellow]No prompts available[/yellow]")
273292
except Exception as e:
274293
import traceback
294+
275295
rich_print(f"[red]Error listing prompts: {e}[/red]")
276296
rich_print(f"[dim]{traceback.format_exc()}[/dim]")
277297

278-
async def _select_prompt(self, list_prompts_func, apply_prompt_func, agent_name, requested_name=None) -> None:
298+
async def _select_prompt(
299+
self, list_prompts_func, apply_prompt_func, agent_name, requested_name=None
300+
) -> None:
279301
"""
280302
Select and apply a prompt.
281303
@@ -293,7 +315,7 @@ async def _select_prompt(self, list_prompts_func, apply_prompt_func, agent_name,
293315
try:
294316
# Get all available prompts directly from the list_prompts function
295317
rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
296-
# IMPORTANT: list_prompts_func gets MCP server prompts, not agent prompts
318+
# IMPORTANT: list_prompts_func gets MCP server prompts, not agent prompts
297319
# So we pass None to get prompts from all servers, not using agent_name as server name
298320
prompt_servers = await list_prompts_func(None)
299321

@@ -514,9 +536,7 @@ async def _select_prompt(self, list_prompts_func, apply_prompt_func, agent_name,
514536

515537
# Apply the prompt
516538
namespaced_name = selected_prompt["namespaced_name"]
517-
rich_print(
518-
f"\n[bold]Applying prompt [cyan]{namespaced_name}[/cyan]...[/bold]"
519-
)
539+
rich_print(f"\n[bold]Applying prompt [cyan]{namespaced_name}[/cyan]...[/bold]")
520540

521541
# Call apply_prompt function with the prompt name and arguments
522542
await apply_prompt_func(namespaced_name, arg_values, agent_name)

0 commit comments

Comments
 (0)