"""
Template for intervaltree

WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
"""

from pandas._libs.algos import is_monotonic

ctypedef fused int_scalar_t:
    int64_t
    float64_t

ctypedef fused uint_scalar_t:
    uint64_t
    float64_t

ctypedef fused scalar_t:
    int_scalar_t
    uint_scalar_t

# ----------------------------------------------------------------------
# IntervalTree
# ----------------------------------------------------------------------

cdef class IntervalTree(IntervalMixin):
    """A centered interval tree

    Based off the algorithm described on Wikipedia:
    https://en.wikipedia.org/wiki/Interval_tree

    we are emulating the IndexEngine interface
    """
    cdef readonly:
        ndarray left, right
        IntervalNode root
        object dtype
        str closed
        object _is_overlapping, _left_sorter, _right_sorter
        Py_ssize_t _na_count

    def __init__(self, left, right, closed='right', leaf_size=100):
        """
        Parameters
        ----------
        left, right : np.ndarray[ndim=1]
            Left and right bounds for each interval. Assumed to contain no
            NaNs.
        closed : {'left', 'right', 'both', 'neither'}, optional
            Whether the intervals are closed on the left-side, right-side, both
            or neither. Defaults to 'right'.
        leaf_size : int, optional
            Parameter that controls when the tree switches from creating nodes
            to brute-force search. Tune this parameter to optimize query
            performance.
        """
        if closed not in ['left', 'right', 'both', 'neither']:
            raise ValueError("invalid option for 'closed': %s" % closed)

        left = np.asarray(left)
        right = np.asarray(right)
        self.dtype = np.result_type(left, right)
        self.left = np.asarray(left, dtype=self.dtype)
        self.right = np.asarray(right, dtype=self.dtype)

        indices = np.arange(len(left), dtype='int64')

        self.closed = closed

        # GH 23352: ensure no nan in nodes
        mask = ~np.isnan(self.left)
        self._na_count = len(mask) - mask.sum()
        self.left = self.left[mask]
        self.right = self.right[mask]
        indices = indices[mask]

        node_cls = NODE_CLASSES[str(self.dtype), closed]
        self.root = node_cls(self.left, self.right, indices, leaf_size)

    @property
    def left_sorter(self) -> np.ndarray:
        """How to sort the left labels; this is used for binary search
        """
        if self._left_sorter is None:
            values = [self.right, self.left]
            self._left_sorter = np.lexsort(values)
        return self._left_sorter

    @property
    def right_sorter(self) -> np.ndarray:
        """How to sort the right labels
        """
        if self._right_sorter is None:
            self._right_sorter = np.argsort(self.right)
        return self._right_sorter

    @property
    def is_overlapping(self) -> bool:
        """
        Determine if the IntervalTree contains overlapping intervals.
        Cached as self._is_overlapping.
        """
        if self._is_overlapping is not None:
            return self._is_overlapping

        # <= when both sides closed since endpoints can overlap
        op = le if self.closed == 'both' else lt

        # overlap if start of current interval < end of previous interval
        # (current and previous in terms of sorted order by left/start side)
        current = self.left[self.left_sorter[1:]]
        previous = self.right[self.left_sorter[:-1]]
        self._is_overlapping = bool(op(current, previous).any())

        return self._is_overlapping

    @property
    def is_monotonic_increasing(self) -> bool:
        """
        Return True if the IntervalTree is monotonic increasing (only equal or
        increasing values), else False
        """
        if self._na_count > 0:
            return False

        sort_order = self.left_sorter
        return is_monotonic(sort_order, False)[0]

    def get_indexer(self, scalar_t[:] target) -> np.ndarray:
        """Return the positions corresponding to unique intervals that overlap
        with the given array of scalar targets.
        """

        # TODO: write get_indexer_intervals
        cdef:
            Py_ssize_t old_len
            Py_ssize_t i
            Int64Vector result

        result = Int64Vector()
        old_len = 0
        for i in range(len(target)):
            try:
                self.root.query(result, target[i])
            except OverflowError:
                # overflow -> no match, which is already handled below
                pass

            if result.data.n == old_len:
                result.append(-1)
            elif result.data.n > old_len + 1:
                raise KeyError(
                    'indexer does not intersect a unique set of intervals')
            old_len = result.data.n
        return result.to_array().astype('intp')

    def get_indexer_non_unique(self, scalar_t[:] target):
        """Return the positions corresponding to intervals that overlap with
        the given array of scalar targets. Non-unique positions are repeated.
        """
        cdef:
            Py_ssize_t old_len
            Py_ssize_t i
            Int64Vector result, missing

        result = Int64Vector()
        missing = Int64Vector()
        old_len = 0
        for i in range(len(target)):
            try:
                self.root.query(result, target[i])
            except OverflowError:
                # overflow -> no match, which is already handled below
                pass

            if result.data.n == old_len:
                result.append(-1)
                missing.append(i)
            old_len = result.data.n
        return (result.to_array().astype('intp'),
                missing.to_array().astype('intp'))

    def __repr__(self) -> str:
        return ('<IntervalTree[{dtype},{closed}]: '
                '{n_elements} elements>'.format(
                    dtype=self.dtype, closed=self.closed,
                    n_elements=self.root.n_elements))

    # compat with IndexEngine interface
    def clear_mapping(self) -> None:
        pass


cdef take(ndarray source, ndarray indices):
    """Take the given positions from a 1D ndarray
    """
    return PyArray_Take(source, indices, 0)


cdef sort_values_and_indices(all_values, all_indices, subset):
    indices = take(all_indices, subset)
    values = take(all_values, subset)
    sorter = PyArray_ArgSort(values, 0, NPY_QUICKSORT)
    sorted_values = take(values, sorter)
    sorted_indices = take(indices, sorter)
    return sorted_values, sorted_indices


# ----------------------------------------------------------------------
# Nodes
# ----------------------------------------------------------------------

@cython.internal
cdef class IntervalNode:
    cdef readonly:
        int64_t n_elements, n_center, leaf_size
        bint is_leaf_node

    def __repr__(self) -> str:
        if self.is_leaf_node:
            return (
                f"<{type(self).__name__}: {self.n_elements} elements (terminal)>"
            )
        else:
            n_left = self.left_node.n_elements
            n_right = self.right_node.n_elements
            n_center = self.n_elements - n_left - n_right
            return (
                f"<{type(self).__name__}: "
                f"pivot {self.pivot}, {self.n_elements} elements "
                f"({n_left} left, {n_right} right, {n_center} overlapping)>"
            )

    def counts(self):
        """
        Inspect counts on this node
        useful for debugging purposes
        """
        if self.is_leaf_node:
            return self.n_elements
        else:
            m = len(self.center_left_values)
            l = self.left_node.counts()
            r = self.right_node.counts()
            return (m, (l, r))


# we need specialized nodes and leaves to optimize for different dtype and
# closed values

{{py:

nodes = []
for dtype in ['float64', 'int64', 'uint64']:
    for closed, cmp_left, cmp_right in [
        ('left', '<=', '<'),
        ('right', '<', '<='),
        ('both', '<=', '<='),
        ('neither', '<', '<')]:
        cmp_left_converse = '<' if cmp_left == '<=' else '<='
        cmp_right_converse = '<' if cmp_right == '<=' else '<='
        if dtype.startswith('int'):
            fused_prefix = 'int_'
        elif dtype.startswith('uint'):
            fused_prefix = 'uint_'
        elif dtype.startswith('float'):
            fused_prefix = ''
        nodes.append((dtype, dtype.title(),
                      closed, closed.title(),
                      cmp_left,
                      cmp_right,
                      cmp_left_converse,
                      cmp_right_converse,
                      fused_prefix))

}}

NODE_CLASSES = {}

{{for dtype, dtype_title, closed, closed_title, cmp_left, cmp_right,
      cmp_left_converse, cmp_right_converse, fused_prefix in nodes}}


@cython.internal
cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode(IntervalNode):
    """Non-terminal node for an IntervalTree

    Categorizes intervals by those that fall to the left, those that fall to
    the right, and those that overlap with the pivot.
    """
    cdef readonly:
        {{dtype_title}}Closed{{closed_title}}IntervalNode left_node, right_node
        {{dtype}}_t[:] center_left_values, center_right_values, left, right
        int64_t[:] center_left_indices, center_right_indices, indices
        {{dtype}}_t min_left, max_right
        {{dtype}}_t pivot

    def __init__(self,
                 ndarray[{{dtype}}_t, ndim=1] left,
                 ndarray[{{dtype}}_t, ndim=1] right,
                 ndarray[int64_t, ndim=1] indices,
                 int64_t leaf_size):

        self.n_elements = len(left)
        self.leaf_size = leaf_size

        # min_left and min_right are used to speed-up query by skipping
        # query on sub-nodes. If this node has size 0, query is cheap,
        # so these values don't matter.
        if left.size > 0:
            self.min_left = left.min()
            self.max_right = right.max()
        else:
            self.min_left = 0
            self.max_right = 0

        if self.n_elements <= leaf_size:
            # make this a terminal (leaf) node
            self.is_leaf_node = True
            self.left = left
            self.right = right
            self.indices = indices
            self.n_center = 0
        else:
            # calculate a pivot so we can create child nodes
            self.is_leaf_node = False
            self.pivot = np.median(left / 2 + right / 2)
            if np.isinf(self.pivot):
                self.pivot = cython.cast({{dtype}}_t, 0)
                if self.pivot > np.max(right):
                    self.pivot = np.max(left)
                if self.pivot < np.min(left):
                    self.pivot = np.min(right)

            left_set, right_set, center_set = self.classify_intervals(
                left, right)

            self.left_node = self.new_child_node(left, right,
                                                 indices, left_set)
            self.right_node = self.new_child_node(left, right,
                                                  indices, right_set)

            self.center_left_values, self.center_left_indices = \
                sort_values_and_indices(left, indices, center_set)
            self.center_right_values, self.center_right_indices = \
                sort_values_and_indices(right, indices, center_set)
            self.n_center = len(self.center_left_indices)

    @cython.wraparound(False)
    @cython.boundscheck(False)
    cdef classify_intervals(self, {{dtype}}_t[:] left, {{dtype}}_t[:] right):
        """Classify the given intervals based upon whether they fall to the
        left, right, or overlap with this node's pivot.
        """
        cdef:
            Int64Vector left_ind, right_ind, overlapping_ind
            Py_ssize_t i

        left_ind = Int64Vector()
        right_ind = Int64Vector()
        overlapping_ind = Int64Vector()

        for i in range(self.n_elements):
            if right[i] {{cmp_right_converse}} self.pivot:
                left_ind.append(i)
            elif self.pivot {{cmp_left_converse}} left[i]:
                right_ind.append(i)
            else:
                overlapping_ind.append(i)

        return (left_ind.to_array(),
                right_ind.to_array(),
                overlapping_ind.to_array())

    cdef new_child_node(self,
                        ndarray[{{dtype}}_t, ndim=1] left,
                        ndarray[{{dtype}}_t, ndim=1] right,
                        ndarray[int64_t, ndim=1] indices,
                        ndarray[int64_t, ndim=1] subset):
        """Create a new child node.
        """
        left = take(left, subset)
        right = take(right, subset)
        indices = take(indices, subset)
        return {{dtype_title}}Closed{{closed_title}}IntervalNode(
            left, right, indices, self.leaf_size)

    @cython.wraparound(False)
    @cython.boundscheck(False)
    @cython.initializedcheck(False)
    cpdef query(self, Int64Vector result, {{fused_prefix}}scalar_t point):
        """Recursively query this node and its sub-nodes for intervals that
        overlap with the query point.
        """
        cdef:
            int64_t[:] indices
            {{dtype}}_t[:] values
            Py_ssize_t i

        if self.is_leaf_node:
            # Once we get down to a certain size, it doesn't make sense to
            # continue the binary tree structure. Instead, we use linear
            # search.
            for i in range(self.n_elements):
                if self.left[i] {{cmp_left}} point {{cmp_right}} self.right[i]:
                    result.append(self.indices[i])
        else:
            # There are child nodes. Based on comparing our query to the pivot,
            # look at the center values, then go to the relevant child.
            if point < self.pivot:
                values = self.center_left_values
                indices = self.center_left_indices
                for i in range(self.n_center):
                    if not values[i] {{cmp_left}} point:
                        break
                    result.append(indices[i])
                if point {{cmp_right}} self.left_node.max_right:
                    self.left_node.query(result, point)
            elif point > self.pivot:
                values = self.center_right_values
                indices = self.center_right_indices
                for i in range(self.n_center - 1, -1, -1):
                    if not point {{cmp_right}} values[i]:
                        break
                    result.append(indices[i])
                if self.right_node.min_left {{cmp_left}} point:
                    self.right_node.query(result, point)
            else:
                result.extend(self.center_left_indices)


NODE_CLASSES['{{dtype}}',
             '{{closed}}'] = {{dtype_title}}Closed{{closed_title}}IntervalNode

{{endfor}}