#!/usr/bin/env python
"""Unit tests for fast tree."""
 
from cogent.util.unit_test import TestCase, main
from cogent.parse.tree import DndParser
from cogent.maths.unifrac.fast_tree import (count_envs, sum_env_dict, 
    index_envs, get_branch_lengths, index_tree, bind_to_array, 
    bind_to_parent_array, _is_parent_empty, delete_empty_parents,
    traverse_reduce, bool_descendants, sum_descendants, fitch_descendants, 
    tip_distances, UniFracTreeNode, FitchCounter, FitchCounterDense,
    permute_selected_rows, prep_items_for_jackknife, jackknife_bool, 
    jackknife_int, unifrac, unnormalized_unifrac, PD, G, unnormalized_G, 
    unifrac_matrix, unifrac_vector, PD_vector, weighted_unifrac, 
    weighted_unifrac_matrix, weighted_unifrac_vector, jackknife_array, 
    env_unique_fraction, unifrac_one_sample, weighted_one_sample)
from numpy import (arange, reshape, zeros, logical_or, array, sum, nonzero, 
    flatnonzero, newaxis)
from numpy.random import permutation    
 
__author__ = "Rob Knight and Micah Hamady"
__copyright__ = "Copyright 2007-2012, The Cogent Project"
__credits__ = ["Rob Knight", "Micah Hamady"]
__license__ = "GPL"
__version__ = "1.5.3-dev"
__maintainer__ = "Rob Knight, Micah Hamady"
__email__ = "rob@spot.colorado.edu, hamady@colorado.edu"
__status__ = "Prototype"
 
class fast_tree_tests(TestCase):
    """Tests of top-level functions"""
    def setUp(self):
        """Define a couple of standard trees"""
        self.t1 = DndParser('(((a,b),c),(d,e))', UniFracTreeNode)
        self.t2 = DndParser('(((a,b),(c,d)),(e,f))', UniFracTreeNode)
        self.t3 = DndParser('(((a,b,c),(d)),(e,f))', UniFracTreeNode)
        self.t4 = DndParser('((c)b,((f,g,h)e,i)d)', UniFracTreeNode)
        self.t4.Name = 'a'
        self.t_str = '((a:1,b:2):4,(c:3,(d:1,e:1):2):3)'
 
        self.t = DndParser(self.t_str, UniFracTreeNode)
        self.env_str = """
a   A   1
a   C   2
b   A   1
b   B   1
c   B   1
d   B   3
e   C   1"""
        self.env_counts = count_envs(self.env_str.splitlines())
        self.node_index, self.nodes = index_tree(self.t)
        self.count_array, self.unique_envs, self.env_to_index, \
            self.node_to_index = index_envs(self.env_counts, self.node_index)
        self.branch_lengths = get_branch_lengths(self.node_index)
 
        self.old_t_str = '((org1:0.11,org2:0.22,(org3:0.12,org4:0.23)g:0.33)b:0.2,(org5:0.44,org6:0.55)c:0.3,org7:0.4)'
 
 
        self.old_t = DndParser(self.old_t_str, UniFracTreeNode)
        self.old_env_str = """
org1    env1    1
org1    env2    1
org2    env2    1
org3    env2    1
org4    env3    1
org5    env1    1
org6    env1    1
org7    env3    1
"""
        self.old_env_counts = count_envs(self.old_env_str.splitlines())
        self.old_node_index, self.old_nodes = index_tree(self.old_t)
        self.old_count_array, self.old_unique_envs, self.old_env_to_index, \
            self.old_node_to_index = index_envs(self.old_env_counts, self.old_node_index)
        self.old_branch_lengths = get_branch_lengths(self.old_node_index)
 
 
 
 
    def test_traverse(self):
        """traverse should work iterative or recursive"""
        stti = self.t4.traverse
        stt = self.t4.traverse_recursive
        obs = [i.Name for i in stt(self_before=False, self_after=False)]
        exp = [i.Name for i in stti(self_before=False, self_after=False)]
        self.assertEqual(obs, exp)
        obs = [i.Name for i in stt(self_before=True, self_after=False)]
        exp = [i.Name for i in stti(self_before=True, self_after=False)]
        self.assertEqual(obs, exp)
        obs = [i.Name for i in stt(self_before=False, self_after=True)]
        exp = [i.Name for i in stti(self_before=False, self_after=True)]
        self.assertEqual(obs, exp)
        obs = [i.Name for i in stt(self_before=True, self_after=True)]
        exp = [i.Name for i in stti(self_before=True, self_after=True)]
        self.assertEqual(obs, exp)
 
    def test_count_envs(self):
        """count_envs should return correct counts from lines"""
        envs = """
a   A   3   some other junk
a   B 
a   C   1
b   A   2
 
skip
c   B
d
b   A   99
"""
        result = count_envs(envs.splitlines())
        self.assertEqual(result, \
            {'a':{'A':3,'B':1,'C':1},'b':{'A':99},'c':{'B':1}})
 
    def test_sum_env_dict(self):
        """sum_env_dict should return correct counts from env_dict"""
        envs = """
a   A   3   some other junk
a   B 
a   C   1
b   A   2
 
skip
c   B
d
b   A   99
"""
        result = count_envs(envs.splitlines())
        sum_ = sum_env_dict(result)
        self.assertEqual(sum_, 105) 
 
    def test_index_envs(self):
        """index_envs should map envs and taxa onto indices"""
        self.assertEqual(self.unique_envs, ['A','B','C'])
        self.assertEqual(self.env_to_index, {'A':0, 'B':1, 'C':2})
        self.assertEqual(self.node_to_index,{'a':0, 'b':1, 'c':4, 'd':2, 'e':3})
        self.assertEqual(self.count_array, \
            array([[1,0,2],[1,1,0],[0,3,0],[0,0,1], \
            [0,1,0],[0,0,0],[0,0,0],[0,0,0],[0,0,0]]))
 
    def test_get_branch_lengths(self):
        """get_branch_lengths should make array of branch lengths from index"""
        result = get_branch_lengths(self.node_index)
        self.assertEqual(result, array([1,2,1,1,3,2,4,3,0]))
 
    def test_env_unique_fraction(self):
        """should report unique fraction of bl in each env """
        # testing old unique fraction   
        cur_count_array = self.count_array.copy()
        bound_indices = bind_to_array(self.nodes, cur_count_array) 
        total_bl = sum(self.branch_lengths)
        bool_descendants(bound_indices)
        env_bl_sums, env_bl_ufracs = env_unique_fraction(self.branch_lengths, cur_count_array)
        # env A has 0 unique bl, B has 4, C has 1        
        self.assertEqual(env_bl_sums, [0,4,1])
        self.assertEqual(env_bl_ufracs, [0,4/17.0,1/17.0])
 
        cur_count_array = self.old_count_array.copy()
        bound_indices = bind_to_array(self.old_nodes, cur_count_array) 
        total_bl = sum(self.old_branch_lengths)
        bool_descendants(bound_indices)
 
        env_bl_sums, env_bl_ufracs = env_unique_fraction(self.old_branch_lengths, cur_count_array)
        # env A has 0 unique bl, B has 4, C has 1        
        self.assertEqual(env_bl_sums, env_bl_sums) 
        self.assertEqual(env_bl_sums, [1.29, 0.33999999999999997, 0.63])
        self.assertEqual(env_bl_ufracs, [1.29/2.9,0.33999999999999997/2.9, 0.63/2.9])
 
    def test_index_tree(self):
        """index_tree should produce correct index and node map"""
        #test for first tree: contains singleton outgroup
        t1 = self.t1
        id_1, child_1 = index_tree(t1)
        nodes_1 = [n._leaf_index for n in t1.traverse(self_before=False, \
            self_after=True)]
        self.assertEqual(nodes_1, [0,1,2,3,6,4,5,7,8])
        self.assertEqual(child_1, [(2,0,1),(6,2,3),(7,4,5),(8,6,7)])
        #test for second tree: strictly bifurcating
        t2 = self.t2
        id_2, child_2 = index_tree(t2)
        nodes_2 = [n._leaf_index for n in t2.traverse(self_before=False, \
            self_after=True)]
        self.assertEqual(nodes_2, [0,1,4,2,3,5,8,6,7,9,10])
        self.assertEqual(child_2, [(4,0,1),(5,2,3),(8,4,5),(9,6,7),(10,8,9)])
        #test for third tree: contains trifurcation and single-child parent
        t3 = self.t3
        id_3, child_3 = index_tree(t3)
        nodes_3 = [n._leaf_index for n in t3.traverse(self_before=False, \
            self_after=True)]
        self.assertEqual(nodes_3, [0,1,2,4,3,5,8,6,7,9,10])
        self.assertEqual(child_3, [(4,0,2),(5,3,3),(8,4,5),(9,6,7),(10,8,9)])
 
    def test_bind_to_array(self):
        """bind_to_array should return correct array ranges"""
        a = reshape(arange(33), (11,3))
        id_, child = index_tree(self.t3)
        bindings = bind_to_array(child, a)
        self.assertEqual(len(bindings), 5)
        self.assertEqual(bindings[0][0], a[4])
        self.assertEqual(bindings[0][1], a[0:3])
        self.assertEqual(bindings[0][1].shape, (3,3))
        self.assertEqual(bindings[1][0], a[5])
        self.assertEqual(bindings[1][1], a[3:4])
        self.assertEqual(bindings[1][1].shape, (1,3))
        self.assertEqual(bindings[2][0], a[8])
        self.assertEqual(bindings[2][1], a[4:6])
        self.assertEqual(bindings[2][1].shape, (2,3))
        self.assertEqual(bindings[3][0], a[9])
        self.assertEqual(bindings[3][1], a[6:8])
        self.assertEqual(bindings[3][1].shape, (2,3))
        self.assertEqual(bindings[4][0], a[10])
        self.assertEqual(bindings[4][1], a[8:10])
        self.assertEqual(bindings[4][1].shape, (2,3))
 
    def test_bind_to_parent_array(self):
        """bind_to_parent_array should bind tree to array correctly"""
        a = reshape(arange(33), (11,3))
        index_tree(self.t3)
        bindings = bind_to_parent_array(self.t3, a)
        self.assertEqual(len(bindings), 10)
        self.assertEqual(bindings[0][0], a[8])
        self.assertEqual(bindings[0][1], a[10])
        self.assertEqual(bindings[1][0], a[4])
        self.assertEqual(bindings[1][1], a[8])
        self.assertEqual(bindings[2][0], a[0])
        self.assertEqual(bindings[2][1], a[4])
        self.assertEqual(bindings[3][0], a[1])
        self.assertEqual(bindings[3][1], a[4])
        self.assertEqual(bindings[4][0], a[2])
        self.assertEqual(bindings[4][1], a[4])
        self.assertEqual(bindings[5][0], a[5])
        self.assertEqual(bindings[5][1], a[8])
        self.assertEqual(bindings[6][0], a[3])
        self.assertEqual(bindings[6][1], a[5])
        self.assertEqual(bindings[7][0], a[9])
        self.assertEqual(bindings[7][1], a[10])
        self.assertEqual(bindings[8][0], a[6])
        self.assertEqual(bindings[8][1], a[9])
        self.assertEqual(bindings[9][0], a[7])
        self.assertEqual(bindings[9][1], a[9])
 
    def test_delete_empty_parents(self):
        """delete_empty_parents should remove empty parents from bound indices"""
        id_to_node, node_first_last = index_tree(self.t)
        bound_indices = bind_to_array(node_first_last, self.count_array[:,0:1])
        bool_descendants(bound_indices)
        self.assertEqual(len(bound_indices), 4)
        deleted = delete_empty_parents(bound_indices)
        self.assertEqual(len(deleted), 2)
        for d in deleted:
            self.assertEqual(d[0][0], 1)
 
    def test_traverse_reduce(self):
        """traverse_reduce should reduce array in traversal order."""
        id_, child = index_tree(self.t3)
        a = zeros((11,3)) + 99    #fill with junk
        bindings = bind_to_array(child, a)
        #load in leaf envs
        a[0] = a[1] = a[2] = a[7] = [0,1,0]
        a[3] = [1,0,0]
        a[6] = [0,0,1]
        f = logical_or.reduce
        traverse_reduce(bindings, f)
        self.assertEqual(a,\
            array([[0,1,0],[0,1,0],[0,1,0],[1,0,0],[0,1,0],[1,0,0],\
            [0,0,1],[0,1,0],[1,1,0],[0,1,1],[1,1,1]])
        )
        f = sum
        traverse_reduce(bindings, f)
        self.assertEqual( a, \
            array([[0,1,0],[0,1,0],[0,1,0],[1,0,0],[0,3,0],[1,0,0],\
            [0,0,1],[0,1,0],[1,3,0],[0,1,1],[1,4,1]])
        )
 
    def test_bool_descendants(self):
        """bool_descendants should be true if any descendant true"""
        #self.t3 = DndParser('(((a,b,c),(d)),(e,f))', UniFracTreeNode)
        id_, child = index_tree(self.t3)
        a = zeros((11,3)) + 99    #fill with junk
        bindings = bind_to_array(child, a)
        #load in leaf envs
        a[0] = a[1] = a[2] = a[7] = [0,1,0]
        a[3] = [1,0,0]
        a[6] = [0,0,1]
        bool_descendants(bindings)
        self.assertEqual(a, \
            array([[0,1,0],[0,1,0],[0,1,0],[1,0,0],[0,1,0],[1,0,0],\
            [0,0,1],[0,1,0],[1,1,0],[0,1,1],[1,1,1]])
        )
 
    def test_sum_descendants(self):
        """sum_descendants should sum total descendants w/ each state"""
        id_, child = index_tree(self.t3)
        a = zeros((11,3)) + 99    #fill with junk
        bindings = bind_to_array(child, a)
        #load in leaf envs
        a[0] = a[1] = a[2] = a[7] = [0,1,0]
        a[3] = [1,0,0]
        a[6] = [0,0,1]
        sum_descendants(bindings)
        self.assertEqual(a, \
            array([[0,1,0],[0,1,0],[0,1,0],[1,0,0],[0,3,0],[1,0,0],\
            [0,0,1],[0,1,0],[1,3,0],[0,1,1],[1,4,1]])
        )
 
    def test_fitch_descendants(self):
        """fitch_descendants should assign states by fitch parsimony, ret. #"""
        id_, child = index_tree(self.t3)
        a = zeros((11,3)) + 99    #fill with junk
        bindings = bind_to_array(child, a)
        #load in leaf envs
        a[0] = a[1] = a[2] = a[7] = [0,1,0]
        a[3] = [1,0,0]
        a[6] = [0,0,1]
        changes = fitch_descendants(bindings)
        self.assertEqual(changes, 2)
        self.assertEqual(a, \
            array([[0,1,0],[0,1,0],[0,1,0],[1,0,0],[0,1,0],[1,0,0],\
            [0,0,1],[0,1,0],[1,1,0],[0,1,1],[0,1,0]])
        )
 
    def test_fitch_descendants_missing_data(self):
        """fitch_descendants should work with missing data"""
        #tree and envs for testing missing values
        t_str = '(((a:1,b:2):4,(c:3,d:1):2):1,(e:2,f:1):3);'
        env_str = """a   A
b   B
c   D
d   C
e   C
f   D"""
        t = DndParser(t_str, UniFracTreeNode)
        node_index, nodes = index_tree(t)
        env_counts = count_envs(env_str.split('\n'))
 
        count_array, unique_envs, env_to_index, node_to_index = \
            index_envs(env_counts, node_index)    
 
        branch_lengths = get_branch_lengths(node_index)
        #test just the AB pair
        ab_counts = count_array[:, 0:2]
        bindings = bind_to_array(nodes, ab_counts)
        changes = fitch_descendants(bindings, counter=FitchCounter)
        self.assertEqual(changes, 1)
        orig_result = ab_counts.copy()
        #check that the original Fitch counter gives the expected 
        #incorrect parsimony result
        changes = fitch_descendants(bindings, counter=FitchCounterDense)
        self.assertEqual(changes, 5)
        new_result = ab_counts.copy()
        #check that the two versions fill the array with the same values
        self.assertEqual(orig_result, new_result)
 
    def test_tip_distances(self):
        """tip_distances should set tips to correct distances."""
        t = self.t
        bl = self.branch_lengths.copy()[:,newaxis]
        bindings = bind_to_parent_array(t, bl)
        tips = []
        for n in t.traverse(self_before=False, self_after=True):
            if not n.Children:
                tips.append(n._leaf_index)
        tip_distances(bl, bindings, tips)
        self.assertEqual(bl, array([5,6,6,6,6,0,0,0,0])[:,newaxis])
 
    def test_permute_selected_rows(self):
        """permute_selected_rows should switch just the selected rows in a"""
        orig = reshape(arange(8),(4,2))
        new = orig.copy()
        fake_permutation = lambda a: range(a)[::-1] #reverse order
        permute_selected_rows([0,2], orig, new, fake_permutation)
        self.assertEqual(new,  array([[4,5],[2,3],[0,1],[6,7]]))
        #make sure we didn't change orig
        self.assertEqual(orig, reshape(arange(8), (4,2)))
 
    def test_prep_items_for_jackknife(self):
        """prep_items_for_jackknife should expand indices of repeated counts"""
        a = array([0,1,0,1,2,0,3])
        #          0 1 2 3 4 5 6
        result = prep_items_for_jackknife(a)
        exp = array([1,3,4,4,6,6,6])
        self.assertEqual(result, exp)
 
    def test_jackknife_bool(self):
        """jackknife_bool should make a vector with right number of nonzeros"""
        fake_permutation = lambda a: range(a)[::-1] #reverse order
        orig_vec = array([0,0,1,0,1,1,0,1,1])
        orig_items = flatnonzero(orig_vec)
        length = len(orig_vec)
        result = jackknife_bool(orig_items, 3, len(orig_vec), fake_permutation)
        self.assertEqual(result, array([0,0,0,0,0,1,0,1,1]))
        #returns the original if trying to take too many
        self.assertEqual(jackknife_bool(orig_items, 20, len(orig_vec)), \
            orig_vec)
 
    def test_jackknife_int(self):
        """jackknife_int should make a vector with right counts"""
        orig_vec = array([0,2,1,0,3,1])
        orig_items = array([1,1,2,4,4,4,5])
        #                   0 1 2 3 4 5 6
        fake_permutation = lambda a: a == 7 and array([4,6,3,1,2,6,5])
        result = jackknife_int(orig_items, 4, len(orig_vec), fake_permutation)
        self.assertEqual(result, array([0,1,0,0,2,1]))
        #returns the original if trying to take too many
        self.assertEqual(jackknife_int(orig_items, 20, len(orig_vec)), \
            orig_vec)
 
    def test_jackknife_array(self):
        """jackknife_array should make a new array with right counts"""
 
        orig_vec1 = array([0,2,2,3,1])
        orig_vec2 = array([2,2,1,2,2])
        test_array = array([orig_vec1, orig_vec2])
 
        # implement this, just doing by eye now
        #perm_fn = fake_permutation
        perm_fn = permutation
 
        #print "need to test with fake permutation!!"
 
 
        new_mat1 = jackknife_array(test_array, 1, axis=1, jackknife_f=jackknife_int, permutation_f=permutation)  
        self.assertEqual(new_mat1.sum(axis=0), [1,1,1,1,1])
 
        new_mat2 = jackknife_array(test_array, 2, axis=1, jackknife_f=jackknife_int, permutation_f=permutation)  
        self.assertEqual(new_mat2.sum(axis=0), [2,2,2,2,2])
 
        new_mat3 = jackknife_array(test_array, 2, axis=0, jackknife_f=jackknife_int, permutation_f=permutation)  
        self.assertEqual(new_mat3.sum(axis=1), [2,2])
 
        # test that you get orig mat back if too many
        self.assertEqual(jackknife_array(test_array, 20, axis=1), test_array)
 
    def test_unifrac(self):
        """unifrac should return correct results for model tree"""
        m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
            [0,1,1],[1,1,1]])
        bl = self.branch_lengths
        self.assertEqual(unifrac(bl, m[:,0], m[:,1]), 10/16.0)
        self.assertEqual(unifrac(bl, m[:,0], m[:,2]), 8/13.0)
        self.assertEqual(unifrac(bl, m[:,1], m[:,2]), 8/17.0)
 
    def test_unnormalized_unifrac(self):
        """unnormalized unifrac should return correct results for model tree"""
        m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
            [0,1,1],[1,1,1]])
        bl = self.branch_lengths
        self.assertEqual(unnormalized_unifrac(bl, m[:,0], m[:,1]), 10/17.)
        self.assertEqual(unnormalized_unifrac(bl, m[:,0], m[:,2]), 8/17.)
        self.assertEqual(unnormalized_unifrac(bl, m[:,1], m[:,2]), 8/17.)
 
    def test_PD(self):
        """PD should return correct results for model tree"""
        m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
            [0,1,1],[1,1,1]])
        bl = self.branch_lengths
        self.assertEqual(PD(bl, m[:,0]), 7)
        self.assertEqual(PD(bl, m[:,1]), 15)
        self.assertEqual(PD(bl, m[:,2]), 11)
 
    def test_G(self):
        """G should return correct results for model tree"""
        m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
            [0,1,1],[1,1,1]])
        bl = self.branch_lengths
        self.assertEqual(G(bl, m[:,0], m[:,0]), 0)
        self.assertEqual(G(bl, m[:,0], m[:,1]), 1/16.0)
        self.assertEqual(G(bl, m[:,1], m[:,0]), 9/16.0)
 
    def test_unnormalized_G(self):
        """unnormalized_G should return correct results for model tree"""
        m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
            [0,1,1],[1,1,1]])
        bl = self.branch_lengths
        self.assertEqual(unnormalized_G(bl, m[:,0], m[:,0]), 0/17.)
        self.assertEqual(unnormalized_G(bl, m[:,0], m[:,1]), 1/17.)
        self.assertEqual(unnormalized_G(bl, m[:,1], m[:,0]), 9/17.)
 
    def test_unifrac_matrix(self):
        """unifrac_matrix should return correct results for model tree"""
        m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
            [0,1,1],[1,1,1]])
        bl = self.branch_lengths
        result = unifrac_matrix(bl, m)
        self.assertEqual(result, array([[0, 10/16.,8/13.],[10/16.,0,8/17.],\
            [8/13.,8/17.,0]]))
        #should work if we tell it the measure is asymmetric
        result = unifrac_matrix(bl, m, is_symmetric=False)
        self.assertEqual(result, array([[0, 10/16.,8/13.],[10/16.,0,8/17.],\
            [8/13.,8/17.,0]]))
        #should work if the measure really is asymmetric
        result = unifrac_matrix(bl,m,metric=unnormalized_G,is_symmetric=False)
        self.assertEqual(result, array([[0, 1/17.,2/17.],[9/17.,0,6/17.],\
            [6/17.,2/17.,0]]))
        #should also match web site calculations
        envs = self.count_array
        bound_indices = bind_to_array(self.nodes, envs)
        bool_descendants(bound_indices)
        result = unifrac_matrix(bl, envs)
        exp = array([[0, 0.6250, 0.6154], [0.6250, 0, \
            0.4706], [0.6154, 0.4707, 0]])
        assert (abs(result - exp)).max() < 0.001
 
    def test_unifrac_one_sample(self):
        """unifrac_one_sample should match unifrac_matrix"""
        m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
            [0,1,1],[1,1,1]])
        bl = self.branch_lengths
        result = unifrac_matrix(bl, m)
 
        for i in range(len(result)):
            one_sam_res = unifrac_one_sample(i, bl, m)
            self.assertEqual(result[i], one_sam_res)
            self.assertEqual(result[:,i], one_sam_res)
 
        #should work ok on asymmetric metrics
        result = unifrac_matrix(bl,m,metric=unnormalized_G,is_symmetric=False)
 
        for i in range(len(result)):
            one_sam_res = unifrac_one_sample(i, bl, m, metric=unnormalized_G)
            self.assertEqual(result[i], one_sam_res)
            # only require row for asym
            # self.assertEqual(result[:,i], one_sam_res)
 
    def test_unifrac_vector(self):
        """unifrac_vector should return correct results for model tree"""
        m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
            [0,1,1],[1,1,1]])
        bl = self.branch_lengths
        result = unifrac_vector(bl, m)
        self.assertFloatEqual(result, array([10./17,6./17,7./17]))
 
    def test_PD_vector(self):
        """PD_vector should return correct results for model tree"""
        m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
            [0,1,1],[1,1,1]])
        bl = self.branch_lengths
        result = PD_vector(bl, m)
        self.assertFloatEqual(result, array([7,15,11]))
 
 
    def test_weighted_unifrac_matrix(self):
        """weighted unifrac matrix should ret correct results for model tree"""
        #should match web site calculations
        envs = self.count_array
        bound_indices = bind_to_array(self.nodes, envs)
        sum_descendants(bound_indices)
        bl = self.branch_lengths
        tip_indices = [n._leaf_index for n in self.t.tips()]
        result = weighted_unifrac_matrix(bl, envs, tip_indices)
        exp = array([[0, 9.1, 4.5], [9.1, 0, \
            6.4], [4.5, 6.4, 0]])
        assert (abs(result - exp)).max() < 0.001
        #should work with branch length corrections
        td = bl.copy()[:,newaxis]
        tip_bindings = bind_to_parent_array(self.t, td)
        tips = [n._leaf_index for n in self.t.tips()]
        tip_distances(td, tip_bindings, tips)
        result = weighted_unifrac_matrix(bl, envs, tip_indices, bl_correct=True,
            tip_distances=td)
        exp = array([[0, 9.1/11.5, 4.5/(10.5+1./3)], [9.1/11.5, 0, \
            6.4/(11+1./3)], [4.5/(10.5+1./3), 6.4/(11+1./3), 0]])
        assert (abs(result - exp)).max() < 0.001
 
    def test_weighted_one_sample(self):
        """weighted one sample should match weighted matrix"""
        #should match web site calculations
        envs = self.count_array
        bound_indices = bind_to_array(self.nodes, envs)
        sum_descendants(bound_indices)
        bl = self.branch_lengths
        tip_indices = [n._leaf_index for n in self.t.tips()]
        result = weighted_unifrac_matrix(bl, envs, tip_indices)
        for i in range(len(result)):
            one_sam_res = weighted_one_sample(i, bl, envs, tip_indices)
            self.assertEqual(result[i], one_sam_res)
            self.assertEqual(result[:,i], one_sam_res)
 
        #should work with branch length corrections
        td = bl.copy()[:,newaxis]
        tip_bindings = bind_to_parent_array(self.t, td)
        tips = [n._leaf_index for n in self.t.tips()]
        tip_distances(td, tip_bindings, tips)
        result = weighted_unifrac_matrix(bl, envs, tip_indices, bl_correct=True,
            tip_distances=td)
        for i in range(len(result)):
            one_sam_res = weighted_one_sample(i, bl, envs, tip_indices,
                bl_correct=True, tip_distances=td)
            self.assertEqual(result[i], one_sam_res)
            self.assertEqual(result[:,i], one_sam_res)
 
 
    def test_weighted_unifrac_vector(self):
        """weighted_unifrac_vector should ret correct results for model tree"""
        envs = self.count_array
        bound_indices = bind_to_array(self.nodes, envs)
        sum_descendants(bound_indices)
        bl = self.branch_lengths
        tip_indices = [n._leaf_index for n in self.t.tips()]
        result = weighted_unifrac_vector(bl, envs, tip_indices)
        self.assertFloatEqual(result[0], sum([
            abs(1./2 - 2./8)*1,
            abs(1./2 - 1./8)*2,
            abs(0 - 1./8)*3,
            abs(0 - 3./8)*1,
            abs(0 - 1./8)*1,
            abs(0 - 4./8)*2,
            abs(2./2 - 3./8)*4,
            abs(0. - 5./8)*3.]))
 
        self.assertFloatEqual(result[1], sum([
            abs(0-.6)*1,
            abs(.2-.2)*2,
            abs(.2-0)*3,
            abs(.6-0)*1,
            abs(0-.2)*1,
            abs(.6-.2)*2,
            abs(.2-.8)*4,
            abs(.8-.2)*3]))
 
        self.assertFloatEqual(result[2], sum([
            abs(2./3-1./7)*1,
            abs(0-2./7)*2,
            abs(0-1./7)*3,
            abs(0-3./7)*1,
            abs(1./3-0)*1,
            abs(1./3-3./7)*2,
            abs(2./3-3./7)*4,
            abs(1./3-4./7)*3]))
 
 
if __name__ == '__main__':    #run if called from command-line
    main()