From 89a09dee71682866d2b22198098876cebb794d16 Mon Sep 17 00:00:00 2001 From: Kent Dong Date: Thu, 23 Jan 2025 15:00:02 +0800 Subject: [PATCH] feat: Support customizing path, header and URL param predicates in AI Route (#414) --- .../sdk/constant/HigressConstants.java | 1 + .../alibaba/higress/sdk/model/ai/AiRoute.java | 23 ++++ .../sdk/model/route/RoutePredicate.java | 9 +- .../sdk/service/ai/AiRouteServiceImpl.java | 39 +++++-- frontend/src/interfaces/ai-route.ts | 1 + frontend/src/interfaces/llm-provider.ts | 1 + .../pages/ai/components/RouteForm/index.tsx | 100 +++++++++++++++++- frontend/src/pages/ai/provider.tsx | 3 +- frontend/src/pages/ai/route.tsx | 1 + 9 files changed, 161 insertions(+), 17 deletions(-) diff --git a/backend/sdk/src/main/java/com/alibaba/higress/sdk/constant/HigressConstants.java b/backend/sdk/src/main/java/com/alibaba/higress/sdk/constant/HigressConstants.java index f160c110..a48293f0 100644 --- a/backend/sdk/src/main/java/com/alibaba/higress/sdk/constant/HigressConstants.java +++ b/backend/sdk/src/main/java/com/alibaba/higress/sdk/constant/HigressConstants.java @@ -24,4 +24,5 @@ public class HigressConstants { public static final String INTERNAL_RESOURCE_NAME_SUFFIX = ".internal"; public static final String FALLBACK_ROUTE_NAME_SUFFIX = ".fallback"; public static final String FALLBACK_FROM_HEADER = "x-higress-fallback-from"; + public static final String MODEL_ROUTING_HEADER = "x-higress-llm-model"; } diff --git a/backend/sdk/src/main/java/com/alibaba/higress/sdk/model/ai/AiRoute.java b/backend/sdk/src/main/java/com/alibaba/higress/sdk/model/ai/AiRoute.java index 21a929e1..612f9aee 100644 --- a/backend/sdk/src/main/java/com/alibaba/higress/sdk/model/ai/AiRoute.java +++ b/backend/sdk/src/main/java/com/alibaba/higress/sdk/model/ai/AiRoute.java @@ -17,7 +17,11 @@ import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; +import com.alibaba.higress.sdk.constant.HigressConstants; import com.alibaba.higress.sdk.exception.ValidationException; +import com.alibaba.higress.sdk.model.route.KeyedRoutePredicate; +import com.alibaba.higress.sdk.model.route.RoutePredicate; +import com.alibaba.higress.sdk.model.route.RoutePredicateTypeEnum; import io.swagger.annotations.ApiModel; import lombok.AllArgsConstructor; @@ -35,6 +39,9 @@ public class AiRoute { private String name; private String version; private List domains; + private RoutePredicate pathPredicate; + private List headerPredicates; + private List urlParamPredicates; private List upstreams; private List modelPredicates; private AiRouteAuthConfig authConfig; @@ -47,6 +54,22 @@ public void validate() { if (CollectionUtils.isEmpty(upstreams)) { throw new ValidationException("upstreams cannot be empty."); } + if (pathPredicate != null) { + pathPredicate.validate(); + if (pathPredicate.getPredicateType() != RoutePredicateTypeEnum.PRE) { + throw new ValidationException("pathPredicate must be of type PRE."); + } + } + if (CollectionUtils.isNotEmpty(headerPredicates)) { + headerPredicates.forEach(KeyedRoutePredicate::validate); + if (headerPredicates.stream() + .anyMatch(p -> HigressConstants.MODEL_ROUTING_HEADER.equalsIgnoreCase(p.getKey()))) { + throw new ValidationException("headerPredicates cannot contain the model routing header."); + } + } + if (CollectionUtils.isNotEmpty(urlParamPredicates)) { + urlParamPredicates.forEach(KeyedRoutePredicate::validate); + } upstreams.forEach(AiUpstream::validate); if (authConfig != null) { authConfig.validate(); diff --git a/backend/sdk/src/main/java/com/alibaba/higress/sdk/model/route/RoutePredicate.java b/backend/sdk/src/main/java/com/alibaba/higress/sdk/model/route/RoutePredicate.java index c5153825..065c5e07 100644 --- a/backend/sdk/src/main/java/com/alibaba/higress/sdk/model/route/RoutePredicate.java +++ b/backend/sdk/src/main/java/com/alibaba/higress/sdk/model/route/RoutePredicate.java @@ -12,6 +12,8 @@ */ package com.alibaba.higress.sdk.model.route; +import java.beans.Transient; + import com.alibaba.higress.sdk.exception.ValidationException; import lombok.AllArgsConstructor; @@ -34,11 +36,16 @@ public class RoutePredicate { private Boolean caseSensitive; + @Transient + public RoutePredicateTypeEnum getPredicateType() { + return RoutePredicateTypeEnum.fromName(this.getMatchType()); + } + public void validate() { if (this.getMatchType() == null) { throw new ValidationException("matchType is required"); } - RoutePredicateTypeEnum predicateType = RoutePredicateTypeEnum.fromName(this.getMatchType()); + RoutePredicateTypeEnum predicateType = getPredicateType(); if (predicateType == null) { throw new ValidationException("Unknown matchType: " + this.getMatchType()); } diff --git a/backend/sdk/src/main/java/com/alibaba/higress/sdk/service/ai/AiRouteServiceImpl.java b/backend/sdk/src/main/java/com/alibaba/higress/sdk/service/ai/AiRouteServiceImpl.java index 6fabaceb..6871ff32 100644 --- a/backend/sdk/src/main/java/com/alibaba/higress/sdk/service/ai/AiRouteServiceImpl.java +++ b/backend/sdk/src/main/java/com/alibaba/higress/sdk/service/ai/AiRouteServiceImpl.java @@ -20,7 +20,6 @@ import java.util.Map; import java.util.Optional; -import com.alibaba.higress.sdk.constant.plugin.config.AiStatisticsConfig; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.MapUtils; import org.apache.commons.io.IOUtils; @@ -30,6 +29,7 @@ import com.alibaba.higress.sdk.constant.HigressConstants; import com.alibaba.higress.sdk.constant.KubernetesConstants; import com.alibaba.higress.sdk.constant.plugin.BuiltInPluginName; +import com.alibaba.higress.sdk.constant.plugin.config.AiStatisticsConfig; import com.alibaba.higress.sdk.constant.plugin.config.ModelMapperConfig; import com.alibaba.higress.sdk.constant.plugin.config.ModelRouterConfig; import com.alibaba.higress.sdk.exception.BusinessException; @@ -71,7 +71,8 @@ public class AiRouteServiceImpl implements AiRouteService { private static final Map AI_ROUTE_LABEL_SELECTORS = Map.of(KubernetesConstants.Label.CONFIG_MAP_TYPE_KEY, KubernetesConstants.Label.CONFIG_MAP_TYPE_VALUE_AI_ROUTE); - private static final String MODEL_ROUTING_HEADER = "x-higress-llm-model"; + private static final RoutePredicate DEFAULT_PATH_PREDICATE = + new RoutePredicate(RoutePredicateTypeEnum.PRE.name(), "/", true); private final KubernetesModelConverter kubernetesModelConverter; @@ -127,7 +128,7 @@ public AiRoute add(AiRoute route) { writeAiRouteResources(route); writeAiRouteFallbackResources(route); - return kubernetesModelConverter.configMap2AiRoute(newConfigMap); + return configMap2AiRoute(newConfigMap); } @Override @@ -138,7 +139,7 @@ public PaginatedResult list(CommonPageQuery query) { } catch (ApiException e) { throw new BusinessException("Error occurs when listing ConfigMap.", e); } - return PaginatedResult.createFromFullList(configMaps, query, kubernetesModelConverter::configMap2AiRoute); + return PaginatedResult.createFromFullList(configMaps, query, this::configMap2AiRoute); } @Override @@ -150,7 +151,7 @@ public AiRoute query(String routeName) { } catch (ApiException e) { throw new BusinessException("Error occurs when reading the ConfigMap with name: " + configMapName, e); } - return Optional.ofNullable(configMap).map(kubernetesModelConverter::configMap2AiRoute).orElse(null); + return Optional.ofNullable(configMap).map(this::configMap2AiRoute).orElse(null); } @Override @@ -184,10 +185,21 @@ public AiRoute update(AiRoute route) { writeAiRouteResources(route); writeAiRouteFallbackResources(route); - return kubernetesModelConverter.configMap2AiRoute(updatedConfigMap); + return configMap2AiRoute(updatedConfigMap); + } + + private AiRoute configMap2AiRoute(V1ConfigMap configMap){ + AiRoute route = kubernetesModelConverter.configMap2AiRoute(configMap); + if (route != null){ + fillDefaultValues(route); + } + return route; } private void fillDefaultValues(AiRoute route) { + if (route.getPathPredicate() == null) { + route.setPathPredicate(DEFAULT_PATH_PREDICATE); + } fillDefaultWeights(route.getUpstreams()); AiRouteFallbackConfig fallbackConfig = route.getFallbackConfig(); if (fallbackConfig != null && Boolean.TRUE.equals(fallbackConfig.getEnabled())) { @@ -306,7 +318,7 @@ private void writeModelRouteResources(List modelPredicates) { instance.setConfigurations(configurations); } - configurations.put(ModelRouterConfig.MODEL_TO_HEADER, MODEL_ROUTING_HEADER); + configurations.put(ModelRouterConfig.MODEL_TO_HEADER, HigressConstants.MODEL_ROUTING_HEADER); wasmPluginInstanceService.addOrUpdate(instance); } @@ -379,12 +391,16 @@ private void writeAiStatisticsResources(String routeName) { private Route buildRoute(String routeName, AiRoute aiRoute) { Route route = new Route(); route.setName(routeName); - route.setPath(new RoutePredicate(RoutePredicateTypeEnum.PRE.name(), "/", true)); + route.setPath(Optional.ofNullable(aiRoute.getPathPredicate()).orElse(DEFAULT_PATH_PREDICATE)); route.setDomains(aiRoute.getDomains()); + List headerPredicates = new ArrayList<>(); + if (CollectionUtils.isNotEmpty(aiRoute.getHeaderPredicates())) { + headerPredicates.addAll(aiRoute.getHeaderPredicates()); + } List modelPredicates = aiRoute.getModelPredicates(); if (CollectionUtils.isNotEmpty(modelPredicates)) { - KeyedRoutePredicate headerRoutePredicate = new KeyedRoutePredicate(MODEL_ROUTING_HEADER); + KeyedRoutePredicate headerRoutePredicate = new KeyedRoutePredicate(HigressConstants.MODEL_ROUTING_HEADER); if (modelPredicates.size() == 1) { AiModelPredicate modelPredicate = modelPredicates.get(0); headerRoutePredicate.setMatchType(modelPredicate.getMatchType()); @@ -393,8 +409,11 @@ private Route buildRoute(String routeName, AiRoute aiRoute) { headerRoutePredicate.setMatchType(RoutePredicateTypeEnum.REGULAR.toString()); headerRoutePredicate.setMatchValue(buildModelRoutingHeaderRegex(modelPredicates)); } - route.setHeaders(List.of(headerRoutePredicate)); + headerPredicates.add(headerRoutePredicate); } + route.setHeaders(headerPredicates); + + route.setUrlParams(aiRoute.getUrlParamPredicates()); return route; } diff --git a/frontend/src/interfaces/ai-route.ts b/frontend/src/interfaces/ai-route.ts index a8f10711..cfec6c29 100644 --- a/frontend/src/interfaces/ai-route.ts +++ b/frontend/src/interfaces/ai-route.ts @@ -1,4 +1,5 @@ export interface AiRoute { + key?: string; name: string; version?: string; domains: string[]; diff --git a/frontend/src/interfaces/llm-provider.ts b/frontend/src/interfaces/llm-provider.ts index 9f51ad91..877c6be0 100644 --- a/frontend/src/interfaces/llm-provider.ts +++ b/frontend/src/interfaces/llm-provider.ts @@ -1,4 +1,5 @@ export interface LlmProvider { + key?: string; name: string; type: string; protocol?: string; diff --git a/frontend/src/pages/ai/components/RouteForm/index.tsx b/frontend/src/pages/ai/components/RouteForm/index.tsx index 4d74ee85..d764b1b0 100644 --- a/frontend/src/pages/ai/components/RouteForm/index.tsx +++ b/frontend/src/pages/ai/components/RouteForm/index.tsx @@ -1,18 +1,22 @@ import { Consumer } from '@/interfaces/consumer'; import { DEFAULT_DOMAIN, Domain } from '@/interfaces/domain'; import { LlmProvider } from '@/interfaces/llm-provider'; +import FactorGroup from '@/pages/route/components/FactorGroup'; import { getGatewayDomains } from '@/services'; import { getConsumers } from '@/services/consumer'; import { getLlmProviders } from '@/services/llm-provider'; import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; import { useRequest } from 'ahooks'; -import { AutoComplete, Button, Form, Input, InputNumber, Select, Space, Switch } from 'antd'; +import { AutoComplete, Button, Checkbox, Form, Input, InputNumber, Select, Space, Switch } from 'antd'; +import { uniqueId } from "lodash"; import React, { forwardRef, useEffect, useImperativeHandle, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { aiModelProviders } from '../../configs'; import { HistoryButton, RedoOutlinedBtn } from './Components'; -const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => { +const { Option } = Select; + +const AiRouteForm: React.FC = forwardRef((props: { value: any }, ref) => { const { t } = useTranslation(); const { value } = props; const [form] = Form.useForm(); @@ -53,7 +57,7 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => { consumerResult.run(); domainsResult.run(); form.resetFields(); - if (value) initForm(); + initForm(); return () => { setAuthConfigEnabled(false); setFallbackConfigEnabled(false); @@ -61,7 +65,15 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => { }, []); const initForm = () => { - const { name = "", domains, upstreams = [], modelPredicates } = value; + const { + name = "", + domains, + pathPredicate = { matchType: 'PRE', matchValue: '/', caseSensitive: false }, + headerPredicates = [], + urlParamPredicates = [], + upstreams = [{}], + modelPredicates, + } = (value || {}); const _authConfig_enabled = value?.authConfig?.enabled || false; const _fallbackConfig_enabled = value?.fallbackConfig?.enabled || false; @@ -76,9 +88,20 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => { fallbackInitValues['fallbackConfig_modelNames'] = ''; } } + + headerPredicates && headerPredicates.map((query) => { + return { ...query, uid: uniqueId() }; + }); + urlParamPredicates && urlParamPredicates.map((header) => { + return { ...header, uid: uniqueId() }; + }); + const initValues = { name, domains: domains?.length ? domains[0] : [], + pathPredicate, + headerPredicates, + urlParamPredicates, upstreams, authConfig_enabled: _authConfig_enabled, authConfig_allowedConsumers: value?.authConfig?.allowedConsumers || "", @@ -136,6 +159,9 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => { const { name, domains, + pathPredicate, + headerPredicates, + urlParamPredicates, fallbackConfig_upstreams = '', authConfig_allowedConsumers = '', fallbackConfig_modelNames = '', @@ -146,6 +172,9 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => { const payload = { name, domains: domains && !Array.isArray(domains) ? [domains] : domains, + pathPredicate, + headerPredicates, + urlParamPredicates, fallbackConfig: { enabled: fallbackConfig_enabled, }, @@ -246,6 +275,66 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => { + + + + + + + + + + + + + + + + + + + { <> {fields.map(({ key, name, ...restField }, index) => ( { ); }); -export default ConsumerForm; +export default AiRouteForm; diff --git a/frontend/src/pages/ai/provider.tsx b/frontend/src/pages/ai/provider.tsx index 9fba59ec..65258e2d 100644 --- a/frontend/src/pages/ai/provider.tsx +++ b/frontend/src/pages/ai/provider.tsx @@ -94,7 +94,7 @@ const LlmProviderList: React.FC = () => { if (!Array.isArray(value) || !value.length) { return '-'; } - return value.map((token) => ); + return value.map((token) => ); }, }, { @@ -124,6 +124,7 @@ const LlmProviderList: React.FC = () => { manual: true, onSuccess: (result) => { const llmProviders = (result || []) as LlmProvider[]; + llmProviders.forEach(r => { r.key = r.name; }); llmProviders.sort((i1, i2) => { return i1.name.localeCompare(i2.name); }) diff --git a/frontend/src/pages/ai/route.tsx b/frontend/src/pages/ai/route.tsx index 7e8e56e7..3f2fad77 100644 --- a/frontend/src/pages/ai/route.tsx +++ b/frontend/src/pages/ai/route.tsx @@ -109,6 +109,7 @@ const AiRouteList: React.FC = () => { manual: true, onSuccess: (result) => { const aiRoutes = (result || []) as AiRoute[]; + aiRoutes.forEach(r => { r.key = r.name; }); aiRoutes.sort((i1, i2) => { return i1.name.localeCompare(i2.name); })