LeetCode : Recover Binary Search Tree

Problem Statement

Solution :

One approach that uses O(n) extra space, is to store for each node N, the pointer to the nodes with minimum and maximum values in the sub-tree rooted at N. Let's denote the minimum node rooted at N by N.min and the maximum  node by N.max. Then for a sub-tree rooted at N, the sub-tree has a "defect", if :

N.val < N.left.max.val and/or N.val > N.right.min.val

This is quite intuitive as either of the above condition violates the binary search tree principle.

Correct Binary Search Tree. (With Min and Max values shown beside each node)

The maximum value on the left sub-tree should be less than the root and the minimum value on the right sub-tree should be greater than the root.

Nodes 20 and 21 are swapped.

In the above BST, we swap the nodes 20 and 21. Node that for the node 21, the condition :

N.val > N.right.min.val

holds true, because N.val = 21 and N.right.min.val = 20. But the condition :

N.val < N.left.max.val

do not hold true, because N.val = 21 and N.left.max.val = 18.

Following is the code for getting the min-max values. We get the min-max values as a map because we will assume that the class for the TreeNode is defined elsewhere and it is immutable.

```def get_min_max(self, root, min_max_map):
if root.left is not None:
min_left, max_left = self.get_min_max(root.left, min_max_map)
else:
min_left, max_left = None, None

if root.right is not None:
min_right, max_right = self.get_min_max(root.right, min_max_map)
else:
min_right, max_right = None, None

a = root

if min_left is not None and a.val > min_left.val:
a = min_left
if min_right is not None and a.val > min_right.val:
a = min_right

b = root

if max_left is not None and b.val < max_left.val:
b = max_left
if max_right is not None and b.val < max_right.val:
b = max_right

min_max_map[root] = (a, b)

return a, b```

If either of the condition holds true, then we swap the values of root N of the sub-tree with N.left.max or N.right.min depending on which condition holds true, else if both the condition holds, then swap N.left.max with N.right.min.

Nodes 13 and 23 swapped.

In the above BST, the nodes 13 and 23 are swapped from the original. For the root node 15, note that both the conditions holds true, because 15  > 13 and 15 < 23. In such cases, we swap the nodes with values 13 and 23 and do not touch the root node.

The code for adjustment of the nodes is done by the following code:

```def swap(self, node1, node2):
if node1 is not None and node2 is not None:
temp = node1.val
node1.val = node2.val
node2.val = temp

def recover(self, root, min_max_map):
p, q = min_max_map[root.left][1], min_max_map[root.right][0]
x, y = root.val < p.val, root.val > q.val

if x and y:
self.swap(p, q)
elif x:
self.swap(root, p)
elif y:
self.swap(root, q)
else:
self.recover(root.left, min_max_map)
self.recover(root.right, min_max_map)```

In order to achieve the same result as above but without using any extra space, i.e. in O(1) space, we need to combine the 'get_min_max' and 'recover' steps into a single function.

There would be no change in the logic of the code, just that, in the 'get_min_max' function, we need to additionally track the sub-trees where the defects occurred, so that we don't need to traverse again to get the sub-trees where the defects occurred (in 'recover' method).

Below is the modified 'get_min_max' method :

```def get_min_max(self, root):
has_defect, last_defect = 0, (None, None)

if root.left is not None:
min_left, max_left, has_defect_left, last_defect_left = self.get_min_max(root.left)
else:
min_left, max_left, has_defect_left, last_defect_left = None, None, 0, (None, None)

if root.right is not None:
min_right, max_right, has_defect_right, last_defect_right = self.get_min_max(root.right)
else:
min_right, max_right, has_defect_right, last_defect_right = None, None, 0, (None, None)

if max_left is not None and max_left.val > root.val:
last_defect_left = (root, max_left)
has_defect_left = 1

has_defect, last_defect = 1, last_defect_left

if min_right is not None and min_right.val < root.val:
last_defect_right = (root, min_right)
has_defect_right = 1

has_defect, last_defect = 1, last_defect_right

if has_defect_left == 1 and has_defect_right == 1:
self.swap(last_defect_left[1], last_defect_right[1])
has_defect, last_defect = 0, (None, None)

elif has_defect_left == 1:
has_defect, last_defect = 1, last_defect_left

elif has_defect_right == 1:
has_defect, last_defect = 1, last_defect_right

a = root

if min_left is not None and a.val > min_left.val:
a = min_left
if min_right is not None and a.val > min_right.val:
a = min_right

b = root

if max_left is not None and b.val < max_left.val:
b = max_left
if max_right is not None and b.val < max_right.val:
b = max_right

return a, b, has_defect, last_defect

def recoverTree(self, root):
a, b, has_defect, last_defect = self.get_min_max(root)

if has_defect == 1:
self.swap(last_defect[0], last_defect[1])```

We track the sub-tree(s) where the defect occurred using the 'last_defect' variable.

Observe that, if a tree T with root N, has a defect within it, then there could at-least one other sub-tree inside it which will have a defect. While recursing bottom-up, we should not swap nodes immediately on observing a defect because, there might be a parent node which needs to swapped with the defective node.

Nodes 2 and 4 swapped.

When we swap the nodes 2 with node 4, note that there is a defect in the sub-tree rooted at 3 as well as the sub-tree rooted at 2. Thus during the bottom up recursion, if we swap 3 with 4, then it will be incorrect since it's super-parent node 2 also has a defect, which can give the correct fix.

We can't swap any node values till we have seen all the nodes all the way up to the root node.

Nodes 18 and 23 swapped in sub-tree

Only in the case where we observe that both the left and the right sub-trees contains a defect, then we can safely swap these nodes immediately on discovering them. Both the sub-trees rooted at 20 has a defect and thus we can safely swap them instead of deferring the swapping till we have seen the root node.

Run time complexity of the above code is O(n) and space complexity is O(1), because we are not tracking the min-max values separately in memory.

Categories: PROBLEM SOLVING