Skip to content

add configurable default routing group, add indicative errors, use ex… #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public class RoutingConfiguration

private boolean addXForwardedHeaders = true;

private String defaultRoutingGroup = "adhoc";

public Duration getAsyncTimeout()
{
return asyncTimeout;
Expand All @@ -42,4 +44,14 @@ public void setAddXForwardedHeaders(boolean addXForwardedHeaders)
{
this.addXForwardedHeaders = addXForwardedHeaders;
}

public String getDefaultRoutingGroup()
{
return defaultRoutingGroup;
}

public void setDefaultRoutingGroup(String defaultRoutingGroup)
{
this.defaultRoutingGroup = defaultRoutingGroup;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import io.trino.gateway.ha.router.RoutingGroupSelector;
import io.trino.gateway.ha.router.RoutingManager;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.NotFoundException;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.Response;

import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -68,11 +71,23 @@ public RoutingTargetHandler(

public String getRoutingDestination(HttpServletRequest request)
{
Optional<String> previousBackend = getPreviousBackend(request);
String clusterHost = previousBackend.orElseGet(() -> getBackendFromRoutingGroup(request));
logRewrite(clusterHost, request);
try {
Optional<String> previousBackend = getPreviousBackend(request);
String clusterHost = previousBackend.orElseGet(() -> getBackendFromRoutingGroup(request));
logRewrite(clusterHost, request);

return buildUriWithNewBackend(clusterHost, request);
return buildUriWithNewBackend(clusterHost, request);
}
catch (NotFoundException e) {
throw new WebApplicationException(
Response.status(Response.Status.NOT_FOUND)
.entity(e.getMessage())
.build());
}
catch (WebApplicationException e) {
// Re-throw other WebApplicationExceptions (like BAD_REQUEST from query validation)
throw e;
}
}

public boolean isPathWhiteListed(String path)
Expand All @@ -92,7 +107,7 @@ private String getBackendFromRoutingGroup(HttpServletRequest request)
String routingGroup = routingGroupSelector.findRoutingGroup(request);
String user = request.getHeader(USER_HEADER);
if (!isNullOrEmpty(routingGroup)) {
// This falls back on adhoc backend if there is no cluster found for the routing group.
// This falls back on default routing group backend if there is no cluster found for the routing group.
return routingManager.provideBackendForRoutingGroup(routingGroup, user);
}
return routingManager.provideAdhocBackend(user);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class QueryCountBasedRouterProvider
public QueryCountBasedRouterProvider(HaGatewayConfiguration configuration)
{
super(configuration);
routingManager = new QueryCountBasedRouter(gatewayBackendManager, queryHistoryManager);
routingManager = new QueryCountBasedRouter(gatewayBackendManager, queryHistoryManager, configuration.getRouting());
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public RouterBaseModule(HaGatewayConfiguration configuration)
Jdbi jdbi = Jdbi.create(configuration.getDataStore().getJdbcUrl(), configuration.getDataStore().getUser(), configuration.getDataStore().getPassword());
connectionManager = new JdbcConnectionManager(jdbi, configuration.getDataStore());
resourceGroupsManager = new HaResourceGroupsManager(connectionManager);
gatewayBackendManager = new HaGatewayManager(jdbi);
gatewayBackendManager = new HaGatewayManager(jdbi, configuration.getRouting());
queryHistoryManager = new HaQueryHistoryManager(jdbi, configuration.getDataStore().getJdbcUrl().startsWith("jdbc:oracle"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class StochasticRoutingManagerProvider
public StochasticRoutingManagerProvider(HaGatewayConfiguration configuration)
{
super(configuration);
routingManager = new StochasticRoutingManager(gatewayBackendManager, queryHistoryManager);
routingManager = new StochasticRoutingManager(gatewayBackendManager, queryHistoryManager, configuration.getRouting());
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ public interface GatewayBackendDao
""")
List<GatewayBackend> findActiveBackend();

/**
* @deprecated Use {@link #findActiveBackendByRoutingGroup(String)} with the configured default routing group
*/
@Deprecated
@SqlQuery("""
SELECT * FROM gateway_backend
WHERE active = true AND routing_group = 'adhoc'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import io.trino.gateway.ha.router.schema.RoutingGroupExternalBody;
import io.trino.gateway.ha.router.schema.RoutingGroupExternalResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.Response;

import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -98,10 +100,18 @@ public String findRoutingGroup(HttpServletRequest servletRequest)
throw new RuntimeException("Unexpected response: null");
}
else if (response.errors() != null && !response.errors().isEmpty()) {
throw new RuntimeException("Response with error: " + String.join(", ", response.errors()));
log.warn("Query validation failed with errors: %s", String.join(", ", response.errors()));
throw new WebApplicationException(
Response.status(Response.Status.BAD_REQUEST)
.entity(response.errors())
.build());
}
return response.routingGroup();
}
catch (WebApplicationException e) {
// Re-throw WebApplicationException to preserve status and entity
throw e;
}
catch (Exception e) {
log.error(e, "Error occurred while retrieving routing group "
+ "from external routing rules processing at " + uri);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ public interface GatewayBackendManager

List<ProxyBackendConfiguration> getAllActiveBackends();

/**
* @deprecated Use {@link #getActiveBackends(String)} with the configured default routing group
*/
@Deprecated
List<ProxyBackendConfiguration> getActiveAdhocBackends();

List<ProxyBackendConfiguration> getActiveBackends(String routingGroup);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.trino.gateway.ha.config.ProxyBackendConfiguration;
import io.trino.gateway.ha.config.RoutingConfiguration;
import io.trino.gateway.ha.persistence.dao.GatewayBackend;
import io.trino.gateway.ha.persistence.dao.GatewayBackendDao;
import org.jdbi.v3.core.Jdbi;
Expand All @@ -32,10 +33,12 @@ public class HaGatewayManager
private static final Logger log = Logger.get(HaGatewayManager.class);

private final GatewayBackendDao dao;
private final String defaultRoutingGroup;

public HaGatewayManager(Jdbi jdbi)
public HaGatewayManager(Jdbi jdbi, RoutingConfiguration routingConfiguration)
{
dao = requireNonNull(jdbi, "jdbi is null").onDemand(GatewayBackendDao.class);
this.defaultRoutingGroup = requireNonNull(routingConfiguration, "routingConfiguration is null").getDefaultRoutingGroup();
}

@Override
Expand All @@ -56,11 +59,10 @@ public List<ProxyBackendConfiguration> getAllActiveBackends()
public List<ProxyBackendConfiguration> getActiveAdhocBackends()
{
try {
List<GatewayBackend> proxyBackendList = dao.findActiveAdhocBackend();
return upcast(proxyBackendList);
return getActiveBackends(defaultRoutingGroup);
}
catch (Exception e) {
log.info("Error fetching all backends: %s", e.getLocalizedMessage());
log.info("Error fetching backends for default routing group: %s", e.getLocalizedMessage());
}
return ImmutableList.of();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.airlift.log.Logger;
import io.trino.gateway.ha.clustermonitor.ClusterStats;
import io.trino.gateway.ha.clustermonitor.TrinoStatus;
import io.trino.gateway.ha.config.RoutingConfiguration;

import java.util.ArrayList;
import java.util.Collections;
Expand All @@ -35,6 +36,7 @@ public class QueryCountBasedRouter
private static final Logger log = Logger.get(QueryCountBasedRouter.class);
@GuardedBy("this")
private List<LocalStats> clusterStats;
private final String defaultRoutingGroup;

@VisibleForTesting
synchronized List<LocalStats> clusterStats()
Expand Down Expand Up @@ -136,9 +138,11 @@ public void userQueuedCount(Map<String, Integer> userQueuedCount)

public QueryCountBasedRouter(
GatewayBackendManager gatewayBackendManager,
QueryHistoryManager queryHistoryManager)
QueryHistoryManager queryHistoryManager,
RoutingConfiguration routingConfiguration)
{
super(gatewayBackendManager, queryHistoryManager);
super(gatewayBackendManager, queryHistoryManager, routingConfiguration);
this.defaultRoutingGroup = routingConfiguration.getDefaultRoutingGroup();
clusterStats = new ArrayList<>();
}

Expand Down Expand Up @@ -224,7 +228,7 @@ private synchronized Optional<String> getBackendForRoutingGroup(String routingGr
@Override
public String provideAdhocBackend(String user)
{
return getBackendForRoutingGroup("adhoc", user).orElseThrow(() -> new RouterException("did not find any cluster for the adhoc routing group"));
return getBackendForRoutingGroup(defaultRoutingGroup, user).orElseThrow(() -> new RouterException("did not find any cluster for the default routing group: " + defaultRoutingGroup));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import io.trino.gateway.ha.clustermonitor.ClusterStats;
import io.trino.gateway.ha.clustermonitor.TrinoStatus;
import io.trino.gateway.ha.config.ProxyBackendConfiguration;
import io.trino.gateway.ha.config.RoutingConfiguration;
import jakarta.ws.rs.HttpMethod;
import jakarta.ws.rs.NotFoundException;

import java.net.HttpURLConnection;
import java.net.URL;
Expand All @@ -47,10 +49,12 @@ public abstract class RoutingManager
private final ExecutorService executorService = Executors.newFixedThreadPool(5);
private final GatewayBackendManager gatewayBackendManager;
private final ConcurrentHashMap<String, TrinoStatus> backendToStatus;
private final String defaultRoutingGroup;

public RoutingManager(GatewayBackendManager gatewayBackendManager)
public RoutingManager(GatewayBackendManager gatewayBackendManager, RoutingConfiguration routingConfiguration)
{
this.gatewayBackendManager = gatewayBackendManager;
this.defaultRoutingGroup = routingConfiguration.getDefaultRoutingGroup();
queryIdBackendCache =
CacheBuilder.newBuilder()
.maximumSize(10000)
Expand Down Expand Up @@ -83,7 +87,7 @@ public void setBackendForQueryId(String queryId, String backend)
*/
public String provideAdhocBackend(String user)
{
List<ProxyBackendConfiguration> backends = this.gatewayBackendManager.getActiveAdhocBackends();
List<ProxyBackendConfiguration> backends = this.gatewayBackendManager.getActiveBackends(defaultRoutingGroup);
backends.removeIf(backend -> isBackendNotHealthy(backend.getName()));
if (backends.size() == 0) {
throw new IllegalStateException("Number of active backends found zero");
Expand All @@ -100,6 +104,10 @@ public String provideBackendForRoutingGroup(String routingGroup, String user)
{
List<ProxyBackendConfiguration> backends =
gatewayBackendManager.getActiveBackends(routingGroup);
// Check if any backends exist for the routing group (even before filtering unhealthy ones)
if (backends.isEmpty()) {
throw new NotFoundException("Routing group does not exist: " + routingGroup);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the error message here does not reflect what this is doing. I think something like "Cannot find any backends for routing group: " + routingGroup is better

}
backends.removeIf(backend -> isBackendNotHealthy(backend.getName()));
if (backends.isEmpty()) {
return provideAdhocBackend(user);
Expand Down Expand Up @@ -177,7 +185,7 @@ protected String findBackendForUnknownQueryId(String queryId)
log.warn("Query id [%s] not found", queryId);
}
// Fallback on first active backend if queryId mapping not found.
return gatewayBackendManager.getActiveAdhocBackends().get(0).getProxyTo();
return gatewayBackendManager.getActiveBackends(defaultRoutingGroup).get(0).getProxyTo();
}

// Predicate helper function to remove the backends from the list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.base.Strings;
import io.airlift.log.Logger;
import io.trino.gateway.ha.config.RoutingConfiguration;

public class StochasticRoutingManager
extends RoutingManager
Expand All @@ -25,7 +26,15 @@ public class StochasticRoutingManager
public StochasticRoutingManager(
GatewayBackendManager gatewayBackendManager, QueryHistoryManager queryHistoryManager)
{
super(gatewayBackendManager);
this(gatewayBackendManager, queryHistoryManager, null);
}

public StochasticRoutingManager(
GatewayBackendManager gatewayBackendManager,
QueryHistoryManager queryHistoryManager,
RoutingConfiguration routingConfiguration)
{
super(gatewayBackendManager, routingConfiguration);
this.queryHistoryManager = queryHistoryManager;
}

Expand Down
Loading