Skip to content

Commit 5889a4c

Browse files
committed
Implement async SteinerTree
1 parent c0356d7 commit 5889a4c

File tree

2 files changed

+75
-3
lines changed

2 files changed

+75
-3
lines changed

algorithms-compute-facade/src/main/java/org/neo4j/gds/pathfinding/PathFindingComputeFacade.java

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@
7777
import org.neo4j.gds.spanningtree.Prim;
7878
import org.neo4j.gds.spanningtree.SpanningTree;
7979
import org.neo4j.gds.spanningtree.SpanningTreeParameters;
80+
import org.neo4j.gds.steiner.ShortestPathsSteinerAlgorithm;
81+
import org.neo4j.gds.steiner.SteinerTreeParameters;
82+
import org.neo4j.gds.steiner.SteinerTreeProgressTask;
8083
import org.neo4j.gds.steiner.SteinerTreeResult;
8184
import org.neo4j.gds.termination.TerminationFlag;
8285
import org.neo4j.gds.traversal.RandomWalk;
@@ -781,13 +784,57 @@ CompletableFuture<SpanningTree> spanningTree(
781784
);
782785
}
783786

784-
CompletableFuture<SteinerTreeResult> steinerTree() {
787+
CompletableFuture<SteinerTreeResult> steinerTree(
788+
GraphName graphName,
789+
GraphParameters graphParameters,
790+
Optional<String> relationshipProperty,
791+
SteinerTreeParameters parameters,
792+
JobId jobId,
793+
boolean logProgress
794+
) {
785795
// Fetch the Graph the algorithm will operate on
796+
var graph = graphStoreCatalogService.fetchGraphResources(
797+
graphName,
798+
graphParameters,
799+
relationshipProperty,
800+
new SourceNodeTargetNodesGraphStoreValidation(parameters.sourceNode(), parameters.targetNodes()),
801+
Optional.empty(),
802+
user,
803+
databaseId
804+
).graph();
805+
786806
// Create ProgressTracker
807+
var progressTracker = progressTrackerFactory.create(
808+
SteinerTreeProgressTask.create(parameters, graph.nodeCount()),
809+
jobId,
810+
parameters.concurrency(),
811+
logProgress
812+
);
813+
787814
// Create the algorithm
788-
// Submit the algorithm for async computation
815+
var mappedSourceNodeId = graph.toMappedNodeId(parameters.sourceNode());
816+
var mappedTargetNodeIds = parameters.targetNodes()
817+
.stream()
818+
.map(graph::safeToMappedNodeId)
819+
.toList();
789820

790-
return CompletableFuture.failedFuture(new RuntimeException("Not yet implemented"));
821+
var steinerTree = new ShortestPathsSteinerAlgorithm(
822+
graph,
823+
mappedSourceNodeId,
824+
mappedTargetNodeIds,
825+
parameters.delta(),
826+
parameters.concurrency(),
827+
parameters.applyRerouting(),
828+
executorService,
829+
progressTracker,
830+
terminationFlag
831+
);
832+
833+
// Submit the algorithm for async computation
834+
return algorithmCaller.run(
835+
steinerTree::compute,
836+
jobId
837+
);
791838
}
792839

793840
CompletableFuture<TopologicalSortResult> topologicalSort() {

algorithms-compute-facade/src/test/java/org/neo4j/gds/pathfinding/PathFindingComputeFacadeTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import org.neo4j.gds.pcst.PCSTParameters;
6060
import org.neo4j.gds.spanningtree.PrimOperators;
6161
import org.neo4j.gds.spanningtree.SpanningTreeParameters;
62+
import org.neo4j.gds.steiner.SteinerTreeParameters;
6263
import org.neo4j.gds.termination.TerminationFlag;
6364
import org.neo4j.gds.traversal.RandomWalkParameters;
6465
import org.neo4j.gds.traversal.TraversalParameters;
@@ -472,4 +473,28 @@ void spanningTree() {
472473
assertThat(future.join()).isNotNull();
473474
}
474475

476+
@Test
477+
void steinerTree() {
478+
var future = facade.steinerTree(
479+
new GraphName("foo"),
480+
new GraphParameters(
481+
List.of(NodeLabel.of("Node")),
482+
List.of(RelationshipType.of("REL")),
483+
true,
484+
Optional.empty()
485+
),
486+
Optional.empty(),
487+
new SteinerTreeParameters(
488+
new Concurrency(4),
489+
idFunction.of("a"),
490+
List.of(idFunction.of("c")),
491+
2.0,
492+
false
493+
),
494+
jobIdMock,
495+
true
496+
);
497+
assertThat(future.join()).isNotNull();
498+
}
499+
475500
}

0 commit comments

Comments
 (0)