import traceback

"""
Main entry point of EDA compute engine implementation
This is a server implementing commands defined in the PythonKernelProtocol Java class
"""
import logging

from typing import Optional

from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.core import debugging
from dataiku.eda import builtins
from dataiku.eda.computations.computation import Computation
from dataiku.eda.computations.data_frame_store import DataFrameStore, DataStreamId
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.types import ComputationModel

logger = logging.getLogger(__name__)


class EDAProtocol:
    builtins.load()

    def __init__(self, link: JavaLink):
        self.idf: Optional[ImmutableDataFrame] = None
        self.link = link
        self._stopped: bool = False
        # this instance of the dataframe store is kernel-wide
        self.df_store = DataFrameStore()

    def _handle_load_dataset(self, dss_schema):
        logger.info("Loading dataset...")
        self.link.send_json({"type": "WaitingForData"})
        self.idf = ImmutableDataFrame.from_csv(self.link.read_stream(), dss_schema)
        self.link.send_json({"type": "DatasetReceived"})

    def _handle_computation(self, computation_params: ComputationModel):
        if self.idf is None:
            raise Exception("Dataset is not loaded")

        logger.info("Executing computation...")
        ctx = Context(df_store=self.df_store)
        computation = Computation.build(computation_params)

        with ctx:
            result = computation.apply_safe(self.idf, ctx)

        logger.info("Total computation time: {:.3f}ms".format(1000 * ctx.totaltime))
        if ctx.totaltime > 1:
            # Avoid spamming when it was super fast
            logger.info(ctx.summary_table())

        self.link.send_json({"type": "ComputationResult", "result": result})

    def _handle_fetch_data_stream(self, payload):
        try:
            data_stream_id = DataStreamId.parse(payload)
            with self.link.send_stream() as output_stream:
                self.df_store.write(data_stream_id, output_stream)

            self.link.send_json({"type": "DataStreamEnded"})

        except Exception as e:
            traceback.print_exc()
            traceback.print_stack()
            logger.error(e)

            self.link.send_json({
                "type": "DataStreamEnded",
                "error": "Unexpected error: {}".format(e),
            })

    def _handle_clear_data_streams(self, payload):
        data_stream_ids = [DataStreamId.parse(i) for i in payload]
        self.df_store.clear(data_stream_ids)
        self.link.send_json({"type": "DataStreamCleared"})

    def _handle_close(self):
        self._stopped = True
        logger.info("Kernel server stopped")
        self.link.send_json({"type": "KernelServerStopped"})

    def start(self):
        self._stopped = False

        while not self._stopped:
            command = self.link.read_json()
            if command["type"] == "LoadDataset":
                self._handle_load_dataset(command['schema'])
            elif command["type"] == "Compute":
                self._handle_computation(command["computation"])
            elif command["type"] == "FetchDataStream":
                self._handle_fetch_data_stream(command["dataStreamId"])
            elif command["type"] == "ClearDataStream":
                self._handle_clear_data_streams(command["dataStreamIds"])
            elif command["type"] == "StopKernelServer":
                self._handle_close()
            else:
                raise Exception("Dataset is not loaded")


def serve(port, secret, server_cert=None):
    link = JavaLink(port, secret, server_cert=server_cert)
    link.connect()

    eda = EDAProtocol(link)
    try:
        eda.start()
    finally:
        link.close()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()

    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
