Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,24 @@
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.apache.nifi.annotation.behavior.DynamicProperties;
import org.apache.nifi.annotation.behavior.DynamicProperty;
import org.apache.nifi.annotation.behavior.SupportsSensitiveDynamicProperties;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.annotation.lifecycle.OnDisabled;
import org.apache.nifi.annotation.lifecycle.OnEnabled;
import org.apache.nifi.components.AllowableValue;
import org.apache.nifi.components.ConfigVerificationResult;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.components.PropertyValue;
import org.apache.nifi.components.ValidationContext;
import org.apache.nifi.components.ValidationResult;
import org.apache.nifi.components.Validator;
import org.apache.nifi.controller.AbstractControllerService;
import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.controller.VerifiableControllerService;
import org.apache.nifi.expression.AttributeExpression;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.migration.PropertyConfiguration;
Expand All @@ -58,14 +63,24 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

@SupportsSensitiveDynamicProperties
@Tags({"oauth2", "provider", "authorization", "access token", "http"})
@CapabilityDescription("Provides OAuth 2.0 access tokens that can be used as Bearer authorization header in HTTP requests." +
" Can use either Resource Owner Password Credentials Grant or Client Credentials Grant." +
" Client authentication can be done with either HTTP Basic authentication or in the request body.")
@DynamicProperties({
@DynamicProperty(
name = "FORM.Request parameter name",
value = "Request parameter value",
expressionLanguageScope = ExpressionLanguageScope.ENVIRONMENT,
description = "Custom parameters that should be added to the body of the request against the token endpoint."
)
})
public class StandardOauth2AccessTokenProvider extends AbstractControllerService implements OAuth2AccessTokenProvider, VerifiableControllerService {
public static final PropertyDescriptor AUTHORIZATION_SERVER_URL = new PropertyDescriptor.Builder()
.name("Authorization Server URL")
Expand Down Expand Up @@ -218,6 +233,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
);

private static final String AUTHORIZATION_HEADER = "Authorization";
private static final String FORM_PREFIX = "FORM.";

public static final ObjectMapper ACCESS_DETAILS_MAPPER = new ObjectMapper()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
Expand All @@ -237,6 +253,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
private volatile String resource;
private volatile String audience;
private volatile long refreshWindowSeconds;
private volatile Map<String, String> customFormParameters = new HashMap<>();

private volatile AccessToken accessDetails;

Expand All @@ -262,6 +279,22 @@ public List<PropertyDescriptor> getSupportedPropertyDescriptors() {
return PROPERTY_DESCRIPTORS;
}

@Override
protected PropertyDescriptor getSupportedDynamicPropertyDescriptor(final String propertyDescriptorName) {
if (propertyDescriptorName.startsWith(FORM_PREFIX)) {
return new PropertyDescriptor.Builder()
.required(false)
.name(propertyDescriptorName)
.description("The value of the form parameter to add to the request body.")
.addValidator(StandardValidators.createAttributeExpressionLanguageValidator(AttributeExpression.ResultType.STRING, true))
.dynamic(true)
.expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT)
.build();
}

return null;
}

@OnEnabled
public void onEnabled(ConfigurationContext context) {
getProperties(context);
Expand Down Expand Up @@ -396,6 +429,27 @@ private void getProperties(ConfigurationContext context) {
}

refreshWindowSeconds = context.getProperty(REFRESH_WINDOW).asTimePeriod(TimeUnit.SECONDS);

for (PropertyDescriptor descriptor : context.getProperties().keySet()) {
if (!descriptor.isDynamic() || !descriptor.getName().startsWith(FORM_PREFIX)) {
continue;
}

String parameterName = descriptor.getName().substring(FORM_PREFIX.length());
if (parameterName.isEmpty()) {
continue;
}

PropertyValue propertyValue = context.getProperty(descriptor);
if (propertyValue == null) {
continue;
}

String evaluatedValue = propertyValue.evaluateAttributeExpressions().getValue();
if (evaluatedValue != null) {
customFormParameters.put(parameterName, evaluatedValue);
}
}
}

private boolean isRefreshRequired() {
Expand Down Expand Up @@ -438,6 +492,7 @@ private void addFormData(FormBody.Builder formBuilder) {
if (audience != null) {
formBuilder.add("audience", audience);
}
customFormParameters.forEach(formBuilder::add);
}

private AccessToken requestToken(FormBody.Builder formBuilder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
import okhttp3.ResponseBody;
import okio.Buffer;
import org.apache.nifi.components.ConfigVerificationResult;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.controller.VerifiableControllerService;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.Processor;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.util.NoOpProcessor;
import org.apache.nifi.util.MockPropertyValue;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
Expand All @@ -49,7 +52,9 @@
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -123,6 +128,7 @@ protected ComponentLog getLogger() {
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUDIENCE).getValue()).thenReturn(AUDIENCE);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.BASIC_AUTHENTICATION.getValue());
when(mockContext.getProperties()).thenReturn(Collections.emptyMap());
}

@Nested
Expand Down Expand Up @@ -411,6 +417,46 @@ public void testRequestBodyFormData() throws Exception {
assertEquals(expected, buffer.readString(Charset.defaultCharset()));
}

@Test
public void testRequestBodyFormDataIncludesCustomParameters() throws Exception {
PropertyDescriptor accountIdDescriptor = new PropertyDescriptor.Builder()
.name("FORM.account_id")
.dynamic(true)
.expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT)
.build();

when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE.getValue());
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.REQUEST_BODY.getValue());
Map<PropertyDescriptor, String> properties = new HashMap<>();
properties.put(accountIdDescriptor, "12345");
when(mockContext.getProperties()).thenReturn(properties);
when(mockContext.getProperty(accountIdDescriptor)).thenReturn(new MockPropertyValue("12345"));

testSubject.onEnabled(mockContext);

Response response = buildResponse(HTTP_OK, "{\"access_token\":\"foobar\"}");
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);

testSubject.getAccessDetails();

verify(mockHttpClient, atLeast(1)).newCall(requestCaptor.capture());
FormBody formBody = (FormBody) requestCaptor.getValue().body();
assertNotNull(formBody);

Map<String, String> parameters = new HashMap<>();
for (int i = 0; i < formBody.size(); i++) {
parameters.put(formBody.encodedName(i), formBody.encodedValue(i));
}

assertEquals("client_credentials", parameters.get("grant_type"));
assertEquals(CLIENT_ID, parameters.get("client_id"));
assertEquals(CLIENT_SECRET, parameters.get("client_secret"));
assertEquals(SCOPE, parameters.get("scope"));
assertEquals(RESOURCE, parameters.get("resource"));
assertEquals(AUDIENCE, parameters.get("audience"));
assertEquals("12345", parameters.get("account_id"));
}

@Test
public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception {
// GIVEN
Expand Down