import logging
import time
import threading

from dataiku.doctor.utils import unix_time_millis

logger = logging.getLogger(__name__)


class AbstractResultStore(object):
    def append_split_results(self, new_points):
        """
        Persist multiple results of the evaluation of one hyper parameter on individual splits
        """
        raise NotImplementedError

    def find_split_result(self, split_id, parameters):
        """
        Lookup the result of the evaluation of one hyper parameter on one split

        This is used to skip already computed points when the search is paused/resumed
        """
        raise NotImplementedError

    def init_result_file(self, n_candidates, real_nthreads, n_splits, evaluation_metric, timeout):
        raise NotImplementedError

    def append_aggregated_result(self, aggregated_result):
        """
        Persist the aggregated result of the evaluation of one hyper parameter on all splits
        """
        raise NotImplementedError

    def update_final_grid_size(self):
        """
        Update the 'gridSize' to match the actual number of explored points
        """
        raise NotImplementedError

    def save_current_progress(self):
        """
        Save the multiple split results to a json file for the search to be resumed later.
        """
        raise NotImplementedError


class OnDiskResultStore(AbstractResultStore):
    """
    Handle persistence of intermediate & aggregated search results
    - Per-split results are stored in "grid_search_done_py.json"
    - Aggregated results are stored in "grid_search_scores.json"

    This class is thread safe and allows multiple search workers to use it at the same time
    """

    PER_SPLIT_RESULTS_FILENAME = 'grid_search_done_py.json'
    AGGREGATED_RESULTS_FILENAME = 'grid_search_scores.json'
    MIN_WRITE_DELAY_IN_SECONDS = 1

    def __init__(self, model_folder_context):
        self._model_folder_context = model_folder_context
        self._lock = threading.Lock()
        self._points = self._get_per_split_results_from_file()
        self._scores = {} # Needs to be initialized by init_result_file()
        self._last_write_epoch_time_in_seconds = 0

    def append_split_results(self, new_points):
        with self._lock:
            for new_point in new_points:
                # Do not insert a result if it's already inserted (during resume)
                if not any((point.get('split_id') == new_point['split_id']
                            and point["parameters"] == new_point["parameters"]) for point in self._points):
                    self._points.append(new_point)

    def find_split_result(self, split_id, parameters):
        with self._lock:
            for point in self._points:
                # Note: 'split_id' isn't mandatory because it was not present in older DSS version
                #        and we don't want to break the 'resume' feature
                if point.get('split_id') == split_id and point["parameters"] == parameters:
                    return point
            return None

    def init_result_file(self, n_candidates, real_nthreads, n_splits, evaluation_metric, timeout):
        with self._lock:
            file_name = OnDiskResultStore.AGGREGATED_RESULTS_FILENAME
            if self._model_folder_context.isfile(file_name):
                self._scores = self._get_aggregated_results_from_file()
                # Already initialized, meaning search is being resumed => do nothing
                return

            self._scores = {
                'startedAt': unix_time_millis(),
                'gridSize': n_candidates,
                'nThreads': real_nthreads,
                'nSplits': n_splits,
                'metric': evaluation_metric,
                'timeout': timeout,
                'gridPoints': []
            }
            self._model_folder_context.write_json(file_name, self._scores)

    def update_final_grid_size(self):
        """
        It is possible to explore less points than initially planned in some cases (eg. when strategy generates duplicates)
        This method update the 'gridSize' value to match the actual number of (deduplicated) explored points
        """
        with self._lock:
            self._scores["gridSize"] = len(self._scores["gridPoints"])
            self._save_aggregated_results_to_file()

    def append_aggregated_result(self, aggregated_result):
        """
        Appends the aggregated result of the evaluation of one hyper parameter on all splits
        """
        with self._lock:
            for existing_result in self._scores['gridPoints']:
                if existing_result['parameters'] == aggregated_result['parameters']:
                    # Already inserted
                    return
            self._scores["gridPoints"].append(aggregated_result)
            now_epoch_time_in_seconds = time.time()
            if now_epoch_time_in_seconds - self._last_write_epoch_time_in_seconds >= OnDiskResultStore.MIN_WRITE_DELAY_IN_SECONDS:
                self._save_aggregated_results_to_file()
                self._last_write_epoch_time_in_seconds = now_epoch_time_in_seconds

    def save_current_progress(self):
        with self._lock:
            self._save_aggregated_results_to_file()
            self._save_per_split_results_to_file()

    def _get_aggregated_results_from_file(self):
        return self._model_folder_context.read_json(OnDiskResultStore.AGGREGATED_RESULTS_FILENAME)

    def _get_per_split_results_from_file(self):
        file_name = OnDiskResultStore.PER_SPLIT_RESULTS_FILENAME
        if not self._model_folder_context.isfile(file_name):
            return []
        return self._model_folder_context.read_json(file_name)

    def _save_aggregated_results_to_file(self):
        assert self._lock.locked()
        self._model_folder_context.write_json(OnDiskResultStore.AGGREGATED_RESULTS_FILENAME, self._scores)

    def _save_per_split_results_to_file(self):
        assert self._lock.locked()
        self._model_folder_context.write_json(OnDiskResultStore.PER_SPLIT_RESULTS_FILENAME, self._points)


class NoopResultStore(AbstractResultStore):
    """
    Fake result store which doesn't persist results on-disk
    """

    def append_split_results(self, new_points):
        pass

    def find_split_result(self, split_id, parameters):
        return None

    def init_result_file(self, n_candidates, real_nthreads, n_splits, evaluation_metric, timeout):
        pass

    def append_aggregated_result(self, aggregated_result):
        pass

    def update_final_grid_size(self):
        pass

    def save_current_progress(self):
        pass
