# encoding: utf-8
import logging

from dataiku.base import dku_pickle
from dataiku.base.dku_pickle import PickleLoadException
from dataiku.base.model_plugin import prepare_for_plugin_from_params
from dataiku.base.utils import safe_unicode_str
from dataiku.doctor.utils import dku_write_mode_for_pickling

logger = logging.getLogger(__name__)


XGBOOST_CLF_ATTRIBUTES_FILENAME = "xgboost_clf_attributes.json"
XGBOOST_BOOSTER_FILENAME = "xgboost_booster.bin"

MODELING_PARAMS_FILE = "rmodeling_params.json"


def load_model_from_folder(folder_context, is_prediction=True):
    # Look for xgboost model in backwards compatible format and perform specific loading
    xgb_clf_attributes_file_exists = folder_context.isfile(XGBOOST_CLF_ATTRIBUTES_FILENAME)
    xgb_booster_file_exists = folder_context.isfile(XGBOOST_BOOSTER_FILENAME)

    if xgb_clf_attributes_file_exists and xgb_booster_file_exists:
        logger.info("Found xgboost booster file, loading it into classifier")
        with folder_context.get_file_path_to_read(XGBOOST_CLF_ATTRIBUTES_FILENAME) as xgb_clf_attributes_path, \
             folder_context.get_file_path_to_read(XGBOOST_BOOSTER_FILENAME) as xgb_booster_path:
            from dataiku.doctor.prediction.dku_xgboost import load_xgboost_model
            return load_xgboost_model(xgb_clf_attributes_path, xgb_booster_path)

    if xgb_clf_attributes_file_exists and not xgb_booster_file_exists:
        logger.warning("Found 'xgboost_clf_attributes.json' but not 'xgboost_booster.bin'. Looking for 'clf.pkl' instead.")

    if xgb_booster_file_exists and not xgb_clf_attributes_file_exists:
        logger.warning("Found xgboost_booster.bin but not xgboost_clf_attributes.json. Looking for 'clf.pkl' instead.")

    # Other cases
    return from_pkl(folder_context, "clf.pkl" if is_prediction else "clusterer.pkl")


def dump_model_to_folder(clf, folder_context, pkl_filename=None, is_prediction=True):
    try:
        from dataiku.doctor.prediction.dku_xgboost import DkuXGBClassifier
        from dataiku.doctor.prediction.dku_xgboost import DkuXGBRegressor
        from dataiku.doctor.prediction.dku_xgboost import dump_xgboost_model
        if isinstance(clf, (DkuXGBRegressor, DkuXGBClassifier)):
            # Save XGBoost models in backwards compatible format
            dump_xgboost_model(folder_context, clf)
    except ImportError as e:
        if hasattr(e, "msg") and "No module named 'xgboost'" in e.msg:
            pass # The doctor is meant to work without xgboost in the code env, and in that case the user can't have trained an XGBoost model anyway
    if not pkl_filename:
        pkl_filename = "clf.pkl" if is_prediction else "clusterer.pkl"
    # Always pickle model, even for XGBoost (for ensemble models compatibility)
    to_pkl(clf, folder_context, pkl_filename)


def _load_model_from_file(clf_file):
    try:
        from xgboost.core import XGBoostError
    except ImportError:
        pass # The doctor is meant to work without xgboost in the code env
    try:
        return dku_pickle.load(clf_file)
    except UnicodeDecodeError:
        raise(PickleLoadException(u"Failed to unpickle {}. You might have been trying to load a model "
                                             u"saved in a python 2 code environment with a python 3 one.".format(clf_file.name)))
    except XGBoostError:
        # TODO @xgboostUpgrade add plugin link
        raise(PickleLoadException(u"Failed to load XGBoost model '{}'. The model has probably been trained with an older "
                                  u"XGBoost version. To score using this model, you can either retrain it, score using a code environment "
                                  u"with the older version, score using optimized Java scoring (not compatible with explanations), or upgrade "
                                  u"the model to the new forwards compatible format using the XGBoost model upgrade plugin (ask you administrator "
                                  u"to install it). See https://www.dataiku.com/product/plugins/xgboost-version-bump".format(clf_file.name)))
    except Exception as e:
        from sys import exc_info
        if safe_unicode_str(e) == "non-string names in Numpy dtype unpickling":
            raise(PickleLoadException(u"Failed to unpickle {}. You might have been trying to load a model "
                                                 u"saved in a python 3 code environment with a python 2 one.".format(clf_file.name)))
        elif safe_unicode_str(e).startswith("No module named 'sklearn."):  # Notice the dot !
            import sklearn
            raise(PickleLoadException(u"You are trying to load a model from {} trained with a different version of scikit-learn. \n"
                                                 "Current context loads version {}".format(clf_file.name, sklearn.__version__)))
        else:
            raise(PickleLoadException(u"Failed to load model from file {}".format(clf_file.name)))


def from_pkl(folder_context, pkl_filename="clf.pkl"):
    """
    :type folder_context: dataiku.base.folder_context.FolderContext
    :param str pkl_filename: file name of the pickled model
    :rtype: sklearn.base.BaseEstimator
    """
    with folder_context.get_file_path_to_read(pkl_filename) as model_pkl_path:
        try:
            with open(model_pkl_path, "rb") as f:
                clf = _load_model_from_file(f)
        except PickleLoadException as e:
            # If we cannot load the pickle we want to try to see if we need to prepare the env for plugins,
            # and we try loading it again.
            if folder_context.isfile(MODELING_PARAMS_FILE):
                params = folder_context.read_json(MODELING_PARAMS_FILE)
                prepare_for_plugin_from_params(params)
                with open(model_pkl_path, "rb") as f:
                    clf = _load_model_from_file(f)
            else:
                raise e
    return clf


def to_pkl(clf, folder_context, pkl_filename="clf.pkl"):
    """
    :param sklearn.base.BaseEstimator clf: classifier to pickle
    :type folder_context: dataiku.base.folder_context.FolderContext
    :param str pkl_filename: file name to give to the pickled model
    """
    with folder_context.get_file_path_to_write(pkl_filename) as model_file:
        with open(model_file, dku_write_mode_for_pickling()) as f:
            dku_pickle.dump(clf, f)
