import asyncio
import json
import logging

from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import AsyncIterator
from typing import List
from typing import Dict

from dataiku.base.batcher import Batcher
from dataiku.huggingface.pipeline import ModelPipeline
from dataiku.huggingface.types import SingleCommand, SingleResponse

logger = logging.getLogger(__name__)


class ModelPipelineBatching(ModelPipeline[SingleCommand, SingleResponse]):
    """Base class for the non-vLLM pipelines to implement."""

    executor: ThreadPoolExecutor
    batcher: Batcher[SingleCommand, SingleResponse]

    def __init__(self, batch_size: int):
        super().__init__()
        self.executor = ThreadPoolExecutor()
        self.batcher = Batcher[SingleCommand, SingleResponse](
            batch_size=batch_size,
            timeout=0.1,
            process_batch=self._process_batch_async,
            group_by=lambda request: json.dumps(self._get_params(request))
        )

    async def run_single_async(self, request: SingleCommand) -> AsyncIterator[SingleResponse]:
        yield await self.batcher.process(request)

    async def initialize_model(self):
        pass

    async def _process_batch_async(self, items: List[SingleCommand]) -> List[SingleResponse]:
        logging.info("Processing a batch of %s sequences" % len(items))
        return await asyncio.get_event_loop().run_in_executor(self.executor, self._run_batch_sync, items)

    def _run_batch_sync(self, requests: List[SingleCommand]) -> List[SingleResponse]:
        inputs = self._get_inputs(requests)

        params = self._get_params(requests[0])
        params['batch_size'] = len(inputs)  # may differ from the batch_size from the backend

        logging.info("Running batch task with parameters: " + str(params))
        if logging.DEBUG >= logging.root.level:
            logging.debug("Task inputs: " + str(inputs))

        responses = self._run_inference(inputs, params)
        return self._parse_responses(responses, requests)

    # Default implementation, meant to be overridden in concrete implementations
    # The parsed params returned are
    #  - passed to the model  in _run_batch_sync,
    #  - used as group-by key when batching the requests.
    def _get_params(self, request: SingleCommand) -> Dict:
        """Extract and parse the model settings from the request received from the DSS backend."""
        return {}

    @abstractmethod
    def _get_inputs(self, requests: List[SingleCommand]) -> List[str]:
        """Extract the model inputs from the requests received from the DSS backend."""
        raise NotImplementedError

    @abstractmethod
    def _run_inference(self, inputs: List[str], params: Dict) -> List:
        """Run the model's inference step."""
        raise NotImplementedError

    def _parse_responses(self, responses: List, requests: List[SingleCommand]) -> List[SingleResponse]:
        if not isinstance(responses, list) or len(responses) != len(requests):
            raise Exception("Unexpected answer from huggingface: " + str(responses))

        return [self._parse_response(response, request) for (response, request) in zip(responses, requests)]

    @abstractmethod
    def _parse_response(self, response: Any, request: SingleCommand) -> SingleResponse:
        """Transform the response to a single request from HF into the structure expected by DSS."""
        raise NotImplementedError