Recitation03 Notes

From 6.006 Wiki

Jump to: navigation, search

Binary Search Tree Implementation

# Recitation03 - WF3
# Updated: 09/17/08
#
# An implementation of a binary search tree using
# Python objects.
#
# Users should instantiate and make method calls to the
# BST object instead of BSTnode. The reason is because
# BSTnode objects may shift positions or even values during
# a tree operation, but the BST will be stable.
#
# Keys inserted into the BST are assumed to be unique
# and must be objects that are comparable. 
#

class BST(object):
    """
    A binary search tree. For the most part, this class is just
    a wrapper for BSTnode
    """
    
    def __init__(self):
        """
        Constructs an empty BST.
        """
        self.root = None

    def insert(self, t):
        """
        Inserts the key t into the BST.
        """
        if self.root is None:
            self.root = BSTnode(None, t)
            return self.root
        else:
            return self.root.insert(t)

    def find(self, t):
        """
        Returns the BSTnode with the value t.
        """
        if self.root is None:
            return None
        else:
            return self.root.find(t)
            
    def next_larger(self, t):
        """
        Returns the next largest node or None if self is largest.
        """
        node = self.find(t)

        if node is None:
            return None

        return node.next_larger()

    def delete(self, t):
        """
        Deletes the <b>value</b> of key t from the BST.
        """
        node = self.find(t)

        # handle delete root case here
        # note: it is tricky to handle this case in
        # BSTnode.delete() because if the node has no
        # children, we need to set self.root to None
        if self.root == node:
            if node.right is None:
                self.root = node.left
                node.left.parent = None
            elif node.left is None:
                self.root = node.right
                node.right.parent = None
            else:
                next = node.next_larger()
                temp = node.key
                node.key = next.key
                next.key = temp
                return next.delete()
            return self

        elif node is not None:
            return node.delete()

    def to_orderedlist(self):
        """
        Returns a list of keys in the BST in sorted order
        """
        return self.root.to_orderedlist()

    def __str__(self):
        if self.root is None: return '<empty tree>'
        def recurse(node):
            if node is None: return [], 0, 0
            label = str(node.key)
            left_lines, left_pos, left_width = recurse(node.left)
            right_lines, right_pos, right_width = recurse(node.right)
            middle = max(right_pos + left_width - left_pos + 1, len(label), 2)
            pos = left_pos + middle // 2
            width = left_pos + middle + right_width - right_pos
            while len(left_lines) < len(right_lines):
                left_lines.append(' ' * left_width)
            while len(right_lines) < len(left_lines):
                right_lines.append(' ' * right_width)
            if (middle - len(label)) % 2 == 1 and node.parent is not None and \
               node is node.parent.left and len(label) < middle:
                label += '.'
            label = label.center(middle, '.')
            if label[0] == '.': label = ' ' + label[1:]
            if label[-1] == '.': label = label[:-1] + ' '
            lines = [' ' * left_pos + label + ' ' * (right_width - right_pos),
                     ' ' * left_pos + '/' + ' ' * (middle-2) +
                     '\\' + ' ' * (right_width - right_pos)] + \
              [left_line + ' ' * (width - left_width - right_width) +
               right_line
               for left_line, right_line in zip(left_lines, right_lines)]
            return lines, pos, width
        return '\n'.join(recurse(self.root) [0])
    

class BSTnode(object):
    """
    A node in a binary search tree
    """
    
    def __init__(self, parent, t):
        self.key = t
        self.parent = parent
        self.left = None
        self.right = None

    def find(self, t):
        if t == self.key:
            return self
        
        elif t < self.key:
            if self.left is None:
                return None
            return self.left.find(t)

        else:
            if self.right is None:
                return None
            return self.right.find(t)

    def insert(self, t):
        if t < self.key:
            if self.left is None:
                self.left = BSTnode(self, t)
                return self.left
            else:
                return self.left.insert(t)
        else:
            if self.right is None:
                self.right = BSTnode(self, t)
                return self.right
            else:
                return self.right.insert(t)

    def get_min(self):
        if self.left is None:
            return self
        else:
            return self.left.get_min()

    def next_larger(self):
        # Case 1: has a right subtree
        if self.right is not None:
            return self.right.get_min()
        # Case 2: no right subtree and no parent
        elif self.parent is None:
            return None

        # Case 3: no right subtree but has parent
        p = self.parent
        node = self
        # traverse up the tree until node is the left-child
        while p is not None and p.right==node:
            node = node.parent
            p = node.parent

        return node.parent

    def delete(self):
        # case 1: leaf node
        if self.left is None and self.right is None:
            if self.parent.left == self:
                self.parent.left = None
            else:
                self.parent.right = None

        # case 2a: 1 right child
        elif self.left is None:
            if self.parent.left == self:
                self.parent.left = self.right
            else:
                self.parent.right = self.right
            self.right.parent = self.parent

        # case 2b: 1 left child
        elif self.right is None:
            if self.parent.left == self:
                self.parent.left = self.left
            else:
                self.parent.right = self.left
            self.left.parent = self.parent
            
        # case 3: 2 children
        else:
            next = self.next_larger()
            temp = self.key
            self.key = next.key
            next.key = temp
            return next.delete()

        return self
            
        
    def to_orderedlist(self):
        L = []
        if self.left is not None:
            L.extend(self.left.to_orderedlist())
        L.append(self.key)
        if self.right is not None:
            L.extend(self.right.to_orderedlist())
        return L
      
    def __repr__(self):
        try:
            p = self.parent.key
        except:
            p = None
        try:
            l = self.left.key
        except:
            l = None
        try:
            r = self.right.key
        except:
            r = None
        return "{Key: %s ; Parent: %s ; Left: %s ; Right: %s}" \
               % (self.key,p,l,r)
            

###################

if __name__ == '__main__':
    L = [10, 5, 8, 7, 4, 11, 1, 3, 2, 9, 6, 12]
    T = BST()
    for item in L:
        T.insert(item)

Personal tools