倍增 & Tarjan 求解LCA

什么是LCA?

假设我们有一棵树:

         1
      /      \
     2         3
   /   \      /
  4    5    6

对于 \(2\)\(6\) 的LCA,就是最近公共祖先,即为距离 \(2\)\(6\) 最近的两个节点公有的节点。怎么求呢?这里就有三种算法。

普通算法

我们可以把这一棵树存好,方式随便(这里展示使用邻接表),可以看到存好之后的树如下:

1 - 2 - 3
2 - 1 - 4 - 5
3 - 1 - 6
4 - 2
5 - 2
6 - 3

其中 \(4,5,6\) 均为叶子节点。现在我们假设要求 \(2\)\(6\) 的LCA,步骤如下:

  • 首先,因为 \(6\) 的深度 \(>2\) 所以我们要 \(6\) 先跳到 \(2\) 的高度。
  • 此时,我们的 \(6\) 节点来到了 \(3\) 节点,\(2\) 节点不变。
  • 现在,把 \(2\)\(3\) 节点同时上提。
  • 经过一次上提之后,两个节点都来到了 \(1\) 位置。那么 \(1\) 就是 \(2\)\(6\) 的LCA。
    算法实现如下:
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;

const int MAXN = 500010;
vector<int> tree[MAXN];
int depth[MAXN];
int parent[MAXN];

// DFS预处理每个节点的深度和父节点
void dfs(int u, int p) {
    parent[u] = p;
    depth[u] = depth[p] + 1;
    for (int v : tree[u]) {
        if (v != p) {
            dfs(v, u);
        }
    }
}

// 暴力方法求LCA
int lca(int u, int v) {
    // 将两个节点提到同一深度
    while (depth[u] > depth[v]) u = parent[u];
    while (depth[v] > depth[u]) v = parent[v];
    // 然后一起向上找
    while (u != v) {
        u = parent[u];
        v = parent[v];
    }
    return u;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    int N, M, S;
    cin >> N >> M >> S;
    
    // 建树
    for (int i = 1; i < N; ++i) {
        int x, y;
        cin >> x >> y;
        tree[x].push_back(y);
        tree[y].push_back(x);
    }
    
    // 预处理
    depth[0] = -1; // 根节点的父节点设为0,深度为-1+1=0
    dfs(S, 0);
    
    // 处理查询
    while (M--) {
        int a, b;
        cin >> a >> b;
        cout << lca(a, b) << '\n';
    }
    
    return 0;
}

可以得知,在最坏情况下(树是一条链),树的高度 \(dis\) 和询问次数 \(m\) 直接关系到查询,会变得很慢,时间复杂度 \(O(nm)\) ,在大多数题目中会被卡掉。那么有没有什么优化的办法呢?当然有。

倍增求解LCA

我们由上面的讲解可以知道,暴力处理LCA并不是一个好的算法。如何优化呢?暴力算法的每次操作只把节点提高了 \(1\) 次。导致上升得很慢,所以我们当然可以一次上升多个节点来满足快速上升高度的需求。那新的问题又来了,我们一次上升多少高度呢?这里就要涉及到一个“数学小常识”了,我们可以证明:,任意一个非 \(0\) 自然数可以被写作若干个 \(2\) 的幂之和。比如:\(10 = 2^3+2^1\)\(17 = 2^4 +2^0\) 。这其实也就是二进制转十进制的计算方法(扯远了)。所以我们可以得知,只要上提若干个 \(2\) 的幂次方步就能得到结果(为了快速一点,代码好写一点,我们通常从大的幂枚举到小的幂)。现在问题就简单了,我们最后的一部就只要求出 \(2^k\) 次幂是多少就可以了。这里我们开一个 dp[100000][20] 。其中 dp[i][j] 表示从节点 \(i\) 向上 \(2^j\) 到达的节点。推导 dp[i][j] 的公式如下:

\[dp_{i,j}=dp_{dp_{i,j-1},j-1} \]

最后加上我们暴力的解法就可以了:

#include <iostream>
#include <vector>
#include <cstring>
using namespace std;

const int MAXN = 500010;
const int LOG = 20; // log2(500000) ≈ 19

vector<int> tree[MAXN];
int depth[MAXN];
int parent[MAXN][LOG]; // parent[u][k]表示u的2^k级祖先

// DFS预处理每个节点的深度和倍增数组
void dfs(int u, int p) {
    parent[u][0] = p;
    depth[u] = depth[p] + 1;
    
    // 预处理倍增数组
    for (int k = 1; k < LOG; ++k) {
        parent[u][k] = parent[parent[u][k-1]][k-1];
    }
    
    for (int v : tree[u]) {
        if (v != p) {
            dfs(v, u);
        }
    }
}

// 二进制倍增法求LCA
int lca(int u, int v) {
    // 确保u是较深的节点
    if (depth[u] < depth[v]) swap(u, v);
    
    // 将u提到与v同一深度
    for (int k = LOG-1; k >= 0; --k) {
        if (depth[parent[u][k]] >= depth[v]) {
            u = parent[u][k];
        }
    }
    
    // 如果此时u==v,说明v就是u的祖先
    if (u == v) return u;
    
    // 现在u和v在同一深度,一起向上找LCA
    for (int k = LOG-1; k >= 0; --k) {
        if (parent[u][k] != parent[v][k]) {
            u = parent[u][k];
            v = parent[v][k];
        }
    }
    
    // 此时u和v的父节点就是LCA
    return parent[u][0];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    int N, M, S;
    cin >> N >> M >> S;
    
    // 建树
    for (int i = 1; i < N; ++i) {
        int x, y;
        cin >> x >> y;
        tree[x].push_back(y);
        tree[y].push_back(x);
    }
    
    // 初始化
    depth[0] = -1; // 虚拟根节点的深度设为-1,这样根节点S的深度为0
    
    // 预处理
    dfs(S, 0); // 0作为虚拟父节点
    
    // 处理查询
    while (M--) {
        int a, b;
        cin >> a >> b;
        cout << lca(a, b) << '\n';
    }
    
    return 0;
}

时间复杂度:预处理 \(O(n\log n)\) ,每次查询 \(O(\log n)\),总体时间复杂度 \(O(n \log n + m \log n)\)

Tarjan算法

Tarjan算法求LCA,应当是我认为的难度最高的,也是我最喜欢的解法。我们举个例子说明一下Tarjan算法,还是刚才的那张图:

         1
      /      \
     2         3
   /   \      /
  4    5    6

此时询问: \(2\)\(6\) 的LCA 是谁呀?
Tarjan:

  • 来到 \(1\) 节点,标记为访问过。
  • 来到 \(2\) 节点,标记为访问过。
  • 来到 \(4\) 节点,标记为访问过。
  • 处理 \(4\) 节点的所有询问:
    • 没有询问,跳过。
  • 来到 \(5\) 节点,标记为访问过
  • 处理 \(5\) 节点的所有询问:
    • 没有询问,跳过。
  • 处理 \(2\) 的所有访问:
    • 找到了与 \(6\) 的一条询问,但是 \(6\) 节点没有被访问过,无法处理。
  • 来到 \(3\) 节点,标记为访问过。
  • 来到 \(6\) 节点,标记为访问过。
  • 处理 \(6\) 节点的所有询问:
    • 找到了与 \(2\) 的一条访问。此时 \(6\) 的祖先节点是 \(1\) 。则这条询问是 \(1\)
  • 处理 \(3\) 的所有访问:
    • 没有询问,跳过。
      那这时有的人就要问了,怎么知道 \(6\) 的祖先是 \(1\) 的?我们添加一个并查集不就好了吗?每次访问一个节点,若子节点没有访问,那就将此节点指向子节点。将子树添加进来到一个集合中。因为下一次回溯来到当前节点的时候,一定是当前子树被处理过了。且如果询问的另一方已经被访问,更新的LCA一定是当前节点的祖先。
#include <iostream>
#include <vector>
#include <unordered_map>
using namespace std;

const int MAXN = 500010;

vector<int> tree[MAXN];
vector<pair<int, int>> queries[MAXN]; // 存储查询,queries[u] = {v, 查询编号}
int ans[MAXN]; // 存储每个查询的答案
int parent[MAXN]; // 并查集父节点
bool visited[MAXN]; // 标记节点是否已被访问

// 并查集查找函数
int find(int u) {
    if (parent[u] != u) {
        parent[u] = find(parent[u]); // 路径压缩
    }
    return parent[u];
}

// Tarjan算法主函数
void tarjan(int u) {
    visited[u] = true;
    parent[u] = u; // 初始化并查集
    
    // 遍历所有子节点
    for (int v : tree[u]) {
        if (!visited[v]) {
            tarjan(v);
            parent[v] = u; // 合并子树
        }
    }
    
    // 处理所有与u相关的查询
    for (auto [v, idx] : queries[u]) {
        if (visited[v]) {
            ans[idx] = find(v);
        }
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    int N, M, S;
    cin >> N >> M >> S;
    
    // 建树
    for (int i = 1; i < N; ++i) {
        int x, y;
        cin >> x >> y;
        tree[x].push_back(y);
        tree[y].push_back(x);
    }
    
    // 存储查询
    for (int i = 0; i < M; ++i) {
        int a, b;
        cin >> a >> b;
        // 双向存储查询
        queries[a].emplace_back(b, i);
        if (a != b) {
            queries[b].emplace_back(a, i);
        }
    }
    
    // 运行Tarjan算法
    tarjan(S);
    
    // 输出查询结果
    for (int i = 0; i < M; ++i) {
        cout << ans[i] << '\n';
    }
    
    return 0;
}

Tarjan算法的时间复杂度为 \(O(n + m(\alpha (n))\),其中 \((\alpha n )\) 通常小于二。

课后习题:

  • P3379 【模板】最近公共祖先(LCA)
  • P11477 [COCI 2024/2025 #3] 林卡树 / Stablo

The End.

来源链接:https://www.cnblogs.com/CheeseFunction/p/18841401

© 版权声明
THE END
支持一下吧
点赞13 分享
评论 抢沙发
头像
请文明发言!
提交
头像

昵称

取消
昵称表情代码快捷回复

    暂无评论内容