# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import pandas as pd, numpy as np
from dataiku import pandasutils as pdu
from datetime import datetime
from plotly_calplot import calplot
import plotly
from dataiku import insights
import plotly.graph_objs as go

# Get inputs
df = dataiku.Dataset("last_events_scored").get_dataframe()

model = dataiku.Model("FjRNGzZC")
cph = model.get_predictor()._clf.fitted_model

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
project = dataiku.Project()
variables = project.get_variables()
end_date = variables['local']['end_date']

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Convert remaining_useful_life to timedelta and add to current datetime
df['failure_time'] = (pd.to_datetime(end_date) + pd.to_timedelta(df['prediction'], unit='d')).dt.normalize()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Group by date and count the number of failures, and collect equipment IDs
def truncate_equipment_ids(equipment_ids, max_length=50):
    equipment_str = ', '.join(map(str, equipment_ids))
    return equipment_str[:max_length] + '...' if len(equipment_str) > max_length else equipment_str

daily_failures = df.groupby('failure_time').agg(
    failures=pd.NamedAgg(column='equipment_id', aggfunc='size'),
    equipment_ids=pd.NamedAgg(column='equipment_id', aggfunc=lambda x: truncate_equipment_ids(list(x)))
).reset_index()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Detect the start and end months from the dataset
start_date = daily_failures['failure_time'].min()
end_date = daily_failures['failure_time'].max()

start_month = start_date.month
end_month = end_date.month

start_year = start_date.year
end_year = end_date.year

# Creating the plot
if start_year == end_year:
    fig = calplot(daily_failures,
                  x="failure_time",
                  y="failures",
                  years_title=True,
                  colorscale="Teal",
                  month_lines_width=2,
                  showscale=True,
                  start_month=start_month,
                  end_month=end_month,
                  text="equipment_ids",
                  name="Maintenance operations needed")
else:
    fig = calplot(daily_failures,
                  x="failure_time",
                  y="failures",
                  years_title=True,
                  colorscale="Teal",
                  month_lines_width=2,
                  showscale=True,
                  text="equipment_ids",
                  name="Maintenance operations needed")

# Replace NaN values with a specific value in the z data
for trace in fig.data:
    if isinstance(trace, go.Heatmap):
        trace.z = np.nan_to_num(trace.z, nan=0)

# Modify the colorscale to map the specific value to grey
for trace in fig.data:
    if isinstance(trace, go.Heatmap):
        colorscale = [item for item in trace.colorscale if item[0] != 0]
        colorscale.insert(0, [0, '#eeeeef'])
        trace.colorscale = colorscale

graph_html = "<meta charset='UTF-8'>" + plotly.offline.plot(fig, output_type="div")
insights.save_data(id="Maintenance_schedule", payload=graph_html, content_type="text/html", label="Maintenance_schedule")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Prepare resulting dataframe
df = dataiku.Dataset("last_events_scored").get_dataframe()
df = df[["equipment_id", "up_time", "prediction"]]
df.columns = ["Equipment ID", "Up Time", "Remaining Useful Life"]
df = df.sort_values("Remaining Useful Life", ascending=True)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Write result
dataiku.Dataset("remaining_useful_life").write_with_schema(df)