|
| 1 | +// 2-SAT solver based on Kosaraju's algorithm. |
| 2 | +// Variables are 0-based. Positive variables are stored in vertices 2n, corresponding negative variables in 2n+1 |
| 3 | +// TODO: This is quite slow (3x-4x slower than Gabow's algorithm) |
| 4 | +struct TwoSat { |
| 5 | + int n; |
| 6 | + vector<vector<int> > adj, radj, scc; |
| 7 | + vector<int> sid, vis, val; |
| 8 | + stack<int> stk; |
| 9 | + int scnt; |
| 10 | + |
| 11 | + // n: number of variables, including negations |
| 12 | + TwoSat(int n): n(n), adj(n), radj(n), sid(n), vis(n), val(n, -1) {} |
| 13 | + |
| 14 | + // adds an implication |
| 15 | + void impl(int x, int y) { adj[x].push_back(y); radj[y].push_back(x); } |
| 16 | + // adds a disjunction |
| 17 | + void vee(int x, int y) { impl(x^1, y); impl(y^1, x); } |
| 18 | + // forces variables to be equal |
| 19 | + void eq(int x, int y) { impl(x, y); impl(y, x); impl(x^1, y^1); impl(y^1, x^1); } |
| 20 | + // forces variable to be true |
| 21 | + void tru(int x) { impl(x^1, x); } |
| 22 | + |
| 23 | + void dfs1(int x) { |
| 24 | + if (vis[x]++) return; |
| 25 | + for (int i = 0; i < adj[x].size(); i++) { |
| 26 | + dfs1(adj[x][i]); |
| 27 | + } |
| 28 | + stk.push(x); |
| 29 | + } |
| 30 | + |
| 31 | + void dfs2(int x) { |
| 32 | + if (!vis[x]) return; vis[x] = 0; |
| 33 | + sid[x] = scnt; scc.back().push_back(x); |
| 34 | + for (int i = 0; i < radj[x].size(); i++) { |
| 35 | + dfs2(radj[x][i]); |
| 36 | + } |
| 37 | + } |
| 38 | + |
| 39 | + // returns true if satisfiable, false otherwise |
| 40 | + // on completion, val[x] is the assigned value of variable x |
| 41 | + // note, val[x] = 0 implies val[x^1] = 1 |
| 42 | + bool two_sat() { |
| 43 | + scnt = 0; |
| 44 | + for (int i = 0; i < n; i++) { |
| 45 | + dfs1(i); |
| 46 | + } |
| 47 | + while (!stk.empty()) { |
| 48 | + int v = stk.top(); stk.pop(); |
| 49 | + if (vis[v]) { |
| 50 | + scc.push_back(vector<int>()); |
| 51 | + dfs2(v); |
| 52 | + scnt++; |
| 53 | + } |
| 54 | + } |
| 55 | + for (int i = 0; i < n; i += 2) { |
| 56 | + if (sid[i] == sid[i+1]) return false; |
| 57 | + } |
| 58 | + vector<int> must(scnt); |
| 59 | + for (int i = 0; i < scnt; i++) { |
| 60 | + for (int j = 0; j < scc[i].size(); j++) { |
| 61 | + val[scc[i][j]] = must[i]; |
| 62 | + must[sid[scc[i][j]^1]] = !must[i]; |
| 63 | + } |
| 64 | + } |
| 65 | + return true; |
| 66 | + } |
| 67 | +}; |
0 commit comments