|
77 | 77 | import org.neo4j.gds.spanningtree.Prim;
|
78 | 78 | import org.neo4j.gds.spanningtree.SpanningTree;
|
79 | 79 | import org.neo4j.gds.spanningtree.SpanningTreeParameters;
|
| 80 | +import org.neo4j.gds.steiner.ShortestPathsSteinerAlgorithm; |
| 81 | +import org.neo4j.gds.steiner.SteinerTreeParameters; |
| 82 | +import org.neo4j.gds.steiner.SteinerTreeProgressTask; |
80 | 83 | import org.neo4j.gds.steiner.SteinerTreeResult;
|
81 | 84 | import org.neo4j.gds.termination.TerminationFlag;
|
82 | 85 | import org.neo4j.gds.traversal.RandomWalk;
|
@@ -781,13 +784,57 @@ CompletableFuture<SpanningTree> spanningTree(
|
781 | 784 | );
|
782 | 785 | }
|
783 | 786 |
|
784 |
| - CompletableFuture<SteinerTreeResult> steinerTree() { |
| 787 | + CompletableFuture<SteinerTreeResult> steinerTree( |
| 788 | + GraphName graphName, |
| 789 | + GraphParameters graphParameters, |
| 790 | + Optional<String> relationshipProperty, |
| 791 | + SteinerTreeParameters parameters, |
| 792 | + JobId jobId, |
| 793 | + boolean logProgress |
| 794 | + ) { |
785 | 795 | // Fetch the Graph the algorithm will operate on
|
| 796 | + var graph = graphStoreCatalogService.fetchGraphResources( |
| 797 | + graphName, |
| 798 | + graphParameters, |
| 799 | + relationshipProperty, |
| 800 | + new SourceNodeTargetNodesGraphStoreValidation(parameters.sourceNode(), parameters.targetNodes()), |
| 801 | + Optional.empty(), |
| 802 | + user, |
| 803 | + databaseId |
| 804 | + ).graph(); |
| 805 | + |
786 | 806 | // Create ProgressTracker
|
| 807 | + var progressTracker = progressTrackerFactory.create( |
| 808 | + SteinerTreeProgressTask.create(parameters, graph.nodeCount()), |
| 809 | + jobId, |
| 810 | + parameters.concurrency(), |
| 811 | + logProgress |
| 812 | + ); |
| 813 | + |
787 | 814 | // Create the algorithm
|
788 |
| - // Submit the algorithm for async computation |
| 815 | + var mappedSourceNodeId = graph.toMappedNodeId(parameters.sourceNode()); |
| 816 | + var mappedTargetNodeIds = parameters.targetNodes() |
| 817 | + .stream() |
| 818 | + .map(graph::safeToMappedNodeId) |
| 819 | + .toList(); |
789 | 820 |
|
790 |
| - return CompletableFuture.failedFuture(new RuntimeException("Not yet implemented")); |
| 821 | + var steinerTree = new ShortestPathsSteinerAlgorithm( |
| 822 | + graph, |
| 823 | + mappedSourceNodeId, |
| 824 | + mappedTargetNodeIds, |
| 825 | + parameters.delta(), |
| 826 | + parameters.concurrency(), |
| 827 | + parameters.applyRerouting(), |
| 828 | + executorService, |
| 829 | + progressTracker, |
| 830 | + terminationFlag |
| 831 | + ); |
| 832 | + |
| 833 | + // Submit the algorithm for async computation |
| 834 | + return algorithmCaller.run( |
| 835 | + steinerTree::compute, |
| 836 | + jobId |
| 837 | + ); |
791 | 838 | }
|
792 | 839 |
|
793 | 840 | CompletableFuture<TopologicalSortResult> topologicalSort() {
|
|
0 commit comments