LCA(lowest common ancestor, 最小共通祖先)
date: 2021-07-11 excerpt: LCA(lowest common ancestor, 最小共通祖先)について
LCA(lowest common ancestor, 最小共通祖先)について
概要
- グラフでいくつか共通の先祖を持つ最初のノードを発見するアルゴリズム
- 任意の二点間の距離を測るアルゴリズムとしても使える
- ダブリングを使用する方法の計算量
O(log n)
- オイラーツアーを使用する方法の計算量(わかりやすい)
O(n)
オイラーツアーを利用するLCAの計算
- 手順
- 行きがけ、帰りがけ、戻りがけでそれぞれ
深さの情報
を記録しながらログを記録する(オイラーツアー) - 2つの求めたいノードのオイラーツアー上の
インデックスを取得
する - 2つの範囲内の
最小の深さのノードがLCA(最小共通祖先)
である
- 行きがけ、帰りがけ、戻りがけでそれぞれ
具体例
tour_log = []
def dfs(node, depth=0):
# preorder
tour_log.append((node.val, node, depth))
if node.left:
dfs(node.left, depth+1)
# inorder
tour_log.append((node.val, node, depth))
if node.right:
dfs(node.right, depth+1)
# 帰りがけ
tour_log.append((node.val, node, depth))
dfs(root, 0)
pi, qi = None, None
for i in range(len(tour_log)):
if tour_log[i][0] == p.val:
pi = i
if tour_log[i][0] == q.val:
qi = i
if pi is not None and qi is not None:
break
pi, qi = sorted([pi, qi])
lca_node = min(tour_log[pi:qi+1], key=lambda x:x[2])[1]
ダブリングスニペットライブラリ
import collections
class LCA(object):
def __init__(self, G, N, root=0):
self.G = G
self.root = root
self.n = N
self.logn = (self.n - 1).bit_length()
self.depth = [-1 if i != root else 0 for i in range(self.n)]
self.parent = [[-1] * self.n for _ in range(self.logn)]
self.dfs()
self.doubling()
def dfs(self):
que = collections.deque([self.root])
while que:
u = que.pop()
for v in self.G[u]:
if self.depth[v] == -1:
self.depth[v] = self.depth[u] + 1
self.parent[0][v] = u
que.append(v)
def doubling(self):
parent = self.parent
for i in range(1, self.logn):
for v in range(self.n):
if parent[i - 1][v] != -1:
parent[i][v] = parent[i - 1][parent[i - 1][v]]
def get(self, u, v):
depth, parent = self.depth, self.parent
if depth[v] < depth[u]:
u, v = v, u
du, dv = depth[u], depth[v]
for i in range(self.logn): # depthの差分だけuを遡らせる
# if (dv - du) >> i & 1:
if (dv-du)&(1<<i) > 0 :
v = parent[i][v]
if u == v:
return u # 高さ揃えた時点で一致してたら終わり
for i in range(self.logn - 1, -1, -1): # そうでなければ上から二分探索
pu, pv = parent[i][u], parent[i][v]
if pu != pv:
u, v = pu, pv
return parent[0][u]
def distance(self, u, v):
return self.depth[u] + self.depth[v] - 2 * self.depth[self.get(u, v)] + 1
例; 典型例
問題
解答
# ライブラリ省略
N = int(input())
G = [[] for _ in range(N)]
for n in range(N - 1):
x, y = map(int,input().split())
x-=1; y -= 1
G[x].append(y)
G[y].append(x)
lca = LCA(G, N)
Q = int(input())
for q in range(Q):
a, b = map(int, input().split())
a-=1; b -= 1
print(lca.distance(a, b))