import six
import logging
from .base import AbstractCombinedNode
from datetime import timedelta, datetime
from ebu_tt_live.bindings._ebuttdt import LimitedClockTimingType, FullClockTimingType
from ebu_tt_live.bindings import _ebuttm as metadata
from ebu_tt_live.documents import EBUTT3Document
from ebu_tt_live.bindings.pyxb_utils import RecursiveOperation, StopBranchIteration
from ebu_tt_live.bindings.validation.timing import TimingValidationMixin
from ebu_tt_live.errors import UnexpectedSequenceIdentifierError
log = logging.getLogger(__name__)
[docs]class RetimingDelayNode(AbstractCombinedNode):
    _document_sequence = None
    _fixed_delay = None
    _expects = EBUTT3Document
    _provides = EBUTT3Document
    def __init__(self, node_id, fixed_delay, document_sequence, consumer_carriage=None, producer_carriage=None):
        super(RetimingDelayNode, self).__init__(
            node_id=node_id,
            producer_carriage=producer_carriage,
            consumer_carriage=consumer_carriage
        )
        self._fixed_delay = fixed_delay
        self._document_sequence = document_sequence
[docs]    def process_document(self, document, **kwargs):
        if self.is_document(document):
            if document.sequence_identifier == self._document_sequence:
                raise UnexpectedSequenceIdentifierError()
            if self.check_if_document_seen(document=document):
                self.limit_sequence_to_one(document)
                # change the sequence identifier
                document.sequence_identifier = self._document_sequence
                if document.binding.head.metadata is None:
                    document.binding.head.metadata = metadata.headMetadata_type(
                        metadata.documentMetadata()
                    )
                if document.binding.head.metadata.documentMetadata is None:
                    document.binding.head.metadata.documentMetadata = metadata.documentMetadata()
                ap_metadata = metadata.appliedProcessing_type(
                    process='retimed by ' + str(self._fixed_delay) + 's',
                    generatedBy='retiming_delay_node_v1.0',
                    sourceId=self.node_id,
                    appliedDateTime=datetime.now()
                )
                document.binding.head.metadata.documentMetadata.appliedProcessing.append(ap_metadata)
                if has_a_leaf_with_no_timing_path(document.binding.body):
                    update_body_timing(document.binding.body, document.time_base, self._fixed_delay)
                else:
                    update_children_timing(document.binding, document.time_base, self._fixed_delay)
                document.validate()
                self.producer_carriage.emit_data(data=document, **kwargs)
            else:
                log.warning(
                    'Ignoring duplicate document: {}__{}'.format(
                        document.sequence_identifier,
                        document.sequence_number
                    )
                )
        else:
            self.producer_carriage.emit_data(data=document, **kwargs)  
[docs]class BufferDelayNode(AbstractCombinedNode):
    _fixed_delay = None
    _expects = six.text_type
    _provides = six.text_type
    def __init__(self, node_id, fixed_delay, consumer_carriage=None, producer_carriage=None):
        super(BufferDelayNode, self).__init__(
            node_id=node_id,
            producer_carriage=producer_carriage,
            consumer_carriage=consumer_carriage
        )
        self._fixed_delay = fixed_delay
[docs]    def process_document(self, document, **kwargs):
        self.producer_carriage.emit_data(data=document, delay=self._fixed_delay, **kwargs)  
[docs]def update_children_timing(element, timebase, delay_int):
    # if the element has a child
    if hasattr(element, 'orderedContent'):
        children = element.orderedContent()
        for child in children:
            if hasattr(child.value, 'end') and child.value.end != None:
                if timebase == 'clock':
                    delay = LimitedClockTimingType(timedelta(seconds=delay_int))
                    child.value.end = LimitedClockTimingType(child.value.end.timedelta + delay.timedelta)
                elif timebase == 'media':
                    delay = FullClockTimingType(timedelta(seconds=delay_int))
                    child.value.end = FullClockTimingType(child.value.end.timedelta + delay.timedelta)
            if hasattr(child.value, 'begin') and child.value.begin != None:
                if timebase == 'clock':
                    delay = LimitedClockTimingType(timedelta(seconds=delay_int))
                    child.value.begin = LimitedClockTimingType(child.value.begin.timedelta + delay.timedelta)
                elif timebase == 'media':
                    delay = FullClockTimingType(timedelta(seconds=delay_int))
                    child.value.begin = FullClockTimingType(child.value.begin.timedelta + delay.timedelta)
            else:
                update_children_timing(child.value, timebase, delay_int) 
[docs]def update_body_timing(body, timebase, delay_int):
    if hasattr(body, 'begin'):
        assert body.begin == None, "The body already has a begin time"
    # we always update the begin attribute, regardless of the presence of a begin or end attribute
    if timebase == 'clock':
        delay = LimitedClockTimingType(timedelta(seconds=delay_int))
        body.begin = LimitedClockTimingType(delay.timedelta)
    elif timebase == 'media':
        delay = FullClockTimingType(timedelta(seconds=delay_int))
        body.begin = FullClockTimingType(delay.timedelta)
    # if the body has an end attribute, we add to it the value of the delay
    if hasattr(body, 'end') and body.end != None:
        if timebase == 'clock':
            delay = LimitedClockTimingType(timedelta(seconds=delay_int))
            body.end = LimitedClockTimingType(body.end.timedelta + delay.timedelta)
        elif timebase == 'media':
            delay = FullClockTimingType(timedelta(seconds=delay_int))
            body.end = FullClockTimingType(body.end.timedelta + delay.timedelta) 
[docs]def is_explicitly_timed(element):
    # if element has begin or end attribute
    if hasattr(element, 'begin') and element.begin != None or hasattr(element, 'end') and element.end != None:
        return True
    else:
        # if element has children
        if hasattr(element, 'orderedContent'):
            children = element.orderedContent()
            for child in children:
                res = is_explicitly_timed(child.value)
                if res:
                    return res 
[docs]class UntimedPathFinder(RecursiveOperation):
    _path_found = False
    _timed_element_stack = None
    def __init__(self, root_element):
        self._timed_element_stack = []
        super(UntimedPathFinder, self).__init__(
            root_element,
            filter=lambda value, element: isinstance(value, TimingValidationMixin)
        )
    def _is_begin_timed(self, value):
        if value.begin is not None:
            return True
        else:
            return False
    def _before_element(self, value, element=None, parent_binding=None, **kwargs):
        if self._path_found is True:
            raise StopBranchIteration()
        if self._is_begin_timed(value=value):
            self._timed_element_stack.append(value)
    def _after_element(self, value, element=None, parent_binding=None, **kwargs):
        if self._is_begin_timed(value=value):
            bla = self._timed_element_stack.pop()
    def _process_element(self, value, element=None, parent_binding=None, **kwargs):
        if value.is_timed_leaf() and not len(self._timed_element_stack):
            self._path_found = True
            raise StopBranchIteration()
    def _process_non_element(self, value, non_element, parent_binding=None, **kwargs):
        pass
    @property
    def path_found(self):
        return self._path_found 
[docs]def has_a_leaf_with_no_timing_path(element):
    """
    Check if a document has at least one leaf that has no ancestor that has begin time or has begin time itself.
    @param element:
    @return:
    """
    finder = UntimedPathFinder(element)
    finder.proceed()
    return finder.path_found