# encoding: utf-8
"""
Single-thread hoster for a custom-code based predictor
"""

import inspect
import time
import logging
import json

from dataiku.apinode import DkuCustomApiException, DkuCustomHttpResponse, DkuHttpRequestMetadata
from dataiku.base.socket_block_link import JavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import get_json_friendly_error
from dataiku.base.utils import watch_stdin
from dataiku.core import debugging

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
debugging.install_handler()

from dataiku.base.utils import watch_stdin, get_json_friendly_error
from dataiku.base.socket_block_link import JavaLink, parse_javalink_args

import os


class LoadedFunction(object):
    def __init__(self, code_file, function_name, data_folders=[]):
        with open(code_file, "r") as f:
            self.code = f.read()

        self.ctx = {
            "folders": data_folders,
            "DkuCustomApiException": DkuCustomApiException
        }

        exec(self.code, self.ctx, self.ctx)

        functions = [o for o in self.ctx.values() if inspect.isfunction(o)]
        self.f = functions[0] if len(functions) == 1 else self.ctx.get(function_name)

        if self.f is None:
            raise Exception('No function "%s" defined' % function_name)

    def with_additional_context(self, name, value):
        self.ctx[name] = value

    def call(self, params):
        return self.f(**params)


# socket-based connection to backend
def serve(port, secret, server_cert=None):
    link = JavaLink(port, secret, server_cert=server_cert)
    # initiate connection
    link.connect()

    # get work to do
    try:
        # retrieve the initialization info and initiate serving
        command = link.read_json()

        function_name = command.get('functionName')
        code_file = command.get('codeFilePath')
        data_folders = command.get('resourceFolderPaths', [])

        loaded_function = LoadedFunction(code_file, function_name, data_folders)

        logging.info("Predictor ready")
        link.send_json({"ok":True})

        # loop and process commands
        while True:
            request = link.read_json()
            if request is None:
                break

            logging.info("Received request %s", json.dumps(request["params"]))

            used_api_key = request.get("usedAPIKey")
            if used_api_key is not None:
                os.environ["DKU_CURRENT_REQUEST_USED_API_KEY"] = used_api_key

            loaded_function.with_additional_context(
                "dku_http_request_metadata", DkuHttpRequestMetadata(request.get("httpRequestMetadata"))
            )

            before = time.time()
            response = loaded_function.call(request["params"])
            after = time.time()
            exec_time = int(1000000 * (after - before))
            if isinstance(response, DkuCustomHttpResponse):
                response = response.to_dict()
            link.send_json(
                {
                    "ok": True,
                    "resp": response,
                    "execTimeUS": exec_time,
                }
            )

            if used_api_key is not None:
                del os.environ["DKU_CURRENT_REQUEST_USED_API_KEY"]

            logging.info("Done processing request %s in %s ms", json.dumps(request["params"]), exec_time / 1000)

        # send end of stream
        logging.info("Work done")
        link.send_string('')
    except DkuCustomApiException as e:
        logging.exception("Function code threw a user-defined DkuCustomApiException")
        link.send_string('')  # send null to mark failure
        link.send_json(get_json_friendly_error({'customHttpStatusCode': e.http_status_code}))
    except Exception:
        logging.exception("Function user code failed")
        link.send_string('')  # send null to mark failure
        link.send_json(get_json_friendly_error())
    finally:
        # done
        link.close()


if __name__ == "__main__":
    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
