```"""
TreeTools.py - tools for dealing with domain trees.
====================================================

:Author:
:Release: \$Id\$
:Date: |today|
:Tags: Python

A tree is a list of nodes:

node 0 is the root

A node is a list consisting of:

[level, parent, left_child, right_child, [ranges]]

The best way of course would have been to do this object
oriented, but I thought that might be overkill.

"""

import Bio
from Bio.Nexus.Nexus import Nexus
from Bio.Nexus.Trees import Tree

import Tree
import sys, string, re, StringIO

from types import *

import Intervalls

##--------------------------------------------------------
def SetChildren( tree ):
"""sets children correctly."""

## set all children to zero
for node in range(len(tree)):

level, parent, left_child, right_child, ranges = tree[node]

tree[node][2] = 0
tree[node][3] = 0

if tree[parent][2]:
tree[parent][3] = node
else:
tree[parent][2] = node

##--------------------------------------------------------
def CollapseTree(tree):
"""remove all empty levels in the tree and renumber nodes
so that they are continuous.
"""

SetChildren(tree)
map_old2new = {}

index = 1
new_tree = []

## write root
new_tree.append( [0, 0, 0, 0, tree[0][4]] )
map_old2new[0] = 0

## PrintTree(tree)

for old_node in range(1, len(tree)):
level, parent, left_child, right_child, ranges = tree[old_node]

## if only a single child of a parent, skip this node
## if ranges are empty: skip this node
if tree[parent][2] == 0 or tree[parent][3] == 0 or len(ranges) == 0:
map_old2new[old_node] = map_old2new[parent]
continue

map_old2new[old_node] = index

new_tree.append( [new_tree[map_old2new[parent]][0] + 1, map_old2new[parent], 0, 0, ranges] )

index += 1

## PrintTree( new_tree )

## PrintTree( new_tree )
## print "#########"

return new_tree

##--------------------------------------------------------
def RemoveEmptyNodes( tree ):
"""remove all empty node from the tree and
renumber nodes, so that they are continuous.
"""

map_old2new = {}

index = 0
new_tree = []

## write root
new_tree.append( [0, 0, 0, 0, tree[0][4]] )
map_old2new[0] = 0

for old_node in range(1, len(tree)):

level, parent, left_child, right_child, ranges = tree[old_node]

## if range is empty, do not save this child
if not ranges:
map_old2new[old_node] = map_old2new[parent]
continue

index += 1
map_old2new[old_node] = index

new_tree.append( [new_tree[map_old2new[parent]][0] + 1, map_old2new[parent], 0, 0, ranges] )

##     PrintTree( new_tree )
##     print "#########"

return new_tree

##--------------------------------------------------------
def PrintTree(tree):
"""print tree."""
for node in range(0, len(tree)):
print "%i\t" % node + string.join( map(str, tree[node]), "\t")

##--------------------------------------------------------
def RemoveSmallDomains( tree, min_segment = 10):
"""remove small ranges from the tree.

If a small range is encountered, it is attributed to
the other child.

Note: this breaks down, if both children are below
the minimum segment size!
"""

PrintTree(tree)
print "#########"

## move segments between left and right children if they
## are small.
for node in range(1, len(tree)):

level, parent, left_child, right_child, ranges = tree[node]

new_ranges = []
for xfrom, xto in ranges:

if tree[parent][2] == node:
other_child = tree[parent][3]
else:
other_child = tree[parent][2]

if xto - xfrom + 1 >= min_segment:
new_ranges.append((xfrom, xto))
else:
tree[other_child][4].append((xfrom, xto))
## combine intervalls (now, otherwise the segments will be moved back and forth
tree[other_child][4] = Intervalls.CombineIntervallsDistance(tree[other_child][4], 1)

tree[node][4] = new_ranges

##     PrintTree(tree)
##     print "#########"
return RemoveEmptyNodes(tree)

##--------------------------------------------------------
def TruncateTree( tree, min_segment = 10):
"""truncate tree, if a segment is less than a certain size.
"""

##     PrintTree(tree)
##     print "#########"

for node in range(1, len(tree)):

level, parent, left_child, right_child, ranges = tree[node]

if not ranges:
truncate = 1
else:
truncate = 0
for xfrom, xto in ranges:
if xto - xfrom <= min_segment:
truncate = 1

if truncate:
tree[node][4] = []
if left_child: tree[left_child][4] = []
if right_child: tree[right_child][4] = []

##     PrintTree(tree)
##     print "#########"

return RemoveEmptyNodes( tree )

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def Newick2Nexus( infile ):
"""convert newick formatted tree(s) into a nexus object.

Multiple trees are separated by a semicolon. Tree names can
be given by fasta-style separators, i.e., lines starting with
'>'.

If the token [&&NHX is found in the tree, it is assumed to be
output from njtree and support values are added. Support values are
"""
lines = ["#NEXUS\nBegin trees;\n"]

## build one line per tree
if type(infile) == FileType:
elif type(infile) in (TupleType,ListType):
tlines = infile
else:
tlines = [infile,]

f = []
id = None
ntrees = 0

if len(f) == 0: return
if not id: id = "tree%i" % ntrees
id = re.sub( "=", "_", id)

s = "".join(f)[:-1]
if s.find( "[&&NHX" ) >= 0:
## process njtree trees with bootstrap values
fragments = []
l = 0
for x in re.finditer( "(:[-0-9.]+\[&&NHX[^\]]*\])", s):
fragments.append( s[l:x.start()] )
frag = s[x.start():x.end()]
bl = ":%s" % re.search(":([-0-9.]+)", frag).groups()[0]

rx = re.search("B=([-0-9.]+)", frag)
if rx:
support=":%s" % rx.groups()[0]
else:
support = ""

fragments.append( "%s%s" % (support,bl) )

l = x.end()

fragments.append( s[l:len(s)] )

s = "".join(fragments)
s = re.sub( "\[&&NHX[^\]]*\]", "", s )

lines.append("tree '%s' = %s;\n" % (id, s))

for line in tlines:
line = line.strip()
if not line: continue
if line[0] == "#": continue
if line[0] == ">":
id = line[1:]
continue

line = re.sub("\s","", line).strip()
f.append( line )
if line[-1] == ";":
f, id = [], None
ntrees += 1

## treat special case of trees without trailing semicolon
lines.append( "End;" )

## previoulsy, a string was ok, now a string
## is interpreted as a filename.
nexus = Nexus( StringIO.StringIO("".join(lines)))

if len(nexus.trees) == 0:
raise ValueError("no tree found in file %s" % str(infile) )

## remove starting/ending ' from name
for tree in nexus.trees:
tree.name = tree.name[1:-1]

Tree.updateNexus( nexus )

return nexus

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def Nexus2Newick( nexus, with_branchlengths = True, with_names = False,
write_all_taxa = False ):
"""convert nexus tree format to newick format.
"""
lines = []

for tree in nexus.trees:
if with_names:
lines.append( ">%s" % tree.name )

lines.append(Tree2Newick( tree,
with_branch_lengths = with_branchlengths,
write_all_taxa = write_all_taxa ))

return string.join(lines, "\n")

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def Tree2Newick( tree, with_branch_lengths = True, write_all_taxa = False ):
"""convert tree to newick format."""
s = tree.to_string( branchlengths_only = with_branch_lengths,
write_all_taxa = write_all_taxa )
return string.strip(s.split("=")[1])

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def Newick2Tree( txt ):
"""convert tree to newick format."""
return Newick2Nexus(txt).trees[0]

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def WriteNexus( nexus, **kwargs ):
"""write nexus file format.
"""
lines = ["#NEXUS\nBegin trees;\n"]
ntrees = 0
for t in nexus.trees:
lines.append("%s\n" % (t.to_string( **kwargs)))
lines.append( "\nEnd;\n")
return "\n".join(lines)

##--------------------------------------------------------
def GetTaxa( tree ):
"""retrieve taxa in a tree."""
return map( lambda x: tree.node(x).get_data().taxon, tree.get_terminals())

##--------------------------------------------------------
def GetTaxonomicNames( tree ):
"""get list of taxa."""

return GetTaxa( tree )

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def MapTaxa( tree, map_old2new, remove_unknown = False ):
"""map taxa in tree."""

unknown = []
for n, node in tree.chain.items():
if node.data.taxon:
try:
node.data.taxon = map_old2new[node.data.taxon]
except KeyError:
unknown.append( node.data.taxon )

if remove_unknown:
for taxon in unknown:
tree.prune( taxon )

##--------------------------------------------------------
def Branchlength2Support( tree ):
"""Copy values stored in data.branchlength to data.support, and do not set branchlength to 0.0

This is necessary when support has been stored as branchlength (e.g. paup), and has thus
"""

for n in tree.chain.keys():
tree.node(n).data.support=tree.node(n).data.branchlength

##-------------------------------------------------------------------------
def Species2Genes( nexus, map_species2genes ):
"""convert a species tree to a gene tree.
"""

for tree in nexus.trees:
for nx in tree.get_terminals():
t1 = tree.node(nx).get_data().taxon
if t1 in map_species2genes:
for g in map_species2genes[t1]:
d = Nexus.NodeData( taxon=g )
new_node = Node( d )
tree.node(nx).get_data().taxon = None

##-------------------------------------------------------------------------
def Genes2Species( nexus, map_gene2species ):
"""convert a gene tree into a species tree.
"""

for tree in nexus.trees:
for nx in tree.get_terminals():
t1 = tree.node(nx).get_data().taxon
if t1 in map_gene2species:
tree.node(nx).get_data().taxon = map_gene2species[t1]

##-------------------------------------------------------------------------
def BuildMapSpecies2Genes( genes, pattern_species = "^([^|]+)[|]"):
"""read genes from infile and build a map of species to genes.
"""
rx = re.compile( pattern_species )
map_species2genes = {}
map_gene2species = {}
for gene in genes:
species = rx.search( gene ).groups()[0]
if species not in map_species2genes:
map_species2genes[species] = []
map_species2genes[species].append( gene )
map_gene2species[gene] = species

return map_species2genes, map_gene2species

##-------------------------------------------------------------------------
def GetMonophyleticPairs( tree ):
"""build list of monophyletic pairs in tree.
"""

leaves = tree.get_terminals()

pairs = []

for z in range(len(leaves)/2):

for x in range(0, len(leaves) -1 ):
rx = leaves[x]
tx = tree.node(rx).get_data().taxon
for y in range( x + 1, len(leaves)):
ry = leaves[y]
ty = tree.node(ry).get_data().taxon
if tree.is_monophyletic( (tx, ty) ) != -1 :
pairs.append( (rx, ry, tx, ty) )

return pairs

##-------------------------------------------------------------------------
def GetTaxaForSpecies( tree, species, pattern_species = "^([^|]+)[|]" ):
"""get all taxa of a given species."""
rx = re.compile( pattern_species )

taxa = []
for l in GetTaxa( tree ):
g = rx.search(l).groups()[0]
if g == species: taxa.append( l )

return taxa
##-------------------------------------------------------------------------
def IsMonophyleticForSpecies( tree,
species,
pattern_species = "^([^|]+)[|]" ):
"""check if a tree is monophyletic for a species."""

taxa = GetTaxaForSpecies( tree, species, pattern_species )
tree.root_with_outgroup( taxa )
return tree.is_monophyletic( taxa ) != -1

##-------------------------------------------------------------------------
def IsMonophyleticForTaxa( tree,
taxa,
support = None ):
"""check if a tree is monophyletic for a list of taxa.

If support is given, minimum support is checked.
"""
tree.root_with_outgroup( taxa )

if support:
n = tree.is_monophyletic( taxa )
if n == -1: return False
return tree.node(tree.node(tree.root).succ[0]).data.support >= support
else:
return tree.is_monophyletic( taxa ) != -1

##-------------------------------------------------------------------------
def GetLeaves( tree, node ):
"""get leaves in tree below node.
"""
tree.root_with_outgroup( (node,) )
return tree.get_taxa( node )

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def IsSingleSpecies( tree, node, pattern_species = "^([^|]+)[|]" ):
"""check if list of taxa below node contains the same species."""
rx = re.compile( pattern_species )

species = {}
leaves = GetLeaves( tree, node )
for l in leaves:
species[rx.search(l).groups()[0]] = 1

return len(species) == 1

##-------------------------------------------------------------------------
def CountDuplications( tree, species,
pattern_species = "^([^|]+)[|]" ):
"""count the number duplications for a given species.

Do not check for monophyly versus species.
"""
result = []
taxa = GetTaxaForSpecies( tree, species, pattern_species )
ids = map( tree.search_taxon, taxa )
for x in range(len(ids)-1):
for y in range(x+1,len(ids)):
t = tree.common_ancestor( ids[x], ids[y] )
is_single_species = IsSingleSpecies( tree, t, pattern_species )
tree.root_with_outgroup( (taxa[x], taxa[y]) )
is_monophyletic = tree.is_monophyletic( (taxa[x], taxa[y]) ) != -1
result.append( ( taxa[x], taxa[y], is_single_species, is_monophyletic, tree.distance( ids[x], ids[y] ) ) )

return result

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def Transcript2GeneTree( tree,
map_transcript2gene,
map_gene2transcripts ):
"""convert a transcript tree into a gene tree.

supply a map for mapping transcripts to genes.

The procedure for converting a transcript tree into a gene tree:

If there are two genes, and they are monophyletic, no matter how many
transcripts, the order is as follows:

1 Merge all nodes into two, one for each gene.

2 The distance between the genes is the minimum distance observed between
two transcripts from different genes. Half of this will be set as the
branch length from the gene leaves.

If this is not possible for a set of genes, the procedure will fail and not
return a gene tree.
"""
raise NotImplementedError()
MapTaxa( tree, map_transcript2gene )

## get all leaves and sort by taxon
ids  = tree.get_terminals()

## sort identities by taxa
ids.sort( lambda x,y: cmp(tree.node(x).get_data().taxon, tree.node(y).get_data().taxon))
taxa = map( lambda x: tree.node(x).get_data().taxon, ids )

print ids
print taxa

for x in range(len(taxa)-1):
for y in range(x+1,len(taxa)):
print ids[x], ids[y], taxa[x], taxa[y]

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def MapTaxa( tree, mapping ):
"""map taxon names in all trees."""

for nx in tree.get_terminals():
t1 = tree.node(nx).get_data().taxon
if t1 in mapping:
tree.node(nx).get_data().taxon = mapping[t1]

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def GetCommonAncestor( tree, taxa ):
"""retrieve common ancestor for a list of taxa.

Reroot tree. Check if it is monopyletic. If it is, return root,
otherwise, return -1.
"""

tree.root_with_outgroup( taxa )
if tree.is_monophyletic(taxa) == -1:
return -1
else:
x = tree.search_taxon( taxa[0] )
trace = tree.trace( tree.root, x )

if trace[0] == tree.root:
return tree.root
else:
return trace[0]

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def Nop( x ): return True

def TreeDFS( tree, node_id,
pre_function = Nop,
descend_condition = Nop,
post_function = Nop):
"""BFS tree tree traversal starting at node_id.

Apply functions pre_function at first and
post_function at last visit of a node.
"""
pre_function( node_id )
for n in tree.node(node_id).succ:
if descend_condition( node_id ):
TreeDFS( tree, n, pre_function, descend_condition,post_function )
post_function( node_id )

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def GetMaxIndex( tree ):
"""get maximum node number."""
return tree.id

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def GetNumChildren( tree ):

nnodes = GetMaxIndex(tree) + 1

counts = [0] * nnodes

def count( node_id ):
s = tree.node(node_id).succ
if s == []:
counts[node_id] = 1
else:
for n in s:
counts[node_id] += counts[n]

TreeDFS( tree, tree.root, post_function=count)
return counts

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def GetBranchLengths( tree ):
"""return an array with minimum and maximum branch length."""

nnodes = GetMaxIndex(tree) + 1

max_sums = [0] * nnodes
min_sums = [0] * nnodes

def count( node_id ):
s = tree.node(node_id).succ
if s != []:
min_vals = [ min_sums[n] + tree.node(n).data.branchlength for n in s]
max_vals = [ max_sums[n] + tree.node(n).data.branchlength for n in s]

min_sums[node_id] = min(min_vals)
max_sums[node_id] = max(max_vals)

TreeDFS( tree, tree.root, post_function=count)

return min_sums, max_sums

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def Reroot( tree, taxa ):
"""reroot tree with taxa - the list of
taxa does not need to be monophyletic.
"""

nnodes = GetMaxIndex(tree) + 1

within_taxa = [False] * nnodes

def update_taxa( node_id ):
n = tree.node(node_id)

s = n.succ
if s == []:
within_taxa[node_id] = n.data.taxon in taxa
else:
for ss in s:
within_taxa[node_id] |= within_taxa[ss]

TreeDFS( tree, tree.root,
post_function = update_taxa)

## go down the tree - get largest branch
## in each subtree left/right from root
## that contain all matches to taxa
## Check if root spans taxa
all_true = True
for n in tree.node(tree.root).succ:
all_true &= within_taxa[n]

## if root spans taxa get largest subtrees with no matches to taxa
if all_true:
def has_taxa( node_id ):
if not within_taxa[node_id]:
extra_subtree[node_id] = True
return False
else:
return True
else:
## if root does not span taxa, get smallest subtree including
## all taxa
def has_taxa( node_id ):
"""check if all children contain taxa."""
if not within_taxa[node_id]: return False
for n in tree.node(node_id).succ:
if not within_taxa[n]:
return True
extra_subtree[node_id] = True
return False

# do the search
extra_subtree = [False] * nnodes
TreeDFS( tree, tree.root,
descend_condition = has_taxa )

#     for x in range(nnodes):
#         print x, within_taxa[x], extra_subtree[x]

nodes = filter( lambda x: extra_subtree[x], range(nnodes))

if len(nodes) == 0:
## no rerooting, if all or no taxa within tree
return tree.root
elif len(nodes) > 1:
## if more than two nodes (i.e, if all_true == True)
nchildren = GetNumChildren( tree )
nodes.sort( lambda x,y: cmp(nchildren[x],nchildren[y]))
nodes.reverse()

subtree_node = nodes[0]

## todo: write rerooting function from node
taxa = tree.get_taxa( subtree_node )
result = tree.root_with_outgroup( taxa )

if result == -1:
raise TreeError('error while rerooting tree')

## return node_id with all taxa
if all_true:
for n in tree.node(tree.root).succ:
if n != subtree_node:
return n
else:
return subtree_node

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def GetSubsets( tree, node = None, with_decoration = True):
"""return subsets below a certain node including
their height (distance from leaves) and branchlength
"""
if node == None:
node = tree.root

if with_decoration:
if tree.node(node).succ==[]:
return [((tree.node(node).data.taxon,), 0, tree.node(node).data.branchlength,)]
else:
## get all subtrees
children, height, branchlength = [], 0, tree.node(node).data.branchlength
subtrees = [ GetSubsets(tree, n) for n in tree.node(node).succ]

ss = []
for s in subtrees:
children += s[-1][0]
height += s[-1][1] + s[-1][2]
ss += s

height /= len(tree.node(node).succ)

return ss + [(children, height, branchlength)]
else:
if tree.node(node).succ==[]:
return [[tree.node(node).data.taxon,]]
else:
## get all subtrees
subtrees = [ GetSubsets(tree, n, with_decoration = False) for n in tree.node(node).succ]

ss = []
children = []
for s in subtrees:
children += s[-1]
ss += s

return ss + [children]

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def CountBranchPoints( tree, taxa ):
"""count the number branch points together with their
distances for a given list of taxa.

return a list of branch points
"""
tree.root_with_outgroup( taxa )

if tree.is_monophyletic(taxa) == -1:
return None

parent = GetCommonAncestor( tree, taxa )

## retrieve all subsets with their branchlengths.
return GetSubsets( tree, parent )

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def IsCompatible( tree1, tree2 ):
"""check if two trees are compatible.

note: this will delete support information.
"""

if len(tree1.get_terminals()) != len(tree2.get_terminals()):
return False, "leaves"

l1 = GetTaxonomicNames( tree1 )
l2 = GetTaxonomicNames( tree2 )

for n in tree2.chain.keys():
tree2.node(n).data.support=1.0
for n in tree1.chain.keys():
tree1.node(n).data.support=1.0

for l in l1:
if l not in l2:
return False, "taxa"

if tree1.is_compatible(tree2, 0 ):
return False, "topology"
else:
return True, ""

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def Tree2Graph( tree ):
"""return tree as a list of edges in a graph."""
for node_id1, node1 in tree.chain.items():
if node1.prev != None:
links.append( (node_id1, node1.prev, node1.get_data().branchlength ) )

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def Graph2Tree( links, label_ancestral_nodes = False ):
"""build tree from list of nodes.

Assumption is that links always point from parent to child.
"""
tree = Tree.Tree()

## map of names to nodes in tree
map_node2id = { links[0][0]: 0 }

for parent, child, branchlength in links:
if parent not in map_node2id:
p = len(tree.chain)
tree.chain[p] = Bio.Nexus.Nodes.Node(Bio.Nexus.Trees.NodeData())
map_node2id[parent] = p
map_id2node.append( parent )
else:
p = map_node2id[parent]

if child not in map_node2id:
c = len(tree.chain)
tree.chain[c] = Bio.Nexus.Nodes.Node(Bio.Nexus.Trees.NodeData())
map_node2id[child] = c
map_id2node.append( child )
else:
c = map_node2id[parent]

tree.chain[p].succ.append( c )
tree.chain[c].prev = p
tree.chain[c].data.branchlength = branchlength

## set taxon names for children and find root
for i,n in tree.chain.items():
if n.prev == []: tree.root = i

if n.succ == [] or label_ancestral_nodes:
n.data.taxon = map_id2node[i]

## set pointer to last id
tree.id = len(tree.chain.items()) - 1

return tree

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def GetAllNodes( tree ):
"""return all nodes in the tree."""
return tree.chain.keys()

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def GetDistancesBetweenTaxa( tree, taxa1, taxa2 ):
"""get average branchlength between taxa1 and taxa2."""

## get sets with terminal nodes in taxa
a, b= [], []
for x in tree.get_terminals():
t = tree.node(x).data.taxon
if t in taxa1:
a.append((t, x))
elif t in taxa2:
b.append((t, x))

distances = []
for ta, aa in a:
for tb, bb in b:
distances.append( (ta, tb, tree.distance( aa, bb) ) )

return distances

##-------------------------------------------------------------------------
def PruneTerminal(tree,taxon):
"""Prunes a terminal taxon from the tree.

id_of_previous_node = prune(tree,taxon)
If taxon is from a bifurcation, the connecting node will be collapsed
and its branchlength added to remaining terminal node. This might be no
longer a meaningful value.

direct copy of Nexus.Trees.py - don't know why have a separate method,
maybe there was a bug in Nexus.Trees.
"""
return tree.prune(taxon)

id=tree.search_taxon(taxon)
if id is None:
elif id not in tree.get_terminals():
raise TreeError('Not a terminal taxon: %s' % taxon)
else:
tree.kill(id)
if not prev==tree.root and len(tree.node(prev).succ)==1:
succ=tree.node(prev).get_succ()
tree.collapse(prev)

return prev

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def add_children( old_tree, new_tree, old_id, new_id ):

for n in old_tree.node(old_id).succ:
nid = new_tree.add( Bio.Nexus.Nodes.Node( old_tree.node(n).data ), new_id )
add_children( old_tree, new_tree, n, nid )

def GetSubtree( tree, node_id ):
"""return a copy of tree from node_id downwards."""

subtree = Tree.Tree( weight=tree.weight,
rooted=tree.rooted,
name=tree.name)

## automatically adds a root, so substitute it
n = Bio.Nexus.Nodes.Node( tree.node(node_id).data )
n.id = subtree.root
subtree.chain[subtree.root] = n

return subtree

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def Unroot( tree ):
"""unroot tree."""

## collapse a child of the root, that has
## at least two children and where both
## children are not terminals.
root_node = tree.node(tree.root)
root_id = tree.root

## check if root has a single child - if so, remove this root
if len(root_node.succ) == 1:
root_id = root_node.succ[0]
tree.kill( tree.root )
tree.root = root_id
root_node = tree.node(tree.root)

if len(tree.node(root_node.succ[0]).succ) == 0 and len( tree.node(root_node.succ[1]).succ) == 0:
return

## calculate branch length along branch that has
## been split by root. This value needs to be assigned
## to the remaining childs branchlength.
n = sum( map( lambda x: tree.node(x).data.branchlength, root_node.succ ) )

if len(tree.node(root_node.succ[0]).succ) > 1:
x = root_node.succ[0]
y = root_node.succ[1]
else:
x = root_node.succ[1]
y = root_node.succ[0]

tree.collapse(x)
tree.node(y).data.branchlength = n
tree.rooted = False

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def GetSize( tree ):
"""return the length of the tree. This is the maximum node_id + 1.

This quantity is useful for tree traversal while updating
a container.
"""
return max( tree.chain.keys() ) + 1

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def PruneTree( tree, taxa, keep_distance_to_root = False ):
"""prune tree: keep only those taxa in list.
"""

for nx in tree.get_terminals():
taxon = tree.node(nx).get_data().taxon
if taxon not in taxa:
tree.prune( taxon )

r = tree.root
rn = tree.node(tree.root)

## if one complete side of the root has been removed,
## collapse it.
if len(rn.succ) == 1:
## tree.collapse on root does not work
# tree.collapse( tree.root )
s = rn.succ[0]
sn = tree.node(s)
sn.prev = None
if not keep_distance_to_root:
sn.data.branchlength = 0.0
tree.root = s
tree.kill( r )

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
def GetNodeMap( tree1, tree2 ):
"""map nodes between tree1 and tree2.
"""

if not tree1.is_identical( tree2 ):
raise ValueError( "trees are not the same" )

map_a2b = [0] * len(tree1.chain.keys())

for i,n in tree1.chain.items():
t = tree1.get_taxa( i )
o = tree2.is_monophyletic( t )
if o != -1:
map_a2b[i] = o
else:
raise ValueError( "trees are not congruent." )

return map_a2b

class NodeType:
mType = "Generic"
mDescription = "Generic"
def __init__(self, node1, node2 ):
self.mSpeciesNode = node1
self.mGeneNode = node2
def __str__(self):
return "\t".join( (self.mType, str(self.mSpeciesNode), str(self.mGeneNode)))

class NodeTypeSpeciation(NodeType):

mType="Speciation"
mDescription="Speciation event"
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

class NodeTypeSpeciationDeletion(NodeType):

mType="SpeciationDeletion"
mDescription="Speciation event, but in one sub-branch, deletions occured."
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

class NodeTypeDuplication(NodeType):
mType = "Duplication"
mDescription = "Duplication event"
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

class NodeTypeOutparalogs(NodeType):
mType = "Outparalogs"
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

class NodeTypeDuplicationDeletion(NodeType):
mType = "DuplicationDeletion"
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

class NodeTypeDuplicationLineage(NodeType):
mType = "DuplicationLineage"
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

class NodeTypeDuplicationInconsistency(NodeType):
mType = "DuplicationInconsistency"
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

class NodeTypeTranscripts(NodeType):
mType = "Transcripts"
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

class NodeTypeInconsistency(NodeType):
mType = "Inconsistency"
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

class NodeTypeInconsistentTranscripts(NodeType):
mType = "InconsistentTranscripts"
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

class NodeTypeLeaf(NodeType):
mType = "Leaf"
def __init__(self, *args, **kwargs ):
NodeType.__init__(self, *args, **kwargs)

##-------------------------------------------------------------------------
def ReconciliateByRio( gene_tree, species_tree,
extract_species,
extract_gene = None,
outgroup_species = None,
min_branch_length = 0.0 ):

"""
Gene tree G and species tree S

If outgroup_species is given: trees will be cut of
as soon as one of the outgroup species is part of a subtree.
The corresponding node type will be out-paralog. Out-paralog
relationship is cast upwards.

Input trees are rooted and binary.

Output: gene tree with duplication/speciation assigned to each node.

Initialization:

Number nodes in S in pre-order traversal (root = 1), such
that child nodes are always larger than parent nodes.

For each external node g of G, set M(g) to the number of the
external node in S with the matching species name.

Recursion:

Visit each internal node g of G in post-order traversal, (i.e.
from leaves to root)::

set a = M(g1) # g1 = first child of current node g
set b = M(g2) # g2 = second child of current node g

while a != b:
if a > b:
set a = parent of node a in species tree
else:
set b = parent of node b in species tree
set M(g) = a

if M(g) == M(g1) or M(g) == M(g2):
g is duplication
else:
g is speciation

The algorithm returns an array for each node with its type.

If extract_gene is given, the algorithm will label transcription nodes
for alternative transcripts (duplications involving the same gene).

The algorithm has been extended to accomodate the following test cases:

Alternative transcripts
Alternative transcripts that span genes from other species are permitted,
if at most one gene of the other species is involved.

To avoid over-counting of speciation events, the one subtree with the

If the branch length of a node in the gene tree is shorter than min_branch_length,
the resultant node is masked, because the topology might be dodgy.
"""

########################################################################
## Initialization
nnodes_genetree = max( gene_tree.chain.keys() ) + 1
nnodes_speciestree = max( species_tree.chain.keys() ) + 1
nspecies = len(species_tree.get_taxa())

## vector with node numbers
N = [0] * (nnodes_speciestree + 1)
map_N2Parent  = [0] * (nnodes_speciestree + 2)
map_N2node_id = [0] * (nnodes_speciestree + 2)
map_species2N = {}
## Mapping function
M = [0] * nnodes_genetree
## result: speciation array
node_types = [None] * nnodes_genetree

m = N[0]

## relabel the species tree externally
counter = [1]

def init( node_id ):
node = species_tree.node(node_id)
if node.succ == []:
map_species2N[node.data.taxon] = counter[0]
map_N2node_id[counter[0]] = node_id
N[node_id] = counter[0]
if node.prev:
map_N2Parent[counter[0]] = N[node.prev]
counter[0] += 1

TreeDFS( species_tree, species_tree.root,
pre_function = init )

if outgroup_species:
outgroups = set(outgroup_species)
else:
outgroups = None

for x in gene_tree.get_terminals():
M[x] = map_species2N[extract_species(gene_tree.node(x).data.taxon)]

########################################################################
########################################################################
## Recursion for masking a subtree
node = gene_tree.node(node_id)
if node.succ == []: return

node_types[node_id].mGeneNode )

########################################################################
########################################################################
## Recursion for updating assignments
def update( g ):

ng = gene_tree.node(g)
if ng.succ == []:
node_types[g] = NodeTypeLeaf( map_N2node_id[M[g]], g )
return

if len(ng.succ) != 2:
gene_tree.display()
raise ValueError( "warning: not a binary tree." )

g1, g2 = ng.succ
a = M[g1]
b = M[g2]

while a != b:
if a > b:
a = map_N2Parent[a]
else:
b = map_N2Parent[b]

M[g] = a

taxa1 = gene_tree.get_taxa(g1)
taxa2 = gene_tree.get_taxa(g2)
species1 = set(map(extract_species, taxa1))
species2 = set(map(extract_species, taxa2))

if min_branch_length:
# mask nodes with short branches on this node or to both children
min_bl = min( ng.data.branchlength, min([ gene_tree.node(x).data.branchlength for x in ng.succ ]) )
if min_bl < min_branch_length:
node_types[g] = NodeTypeMasked( map_N2node_id[M[g]], g )
return

if M[g] == 1:
## check if species sets of subtrees are overlapping
## If they are, then we have outparalogs, otherwise, it is the final
## speciation event
if species1.intersection( species2 ):
node_types[g] = NodeTypeOutparalogs( map_N2node_id[M[g]], g )
else:
node_types[g] = NodeTypeSpeciation( map_N2node_id[M[g]], g )
elif outgroups and (species1.intersection( outgroups ) or species2.intersection( outgroups)):
node_types[g] = NodeTypeOutparalogs( map_N2node_id[M[g]], g )
elif M[g] == M[g1] or M[g] == M[g2]:
## additional check: check if species sets of subtrees are actually overlapping
if species1.intersection( species2 ):

## check for inconsistent transcripts
if extract_gene:
try:
genes1 = set( zip(map(extract_species,taxa1), map(extract_gene, taxa1)) )
except AttributeError:
raise AttributeError, "could not parse %s" % (",".join(taxa1))
try:
genes2 = set( zip(map(extract_species,taxa2), map(extract_gene, taxa2)) )
except AttributeError:
raise AttributeError, "could not parse %s" % (",".join(taxa2))

if genes1.intersection(genes2):
if len(genes1) == len(genes2) and len(genes1) == 1:
## transcripts of the same gene are joined
node_types[g] = NodeTypeTranscripts( map_N2node_id[M[g]], g )
else:
## for all transcripts that come from the same genome - check if
## they are all from the same gene(s). If they are, we have an alternative
## transcripts node.
is_ok = True
for species in species1.intersection( species2 ):
sg1 = set(filter( lambda x: x[0] == species, genes1))
sg2 = set(filter( lambda x: x[0] == species, genes2))
if sg1.intersection(sg2) != sg1.union(sg2):
is_ok = False
if is_ok:
# all transcripts are from the same gene
# mark node as alternative transcript one
node_types[g] = NodeTypeTranscripts( map_N2node_id[M[g]], g )
# mask subtree with fewer species to avoid over-counting
# of speciation events
if len(species1) > len(species2):
TreeDFS( gene_tree, g2, post_function = mask_subtree)
else:
TreeDFS( gene_tree, g1, post_function = mask_subtree)
else:
# transcripts of the same gene are joined, but they are not monophyletic
node_types[g] = NodeTypeInconsistentTranscripts( map_N2node_id[M[g]], g )

if node_types[g] == None:
if len(species1) == len(species2) and len(species1) == 1:
## lineage specific duplication: only one species is involved
node_types[g] = NodeTypeDuplicationLineage( map_N2node_id[M[g]], g )
elif len(species1) == len(species2) and not species1.difference(species2):
## clean duplication: species sets do completely overlap
node_types[g] = NodeTypeDuplication( map_N2node_id[M[g]], g )
elif len(species1.difference(species2)) == 0 or len(species2.difference(species1)) == 0:
## duplication, but with only deletions in one branch
node_types[g] = NodeTypeDuplicationDeletion( map_N2node_id[M[g]], g )
else:
node_types[g] = NodeTypeDuplicationInconsistency( map_N2node_id[M[g]], g )
## check for speciation events with a deletion
##
else:
## check for alternative transcripts.
if ng.data.branchlength < min_branch_length:
node_types[g] = NodeTypeMasked( map_N2node_id[M[g]], g )
else:
node_types[g] = NodeTypeInconsistency( map_N2node_id[M[g]], g )
else:
# get species expected under species tree
expected_species = set(species_tree.get_taxa( map_N2node_id[M[g]] ))
observed_species = species1.union( species2 )
if len(observed_species.intersection(expected_species)) < len(expected_species):
node_types[g] = NodeTypeSpeciationDeletion( map_N2node_id[M[g]], g )
else:
node_types[g] = NodeTypeSpeciation( map_N2node_id[M[g]], g )

TreeDFS( gene_tree, gene_tree.root,
post_function = update)

return node_types

##-------------------------------------------------------------------------
##-------------------------------------------------------------------------
##-------------------------------------------------------------------
## Count duplications
def CountDuplications( gene_tree, species_tree, node_types,
extract_species,
extract_gene = None):
"""count duplications.

given are gene and species tree and node types (duplication/speciation)

extract_species gives the species for an OTU in the gene tree

Extract_gene gives the gene for an OTU in the gene tree. If not given,
all transcripts are counted as unique.
"""

########################################################################
## Get additional data for branch lengths
min_branch_lengths, max_branch_lengths = GetBranchLengths( gene_tree )

for x in range(len(node_types)):

node_type = node_types[x]
if node_type:
if node_type.mType not in ("Speciation"):
print "\t".join( map(str, (node_type,
(min_branch_lengths[ node_type.mGeneNode] + max_branch_lengths[ node_type.mGeneNode ]) / 2)) )
for s in species_tree.node(node_type.mSpeciesNode).succ:
print "\t" * 5 + ",".join( species_tree.get_taxa( s ) )
for s in gene_tree.node(node_type.mGeneNode).succ:
print "\t" * 5 + ",".join( gene_tree.get_taxa( s ) )

##------------------------------------------------------------------------------------------
##------------------------------------------------------------------------------------------
##------------------------------------------------------------------------------------------
def GetParentNodeWhereTrue( node_id, tree, stop_function ):
"""walk up in gene tree and stop where stop_function is true.

The walk finishes at the root.

returns tuple of node and distance.
"""

node = tree.node(node_id)
distance = 0
while node.prev != None:

if stop_function( node_id ):
return node_id, distance

distance += node.data.branchlength

node_id = node.prev
node = tree.node(node_id)

return node_id, distance

##------------------------------------------------------------------------------------------
##------------------------------------------------------------------------------------------
##------------------------------------------------------------------------------------------
def GetChildNodesWhereTrue( node_id, tree, stop_function ):
"""walk down in tree and stop where stop_function is true

The walk finishes at the leaves.

returns a list of tuples of nodes and distance.
"""

result = []
def __getChildNodes( node_id, distance ):

node = tree.node(node_id)

distance += node.data.branchlength

if not node.succ:
result.append( (node_id, distance) )
elif stop_function( node_id ):
result.append( (node_id, distance) )
else:
for x in node.succ:
__getChildNodes(x, distance)

node = tree.node(node_id)
__getChildNodes( node_id, -node.data.branchlength )

return result

def GetDistanceToRoot( tree ):
"""return list with distance to root for each node."""
#########################################################################
#########################################################################
#########################################################################
## compute distance to root for each node
#########################################################################
distance_to_root = [ 0 ] * GetSize( tree )

def record_distance( node_id ):
node = tree.node(node_id)
if node.prev:
distance_to_root[node_id] += distance_to_root[node.prev] + node.data.branchlength
else:
distance_to_root[node_id] = node.data.branchlength

TreeDFS( tree, tree.root, pre_function = record_distance )

return distance_to_root

def traverseGraph( graph, start, block = []):
"""traverse graph, go not passed nodes in block.
"""

to_visit = [start,]
visited = {}

while to_visit:
v = to_visit[0]
del to_visit[0]
visited[v] = 1
for n in graph[v]:
if n not in visited and n not in block:
to_visit.append(n)

return visited

def getPattern( tree, nodes, map_taxon2position ):

pattern = ["0"] * len(map_taxon2position)
for n in nodes:
t = tree.node( n ).get_data().taxon
if t != None:
pattern[map_taxon2position[t]] = "1"
return "".join(pattern)

def convertTree2Graph( tree ):
"""convert tree to a graph."""

graph={}
edges=[]
for i, n in tree.chain.items():
if i not in graph: graph[i] = []
for nn in n.succ:
if nn not in graph: graph[nn] = []
graph[nn].append(i)
graph[i].append(nn)
edges.append( (i, nn) )

return graph,edges

def calculatePatternsFromTree( tree, sort_order ):
"""calculate patterns from a tree."""

notus = len(sort_order)

map_taxon2position = {}
for x in range(notus):
map_taxon2position[sort_order[x]] = x

graph,edges = convertTree2Graph(tree)
patterns = []
for a,b in edges:
result = traverseGraph( graph, a, [b,])
patterns.append( getPattern( tree, result.keys(), map_taxon2position) )
result = traverseGraph( graph, b, [a,])
patterns.append( getPattern( tree, result.keys(), map_taxon2position) )

patterns.append( "1" * notus )
return patterns

```