from typing import Optional
from dataiku.eda.types import Literal

import numpy as np

from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.filtering.filter import Filter
from dataiku.eda.types import ClosedMode, IntervalFilterModel


class IntervalFilter(Filter):
    def __init__(self, column: str, left: float, right: float, closed: ClosedMode, name: Optional[str] = None):
        self.column = column
        self.left = left
        self.right = right
        self.closed: ClosedMode = closed
        self.name = name

    @staticmethod
    def get_type() -> Literal["interval"]:
        return "interval"

    @staticmethod
    def build(params: IntervalFilterModel) -> 'IntervalFilter':
        return IntervalFilter(params["column"], params["left"], params["right"], params["closed"], params.get("name"))

    def apply(self, idf: ImmutableDataFrame, inverse: bool = False) -> ImmutableDataFrame:
        series = idf.float_col(self.column)
        no_missing_mask = np.isfinite(series)
        idf_no_missing = idf[no_missing_mask]
        series_no_missing = idf_no_missing.float_col(self.column)

        if self.closed in ('BOTH', 'LEFT'):
            mask = series_no_missing >= self.left
        else:
            mask = series_no_missing > self.left

        if self.closed in ('BOTH', 'RIGHT'):
            mask &= series_no_missing <= self.right
        else:
            mask &= series_no_missing < self.right

        if inverse:
            # Not being inside interval means "outside interval OR missing value"
            return idf_no_missing[~mask] | idf[~no_missing_mask]
        else:
            # Being in interval means "inside interval AND no missing"
            return idf_no_missing[mask]

    def serialize(self) -> IntervalFilterModel:
        return {
            "type": self.get_type(),
            "left": self.left,
            "right": self.right,
            "closed": self.closed,
            "name": self.name,
            "column": self.column
        }
