Skip to content

Commit a39aa1a

Browse files
committed
fix(core): wire fallback model into ReActAgent
1 parent 7d3c6a1 commit a39aa1a

3 files changed

Lines changed: 104 additions & 2 deletions

File tree

agentscope-core/src/main/java/io/agentscope/core/ReActAgent.java

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
import io.agentscope.core.middleware.MiddlewareChain;
8484
import io.agentscope.core.middleware.ModelCallInput;
8585
import io.agentscope.core.middleware.ReasoningInput;
86+
import io.agentscope.core.model.ChatResponse;
8687
import io.agentscope.core.model.ChatUsage;
8788
import io.agentscope.core.model.ExecutionConfig;
8889
import io.agentscope.core.model.GenerateOptions;
@@ -2083,7 +2084,7 @@ Flux<AgentEvent> reasoningStream(
20832084
rc,
20842085
MiddlewareBase::onModelCall,
20852086
modelCallCore)
2086-
.apply(new ModelCallInput(messages, tools, options, model))
2087+
.apply(new ModelCallInput(messages, tools, options, modelForCall()))
20872088
.doOnNext(this::publishEvent);
20882089
}
20892090

@@ -2994,7 +2995,7 @@ Flux<AgentEvent> summaryStream(
29942995
rc,
29952996
MiddlewareBase::onModelCall,
29962997
summaryModelCallCore)
2997-
.apply(new ModelCallInput(messages, null, options, model))
2998+
.apply(new ModelCallInput(messages, null, options, modelForCall()))
29982999
.doOnNext(this::publishEvent);
29993000
}
30003001

@@ -3246,6 +3247,18 @@ protected GenerateOptions buildGenerateOptions() {
32463247
// Start with user-configured generateOptions if available
32473248
GenerateOptions baseOptions = generateOptions;
32483249

3250+
// Layer the agent-level retry budget underneath explicit per-call settings.
3251+
if (modelConfig != null) {
3252+
GenerateOptions retryBudgetOptions =
3253+
GenerateOptions.builder()
3254+
.executionConfig(
3255+
ExecutionConfig.builder()
3256+
.maxAttempts(modelConfig.maxRetries())
3257+
.build())
3258+
.build();
3259+
baseOptions = GenerateOptions.mergeOptions(baseOptions, retryBudgetOptions);
3260+
}
3261+
32493262
// If modelExecutionConfig is set, merge it into the options
32503263
if (modelExecutionConfig != null) {
32513264
GenerateOptions execConfigOptions =
@@ -3256,6 +3269,41 @@ protected GenerateOptions buildGenerateOptions() {
32563269
return baseOptions != null ? baseOptions : GenerateOptions.builder().build();
32573270
}
32583271

3272+
private Model modelForCall() {
3273+
Model fallbackModel = modelConfig.fallbackModel();
3274+
if (fallbackModel == null) {
3275+
return model;
3276+
}
3277+
3278+
AtomicReference<Model> activeModel = new AtomicReference<>(model);
3279+
return new Model() {
3280+
@Override
3281+
public Flux<ChatResponse> stream(
3282+
List<Msg> messages, List<ToolSchema> tools, GenerateOptions options) {
3283+
Flux<ChatResponse> primaryFlux = model.stream(messages, tools, options);
3284+
return primaryFlux.switchOnFirst(
3285+
(signal, flux) -> {
3286+
if (signal.isOnError()) {
3287+
Throwable error = signal.getThrowable();
3288+
activeModel.set(fallbackModel);
3289+
log.warn(
3290+
"Primary model {} failed, switching to fallback {}",
3291+
model.getModelName(),
3292+
fallbackModel.getModelName(),
3293+
error);
3294+
return fallbackModel.stream(messages, tools, options);
3295+
}
3296+
return flux;
3297+
});
3298+
}
3299+
3300+
@Override
3301+
public String getModelName() {
3302+
return activeModel.get().getModelName();
3303+
}
3304+
};
3305+
}
3306+
32593307
@Override
32603308
protected Mono<Msg> handleInterrupt(InterruptContext context, Msg... originalArgs) {
32613309
return Mono.deferContextual(
@@ -4167,6 +4215,11 @@ public static Builder fromAgent(ReActAgent agent) {
41674215
b.model = agent.getModel();
41684216
b.maxIters = agent.getMaxIters();
41694217
b.generateOptions = agent.getGenerateOptions();
4218+
ModelConfig srcModelConfig = agent.getModelConfig();
4219+
if (srcModelConfig != null) {
4220+
b.flatMaxRetries = srcModelConfig.maxRetries();
4221+
b.flatFallbackModel = srcModelConfig.fallbackModel();
4222+
}
41704223
b.toolkit = agent.getToolkit().copy();
41714224
return b;
41724225
}

agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentNewLoopBuilderTest.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,29 @@ void builderAppliesDefaultsWhenConfigsOmitted() {
110110
assertTrue(agent.getMiddlewares().get(0) instanceof GracefulShutdownMiddleware);
111111
}
112112

113+
@Test
114+
void fromAgentCopiesModelResilienceConfig() {
115+
ChatModelBase model = newFakeModel();
116+
ChatModelBase fallback = newFakeModel();
117+
118+
ReActAgent source =
119+
ReActAgent.builder()
120+
.name("source")
121+
.sysPrompt("sys")
122+
.model(model)
123+
.fallbackModel(fallback)
124+
.maxRetries(7)
125+
.toolkit(new Toolkit())
126+
.build();
127+
128+
ReActAgent copy = ReActAgent.Builder.fromAgent(source).build();
129+
130+
assertNotNull(copy.getModelConfig());
131+
assertEquals(7, copy.getModelConfig().maxRetries());
132+
assertSame(fallback, copy.getModelConfig().fallbackModel());
133+
assertSame(model, copy.getModel());
134+
}
135+
113136
@Test
114137
void observeAddsMessagesToState() {
115138
ReActAgent agent =

agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentTest.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,32 @@ void testErrorHandling() {
552552
}
553553
}
554554

555+
@Test
556+
@DisplayName("Should switch to fallback model when primary fails")
557+
void testFallbackModel() {
558+
String errorMessage = "Primary model unavailable";
559+
MockModel primaryModel = new MockModel("").withError(errorMessage);
560+
MockModel fallbackModel = new MockModel("Fallback response");
561+
562+
agent =
563+
ReActAgent.builder()
564+
.name(TestConstants.TEST_REACT_AGENT_NAME)
565+
.sysPrompt(TestConstants.DEFAULT_SYS_PROMPT)
566+
.model(primaryModel)
567+
.fallbackModel(fallbackModel)
568+
.toolkit(mockToolkit)
569+
.build();
570+
571+
Msg userMsg = TestUtils.createUserMessage("User", TestConstants.TEST_USER_INPUT);
572+
Msg response =
573+
agent.call(userMsg).block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS));
574+
575+
assertNotNull(response, "Response should not be null");
576+
assertEquals("Fallback response", TestUtils.extractTextContent(response));
577+
assertEquals(1, primaryModel.getCallCount(), "Primary model should be tried once");
578+
assertEquals(1, fallbackModel.getCallCount(), "Fallback model should be called once");
579+
}
580+
555581
@Test
556582
@DisplayName("Should support streaming responses")
557583
void testStreaming() {

0 commit comments

Comments
 (0)