# -*- coding: utf-8 -*-
# legume. Copyright 2009-2013 Dale Reidy. All rights reserved.
# See LICENSE for details.
 
__docformat__ = 'restructuredtext'
 
import struct
import random
import logging
from legume import netshared
from legume import timing as time
from legume.nevent import Event
from legume.pingsampler import PingSampler
from legume.bitfield import bitfield
from legume.bytebuffer import ByteBuffer
from legume import messages
import errno
 
PING_REQUEST_FREQUENCY = 2.0
CONNECTION_LOSS = 0
 
class OutgoingMessage(object):
    def __init__(self, message_id, message_bytes, require_ack):
        self.message_id = message_id
        self.message_bytes = message_bytes
        self.require_ack = require_ack
 
        '''
        If this message requires an ack this timestamp will
        either be None (not yet sent), or a value obtained
        from time.time() indicating when the message was last sent.
        The EndpointBuffer will not send a packet until 2xRTT ms
        has elapsed between send attempts.
        '''
        self.last_send_attempt_timestamp = None
 
    @property
    def length(self):
        return len(self.message_bytes)
 
class Connection(object):
 
    MTU = 1400
    MESSAGE_TRANSPORT_HEADER = 'HHB'
    RECENT_MESSAGE_LIST_SIZE = 1000
    MINIMUM_RESEND_DELAY_MS = 10 / 1000.0
 
    _log = logging.getLogger('legume.Connection')
 
    def __init__(self, parent=None, message_factory=None):
        if message_factory is None:
            self.message_factory = parent.message_factory
        else:
            self.message_factory = message_factory
 
        self.parent = parent
        self._last_receive_timestamp = time.time()
        self._last_send_timestamp = time.time()
        self._keep_alive_send_timestamp = time.time()
        self._keep_alive_message_id = 0
 
        # server: number of keepalives sent
        # client: number of keepalives received
        self._keepalive_count = 0
 
        self._ping_id = 0
        self._ping_send_timestamp = time.time()
        self._ping_meter = PingSampler()
 
        self.OnConnectRequestAccepted = Event()
        self.OnConnectRequestRejected = Event()
        self.OnConnectRequest = Event()
        self.OnError = Event()
        self.OnMessage = Event()
        self.OnDisconnect = Event()
 
        # Packet instances to be processed go in here
        self._incoming_messages = []
 
        # List of OutgoingMessages
        self._outgoing = []
 
        # In-order packet instances that have arrived early
        self._incoming_out_of_sequence_messages = []
 
        self._incoming_ordered_sequence_number = 0
        self._outgoing_ordered_sequence_number = 1
        self._outgoing_message_id = 0
        self._recent_message_ids = []
 
        # Metrics
        self._in_bytes = 0
        self._out_bytes = 0
        self._in_packets = 0
        self._out_packets = 0
        self._in_messages = 0
        self._out_messages = 0
 
        '''
        Default transport latency is high - This prevents spamming
        of the network prior to obtaining a calculated latency.
        '''
        self._transport_latency = 0.3 # 0.1 = 100ms
 
    @property
    def out_buffer_bytes(self):
        return sum([len(o.message_bytes) for o in self._outgoing])
 
    @property
    def latency(self):
        return self._ping_meter.get_ping()
 
    @property
    def in_bytes(self):
        return self._in_bytes
 
    @property
    def out_bytes(self):
        return self._out_bytes
 
    @property
    def reorder_queue(self):
        return len(self._incoming_out_of_sequence_messages)
 
    @property
    def keepalive_count(self):
        return self._keepalive_count
 
    # ------------- Public Methods -------------
 
    def process_inbound_packet(self, data):
        self._in_packets += 1
        self._process_inbound_packet(data)
 
    def update(self):
        '''
        Send any packets that are in the output buffer and read
        any packets that have been received.
        '''
        try:
            self.parent.do_read(self._on_socket_data)
        except netshared.NetworkEndpointError:
            self.raiseOnError('Connection reset by peer')
            return
 
        if self._ping_meter.has_estimate():
            self._transport_latency = self._ping_meter.get_ping()
 
        read_messages = self._update(
                        self.parent._socket, self.parent._address)
 
        if len(read_messages) != 0:
            self._last_receive_timestamp = time.time()
 
        for message in read_messages:
 
            if self.message_factory.is_a(message, 'ConnectRequestAccepted'):
                self.OnConnectRequestAccepted(self, None)
 
            elif self.message_factory.is_a(message, 'ConnectRequestRejected'):
                self.OnConnectRequestRejected(self, None)
 
            elif self.message_factory.is_a(message, 'KeepAliveResponse'):
 
                if (message.id.value == self._keep_alive_message_id):
                    self._ping_meter.add_sample(
                        (time.time()-self._keep_alive_send_timestamp)*1000)
                else:
                    self._log.warning('Received old keep-alive, discarding')
 
            elif self.message_factory.is_a(message, 'KeepAliveRequest'):
                self._keepalive_count += 1
                response = self.message_factory.get_by_name('KeepAliveResponse')()
                response.id.value = message.id.value
                self.send_message(response)
 
            elif self.message_factory.is_a(message, 'Pong'):
                if (message.id.value == self._ping_id):
                    self._ping_meter.add_sample(
                      (time.time()-self._ping_send_timestamp)*1000)
                else:
                    self._log.warning('Received old Pong, discarding')
 
            elif self.message_factory.is_a(message, 'Ping'):
                self._send_pong(message.id.value)
 
            elif self.message_factory.is_a(message, 'Disconnected'):
                self._log.debug('Received `Disconnected` message')
                self.OnDisconnect(self, None)
 
            elif self.message_factory.is_a(message, 'MessageAck'):
                self._process_message_ack(message.message_to_ack.value)
 
            elif self.message_factory.is_a(message, 'ConnectRequest'):
                # Unless the connection request is explicitly denied then
                # a connection is made - OnConnectRequest may return None
                # if no event handlers are bound.
                accept = True
 
                if (message.protocol.value != netshared.PROTOCOL_VERSION):
                    self._log.error('Invalid protocol version for client')
                    accept = False
                if self.OnConnectRequest(self.parent, message) is False:
                    accept = False
 
                if accept:
                    response = self.message_factory.get_by_name('ConnectRequestAccepted')
                    self.send_reliable_message(response())
                else:
                    response = self.message_factory.get_by_name('ConnectRequestRejected')
                    self.send_reliable_message(response())
                    self.pendingDisconnect = True
            else:
                self.OnMessage(self, message)
 
 
        if (time.time() > self._ping_send_timestamp + PING_REQUEST_FREQUENCY):
            if self.parent.is_server:
                self._keep_alive_send_timestamp = time.time()
            self._send_ping()
 
 
        if self.parent.is_server:
            # Server sends keep alive requests...
            if ((time.time()-self._keep_alive_send_timestamp)>
               (self.parent.timeout/2)):
                self._send_keep_alive()
            # though it will eventually give up...
            if (time.time()-self._last_receive_timestamp)>(self.parent.timeout):
                self.OnError(self, 'Connection timed out')
        else:
            # ...Client waits for the connection to timeout
            if (time.time()-self._last_receive_timestamp)>(self.parent.timeout):
                self._log.info('Connection has timed out')
                self.OnError(self, 'Connection timed out')
 
    def send_message(self, message, ordered=False, reliable=False):
        '''
        Send a message and specify any options for the send method used.
        A message sent inOrder is implicitly sent as reliable.
        message is an instance of a subclass of packets.BasePacket.
        Returns the number of bytes added to the output queue for this
        message (header + message).
        '''
        self._last_send_timestamp = time.time()
 
        self._outgoing_message_id += 1
        message_id = self._outgoing_message_id
        if ordered:
            self._outgoing_ordered_sequence_number += 1
            inorder_sequence_number = self._outgoing_ordered_sequence_number
        else:
            inorder_sequence_number = 0
 
        packet_flags = bitfield()
        packet_flags[0] = int(ordered)
        packet_flags[1] = int(reliable)
 
        message_transport_header = struct.pack(
            '!'+self.MESSAGE_TRANSPORT_HEADER,
            message_id, inorder_sequence_number, int(packet_flags))
 
        message_bytes = message.get_packet_bytes()
        total_length = len(message_bytes)+len(message_transport_header)
        self._out_bytes += total_length
 
        self._add_message_bytes_to_output_list(
            message_id,
            message_transport_header+message_bytes,
            ordered or reliable)
 
        self._log.debug('Packet data length = %s' % len(message_bytes))
        self._log.debug('Header length = %s' % len(message_transport_header))
        self._log.debug('Added %d byte %s packet in outgoing buffer' %
            (total_length, message.__class__.__name__))
 
        return total_length
 
    def send_reliable_message(self, message):
        '''
        Send a message that is guaranteed to be delivered.
        message is an instance of a subclass of packets.BasePacket
        '''
        self.send_message(message, False, True)
 
    def send_inorder_message(self, message):
        '''
        Send a message in the in-order channel. Any packets sent in-order will
        arrive in the order they were sent.
        message is an instance of a subclass of packets.BasePacket
        '''
        self.send_message(message, True)
 
    def has_outgoing_packets(self):
        '''
        Returns whether this buffer has any packets waiting to be sent.
        '''
        return len(self._outgoing) > 0
 
    # ------------- Private Methods -------------
 
    def _on_socket_data(self, data, addr):
        self._process_inbound_packet(data)
 
    def _send_keep_alive(self):
        self._keep_alive_message_id += 1
        if (self._keep_alive_message_id > netshared.USHRT_MAX):
            self._keep_alive_message_id = 0
 
        message = self.message_factory.get_by_name('KeepAliveRequest')()
        message.id.value = self._keep_alive_message_id
 
        self.send_message(message)
        self._keep_alive_send_timestamp = time.time()
        self._keepalive_count += 1
 
    def _send_ping(self):
        self._ping_id += 1
        if (self._ping_id > netshared.USHRT_MAX):
            self._ping_id = 0
 
        ping = self.message_factory.get_by_name('Ping')()
        ping.id.value = self._ping_id
        self.send_message(ping)
        self._ping_send_timestamp = time.time()
 
    def _send_pong(self, pingID):
        pong = self.message_factory.get_by_name('Pong')()
        pong.id.value = pingID
        self.send_message(pong)
 
    def _process_message_ack(self, message_id):
        for m in self._outgoing:
            if m.message_id == message_id:
                self._outgoing.remove(m)
                return
 
        self._log.warning('Got duplicate ACK for packet. message_id=%s' % (
            message_id))
 
 
    def _parse_packet(self, packet_bytes):
        '''
        Parse a raw udp packet and return a list of parsed messages.
        '''
 
        byte_buffer = ByteBuffer(packet_bytes)
        parsed_messages = []
 
        while not byte_buffer.is_empty():
 
            message_id, sequence_number, message_flags = \
                byte_buffer.read_struct(self.MESSAGE_TRANSPORT_HEADER)
 
            message_type_id = messages.BaseMessage.read_header_from_byte_buffer(
                byte_buffer)[0]
            message = self.message_factory.get_by_id(message_type_id)()
            message.read_from_byte_buffer(byte_buffer)
 
            # - These flags are for consumption by .update()
            message_flags_bf = bitfield(message_flags)
            message.is_reliable = message_flags_bf[1]
            message.is_ordered = message_flags_bf[0]
 
            message.sequence_number = sequence_number
            message.message_id = message_id
 
            parsed_messages.append(message)
 
        return parsed_messages
 
    def _process_inbound_packet(self, packet_bytes):
        '''
        Pass raw udp packet data to this method.
        Returns the number of packets parsed and inserted into
        the .incoming list.
        '''
        self._log.debug('%d bytes of packet_bytes read' % len(packet_bytes))
        messages_to_read = self._parse_packet(packet_bytes)
 
        self._in_bytes += len(packet_bytes)
 
        self._log.debug('parsed %d messages from packet' % len(messages_to_read))
 
        for message in messages_to_read:
            if not message.message_id in self._recent_message_ids:
                self._log.debug('Message ordered flag %s' % str(message.is_ordered))
                if message.is_ordered:
                    if self._can_read_inorder_message(message.sequence_number):
                        self._insert_message(message)
                    else:
                        self._incoming_out_of_sequence_messages.append(message)
                        self._recent_message_ids.append(message.message_id)
 
                else:
                    self._insert_message(message)
 
        return len(messages_to_read)
 
    def _update(self, sock, address):
        '''
        Update this buffer by sending any messages in the output lists
        and read any messages which have been insert into the inputBuffer
        via the process_inbound_packet call.
 
        Returns a list of message instances of messages that were read.
        '''
        read_packets = self._do_read()
        self._truncate_recent_message_list()
        self._do_write(sock, address)
 
        return read_packets
 
    def _add_message_bytes_to_output_list(self, message_id,
                                     message_bytes, require_ack=False):
        if len(message_bytes) > self.MTU:
            raise BufferError('Packet is too large. size=%s, mtu=%s' % (
                len(message_bytes), self.MTU))
        else:
            self._outgoing.append(
                OutgoingMessage(message_id, message_bytes, require_ack))
 
    def _can_read_inorder_message(self, sequence_number):
        '''
        Can the in-order message with the specified sequence number be
        insert into the .incoming list for processing?
        '''
        return self._incoming_ordered_sequence_number == (sequence_number+1)
 
    def _create_packet(self):
        packet_size = 0
        packet_bytes = bytearray()
 
        sent_messages = []
 
        self._log.debug('%d packets pending' % len(self._outgoing))
 
        for message in self._outgoing:
 
            if message.require_ack:
 
                # a minimum resend delay is required for two reasons:
                # 1. deadlock with a 0ms latency connection causes _do_write
                #    to never exit as _create_packet always returns data.
                # 2. resending every 0ms is just plain stupid.
 
                t = time.time()
                resend_delay = max(self.MINIMUM_RESEND_DELAY_MS, self._transport_latency)
 
                self._log.debug('LSAT: %s' % str(message.last_send_attempt_timestamp))
                self._log.debug('l8nc: %s' % str(self._transport_latency))
                self._log.debug('time: %s' % t)
                self._log.debug('rsnd: %s' % resend_delay)
 
                if message.last_send_attempt_timestamp is not None:
                    if ((message.last_send_attempt_timestamp +
                      resend_delay) >= t):
                        self._log.debug('Waiting for ack.')
                        continue
 
            if packet_size + message.length <= self.MTU:
                self._log.debug('Added data message into UDP packet')
                packet_size += message.length
                packet_bytes += message.message_bytes
                message.last_send_attempt_timestamp = time.time()
                sent_messages.append(message)
            else:
                self._log.debug('packet at MTU limit.')
 
        for sent_message in sent_messages:
            # Packets that require an ack are only removed
            # from the outgoing list if an ack is received.
            if not sent_message.require_ack:
                self._log.info('Message %d doesnt require ack - removing' %
                    sent_message.message_id)
                self._outgoing.remove(sent_message)
            else:
                self._log.info(
                    'Message %d requires ack - waiting for response' %
                    sent_message.message_id)
 
        return packet_bytes
 
    def _do_read(self):
        unheld_messages = []
        for held_message in self._incoming_out_of_sequence_messages:
 
            if self._can_read_inorder_message(held_message.sequence_number):
                unheld_messages.append(held_message)
 
                self._incoming_inorder_sequence_number = held_message.sequence_number
                self._incoming_messages.append(held_message)
 
        for unheld_message in unheld_messages:
            self._incoming_out_of_sequence_messages.remove(unheld_message)
 
        for message in self._incoming_messages:
            self._log.debug('Incoming message:')
            self._log.debug('IsInOrder: %d' % message.is_ordered)
            self._log.debug('IsReliable:%d' % message.is_reliable)
            self._log.debug('MessageID :%d' % message.message_id)
            if message.is_ordered or message.is_reliable:
                ack_message = messages.MessageAck()
                ack_message.message_to_ack.value = message.message_id
                self.send_message(ack_message)
                self._log.info(
                  'Informing of reciept of message %d' % message.message_id)
 
        read_messages = self._incoming_messages
        self._incoming_messages = []
 
        return read_messages
 
    def _do_write(self, sock, address):
        # TODO: combine _create_packet into this method, eg:
        #   while has_messages:
        #     add_message_to_packet
        #     is packet_full?
        #       send_packet
        #   is packet_partially_full?
        #     send_packet
 
        while True:
            packet = self._create_packet()
            if not packet:
                break
 
            if ((CONNECTION_LOSS == 0) or (random.randint(1, 100) > CONNECTION_LOSS)):
                try:
                    bytes_sent = sock.sendto(packet, 0, address)
                    self._out_packets += 1
                except IOError as e:
                    # HACK: ewouldblocks are ignored and the packet is silently
                    # discarded. Packet sending should be re-written to
                    # only remove messages from the send queue if the socket
                    # operation completes successfully.
                    errornum = e[0]
                    if errornum != errno.EWOULDBLOCK:
                        raise
            else:
                self._log.info('Simulated packet loss')
 
            self._log.info('Sent UDP packet %d bytes in length' % len(packet))
 
    def _insert_message(self, message):
        self._incoming_messages.append(message)
        self._recent_message_ids.append(message.message_id)
        if message.is_ordered:
            self._incoming_ordered_sequence_number = message.sequence_number
 
    def _truncate_recent_message_list(self):
        '''
        Ensures that the recentMessageIDs list length is kept below
        MAX_RECENT_PACKET_LIST_SIZE. This method is called as part of this class'
        update method.
        '''
        if len(self._recent_message_ids) > self.RECENT_MESSAGE_LIST_SIZE:
            self._recent_message_ids = \
                self._recent_message_ids[-self.RECENT_MESSAGE_LIST_SIZE:]