#!/usr/bin/env python import sys import os import cPickle import numpy import time from string import strip from collections import defaultdict from common import (ncbi, argparse, PhyloTree, Tree, SVG_COLORS, faces, treeview, NodeStyle, TreeStyle, color, print_table) __DESCRIPTION__ = ("Calculates the consensus of a tree with the NCBI taxonomy." " The analysis can be visualized over the tree, in" " which broken clades are shown.") try: name2color = cPickle.load(open("ncbi_colors.pkl")) except Exception: name2color = {} else: print "loaded cached color information" args = None def npr_layout(node): if node.is_leaf(): name = faces.AttrFace("name", fsize=12) faces.add_face_to_node(name, node, 0, position="branch-right") if hasattr(node, "sequence"): seq_face = faces.SeqFace(node.sequence, []) faces.add_face_to_node(seq_face, node, 0, position="aligned") if "alg_type" in node.features: faces.add_face_to_node(faces.AttrFace("alg_type", fsize=8), node, 0, position="branch-top") ttype=faces.AttrFace("tree_type", fsize=8, fgcolor="DarkBlue") faces.add_face_to_node(ttype, node, 0, position="branch-top") #ttype.background.color = "DarkOliveGreen" node.img_style["size"] = 20 node.img_style["fgcolor"] = "red" if "treemerger_rf" in node.features: faces.add_face_to_node(faces.AttrFace("treemerger_rf", fsize=8), node, 0, position="branch-bottom") support_radius= (1.0 - node.support) * 50 if not node.is_leaf() and support_radius > 1: support_face = faces.CircleFace(support_radius, "red") faces.add_face_to_node(support_face, node, 0, position="float-behind") support_face.opacity = 0.25 faces.add_face_to_node(faces.AttrFace("support", fsize=8), node, 0, position="branch-bottom") if "clean_alg_mean_identn" in node.features: identity = node.clean_alg_mean_identn elif "alg_mean_identn" in node.features: identity = node.alg_mean_identn if "highlighted" in node.features: node.img_style["bgcolor"] = "LightCyan" if "npr_iter" in node.features: node.img_style["size"] = 50 if "improve" in node.features: color = "orange" if float(node.improve) < 0 else "green" if float(node.improve) == 0: color = "blue" support_face = faces.CircleFace(200, color) faces.add_face_to_node(support_face, node, 0, position="float-behind") def ncbi_layout(node): npr_layout(node) global name2color if node.is_leaf(): tax_pos = 10 if hasattr(node, "lineage"): for tax,k in zip(node.lineage, node.named_lineage): f = faces.TextFace("%10s" %k, fsize=15) try: color = name2color[k] except KeyError: name2color[k] = color = treeview.main.random_color() #if hasattr(node, "broken_groups") and tax in node.broken_groups: f.background.color = color faces.add_face_to_node(f, node, tax_pos, position="aligned") tax_pos += 1 f = faces.AttrFace("spname", fsize=15) faces.add_face_to_node(f, node, 10, position="branch-right") else: if getattr(node, "broken_groups", None): for broken in node.broken_groups: f = faces.TextFace(broken, fsize=10, fgcolor="red") faces.add_face_to_node(f, node, 1, position="branch-bottom") if getattr(node, "broken_levels", None): for broken in node.broken_levels: f = faces.TextFace(broken, fsize=10, fgcolor="blue") faces.add_face_to_node(f, node, 1, position="branch-bottom") if hasattr(node, "changed"): if node.changed == "yes": node.img_style["bgcolor"]="indianred" else: node.img_style["bgcolor"]="white" #def analyze_tracks(t, n2content): # counterdict = lambda: defaultdict(int) # node2track = defaultdict(counterdict) # taxcounter = defaultdict(int) # tax2name = {} # for node, leaves in n2content.iteritems(): # node.add_features(broken_levels=[]) # if node.is_leaf(): # for index, tax in enumerate(node.lineage): # taxcounter[tax] += 1 # tax2name[tax] = node.named_lineage[index] # else: # for lf in leaves: # for index, tax in enumerate(lf.lineage): # node2track[node][tax] += 1 # # mono = set(taxcounter.keys()) # non_mono = set() # non_mono_sizes = [] # broken_branches = 0 # for node, taxa in node2track.iteritems(): # for tax, num in taxa.iteritems(): # if taxcounter[tax] != num and len(n2content[node]) != num: # if 0: # print "max:", taxcounter[tax] # print "in this node:", num # print "node size:", len(n2content[node]) # print tax2name[tax] # print "..." # raw_input() # mono.discard(tax) # node.broken_levels.append(tax2name[tax]) # non_mono_sizes.append(taxcounter[tax]) # broken_branches += 1 # if tax not in non_mono: # non_mono.add(tax) # # return mono, non_mono, broken_branches, non_mono_sizes, tax2name def analyze_subtrees(t, subtrees, reft=None, show_progress=False): ncbi_mistakes = 0 valid_subtrees = 0 broken_groups = set() broken_subtrees = 0 total_rf = 0 broken_group_sizes = [] all_broken_branches = 0 for count, subt in enumerate(subtrees): if show_progress: print >>sys.stderr, "\r", count, " ", sys.stdout.flush() n2content = subt.get_cached_content() subt_size = len(n2content[subt]) if subt_size > 1: valid_subtrees += 1 if reft: for _n in n2content(): if _n.is_leaf(): _n.spcode = _n.realname rf, rf_max, _, _, _, _, _ = subt.robinson_foulds(reft, attr_t1="spcode") total_rf += float(rf)/rf_max broken_branches, broken_clades, broken_clade_sizes, tax2name = ncbi.get_broken_branches(subt, n2content) ncbi_mistakes += len(broken_clades) all_broken_branches += len(broken_branches) broken_group_sizes.extend(broken_clade_sizes) if broken_clades: broken_subtrees += 1 broken_groups.update(broken_clades) children = [] if args.show_tree or args.render: for branch in broken_branches: branch.broken_groups = set([tax2name[e] for e in broken_clades]) #si, no, broken_branches, non_mono_sizes, tax2name = analyze_tracks(subt, n2content) #ncbi_mistakes += len(no) #all_broken_branches += broken_branches #print color(len(no) == len(non_mono_sizes), "blue") #broken_group_sizes.extend(non_mono_sizes) #if no: # broken_subtrees += 1 #broken_groups.update(no) #correct_groups.update(si) #children = [] #if args.show_tree or args.render: # for tip in subt.iter_leaves(): # target = (t&tip.name) # children.append(target) # target.broken_groups = set(no) # # Annotate node # source_node = t.get_common_ancestor(children) # source_node.broken_groups = set([tax2name[e] for e in no]) if show_progress: print >>sys.stderr, "\nDone" return valid_subtrees, broken_subtrees, ncbi_mistakes, all_broken_branches, total_rf, set([tax2name[b] for b in broken_groups]), broken_group_sizes def annotate_tree_with_taxa(t, name2taxa_file, tax2name=None, tax2track=None): if name2taxa_file: names2taxid = dict([map(strip, line.split("\t")) for line in open(name2taxa_file)]) else: names2taxid = dict([(n.name, n.name) for n in t.iter_leaves()]) not_found = 0 for n in t.iter_leaves(): n.add_features(taxid=names2taxid.get(n.name, 1)) n.add_features(species=n.taxid) if n.taxid == 1: not_found += 1 if not_found: print >>sys.stderr, "WARNING: %s nodes where not found within NCBI taxonomy!!" %not_found return ncbi.annotate_tree(t, tax2name, tax2track) def tree_compare(t1, t2): t2_c2node = {} for n, content in t2.get_cached_content().iteritems(): t2_c2node[frozenset([_c.name for _c in content])] = n for n, content in t1.get_cached_content().iteritems(): named_content = frozenset([_c.name for _c in content]) if frozenset(named_content) not in t2_c2node: n.add_feature("changed", "yes") else: n.add_feature("changed", "no") def main(argv): parser = argparse.ArgumentParser(description=__DESCRIPTION__, formatter_class=argparse.RawDescriptionHelpFormatter) # name or flags - Either a name or a list of option strings, e.g. foo or -f, --foo. # action - The basic type of action to be taken when this argument is encountered at the command line. (store, store_const, store_true, store_false, append, append_const, version) # nargs - The number of command-line arguments that should be consumed. (N, ? (one or default), * (all 1 or more), + (more than 1) ) # const - A constant value required by some action and nargs selections. # default - The value produced if the argument is absent from the command line. # type - The type to which the command-line argument should be converted. # choices - A container of the allowable values for the argument. # required - Whether or not the command-line option may be omitted (optionals only). # help - A brief description of what the argument does. # metavar - A name for the argument in usage messages. # dest - The name of the attribute to be added to the object returned by parse_args(). parser.add_argument("--show", dest="show_tree", action="store_true", help="""Display tree after the analysis.""") parser.add_argument("--render", dest="render", action="store_true", help="""Render tree.""") parser.add_argument("--dump", dest="dump", action="store_true", help="""Dump analysis""") parser.add_argument("--explore", dest="explore", type=str, help="""Reads a previously analyzed tree and visualize it""") input_args = parser.add_mutually_exclusive_group() input_args.required=True input_args.add_argument("-t", "--tree", dest="target_tree", nargs="+", type=str, help="""Tree file in newick format""") input_args.add_argument("-tf", dest="tree_list_file", type=str, help="File with the list of tree files") parser.add_argument("--tax", dest="tax_info", type=str, help="If the taxid attribute is not set in the" " newick file for all leaf nodes, a tab file file" " with the translation of name and taxid can be" " provided with this option.") parser.add_argument("--sp_delimiter", dest="sp_delimiter", type=str, help="If taxid is part of the leaf name, delimiter used to split the string") parser.add_argument("--sp_field", dest="sp_field", type=int, default=0, help="field position for taxid after splitting leaf names") parser.add_argument("--ref", dest="ref_tree", type=str, help="Uses ref tree to compute robinson foulds" " distances of the different subtrees") parser.add_argument("--rf-only", dest="rf_only", action = "store_true", help="Skip ncbi consensus analysis") parser.add_argument("--outgroup", dest="outgroup", type=str, nargs="+", help="A list of node names defining the trees outgroup") parser.add_argument("--is_sptree", dest="is_sptree", action = "store_true", help="Assumes no duplication nodes in the tree") parser.add_argument("-o", dest="output", type=str, help="Writes result into a file") parser.add_argument("--tax2name", dest="tax2name", type=str, help="") parser.add_argument("--tax2track", dest="tax2track", type=str, help="") parser.add_argument("--dump_tax_info", dest="dump_tax_info", action="store_true", help="") args = parser.parse_args(argv) if args.sp_delimiter: GET_TAXID = lambda x: x.split(args.sp_delimiter)[args.sp_field] else: GET_TAXID = None reftree_name = os.path.basename(args.ref_tree) if args.ref_tree else "" if args.explore: print >>sys.stderr, "Reading tree from file:", args.explore t = cPickle.load(open(args.explore)) ts = TreeStyle() ts.force_topology = True ts.show_leaf_name = False ts.layout_fn = ncbi_layout ts.mode = "r" t.show(tree_style=ts) print >>sys.stderr, "dumping color config" cPickle.dump(name2color, open("ncbi_colors.pkl", "w")) sys.exit() if args.output: OUT = open(args.output, "w") else: OUT = sys.stdout print >>sys.stderr, "Dumping results into", OUT target_trees = [] if args.tree_list_file: target_trees = [line.strip() for line in open(args.tree_list_file)] if args.target_tree: target_trees += args.target_tree prev_tree = None if args.tax2name: tax2name = cPickle.load(open(args.tax2name)) else: tax2name = {} if args.tax2track: tax2track = cPickle.load(open(args.tax2track)) else: tax2track = {} print len(tax2track), len(tax2name) header = ("TargetTree", "Subtrees", "Ndups", "Broken subtrees", "Broken clades", "Clade sizes", "RF (avg)", "RF (med)", "RF (std)", "RF (max)", "Shared tips") print >>OUT, '|'.join([h.ljust(15) for h in header]) if args.ref_tree: print >>sys.stderr, "Reading ref tree from", args.ref_tree reft = Tree(args.ref_tree, format=1) else: reft = None prev_broken = set() ENTRIES = [] ncbi.connect_database() for tfile in target_trees: #print tfile t = PhyloTree(tfile, sp_naming_function=None) if GET_TAXID: for n in t.iter_leaves(): n.name = GET_TAXID(n.name) if args.outgroup: if len(args.outgroup) == 1: out = t & args.outgroup[0] else: out = t.get_common_ancestor(args.outgroup) if set(out.get_leaf_names()) ^ set(args.outgroup): raise ValueError("Outgroup is not monophyletic") t.set_outgroup(out) t.ladderize() if prev_tree: tree_compare(t, prev_tree) prev_tree = t if args.tax_info: tax2name, tax2track = annotate_tree_with_taxa(t, args.tax_info, tax2name, tax2track) if args.dump_tax_info: cPickle.dump(tax2track, open("tax2track.pkl", "w")) cPickle.dump(tax2name, open("tax2name.pkl", "w")) print "Tax info written into pickle files" else: for n in t.iter_leaves(): spcode = n.name n.add_features(taxid=spcode) n.add_features(species=spcode) tax2name, tax2track = annotate_tree_with_taxa(t, None, tax2name, tax2track) # Split tree into species trees #subtrees = t.get_speciation_trees() if not args.rf_only: #print "Calculating tree subparts..." t1 = time.time() if not args.is_sptree: subtrees = t.split_by_dups() #print "Subparts:", len(subtrees), time.time()-t1 else: subtrees = [t] valid_subtrees, broken_subtrees, ncbi_mistakes, broken_branches, total_rf, broken_clades, broken_sizes = analyze_subtrees(t, subtrees) #print valid_subtrees, broken_subtrees, ncbi_mistakes, total_rf else: subtrees = [] valid_subtrees, broken_subtrees, ncbi_mistakes, broken_branches, total_rf, broken_clades, broken_sizes = 0, 0, 0, 0, 0, 0 ndups = 0 nsubtrees = len(subtrees) rf = 0 rf_max = 0 rf_std = 0 rf_med = 0 common_names = 0 max_size = 0 if reft and len(subtrees) == 1: rf = t.robinson_foulds(reft, attr_t1="realname") rf_max = rf[1] rf = rf[0] rf_med = rf elif reft: #print "Calculating avg RF..." nsubtrees, ndups, subtrees = t.get_speciation_trees(map_features=["taxid"]) #print len(subtrees), "Sub-Species-trees found" avg_rf = [] rf_max = 0.0 # reft.robinson_foulds(reft)[1] sum_size = 0.0 print nsubtrees, "subtrees", ndups, "duplications" for ii, subt in enumerate(subtrees): print "\r%d" %ii, sys.stdout.flush() try: partial_rf = subt.robinson_foulds(reft, attr_t1="taxid") except ValueError: pass else: sptree_size = len(set([n.taxid for n in subt.iter_leaves()])) sum_size += sptree_size avg_rf.append((partial_rf[0]/float(partial_rf[1])) * sptree_size) common_names = len(partial_rf[3]) max_size = max(max_size, sptree_size) rf_max = max(rf_max, partial_rf[1]) #print partial_rf[:2] rf = numpy.sum(avg_rf) / float(sum_size) # Treeko dist rf_std = numpy.std(avg_rf) rf_med = numpy.median(avg_rf) sizes_info = "%0.1f/%0.1f +- %0.1f" %( numpy.mean(broken_sizes), numpy.median(broken_sizes), numpy.std(broken_sizes)) iter_values = [os.path.basename(tfile), nsubtrees, ndups, broken_subtrees, ncbi_mistakes, broken_branches, sizes_info, rf, rf_med, rf_std, rf_max, common_names] print >>OUT, '|'.join(map(lambda x: str(x).strip().ljust(15), iter_values)) fixed = sorted([n for n in prev_broken if n not in broken_clades]) new_problems = sorted(broken_clades - prev_broken) fixed_string = color(', '.join(fixed), "green") if fixed else "" problems_string = color(', '.join(new_problems), "red") if new_problems else "" OUT.write(" Fixed clades: %s\n" %fixed_string) if fixed else None OUT.write(" New broken: %s\n" %problems_string) if new_problems else None prev_broken = broken_clades ENTRIES.append([os.path.basename(tfile), nsubtrees, ndups, broken_subtrees, ncbi_mistakes, broken_branches, sizes_info, fixed_string, problems_string]) OUT.flush() if args.show_tree or args.render: ts = TreeStyle() ts.force_topology = True #ts.tree_width = 500 ts.show_leaf_name = False ts.layout_fn = ncbi_layout ts.mode = "r" t.dist = 0 if args.show_tree: #if args.hide_monophyletic: # tax2monophyletic = {} # n2content = t.get_node2content() # for node in t.traverse(): # term2count = defaultdict(int) # for leaf in n2content[node]: # if leaf.lineage: # for term in leaf.lineage: # term2count[term] += 1 # expected_size = len(n2content) # for term, count in term2count.iteritems(): # if count > 1 print "Showing tree..." t.show(tree_style=ts) else: t.render("img.svg", tree_style=ts, dpi=300) print "dumping color config" cPickle.dump(name2color, open("ncbi_colors.pkl", "w")) if args.dump: cPickle.dump(t, open("ncbi_analysis.pkl", "w")) print print HEADER = ("TargetTree", "Subtrees", "Ndups", "Broken subtrees", "Broken clades", "Broken branches", "Clade sizes", "Fixed Groups", "New Broken Clades") print_table(ENTRIES, max_col_width = 50, row_line=True, header=HEADER) if args.output: OUT.close() if __name__ == '__main__': main(sys.argv[1:])