union find(disjoint set)
date: 2020-10-01 excerpt: union find(disjoint set)
union find(disjoint set)
概要
- 与えられた列に対して、リンク関係から2つ以上に分けられるとき、効率的に分割するアルゴリズム
- 具体的な挙動
-1
や特定の数で初期化して根構造を再帰でたどってrootを計算していくというもの- クラス分けや、所属分けなどで効率的
- カスタマイズ要素多いのでどこをどういじるか
- 高さを求める等のオプションが付くことがある
- 閉路の検出でも使うことができ、閉路が存在するとき、union時にparentの衝突が発生するのでそれを利用する
シンプルな実装例
以下の例では、最大のノードへのリンクを求めるというものになる
競プロによる例
- ルートノードの参照料を知りたい場合もサポートできるように拡張したもの
-1
で初期化することがより多い- 参考
- AtCoder Beginner Contest 120; D - Decayed Bridges
- https://atcoder.jp/contests/abc177/tasks/abc177_d
import collections
class UnionFind:
def __init__(self, n):
self.n = n
self.parents = [-1] * n
self.has_cycles = [0] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
# 積極的aggregation
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
# 同じノード同士ならばなにもしない
# ここに閉路情報を入れることができる
if x == y:
self.has_cycles[x] = 1
return
# 既知の親子で小さいものが左に来るべき
if self.parents[x] > self.parents[y]:
x, y = y, x
# rootノードを負の値で参照量をカウントしたいため、このような+=が入っている
self.parents[x] += self.parents[y]
# rootノードでなければ、正のindex値を入れる
self.parents[y] = x
def size(self, x):
# rootノードの参照料を保存したものを取り出している
return -self.parents[self.find(x)]
def same(self, x, y):
# 同じルートを持つか
return self.find(x) == self.find(y)
def roots(self):
# どのノードがrootとなるか
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
# グループの個数
return len(self.roots())
def all_group_members(self) -> "Tuple[GroupMember, GroupCycle]":
# rootをkeyに子をvalueのlistに, 閉路情報も返す
group_members = collections.defaultdict(list)
for member in range(self.n):
group_members[self.find(member)].append(member)
group_cycle = collections.defaultdict(bool)
for group, members in group_members.items():
group_cycle[group] = True if any([self.has_cycles[member] for member in members]) else False
return group_members, group_cycle
例; グループ間の行き来の量がわかると平衡状態かどうかを判定できる例
問題
説明
- 変化可能かは閉区間を求めることで判明する
- 閉区間はunion findで知ることができる
解答
例; 閉路の検出
問題
解説
dfsでも閉路チェックができるがコードをまとめたいときにはunion findが便利
解答
例; ドット状のグラフの結合判定
ドット上のものもunion findが適応可能なグラフであると気づけると早い
問題
解答
例; swap法則とUnionFind
問題
解説
- swap可能になる ≡ ネットワーク的に結合する ≡ UnionFindで結合状態を知ることができる
解答
例; UnionFildを複雑な手続きで利用して解答を得る例
問題
解説
- 友達のグラフでufを初期化する
- ブロック関係にあるときは、
友達関係になりうる同じグループである
かつブロック関係である
とき、候補から除外する - ブロック処理のプロセスにてもufを用いるので複雑
解答
例; すべての最短経路の最大のコストの和
- 何度も見返したがunionfindが適応例だとはなかなか発想が至らない
問題
解説
- 最短パスに含まれるパスの最大のコストのすべての和
- コストが小さい順でソートして、
コスト * group_size(l) * group_size(r)
を累積していくと答えになる
解答
例; 重み付きunion find
問題
解説
- 重み付きunion findは実装がそれ専用になる
解答
class WeightedUnionFind():
def __init__(self, n):
self.n = n
# 各親要素の番号を格納 rootの場合は、そのグループの要素数
self.parents = [-1] * n
self.diff_weight = [0] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
# 根を見つけると同時に、他の要素の親を根に変更(経路圧縮)
r = self.find(self.parents[x])
# 親を遡りながら、重みの累積和を取る
self.diff_weight[x] += self.diff_weight[self.parents[x]]
self.parents[x] = r
return self.parents[x]
def weight(self, x):
# 経路圧縮
self.find(x)
return self.diff_weight[x]
def diff(self, x, y):
return self.weight(y) - self.weight(x)
def union(self, x, y, w):
# xとyそれぞれについて、rootとの重み差分を補正
w += self.weight(x)
w -= self.weight(y)
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
w = -w
self.parents[x] += self.parents[y]
self.parents[y] = x
# x が y の親になるので、x と y の差分を diff_weight[y] に記録
self.diff_weight[y] = w
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.find(member)].append(member)
return group_members
def __str__(self):
return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())
N,M=map(int,input().split())
wuf = WeightedUnionFind(n=N)
for _ in range(M):
l, r, d = map(int,input().split())
l-=1; r-=1;
if wuf.same(l, r):
if wuf.diff(l,r) != d:
print("No")
exit()
else:
wuf.union(l, r, d)
print("Yes")