import dataiku
import pandas as pd, json
from dataiku import pandasutils as pdu
import kafka_wrapper

PQC_project_key = dataiku.get_custom_variables()["PQC_project_key"]
model_key = dataiku.get_custom_variables()["trained_model_key"]

# Read recipe inputs
input_stream = dataiku.StreamingEndpoint("process-data-stream")
kafka_consumer = kafka_wrapper.get_native_kafka_consumer(input_stream)


# Fetch deployed model
deployed_model = dataiku.Model(PQC_project_key+"."+model_key)
predictor = deployed_model.get_predictor()


# Create buffer for storing records
df_buff = pd.DataFrame()

# Events reception:

for f_event in kafka_consumer:
        # Extract the event data
        print('Receiving event:')
        print(f_event.value)
        f_event_data = json.loads(f_event.value)
        df = pd.DataFrame.from_records([f_event_data])
        # Make the prediction
        pred = predictor.predict(df)
        # Add the prediction result to the event
        df['prediction'] = pred['prediction'][0]
        df['proba_0'] = pred['proba_0'][0]
        df['proba_1'] = pred['proba_1'][0]
        break
df_buff = df_buff.append(df)

#Set output dataset to append instead of overwrite:
recipe_output = dataiku.Dataset("enriched-process-data")
recipe_output.spec_item["appendMode"] = True

# Setting the schema before writing:
## Retrieve the dtypes:
dtypes = df_buff.dtypes.to_dict()


## Convert the dtypes to dataiku data types:
for name, dtype in dtypes.items():
    print(str(dtype))
    if 'date' in str(dtype):
        dtypes[name] = 'date'
    elif 'float' in str(dtype):
        dtypes[name] = 'double'
    elif 'int' in str(dtype):
        dtypes[name] = 'bigint'
    elif "object" in str(dtype):
        dtypes[name] = 'string'


## format the dtypes dictionary
dtypes_formatted = [{"name": name, "type": str(dtype)} for name, dtype in dtypes.items()]

print("dtypes_formatted:",dtypes_formatted)
print("df_buff cols:",df_buff.columns)

## set schema:
from dku_utils.core import get_current_project_and_variables
from dku_utils.datasets.dataset_commons import set_dataset_schema

project, variables = get_current_project_and_variables()

set_dataset_schema(project,"enriched-process-data", dtypes_formatted)


with recipe_output.get_continuous_writer("enriched-process-data") as writer:
    writer.write_dataframe(df_buff)
    writer.checkpoint("")
    print("state:",writer.get_state())
    

# Recompute the metrics of the dataset everytime I write some rows:
client = dataiku.api_client()
project = client.get_default_project() 

ds = project.get_dataset("enriched-process-data")  

# Compute dataset metrics
ds.compute_metrics(
    metric_ids=['records:COUNT_RECORDS']
)
    
del df
