import time
from itertools import chain
from typing import overload, Optional

import numpy as np
from tabulate import tabulate

from dataiku.eda.computations.data_frame_store import DataFrameStore


class Context:

    @overload
    def __init__(self, *, parent: 'Context', name: str = ..., brackets: bool = ..., df_store: None = ...):
        ...
    @overload
    def __init__(self, *, parent: None = ..., name: str = ..., brackets: bool = ..., df_store: DataFrameStore):
        ...
    def __init__(self, parent: Optional['Context'] = None, name: str = "", brackets: bool = False, df_store: Optional[DataFrameStore] = None):
        self.childs = []
        self.start = None
        self.name = name
        self.parent = parent
        self.end = None
        if parent is None:
            self.fullname = name
        elif brackets:
            self.fullname = "%s[%s]" % (parent.fullname, name)
        elif parent.fullname:
            self.fullname = "%s.%s" % (parent.fullname, name)
        else:
            self.fullname = name

        if df_store is not None:
            self.df_store = df_store
        elif parent is not None:
            self.df_store = parent.df_store
        else:
            raise Exception("Context cannot be created without either a dataframe store or a parent")

    def __enter__(self):
        assert self.start is None

        self.start = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        assert self.start is not None

        self.end = time.time()
        self.totaltime = self.end - self.start
        self.childtime = np.sum([child.totaltime for child in self.childs])
        self.owntime = self.totaltime - self.childtime

    def summary_table(self, quantile: float = 0.1, max_line: int = 10) -> str:
        summary = self.summary()
        owntimes = sorted([line[1] for line in summary], reverse=True)
        threshold = owntimes[min(int(len(owntimes) * quantile), max_line + 1)]
        cells = [[name, "%.1fms" % (own * 1000), "%.1fms" % (total * 1000)] for name, own, total in summary if
                 own >= threshold]
        return tabulate(cells, headers=["Computation (slowest, quantile=%s, limit=%s)"%(quantile, max_line), "Own", "Total"])

    def summary(self):
        child_summaries = list(chain(*[child.summary() for child in self.childs]))
        wildcard = ".*" if len(child_summaries) > 0 else ""
        return [[self.fullname + wildcard, self.owntime, self.totaltime]] + child_summaries

    def sub(self, name: str, brackets: bool = False) -> 'Context':
        assert self.start is not None
        assert self.end is None

        sub_context = Context(parent=self, name=name, brackets=brackets)
        self.childs.append(sub_context)
        return sub_context
