from typing import List, Union

from answers.solutions.vector_search.generic_vector_search import (
    GenericVectorQuery,
    ListConditional,
    RawExpression,
    ValueConditional,
)
from answers.solutions.vector_search.models import VecOperator
from qdrant_client.http.models import AnyVariants, ValueVariants
from qdrant_client.models import FieldCondition, Filter, MatchAny, MatchExcept, MatchValue, Range


class QdrantOperators(VecOperator):
    EQ = "MatchValue"
    NE = "MatchExcept"
    ####### can be done time too and converted during query
    GT = "gt" #part of range
    GTE = "gte" #part of range
    LT = "lt" #part of range
    LTE = "lte" #part of range
    ################################################
    IN = "MatchAny" 
    NIN = "MatchExcept" # as list of exp
    AND = "must"
    OR = "should"
    # BETWEEN = 

    # Can also do geo searches e.g radius 
    # can also do counts e.g. only where there are 2 tags
    # can also do is empty too
NumericConditional = Union[int, float]

class QdrantVectorQuery(GenericVectorQuery):
    def __init__(self):
        super().__init__(QdrantOperators)
        self.operator_methods.update({"between": self.between_cond})

    @property
    def supports_hybrid(self) -> bool:
        return True

    @property
    def hybrid_key(self) -> str:
        return "where_document"

    def eq_cond(self, field: str, value: Union[ValueConditional, ValueVariants]) -> FieldCondition:
        if isinstance(value, List):
            raise ValueError("Value must be a single value for EQ operator")
        return FieldCondition(  # type: ignore
            key=f"metadata.{field}",
            match=MatchValue(value=value)  # type: ignore
    )
    def ne_cond(self, field: str, value: Union[ValueConditional, ValueVariants]) -> FieldCondition:
        if isinstance(value, list):
            raise ValueError("Value must be a single value for NE operator")
        return FieldCondition(  # type: ignore
            key=f"metadata.{field}",
            match=MatchExcept(value=value))  # type: ignore


    def gt_cond(self, field: str, value: NumericConditional) -> FieldCondition:
        if isinstance(value, list) or isinstance(value, str):
            raise ValueError("Value must be a single numerical value for GT operator")
        return FieldCondition(key=f"metadata.{field}",
                                        range=Range(
                                            gt=value,
                                            gte=None,
                                            lt=None,
                                            lte=None,
        ))

    def gte_cond(self, field: str, value: NumericConditional) -> FieldCondition:
        if isinstance(value, list) or isinstance(value, str):
            raise ValueError("Value must be a single numerical value for GTE operator")
        return FieldCondition(key=f"metadata.{field}",
                                        range=Range(
                                            gt=None,
                                            gte=value,
                                            lt=None,
                                            lte=None,
        ))

    def lt_cond(self, field: str, value: NumericConditional) -> FieldCondition:
        if isinstance(value, list) or isinstance(value, str):
            raise ValueError("Value must be a single numerical value for LT operator")
        return FieldCondition(key=f"metadata.{field}",
                                        range=Range(
                                            gt=None,
                                            gte=None,
                                            lt=value,
                                            lte=None,
        ))

    def lte_cond(self, field: str, value: NumericConditional) -> FieldCondition:
        if isinstance(value, list) or isinstance(value, str):
            raise ValueError("Value must be a single numerical value for LTE operator")
        return FieldCondition(key=f"metadata.{field}",
                                        range=Range(
                                            gt=None,
                                            gte=None,
                                            lt=None,
                                            lte=value,
        ))

    def between_cond(self, field: str, from_: NumericConditional, upto: NumericConditional) -> FieldCondition:
        if isinstance(from_, list) or isinstance(from_, str) or isinstance(upto, list) or isinstance(upto, str):
            raise ValueError("Values must be a single numerical value for BETWEEN operator")
        return FieldCondition(key=f"metadata.{field}",
                                        range=Range(
                                            gt=upto,
                                            gte=None,
                                            lt=from_,
                                            lte=None,
        ))

    def in_cond(self, field: str, values: Union[AnyVariants, List[Union[str, int, float]]]) -> FieldCondition:
        if isinstance(values, str) or isinstance(values, int) or isinstance(values, float):
            raise ValueError("Values must be a list for IN operator")
        return FieldCondition( # type: ignore
            key=f"metadata.{field}",
            match=MatchAny(any=values) # type: ignore
        )

    def nin_cond(self, field: str, values: Union[ListConditional, AnyVariants]) -> FieldCondition:
        if isinstance(values, str) or isinstance(values, int) or isinstance(values, float):
            raise ValueError("Values must be a list for NIN operator")
        return FieldCondition( # type: ignore
            key=f"metadata.{field}",
            match=MatchExcept(**{"except": values}) # type: ignore
        )

    def and_op(self, expression_list, key: str = "filter") -> None:
        self._query[key] = Filter(must=expression_list)

    def or_op(self, expression_list, key: str = "filter") -> None:
        self._query[key] = Filter(should=expression_list)

    def invoke(self, raw_expression) -> RawExpression:
        if "hybrid" in raw_expression and isinstance(raw_expression, dict):
            if not self.supports_hybrid:
                raise ValueError("Value 'hybrid' not supported. This vector DB does not support hybrid filtering")
            self._query[self.hybrid_key] = {}
            self.parse_raw_expression(raw_expression["hybrid"], key=self.hybrid_key)
        if "and" not in raw_expression["filter"] and "or" not in raw_expression["filter"]:
            raw_expression["filter"] = {"and": [raw_expression["filter"]]}
        self.parse_raw_expression(raw_expression["filter"],  key="filter")
        return self.query
