# LeetCode : Count of Range Sum

Problem Statement

Solution :

This one looks like a very simple problem at first glance but I found it to be quite tricky during implementation. The straightforward solution is to pre-compute the prefix sums S(i), i.e. the sum of all integers from 0 to i-th index for all possible i, and then compute all possible range sums S(i, j), which is the sum of all integers from index i to index j. This can be computed by the simple formula S(i, j) = S(j) - S(i-1). After we have obtained the sums S(i, j) for all pairs i, j (i <= j), we simply check whether S(i, j) lies within the [lower, upper] range. Since there are O(N2) number of pairs i, j and for each pair, the operation of computing the sum S(i, j) and checking if the sum lies within [lower, upper] is O(1), the run-time complexity of this approach is O(N2).

Can we do better than that ? The answer is Yes.

Let us take an example, to illustrate how we are going to solve this problem in a better time complexity of O(N*logN). Let the input array be :

'nums' = [-2, 5, 1, -1, 0, 2, -3, 2, 4, -2] and the range be [-5, 5]

Let's say that we compute the prefix sums S(i) as before, then the prefix sums will be as follows :

[-2, 3, 4, 3, 3, 5, 2, 4, 8, 6]

Then sort the sums. The sorted sums would look like :

[-2, 2, 3, 3, 3, 4, 4, 5, 6, 8]

Along with the prefix sums, we would also like to maintain the indices of the sums, i.e. the sorted sums would be a list of tuples :

H = [(0, -2), (6, 2), (1, 3), (3, 3), (4, 3), (2, 4), (7, 4), (5, 5), (9, 6), (8, 8)]

where the first element of a tuple is the index in the unsorted prefix sum array. Balanced Binary Tree. The numbers inside nodes represents (position, value) and outside node (left_size, right_size).

To find all valid range sums for the first number in the array 'nums', i.e. -2, find the indices i and j in the above sorted array, such that all prefix sums between i and j (inclusive) lies in the range [lower, upper], i.e.

lower <= H[k] <= upper, for k between [i, j] inclusive

In our example lower = -5 and upper = 5. Thus for i = 0 and j = 7, the sorted prefix sums (from (0, -2) to (5, 5)) lies within the required bounds. The count of the valid range sums = 8.

One can easily find out the indices i and j using a binary search on the sorted prefix sums array. For 'i', we need to search for the position at which H[i - 1] < 'lower' but H[i] >= 'lower'. Similarly for 'j', we need to search for the position such that H[j + 1] > 'upper' but H[j] <= 'upper'.

For the next number in the array 'nums' i.e. 5, we cannot consider the prefix sums as it is. We need to subtract of the contribution of the first number i.e. -2 to the prefix sums.

Thus while searching for the indices i and j for the number 5, we need to implicitly add +2 to the prefix sums while doing binary search. Although, the time complexity for searching the indices i and j remains same i.e. O(logN), but observe that the indices i and j doesn't exclude the prefix sums for the numbers located before 5 in the 'nums' array.

For the 'nums' array [-2, 5, 1, -1, 0, 2, -3, 2, 4, -2], the range sums beginning at the 2nd number 5 should exclude all prefix sums that begins with the 1st number -2 (because they have already been counted while computing the range sums for -2 and thus the count will be duplicated). Or in general, range sums starting at the index k, should exclude all prefix sums beginning at the positions 0, 1, 2, ..., or k-1.

For the number 5 at position 1 in 'nums', i = 0 and j = 4, because H + 2 = 6 > 'upper'. But i = 0 is not valid as H[i] = 0 < 1.

The above binary search implementation does not take care of this.

One way to handle this is probably to delete the prefix sum at position k once the range sums from k is computed. But as we know that deletion from an array of size N is O(N). Thus our run-time complexity would again come out to be O(N2).

Alternative Approach :

Create a balanced binary search tree out of the sorted prefix sums array.

But additional to the prefix sum value and the index (apart from left and right pointers), we will also keep the following attributes for each node :

• Parent node pointer
• Size of the left and right sub-trees for each node.
```class BST(object):
def __init__(self, value=None, position=None):
self.value, self.position = value, position
self.left, self.right = None, None
self.left_size, self.right_size = 0, 0
self.parent = None```

Let's define the functions for inserting the sorted prefix sums array into a balanced binary search tree.

```class Solution(object):
def create_bst_from_sorted_arr(self, sorted_arr, left, right, node_pos_map):
if left <= right:
mid = (left + right) / 2

node = BST(sorted_arr[mid], sorted_arr[mid])
node_pos_map[sorted_arr[mid]] = node

a = self.create_bst_from_sorted_arr(sorted_arr, left, mid - 1, node_pos_map)
b = self.create_bst_from_sorted_arr(sorted_arr, mid + 1, right, node_pos_map)

if a is not None:
a.parent = node

if b is not None:
b.parent = node

node.left, node.right = a, b

return node
else:
return None

def countRangeSum(self, nums, lower, upper):
sums = []
for idx in range(len(nums)):
if idx == 0:
sums.append((idx, nums[idx]))
else:
sums.append((idx, nums[idx] + sums[idx - 1]))

sums = sorted(sums, key=lambda k: k)

node_pos_map = dict()

root = self.create_bst_from_sorted_arr(sums, 0, len(sums) - 1, node_pos_map)```

Note that we are using a map 'node_pos_map' that maps an index in the original 'nums' array to the corresponding node in the BST. This will be useful when we want to delete a node from the BST.

The time complexity of creating the balanced BST is O(N*logN).

Okay now we have inserted our prefix sums into a balanced BST. But we have not yet assigned the sizes of the left and right sub-trees for each node. Let's do that using the following recursive function.

```def assign_sizes(self, node):
if node is not None:
self.assign_sizes(node.left)
self.assign_sizes(node.right)

if node.left is not None:
node.left_size = node.left.left_size + node.left.right_size
else:
node.left_size = 1

if node.right is not None:
node.right_size = node.right.left_size + node.right.right_size
else:
node.right_size = 1```

Assigning left and right sub-tree sizes to all the nodes takes O(N) time complexity.

The sizes are calculated in a way such that the size of the left sub-tree and the size of the right-tree includes the node itself. Thus the total size of the sub-tree rooted at a node 'n' would be :

n.size = n.left_size + n.right_size - 1, although we do not store 'size' separately

After we have defined the sizes, its time to compute the count of valid range sums.

The idea is that, for each nums[i], we need to find two nodes 'm' and 'n', such that the nodes hold the following properties :

m.value + offset >= lower and m.left.value + offset < lower

n.value + offset <= upper and n.right.value + offset > upper

where 'offset' is the negative of the sum of the elements from the array 'nums', located before element nums[i]. i.e. for nums offset is 0, for nums offset is -nums, for nums offset is -(nums + nums) and so on.

Also we need to find the lowest common ancestors for the nodes 'm' and 'n'. The lowest common ancestor can be found out by checking whether 'upper - offset' and 'lower - offset' lies to right and left sub-tree respectively of a node.

Once we have found out the lowest common ancestor node T for the nodes 'm' and 'n', then we can simply calculate the count of valid range sums by counting how many nodes are there in between 'm' and 'n' with the root as T.

```def get_common_root(self, root, val1, val2):
if root is not None:
if root.value <= val1 and root.value >= val2:
return root
else:
if root.value < val1 and root.value < val2:
return self.get_common_root(root.right, val1, val2)
else:
return self.get_common_root(root.left, val1, val2)
else:
return root

def get_position_less(self, root, common_root, offset, upper):
if root is None:
return 0
else:
if root.value + offset <= upper:
if root == common_root:
return 1 + self.get_position_less(root.right, common_root, offset, upper)
else:
return root.left_size + self.get_position_less(root.right, common_root, offset, upper)

else:
return self.get_position_less(root.left, common_root, offset, upper)

def get_position_more(self, root, common_root, offset, lower):
if root is None:
return 0
else:
if root.value + offset >= lower:
if root == common_root:
return 1 + self.get_position_more(root.left, common_root, offset, lower)
else:
return root.right_size + self.get_position_more(root.left, common_root, offset, lower)

else:
return self.get_position_more(root.right, common_root, offset, lower)```

The time complexity for each of the above functions is O(logN).

The above functions can be called as following :

```common_root = self.get_common_root(root, upper - offset, lower - offset)
x = self.get_position_less(common_root, common_root, offset, upper)
y = self.get_position_more(common_root, common_root, offset, lower)
out = max(0, x + y - 1)
count += out```

The variable 'x' stores how many nodes have values in between T.value and n.value, whereas the variable 'y' stores how many nodes have values in between T.value and m.value, where T is the lowest common ancestor for m and n.

Once we compute the count of valid range sums for nums[i], we need to delete the node corresponding to index 'i', i.e. the node 'node_pos_map[i]'. Observe that once we delete a node, we also need to adjust the left and right subtree sizes for only the 'concerned' nodes.

As a matter of fact, the number of nodes for which we need to adjust the sizes once we delete a node is O(logN), because affected nodes will be along a single path towards the root i.e. maximum height of tree.

```def adjust_sizes(self, node):
if node is not None and node.parent is not None:
if node.parent.left == node:
node.parent.left_size -= 1
else:
node.parent.right_size -= 1

def delete_node(self, root, node, node_pos_map):
new_node = None

if node.left is not None and node.right is not None:
temp = node.right
node.right_size -= 1

while temp is not None and temp.left is not None:
temp.left_size -= 1
temp = temp.left

node.value, node.position = temp.value, temp.position
node_pos_map[temp.position] = node

node = temp
new_node = temp.right

elif node.left is None and node.right is not None:
new_node = node.right

elif node.right is None and node.left is not None:
new_node = node.left

if node.parent is not None:
if node.parent.left == node:
node.parent.left = new_node
else:
node.parent.right = new_node

else:
root = new_node

if new_node is not None:
new_node.parent = node.parent

return root```

Deletion of a node from a BST is also O(logN).

Note that we need to also update 'node_pos_map', incase we are substituting a root of a sub-tree to be deleted by the in-order successor of the root, because the node object for the in-order successor will change.

Following is the call to the above methods for computing the counts of range sums. Once we find the count for one of the numbers, we delete the node corresponding to that number and update the offset.

```count, offset = 0, 0

for idx in range(len(nums)):
common_root = self.get_common_root(root, upper - offset, lower - offset)
x = self.get_position_less(common_root, common_root, offset, upper)
y = self.get_position_more(common_root, common_root, offset, lower)

out = max(0, x + y - 1)
count += out