Zunayed Ali Morsalin Home

# Single Value / Unival tree

Problem Statement:

Given a binary tree, find the number of unival subtrees (the unival tree is a tree which has the same value for every node in it). Let's break down this problem into smaller subproblems. • ex.1 shows just one node. In this case all the nodes are the same so it's a unival tree and the count is 1

• ex.2 3 identical valued nodes should return 3 unival trees.
the two leaf nodes = 2 unival nodes
the whole tree itself is a unival so 2+1 = 3

• ex.3 we see that one of the leaf nodes invalidates the “univalness” of the whole tree. It should still return 2 nodes because of the leaf nodes.

So we can see here we need to evaluate the values of the left and right nodes compared to the root node. We also have to do a number of unival nodes on the left and the right.

So working through our example we can create a formula for adding up all the counts ``````def unival_count(root):
if root == None:
return 0

# check if a leaf node
if root.left_ptr == None and root.right_ptr == None:
return 1

count_left = unival_count(root.left_ptr)
count_right = unival_count(root.right_ptr)
total_count = count_left + count_right

# compare the values of root to left and right
if root.left_ptr and root.left_ptr.val != root.val:
return total_count

if root.right_ptr and root.right_ptr.val != root.val:
return total_count

# Both left and right are equal
return total_count + 1
``````

Doing it again for the second example reveals an edge case however We can get counts of the number of unival trees but we can't evaluate the “univalness” of the whole tree. That leaf node of 4 should invalidate the whole tree and we need to pass that information on up the stack. We can do that by not just returning the total counts but also a boolean to let us know - hey was the subtree below us unival? If it was then we can keep returning true but if not we just bubble up that false value.

So our tweaked and final code will look something like this

``````def findSingleValueTrees(root):
count, is_unival = unival_count(root)
return count

def unival_count(root):
if root == None:
return 0, True

if root.left_ptr == None and root.right_ptr == None:
return 1, True

count_left, is_unival_left = unival_count(root.left_ptr)
count_right, is_unival_right = unival_count(root.right_ptr)
total_count = count_left + count_right

# if we have non unival trees in either side no point in
# adding that + 1 and return early!
if not is_unival_left or not is_unival_right:
return total_count, False

if root.left_ptr and root.left_ptr.val != root.val:
return total_count, False

if root.right_ptr and root.right_ptr.val != root.val:
return total_count, False

# Both left and right are equal
return total_count + 1, True
``````