• Facebook
  • Twitter
  • Reddit
  • StumbleUpon
  • Digg
  • email

# particle_worker.py
 
# By Nathan C. Hearn
#    August 1, 2007
#
# Default worker class
 
 
from particle_reader import ParticleReader
 
 
MPI_Tag_tag = int(0)
MPI_Tag_times = int(1)
MPI_Tag_part_data = int(2)
 
MPI_Signal_end = int(-1)
 
 
class ParticleWorker :
 
    def __init__(self, file_prefix=None, start_index=None, end_index=None,
                 part_data_names=None, output_prefix=None) :
 
        self.__file_prefix = None
        self.__output_prefix = None
 
        self.__start_file_index = int(0)
        self.__num_files = int(0)
 
        self.__part_data = list()
        self.__part_data_times = list()
 
        self.__num_variables = int(0)
        self.__part_data_names = list()
        self.__part_data_indexes = list()
 
        self.__use_mpi = False
        self.__mpi_rank = int(0)
        self.__mpi_size = int(0)
 
        self.__setup_mpi()
 
        self.reset(file_prefix, start_index, end_index, part_data_names,
                   output_prefix)
 
 
    def reset(self, file_prefix=None, start_index=None, end_index=None,
              part_data_names=None, output_prefix=None) :
 
        self.__file_prefix = None
        self.__output_prefix = None
 
        self.__start_file_index = int(0)
        self.__num_files = int(0)
 
        self.__part_data = list()
        self.__part_data_times = list()
 
        self.__num_variables = int(0)
        self.__part_data_names = list()
        self.__part_data_indexes = list()
 
        if ((file_prefix is not None) and (start_index is not None)
            and (end_index is not None) and (part_data_names is not None)) :
 
            self.__setup_indexes(start_index, end_index)
 
            self.__file_prefix = file_prefix
 
            self.__output_prefix = output_prefix
 
            if (self.__output_prefix is None) :
                self.__output_prefix = str(self.__file_prefix)
 
            for index in range(self.__num_files) :
 
                file_index = index + self.__start_file_index
 
                filename = self.__build_filename(file_index)
 
                part_data = ParticleReader(filename)
                sim_time = part_data.get_sim_time()
 
                self.__part_data.append(part_data)
                self.__part_data_times.append(sim_time)
 
            self.__num_variables = len(part_data_names)
 
            # NOTE: HERE WE ASSUME THE DATA APPEARS IN THE SAME ORDER EACH FILE
 
            for index in range(self.__num_variables) :
 
                name = part_data_names[index]
 
                if (not (part_data.variable_present(name))) :
                    raise RuntimeError("Variable [ " + name + " ] not present")
 
                data_index = part_data.get_variable_index(name)
 
                self.__part_data_names.append(name)
                self.__part_data_indexes.append(data_index)
 
 
        elif ((file_prefix is not None) or (start_index is not None)
            or (end_index is not None) or (part_data_names is not None)) :
 
            raise RuntimeError("All arguments must be specified")
 
 
    def get_init_tags(self) :
 
        tag_list = list()
 
        dataset = self.__part_data[0]
 
        num_parts = dataset.get_num_particles()
 
        for index in range(num_parts) :
            tag_list.append(dataset.get_particle_tag(index))
 
        return tag_list
 
 
    def collect_part_data(self, particle_tag) :
 
        times = list()
        part_data = list()
 
        for file_index in range(self.__num_files) :
 
            current_file = self.__part_data[file_index]
 
            if (current_file.particle_present(particle_tag)) :
 
                current_time = self.__part_data_times[file_index]
 
                current_data \
                       = current_file.get_particle_data_tag(particle_tag)
 
                timestep_data = list()
 
                for index in range(self.__num_variables) :
 
                    data_index = self.__part_data_indexes[index]
 
                    timestep_data.append(current_data[data_index])
 
                times.append(current_time)
 
                part_data.append(timestep_data)
 
        return times, part_data
 
 
    def run_loop(self, tag_list=None) :
 
        if ((self.__mpi_rank == 0) or (not self.__use_mpi)) :
 
            if (tag_list is None) :
                tag_list = self.get_init_tags()
 
            for tag in tag_list :
 
                times, part_data = self.collect_part_data(tag)
 
                self.__handle_data(tag, times, part_data)
 
            self.__handle_data(MPI_Signal_end)
 
        else :
 
            tag = int(0)
 
            while (tag >= 0) :
 
                tag, times, part_data = self.__receive_data()
 
                if (tag >= 0) :
 
                    new_times, new_part_data = self.collect_part_data(tag)
 
                    num_new_times = len(new_times)
 
                    for time_index in range(num_new_times) :
 
                        times.append(new_times[time_index])
                        part_data.append(new_part_data[time_index])
 
                self.__handle_data(tag, times, part_data)
 
 
    def __build_filename(self, index) :
 
        filename = self.__file_prefix + ("%(idx)04d" % { "idx" : index })
 
        return filename
 
 
    def __build_output_filename(self, tag) :
 
        filename = self.__output_prefix + ("%(tag)07d" % { "tag" : tag }) \
                   + ".txt"
 
        return filename
 
 
    def __handle_data(self, tag, times=None, part_data=None) :
 
        if ((not self.__use_mpi)
            or (self.__mpi_rank == (self.__mpi_size - 1))) :
 
            if (tag >= 0) :
                self.__write_data(tag, times, part_data)
 
        else :
 
            import mpi
 
            next_rank = self.__mpi_rank + 1
 
            mpi.send(tag, next_rank, MPI_Tag_tag)
 
            if (tag >= 0) :
                mpi.send(times, next_rank, MPI_Tag_times)
                mpi.send(part_data, next_rank, MPI_Tag_part_data)
 
 
    def __receive_data(self) :
 
        tag = int(-1)
 
        times = None
        part_data = None
 
        if ((self.__use_mpi) and (self.__mpi_rank > 0)) :
 
            import mpi
 
            tag, status = mpi.recv(tag=MPI_Tag_tag)
 
            if (tag >= 0) :
                times, status = mpi.recv(tag=MPI_Tag_times)
                part_data, status = mpi.recv(tag=MPI_Tag_part_data)
 
        return tag, times, part_data
 
 
    def __write_data(self, tag, times, part_data) :
 
        num_times = len(times)
 
        if (len(part_data) != num_times) :
            raise RuntimeError("Particle data does not match times")
 
        filename = self.__build_output_filename(tag)
 
        f = file(filename, "w")
 
        f.write("# Particle [ " + str(tag) + " ]\n\n")
 
        f.write("# Time")
 
        for var_index in range(self.__num_variables) :
            f.write(" " + self.__part_data_names[var_index])
 
        f.write("\n")
 
        for time_index in range(num_times) :
 
            f.write(str(times[time_index]))
 
            timestep_data = part_data[time_index]
 
            for var_index in range(self.__num_variables) :
                f.write(" " + str(timestep_data[var_index]))
 
            f.write("\n")
 
        f.write("\n")
 
 
    def __setup_mpi(self) :
 
        self.__use_mpi = True
 
        try :
            import mpi
        except ImportError :
            self.__use_mpi = False
 
        if (self.__use_mpi) :
 
            self.__mpi_rank = mpi.rank
            self.__mpi_size = mpi.size
 
 
    def __setup_indexes(self, start_index, end_index) :
 
        total_files = (end_index + 1) - start_index
 
        if (total_files < 0) :
            raise ValueError("Last file index can not be less than the "
                             + "first index")
 
        mpi_size = self.__mpi_size
 
        if ((not self.__use_mpi) or (mpi_size == 1)) :
 
            self.__start_file_index = start_index
 
            self.__num_files = total_files
 
        else :
 
            mpi_size = self.__mpi_size
 
            default_num_files = self.__num_files / mpi_size
 
            leftover_files = self.__num_files % mpi_size
 
            # The first leftover_files processes get (default_num_files + 1)
            # files, while the remaining processes get default_num_files files.
 
            mpi_rank = self.__mpi_rank
 
            if (mpi_rank < leftover_files) :
 
                self.__start_file_index = mpi_rank * (default_num_files + 1)
                self.__num_files = default_num_files + 1
 
            else :
 
                self.__start_file_index \
                    = (leftover_files * (default_num_files + 1)) \
                      + ((mpi_rank - leftover_files) * default_num_files)
 
                self.__num_files = default_num_files
 
 
 
if (__name__ == "__main__") :
 
    from sys import argv, exit, stdout, stderr
 
    argc = len(argv)
 
    min_args = int(5)
 
    if (argc < min_args) :
        stderr.write("\n  Usage: " + argv[0] + " filename_prefix "
                     + "start_file_index end_file_index\n"
                     + "      data_name1 [ data_name2 ... ]\n\n")
        exit(1)
 
    argptr = int(1)
 
    filename_prefix = str(argv[argptr])
    argptr += 1
 
    start_index = int(argv[argptr])
    argptr += 1
 
    end_index = int(argv[argptr])
    argptr += 1
 
    data_names = list()
 
    while (argptr < argc) :
        data_names.append(str(argv[argptr]))
        argptr += 1
 
 
    worker = ParticleWorker(filename_prefix, start_index, end_index,
                            data_names)
 
    worker.run_loop()