/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.pipelines;

import com.dataiku.dss.shadelib.javax.annotation.Nullable;
import com.dataiku.scoring.Try;
import com.dataiku.scoring.linalg.Vector;
import com.dataiku.scoring.models.Regressor;
import com.dataiku.scoring.pipelines.AbstractPipeline;
import com.dataiku.scoring.pipelines.PreprocessingPipeline;
import com.dataiku.scoring.pipelines.RegressionPipeline;
import com.dataiku.scoring.pipelines.RegressionResult;
import com.dataiku.scoring.pipelines.overrides.OverridesLayerBase;
import com.dataiku.scoring.pipelines.uncertainty.Interval;
import com.dataiku.scoring.pipelines.uncertainty.PredictionIntervalModel;
import com.dataiku.scoring.util.RawObservation;
import java.util.Optional;

public class RegressionPipelineImpl
extends AbstractPipeline<Regressor, Double, RegressionResult>
implements RegressionPipeline {
    private static final long serialVersionUID = 0L;
    @Nullable
    private final PredictionIntervalModel predictionIntervalModel;

    public RegressionPipelineImpl(PreprocessingPipeline preprocessing, Regressor model, OverridesLayerBase<RegressionResult> overridesLayer) {
        this(preprocessing, model, overridesLayer, (PredictionIntervalModel)null);
    }

    public RegressionPipelineImpl(PreprocessingPipeline preprocessing, Regressor model, OverridesLayerBase<RegressionResult> overridesLayer, @Nullable PredictionIntervalModel predictionIntervalModel) {
        super(preprocessing, model, overridesLayer, Double.class);
        this.predictionIntervalModel = predictionIntervalModel;
        if (predictionIntervalModel != null) {
            this.columnComputers.add(new PredictionIntervalLowerColumnComputer());
            this.columnComputers.add(new PredictionIntervalUpperColumnComputer());
        }
    }

    @Override
    public Try<RegressionResult> getPredictionResults(RawObservation r) {
        RegressionResult result;
        this.checkInitialized();
        RawObservation originObservation = this.copyRawObservationForPostPredictIfNeededOrNull(r);
        Try<Vector> v = this.preprocessing.process(r);
        if (v.isError()) {
            return Try.failure(v.getMessage());
        }
        Try<Double> prediction = ((Regressor)this.model).predict(v.get());
        if (prediction.isError()) {
            return Try.failure(prediction.getMessage());
        }
        if (this.predictionIntervalModel != null) {
            Try<Interval> predictionInterval = this.predictionIntervalModel.computeInterval(v.get(), prediction.get());
            if (predictionInterval.isError()) {
                return Try.failure(predictionInterval.getMessage());
            }
            result = new RegressionResult(prediction.get(), predictionInterval.get());
        } else {
            result = new RegressionResult(prediction.get());
        }
        return this.postPredict(originObservation, result);
    }

    public static class PredictionIntervalLowerColumnComputer
    extends AbstractPipeline.AbstractColumnComputer<Double, Double, RegressionResult> {
        public static final String COLUMN_NAME = "prediction_interval_lower";

        public PredictionIntervalLowerColumnComputer() {
            super(COLUMN_NAME, Double.class);
        }

        @Override
        public Optional<Double> getOutputValue(RegressionResult result) {
            return result.getPredictionInterval().map(interval -> interval.lower);
        }
    }

    public static class PredictionIntervalUpperColumnComputer
    extends AbstractPipeline.AbstractColumnComputer<Double, Double, RegressionResult> {
        public static final String COLUMN_NAME = "prediction_interval_upper";

        public PredictionIntervalUpperColumnComputer() {
            super(COLUMN_NAME, Double.class);
        }

        @Override
        public Optional<Double> getOutputValue(RegressionResult result) {
            return result.getPredictionInterval().map(interval -> interval.upper);
        }
    }
}

