桁DPについて
- 覚えないと使いこなすのは難しそう
less
は上の桁の影響を受けるかどうか
colab
例; 100000
までの数字で3
が含まれる個数
import itertools
import numpy as np
A = 10**5
def digit_dp(x):
a = str(x)
n = len(a)
#配列は末から
dp = np.zeros((7, 2, n+1)).astype(int).tolist()
dp[0][0][0] = 1
# 条件に合わせてDP
# 未満フラグ(less)は上の桁から成約を受けていないかを調べる <=> lessのときi桁までは全スキャン、lessでないときi桁はa[i]桁までしか見れない
for i, less, has3 in itertools.product(range(n), (0,1), (0,1)):
max_d = 9 if less else int(a[i])
for d in range(max_d+1):
less_ = less or d < max_d
has3_ = has3 or d == 3
dp[i + 1][less_][has3_] += dp[i][less][has3]
#合致するものを合算
ret = 0
for less in (0, 1):
# ret += dp[n][less][1][0]
ret += dp[n][less][1]
return ret
print(digit_dp(A))
例; 100000
までの数字で(3
が含まれるか3
で割り切れる) かつ (8
で割り切れない) 数
import itertools
import numpy as np
A = 10**5
def digit_dp(x):
a = str(x)
n = len(a)
#配列は末から
#dp=[[[[[0] * 8 for _ in range(3)] for _ in range(2)] for _ in range(2)] for _ in range(n+1)]
dp = np.zeros((7, 2, 2, 3, n+1)).astype(int).tolist()
dp[0][0][0][0][0] = 1
print(np.array(dp).shape)
#条件に合わせてDP
for i, less, has3, mod3, mod8 in itertools.product(range(n), (0,1), (0,1), range(3), range(8)):
max_d = 9 if less else int(a[i])
for d in range(max_d+1):
less_ = less or d < max_d
has3_ = has3 or d == 3
mod3_ = (mod3 + d) % 3
mod8_ = (mod8*10 + d) % 8
dp[i + 1][less_][has3_][mod3_][mod8_] += dp[i][less][has3][mod3][mod8]
#合致するものを合算
ret = 0
for less, mod8 in product((0,1), range(1,8)):
ret += dp[n][less][1][0][mod8]
ret += dp[n][less][1][1][mod8]
ret += dp[n][less][1][2][mod8]
ret += dp[n][less][0][0][mod8]
return ret
print(digit_dp(A))
例; 数字の中にいくつ1
が含まれるか
問題
AtCoder Beginner Contest 029; D - 1
解説
1の数はどんどん増えるので最大でどの程度まで増えるかを管理するdpの次元を追加する必要がある
集計時に条件を満たす数をカウントしただけの結果が得られるので、数を掛けて計算する必要がある
解答
import itertools
import numpy as np
def digit_dp(x):
a = str(x)
n = len(a)
#配列は末から
dp = np.zeros((11, 2, 11)).astype(int).tolist()
dp[0][0][0] = 1
# 条件に合わせてDP
# 未満フラグ(less)は上の桁から成約を受けていないかを調べる <=> lessのときi桁までは全スキャン、lessでないときi桁はa[i]桁までしか見れない
for i, less, num1 in itertools.product(range(n), (0,1), range(10)):
max_d = 9 if less else int(a[i])
for d in range(max_d+1):
less_ = less or d < max_d
num1_ = num1 + int(d == 1)
dp[i + 1][less_][num1_] += dp[i][less][num1]
#合致するものを合算
ret = 0
for less, num1 in itertools.product((0, 1), range(10)):
'''num1になるものが何個あるのかを知りたいのでnum1を掛ける必要がある'''
ret += dp[n][less][num1]*num1
return ret
A = int(input())
print(digit_dp(A))