from typing import AnyStr, List, Union, Tuple
import pandas as pd
from shapely import geometry
import numpy as np
import logging
from time import perf_counter
from collections import defaultdict
from osm_dataset_enrichment.geometry_parser import GeometryParser
from osm_dataset_enrichment.osm_client import (
    OSMClient,
    CustomErrorHandler,
    custom_make_request,
)
from osm_dataset_enrichment.utils.cleaning_utils import (
    generate_unique_columns,
    generate_unique,
)
from .collector import Collector
from .writer import Writer
from apiclient.request_strategies import RequestStrategy


class OSMEnrichment:
    """Display all points of interests found in given geometries.
    Attributes:
        filters (List[AnyStr]): The tags and keys to filter the pois.
        enrichments (List[AnyStr]): The extra informations to display about the pois.
        request_mode (bool): If we should request by batch or by row
    """

    def __init__(
        self,
        filters: List[AnyStr],
        enrichments: List[AnyStr],
        request_by_batch: bool,
    ):
        self.filters = filters
        self.enrichments = enrichments
        self.keys = self.filters + self.enrichments
        self.geometry_parser = GeometryParser()
        self.collector = Collector(filters=self.filters, keys=self.keys, request_by_batch=request_by_batch)
        self.writer = Writer(filters=self.filters, keys=self.keys, request_by_batch=request_by_batch)
        # instance to parse the polygons into overpass format
        self.output_columns = (
            self.filters + self.enrichments + ["tags", "geopoint", "failure_response"]
        )  # enrichment columns names
        self.keys = self.filters + self.enrichments

    def _set_output_column_names(self, existing_column_names: List[AnyStr]) -> None:
        """Set self.output_columns with unique column names"""
        self.output_columns = generate_unique_columns(df_columns=existing_column_names, columns=self.output_columns)

    def enrich_batch_df(self, df: pd.DataFrame, polygon_column: AnyStr) -> pd.DataFrame:
        """Main function to search and write points of interests for each geometry in a dataframe"""
        # parse
        df = self.geometry_parser.parse_df(df, polygon_column)
        df_parsed_geometry = df[~(df[self.geometry_parser.parsed_geometry_column].isnull())]
        # collect point of interests
        datas = self.collector.collect_pois(df_parsed_geometry, self.geometry_parser.parsed_geometry_column)
        # else:  # if the request has returned a 200 status
        input_list, output_list = self.writer.write_pois(datas, df_parsed_geometry)
        # write
        self._set_output_column_names(existing_column_names=df.columns.tolist())
        df_invalid_geom = self._invalid_geom_df(df)
        return self.create_output_df(input_list, output_list, df_parsed_geometry, df_invalid_geom)

    def create_output_df(self, input_list, output_list, df_parsed_geometry, df_invalid_geom):
        """Write the output dataframe by concatenating:
        -dataframe with the input columns
        -dataframe with the output columns
        -dataframe with unparsed geometries
        """
        input_df = pd.DataFrame(input_list, columns=df_parsed_geometry.columns)
        output_df = pd.DataFrame(output_list, columns=self.output_columns)
        output_df = pd.concat([input_df, output_df], axis=1)
        output_df = pd.concat([output_df, df_invalid_geom], sort=False)
        output_df = output_df.drop(
            columns=[
                "index",
                self.geometry_parser.parsed_geometry_column,
                self.geometry_parser.error_type_column,
                self.geometry_parser.error_message_column,
            ]
        )
        return output_df

    def _invalid_geom_df(self, df: pd.DataFrame) -> pd.DataFrame:
        """Returns a dataframe of all unparsable geometries"""
        df_invalid_geom = df[(df[self.geometry_parser.parsed_geometry_column].isnull())].reset_index()
        df_invalid_geom[self.geometry_parser.parsed_geometry_column] = np.nan
        df_invalid_geom["failure_response"] = df_invalid_geom[self.geometry_parser.error_type_column].str.cat(
            others=df_invalid_geom[self.geometry_parser.error_message_column]
        )
        return df_invalid_geom
