########################################################################
#
# See AUTHORS.txt for a list of people who contributed.
# See LICENSE.txt for license information.
#
########################################################################
 
import numpy
from Atom import Atom
from Lattice import Lattice
 
class Molecule(list):
    """Molecule --> group of atoms with bonds
 
    Molecule class is inherited from Python list.  It contains
    a list of Atom instances.  Molecule overloads setitem and setslice
    methods so that the lattice attribute of atoms get set to lattice.
 
    Data members:
        description   -- molecule description
        lattice -- coordinate system (instance of Lattice)
    """
    def __init__(self, atoms=[], lattice=None, description=None, filename=None):
        """define group of atoms in a specified lattice.
 
        atoms    -- list of Atom instances to be included in this Molecule.
                    When atoms argument is an existing Molecule instance,
                    the new Molecule is its copy.
        lattice  -- instance of Lattice defining coordinate systems.
        description -- string description of the structure
        filename -- optional, name of a file to load the structure from.
                    Overrides atoms argument when specified.
 
        Molecule(mol)     create a copy of Molecule instance.
 
        Because Molecule is inherited from a list it can use list expansions,
        for example:
            oxygen_atoms = [ for a in stru if a.symbol == "O" ]
        """
        self._labels = {}
        self._labels_cached = False
        if isinstance(atoms, Molecule):
            mol = atoms
            # create a shallow copy of all source attributes
            self.__dict__.update(mol.__dict__)
            # make a deep copy of source lattice
            self.lattice = Lattice(mol.lattice)
 
        self.description = description
        # check if data should be loaded from file
        if filename is not None:
            self.read(filename)
        # otherwise assign list of atoms to self
        else:
            self[:] = atoms 
        # override from lattice argument
        if lattice is None:
            if not self.lattice:    self.lattice = Lattice()
        #elif not isinstance(lattice, Lattice):
        #   emsg = "expected instance of Lattice"
        #    raise TypeError(emsg)
        else:
            self.lattice = lattice
        import time
        self.date = time.ctime()
#    def __str__(self):
#        """simple string representation"""
#        s_lattice = "lattice=%s" % self.lattice
#        s_atoms = '\n'.join([str(a) for a in self])
#        return s_lattice + '\n' + s_atoms
 
    def getChemicalFormula(self):
        atoms = self
        counts = {}
        for atom in atoms:
            e = atom.symbol
            if e in counts: counts[e]+=1
            else: counts[e]=1
            continue
        elems = counts.keys()
        elems.sort()
        chemFormRaw = ''.join( '%s_%s ' % (e, counts[e]) for e in elems )
        return chemFormRaw.strip()
 
    def getSpecies(self):
        speciesList = []
        for atom in self:
            if atom.symbol in speciesList:
                pass
            else:
                speciesList.append(atom.symbol)
        return speciesList
 
    def addNewAtom(self, *args, **kwargs):
        """Add new Atom instance to the end of this Structure.
 
        All arguments are forwarded to Atom constructor.
 
        No return value.
        """
        kwargs['lattice'] = self.lattice
        a = Atom(*args, **kwargs)
        list.append(self, a)
        self._uncache('labels')
        return
 
    def getLastAtom(self):
        """Return Reference to the last Atom in this structure.
        """
        last_atom = self[-1]
        return last_atom
 
    def getAtom(self, id):
        """Reference to internal Atom specified by the identifier.
 
        id  -- zero based index or a string label formatted as
               "%(element)s%(order)i", for example "Na1", "Cl1"
 
        Return Atom instance.
        Raise ValueError for invalid id.
 
        See also getLabels().
        """
        try:
            if type(id) is int:
                rv = self[id]
            else:
                if not self._labels_cached or id not in self._labels:
                    self._update_labels()
                rv = self._labels[id]
        except (IndexError, KeyError):
            emsg = "Invalid atom identifier %r." % id
            raise ValueError(emsg)
        return rv
 
    def getLabels(self):
        """List of unique string labels for all atoms in this structure.
 
        Return a list.
        """
        elnum = {}
        labels = []
        for a in self:
            elnum[a.symbol] = elnum.get(a.symbol, 0) + 1
            alabel = a.symbol + str(elnum[a.symbol])
            labels.append(alabel)
        return labels
 
    def getPosition(self, siteId):
        """Returns the (fractional) position of a site."""
        return self._siteIds[siteId].getPosition()
 
    def setPositions(self, positions):
        """Sets the (fractional) positions of the sites in the unit cell."""
        assert(len(positions) == self.getNumSites())
        for isite in range(self.getNumSites()):
            self._sites[isite].setPosition(positions[isite])
 
    def generateDescription(self):
        if self._description==None:
            self._description = self.getChemicalFormula()#+' in '+str(self.lattice)
        return self._description
    def setDescription(self, desc):
        self._description = desc
    description = property(generateDescription, setDescription, "structure description")
 
################################################    
# property methods
################################################
#Notes:
# for now these are done in the style of diffraction
# eventually will be done in Jiao's style with metaclasses
 
    # fractional xyz
    def _get_xyz(self):
        return [atom.xyz.tolist()  for atom in self[:]]
    def _set_xyz(self, xyzList):
        for atom,xyz in zip(self, xyzList):
            atom.xyz = xyz
    xyz = property(_get_xyz, _set_xyz, doc =
        """fractional coordinates of all atoms""" )  
 
    # xyz_cartn
    def _get_xyz_cartn(self):
        return [atom.xyz_cartn.tolist() for atom in self[:]]
    def _set_xyz_cartn(self, xyzList):
        for atom,xyz_cartn in zip(self, xyzList):
            atom.xyz_cartn = xyz_cartn
    xyz_cartn = property(_get_xyz_cartn, _set_xyz_cartn, doc =
        """absolute Cartesian coordinates of all atoms""" )   
 
    # symbols
    def _get_symbols(self):
        return [atom.symbol for atom in self[:]]
    def _set_symbols(self, symbolList):
        for atom,symbol in zip(self, symbolList):
            atom.symbol = symbol
    symbols = property(_get_symbols, _set_symbols, doc =
        """symbols of all atoms""" )  
 
    # forces
    def _get_forces(self):
        return [atom.force for atom in self]
    def _set_forces(self, forceList):
        for atom,force in zip(self, forceList):
            atom.force = force
    forces = property(_get_forces, _set_forces, doc =
        """forces on all atoms""" )   
 
    # charges
    def _get_charges(self):
        return [atom.charge for atom in self]
    def _set_charges(self, chargeList):
        for atom,charge in zip(self, chargeList):
            atom.charge = charge
    charges = property(_get_charges, _set_charges, doc =
        """charges on all atoms in electron units""" )   
 
################################################################################################    
# geometry and symmetry methods--these should be farmed out to Geometry class which does all this--see vimm
################################################################################################ 
 
    def distance(self, id0, id1):
        """Distance between 2 atoms, no periodic boundary conditions.
 
        id0 -- zero based index of the first atom or a string label
               such as "Na1"
        id1 -- zero based index or string label of the second atom.
 
        Return float.
        Raise ValueError for invalid arguments.
        """
        a0 = self.getAtom(id0)
        a1 = self.getAtom(id1)
        return self.lattice.dist(a0.xyz, a1.xyz)
 
    def angle(self, a0, a1, a2):
        """angle at atom a1 in degrees"""
        u10 = a0.xyz - a1.xyz
        u12 = a2.xyz - a1.xyz
        return self.lattice.angle(u10, u12)
 
    def computeDistances(self, maxdist=30, latticeRange=[2,2,2]):
        """ unitcell.computeDistances(self, [nx,ny,nz]):
        builds up a Big multiple dictionary, namely
        self.distances[atom1][atom2][(DX,DY,DZ)]
        (DX,DY,DZ) are integer numbers specifying the cell containing atom2,
        relatively to atom1.
        DX,DY,DZ run from -nx to nx, -ny to ny, -nz to nz, respectively."""
        distances = {}
        idlist = self.getLabels()
        #idlist = self.getIds()
        for iA in range(0, len(idlist)):
            idA = idlist[iA]
            distances[idA] = {}
            for iB in range(0, len(idlist)):
                idB = idlist[iB]
                distances[idA][idB]={}
                for tx in range(-latticeRange[0],latticeRange[0]+1):
                    for ty in range(-latticeRange[1],latticeRange[1]+1):
                        for tz in range(-latticeRange[2],latticeRange[2]+1):
                            posA = self.getCartesianPosition(idA)
                            posB = self.getCartesianPosition(idB) + numpy.dot([tx,ty,tz], self._lattice)
                            dist = numpy.sqrt(numpy.sum( (posB-posA) * (posB-posA) ))
                            if(dist<maxdist):
                                distances[idA][idB][(tx,ty,tz)] = dist
        self._distances = distances
        return distances
 
###########################
# IO
###########################
    def read(self, filename, format='auto'):
        """Load structure from a file, any original data may become lost.
 
        filename -- file to be loaded
        format   -- all structure formats are defined in Parsers submodule,
                    when format == 'auto' all Parsers are tried one by one
 
        Return instance of data Parser used to process file.  This
        can be inspected for information related to particular format.
        """
        from Parsers import getParser
        p = getParser(format)
        new_structure = p.parseFile(filename)
        # reinitialize data after successful parsing
        # avoid calling __init__ from a derived class
        Molecule.__init__(self)
        if new_structure is not None:
            self.__dict__.update(new_structure.__dict__)
            self[:] = new_structure
        if not self.description:
            self.generateDescription()
#            import os.path
#            tailname = os.path.basename(filename)
#            tailbase = os.path.splitext(tailname)[0]
#            self.description = tailbase
        return p
 
    def readStr(self, s, format='auto'):
        """Load structure from a string, any original data become lost.
 
        s        -- string with structure definition
        format   -- all structure formats are defined in Parsers submodule,
                    when format == 'auto' all Parsers are tried one by one
 
        Return instance of data Parser used to process input string.  This
        can be inspected for information related to particular format.
        """
        from Parsers import getParser
        p = getParser(format)
        new_structure = p.parse(s)
        # reinitialize data after successful parsing
        # avoid calling __init__ from a derived class
        Molecule.__init__(self)
        if new_structure is not None:
            self.__dict__.update(new_structure.__dict__)
            self[:] = new_structure
        return p
 
    def write(self, filename, format):
        """Save molecule to file in the specified format
 
        No return value.
 
        Note: available structure formats can be obtained by:
            from Parsers import formats
        """
        from Parsers import getParser
        p = getParser(format)
        p.filename = filename
        s = p.tostring(self)
        f = open(filename, 'wb')
        f.write(s)
        f.close()
        return
 
    def writeStr(self, format):
        """return string representation of the structure in specified format
 
        Note: available structure formats can be obtained by:
            from Parsers import formats
        """
        from Parsers import getParser
        p = getParser(format)
        s = p.tostring(self)
        return s
 
    ##############################################################################
    # overloaded list methods
    ##############################################################################
 
    def append(self, a, copy=True):
        """Append atom to a structure and update its lattice attribute.
 
        a    -- instance of Atom
        copy -- flag for appending a copy of a.
                When False, append a and update a.owner.
 
        No return value.
        """
        self._uncache('labels')
        adup = copy and Atom(a) or a
        adup.lattice = self.lattice
        list.append(self, adup)
        return
 
    def insert(self, idx, a, copy=True):
        """Insert atom a before position idx in this Structure.
 
        idx  -- position in atom list
        a    -- instance of Atom
        copy -- flag for inserting a copy of a.
                When False, append a and update a.lattice.
 
        No return value.
        """
        self._uncache('labels')
        adup = copy and Atom(a) or a
        adup.lattice = self.lattice
        list.insert(self, idx, adup)
        return
 
    def extend(self, atoms, copy=True):
        """Extend Structure by appending copies from a list of atoms.
 
        atoms -- list of Atom instances
        copy  -- flag for extending with copies of Atom instances.
                 When False extend with atoms and update their lattice
                 attributes.
 
        No return value.
        """
        self._uncache('labels')
        if copy:    adups = [Atom(a) for a in atoms]
        else:       adups = atoms
        for a in adups: a.lattice = self.lattice
        list.extend(self, adups)
        return
 
    def __setitem__(self, idx, a, copy=True):
        """Set idx-th atom to a.
 
        idx  -- index of atom in this Structure
        a    -- instance of Atom
        copy -- flag for setting to a copy of a.
                When False, set to a and update a.lattice.
 
        No return value.
        """
        self._uncache('labels')
        adup = copy and Atom(a) or a
        adup.lattice = self.lattice
        list.__setitem__(self, idx, adup)
        return
 
    def __setslice__(self, lo, hi, atoms, copy=False):
        """Set Structure slice from lo to hi-1 to the sequence of atoms.
 
        lo    -- low index for the slice
        hi    -- high index of the slice
        atoms -- sequence of Atom instances
        copy  -- flag for using copies of Atom instances.  When False, set
                 to existing instances and update their lattice attributes.
 
        No return value.
        """
        self._uncache('labels')
        if copy:    
            adups = [Atom(a) for a in atoms]
        else:       
            adups = atoms
        for a in adups: a.lattice = self.lattice
        list.__setslice__(self, lo, hi, adups)
 
    ####################################################################
    # property handlers
    ####################################################################
 
    # lattice
    def _get_lattice(self):
        if not hasattr(self, '_lattice'):
            self._lattice = Lattice()
        return self._lattice
    def _set_lattice(self, value):
        for a in self:  a.lattice = value
        self._lattice = value
        return
    lattice = property(_get_lattice, _set_lattice, doc =
        "Coordinate system for this Structure.")
 
    ####################################################################
    # protected methods
    ####################################################################
 
    def _update_labels(self):
        """Update the _labels dictionary of unique string labels of atoms.
 
        No return value.
        """
        kv = zip(self.getLabels(), self[:])
        self._labels = dict(kv)
        self._labels_cached = True
 
    def _uncache(self, *args):
        """Reset cached flag for a list of internal attributes.
 
        *args -- list of strings, currently supported are "labels"
 
        No return value.
        Raise AttributeError for any invalid args.
        """
        for a in args:
            attrname = "_" + a + "_cached"
            setattr(self, attrname, False)
 
if __name__=='__main__':
    m = Molecule()
    print m
    from Atom import Atom
    o1 = Atom('O',[0.0, 0.0, 0.0])
    print o1.xyz_cartn
    h1 = Atom('H',[0.5, 0.0, 0.0])
    h2 = Atom('H',[0.0, 0.5, 0.0])
    m = Molecule([o1,h1,h2])
    print m