Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor instance/region discovery #763

Merged
merged 4 commits into from
Dec 20, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,40 +62,39 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,
boolean validateAuthority,
MsalRequest msalRequest,
ServiceBundle serviceBundle) {
String host = authorityUrl.getHost();

if (shouldUseRegionalEndpoint(msalRequest)) {
//Server side telemetry requires the result from region discovery when any part of the region API is used
String detectedRegion = discoverRegion(msalRequest, serviceBundle);

if (msalRequest.application().azureRegion() != null) {
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
}
String host = authorityUrl.getHost();

//If region autodetection is enabled and a specific region not already set,
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion()
&& null != detectedRegion) {
msalRequest.application().azureRegion = detectedRegion;
}
cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion());
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
determineRegionOutcome(detectedRegion, msalRequest.application().azureRegion(), msalRequest.application().autoDetectRegion()));
//If instanceDiscovery flag set to false or instance discovery previously failed, do not do instance discovery
if (!msalRequest.application().instanceDiscovery() || instanceDiscoveryFailed) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it might be simpler to do smth like:

  • if a cached entry exists, use it
  • if not and instance discovery has been disabled, add the 1-host dumb entry to the cache. Same if instance discovery fails.

makes the code shorter and removes the need for having instanceDiscoveryFailed flag

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest commit I rearranged it so this is now the pattern:

  1. If instance discovery is disabled, just add a default entry (if it doesn't already exist) and return it
  2. If a region is set at the application level, adjust the host variable to use it
  3. If no cached entry for host exists, then do instance discovery
  4. If any part of the region API is used, try to detect the region
  5. If region autodetect enabled and a region wasn't set at the application level, set the region at the application level and adjust the host to use it
  6. Do instance discovery for host, and add the default entry if it fails
  7. Create cache entries for whatever host is at the end of all that
  8. Return the cached result

This means on future runs the cached entry will be returned quicker, because:

  • Instance discovery was disabled, so step 1 returns immediately
  • Instance discovery failed and a default entry for host was cached, so it skips steps 3-7 and returns the cached entry
  • Instance discovery succeeded and either a regional or non-regional entry for host was cached, so it skips steps 3-7 and returns the cached entry

return InstanceDiscoveryMetadataEntry.builder().
preferredCache(host).
preferredNetwork(host).
aliases(Collections.singleton(host)).
build();
}

InstanceDiscoveryMetadataEntry result = cache.get(host);
//If there is no cached instance metadata, do instance and region discovery and cache the result
if (cache.get(host) == null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this top level logic needs logging statements, which will help investiagte issues.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest commit I added debug logs for all the major if blocks: when instance discovery is disabled, when there's no cached metadata, when region API is used, etc.

if (shouldUseRegionalEndpoint(msalRequest)) {
//Server side telemetry requires the result from region discovery when any part of the region API is used
String detectedRegion = discoverRegion(msalRequest, serviceBundle);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the request provides an azureRegion, do you still need to do the discovery? Might be better to use the if(masalRequest.application().azureRegion() != null) logic before making this call.

Suggested change
String detectedRegion = discoverRegion(msalRequest, serviceBundle);
if (msalRequest.application().azureRegion() != null) {
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
}
else if (msalRequest.application().autoDetectRegion() ) {
//Server side telemetry requires the result from region discovery when any part of the region API is used
String detectedRegion = discoverRegion(msalRequest, serviceBundle);
msalRequest.application().azureRegion = detectedRegion;
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble finding the design doc, but there's a list of region telemetry values for different scenarios, including:

  • A developer set the region, and MSAL is able to detect the same one
  • A developer set the region, but MSAL found a different one
  • A developer set the region, but MSAL's autodetect failed for some reason

This means that we need to do autodetection whenever any part of the region API is used, unless we decide to just not collect this telemetry. I believe this behavior is the same in .NET and other MSALs too (feel free to correct me on that @bgavrilMS).

Our RegionTelemetry enum gives a good overview of all the different scenarios we look for: https://github.com/AzureAD/microsoft-authentication-library-for-java/blob/dev/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/RegionTelemetry.java

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we will try to auto-detect the region even if the app developer tells us "use region X". Auto-detection should only happen once per process (even if it fails) and I believe is set to timeout after 2 sec.

If this is too slow, we can discuss about removing this logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prior to these changes we were incorrectly doing autodetection every time a region was used, since the code to do it was happening before the code to check if instance metadata was cached.

With these changes autodetection happens only if no instance is cached, so it'll only happen once for a given authority/region and thus the 2 second timeout will also only happen once.


//If there is a specific region that should be used, adjust the authority URL to include it
// Otherwise, if region autodetection is enabled and a specific region was not set,
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
if (msalRequest.application().azureRegion() != null) {
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
} else if (msalRequest.application().autoDetectRegion() && detectedRegion != null) {
msalRequest.application().azureRegion = detectedRegion;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you also update the host in this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest commit I refactored some of the logic here:

  1. The part in the if block has been moved up closer to the start of the method, so if the application's azureRegion value is set it will do the host = getRegionalizedHost... and the regional host will be used when looking through the cache
  2. The part in the else if block has been adjusted so the host = getRegionalizedHost... is done if autodetection succeeds and the azureRegion value was not already set

}

if (result == null) {
if(msalRequest.application().instanceDiscovery() && !instanceDiscoveryFailed){
doInstanceDiscoveryAndCache(authorityUrl, validateAuthority, msalRequest, serviceBundle);
} else {
// instanceDiscovery flag is set to False. Do not perform instanceDiscovery.
return InstanceDiscoveryMetadataEntry.builder().
preferredCache(host).
preferredNetwork(host).
aliases(Collections.singleton(host)).
build();
cacheRegionInstanceMetadata(authorityUrl.getHost(), host);
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
determineRegionOutcome(detectedRegion, msalRequest.application().azureRegion(), msalRequest.application().autoDetectRegion()));
}

doInstanceDiscoveryAndCache(authorityUrl, validateAuthority, msalRequest, serviceBundle);
}

return cache.get(host);
Expand Down Expand Up @@ -164,14 +163,13 @@ private static boolean shouldUseRegionalEndpoint(MsalRequest msalRequest){
return false;
}

static void cacheRegionInstanceMetadata(String host, String region) {
static void cacheRegionInstanceMetadata(String originalHost, String regionalHost) {

Set<String> aliases = new HashSet<>();
aliases.add(host);
String regionalHost = getRegionalizedHost(host, region);
aliases.add(originalHost);

cache.putIfAbsent(regionalHost, InstanceDiscoveryMetadataEntry.builder().
preferredCache(host).
preferredCache(originalHost).
preferredNetwork(regionalHost).
aliases(aliases).
build());
Expand Down Expand Up @@ -295,7 +293,7 @@ static String discoverRegion(MsalRequest msalRequest, ServiceBundle serviceBundl

//Check if the REGION_NAME environment variable has a value for the region
if (System.getenv(REGION_NAME) != null) {
log.info("Region found in environment variable: " + System.getenv(REGION_NAME));
log.info(String.format("Region found in environment variable: %s",System.getenv(REGION_NAME)));
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_ENV_VARIABLE.telemetryValue);

return System.getenv(REGION_NAME);
Expand Down