組み合わせ
date: 2021-04-01 excerpt: 組み合わせ(combination, choose)の様々な実装方法
組み合わせ様々な実装方法
scipyでの実装方法
- floatで値が帰るので注意
from scipy.special import comb
print(comb(10, 7)) # 120.0
mathでの実装方法
import math
math.comb(10,7) # 120
Pure Pythonでの実装方法
def binomial(n, k):
if not 0 <= k <= n:
return 0
b = 1
for t in range(min(k, n-k)):
b *= n
b //= t+1
n -= 1
return b
Pure Pythonでの実装方法(大きな値にならない時)
from functools import reduce
def comb(n, k):
numer = reduce(lambda x,y:x*y, range(n - k + 1, n))
denom = reduce(lambda x,y:x*y, range(1, k + 1))
return numer // denom
mod付き組み合わせ
- 何度も計算しないであれば最速
mod = 10 ** 9 + 7
def mod_cmb(n, a, mod):
nca = 1
for i in range(a):
nca *= (n - i) * pow(a - i, mod - 2, mod)
nca %= mod
return nca
print(mod_cmb(10, 7, mod)) # 120
メモ化付き組み合わせ
- 何度もcombinationの計算結果を求められるとき
def cmb(n, r, p):
if (r < 0) or (n < r):
return 0
r = min(r, n - r)
return fact[n] * factinv[r] * factinv[n-r] % p
p = 10 ** 9 + 7
N = 10 ** 6 # N は必要分だけ用意する
fact = [1, 1] # fact[n] = (n! mod p)
factinv = [1, 1] # factinv[n] = ((n!)^(-1) mod p)
inv = [0, 1] # factinv 計算用
for i in range(2, N + 1):
fact.append((fact[-1] * i) % p)
inv.append((-inv[p % i] * (p // i)) % p)
factinv.append((factinv[-1] * inv[-1]) % p)
K個のお菓子をN人に最低1個以上分配する
- Kのお菓子を並べて、K-1個の隙間にN-1個のセパレータを入れるイメージ
(K-1)C(N-1)
が数式になる
combination with replacement
重複許可のcombination
- 選択される要素が排反的でない
- 特に、動的なリスト等を作成して、フィルターにするなどはアルゴリズム的にも計算量的にも大変重いので、これを使わないと解けないなどがある
具体例
以下の問題ではメモリの制約がなければ問題なく解けるものであるが、このサイズのメモリを許容できないことによって、combination_with_replacement
が必要になる
- https://atcoder.jp/contests/abc165/tasks/abc165_c
NG
# このアプローチはメモリーエラー
s=[[1]]
#長さNまで
import copy
while True:
n = []
for s1 in s:
for append in range(s1[-1], M+1):
n1 = copy.copy(s1)
n1.append(append)
n.append(n1)
if len(n[0]) > M:
break
s = n
#for s1 in s:
# print(s1)
iidp = [list(map(int,input().split())) for q in range(Q)]
al = 0
for s1 in s:
a = 0
for i1, i2, d, p in iidp:
if s1[i2-1] - s1[i1-1] == d:
a += p
al = max(a,al)
print(al)
OK
N,M,Q=map(int,input().split())
iidp = [list(map(int,input().split())) for q in range(Q)]
import itertools
al = 0
for s1 in itertools.combinations_with_replacement(list(range(1,M+1)), N):
a = 0
for i1, i2, d, p in iidp:
if s1[i2-1] - s1[i1-1] == d:
a += p
al = max(a,al)
print(al)