```#!/usr/bin/env python
"""Unit tests for fast tree."""

from cogent.util.unit_test import TestCase, main
from cogent.parse.tree import DndParser
from cogent.maths.unifrac.fast_tree import (count_envs, sum_env_dict,
index_envs, get_branch_lengths, index_tree, bind_to_array,
bind_to_parent_array, _is_parent_empty, delete_empty_parents,
traverse_reduce, bool_descendants, sum_descendants, fitch_descendants,
tip_distances, UniFracTreeNode, FitchCounter, FitchCounterDense,
permute_selected_rows, prep_items_for_jackknife, jackknife_bool,
jackknife_int, unifrac, unnormalized_unifrac, PD, G, unnormalized_G,
unifrac_matrix, unifrac_vector, PD_vector, weighted_unifrac,
weighted_unifrac_matrix, weighted_unifrac_vector, jackknife_array,
env_unique_fraction, unifrac_one_sample, weighted_one_sample)
from numpy import (arange, reshape, zeros, logical_or, array, sum, nonzero,
flatnonzero, newaxis)
from numpy.random import permutation

__author__ = "Rob Knight and Micah Hamady"
__credits__ = ["Rob Knight", "Micah Hamady"]
__version__ = "1.5.3-dev"
__maintainer__ = "Rob Knight, Micah Hamady"
__status__ = "Prototype"

class fast_tree_tests(TestCase):
"""Tests of top-level functions"""
def setUp(self):
"""Define a couple of standard trees"""
self.t1 = DndParser('(((a,b),c),(d,e))', UniFracTreeNode)
self.t2 = DndParser('(((a,b),(c,d)),(e,f))', UniFracTreeNode)
self.t3 = DndParser('(((a,b,c),(d)),(e,f))', UniFracTreeNode)
self.t4 = DndParser('((c)b,((f,g,h)e,i)d)', UniFracTreeNode)
self.t4.Name = 'a'
self.t_str = '((a:1,b:2):4,(c:3,(d:1,e:1):2):3)'

self.t = DndParser(self.t_str, UniFracTreeNode)
self.env_str = """
a   A   1
a   C   2
b   A   1
b   B   1
c   B   1
d   B   3
e   C   1"""
self.env_counts = count_envs(self.env_str.splitlines())
self.node_index, self.nodes = index_tree(self.t)
self.count_array, self.unique_envs, self.env_to_index, \
self.node_to_index = index_envs(self.env_counts, self.node_index)
self.branch_lengths = get_branch_lengths(self.node_index)

self.old_t_str = '((org1:0.11,org2:0.22,(org3:0.12,org4:0.23)g:0.33)b:0.2,(org5:0.44,org6:0.55)c:0.3,org7:0.4)'

self.old_t = DndParser(self.old_t_str, UniFracTreeNode)
self.old_env_str = """
org1    env1    1
org1    env2    1
org2    env2    1
org3    env2    1
org4    env3    1
org5    env1    1
org6    env1    1
org7    env3    1
"""
self.old_env_counts = count_envs(self.old_env_str.splitlines())
self.old_node_index, self.old_nodes = index_tree(self.old_t)
self.old_count_array, self.old_unique_envs, self.old_env_to_index, \
self.old_node_to_index = index_envs(self.old_env_counts, self.old_node_index)
self.old_branch_lengths = get_branch_lengths(self.old_node_index)

def test_traverse(self):
"""traverse should work iterative or recursive"""
stti = self.t4.traverse
stt = self.t4.traverse_recursive
obs = [i.Name for i in stt(self_before=False, self_after=False)]
exp = [i.Name for i in stti(self_before=False, self_after=False)]
self.assertEqual(obs, exp)
obs = [i.Name for i in stt(self_before=True, self_after=False)]
exp = [i.Name for i in stti(self_before=True, self_after=False)]
self.assertEqual(obs, exp)
obs = [i.Name for i in stt(self_before=False, self_after=True)]
exp = [i.Name for i in stti(self_before=False, self_after=True)]
self.assertEqual(obs, exp)
obs = [i.Name for i in stt(self_before=True, self_after=True)]
exp = [i.Name for i in stti(self_before=True, self_after=True)]
self.assertEqual(obs, exp)

def test_count_envs(self):
"""count_envs should return correct counts from lines"""
envs = """
a   A   3   some other junk
a   B
a   C   1
b   A   2

skip
c   B
d
b   A   99
"""
result = count_envs(envs.splitlines())
self.assertEqual(result, \
{'a':{'A':3,'B':1,'C':1},'b':{'A':99},'c':{'B':1}})

def test_sum_env_dict(self):
"""sum_env_dict should return correct counts from env_dict"""
envs = """
a   A   3   some other junk
a   B
a   C   1
b   A   2

skip
c   B
d
b   A   99
"""
result = count_envs(envs.splitlines())
sum_ = sum_env_dict(result)
self.assertEqual(sum_, 105)

def test_index_envs(self):
"""index_envs should map envs and taxa onto indices"""
self.assertEqual(self.unique_envs, ['A','B','C'])
self.assertEqual(self.env_to_index, {'A':0, 'B':1, 'C':2})
self.assertEqual(self.node_to_index,{'a':0, 'b':1, 'c':4, 'd':2, 'e':3})
self.assertEqual(self.count_array, \
array([[1,0,2],[1,1,0],[0,3,0],[0,0,1], \
[0,1,0],[0,0,0],[0,0,0],[0,0,0],[0,0,0]]))

def test_get_branch_lengths(self):
"""get_branch_lengths should make array of branch lengths from index"""
result = get_branch_lengths(self.node_index)
self.assertEqual(result, array([1,2,1,1,3,2,4,3,0]))

def test_env_unique_fraction(self):
"""should report unique fraction of bl in each env """
# testing old unique fraction
cur_count_array = self.count_array.copy()
bound_indices = bind_to_array(self.nodes, cur_count_array)
total_bl = sum(self.branch_lengths)
bool_descendants(bound_indices)
env_bl_sums, env_bl_ufracs = env_unique_fraction(self.branch_lengths, cur_count_array)
# env A has 0 unique bl, B has 4, C has 1
self.assertEqual(env_bl_sums, [0,4,1])
self.assertEqual(env_bl_ufracs, [0,4/17.0,1/17.0])

cur_count_array = self.old_count_array.copy()
bound_indices = bind_to_array(self.old_nodes, cur_count_array)
total_bl = sum(self.old_branch_lengths)
bool_descendants(bound_indices)

env_bl_sums, env_bl_ufracs = env_unique_fraction(self.old_branch_lengths, cur_count_array)
# env A has 0 unique bl, B has 4, C has 1
self.assertEqual(env_bl_sums, env_bl_sums)
self.assertEqual(env_bl_sums, [1.29, 0.33999999999999997, 0.63])
self.assertEqual(env_bl_ufracs, [1.29/2.9,0.33999999999999997/2.9, 0.63/2.9])

def test_index_tree(self):
"""index_tree should produce correct index and node map"""
#test for first tree: contains singleton outgroup
t1 = self.t1
id_1, child_1 = index_tree(t1)
nodes_1 = [n._leaf_index for n in t1.traverse(self_before=False, \
self_after=True)]
self.assertEqual(nodes_1, [0,1,2,3,6,4,5,7,8])
self.assertEqual(child_1, [(2,0,1),(6,2,3),(7,4,5),(8,6,7)])
#test for second tree: strictly bifurcating
t2 = self.t2
id_2, child_2 = index_tree(t2)
nodes_2 = [n._leaf_index for n in t2.traverse(self_before=False, \
self_after=True)]
self.assertEqual(nodes_2, [0,1,4,2,3,5,8,6,7,9,10])
self.assertEqual(child_2, [(4,0,1),(5,2,3),(8,4,5),(9,6,7),(10,8,9)])
#test for third tree: contains trifurcation and single-child parent
t3 = self.t3
id_3, child_3 = index_tree(t3)
nodes_3 = [n._leaf_index for n in t3.traverse(self_before=False, \
self_after=True)]
self.assertEqual(nodes_3, [0,1,2,4,3,5,8,6,7,9,10])
self.assertEqual(child_3, [(4,0,2),(5,3,3),(8,4,5),(9,6,7),(10,8,9)])

def test_bind_to_array(self):
"""bind_to_array should return correct array ranges"""
a = reshape(arange(33), (11,3))
id_, child = index_tree(self.t3)
bindings = bind_to_array(child, a)
self.assertEqual(len(bindings), 5)
self.assertEqual(bindings[0][0], a[4])
self.assertEqual(bindings[0][1], a[0:3])
self.assertEqual(bindings[0][1].shape, (3,3))
self.assertEqual(bindings[1][0], a[5])
self.assertEqual(bindings[1][1], a[3:4])
self.assertEqual(bindings[1][1].shape, (1,3))
self.assertEqual(bindings[2][0], a[8])
self.assertEqual(bindings[2][1], a[4:6])
self.assertEqual(bindings[2][1].shape, (2,3))
self.assertEqual(bindings[3][0], a[9])
self.assertEqual(bindings[3][1], a[6:8])
self.assertEqual(bindings[3][1].shape, (2,3))
self.assertEqual(bindings[4][0], a[10])
self.assertEqual(bindings[4][1], a[8:10])
self.assertEqual(bindings[4][1].shape, (2,3))

def test_bind_to_parent_array(self):
"""bind_to_parent_array should bind tree to array correctly"""
a = reshape(arange(33), (11,3))
index_tree(self.t3)
bindings = bind_to_parent_array(self.t3, a)
self.assertEqual(len(bindings), 10)
self.assertEqual(bindings[0][0], a[8])
self.assertEqual(bindings[0][1], a[10])
self.assertEqual(bindings[1][0], a[4])
self.assertEqual(bindings[1][1], a[8])
self.assertEqual(bindings[2][0], a[0])
self.assertEqual(bindings[2][1], a[4])
self.assertEqual(bindings[3][0], a[1])
self.assertEqual(bindings[3][1], a[4])
self.assertEqual(bindings[4][0], a[2])
self.assertEqual(bindings[4][1], a[4])
self.assertEqual(bindings[5][0], a[5])
self.assertEqual(bindings[5][1], a[8])
self.assertEqual(bindings[6][0], a[3])
self.assertEqual(bindings[6][1], a[5])
self.assertEqual(bindings[7][0], a[9])
self.assertEqual(bindings[7][1], a[10])
self.assertEqual(bindings[8][0], a[6])
self.assertEqual(bindings[8][1], a[9])
self.assertEqual(bindings[9][0], a[7])
self.assertEqual(bindings[9][1], a[9])

def test_delete_empty_parents(self):
"""delete_empty_parents should remove empty parents from bound indices"""
id_to_node, node_first_last = index_tree(self.t)
bound_indices = bind_to_array(node_first_last, self.count_array[:,0:1])
bool_descendants(bound_indices)
self.assertEqual(len(bound_indices), 4)
deleted = delete_empty_parents(bound_indices)
self.assertEqual(len(deleted), 2)
for d in deleted:
self.assertEqual(d[0][0], 1)

def test_traverse_reduce(self):
"""traverse_reduce should reduce array in traversal order."""
id_, child = index_tree(self.t3)
a = zeros((11,3)) + 99    #fill with junk
bindings = bind_to_array(child, a)
a[0] = a[1] = a[2] = a[7] = [0,1,0]
a[3] = [1,0,0]
a[6] = [0,0,1]
f = logical_or.reduce
traverse_reduce(bindings, f)
self.assertEqual(a,\
array([[0,1,0],[0,1,0],[0,1,0],[1,0,0],[0,1,0],[1,0,0],\
[0,0,1],[0,1,0],[1,1,0],[0,1,1],[1,1,1]])
)
f = sum
traverse_reduce(bindings, f)
self.assertEqual( a, \
array([[0,1,0],[0,1,0],[0,1,0],[1,0,0],[0,3,0],[1,0,0],\
[0,0,1],[0,1,0],[1,3,0],[0,1,1],[1,4,1]])
)

def test_bool_descendants(self):
"""bool_descendants should be true if any descendant true"""
#self.t3 = DndParser('(((a,b,c),(d)),(e,f))', UniFracTreeNode)
id_, child = index_tree(self.t3)
a = zeros((11,3)) + 99    #fill with junk
bindings = bind_to_array(child, a)
a[0] = a[1] = a[2] = a[7] = [0,1,0]
a[3] = [1,0,0]
a[6] = [0,0,1]
bool_descendants(bindings)
self.assertEqual(a, \
array([[0,1,0],[0,1,0],[0,1,0],[1,0,0],[0,1,0],[1,0,0],\
[0,0,1],[0,1,0],[1,1,0],[0,1,1],[1,1,1]])
)

def test_sum_descendants(self):
"""sum_descendants should sum total descendants w/ each state"""
id_, child = index_tree(self.t3)
a = zeros((11,3)) + 99    #fill with junk
bindings = bind_to_array(child, a)
a[0] = a[1] = a[2] = a[7] = [0,1,0]
a[3] = [1,0,0]
a[6] = [0,0,1]
sum_descendants(bindings)
self.assertEqual(a, \
array([[0,1,0],[0,1,0],[0,1,0],[1,0,0],[0,3,0],[1,0,0],\
[0,0,1],[0,1,0],[1,3,0],[0,1,1],[1,4,1]])
)

def test_fitch_descendants(self):
"""fitch_descendants should assign states by fitch parsimony, ret. #"""
id_, child = index_tree(self.t3)
a = zeros((11,3)) + 99    #fill with junk
bindings = bind_to_array(child, a)
a[0] = a[1] = a[2] = a[7] = [0,1,0]
a[3] = [1,0,0]
a[6] = [0,0,1]
changes = fitch_descendants(bindings)
self.assertEqual(changes, 2)
self.assertEqual(a, \
array([[0,1,0],[0,1,0],[0,1,0],[1,0,0],[0,1,0],[1,0,0],\
[0,0,1],[0,1,0],[1,1,0],[0,1,1],[0,1,0]])
)

def test_fitch_descendants_missing_data(self):
"""fitch_descendants should work with missing data"""
#tree and envs for testing missing values
t_str = '(((a:1,b:2):4,(c:3,d:1):2):1,(e:2,f:1):3);'
env_str = """a   A
b   B
c   D
d   C
e   C
f   D"""
t = DndParser(t_str, UniFracTreeNode)
node_index, nodes = index_tree(t)
env_counts = count_envs(env_str.split('\n'))

count_array, unique_envs, env_to_index, node_to_index = \
index_envs(env_counts, node_index)

branch_lengths = get_branch_lengths(node_index)
#test just the AB pair
ab_counts = count_array[:, 0:2]
bindings = bind_to_array(nodes, ab_counts)
changes = fitch_descendants(bindings, counter=FitchCounter)
self.assertEqual(changes, 1)
orig_result = ab_counts.copy()
#check that the original Fitch counter gives the expected
#incorrect parsimony result
changes = fitch_descendants(bindings, counter=FitchCounterDense)
self.assertEqual(changes, 5)
new_result = ab_counts.copy()
#check that the two versions fill the array with the same values
self.assertEqual(orig_result, new_result)

def test_tip_distances(self):
"""tip_distances should set tips to correct distances."""
t = self.t
bl = self.branch_lengths.copy()[:,newaxis]
bindings = bind_to_parent_array(t, bl)
tips = []
for n in t.traverse(self_before=False, self_after=True):
if not n.Children:
tips.append(n._leaf_index)
tip_distances(bl, bindings, tips)
self.assertEqual(bl, array([5,6,6,6,6,0,0,0,0])[:,newaxis])

def test_permute_selected_rows(self):
"""permute_selected_rows should switch just the selected rows in a"""
orig = reshape(arange(8),(4,2))
new = orig.copy()
fake_permutation = lambda a: range(a)[::-1] #reverse order
permute_selected_rows([0,2], orig, new, fake_permutation)
self.assertEqual(new,  array([[4,5],[2,3],[0,1],[6,7]]))
#make sure we didn't change orig
self.assertEqual(orig, reshape(arange(8), (4,2)))

def test_prep_items_for_jackknife(self):
"""prep_items_for_jackknife should expand indices of repeated counts"""
a = array([0,1,0,1,2,0,3])
#          0 1 2 3 4 5 6
result = prep_items_for_jackknife(a)
exp = array([1,3,4,4,6,6,6])
self.assertEqual(result, exp)

def test_jackknife_bool(self):
"""jackknife_bool should make a vector with right number of nonzeros"""
fake_permutation = lambda a: range(a)[::-1] #reverse order
orig_vec = array([0,0,1,0,1,1,0,1,1])
orig_items = flatnonzero(orig_vec)
length = len(orig_vec)
result = jackknife_bool(orig_items, 3, len(orig_vec), fake_permutation)
self.assertEqual(result, array([0,0,0,0,0,1,0,1,1]))
#returns the original if trying to take too many
self.assertEqual(jackknife_bool(orig_items, 20, len(orig_vec)), \
orig_vec)

def test_jackknife_int(self):
"""jackknife_int should make a vector with right counts"""
orig_vec = array([0,2,1,0,3,1])
orig_items = array([1,1,2,4,4,4,5])
#                   0 1 2 3 4 5 6
fake_permutation = lambda a: a == 7 and array([4,6,3,1,2,6,5])
result = jackknife_int(orig_items, 4, len(orig_vec), fake_permutation)
self.assertEqual(result, array([0,1,0,0,2,1]))
#returns the original if trying to take too many
self.assertEqual(jackknife_int(orig_items, 20, len(orig_vec)), \
orig_vec)

def test_jackknife_array(self):
"""jackknife_array should make a new array with right counts"""

orig_vec1 = array([0,2,2,3,1])
orig_vec2 = array([2,2,1,2,2])
test_array = array([orig_vec1, orig_vec2])

# implement this, just doing by eye now
#perm_fn = fake_permutation
perm_fn = permutation

#print "need to test with fake permutation!!"

new_mat1 = jackknife_array(test_array, 1, axis=1, jackknife_f=jackknife_int, permutation_f=permutation)
self.assertEqual(new_mat1.sum(axis=0), [1,1,1,1,1])

new_mat2 = jackknife_array(test_array, 2, axis=1, jackknife_f=jackknife_int, permutation_f=permutation)
self.assertEqual(new_mat2.sum(axis=0), [2,2,2,2,2])

new_mat3 = jackknife_array(test_array, 2, axis=0, jackknife_f=jackknife_int, permutation_f=permutation)
self.assertEqual(new_mat3.sum(axis=1), [2,2])

# test that you get orig mat back if too many
self.assertEqual(jackknife_array(test_array, 20, axis=1), test_array)

def test_unifrac(self):
"""unifrac should return correct results for model tree"""
m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
[0,1,1],[1,1,1]])
bl = self.branch_lengths
self.assertEqual(unifrac(bl, m[:,0], m[:,1]), 10/16.0)
self.assertEqual(unifrac(bl, m[:,0], m[:,2]), 8/13.0)
self.assertEqual(unifrac(bl, m[:,1], m[:,2]), 8/17.0)

def test_unnormalized_unifrac(self):
"""unnormalized unifrac should return correct results for model tree"""
m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
[0,1,1],[1,1,1]])
bl = self.branch_lengths
self.assertEqual(unnormalized_unifrac(bl, m[:,0], m[:,1]), 10/17.)
self.assertEqual(unnormalized_unifrac(bl, m[:,0], m[:,2]), 8/17.)
self.assertEqual(unnormalized_unifrac(bl, m[:,1], m[:,2]), 8/17.)

def test_PD(self):
"""PD should return correct results for model tree"""
m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
[0,1,1],[1,1,1]])
bl = self.branch_lengths
self.assertEqual(PD(bl, m[:,0]), 7)
self.assertEqual(PD(bl, m[:,1]), 15)
self.assertEqual(PD(bl, m[:,2]), 11)

def test_G(self):
"""G should return correct results for model tree"""
m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
[0,1,1],[1,1,1]])
bl = self.branch_lengths
self.assertEqual(G(bl, m[:,0], m[:,0]), 0)
self.assertEqual(G(bl, m[:,0], m[:,1]), 1/16.0)
self.assertEqual(G(bl, m[:,1], m[:,0]), 9/16.0)

def test_unnormalized_G(self):
"""unnormalized_G should return correct results for model tree"""
m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
[0,1,1],[1,1,1]])
bl = self.branch_lengths
self.assertEqual(unnormalized_G(bl, m[:,0], m[:,0]), 0/17.)
self.assertEqual(unnormalized_G(bl, m[:,0], m[:,1]), 1/17.)
self.assertEqual(unnormalized_G(bl, m[:,1], m[:,0]), 9/17.)

def test_unifrac_matrix(self):
"""unifrac_matrix should return correct results for model tree"""
m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
[0,1,1],[1,1,1]])
bl = self.branch_lengths
result = unifrac_matrix(bl, m)
self.assertEqual(result, array([[0, 10/16.,8/13.],[10/16.,0,8/17.],\
[8/13.,8/17.,0]]))
#should work if we tell it the measure is asymmetric
result = unifrac_matrix(bl, m, is_symmetric=False)
self.assertEqual(result, array([[0, 10/16.,8/13.],[10/16.,0,8/17.],\
[8/13.,8/17.,0]]))
#should work if the measure really is asymmetric
result = unifrac_matrix(bl,m,metric=unnormalized_G,is_symmetric=False)
self.assertEqual(result, array([[0, 1/17.,2/17.],[9/17.,0,6/17.],\
[6/17.,2/17.,0]]))
#should also match web site calculations
envs = self.count_array
bound_indices = bind_to_array(self.nodes, envs)
bool_descendants(bound_indices)
result = unifrac_matrix(bl, envs)
exp = array([[0, 0.6250, 0.6154], [0.6250, 0, \
0.4706], [0.6154, 0.4707, 0]])
assert (abs(result - exp)).max() < 0.001

def test_unifrac_one_sample(self):
"""unifrac_one_sample should match unifrac_matrix"""
m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
[0,1,1],[1,1,1]])
bl = self.branch_lengths
result = unifrac_matrix(bl, m)

for i in range(len(result)):
one_sam_res = unifrac_one_sample(i, bl, m)
self.assertEqual(result[i], one_sam_res)
self.assertEqual(result[:,i], one_sam_res)

#should work ok on asymmetric metrics
result = unifrac_matrix(bl,m,metric=unnormalized_G,is_symmetric=False)

for i in range(len(result)):
one_sam_res = unifrac_one_sample(i, bl, m, metric=unnormalized_G)
self.assertEqual(result[i], one_sam_res)
# only require row for asym
# self.assertEqual(result[:,i], one_sam_res)

def test_unifrac_vector(self):
"""unifrac_vector should return correct results for model tree"""
m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
[0,1,1],[1,1,1]])
bl = self.branch_lengths
result = unifrac_vector(bl, m)
self.assertFloatEqual(result, array([10./17,6./17,7./17]))

def test_PD_vector(self):
"""PD_vector should return correct results for model tree"""
m = array([[1,0,1],[1,1,0],[0,1,0],[0,0,1],[0,1,0],[0,1,1],[1,1,1],\
[0,1,1],[1,1,1]])
bl = self.branch_lengths
result = PD_vector(bl, m)
self.assertFloatEqual(result, array([7,15,11]))

def test_weighted_unifrac_matrix(self):
"""weighted unifrac matrix should ret correct results for model tree"""
#should match web site calculations
envs = self.count_array
bound_indices = bind_to_array(self.nodes, envs)
sum_descendants(bound_indices)
bl = self.branch_lengths
tip_indices = [n._leaf_index for n in self.t.tips()]
result = weighted_unifrac_matrix(bl, envs, tip_indices)
exp = array([[0, 9.1, 4.5], [9.1, 0, \
6.4], [4.5, 6.4, 0]])
assert (abs(result - exp)).max() < 0.001
#should work with branch length corrections
td = bl.copy()[:,newaxis]
tip_bindings = bind_to_parent_array(self.t, td)
tips = [n._leaf_index for n in self.t.tips()]
tip_distances(td, tip_bindings, tips)
result = weighted_unifrac_matrix(bl, envs, tip_indices, bl_correct=True,
tip_distances=td)
exp = array([[0, 9.1/11.5, 4.5/(10.5+1./3)], [9.1/11.5, 0, \
6.4/(11+1./3)], [4.5/(10.5+1./3), 6.4/(11+1./3), 0]])
assert (abs(result - exp)).max() < 0.001

def test_weighted_one_sample(self):
"""weighted one sample should match weighted matrix"""
#should match web site calculations
envs = self.count_array
bound_indices = bind_to_array(self.nodes, envs)
sum_descendants(bound_indices)
bl = self.branch_lengths
tip_indices = [n._leaf_index for n in self.t.tips()]
result = weighted_unifrac_matrix(bl, envs, tip_indices)
for i in range(len(result)):
one_sam_res = weighted_one_sample(i, bl, envs, tip_indices)
self.assertEqual(result[i], one_sam_res)
self.assertEqual(result[:,i], one_sam_res)

#should work with branch length corrections
td = bl.copy()[:,newaxis]
tip_bindings = bind_to_parent_array(self.t, td)
tips = [n._leaf_index for n in self.t.tips()]
tip_distances(td, tip_bindings, tips)
result = weighted_unifrac_matrix(bl, envs, tip_indices, bl_correct=True,
tip_distances=td)
for i in range(len(result)):
one_sam_res = weighted_one_sample(i, bl, envs, tip_indices,
bl_correct=True, tip_distances=td)
self.assertEqual(result[i], one_sam_res)
self.assertEqual(result[:,i], one_sam_res)

def test_weighted_unifrac_vector(self):
"""weighted_unifrac_vector should ret correct results for model tree"""
envs = self.count_array
bound_indices = bind_to_array(self.nodes, envs)
sum_descendants(bound_indices)
bl = self.branch_lengths
tip_indices = [n._leaf_index for n in self.t.tips()]
result = weighted_unifrac_vector(bl, envs, tip_indices)
self.assertFloatEqual(result[0], sum([
abs(1./2 - 2./8)*1,
abs(1./2 - 1./8)*2,
abs(0 - 1./8)*3,
abs(0 - 3./8)*1,
abs(0 - 1./8)*1,
abs(0 - 4./8)*2,
abs(2./2 - 3./8)*4,
abs(0. - 5./8)*3.]))

self.assertFloatEqual(result[1], sum([
abs(0-.6)*1,
abs(.2-.2)*2,
abs(.2-0)*3,
abs(.6-0)*1,
abs(0-.2)*1,
abs(.6-.2)*2,
abs(.2-.8)*4,
abs(.8-.2)*3]))

self.assertFloatEqual(result[2], sum([
abs(2./3-1./7)*1,
abs(0-2./7)*2,
abs(0-1./7)*3,
abs(0-3./7)*1,
abs(1./3-0)*1,
abs(1./3-3./7)*2,
abs(2./3-3./7)*4,
abs(1./3-4./7)*3]))

if __name__ == '__main__':    #run if called from command-line
main()

```