diff --git a/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-client-service-api/src/main/java/org/apache/nifi/graph/GraphClientService.java b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-client-service-api/src/main/java/org/apache/nifi/graph/GraphClientService.java index 4dd338b8068b..886f9bedbc57 100644 --- a/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-client-service-api/src/main/java/org/apache/nifi/graph/GraphClientService.java +++ b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-client-service-api/src/main/java/org/apache/nifi/graph/GraphClientService.java @@ -18,7 +18,9 @@ package org.apache.nifi.graph; import org.apache.nifi.controller.ControllerService; +import org.apache.nifi.util.Tuple; +import java.util.List; import java.util.Map; public interface GraphClientService extends ControllerService { @@ -30,6 +32,28 @@ public interface GraphClientService extends ControllerService { String PROPERTIES_SET = "graph.properties.set"; String ROWS_RETURNED = "graph.rows.returned"; - Map executeQuery(String query, Map parameters, GraphQueryResultCallback handler); + String NODES_TYPE = "Nodes"; + String EDGES_TYPE = "Edges"; + + /** + * Executes the given query using the parameters provided and returns a map of results. + * @param query The query to execute + * @param parameters The parameter values to be used in the query + * @param handler A callback handler to process the results of the query + * @return A map containing the results of the query execution, where the keys are column names and the values are the corresponding data. + */ + Map executeQuery(final String query, final Map parameters, final GraphQueryResultCallback handler); String getTransitUrl(); + + /** + * Generates a query/statement for setting properties on matched node(s) in the language associated with the Graph Database + * @param componentType The type of component that is executing the query, e.g. "Nodes", "Edges", etc. + * @param identifiersAndValues A tuple containing the name of and value for the property to match on, + * @param nodeType The type of node to match on, e.g. "Person", "Organization", etc. + * @param propertyMap A map of key/value pairs of property names and values to set on the matched node(s) + * @return A query/statement that can be executed against the Graph Database to set the properties on the matched node(s) + */ + default String generateSetPropertiesStatement(final String componentType, final List> identifiersAndValues, final String nodeType, final Map propertyMap) { + throw new UnsupportedOperationException("This capability is not implemented for this GraphClientService"); + } } \ No newline at end of file diff --git a/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/main/java/org/apache/nifi/processors/graph/EnrichGraphRecord.java b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/main/java/org/apache/nifi/processors/graph/EnrichGraphRecord.java new file mode 100644 index 000000000000..ac93716785a8 --- /dev/null +++ b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/main/java/org/apache/nifi/processors/graph/EnrichGraphRecord.java @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.graph; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.nifi.annotation.behavior.DynamicProperty; +import org.apache.nifi.annotation.behavior.InputRequirement; +import org.apache.nifi.annotation.behavior.WritesAttribute; +import org.apache.nifi.annotation.behavior.WritesAttributes; +import org.apache.nifi.annotation.documentation.CapabilityDescription; +import org.apache.nifi.annotation.documentation.Tags; +import org.apache.nifi.annotation.lifecycle.OnScheduled; +import org.apache.nifi.components.AllowableValue; +import org.apache.nifi.components.PropertyDescriptor; +import org.apache.nifi.components.Validator; +import org.apache.nifi.expression.ExpressionLanguageScope; +import org.apache.nifi.flowfile.FlowFile; +import org.apache.nifi.graph.GraphClientService; +import org.apache.nifi.processor.ProcessContext; +import org.apache.nifi.processor.ProcessSession; +import org.apache.nifi.processor.Relationship; +import org.apache.nifi.processor.exception.ProcessException; +import org.apache.nifi.processor.util.StandardValidators; +import org.apache.nifi.record.path.FieldValue; +import org.apache.nifi.record.path.RecordPath; +import org.apache.nifi.record.path.RecordPathResult; +import org.apache.nifi.record.path.util.RecordPathCache; +import org.apache.nifi.serialization.RecordReader; +import org.apache.nifi.serialization.RecordReaderFactory; +import org.apache.nifi.serialization.RecordSetWriter; +import org.apache.nifi.serialization.RecordSetWriterFactory; +import org.apache.nifi.serialization.WriteResult; +import org.apache.nifi.serialization.record.Record; +import org.apache.nifi.util.Tuple; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +@Tags({"graph", "gremlin", "cypher", "enrich", "record"}) +@CapabilityDescription("This processor uses fields from FlowFile records to add property values to nodes in a graph. Each record is associated with an individual node " + + "(associated by the specified 'identifier' field value), and a single FlowFile will be output for all successful operations. Failed records will be sent as " + + "individual FlowFiles to the failure relationship.") +@WritesAttributes({ + @WritesAttribute(attribute = EnrichGraphRecord.GRAPH_OPERATION_TIME, description = "The amount of time it took to execute all of the graph operations."), +}) +@InputRequirement(InputRequirement.Requirement.INPUT_REQUIRED) +@DynamicProperty(name = "Field(s) containing values to be added to the matched node as properties", + value = "The variable name to be set", expressionLanguageScope = ExpressionLanguageScope.FLOWFILE_ATTRIBUTES, + description = "A dynamic property specifying a RecordField Expression identifying field(s) for whose values will be added to the matched node as properties") +public class EnrichGraphRecord extends AbstractGraphExecutor { + + private static final AllowableValue NODES = new AllowableValue( + GraphClientService.NODES_TYPE, + GraphClientService.NODES_TYPE, + "Enrich nodes in the graph with properties from the incoming records. The node identifier is determined by the 'Identifier Field(s)' property." + ); + + private static final AllowableValue EDGES = new AllowableValue( + GraphClientService.EDGES_TYPE, + GraphClientService.EDGES_TYPE, + "Enrich edges in the graph with properties from the incoming records. The edge identifier is determined by the 'Identifier Field(s)' property." + ); + + public static final PropertyDescriptor CLIENT_SERVICE = new PropertyDescriptor.Builder() + .name("Graph Client Service") + .description("The graph client service for connecting to a graph database.") + .identifiesControllerService(GraphClientService.class) + .addValidator(Validator.VALID) + .required(true) + .build(); + + public static final PropertyDescriptor READER_SERVICE = new PropertyDescriptor.Builder() + .name("Record Reader") + .description("The record reader to use with this processor to read incoming records.") + .identifiesControllerService(RecordReaderFactory.class) + .required(true) + .addValidator(Validator.VALID) + .build(); + + public static final PropertyDescriptor WRITER_SERVICE = new PropertyDescriptor.Builder() + .name("Failed Record Writer") + .description("The record writer to use for writing failed records.") + .identifiesControllerService(RecordSetWriterFactory.class) + .required(true) + .addValidator(Validator.VALID) + .build(); + + public static final PropertyDescriptor UPDATE_TYPE = new PropertyDescriptor.Builder() + .name("Components to Enrich") + .description("The components in the graph to enrich with properties from the incoming records.") + .addValidator(Validator.VALID) + .allowableValues(NODES, EDGES) + .defaultValue(NODES.getValue()) + .required(true) + .build(); + + public static final PropertyDescriptor IDENTIFIER_FIELD = new PropertyDescriptor.Builder() + .name("Identifier Field(s)") + .description("A RecordPath Expression for field(s) in the record used to match the node identifier(s) in order to set properties on that node") + .required(true) + .expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES) + .addValidator(StandardValidators.NON_EMPTY_EL_VALIDATOR) + .addValidator(StandardValidators.ATTRIBUTE_KEY_PROPERTY_NAME_VALIDATOR) + .build(); + + public static final PropertyDescriptor NODE_TYPE = new PropertyDescriptor.Builder() + .name("Node/Edge Type") + .description("The type of the nodes or edges to match on. Setting this can result in faster execution") + .required(false) + .expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES) + .addValidator(StandardValidators.NON_EMPTY_EL_VALIDATOR) + .build(); + + @Override + protected PropertyDescriptor getSupportedDynamicPropertyDescriptor(final String propertyDescriptorName) { + return new PropertyDescriptor.Builder() + .name(propertyDescriptorName) + .required(false) + .addValidator(StandardValidators.ATTRIBUTE_KEY_PROPERTY_NAME_VALIDATOR) + .addValidator(StandardValidators.NON_EMPTY_VALIDATOR) + .expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES) + .dynamic(true) + .build(); + } + + public static final Relationship ORIGINAL = new Relationship.Builder().name("original") + .description("Original flow files that successfully interacted with " + + "graph server.") + .build(); + public static final Relationship FAILURE = new Relationship.Builder().name("failure") + .description("Flow files that fail to interact with graph server.") + .build(); + public static final Relationship GRAPH = new Relationship.Builder().name("response") + .description("The response object from the graph server.") + .autoTerminateDefault(true) + .build(); + + private static final List PROPERTY_DESCRIPTORS = List.of( + CLIENT_SERVICE, + READER_SERVICE, + WRITER_SERVICE, + UPDATE_TYPE, + IDENTIFIER_FIELD, + NODE_TYPE + ); + + private static final Set RELATIONSHIPS = Set.of( + ORIGINAL, + FAILURE, + GRAPH + ); + + public static final String RECORD_COUNT = "record.count"; + public static final String GRAPH_OPERATION_TIME = "graph.operations.took"; + private volatile RecordPathCache recordPathCache; + + @Override + public Set getRelationships() { + return RELATIONSHIPS; + } + + @Override + public List getSupportedPropertyDescriptors() { + return PROPERTY_DESCRIPTORS; + } + + private GraphClientService clientService; + private RecordReaderFactory recordReaderFactory; + private RecordSetWriterFactory recordSetWriterFactory; + private final ObjectMapper mapper = new ObjectMapper(); + + @Override + @OnScheduled + public void onScheduled(ProcessContext context) { + clientService = context.getProperty(CLIENT_SERVICE).asControllerService(GraphClientService.class); + recordReaderFactory = context.getProperty(READER_SERVICE).asControllerService(RecordReaderFactory.class); + recordSetWriterFactory = context.getProperty(WRITER_SERVICE).asControllerService(RecordSetWriterFactory.class); + recordPathCache = new RecordPathCache(100); + } + + private List getRecordValue(Record record, RecordPath recordPath) { + final RecordPathResult result = recordPath.evaluate(record); + final List values = result.getSelectedFields().toList(); + if (!values.isEmpty()) { + if (values.size() == 1) { + FieldValue fieldValue = values.get(0); + Object raw = fieldValue.getValue(); + + if (raw != null && raw.getClass().isArray()) { + return Collections.emptyList(); + } + + return List.of(fieldValue); + } else { + return values; + } + } else { + return null; + } + } + + @Override + public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { + FlowFile input = session.get(); + if ( input == null ) { + return; + } + + Map dynamic = new HashMap<>(); + + FlowFile finalInput = input; + context.getProperties() + .keySet().stream() + .filter(PropertyDescriptor::isDynamic) + .forEach(it -> + dynamic.put(it.getName(), recordPathCache.getCompiled( + context + .getProperty(it.getName()) + .evaluateAttributeExpressions(finalInput) + .getValue())) + ); + + long delta; + FlowFile failedRecords = session.create(input); + WriteResult failedWriteResult = null; + try (InputStream is = session.read(input); + RecordReader reader = recordReaderFactory.createRecordReader(input, is, getLogger()); + OutputStream os = session.write(failedRecords); + RecordSetWriter failedWriter = recordSetWriterFactory.createWriter(getLogger(), reader.getSchema(), os, input.getAttributes()) + ) { + Record record; + long start = System.currentTimeMillis(); + failedWriter.beginRecordSet(); + int records = 0; + while ((record = reader.nextRecord()) != null) { + FlowFile graph = session.create(input); + + try { + Map dynamicPropertyMap = new HashMap<>(); + for (String entry : dynamic.keySet()) { + if (!dynamicPropertyMap.containsKey(entry)) { + final List propertyValues = getRecordValue(record, dynamic.get(entry)); + // Use the first value if multiple are found + if (propertyValues == null || propertyValues.isEmpty() || propertyValues.get(0).getValue() == null) { + throw new IOException("Dynamic property field(s) not found in record (check the RecordPath Expression), sending this record to failure"); + } + + dynamicPropertyMap.put(entry, propertyValues.get(0).getValue()); + } + } + + dynamicPropertyMap.putAll(input.getAttributes()); + getLogger().debug("Dynamic Properties: {}", dynamicPropertyMap); + final String identifierField = context.getProperty(IDENTIFIER_FIELD).evaluateAttributeExpressions(input).getValue(); + final String nodeType = context.getProperty(NODE_TYPE).evaluateAttributeExpressions(input).getValue(); + final RecordPath identifierPath = recordPathCache.getCompiled(identifierField); + final List identifierValues = getRecordValue(record, identifierPath); + if (identifierValues == null || identifierValues.isEmpty()) { + throw new IOException("Identifier field(s) not found in record (check the RecordPath Expression), sending this record to failure"); + } + List> identifiersAndValues = new ArrayList<>(identifierValues.size()); + for (FieldValue fieldValue: identifierValues) { + if (fieldValue.getValue() == null) { + throw new IOException(String.format("Identifier field '%s' is null for record at index %d, sending this record to failure", identifierField, records)); + } + identifiersAndValues.add(new Tuple<>(fieldValue.getField().getFieldName(), fieldValue.getValue().toString())); + } + + final String setStatement = clientService.generateSetPropertiesStatement( + GraphClientService.NODES_TYPE, + identifiersAndValues, + nodeType, + dynamicPropertyMap); + + List> graphResponses = new ArrayList<>(executeQuery(setStatement, dynamicPropertyMap)); + + OutputStream graphOutputStream = session.write(graph); + String graphOutput = mapper.writerWithDefaultPrettyPrinter().writeValueAsString(graphResponses); + graphOutputStream.write(graphOutput.getBytes(StandardCharsets.UTF_8)); + graphOutputStream.close(); + session.transfer(graph, GRAPH); + } catch (Exception e) { + getLogger().error("Error processing record at index {}", records, e); + // write failed records to a flowfile destined for the failure relationship + failedWriter.write(record); + session.remove(graph); + } finally { + records++; + } + } + long end = System.currentTimeMillis(); + delta = (end - start) / 1000; + if (getLogger().isDebugEnabled()) { + getLogger().debug(String.format("Took %s seconds.\nHandled %d records", delta, records)); + } + failedWriteResult = failedWriter.finishRecordSet(); + failedWriter.flush(); + + } catch (Exception ex) { + getLogger().error("Error reading records, routing input FlowFile to failure", ex); + session.remove(failedRecords); + session.transfer(input, FAILURE); + return; + } + + // Generate provenance and send input flowfile to success + session.getProvenanceReporter().send(input, clientService.getTransitUrl(), delta * 1000); + + if (failedWriteResult.getRecordCount() < 1) { + // No failed records, remove the failure flowfile and send the input flowfile to success + session.remove(failedRecords); + input = session.putAttribute(input, GRAPH_OPERATION_TIME, String.valueOf(delta)); + session.transfer(input, ORIGINAL); + } else { + failedRecords = session.putAttribute(failedRecords, RECORD_COUNT, String.valueOf(failedWriteResult.getRecordCount())); + session.transfer(failedRecords, FAILURE); + // There were failures, don't send the input FlowFile to SUCCESS + session.remove(input); + } + } + + private List> executeQuery(String recordScript, Map parameters) { + ObjectMapper mapper = new ObjectMapper(); + List> graphResponses = new ArrayList<>(); + clientService.executeQuery(recordScript, parameters, (map, b) -> { + if (getLogger().isDebugEnabled()) { + try { + getLogger().debug(mapper.writerWithDefaultPrettyPrinter().writeValueAsString(map)); + } catch (JsonProcessingException ex) { + getLogger().error("Error converted map to JSON ", ex); + } + } + graphResponses.add(map); + }); + return graphResponses; + } +} diff --git a/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/main/resources/META-INF/services/org.apache.nifi.processor.Processor b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/main/resources/META-INF/services/org.apache.nifi.processor.Processor index 2ab7e95b40d1..57578cc49fd6 100644 --- a/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/main/resources/META-INF/services/org.apache.nifi.processor.Processor +++ b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/main/resources/META-INF/services/org.apache.nifi.processor.Processor @@ -12,5 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +org.apache.nifi.processors.graph.EnrichGraphRecord org.apache.nifi.processors.graph.ExecuteGraphQuery org.apache.nifi.processors.graph.ExecuteGraphQueryRecord \ No newline at end of file diff --git a/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/MockCypherClientService.java b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/MockCypherClientService.java index 0b55ef12bd93..dc3af0445758 100644 --- a/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/MockCypherClientService.java +++ b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/MockCypherClientService.java @@ -20,16 +20,20 @@ import org.apache.nifi.controller.AbstractControllerService; import org.apache.nifi.graph.GraphClientService; import org.apache.nifi.graph.GraphQueryResultCallback; +import org.apache.nifi.processor.exception.ProcessException; +import org.apache.nifi.util.Tuple; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; public class MockCypherClientService extends AbstractControllerService implements GraphClientService { @Override public Map executeQuery(String query, Map parameters, GraphQueryResultCallback handler) { - handler.process(Map.of("name", "John Smith", "age", 40), true); - handler.process(Map.of("name", "John Smith", "age", 40), false); + handler.process(Map.of("name", "John Smith", "age", 40, "relationship", "ASSOCIATED_WITH"), true); + handler.process(Map.of("name", "John Smith", "age", 40, "relationship", "ASSOCIATED_WITH"), false); Map resultAttributes = new HashMap<>(); resultAttributes.put(NODES_CREATED, String.valueOf(1)); @@ -47,4 +51,85 @@ public Map executeQuery(String query, Map parame public String getTransitUrl() { return "mock://localhost:12345/fake_database"; } + + @Override + public String generateSetPropertiesStatement(final String componentType, + final List> identifiersAndValues, + final String nodeType, + final Map propertyMap) { + + StringBuilder queryBuilder = switch (componentType) { + case GraphClientService.NODES_TYPE -> getNodeQueryBuilder(identifiersAndValues, nodeType); + case GraphClientService.EDGES_TYPE -> getEdgeQueryBuilder(identifiersAndValues, nodeType); + default -> throw new ProcessException("Unsupported component type: " + componentType); + }; + + queryBuilder.append(")\n") + .append("ON MATCH SET "); + + List setClauses = new ArrayList<>(); + for (Map.Entry entry : propertyMap.entrySet()) { + StringBuilder setClause = new StringBuilder("n.") + .append(entry.getKey()) + .append(" = "); + if (entry.getValue() == null) { + setClause.append(" NULL"); + } else { + setClause.append("'") + .append(entry.getValue()) + .append("'"); + } + setClauses.add(setClause.toString()); + } + String setClauseString = String.join(", ", setClauses); + queryBuilder.append(setClauseString) + .append("\nON CREATE SET ") + .append(setClauseString); + + return queryBuilder.toString(); + } + + private static StringBuilder getNodeQueryBuilder(List> identifiersAndValues, String nodeType) { + StringBuilder queryBuilder = new StringBuilder("MERGE (n"); + if (nodeType != null && !nodeType.isEmpty()) { + queryBuilder.append(":").append(nodeType); + } + + buildMatchClause(identifiersAndValues, queryBuilder); + return queryBuilder; + } + + private static StringBuilder getEdgeQueryBuilder(List> identifiersAndValues, String edgeType) { + StringBuilder queryBuilder = new StringBuilder("MERGE (n)<-[e:"); + + if (edgeType == null || edgeType.isEmpty()) { + throw new ProcessException("Edge type must not be null or empty"); + } + queryBuilder.append(edgeType); + + buildMatchClause(identifiersAndValues, queryBuilder); + queryBuilder.append("]-> (x)"); + return queryBuilder; + } + + private static void buildMatchClause(List> identifiersAndValues, StringBuilder queryBuilder) { + if (!identifiersAndValues.isEmpty()) { + queryBuilder.append(" {"); + + List identifierNamesAndValues = new ArrayList<>(); + for (Tuple identifierAndValue : identifiersAndValues) { + if (identifierAndValue == null || identifierAndValue.getKey() == null || identifierAndValue.getValue() == null) { + throw new ProcessException("Identifiers and values must not be null"); + } + + final String identifierName = identifierAndValue.getKey(); + final Object identifierObject = identifierAndValue.getValue(); + if (identifierObject != null) { + identifierNamesAndValues.add(identifierName + ": '" + identifierObject + "'"); + } + } + queryBuilder.append(String.join(", ", identifierNamesAndValues)) + .append("}"); + } + } } diff --git a/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/TestEnrichGraphRecord.java b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/TestEnrichGraphRecord.java new file mode 100644 index 000000000000..5bbc1f886c17 --- /dev/null +++ b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/TestEnrichGraphRecord.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.graph; + +import org.apache.nifi.graph.GraphClientService; +import org.apache.nifi.json.JsonTreeReader; +import org.apache.nifi.processor.Processor; +import org.apache.nifi.serialization.record.Record; +import org.apache.nifi.serialization.RecordReader; +import org.apache.nifi.serialization.record.MockRecordWriter; +import org.apache.nifi.util.MockComponentLog; +import org.apache.nifi.util.MockFlowFile; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.nifi.processors.graph.EnrichGraphRecord.CLIENT_SERVICE; +import static org.apache.nifi.processors.graph.EnrichGraphRecord.IDENTIFIER_FIELD; +import static org.apache.nifi.processors.graph.EnrichGraphRecord.NODE_TYPE; +import static org.apache.nifi.processors.graph.EnrichGraphRecord.READER_SERVICE; +import static org.apache.nifi.processors.graph.EnrichGraphRecord.WRITER_SERVICE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class TestEnrichGraphRecord { + + private TestRunner testRunner; + private JsonTreeReader reader; + private Processor processor; + + @BeforeEach + public void setup() throws Exception { + processor = new EnrichGraphRecord(); + testRunner = TestRunners.newTestRunner(processor); + + GraphClientService mockGraphClientService = new MockCypherClientService(); + MockRecordWriter writer = new MockRecordWriter(); + reader = new JsonTreeReader(); + + testRunner.setProperty(CLIENT_SERVICE, "graphClient"); + testRunner.addControllerService("graphClient", mockGraphClientService); + testRunner.addControllerService("reader", reader); + testRunner.addControllerService("writer", writer); + testRunner.setProperty(READER_SERVICE, "reader"); + testRunner.setProperty(WRITER_SERVICE, "writer"); + testRunner.enableControllerService(writer); + testRunner.enableControllerService(reader); + testRunner.enableControllerService(mockGraphClientService); + } + + @Test + public void testSuccessfulNodeProcessing() { + Map attributes = new HashMap<>(); + attributes.put("id", "123"); + + String inputContent = "[{\"id\": \"123\", \"name\": \"Node1\"},{\"id\": \"789\", \"name\": \"Node2\"}]"; + testRunner.setProperty(IDENTIFIER_FIELD, "//id"); + testRunner.setProperty("name", "//name"); + testRunner.enqueue(inputContent.getBytes(), attributes); + + testRunner.run(); + + testRunner.assertTransferCount(EnrichGraphRecord.ORIGINAL, 1); + testRunner.assertTransferCount(EnrichGraphRecord.FAILURE, 0); + testRunner.assertTransferCount(EnrichGraphRecord.GRAPH, 2); + + MockFlowFile originalFlowFile = testRunner.getFlowFilesForRelationship(EnrichGraphRecord.ORIGINAL).get(0); + assertEquals("123", originalFlowFile.getAttribute("id")); + MockFlowFile successFlowFile = testRunner.getFlowFilesForRelationship(EnrichGraphRecord.GRAPH).get(0); + + try { + RecordReader recordReader = reader.createRecordReader(successFlowFile, successFlowFile.getContentStream(), new MockComponentLog("1", processor)); + Record record = recordReader.nextRecord(); + assertEquals("John Smith", record.getValue("name")); + assertEquals(40, record.getAsInt("age")); + } catch (Exception e) { + fail("Should not reach here"); + } + } + + @Test + public void testSuccessfulEdgeProcessing() { + Map attributes = new HashMap<>(); + attributes.put("id", "123"); + + String inputContent = "[{\"id\": \"123\", \"name\": \"Node1\", \"relationship\": \"ASSOCIATED_WITH\"}," + + "{\"id\": \"789\", \"name\": \"Node2\",\"relationship\": \"ASSOCIATED_WITH\"}]"; + testRunner.setProperty(IDENTIFIER_FIELD, "//relationship"); + testRunner.setProperty("name", "//name"); + testRunner.setProperty(NODE_TYPE, GraphClientService.EDGES_TYPE); + testRunner.enqueue(inputContent.getBytes(), attributes); + + testRunner.run(); + + testRunner.assertTransferCount(EnrichGraphRecord.ORIGINAL, 1); + testRunner.assertTransferCount(EnrichGraphRecord.FAILURE, 0); + testRunner.assertTransferCount(EnrichGraphRecord.GRAPH, 2); + + MockFlowFile successFlowFile = testRunner.getFlowFilesForRelationship(EnrichGraphRecord.GRAPH).get(0); + + try { + RecordReader recordReader = reader.createRecordReader(successFlowFile, successFlowFile.getContentStream(), new MockComponentLog("1", processor)); + Record record = recordReader.nextRecord(); + assertEquals("John Smith", record.getValue("name")); + assertEquals(40, record.getAsInt("age")); + assertEquals("ASSOCIATED_WITH", record.getValue("relationship")); + } catch (Exception e) { + fail("Should not reach here"); + } + } + + @Test + public void testNullIdentifierValue() { + Map attributes = new HashMap<>(); + attributes.put("id", "123"); + + // Two bad identifiers, one good + String inputContent = "[{\"id\": null, \"name\": \"Node1\"},{\"id\": null, \"name\": \"Node2\"},{\"id\": \"123\", \"name\": \"Node3\"}]"; + testRunner.setProperty(IDENTIFIER_FIELD, "//id"); + testRunner.setProperty("name", "//name"); + testRunner.enqueue(inputContent.getBytes(), attributes); + + testRunner.run(); + + testRunner.assertTransferCount(EnrichGraphRecord.ORIGINAL, 0); + testRunner.assertTransferCount(EnrichGraphRecord.FAILURE, 1); + testRunner.assertTransferCount(EnrichGraphRecord.GRAPH, 1); + + // Verify 2 failed records + MockFlowFile failedFlowFile = testRunner.getFlowFilesForRelationship(EnrichGraphRecord.FAILURE).get(0); + assertEquals("2", failedFlowFile.getAttribute("record.count")); + } + + @Test + public void testFailedProcessing() { + Map attributes = new HashMap<>(); + attributes.put("id", "null"); + + String inputContent = "[{\"id\": null, \"name\": \"Node1\"}]"; + testRunner.setProperty(IDENTIFIER_FIELD, "//id"); + testRunner.setProperty("name", "//name"); + + testRunner.enqueue(inputContent.getBytes(), attributes); + + testRunner.run(); + + testRunner.assertTransferCount(EnrichGraphRecord.ORIGINAL, 0); + testRunner.assertTransferCount(EnrichGraphRecord.FAILURE, 1); + testRunner.assertTransferCount(EnrichGraphRecord.GRAPH, 0); + } +} \ No newline at end of file diff --git a/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/TestExecuteGraphQuery.java b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/TestExecuteGraphQuery.java index cfc65b77a9bb..a72b3d4f7efa 100644 --- a/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/TestExecuteGraphQuery.java +++ b/nifi-extension-bundles/nifi-graph-bundle/nifi-graph-processors/src/test/java/org/apache/nifi/processors/graph/TestExecuteGraphQuery.java @@ -85,9 +85,10 @@ private void testExecute(int success, int failure, int original) throws Exceptio assertNotNull(parsed); assertEquals(2, parsed.size()); for (Map result : parsed) { - assertEquals(2, result.size()); + assertEquals(3, result.size()); assertTrue(result.containsKey("name")); assertTrue(result.containsKey("age")); + assertTrue(result.containsKey("relationship")); } } } \ No newline at end of file diff --git a/nifi-extension-bundles/nifi-graph-bundle/nifi-neo4j-cypher-service/src/main/java/org/apache/nifi/graph/Neo4JCypherClientService.java b/nifi-extension-bundles/nifi-graph-bundle/nifi-neo4j-cypher-service/src/main/java/org/apache/nifi/graph/Neo4JCypherClientService.java index 4a89d8dbae95..06dd5c28782e 100644 --- a/nifi-extension-bundles/nifi-graph-bundle/nifi-neo4j-cypher-service/src/main/java/org/apache/nifi/graph/Neo4JCypherClientService.java +++ b/nifi-extension-bundles/nifi-graph-bundle/nifi-neo4j-cypher-service/src/main/java/org/apache/nifi/graph/Neo4JCypherClientService.java @@ -29,6 +29,7 @@ import org.apache.nifi.expression.ExpressionLanguageScope; import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.processor.util.StandardValidators; +import org.apache.nifi.util.Tuple; import org.neo4j.driver.AuthTokens; import org.neo4j.driver.Config; import org.neo4j.driver.Driver; @@ -56,7 +57,8 @@ "the Neo4J driver that corresponds to most of the settings for this service can be found here: " + "https://neo4j.com/docs/driver-manual/current/client-applications/#driver-configuration. This service was created as a " + "result of the break in driver compatibility between Neo4J 3.X and 4.X and might be renamed in the future if and when " + - "Neo4J should break driver compatibility between 4.X and a future release.") + "Neo4J should break driver compatibility between 4.X and a future release. NOTE: When generating Cypher statements for node/edge properties, " + + "special characters in the property names will be changed to underscores ('_') in the graph") public class Neo4JCypherClientService extends AbstractControllerService implements GraphClientService { public static final PropertyDescriptor CONNECTION_URL = new PropertyDescriptor.Builder() .name("neo4j-connection-url") @@ -291,4 +293,96 @@ public Map executeQuery(String query, Map parame public String getTransitUrl() { return connectionUrl; } + + @Override + public String generateSetPropertiesStatement(final String componentType, + final List> identifiersAndValues, + final String nodeType, + final Map propertyMap) { + + StringBuilder queryBuilder = switch (componentType) { + case GraphClientService.NODES_TYPE -> getNodeQueryBuilder(identifiersAndValues, nodeType); + case GraphClientService.EDGES_TYPE -> getEdgeQueryBuilder(identifiersAndValues, nodeType); + default -> throw new ProcessException("Unsupported component type: " + componentType); + }; + + queryBuilder.append(")\n") + .append("ON MATCH SET "); + + List setClauses = new ArrayList<>(); + for (Map.Entry entry : propertyMap.entrySet()) { + StringBuilder setClause = new StringBuilder("n.") + .append(getNormalizedName(entry.getKey())) + .append(" = "); + if (entry.getValue() == null) { + setClause.append(" NULL"); + } else { + setClause.append("'") + .append(entry.getValue()) + .append("'"); + } + setClauses.add(setClause.toString()); + } + String setClauseString = String.join(", ", setClauses); + queryBuilder.append(setClauseString) + .append("\nON CREATE SET ") + .append(setClauseString); + + return queryBuilder.toString(); + } + + private static StringBuilder getNodeQueryBuilder(List> identifiersAndValues, String nodeType) { + StringBuilder queryBuilder = new StringBuilder("MERGE (n"); + if (nodeType != null && !nodeType.isEmpty()) { + queryBuilder.append(":").append(nodeType); + } + + buildMatchClause(identifiersAndValues, queryBuilder); + return queryBuilder; + } + + private static StringBuilder getEdgeQueryBuilder(List> identifiersAndValues, String edgeType) { + StringBuilder queryBuilder = new StringBuilder("MERGE (n)<-[e:"); + + if (edgeType == null || edgeType.isEmpty()) { + throw new ProcessException("Edge type must not be null or empty"); + } + queryBuilder.append(edgeType); + + buildMatchClause(identifiersAndValues, queryBuilder); + queryBuilder.append("]-> (x)"); + return queryBuilder; + } + + private static void buildMatchClause(List> identifiersAndValues, StringBuilder queryBuilder) { + if (!identifiersAndValues.isEmpty()) { + queryBuilder.append(" {"); + + List identifierNamesAndValues = new ArrayList<>(); + for (Tuple identifierAndValue : identifiersAndValues) { + if (identifierAndValue == null || identifierAndValue.getKey() == null || identifierAndValue.getValue() == null) { + throw new ProcessException("Identifiers and values must not be null"); + } + + final String identifierName = identifierAndValue.getKey(); + final Object identifierObject = identifierAndValue.getValue(); + if (identifierObject != null) { + identifierNamesAndValues.add(getNormalizedName(identifierName) + ": '" + identifierObject + "'"); + } + } + queryBuilder.append(String.join(", ", identifierNamesAndValues)) + .append("}"); + } + } + /* + merge(d:Disease {id:'EFO_0000279'})-[e:ASSOCIATED_WITH]-> (t:Target) on match set e.test = 'myLabel' + */ + + private static String getNormalizedName(String identifierName) { + StringBuilder normalizedName = new StringBuilder(identifierName.replaceAll("[^A-Za-z0-9_]", "_")); + if (Character.isDigit(normalizedName.charAt(0))) { + normalizedName.append("_").append(normalizedName); + } + return normalizedName.toString(); + } }