import dataiku
import pandas as pd
from dku_utils.projects.deployed_models.deployed_model_commons import get_deployed_model_id


def score_demand_forecast_dataset(
    project_containing_models,
    demand_forecast_modeling_problem,
    input_dataset_name,
    demand_forecast_dataset_name,
    read_chunk_size=100000,
    old_flag = False
):
    input_dataset = dataiku.Dataset(input_dataset_name)
    demand_forecast_dataset = dataiku.Dataset(demand_forecast_dataset_name)
    demand_forecast_dataset_writer = demand_forecast_dataset.get_writer()

    first_record_index = 1
    last_record_index = read_chunk_size + 1
    flag = True
    for dataframe_chunk in input_dataset.iter_dataframes(chunksize=read_chunk_size):
        print(
            "Reading dataset '{}' data by chunks: Record n° '{}' to '{}'".format(
                input_dataset_name, first_record_index, last_record_index
            )
        )
        if old_flag:
            regular_forecast_df = dataframe_chunk
        else:
            regular_forecast_df = dataframe_chunk[
                dataframe_chunk["product_forecast_type"] == "regular_forecast"
            ].copy()
            cold_start_forecast_df = dataframe_chunk[
                dataframe_chunk["product_forecast_type"] == "cold_start_forecast"
            ].copy()

        predictions_df_columns = list(regular_forecast_df.columns) + ["prediction"]

        if "regular" in demand_forecast_modeling_problem:
            print("Scoring data with the 'regular' demand forecast model ...")
            demand_forecast_model_id = get_deployed_model_id(
                project_containing_models, "demand_forecast_model"
            )
            regular_demand_forecast_model = dataiku.Model(
                demand_forecast_model_id,
                project_key=project_containing_models.project_key,
                ignore_flow=True
            )
            regular_demand_forecast_predictor = (
                regular_demand_forecast_model.get_predictor()
            )
            regular_forecast_predictions_df = regular_demand_forecast_predictor.predict(
                regular_forecast_df, with_input_cols=True, with_prediction=True
            )
            print("Data successfully scored with the 'regular' demand forecast model!")
        else:
            regular_forecast_predictions_df = pd.DataFrame(
                columns=predictions_df_columns
            )

        if "cold_start" in demand_forecast_modeling_problem:
            print("Scoring data with the 'cold_start' demand forecast model ...")
            cold_start_demand_forecast_model_id = get_deployed_model_id(
                project_containing_models, "cold_start_demand_forecast_model"
            )
            cold_start_demand_forecast_model = dataiku.Model(
                cold_start_demand_forecast_model_id,
                project_key=project_containing_models.project_key,
                ignore_flow=True
            )
            cold_start_demand_forecast_predictor = (
                cold_start_demand_forecast_model.get_predictor()
            )
            cold_start_forecast_predictions_df = cold_start_demand_forecast_predictor.predict(
                cold_start_forecast_df, with_input_cols=True, with_prediction=True
            )
            print(
                "Data successfully scored with the 'cold_start' demand forecast model!"
            )
        else:
            cold_start_forecast_predictions_df = pd.DataFrame(
                columns=predictions_df_columns
            )

        demand_forecast_df = pd.concat(
            [regular_forecast_predictions_df, cold_start_forecast_predictions_df]
        )
        print(
            "Writing prediction data in dataset '{}'... ".format(
                demand_forecast_dataset_name
            )
        )
        if flag:
            demand_forecast_dataset.write_schema_from_dataframe(demand_forecast_df)
            flag=False
        demand_forecast_dataset_writer.write_dataframe(demand_forecast_df)
    demand_forecast_dataset_writer.close()
    pass
