We use cookies (including Google cookies) to personalize ads and analyze traffic. By continuing to use our site, you accept our Privacy Policy.

Maximum XOR of Two Non-Overlapping Subtrees

Number: 2623

Difficulty: Hard

Paid? Yes

Companies: Directi, Media.net


Problem Description

Given a tree with n nodes (labeled 0 to n–1) where node 0 is the root, each node has an integer value. The subtree of a node is defined as the node itself plus all its descendants. We want to choose two subtrees that do not share any common node and maximize the score defined as the bitwise XOR of the sum of values in the two subtrees. If no two non‐overlapping subtrees exist, we must return 0.


Key Insights

  • First, compute for each node its subtree sum using a DFS.
  • Use an Euler tour (recording tin and tout for every node) so that the subtree of a node appears as a contiguous interval.
  • Two subtrees (with roots u and v) are non-overlapping if and only if their Euler intervals do not intersect; equivalently, either tout[u] < tin[v] or tout[v] < tin[u].
  • Thus, if we sort nodes by tin, for a fixed node i the “left” valid partners are those earlier in the tin‐ordering (j < i) that satisfy tout[j] < tin[i], and the “right” valid partners are those with index j whose tin is greater than the current node’s tout.
  • For fast maximum XOR queries, use a Trie (bitwise trie) where each inserted number is represented in binary.
  • Because the valid partner set “changes” depending on the current node’s Euler indices, the solution splits into two parts: • A left‐pass: iterate in tin order while “adding” nodes that “end” (i.e. their tout is before the current node’s tin). • A right–pass: for a given node, the valid right candidates are those whose tin index is larger than the current node’s tout. This can be supported by building a persistent (or segment‑tree based) Trie keyed on the Euler ordering.
  • The final answer is the maximum of the XOR between the current node’s subtree sum and a valid candidate’s subtree sum from either left or right.

Space and Time Complexity

Time Complexity: O(n · B) where B is the number of bits (roughly 32) – each DFS + trie insertion/query is O(32). Space Complexity: O(n · B) for storing subtree information and the trie nodes.


Solution

We start by performing a DFS from the root to compute three values for each node: • subtreeSum – the sum of values in the subtree rooted at that node. • tin – the time when we first visit the node. • tout – the time when we finish processing the node. Because the DFS Euler tour order yields that the subtree of a node corresponds to a contiguous range [tin, tout], two subtrees (with roots u and v) are non-overlapping if either tout[u] < tin[v] or tout[v] < tin[u].

The idea is to precompute an array "nodes" sorted by tin. Then, in a left–pass we maintain a Trie that only contains those nodes (already processed) whose tout is less than the current node’s tin (ensuring they are not ancestors). We query the Trie with the current node’s subtree sum to get the best XOR candidate.

Similarly, in a right–pass we need to quickly query nodes that come “after” the current node’s Euler interval, i.e. nodes with tin > current node’s tout. This is achieved by building a persistent Trie (or an offline segment tree of Tries) indexed by the tin order so that for any index L we can query the “suffix” Trie built from nodes in positions L..n–1.

The solution uses standard bit–trie operations (insert and query) performed on 32–bit representations. The “gotcha” is to pay close attention when adding a candidate to the Trie: only those nodes with Euler tout less than the current node’s tin (or, for the right–pass, only those with tin > current node’s tout) are valid.

Below are code solutions in several languages with clear line–by–line commentary.


Code Solutions

# Python solution using two passes: one forward (left pass) using a moving pointer over nodes sorted by tout,
# and one backward (right pass) using a prebuilt persistent trie (via a segment tree style technique).
# For brevity, this implementation uses a simplified persistent trie construction.

class TrieNode:
    def __init__(self):
        self.child = [None, None]
        
class Trie:
    def __init__(self):
        self.root = TrieNode()
    
    # Inserts the number into the trie.
    def insert(self, num):
        node = self.root
        for i in range(31, -1, -1):  # assume 32-bit numbers
            bit = (num >> i) & 1
            if not node.child[bit]:
                node.child[bit] = TrieNode()
            node = node.child[bit]
    
    # Query maximum XOR for the given num.
    def query(self, num):
        node = self.root
        if not node:
            return 0
        max_xor = 0
        for i in range(31, -1, -1):
            bit = (num >> i) & 1
            # try to choose the opposite bit if exists
            if node.child[1-bit]:
                max_xor |= (1 << i)
                node = node.child[1-bit]
            else:
                node = node.child[bit]
        return max_xor

import sys
sys.setrecursionlimit(100000)

def dfs(u, parent, graph, values, tin, tout, subtreeSum, time):
    tin[u] = time[0]
    time[0] += 1
    total = values[u]
    for v in graph[u]:
        if v == parent:
            continue
        total += dfs(v, u, graph, values, tin, tout, subtreeSum, time)
    subtreeSum[u] = total
    tout[u] = time[0]-1
    return total

def maxXorNonOverlappingSubtrees(n, edges, values):
    # Build tree structure.
    graph = [[] for _ in range(n)]
    for u,v in edges:
        graph[u].append(v)
        graph[v].append(u)
    
    tin = [0]*n
    tout = [0]*n
    subtreeSum = [0]*n
    time = [0]
    dfs(0, -1, graph, values, tin, tout, subtreeSum, time)
    
    # Create list of nodes with their Euler times and subtree sums.
    # Each element is (tin, tout, sum, node)
    nodes = []
    for i in range(n):
        nodes.append((tin[i], tout[i], subtreeSum[i], i))
    # Sort by tin (note: tin values are 0..n-1 so this is the same order)
    nodes.sort(key=lambda x: x[0])
    
    # Also prepare a copy sorted by tout.
    nodes_by_tout = sorted(nodes, key=lambda x: x[1])
    
    ans = 0
    # Left pass: valid partner j is in prefix and has tout < current tin.
    leftTrie = Trie()
    p = 0  # pointer into nodes_by_tout
    for i in range(n):
        curr_tin, curr_tout, curr_sum, node_id = nodes[i]
        # Add all nodes from nodes_by_tout with tout < curr_tin
        while p < n and nodes_by_tout[p][1] < curr_tin:
            # Insert the subtree sum of that node.
            leftTrie.insert(nodes_by_tout[p][2])
            p += 1
        if p > 0:  # if leftTrie is not empty
            candidate = leftTrie.query(curr_sum)
            ans = max(ans, candidate)
    
    # Right pass: For a node, valid partner j must have tin > curr_tout.
    # To achieve this, we pre-build a suffix trie.
    suffixTrie = [None]*(n+1)
    # suffixTrie[i] will hold a Trie built from nodes[i...n-1]
    suffixTrie[n] = Trie()  # empty trie at end.
    # Build suffixTrie array backwards.
    # For simplicity, we create new tries via copying (persistent trie ideas) and note that in an interview 
    # one may use a persistent trie structure to do this efficiently.
    # Here we use a list to store the subtree sum from the suffix.
    suffixArr = [0]*(n)  # All sums in sorted tin order.
    for i in range(n):
        suffixArr[i] = nodes[i][2]
    # Build a simple array-based structure: 
    # For each query for a given index L, we build a Trie from scratch from suffixArr[L:].
    # (This is not optimal but clearly explains the idea.)
    def buildSuffixTrie(start):
        trie = Trie()
        for j in range(start, n):
            trie.insert(suffixArr[j])
        return trie
    # For each node in nodes, query from index = curr_tout+1 (since tin equals index).
    for i in range(n):
        curr_tin, curr_tout, curr_sum, node_id = nodes[i]
        pos = curr_tout + 1
        if pos < n:
            trie = buildSuffixTrie(pos)
            candidate = trie.query(curr_sum)
            ans = max(ans, candidate)
    return ans

# Example usage:
if __name__ == '__main__':
    n = 6
    edges = [[0,1],[0,2],[1,3],[1,4],[2,5]]
    values = [2,8,3,6,2,5]
    print(maxXorNonOverlappingSubtrees(n, edges, values))  # Expected output: 24
← Back to All Questions