From 89ae5d84c22c6d873c248d9f32bf8f5ccaae3fef Mon Sep 17 00:00:00 2001 From: Ido Berkovich Date: Sun, 29 Dec 2024 11:20:53 +0200 Subject: [PATCH] OPIK-610 code style --- .../opik/domain/ChatCompletionService.java | 1 - .../opik/domain/llmproviders/Anthropic.java | 23 +++++----- .../llmproviders/LlmProviderFactory.java | 5 +-- .../opik/domain/llmproviders/OpenAi.java | 1 - .../llmproviders/LlmProviderFactoryTest.java | 45 +++++++------------ 5 files changed, 30 insertions(+), 45 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java index 4ee68ea76d..3de9a09c63 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java @@ -44,7 +44,6 @@ public ChatCompletionService( } public ChatCompletionResponse create(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) { - log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); var llmProviderClient = llmProviderFactory.getService(workspaceId, request.model()); llmProviderClient.validateRequest(request); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/Anthropic.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/Anthropic.java index 81e936fd34..7801410f3f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/Anthropic.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/Anthropic.java @@ -49,12 +49,12 @@ public Anthropic(LlmProviderClientConfig llmProviderClientConfig, String apiKey) @Override public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) { - var response = anthropicClient.createMessage(mapToAnthropicCreateMessageRequest(request)); + var response = anthropicClient.createMessage(toAnthropicCreateMessageRequest(request)); return ChatCompletionResponse.builder() .id(response.id) .model(response.model) - .choices(response.content.stream().map(content -> mapContentToChoice(response, content)) + .choices(response.content.stream().map(content -> toChatCompletionChoice(response, content)) .toList()) .usage(Usage.builder() .promptTokens(response.usage.inputTokens) @@ -71,7 +71,7 @@ public void generateStream( @NonNull Consumer handleMessage, @NonNull Runnable handleClose, @NonNull Consumer handleError) { validateRequest(request); - anthropicClient.createMessage(mapToAnthropicCreateMessageRequest(request), + anthropicClient.createMessage(toAnthropicCreateMessageRequest(request), new ChunkedResponseHandler(handleMessage, handleClose, handleError, request.model())); } @@ -100,14 +100,15 @@ public int getHttpErrorStatusCode(Throwable throwable) { return 500; } - private AnthropicCreateMessageRequest mapToAnthropicCreateMessageRequest(ChatCompletionRequest request) { + private AnthropicCreateMessageRequest toAnthropicCreateMessageRequest(ChatCompletionRequest request) { var builder = AnthropicCreateMessageRequest.builder(); Optional.ofNullable(request.toolChoice()) - .ifPresent(toolChoice -> builder.toolChoice(AnthropicToolChoice.from(request.toolChoice().toString()))); + .ifPresent(toolChoice -> builder.toolChoice(AnthropicToolChoice.from( + request.toolChoice().toString()))); return builder .stream(request.stream()) .model(request.model()) - .messages(request.messages().stream().map(this::mapMessage).toList()) + .messages(request.messages().stream().map(this::toMessage).toList()) .temperature(request.temperature()) .topP(request.topP()) .stopSequences(request.stop()) @@ -115,7 +116,7 @@ private AnthropicCreateMessageRequest mapToAnthropicCreateMessageRequest(ChatCom .build(); } - private AnthropicMessage mapMessage(Message message) { + private AnthropicMessage toMessage(Message message) { if (message.role() == Role.ASSISTANT) { return AnthropicMessage.builder() .role(AnthropicRole.ASSISTANT) @@ -124,7 +125,7 @@ private AnthropicMessage mapMessage(Message message) { } else if (message.role() == Role.USER) { return AnthropicMessage.builder() .role(AnthropicRole.USER) - .content(List.of(mapMessageContent(((UserMessage) message).content()))) + .content(List.of(toAnthropicMessageContent(((UserMessage) message).content()))) .build(); } @@ -132,7 +133,7 @@ private AnthropicMessage mapMessage(Message message) { throw new BadRequestException("not supported message role: " + message.role()); } - private AnthropicMessageContent mapMessageContent(Object rawContent) { + private AnthropicMessageContent toAnthropicMessageContent(Object rawContent) { if (rawContent instanceof String content) { return new AnthropicTextContent(content); } @@ -140,7 +141,8 @@ private AnthropicMessageContent mapMessageContent(Object rawContent) { throw new BadRequestException("only text content is supported"); } - private ChatCompletionChoice mapContentToChoice(AnthropicCreateMessageResponse response, AnthropicContent content) { + private ChatCompletionChoice toChatCompletionChoice( + AnthropicCreateMessageResponse response, AnthropicContent content) { return ChatCompletionChoice.builder() .message(AssistantMessage.builder() .name(content.name) @@ -172,6 +174,7 @@ private AnthropicClient newClient(String apiKey) { Optional.ofNullable(llmProviderClientConfig.getAnthropicClient()) .map(LlmProviderClientConfig.AnthropicClientConfig::logResponses) .ifPresent(anthropicClientBuilder::logResponses); + // anthropic client builder only receives one timeout variant Optional.ofNullable(llmProviderClientConfig.getCallTimeout()) .ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration())); return anthropicClientBuilder diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java index 130a355197..3d9cd4ab89 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java @@ -42,9 +42,6 @@ public LlmProviderService getService(@NonNull String workspaceId, @NonNull Strin /** * The agreed requirement is to resolve the LLM provider and its API key based on the model. - * Currently, only OPEN AI is supported, so model param is ignored. - * No further validation is needed on the model, as it's just forwarded in the OPEN AI request and will be rejected - * if not valid. */ private LlmProvider getLlmProvider(String model) { if (isModelBelongToProvider(model, ChatCompletionModel.class, ChatCompletionModel::toString)) { @@ -58,7 +55,7 @@ private LlmProvider getLlmProvider(String model) { } /** - * Finding API keys isn't paginated at the moment, since only OPEN AI is supported. + * Finding API keys isn't paginated at the moment. * Even in the future, the number of supported LLM providers per workspace is going to be very low. */ private String getEncryptedApiKey(String workspaceId, LlmProvider llmProvider) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java index 7a2b77a7a5..0104afe099 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java @@ -63,7 +63,6 @@ public int getHttpErrorStatusCode(Throwable throwable) { } /** - * Initially, only OPEN AI is supported, so no need for a more sophisticated client resolution to start with. * At the moment, openai4j client and also langchain4j wrappers, don't support dynamic API keys. That can imply * an important performance penalty for next phases. The following options should be evaluated: * - Cache clients, but can be unsafe. diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java index bff65c67bd..db205ed7c4 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java @@ -15,16 +15,20 @@ import io.dropwizard.jackson.Jackson; import io.dropwizard.jersey.validation.Validators; import jakarta.validation.Validator; +import org.apache.commons.lang3.EnumUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import java.io.IOException; import java.util.List; import java.util.UUID; +import java.util.stream.Stream; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.when; @@ -52,15 +56,15 @@ public void tearDown() { } @ParameterizedTest - @EnumSource(value = ChatCompletionModel.class) - void testGetServiceOpenai(ChatCompletionModel model) { + @MethodSource + void testGetService(String model, LlmProvider llmProvider, Class providerClass) { // setup String workspaceId = UUID.randomUUID().toString(); String apiKey = UUID.randomUUID().toString(); when(llmProviderApiKeyService.find(workspaceId)).thenReturn(ProviderApiKey.ProviderApiKeyPage.builder() .content(List.of(ProviderApiKey.builder() - .provider(LlmProvider.OPEN_AI) + .provider(llmProvider) .apiKey(EncryptionUtils.encrypt(apiKey)) .build())) .total(1) @@ -71,35 +75,18 @@ void testGetServiceOpenai(ChatCompletionModel model) { // SUT var llmProviderFactory = new LlmProviderFactory(llmProviderClientConfig, llmProviderApiKeyService); - LlmProviderService actual = llmProviderFactory.getService(workspaceId, model.toString()); + LlmProviderService actual = llmProviderFactory.getService(workspaceId, model); // assertions - assertThat(actual).isInstanceOf(OpenAi.class); + assertThat(actual).isInstanceOf(providerClass); } - @ParameterizedTest - @EnumSource(value = AnthropicChatModelName.class) - void testGetServiceAnthropic(AnthropicChatModelName model) { - // setup - String workspaceId = UUID.randomUUID().toString(); - String apiKey = UUID.randomUUID().toString(); - - when(llmProviderApiKeyService.find(workspaceId)).thenReturn(ProviderApiKey.ProviderApiKeyPage.builder() - .content(List.of(ProviderApiKey.builder() - .provider(LlmProvider.ANTHROPIC) - .apiKey(EncryptionUtils.encrypt(apiKey)) - .build())) - .total(1) - .page(1) - .size(1) - .build()); + private static Stream testGetService() { + var openAiModels = EnumUtils.getEnumList(ChatCompletionModel.class).stream() + .map(model -> arguments(model.toString(), LlmProvider.OPEN_AI, OpenAi.class)); + var anthropicModels = EnumUtils.getEnumList(AnthropicChatModelName.class).stream() + .map(model -> arguments(model.toString(), LlmProvider.ANTHROPIC, Anthropic.class)); - // SUT - var llmProviderFactory = new LlmProviderFactory(llmProviderClientConfig, llmProviderApiKeyService); - - LlmProviderService actual = llmProviderFactory.getService(workspaceId, model.toString()); - - // assertions - assertThat(actual).isInstanceOf(Anthropic.class); + return Stream.concat(openAiModels, anthropicModels); } }