import pandas as pd
import requests
import json
import logging

logging.basicConfig(level=logging.INFO, format='osm-e plugin %(levelname)s - %(message)s')

overpass_url = "http://overpass-api.de/api/interpreter"


def enrich_bounding_box(bbox, tags):
    _check_bbox(bbox)
    _check_tags(tags)
    tags = [t.strip().lower() for t in tags]
    tags = [t for t in tags if len(t) > 0]
    tags = set(tags)  # deduplicate
    print('Tags: ')
    print(tags)

    bbox_str = '(%s,%s,%s,%s)' % (bbox["minLat"], bbox["minLon"], bbox["maxLat"], bbox["maxLon"])
    overpass_query = "[out:json];("
    if len(tags) > 0:
        for t in tags:
            overpass_query += 'node' + '[' + t + ']' + bbox_str + ';' + 'way' + '[' + t + ']' + bbox_str + ';'
        overpass_query = overpass_query + ')' + ';out center;'
    else:
        overpass_query = overpass_query[:-1] + bbox_str + ';out center;'

    print('Query: ')
    print(overpass_query)

    response = requests.get(overpass_url, params={'data': overpass_query})
    data = response.json()
    return data


def detect_categories(data):
    for e in data['elements']:
        dic_tag_keys = e['tags'].keys()
        if 'shop' in dic_tag_keys:
            if e['tags']['shop'] == 'supermarket':
                e['category'] = 'supermarket'
            else:
                e['category'] = 'shop'
        elif 'leisure' in dic_tag_keys or 'sport' in dic_tag_keys or 'tourism' in dic_tag_keys or 'historic' in dic_tag_keys:
            e['category'] = 'entertainment'
        elif 'railway' in dic_tag_keys:
            if e['tags']['railway'] in ['station', 'tram_stop']:
                e['category'] = 'transport'
        elif 'amenity' in dic_tag_keys:
            if e['tags']['amenity'] in ['cinema', 'theatre']:
                e['category'] = 'entertainment'
            elif e['tags']['amenity'] in ['pharmacy', 'car_rental', 'fuel', 'car_wash', 'bank', 'atm']:
                e['category'] = 'shop'
            elif e['tags']['amenity'] in ['fast_food', 'cafe', 'bar', 'pub', 'restaurant']:
                e['category'] = 'food'
            elif e['tags']['amenity'] in ['social_facility', 'hospital', 'police', 'townhall', 'kindergarten',
                                          'library']:
                e['category'] = 'public_service'
            elif e['tags']['amenity'] == 'bus_station':
                e['category'] = 'transport'
    return data


def make_grid(bbox, n_lat, n_lon):
    lat_mid = []
    lon_mid = []
    lat_min = []
    lon_min = []
    lat_max = []
    lon_max = []
    d_lat = (bbox["maxLat"] - bbox["minLat"]) / n_lat
    d_lon = (bbox["maxLon"] - bbox["minLon"]) / n_lon
    for i in range(n_lat):
        for j in range(n_lon):
            lat_mid.append(bbox["minLat"] + d_lat / 2 + i * d_lat)
            lon_mid.append(bbox["minLon"] + d_lon / 2 + j * d_lon)
            lat_min.append(bbox["minLat"] + i * d_lat)
            lon_min.append(bbox["minLon"] + j * d_lon)
            lat_max.append(bbox["minLat"] + (i + 1) * d_lat)
            lon_max.append(bbox["minLon"] + (j + 1) * d_lon)

    grid = pd.DataFrame.from_dict({
        'lat_square_center': lat_mid,
        'lon_square_center': lon_mid,
        'lat_square_min': lat_min,
        'lon_square_min': lon_min,
        'lat_square_max': lat_max,
        'lon_square_max': lon_max
    })
    grid['cell_grid_id'] = range(1, len(grid) + 1)
    return grid


def aggregate_by_category_on_grid(data, grid):
    data = detect_categories(data)
    pois = pd.DataFrame(data['elements'])
    pois['lat'] = pois[['center', 'lat', 'type']].apply(lambda x: x[1] if x[2] == 'node' else x[0]['lat'], axis=1)
    pois['lon'] = pois[['center', 'lon', 'type']].apply(lambda x: x[1] if x[2] == 'node' else x[0]['lon'], axis=1)
    pois = pois[['id', 'lat', 'lon', 'category']].dropna()
    pois = pd.get_dummies(pois)
    pois['key'] = 1
    grid['key'] = 1
    df = grid.merge(pois)
    df = df.loc[(df.lat > df.lat_square_min) & (df.lat < df.lat_square_max) & (df.lon > df.lon_square_min) & (
                df.lon < df.lon_square_max)]
    cols = list(df.columns)
    for f in ['id', 'lat', 'lon', 'key']:
        cols.remove(f)
    group_keys = [x for x in cols if 'category' not in x]
    res = df[cols].groupby(group_keys).sum().reset_index()
    res = _create_polygon(res, cols)
    return res


def _create_polygon(res, cols):
    cols2 = ['lon_square_min', 'lat_square_min', 'lon_square_max', 'lat_square_max']
    res['geom1'] = res[cols2].apply(lambda x: str(x[0]) + ' ' + str(x[1]), axis=1)
    res['geom2'] = res[cols2].apply(lambda x: str(x[0]) + ' ' + str(x[3]), axis=1)
    res['geom3'] = res[cols2].apply(lambda x: str(x[2]) + ' ' + str(x[3]), axis=1)
    res['geom4'] = res[cols2].apply(lambda x: str(x[2]) + ' ' + str(x[1]), axis=1)
    res['geom5'] = res[cols2].apply(lambda x: str(x[0]) + ' ' + str(x[1]), axis=1)
    res['geom'] = res[['geom1', 'geom2', 'geom3', 'geom4', 'geom5']].apply(
        lambda x: 'MULTIPOLYGON(((' + x[0] + ',' + x[1] + ',' + x[2] + ',' + x[3] + ',' + x[4] + ')))', axis=1)
    for c in cols2:
        cols.remove(c)
    res = res[cols + ['geom']]
    return res


def _check_bbox(bbox):
    print(bbox)
    if bbox is None:
        raise ValueError("Bounding box not specified")
    if not isinstance(bbox, dict):
        raise ValueError("Bounding box must be a dict")
    for k in ["minLat", "minLon", "maxLat", "maxLon"]:
        if k not in bbox:
            raise ValueError("Bounding box must have a " + k)
    if (bbox['minLat'] >= bbox['maxLat']):
        raise ValueError("minLat must be strictly inferior to maxLat")
    if (bbox['minLon'] >= bbox['maxLon']):
        raise ValueError("minLat must be strictly inferior to maxLon")


def _check_tags(tags):
    if tags is None:
        raise ValueError("Tags not specified")
