diff --git a/include/kaminpar-shm/kaminpar.h b/include/kaminpar-shm/kaminpar.h index d129ec09..c3fe4126 100644 --- a/include/kaminpar-shm/kaminpar.h +++ b/include/kaminpar-shm/kaminpar.h @@ -694,6 +694,9 @@ class Graph { [[nodiscard]] bool sorted() const; + void set_level(int level); + [[nodiscard]] int level() const; + [[nodiscard]] AbstractGraph *underlying_graph(); [[nodiscard]] const AbstractGraph *underlying_graph() const; @@ -707,6 +710,7 @@ class Graph { private: std::unique_ptr _underlying_graph; + int _level = 0; }; [[nodiscard]] Graph compress( diff --git a/kaminpar-shm/coarsening/abstract_cluster_coarsener.cc b/kaminpar-shm/coarsening/abstract_cluster_coarsener.cc index 0add8285..2b475f6d 100644 --- a/kaminpar-shm/coarsening/abstract_cluster_coarsener.cc +++ b/kaminpar-shm/coarsening/abstract_cluster_coarsener.cc @@ -95,8 +95,7 @@ bool AbstractClusterCoarsener::keep_allocated_memory() const { return level() >= _c_ctx.clustering.max_mem_free_coarsening_level; } -void AbstractClusterCoarsener::compute_clustering_for_current_graph( - StaticArray &clustering +void AbstractClusterCoarsener::compute_clustering_for_current_graph(StaticArray &clustering ) { const bool free_allocated_memory = !keep_allocated_memory(); const NodeWeight total_node_weight = current().total_node_weight(); @@ -174,9 +173,13 @@ void AbstractClusterCoarsener::contract_current_graph_and_push(StaticArrayget().set_level(level() + 1); + return c_graph; + }()); auto project_communities = [&](const std::size_t fine_n, const NodeID *fine_ptr, diff --git a/kaminpar-shm/datastructures/graph.cc b/kaminpar-shm/datastructures/graph.cc index f85febc0..97d382c2 100644 --- a/kaminpar-shm/datastructures/graph.cc +++ b/kaminpar-shm/datastructures/graph.cc @@ -16,6 +16,7 @@ namespace kaminpar::shm { Graph::Graph(std::unique_ptr graph) : _underlying_graph(std::move(graph)) {} + Graph::~Graph() = default; Graph::Graph(Graph &&) noexcept = default; @@ -53,6 +54,14 @@ bool Graph::sorted() const { return _underlying_graph->sorted(); } +void Graph::set_level(const int level) { + _level = level; +} + +int Graph::level() const { + return _level; +} + AbstractGraph *Graph::underlying_graph() { return _underlying_graph.get(); } diff --git a/kaminpar-shm/refinement/jet/jet_refiner.cc b/kaminpar-shm/refinement/jet/jet_refiner.cc index 8d6e150e..06822071 100644 --- a/kaminpar-shm/refinement/jet/jet_refiner.cc +++ b/kaminpar-shm/refinement/jet/jet_refiner.cc @@ -37,7 +37,8 @@ template class JetRefinerImpl { SCOPED_TIMER("Jet Refiner"); SCOPED_TIMER("Initialization"); - const bool is_coarse_level = p_graph.graph().n() < _ctx.partition.n; + const bool is_coarse_level = p_graph.graph().level() > 0; + if (is_coarse_level) { _num_rounds = _ctx.refinement.jet.num_rounds_on_coarse_level; _initial_gain_temp = _ctx.refinement.jet.initial_gain_temp_on_coarse_level;