from abc import ABC, abstractmethod
from typing import Dict, Generic, Type, TypeVar
from dataiku.eda.types import Final

from dataiku.eda.computations.immutable_data_frame import FloatVector
from dataiku.eda.exceptions import UnknownObjectType
from dataiku.eda.types import CurveModel, CurveTypeLiteral, ParametrizedCurveModel


CurveModelType = TypeVar("CurveModelType", bound=CurveModel)


class Curve(ABC, Generic[CurveModelType]):
    REGISTRY: Final[Dict[CurveTypeLiteral, Type['Curve']]] = {}

    @staticmethod
    @abstractmethod
    def get_type() -> CurveTypeLiteral:
        raise NotImplementedError

    @abstractmethod
    def fit(self, x: FloatVector, y: FloatVector) -> 'ParametrizedCurve':
        raise NotImplementedError

    @staticmethod
    def define(curve_class: Type['Curve']) -> None:
        Curve.REGISTRY[curve_class.get_type()] = curve_class

    @staticmethod
    @abstractmethod
    def build(params: CurveModelType) -> 'Curve':
        try:
            curve_class = Curve.REGISTRY[params["type"]]
        except KeyError:
            raise UnknownObjectType("Unknown curve type: %s" % params.get("type"))
        return curve_class.build(params)


class ParametrizedCurve(ABC):
    @abstractmethod
    def apply(self, x: FloatVector) -> FloatVector:
        raise NotImplementedError

    @abstractmethod
    def serialize(self) -> ParametrizedCurveModel:
        raise NotImplementedError
