区間dpについて
- ある区間の中に成立する組み合わせがいくつかるかを数え上げるようなものに対して有用
- pythonだと遅いのでTLEのリスクが高い
例; ある文字列Sから部分集合の文字列Tから何回取り除けるか
問題
解説
- 区間dpのテンプレらしい
- コードの意味を理解するのが大変難しかった(なぜdpの値の3倍が差分と一致する必要があるのか等の理解が浅い)
解答
S= str(input())
N = len(S)
T = 'iwi'
if N < 3:
print(0)
exit()
dp = [[0]*(N+1) for i in range(N+1)]
#dp[i][j]: 区間[i, j)でTを取り除ける最大値
for i in range(N-2):
u = S[i:i+3]
if u == T:
dp[i][i+3] = 1
for d in range(4, N+1):
for i in range(N+1-d):
j = i+d
# 1. s[i]またはs[j-1]を取り除かない場合
dp[i][j] = max(dp[i][j], dp[i+1][j], dp[i][j-1])
# 2. 区間を分けて取り除く場合
for k in range(i+1, j):
dp[i][j] = max(dp[i][j], dp[i][k]+dp[k][j])
# 3. あるkが存在して、[i+1, k), [k+1, j-1)がすべて取り除かれ、
# s[i]+s[k]+s[j-1]がTと一致する場合
if S[i] == T[0] and S[j-1] == T[2]:
for k in range(i+1, j-1):
# ここの意味が不明瞭
if dp[i+1][k]*3 == k-(i+1) and dp[k+1][j-1]*3 == j-1-(k+1) and S[k] == T[1]:
dp[i][j] = max(dp[i][j], dp[i+1][k]+dp[k+1][j]+1)
# print(dp)
print(dp[0][N])
例; 重量の差が1以下のだるま落として最大いくつ落とせるか
問題
解答
import sys; sys.setrecursionlimit(10**9)
def wrap():
N = int(input())
if N == 0:
exit(0)
*W,=map(int,input().split())
dp = [[-1]*(N+2) for _ in range(N+2)]
def rec(l = 0, r = N): # メモ化再帰
if (r - l) <= 1:
return 0
if (r - l) == 2:
if abs(W[l] - W[l + 1]) <= 1: # 隣り合う2つの差が1以下
return 2
else:
return 0
if dp[l][r] != -1:
return dp[l][r] # 既に計算済みならその値を使う
# 1. 全部取り除けるとき
if abs(W[l] - W[r-1]) <= 1 and rec(l + 1, r - 1) == r - l - 2:
dp[l][r] = max(dp[l][r], r - l)
# 2. そうでない時
for i in range(l+1, r-1+1):
dp[l][r] = max(dp[l][r], rec(l, i) + rec(i, r))
return dp[l][r]
print(rec())
while True:
wrap()