Skip to content

Latest commit

 

History

History
452 lines (378 loc) · 11.9 KB

File metadata and controls

452 lines (378 loc) · 11.9 KB
comments difficulty edit_url tags
true
困难
深度优先搜索
广度优先搜索
并查集
数组
哈希表

English Version

题目描述

给出了一个由 n 个节点组成的网络,用 n × n 个邻接矩阵图 graph 表示。在节点网络中,当 graph[i][j] = 1 时,表示节点 i 能够直接连接到另一个节点 j。 

一些节点 initial 最初被恶意软件感染。只要两个节点直接连接,且其中至少一个节点受到恶意软件的感染,那么两个节点都将被恶意软件感染。这种恶意软件的传播将继续,直到没有更多的节点可以被这种方式感染。

假设 M(initial) 是在恶意软件停止传播之后,整个网络中感染恶意软件的最终节点数。

如果从 initial 中移除某一节点能够最小化 M(initial), 返回该节点。如果有多个节点满足条件,就返回索引最小的节点。

请注意,如果某个节点已从受感染节点的列表 initial 中删除,它以后仍有可能因恶意软件传播而受到感染。

 

示例 1:

输入:graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1]
输出:0

示例 2:

输入:graph = [[1,0,0],[0,1,0],[0,0,1]], initial = [0,2]
输出:0

示例 3:

输入:graph = [[1,1,1],[1,1,1],[1,1,1]], initial = [1,2]
输出:1

 

提示:

  • n == graph.length
  • n == graph[i].length
  • 2 <= n <= 300
  • graph[i][j] == 0 或 1.
  • graph[i][j] == graph[j][i]
  • graph[i][i] == 1
  • 1 <= initial.length <= n
  • 0 <= initial[i] <= n - 1
  • initial 中所有整数均不重复

解法

方法一:并查集

根据题目描述,如果初始时有若干个节点属于同一个连通分量,那么一共可以分为三种情况:

  1. 这些节点中没有一个节点被感染
  2. 这些节点中只有一个节点被感染
  3. 这些节点中有多个节点被感染

我们要考虑的是,移除某个感染节点后,剩下的节点中被感染的节点数最少。

情况一没有被感染的节点,不需要考虑;情况二只有一个节点被感染,那么移除这个节点后,该连通分量中的其他节点都不会被感染;情况三有多个节点被感染,那么移除任意一个感染节点后,该连通分量中的其他节点还是会被感染,所以我们只需要考虑情况二。

我们利用并查集 $uf$ 维护节点的连通关系,用一个变量 $ans$ 记录答案,用一个变量 $mx$ 记录当前能减少感染的最大节点数,初始时 $ans = n$, $mx = 0$

然后遍历数组 $initial$,用一个哈希表或者一个长度为 $n$ 的数组 $cnt$ 统计每个连通分量中被感染节点的个数。

接下来,我们再遍历数组 $initial$,对于每个节点 $x$,我们找到其所在的连通分量的根节点 $root$,如果该连通分量中只有一个被感染节点,即 $cnt[root] = 1$,我们就更新答案,更新的条件是该连通分量中的节点数 $sz$ 大于 $mx$ 或者 $sz$ 等于 $mx$$x$ 的值小于 $ans$

最后,如果 $ans$ 没有被更新,说明所有的连通分量中都有多个被感染节点,那么我们返回 $initial$ 中的最小值,否则返回 $ans$

时间复杂度 $O(n^2 \times \alpha(n))$,空间复杂度 $O(n)$。其中 $n$ 是节点的个数,而 $\alpha(n)$ 是 Ackermann 函数的反函数。

Python3

class UnionFind:
    __slots__ = "p", "size"

    def __init__(self, n: int):
        self.p = list(range(n))
        self.size = [1] * n

    def find(self, x: int) -> int:
        if self.p[x] != x:
            self.p[x] = self.find(self.p[x])
        return self.p[x]

    def union(self, a: int, b: int) -> bool:
        pa, pb = self.find(a), self.find(b)
        if pa == pb:
            return False
        if self.size[pa] > self.size[pb]:
            self.p[pb] = pa
            self.size[pa] += self.size[pb]
        else:
            self.p[pa] = pb
            self.size[pb] += self.size[pa]
        return True

    def get_size(self, root: int) -> int:
        return self.size[root]


class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        n = len(graph)
        uf = UnionFind(n)
        for i in range(n):
            for j in range(i + 1, n):
                graph[i][j] and uf.union(i, j)
        cnt = Counter(uf.find(x) for x in initial)
        ans, mx = n, 0
        for x in initial:
            root = uf.find(x)
            if cnt[root] > 1:
                continue
            sz = uf.get_size(root)
            if sz > mx or (sz == mx and x < ans):
                ans = x
                mx = sz
        return min(initial) if ans == n else ans

Java

class UnionFind {
    private final int[] p;
    private final int[] size;

    public UnionFind(int n) {
        p = new int[n];
        size = new int[n];
        for (int i = 0; i < n; ++i) {
            p[i] = i;
            size[i] = 1;
        }
    }

    public int find(int x) {
        if (p[x] != x) {
            p[x] = find(p[x]);
        }
        return p[x];
    }

    public boolean union(int a, int b) {
        int pa = find(a), pb = find(b);
        if (pa == pb) {
            return false;
        }
        if (size[pa] > size[pb]) {
            p[pb] = pa;
            size[pa] += size[pb];
        } else {
            p[pa] = pb;
            size[pb] += size[pa];
        }
        return true;
    }

    public int size(int root) {
        return size[root];
    }
}

class Solution {
    public int minMalwareSpread(int[][] graph, int[] initial) {
        int n = graph.length;
        UnionFind uf = new UnionFind(n);
        for (int i = 0; i < n; ++i) {
            for (int j = i + 1; j < n; ++j) {
                if (graph[i][j] == 1) {
                    uf.union(i, j);
                }
            }
        }
        int ans = n;
        int mi = n, mx = 0;
        int[] cnt = new int[n];
        for (int x : initial) {
            ++cnt[uf.find(x)];
            mi = Math.min(mi, x);
        }

        for (int x : initial) {
            int root = uf.find(x);
            if (cnt[root] == 1) {
                int sz = uf.size(root);
                if (sz > mx || (sz == mx && x < ans)) {
                    ans = x;
                    mx = sz;
                }
            }
        }
        return ans == n ? mi : ans;
    }
}

C++

class UnionFind {
public:
    UnionFind(int n) {
        p = vector<int>(n);
        size = vector<int>(n, 1);
        iota(p.begin(), p.end(), 0);
    }

    bool unite(int a, int b) {
        int pa = find(a), pb = find(b);
        if (pa == pb) {
            return false;
        }
        if (size[pa] > size[pb]) {
            p[pb] = pa;
            size[pa] += size[pb];
        } else {
            p[pa] = pb;
            size[pb] += size[pa];
        }
        return true;
    }

    int find(int x) {
        if (p[x] != x) {
            p[x] = find(p[x]);
        }
        return p[x];
    }

    int getSize(int root) {
        return size[root];
    }

private:
    vector<int> p, size;
};

class Solution {
public:
    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        UnionFind uf(n);
        for (int i = 0; i < n; ++i) {
            for (int j = i + 1; j < n; ++j) {
                if (graph[i][j]) {
                    uf.unite(i, j);
                }
            }
        }
        int ans = n, mx = 0;
        vector<int> cnt(n);
        for (int x : initial) {
            ++cnt[uf.find(x)];
        }
        for (int x : initial) {
            int root = uf.find(x);
            if (cnt[root] == 1) {
                int sz = uf.getSize(root);
                if (sz > mx || (sz == mx && ans > x)) {
                    ans = x;
                    mx = sz;
                }
            }
        }
        return ans == n ? *min_element(initial.begin(), initial.end()) : ans;
    }
};

Go

type unionFind struct {
	p, size []int
}

func newUnionFind(n int) *unionFind {
	p := make([]int, n)
	size := make([]int, n)
	for i := range p {
		p[i] = i
		size[i] = 1
	}
	return &unionFind{p, size}
}

func (uf *unionFind) find(x int) int {
	if uf.p[x] != x {
		uf.p[x] = uf.find(uf.p[x])
	}
	return uf.p[x]
}

func (uf *unionFind) union(a, b int) bool {
	pa, pb := uf.find(a), uf.find(b)
	if pa == pb {
		return false
	}
	if uf.size[pa] > uf.size[pb] {
		uf.p[pb] = pa
		uf.size[pa] += uf.size[pb]
	} else {
		uf.p[pa] = pb
		uf.size[pb] += uf.size[pa]
	}
	return true
}

func (uf *unionFind) getSize(root int) int {
	return uf.size[root]
}

func minMalwareSpread(graph [][]int, initial []int) int {
	n := len(graph)
	uf := newUnionFind(n)
	for i := range graph {
		for j := i + 1; j < n; j++ {
			if graph[i][j] == 1 {
				uf.union(i, j)
			}
		}
	}
	cnt := make([]int, n)
	ans, mx := n, 0
	for _, x := range initial {
		cnt[uf.find(x)]++
	}
	for _, x := range initial {
		root := uf.find(x)
		if cnt[root] == 1 {
			sz := uf.getSize(root)
			if sz > mx || sz == mx && x < ans {
				ans, mx = x, sz
			}
		}
	}
	if ans == n {
		return slices.Min(initial)
	}
	return ans
}

TypeScript

class UnionFind {
    p: number[];
    size: number[];
    constructor(n: number) {
        this.p = Array(n)
            .fill(0)
            .map((_, i) => i);
        this.size = Array(n).fill(1);
    }

    find(x: number): number {
        if (this.p[x] !== x) {
            this.p[x] = this.find(this.p[x]);
        }
        return this.p[x];
    }

    union(a: number, b: number): boolean {
        const [pa, pb] = [this.find(a), this.find(b)];
        if (pa === pb) {
            return false;
        }
        if (this.size[pa] > this.size[pb]) {
            this.p[pb] = pa;
            this.size[pa] += this.size[pb];
        } else {
            this.p[pa] = pb;
            this.size[pb] += this.size[pa];
        }
        return true;
    }

    getSize(root: number): number {
        return this.size[root];
    }
}

function minMalwareSpread(graph: number[][], initial: number[]): number {
    const n = graph.length;
    const uf = new UnionFind(n);
    for (let i = 0; i < n; ++i) {
        for (let j = i + 1; j < n; ++j) {
            graph[i][j] && uf.union(i, j);
        }
    }
    let [ans, mx] = [n, 0];
    const cnt: number[] = Array(n).fill(0);
    for (const x of initial) {
        ++cnt[uf.find(x)];
    }
    for (const x of initial) {
        const root = uf.find(x);
        if (cnt[root] === 1) {
            const sz = uf.getSize(root);
            if (sz > mx || (sz === mx && x < ans)) {
                [ans, mx] = [x, sz];
            }
        }
    }
    return ans === n ? Math.min(...initial) : ans;
}