from dataiku.eda.types import Literal

from numpy.polynomial import Polynomial

from dataiku.eda.computations.immutable_data_frame import FloatVector
from dataiku.eda.curves.curve import Curve
from dataiku.eda.curves.curve import ParametrizedCurve
from dataiku.eda.types import PolynomialCurveModel, ParametrizedPolynomialCurveModel


class PolynomialCurve(Curve):
    def __init__(self, degree: int):
        self.degree = degree

    @staticmethod
    def get_type() -> Literal["polynomial"]:
        return "polynomial"

    @staticmethod
    def build(params: PolynomialCurveModel) -> 'PolynomialCurve':
        return PolynomialCurve(params["degree"])

    def fit(self, x: FloatVector, y: FloatVector) -> 'ParametrizedPolynomial':
        poly = Polynomial.fit(x, y, self.degree)
        coefs = poly.convert().coef
        return ParametrizedPolynomial(coefs)


class ParametrizedPolynomial(ParametrizedCurve):
    def __init__(self, coefs: FloatVector):
        self.coefs = coefs

    def serialize(self) -> ParametrizedPolynomialCurveModel:
        return {
            "type": PolynomialCurve.get_type(),
            "coefs": list(self.coefs)  # coefs[0] + coefs[1]*x + coefs[2]*x^2 + ...
        }

    def apply(self, x: FloatVector) -> FloatVector:
        return Polynomial(self.coefs)(x)
