Skip to content

Commit

Permalink
feat: Support customizing path, header and URL param predicates in AI…
Browse files Browse the repository at this point in the history
… Route (higress-group#414)
  • Loading branch information
CH3CHO authored Jan 23, 2025
1 parent 5b1309d commit 89a09de
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ public class HigressConstants {
public static final String INTERNAL_RESOURCE_NAME_SUFFIX = ".internal";
public static final String FALLBACK_ROUTE_NAME_SUFFIX = ".fallback";
public static final String FALLBACK_FROM_HEADER = "x-higress-fallback-from";
public static final String MODEL_ROUTING_HEADER = "x-higress-llm-model";
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;

import com.alibaba.higress.sdk.constant.HigressConstants;
import com.alibaba.higress.sdk.exception.ValidationException;
import com.alibaba.higress.sdk.model.route.KeyedRoutePredicate;
import com.alibaba.higress.sdk.model.route.RoutePredicate;
import com.alibaba.higress.sdk.model.route.RoutePredicateTypeEnum;

import io.swagger.annotations.ApiModel;
import lombok.AllArgsConstructor;
Expand All @@ -35,6 +39,9 @@ public class AiRoute {
private String name;
private String version;
private List<String> domains;
private RoutePredicate pathPredicate;
private List<KeyedRoutePredicate> headerPredicates;
private List<KeyedRoutePredicate> urlParamPredicates;
private List<AiUpstream> upstreams;
private List<AiModelPredicate> modelPredicates;
private AiRouteAuthConfig authConfig;
Expand All @@ -47,6 +54,22 @@ public void validate() {
if (CollectionUtils.isEmpty(upstreams)) {
throw new ValidationException("upstreams cannot be empty.");
}
if (pathPredicate != null) {
pathPredicate.validate();
if (pathPredicate.getPredicateType() != RoutePredicateTypeEnum.PRE) {
throw new ValidationException("pathPredicate must be of type PRE.");
}
}
if (CollectionUtils.isNotEmpty(headerPredicates)) {
headerPredicates.forEach(KeyedRoutePredicate::validate);
if (headerPredicates.stream()
.anyMatch(p -> HigressConstants.MODEL_ROUTING_HEADER.equalsIgnoreCase(p.getKey()))) {
throw new ValidationException("headerPredicates cannot contain the model routing header.");
}
}
if (CollectionUtils.isNotEmpty(urlParamPredicates)) {
urlParamPredicates.forEach(KeyedRoutePredicate::validate);
}
upstreams.forEach(AiUpstream::validate);
if (authConfig != null) {
authConfig.validate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package com.alibaba.higress.sdk.model.route;

import java.beans.Transient;

import com.alibaba.higress.sdk.exception.ValidationException;

import lombok.AllArgsConstructor;
Expand All @@ -34,11 +36,16 @@ public class RoutePredicate {

private Boolean caseSensitive;

@Transient
public RoutePredicateTypeEnum getPredicateType() {
return RoutePredicateTypeEnum.fromName(this.getMatchType());
}

public void validate() {
if (this.getMatchType() == null) {
throw new ValidationException("matchType is required");
}
RoutePredicateTypeEnum predicateType = RoutePredicateTypeEnum.fromName(this.getMatchType());
RoutePredicateTypeEnum predicateType = getPredicateType();
if (predicateType == null) {
throw new ValidationException("Unknown matchType: " + this.getMatchType());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.Map;
import java.util.Optional;

import com.alibaba.higress.sdk.constant.plugin.config.AiStatisticsConfig;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.io.IOUtils;
Expand All @@ -30,6 +29,7 @@
import com.alibaba.higress.sdk.constant.HigressConstants;
import com.alibaba.higress.sdk.constant.KubernetesConstants;
import com.alibaba.higress.sdk.constant.plugin.BuiltInPluginName;
import com.alibaba.higress.sdk.constant.plugin.config.AiStatisticsConfig;
import com.alibaba.higress.sdk.constant.plugin.config.ModelMapperConfig;
import com.alibaba.higress.sdk.constant.plugin.config.ModelRouterConfig;
import com.alibaba.higress.sdk.exception.BusinessException;
Expand Down Expand Up @@ -71,7 +71,8 @@ public class AiRouteServiceImpl implements AiRouteService {
private static final Map<String, String> AI_ROUTE_LABEL_SELECTORS =
Map.of(KubernetesConstants.Label.CONFIG_MAP_TYPE_KEY, KubernetesConstants.Label.CONFIG_MAP_TYPE_VALUE_AI_ROUTE);

private static final String MODEL_ROUTING_HEADER = "x-higress-llm-model";
private static final RoutePredicate DEFAULT_PATH_PREDICATE =
new RoutePredicate(RoutePredicateTypeEnum.PRE.name(), "/", true);

private final KubernetesModelConverter kubernetesModelConverter;

Expand Down Expand Up @@ -127,7 +128,7 @@ public AiRoute add(AiRoute route) {
writeAiRouteResources(route);
writeAiRouteFallbackResources(route);

return kubernetesModelConverter.configMap2AiRoute(newConfigMap);
return configMap2AiRoute(newConfigMap);
}

@Override
Expand All @@ -138,7 +139,7 @@ public PaginatedResult<AiRoute> list(CommonPageQuery query) {
} catch (ApiException e) {
throw new BusinessException("Error occurs when listing ConfigMap.", e);
}
return PaginatedResult.createFromFullList(configMaps, query, kubernetesModelConverter::configMap2AiRoute);
return PaginatedResult.createFromFullList(configMaps, query, this::configMap2AiRoute);
}

@Override
Expand All @@ -150,7 +151,7 @@ public AiRoute query(String routeName) {
} catch (ApiException e) {
throw new BusinessException("Error occurs when reading the ConfigMap with name: " + configMapName, e);
}
return Optional.ofNullable(configMap).map(kubernetesModelConverter::configMap2AiRoute).orElse(null);
return Optional.ofNullable(configMap).map(this::configMap2AiRoute).orElse(null);
}

@Override
Expand Down Expand Up @@ -184,10 +185,21 @@ public AiRoute update(AiRoute route) {
writeAiRouteResources(route);
writeAiRouteFallbackResources(route);

return kubernetesModelConverter.configMap2AiRoute(updatedConfigMap);
return configMap2AiRoute(updatedConfigMap);
}

private AiRoute configMap2AiRoute(V1ConfigMap configMap){
AiRoute route = kubernetesModelConverter.configMap2AiRoute(configMap);
if (route != null){
fillDefaultValues(route);
}
return route;
}

private void fillDefaultValues(AiRoute route) {
if (route.getPathPredicate() == null) {
route.setPathPredicate(DEFAULT_PATH_PREDICATE);
}
fillDefaultWeights(route.getUpstreams());
AiRouteFallbackConfig fallbackConfig = route.getFallbackConfig();
if (fallbackConfig != null && Boolean.TRUE.equals(fallbackConfig.getEnabled())) {
Expand Down Expand Up @@ -306,7 +318,7 @@ private void writeModelRouteResources(List<AiModelPredicate> modelPredicates) {
instance.setConfigurations(configurations);
}

configurations.put(ModelRouterConfig.MODEL_TO_HEADER, MODEL_ROUTING_HEADER);
configurations.put(ModelRouterConfig.MODEL_TO_HEADER, HigressConstants.MODEL_ROUTING_HEADER);

wasmPluginInstanceService.addOrUpdate(instance);
}
Expand Down Expand Up @@ -379,12 +391,16 @@ private void writeAiStatisticsResources(String routeName) {
private Route buildRoute(String routeName, AiRoute aiRoute) {
Route route = new Route();
route.setName(routeName);
route.setPath(new RoutePredicate(RoutePredicateTypeEnum.PRE.name(), "/", true));
route.setPath(Optional.ofNullable(aiRoute.getPathPredicate()).orElse(DEFAULT_PATH_PREDICATE));
route.setDomains(aiRoute.getDomains());

List<KeyedRoutePredicate> headerPredicates = new ArrayList<>();
if (CollectionUtils.isNotEmpty(aiRoute.getHeaderPredicates())) {
headerPredicates.addAll(aiRoute.getHeaderPredicates());
}
List<AiModelPredicate> modelPredicates = aiRoute.getModelPredicates();
if (CollectionUtils.isNotEmpty(modelPredicates)) {
KeyedRoutePredicate headerRoutePredicate = new KeyedRoutePredicate(MODEL_ROUTING_HEADER);
KeyedRoutePredicate headerRoutePredicate = new KeyedRoutePredicate(HigressConstants.MODEL_ROUTING_HEADER);
if (modelPredicates.size() == 1) {
AiModelPredicate modelPredicate = modelPredicates.get(0);
headerRoutePredicate.setMatchType(modelPredicate.getMatchType());
Expand All @@ -393,8 +409,11 @@ private Route buildRoute(String routeName, AiRoute aiRoute) {
headerRoutePredicate.setMatchType(RoutePredicateTypeEnum.REGULAR.toString());
headerRoutePredicate.setMatchValue(buildModelRoutingHeaderRegex(modelPredicates));
}
route.setHeaders(List.of(headerRoutePredicate));
headerPredicates.add(headerRoutePredicate);
}
route.setHeaders(headerPredicates);

route.setUrlParams(aiRoute.getUrlParamPredicates());

return route;
}
Expand Down
1 change: 1 addition & 0 deletions frontend/src/interfaces/ai-route.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export interface AiRoute {
key?: string;
name: string;
version?: string;
domains: string[];
Expand Down
1 change: 1 addition & 0 deletions frontend/src/interfaces/llm-provider.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export interface LlmProvider {
key?: string;
name: string;
type: string;
protocol?: string;
Expand Down
100 changes: 95 additions & 5 deletions frontend/src/pages/ai/components/RouteForm/index.tsx
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import { Consumer } from '@/interfaces/consumer';
import { DEFAULT_DOMAIN, Domain } from '@/interfaces/domain';
import { LlmProvider } from '@/interfaces/llm-provider';
import FactorGroup from '@/pages/route/components/FactorGroup';
import { getGatewayDomains } from '@/services';
import { getConsumers } from '@/services/consumer';
import { getLlmProviders } from '@/services/llm-provider';
import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons';
import { useRequest } from 'ahooks';
import { AutoComplete, Button, Form, Input, InputNumber, Select, Space, Switch } from 'antd';
import { AutoComplete, Button, Checkbox, Form, Input, InputNumber, Select, Space, Switch } from 'antd';
import { uniqueId } from "lodash";
import React, { forwardRef, useEffect, useImperativeHandle, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { aiModelProviders } from '../../configs';
import { HistoryButton, RedoOutlinedBtn } from './Components';

const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => {
const { Option } = Select;

const AiRouteForm: React.FC = forwardRef((props: { value: any }, ref) => {
const { t } = useTranslation();
const { value } = props;
const [form] = Form.useForm();
Expand Down Expand Up @@ -53,15 +57,23 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => {
consumerResult.run();
domainsResult.run();
form.resetFields();
if (value) initForm();
initForm();
return () => {
setAuthConfigEnabled(false);
setFallbackConfigEnabled(false);
}
}, []);

const initForm = () => {
const { name = "", domains, upstreams = [], modelPredicates } = value;
const {
name = "",
domains,
pathPredicate = { matchType: 'PRE', matchValue: '/', caseSensitive: false },
headerPredicates = [],
urlParamPredicates = [],
upstreams = [{}],
modelPredicates,
} = (value || {});
const _authConfig_enabled = value?.authConfig?.enabled || false;
const _fallbackConfig_enabled = value?.fallbackConfig?.enabled || false;

Expand All @@ -76,9 +88,20 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => {
fallbackInitValues['fallbackConfig_modelNames'] = '';
}
}

headerPredicates && headerPredicates.map((query) => {
return { ...query, uid: uniqueId() };
});
urlParamPredicates && urlParamPredicates.map((header) => {
return { ...header, uid: uniqueId() };
});

const initValues = {
name,
domains: domains?.length ? domains[0] : [],
pathPredicate,
headerPredicates,
urlParamPredicates,
upstreams,
authConfig_enabled: _authConfig_enabled,
authConfig_allowedConsumers: value?.authConfig?.allowedConsumers || "",
Expand Down Expand Up @@ -136,6 +159,9 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => {
const {
name,
domains,
pathPredicate,
headerPredicates,
urlParamPredicates,
fallbackConfig_upstreams = '',
authConfig_allowedConsumers = '',
fallbackConfig_modelNames = '',
Expand All @@ -146,6 +172,9 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => {
const payload = {
name,
domains: domains && !Array.isArray(domains) ? [domains] : domains,
pathPredicate,
headerPredicates,
urlParamPredicates,
fallbackConfig: {
enabled: fallbackConfig_enabled,
},
Expand Down Expand Up @@ -246,6 +275,66 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => {
</Form.Item>
<RedoOutlinedBtn getList={domainsResult} />
</div>
<Form.Item label={t('route.routeForm.path')} required>
<Input.Group compact>
<Form.Item
name={['pathPredicate', 'matchType']}
noStyle
rules={[
{
required: true,
message: t('route.routeForm.pathPredicatesRequired'),
},
]}
>
<Select
style={{ width: '20%' }}
placeholder={t('route.routeForm.matchType')}
>
<Option value="PRE">{t('route.matchTypes.PRE')}</Option>
</Select>
</Form.Item>
<Form.Item
name={['pathPredicate', 'matchValue']}
noStyle
rules={[
{
required: true,
message: t('route.routeForm.pathMatcherRequired'),
},
]}
>
<Input style={{ width: '60%' }} placeholder={t('route.routeForm.pathMatcherPlacedholder')} />
</Form.Item>
<Form.Item
name={['pathPredicate', 'ignoreCase']}
noStyle
>
<Checkbox.Group
options={[
{
label: t('route.routeForm.caseInsensitive'), value: 'ignore',
},
]}
style={{ width: '18%', display: 'inline-flex', marginLeft: 12, marginTop: 4 }}
/>
</Form.Item>
</Input.Group>
</Form.Item>
<Form.Item
label={t('route.routeForm.header')}
name="headerPredicates"
tooltip={t('route.routeForm.headerTooltip')}
>
<FactorGroup />
</Form.Item>
<Form.Item
label={t('route.routeForm.query')}
name="urlParamPredicates"
tooltip={t('route.routeForm.queryTooltip')}
>
<FactorGroup />
</Form.Item>
<Form.Item
style={{ marginBottom: 10 }}
label={t("aiRoute.routeForm.selectModelService")} // {/* 选择模型服务 */}
Expand Down Expand Up @@ -338,6 +427,7 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => {
<>
{fields.map(({ key, name, ...restField }, index) => (
<Form.Item
key={key}
label={t("aiRoute.routeForm.label.serviceName")}
{...restField}
name={[name, 'provider']}
Expand Down Expand Up @@ -526,4 +616,4 @@ const ConsumerForm: React.FC = forwardRef((props: { value: any }, ref) => {
);
});

export default ConsumerForm;
export default AiRouteForm;
3 changes: 2 additions & 1 deletion frontend/src/pages/ai/provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ const LlmProviderList: React.FC = () => {
if (!Array.isArray(value) || !value.length) {
return '-';
}
return value.map((token) => <EllipsisMiddle value={token} />);
return value.map((token) => <EllipsisMiddle key={token} value={token} />);
},
},
{
Expand Down Expand Up @@ -124,6 +124,7 @@ const LlmProviderList: React.FC = () => {
manual: true,
onSuccess: (result) => {
const llmProviders = (result || []) as LlmProvider[];
llmProviders.forEach(r => { r.key = r.name; });
llmProviders.sort((i1, i2) => {
return i1.name.localeCompare(i2.name);
})
Expand Down
Loading

0 comments on commit 89a09de

Please sign in to comment.