怎么推导出《115. 不同的子序列》的动态规划解法

今天的打卡题,是一道“困难”,先来看看题目:

给定一个字符串 s 和一个字符串 t ,计算在 s 的子序列中 t 出现的个数。

字符串的一个 子序列 是指,通过删除一些(也可以不删除)字符且不干扰剩余字符相对位置所组成的新字符串。(例如,”ACE” 是 “ABCDE” 的一个子序列,而 “AEC” 不是)

题目数据保证答案符合 32 位带符号整数范围。

  1. 不同的子序列 <– 传送门

具体的例子,我都略去了,因为博客的排版比较弱,没法很好展示。其实,这道题目一看,就隐含了动态规划的气味,比如,求解的是个数,而不是具体的位置,第二,隐含了大量的穷举,才能进行的统计。

不过,如果你跟我一样小白的话,就会发现,你根本写不出状态转移的方程的,因为这个状态转移方程非常地不显然。

当然了,无论你想怎么去解决这道题目,首先还是要自己看懂题目。我们看看如何确定 t 在 s 中出现的次数。举个例子,假如 s = ababc,t = abc,那么很显然,t 在 s 出现了三次。然后我们来理一下,怎么统计。首先,我们把 t 的第一个字符 a,在 s 中标定,然后,是 b 标定,最后是 c,当 a 固定在第一个字符的时候,b 其实有两个选择。这两个选择的 b 都对应着唯一的 c,所以有两次,而 a,还有另一个选择,但是选那个 a 的时候,b,c 都只有唯一选择了,所以,总次数就是 3 次。

我上面描述的这个过程,就是一种最直观的统计策略。逐个标定 t 的每个字符,而且要穷尽每个字符的多个位置选择。这个就是一个比较明显的回溯过程。所以,很自然,我写出了回溯算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Solution:
def numDistinct(self, s: str, t: str) -> int:
self.res = 0
self.dfs(s, t, 0)
return self.res

def dfs(self, s: str, t: str, p: int) -> None:
if len(t) == 1:
self.res += s[p:].count(t)
return

for i in range(p, len(s)):
if s[i] == t[0]:
self.dfs(s, t[1:], i + 1)

虽然是一道困难题,但是解法却很短小嘛,这个算法,每成功标定一个 t 的字符,就把 t 截短一截,然后继续向前扩展。这里有很低效的地方,比如每次截短 t,以及每次拷贝 s 之类的,容我偷个懒。

这个算法通过了 87% 的用例,说明,这是一个正确的算法,但是时间复杂度过高了。可以简单分析一下递归算法的时间复杂度,函数体内部,最坏情况循环大约是 len(s) - len(t) 次,而没深入一层,s 和 t 各减少了 1 (t 截短了 1 个字母,s 的下标后移了一个)。所以,然后把每层乘起来,综合复杂度是 O((m - n)!) 的复杂度。难怪过不了。

很容易想到,这个算法的优化空间,就是引入缓存,对中间结果进行记录,可以大幅减少搜索的数量。让我们来优化一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution:
def numDistinct(self, s: str, t: str) -> int:
self.s, self.t = s, t
self.m, self.n = len(s), len(t)
self.mem = {}
return self.dfs(0, 0)

def dfs(self, p: int, q: int) -> int:
if self.m - p < self.n - q:
return 0

# 缓存
k = (p, q)
v = self.mem.get(k, -1)
if v > 0:
return v

if q == self.n - 1:
v = self.s[p:].count(self.t[-1])
self.mem[k] = v
return v

total = 0
for i in range(p, self.m):
if self.m - i >= self.n - q and self.s[i] == self.t[q]:
total += self.dfs(i + 1, q + 1)

self.mem[k] = total
return total

然后,我优化出了这个算法:

  1. 不在使用字符串作为参数,只是使用下标,避免了字符串拷贝
  2. 改换了函数的返回值,使得方便去缓存
  3. 引入了缓存来记录结果
  4. 如果 s 长度小于 t 显然就不用继续搜索了,可以直接剪枝

这么多优化后,总共 62 个用例,过了 61 个,好消息是优化是有效的,效率提高了很多。坏消息是,仍然 TLE。这给我们最大的提示就是,非多项式时间的算法,无论怎么优化,在足够大的问题规模面前,就是无法有效运算。哈!

还是要使用递推方法,用动态规划来解,一个是不用递归,省去了栈的开销,另一个就是放弃了大量无效的搜索。但是写不出状态转移方程怎么办?不要急,仔细观察一下咱们优化后的算法。

为了避免拷贝字符串,我们用 p 和 q,记录了 s 和 t 的下标搜索位置,从上述算法里,我们能看到 dfs(p, q) 其实就是 dfs(i + 1, q + 1) 当 s[i] == t[q] 时,i 的取值范围是 [p, m)。这几乎就是状态转移方程了。

第一版状态转移方程:[
f(i, j) = \left {

\right.
]

现在这个方程基本上是没法写出代码来的,基本上是递归搜索方法的原本表达,一个是里面有个求和公式,就算写出来,算法的时间复杂度也是O(n^3)的,另一个就是,在 s[i] != t[j] 的时候,没有写出方程。这时候,我们就要耐心下来深入研究一下,怎么把这个状态方程给写完整了。

再来观察一下 dfs 的实现,在 for 循环里,套了一个 if,所以我们直接写出了 s[i] == t[j] 时候的情况,if 的 else 子句,代码里没写,其实就是 s[i] != t[j] 时的状态转移,我们发现,其实 else 的情况,就是跳过了这个下标,直接从下一个重复这个扫描过程了。相当于是调用了 dfs(i + 1, j) 了。这里可以品味一下。

然后在看怎么解决那个求和公式,其实解决起来也不难,就是咱们可以展开这个求和公式来看看:[

]

不知道这么解释,容易理解么?这里我们用了一个数列求和的小技巧,就是错位相减,然后得到了 f(i,j) 的递推公式,到此,我们就构造了我们的第二版状态转移方程:[
f(i, j) = \left {

\right.
]

有了状态转移公式,我们还要处理好边界条件,当 j = len(t) - 1 的时候,f(i, j) 的值,其实就是那个字符出现的次数。所以,我们可以把 i 从 0 到 len(s) 的值,初始化成在这个范围内搜索 t 的最后一个字符的出现次数。而实际上,更抽象地看,可以理解成,空串出现的次数。因为空串是任何串的子串,所以空串都会出现 1 次。当然,这个结论并不显然,这可能需要一些证明,我们可以这么猜测,如果是对的就万事大吉了。毕竟我们不是在做数学证明。

另外一个边界,就是递归代码里也用到了,就是 s 剩余的长度小于 t 剩余的长度时,就不用扫描了,显然结果是 0 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Solution:
def numDistinct(self, s: str, t: str) -> int:
if len(s) < len(t):
return 0
dp = [[0] * len(t) for _ in range(len(s))]

for i in range(len(s)):
dp[i][len(t) - 1] = s[i:].count(t[-1])

for i in range(len(s) - 2, -1, -1):
for j in range(len(t) - 2, -1, -1):
if s[i] == t[j] :
dp[i][j] = dp[i + 1][j + 1] + dp[i + 1][j]
else:
dp[i][j] = dp[i + 1][j]
return dp[0][0]

我写出了这样的实现。7 - 8 行,可以用我说的更抽象的形式去替代,这样效率会更高一点。少了 len(s) 次的字符串统计行为。这个代码经测试是正确的。

总结,我通过这篇文章,从分析题目,到写出朴素的回溯算法,然后进行性能优化,发现了状态转移,并优化了状态转移方程,最终写出了动态规划算法。状态转移方程的物理意义,在优化过后已经变得比较模糊了,有时候强行去解释的话,,听起来很绕,而且一下子想看懂,恐怕也非常难。学习动态规划,不用去纠结这一点,不然真的会陷在里面。