分布式系统,程序语言,算法设计

数据结构与算法(三):拆分二叉搜索树

小引

二叉树(Binary Tree)是数据结构中很好玩的一种,可以把玩的地方非常之多。而二叉搜索树(Binary Serach Tree,下面简称 BST,当然也有叫二叉查找树、查找二叉树等等)又是其中常用的一种,它有很多有趣的性质

  1. 左皆小,右皆大。
  2. 中序遍历有序。
  3. 投影升序。

当然,加上平衡会引入更多的特性,这里先按下不表。今天先从个小题入手把玩一番。

作者:青藤木鸟 https://www.qtmuniao.com, 转载请注明出处

入题

给定一个二叉搜索树 t (树中没有相同值的节点)以及其中的一个节点的值 val*,请以 *val 为界,将 t 拆为两棵新的二叉树 sl,要求:

  1. val 扔掉即可。
  2. sl 仍然是二叉排序树。
  3. s 值皆小于 val ,树 l 值皆大于 val
  4. sl 须为原地(in-place)拆解,不能重新构造。

思考

一般我们一上来的思路是这样的:

先利用搜索树的性质,找到这个节点 val –> 这个点不要了,其左子树肯定放 s 中,右子树肯定放 l 中 –> 再考虑其父节点,如果其父节点是节点就好说了,balala

那如果父节点不是根节点呢?如果在很深的地方呢?没有父指针你如何进行回溯呢?

一般人面对三连问,直接就懵逼了。

一个简单解法

解法其实很简单,将思维逆向一下即可。即,我们仍是要寻找该节点,但是不是最后才思考拆分树,而是在找该节点的时候边找边拆分。即:

  1. 从根节点二分查找 val,会形成一条查找路径。
  2. 对于该路径上的节点 a:
    1. 如果 a > val,则 a 连同其右子树都大于 val。
    2. 如果 a < val,则 a 连同其右子树都小于 val。
    3. 如果 a == val(即找到该点),则其左子树小于 val,右子树大于 val。
  3. 自根向下,每次切分路径上一节点,连带相应分支。找到该节点时,分别切分下其左右分支。

即得到所有大于 val 的分支集合,和所有小于 val 的分支集合,val 被扔掉。

如下图,寻找树中 val = 11 节点示意图。

bst-split

下一个问题是,如何将切下来的树枝合到一块得到结果?

易证明(同学们可以思考一下),同侧切下来的“分枝”,都是可以通过“切口”合在一块。如上图中以 10 为根的分支可以合到 8 的右枝。

代码

最近写 Python 多,而且 Python 表达比较简洁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import typing as T

class TreeNode:
def __init__(self, val: int, left: 'TreeNode' = None, right: 'TreeNode' = None):
self.val = val
self.left = left
self.right = right

def split_bst(root: TreeNode, target: int) -> T.Tuple[TreeNode, TreeNode]:
small_root, large_root = None, None
if root is None:
return small_root, large_root

small_root, large_root = TreeNode(0), TreeNode(0) # dummy root
small_tmp, large_tmp = small_root, large_root

curr = root
while curr.val != target:
# curr root with left branch is small than target
if curr.val < target:
small_tmp.right = curr
small_tmp = curr
curr = curr.right
# curr root with right branch is larger than target
else:
large_tmp.left = curr
large_tmp = curr
curr = curr.left

small_tmp.right = curr.left
large_tmp.left = curr.right

return small_root.right, large_root.left

为了验证这段代码的正确性,给出构造搜索二叉树、验证搜索二叉树和打印二叉树的函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# 给出上下边界,折半构造。
def construct_bst(start: int, end: int) -> TreeNode:
if start > end:
return None

if start == end:
return TreeNode(start)

mid = (start+ end) // 2
root = TreeNode(mid)
root.left = construct_bst(start, mid-1)
root.right = construct_bst(mid+1, end)

return root

# 利用中序遍历有序的性质
def valid_bst(root: TreeNode) -> bool:
prev: TreeNode = None
def valid(curr: TreeNode) -> bool:
if curr is None:
return True

if not valid(curr.left):
return False

nonlocal prev
if prev is not None and prev.val > curr.val:
return False

prev = curr
return valid(curr.right)

return valid(root)

# 逆中序遍历,空格个数 = 层深 * factor
def print_tree(root: TreeNode):
SPACING = 3

def print_util(curr: TreeNode, space: int):
if curr is None:
return

space += SPACING
print_util(curr.right, space)
print(' ' * space, curr.val)
print_util(curr.left, space)

if root is None:
print('Empty Tree')
else:
print_util(root, 0)

if __name__ == '__main__':
t = construct_bst(1, 15)
print_tree(t)
print('above tree is bst:', valid_bst(t))

for split_point in range(11):
t = construct_bst(0, 16)
print('='*50)
print('current split point is', split_point)
s, l = split_bst(t, split_point)
print_tree(s)
print('above tree is bst:', valid_bst(s))
print_tree(l)
print('above tree is bst:', valid_bst(l))

其中打印二叉树的实现是侧着打印,每一行输出一个值,每一列同属一层,算是一种 tricky 的简易打印方法(想想为什么)。可以歪着脑袋看:)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
## 截取两个典型结果:
current split point is 8
7
6
5
4
3
2
1
0
above tree is bst: True
16
15
14
13
12
11
10
9
above tree is bst: True
==============================
current split point is 9
8
7
6
5
4
3
2
1
0
above tree is bst: True
16
15
14
13
12
11
10
above tree is bst: True