Design a stack data structure that supports standard operations (push, pop, top) along with two max-related operations: peekMax (to get the maximum element without removing it) and popMax (to remove the maximum element, and if there are duplicates, remove the one closest to the top). The goal is to achieve O(1) for top and O(log n) for push, pop, peekMax, and popMax.
Key Insights
Use a doubly-linked list to support fast insertion and removal from the stack.
Maintain a balanced tree (or ordered dictionary/multiset) to quickly locate the maximum element in O(log n) time.
Map each value to its corresponding node(s) in the doubly-linked list, so that when popMax is called, you can immediately remove the corresponding node from the list.
The double-linking allows removal of arbitrary nodes in O(1), once you have the reference.
Space and Time Complexity
Time Complexity:
push: O(log n) due to insertion in the ordered structure.
pop: O(1) for removal from the doubly-linked list, plus O(log n) for updating the tree.
top: O(1)
peekMax: O(1) or O(log n) depending on the tree implementation.
popMax: O(log n) for looking up and removal from the tree, O(1) for doubly-linked list deletion.
Space Complexity:
O(n) for storing all nodes in the doubly-linked list and the tree structure mapping values to nodes.
Solution
We use two data structures:
A doubly-linked list to store the elements in stack order. This supports fast push and pop operations and allows O(1) removal when given a reference.
An ordered structure (like a balanced BST, TreeMap in Java, or SortedList in Python) that maps each value to a list of nodes that hold that value in the linked list. This allows us to quickly determine the maximum element. When there are duplicates, we remove the one closest to the top by storing nodes in order.
The trick is to keep these two structures synchronized. When an element is pushed, add its node to both the doubly-linked list and the ordered structure. When an element is removed (either via pop or popMax), remove its node from both structures.
Code Solutions
# Python solution implementing the MaxStack data structure# We use collections.OrderedDict for demonstration.# However, in Python we do not have a built in balanced BST;# one can use "sortedcontainers" or manually maintain a heap and dictionary.# For clarity, we simulate an ordered map with a TreeMap-like structure using sorted list of keys.# In production, one may use "SortedDict" from sortedcontainers.from collections import defaultdict
classNode:def__init__(self, val): self.val = val
self.prev =None self.next=NoneclassDoubleLinkedList:def__init__(self):# Create dummy head and tail nodes. self.head = Node(0) self.tail = Node(0) self.head.next= self.tail
self.tail.prev = self.head
defadd_last(self, node):# Add node right before the tail. node.prev = self.tail.prev
node.next= self.tail
self.tail.prev.next= node
self.tail.prev = node
defpop(self, node=None):# If no node is provided, pop from the end.ifnot node: node = self.tail.prev
node.prev.next= node.next node.next.prev = node.prev
return node
deftop(self):# Return the last node's value.return self.tail.prev.val
classMaxStack:def__init__(self): self.dll = DoubleLinkedList()# Dictionary mapping value to list of nodes self.valToNodes = defaultdict(list)# Sorted list to simulate ordered set of keys. self.sortedKeys =[]defpush(self, x:int)->None: node = Node(x) self.dll.add_last(node) self.valToNodes[x].append(node)# Insert x into sortedKeys if new, otherwise maintain count.iflen(self.valToNodes[x])==1:# binary insertion into sortedKeys lo, hi =0,len(self.sortedKeys)while lo < hi: mid =(lo + hi)//2if self.sortedKeys[mid]< x: lo = mid +1else: hi = mid
self.sortedKeys.insert(lo, x)defpop(self)->int: node = self.dll.pop() x = node.val
self.valToNodes[x].pop()ifnot self.valToNodes[x]:# Remove x from sortedKeys self.sortedKeys.remove(x)return x
deftop(self)->int:return self.dll.top()defpeekMax(self)->int:# The maximum is the last element in sortedKeys.return self.sortedKeys[-1]defpopMax(self)->int: max_val = self.peekMax()# Get the most recent node with max_val. node = self.valToNodes[max_val].pop()ifnot self.valToNodes[max_val]: self.sortedKeys.pop()# removes the last element# Remove the node from the doubly-linked list. self.dll.pop(node)return max_val
# Example usage:# stk = MaxStack()# stk.push(5)# stk.push(1)# stk.push(5)# print(stk.top()) # returns 5# print(stk.popMax()) # returns 5# print(stk.top()) # returns 1# print(stk.peekMax()) # returns 5# print(stk.pop()) # returns 1# print(stk.top()) # returns 5