Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import com.epam.aidial.core.server.Proxy;
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.ApiKeyData;
import com.epam.aidial.core.server.function.CollectResponseAttachmentsFn;
import com.epam.aidial.core.server.function.CollectResponseChatCompletionAttachmentsFn;
import com.epam.aidial.core.server.token.TokenUsage;
import com.epam.aidial.core.server.token.TokenUsageParser;
Expand All @@ -35,6 +34,7 @@
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.epam.aidial.core.server.Proxy.HEADER_APPLICATION_ID;
import static com.epam.aidial.core.server.Proxy.HEADER_APPLICATION_PROPERTIES;
Expand All @@ -49,21 +49,28 @@ public class BaseDeploymentPostController {
protected final Proxy proxy;
protected final ProxyContext context;

protected BufferingReadStream createResponseStream(HttpClientResponse proxyResponse) {
protected BufferingReadStream createResponseStream(HttpClientResponse proxyResponse,
Supplier<BufferingReadStream.BaseEventListener> listenerSupplier) {
BufferingReadStream.BaseEventListener eventListener = null;
if (isEventStreamResponse(proxyResponse)) {
eventListener = listenerSupplier.get();
}
return new BufferingReadStream(proxyResponse, ProxyUtil.contentLength(proxyResponse, 1024), eventListener);
}

protected boolean isEventStreamResponse(HttpClientResponse proxyResponse) {
String contentType = proxyResponse.getHeader(HttpHeaders.CONTENT_TYPE);
boolean isEventStreamResponse = Strings.CI.contains(contentType, "text/event-stream") && context.isStreamingRequest();
CollectResponseAttachmentsFn handler = isEventStreamResponse ? new CollectResponseChatCompletionAttachmentsFn(proxy, context) : null;
return new BufferingReadStream(proxyResponse, ProxyUtil.contentLength(proxyResponse, 1024), handler);
return Strings.CI.contains(contentType, "text/event-stream") && context.isStreamingRequest();
}

protected Future<Void> collectResponseAttachments(Buffer responseBody) {
if (context.isStreamingRequest()) {
if (isEventStreamResponse(context.getProxyResponse())) {
return Future.succeededFuture();
}
try (InputStream stream = new ByteBufInputStream(responseBody.getByteBuf())) {
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream);
var fn = new CollectResponseChatCompletionAttachmentsFn(proxy, context);
return fn.apply(tree);
return fn.apply(tree).map(ignored -> null);
} catch (Throwable e) {
log.warn("Can't parse JSON response body. Error:", e);
return Future.failedFuture(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,26 @@
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.ApiKeyData;
import com.epam.aidial.core.server.function.BaseRequestFunction;
import com.epam.aidial.core.server.function.BaseResponseFunction;
import com.epam.aidial.core.server.function.BuildUpstreamCacheFn;
import com.epam.aidial.core.server.function.CollectDeploymentsFn;
import com.epam.aidial.core.server.function.CollectRequestApplicationFilesFn;
import com.epam.aidial.core.server.function.CollectRequestChatCompletionAttachmentsFn;
import com.epam.aidial.core.server.function.CollectRequestDataFn;
import com.epam.aidial.core.server.function.CollectResponseChatCompletionAttachmentsFn;
import com.epam.aidial.core.server.function.enhancement.ApplyDefaultDeploymentSettingsFn;
import com.epam.aidial.core.server.function.enhancement.EnhanceModelRequestFn;
import com.epam.aidial.core.server.limiter.RateLimitResult;
import com.epam.aidial.core.server.service.PermissionDeniedException;
import com.epam.aidial.core.server.sse.SseEvent;
import com.epam.aidial.core.server.token.TokenUsage;
import com.epam.aidial.core.server.upstream.UpstreamRoute;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.server.vertx.stream.BufferingReadStream;
import com.epam.aidial.core.storage.exception.ResourceNotFoundException;
import com.epam.aidial.core.storage.http.HttpException;
import com.epam.aidial.core.storage.http.HttpStatus;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBufInputStream;
Expand All @@ -41,6 +45,7 @@

import java.io.InputStream;
import java.util.List;
import java.util.function.Supplier;

@Slf4j
public class DeploymentPostController extends BaseDeploymentPostController {
Expand Down Expand Up @@ -283,7 +288,9 @@ private void handleProxyResponse(HttpClientResponse proxyResponse) {
upstreamRoute.fail(proxyResponse);
}

BufferingReadStream responseStream = createResponseStream(proxyResponse);
Supplier<BufferingReadStream.BaseEventListener> eventListenerSupplier = () ->
new ChatCompletionSseListener(new CollectResponseChatCompletionAttachmentsFn(proxy, context));
BufferingReadStream responseStream = createResponseStream(proxyResponse, eventListenerSupplier);

context.setProxyResponse(proxyResponse);
context.setProxyResponseTimestamp(System.currentTimeMillis());
Expand Down Expand Up @@ -389,4 +396,27 @@ private void handleResponseError(Throwable error, BufferingReadStream responseSt
context.getProxyRequest().reset();
}
}

public static class ChatCompletionSseListener extends BufferingReadStream.BaseEventListener {

public static final String CHAT_COMPLETION_FINAL_MESSAGE = "[DONE]";

public ChatCompletionSseListener(BaseResponseFunction function) {
super(function);
}

@Override
protected boolean isLastEvent(SseEvent event, JsonNode data) {
return isFinalEvent(event);
}

@Override
protected boolean skipEvent(SseEvent event) {
return isFinalEvent(event);
}

private static boolean isFinalEvent(SseEvent event) {
return CHAT_COMPLETION_FINAL_MESSAGE.equals(event.getData());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.epam.aidial.core.server.function.BaseRequestFunction;
import com.epam.aidial.core.server.function.CollectRequestChatCompletionAttachmentsFn;
import com.epam.aidial.core.server.function.CollectRequestDataFn;
import com.epam.aidial.core.server.function.CollectResponseChatCompletionAttachmentsFn;
import com.epam.aidial.core.server.function.enhancement.ApplyDefaultDeploymentSettingsFn;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.server.vertx.stream.BufferingReadStream;
Expand All @@ -27,6 +28,7 @@

import java.io.InputStream;
import java.util.List;
import java.util.function.Supplier;

@Slf4j
public class InterceptorController extends BaseDeploymentPostController {
Expand Down Expand Up @@ -143,7 +145,9 @@ private void handleProxyResponse(HttpClientResponse proxyResponse) {
context.getDeployment().getEndpoint(),
proxyResponse.statusCode(), proxyResponse.headers().size());

BufferingReadStream responseStream = createResponseStream(proxyResponse);
Supplier<BufferingReadStream.BaseEventListener> eventListenerSupplier = () ->
new DeploymentPostController.ChatCompletionSseListener(new CollectResponseChatCompletionAttachmentsFn(proxy, context));
BufferingReadStream responseStream = createResponseStream(proxyResponse, eventListenerSupplier);

context.setProxyResponse(proxyResponse);
context.setProxyResponseTimestamp(System.currentTimeMillis());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.ApiKeyData;
import com.epam.aidial.core.server.function.BaseRequestFunction;
import com.epam.aidial.core.server.function.BaseResponseFunction;
import com.epam.aidial.core.server.function.CollectResponseChatCompletionAttachmentsFn;
import com.epam.aidial.core.server.function.enhancement.EnhanceModelRequestFn;
import com.epam.aidial.core.server.service.PermissionDeniedException;
import com.epam.aidial.core.server.sse.SseEvent;
import com.epam.aidial.core.server.token.TokenUsage;
import com.epam.aidial.core.server.upstream.UpstreamRoute;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.server.vertx.stream.BufferingReadStream;
import com.epam.aidial.core.storage.exception.ResourceNotFoundException;
import com.epam.aidial.core.storage.http.HttpException;
import com.epam.aidial.core.storage.http.HttpStatus;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.netty.buffer.ByteBufInputStream;
import io.vertx.core.Future;
Expand All @@ -33,6 +37,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.function.Supplier;

@Slf4j
public class ResponsesController extends BaseDeploymentPostController {
Expand Down Expand Up @@ -188,7 +193,8 @@ private void handleProxyResponse(HttpClientResponse proxyResponse) {
upstreamRoute.fail(proxyResponse);
}

BufferingReadStream responseStream = createResponseStream(proxyResponse);
Supplier<BufferingReadStream.BaseEventListener> eventListenerSupplier = ResponsesSseListener::new;
BufferingReadStream responseStream = createResponseStream(proxyResponse, eventListenerSupplier);

context.setProxyResponse(proxyResponse);
context.setProxyResponseTimestamp(System.currentTimeMillis());
Expand Down Expand Up @@ -278,4 +284,16 @@ private void handleResponseError(Throwable error, BufferingReadStream responseSt
context.getProxyRequest().reset();
}
}

private static class ResponsesSseListener extends BufferingReadStream.BaseEventListener {

public ResponsesSseListener() {
super();
}

@Override
protected boolean isLastEvent(SseEvent event, JsonNode data) {
return "response.incomplete".equals(event.getEvent()) || "response.completed".equals(event.getEvent());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import com.epam.aidial.core.server.data.ApiKeyData;
import com.epam.aidial.core.server.data.ErrorData;
import com.epam.aidial.core.server.function.BaseRequestFunction;
import com.epam.aidial.core.server.function.BaseResponseFunction;
import com.epam.aidial.core.server.function.FilterAllowedToolsFn;
import com.epam.aidial.core.server.function.enhancement.InjectApplicationPropsToMcpRequest;
import com.epam.aidial.core.server.limiter.RateLimitResult;
import com.epam.aidial.core.server.limiter.RateLimiter;
Expand All @@ -24,6 +26,7 @@
import com.epam.aidial.core.server.service.ConsentService;
import com.epam.aidial.core.server.service.DeploymentService;
import com.epam.aidial.core.server.service.PermissionDeniedException;
import com.epam.aidial.core.server.sse.SseEvent;
import com.epam.aidial.core.server.token.TokenStatsTracker;
import com.epam.aidial.core.server.upstream.UpstreamRoute;
import com.epam.aidial.core.server.upstream.UpstreamRouteProvider;
Expand Down Expand Up @@ -68,8 +71,6 @@
@Slf4j
public class ToolSetProxyController implements Controller {

private static final ArrayNode EMPTY_JSON_ARRAY = ProxyUtil.MAPPER.createArrayNode();

private final String toolSetId;

private final CredentialsLocator credentialsLocator;
Expand Down Expand Up @@ -100,6 +101,8 @@ public class ToolSetProxyController implements Controller {

private final List<BaseRequestFunction<ObjectNode>> enhancementFunctions;

private final Proxy proxy;

private String mcpMethodName;

private boolean useAllowedTools;
Expand All @@ -125,6 +128,7 @@ public ToolSetProxyController(Proxy proxy, ProxyContext context, String toolSetI
this.resourceCredentialsService = proxy.getResourceCredentialsService();
this.applicationSchemaService = proxy.getApplicationSchemaService();
this.enhancementFunctions = List.of(new InjectApplicationPropsToMcpRequest(proxy, context));
this.proxy = proxy;
}

@Override
Expand Down Expand Up @@ -346,8 +350,14 @@ private void handleProxyResponse(HttpClientResponse proxyResponse) {
}

private void handleSseProxyResponse(HttpClientResponse proxyResponse) {
BufferingReadStream.BaseEventListener eventListener = null;
if (requireToolFiltering()) {
FilterAllowedToolsFn fn = new FilterAllowedToolsFn(proxy, context);
eventListener = new BufferingReadStream.BaseEventListener(fn);
}

BufferingReadStream proxyResponseStream = new BufferingReadStream(proxyResponse,
ProxyUtil.contentLength(proxyResponse, 1024));
ProxyUtil.contentLength(proxyResponse, 1024), eventListener);

context.setProxyResponse(proxyResponse);
context.setResponseStream(proxyResponseStream);
Expand Down Expand Up @@ -383,12 +393,12 @@ private void handleResponse() {
}

private void handleResponse(int responseStatus, Buffer proxyResponseBody) {
if ("tools/list".equalsIgnoreCase(mcpMethodName) && useAllowedTools) {
Future<Buffer> future;
if (requireToolFiltering()) {
try (InputStream stream = new ByteBufInputStream(proxyResponseBody.getByteBuf())) {
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream);
if (filterToolList(tree)) {
proxyResponseBody = Buffer.buffer(ProxyUtil.MAPPER.writeValueAsBytes(tree));
}
JsonNode tree = ProxyUtil.MAPPER.readTree(stream);
FilterAllowedToolsFn fn = new FilterAllowedToolsFn(proxy, context);
future = fn.apply(tree).map(result -> Buffer.buffer(result.toString()));
} catch (Throwable e) {
if (e instanceof HttpException httpException) {
respond(httpException.getStatus(), httpException.getMessage());
Expand All @@ -398,38 +408,21 @@ private void handleResponse(int responseStatus, Buffer proxyResponseBody) {
log.warn("Can't process JSON response body. Error:", e);
return;
}
} else {
future = Future.succeededFuture(proxyResponseBody);
}
context.setResponseBody(proxyResponseBody);
respond(responseStatus, proxyResponseBody);
logStore.save(context);
}

private boolean filterToolList(ObjectNode body) {
ArrayNode tools = (ArrayNode) Optional.ofNullable(body.get("result")).map(result -> result.get("tools"))
.filter(JsonNode::isArray).orElse(EMPTY_JSON_ARRAY);
List<String> allowedTools = getAllowedTools(context.getDeployment());
if (allowedTools.isEmpty()) {
return false;
}
boolean modified = false;
for (Iterator<JsonNode> iter = tools.iterator(); iter.hasNext();) {
JsonNode tool = iter.next();
String name = tool.get("name").asText();
if (!allowedTools.contains(name)) {
iter.remove();
modified = true;
}
}
return modified;
future.onSuccess(result -> {
context.setResponseBody(result);
respond(responseStatus, result);
logStore.save(context);
}).onFailure(error -> {
log.error("Failed to handle MCP response body", error);
respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to handle MCP response body");
});
}

private List<String> getAllowedTools(Deployment deployment) {
if (deployment instanceof ToolSet toolSet) {
return toolSet.getAllowedTools();
} else if (deployment instanceof Application application) {
return application.getMcp().getAllowedTools();
}
throw new IllegalArgumentException("Unsupported deployment type: " + deployment.getName());
private boolean requireToolFiltering() {
return "tools/list".equalsIgnoreCase(mcpMethodName) && useAllowedTools;
}

private Future<?> handleRateLimitSuccess(Deployment deployment) {
Expand Down Expand Up @@ -562,4 +555,5 @@ protected void finalizeRequest() {
}).onFailure(error -> log.error("error occurred on invalidating per-request key", error));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ protected Future<Void> handleProxyResponseBody(Buffer responseBody) {
try (InputStream stream = new ByteBufInputStream(responseBody.getByteBuf())) {
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream);
var fn = new CollectResponseCustomAttachmentsFn(proxy, context);
return fn.apply(tree);
return fn.apply(tree).map(ignored -> null);
} catch (Throwable e) {
log.warn("Can't parse JSON response body. Error:", e);
return Future.failedFuture(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import com.epam.aidial.core.server.Proxy;
import com.epam.aidial.core.server.ProxyContext;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.JsonNode;
import io.vertx.core.Future;

public abstract class BaseResponseFunction extends BaseFunction<ObjectNode, Future<Void>> {
public abstract class BaseResponseFunction extends BaseFunction<JsonNode, Future<JsonNode>> {
public BaseResponseFunction(Proxy proxy, ProxyContext context) {
super(proxy, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.storage.resource.ResourceDescriptor;
import com.epam.aidial.core.storage.resource.ResourceUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.vertx.core.Future;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -24,10 +25,14 @@ public CollectResponseAttachmentsFn(Proxy proxy, ProxyContext context) {
}

@Override
public Future<Void> apply(ObjectNode tree) {
public Future<JsonNode> apply(JsonNode tree) {
if (!tree.isObject()) {
return Future.succeededFuture(tree);
}
ObjectNode objectNode = (ObjectNode) tree;
try {
Map<String, Set<ResourceAccessType>> permittedAttachments = new HashMap<>();
Set<String> attachments = collectAttachments(tree);
Set<String> attachments = collectAttachments(objectNode);
for (String attachment : attachments) {
processAttachedFile(attachment, permittedAttachments);
}
Expand All @@ -37,7 +42,7 @@ public Future<Void> apply(ObjectNode tree) {
String perRequestKey = context.getApiKeyData().getPerRequestKey();
return proxy.getTaskExecutor().submit(() -> {
proxy.getApiKeyStore().updatePerRequestApiKey(perRequestKey, json -> updateAutoSharedAttachments(json, permittedAttachments, perRequestKey));
return null;
return tree;
});
} catch (Throwable e) {
return Future.failedFuture(e);
Expand Down
Loading
Loading