-
Notifications
You must be signed in to change notification settings - Fork 120
[diffusion]: video_creator -> diffusion_video [audio]: new agent audio_generator, support doubao_tts #835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[diffusion]: video_creator -> diffusion_video [audio]: new agent audio_generator, support doubao_tts #835
Changes from 4 commits
06207b2
33bf19e
7771aab
f0beccb
bacbf02
34f499e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,6 @@ | |
| from rich.style import Style | ||
| from rich.table import Table | ||
| from rich.text import Text | ||
| import os | ||
|
|
||
| from aworld.logs.util import logger | ||
| from ._globals import console | ||
|
|
@@ -30,6 +29,7 @@ class AWorldCLI: | |
| def __init__(self): | ||
| self.console = console | ||
| self.user_input = UserInputHandler(console) | ||
| # self.team_handler = InteractiveTeamHandler(console) | ||
|
|
||
| def _get_gradient_text(self, text: str, start_color: str, end_color: str) -> Text: | ||
| """Create a Text object with a horizontal gradient.""" | ||
|
|
@@ -176,13 +176,13 @@ async def _edit_models_config(self, config, current_config: dict): | |
| else: | ||
| default_cfg.pop('base_url', None) | ||
|
|
||
| # Diffusion (models.diffusion -> DIFFUSION_* for video_creator agent) | ||
| self.console.print("\n[bold]Diffusion configuration[/bold] [dim](optional, for video_creator agent)[/dim]") | ||
| # Diffusion (models.diffusion -> DIFFUSION_* for diffusion agent) | ||
| self.console.print("\n[bold]Diffusion configuration[/bold] [dim](optional, for diffusion agent)[/dim]") | ||
| self.console.print(" [dim]Leave empty to use Media LLM or default LLM config above[/dim]\n") | ||
| if 'diffusion' not in current_config['models']: | ||
| # Migrate from legacy models.video_creator | ||
| current_config['models']['diffusion'] = current_config['models'].get('video_creator') or {} | ||
| current_config['models'].pop('video_creator', None) | ||
| # Migrate from legacy models.diffusion | ||
| current_config['models']['diffusion'] = current_config['models'].get('diffusion') or {} | ||
| current_config['models'].pop('diffusion', None) | ||
| diff_cfg = current_config['models']['diffusion'] | ||
|
|
||
| current_diff_api_key = diff_cfg.get('api_key', '') | ||
|
|
@@ -230,6 +230,58 @@ async def _edit_models_config(self, config, current_config: dict): | |
| if not diff_cfg: | ||
| current_config['models'].pop('diffusion', None) | ||
|
|
||
| # Audio (models.audio -> AUDIO_* for audio agent) | ||
| self.console.print("\n[bold]Audio configuration[/bold] [dim](optional, for audio agent)[/dim]") | ||
| self.console.print(" [dim]Leave empty to use Media LLM or default LLM config above[/dim]\n") | ||
| if 'audio' not in current_config['models']: | ||
| current_config['models']['audio'] = {} | ||
| audio_cfg = current_config['models']['audio'] | ||
|
|
||
| current_audio_api_key = audio_cfg.get('api_key', '') | ||
| if current_audio_api_key: | ||
| masked = current_audio_api_key[:8] + "..." if len(current_audio_api_key) > 8 else "***" | ||
| self.console.print(f" [dim]Current AUDIO_API_KEY: {masked}[/dim]") | ||
| audio_api_key = Prompt.ask(" AUDIO_API_KEY", default=current_audio_api_key, password=True) | ||
| if audio_api_key: | ||
| audio_cfg['api_key'] = audio_api_key | ||
| else: | ||
| audio_cfg.pop('api_key', None) | ||
|
|
||
| current_audio_model = audio_cfg.get('model', '') | ||
| self.console.print(" [dim]e.g. claude-3-5-sonnet-20241022 · Enter to inherit from Media/default[/dim]") | ||
| audio_model = Prompt.ask(" AUDIO_MODEL_NAME", default=current_audio_model) | ||
| if audio_model: | ||
| audio_cfg['model'] = audio_model | ||
| else: | ||
| audio_cfg.pop('model', None) | ||
|
|
||
| current_audio_base_url = audio_cfg.get('base_url', '') | ||
| audio_base_url = Prompt.ask(" AUDIO_BASE_URL", default=current_audio_base_url) | ||
| if audio_base_url: | ||
| audio_cfg['base_url'] = audio_base_url | ||
| else: | ||
| audio_cfg.pop('base_url', None) | ||
|
|
||
| current_audio_provider = audio_cfg.get('provider', 'openai') | ||
| audio_provider = Prompt.ask(" AUDIO_PROVIDER", default=current_audio_provider) | ||
| if audio_provider: | ||
| audio_cfg['provider'] = audio_provider | ||
| else: | ||
| audio_cfg.pop('provider', None) | ||
|
|
||
| current_audio_temp = audio_cfg.get('temperature', 0.1) | ||
| audio_temp = Prompt.ask(" AUDIO_TEMPERATURE", default=str(current_audio_temp)) | ||
| if audio_temp: | ||
| try: | ||
| audio_cfg['temperature'] = float(audio_temp) | ||
| except ValueError: | ||
| audio_cfg.pop('temperature', None) | ||
| else: | ||
| audio_cfg.pop('temperature', None) | ||
|
|
||
| if not audio_cfg: | ||
| current_config['models'].pop('audio', None) | ||
|
Comment on lines
+233
to
+283
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new "Audio configuration" section is almost identical to the "Diffusion configuration" section (lines 179-231). This significant code duplication makes the code harder to maintain. Consider refactoring the logic for prompting and setting configuration values into a reusable helper function. This would apply to both the A similar refactoring could be applied to the table display logic on lines 300-324, which also contains duplicated code for displaying the |
||
|
|
||
| config.save_config(current_config) | ||
| self.console.print(f"\n[green]✅ Configuration saved to {config.get_config_path()}[/green]") | ||
| table = Table(title="Default LLM Configuration", box=box.ROUNDED) | ||
|
|
@@ -258,6 +310,19 @@ async def _edit_models_config(self, config, current_config: dict): | |
| self.console.print() | ||
| self.console.print(diff_table) | ||
|
|
||
| if current_config['models'].get('audio'): | ||
| audio_table = Table(title="Audio Configuration (AUDIO_*)", box=box.ROUNDED) | ||
| audio_table.add_column("Setting", style="cyan") | ||
| audio_table.add_column("Value", style="green") | ||
| for key, value in current_config['models']['audio'].items(): | ||
| if key == 'api_key': | ||
| masked_value = value[:8] + "..." if len(str(value)) > 8 else "***" | ||
| audio_table.add_row(key, masked_value) | ||
| else: | ||
| audio_table.add_row(key, str(value)) | ||
| self.console.print() | ||
| self.console.print(audio_table) | ||
|
|
||
| async def _edit_skills_config(self, config, current_config: dict): | ||
| """Edit skills section of config (global SKILLS_PATH and per-agent XXX_SKILLS_PATH).""" | ||
| default_skills_path = str(Path.home() / ".aworld" / "skills") | ||
|
|
@@ -905,6 +970,7 @@ async def run_chat_session(self, agent_name: str, executor: Callable[[str], Any] | |
| f"Type '/agents' to list all available agents.\n" | ||
| f"Type '/cost' for current session, '/cost -all' for global history.\n" | ||
| f"Type '/compact' to run context compression.\n" | ||
| f"Type '/team' for agent team management.\n" | ||
| f"Type '/memory' to edit project context, '/memory view' to view, '/memory status' for status.\n" | ||
| f"Use @filename to include images or text files (e.g., @photo.jpg or @document.txt)." | ||
| ) | ||
|
|
@@ -921,6 +987,7 @@ async def run_chat_session(self, agent_name: str, executor: Callable[[str], Any] | |
| slash_cmds = [ | ||
| "/agents", "/skills", "/new", "/restore", "/latest", | ||
| "/exit", "/quit", "/switch", "/cost", "/cost -all", "/compact", | ||
| "/team", | ||
| "/memory", "/memory view", "/memory reload", "/memory status", | ||
| ] | ||
| switch_with_agents = [f"/switch {n}" for n in agent_names] if agent_names else [] | ||
|
|
@@ -941,6 +1008,7 @@ async def run_chat_session(self, agent_name: str, executor: Callable[[str], Any] | |
| "/memory view": "View current memory content", | ||
| "/memory reload": "Reload memory from file", | ||
| "/memory status": "Show memory system status", | ||
| "/team": "Agent team management commands", | ||
| "exit": "Exit chat", | ||
| "quit": "Exit chat", | ||
| } | ||
|
|
@@ -1178,12 +1246,12 @@ async def run_chat_session(self, agent_name: str, executor: Callable[[str], Any] | |
| try: | ||
| parts = user_input.split(maxsplit=1) | ||
| subcommand = parts[1] if len(parts) > 1 else "" | ||
|
|
||
| # Import required modules | ||
| import os | ||
| from pathlib import Path | ||
| import subprocess | ||
|
|
||
| # Find AWORLD.md file | ||
| def find_aworld_file(): | ||
| """Find AWORLD.md in standard locations""" | ||
|
|
@@ -1197,11 +1265,11 @@ def find_aworld_file(): | |
| if path.exists(): | ||
| return path | ||
| return None | ||
|
|
||
| def get_editor(): | ||
| """Get editor from environment variables""" | ||
| return os.environ.get('VISUAL') or os.environ.get('EDITOR') or 'nano' | ||
|
|
||
| if subcommand == "view": | ||
| # View current memory content | ||
| aworld_file = find_aworld_file() | ||
|
|
@@ -1216,20 +1284,20 @@ def get_editor(): | |
| from rich.syntax import Syntax | ||
| syntax = Syntax(content, "markdown", theme="monokai", line_numbers=False) | ||
| self.console.print(Panel(syntax, title="AWORLD.md", border_style="cyan")) | ||
|
|
||
| elif subcommand == "reload": | ||
| # Reload memory from file | ||
| self.console.print("[dim]Memory reload functionality requires agent restart.[/dim]") | ||
| self.console.print("[dim]The AWORLD.md file will be automatically loaded on next agent start.[/dim]") | ||
|
|
||
| elif subcommand == "status": | ||
| # Show memory system status | ||
| aworld_file = find_aworld_file() | ||
| from rich.table import Table | ||
| table = Table(title="Memory System Status", box=box.ROUNDED) | ||
| table.add_column("Property", style="cyan") | ||
| table.add_column("Value", style="green") | ||
|
|
||
| if aworld_file: | ||
| table.add_row("AWORLD.md Location", str(aworld_file)) | ||
| table.add_row("File Size", f"{aworld_file.stat().st_size} bytes") | ||
|
|
@@ -1240,25 +1308,25 @@ def get_editor(): | |
| else: | ||
| table.add_row("AWORLD.md Location", "Not found") | ||
| table.add_row("Status", "❌ Not configured") | ||
|
|
||
| table.add_row("Feature", "AWORLDFileNeuron") | ||
| table.add_row("Auto-load", "Enabled") | ||
| self.console.print(table) | ||
|
|
||
| else: | ||
| # Edit AWORLD.md (default action) | ||
| aworld_file = find_aworld_file() | ||
|
|
||
| if not aworld_file: | ||
| # Create new file in user directory (DEFAULT) | ||
| default_location = Path.home() / '.aworld' / 'AWORLD.md' | ||
| self.console.print(f"[yellow]No AWORLD.md found. Creating new file at:[/yellow]") | ||
| self.console.print(f"[cyan]{default_location}[/cyan]") | ||
| self.console.print(f"[dim](Default: ~/.aworld/AWORLD.md)[/dim]\n") | ||
|
|
||
| # Create directory if needed | ||
| default_location.parent.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| # Create template | ||
| template = """# Project Context | ||
|
|
||
|
|
@@ -1283,11 +1351,11 @@ def get_editor(): | |
| """ | ||
| default_location.write_text(template, encoding='utf-8') | ||
| aworld_file = default_location | ||
|
|
||
| # Open in editor | ||
| editor = get_editor() | ||
| self.console.print(f"[dim]Opening {aworld_file} in {editor}...[/dim]") | ||
|
|
||
| try: | ||
| # Open editor and wait for it to close | ||
| result = subprocess.run([editor, str(aworld_file)]) | ||
|
|
@@ -1301,13 +1369,18 @@ def get_editor(): | |
| self.console.print("[dim]Set EDITOR or VISUAL environment variable to your preferred editor.[/dim]") | ||
| except Exception as e: | ||
| self.console.print(f"[red]Error opening editor: {e}[/red]") | ||
|
|
||
| except Exception as e: | ||
| self.console.print(f"[red]Error handling memory command: {e}[/red]") | ||
| import traceback | ||
| traceback.print_exc() | ||
| continue | ||
|
|
||
| # Handle team command | ||
| if user_input.lower().startswith("/team"): | ||
| # await self.team_handler.handle_command(user_input) | ||
| continue | ||
|
|
||
| # Handle agents command | ||
| if user_input.lower() in ("/agents", "agents"): | ||
| try: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -281,9 +281,9 @@ def _apply_filesystem_config(filesystem_cfg: Optional[Dict[str, Any]] = None) -> | |
|
|
||
| def _apply_diffusion_models_config(models_config: Dict[str, Any]) -> None: | ||
| """ | ||
| Apply models.diffusion config to DIFFUSION_* env vars for video_creator agent. | ||
| Apply models.diffusion config to DIFFUSION_* env vars for diffusion agent. | ||
| Priority: models.diffusion config > existing DIFFUSION_* env vars > LLM_*. | ||
| Supports models.video_creator for backwards compatibility. | ||
| Supports models.diffusion for backwards compatibility. | ||
| """ | ||
| diff_cfg = models_config.get('diffusion') | ||
| diff_cfg = diff_cfg if isinstance(diff_cfg, dict) else {} | ||
|
|
@@ -345,6 +345,69 @@ def _apply_diffusion_models_config(models_config: Dict[str, Any]) -> None: | |
| os.environ['DIFFUSION_TEMPERATURE'] = str(float(temperature)) | ||
|
|
||
|
|
||
| def _apply_audio_models_config(models_config: Dict[str, Any]) -> None: | ||
| """ | ||
| Apply models.audio config to AUDIO_* env vars for audio agent. | ||
| Priority: models.audio config > existing AUDIO_* env vars > LLM_*. | ||
| """ | ||
| audio_cfg = models_config.get('audio') | ||
| audio_cfg = audio_cfg if isinstance(audio_cfg, dict) else {} | ||
| api_key = (audio_cfg.get('api_key') or '').strip() | ||
| model_name = (audio_cfg.get('model') or '').strip() | ||
| base_url = (audio_cfg.get('base_url') or '').strip() | ||
| provider = (audio_cfg.get('provider') or '').strip() | ||
| temperature = audio_cfg.get('temperature') | ||
|
|
||
| if not api_key: | ||
| api_key = (os.environ.get('AUDIO_API_KEY') or '').strip() | ||
| if not api_key: | ||
| api_key = (os.environ.get('LLM_API_KEY') or '').strip() | ||
| if not api_key: | ||
| for key in ('OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'GEMINI_API_KEY'): | ||
| v = (os.environ.get(key) or '').strip() | ||
| if v: | ||
| api_key = v | ||
| if not provider and 'OPENAI' in key: | ||
| provider = 'openai' | ||
| elif not provider and 'ANTHROPIC' in key: | ||
| provider = 'anthropic' | ||
| elif not provider and 'GEMINI' in key: | ||
| provider = 'gemini' | ||
| break | ||
| if not model_name: | ||
| model_name = (os.environ.get('AUDIO_MODEL_NAME') or '').strip() | ||
| if not model_name: | ||
| model_name = (os.environ.get('LLM_MODEL_NAME') or '').strip() | ||
| if not base_url: | ||
| base_url = (os.environ.get('AUDIO_BASE_URL') or '').strip() | ||
| if not base_url: | ||
| base_url = (os.environ.get('LLM_BASE_URL') or '').strip() | ||
| if not base_url: | ||
| for key in ('OPENAI_BASE_URL', 'ANTHROPIC_BASE_URL', 'GEMINI_BASE_URL'): | ||
| v = (os.environ.get(key) or '').strip() | ||
| if v: | ||
| base_url = v | ||
| break | ||
| if not provider: | ||
| provider = (os.environ.get('AUDIO_PROVIDER') or '').strip() | ||
| if not provider: | ||
| provider = 'openai' | ||
| if temperature is None: | ||
| env_temp = (os.environ.get('AUDIO_TEMPERATURE') or '').strip() | ||
| if env_temp: | ||
| temperature = float(env_temp) | ||
|
|
||
| if api_key: | ||
| os.environ['AUDIO_API_KEY'] = api_key | ||
| if model_name: | ||
| os.environ['AUDIO_MODEL_NAME'] = model_name | ||
| if base_url: | ||
| os.environ['AUDIO_BASE_URL'] = base_url | ||
| os.environ['AUDIO_PROVIDER'] = provider | ||
| if temperature is not None: | ||
| os.environ['AUDIO_TEMPERATURE'] = str(float(temperature)) | ||
|
Comment on lines
+348
to
+408
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new function To improve maintainability, I recommend creating a single, parameterized helper function that can handle applying model configurations for different types (like |
||
|
|
||
|
|
||
| def _apply_models_config_to_env(models_config: Dict[str, Any]) -> None: | ||
| """ | ||
| Apply models config (api_key, model, base_url) to os.environ. | ||
|
|
@@ -381,6 +444,7 @@ def _apply_models_config_to_env(models_config: Dict[str, Any]) -> None: | |
| if base_url: | ||
| os.environ['LLM_BASE_URL'] = base_url | ||
| _apply_diffusion_models_config(models_config) | ||
| _apply_audio_models_config(models_config) | ||
| return | ||
| # Legacy: nested models.default.{provider} or models.{provider} | ||
| default_providers = {k: v for k, v in default_cfg.items() | ||
|
|
@@ -422,6 +486,7 @@ def _apply_models_config_to_env(models_config: Dict[str, Any]) -> None: | |
| os.environ['LLM_BASE_URL'] = base_url | ||
|
|
||
| _apply_diffusion_models_config(models_config) | ||
| _apply_audio_models_config(models_config) | ||
|
|
||
|
|
||
| def _load_from_local_env(source_path: str) -> tuple[Dict[str, Any], str, str]: | ||
|
|
@@ -439,6 +504,7 @@ def _load_from_local_env(source_path: str) -> tuple[Dict[str, Any], str, str]: | |
| }) | ||
| # Apply DIFFUSION_* from LLM_* when not set in .env | ||
| _apply_diffusion_models_config({}) | ||
| _apply_audio_models_config({}) | ||
| logger.info(f"[config] load_dotenv loaded from: {source_path} {os.environ.get('LLM_MODEL_NAME')} {os.environ.get('LLM_BASE_URL')}") | ||
| return _env_to_config(), "local", source_path | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -82,6 +82,7 @@ def check_session_token_limit( | |||||
|
|
||||||
| history = JSONLHistory(str(history_path)) | ||||||
| stats = history.get_token_stats(session_id=session_id) | ||||||
| logger.info(f"check_session_token_limit|agent_name={agent_name}|session_id={session_id}|limit={limit}|stats={stats}") | ||||||
|
|
||||||
| # Use current agent's context_window_tokens (ctx) when agent_name provided | ||||||
| if agent_name: | ||||||
|
|
@@ -90,7 +91,7 @@ def check_session_token_limit( | |||||
| total = ( | ||||||
| agent_stats.get("context_window_tokens", 0) | ||||||
| if agent_stats | ||||||
| else stats.get("total_tokens", 0) | ||||||
| else 0 | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for calculating Setting I recommend reverting to the previous logic to ensure a more robust fallback.
Suggested change
|
||||||
| ) | ||||||
| else: | ||||||
| total = stats.get("total_tokens", 0) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a logic error in the migration block for the diffusion configuration. The code currently checks if
'diffusion'is not incurrent_config['models'], and if so, it attempts to get'diffusion'from the same dictionary (which will beNone), and then immediately removes it. This has no effect.If the goal is to simply ensure the
diffusiondictionary exists, this block should be simplified to match the pattern used for the newaudioconfiguration.