木鸟杂记

大规模数据系统

'Data Structures and Algorithms (III): Splitting a Binary Search Tree'

Introduction

Binary Tree is a fascinating data structure with lots of interesting properties. Binary Search Tree (BST for short below; also known as binary search tree, search binary tree, etc.) is one of the commonly used variants, and it has many interesting properties:

  1. Left children are all smaller, right children are all larger.
  2. In-order traversal yields a sorted sequence.
  3. Projection is in ascending order.

Of course, adding balance introduces even more properties, but let’s set that aside for now. Today, let’s start with a small problem.

Author: Woodpecker Notes https://www.qtmuniao.com, please indicate the source when reposting

Problem Statement

Given a binary search tree t (no duplicate values) and a node value val in it, split t into two new binary trees s and l using val as the boundary, with the following requirements:

  1. Simply discard val.
  2. Trees s and l must still be binary search trees.
  3. All values in tree s are less than val, and all values in tree l are greater than val.
  4. s and l must be split in-place; do not reconstruct the trees.

Thinking

Usually, our first thought goes like this:

First use the BST property to find the node val --> discard this node, its left subtree definitely goes into s, its right subtree definitely goes into l --> then consider its parent node, if the parent is the root it’s easy, balala

But what if the parent is not the root? What if it’s deep in the tree? Without a parent pointer, how do you backtrack?

Most people are completely stumped by these three questions.

A Simple Solution

The solution is actually quite simple: just reverse your thinking. That is, we still need to find the node, but instead of thinking about splitting the tree only at the end, we split while searching for the node. Specifically:

  1. Binary search for val from the root, which forms a search path.
  2. For each node a on this path:
    1. If a > val, then a along with its right subtree are all greater than val.
    2. If a < val, then a along with its right subtree are all less than val.
    3. If a == val (i.e., we found the node), then its left subtree is less than val, and its right subtree is greater than val.
  3. From the root downward, at each step cut off a node on the path along with the corresponding branch. When the node is found, cut off its left and right branches separately.

This gives us all the branches greater than val and all the branches less than val, with val discarded.

The figure below illustrates searching for val = 11 in the tree.

bst-splitbst-split

The next question is, how do we merge the cut branches to get the result?

It is easy to prove (think about it, dear readers) that branches cut from the same side can all be merged together through the “cut point”. For example, in the figure above, the branch rooted at 10 can be merged into the right branch of 8.

Code

I’ve been writing a lot of Python lately, and Python is quite concise:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import typing as T

class TreeNode:
def __init__(self, val: int, left: 'TreeNode' = None, right: 'TreeNode' = None):
self.val = val
self.left = left
self.right = right

def split_bst(root: TreeNode, target: int) -> T.Tuple[TreeNode, TreeNode]:
small_root, large_root = None, None
if root is None:
return small_root, large_root

small_root, large_root = TreeNode(0), TreeNode(0) # dummy root
small_tmp, large_tmp = small_root, large_root

curr = root
while curr.val != target:
# curr root with left branch is small than target
if curr.val < target:
small_tmp.right = curr
small_tmp = curr
curr = curr.right
# curr root with right branch is larger than target
else:
large_tmp.left = curr
large_tmp = curr
curr = curr.left

small_tmp.right = curr.left
large_tmp.left = curr.right

return small_root.right, large_root.left

To verify the correctness of this code, here are functions to construct a BST, validate a BST, and print a binary tree:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Given upper and lower bounds, construct by halving.
def construct_bst(start: int, end: int) -> TreeNode:
if start > end:
return None

if start == end:
return TreeNode(start)

mid = (start+ end) // 2
root = TreeNode(mid)
root.left = construct_bst(start, mid-1)
root.right = construct_bst(mid+1, end)

return root

# Using the property that in-order traversal yields a sorted sequence
def valid_bst(root: TreeNode) -> bool:
prev: TreeNode = None
def valid(curr: TreeNode) -> bool:
if curr is None:
return True

if not valid(curr.left):
return False

nonlocal prev
if prev is not None and prev.val > curr.val:
return False

prev = curr
return valid(curr.right)

return valid(root)

# Reverse in-order traversal; number of spaces = depth * factor
def print_tree(root: TreeNode):
SPACING = 3

def print_util(curr: TreeNode, space: int):
if curr is None:
return

space += SPACING
print_util(curr.right, space)
print(' ' * space, curr.val)
print_util(curr.left, space)

if root is None:
print('Empty Tree')
else:
print_util(root, 0)

if __name__ == '__main__':
t = construct_bst(1, 15)
print_tree(t)
print('above tree is bst:', valid_bst(t))

for split_point in range(11):
t = construct_bst(0, 16)
print('='*50)
print('current split point is', split_point)
s, l = split_bst(t, split_point)
print_tree(s)
print('above tree is bst:', valid_bst(s))
print_tree(l)
print('above tree is bst:', valid_bst(l))

The binary tree printing implementation prints the tree sideways: each line outputs one value, and each column belongs to the same level—it’s a tricky but simple printing method (think about why). You can tilt your head to view it :).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
## Excerpt of two typical results:
current split point is 8
7
6
5
4
3
2
1
0
above tree is bst: True
16
15
14
13
12
11
10
9
above tree is bst: True
==============================
current split point is 9
8
7
6
5
4
3
2
1
0
above tree is bst: True
16
15
14
13
12
11
10
above tree is bst: True

我是青藤木鸟,一个喜欢摄影、专注大规模数据系统的程序员,欢迎关注我的公众号:“木鸟杂记”,有更多的分布式系统、存储和数据库相关的文章,欢迎关注。 关注公众号后,回复“资料”可以获取我总结一份分布式数据库学习资料。 回复“优惠券”可以获取我的大规模数据系统付费专栏《系统日知录》的八折优惠券。

我们还有相关的分布式系统和数据库的群,可以添加我的微信号:qtmuniao,我拉你入群。加我时记得备注:“分布式系统群”。 另外,如果你不想加群,还有一个分布式系统和数据库的论坛(点这里),欢迎来玩耍。

wx-distributed-system-s.jpg