package com.dataiku.exports;

import java.io.*;
import java.nio.file.*;
import java.util.*;

import com.dataiku.scoring.*;
import com.opencsv.CSVReader;
import com.opencsv.CSVWriter;

public class ModelRunner {

    private File inputData;
    private File predictionLogs;
    private ModelPredictor modelPredictor;


    private ModelRunner(String inputDataS, String predictionLogsS) {
        this.inputData = new File(inputDataS);
        this.predictionLogs = new File(predictionLogsS);

        if (!inputData.exists()) {
                System.out.println("Unable to access input data file at " + inputDataS);
            }
        if (!predictionLogs.exists()) {
            predictionLogs.mkdir();
            if (!predictionLogs.mkdir()) {
                System.out.println("Error while creating prediction folder output " + predictionLogsS);
            }
        }
        try {
            System.out.println("Using input data in '" + inputData.getCanonicalPath() + "' / output prediction logs to '" + predictionLogs.getCanonicalPath());
        } catch (IOException e) {
            e.printStackTrace();
        }
        this.modelPredictor = new ModelPredictor();

    }

    public static void main(String[] args) {
        String inputData, predictionLogs;
        try {
            System.out.println("Command-line arguments:");
            for (String arg : args) {
                System.out.println(arg);
            }
            if(args.length >= 2) {
                inputData = args[0];
                predictionLogs = args[1];
            } else if (args.length == 1) {
                inputData = args[0];
                predictionLogs = "output/java/predictions.log";
            } else {
                inputData = "input/input_data.csv";
                predictionLogs = "output/java/predictions.log";
            }

            ModelRunner runFlow = new ModelRunner(inputData, predictionLogs);
            runFlow.process();
        } catch (Exception e) {
            System.out.println("An error occurred: " + e.getMessage());
            e.printStackTrace();
        }
    }

    private void process() throws IOException {
        System.out.println("Processing file " + inputData.getAbsolutePath());
        long timestamp = System.currentTimeMillis();
        // Load data from input file
        List<String[]> inputDataset = new ArrayList<>();
        Reader reader = Files.newBufferedReader(inputData.toPath());
        CSVReader csvReader = new CSVReader(reader);
        String[] line;
        String[] header = csvReader.readNext();
        while ((line = csvReader.readNext()) != null) {
            inputDataset.add(line);
        }
        reader.close();
        csvReader.close();
        System.out.print("Loaded input data with " + inputDataset.size() + " rows");

        // Make predictions
        outputDataset = modelPredictor.predictBatch(header, inputDataset);
        System.out.print("Predictions done, storing prediction logs");

        // Output predictions in output folder
        FileWriter outputFileWriter = new FileWriter(outputDir.toString() + "/" + timestamp + "-" + file.getName());
        CSVWriter writer = new CSVWriter(outputFileWriter);
        ArrayList<String> outputHeader = new ArrayList<>();
        outputHeader.addAll(Arrays.asList(header));
        outputHeader.add("prediction");
        outputHeader.add("proba_0");
        outputHeader.add("proba_1");
        writer.writeNext(outputHeader.toArray(new String[0]));
        writer.writeAll(outputDataset);
        writer.close();
        outputFileWriter.close();

    }
}