import plotly.express as px
import dataiku
from dash import dcc
from dash import html
import pandas as pd
import numpy as np
from dataiku import insights
from scenario_computation import run_what_if_exploration
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
import plotly.graph_objects as go

import math
import re
from decimal import Decimal
from styles import get_description_style, get_button_style, get_card_style, get_main_style
import math

millnames = ['',' k',' M',' B']
color_discrete_sequence=[
    "#875EAF",
    "#82BCE6",
    "#D1A9DA",
    "#23A373",
    "#3A69DA",
    "#8ABB4C",
    "#FBCD22",
    "#E96709",
    "#F8A217",
    "#C6302A"    
]


def millify(n):
    n = float(n)
    millidx = max(0,min(len(millnames)-1,
                        int(math.floor(0 if n == 0 else math.log10(abs(n))/3))))

    return '{:.0f}{}'.format(n / 10**(3 * millidx), millnames[millidx])

app.config.external_stylesheets = [dbc.themes.BOOTSTRAP]


def slider(min_value, max_value, default, slider_id, title, description):
    return dbc.Row([dbc.Col([dbc.Row([html.H6(title, style={'margin-bottom': '1em'}),]), 
                             #html.P(description, style={'color': '#8f8f8f', 'font-size': '12px'})
                            
                            #, md=2
                           
                    dbc.Row(
                        dcc.RangeSlider(
                            min_value,
                            max_value,
                            #step=None,
                            value=default,
                            id=slider_id,
                            tooltip={"placement": "bottom", "always_visible": True,"template": "{value} %"},
                            marks= {
                                -100:"-100%",
                                0:"0%",
                                100:"+100%"
                            }                   
                        ),                            
                        #, md=10
                           )
                   ])], style={'margin-bottom': '1em'})    

project = dataiku.api_client().get_default_project()
MEDIA_CHANNELS_SELECTED = project.get_variables().get('standard').get("media_channels_selected")
variables = project.get_variables()

spend_constraint_lower = variables['standard']["spend_constraint_lower"]
spend_constraint_upper = variables['standard']["spend_constraint_upper"]

    
default_indicator = go.Figure(
        go.Indicator(
            mode="number+delta",
            value=0,
            delta={"reference": 0, "relative": True,"valueformat": "2.2%"},
        )
    )
default_indicator.update_layout(
        margin=dict(l=0, r=0, t=0, b=10),
        height=50
    )
panel_text = "Use the sliders or type in values to refine media channel inputs. The default values are based on your optimal calculated value."


sliders = [
    slider(
        min_value=-100,
        max_value=100,
        default=[-spend_constraint_lower*100, spend_constraint_upper*100],
        slider_id=media,
        title=media,
        description=f'{media} slider',
    )
    for media in MEDIA_CHANNELS_SELECTED
]
rerun_button = html.Button('Apply Changes', id='button',style=get_button_style())
panel = [
    html.H4("Mixed media refinement",style={"font-weight": "400", "font-size": "18px"}),
    html.P(panel_text, style={'font-size': "0.8rem"})
]
panel+=sliders

combined_elements = html.Div([
    dbc.Label("Constraint"),
    dbc.RadioItems(
        id="constraint-radio-selection",
        options=[
            {"label": "None", "value": "none"},
            {"label": "Total Budget", "value": "budget"},
            {"label": "Target ROI", "value": "roi"},
        ],
        value="none", # Default to 'None'
        className="mb-3"
    ),

    # Input container for Total Budget, hidden by default
    html.Div(
        dbc.InputGroup(
            [
                dbc.InputGroupText("Total Budget"),
                dcc.Input(
                    id="input_budget",
                    type="number",
                    value=variables['standard']["total_current_budget"],
                ),
            ],
        ),
        id="budget-input-container",
        style={'display': 'none'}  # This container is hidden initially
    ),

    # Input container for Target ROI, hidden by default
    html.Div(
        dbc.InputGroup(
            [
                dbc.InputGroupText("Target ROI"),
                dcc.Input(
                    id="target_roi",
                    type="number",
                    value=1.1,
                ),
            ],
        ),
        id="roi-input-container",
        style={'display': 'none'}  # This container is hidden initially
    )
])

# --- A single, simple callback to control which input is visible ---
@app.callback(
    Output('budget-input-container', 'style'),
    Output('roi-input-container', 'style'),
    Input('constraint-radio-selection', 'value')
)
def toggle_input_visibility(selection):
    # When a radio button is selected, this function shows the
    # corresponding input field and keeps the other one hidden.

    if selection == 'budget':
        # Show budget input, hide ROI input
        return {'display': 'block', 'marginBottom': '15px'}, {'display': 'none'}
    elif selection == 'roi':
        # Show ROI input, hide budget input
        return {'display': 'none'}, {'display': 'block', 'marginBottom': '15px'}
    else: # 'none'
        # Hide both inputs
        return {'display': 'none'}, {'display': 'none'}


    
panel.append(combined_elements)
#panel.append(target_roi)

panel.append(rerun_button)



current_tooltip = dbc.Tooltip(
    "Current: Marketing spending and predicted sales according to the average spending over the historical period.",
    target = "current-title-sales"
)
reference_tooltip = dbc.Tooltip(
    "Reference: Marketing spending and predicted sales according to the optimized scenario with fixed exploration bounds.",
     target = "reference-title-sales"

)
custom_tooltip = dbc.Tooltip(
    "Custom: Marketing spending and predicted sales according to the optimized scenario with exploration bounds defined by the sliders in the webapp.",
    target = "custom-title-sales"
)
main_size ="22px"
sub_size ="12px"

previous_sales_card = dbc.Card(
    dbc.CardBody(
        [
            html.H4("Total Sales Prediction",style={'font-size': main_size}),
            dbc.Row([
                dbc.Col([
                    html.H6("Current",id="current-title-sales",style={'font-size': sub_size}),
                    current_tooltip,
                    dcc.Graph(figure=default_indicator, id="previous_sales_change"),
                    
                ],md=4,
                ),
                dbc.Col([
                    html.H6("Reference",id="reference-title-sales",style={'font-size': sub_size}),
                    reference_tooltip,
                    dcc.Graph(figure=default_indicator, id="optimal_sales_change"),
                    
                ],md=4,),                
                dbc.Col([
                    html.H6("Custom",id="custom-title-sales",style={'font-size': sub_size}),
                    custom_tooltip,
                    dcc.Graph(figure=default_indicator, id="custom_sales_change"),
                    
                ],md=4,
                ),

            ])

        ],
    ),
    style=get_card_style(),
)
previous_budget_card = dbc.Card(
    dbc.CardBody(
        [
            html.H4("Total Marketing Budget",style={'font-size': main_size}),
            dbc.Row([
                dbc.Col([
                    html.H6("Current",style={'font-size': sub_size}),
                    dcc.Graph(figure=default_indicator, id="previous_spending_change"),
                    
                ],md=4,),
                dbc.Col([
                    html.H6("Reference",style={'font-size': sub_size}),
                    dcc.Graph(figure=default_indicator, id="optimal_spending_change"),
                    
                ],md=4,),
                dbc.Col([
                    html.H6("Custom",style={'font-size': sub_size}),
                    dcc.Graph(figure=default_indicator, id="custom_spending_change"),
                    
                ], md=4,
                    
                ),

            ])

        ]
    ),
    style=get_card_style(),
)

roi_card = dbc.Card(
    dbc.CardBody(
        [
            html.H4("Total Marketing ROI",style={'font-size': main_size}),
            dbc.Row([
                dbc.Col([
                    html.H6("Current",style={'font-size': sub_size}),
                    dcc.Graph(figure=default_indicator, id="previous_roi_change"),

                ],md=4,),
                dbc.Col([
                    html.H6("Reference",style={'font-size': sub_size}),
                    dcc.Graph(figure=default_indicator, id="optimal_roi_change"),

                ],md=4,),                
                dbc.Col([
                    html.H6("Custom",style={'font-size': sub_size}),
                    dcc.Graph(figure=default_indicator, id="custom_roi_change"),

                ],md=4,),

            ])

        ]
    ),
    style=get_card_style(),
)
#sliders.append(previous_sales_card)
#sliders.append(previous_budget_card)

@app.callback(
    Output(component_id='left-graph', component_property='figure'),
    Output(component_id='right-graph', component_property='figure'),
    Output(component_id='previous_sales_change', component_property='figure'),
    Output(component_id='custom_sales_change', component_property='figure'),
    Output(component_id='optimal_sales_change', component_property='figure'),
    Output(component_id='previous_spending_change', component_property='figure'),
    Output(component_id='custom_spending_change', component_property='figure'),
    Output(component_id='optimal_spending_change', component_property='figure'),
    Output(component_id='previous_roi_change', component_property='figure'),
    Output(component_id='custom_roi_change', component_property='figure'),
    Output(component_id='optimal_roi_change', component_property='figure'),
    Input('button', 'n_clicks'),
    [State(component_id=media, component_property='value') for media in MEDIA_CHANNELS_SELECTED],
    State(component_id="input_budget", component_property='value'),
    State(component_id="target_roi", component_property='value'),
    State(component_id="constraint-radio-selection", component_property='value')    
)
def update_output_div(n_clicks, *arg):
    
    optimization_args = {}
    optimization_args["ranges"] = arg[:-3]
    optimization_args["target_roi"] = arg[-2]
    optimization_args["new_budget"] = arg[-3]
    optimization_args["constraint_type"] = arg[-1]
    
    variables['standard']["optimization_args"] = optimization_args
    project.set_variables(variables)
    
    

    df = run_what_if_exploration(
        ranges=optimization_args["ranges"],
        constraint_type=optimization_args["constraint_type"],
        target_roi=optimization_args["target_roi"],
        new_budget= optimization_args["new_budget"],
    )

    #fig = px.bar(df, x="Media Channels", y=["Initial budget Weekly", "Scenario budget Weekly","OPT budget Weekly"], barmode="group")
    wi_sales = df["Custom Sales"].drop_duplicates().values[0]
    goal_solution = df["Reference Sales"].drop_duplicates().values[0]
    goal_pre_op = df["Current Sales"].drop_duplicates().values[0]
    
    previous_weekly_budget = sum(df["Current Budget"])
    custom_weekly_budget = sum(df["Custom Budget"])
    optimal_weekly_budget = sum(df["Reference Budget"])

    df_copy =df.copy(deep=True)
    
    df_copy["init_budget_share"] = df["Current Budget"]/previous_weekly_budget*100
    df_copy["scenario_budget_share"] = df["Custom Budget"]/custom_weekly_budget*100
    df_copy["opti_budget_share"] = df["Reference Budget"]/optimal_weekly_budget*100
    
    df_tr = df_copy.set_index('Media Channels').T
    df_tr = df_tr.loc[["Current Budget","Reference Budget","Custom Budget"]]
    df_tr.rename(inplace=True, index={"Current Budget": "Current", "Custom Budget": "Custom", "Reference Budget": "Reference"})

    df_copy.rename(inplace=True, columns={"init_budget_share": "Current", "opti_budget_share": "Reference", "scenario_budget_share": "Custom"})

    df_long=df_copy.melt(id_vars="Media Channels",value_vars=["Current","Reference","Custom"],var_name="Scenario",value_name="Budget Share")    
    df_long=df_long.reset_index(drop=True)
    df_long["Budget Share"] = df_long["Budget Share"].apply(float)
    
    fig_side = px.bar(df_long,
                      x="Media Channels",
                      y="Budget Share",
                      color="Scenario",
                      barmode="group",
                      template="plotly_white",
                      labels={
                         "Media Channels": "",
                         "Budget Share": "Share of total budget %",
                     },
                     color_discrete_sequence=["black", "skyblue", "blue"]
                     )
    fig_side = fig_side.update_layout(
        title={
            'text' : "Budget share by media channel",
            'x':0.5,
            'xanchor': 'center',
        },
        legend={"orientation":"h", "xanchor":"center", "x":0.5, "y":-0.1}
    )
    
    fig_stack = px.bar(df_tr,template="plotly_white", 
                       labels={
                           "index": "",
                           "value": "Total Budget",
                       },
                       color_discrete_sequence=color_discrete_sequence
                      )
    fig_stack = fig_stack.update_layout(
        title={
            'text' : "Total marketing budget",
            'x':0.5,
            'xanchor': 'center'
        },
        legend={"orientation":"h", "xanchor":"center", "x":0.5, "y":-0.1}
    )
    
    
    indicator_custom = go.Figure(
        go.Indicator(
            mode="number+delta",
            value=wi_sales,
            delta={"reference": goal_pre_op, "relative": True,"valueformat": "2.2%",'font': {'size': 12}},
            number={'font': {'size': 15}}, 
        )
    )
    bottom_margin = 0
    indicator_custom.update_layout(
        margin=dict(l=0, r=0, t=0, b=bottom_margin),
        height=50
    )
    #indicator_custom.update_traces(delta_increasing_symbol="<br><br>▲", selector=dict(type='indicator'))
    
    indicator_previous = go.Figure(
        go.Indicator(
            mode="number+delta",
            value=goal_pre_op,
            delta={"reference": goal_pre_op, "relative": True,"valueformat": "2.2%",'font': {'size': 12}},
            number={'font': {'size': 15}}, 
        )
    )
    indicator_previous.update_layout(
        margin=dict(l=0, r=0, t=0, b=bottom_margin),
        height=50
    )
    #indicator_previous.update_traces(delta_increasing_symbol="<br><br>▲", selector=dict(type='indicator'))

    indicator_optimal = go.Figure(
        go.Indicator(
            mode="number+delta",
            value=goal_solution,
            delta={"reference": goal_pre_op, "relative": True,"valueformat": "2.2%",'font': {'size': 12}},
            number={'font': {'size': 15}}, 
        )
    )
    indicator_optimal.update_layout(
        margin=dict(l=0, r=0, t=0, b=bottom_margin),
        height=50
    )
    #indicator_optimal.update_traces(delta_increasing_symbol="<br><br>▲", selector=dict(type='indicator'))

    
    indicator_previous_budget = go.Figure(
        go.Indicator(
            mode="number+delta",
            value=previous_weekly_budget,
            delta={"reference": previous_weekly_budget, "relative": True,"valueformat": "2.2%",'font': {'size': 12}},
            number={'font': {'size': 15}},
        )
    )
    indicator_previous_budget.update_layout(
        margin=dict(l=0, r=0, t=0, b=bottom_margin),
        height=50
    )      
    
    indicator_custom_budget = go.Figure(
        go.Indicator(
            mode="number+delta",
            value=custom_weekly_budget,
            delta={"reference": previous_weekly_budget, "relative": True,"valueformat": "2.2%",'font': {'size': 12}},
            number={'font': {'size': 15}},
        )
    )
    indicator_custom_budget.update_layout(
        margin=dict(l=0, r=0, t=0, b=bottom_margin),
        height=50
    ) 
    
    indicator_optimal_budget = go.Figure(
        go.Indicator(
            mode="number+delta",
            value=optimal_weekly_budget,
            delta={"reference": previous_weekly_budget, "relative": True,"valueformat": "2.2%",'font': {'size': 12}},
            number={'font': {'size': 15}}, 
        )
    )
    indicator_optimal_budget.update_layout(
        margin=dict(l=0, r=0, t=0, b=bottom_margin),
        height=50
    ) 
    
    indicator_previous_roi = go.Figure(
        go.Indicator(
            mode="number+delta",
            value=goal_pre_op/previous_weekly_budget,
            delta={"reference": goal_pre_op/previous_weekly_budget, "relative": True,"valueformat": "2.2%",'font': {'size': 12}},
            number={'font': {'size': 15}}, 
        )
    )
    indicator_previous_roi.update_layout(
        margin=dict(l=0, r=0, t=0, b=bottom_margin),
        height=50
    )      
    
    indicator_custom_roi = go.Figure(
        go.Indicator(
            mode="number+delta",
            value=wi_sales/custom_weekly_budget,
            delta={"reference": goal_pre_op/previous_weekly_budget, "relative": True,"valueformat": "2.2%",'font': {'size': 12}},
            number={'font': {'size': 15}}, 
        )
    )
    indicator_custom_roi.update_layout(
        margin=dict(l=0, r=0, t=0, b=bottom_margin),
        height=50
    ) 
    
    indicator_optimal_roi = go.Figure(
        go.Indicator(
            mode="number+delta",
            value=goal_solution/optimal_weekly_budget,
            delta={"reference": goal_pre_op/previous_weekly_budget, "relative": True,"valueformat": "2.2%",'font': {'size': 12}},
            number={'font': {'size': 15}},
        )
    )
    indicator_optimal_roi.update_layout(
        margin=dict(l=0, r=0, t=0, b=bottom_margin),
        height=50
    )      
    
    return fig_stack,fig_side, indicator_previous, indicator_custom, indicator_optimal,indicator_previous_budget,indicator_custom_budget, indicator_optimal_budget,indicator_previous_roi, indicator_custom_roi,indicator_optimal_roi 

app.layout = dbc.Container(
    [
        dbc.Row([
        #html.H3("Marketing Mix Modeling", style={'margin-bottom': '0em', 'text-align': 'center'}),
        #dcc.Markdown("Placeholder",id="description", style=get_description_style()),
        html.Hr(),]),
        dbc.Row([
            dbc.Col(panel,
                    md=2,
                   style={'border-right': '1px solid',
                         'border-right-color': '#e3e4e4'}),
            dbc.Col([
                dbc.Row([dbc.Col(previous_sales_card,md=4), dbc.Col(previous_budget_card,md=4),dbc.Col(roi_card,md=4)]),
                dbc.Row([dcc.Loading([
                       #html.H6("Results", style={'font-size': '15px'}),
                       #previous_sales_card,#html.P("sales:",id="sales"),
                       #previous_budget_card,#html.P("sales:",id="budget"),
                       dbc.Row([dbc.Col(dcc.Graph(
                                    id='left-graph',
                                ),md=6),
                                dbc.Col(dcc.Graph(
                                    id='right-graph',
                                ),md=6),
                                
                               ])
                             ], 
                            style={'display': 'block', 
                                   'background': '#e6eef2',
                                    'padding': '20px',
                                    'border-radius': '5px',
                                    'color': '#31708f'},
                           type="circle",)]),
                    ],
                    md=10,
                   )
        ])
    ], style=get_main_style(),fluid=True
)
