import csv
try:
    import numpy
except ImportError: # pragma: no cover
    numpy = None
 
from ..common import TableObject, Table, InternalModuleError
 
 
def count_lines(fp):
    lines = 0
    for line in fp:
        lines += 1
    return lines
 
 
# FIXME : test coverage for CSVTable
class CSVTable(TableObject):
    def __init__(self, csv_file, header_present, delimiter,
                 skip_lines=0, dialect=None, use_sniffer=True):
        self._rows = None
 
        self.header_present = header_present
        self.delimiter = delimiter
        self.filename = csv_file
        self.skip_lines = skip_lines
        self.dialect = dialect
 
        (self.columns, self.names, self.delimiter,
         self.header_present, self.dialect) = \
            self.read_file(csv_file, delimiter, header_present, skip_lines,
                           dialect, use_sniffer)
        if self.header_present:
            self.skip_lines += 1
 
        self.column_cache = {}
 
    @staticmethod
    def read_file(filename, delimiter=None, header_present=True,
                  skip_lines=0, dialect=None, use_sniffer=True):
        if delimiter is None and use_sniffer is False:
            raise InternalModuleError("Must set delimiter if not using sniffer")
 
        try:
            with open(filename, 'rb') as fp:
                if use_sniffer:
                    first_lines = ""
                    line = fp.readline()
                    for i in xrange(skip_lines):
                        if not line:
                            break
                        line = fp.readline()
                    for i in xrange(5):
                        if not line:
                            break
                        first_lines += line
                        line = fp.readline()
                    sniffer = csv.Sniffer()
                    fp.seek(0)
                    if delimiter is None:
                        dialect = sniffer.sniff(first_lines)
                        delimiter = dialect.delimiter
                        # cannot determine header without sniffing delimiter
                        if header_present is None:
                            header_present = sniffer.has_header(first_lines)
 
                for i in xrange(skip_lines):
                    if not line:
                        raise InternalModuleError("skip_lines greater than "
                                                  "the number of lines in the "
                                                  "file")
                    line = fp.readline()
 
                if dialect is not None:
                    reader = csv.reader(fp, dialect=dialect)
                else:
                    reader = csv.reader(fp, delimiter=delimiter)
                result = reader.next()
                column_count = len(result)
 
                if header_present:
                    column_names = [name.strip() for name in result]
                else:
                    column_names = None
        except IOError:
            raise InternalModuleError("File does not exist")
 
        return column_count, column_names, delimiter, header_present, dialect
 
    def get_column(self, index, numeric=False):
        if (index, numeric) in self.column_cache:
            return self.column_cache[(index, numeric)]
 
        if numeric and numpy is not None:
            result = numpy.loadtxt(
                    self.filename,
                    dtype=numpy.float32,
                    delimiter=self.delimiter,
                    skiprows=self.skip_lines,
                    usecols=[index])
        else:
            with open(self.filename, 'rb') as fp:
                for i in xrange(self.skip_lines):
                    line = fp.readline()
                    if not line:
                        raise InternalModuleError("skip_lines greater than "
                                                  "the number of lines in the "
                                                  "file")
                if self.dialect is not None:
                    reader = csv.reader(fp, dialect=self.dialect)
                else:
                    reader = csv.reader(fp, delimiter=self.delimiter)
 
                result = [row[index] for row in reader]
            if numeric:
                result = [float(e) for e in result]
 
        self.column_cache[(index, numeric)] = result
        return result
 
    @property
    def rows(self):
        if self._rows is not None:
            return self._rows
        with open(self.filename, 'rb') as fp:
            self._rows = count_lines(fp)
        self._rows -= self.skip_lines
        return self._rows
 
 
class CSVFile(Table):
    """Reads a table from a CSV file.
 
    This module uses Python's csv module to read a table from a file. It is
    able to guess the actual format of the file in most cases, or you can use
    the 'delimiter', 'header_present' and 'skip_lines' ports to force how the
    file will be read.
    """
    _input_ports = [
            ('file', '(org.vistrails.vistrails.basic:File)'),
            ('delimiter', '(org.vistrails.vistrails.basic:String)',
             {'optional': True}),
            ('header_present', '(org.vistrails.vistrails.basic:Boolean)',
             {'optional': True, 'defaults': "['True']"}),
            ('sniff_header', '(org.vistrails.vistrails.basic:Boolean)',
             {'optional': True, 'defaults': "['True']"}),
            ('skip_lines', '(org.vistrails.vistrails.basic:Integer)',
             {'optional': True, 'defaults': "['0']"}),
            ('dialect', '(org.vistrails.vistrails.basic:String)',
             {'optional': True})]
    _output_ports = [
            ('column_count', '(org.vistrails.vistrails.basic:Integer)'),
            ('column_names', '(org.vistrails.vistrails.basic:List)'),
            ('value', Table)]
 
    def compute(self):
        csv_file = self.get_input('file').name
        header_present = self.force_get_input('header_present', None)
        delimiter = self.force_get_input('delimiter', None)
        skip_lines = self.get_input('skip_lines')
        dialect = self.force_get_input('dialect', None)
        sniff_header = self.get_input('sniff_header')
 
        try:
            table = CSVTable(csv_file, header_present, delimiter, skip_lines,
                             dialect, sniff_header)
        except InternalModuleError, e:
            e.raise_module_error(self)
 
        self.set_output('column_count', table.columns)
        self.set_output('column_names', table.names)
        self.set_output('value', table)
 
 
_modules = [CSVFile]
 
 
###############################################################################
 
from StringIO import StringIO
import unittest
from vistrails.tests.utils import execute, intercept_result
from ..identifiers import identifier
from ..common import ExtractColumn
 
 
class CSVTestCase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        import os
        cls._test_dir = os.path.join(
                os.path.dirname(__file__),
                os.pardir,
                'test_files')
 
    def test_csv_numeric(self):
        """Uses CSVFile and ExtractColumn to load a numeric array.
        """
        with intercept_result(ExtractColumn, 'value') as results:
            with intercept_result(CSVFile, 'column_count') as columns:
                self.assertFalse(execute([
                        ('read|CSVFile', identifier, [
                            ('file', [('File', self._test_dir + '/test.csv')]),
                        ]),
                        ('ExtractColumn', identifier, [
                            ('column_index', [('Integer', '1')]),
                            ('column_name', [('String', 'col 2')]),
                            ('numeric', [('Boolean', 'True')]),
                        ]),
                        ('PythonSource', 'org.vistrails.vistrails.basic', [
                            ('source', [('String', '')]),
                        ]),
                    ],
                    [
                        (0, 'value', 1, 'table'),
                        (1, 'value', 2, 'l'),
                    ],
                    add_port_specs=[
                        (2, 'input', 'l',
                         'org.vistrails.vistrails.basic:List'),
                    ]))
                # Here we use a PythonSource just to check that a numpy array
                # can be passed on a List port
        self.assertEqual(columns, [3])
        self.assertEqual(len(results), 1)
        self.assertEqual(list(results[0]), [2.0, 3.0, 14.5])
 
    def test_csv_mismatch(self):
        """Uses CSVFile and ExtractColumn with mismatching columns.
        """
        self.assertTrue(execute([
                ('read|CSVFile', identifier, [
                    ('file', [('File', self._test_dir + '/test.csv')]),
                ]),
                ('ExtractColumn', identifier, [
                    ('column_index', [('Integer', '0')]), # index is wrong
                    ('column_name', [('String', 'col 2')]),
                ]),
            ],
            [
                (0, 'value', 1, 'table'),
            ]))
 
    def test_csv_missing(self):
        """Uses CSVFile and ExtractColumn with a nonexisting column.
        """
        self.assertTrue(execute([
                ('read|CSVFile', identifier, [
                    ('file', [('File', self._test_dir + '/test.csv')]),
                ]),
                ('ExtractColumn', identifier, [
                    ('column_name', [('String', 'col not here')]),
                ]),
            ],
            [
                (0, 'value', 1, 'table'),
            ]))
 
    def test_csv_nonnumeric(self):
        """Uses CSVFile and ExtractColumn to load strings.
        """
        with intercept_result(ExtractColumn, 'value') as results:
            self.assertFalse(execute([
                    ('read|CSVFile', identifier, [
                        ('file', [('File', self._test_dir + '/test.csv')]),
                        ('header_present', [('Boolean', 'False')]),
                    ]),
                    ('ExtractColumn', identifier, [
                        ('column_index', [('Integer', '2')]),
                        ('numeric', [('Boolean', 'False')]),
                    ]),
                ],
                [
                    (0, 'value', 1, 'table'),
                ]))
        self.assertEqual(len(results), 1)
        self.assertEqual(results[0],
                         ['col moutarde', '4', 'not a number', '7'])
 
 
class TestCountlines(unittest.TestCase):
    def test_countlines(self):
        # Simple
        fp = StringIO("first\nsecond")
        self.assertEqual(count_lines(fp), 2)
 
        # With newline at EOF
        fp = StringIO("first\nsecond\n")
        self.assertEqual(count_lines(fp), 2)
 
        # Empty
        fp = StringIO("")
        self.assertEqual(count_lines(fp), 0)
 
        # Single newline
        fp = StringIO("\n")
        self.assertEqual(count_lines(fp), 1)