#!/usr/bin/env python
"""Estimating pairwise distances between sequences.
from cogent.util import parallel, table, warning, progress_display as UI
from cogent.maths.stats.util import Numbers
from cogent import LoadSeqs, LoadTree
from warnings import warn
__author__ = "Gavin Huttley"
__copyright__ = "Copyright 2007-2012, The Cogent Project"
__credits__ = ["Gavin Huttley", "Peter Maxwell", "Matthew Wakefield"]
__license__ = "GPL"
__version__ = "1.5.3"
__maintainer__ = "Gavin Huttley"
__email__ = "gavin.huttley@anu.edu.au"
__status__ = "Production"
class EstimateDistances(object):
    """Base class used for estimating pairwise distances between sequences.
    Can also estimate other parameters from pairs."""
    def __init__(self, seqs, submodel, threeway=False, motif_probs = None,
                do_pair_align=False, rigorous_align=False, est_params=None,
            - seqs: an Alignment or SeqCollection instance with > 1 sequence
            - submodel: substitution model object Predefined models can
              be imported from cogent.evolve.models
            - threeway: a boolean flag for using threeway comparisons to
              estimate distances. default False. Ignored if do_pair_align is
            - do_pair_align: if the input sequences are to be pairwise aligned
              first and then the distance will be estimated. A pair HMM based
              on the submodel will be used.
            - rigorous_align: if True the pairwise alignments are actually
              numerically optimised, otherwise the current substitution model
              settings are used. This slows down estimation considerably.
            - est_params: substitution model parameters to save estimates from
              in addition to length (distance)
            - modify_lf: a callback function for that takes a likelihood
              function (with alignment set) and modifies it. Can be used to
              configure local_params, set bounds, optimise using a restriction
              for faster performance.
        Note: Unless you know a priori your alignment will be flush ended
        (meaning no sequence has terminal gaps) it is advisable to construct a
        substitution model that recodes gaps. Otherwise the terminal gaps will
        significantly bias the estimation of branch lengths when using
        if do_pair_align:
            self.__threeway = False
            # whether pairwise is to be estimated from 3-way
            self.__threeway = [threeway, False][do_pair_align]
        self.__seq_collection = seqs
        self.__seqnames = seqs.getSeqNames()
        self.__motif_probs = motif_probs
        # the following may be pairs or three way combinations
        self.__combination_aligns = None
        self._do_pair_align = do_pair_align
        self._rigorous_align = rigorous_align
        # substitution model stuff
        self.__sm = submodel
        self._modify_lf = modify_lf
        # store for the results
        self.__param_ests = {}
        self.__est_params = list(est_params or [])
        self.__run = False # a flag indicating whether estimation completed
        # whether we're on the master CPU or not
        self._on_master_cpu = parallel.getCommunicator().Get_rank() == 0
    def __str__(self):
        return str(self.getTable())
    def __make_pairwise_comparison_sets(self):
        comps = []
        names = self.__seq_collection.getSeqNames()
        n = len(names)
        for i in range(0, n - 1):
            for j in range(i + 1, n):
                comps.append((names[i], names[j]))
        return comps
    def __make_threeway_comparison_sets(self):
        comps = []
        names = self.__seq_collection.getSeqNames()
        n = len(names)
        for i in range(0, n - 2):
            for j in range(i + 1, n - 1):
                for k in range(j + 1, n):
                    comps.append((names[i], names[j], names[k]))
        return comps
    def __make_pair_alignment(self, seqs, opt_kwargs):
        lf = self.__sm.makeLikelihoodFunction(\
        # allow user to modify the lf config
        if self._modify_lf:
            lf = self._modify_lf(lf)
        if self._rigorous_align:
        lnL = lf.getLogLikelihood()
        (vtLnL, aln) = lnL.edge.getViterbiScoreAndAlignment()
        return aln
    def __doset(self, sequence_names, dist_opt_args, aln_opt_args, ui):
        # slice the alignment
        seqs = self.__seq_collection.takeSeqs(sequence_names)
        if self._do_pair_align:
            ui.display('Aligning', progress=0.0, current=.5)
            align = self.__make_pair_alignment(seqs, aln_opt_args)
            ui.display('', progress=.5, current=.5)
            align = seqs
            ui.display('', progress=0.0, current=1.0)
        # note that we may want to consider removing the redundant gaps
        # create the tree object
        tree = LoadTree(tip_names = sequence_names)
        # make the parameter controller
        lf = self.__sm.makeLikelihoodFunction(tree)
        if not self.__threeway:
            lf.setParamRule('length', is_independent = False)
        if self.__motif_probs:
        # allow user modification of lf using the modify_lf
        if self._modify_lf:
            lf = self._modify_lf(lf)
        # get the statistics
        stats_dict = lf.getParamValueDict(['edge'], 
                params=['length'] + self.__est_params)
        # if two-way, grab first distance only
        if not self.__threeway:
            result = {'length': stats_dict['length'].values()[0] * 2.0}
            result = {'length': stats_dict['length']}
        # include any other params requested
        for param in self.__est_params:
            result[param] = stats_dict[param].values()[0]
        return result
    def run(self, dist_opt_args=None, aln_opt_args=None, ui=None, **kwargs):
        """Start estimating the distances between sequences. Distance estimation
        is done using the Powell local optimiser. This can be changed using the
        dist_opt_args and aln_opt_args.
            - show_progress: whether to display progress. More detailed progress
              information from individual optimisation is controlled by the
            - dist_opt_args, aln_opt_args: arguments for the optimise method for
              the distance estimation and alignment estimation respectively."""
        if 'local' in kwargs:
              warn("local argument ignored, provide it to dist_opt_args or"\
              " aln_opt_args", DeprecationWarning, stacklevel=2)
        dist_opt_args = dist_opt_args or {}
        aln_opt_args = aln_opt_args or {}
        # set the optimiser defaults
        dist_opt_args['local'] = dist_opt_args.get('local', True)
        aln_opt_args['local'] = aln_opt_args.get('local', True)
        # generate the list of unique sequence sets (pairs or triples) to be
        # analysed
        if self.__threeway:
            combination_aligns = self.__make_threeway_comparison_sets()
            desc = "triplet "
            combination_aligns = self.__make_pairwise_comparison_sets()
            desc = "pair "
        labels = [desc + ','.join(names) for names in combination_aligns]
        def _one_alignment(comp):
            result = self.__doset(comp, dist_opt_args, aln_opt_args)
            return (comp, result)
        for (comp, value) in ui.imap(_one_alignment, combination_aligns,
            self.__param_ests[comp] = value
    def getPairwiseParam(self, param, summary_function="mean"):
        """Return the pairwise statistic estimates as a dictionary keyed by
        (seq1, seq2)
            - param: name of a parameter in est_params or 'length'
            - summary_function: a string naming the function used for
              estimating param from threeway distances. Valid values are 'mean'
              (default) and 'median'."""
        summary_func = summary_function.capitalize()
        pairwise_stats = {}
        assert param in self.__est_params + ['length'], \
                "unrecognised param %s" % param
        if self.__threeway and param == 'length':
            pairwise = self.__make_pairwise_comparison_sets()
            # get all the distances involving this pair
            for a, b in pairwise:
                values = Numbers()
                for comp_names, param_vals in self.__param_ests.items():
                    if a in comp_names and b in comp_names:
                        values.append(param_vals[param][a] + \
                pairwise_stats[(a,b)] = getattr(values, summary_func)
            # no additional processing of the distances is required
            for comp_names, param_vals in self.__param_ests.items():
                pairwise_stats[comp_names] = param_vals[param]
        return pairwise_stats
    def getPairwiseDistances(self,summary_function="mean", **kwargs):
        """Return the pairwise distances as a dictionary keyed by (seq1, seq2).
        Convenience interface to getPairwiseParam.
            - summary_function: a string naming the function used for
              estimating param from threeway distances. Valid values are 'mean'
              (default) and 'median'.
        return self.getPairwiseParam('length',summary_function=summary_function,
    def getParamValues(self, param, **kwargs):
        """Returns a Numbers object with all estimated values of param.
            - param: name of a parameter in est_params or 'length'
            - **kwargs: arguments passed to getPairwiseParam"""
        ests = self.getPairwiseParam(param, **kwargs)
        return Numbers(ests.values())
    def getTable(self,summary_function="mean", **kwargs):
        """returns a Table instance of the distance matrix.
            - summary_function: a string naming the function used for
              estimating param from threeway distances. Valid values are 'mean'
              (default) and 'median'."""
        d = \
        if not d:
            d = {}
            for s1 in self.__seqnames:
                for s2 in self.__seqnames:
                    if s1 == s2:
                        d[(s1,s2)] = 'Not Done'
        twoD = []
        for s1 in self.__seqnames:
            row = [s1]
            for s2 in self.__seqnames:
                if s1 == s2:
                except KeyError:
        T = table.Table(['Seq1 \ Seq2'] + self.__seqnames, twoD, row_ids = True,
                        missing_data = "*")
        return T
    def getNewickTrees(self):
        """Returns a list of Newick format trees for supertree methods."""
        trees = []
        for comp_names, param_vals in self.__param_ests.items():
            tips = []
            for name in comp_names:
                tips.append(repr(name)+":%s" % param_vals[name])
        return trees
    def writeToFile(self, filename, summary_function="mean", format='phylip',
        """Save the pairwise distances to a file using phylip format. Other
        formats can be obtained by getting to a Table.  If running in parallel,
        the master CPU writes out.
            - filename: where distances will be written, required.
            - summary_function: a string naming the function used for
              estimating param from threeway distances. Valid values are 'mean'
              (default) and 'median'.
            - format: output format of distance matrix
        if self._on_master_cpu:
             # only write output from 0th node
             table = self.getTable(summary_function=summary_function, **kwargs)
             table.writeToFile(filename, format=format)