from typing import Any, Callable, Dict, Generic, List, Optional, Type, Union

from answers.solutions.vector_search.models import (
    ConditionDictType,
    ExpressionList,
    ListConditional,
    NumericConditional,
    RawExpression,
    T,
    ValueConditional,
)
from qdrant_client.http.models.models import AnyVariants, ValueVariants
from qdrant_client.models import FieldCondition


class GenericVectorQuery(Generic[T]):
    def __init__(self, vec_operator: Type[T]) -> None:
        """Accepts a VecOperator subclass instead of an instance."""
        self._query: Dict[str, Any] = {}
        self._vec_operator: Type[T] = vec_operator
        self.operator_methods = {
            "eq": self.eq_cond,
            "ne": self.ne_cond,
            "gt": self.gt_cond,
            "gte": self.gte_cond,
            "lt": self.lt_cond,
            "lte": self.lte_cond,
            "in": self.in_cond,
            "nin": self.nin_cond,
        }

    @property
    def query(self) -> Dict[str, Any]:
        return self._query

    @property
    def vec_operator(self) -> Type[T]:
        return self._vec_operator

    @property
    def supports_hybrid(self) -> bool:
        return False

    @property
    def hybrid_key(self) -> str:
        return ""

    def eq_cond(
        self, field: str, value: Union[ValueConditional, ValueVariants]
    ) -> Union[ConditionDictType, str, FieldCondition]:
        if isinstance(value, List):
            raise ValueError("Value must be a single value for EQ operator")
        return {field: {self.vec_operator.EQ: value}}

    def ne_cond(
        self, field: str, value: Union[ValueConditional, ValueVariants]
    ) -> Union[ConditionDictType, str, FieldCondition]:
        if isinstance(value, list):
            raise ValueError("Value must be a single value for NE operator")
        return {field: {self.vec_operator.NE: value}}

    def gt_cond(self, field: str, value: NumericConditional) -> Union[ConditionDictType, str, FieldCondition]:
        if isinstance(value, list) or isinstance(value, str):
            raise ValueError("Value must be a single numerical value for GT operator")
        return {field: {self.vec_operator.GT: value}}

    def gte_cond(self, field: str, value: NumericConditional) -> Union[ConditionDictType, str, FieldCondition]:
        if isinstance(value, list) or isinstance(value, str):
            raise ValueError("Value must be a single numerical value for GTE operator")
        return {field: {self.vec_operator.GTE: value}}

    def lt_cond(self, field: str, value: NumericConditional) -> Union[ConditionDictType, str, FieldCondition]:
        if isinstance(value, list) or isinstance(value, str):
            raise ValueError("Value must be a single numerical value for LT operator")
        return {field: {self.vec_operator.LT: value}}

    def lte_cond(self, field: str, value: NumericConditional) -> Union[ConditionDictType, str, FieldCondition]:
        if isinstance(value, list) or isinstance(value, str):
            raise ValueError("Value must be a single numerical value for LTE operator")
        return {field: {self.vec_operator.LTE: value}}

    def in_cond(self, field: str, values: Union[AnyVariants, List[Union[str, int, float]]]) -> Union[ConditionDictType, str, FieldCondition]:
        if isinstance(values, str) or isinstance(values, int) or isinstance(values, float):
            raise ValueError("Values must be a list for IN operator")
        return {field: {self.vec_operator.IN: values}}

    def nin_cond(
        self, field: str, values: Union[ListConditional, AnyVariants]
    ) -> Union[ConditionDictType, str, FieldCondition]:
        if isinstance(values, str) or isinstance(values, int) or isinstance(values, float):
            raise ValueError("Values must be a list for NIN operator")
        return {field: {self.vec_operator.NIN: values}}

    def verify_filter_key(self, key: str) -> None:
        if not self._query.get(key):
            self._query[key] = {}

    def or_op(self, expression_list, key: str = "filter") -> None:
        self.verify_filter_key(key)
        self._query[key].update({self.vec_operator.OR: expression_list})

    def and_op(self, expression_list, key: str = "filter") -> None:
        self.verify_filter_key(key)
        self._query[key].update({self.vec_operator.AND: expression_list})

    def parse_conditional(self, field: str, conditional: dict):
        op, arg = next(iter(conditional.items()))
        method = self.operator_methods.get(op)
        if method is None:
            raise Exception(f"Unknown operator: {op}")
        return method(field, arg) # type: ignore

    def parse_expression(self, operator: str, value: str) -> Union[Dict[str, Callable[[str, Any], str]], str]:
        if not (op := self.operator_methods.get(operator)):
            raise ValueError(f"operator {operator} not a recognised operator")
        # TODO: mypy can't understand that op here so we might need an explicit type hint
        return {op: value} # type: ignore

    def parse_raw_expression(
        self, raw_expression: Union[ConditionDictType, str], expression_list: Optional[ExpressionList] = None, key: str = "filter"
    ) -> Optional[ExpressionList]:
        if isinstance(raw_expression, dict):
            for k, v in raw_expression.items():
                # AND / OR
                if k == "and":
                    expression_list = []
                    for cond in v:
                        expression_list = self.parse_raw_expression(cond, expression_list)
                    self.and_op(expression_list=expression_list, key=key)
                    expression_list = None
                elif k == "or":
                    expression_list = []
                    for cond in v:
                        expression_list = self.parse_raw_expression(cond, expression_list)
                    self.or_op(expression_list=expression_list, key=key)
                    expression_list = None
                # operator
                elif isinstance(v, dict):
                    if expression_list is None:
                        self._query[key] = self.parse_conditional(k, v)
                        return None
                    elif isinstance(expression_list, list):
                        expression_list.append(self.parse_conditional(k, v))
                        return expression_list
                # expression
                elif isinstance(v, str):
                    if expression_list is None:
                        self._query[key] = self.parse_expression(k, v)
                        return None
                    elif isinstance(expression_list, list):
                        expression_list.append(self.parse_expression(k, v))
                        return expression_list
                else:
                    raise Exception(f"Unsupported expression format for expression {raw_expression}")
        if expression_list is None:
            return None
        return expression_list

    def invoke(self, raw_expression) -> RawExpression:
        if "hybrid" in raw_expression and isinstance(raw_expression, dict):
            if not self.supports_hybrid:
                raise ValueError("This vector DB does not support hybrid filtering")
            self._query[self.hybrid_key] = {}
            self.parse_raw_expression(raw_expression["hybrid"], key=self.hybrid_key)
        self.parse_raw_expression(raw_expression["filter"], key="filter")
        return self.query
