diff --git a/dsu.py b/dsu.py new file mode 100644 index 0000000000..03b363e79e --- /dev/null +++ b/dsu.py @@ -0,0 +1,55 @@ +class DisjointSetUnion: + def __init__(self, n): + self.parents = [i for i in range(n)] + self.ranks = [0] * n + + def find(self, x): + if self.parents[x] != x: + self.parents[x] = self.find(self.parents[x]) + return self.parents[x] + + def union(self, x, y): + root_x = self.find(x) + root_y = self.find(y) + + if root_x == root_y: + return + + if self.ranks[root_x] > self.ranks[root_y]: + self.parents[root_y] = root_x + elif self.ranks[root_x] < self.ranks[root_y]: + self.parents[root_x] = root_y + else: + self.parents[root_y] = root_x + self.ranks[root_x] += 1 + +class Graph: + def __init__(self, vertices, edges): + self.vertices = vertices + self.edges = edges + self.graph = [[] for _ in range(vertices)] + + def add_edge(self, u, v, w): + self.graph[u].append([v, w]) + self.graph[v].append([u, w]) + + def kruskal_mst(self): + result = [] + i, e = 0, 0 + + self.graph = sorted(self.graph, key=lambda item: item[1]) + + dsu = DisjointSetUnion(self.vertices) + + while e < self.vertices - 1: + u, w = self.graph[i] + i += 1 + + root_u = dsu.find(u) + + if root_u not in [dsu.find(v) for _, v in result]: + e += 1 + result.append((u, w)) + dsu.union(u, root_u) + + return result, sum([w for _, w in result]) \ No newline at end of file diff --git a/graph.py b/graph.py new file mode 100644 index 0000000000..330dd98bd8 --- /dev/null +++ b/graph.py @@ -0,0 +1,67 @@ +import heapq + +class Graph: + def __init__(self, vertices): + self.V = vertices + self.graph = [] + + def add_edge(self, u, v, w): + self.graph.append([u, v, w]) + + def find(self, parent, i): + if parent[i] == i: + return i + return self.find(parent, parent[i]) + + def union(self, parent, rank, x, y): + xroot = self.find(parent, x) + yroot = self.find(parent, y) + + if rank[xroot] < rank[yroot]: + parent[xroot] = yroot + elif rank[xroot] > rank[yroot]: + parent[yroot] = xroot + else: + parent[yroot] = xroot + rank[xroot] += 1 + + def kruskal_mst(self): + result = [] + i, e = 0, 0 + self.graph = sorted(self.graph, key=lambda item: item[2]) + parent = [] + rank = [] + + for node in range(self.V): + parent.append(node) + rank.append(0) + + while e < self.V - 1: + u, v, w = self.graph[i] + i = i + 1 + x = self.find(parent, u) + y = self.find(parent, v) + + if x != y: + e = e + 1 + result.append([u, v, w]) + self.union(parent, rank, x, y) + + return result, sum([weight for u, v, weight in result]) + + +if __name__ == "__main__": + g = Graph(4) + g.add_edge(0, 1, 10) + g.add_edge(1, 2, 5) + g.add_edge(2, 0, 2) + g.add_edge(1, 3, 15) + g.add_edge(3, 2, 1) + + mst, weight = g.kruskal_mst() + print("Edges in the MST are:") + + for u, v, w in mst: + print(f"{u} -- {v} == {w}") + + print(f"Minimum Spanning Tree Weight: {weight}") \ No newline at end of file diff --git a/kruskal.py b/kruskal.py new file mode 100644 index 0000000000..11662440cc --- /dev/null +++ b/kruskal.py @@ -0,0 +1,41 @@ +from typing import List, Tuple + +class DisjointSetUnion: + def __init__(self, n: int): + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x: int, y: int) -> None: + root_x = self.find(x) + root_y = self.find(y) + + if root_x == root_y: + return + + if self.rank[root_x] > self.rank[root_y]: + self.parent[root_y] = root_x + elif self.rank[root_x] < self.rank[root_y]: + self.parent[root_x] = root_y + else: + self.parent[root_y] = root_x + self.rank[root_x] += 1 + +def kruskal(n: int, edges: List[Tuple[int, int, int]]) -> Tuple[List[Tuple[int, int, int]], int]: + edges.sort(key=lambda edge: edge[2]) + dsu = DisjointSetUnion(n) + mst_edges = [] + total_weight = 0 + + for edge in edges: + x, y, weight = edge + if dsu.find(x) != dsu.find(y): + dsu.union(x, y) + mst_edges.append(edge) + total_weight += weight + + return mst_edges, total_weight \ No newline at end of file