import dataiku
import dash
from dash import dcc
import dash_bootstrap_components as dbc
from dash import html, callback_context
from dash.dependencies import Input, Output, State, MATCH, ALL
from dash import dash_table
import pandas as pd
import numpy as np
import dataikuapi
import generalized_linear_models
sys.modules['generalized_linear_models'] = generalized_linear_models

app.config.external_stylesheets = [dbc.themes.BOOTSTRAP]
use_api = dataiku.get_custom_variables()["use_api"]
api_node_url = dataiku.get_custom_variables()["api_node_url"]



if use_api == "True":
    client = dataikuapi.APINodeClient(api_node_url, "claim_risk")
else:
    claim_frequency = dataiku.Model("aHJZVrBQ")
    claim_frequency_predictor = claim_frequency.get_predictor()
    claim_severity = dataiku.Model("j1Mpq3TM")
    claim_severity_predictor = claim_severity.get_predictor()
    pure_premium = dataiku.Model("PVKl9xt0")
    pure_premium_predictor = pure_premium.get_predictor()


def process_data(data):
    data = pd.DataFrame.from_dict(data={k: [data[k]] for k in data})
    data['VehAge'] = data['VehAge'].clip(upper=20)
    data['DrivAge'] = data['DrivAge'].clip(lower=20, upper=90)
    data['VehPower'] = data['VehPower'].clip(upper=9).astype(str)
    data['VehAgeBin'] = pd.cut(data['VehAge'], bins=[0, 1, 10, 100], right=False, labels=['0 : 1', '1 : 10', '10 : 100'])
    data['DrivAgeBin'] = pd.cut(data['DrivAge'],  bins=[20, 21, 26, 31, 41, 51, 71, 100], right=False,
                                labels=['20 : 21', '21 : 26', '26 : 31', '31 : 41', '41 : 51', '51 : 71', '71 : 100'])
    data['LogDensity'] = np.log10(data['Density'])
    data['LogDensityBin'] = pd.cut(data['LogDensity'], right=False, bins=list(range(6)), labels=[str(i) for i in range(5)])
    data['BonusMalus'] = data['BonusMalus'].clip(upper=150)
    data['LogBonusMalus'] = np.log10(data['BonusMalus'])
    BonusMalusBins = list(range(50, 160, 10))
    data['BonusMalusBin'] = pd.cut(data['BonusMalus'], right=False, bins=BonusMalusBins, labels=[str(i) for i in BonusMalusBins[:-1]])
    return data

def make_prediction(predictor, data):
    prediction = predictor.predict(data)
    return prediction

def slider(min_value, max_value, step, marks, default, slider_id, title, description):
    return dbc.Row([dbc.Col([html.H6(title, style={'margin-bottom': '0em'}), 
                             html.P(description, style={'color': '#8f8f8f', 'font-size': '12px'})], md=6), 
                    dbc.Col(dcc.Slider(min_value,
                     max_value,
                     step,
                     marks=marks,
                     value=default,
                     id=slider_id,
    tooltip={"placement": "bottom", "always_visible": True}), md=6)], style={'margin-bottom': '0em'})    

def dropdown(options, default, dropdown_id, title, description):
    return dbc.Row([dbc.Col([html.H6(title, style={'margin-bottom': '0em'}), 
                             html.P(description, style={'color': '#8f8f8f', 'font-size': '12px'})], md=6),
                    dbc.Col(dcc.Dropdown(id=dropdown_id,
                       options=[{'label': value, 
                                 'value': value}
                         for value in options],
                       value=default), md=6, style={'font-size': '12px'})], 
                   style={'margin-bottom': '0em'})

vehicle_power_slider = slider(4, 15, 1,
                                   marks={'4': 4,
                     '8': 8,
                     '12': 12,
                     '15': 15},
               default=4,
               slider_id='vehicle-power',
              title='Vehicle Power',
              description='power of the vehicle, between 4 and 15, the higher the more powerful')

vehicle_age_slider = slider(0, 20, 1,
                                 marks={'0': 0,
                     '5': 5,
                     '10': 10,
                     '15': 15,
                     '20': 20},
               default=0,
               slider_id='vehicle-age',
              title='Vehicle Age',
           description='age of the vehicle, in years, capped at 20 years old')

driver_age_slider = slider(18, 99, 1,
                               marks={'18': 18,
                     '30': 30,
                     '60': 60,
                     '90': 90},
               default=42,
               slider_id='driver-age',
              title='Driver Age',
              description='driver\'s age, in years')

bonus_malus_slider = slider(50, 150, 1,
               marks={'50': 50,
                     '100': 100,
                     '150': 150},
               default=50,
               slider_id='bonus-malus',
               title='No-Claims Discount',
               description='No-Claims discount of the customer, the closer to 50 the better, capped at 150')

vehicle_brand_dropdown = dropdown(
                    ['B' + str(brand) for brand in [*range(1, 7), *range(10, 15)]],
                    default='B1',
                    dropdown_id='vehicle-brand',
                    title='Vehicle Brand',
                    description='brand of the vehicle')

vehicle_gas_dropdown = dropdown(
                        ['Regular', 'Diesel'],
                    default='Regular',
                    dropdown_id='vehicle-gas',
                    title='Fuel Type',
                    description='type of fuel used by vehicle')


density_slider = slider(0, 27000, 1,
               marks={'100': 100,
                     '1000': 1000,
                     '10000': 10000},
               default=100,
               slider_id='density',
              title='Density',
              description='population density of the city the customer lives in, in number of inhabitants per square-kilometer')

region_dropdown = dropdown(
                        ['Alsace',
                         'Aquitaine',
                         'Auvergne',
                         'Basse-Normandie',
                         'Bourgogne',
                         'Bretagne',
                         'Centre',
                         'Champagne-Ardenne',
                         'Corse',
                         'Franche-Comte',
                         'Haute-Normandie',
                         'Ile-de-France',
                         'Languedoc-Roussillon',
                         'Limousin',
                         'Midi-Pyrenees',
                         'Nord-Pas-de-Calais',
                         'Pays-de-la-Loire',
                         'Picardie',
                         'Poitou-Charentes',
                         'Provence-Alpes-Cote-D\'Azur',
                         'Rhone-Alpes'],
                    default='Centre',
                    dropdown_id='region',
                    title='Region',
                    description='region of the customer, to choose from the list of regions that were defined before 2016')

def render_prediction(name, value, description):
    return dbc.Row([
            dbc.Col([html.H6(name + ": ", style={'text-align': 'left', 'margin-bottom': '0em'}),
                    html.P(description, style={'color': '#8f8f8f', 'font-size': '12px'})], md=6),
            dbc.Col(html.H6(str(round(value, 2)), 
                    style={'margin-bottom': '2em', 'text-align': 'left', 'color': '#3b99fc'}), md=6)])

@app.callback(Output('result', 'children'),
              Input('vehicle-power', 'value'),
              Input('vehicle-age', 'value'),
              Input('driver-age', 'value'),
              Input('bonus-malus', 'value'),
              Input('vehicle-brand', 'value'),
              Input('vehicle-gas', 'value'),
              Input('density', 'value'),
              Input('region', 'value'))
def render_results(vehicle_power, vehicle_age,
                  driver_age, bonus_malus,
                  vehicle_brand, vehicle_gas,
                  density, region):
    if use_api == "True":
        record_to_predict = {
            "Exposure": "1",
            "ClaimNb": "1",
            "VehPower": str(vehicle_power),
            "VehAge": str(vehicle_age),
            "DrivAge": str(driver_age),
            "BonusMalus": str(bonus_malus),
            "VehBrand": str(vehicle_brand),
            "VehGas": str(vehicle_gas),
            "Density": str(density),
            "Region": str(region)
        }
        claim_frequency = client.predict_record("claim_frequency", record_to_predict)['result']['prediction']
        claim_severity = client.predict_record("claim_severity", record_to_predict)['result']['prediction']
        pure_premium = client.predict_record("pure_premium", record_to_predict)['result']['prediction']
    else:
        record_to_predict = {
            "Exposure": 1,
            "ClaimNb": 1,
            "VehPower": vehicle_power,
            "VehAge": vehicle_age,
            "DrivAge": driver_age,
            "BonusMalus": bonus_malus,
            "VehBrand": vehicle_brand,
            "VehGas": vehicle_gas,
            "Density": density,
            "Region": region
        }
        processed_data = process_data(record_to_predict)
        claim_frequency = make_prediction(claim_frequency_predictor, processed_data)['prediction'].iloc[0]
        claim_severity = make_prediction(claim_severity_predictor, processed_data)['prediction'].iloc[0]
        pure_premium = make_prediction(pure_premium_predictor, processed_data)['prediction'].iloc[0]
    output = dbc.Col([
        render_prediction('Claim Number Prediction', claim_frequency, 'expected yearly number of claims'),
        render_prediction('Claim Amount Prediction', claim_severity, 'expected claim amount per claim, in euros'),
        render_prediction('Compound Pure Premium Prediction', claim_frequency*claim_severity, 'expected yearly claim amount using the compound model, in euros'),
        render_prediction('Tweedie Pure Premium Prediction', pure_premium, 'expected yearly claim amount using the tweedie model, in euros')])
    return output


# build your Dash app
app.layout = dbc.Container(
    dbc.Col([
        html.H3("Claim Modeling App", style={'margin-bottom': '0em'}),
        html.P("On this interactive screen, the user can modify the input parameters of the models and check how they react. "+
               "The compound model and the tweedie model are compared side by side. " + 
               "The models can either be called with API calls if the models have been pushed on an API node or directly from the models deployed on the flow.",
              style={'color': '#8f8f8f', 'font-size': '14px', 'margin-top': '0em'}),
        html.Hr(),
        dbc.Row([
            dbc.Col([vehicle_power_slider,
                    vehicle_age_slider,
                    driver_age_slider,
                    bonus_malus_slider,
                    vehicle_brand_dropdown,
                    vehicle_gas_dropdown,
                    density_slider,
                    region_dropdown], md=6,
                   style={'border-right': '1px solid',
                         'border-right-color': '#e3e4e4'}),
            dbc.Col([dbc.Container(id='result'),
                   html.Div([html.H6("Explanation", style={'font-size': '15px'}),
                             html.P("For these four predictions, the model features are the ones set on the left-hand side and " + 
                                    "exposure is set to 1 to express risk over a one year period. " + 
                                    "The number of claims is predicted using the Claim Frequency model, "+
                                    "the output is the expected number of claims for the given parameters over a year. "+
                                    "The claim amount is using the Claim Severity model, it represents the expected amount of a given claim. " +
                                    "The compound pure premium prediction is simply the product of expected number of claims and expected claim amount, "+
                                    "the two above predictions. It represents the total risk in monetary value of the customer over a year. "+
                                    "The tweedie pure premium prediction computes the same quantity as the compound pure premium, "+
                                    "but using the Pure Premium Tweedie model which aims at directly modeling frequency times severity, "+
                                    "instead of separately.", 
                                    style={'font-size': '12px'})], 
                            style={'display': 'block', 
                                   'background': '#e6eef2',
                                    'padding': '20px',
                                    'border-radius': '5px',
                                    'color': '#31708f'})])
        ])
    ]), style={'font-family': 'Helvetica Neue'}
)

