# -*- coding: utf-8 -*-
import dataiku
import pandas as pd, numpy as np
from dataiku import pandasutils as pdu
import tensorflow as tf
import tensorflow_probability as tfp
import arviz as az
import io
import pickle
from dataiku import insights

from meridian import constants
from meridian.data import load
from meridian.data import test_utils
from meridian.model import model
from meridian.model import spec
from meridian.model import prior_distribution
from meridian.analysis import optimizer
from meridian.analysis import analyzer
from meridian.analysis import visualizer
from meridian.analysis import summarizer
from meridian.analysis import formatter

project = dataiku.api_client().get_default_project()
variables = project.get_variables()



# Read recipe inputs
geo_all_channels = dataiku.Dataset("input_prepared")
geo_all_channels_df = geo_all_channels.get_dataframe()

coord_to_columns = load.CoordToColumns(
    time='time',
    geo='geo',
    controls=['GQV', 'Competitor_Sales'],
    population='population',
    kpi='conversions',
    revenue_per_kpi='revenue_per_conversion',
    media=[
        'Channel0_impression',
        'Channel1_impression',
        'Channel2_impression',
        'Channel3_impression',
        'Channel4_impression',
    ],
    media_spend=[
        'Channel0_spend',
        'Channel1_spend',
        'Channel2_spend',
        'Channel3_spend',
        'Channel4_spend',
    ],
    organic_media=['Organic_channel0_impression'],
    non_media_treatments=['Promo'],
)

correct_media_to_channel = {
    'Channel0_impression': 'Channel_0',
    'Channel1_impression': 'Channel_1',
    'Channel2_impression': 'Channel_2',
    'Channel3_impression': 'Channel_3',
    'Channel4_impression': 'Channel_4',
}
correct_media_spend_to_channel = {
    'Channel0_spend': 'Channel_0',
    'Channel1_spend': 'Channel_1',
    'Channel2_spend': 'Channel_2',
    'Channel3_spend': 'Channel_3',
    'Channel4_spend': 'Channel_4',
}

loader = load.DataFrameDataLoader(
    df= geo_all_channels_df,
    kpi_type='non_revenue',
    coord_to_columns=coord_to_columns,
    media_to_channel=correct_media_to_channel,
    media_spend_to_channel=correct_media_spend_to_channel,
)

data = loader.load()

roi_mu = 0.2     # Mu for ROI prior for each media channel.
roi_sigma = 0.9  # Sigma for ROI prior for each media channel.
prior = prior_distribution.PriorDistribution(
    roi_m=tfp.distributions.LogNormal(roi_mu, roi_sigma, name=constants.ROI_M)
)
model_spec = spec.ModelSpec(prior=prior)

mmm = model.Meridian(input_data=data, model_spec=model_spec)

mmm.sample_prior(500)
mmm.sample_posterior(n_chains=7, n_adapt=500, n_burnin=500, n_keep=1000, seed=1)

mmm_summarizer = summarizer.Summarizer(mmm)
model_summary = mmm_summarizer._gen_model_results_summary(variables['standard']["min_date_eval_app"], variables['standard']["max_date_eval_app"])
encoded_summary = model_summary.encode('utf-8')



# Write recipe outputs
meridian_data = dataiku.Folder("ajsbTNVY")
meridian_data_info = meridian_data.get_info()

with meridian_data.get_writer("model_summary_meridian_data.html") as w:
    w.write(encoded_summary)

pickle_bytes = io.BytesIO()
pickle.dump(mmm, pickle_bytes)    
with meridian_data.get_writer("model.pkl") as w:
    w.write(pickle_bytes.getvalue())  

pickle_bytes.close()
  
budget_optimizer = optimizer.BudgetOptimizer(mmm)
optimization_results = budget_optimizer.optimize(
    selected_times=[variables['standard']["min_date_eval_app"], variables['standard']["max_date_eval_app"]],
    spend_constraint_lower=variables['standard']["spend_constraint_lower"],
    spend_constraint_upper=variables['standard']["spend_constraint_upper"],
    batch_size=500,
    gtol = 0.001,
)

variables['standard']["opt_target"] = int(optimization_results.optimized_data.total_incremental_outcome)
variables['standard']["opt_spend"] = list(map(int,optimization_results.optimized_data.spend.values))

variables['standard']["current_spend"] = list(map(int,optimization_results.nonoptimized_data.spend.values))
variables['standard']["current_sales"] = int(optimization_results.nonoptimized_data.total_incremental_outcome)

variables['standard']["total_current_budget"] = int(sum(optimization_results.nonoptimized_data.spend.values)) 
variables['standard']["media_channels_selected"] = list(optimization_results.nonoptimized_data.channel.values)

optimization_args = {}
optimization_args["ranges"] = [(-variables['standard']["spend_constraint_lower"]*100, variables['standard']["spend_constraint_upper"]*100) for _ in correct_media_to_channel]
optimization_args["target_roi"] = None
optimization_args["new_budget"] = None
optimization_args["constraint_type"] = "none"

variables['standard']["optimization_args"] = optimization_args


project.set_variables(variables)



optim_file = optimization_results._gen_optimization_summary()
     
bb = optim_file.encode('utf-8')
with meridian_data.get_writer("optim_file_meridian_data.html") as w:
    w.write(bb)

insights.save_data(
    "model_summary",
    payload=model_summary,
    content_type="text/html",
    label=None,
    project_key=None,
    encoding=None
)

insights.save_data(
    "optim_file",
    payload=optim_file,
    content_type="text/html",
    label=None,
    project_key=None,
    encoding=None
)