#!/usr/bin/env python """Unit tests for fast unifrac.""" from __future__ import division from numpy import array, logical_not, argsort from cogent.util.unit_test import TestCase, main from cogent.parse.tree import DndParser from cogent.maths.unifrac.fast_tree import (count_envs, index_tree, index_envs, get_branch_lengths) from cogent.maths.unifrac.fast_unifrac import (reshape_by_name, meta_unifrac, shuffle_tipnames, weight_equally, weight_by_num_tips, weight_by_branch_length, weight_by_num_seqs, get_all_env_names, consolidate_skipping_missing_matrices, consolidate_missing_zero, consolidate_missing_one, consolidate_skipping_missing_values, UniFracTreeNode, mcarlo_sig, num_comps, fast_unifrac, fast_unifrac_whole_tree, PD_whole_tree, PD_generic_whole_tree, TEST_ON_TREE, TEST_ON_ENVS, TEST_ON_PAIRWISE, shared_branch_length, shared_branch_length_to_root, fast_unifrac_one_sample) from numpy.random import permutation __author__ = "Rob Knight and Micah Hamady" __copyright__ = "Copyright 2007-2012, The Cogent Project" __credits__ = ["Rob Knight", "Micah Hamady", "Daniel McDonald", "Justin Kuczynski"] __license__ = "GPL" __version__ = "1.5.3-dev" __maintainer__ = "Rob Knight, Micah Hamady" __email__ = "rob@spot.colorado.edu, hamady@colorado.edu" __status__ = "Prototype" class unifrac_tests(TestCase): """Tests of top-level functions.""" def setUp(self): """Define some standard trees.""" self.t_str = '((a:1,b:2):4,(c:3,(d:1,e:1):2):3)' self.t = DndParser(self.t_str, UniFracTreeNode) self.env_str = """ a A 1 a C 2 b A 1 b B 1 c B 1 d B 3 e C 1""" self.env_counts = count_envs(self.env_str.splitlines()) self.missing_env_str = """ a A 1 a C 2 e C 1""" self.missing_env_counts = count_envs(self.missing_env_str.splitlines()) self.extra_tip_str = """ q A 1 w C 2 e A 1 r B 1 t B 1 y B 3 u C 1""" self.extra_tip_counts = count_envs(self.extra_tip_str.splitlines()) self.wrong_tip_str = """ q A 1 w C 2 r B 1 t B 1 y B 3 u C 1""" self.wrong_tip_counts = count_envs(self.wrong_tip_str.splitlines()) self.t2_str = '(((a:1,b:1):1,c:5):2,d:4)' self.t2 = DndParser(self.t2_str, UniFracTreeNode) self.env2_str = """ a B 1 b A 1 c A 2 c C 2 d B 1 d C 1""" self.env2_counts = count_envs(self.env2_str.splitlines()) self.trees = [self.t, self.t2] self.envs = [self.env_counts, self.env2_counts] self.mc_1 = array([.5, .4, .3, .2, .1, .6, .7, .8, .9, 1.0]) # from old EnvsNode tests self.old_t_str = '((org1:0.11,org2:0.22,(org3:0.12,org4:0.23)g:0.33)b:0.2,(org5:0.44,org6:0.55)c:0.3,org7:0.4)' self.old_t = DndParser(self.old_t_str, UniFracTreeNode) self.old_env_str = """ org1 env1 1 org1 env2 1 org2 env2 1 org3 env2 1 org4 env3 1 org5 env1 1 org6 env1 1 org7 env3 1 """ self.old_env_counts = count_envs(self.old_env_str.splitlines()) self.old_node_index, self.old_nodes = index_tree(self.old_t) self.old_count_array, self.old_unique_envs, self.old_env_to_index, \ self.old_node_to_index = index_envs(self.old_env_counts, self.old_node_index) self.old_branch_lengths = get_branch_lengths(self.old_node_index) def test_shared_branch_length(self): """Should return the correct shared branch length by env""" t_str = "(((a:1,b:2):3,c:4),(d:5,e:6,f:7):8);" envs = """ a A 1 b A 1 c A 1 d A 1 e A 1 f B 1 """ env_counts = count_envs(envs.splitlines()) t = DndParser(t_str, UniFracTreeNode) exp = {('A',):21.0,('B',):7.0} obs = shared_branch_length(t, env_counts, 1) self.assertEqual(obs, exp) exp = {('A','B'):8.0} obs = shared_branch_length(t, env_counts, 2) self.assertEqual(obs, exp) self.assertRaises(ValueError, shared_branch_length, t, env_counts, 3) def test_shared_branch_length_to_root(self): """Should return the correct shared branch length by env to root""" t_str = "(((a:1,b:2):3,c:4),(d:5,e:6,f:7):8);" envs = """ a A 1 b A 1 c A 1 d A 1 e A 1 f B 1 """ env_counts = count_envs(envs.splitlines()) t = DndParser(t_str, UniFracTreeNode) exp = {'A':29.0,'B':15.0} obs = shared_branch_length_to_root(t, env_counts) self.assertEqual(obs, exp) def test_fast_unifrac(self): """Should calc unifrac values for whole tree.""" #Note: results not tested for correctness here as detailed tests #in fast_tree module. res = fast_unifrac(self.t, self.env_counts) res = fast_unifrac(self.t, self.missing_env_counts) res = fast_unifrac(self.t, self.extra_tip_counts) self.assertRaises(ValueError, fast_unifrac, self.t, \ self.wrong_tip_counts) def test_fast_unifrac_one_sample(self): """ fu one sample should match whole unifrac result, for env 'B'""" # first get full unifrac matrix res = fast_unifrac(self.t, self.env_counts) dmtx, env_order = res['distance_matrix'] dmtx_vec = dmtx[env_order.index('B')] dmtx_vec = dmtx_vec[argsort(env_order)] # then get one sample unifrac vector one_sam_dvec, one_sam_env_order = \ fast_unifrac_one_sample('B', self.t, self.env_counts) one_sam_dvec = one_sam_dvec[argsort(one_sam_env_order)] self.assertFloatEqual(one_sam_dvec, dmtx_vec) def test_fast_unifrac_one_sample2(self): """fu one sam should match whole weighted unifrac result, for env 'B'""" # first get full unifrac matrix res = fast_unifrac(self.t, self.env_counts, weighted=True) dmtx, env_order = res['distance_matrix'] dmtx_vec = dmtx[env_order.index('B')] dmtx_vec = dmtx_vec[argsort(env_order)] # then get one sample unifrac vector one_sam_dvec, one_sam_env_order = \ fast_unifrac_one_sample('B', self.t, self.env_counts,weighted=True) one_sam_dvec = one_sam_dvec[argsort(one_sam_env_order)] self.assertFloatEqual(one_sam_dvec, dmtx_vec) def test_fast_unifrac_one_sample3(self): """fu one sam should match missing env unifrac result, for env 'B'""" # first get full unifrac matrix res = fast_unifrac(self.t, self.missing_env_counts, weighted=False) dmtx, env_order = res['distance_matrix'] dmtx_vec = dmtx[env_order.index('C')] dmtx_vec = dmtx_vec[argsort(env_order)] # then get one sample unifrac vector one_sam_dvec, one_sam_env_order = \ fast_unifrac_one_sample('C', self.t, self.missing_env_counts,weighted=False) one_sam_dvec = one_sam_dvec[argsort(one_sam_env_order)] self.assertFloatEqual(one_sam_dvec, dmtx_vec) # and should raise valueerror when 'B' self.assertRaises(ValueError, fast_unifrac_one_sample, 'B', self.t, self.missing_env_counts,weighted=False) def test_fast_unifrac_whole_tree(self): """ should correctly compute one p-val for whole tree """ # "should test with fake permutation but # using same as old envs nodefor now" result = [] num_to_do = 10 for i in range(num_to_do): real_ufracs, sim_ufracs = fast_unifrac_whole_tree(self.old_t, \ self.old_env_counts, 1000, permutation_f=permutation) rawp, corp = mcarlo_sig(sum(real_ufracs), [sum(x) for x in \ sim_ufracs], 1, tail='high') result.append(rawp) self.assertSimilarMeans(result, 0.047) def test_unifrac_explicit(self): """unifrac should correctly compute correct values. environment M contains only tips not in tree, tip j is in no envs values were calculated by hand """ t1 = DndParser('((a:1,b:2):4,((c:3, j:17),(d:1,e:1):2):3)', \ UniFracTreeNode) # note c,j is len 0 node # /-------- /-a # ---------| \-b # | /-------- /-c # \--------| \-j # \-------- /-d # \-e env_str = """ a A 1 a C 2 b A 1 b B 1 c B 1 d B 3 e C 1 m M 88""" env_counts = count_envs(env_str.splitlines()) self.assertFloatEqual(fast_unifrac(t1,env_counts)['distance_matrix'], \ (array( [[0,10/16, 8/13], [10/16,0,8/17], [8/13,8/17,0]]),['A','B','C'])) # changing tree topology relative to c,j tips shouldn't change # anything t2 = DndParser('((a:1,b:2):4,((c:2, j:16):1,(d:1,e:1):2):3)', \ UniFracTreeNode) self.assertFloatEqual(fast_unifrac(t2,env_counts)['distance_matrix'], \ (array( [[0,10/16, 8/13], [10/16,0,8/17], [8/13,8/17,0]]),['A','B','C'])) def test_unifrac_make_subtree(self): """unifrac result should not depend on make_subtree environment M contains only tips not in tree, tip j, k is in no envs one clade is missing entirely values were calculated by hand we also test that we still have a valid tree at the end """ t1 = DndParser('((a:1,b:2):4,((c:3, (j:1,k:2)mt:17),(d:1,e:1):2):3)',\ UniFracTreeNode) # note c,j is len 0 node # /-------- /-a # ---------| \-b # | /-------- /-c # \--------| \mt------ /-j # | \-k # \-------- /-d # \-e # env_str = """ a A 1 a C 2 b A 1 b B 1 c B 1 d B 3 e C 1 m M 88""" env_counts = count_envs(env_str.splitlines()) self.assertFloatEqual(fast_unifrac(t1,env_counts,make_subtree=False)['distance_matrix'], \ (array( [[0,10/16, 8/13], [10/16,0,8/17], [8/13,8/17,0]]),['A','B','C'])) self.assertFloatEqual(fast_unifrac(t1,env_counts,make_subtree=True)['distance_matrix'], \ (array( [[0,10/16, 8/13], [10/16,0,8/17], [8/13,8/17,0]]),['A','B','C'])) # changing tree topology relative to c,j tips shouldn't change anything t2 = DndParser('((a:1,b:2):4,((c:2, (j:1,k:2)mt:17):1,(d:1,e:1):2):3)', \ UniFracTreeNode) self.assertFloatEqual(fast_unifrac(t2,env_counts,make_subtree=False)['distance_matrix'], \ (array( [[0,10/16, 8/13], [10/16,0,8/17], [8/13,8/17,0]]),['A','B','C'])) self.assertFloatEqual(fast_unifrac(t2,env_counts,make_subtree=True)['distance_matrix'], \ (array( [[0,10/16, 8/13], [10/16,0,8/17], [8/13,8/17,0]]),['A','B','C'])) # ensure we haven't meaningfully changed the tree # by passing it to unifrac t3 = DndParser('((a:1,b:2):4,((c:3, (j:1,k:2)mt:17),(d:1,e:1):2):3)',\ UniFracTreeNode) # note c,j is len 0 node t1_tips = [tip.Name for tip in t1.tips()] t1_tips.sort() t3_tips = [tip.Name for tip in t3.tips()] t3_tips.sort() self.assertEqual(t1_tips, t3_tips) tipj3 = t3.getNodeMatchingName('j') tipb3 = t3.getNodeMatchingName('b') tipj1 = t1.getNodeMatchingName('j') tipb1 = t1.getNodeMatchingName('b') self.assertFloatEqual(tipj1.distance(tipb1), tipj3.distance(tipb3)) def test_PD_whole_tree(self): """PD_whole_tree should correctly compute PD for test tree. environment M contains only tips not in tree, tip j is in no envs """ t1 = DndParser('((a:1,b:2):4,((c:3, j:17),(d:1,e:1):2):3)', \ UniFracTreeNode) env_str = """ a A 1 a C 2 b A 1 b B 1 c B 1 d B 3 e C 1 m M 88""" env_counts = count_envs(env_str.splitlines()) self.assertEqual(PD_whole_tree(t1,env_counts), \ (['A','B','C'], array([7.,15.,11.]))) def test_PD_generic_whole_tree(self): """PD_generic_whole_tree should correctly compute PD for test tree.""" self.t1 = DndParser('((a:1,b:2):4,(c:3,(d:1,e:1):2):3)', \ UniFracTreeNode) self.env_str = """ a A 1 a C 2 b A 1 b B 1 c B 1 d B 3 e C 1""" env_counts = count_envs(self.env_str.splitlines()) self.assertEqual(PD_generic_whole_tree(self.t1,self.env_counts), \ (['A','B','C'], array([7.,15.,11.]))) def test_mcarlo_sig(self): """test_mcarlo_sig should calculate monte carlo sig high/low""" self.assertEqual(mcarlo_sig(.5, self.mc_1, 1, 'high'), (5.0/10, 5.0/10)) self.assertEqual(mcarlo_sig(.5, self.mc_1, 1, 'low'), (4.0/10, 4.0/10)) self.assertEqual(mcarlo_sig(.5, self.mc_1, 5, 'high'), (5.0/10, 1.0)) self.assertEqual(mcarlo_sig(.5, self.mc_1, 5, 'low'), (4.0/10, 1.0)) self.assertEqual(mcarlo_sig(0, self.mc_1, 1, 'low'), (0.0, "<=%.1e" % (1.0/10))) self.assertEqual(mcarlo_sig(100, self.mc_1, 10, 'high'), (0.0, "<=%.1e" % (1.0/10))) def test_num_comps(self): """ test num comps """ self.assertEqual(num_comps(5), sum([i for i in range(1, 5)])) self.assertEqual(num_comps(15), sum([i for i in range(1, 15)])) self.assertEqual(num_comps(10000), sum([i for i in range(1, 10000)])) self.assertEqual(num_comps(1833), sum([i for i in range(1, 1833)])) def test_shuffle_tipnames(self): """shuffle_tipnames should return copy of tree w/ labels permuted""" #Note: this should never fail but is technically still stochastic #5! is 120 so repeating 5 times should fail about 1 in 10^10. for i in range(5): try: t = DndParser(self.t_str) result = shuffle_tipnames(t) orig_names = [n.Name for n in t.tips()] new_names = [n.Name for n in result.tips()] self.assertIsPermutation(orig_names, new_names) return except AssertionError: continue raise AssertionError, "Produced same permutation in 5 tries: broken?" def test_weight_equally(self): """weight_equally should return unit weight per tree""" self.assertEqual(weight_equally(self.trees, self.envs), array([1,1])) def test_weight_by_num_tips(self): """weight_by_num_tips should return tips per tree""" self.assertEqual(weight_by_num_tips(self.trees, self.envs), array([5, 4])) def test_weight_by_branch_length(self): """weight_by_branch_length should return branch length per tree""" self.assertEqual(weight_by_branch_length(self.trees, self.envs), array([17, 14])) def test_weight_by_num_seqs(self): """weight_by_num_seqs should return num seqs per tree""" self.assertEqual(weight_by_num_seqs(self.trees, self.envs), array([10, 8])) def test_get_all_env_names(self): """get_all_env_names should get all names from counts""" self.assertEqual(get_all_env_names(self.env_counts), set('ABC')) def test_consolidate_skipping_missing_matrices(self): """consolidate_skipping_missing_matrices should skip those missing data""" m1 = array([[1,2],[3,4]]) m2 = array([[1,2,3],[4,5,6],[7,8,9]]) m3 = array([[2,2,2],[3,3,3],[4,4,4]]) matrices = [m1,m2, m3] env_names = map(list, ['AB', 'ABC', 'ABC']) weights = [1, 2, 3] all_names =list('ABC') result = consolidate_skipping_missing_matrices(matrices, env_names, weights, all_names) self.assertFloatEqual(result, .4*m2 + .6*m3) def test_consolidate_missing_zero(self): """consolidate_missing_zero should fill missing values to zero""" m1 = array([[1,2],[3,4]]) m2 = array([[1,2,3],[4,5,6],[7,8,9]]) m3 = array([[2,2,2],[3,3,3],[4,4,4]]) matrices = [m1,m2, m3] env_names = map(list, ['AB', 'ABC', 'ABC']) weights = [1, 2, 3] weights = array(weights, float) weights/=weights.sum() all_names =list('ABC') transformed_m1 = array([[1,2,0],[3,4,0],[0,0,0]]) result = consolidate_missing_zero(matrices, env_names, weights, all_names) self.assertFloatEqual(result, (1/6.)*transformed_m1 + (2/6.)*m2 + (3/6.)*m3) def test_consolidate_missing_one(self): """consolidate_missing_one should fill missing off-diags to one""" m1 = array([[1,2],[3,4]]) m2 = array([[1,2,3],[4,5,6],[7,8,9]]) m3 = array([[2,2,2],[3,3,3],[4,4,4]]) matrices = [m1,m2, m3] env_names = map(list, ['AB', 'ABC', 'ABC']) weights = [1, 2, 3] weights = array(weights, float) weights/=weights.sum() all_names =list('ABC') transformed_m1 = array([[1,2,1],[3,4,1],[1,1,0]]) result = consolidate_missing_one(matrices, env_names, weights, all_names) self.assertFloatEqual(result, (1/6.)*transformed_m1 + (2/6.)*m2 + (3/6.)*m3) def test_consolidate_skipping_missing_values(self): """consolidate_skipping_missing_values should average over filled values""" m1 = array([[1,2],[3,4]]) m2 = array([[1,2,3],[4,5,6],[7,8,9]]) m3 = array([[2,2,2],[3,3,3],[4,4,4]]) matrices = [m1,m2, m3] env_names = map(list, ['AB', 'ABC', 'ABC']) weights = [1., 2, 3] weights = array(weights) weights/=weights.sum() all_names =list('ABC') expected = array([[ 1/6.*1 + 2/6.*1 + 3/6.*2, 1/6.*2 + 2/6.*2 + 3/6.*2, 2/5.*3 + 3/5.*2], [ 1/6.*3 + 2/6.*4 + 3/6.*3, 1/6.*4 + 2/6.*5 + 3/6.*3, 2/5.*6 + 3/5.*3], [ 2/5.*7 + 3/5.*4, 2/5.*8 + 3/5.*4, 2/5.*9 + 3/5.*4]]) result = consolidate_skipping_missing_values(matrices, env_names, weights, all_names) self.assertFloatEqual(result, expected) def test_reshape_by_name(self): """reshape_by_name should reshape matrix from old to new names""" old = array([[0,1,2],[3,4,5],[6,7,8]]) old_names = 'ABC' new_names = 'xCyBA' exp = array([[0,0,0,0,0],[0,8,0,7,6],[0,0,0,0,0],\ [0,5,0,4,3],[0,2,0,1,0]]) self.assertEqual(reshape_by_name(old, old_names, new_names), exp) result = reshape_by_name(old, old_names, new_names, masked=True) result.fill_value=0 self.assertEqual(result._data * logical_not(result._mask), exp) def test_meta_unifrac(self): """meta_unifrac should give correct result on sample trees""" tree_list = [self.t, self.t2] envs_list = [self.env_counts, self.env2_counts] result = meta_unifrac(tree_list, envs_list, weight_equally, modes=["distance_matrix"]) u1_distances = array([[0, 10/16.,8/13.],[10/16.,0,8/17.],\ [8/13.,8/17.,0]]) u2_distances = array([[0,11/14.,6/13.],[11/14.,0,7/13.],[6/13.,7/13., 0]]) exp = (u1_distances + u2_distances)/2 self.assertFloatEqual(result['distance_matrix'], (exp, list('ABC'))) if __name__ == '__main__': main()