Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ public Mono<ToolResultBlock> callAsync(ToolCallParam p) {
() -> {
try {
Thread.sleep(4_000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
try {
Files.writeString(
tmpFile, "slow_tool round done\n");
} catch (Exception ignored) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ public class OpenAIMultiModalTool {

private static final Logger log = LoggerFactory.getLogger(OpenAIMultiModalTool.class);

private static final String DEFAULT_VISION_MODEL = "gpt-4o";
private static final String DEFAULT_IMAGE_GEN_MODEL = "dall-e-3";
private static final String DEFAULT_TTS_MODEL = "tts-1";
private static final String DEFAULT_STT_MODEL = "whisper-1";

/** OpenAI API key. */
private final String apiKey;

Expand All @@ -68,7 +73,16 @@ public class OpenAIMultiModalTool {
private final String baseUrl;

/** Default vision model used when the caller does not specify one. */
private final String defaultModelName;
private final String defaultVisionModel;

/** Default image generation model used when the caller does not specify one. */
private final String defaultImageGenModel;

/** Default TTS model used when the caller does not specify one. */
private final String defaultTtsModel;

/** Default speech-to-text model used when the caller does not specify one. */
private final String defaultSttModel;

/**
* Create a new OpenAIMultiModalTool with default base URL.
Expand All @@ -94,19 +108,38 @@ public OpenAIMultiModalTool(String apiKey, String baseUrl) {
*
* @param apiKey the OpenAI API key
* @param baseUrl the base URL (null for default https://api.openai.com)
* @param defaultModelName the default vision model name used when the caller omits the model
* @param defaultVisionModel the default vision model name used when the caller omits the model
* parameter (e.g., "gpt-4o" for OpenAI, or your backend's vision model name)
*/
public OpenAIMultiModalTool(String apiKey, String baseUrl, String defaultModelName) {
public OpenAIMultiModalTool(String apiKey, String baseUrl, String defaultVisionModel) {
this(
apiKey,
baseUrl,
defaultVisionModel,
DEFAULT_IMAGE_GEN_MODEL,
DEFAULT_TTS_MODEL,
DEFAULT_STT_MODEL);
}

private OpenAIMultiModalTool(
String apiKey,
String baseUrl,
String defaultVisionModel,
String defaultImageGenModel,
String defaultTtsModel,
String defaultSttModel) {
if (apiKey == null || apiKey.trim().isEmpty()) {
throw new IllegalArgumentException("OpenAI API key cannot be empty.");
}
if (defaultModelName == null || defaultModelName.trim().isEmpty()) {
throw new IllegalArgumentException("defaultModelName cannot be empty.");
if (defaultVisionModel == null || defaultVisionModel.trim().isEmpty()) {
throw new IllegalArgumentException("defaultVisionModel cannot be empty.");
}
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.defaultModelName = defaultModelName;
this.defaultVisionModel = defaultVisionModel;
this.defaultImageGenModel = defaultImageGenModel;
this.defaultTtsModel = defaultTtsModel;
this.defaultSttModel = defaultSttModel;
this.client = new OpenAIClient();
}

Expand All @@ -123,18 +156,138 @@ protected OpenAIMultiModalTool(OpenAIClient client) {
* Create a new OpenAIMultiModalTool with custom client and default model (for testing).
*
* @param client the OpenAI client
* @param defaultModelName the default vision model name
* @param defaultVisionModel the default vision model name
*/
protected OpenAIMultiModalTool(OpenAIClient client, String defaultModelName) {
if (defaultModelName == null || defaultModelName.trim().isEmpty()) {
throw new IllegalArgumentException("defaultModelName cannot be empty.");
protected OpenAIMultiModalTool(OpenAIClient client, String defaultVisionModel) {
if (defaultVisionModel == null || defaultVisionModel.trim().isEmpty()) {
throw new IllegalArgumentException("defaultVisionModel cannot be empty.");
}
this.apiKey = "test-key";
this.baseUrl = null;
this.defaultModelName = defaultModelName;
this.defaultVisionModel = defaultVisionModel;
this.defaultImageGenModel = DEFAULT_IMAGE_GEN_MODEL;
this.defaultTtsModel = DEFAULT_TTS_MODEL;
this.defaultSttModel = DEFAULT_STT_MODEL;
this.client = client;
}

/**
* Creates a new builder for OpenAIMultiModalTool.
*
* @return a new Builder instance
*/
public static Builder builder() {
return new Builder();
}

/**
* Builder for OpenAIMultiModalTool.
*
* <p>Each capability has an independent default model. When not set via builder,
* the existing hardcoded default is used — no change to current behavior.
*/
public static class Builder {
private String apiKey;
private String baseUrl;
private String defaultVisionModel;
private String defaultImageGenModel;
private String defaultTtsModel;
private String defaultSttModel;

/**
* Sets the API key for OpenAI authentication.
*
* @param apiKey the API key
* @return this builder instance
*/
public Builder apiKey(String apiKey) {
this.apiKey = apiKey;
return this;
}

/**
* Sets a custom base URL for OpenAI API.
*
* @param baseUrl the base URL (null for default)
* @return this builder instance
*/
public Builder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
return this;
}

/**
* Sets the default vision model for image-to-text.
*
* <p>When not set, falls back to {@code "gpt-4o"}.
*
* @param defaultVisionModel the vision model name
* @return this builder instance
*/
public Builder defaultVisionModel(String defaultVisionModel) {
this.defaultVisionModel = defaultVisionModel;
return this;
}

/**
* Sets the default model for text-to-image generation.
*
* <p>When not set, falls back to {@code "dall-e-3"}.
*
* @param defaultImageGenModel the image generation model name
* @return this builder instance
*/
public Builder defaultImageGenModel(String defaultImageGenModel) {
this.defaultImageGenModel = defaultImageGenModel;
return this;
}

/**
* Sets the default TTS model for text-to-audio.
*
* <p>When not set, falls back to {@code "tts-1"}.
*
* @param defaultTtsModel the TTS model name
* @return this builder instance
*/
public Builder defaultTtsModel(String defaultTtsModel) {
this.defaultTtsModel = defaultTtsModel;
return this;
}

/**
* Sets the default speech-to-text model for audio-to-text.
*
* <p>When not set, falls back to {@code "whisper-1"}.
*
* @param defaultSttModel the STT model name
* @return this builder instance
*/
public Builder defaultSttModel(String defaultSttModel) {
this.defaultSttModel = defaultSttModel;
return this;
}

/**
* Builds the OpenAIMultiModalTool instance.
*
* @return configured OpenAIMultiModalTool instance
* @throws IllegalArgumentException if apiKey is not set
*/
public OpenAIMultiModalTool build() {
if (apiKey == null || apiKey.trim().isEmpty()) {
throw new IllegalArgumentException("apiKey must be set");
}
return new OpenAIMultiModalTool(
apiKey,
baseUrl,
defaultVisionModel != null ? defaultVisionModel : DEFAULT_VISION_MODEL,
defaultImageGenModel != null ? defaultImageGenModel : DEFAULT_IMAGE_GEN_MODEL,
defaultTtsModel != null ? defaultTtsModel : DEFAULT_TTS_MODEL,
defaultSttModel != null ? defaultSttModel : DEFAULT_STT_MODEL);
}
}

/**
* Generate image(s) based on the given prompt.
*
Expand Down Expand Up @@ -186,7 +339,9 @@ public Mono<ToolResultBlock> openaiTextToImage(
String responseFormat) {

String finalModel =
Optional.ofNullable(model).filter(s -> !s.trim().isEmpty()).orElse("dall-e-3");
Optional.ofNullable(model)
.filter(s -> !s.trim().isEmpty())
.orElse(this.defaultImageGenModel);
Integer finalN = Optional.ofNullable(n).orElse(1);
String finalSize =
Optional.ofNullable(size).filter(s -> !s.trim().isEmpty()).orElse("1024x1024");
Expand Down Expand Up @@ -318,7 +473,7 @@ public Mono<ToolResultBlock> openaiImageToText(
String finalModel =
Optional.ofNullable(model)
.filter(s -> !s.trim().isEmpty())
.orElse(this.defaultModelName);
.orElse(this.defaultVisionModel);
String finalPrompt =
Optional.ofNullable(prompt)
.filter(s -> !s.trim().isEmpty())
Expand Down Expand Up @@ -447,7 +602,9 @@ public Mono<ToolResultBlock> openaiTextToAudio(
Double speed) {

String finalModel =
Optional.ofNullable(model).filter(s -> !s.trim().isEmpty()).orElse("tts-1");
Optional.ofNullable(model)
.filter(s -> !s.trim().isEmpty())
.orElse(this.defaultTtsModel);
String finalVoice =
Optional.ofNullable(voice).filter(s -> !s.trim().isEmpty()).orElse("alloy");
String finalResponseFormat =
Expand Down Expand Up @@ -541,7 +698,9 @@ public Mono<ToolResultBlock> openaiAudioToText(
Double temperature) {

String finalModel =
Optional.ofNullable(model).filter(s -> !s.trim().isEmpty()).orElse("whisper-1");
Optional.ofNullable(model)
.filter(s -> !s.trim().isEmpty())
.orElse(this.defaultSttModel);
String finalResponseFormat =
Optional.ofNullable(responseFormat).filter(s -> !s.trim().isEmpty()).orElse("text");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.agentscope.core.message.URLSource;
import io.agentscope.extensions.model.openai.OpenAIClient;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -81,8 +82,7 @@ void testImageToText_usesCustomDefaultModel() {
argThat(
req -> {
@SuppressWarnings("unchecked")
java.util.Map<String, Object> map =
(java.util.Map<String, Object>) req;
Map<String, Object> map = (Map<String, Object>) req;
return "my-custom-vision-model".equals(map.get("model"));
})))
.thenReturn(jsonResponse);
Expand All @@ -104,4 +104,34 @@ void testConstructor_rejectsBlankDefaultModelName() {
assertThrows(
IllegalArgumentException.class, () -> new OpenAIMultiModalTool("key", null, null));
}

@Test
void testBuilder_withAllCustomModels() {
OpenAIMultiModalTool tool =
OpenAIMultiModalTool.builder()
.apiKey("sk-test")
.baseUrl("https://custom.api.com")
.defaultVisionModel("gpt-4o-mini")
.defaultImageGenModel("dall-e-2")
.defaultTtsModel("tts-1-hd")
.defaultSttModel("whisper-1")
.build();
assertNotNull(tool);
}

@Test
void testBuilder_usesDefaultsWhenNotSet() {
OpenAIMultiModalTool tool = OpenAIMultiModalTool.builder().apiKey("sk-test").build();
assertNotNull(tool);
}

@Test
void testBuilder_rejectsBlankApiKey() {
assertThrows(
IllegalArgumentException.class,
() -> OpenAIMultiModalTool.builder().apiKey("").build());
assertThrows(
IllegalArgumentException.class,
() -> OpenAIMultiModalTool.builder().apiKey(null).build());
}
}
Loading