(function() {
'use strict';

const app = angular.module('dataiku.analysis.mlcore');

app.constant('CAUSAL_META_LEARNERS', {
    displayNames: ["S-learner", "T-learner", "X-learner"],
    rawNames: ["S_LEARNER", "T_LEARNER", "X_LEARNER"],
});

app.controller("CausalPMLTaskBaseController", function($scope, $controller, CachedAPICalls, AlgorithmsSettingsService) {
    $controller("_PMLTrainSessionController", { $scope: $scope });

    $scope.deferredAfterInitMlTaskDesign.then(() => CachedAPICalls.pmlGuessPolicies)
    .then(() => {
        $scope.setAlgorithms($scope.mlTaskDesign);
    })
    .catch(setErrorInScope.bind($scope));
});

app.controller("CausalPMLTaskDesignController", function($scope, $controller, $stateParams, $filter, DataikuAPI, CreateModalFromTemplate, CAUSAL_META_LEARNERS, Dialogs, AlgorithmsSettingsService) {
    $controller("_PMLTrainSessionController", {$scope: $scope});
    $controller("_TabularPMLTaskDesignController", {$scope: $scope});

    $scope.META_LEARNERS = CAUSAL_META_LEARNERS.rawNames;

    $scope.selectMethod = function(method) {
        if ($scope.uiState.selectedCausalMethod === method) return;

        $scope.uiState.selectedCausalMethod = method;

        if (method === 'META_LEARNER') {
            $scope.setSelectedAlgorithm(AlgorithmsSettingsService.getDefaultAlgorithm(
                $scope.mlTaskDesign,
                $scope.algorithms[$scope.mlTaskDesign.backendType].filter(algo => $scope.showAlgorithm(algo))
            ));
        }
        if (method === 'CAUSAL_FOREST') {
            $scope.setSelectedAlgorithm('causal_forest');
        }
    };

    $scope.toggleMetaLearner = function(metaLearner) {
        const idx = $scope.mlTaskDesign.modeling.meta_learners.indexOf(metaLearner);
        if (idx > -1) {
            $scope.mlTaskDesign.modeling.meta_learners.splice(idx, 1);
        } else {
            $scope.mlTaskDesign.modeling.meta_learners.push(metaLearner);
        }
    };

    $scope.getNbMetaLearnerAlgorithms = function() {
        const nbMetaLearners = $scope.mlTaskDesign.modeling.meta_learners.length;
        return nbMetaLearners * getNbBaseLearners();
    };

    $scope.getNbEnabledMetaLearnerAlgorithms = function() {
        const nbMetaLearners = $scope.mlTaskDesign.modeling.meta_learners.length;
        return nbMetaLearners * $scope.getNbEnabledBaseLearners();
    };

    $scope.getNbEnabledBaseLearners = function() {
        return getNbBaseLearners(true);
    };

    function getNbBaseLearners(onlyEnabled) {
        return (($scope.algorithms && $scope.algorithms[$scope.mlTaskDesign.backendType]) || []).filter(function(algo) {
            return (!algo.condition || algo.condition())
                && algo.supportedCausalMethod === 'META_LEARNER'
                && (!onlyEnabled || $scope.getAlgorithmModeling(algo.algKey).enabled);
        }).length;
    }

    // Causal tasks can either be regression or binary classification
    $scope.isCausalClassification = function() {
        return $scope.mlTaskDesign.predictionType === 'CAUSAL_BINARY_CLASSIFICATION';
    };

    $scope.predictionTypes = $scope.predictionTypes.filter(type => type.causal);

    $scope.getFeaturesExceptTarget = function() {
       let features = $scope.mlTaskDesign.preprocessing.per_feature;
       return Object.keys(features)
           .filter(key => features[key].role !== "TARGET")
           .reduce((obj, key) => {
               obj[key] = features[key];
               return obj;
           }, {});
    };

    $scope.displayControlValue = function() {
        const controlValue = $scope.mlTaskDesign.controlValue;
        return controlValue === "" ? "<Empty>" : controlValue;
    };

    $scope.displayTreatmentValue = function(treatmentValue) {
        return treatmentValue === "" ? "<Empty>" : treatmentValue;
    };

    $scope.onChangeTreatmentVariable = function() {
        if (!$scope.uiState.treatmentVariable) return;
        if ($scope.dirtySettings()) {
            $scope.saveSettings();
        }
        CreateModalFromTemplate("/templates/analysis/prediction/change-core-params-modal.html", $scope, "PMLChangeBasicParamsModal", function(newScope) {
            newScope.paramKey = "treatmentVariable";
            newScope.onCloseCallback = function() {
                refreshTreatmentColValues();
            };
        });
    };

    $scope.onChangePositiveOutcomeClass = function() {
        $scope.mlTaskDesign.preprocessing.target_remapping.forEach(x => { x.mappedValue = Number(x.sourceValue === $scope.mlTaskDesign.positiveClass) });
        $scope.outcomeStats.averageOutcomes = getAverageOutcomes();
    };

    function setBoundsForColorScale() {
        // Get min & max values of the cells to be colored
        const scaleMinMax = { min: Number.MAX_VALUE, max: - Number.MAX_VALUE };
        if ($scope.mlTaskDesign.enableMultiTreatment) {
            $scope.outcomeStats.rowCountsByOutcomeBins.forEach(function(row) {
                scaleMinMax.min = Math.min(Math.min(...Object.values(row.perTreatment)), scaleMinMax.min);
                scaleMinMax.max = Math.max(Math.max(...Object.values(row.perTreatment)), scaleMinMax.max);
            });
        } else {
            $scope.outcomeStats.rowCountsByOutcomeBins.forEach(function(row) {
                scaleMinMax.min = Math.min(row.control, row.treated, scaleMinMax.min);
                scaleMinMax.max = Math.max(row.control, row.treated, scaleMinMax.max);
            });
        }


        const step = (scaleMinMax.max - scaleMinMax.min) / TREATMENT_COLUMN_STATS_TABLE_COLOR_SCALE.length;
        return TREATMENT_COLUMN_STATS_TABLE_COLOR_SCALE.map(function(_, idx) {
            return scaleMinMax.min + (idx + 1) * step;
        });
    }

    const TREATMENT_COLUMN_STATS_TABLE_COLOR_SCALE = [  // digital blue palette, from lighten-4 to darken-1
        "#C4E0FE", "#9DCCFE", "#76B8FD", "#58A8FC", "#3B99FC", "#3591FC",
    ];

    let boundsForColorScale;
    $scope.getCellColor = function(value) {
        if(!boundsForColorScale) return;

        if (!value) return TREATMENT_COLUMN_STATS_TABLE_COLOR_SCALE[0];

        let idx = 0;
        while(value > boundsForColorScale[idx]) {
            idx++;
        }
        return TREATMENT_COLUMN_STATS_TABLE_COLOR_SCALE[idx];
    };

    function getAverageOutcomes() {
        const totalRowCounts = $scope.outcomeStats.totalRowCounts;
        const perTreatment = {};
        if ($scope.isCausalClassification()) {
            const positiveClassIdx = $scope.outcomeStats.outcomeClasses.indexOf($scope.mlTaskDesign.positiveClass);
            const rowCountsForPositiveOutcome = $scope.outcomeStats.rowCountsByOutcomeBins[positiveClassIdx];
            $scope.mlTaskDesign.treatmentValues.forEach(function(t) {
                if (t in rowCountsForPositiveOutcome.perTreatment) {
                    perTreatment[t] = rowCountsForPositiveOutcome.perTreatment[t] / totalRowCounts.perTreatment[t];
                } else {
                    perTreatment[t] = 0;
                }
            });

            return {
                control: rowCountsForPositiveOutcome.control / totalRowCounts.control,
                treated: rowCountsForPositiveOutcome.treated / totalRowCounts.treated,
                total: (rowCountsForPositiveOutcome.control + rowCountsForPositiveOutcome.treated) / (totalRowCounts.control + totalRowCounts.treated),
                perTreatment,
            };
        } else {
            const outcomeSums = $scope.outcomeStats.outcomeSums;
            $scope.mlTaskDesign.treatmentValues.forEach(t => perTreatment[t] = outcomeSums.perTreatment[t] / totalRowCounts.perTreatment[t]);

            return {
                control: outcomeSums.control / totalRowCounts.control,
                treated: outcomeSums.treated / totalRowCounts.treated,
                total: (outcomeSums.control + outcomeSums.treated) / (totalRowCounts.control + totalRowCounts.treated),
                perTreatment,
            };
        }
    }

    function getOutcomeBins() {
        if ($scope.isCausalClassification()) return $scope.outcomeStats.outcomeClasses;

        const bounds = $scope.outcomeStats.outcomeIntervalBounds;
        const binIndices = [0, 1, 2, 3]; // indices for the four interquartile ranges
        return binIndices.map(function(idx) {
            return `[${$filter('nicePrecision')(bounds[idx])}, ${$filter('nicePrecision')(bounds[idx + 1])}]`;
        });
    }

    function getTreatmentValuesNoControl() {
        if ($scope.mlTaskDesign.preprocessing.dropMissingTreatmentValues) {
            return $scope.mlTaskDesign.treatmentValues.filter(x => (x != $scope.mlTaskDesign.controlValue) && x != "");
        } else {
            return $scope.mlTaskDesign.treatmentValues.filter(x => (x != $scope.mlTaskDesign.controlValue));
        }
    }

    $scope.fillOutcomeStatsPerTreatmentStateTable = function() {
        DataikuAPI.analysis.pml.getOutcomeStatsPerTreatmentState(
            $stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId, $scope.mlTaskDesign.controlValue,
            $scope.mlTaskDesign.preprocessing.dropMissingTreatmentValues
        ).success(function(data) {
            $scope.outcomeStats = data;
            $scope.outcomeStats.averageOutcomes = getAverageOutcomes();
            $scope.outcomeBins = getOutcomeBins();
            $scope.treatmentValuesNoControl = getTreatmentValuesNoControl();
            boundsForColorScale = setBoundsForColorScale();
        }).error(setErrorInScope.bind($scope));
    };

    function refreshTreatmentColValues() {
        DataikuAPI.analysis.pml.getTreatmentColValues($stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId).success(function(data) {
            $scope.mlTaskDesign.treatmentValues = data;
            $scope.treatmentValuesNoControl = getTreatmentValuesNoControl();
        });
    }

    $scope.displayTreatmentColValues = function(value) {
        return value === '' ? '<Empty>' : value;
    };

    $scope.controlValueMissingFromDesignSample = function() {
        return !$scope.mlTaskDesign.treatmentValues.includes($scope.mlTaskDesign.controlValue);
    };

    $scope.warnNegativeEstimatedATE = function() {
        return $scope.outcomeStats && ($scope.outcomeStats.averageOutcomes.treated - $scope.outcomeStats.averageOutcomes.control < 0);
    };

    $scope.onChangePropensityEnabled = function() {
        const inversePropensityWeightsEnabled = $scope.mlTaskDesign.modeling.metrics.causalWeighting === "INVERSE_PROPENSITY";
        const enabledDiags = $scope.mlTaskDesign.diagnosticsSettings.settings.filter(x => x.enabled).map(x => x.type);
        const propensityBasedDiagnosticsEnabled = $scope.mlTaskDesign.diagnosticsSettings.enabled && (enabledDiags.includes("ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS") || enabledDiags.includes("ML_DIAGNOSTICS_CAUSAL_PROPENSITY_CHECKS"));
        if ($scope.mlTaskDesign.modeling.propensityModeling.enabled && (inversePropensityWeightsEnabled || propensityBasedDiagnosticsEnabled)) {
            const warningMessage = "Current " +
                (inversePropensityWeightsEnabled ? "inverse propensity weighting of the metrics" : "") +
                (inversePropensityWeightsEnabled && propensityBasedDiagnosticsEnabled ? " and " : "") +
                (propensityBasedDiagnosticsEnabled ? "diagnostics relying on propensity" : "") +
                " will be disabled.";
            Dialogs.confirmAlert($scope, "Disabling Treatment Analysis", "Are you sure you want to disable Treatment Analysis (propensity modeling)?", warningMessage, "WARNING").then(function() {
                $scope.mlTaskDesign.modeling.propensityModeling.enabled = false;
                // Disable inverse propensity weighted metrics
                $scope.mlTaskDesign.modeling.metrics.causalWeighting = "NO_WEIGHTING";
                // Disable propensity-based diagnostics
                $scope.mlTaskDesign.diagnosticsSettings.settings.filter(x => ["ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS", "ML_DIAGNOSTICS_CAUSAL_PROPENSITY_CHECKS"].includes(x.type)).forEach(x => x.enabled = false);
                $scope.saveSettings();
            }, function() {
                // dialog is closed
                $scope.uiState.enablePropensityModeling = $scope.mlTaskDesign.modeling.propensityModeling.enabled;
            });
        } else {
            $scope.mlTaskDesign.modeling.propensityModeling.enabled = $scope.uiState.enablePropensityModeling;
        }
    }

    $scope.causalMetricsWeightingOptions =  [["NO_WEIGHTING", "No weighting"],
                                             ["INVERSE_PROPENSITY", "Inverse propensity weighting"]];

    $scope.$watch('mlTaskDesign', function(nv) {
        if (!nv) return;
        $scope.uiState.predictionType = nv.predictionType;
        $scope.uiState.treatmentVariable = nv.treatmentVariable;
        $scope.uiState.enablePropensityModeling = nv.modeling.propensityModeling.enabled;

        $scope.fillOutcomeStatsPerTreatmentStateTable();

        refreshTreatmentColValues();

        // backendType assumed to be PY_MEMORY
        $scope.uiState.hyperparamSearchStrategies =  [
            ["GRID", "Grid search"],
            ["RANDOM", "Random search"],
            ["BAYESIAN", "Bayesian search"]
        ];
    });
});

app.directive("causalMetricsChart", function() {
    return {
        scope: {
            metric: "<",
            netUpliftPoint: "<"
        },
        template: `<div class="h100">
            <div block-api-error />
            <ng2-lazy-echart [options]="chartOptions" ng-if="chartOptions"></ng2-lazy-echart>
        </div>`,
        restrict: 'A',
        link: function(scope) {
            function updateChart() {
                if (!scope.metric || scope.netUpliftPoint === undefined) return;

                const COLORS = {
                    blueHighlight: "#1F77B4",  // @custom-ml-chart-blue
                    grey: "#666",              // @grey-lighten-3
                    lighterBlue: "#AEC7E8",    // lighter color paired with @custom-ml-chart-blue in d3 category 20 scale
                    randomLine: "#FF7F0E",     // @custom-ml-chart-orange
                };

                // static dummy data, for illustration purpose only
                const upliftModelData = [ [0, 0], [20, .8], [40, 1], [60, 1.15], [80, 1.15], [100, 1] ];
                const qiniModelData = [ [0, 0], [20, .85], [40, .95], [60, 1.05], [80, 1.025], [100, 1] ];
                const randomAssignmentData = upliftModelData.map((_, idx) => [ idx * 20, idx * 20 / 100 ]);

                // Define current model curves + random model line
                const upliftSeries = {
                    color: COLORS.grey,
                    data: upliftModelData,
                    emphasis: { disabled: true },
                    lineStyle: { width: 1 },
                    name: "Uplift",
                    symbol: 'none',
                    type: 'line',
                };

                const qiniSeries = {
                    color: COLORS.grey,
                    data: qiniModelData,
                    emphasis: { disabled: true },
                    lineStyle: { width: 1 },
                    name: "Qini",
                    symbol: 'none',
                    type: 'line',
                };

                let curveWithArea;
                if (scope.metric === "QINI") {
                    curveWithArea = qiniSeries;
                }
                if (scope.metric === "AUUC") {
                    curveWithArea = upliftSeries;
                }
                if (curveWithArea) {
                    curveWithArea.areaStyle = { color: COLORS.lighterBlue };
                    curveWithArea.color = COLORS.blueHighlight;
                    curveWithArea.data = curveWithArea.data.map(dataPoint => [ dataPoint[0], dataPoint[1] - dataPoint[0] / 100 ]);
                    curveWithArea.lineStyle.width = 2;
                    curveWithArea.stack = "area-under-curve";
                }

                const randomModelSeries = {
                    color: COLORS.randomLine,
                    data: randomAssignmentData,
                    emphasis: { disabled: true },
                    lineStyle: { type: 'dashed' },
                    name: "Random",
                    stack: scope.metric === "NET_UPLIFT" ? "" : "area-under-curve",
                    symbol: 'none',
                    type: 'line',
                };

                // Define net uplift segment
                const modelUpliftPointX = scope.netUpliftPoint * 100;
                const segmentStartIdx = Math.floor(modelUpliftPointX / 20);
                const segmentParameter = modelUpliftPointX / 20 - segmentStartIdx;
                const segmentStart = upliftModelData[segmentStartIdx];
                const segmentEnd = upliftModelData[segmentStartIdx + 1] || upliftModelData[upliftModelData.length - 1];

                const modelUpliftPointY = (1 - segmentParameter) * segmentStart[1] + segmentParameter * segmentEnd[1];

                const xAxisOffset = 10; // to align the two x axes we use
                const netUpliftSeries = {
                    data: [
                        [ modelUpliftPointX + xAxisOffset, modelUpliftPointY ],    // higher point (on the model curve)
                        [ modelUpliftPointX + xAxisOffset, scope.netUpliftPoint ], // lower point (on the random line)
                    ],
                    emphasis: { disabled: true },
                    name: "Net uplift",
                    type: 'line',
                    symbol: 'none',
                    xAxisIndex: 1,
                };

                if (scope.metric === "NET_UPLIFT") {
                    netUpliftSeries.lineStyle = { color: COLORS.blueHighlight }
                    netUpliftSeries.markPoint = {
                        data: [{ coord: [ modelUpliftPointX + xAxisOffset, modelUpliftPointY ] }],
                        emphasis: { disabled: true },
                        itemStyle: { color: COLORS.blueHighlight },
                        symbol: "circle",
                        symbolSize: 6,
                    }
                    upliftSeries.color = COLORS.blueHighlight;
                } else {
                    netUpliftSeries.lineStyle = { color: COLORS.grey, width: 2, type : "dotted" }
                }

                scope.chartOptions = {
                    animation: false,
                    grid: {bottom: 64, top: 40, left: 32, right: 8},
                    legend: {
                        data: ["Uplift", "Qini", "Random", "Net uplift"],
                        itemStyle: { opacity: 0 },
                        selectedMode: false,
                        x: "center",
                        y: 16,
                    },
                    textStyle: { fontFamily: 'SourceSansPro' },
                    xAxis: [{
                        axisTick: { show: false },
                        name: "Fraction of the test observations, sorted by decreasing predicted individual effect\n(% of total test observations)",
                        nameGap: 32,
                        nameLocation: "middle",
                        type: 'category',
                        position: 'bottom'
                    }, {
                        axisLabel: { show: false },
                        axisTick: { show: false },
                        interval: 20,
                        max: 120,
                        min: 0,
                        position: 'bottom',
                        type: 'value',
                    }],
                    yAxis: {
                        axisLabel: { show: false },
                        axisTick: { show: false },
                        name: "Cumulative effect",
                        nameLocation: "middle",
                        type: "value",
                    },
                    series: [ randomModelSeries, upliftSeries, qiniSeries, netUpliftSeries ]
                };
            }

        scope.$watchGroup(["metric", "netUpliftPoint"], updateChart);
        }
    };
});

app.controller("CausalPMLTaskPreTrainModal", function($scope, $controller, $stateParams, DataikuAPI, Logger, WT1) {
    $controller("PMLTaskPreTrainModal", { $scope });
    $controller("_TabularPMLTaskPreTrainBase", { $scope });

    $scope.train = $scope._doTrainThenResolveModal;

    $scope._doTrain = function () {
        try {
            const algorithms = {};
            $.each($scope.mlTaskDesign.modeling, function (alg, params) {
                if (params.enabled) {
                    algorithms[alg] = params;
                }
            });

            const eventContent = {
                backendType: $scope.mlTaskDesign.backendType,
                taskType: $scope.mlTaskDesign.taskType,
                predictionType: $scope.mlTaskDesign.predictionType,
                guessPolicy: $scope.mlTaskDesign.guessPolicy,
                feature_selection_params: JSON.stringify($scope.mlTaskDesign.preprocessing.feature_selection_params),
                metaLearners: $scope.mlTaskDesign.modeling.meta_learners,
                algorithms: JSON.stringify(algorithms),
                hasSessionName: !!$scope.uiState.userSessionName,
                hasSessionDescription: !!$scope.uiState.userSessionDescription,
                gridSearchParams: JSON.stringify($scope.mlTaskDesign.modeling.gridSearchParams),
                runsOnKubernetes: $scope.hasSelectedK8sContainer(),
            };

            WT1.event("prediction-train", eventContent);
        } catch (e) {
            Logger.error('Failed to report mltask info', e);
        }
        return DataikuAPI.analysis.pml.trainStart($stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId,
            $scope.uiState.userSessionName, $scope.uiState.userSessionDescription, $scope.uiState.forceRefresh, true).error(setErrorInScope.bind($scope));
    };

    $scope.displayMessages = function() {
        return $scope.preTrainStatus.messages.length || $scope.mlTaskDesign.modeling.xgboost.enable_cuda ;
    };
    // TODO @casual: add specific warnings
    // e.g. if meta-learner selected but no base learner
});

app.controller("CausalPMLTaskResultController", function($scope, $controller) {
    $controller("_TabularPMLTaskResultController",{$scope:$scope});
});

app.controller("CausalPMLModelReportController", function($scope, $state, $stateParams, MLRoutingService) {

    $scope.isMultiValueTreatment = function() {
        return $scope.modelData.coreParams.enable_multi_treatment && $scope.modelData.coreParams.treatment_values.length > 2;
    };

    $scope.goToGlobalModel = function() {
        $state.go('^.' + MLRoutingService.getPredictionReportSummaryTab(false, false),
                  {fullModelId: $scope.modelData.fullModelId, treatment: null}, {reload: !!$stateParams.insightId}
        );
    };

    $scope.goToPerTreatmentModel = function(treatment) {
        if (!treatment) {
            // Switch from overall model to per treatment, go to the first one in the list
            treatment = $scope.getNonControlTreatmentValues()[0];
        }
        $scope.uiState.currentTreatmentName = treatment;
        $state.go('.', {fullModelId: $scope.modelData.fullModelId, treatment}, {reload: !!$stateParams.insightId});
    };

    $scope.isOnGlobalModel = function() {
        return $scope.isMultiValueTreatment() && !$stateParams.treatment;
    };

    $scope.getNonControlTreatmentValues = function() {
        const treatments = $scope.modelData.coreParams.treatment_values.filter(x => x !== $scope.modelData.coreParams.control_value);
        if (!$scope.modelData.preprocessing.drop_missing_treatment_values) {
            return treatments.map(x => x === "" ? "<Empty>" : x);
        } else {
            return treatments.filter(x => x !== "");
        }
    };

    if ($scope.isMultiValueTreatment()) {
        if ($stateParams.treatment) {
            $scope.uiState.currentTreatmentName = $stateParams.treatment;
        } else {
            $scope.uiState.currentTreatmentName = null;
        }
    };
});

})();
