import logging
from time import perf_counter
import pandas as pd
import numpy as np
from enum import Enum
from typing import List, AnyStr
from collections import defaultdict
from .osm_client import (
    OSMClient,
    CustomErrorHandler,
    custom_make_request,
)
from apiclient.request_strategies import RequestStrategy
from apiclient import JsonResponseHandler


class Collector:
    """Class to collect point of interests from a dataframe with geometry parsed in Overpass format"""

    TIME_OUT = "2000"

    def __init__(self, keys, filters, request_by_batch=False):
        self.request_by_batch = request_by_batch
        self.keys = keys  # pois to collect + enrichments
        self.filters = filters  # type of pois to collect
        RequestStrategy._make_request = custom_make_request
        self.osm_client = OSMClient(error_handler=CustomErrorHandler, response_handler=JsonResponseHandler)

    def collect_pois(self, df, geometry_column) -> dict:
        if self.request_by_batch:
            logging.info(f"Collect pois on {len(df[geometry_column])} geometries...")
            start = perf_counter()
            datas = self._batch_collect_pois(df[geometry_column].tolist())
            logging.info(
                f"Collecting pois on {len(df[geometry_column])} geometries: Done in {perf_counter() - start:.2f} seconds"
            )
            return datas
        else:
            return df[geometry_column].apply(lambda geometry: [geometry]).apply(self._batch_collect_pois).to_dict()

    def _batch_collect_pois(self, batch_overpass_geom: pd.Series) -> dict:
        """Create the query,send it to OSM client and return the response """
        tags = self._get_set_of_tags(self.filters)
        overpass_query = f"[out:json][timeout:{self.TIME_OUT}];"
        if len(tags) > 0:
            for geom in batch_overpass_geom:
                overpass_query += "("
                for t in tags:
                    overpass_query += self._query_by_tag(t, geom)
                overpass_query += ");out tags center;out count;"
        else:
            raise ValueError(
                "No valid filters were found. Please specify at least one filter in the 'Type of POIs' field"
            )
        try:
            data = self.osm_client.post_pois(overpass_query)
            data["failure_response"] = np.nan
        except Exception as error:
            data = {"elements": [], "failure_response": str(error)}
        return data

    def _extract_tags(self, poi_tags: dict) -> List[AnyStr]:
        """Fill a list with the value of the tag for each tag"""
        poi_tags = defaultdict(str, poi_tags)
        line = []
        for tag in self.keys:
            line.append(str(poi_tags[tag]))
        return line

    @staticmethod
    def _get_set_of_tags(tags: List[AnyStr]) -> List[AnyStr]:
        """Return the list of tags stripped"""
        tags = [t.strip().lower() for t in tags]
        tags = [t for t in tags if len(t) > 0]
        return tags

    @staticmethod
    def _query_by_tag(tag: AnyStr, overpass_geometry: List[AnyStr]) -> AnyStr:
        """Part of the query linked to a given tag and a list of geometries in Overpass format

        Args:
            tag (AnyStr): The tag to filter POIs with
            overpass_geometry (list[AnyStr]: list of polygons to search POIs into.
                If len(overpass_geometry) == 1 : polygon
                If len(overpass_geometry) > 1 : multipolygon represented as a list of polygons to search into at the same time

        Returns:
            str: The union query to send

        """
        query = ""
        for poly in overpass_geometry:
            query += "node" + "[" + tag + "]" + poly + ";"
        return query
