Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,18 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.http.client.ClientHttpRequestFactoryBuilder;
import org.springframework.boot.http.client.HttpClientSettings;
import org.springframework.boot.http.client.autoconfigure.HttpClientSettingsPropertyMapper;
import org.springframework.boot.http.client.reactive.ClientHttpConnectorBuilder;
import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.reactive.ClientHttpConnector;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
Expand All @@ -65,15 +72,30 @@ public class AnthropicChatAutoConfiguration {
@ConditionalOnMissingBean
public AnthropicApi anthropicApi(AnthropicConnectionProperties connectionProperties,
ObjectProvider<RestClient.Builder> restClientBuilderProvider,
ObjectProvider<WebClient.Builder> webClientBuilderProvider, ResponseErrorHandler responseErrorHandler) {
ObjectProvider<WebClient.Builder> webClientBuilderProvider, ResponseErrorHandler responseErrorHandler,
ObjectProvider<SslBundles> sslBundles, ObjectProvider<HttpClientSettings> globalHttpClientSettings,
ObjectProvider<ClientHttpRequestFactoryBuilder<?>> factoryBuilder,
ObjectProvider<ClientHttpConnectorBuilder<?>> webConnectorBuilderProvider) {

HttpClientSettingsPropertyMapper mapper = new HttpClientSettingsPropertyMapper(sslBundles.getIfAvailable(),
globalHttpClientSettings.getIfAvailable());
HttpClientSettings httpClientSettings = mapper.map(connectionProperties);

RestClient.Builder restClientBuilder = restClientBuilderProvider.getIfAvailable(RestClient::builder);
applyRestClientSettings(restClientBuilder, httpClientSettings,
factoryBuilder.getIfAvailable(ClientHttpRequestFactoryBuilder::detect));

WebClient.Builder webClientBuilder = webClientBuilderProvider.getIfAvailable(WebClient::builder);
applyWebClientSettings(webClientBuilder, httpClientSettings,
webConnectorBuilderProvider.getIfAvailable(ClientHttpConnectorBuilder::detect));

return AnthropicApi.builder()
.baseUrl(connectionProperties.getBaseUrl())
.completionsPath(connectionProperties.getCompletionsPath())
.apiKey(connectionProperties.getApiKey())
.anthropicVersion(connectionProperties.getVersion())
.restClientBuilder(restClientBuilderProvider.getIfAvailable(RestClient::builder))
.webClientBuilder(webClientBuilderProvider.getIfAvailable(WebClient::builder))
.restClientBuilder(restClientBuilder)
.webClientBuilder(webClientBuilder)
.responseErrorHandler(responseErrorHandler)
.anthropicBetaFeatures(connectionProperties.getBetaVersion())
.build();
Expand Down Expand Up @@ -102,4 +124,16 @@ public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi, Anthropi
return chatModel;
}

private void applyRestClientSettings(RestClient.Builder builder, HttpClientSettings httpClientSettings,
ClientHttpRequestFactoryBuilder<?> factoryBuilder) {
ClientHttpRequestFactory requestFactory = factoryBuilder.build(httpClientSettings);
builder.requestFactory(requestFactory);
}

private void applyWebClientSettings(WebClient.Builder builder, HttpClientSettings httpClientSettings,
ClientHttpConnectorBuilder<?> connectorBuilder) {
ClientHttpConnector connector = connectorBuilder.build(httpClientSettings);
builder.clientConnector(connector);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.http.client.autoconfigure.HttpClientSettingsProperties;

/**
* Anthropic API connection properties.
Expand All @@ -26,7 +27,7 @@
* @since 1.0.0
*/
@ConfigurationProperties(AnthropicConnectionProperties.CONFIG_PREFIX)
public class AnthropicConnectionProperties {
public class AnthropicConnectionProperties extends HttpClientSettingsProperties {

public static final String CONFIG_PREFIX = "spring.ai.anthropic";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package org.springframework.ai.model.anthropic.autoconfigure;

import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import org.apache.commons.logging.Log;
Expand Down Expand Up @@ -91,4 +93,51 @@ void stream() {
});
}

@Test
void generateWithCustomTimeout() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.anthropic.apiKey=" + System.getenv("ANTHROPIC_API_KEY"),
"spring.ai.deepseek.connect-timeout=1ms", "spring.ai.deepseek.read-timeout=1ms")
.withConfiguration(SpringAiTestAutoConfigurations.of(AnthropicChatAutoConfiguration.class))
.run(context -> {
AnthropicChatModel client = context.getBean(AnthropicChatModel.class);

// Verify that the HTTP client configuration is applied
var connectionProperties = context.getBean(AnthropicConnectionProperties.class);
assertThat(connectionProperties.getConnectTimeout()).isEqualTo(Duration.ofMillis(1));
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMillis(1));

// Verify that the client can actually make requests with the configured
// timeout
String response = client.call("Hello");
assertThat(response).isNotEmpty();
logger.info("Response with custom timeout: " + response);
});
}

@Test
void generateStreamingWithCustomTimeout() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.deepseek.apiKey=" + "sk-2567813d742c40e79fa6f1f2ee2f830c",
"spring.ai.deepseek.connect-timeout=1s", "spring.ai.deepseek.read-timeout=1s")
.withConfiguration(SpringAiTestAutoConfigurations.of(AnthropicChatAutoConfiguration.class))
.run(context -> {
AnthropicChatModel client = context.getBean(AnthropicChatModel.class);

// Verify that the HTTP client configuration is applied
var connectionProperties = context.getBean(AnthropicConnectionProperties.class);
assertThat(connectionProperties.getConnectTimeout()).isEqualTo(Duration.ofSeconds(1));
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofSeconds(1));

Flux<ChatResponse> responseFlux = client.stream(new Prompt(new UserMessage("Hello")));
String response = Objects.requireNonNull(responseFlux.collectList().block())
.stream()
.map(chatResponse -> chatResponse.getResults().get(0).getOutput().getText())
.collect(Collectors.joining());

assertThat(response).isNotEmpty();
logger.info("Response with custom timeout: " + response);
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.http.client.ClientHttpRequestFactoryBuilder;
import org.springframework.boot.http.client.HttpClientSettings;
import org.springframework.boot.http.client.autoconfigure.HttpClientSettingsPropertyMapper;
import org.springframework.boot.http.client.reactive.ClientHttpConnectorBuilder;
import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.reactive.ClientHttpConnector;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.client.ResponseErrorHandler;
Expand Down Expand Up @@ -67,10 +74,24 @@ public DeepSeekChatModel deepSeekChatModel(DeepSeekConnectionProperties commonPr
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention,
ObjectProvider<ToolExecutionEligibilityPredicate> deepseekToolExecutionEligibilityPredicate) {
ObjectProvider<ToolExecutionEligibilityPredicate> deepseekToolExecutionEligibilityPredicate,
ObjectProvider<SslBundles> sslBundles, ObjectProvider<HttpClientSettings> globalHttpClientSettings,
ObjectProvider<ClientHttpRequestFactoryBuilder<?>> factoryBuilder,
ObjectProvider<ClientHttpConnectorBuilder<?>> webConnectorBuilderProvider) {

var deepSeekApi = deepSeekApi(chatProperties, commonProperties,
restClientBuilderProvider.getIfAvailable(RestClient::builder),
HttpClientSettingsPropertyMapper mapper = new HttpClientSettingsPropertyMapper(sslBundles.getIfAvailable(),
globalHttpClientSettings.getIfAvailable());
HttpClientSettings httpClientSettings = mapper.map(commonProperties);

RestClient.Builder restClientBuilder = restClientBuilderProvider.getIfAvailable(RestClient::builder);
applyRestClientSettings(restClientBuilder, httpClientSettings,
factoryBuilder.getIfAvailable(ClientHttpRequestFactoryBuilder::detect));

WebClient.Builder webClientBuilder = webClientBuilderProvider.getIfAvailable(WebClient::builder);
applyWebClientSettings(webClientBuilder, httpClientSettings,
webConnectorBuilderProvider.getIfAvailable(ClientHttpConnectorBuilder::detect));

var deepSeekApi = deepSeekApi(chatProperties, commonProperties, restClientBuilder,
webClientBuilderProvider.getIfAvailable(WebClient::builder), responseErrorHandler);

var chatModel = DeepSeekChatModel.builder()
Expand Down Expand Up @@ -111,4 +132,16 @@ private DeepSeekApi deepSeekApi(DeepSeekChatProperties chatProperties,
.build();
}

private void applyRestClientSettings(RestClient.Builder builder, HttpClientSettings httpClientSettings,
ClientHttpRequestFactoryBuilder<?> factoryBuilder) {
ClientHttpRequestFactory requestFactory = factoryBuilder.build(httpClientSettings);
builder.requestFactory(requestFactory);
}

private void applyWebClientSettings(WebClient.Builder builder, HttpClientSettings httpClientSettings,
ClientHttpConnectorBuilder<?> connectorBuilder) {
ClientHttpConnector connector = connectorBuilder.build(httpClientSettings);
builder.clientConnector(connector);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,36 @@
package org.springframework.ai.model.deepseek.autoconfigure;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.http.client.autoconfigure.HttpClientSettingsProperties;

/**
* Parent properties for DeepSeek.
*
* @author Geng Rong
*/
@ConfigurationProperties(DeepSeekConnectionProperties.CONFIG_PREFIX)
public class DeepSeekConnectionProperties extends DeepSeekParentProperties {
public class DeepSeekConnectionProperties extends HttpClientSettingsProperties {

public static final String CONFIG_PREFIX = "spring.ai.deepseek";

public static final String DEFAULT_BASE_URL = "https://api.deepseek.com";
private String apiKey;

public DeepSeekConnectionProperties() {
super.setBaseUrl(DEFAULT_BASE_URL);
private String baseUrl = "https://api.deepseek.com";

public String getApiKey() {
return this.apiKey;
}

public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}

public String getBaseUrl() {
return this.baseUrl;
}

public void setBaseUrl(String baseUrl) {
this.baseUrl = baseUrl;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.model.deepseek.autoconfigure;

import java.time.Duration;
import java.util.Objects;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -73,4 +74,51 @@ void generateStreaming() {
});
}

@Test
void generateWithCustomTimeout() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.deepseek.apiKey=" + System.getenv("DEEPSEEK_API_KEY"),
"spring.ai.deepseek.connect-timeout=5s", "spring.ai.deepseek.read-timeout=30s")
.withConfiguration(SpringAiTestAutoConfigurations.of(DeepSeekChatAutoConfiguration.class))
.run(context -> {
DeepSeekChatModel client = context.getBean(DeepSeekChatModel.class);

// Verify that the HTTP client configuration is applied
var connectionProperties = context.getBean(DeepSeekConnectionProperties.class);
assertThat(connectionProperties.getConnectTimeout()).isEqualTo(Duration.ofSeconds(5));
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofSeconds(30));

// Verify that the client can actually make requests with the configured
// timeout
String response = client.call("Hello");
assertThat(response).isNotEmpty();
logger.info("Response with custom timeout: " + response);
});
}

@Test
void generateStreamingWithCustomTimeout() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.deepseek.apiKey=" + System.getenv("DEEPSEEK_API_KEY"),
"spring.ai.deepseek.connect-timeout=1s", "spring.ai.deepseek.read-timeout=1s")
.withConfiguration(SpringAiTestAutoConfigurations.of(DeepSeekChatAutoConfiguration.class))
.run(context -> {
DeepSeekChatModel client = context.getBean(DeepSeekChatModel.class);

// Verify that the HTTP client configuration is applied
var connectionProperties = context.getBean(DeepSeekConnectionProperties.class);
assertThat(connectionProperties.getConnectTimeout()).isEqualTo(Duration.ofSeconds(1));
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofSeconds(1));

Flux<ChatResponse> responseFlux = client.stream(new Prompt(new UserMessage("Hello")));
String response = Objects.requireNonNull(responseFlux.collectList().block())
.stream()
.map(chatResponse -> chatResponse.getResults().get(0).getOutput().getText())
.collect(Collectors.joining());

assertThat(response).isNotEmpty();
logger.info("Response with custom timeout: " + response);
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package org.springframework.ai.model.deepseek.autoconfigure;

import java.time.Duration;

import org.junit.jupiter.api.Test;

import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.utils.SpringAiTestAutoConfigurations;
import org.springframework.boot.http.client.HttpRedirects;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -153,4 +156,37 @@ void chatActivation() {
});
}

@Test
public void httpClientCustomTimeouts() {
new ApplicationContextRunner().withPropertyValues(
// @formatter:off
"spring.ai.deepseek.api-key=API_KEY",
"spring.ai.deepseek.base-url=TEST_BASE_URL",
"spring.ai.deepseek.connect-timeout=5s",
"spring.ai.deepseek.read-timeout=30s")
// @formatter:on
.withConfiguration(SpringAiTestAutoConfigurations.of(DeepSeekChatAutoConfiguration.class))
.run(context -> {
var connectionProperties = context.getBean(DeepSeekConnectionProperties.class);

assertThat(connectionProperties.getConnectTimeout()).isEqualTo(Duration.ofSeconds(5));
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofSeconds(30));
});
}

@Test
public void httpClientRedirects() {
new ApplicationContextRunner().withPropertyValues(
// @formatter:off
"spring.ai.deepseek.api-key=API_KEY",
"spring.ai.deepseek.base-url=TEST_BASE_URL",
"spring.ai.deepseek.redirects=DONT_FOLLOW")
// @formatter:on
.withConfiguration(SpringAiTestAutoConfigurations.of(DeepSeekChatAutoConfiguration.class))
.run(context -> {
var connectionProperties = context.getBean(DeepSeekConnectionProperties.class);
assertThat(connectionProperties.getRedirects()).isEqualTo(HttpRedirects.DONT_FOLLOW);
});
}

}