diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/RoutingConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/RoutingConfiguration.java index c3cdb631d..035218d3d 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/RoutingConfiguration.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/RoutingConfiguration.java @@ -23,6 +23,8 @@ public class RoutingConfiguration private boolean addXForwardedHeaders = true; + private String defaultRoutingGroup = "adhoc"; + public Duration getAsyncTimeout() { return asyncTimeout; @@ -42,4 +44,14 @@ public void setAddXForwardedHeaders(boolean addXForwardedHeaders) { this.addXForwardedHeaders = addXForwardedHeaders; } + + public String getDefaultRoutingGroup() + { + return defaultRoutingGroup; + } + + public void setDefaultRoutingGroup(String defaultRoutingGroup) + { + this.defaultRoutingGroup = defaultRoutingGroup; + } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/RulesExternalConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/RulesExternalConfiguration.java index 4601e0566..6e42ed0df 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/RulesExternalConfiguration.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/RulesExternalConfiguration.java @@ -19,6 +19,7 @@ public class RulesExternalConfiguration { private String urlPath; private List excludeHeaders; + private boolean propagateErrors; public String getUrlPath() { @@ -39,4 +40,14 @@ public void setExcludeHeaders(List excludeHeaders) { this.excludeHeaders = excludeHeaders; } + + public boolean isPropagateErrors() + { + return this.propagateErrors; + } + + public void setPropagateErrors(Boolean propagateErrors) + { + this.propagateErrors = propagateErrors; + } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java index cb995b273..9d9276cac 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java @@ -25,6 +25,9 @@ import io.trino.gateway.ha.router.schema.RoutingSelectorResponse; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; +import jakarta.ws.rs.NotFoundException; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.Response; import java.util.Arrays; import java.util.Collections; @@ -53,6 +56,7 @@ public class RoutingTargetHandler private static final Logger log = Logger.get(RoutingTargetHandler.class); private final RoutingManager routingManager; private final RoutingGroupSelector routingGroupSelector; + private final String defaultRoutingGroup; private final List statementPaths; private final List extraWhitelistPaths; private final boolean requestAnalyserClientsUseV2Format; @@ -67,6 +71,7 @@ public RoutingTargetHandler( { this.routingManager = requireNonNull(routingManager); this.routingGroupSelector = requireNonNull(routingGroupSelector); + this.defaultRoutingGroup = haGatewayConfiguration.getRouting().getDefaultRoutingGroup(); statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths()); extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList()); requestAnalyserClientsUseV2Format = haGatewayConfiguration.getRequestAnalyzerConfig().isClientsUseV2Format(); @@ -76,28 +81,45 @@ public RoutingTargetHandler( public RoutingTargetResponse resolveRouting(HttpServletRequest request) { - Optional queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize); - Optional previousCluster = getPreviousCluster(queryId, request); - RoutingTargetResponse routingTargetResponse = previousCluster.map(cluster -> { - String routingGroup = queryId.map(routingManager::findRoutingGroupForQueryId) - .orElse("adhoc"); - return new RoutingTargetResponse( - new RoutingDestination(routingGroup, cluster, buildUriWithNewCluster(cluster, request)), - request); - }).orElse(getRoutingTargetResponse(request)); - logRewrite(routingTargetResponse.routingDestination().clusterHost(), request); - return routingTargetResponse; + try { + Optional queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize); + Optional previousCluster = getPreviousCluster(queryId, request); + + RoutingTargetResponse routingTargetResponse = previousCluster.map(cluster -> { + String routingGroup = queryId.map(routingManager::findRoutingGroupForQueryId) + .orElse(defaultRoutingGroup); + + return new RoutingTargetResponse( + new RoutingDestination(routingGroup, cluster, buildUriWithNewCluster(cluster, request)), + request); + }).orElse(getRoutingTargetResponse(request)); + + logRewrite(routingTargetResponse.routingDestination().clusterHost(), request); + return routingTargetResponse; + } + catch (NotFoundException e) { + throw new WebApplicationException( + Response.status(Response.Status.NOT_FOUND) + .entity(e.getMessage()) + .build()); + } } private RoutingTargetResponse getRoutingTargetResponse(HttpServletRequest request) { RoutingSelectorResponse routingDestination = routingGroupSelector.findRoutingDestination(request); String user = request.getHeader(USER_HEADER); - // This falls back on adhoc routing group if there is no cluster found (or value is empty) for the routing group. - String routingGroup = (routingDestination.routingGroup() != null && !routingDestination.routingGroup().isEmpty()) - ? routingDestination.routingGroup() - : "adhoc"; + String routingGroup; + + // This falls back on default routing group backend if there is no cluster found for the routing group. + if (!isNullOrEmpty(routingDestination.routingGroup())) { + routingGroup = routingDestination.routingGroup(); + } + else { + routingGroup = defaultRoutingGroup; + } String clusterHost = routingManager.provideClusterForRoutingGroup(routingGroup, user); + // Apply headers from RoutingDestination if there are any HttpServletRequest modifiedRequest = request; if (!routingDestination.externalHeaders().isEmpty()) { diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/module/QueryCountBasedRouterProvider.java b/gateway-ha/src/main/java/io/trino/gateway/ha/module/QueryCountBasedRouterProvider.java index b417fd19e..90bc0d513 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/module/QueryCountBasedRouterProvider.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/module/QueryCountBasedRouterProvider.java @@ -15,20 +15,25 @@ import com.google.inject.Scopes; import io.trino.gateway.ha.config.HaGatewayConfiguration; +import io.trino.gateway.ha.config.RoutingConfiguration; import io.trino.gateway.ha.router.QueryCountBasedRouter; import io.trino.gateway.ha.router.RoutingManager; public class QueryCountBasedRouterProvider extends RouterBaseModule { + private final HaGatewayConfiguration configuration; + public QueryCountBasedRouterProvider(HaGatewayConfiguration configuration) { super(configuration); + this.configuration = configuration; } @Override public void configure() { + bind(RoutingConfiguration.class).toInstance(configuration.getRouting()); bind(RoutingManager.class).to(QueryCountBasedRouter.class).in(Scopes.SINGLETON); } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/module/RouterBaseModule.java b/gateway-ha/src/main/java/io/trino/gateway/ha/module/RouterBaseModule.java index 3903e252b..97df5db3e 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/module/RouterBaseModule.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/module/RouterBaseModule.java @@ -38,7 +38,7 @@ public RouterBaseModule(HaGatewayConfiguration configuration) Jdbi jdbi = Jdbi.create(configuration.getDataStore().getJdbcUrl(), configuration.getDataStore().getUser(), configuration.getDataStore().getPassword()); connectionManager = new JdbcConnectionManager(jdbi, configuration.getDataStore()); resourceGroupsManager = new HaResourceGroupsManager(connectionManager); - gatewayBackendManager = new HaGatewayManager(jdbi); + gatewayBackendManager = new HaGatewayManager(jdbi, configuration.getRouting()); queryHistoryManager = new HaQueryHistoryManager(jdbi, configuration.getDataStore().getJdbcUrl().startsWith("jdbc:oracle")); } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/module/StochasticRoutingManagerProvider.java b/gateway-ha/src/main/java/io/trino/gateway/ha/module/StochasticRoutingManagerProvider.java index 78c16cd32..3e0409795 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/module/StochasticRoutingManagerProvider.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/module/StochasticRoutingManagerProvider.java @@ -15,20 +15,25 @@ import com.google.inject.Scopes; import io.trino.gateway.ha.config.HaGatewayConfiguration; +import io.trino.gateway.ha.config.RoutingConfiguration; import io.trino.gateway.ha.router.RoutingManager; import io.trino.gateway.ha.router.StochasticRoutingManager; public class StochasticRoutingManagerProvider extends RouterBaseModule { + private final HaGatewayConfiguration configuration; + public StochasticRoutingManagerProvider(HaGatewayConfiguration configuration) { super(configuration); + this.configuration = configuration; } @Override public void configure() { + bind(RoutingConfiguration.class).toInstance(configuration.getRouting()); bind(RoutingManager.class).to(StochasticRoutingManager.class).in(Scopes.SINGLETON); } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/persistence/dao/GatewayBackendDao.java b/gateway-ha/src/main/java/io/trino/gateway/ha/persistence/dao/GatewayBackendDao.java index 87009a03b..30429a950 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/persistence/dao/GatewayBackendDao.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/persistence/dao/GatewayBackendDao.java @@ -29,6 +29,10 @@ public interface GatewayBackendDao """) List findActiveBackend(); + /** + * @deprecated Use {@link #findActiveBackendByRoutingGroup(String)} with the configured default routing group + */ + @Deprecated @SqlQuery(""" SELECT * FROM gateway_backend WHERE active = true AND routing_group = 'adhoc' diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java index 0d1160325..3b04873df 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java @@ -29,6 +29,8 @@ import io.trino.gateway.ha.router.schema.RoutingGroupExternalBody; import io.trino.gateway.ha.router.schema.RoutingSelectorResponse; import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.net.URISyntaxException; @@ -52,6 +54,7 @@ public class ExternalRoutingGroupSelector private static final Logger log = Logger.get(ExternalRoutingGroupSelector.class); private final Set excludeHeaders; private final URI uri; + private final Boolean propagateErrors; private final HttpClient httpClient; private final RequestAnalyzerConfig requestAnalyzerConfig; private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider; @@ -67,6 +70,7 @@ public class ExternalRoutingGroupSelector .add("Content-Length") .addAll(rulesExternalConfiguration.getExcludeHeaders()) .build(); + this.propagateErrors = rulesExternalConfiguration.isPropagateErrors(); this.requestAnalyzerConfig = requestAnalyzerConfig; trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig); @@ -101,7 +105,13 @@ public RoutingSelectorResponse findRoutingDestination(HttpServletRequest servlet throw new RuntimeException("Unexpected response: null"); } else if (response.errors() != null && !response.errors().isEmpty()) { - throw new RuntimeException("Response with error: " + String.join(", ", response.errors())); + if (propagateErrors) { + log.warn("Query validation failed with errors: %s", String.join(", ", response.errors())); + throw new WebApplicationException( + Response.status(Response.Status.BAD_REQUEST) + .entity(response.errors()) + .build()); + } } // Filter out excluded headers and null values @@ -119,6 +129,10 @@ else if (response.errors() != null && !response.errors().isEmpty()) { } return new RoutingSelectorResponse(response.routingGroup(), filteredHeaders); } + catch (WebApplicationException e) { + // Re-throw WebApplicationException to preserve status and entity + throw e; + } catch (Exception e) { log.error(e, "Error occurred while retrieving routing group " + "from external routing rules processing at " + uri); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayBackendManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayBackendManager.java index 7016f49e6..e643e4d6c 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayBackendManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayBackendManager.java @@ -24,7 +24,7 @@ public interface GatewayBackendManager List getAllActiveBackends(); - List getActiveAdhocBackends(); + List getActiveDefaultBackends(); List getActiveBackends(String routingGroup); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/HaGatewayManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/HaGatewayManager.java index 732b06912..e83c8e4e6 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/HaGatewayManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/HaGatewayManager.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; import io.trino.gateway.ha.config.ProxyBackendConfiguration; +import io.trino.gateway.ha.config.RoutingConfiguration; import io.trino.gateway.ha.persistence.dao.GatewayBackend; import io.trino.gateway.ha.persistence.dao.GatewayBackendDao; import org.jdbi.v3.core.Jdbi; @@ -32,10 +33,12 @@ public class HaGatewayManager private static final Logger log = Logger.get(HaGatewayManager.class); private final GatewayBackendDao dao; + private final String defaultRoutingGroup; - public HaGatewayManager(Jdbi jdbi) + public HaGatewayManager(Jdbi jdbi, RoutingConfiguration routingConfiguration) { dao = requireNonNull(jdbi, "jdbi is null").onDemand(GatewayBackendDao.class); + this.defaultRoutingGroup = requireNonNull(routingConfiguration, "routingConfiguration is null").getDefaultRoutingGroup(); } @Override @@ -53,14 +56,13 @@ public List getAllActiveBackends() } @Override - public List getActiveAdhocBackends() + public List getActiveDefaultBackends() { try { - List proxyBackendList = dao.findActiveAdhocBackend(); - return upcast(proxyBackendList); + return getActiveBackends(defaultRoutingGroup); } catch (Exception e) { - log.info("Error fetching all backends: %s", e.getLocalizedMessage()); + log.info("Error fetching backends for default routing group: %s", e.getLocalizedMessage()); } return ImmutableList.of(); } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/QueryCountBasedRouter.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/QueryCountBasedRouter.java index 1d5092d9b..cae7a2af5 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/QueryCountBasedRouter.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/QueryCountBasedRouter.java @@ -21,6 +21,7 @@ import io.airlift.log.Logger; import io.trino.gateway.ha.clustermonitor.ClusterStats; import io.trino.gateway.ha.clustermonitor.TrinoStatus; +import io.trino.gateway.ha.config.RoutingConfiguration; import java.util.ArrayList; import java.util.Collections; @@ -36,6 +37,7 @@ public class QueryCountBasedRouter private static final Logger log = Logger.get(QueryCountBasedRouter.class); @GuardedBy("this") private List clusterStats; + private final String defaultRoutingGroup; @VisibleForTesting synchronized List clusterStats() @@ -138,9 +140,11 @@ public void userQueuedCount(Map userQueuedCount) @Inject public QueryCountBasedRouter( GatewayBackendManager gatewayBackendManager, - QueryHistoryManager queryHistoryManager) + QueryHistoryManager queryHistoryManager, + RoutingConfiguration routingConfiguration) { - super(gatewayBackendManager, queryHistoryManager); + super(gatewayBackendManager, queryHistoryManager, routingConfiguration); + this.defaultRoutingGroup = routingConfiguration.getDefaultRoutingGroup(); clusterStats = new ArrayList<>(); } @@ -224,16 +228,16 @@ private synchronized Optional getBackendForRoutingGroup(String routingGr } @Override - public String provideAdhocCluster(String user) + public String provideDefaultCluster(String user) { - return getBackendForRoutingGroup("adhoc", user).orElseThrow(() -> new RouterException("did not find any cluster for the adhoc routing group")); + return getBackendForRoutingGroup(defaultRoutingGroup, user).orElseThrow(() -> new RouterException("did not find any cluster for the default routing group: " + defaultRoutingGroup)); } @Override public String provideClusterForRoutingGroup(String routingGroup, String user) { return getBackendForRoutingGroup(routingGroup, user) - .orElse(provideAdhocCluster(user)); + .orElse(provideDefaultCluster(user)); } @Override diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/RoutingManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/RoutingManager.java index 4d029ca8c..2fd0bdf10 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/RoutingManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/RoutingManager.java @@ -20,6 +20,7 @@ import io.trino.gateway.ha.clustermonitor.ClusterStats; import io.trino.gateway.ha.clustermonitor.TrinoStatus; import io.trino.gateway.ha.config.ProxyBackendConfiguration; +import io.trino.gateway.ha.config.RoutingConfiguration; import jakarta.ws.rs.HttpMethod; import java.net.HttpURLConnection; @@ -47,12 +48,14 @@ public abstract class RoutingManager private final ExecutorService executorService = Executors.newFixedThreadPool(5); private final GatewayBackendManager gatewayBackendManager; private final ConcurrentHashMap backendToStatus; + private final String defaultRoutingGroup; private final LoadingCache queryIdRoutingGroupCache; private final QueryHistoryManager queryHistoryManager; - public RoutingManager(GatewayBackendManager gatewayBackendManager, QueryHistoryManager queryHistoryManager) + public RoutingManager(GatewayBackendManager gatewayBackendManager, QueryHistoryManager queryHistoryManager, RoutingConfiguration routingConfiguration) { this.gatewayBackendManager = gatewayBackendManager; + this.defaultRoutingGroup = routingConfiguration.getDefaultRoutingGroup(); this.queryHistoryManager = queryHistoryManager; queryIdBackendCache = CacheBuilder.newBuilder() @@ -100,13 +103,13 @@ public void setRoutingGroupForQueryId(String queryId, String routingGroup) } /** - * Performs routing to an adhoc backend. + * Performs routing to a default backend. */ - public String provideAdhocCluster(String user) + public String provideDefaultCluster(String user) { - List backends = this.gatewayBackendManager.getActiveAdhocBackends(); + List backends = gatewayBackendManager.getActiveDefaultBackends(); backends.removeIf(backend -> isBackendNotHealthy(backend.getName())); - if (backends.size() == 0) { + if (backends.isEmpty()) { throw new IllegalStateException("Number of active backends found zero"); } int backendId = Math.abs(RANDOM.nextInt()) % backends.size(); @@ -114,7 +117,7 @@ public String provideAdhocCluster(String user) } /** - * Performs routing to a given cluster group. This falls back to an adhoc backend, if no scheduled + * Performs routing to a given cluster group. This falls back to a default backend, if no scheduled * backend is found. */ public String provideClusterForRoutingGroup(String routingGroup, String user) @@ -123,7 +126,7 @@ public String provideClusterForRoutingGroup(String routingGroup, String user) gatewayBackendManager.getActiveBackends(routingGroup); backends.removeIf(backend -> isBackendNotHealthy(backend.getName())); if (backends.isEmpty()) { - return provideAdhocCluster(user); + return provideDefaultCluster(user); } int backendId = Math.abs(RANDOM.nextInt()) % backends.size(); return backends.get(backendId).getProxyTo(); @@ -214,7 +217,7 @@ protected String findBackendForUnknownQueryId(String queryId) log.warn("Query id [%s] not found", queryId); } // Fallback on first active backend if queryId mapping not found. - return gatewayBackendManager.getActiveAdhocBackends().get(0).getProxyTo(); + return gatewayBackendManager.getActiveBackends(defaultRoutingGroup).get(0).getProxyTo(); } /** diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java index f169b5592..cabaf53ad 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java @@ -16,6 +16,7 @@ import com.google.common.base.Strings; import com.google.inject.Inject; import io.airlift.log.Logger; +import io.trino.gateway.ha.config.RoutingConfiguration; public class StochasticRoutingManager extends RoutingManager @@ -23,11 +24,19 @@ public class StochasticRoutingManager private static final Logger log = Logger.get(StochasticRoutingManager.class); private final QueryHistoryManager queryHistoryManager; - @Inject public StochasticRoutingManager( GatewayBackendManager gatewayBackendManager, QueryHistoryManager queryHistoryManager) { - super(gatewayBackendManager, queryHistoryManager); + this(gatewayBackendManager, queryHistoryManager, null); + } + + @Inject + public StochasticRoutingManager( + GatewayBackendManager gatewayBackendManager, + QueryHistoryManager queryHistoryManager, + RoutingConfiguration routingConfiguration) + { + super(gatewayBackendManager, queryHistoryManager, routingConfiguration); this.queryHistoryManager = queryHistoryManager; } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestHaGatewayManager.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestHaGatewayManager.java index e13a071c9..c7d32c3b2 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestHaGatewayManager.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestHaGatewayManager.java @@ -14,6 +14,7 @@ package io.trino.gateway.ha.router; import io.trino.gateway.ha.config.ProxyBackendConfiguration; +import io.trino.gateway.ha.config.RoutingConfiguration; import io.trino.gateway.ha.persistence.JdbcConnectionManager; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -32,7 +33,8 @@ final class TestHaGatewayManager void setUp() { JdbcConnectionManager connectionManager = createTestingJdbcConnectionManager(); - haGatewayManager = new HaGatewayManager(connectionManager.getJdbi()); + RoutingConfiguration routingConfiguration = new RoutingConfiguration(); + haGatewayManager = new HaGatewayManager(connectionManager.getJdbi(), routingConfiguration); } @Test diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestQueryCountBasedRouter.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestQueryCountBasedRouter.java index cdcd34105..c3cff3115 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestQueryCountBasedRouter.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestQueryCountBasedRouter.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import io.trino.gateway.ha.clustermonitor.ClusterStats; import io.trino.gateway.ha.clustermonitor.TrinoStatus; +import io.trino.gateway.ha.config.RoutingConfiguration; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -155,7 +156,8 @@ public void init() .addAll(getClusterStatsList("etl")) .build(); - queryCountBasedRouter = new QueryCountBasedRouter(null, null); + RoutingConfiguration routingConfiguration = new RoutingConfiguration(); + queryCountBasedRouter = new QueryCountBasedRouter(null, null, routingConfiguration); queryCountBasedRouter.updateBackEndStats(clusters); } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingManagerNotFound.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingManagerNotFound.java new file mode 100644 index 000000000..5f5453d0e --- /dev/null +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingManagerNotFound.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.router; + +import io.trino.gateway.ha.config.RoutingConfiguration; +import io.trino.gateway.ha.persistence.JdbcConnectionManager; +import jakarta.ws.rs.NotFoundException; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; + +import static io.trino.gateway.ha.TestingJdbcConnectionManager.createTestingJdbcConnectionManager; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@TestInstance(Lifecycle.PER_CLASS) +final class TestRoutingManagerNotFound +{ + private RoutingManager routingManager; + + @BeforeAll + void setUp() + { + JdbcConnectionManager connectionManager = createTestingJdbcConnectionManager(); + RoutingConfiguration routingConfiguration = new RoutingConfiguration(); + routingConfiguration.setDefaultRoutingGroup("default"); + GatewayBackendManager backendManager = new HaGatewayManager(connectionManager.getJdbi(), routingConfiguration); + QueryHistoryManager historyManager = new HaQueryHistoryManager(connectionManager.getJdbi(), false); + routingManager = new StochasticRoutingManager(backendManager, historyManager, routingConfiguration); + } + + @Test + void testNonExistentRoutingGroupThrowsNotFoundException() + { + // When requesting a non-existent routing group, a NotFoundException should be thrown + assertThatThrownBy(() -> routingManager.provideBackendForRoutingGroup("non_existent_group", "user")) + .isInstanceOf(NotFoundException.class) + .hasMessageContaining("Routing group does not exist: non_existent_group"); + } +} diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestStochasticRoutingManager.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestStochasticRoutingManager.java index bd7bf67e6..38f60ab22 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestStochasticRoutingManager.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestStochasticRoutingManager.java @@ -15,6 +15,7 @@ import io.trino.gateway.ha.clustermonitor.TrinoStatus; import io.trino.gateway.ha.config.ProxyBackendConfiguration; +import io.trino.gateway.ha.config.RoutingConfiguration; import io.trino.gateway.ha.persistence.JdbcConnectionManager; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -35,9 +36,10 @@ final class TestStochasticRoutingManager void setUp() { JdbcConnectionManager connectionManager = createTestingJdbcConnectionManager(); - backendManager = new HaGatewayManager(connectionManager.getJdbi()); + RoutingConfiguration routingConfiguration = new RoutingConfiguration(); + backendManager = new HaGatewayManager(connectionManager.getJdbi(), routingConfiguration); historyManager = new HaQueryHistoryManager(connectionManager.getJdbi(), false); - haRoutingManager = new StochasticRoutingManager(backendManager, historyManager); + haRoutingManager = new StochasticRoutingManager(backendManager, historyManager, routingConfiguration); } @Test