import shapely.wkt
from shapely import geometry
import numpy as np
from typing import AnyStr, List, Union, Tuple
import re
import pandas as pd
from osm_dataset_enrichment.utils.cleaning_utils import generate_unique
import logging
from time import perf_counter


class GeometryParser:
    """Class to Parse a geometry from WKT format to Overpass format"""

    PARSE_ERROR = "Parse error: "
    UNSUPPORTED_GEOMETRY_ERROR = "Unsupported geometry: "

    def parse_df(self, df: pd.DataFrame, geometry_column: AnyStr) -> pd.DataFrame:
        """Parse a dataframe to Overpass API format"""
        logging.info(f"Parse {len(df)} geometries...")
        start = perf_counter()
        self.generate_parsing_columns(df.columns.tolist())
        df = df.apply(
            self.parse,
            args=[geometry_column],
            axis=1,
        )
        logging.info(f"Parse {len(df)} geometries: Done in {perf_counter() - start:.2f} seconds")
        return df

    def generate_parsing_columns(self, existing_columns: List[AnyStr]) -> None:
        """Set unique column names"""
        parsing_columns = ["parsed_geometry", "error_type", "error_message"]
        self.error_type_column = generate_unique("error_type", existing_columns)
        self.error_message_column = generate_unique("error_message", existing_columns)
        self.parsed_geometry_column = generate_unique("parsed_geometry", existing_columns)

    def parse(self, row: pd.Series, geometry_column: AnyStr) -> pd.Series:
        """Parse a geometry from a pandas.Series into a valid format to request Overpass API"""
        geometry = row[geometry_column]
        # initialize cells to nan
        row[self.parsed_geometry_column] = np.nan
        row[self.error_type_column] = ""
        row[self.error_message_column] = ""
        if geometry is np.nan:
            return row
        # load geometry
        try:
            geometry = shapely.wkt.loads(geometry)
        except Exception as error:
            row[self.error_type_column] = self.PARSE_ERROR
            row[self.error_message_column] = f" {str(type(error))} : {str(error)}"
            return row
        # parse
        if geometry.geom_type == "Polygon":
            row[self.parsed_geometry_column] = [self._parse_polygon(geometry)]
        elif geometry.geom_type == "MultiPolygon":
            row[self.parsed_geometry_column] = self._parse_multipolygon(geometry)
        else:
            row[self.error_type_column] = self.UNSUPPORTED_GEOMETRY_ERROR
            row[self.error_message_column] = "only polygons and multipolygons accepted."
        return row

    def _parse_multipolygon(self, multipolygon: shapely.geometry.multipolygon.MultiPolygon) -> List[np.ndarray]:
        """Parse a multipolygon with overpass geometry syntax"""
        overpass_geometries = []
        for polygon in multipolygon.geoms:
            overpass_geometries.append(self._parse_polygon(polygon))
        return overpass_geometries

    def _parse_polygon(self, polygon: shapely.geometry.polygon.Polygon) -> np.ndarray:
        """Parse a polygon with overpass geometry syntax"""
        polygon_coordinates = self._get_polygon_coordinates(polygon)
        polygon_coordinates_recomputed = self._recompute_polygon_coordinates(polygon_coordinates)
        return self._from_polygon_coordinates_to_overpass_geometry(polygon_coordinates_recomputed)

    @staticmethod
    def _get_polygon_coordinates(
        polygon: shapely.geometry.polygon.Polygon,
    ) -> List[np.ndarray]:
        """Get the coordinates of a polygon as a list of np.array as [main polygon, list of holes]"""
        return [
            np.asarray(polygon.exterior.coords),
            [np.asarray(hole) for hole in polygon.interiors],
        ]

    @staticmethod
    def _recompute_polygon_coordinates(coordinates) -> np.ndarray:
        """recalculate coordinates to include holes in it"""
        recomputed_coordinates = [coordinates[0]]
        for hole in coordinates[1]:
            recomputed_coordinates.append(hole)
            recomputed_coordinates.append(np.stack([hole[0], coordinates[0][-1]], axis=0))
        return np.concatenate(recomputed_coordinates)

    def _from_polygon_coordinates_to_overpass_geometry(self, polygon: np.ndarray) -> AnyStr:
        """
        polygon must have format :
        [[longitude_1, latitude_1], [longitude_2, latitude_2], ...., [longitude_n, latitude_n]]
        """
        polygon = self._reverse_polygon_coordinates(polygon)
        overpass_parameters = re.sub(r"[\[\],]", "", str(polygon))
        overpass_parameters = '(poly:"{}")'.format(overpass_parameters)
        return overpass_parameters

    @staticmethod
    def _reverse_polygon_coordinates(polygon) -> list:
        """
        polygon must have format :
        [[param_1_1, param_2_1], [param_1_2, param_2_2], ...., [param_1_n, param_2_n]]
        """
        new_coordinates = [[coordinates[1], coordinates[0]] for coordinates in polygon]
        return new_coordinates
