354. 俄罗斯套娃信封问题

这是一道打卡题,没想到难度是“困难”,每次看到红色的“困难”,我还是菊花一紧的,感觉很多脑细胞要死亡了。

题目规则很简单,就是大信封能套小信封,然后,问最多能套多少层。信封用 [w, h] 来标记宽度和长度。比如在宽度和长度都更大的前提下才能套入。

因为有两个维度,就显得有些复杂,我的第一个直觉就是排序。至少先按照其中一个维度比如宽度,进行排序。完事后,对长度开始穷举。穷举自然就是回溯法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution:
def maxEnvelopes(self, envelopes: List[List[int]]) -> int:
self.envelopes = sorted(envelopes, key = lambda x : x[0])
self.l = len(envelopes)
if self.l < 2:
return self.l
self.visited = [0] * self.l
self.max_env = 0
self.dfs(0, 0, 0)
return self.max_env

def dfs(self, start: int, maxW: int, maxH: int):
if start == self.l:
self.max_env = max(self.max_env, sum(self.visited))
else:
for i in range(start, self.l):
env = self.envelopes[i]
if env[0] <= maxW or env[1] <= maxH:
continue
self.visited[i] = 1
self.dfs(i + 1, env[0], env[1])
self.visited[i] = 0
self.max_env = max(self.max_env, sum(self.visited))

我写出了这样的代码,回溯的时候,用一个状态组记住一个信封有没有被套入,用 maxW 和 maxH 记录套到此时此刻,最大的 (w, h) 是多少,这样好判断下一个能不能套。因为要严格递增,w 又是有序的,那么我们可以把那些 w 相等的信封跳过,相当于剪枝了。所有信封都检查过,或者后面已经没有任何满足条件的信封时,就搜索到了结局,visited 里用到的信封总数,就是最长值。

第一个信封选择,理论上有 N 种,而第二个信封选择,理论上有 N - 1 种。只有 w 值相等时候才能剪枝。这就看出来,上面的算法,时间复杂度是 O( n! )。果然是无法 AC 的,超出了时间限制。

我的智慧估计就到这里了,能写出来不错了,还没有能力想到更短的算法。然后看了官方解法。我们设想,w 有序后,我们彻底摒弃 w 这个维度的影响,单纯看 h 这个维度的影响,应该怎么做呢?其实就是求剩下 h 的序列里,最长严格递增子序列。但是,这里其实有个问题,就是如果在 w 值相等的情况下,h 的最长严格递增子序列长度是错的。解决方法也很简单,在 w 相等的时候,h 按降序排列,就好了,这样就无法形成严格递增子序列了,只能在 w 不同的情况下增加子序列长度。

到了这一步,就可以设计一个动态规划的算法,其实,这里也是递推法。状态怎么转换的呢?第 i 个状态,就是前 i 个数构成的最长严格递增子序列长度。这时候,新来的第 i + 1 个数,就要去和前 i 个数里的每个数去比,如果更大,就能得到一个比当前这个状态长 1 的子序列,这里最大的一个,就是第 i + 1 个状态的值。太绕了,举个例子:

对于一个序列 6, 5, 1, 2,第 1 个状态是 1,一个数字的子串长度只能是 1,第 2 个状态是 1,因为 5 比 6 小,所以这两个数字构成的最长递增子串长度还是 1,以此类推,第 3 个还是 1,第 4 个是 2,因为 2 比 1 大,所以可以构成一个 2 个数字的递增子串,此时的状态数组就是 1,1,1,2。设想,现在下一个数字是 7。7 比 6 大,所以能构成 2 个数字的递增子串,对于后续的 5 和 1 来说也是如此,但是对于 2 来说,之前 2 已经可以构成一个长度为 2 的子串了,这时候来了 7 就能得到 3 个数字的最长子串了,所以下一个状态的值是 3。扫描到 7 的时候,状态数组变成了:1,1,1,2,3。基于这个原理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution:
def maxEnvelopes(self, envelopes: List[List[int]]) -> int:
self.l = len(envelopes)
if self.l < 2:
return self.l

def cmp(x, y):
if x[0] < y[0]:
return -1
elif x[0] > y[0]:
return 1
else:
if x[1] < y[1]:
return 1
elif x[1] > y[1]:
return -1
else:
return 0
self.envelopes = sorted(envelopes, key = functools.cmp_to_key(cmp))

dp = [1] * self.l
for i in range(self.l):
for j in range(i):
if self.envelopes[j][1] < self.envelopes[i][1]:
dp[i] = max(dp[i], dp[j] + 1)
return max(dp)

忽略上面乱七八糟的 cmp 函数,就是为了排序而已。关注 21 - 25 行。就是上面说的推导下一个状态的过程。这里算法,我们可以看到,每扩展一个新状态,需要扫描前面 n - 1 个状态。所以,时间复杂度是 O( n^2 ),这比起 O( n! ) 已经好了很多了。然而,悲摧的是,仍然不能 AC,还是超时,我看了这次超时的用例,更长了~

这个递推法比回溯法优秀到哪里了呢?其实在回溯法里,每出现一个子串,我们不但计算出了子串的长度,还计算出了子串的排列。visited 数组就给出了哪几个信封够成了该子串。而这本是没必要算出来的。在递推法里,我们只关心能构成的子串中长度最大的那个的长度是多少,其他都不关心了。所以节省了很多。

再来看最后一个解法,官方解法的说明,说实在的太抽象了,都是用的代数,我看不懂。我自己想了个。我们朴素来想想,如果我们想要套得更多层,到底怎么挑选信封比较有效呢?比如第一个信封是 (w:1, h:1),这时候第二个信封有两个选择(w:2, h:5)和 (w:2, h:2),你选哪个呢?不傻的话,都会选后者吧,因为 h 更小,意味着将来可以被下一个信封套进去的概率更大。

这个新的动态规划算法,状态这样设计的,假如,我想组成一个层数为 i 的套娃,那么用到的最小的那个信封高度,记为第 i 个状态值。注意这个算法仍然建立在 w 已经是升序,而 h 是降序的基础上。这时候,新来一个信封。如果它的高度,比 i 大,那只能套在 i 的外面,其高度变成第 i + 1 个状态。如果其高度比 i 小的时候呢?注意我们的策略,组成一个套娃的时候,尽可能选满足条件最小的一个信封,将来套入新信封的概率才会变大。所以,如果其高度比 i 小的话,理论上我们应该把第 i 个状态的值调小,也即没有扩展新状态。不过,我们真应该更新第 i 个么?如果前面还有比这个值大的呢?那这个肯定是套不上去的,所以得往前找一个位置,正好是前一个状态比它小,后一个状态比它大,然后把比它大的那个给更新了。这个位置可以用 二分查找 来搜索。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution:
def maxEnvelopes(self, envelopes: List[List[int]]) -> int:
if not envelopes:
return 0

n = len(envelopes)
envelopes.sort(key=lambda x: (x[0], -x[1]))

f = [envelopes[0][1]]
for i in range(1, n):
if (num := envelopes[i][1]) > f[-1]:
f.append(num)
else:
index = bisect.bisect_left(f, num)
f[index] = num

return len(f)

于是,我们得到了这个算法,这是我抄的,不是我自己手撸的。我们可以看看这个算法的复杂度是什么,排序是 O( n logn ),二分查找是 O( logn ),所以动态规划的过程,复杂度还是 O( n logn ),总的时间复杂度是 O( n logn )。比上一个算法又进步了。

那么这个算法到底节省了什么东西呢?上一个算法里,出现了 i 个数字的时候,我们记录了每个数字能组成最长子序列长度。还拿那个 6, 5, 1, 2 来举例,其状态数组是:1, 1, 1, 2。这里我们看到,6,5,1 三个数字,其能形成的最长子序列都是 1,被记录了 3 次。其实,6 和 5 如果能跟后续新出现的数字,比如 7 ,组成更长子序列的话,1 也一样可以,因为 1 更小,适应性更强。在新算法里,我们就把 6,5 的状态值,都给抛弃了,只保留了 1。这样,如果新出现一个数字的时候,显然我们需要遍历的状态就更少了。于是将 O( n^2 ) 节省到了 O( n logn ),终于 AC 了。

问题总结:

第一,我对 Python 的排序还是太不熟悉了,查了文档,还是写出了那么 Ugly 的排序,其实 List 自己就有 sort 方法,而且,对于 List 是有默认多级排序的,不用自己写 Comparator,尤其我还不熟悉 Python3 怎么利用 Comparator;

第二,才知道,还有个类库 biset,可以直接二分查找,比如这个题的考点不是写二分,完全可以直接用现成算法,来集中注意力研究场景。