diff --git a/agentscope-core/src/test/java/io/agentscope/core/tool/subagent/SubAgentToolTimeoutRetryIntegrationTest.java b/agentscope-core/src/test/java/io/agentscope/core/tool/subagent/SubAgentToolTimeoutRetryIntegrationTest.java index e49c2c033..4b304776d 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/tool/subagent/SubAgentToolTimeoutRetryIntegrationTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/tool/subagent/SubAgentToolTimeoutRetryIntegrationTest.java @@ -103,6 +103,10 @@ public Mono callAsync(ToolCallParam p) { () -> { try { Thread.sleep(4_000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + try { Files.writeString( tmpFile, "slow_tool round done\n"); } catch (Exception ignored) { diff --git a/agentscope-extensions/agentscope-extensions-model/agentscope-extensions-model-openai/src/main/java/io/agentscope/extensions/model/openai/tool/OpenAIMultiModalTool.java b/agentscope-extensions/agentscope-extensions-model/agentscope-extensions-model-openai/src/main/java/io/agentscope/extensions/model/openai/tool/OpenAIMultiModalTool.java index 2097c96f9..88ba0b1d2 100644 --- a/agentscope-extensions/agentscope-extensions-model/agentscope-extensions-model-openai/src/main/java/io/agentscope/extensions/model/openai/tool/OpenAIMultiModalTool.java +++ b/agentscope-extensions/agentscope-extensions-model/agentscope-extensions-model-openai/src/main/java/io/agentscope/extensions/model/openai/tool/OpenAIMultiModalTool.java @@ -58,6 +58,11 @@ public class OpenAIMultiModalTool { private static final Logger log = LoggerFactory.getLogger(OpenAIMultiModalTool.class); + private static final String DEFAULT_VISION_MODEL = "gpt-4o"; + private static final String DEFAULT_IMAGE_GEN_MODEL = "dall-e-3"; + private static final String DEFAULT_TTS_MODEL = "tts-1"; + private static final String DEFAULT_STT_MODEL = "whisper-1"; + /** OpenAI API key. */ private final String apiKey; @@ -68,7 +73,16 @@ public class OpenAIMultiModalTool { private final String baseUrl; /** Default vision model used when the caller does not specify one. */ - private final String defaultModelName; + private final String defaultVisionModel; + + /** Default image generation model used when the caller does not specify one. */ + private final String defaultImageGenModel; + + /** Default TTS model used when the caller does not specify one. */ + private final String defaultTtsModel; + + /** Default speech-to-text model used when the caller does not specify one. */ + private final String defaultSttModel; /** * Create a new OpenAIMultiModalTool with default base URL. @@ -94,19 +108,38 @@ public OpenAIMultiModalTool(String apiKey, String baseUrl) { * * @param apiKey the OpenAI API key * @param baseUrl the base URL (null for default https://api.openai.com) - * @param defaultModelName the default vision model name used when the caller omits the model + * @param defaultVisionModel the default vision model name used when the caller omits the model * parameter (e.g., "gpt-4o" for OpenAI, or your backend's vision model name) */ - public OpenAIMultiModalTool(String apiKey, String baseUrl, String defaultModelName) { + public OpenAIMultiModalTool(String apiKey, String baseUrl, String defaultVisionModel) { + this( + apiKey, + baseUrl, + defaultVisionModel, + DEFAULT_IMAGE_GEN_MODEL, + DEFAULT_TTS_MODEL, + DEFAULT_STT_MODEL); + } + + private OpenAIMultiModalTool( + String apiKey, + String baseUrl, + String defaultVisionModel, + String defaultImageGenModel, + String defaultTtsModel, + String defaultSttModel) { if (apiKey == null || apiKey.trim().isEmpty()) { throw new IllegalArgumentException("OpenAI API key cannot be empty."); } - if (defaultModelName == null || defaultModelName.trim().isEmpty()) { - throw new IllegalArgumentException("defaultModelName cannot be empty."); + if (defaultVisionModel == null || defaultVisionModel.trim().isEmpty()) { + throw new IllegalArgumentException("defaultVisionModel cannot be empty."); } this.apiKey = apiKey; this.baseUrl = baseUrl; - this.defaultModelName = defaultModelName; + this.defaultVisionModel = defaultVisionModel; + this.defaultImageGenModel = defaultImageGenModel; + this.defaultTtsModel = defaultTtsModel; + this.defaultSttModel = defaultSttModel; this.client = new OpenAIClient(); } @@ -123,18 +156,138 @@ protected OpenAIMultiModalTool(OpenAIClient client) { * Create a new OpenAIMultiModalTool with custom client and default model (for testing). * * @param client the OpenAI client - * @param defaultModelName the default vision model name + * @param defaultVisionModel the default vision model name */ - protected OpenAIMultiModalTool(OpenAIClient client, String defaultModelName) { - if (defaultModelName == null || defaultModelName.trim().isEmpty()) { - throw new IllegalArgumentException("defaultModelName cannot be empty."); + protected OpenAIMultiModalTool(OpenAIClient client, String defaultVisionModel) { + if (defaultVisionModel == null || defaultVisionModel.trim().isEmpty()) { + throw new IllegalArgumentException("defaultVisionModel cannot be empty."); } this.apiKey = "test-key"; this.baseUrl = null; - this.defaultModelName = defaultModelName; + this.defaultVisionModel = defaultVisionModel; + this.defaultImageGenModel = DEFAULT_IMAGE_GEN_MODEL; + this.defaultTtsModel = DEFAULT_TTS_MODEL; + this.defaultSttModel = DEFAULT_STT_MODEL; this.client = client; } + /** + * Creates a new builder for OpenAIMultiModalTool. + * + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for OpenAIMultiModalTool. + * + *

Each capability has an independent default model. When not set via builder, + * the existing hardcoded default is used — no change to current behavior. + */ + public static class Builder { + private String apiKey; + private String baseUrl; + private String defaultVisionModel; + private String defaultImageGenModel; + private String defaultTtsModel; + private String defaultSttModel; + + /** + * Sets the API key for OpenAI authentication. + * + * @param apiKey the API key + * @return this builder instance + */ + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Sets a custom base URL for OpenAI API. + * + * @param baseUrl the base URL (null for default) + * @return this builder instance + */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets the default vision model for image-to-text. + * + *

When not set, falls back to {@code "gpt-4o"}. + * + * @param defaultVisionModel the vision model name + * @return this builder instance + */ + public Builder defaultVisionModel(String defaultVisionModel) { + this.defaultVisionModel = defaultVisionModel; + return this; + } + + /** + * Sets the default model for text-to-image generation. + * + *

When not set, falls back to {@code "dall-e-3"}. + * + * @param defaultImageGenModel the image generation model name + * @return this builder instance + */ + public Builder defaultImageGenModel(String defaultImageGenModel) { + this.defaultImageGenModel = defaultImageGenModel; + return this; + } + + /** + * Sets the default TTS model for text-to-audio. + * + *

When not set, falls back to {@code "tts-1"}. + * + * @param defaultTtsModel the TTS model name + * @return this builder instance + */ + public Builder defaultTtsModel(String defaultTtsModel) { + this.defaultTtsModel = defaultTtsModel; + return this; + } + + /** + * Sets the default speech-to-text model for audio-to-text. + * + *

When not set, falls back to {@code "whisper-1"}. + * + * @param defaultSttModel the STT model name + * @return this builder instance + */ + public Builder defaultSttModel(String defaultSttModel) { + this.defaultSttModel = defaultSttModel; + return this; + } + + /** + * Builds the OpenAIMultiModalTool instance. + * + * @return configured OpenAIMultiModalTool instance + * @throws IllegalArgumentException if apiKey is not set + */ + public OpenAIMultiModalTool build() { + if (apiKey == null || apiKey.trim().isEmpty()) { + throw new IllegalArgumentException("apiKey must be set"); + } + return new OpenAIMultiModalTool( + apiKey, + baseUrl, + defaultVisionModel != null ? defaultVisionModel : DEFAULT_VISION_MODEL, + defaultImageGenModel != null ? defaultImageGenModel : DEFAULT_IMAGE_GEN_MODEL, + defaultTtsModel != null ? defaultTtsModel : DEFAULT_TTS_MODEL, + defaultSttModel != null ? defaultSttModel : DEFAULT_STT_MODEL); + } + } + /** * Generate image(s) based on the given prompt. * @@ -186,7 +339,9 @@ public Mono openaiTextToImage( String responseFormat) { String finalModel = - Optional.ofNullable(model).filter(s -> !s.trim().isEmpty()).orElse("dall-e-3"); + Optional.ofNullable(model) + .filter(s -> !s.trim().isEmpty()) + .orElse(this.defaultImageGenModel); Integer finalN = Optional.ofNullable(n).orElse(1); String finalSize = Optional.ofNullable(size).filter(s -> !s.trim().isEmpty()).orElse("1024x1024"); @@ -318,7 +473,7 @@ public Mono openaiImageToText( String finalModel = Optional.ofNullable(model) .filter(s -> !s.trim().isEmpty()) - .orElse(this.defaultModelName); + .orElse(this.defaultVisionModel); String finalPrompt = Optional.ofNullable(prompt) .filter(s -> !s.trim().isEmpty()) @@ -447,7 +602,9 @@ public Mono openaiTextToAudio( Double speed) { String finalModel = - Optional.ofNullable(model).filter(s -> !s.trim().isEmpty()).orElse("tts-1"); + Optional.ofNullable(model) + .filter(s -> !s.trim().isEmpty()) + .orElse(this.defaultTtsModel); String finalVoice = Optional.ofNullable(voice).filter(s -> !s.trim().isEmpty()).orElse("alloy"); String finalResponseFormat = @@ -541,7 +698,9 @@ public Mono openaiAudioToText( Double temperature) { String finalModel = - Optional.ofNullable(model).filter(s -> !s.trim().isEmpty()).orElse("whisper-1"); + Optional.ofNullable(model) + .filter(s -> !s.trim().isEmpty()) + .orElse(this.defaultSttModel); String finalResponseFormat = Optional.ofNullable(responseFormat).filter(s -> !s.trim().isEmpty()).orElse("text"); diff --git a/agentscope-extensions/agentscope-extensions-model/agentscope-extensions-model-openai/src/test/java/io/agentscope/extensions/model/openai/tool/OpenAIMultiModalToolTest.java b/agentscope-extensions/agentscope-extensions-model/agentscope-extensions-model-openai/src/test/java/io/agentscope/extensions/model/openai/tool/OpenAIMultiModalToolTest.java index 602dd8902..779a2115c 100644 --- a/agentscope-extensions/agentscope-extensions-model/agentscope-extensions-model-openai/src/test/java/io/agentscope/extensions/model/openai/tool/OpenAIMultiModalToolTest.java +++ b/agentscope-extensions/agentscope-extensions-model/agentscope-extensions-model-openai/src/test/java/io/agentscope/extensions/model/openai/tool/OpenAIMultiModalToolTest.java @@ -31,6 +31,7 @@ import io.agentscope.core.message.URLSource; import io.agentscope.extensions.model.openai.OpenAIClient; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; @@ -81,8 +82,7 @@ void testImageToText_usesCustomDefaultModel() { argThat( req -> { @SuppressWarnings("unchecked") - java.util.Map map = - (java.util.Map) req; + Map map = (Map) req; return "my-custom-vision-model".equals(map.get("model")); }))) .thenReturn(jsonResponse); @@ -104,4 +104,34 @@ void testConstructor_rejectsBlankDefaultModelName() { assertThrows( IllegalArgumentException.class, () -> new OpenAIMultiModalTool("key", null, null)); } + + @Test + void testBuilder_withAllCustomModels() { + OpenAIMultiModalTool tool = + OpenAIMultiModalTool.builder() + .apiKey("sk-test") + .baseUrl("https://custom.api.com") + .defaultVisionModel("gpt-4o-mini") + .defaultImageGenModel("dall-e-2") + .defaultTtsModel("tts-1-hd") + .defaultSttModel("whisper-1") + .build(); + assertNotNull(tool); + } + + @Test + void testBuilder_usesDefaultsWhenNotSet() { + OpenAIMultiModalTool tool = OpenAIMultiModalTool.builder().apiKey("sk-test").build(); + assertNotNull(tool); + } + + @Test + void testBuilder_rejectsBlankApiKey() { + assertThrows( + IllegalArgumentException.class, + () -> OpenAIMultiModalTool.builder().apiKey("").build()); + assertThrows( + IllegalArgumentException.class, + () -> OpenAIMultiModalTool.builder().apiKey(null).build()); + } }