diff --git a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java index c09f300c87..e5319711c2 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java @@ -1,7 +1,6 @@ package com.comet.opik; import com.comet.opik.api.error.JsonInvalidFormatExceptionMapper; -import com.comet.opik.domain.llmproviders.LlmProviderClientModule; import com.comet.opik.infrastructure.ConfigurationModule; import com.comet.opik.infrastructure.EncryptionUtils; import com.comet.opik.infrastructure.OpikConfiguration; @@ -16,6 +15,10 @@ import com.comet.opik.infrastructure.events.EventModule; import com.comet.opik.infrastructure.http.HttpModule; import com.comet.opik.infrastructure.job.JobGuiceyInstaller; +import com.comet.opik.infrastructure.llm.LlmModule; +import com.comet.opik.infrastructure.llm.antropic.AnthropicModule; +import com.comet.opik.infrastructure.llm.gemini.GeminiModule; +import com.comet.opik.infrastructure.llm.openai.OpenAIModule; import com.comet.opik.infrastructure.ratelimit.RateLimitModule; import com.comet.opik.infrastructure.redis.RedisModule; import com.comet.opik.utils.JsonBigDecimalDeserializer; @@ -74,7 +77,8 @@ public void initialize(Bootstrap bootstrap) { .withPlugins(new SqlObjectPlugin(), new Jackson2Plugin())) .modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(), new RateLimitModule(), new NameGeneratorModule(), new HttpModule(), new EventModule(), - new ConfigurationModule(), new BiModule(), new CacheModule(), new LlmProviderClientModule()) + new ConfigurationModule(), new BiModule(), new CacheModule(), new AnthropicModule(), + new GeminiModule(), new OpenAIModule(), new LlmModule()) .installers(JobGuiceyInstaller.class) .listen(new OpikGuiceyLifecycleEventListener()) .enableAutoConfig() diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/ChunkedResponseHandler.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ChunkedResponseHandler.java similarity index 98% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/ChunkedResponseHandler.java rename to apps/opik-backend/src/main/java/com/comet/opik/api/ChunkedResponseHandler.java index d5f89a59a7..254d74fe46 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/ChunkedResponseHandler.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ChunkedResponseHandler.java @@ -1,4 +1,4 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.api; import dev.ai4j.openai4j.chat.ChatCompletionChoice; import dev.ai4j.openai4j.chat.ChatCompletionResponse; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/LogCriteria.java b/apps/opik-backend/src/main/java/com/comet/opik/api/LogCriteria.java new file mode 100644 index 0000000000..a8fe507f1e --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/LogCriteria.java @@ -0,0 +1,19 @@ +package com.comet.opik.api; + +import lombok.Builder; +import lombok.NonNull; + +import java.util.Map; +import java.util.UUID; + +import static com.comet.opik.api.LogItem.LogLevel; + +@Builder +public record LogCriteria( + @NonNull String workspaceId, + UUID entityId, + LogLevel level, + Integer size, + Map markers) { + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/LogItem.java b/apps/opik-backend/src/main/java/com/comet/opik/api/LogItem.java new file mode 100644 index 0000000000..ec9ed59a88 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/LogItem.java @@ -0,0 +1,43 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Builder; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +@Builder +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record LogItem( + @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant timestamp, + @JsonIgnore String workspaceId, + @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID ruleId, + @Schema(accessMode = Schema.AccessMode.READ_ONLY) LogLevel level, + @Schema(accessMode = Schema.AccessMode.READ_ONLY) String message, + @Schema(accessMode = Schema.AccessMode.READ_ONLY) Map markers) { + + public enum LogLevel { + INFO, + WARN, + ERROR, + DEBUG, + TRACE + } + + @Builder + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) + public record LogPage(List content, int page, int size, long total) implements Page { + + public static LogPage empty(int page) { + return new LogPage(List.of(), page, 0, 0); + } + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringLlmAsJudgeScorer.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringLlmAsJudgeScorer.java index 47dd27fbff..df999de488 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringLlmAsJudgeScorer.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringLlmAsJudgeScorer.java @@ -2,15 +2,16 @@ import com.comet.opik.api.FeedbackScoreBatchItem; import com.comet.opik.api.events.TraceToScoreLlmAsJudge; -import com.comet.opik.domain.ChatCompletionService; import com.comet.opik.domain.FeedbackScoreService; import com.comet.opik.domain.UserLog; +import com.comet.opik.domain.llm.ChatCompletionService; import com.comet.opik.infrastructure.OnlineScoringConfig; import com.comet.opik.infrastructure.OnlineScoringConfig.StreamConfiguration; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.log.UserFacingLoggingFactory; import dev.langchain4j.model.chat.request.ChatRequest; import dev.langchain4j.model.chat.response.ChatResponse; +import io.dropwizard.lifecycle.Managed; import jakarta.inject.Inject; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; @@ -22,6 +23,7 @@ import org.redisson.client.codec.Codec; import org.slf4j.Logger; import org.slf4j.MDC; +import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -29,8 +31,10 @@ import ru.vyarus.dropwizard.guice.module.yaml.bind.Config; import java.math.BigDecimal; +import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.UUID; import static com.comet.opik.api.AutomationRuleEvaluatorType.LLM_AS_JUDGE; @@ -41,10 +45,12 @@ /** * This service listens a Redis stream for Traces to be scored in a LLM provider. It will prepare the LLM request * by rendering message templates using values from the Trace and prepare the schema for the return (structured output). + * + * The service has to implement the Managed interface to be able to start and stop the stream connected to the application lifecycle. */ @EagerSingleton @Slf4j -public class OnlineScoringLlmAsJudgeScorer { +public class OnlineScoringLlmAsJudgeScorer implements Managed { private final OnlineScoringConfig config; private final ChatCompletionService aiProxyService; @@ -53,6 +59,10 @@ public class OnlineScoringLlmAsJudgeScorer { private final String consumerId; private final StreamReadGroupArgs redisReadConfig; private final Logger userFacingLogger; + private final RedissonReactiveClient redisson; + + private RStreamReactive stream; + private Disposable streamSubscription; // Store the subscription reference @Inject public OnlineScoringLlmAsJudgeScorer(@NonNull @Config("onlineScoring") OnlineScoringConfig config, @@ -62,29 +72,80 @@ public OnlineScoringLlmAsJudgeScorer(@NonNull @Config("onlineScoring") OnlineSco this.config = config; this.aiProxyService = aiProxyService; this.feedbackScoreService = feedbackScoreService; + this.redisson = redisson; this.redisReadConfig = StreamReadGroupArgs.neverDelivered().count(config.getConsumerBatchSize()); this.consumerId = "consumer-" + config.getConsumerGroupName() + "-" + UUID.randomUUID(); userFacingLogger = UserFacingLoggingFactory.getLogger(OnlineScoringLlmAsJudgeScorer.class); + } + + @Override + public void start() { + if (stream != null) { + log.warn("OnlineScoringLlmAsJudgeScorer already started. Ignoring start request."); + return; + } // as we are a LLM consumer, lets check only LLM stream - initStream(config, redisson); + stream = initStream(config, redisson); + log.info("OnlineScoringLlmAsJudgeScorer started."); } - private void initStream(OnlineScoringConfig config, RedissonReactiveClient redisson) { - config.getStreams().stream() + @Override + public void stop() { + log.info("Shutting down OnlineScoringLlmAsJudgeScorer and closing stream."); + if (stream != null) { + if (streamSubscription != null && !streamSubscription.isDisposed()) { + log.info("Waiting for last messages to be processed before shutdown..."); + + try { + // Read any remaining messages before stopping + stream.readGroup(config.getConsumerGroupName(), consumerId, redisReadConfig) + .flatMap(messages -> { + if (!messages.isEmpty()) { + log.info("Processing last {} messages before shutdown.", messages.size()); + + return Flux.fromIterable(messages.entrySet()) + .publishOn(Schedulers.boundedElastic()) + .doOnNext(entry -> processReceivedMessages(stream, entry)) + .collectList() + .then(Mono.fromRunnable(() -> streamSubscription.dispose())); + } + + return Mono.fromRunnable(() -> streamSubscription.dispose()); + }) + .block(Duration.ofSeconds(2)); + } catch (Exception e) { + log.error("Error processing last messages before shutdown: {}", e.getMessage(), e); + } + } else { + log.info("No active subscription, deleting Redis stream."); + } + + stream.delete().doOnTerminate(() -> log.info("Redis Stream deleted")).subscribe(); + } + } + + private RStreamReactive initStream(OnlineScoringConfig config, + RedissonReactiveClient redisson) { + Optional configuration = config.getStreams().stream() .filter(this::isLlmAsJudge) - .findFirst() - .ifPresentOrElse( - llmConfig -> setupListener(redisson, llmConfig), - this::logIfEmpty); + .findFirst(); + + if (configuration.isEmpty()) { + this.logIfEmpty(); + return null; + } + + return setupListener(redisson, configuration.get()); } private void logIfEmpty() { log.warn("No '{}' redis stream config found. Online Scoring consumer won't start.", LLM_AS_JUDGE.name()); } - private void setupListener(RedissonReactiveClient redisson, StreamConfiguration llmConfig) { + private RStreamReactive setupListener(RedissonReactiveClient redisson, + StreamConfiguration llmConfig) { var scoringCodecs = OnlineScoringCodecs.fromString(llmConfig.getCodec()); String streamName = llmConfig.getStreamName(); Codec codec = scoringCodecs.getCodec(); @@ -95,6 +156,8 @@ private void setupListener(RedissonReactiveClient redisson, StreamConfiguration enforceConsumerGroup(stream); setupStreamListener(stream); + + return stream; } private boolean isLlmAsJudge(StreamConfiguration streamConfiguration) { @@ -118,7 +181,7 @@ private void enforceConsumerGroup(RStreamReactive stream) { // Listen for messages - Flux.interval(config.getPoolingInterval().toJavaDuration()) + this.streamSubscription = Flux.interval(config.getPoolingInterval().toJavaDuration()) .flatMap(i -> stream.readGroup(config.getConsumerGroupName(), consumerId, redisReadConfig)) .flatMap(messages -> Flux.fromIterable(messages.entrySet())) .publishOn(Schedulers.boundedElastic()) @@ -150,7 +213,7 @@ private void processReceivedMessages(RStreamReactive findLogs(LogCriteria criteria); + +} + +@Slf4j +@Singleton +@RequiredArgsConstructor(onConstructor_ = @Inject) +class AutomationRuleEvaluatorLogsDAOImpl implements AutomationRuleEvaluatorLogsDAO { + + private static final String INSERT_STATEMENT = """ + INSERT INTO automation_rule_evaluator_logs (timestamp, level, workspace_id, rule_id, message, markers) + VALUES , 9), + :level, + :workspace_id, + :rule_id, + :message, + mapFromArrays(:marker_keys, :marker_values) + ) + , + }> + ; + """; + + public static final String FIND_ALL = """ + SELECT * FROM automation_rule_evaluator_logs + WHERE workspace_id = :workspaceId + AND level = :level + AND rule_id = :ruleId + ORDER BY timestamp DESC + LIMIT :limit OFFSET :offset + """; + + private final @NonNull ConnectionFactory connectionFactory; + + public Mono findLogs(@NonNull LogCriteria criteria) { + return Mono.from(connectionFactory.create()) + .flatMapMany(connection -> { + log.info("Finding logs with criteria: {}", criteria); + + var template = new ST(FIND_ALL); + + bindTemplateParameters(criteria, template); + + Statement statement = connection.createStatement(template.render()); + + bindParameters(criteria, statement); + + return statement.execute(); + }) + .flatMap(result -> result.map((row, rowMetadata) -> mapRow(row))) + .collectList() + .map(this::mapPage); + } + + private LogPage mapPage(List logs) { + return LogPage.builder() + .content(logs) + .page(1) + .total(logs.size()) + .size(logs.size()) + .build(); + } + + private LogItem mapRow(Row row) { + return LogItem.builder() + .timestamp(row.get("timestamp", Instant.class)) + .level(LogLevel.valueOf(row.get("level", String.class))) + .workspaceId(row.get("workspace_id", String.class)) + .ruleId(row.get("rule_id", UUID.class)) + .message(row.get("message", String.class)) + .markers(row.get("markers", Map.class)) + .build(); + } + + private void bindTemplateParameters(LogCriteria criteria, ST template) { + Optional.ofNullable(criteria.level()).ifPresent(level -> template.add("level", level)); + Optional.ofNullable(criteria.entityId()).ifPresent(ruleId -> template.add("ruleId", ruleId)); + Optional.ofNullable(criteria.size()).ifPresent(limit -> template.add("limit", limit)); + } + + private void bindParameters(LogCriteria criteria, Statement statement) { + statement.bind("workspaceId", criteria.workspaceId()); + Optional.ofNullable(criteria.level()).ifPresent(level -> statement.bind("level", level)); + Optional.ofNullable(criteria.entityId()).ifPresent(ruleId -> statement.bind("ruleId", ruleId)); + Optional.ofNullable(criteria.size()).ifPresent(limit -> statement.bind("limit", limit)); + } + + @Override + public Mono saveAll(@NonNull List events) { + return Mono.from(connectionFactory.create()) + .flatMapMany(connection -> { + var template = new ST(INSERT_STATEMENT); + List queryItems = getQueryItemPlaceHolder(events.size()); + + template.add("items", queryItems); + + Statement statement = connection.createStatement(template.render()); + + for (int i = 0; i < events.size(); i++) { + ILoggingEvent event = events.get(i); + + String logLevel = event.getLevel().toString(); + String workspaceId = Optional.ofNullable(event.getMDCPropertyMap().get("workspace_id")) + .orElseThrow(() -> failWithMessage("workspace_id is not set")); + String traceId = Optional.ofNullable(event.getMDCPropertyMap().get("trace_id")) + .orElseThrow(() -> failWithMessage("trace_id is not set")); + String ruleId = Optional.ofNullable(event.getMDCPropertyMap().get("rule_id")) + .orElseThrow(() -> failWithMessage("rule_id is not set")); + + statement + .bind("timestamp" + i, event.getInstant().toString()) + .bind("level" + i, logLevel) + .bind("workspace_id" + i, workspaceId) + .bind("rule_id" + i, ruleId) + .bind("message" + i, event.getFormattedMessage()) + .bind("marker_keys" + i, new String[]{"trace_id"}) + .bind("marker_values" + i, new String[]{traceId}); + } + + return statement.execute(); + }) + .collectList() + .then(); + } + + private IllegalStateException failWithMessage(String message) { + log.error(message); + return new IllegalStateException(message); + } + +} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/AutomationRuleEvaluatorService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/AutomationRuleEvaluatorService.java index 3335750068..b658e2084c 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/AutomationRuleEvaluatorService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/AutomationRuleEvaluatorService.java @@ -6,6 +6,7 @@ import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; import com.comet.opik.api.AutomationRuleEvaluatorType; import com.comet.opik.api.AutomationRuleEvaluatorUpdate; +import com.comet.opik.api.LogCriteria; import com.comet.opik.api.error.EntityAlreadyExistsException; import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.infrastructure.cache.CacheEvict; @@ -19,6 +20,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.jdbi.v3.core.statement.UnableToExecuteStatementException; +import reactor.core.publisher.Mono; import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; import java.sql.SQLIntegrityConstraintViolationException; @@ -28,6 +30,7 @@ import java.util.UUID; import static com.comet.opik.api.AutomationRuleEvaluator.AutomationRuleEvaluatorPage; +import static com.comet.opik.api.LogItem.LogPage; import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY; import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; @@ -50,6 +53,8 @@ AutomationRuleEvaluatorPage find(@NonNull UUID projectId, @NonNull String worksp List findAll(@NonNull UUID projectId, @NonNull String workspaceId, AutomationRuleEvaluatorType automationRuleEvaluatorType); + + Mono getLogs(LogCriteria criteria); } @Singleton @@ -61,6 +66,7 @@ class AutomationRuleEvaluatorServiceImpl implements AutomationRuleEvaluatorServi private final @NonNull IdGenerator idGenerator; private final @NonNull TransactionTemplate template; + private final @NonNull AutomationRuleEvaluatorLogsDAO logsDAO; @Override @CacheEvict(name = "automation_rule_evaluators_find_by_type", key = "$projectId +'-'+ $workspaceId +'-'+ $inputRuleEvaluator.type") @@ -250,4 +256,9 @@ public List findAll(@NonNull UUID projectId, }); } + @Override + public Mono getLogs(@NonNull LogCriteria criteria) { + return logsDAO.findLogs(criteria); + } + } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llm/ChatCompletionService.java similarity index 97% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java rename to apps/opik-backend/src/main/java/com/comet/opik/domain/llm/ChatCompletionService.java index 8bda68aacc..4bc1912a82 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llm/ChatCompletionService.java @@ -1,8 +1,6 @@ -package com.comet.opik.domain; +package com.comet.opik.domain.llm; import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; -import com.comet.opik.domain.llmproviders.LlmProviderFactory; -import com.comet.opik.domain.llmproviders.LlmProviderService; import com.comet.opik.infrastructure.LlmProviderClientConfig; import com.comet.opik.utils.ChunkedOutputHandlers; import dev.ai4j.openai4j.chat.ChatCompletionRequest; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llm/LlmProviderFactory.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llm/LlmProviderFactory.java new file mode 100644 index 0000000000..f70aaab932 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llm/LlmProviderFactory.java @@ -0,0 +1,19 @@ +package com.comet.opik.domain.llm; + +import com.comet.opik.api.LlmProvider; +import com.comet.opik.infrastructure.llm.LlmServiceProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; + +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters; + +public interface LlmProviderFactory { + + String ERROR_MODEL_NOT_SUPPORTED = "model not supported %s"; + + void register(LlmProvider llmProvider, LlmServiceProvider service); + + LlmProviderService getService(String workspaceId, String model); + + ChatLanguageModel getLanguageModel(String workspaceId, LlmAsJudgeModelParameters modelParameters); + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llm/LlmProviderService.java similarity index 95% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java rename to apps/opik-backend/src/main/java/com/comet/opik/domain/llm/LlmProviderService.java index 83d9140f1d..3ef78b8c5a 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llm/LlmProviderService.java @@ -1,4 +1,4 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.domain.llm; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientModule.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientModule.java deleted file mode 100644 index f6ea5a4cb8..0000000000 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientModule.java +++ /dev/null @@ -1,19 +0,0 @@ -package com.comet.opik.domain.llmproviders; - -import com.comet.opik.infrastructure.LlmProviderClientConfig; -import com.comet.opik.infrastructure.OpikConfiguration; -import com.google.inject.Provides; -import jakarta.inject.Singleton; -import lombok.NonNull; -import ru.vyarus.dropwizard.guice.module.support.DropwizardAwareModule; -import ru.vyarus.dropwizard.guice.module.yaml.bind.Config; - -public class LlmProviderClientModule extends DropwizardAwareModule { - - @Provides - @Singleton - public LlmProviderClientGenerator clientGenerator( - @NonNull @Config("llmProviderClient") LlmProviderClientConfig config) { - return new LlmProviderClientGenerator(config); - } -} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmModule.java new file mode 100644 index 0000000000..4b769d93f8 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmModule.java @@ -0,0 +1,20 @@ +package com.comet.opik.infrastructure.llm; + +import com.comet.opik.domain.LlmProviderApiKeyService; +import com.comet.opik.domain.llm.LlmProviderFactory; +import com.google.inject.AbstractModule; +import com.google.inject.Provides; +import jakarta.inject.Singleton; + +public class LlmModule extends AbstractModule { + + @Provides + @Singleton + public LlmProviderFactory llmProviderFactory(LlmProviderApiKeyService llmProviderApiKeyService) { + return createInstance(llmProviderApiKeyService); + } + + public LlmProviderFactory createInstance(LlmProviderApiKeyService llmProviderApiKeyService) { + return new LlmProviderFactoryImpl(llmProviderApiKeyService); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmProviderClientGenerator.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmProviderClientGenerator.java new file mode 100644 index 0000000000..268b1c1a28 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmProviderClientGenerator.java @@ -0,0 +1,12 @@ +package com.comet.opik.infrastructure.llm; + +import dev.langchain4j.model.chat.ChatLanguageModel; + +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters; + +public interface LlmProviderClientGenerator { + + T generate(String apiKey, Object... params); + + ChatLanguageModel generateChat(String apiKey, LlmAsJudgeModelParameters modelParameters); +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmProviderFactoryImpl.java similarity index 64% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmProviderFactoryImpl.java index 3b1d6f5d2f..99bfb3f633 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmProviderFactoryImpl.java @@ -1,49 +1,58 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm; -import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; import com.comet.opik.api.LlmProvider; import com.comet.opik.domain.LlmProviderApiKeyService; +import com.comet.opik.domain.llm.LlmProviderFactory; +import com.comet.opik.domain.llm.LlmProviderService; import com.comet.opik.infrastructure.EncryptionUtils; +import com.comet.opik.infrastructure.llm.antropic.AnthropicModelName; +import com.comet.opik.infrastructure.llm.gemini.GeminiModelName; +import com.comet.opik.infrastructure.llm.openai.OpenaiModelName; import dev.langchain4j.model.chat.ChatLanguageModel; import jakarta.inject.Inject; -import jakarta.inject.Singleton; import jakarta.ws.rs.BadRequestException; import lombok.NonNull; import lombok.RequiredArgsConstructor; import org.apache.commons.lang3.EnumUtils; +import java.util.EnumMap; +import java.util.Map; +import java.util.Optional; import java.util.function.Function; -@Singleton +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters; + @RequiredArgsConstructor(onConstructor_ = @Inject) -public class LlmProviderFactory { - public static final String ERROR_MODEL_NOT_SUPPORTED = "model not supported %s"; +class LlmProviderFactoryImpl implements LlmProviderFactory { private final @NonNull LlmProviderApiKeyService llmProviderApiKeyService; - private final @NonNull LlmProviderClientGenerator llmProviderClientGenerator; + private final Map services = new EnumMap<>(LlmProvider.class); + + public void register(LlmProvider llmProvider, LlmServiceProvider service) { + services.put(llmProvider, service); + } public LlmProviderService getService(@NonNull String workspaceId, @NonNull String model) { var llmProvider = getLlmProvider(model); var apiKey = EncryptionUtils.decrypt(getEncryptedApiKey(workspaceId, llmProvider)); - return switch (llmProvider) { - case LlmProvider.OPEN_AI -> new LlmProviderOpenAi(llmProviderClientGenerator.newOpenAiClient(apiKey)); - case LlmProvider.ANTHROPIC -> - new LlmProviderAnthropic(llmProviderClientGenerator.newAnthropicClient(apiKey)); - case LlmProvider.GEMINI -> new LlmProviderGemini(llmProviderClientGenerator, apiKey); - }; + return Optional.ofNullable(services.get(llmProvider)) + .map(provider -> provider.getService(apiKey)) + .orElseThrow(() -> new LlmProviderUnsupportedException( + "LLM provider not supported: %s".formatted(llmProvider))); } public ChatLanguageModel getLanguageModel(@NonNull String workspaceId, - @NonNull AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters modelParameters) { + @NonNull LlmAsJudgeModelParameters modelParameters) { var llmProvider = getLlmProvider(modelParameters.name()); var apiKey = EncryptionUtils.decrypt(getEncryptedApiKey(workspaceId, llmProvider)); - return switch (llmProvider) { - case LlmProvider.OPEN_AI -> llmProviderClientGenerator.newOpenAiChatLanguageModel(apiKey, modelParameters); - default -> throw new BadRequestException(String.format(ERROR_MODEL_NOT_SUPPORTED, modelParameters.name())); - }; + return Optional.ofNullable(services.get(llmProvider)) + .map(provider -> provider.getLanguageModel(apiKey, modelParameters)) + .orElseThrow(() -> new BadRequestException( + String.format(ERROR_MODEL_NOT_SUPPORTED, modelParameters.name()))); } + /** * The agreed requirement is to resolve the LLM provider and its API key based on the model. */ diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmProviderUnsupportedException.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmProviderUnsupportedException.java new file mode 100644 index 0000000000..f83d32c443 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmProviderUnsupportedException.java @@ -0,0 +1,15 @@ +package com.comet.opik.infrastructure.llm; + +import io.dropwizard.jersey.errors.ErrorMessage; +import jakarta.ws.rs.ClientErrorException; +import jakarta.ws.rs.core.Response; + +public class LlmProviderUnsupportedException extends ClientErrorException { + + public LlmProviderUnsupportedException(String message) { + super( + Response.status(Response.Status.CONFLICT) + .entity(new ErrorMessage(Response.Status.CONFLICT.getStatusCode(), message)) + .build()); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmServiceProvider.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmServiceProvider.java new file mode 100644 index 0000000000..085ec5ee8a --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/LlmServiceProvider.java @@ -0,0 +1,13 @@ +package com.comet.opik.infrastructure.llm; + +import com.comet.opik.domain.llm.LlmProviderService; +import dev.langchain4j.model.chat.ChatLanguageModel; + +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters; + +public interface LlmServiceProvider { + + LlmProviderService getService(String apiKey); + + ChatLanguageModel getLanguageModel(String apiKey, LlmAsJudgeModelParameters modelParameters); +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicClientGenerator.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicClientGenerator.java new file mode 100644 index 0000000000..b599da7bf2 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicClientGenerator.java @@ -0,0 +1,70 @@ +package com.comet.opik.infrastructure.llm.antropic; + +import com.comet.opik.infrastructure.LlmProviderClientConfig; +import com.comet.opik.infrastructure.llm.LlmProviderClientGenerator; +import dev.langchain4j.model.anthropic.AnthropicChatModel; +import dev.langchain4j.model.anthropic.internal.client.AnthropicClient; +import dev.langchain4j.model.chat.ChatLanguageModel; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; + +import java.util.Optional; + +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters; + +@RequiredArgsConstructor +public class AnthropicClientGenerator implements LlmProviderClientGenerator { + + private final @NonNull LlmProviderClientConfig llmProviderClientConfig; + + private AnthropicClient newAnthropicClient(@NonNull String apiKey) { + var anthropicClientBuilder = AnthropicClient.builder(); + Optional.ofNullable(llmProviderClientConfig.getAnthropicClient()) + .map(LlmProviderClientConfig.AnthropicClientConfig::url) + .filter(StringUtils::isNotEmpty) + .ifPresent(anthropicClientBuilder::baseUrl); + Optional.ofNullable(llmProviderClientConfig.getAnthropicClient()) + .map(LlmProviderClientConfig.AnthropicClientConfig::version) + .filter(StringUtils::isNotBlank) + .ifPresent(anthropicClientBuilder::version); + Optional.ofNullable(llmProviderClientConfig.getLogRequests()) + .ifPresent(anthropicClientBuilder::logRequests); + Optional.ofNullable(llmProviderClientConfig.getLogResponses()) + .ifPresent(anthropicClientBuilder::logResponses); + // anthropic client builder only receives one timeout variant + Optional.ofNullable(llmProviderClientConfig.getCallTimeout()) + .ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration())); + return anthropicClientBuilder + .apiKey(apiKey) + .build(); + } + + private ChatLanguageModel newChatLanguageModel(String apiKey, LlmAsJudgeModelParameters modelParameters) { + var builder = AnthropicChatModel.builder() + .apiKey(apiKey) + .modelName(modelParameters.name()); + + Optional.ofNullable(llmProviderClientConfig.getConnectTimeout()) + .ifPresent(connectTimeout -> builder.timeout(connectTimeout.toJavaDuration())); + + Optional.ofNullable(llmProviderClientConfig.getOpenAiClient()) + .map(LlmProviderClientConfig.OpenAiClientConfig::url) + .filter(StringUtils::isNotBlank) + .ifPresent(builder::baseUrl); + + Optional.ofNullable(modelParameters.temperature()).ifPresent(builder::temperature); + + return builder.build(); + } + + @Override + public AnthropicClient generate(@NonNull String apiKey, Object... params) { + return newAnthropicClient(apiKey); + } + + @Override + public ChatLanguageModel generateChat(@NonNull String apiKey, @NonNull LlmAsJudgeModelParameters modelParameters) { + return newChatLanguageModel(apiKey, modelParameters); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicLlmServiceProvider.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicLlmServiceProvider.java new file mode 100644 index 0000000000..0e9b9d53c2 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicLlmServiceProvider.java @@ -0,0 +1,34 @@ +package com.comet.opik.infrastructure.llm.antropic; + +import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters; +import com.comet.opik.api.LlmProvider; +import com.comet.opik.domain.llm.LlmProviderFactory; +import com.comet.opik.domain.llm.LlmProviderService; +import com.comet.opik.infrastructure.llm.LlmServiceProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +@RequiredArgsConstructor +class AnthropicLlmServiceProvider implements LlmServiceProvider { + + private final AnthropicClientGenerator clientGenerator; + + AnthropicLlmServiceProvider(@NonNull AnthropicClientGenerator clientGenerator, + @NonNull LlmProviderFactory factory) { + this.clientGenerator = clientGenerator; + factory.register(LlmProvider.ANTHROPIC, this); + } + + @Override + public LlmProviderService getService(String apiKey) { + return new LlmProviderAnthropic(clientGenerator.generate(apiKey)); + } + + @Override + public ChatLanguageModel getLanguageModel(@NonNull String apiKey, + @NonNull LlmAsJudgeModelParameters modelParameters) { + return clientGenerator.generateChat(apiKey, modelParameters); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/AnthropicModelName.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicModelName.java similarity index 94% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/AnthropicModelName.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicModelName.java index 25759662bc..5ef415c3e8 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/AnthropicModelName.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicModelName.java @@ -1,4 +1,4 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm.antropic; import lombok.RequiredArgsConstructor; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicModule.java new file mode 100644 index 0000000000..0c87397db6 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/AnthropicModule.java @@ -0,0 +1,30 @@ +package com.comet.opik.infrastructure.llm.antropic; + +import com.comet.opik.domain.llm.LlmProviderFactory; +import com.comet.opik.infrastructure.LlmProviderClientConfig; +import com.comet.opik.infrastructure.llm.LlmServiceProvider; +import com.google.inject.AbstractModule; +import com.google.inject.Provides; +import com.google.inject.Singleton; +import jakarta.inject.Named; +import lombok.NonNull; +import ru.vyarus.dropwizard.guice.module.yaml.bind.Config; + +public class AnthropicModule extends AbstractModule { + + @Provides + @Singleton + public AnthropicClientGenerator clientGenerator( + @NonNull @Config("llmProviderClient") LlmProviderClientConfig config) { + return new AnthropicClientGenerator(config); + } + + @Provides + @Singleton + @Named("anthropic") + public LlmServiceProvider llmServiceProvider(@NonNull LlmProviderFactory llmProviderFactory, + @NonNull AnthropicClientGenerator clientGenerator) { + return new AnthropicLlmServiceProvider(clientGenerator, llmProviderFactory); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/LlmProviderAnthropic.java similarity index 77% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/LlmProviderAnthropic.java index 511cd250a9..80ea278948 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropic.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/LlmProviderAnthropic.java @@ -1,5 +1,7 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm.antropic; +import com.comet.opik.api.ChunkedResponseHandler; +import com.comet.opik.domain.llm.LlmProviderService; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.langchain4j.model.anthropic.internal.client.AnthropicClient; @@ -14,20 +16,21 @@ import java.util.Optional; import java.util.function.Consumer; -import static com.comet.opik.domain.ChatCompletionService.ERROR_EMPTY_MESSAGES; -import static com.comet.opik.domain.ChatCompletionService.ERROR_NO_COMPLETION_TOKENS; +import static com.comet.opik.domain.llm.ChatCompletionService.ERROR_EMPTY_MESSAGES; +import static com.comet.opik.domain.llm.ChatCompletionService.ERROR_NO_COMPLETION_TOKENS; @RequiredArgsConstructor @Slf4j class LlmProviderAnthropic implements LlmProviderService { + private final @NonNull AnthropicClient anthropicClient; @Override public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) { - var mapper = LlmProviderAnthropicMapper.INSTANCE; - var response = anthropicClient.createMessage(mapper.toCreateMessageRequest(request)); + var response = anthropicClient + .createMessage(LlmProviderAnthropicMapper.INSTANCE.toCreateMessageRequest(request)); - return mapper.toResponse(response); + return LlmProviderAnthropicMapper.INSTANCE.toResponse(response); } @Override @@ -35,7 +38,8 @@ public void generateStream( @NonNull ChatCompletionRequest request, @NonNull String workspaceId, @NonNull Consumer handleMessage, - @NonNull Runnable handleClose, @NonNull Consumer handleError) { + @NonNull Runnable handleClose, + @NonNull Consumer handleError) { validateRequest(request); anthropicClient.createMessage(LlmProviderAnthropicMapper.INSTANCE.toCreateMessageRequest(request), new ChunkedResponseHandler(handleMessage, handleClose, handleError, request.model())); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropicMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/LlmProviderAnthropicMapper.java similarity index 98% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropicMapper.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/LlmProviderAnthropicMapper.java index dcce42e8a0..7f3fad6770 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderAnthropicMapper.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/antropic/LlmProviderAnthropicMapper.java @@ -1,4 +1,4 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm.antropic; import dev.ai4j.openai4j.chat.AssistantMessage; import dev.ai4j.openai4j.chat.ChatCompletionChoice; @@ -27,7 +27,7 @@ import java.util.List; @Mapper -public interface LlmProviderAnthropicMapper { +interface LlmProviderAnthropicMapper { LlmProviderAnthropicMapper INSTANCE = Mappers.getMapper(LlmProviderAnthropicMapper.class); @Mapping(source = "response", target = "choices", qualifiedByName = "mapToChoices") diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiClientGenerator.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiClientGenerator.java new file mode 100644 index 0000000000..52b53468c2 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiClientGenerator.java @@ -0,0 +1,57 @@ +package com.comet.opik.infrastructure.llm.gemini; + +import com.comet.opik.infrastructure.LlmProviderClientConfig; +import com.comet.opik.infrastructure.llm.LlmProviderClientGenerator; +import com.google.common.base.Preconditions; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel; +import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +import java.util.Objects; +import java.util.Optional; + +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters; +import static dev.langchain4j.model.googleai.GoogleAiGeminiChatModel.GoogleAiGeminiChatModelBuilder; + +@RequiredArgsConstructor +public class GeminiClientGenerator implements LlmProviderClientGenerator { + + private static final int MAX_RETRIES = 1; + private final @NonNull LlmProviderClientConfig llmProviderClientConfig; + + public GoogleAiGeminiChatModel newGeminiClient(@NonNull String apiKey, @NonNull ChatCompletionRequest request) { + return LlmProviderGeminiMapper.INSTANCE.toGeminiChatModel(apiKey, request, + llmProviderClientConfig.getCallTimeout().toJavaDuration(), MAX_RETRIES); + } + + public GoogleAiGeminiStreamingChatModel newGeminiStreamingClient( + @NonNull String apiKey, @NonNull ChatCompletionRequest request) { + return LlmProviderGeminiMapper.INSTANCE.toGeminiStreamingChatModel(apiKey, request, + llmProviderClientConfig.getCallTimeout().toJavaDuration(), MAX_RETRIES); + } + + @Override + public GoogleAiGeminiChatModel generate(String apiKey, Object... params) { + Preconditions.checkArgument(params.length >= 1, "Expected at least 1 parameter, got " + params.length); + ChatCompletionRequest request = (ChatCompletionRequest) Objects.requireNonNull(params[0], + "ChatCompletionRequest is required"); + return newGeminiClient(apiKey, request); + } + + @Override + public ChatLanguageModel generateChat(String apiKey, LlmAsJudgeModelParameters modelParameters) { + GoogleAiGeminiChatModelBuilder modelBuilder = GoogleAiGeminiChatModel.builder() + .modelName(modelParameters.name()) + .apiKey(apiKey); + + Optional.ofNullable(llmProviderClientConfig.getConnectTimeout()) + .ifPresent(connectTimeout -> modelBuilder.timeout(connectTimeout.toJavaDuration())); + + Optional.ofNullable(modelParameters.temperature()).ifPresent(modelBuilder::temperature); + + return modelBuilder.build(); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiLlmServiceProvider.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiLlmServiceProvider.java new file mode 100644 index 0000000000..6dba7bb4ec --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiLlmServiceProvider.java @@ -0,0 +1,29 @@ +package com.comet.opik.infrastructure.llm.gemini; + +import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; +import com.comet.opik.api.LlmProvider; +import com.comet.opik.domain.llm.LlmProviderFactory; +import com.comet.opik.domain.llm.LlmProviderService; +import com.comet.opik.infrastructure.llm.LlmServiceProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; + +public class GeminiLlmServiceProvider implements LlmServiceProvider { + + private final GeminiClientGenerator clientGenerator; + + GeminiLlmServiceProvider(GeminiClientGenerator clientGenerator, LlmProviderFactory factory) { + this.clientGenerator = clientGenerator; + factory.register(LlmProvider.GEMINI, this); + } + + @Override + public LlmProviderService getService(String apiKey) { + return new LlmProviderGemini(clientGenerator, apiKey); + } + + @Override + public ChatLanguageModel getLanguageModel(String apiKey, + AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters modelParameters) { + return clientGenerator.generateChat(apiKey, modelParameters); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/GeminiModelName.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiModelName.java similarity index 92% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/GeminiModelName.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiModelName.java index a7a83d160f..48ccf36fdc 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/GeminiModelName.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiModelName.java @@ -1,4 +1,4 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm.gemini; import lombok.RequiredArgsConstructor; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiModule.java new file mode 100644 index 0000000000..89ced9add7 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/GeminiModule.java @@ -0,0 +1,28 @@ +package com.comet.opik.infrastructure.llm.gemini; + +import com.comet.opik.domain.llm.LlmProviderFactory; +import com.comet.opik.infrastructure.LlmProviderClientConfig; +import com.comet.opik.infrastructure.llm.LlmServiceProvider; +import com.google.inject.AbstractModule; +import com.google.inject.Provides; +import com.google.inject.Singleton; +import jakarta.inject.Named; +import lombok.NonNull; +import ru.vyarus.dropwizard.guice.module.yaml.bind.Config; + +public class GeminiModule extends AbstractModule { + + @Provides + @Singleton + public GeminiClientGenerator clientGenerator(@NonNull @Config("llmProviderClient") LlmProviderClientConfig config) { + return new GeminiClientGenerator(config); + } + + @Provides + @Singleton + @Named("gemini") + public LlmServiceProvider llmServiceProvider(@NonNull LlmProviderFactory llmProviderFactory, + @NonNull GeminiClientGenerator clientGenerator) { + return new GeminiLlmServiceProvider(clientGenerator, llmProviderFactory); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGemini.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/LlmProviderGemini.java similarity index 83% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGemini.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/LlmProviderGemini.java index 3ae0c2efa9..94b5b26c11 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGemini.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/LlmProviderGemini.java @@ -1,5 +1,7 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm.gemini; +import com.comet.opik.api.ChunkedResponseHandler; +import com.comet.opik.domain.llm.LlmProviderService; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; import io.dropwizard.jersey.errors.ErrorMessage; @@ -11,13 +13,13 @@ @RequiredArgsConstructor public class LlmProviderGemini implements LlmProviderService { - private final @NonNull LlmProviderClientGenerator llmProviderClientGenerator; + private final @NonNull GeminiClientGenerator llmProviderClientGenerator; private final @NonNull String apiKey; @Override public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) { var mapper = LlmProviderGeminiMapper.INSTANCE; - var response = llmProviderClientGenerator.newGeminiClient(apiKey, request) + var response = llmProviderClientGenerator.generate(apiKey, request) .generate(request.messages().stream().map(mapper::toChatMessage).toList()); return mapper.toChatCompletionResponse(request, response); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGeminiMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/LlmProviderGeminiMapper.java similarity index 98% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGeminiMapper.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/LlmProviderGeminiMapper.java index 65fc410a1c..a5fdfc0560 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderGeminiMapper.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/gemini/LlmProviderGeminiMapper.java @@ -1,4 +1,4 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm.gemini; import dev.ai4j.openai4j.chat.AssistantMessage; import dev.ai4j.openai4j.chat.ChatCompletionChoice; @@ -25,7 +25,7 @@ import java.util.List; @Mapper -public interface LlmProviderGeminiMapper { +interface LlmProviderGeminiMapper { String ERR_UNEXPECTED_ROLE = "unexpected role '%s'"; String ERR_ROLE_MSG_TYPE_MISMATCH = "role and message instance are not matching, role: '%s', instance: '%s'"; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/LlmProviderOpenAi.java similarity index 94% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/LlmProviderOpenAi.java index 60c902d97a..d1eb422c3e 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderOpenAi.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/LlmProviderOpenAi.java @@ -1,5 +1,6 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm.openai; +import com.comet.opik.domain.llm.LlmProviderService; import dev.ai4j.openai4j.OpenAiClient; import dev.ai4j.openai4j.OpenAiHttpException; import dev.ai4j.openai4j.chat.ChatCompletionRequest; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenAIClientGenerator.java similarity index 50% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenAIClientGenerator.java index 6671e534c1..56aa3189de 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenAIClientGenerator.java @@ -1,13 +1,9 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm.openai; -import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; import com.comet.opik.infrastructure.LlmProviderClientConfig; +import com.comet.opik.infrastructure.llm.LlmProviderClientGenerator; import dev.ai4j.openai4j.OpenAiClient; -import dev.ai4j.openai4j.chat.ChatCompletionRequest; -import dev.langchain4j.model.anthropic.internal.client.AnthropicClient; import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel; -import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel; import dev.langchain4j.model.openai.OpenAiChatModel; import lombok.NonNull; import lombok.RequiredArgsConstructor; @@ -15,34 +11,13 @@ import java.util.Optional; +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters; + @RequiredArgsConstructor -public class LlmProviderClientGenerator { - private static final int MAX_RETRIES = 1; +public class OpenAIClientGenerator implements LlmProviderClientGenerator { private final @NonNull LlmProviderClientConfig llmProviderClientConfig; - public AnthropicClient newAnthropicClient(@NonNull String apiKey) { - var anthropicClientBuilder = AnthropicClient.builder(); - Optional.ofNullable(llmProviderClientConfig.getAnthropicClient()) - .map(LlmProviderClientConfig.AnthropicClientConfig::url) - .filter(StringUtils::isNotEmpty) - .ifPresent(anthropicClientBuilder::baseUrl); - Optional.ofNullable(llmProviderClientConfig.getAnthropicClient()) - .map(LlmProviderClientConfig.AnthropicClientConfig::version) - .filter(StringUtils::isNotBlank) - .ifPresent(anthropicClientBuilder::version); - Optional.ofNullable(llmProviderClientConfig.getLogRequests()) - .ifPresent(anthropicClientBuilder::logRequests); - Optional.ofNullable(llmProviderClientConfig.getLogResponses()) - .ifPresent(anthropicClientBuilder::logResponses); - // anthropic client builder only receives one timeout variant - Optional.ofNullable(llmProviderClientConfig.getCallTimeout()) - .ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration())); - return anthropicClientBuilder - .apiKey(apiKey) - .build(); - } - public OpenAiClient newOpenAiClient(@NonNull String apiKey) { var openAiClientBuilder = OpenAiClient.builder(); Optional.ofNullable(llmProviderClientConfig.getOpenAiClient()) @@ -62,19 +37,7 @@ public OpenAiClient newOpenAiClient(@NonNull String apiKey) { .build(); } - public GoogleAiGeminiChatModel newGeminiClient(@NonNull String apiKey, @NonNull ChatCompletionRequest request) { - return LlmProviderGeminiMapper.INSTANCE.toGeminiChatModel(apiKey, request, - llmProviderClientConfig.getCallTimeout().toJavaDuration(), MAX_RETRIES); - } - - public GoogleAiGeminiStreamingChatModel newGeminiStreamingClient( - @NonNull String apiKey, @NonNull ChatCompletionRequest request) { - return LlmProviderGeminiMapper.INSTANCE.toGeminiStreamingChatModel(apiKey, request, - llmProviderClientConfig.getCallTimeout().toJavaDuration(), MAX_RETRIES); - } - - public ChatLanguageModel newOpenAiChatLanguageModel(String apiKey, - AutomationRuleEvaluatorLlmAsJudge.@NonNull LlmAsJudgeModelParameters modelParameters) { + public ChatLanguageModel newOpenAiChatLanguageModel(String apiKey, LlmAsJudgeModelParameters modelParameters) { var builder = OpenAiChatModel.builder() .modelName(modelParameters.name()) .apiKey(apiKey) @@ -93,4 +56,14 @@ public ChatLanguageModel newOpenAiChatLanguageModel(String apiKey, return builder.build(); } + + @Override + public OpenAiClient generate(@NonNull String apiKey, Object... params) { + return newOpenAiClient(apiKey); + } + + @Override + public ChatLanguageModel generateChat(@NonNull String apiKey, @NonNull LlmAsJudgeModelParameters modelParameters) { + return newOpenAiChatLanguageModel(apiKey, modelParameters); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenAILlmServiceProvider.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenAILlmServiceProvider.java new file mode 100644 index 0000000000..b5e0b810e2 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenAILlmServiceProvider.java @@ -0,0 +1,29 @@ +package com.comet.opik.infrastructure.llm.openai; + +import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; +import com.comet.opik.api.LlmProvider; +import com.comet.opik.domain.llm.LlmProviderFactory; +import com.comet.opik.domain.llm.LlmProviderService; +import com.comet.opik.infrastructure.llm.LlmServiceProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; + +class OpenAILlmServiceProvider implements LlmServiceProvider { + + private final OpenAIClientGenerator clientGenerator; + + OpenAILlmServiceProvider(OpenAIClientGenerator clientGenerator, LlmProviderFactory factory) { + this.clientGenerator = clientGenerator; + factory.register(LlmProvider.OPEN_AI, this); + } + + @Override + public LlmProviderService getService(String apiKey) { + return new LlmProviderOpenAi(clientGenerator.newOpenAiClient(apiKey)); + } + + @Override + public ChatLanguageModel getLanguageModel(String apiKey, + AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters modelParameters) { + return clientGenerator.newOpenAiChatLanguageModel(apiKey, modelParameters); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenAIModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenAIModule.java new file mode 100644 index 0000000000..50e3043993 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenAIModule.java @@ -0,0 +1,28 @@ +package com.comet.opik.infrastructure.llm.openai; + +import com.comet.opik.domain.llm.LlmProviderFactory; +import com.comet.opik.infrastructure.LlmProviderClientConfig; +import com.comet.opik.infrastructure.llm.LlmServiceProvider; +import com.google.inject.AbstractModule; +import com.google.inject.Provides; +import com.google.inject.Singleton; +import jakarta.inject.Named; +import lombok.NonNull; +import ru.vyarus.dropwizard.guice.module.yaml.bind.Config; + +public class OpenAIModule extends AbstractModule { + + @Provides + @Singleton + public OpenAIClientGenerator clientGenerator(@NonNull @Config("llmProviderClient") LlmProviderClientConfig config) { + return new OpenAIClientGenerator(config); + } + + @Provides + @Singleton + @Named("openai") + public LlmServiceProvider llmServiceProvider(@NonNull LlmProviderFactory llmProviderFactory, + @NonNull OpenAIClientGenerator clientGenerator) { + return new OpenAILlmServiceProvider(clientGenerator, llmProviderFactory); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenaiModelName.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenaiModelName.java similarity index 96% rename from apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenaiModelName.java rename to apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenaiModelName.java index c014c40ced..2b57219ed9 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenaiModelName.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/llm/openai/OpenaiModelName.java @@ -1,4 +1,4 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm.openai; import lombok.RequiredArgsConstructor; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/ClickHouseAppender.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/ClickHouseAppender.java index b24f3b4114..533aae0c66 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/ClickHouseAppender.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/ClickHouseAppender.java @@ -1,5 +1,6 @@ package com.comet.opik.infrastructure.log; +import ch.qos.logback.classic.LoggerContext; import ch.qos.logback.classic.spi.ILoggingEvent; import ch.qos.logback.core.AppenderBase; import com.comet.opik.domain.UserLog; @@ -17,6 +18,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import static java.util.stream.Collectors.groupingBy; @@ -27,19 +29,16 @@ class ClickHouseAppender extends AppenderBase { private static ClickHouseAppender instance; - public static synchronized void init(@NonNull UserLogTableFactory userLogTableFactory, int batchSize, - @NonNull Duration flushIntervalDuration) { + public static synchronized ClickHouseAppender init(@NonNull UserLogTableFactory userLogTableFactory, int batchSize, + @NonNull Duration flushIntervalDuration, @NonNull LoggerContext context) { if (instance == null) { - setInstance(new ClickHouseAppender(userLogTableFactory, flushIntervalDuration, batchSize)); + ClickHouseAppender appender = new ClickHouseAppender(userLogTableFactory, flushIntervalDuration, batchSize); + setInstance(appender); + appender.setContext(context); instance.start(); } - } - public static synchronized ClickHouseAppender getInstance() { - if (instance == null) { - throw new IllegalStateException("ClickHouseAppender is not initialized"); - } return instance; } @@ -52,17 +51,15 @@ private static synchronized void setInstance(ClickHouseAppender instance) { private final int batchSize; private volatile boolean running = true; - private BlockingQueue logQueue; - private ScheduledExecutorService scheduler; + private final BlockingQueue logQueue = new LinkedBlockingQueue<>(); + private AtomicReference scheduler = new AtomicReference<>( + Executors.newSingleThreadScheduledExecutor()); @Override public void start() { - logQueue = new LinkedBlockingQueue<>(); - scheduler = Executors.newSingleThreadScheduledExecutor(); - // Background flush thread - scheduler.scheduleAtFixedRate(this::flushLogs, flushIntervalDuration.toMillis(), + scheduler.get().scheduleAtFixedRate(this::flushLogs, flushIntervalDuration.toMillis(), flushIntervalDuration.toMillis(), TimeUnit.MILLISECONDS); super.start(); @@ -113,7 +110,7 @@ protected void append(ILoggingEvent event) { } if (logQueue.size() >= batchSize) { - scheduler.execute(this::flushLogs); + scheduler.get().execute(this::flushLogs); } } @@ -123,20 +120,23 @@ public void stop() { super.stop(); flushLogs(); setInstance(null); - scheduler.shutdown(); + scheduler.get().shutdown(); awaitTermination(); + logQueue.clear(); + scheduler.set(Executors.newSingleThreadScheduledExecutor()); } private void awaitTermination() { try { - if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) { - scheduler.shutdownNow(); - if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) { // Final attempt + if (!scheduler.get().awaitTermination(5, TimeUnit.SECONDS)) { + scheduler.get().shutdownNow(); + if (!scheduler.get().awaitTermination(5, TimeUnit.SECONDS)) { // Final attempt log.error("ClickHouseAppender did not terminate"); } } } catch (InterruptedException ex) { - scheduler.shutdownNow(); + Thread.currentThread().interrupt(); + scheduler.get().shutdownNow(); log.warn("ClickHouseAppender interrupted while waiting for termination", ex); } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/UserFacingLoggingFactory.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/UserFacingLoggingFactory.java index 57c23e0f93..9d62dd2a47 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/UserFacingLoggingFactory.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/UserFacingLoggingFactory.java @@ -21,10 +21,8 @@ public static synchronized void init(@NonNull ConnectionFactory connectionFactor @NonNull Duration flushIntervalSeconds) { UserLogTableFactory tableFactory = UserLogTableFactory.getInstance(connectionFactory); - ClickHouseAppender.init(tableFactory, batchSize, flushIntervalSeconds); - - ClickHouseAppender clickHouseAppender = ClickHouseAppender.getInstance(); - clickHouseAppender.setContext(CONTEXT); + ClickHouseAppender clickHouseAppender = ClickHouseAppender.init(tableFactory, batchSize, flushIntervalSeconds, + CONTEXT); asyncAppender = new AsyncAppender(); asyncAppender.setContext(CONTEXT); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/tables/AutomationRuleEvaluatorLogDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/tables/AutomationRuleEvaluatorLogDAO.java deleted file mode 100644 index 83f2a33307..0000000000 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/tables/AutomationRuleEvaluatorLogDAO.java +++ /dev/null @@ -1,83 +0,0 @@ -package com.comet.opik.infrastructure.log.tables; - -import ch.qos.logback.classic.spi.ILoggingEvent; -import com.comet.opik.utils.TemplateUtils; -import io.r2dbc.spi.ConnectionFactory; -import io.r2dbc.spi.Statement; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.stringtemplate.v4.ST; -import reactor.core.publisher.Mono; - -import java.util.List; -import java.util.Optional; - -import static com.comet.opik.infrastructure.log.tables.UserLogTableFactory.UserLogTableDAO; -import static com.comet.opik.utils.TemplateUtils.getQueryItemPlaceHolder; - -@RequiredArgsConstructor -@Slf4j -class AutomationRuleEvaluatorLogDAO implements UserLogTableDAO { - - private final ConnectionFactory factory; - - private static final String INSERT_STATEMENT = """ - INSERT INTO automation_rule_evaluator_logs (timestamp, level, workspace_id, rule_id, message, markers) - VALUES , 9), - :level, - :workspace_id, - :rule_id, - :message, - mapFromArrays(:marker_keys, :marker_values) - ) - , - }> - ; - """; - - @Override - public Mono saveAll(List events) { - return Mono.from(factory.create()) - .flatMapMany(connection -> { - var template = new ST(INSERT_STATEMENT); - List queryItems = getQueryItemPlaceHolder(events.size()); - - template.add("items", queryItems); - - Statement statement = connection.createStatement(template.render()); - - for (int i = 0; i < events.size(); i++) { - ILoggingEvent event = events.get(i); - - String logLevel = event.getLevel().toString(); - String workspaceId = Optional.ofNullable(event.getMDCPropertyMap().get("workspace_id")) - .orElseThrow(() -> failWithMessage("workspace_id is not set")); - String traceId = Optional.ofNullable(event.getMDCPropertyMap().get("trace_id")) - .orElseThrow(() -> failWithMessage("trace_id is not set")); - String ruleId = Optional.ofNullable(event.getMDCPropertyMap().get("rule_id")) - .orElseThrow(() -> failWithMessage("rule_id is not set")); - - statement - .bind("timestamp" + i, event.getInstant().toString()) - .bind("level" + i, logLevel) - .bind("workspace_id" + i, workspaceId) - .bind("rule_id" + i, ruleId) - .bind("message" + i, event.getFormattedMessage()) - .bind("marker_keys" + i, new String[]{"trace_id"}) - .bind("marker_values" + i, new String[]{traceId}); - } - - return statement.execute(); - }) - .collectList() - .then(); - } - - private IllegalStateException failWithMessage(String message) { - log.error(message); - return new IllegalStateException(message); - } - -} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/tables/UserLogTableFactory.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/tables/UserLogTableFactory.java index 9bce54ce50..5900326a46 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/tables/UserLogTableFactory.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/log/tables/UserLogTableFactory.java @@ -1,13 +1,14 @@ package com.comet.opik.infrastructure.log.tables; import ch.qos.logback.classic.spi.ILoggingEvent; +import com.comet.opik.domain.AutomationRuleEvaluatorLogsDAO; import com.comet.opik.domain.UserLog; import io.r2dbc.spi.ConnectionFactory; import lombok.NonNull; -import lombok.RequiredArgsConstructor; import reactor.core.publisher.Mono; import java.util.List; +import java.util.Map; public interface UserLogTableFactory { @@ -23,15 +24,17 @@ interface UserLogTableDAO { } -@RequiredArgsConstructor class UserLogTableFactoryImpl implements UserLogTableFactory { - private final ConnectionFactory factory; + private final Map daoMap; + + UserLogTableFactoryImpl(@NonNull ConnectionFactory factory) { + daoMap = Map.of( + UserLog.AUTOMATION_RULE_EVALUATOR, AutomationRuleEvaluatorLogsDAO.create(factory)); + } @Override public UserLogTableDAO getDAO(@NonNull UserLog userLog) { - return switch (userLog) { - case AUTOMATION_RULE_EVALUATOR -> new AutomationRuleEvaluatorLogDAO(factory); - }; + return daoMap.get(userLog); } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java index 24de1f7737..5c51ce6a0c 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestDropwizardAppExtensionUtils.java @@ -6,6 +6,7 @@ import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.google.common.eventbus.EventBus; import com.google.inject.AbstractModule; +import com.google.inject.Module; import lombok.Builder; import lombok.experimental.UtilityClass; import org.apache.commons.collections4.CollectionUtils; @@ -48,6 +49,7 @@ public record AppContextConfig( EventBus mockEventBus, boolean corsEnabled, List customConfigs, + List> disableModules, List modules) { } @@ -129,9 +131,9 @@ public static TestDropwizardAppExtension newTestDropwizardAppExtension(AppContex GuiceyConfigurationHook hook = injector -> { injector.modulesOverride(TestHttpClientUtils.testAuthModule()); - Optional.ofNullable(appContextConfig.modules) + Optional.ofNullable(appContextConfig.disableModules) .orElse(List.of()) - .forEach(injector::modulesOverride); + .forEach(injector::disableModules); if (appContextConfig.mockEventBus() != null) { injector.modulesOverride(new EventModule() { @@ -154,6 +156,9 @@ public void run(GuiceyEnvironment environment) { } }); + Optional.ofNullable(appContextConfig.modules) + .orElse(List.of()) + .forEach(injector::modulesOverride); }; if (appContextConfig.redisUrl() != null) { diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/AutomationRuleEvaluatorResourceClient.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/AutomationRuleEvaluatorResourceClient.java index abe0f9339d..5dea375edc 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/AutomationRuleEvaluatorResourceClient.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/AutomationRuleEvaluatorResourceClient.java @@ -2,6 +2,7 @@ import com.comet.opik.api.AutomationRuleEvaluator; import com.comet.opik.api.AutomationRuleEvaluatorUpdate; +import com.comet.opik.api.LogItem.LogPage; import com.comet.opik.api.resources.utils.TestHttpClientUtils; import com.comet.opik.api.resources.utils.TestUtils; import jakarta.ws.rs.HttpMethod; @@ -70,4 +71,21 @@ public void updateEvaluator(UUID evaluatorId, UUID projectId, String workspaceNa } } } + + public LogPage getEvaluatorLogs(UUID evaluatorId, UUID projectId, String workspaceName, + String apiKey) { + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI, projectId)) + .path(evaluatorId.toString()) + .path("logs") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + + return actualResponse.readEntity(LogPage.class); + } + } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java index e8de4415da..c0f26c949d 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java @@ -17,8 +17,8 @@ import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.api.resources.utils.resources.AutomationRuleEvaluatorResourceClient; import com.comet.opik.api.resources.utils.resources.ProjectResourceClient; -import com.comet.opik.domain.ChatCompletionService; import com.comet.opik.domain.FeedbackScoreService; +import com.comet.opik.domain.llm.ChatCompletionService; import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; import com.comet.opik.podam.PodamFactoryUtils; import com.comet.opik.utils.JsonUtils; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/AuthenticationResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/AuthenticationResourceTest.java index dabc9eae97..710e490f7f 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/AuthenticationResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/AuthenticationResourceTest.java @@ -10,7 +10,6 @@ import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; -import com.comet.opik.podam.PodamFactoryUtils; import com.github.tomakehurst.wiremock.client.WireMock; import com.redis.testcontainers.RedisContainer; import jakarta.ws.rs.client.Entity; @@ -32,7 +31,6 @@ import org.testcontainers.lifecycle.Startables; import ru.vyarus.dropwizard.guice.test.ClientSupport; import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; -import uk.co.jemos.podam.api.PodamFactory; import java.sql.SQLException; import java.util.UUID; @@ -65,7 +63,7 @@ class AuthenticationResourceTest { private static final String UNAUTHORISED_WORKSPACE_NAME = UUID.randomUUID().toString(); @RegisterExtension - private static final TestDropwizardAppExtension app; + private static final TestDropwizardAppExtension APP; private static final WireMockUtils.WireMockRuntime wireMock; @@ -77,12 +75,10 @@ class AuthenticationResourceTest { DatabaseAnalyticsFactory databaseAnalyticsFactory = ClickHouseContainerUtils .newDatabaseAnalyticsFactory(CLICKHOUSE_CONTAINER, DATABASE_NAME); - app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + APP = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( MYSQL.getJdbcUrl(), databaseAnalyticsFactory, wireMock.runtimeInfo(), REDIS.getRedisURI()); } - private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); - private String baseURI; private ClientSupport client; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/AutomationRuleEvaluatorsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/AutomationRuleEvaluatorsResourceTest.java index c95b1664ec..d5f8f3ea9b 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/AutomationRuleEvaluatorsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/AutomationRuleEvaluatorsResourceTest.java @@ -4,21 +4,36 @@ import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; import com.comet.opik.api.AutomationRuleEvaluatorUpdate; import com.comet.opik.api.BatchDelete; +import com.comet.opik.api.LogItem; +import com.comet.opik.api.Trace; import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; import com.comet.opik.api.resources.utils.ClientSupportUtils; import com.comet.opik.api.resources.utils.MigrationUtils; import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils.AppContextConfig; import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.api.resources.utils.resources.AutomationRuleEvaluatorResourceClient; +import com.comet.opik.api.resources.utils.resources.ProjectResourceClient; +import com.comet.opik.api.resources.utils.resources.TraceResourceClient; +import com.comet.opik.domain.llm.LlmProviderFactory; +import com.comet.opik.infrastructure.llm.LlmModule; import com.comet.opik.podam.PodamFactoryUtils; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.github.tomakehurst.wiremock.client.WireMock; +import com.google.inject.AbstractModule; import com.redis.testcontainers.RedisContainer; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.response.ChatResponse; import jakarta.ws.rs.HttpMethod; import jakarta.ws.rs.client.Entity; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import org.apache.http.HttpStatus; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -31,19 +46,25 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import org.testcontainers.clickhouse.ClickHouseContainer; import org.testcontainers.containers.MySQLContainer; import org.testcontainers.lifecycle.Startables; +import org.testcontainers.shaded.org.awaitility.Awaitility; import ru.vyarus.dropwizard.guice.test.ClientSupport; import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; import uk.co.jemos.podam.api.PodamFactory; +import java.time.Instant; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; import static com.comet.opik.api.resources.utils.TestHttpClientUtils.UNAUTHORIZED_RESPONSE; import static com.comet.opik.infrastructure.auth.RequestContext.SESSION_COOKIE; import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; @@ -55,6 +76,9 @@ import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.when; @TestInstance(TestInstance.Lifecycle.PER_CLASS) @DisplayName("Automation Rule Evaluators Resource Test") @@ -62,6 +86,53 @@ class AutomationRuleEvaluatorsResourceTest { private static final String URL_TEMPLATE = "%s/v1/private/automations/projects/%s/evaluators/"; + private static final String messageToTest = "Summary: {{summary}}\\nInstruction: {{instruction}}\\n\\n"; + private static final String testEvaluator = """ + { + "model": { "name": "gpt-4o", "temperature": 0.3 }, + "messages": [ + { "role": "USER", "content": "%s" }, + { "role": "SYSTEM", "content": "You're a helpful AI, be cordial." } + ], + "variables": { + "summary": "input.questions.question1", + "instruction": "output.output", + "nonUsed": "input.questions.question2", + "toFail1": "metadata.nonexistent.path" + }, + "schema": [ + { "name": "Relevance", "type": "INTEGER", "description": "Relevance of the summary" }, + { "name": "Conciseness", "type": "DOUBLE", "description": "Conciseness of the summary" }, + { "name": "Technical Accuracy", "type": "BOOLEAN", "description": "Technical accuracy of the summary" } + ] + } + """ + .formatted(messageToTest).trim(); + + private static final String summaryStr = "What was the approach to experimenting with different data mixtures?"; + private static final String outputStr = "The study employed a systematic approach to experiment with varying data mixtures by manipulating the proportions and sources of datasets used for model training."; + private static final String input = """ + { + "questions": { + "question1": "%s", + "question2": "Whatever, we wont use it anyway" + }, + "pdf_url": "https://arxiv.org/pdf/2406.04744", + "title": "CRAG -- Comprehensive RAG Benchmark" + } + """.formatted(summaryStr).trim(); + private static final String output = """ + { + "output": "%s" + } + """.formatted(outputStr).trim(); + + private static final String validAiMsgTxt = "{\"Relevance\":{\"score\":5,\"reason\":\"The summary directly addresses the approach taken in the study by mentioning the systematic experimentation with varying data mixtures and the manipulation of proportions and sources.\"}," + + + "\"Conciseness\":{\"score\":4,\"reason\":\"The summary is mostly concise but could be slightly more streamlined by removing redundant phrases.\"}," + + + "\"Technical Accuracy\":{\"score\":0,\"reason\":\"The summary accurately describes the experimental approach involving data mixtures, proportions, and sources, reflecting the technical details of the study.\"}}"; + private static final String USER = UUID.randomUUID().toString(); private static final String API_KEY = UUID.randomUUID().toString(); private static final String WORKSPACE_ID = UUID.randomUUID().toString(); @@ -71,18 +142,37 @@ class AutomationRuleEvaluatorsResourceTest { private static final MySQLContainer MYSQL = MySQLContainerUtils.newMySQLContainer(); + private static final ClickHouseContainer CLICKHOUSE = ClickHouseContainerUtils.newClickHouseContainer(); + @RegisterExtension - private static final TestDropwizardAppExtension app; + private static TestDropwizardAppExtension APP; private static final WireMockUtils.WireMockRuntime wireMock; static { - Startables.deepStart(REDIS, MYSQL).join(); + Startables.deepStart(REDIS, MYSQL, CLICKHOUSE).join(); wireMock = WireMockUtils.startWireMock(); - app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension(MYSQL.getJdbcUrl(), null, - wireMock.runtimeInfo(), REDIS.getRedisURI()); + var databaseAnalyticsFactory = ClickHouseContainerUtils.newDatabaseAnalyticsFactory(CLICKHOUSE, DATABASE_NAME); + + APP = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + AppContextConfig.builder() + .jdbcUrl(MYSQL.getJdbcUrl()) + .databaseAnalyticsFactory(databaseAnalyticsFactory) + .redisUrl(REDIS.getRedisURI()) + .runtimeInfo(wireMock.runtimeInfo()) + .disableModules(List.of(LlmModule.class)) + .modules(List.of(new AbstractModule() { + + @Override + public void configure() { + bind(LlmProviderFactory.class) + .toInstance(Mockito.mock(LlmProviderFactory.class, Mockito.RETURNS_DEEP_STUBS)); + } + + })) + .build()); } private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); @@ -90,6 +180,8 @@ class AutomationRuleEvaluatorsResourceTest { private String baseURI; private ClientSupport client; private AutomationRuleEvaluatorResourceClient evaluatorsResourceClient; + private TraceResourceClient traceResourceClient; + private ProjectResourceClient projectResourceClient; @BeforeAll void setUpAll(ClientSupport client, Jdbi jdbi) { @@ -104,6 +196,8 @@ void setUpAll(ClientSupport client, Jdbi jdbi) { mockTargetWorkspace(API_KEY, WORKSPACE_NAME, WORKSPACE_ID); this.evaluatorsResourceClient = new AutomationRuleEvaluatorResourceClient(this.client, baseURI); + this.traceResourceClient = new TraceResourceClient(this.client, baseURI); + this.projectResourceClient = new ProjectResourceClient(this.client, baseURI, factory); } private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { @@ -430,6 +524,107 @@ void deleteProjectAutomationRuleEvaluators__whenApiKeyIsPresent__thenReturnPrope } } } + + @ParameterizedTest + @MethodSource("credentials") + @DisplayName("get logs per rule evaluators: when api key is present, then return proper response") + void getLogsPerRuleEvaluators__whenSessionTokenIsPresent__thenReturnProperResponse( + String apikey, + boolean isAuthorized, + LlmProviderFactory llmProviderFactory) throws JsonProcessingException { + + ChatResponse chatResponse = ChatResponse.builder() + .aiMessage(AiMessage.from(validAiMsgTxt)) + .build(); + + when(llmProviderFactory.getLanguageModel(anyString(), any()) + .chat(any())) + .thenAnswer(invocationOnMock -> chatResponse); + + String projectName = UUID.randomUUID().toString(); + + String workspaceName = "workspace-" + UUID.randomUUID(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(okApikey, workspaceName, workspaceId); + + ObjectMapper mapper = new ObjectMapper(); + + var projectId = projectResourceClient.createProject(projectName, okApikey, workspaceName); + + var evaluator = factory.manufacturePojo(AutomationRuleEvaluatorLlmAsJudge.class).toBuilder() + .id(null) + .code(mapper.readValue(testEvaluator, AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode.class)) + .samplingRate(1f) + .build(); + + var trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .input(mapper.readTree(input)) + .output(mapper.readTree(output)) + .build(); + + var id = evaluatorsResourceClient.createEvaluator(evaluator, projectId, workspaceName, okApikey); + + Instant startTime = Instant.now(); + traceResourceClient.createTrace(trace, okApikey, workspaceName); + + Awaitility.await().untilAsserted(() -> { + + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI, projectId)) + .path(id.toString()) + .path("logs") + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apikey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + if (isAuthorized) { + assertLogResponse(actualResponse, startTime, id, trace); + } else { + assertThat(actualResponse.getStatusInfo().getStatusCode()) + .isEqualTo(HttpStatus.SC_UNAUTHORIZED); + assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) + .isEqualTo(UNAUTHORIZED_RESPONSE); + } + } + }); + } + } + + private static void assertLogResponse(Response actualResponse, Instant startTime, UUID id, Trace trace) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_OK); + assertThat(actualResponse.hasEntity()).isTrue(); + + var actualEntity = actualResponse.readEntity(LogItem.LogPage.class); + + assertThat(actualEntity.content()).hasSize(4); + assertThat(actualEntity.total()).isEqualTo(4); + assertThat(actualEntity.size()).isEqualTo(4); + assertThat(actualEntity.page()).isEqualTo(1); + + assertThat(actualEntity.content()) + .allSatisfy(log -> { + assertThat(log.timestamp()).isBetween(startTime, Instant.now()); + assertThat(log.ruleId()).isEqualTo(id); + assertThat(log.markers()).isEqualTo(Map.of("trace_id", trace.id().toString())); + assertThat(log.level()).isEqualTo(LogItem.LogLevel.INFO); + }); + + assertThat(actualEntity.content()) + .anyMatch(log -> log.message() + .matches("Scores for traceId '.*' stored successfully:\\n\\n.*")); + + assertThat(actualEntity.content()) + .anyMatch(log -> log.message().matches("Received response for traceId '.*':\\n\\n.*")); + + assertThat(actualEntity.content()) + .anyMatch(log -> log.message().matches( + "(?s)Sending traceId '([^']*)' to LLM using the following input:\\n\\n.*")); + + assertThat(actualEntity.content()) + .anyMatch(log -> log.message().matches("Evaluating traceId '.*' sampled by rule '.*'")); } @Nested @@ -710,6 +905,68 @@ void deleteProjectAutomationRuleEvaluators__whenSessionTokenIsPresent__thenRetur } } } - } + @ParameterizedTest + @MethodSource("credentials") + @DisplayName("get logs per rule evaluators: when api key is present, then return proper response") + void getLogsPerRuleEvaluators__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, + boolean isAuthorized, + String workspaceName, + LlmProviderFactory llmProviderFactory) throws JsonProcessingException { + + ChatResponse chatResponse = ChatResponse.builder() + .aiMessage(AiMessage.from(validAiMsgTxt)) + .build(); + + when(llmProviderFactory.getLanguageModel(anyString(), any()) + .chat(any())) + .thenAnswer(invocationOnMock -> chatResponse); + + String projectName = UUID.randomUUID().toString(); + + ObjectMapper mapper = new ObjectMapper(); + + var projectId = projectResourceClient.createProject(projectName, API_KEY, WORKSPACE_NAME); + + var evaluator = factory.manufacturePojo(AutomationRuleEvaluatorLlmAsJudge.class).toBuilder() + .id(null) + .code(mapper.readValue(testEvaluator, AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode.class)) + .samplingRate(1f) + .build(); + + var trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .input(mapper.readTree(input)) + .output(mapper.readTree(output)) + .build(); + + var id = evaluatorsResourceClient.createEvaluator(evaluator, projectId, WORKSPACE_NAME, API_KEY); + + Instant startTime = Instant.now(); + traceResourceClient.createTrace(trace, API_KEY, WORKSPACE_NAME); + + Awaitility.await().untilAsserted(() -> { + + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI, projectId)) + .path(id.toString()) + .path("logs") + .request() + .cookie(SESSION_COOKIE, sessionToken) + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + if (isAuthorized) { + assertLogResponse(actualResponse, startTime, id, trace); + } else { + assertThat(actualResponse.getStatusInfo().getStatusCode()) + .isEqualTo(HttpStatus.SC_UNAUTHORIZED); + assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) + .isEqualTo(UNAUTHORIZED_RESPONSE); + } + } + }); + } + } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java index 8c3d5cc830..3d9f529957 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java @@ -11,9 +11,10 @@ import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.api.resources.utils.resources.ChatCompletionsClient; import com.comet.opik.api.resources.utils.resources.LlmProviderApiKeyResourceClient; -import com.comet.opik.domain.llmproviders.AnthropicModelName; -import com.comet.opik.domain.llmproviders.GeminiModelName; -import com.comet.opik.domain.llmproviders.OpenaiModelName; +import com.comet.opik.domain.llm.LlmProviderFactory; +import com.comet.opik.infrastructure.llm.antropic.AnthropicModelName; +import com.comet.opik.infrastructure.llm.gemini.GeminiModelName; +import com.comet.opik.infrastructure.llm.openai.OpenaiModelName; import com.comet.opik.podam.PodamFactoryUtils; import com.redis.testcontainers.RedisContainer; import dev.ai4j.openai4j.chat.ChatCompletionRequest; @@ -41,9 +42,9 @@ import java.util.UUID; import java.util.stream.Stream; -import static com.comet.opik.domain.ChatCompletionService.ERROR_EMPTY_MESSAGES; -import static com.comet.opik.domain.ChatCompletionService.ERROR_NO_COMPLETION_TOKENS; -import static com.comet.opik.domain.llmproviders.LlmProviderFactory.ERROR_MODEL_NOT_SUPPORTED; +import static com.comet.opik.domain.llm.ChatCompletionService.ERROR_EMPTY_MESSAGES; +import static com.comet.opik.domain.llm.ChatCompletionService.ERROR_NO_COMPLETION_TOKENS; +import static com.comet.opik.domain.llm.LlmProviderFactory.ERROR_MODEL_NOT_SUPPORTED; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assumptions.assumeThat; import static org.junit.jupiter.api.Named.named; @@ -182,7 +183,7 @@ void createReturnsBadRequestWhenModelIsInvalid(String model) { assertThat(errorMessage.getCode()).isEqualTo(HttpStatus.SC_BAD_REQUEST); assertThat(errorMessage.getMessage()) - .containsIgnoringCase(ERROR_MODEL_NOT_SUPPORTED.formatted(model)); + .containsIgnoringCase(LlmProviderFactory.ERROR_MODEL_NOT_SUPPORTED.formatted(model)); } @ParameterizedTest diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/llm/LlmProviderFactoryTest.java similarity index 65% rename from apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java rename to apps/opik-backend/src/test/java/com/comet/opik/infrastructure/llm/LlmProviderFactoryTest.java index 5d7a2725f6..6dd9279eb1 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderFactoryTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/llm/LlmProviderFactoryTest.java @@ -1,11 +1,21 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm; import com.comet.opik.api.LlmProvider; import com.comet.opik.api.ProviderApiKey; import com.comet.opik.domain.LlmProviderApiKeyService; +import com.comet.opik.domain.llm.LlmProviderService; import com.comet.opik.infrastructure.EncryptionUtils; import com.comet.opik.infrastructure.LlmProviderClientConfig; import com.comet.opik.infrastructure.OpikConfiguration; +import com.comet.opik.infrastructure.llm.antropic.AnthropicClientGenerator; +import com.comet.opik.infrastructure.llm.antropic.AnthropicModelName; +import com.comet.opik.infrastructure.llm.antropic.AnthropicModule; +import com.comet.opik.infrastructure.llm.gemini.GeminiClientGenerator; +import com.comet.opik.infrastructure.llm.gemini.GeminiModelName; +import com.comet.opik.infrastructure.llm.gemini.GeminiModule; +import com.comet.opik.infrastructure.llm.openai.OpenAIClientGenerator; +import com.comet.opik.infrastructure.llm.openai.OpenAIModule; +import com.comet.opik.infrastructure.llm.openai.OpenaiModelName; import com.fasterxml.jackson.databind.ObjectMapper; import io.dropwizard.configuration.ConfigurationException; import io.dropwizard.configuration.FileConfigurationSourceProvider; @@ -50,7 +60,7 @@ void setUpAll() throws ConfigurationException, IOException { @ParameterizedTest @MethodSource - void testGetService(String model, LlmProvider llmProvider, Class providerClass) { + void testGetService(String model, LlmProvider llmProvider, String providerClass) { // setup LlmProviderApiKeyService llmProviderApiKeyService = mock(LlmProviderApiKeyService.class); String workspaceId = UUID.randomUUID().toString(); @@ -67,22 +77,34 @@ void testGetService(String model, LlmProvider llmProvider, Class testGetService() { var openAiModels = EnumUtils.getEnumList(OpenaiModelName.class).stream() - .map(model -> arguments(model.toString(), LlmProvider.OPEN_AI, LlmProviderOpenAi.class)); + .map(model -> arguments(model.toString(), LlmProvider.OPEN_AI, "LlmProviderOpenAi")); var anthropicModels = EnumUtils.getEnumList(AnthropicModelName.class).stream() - .map(model -> arguments(model.toString(), LlmProvider.ANTHROPIC, LlmProviderAnthropic.class)); + .map(model -> arguments(model.toString(), LlmProvider.ANTHROPIC, "LlmProviderAnthropic")); var geminiModels = EnumUtils.getEnumList(GeminiModelName.class).stream() - .map(model -> arguments(model.toString(), LlmProvider.GEMINI, LlmProviderGemini.class)); + .map(model -> arguments(model.toString(), LlmProvider.GEMINI, "LlmProviderGemini")); return Stream.of(openAiModels, anthropicModels, geminiModels).flatMap(Function.identity()); } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/llm/antropic/AnthropicMappersTest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/llm/antropic/AnthropicMappersTest.java new file mode 100644 index 0000000000..0c610fa710 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/llm/antropic/AnthropicMappersTest.java @@ -0,0 +1,73 @@ +package com.comet.opik.infrastructure.llm.antropic; + +import com.comet.opik.podam.PodamFactoryUtils; +import dev.ai4j.openai4j.chat.AssistantMessage; +import dev.ai4j.openai4j.chat.ChatCompletionChoice; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.Role; +import dev.ai4j.openai4j.shared.Usage; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import uk.co.jemos.podam.api.PodamFactory; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +public class AnthropicMappersTest { + private final PodamFactory podamFactory = PodamFactoryUtils.newPodamFactory(); + + @Nested + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class AnthropicMappers { + @Test + void testToResponse() { + var response = podamFactory.manufacturePojo(AnthropicCreateMessageResponse.class); + + var actual = LlmProviderAnthropicMapper.INSTANCE.toResponse(response); + assertThat(actual).isNotNull(); + assertThat(actual.id()).isEqualTo(response.id); + assertThat(actual.choices()).isEqualTo(List.of(ChatCompletionChoice.builder() + .message(AssistantMessage.builder() + .name(response.content.getFirst().name) + .content(response.content.getFirst().text) + .build()) + .finishReason(response.stopReason) + .build())); + assertThat(actual.usage()).isEqualTo(Usage.builder() + .promptTokens(response.usage.inputTokens) + .completionTokens(response.usage.outputTokens) + .totalTokens(response.usage.inputTokens + response.usage.outputTokens) + .build()); + } + + @Test + void toCreateMessage() { + var request = podamFactory.manufacturePojo(ChatCompletionRequest.class); + + AnthropicCreateMessageRequest actual = LlmProviderAnthropicMapper.INSTANCE + .toCreateMessageRequest(request); + + assertThat(actual).isNotNull(); + assertThat(actual.model).isEqualTo(request.model()); + assertThat(actual.stream).isEqualTo(request.stream()); + assertThat(actual.temperature).isEqualTo(request.temperature()); + assertThat(actual.topP).isEqualTo(request.topP()); + assertThat(actual.stopSequences).isEqualTo(request.stop()); + assertThat(actual.messages).usingRecursiveComparison().ignoringCollectionOrder().isEqualTo( + request.messages().stream() + .filter(message -> List.of(Role.USER, Role.ASSISTANT).contains(message.role())) + .map(LlmProviderAnthropicMapper.INSTANCE::mapToAnthropicMessage) + .toList()); + assertThat(actual.system).usingRecursiveComparison().ignoringCollectionOrder().isEqualTo( + request.messages().stream() + .filter(message -> message.role() == Role.SYSTEM) + .map(LlmProviderAnthropicMapper.INSTANCE::mapToSystemMessage) + .toList()); + } + } + +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderClientsMappersTest.java b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/llm/gemini/GeminiMappersTest.java similarity index 54% rename from apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderClientsMappersTest.java rename to apps/opik-backend/src/test/java/com/comet/opik/infrastructure/llm/gemini/GeminiMappersTest.java index 37f073194b..6b795a6314 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/llmproviders/LlmProviderClientsMappersTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/infrastructure/llm/gemini/GeminiMappersTest.java @@ -1,18 +1,15 @@ -package com.comet.opik.domain.llmproviders; +package com.comet.opik.infrastructure.llm.gemini; import com.comet.opik.podam.PodamFactoryUtils; import dev.ai4j.openai4j.chat.AssistantMessage; import dev.ai4j.openai4j.chat.ChatCompletionChoice; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.Message; -import dev.ai4j.openai4j.chat.Role; import dev.ai4j.openai4j.shared.Usage; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; -import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest; -import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; @@ -31,59 +28,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.params.provider.Arguments.arguments; -public class LlmProviderClientsMappersTest { +public class GeminiMappersTest { private final PodamFactory podamFactory = PodamFactoryUtils.newPodamFactory(); - @Nested - @TestInstance(TestInstance.Lifecycle.PER_CLASS) - class AnthropicMappers { - @Test - void testToResponse() { - var response = podamFactory.manufacturePojo(AnthropicCreateMessageResponse.class); - - var actual = LlmProviderAnthropicMapper.INSTANCE.toResponse(response); - assertThat(actual).isNotNull(); - assertThat(actual.id()).isEqualTo(response.id); - assertThat(actual.choices()).isEqualTo(List.of(ChatCompletionChoice.builder() - .message(AssistantMessage.builder() - .name(response.content.getFirst().name) - .content(response.content.getFirst().text) - .build()) - .finishReason(response.stopReason) - .build())); - assertThat(actual.usage()).isEqualTo(Usage.builder() - .promptTokens(response.usage.inputTokens) - .completionTokens(response.usage.outputTokens) - .totalTokens(response.usage.inputTokens + response.usage.outputTokens) - .build()); - } - - @Test - void toCreateMessage() { - var request = podamFactory.manufacturePojo(ChatCompletionRequest.class); - - AnthropicCreateMessageRequest actual = LlmProviderAnthropicMapper.INSTANCE - .toCreateMessageRequest(request); - - assertThat(actual).isNotNull(); - assertThat(actual.model).isEqualTo(request.model()); - assertThat(actual.stream).isEqualTo(request.stream()); - assertThat(actual.temperature).isEqualTo(request.temperature()); - assertThat(actual.topP).isEqualTo(request.topP()); - assertThat(actual.stopSequences).isEqualTo(request.stop()); - assertThat(actual.messages).usingRecursiveComparison().ignoringCollectionOrder().isEqualTo( - request.messages().stream() - .filter(message -> List.of(Role.USER, Role.ASSISTANT).contains(message.role())) - .map(LlmProviderAnthropicMapper.INSTANCE::mapToAnthropicMessage) - .toList()); - assertThat(actual.system).usingRecursiveComparison().ignoringCollectionOrder().isEqualTo( - request.messages().stream() - .filter(message -> message.role() == Role.SYSTEM) - .map(LlmProviderAnthropicMapper.INSTANCE::mapToSystemMessage) - .toList()); - } - } - @Nested @TestInstance(TestInstance.Lifecycle.PER_CLASS) class GeminiMappers {