import logging
import traceback
from abc import ABC, abstractmethod
from typing import Dict, Generic, List, Type, TypeVar
from dataiku.eda.types import Final

import numpy as np

from dataiku.eda.computations.context import Context
from dataiku.eda.computations.data_frame_store import DataFrameStore
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import EdaComputeError
from dataiku.eda.exceptions import InvalidResultError
from dataiku.eda.exceptions import UnknownObjectType
from dataiku.eda.types import ComputationModel, ComputationResultModel, ComputationTypeLiteral, EdaErrorCodes, FailedResultModel

logger = logging.getLogger(__name__)


ComputationModelType = TypeVar("ComputationModelType", bound=ComputationModel)


class Computation(ABC, Generic[ComputationModelType]):
    REGISTRY: Final[Dict[ComputationTypeLiteral, Type['Computation']]] = {}

    @staticmethod
    @abstractmethod
    def get_type() -> ComputationTypeLiteral:
        raise NotImplementedError

    @abstractmethod
    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> ComputationResultModel:
        raise NotImplementedError

    def describe(self) -> str:
        return self.__class__.__name__

    @staticmethod
    def _require_result_checking() -> bool:
        return True

    def apply_safe(self, idf: ImmutableDataFrame, ctx: Context) -> ComputationResultModel:
        with ctx.sub(self.describe()) as sub:
            try:
                result = self.apply(idf, sub)
                if self._require_result_checking():
                    result = Computation._check_and_fix_result(result)
                return result

            except EdaComputeError as e:
                # Error directly produced by EDA
                return Computation._failed_result(e.CODE, "%s" % e)

            except Exception as e:
                # Catch-all handler for cases where exception hasn't been thrown explicitly by EDA
                # In this case, we are likely interested by the full stack trace
                traceback.print_exc()
                traceback.print_stack()
                logger.error(e)

                return Computation._failed_result(EdaComputeError.CODE, "Unexpected error: %s" % e)

    # Make sure computation results are well-formed:
    # - They must be JSON-serializable
    # - Numpy float/int primitives are converted into Python primitives
    @staticmethod
    def _check_and_fix_result(obj):
        # Handle dicts
        if isinstance(obj, dict):
            return {
                Computation._check_key(k): Computation._check_and_fix_result(v)
                for k, v in obj.items()
            }
        # Handle arrays
        elif isinstance(obj, list):
            return [Computation._check_and_fix_result(v) for v in obj]

        # Unbox Numpy primitives
        if isinstance(obj, (np.integer, np.floating)):
            obj = obj.item()

        if isinstance(obj, float):
            if np.isnan(obj) or np.isinf(obj):
                raise InvalidResultError("Invalid NaN/Inf in result. Changing the settings may help.")
        elif obj is None or obj is True or obj is False or isinstance(obj, int) or isinstance(obj, str):
            pass  # Always valid primitives
        else:
            raise InvalidResultError("Output type is not serializable: %s" % obj.__class__)

        return obj

    @staticmethod
    def _check_key(key: str) -> str:
        if not isinstance(key, str):
            raise InvalidResultError("Keys must be str")
        return key

    @staticmethod
    @abstractmethod
    def build(params: ComputationModelType) -> 'Computation':
        try:
            computation_class = Computation.REGISTRY[params["type"]]
        except KeyError:
            raise UnknownObjectType("Unknown computation type: %s" % params.get("type"))
        return computation_class.build(params)

    @staticmethod
    def define(computation_class: Type['Computation']) -> None:
        Computation.REGISTRY[computation_class.get_type()] = computation_class

    @staticmethod
    def _failed_result(code: EdaErrorCodes, message: str) -> FailedResultModel:
        return {
            "type": "failed",
            "code": code,
            "message": message
        }


class UnivariateComputation(Computation):
    def __init__(self, column: str):
        self.column = column

    def describe(self) -> str:
        return "%s(%s)" % (self.__class__.__name__, self.column)


class BivariateComputation(Computation):
    def __init__(self, x_column: str, y_column: str):
        self.x_column = x_column
        self.y_column = y_column

    def describe(self) -> str:
        return "%s(y=%s, y=%s)" % (self.__class__.__name__, self.x_column, self.y_column)


class MultivariateComputation(Computation):
    def __init__(self, columns: List[str]):
        self.columns = columns

    def describe(self) -> str:
        return "%s(%s)" % (self.__class__.__name__, ', '.join(self.columns))
