Skip to content

Commit

Permalink
NIFI-14154 Add support for oAuth to GetWorkdayReport processor
Browse files Browse the repository at this point in the history
Signed-off-by: Pierre Villard <[email protected]>

This closes apache#9631.
  • Loading branch information
sfc-gh-mgemra authored and pvillard31 committed Jan 15, 2025
1 parent fd327e6 commit 53911bb
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-web-client-provider-api</artifactId>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-oauth2-provider-api</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-utils</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.annotation.lifecycle.OnScheduled;
import org.apache.nifi.components.AllowableValue;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.flowfile.attributes.CoreAttributes;
import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
import org.apache.nifi.processor.AbstractProcessor;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
Expand Down Expand Up @@ -90,6 +92,7 @@ public class GetWorkdayReport extends AbstractProcessor {
protected static final String GET_WORKDAY_REPORT_JAVA_EXCEPTION_MESSAGE = "getworkdayreport.java.exception.message";
protected static final String RECORD_COUNT = "record.count";
protected static final String BASIC_PREFIX = "Basic ";
protected static final String BEARER_PREFIX = "Bearer ";
protected static final String HEADER_AUTHORIZATION = "Authorization";
protected static final String HEADER_CONTENT_TYPE = "Content-Type";
protected static final String USERNAME_PASSWORD_SEPARATOR = ":";
Expand All @@ -103,10 +106,31 @@ public class GetWorkdayReport extends AbstractProcessor {
.addValidator(URL_VALIDATOR)
.build();

public static AllowableValue BASIC_AUTH_TYPE = new AllowableValue(
"BASIC_AUTH",
"Basic Auth",
"Used to access resources using Workday password and username."
);

public static AllowableValue OAUTH_TYPE = new AllowableValue(
"OAUTH",
"OAuth",
"Used to get fresh access tokens based on a previously acquired refresh token. Requires Client ID, Client Secret and Refresh Token."
);

public static final PropertyDescriptor AUTH_TYPE = new PropertyDescriptor.Builder()
.name("Authorization Type")
.description("The type of authorization for retrieving data from Workday resources.")
.required(true)
.allowableValues(BASIC_AUTH_TYPE, OAUTH_TYPE)
.defaultValue(BASIC_AUTH_TYPE.getValue())
.build();

protected static final PropertyDescriptor WORKDAY_USERNAME = new PropertyDescriptor.Builder()
.name("Workday Username")
.displayName("Workday Username")
.description("The username provided for authentication of Workday requests. Encoded using Base64 for HTTP Basic Authentication as described in RFC 7617.")
.dependsOn(AUTH_TYPE, BASIC_AUTH_TYPE)
.required(true)
.addValidator(StandardValidators.createRegexMatchingValidator(Pattern.compile("^[\\x20-\\x39\\x3b-\\x7e\\x80-\\xff]+$")))
.expressionLanguageSupported(FLOWFILE_ATTRIBUTES)
Expand All @@ -116,6 +140,7 @@ public class GetWorkdayReport extends AbstractProcessor {
.name("Workday Password")
.displayName("Workday Password")
.description("The password provided for authentication of Workday requests. Encoded using Base64 for HTTP Basic Authentication as described in RFC 7617.")
.dependsOn(AUTH_TYPE, BASIC_AUTH_TYPE)
.required(true)
.sensitive(true)
.addValidator(StandardValidators.createRegexMatchingValidator(Pattern.compile("^[\\x20-\\x7e\\x80-\\xff]+$")))
Expand All @@ -129,6 +154,14 @@ public class GetWorkdayReport extends AbstractProcessor {
.identifiesControllerService(WebClientServiceProvider.class)
.build();

public static final PropertyDescriptor OAUTH2_ACCESS_TOKEN_PROVIDER = new PropertyDescriptor.Builder()
.name("Access Token Provider")
.description("Enables managed retrieval of OAuth2 Bearer Token.")
.dependsOn(AUTH_TYPE, OAUTH_TYPE)
.identifiesControllerService(OAuth2AccessTokenProvider.class)
.required(true)
.build();

protected static final PropertyDescriptor RECORD_READER_FACTORY = new PropertyDescriptor.Builder()
.name("record-reader")
.displayName("Record Reader")
Expand Down Expand Up @@ -170,6 +203,8 @@ public class GetWorkdayReport extends AbstractProcessor {

protected static final List<PropertyDescriptor> PROPERTIES = List.of(
REPORT_URL,
AUTH_TYPE,
OAUTH2_ACCESS_TOKEN_PROVIDER,
WORKDAY_USERNAME,
WORKDAY_PASSWORD,
WEB_CLIENT_SERVICE,
Expand All @@ -178,6 +213,7 @@ public class GetWorkdayReport extends AbstractProcessor {
);

private final AtomicReference<WebClientService> webClientReference = new AtomicReference<>();
private final AtomicReference<OAuth2AccessTokenProvider> tokenProviderReference = new AtomicReference<>();
private final AtomicReference<RecordReaderFactory> recordReaderFactoryReference = new AtomicReference<>();
private final AtomicReference<RecordSetWriterFactory> recordSetWriterFactoryReference = new AtomicReference<>();

Expand All @@ -193,10 +229,12 @@ public Set<Relationship> getRelationships() {

@OnScheduled
public void setUpClient(final ProcessContext context) {
OAuth2AccessTokenProvider tokenProvider = context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).asControllerService(OAuth2AccessTokenProvider.class);
WebClientServiceProvider standardWebClientServiceProvider = context.getProperty(WEB_CLIENT_SERVICE).asControllerService(WebClientServiceProvider.class);
RecordReaderFactory recordReaderFactory = context.getProperty(RECORD_READER_FACTORY).asControllerService(RecordReaderFactory.class);
RecordSetWriterFactory recordSetWriterFactory = context.getProperty(RECORD_WRITER_FACTORY).asControllerService(RecordSetWriterFactory.class);
WebClientService webClientService = standardWebClientServiceProvider.getWebClientService();
tokenProviderReference.set(tokenProvider);
webClientReference.set(webClientService);
recordReaderFactoryReference.set(recordReaderFactory);
recordSetWriterFactoryReference.set(recordSetWriterFactory);
Expand Down Expand Up @@ -296,10 +334,16 @@ private FlowFile createResponseFlowFile(FlowFile flowfile, ProcessSession sessio
}

private String createAuthorizationHeader(ProcessContext context, FlowFile flowfile) {
String userName = context.getProperty(WORKDAY_USERNAME).evaluateAttributeExpressions(flowfile).getValue();
String password = context.getProperty(WORKDAY_PASSWORD).evaluateAttributeExpressions(flowfile).getValue();
String base64Credential = Base64.getEncoder().encodeToString((userName + USERNAME_PASSWORD_SEPARATOR + password).getBytes(StandardCharsets.UTF_8));
return BASIC_PREFIX + base64Credential;
String authType = context.getProperty(AUTH_TYPE).getValue();
if (BASIC_AUTH_TYPE.getValue().equals(authType)) {
String userName = context.getProperty(WORKDAY_USERNAME).evaluateAttributeExpressions(flowfile).getValue();
String password = context.getProperty(WORKDAY_PASSWORD).evaluateAttributeExpressions(flowfile).getValue();
String base64Credential = Base64.getEncoder().encodeToString((userName + USERNAME_PASSWORD_SEPARATOR + password).getBytes(StandardCharsets.UTF_8));
return BASIC_PREFIX + base64Credential;
} else {
OAuth2AccessTokenProvider tokenProvider = tokenProviderReference.get();
return BEARER_PREFIX + tokenProvider.getAccessDetails().getAccessToken();
}
}

private TransformResult transformRecords(ProcessSession session, FlowFile flowfile, FlowFile responseFlowFile, InputStream responseBodyStream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.net.URISyntaxException;
Expand All @@ -44,6 +46,7 @@
import org.apache.nifi.csv.CSVRecordSetWriter;
import org.apache.nifi.flowfile.attributes.CoreAttributes;
import org.apache.nifi.json.JsonTreeReader;
import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.serialization.RecordReaderFactory;
import org.apache.nifi.serialization.RecordSetWriterFactory;
Expand All @@ -55,7 +58,9 @@
import org.apache.nifi.web.client.provider.service.StandardWebClientServiceProvider;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.mockito.Answers;

class GetWorkdayReportTest {

Expand Down Expand Up @@ -84,40 +89,92 @@ public void shutdownServer() throws IOException {
mockWebServer.shutdown();
}

@Test
public void testNotValidWithoutReportUrlProperty() throws InitializationException {
withWebClientService();
runner.setProperty(GetWorkdayReport.WORKDAY_USERNAME, USER_NAME);
runner.setProperty(GetWorkdayReport.WORKDAY_PASSWORD, PASSWORD);

runner.assertNotValid();
}

@Test
public void testNotValidWithInvalidReportUrlProperty() throws InitializationException {
withWebClientService();
runner.setProperty(GetWorkdayReport.WORKDAY_USERNAME, USER_NAME);
runner.setProperty(GetWorkdayReport.WORKDAY_PASSWORD, PASSWORD);
runner.setProperty(GetWorkdayReport.REPORT_URL, INVALID_URL);
runner.assertNotValid();
@Nested
class BasicAuthPropertiesValidation {
@Test
void testNotValidWithoutReportUrlProperty() throws InitializationException {
withWebClientService();
runner.setProperty(GetWorkdayReport.WORKDAY_USERNAME, USER_NAME);
runner.setProperty(GetWorkdayReport.WORKDAY_PASSWORD, PASSWORD);

runner.assertNotValid();
}

@Test
void testNotValidWithInvalidReportUrlProperty() throws InitializationException {
withWebClientService();
runner.setProperty(GetWorkdayReport.WORKDAY_USERNAME, USER_NAME);
runner.setProperty(GetWorkdayReport.WORKDAY_PASSWORD, PASSWORD);
runner.setProperty(GetWorkdayReport.REPORT_URL, INVALID_URL);
runner.assertNotValid();
}

@Test
void testNotValidWithoutUserName() throws InitializationException {
withWebClientService();
runner.setProperty(GetWorkdayReport.WORKDAY_PASSWORD, PASSWORD);
runner.setProperty(GetWorkdayReport.REPORT_URL, REPORT_URL);

runner.assertNotValid();
}

@Test
void testNotValidWithoutPassword() throws InitializationException {
withWebClientService();
runner.setProperty(GetWorkdayReport.WORKDAY_USERNAME, USER_NAME);
runner.setProperty(GetWorkdayReport.REPORT_URL, REPORT_URL);

runner.assertNotValid();
}

@Test
void testNotValidWithoutWebClient() {
runner.setProperty(GetWorkdayReport.WORKDAY_USERNAME, USER_NAME);
runner.setProperty(GetWorkdayReport.WORKDAY_PASSWORD, PASSWORD);
runner.setProperty(GetWorkdayReport.REPORT_URL, REPORT_URL);

runner.assertNotValid();
}
}

@Test
public void testNotValidWithoutUserName() throws InitializationException {
withWebClientService();
runner.setProperty(GetWorkdayReport.WORKDAY_PASSWORD, PASSWORD);
runner.setProperty(GetWorkdayReport.REPORT_URL, REPORT_URL);

runner.assertNotValid();
}

@Test
public void testNotValidWithoutPassword() throws InitializationException {
withWebClientService();
runner.setProperty(GetWorkdayReport.WORKDAY_USERNAME, USER_NAME);
runner.setProperty(GetWorkdayReport.REPORT_URL, REPORT_URL);

runner.assertNotValid();
@Nested
class OAuthPropertiesValidation {
@BeforeEach
void setUp() {
runner.setProperty(GetWorkdayReport.AUTH_TYPE, GetWorkdayReport.OAUTH_TYPE);
}

@Test
void testNotValidWithoutOAuth2AccessTokenProvider() throws InitializationException {
withWebClientService();
runner.setProperty(GetWorkdayReport.REPORT_URL, REPORT_URL);

runner.assertNotValid();
}

@Test
void testNotValidWithInvalidReportUrlProperty() throws InitializationException {
withWebClientService();
withAccessTokenProvider();
runner.setProperty(GetWorkdayReport.REPORT_URL, INVALID_URL);
runner.assertNotValid();
}

@Test
void testNotValidWithoutReportUrlProperty() throws InitializationException {
withWebClientService();
withAccessTokenProvider();

runner.assertNotValid();
}

@Test
void testNotValidWithoutWebClient() throws InitializationException {
withAccessTokenProvider();
runner.setProperty(GetWorkdayReport.REPORT_URL, REPORT_URL);

runner.assertNotValid();
}
}

@Test
Expand Down Expand Up @@ -296,6 +353,26 @@ void testContentIsTransformedIfRecordReaderAndWriterIsDefined() throws Initializ
flowFile.assertContentEquals(csvContent);
}

@Test
void testOAuthAuthorization() throws InitializationException, InterruptedException {
runner.setIncomingConnection(false);
withWebClientService();
runner.setProperty(GetWorkdayReport.REPORT_URL, getMockWebServerUrl());
runner.setProperty(GetWorkdayReport.AUTH_TYPE, GetWorkdayReport.OAUTH_TYPE);
withAccessTokenProvider();

mockWebServer.enqueue(new MockResponse().setResponseCode(200).setHeader(CONTENT_TYPE, APPLICATION_JSON));

runner.run();

RecordedRequest recordedRequest = mockWebServer.takeRequest(1, TimeUnit.SECONDS);
String authorization = recordedRequest.getHeader(HEADER_AUTHORIZATION);
assertNotNull(authorization, "Authorization Header not found");

Pattern bearerPattern = Pattern.compile("^Bearer \\S+$");
assertTrue(bearerPattern.matcher(authorization).matches(), "OAuth bearer not matched");
}

@Test
void testBasicAuthentication() throws InitializationException, InterruptedException {
runner.setIncomingConnection(false);
Expand All @@ -320,6 +397,19 @@ private String getMockWebServerUrl() {
return mockWebServer.url("workdayReport").newBuilder().host(LOCALHOST).build().toString();
}

private void withAccessTokenProvider() throws InitializationException {
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
String accessToken = "access_token";

OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
when(oauth2AccessTokenProvider.getAccessDetails().getAccessToken()).thenReturn(accessToken);

runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
runner.enableControllerService(oauth2AccessTokenProvider);
runner.setProperty(GetWorkdayReport.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
}

private void withWebClientService() throws InitializationException {
String serviceIdentifier = StandardWebClientServiceProvider.class.getName();
WebClientServiceProvider webClientServiceProvider = new StandardWebClientServiceProvider();
Expand Down

0 comments on commit 53911bb

Please sign in to comment.