
/**
 Maintains a set of (key, value) pairs. A linear
 ordering on the keys is assumed via a compareTo
 methos. We also assume that every key in the set
 is unique.
 **/


public class AVLTreeMap<K extends Comparable, V>
{
    AVLTreeNode<K, V> root;
    int height, size;

    // these are for remembering what needs to be returned in the
    // recursive put, remove and removeMin methods
    V retValue;
    K minKey;
    V minValue;
    
    // for testing purposes
    K lastKey;



    public int size()
    {
        return size;
    }

    public boolean isEmpty()
    {
        return size == 0;
    }

    // if there is a node with key k, this will return its value
    // otherwise we return null.
    public V get(K k)
    {
        return get(k, root);
    }

    // recurive get method
    public V get(K k, AVLTreeNode<K, V> w)
    {
        if ( w == null ) return null;
        if ( w.getKey().compareTo( k ) == 0 ) return null;//return w.getValue();

        if ( w.getKey().compareTo( k ) < 0 )
            return( get(k, w.getLeft()) );
        else
            return( get(k, w.getRight()) );
    }

    // if there is a node with key k, it replaces its value with v and
    // returns the old value. Otherwise, a new node with k and v is
    // created and added to the tree
    public V put(K k,V v)
    {
        root = put(k, v, root);
        return retValue;
    }

    // recursive put method. Adds a node to the subtree rooted at w.
    // returns a pointer to the root of the subtree.
    private AVLTreeNode<K, V> put( K k, V v, AVLTreeNode<K, V> w )
    {
        // key wasn't found. Make a new node
        if ( w == null ) 
        {
            retValue = null;
            w = new AVLTreeNode<K, V>(k, v, null, null);
        }
        // key already exists in the tree - replace value
        else if ( w.getKey().compareTo( k ) == 0 )
        {
            retValue = w.getValue();
            w.setValue( v );
        }
        // need to keep searching
        else if ( w.getKey().compareTo( k ) < 0 )
        {
            w.right = put( k, v, w.right );
        }
        else
        {
            w.left = put( k, v, w.left );
        }
        
        // begin AVL code
        w.setHeight();
        
        if ( !w.isBalanced() )
        {
            w = rebalance( w );
        }
        // end AVL code
        return w;
    }

    private AVLTreeNode<K, V> rebalance( AVLTreeNode<K, V> z )
    {
        // find the tallest child and grandshild of z
        AVLTreeNode<K, V> y = z.tallerChild();
        if ( y == null ) System.out.println(" Null child of unbalanced node ");
        AVLTreeNode<K, V> x = y.tallerChild( z );
        if ( z == null ) System.out.println(" Null grandchild of unbalanced node ");

        AVLTreeNode<K, V> a, b, c, t0, t1, t2, t3;
        
        // a is set to be the smallest of {x,y,z}, c the largest and
        // b in the middle. Then t0, t1, t2, t3 are set to be the four
        // subtrees that hang off of x, y, and z. They are named in increasing
        // order. That is, all the keys in t0 are smaller than the keys in t1, etc.
        // There are four cases....
        if ( z.getKey().compareTo( y.getKey() ) < 0 )
        {
            a = z;
            t0 = z.getLeft();
            if ( y.getKey().compareTo( x.getKey() ) < 0 )
            {
                b = y;
                c = x;
                t1 = y.getLeft();
                t2 = x.getLeft();
                t3 = x.getRight();
            }
            else
            {
                c = y;
                b = x;
                t1 = x.getLeft();
                t2 = x.getRight();
                t3 = y.getRight();
            }
        }
        else
        {
            c = z;
            t3 = z.getRight();
            if ( y.getKey().compareTo( x.getKey() ) > 0 )
            {
                b = y;
                a = x;
                t0 = x.getLeft();
                t1 = x.getRight();
                t2 = y.getRight();
            }
            else
            {
                a = y;
                b = x;
                t0 = y.getLeft();
                t1 = x.getLeft();
                t2 = x.getRight();
            }
        }

        // now restructure so that b is the new root of the subtree and
        // a and c are the children of b.
        restructure( a, b, c, t0, t1, t2, t3 );

        return b;
    }

    private void restructure( AVLTreeNode<K, V> a, AVLTreeNode<K, V> b, AVLTreeNode<K, V> c,
                              AVLTreeNode<K, V> t0, AVLTreeNode<K, V> t1, AVLTreeNode<K, V> t2,
                              AVLTreeNode<K, V> t3 )
    {
        b.setLeft( a );
        b.setRight( c );
        a.setLeft( t0 );
        a.setRight( t1 );
        c.setLeft( t2 );
        c.setRight( t3 );

        // the heights of a, b, and c may have changed.
        a.setHeight();
        c.setHeight();
        b.setHeight();
    }

    public void printTree()
    {
        System.out.println(" BEGIN ");
        printTree( root );
        System.out.println(" END ");
    }

    // recursive inorder traversal of the nodes.
    // each visit to a node outputs its key, the keys of its left and
    // right child and its height.
    private void printTree( AVLTreeNode<K, V> x )
    {
        if ( x == null ) return;

        String leftString, rightString;

        printTree( x.getLeft() );

        if  ( x.getLeft() == null )
            leftString = "X";
        else leftString = x.getLeft().getKey().toString();

        if  ( x.getRight() == null )
            rightString = "X";
        else rightString = x.getRight().getKey().toString();

        System.out.println( " " + x.getKey().toString() + " " + leftString
                                + " " + rightString + " " + x.getHeight() );

        printTree( x.getRight() );
    }

    // recursive remove method. if there is a node with key k, this
    // returns its value. Otherwise, null is returned.
    public V remove(K k)
    {
        root = remove(k, root);
        return retValue;
    }

    // removes the node from the subtree rooted at w. The root of
    // the resulting subtree is returned.
    private AVLTreeNode<K, V> remove(K k, AVLTreeNode<K, V> w)
    {
        // key not found - return null
        if ( w == null )
        {
            retValue = null;
            return null;
        }
        // keep searching in the left subtree
        else if ( w.getKey().compareTo( k ) > 0 )
        {
            w.left = remove( k, w.left );
        }
        // keep searching in the right subtree
        else if ( w.getKey().compareTo( k ) < 0 )
        {
            w.right = remove( k, w.right );
        }
        // found the node to be removed
        else 
        {
            // easy case - node to be removed has at most one child
            if ( w.getRight() == null || w.getLeft() == null )
            {
                retValue = w.getValue();
                w = ( w.getRight() == null ? w.getLeft() : w.getRight() );
            }
            // replace with minimum node in right subtree
            else
            {
                w.setRight( removeMin( w.right ) );
                w.setValue( minValue );
                w.setKey( minKey );
            }
        }
        
        // begin AVL code
        if ( w != null )
        {
            w.setHeight();
            if ( !w.isBalanced() )
            {
                w = rebalance( w );
            }
        }
        // end AVL code
        return w;
    }

    // removes the smallest node in the subtree rooted at w.
    // the data for the removed node is stored in the member
    // variables minValue and minKey.
    private AVLTreeNode<K, V> removeMin( AVLTreeNode<K, V> w )
    {
        if ( w.getLeft() == null )
        {
            minValue = w.getValue();
            minKey = w.getKey();
            w = w.getRight();
        }
        else 
            w.setLeft( removeMin( w.getLeft() ) );
        
        // begin AVL code
        if ( w != null )
        {
            w.setHeight();
            if ( !w.isBalanced() )
            {
                w = rebalance( w );
            }
        }
        // end AVL code
        return w;

    }
    
    // test to ensure tree is in correct binary search order
    public boolean testOrder()
    {
        lastKey = null;
        return testOrder( root );
    }
    
    private boolean testOrder( AVLTreeNode<K, V> w )
    {
        if ( w == null ) return true;
        boolean test = true;
        test = test && testOrder( w.left );
        test = test && ( lastKey == null || lastKey.compareTo( w.getKey() ) < 0 );
        lastKey = w.getKey();
        test = test && testOrder( w.right );
        return test;
    }
    
    // tests that tree heights are correct and tree is balanced
    public boolean testHeight()
    {
        return testHeight( root );
    }
    
    private boolean testHeight( AVLTreeNode<K, V> w )
    {
        if ( w == null ) return true;
        boolean test = true;
        test = test && testHeight( w.left );
        test = test && testHeight( w.right );
        int oldHeight = w.getHeight();
        w.setHeight();
        test = test && ( oldHeight == w.getHeight() ) && w.isBalanced();
        return test;
    }
}
