# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import io
import logging
import dataiku
import numpy as np

import elasticsearch
from elasticsearch.helpers import scan

from project_utils import load, save, normalize, compute_embeddings

id_label = dataiku.get_custom_variables()["id_label"]
text_label = dataiku.get_custom_variables()["text_label"]

embeddings_folder = dataiku.Folder("P4SttKJS")
df = dataiku.Dataset("data").get_dataframe().set_index(id_label)
corpus_embeddings = normalize(load(embeddings_folder, "embeddings.npy"))
corpus_ids = load(embeddings_folder, "ids.npy")
idx2rank = {corpus_ids[k]: k for k in range(len(corpus_ids))}

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
elastic_user, elastic_password, cloud_id = None, None, None

client = dataiku.api_client()
auth_info = client.get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "elastic_user":
        elastic_user = secret["value"]
    elif secret["key"] == "elastic_password":
        elastic_password = secret["value"]
    elif secret["key"] == "cloud_id":
        cloud_id = secret["value"]

assert (
    elastic_user is not None
    and elastic_password is not None
    and cloud_id is not None
)

es = elasticsearch.Elasticsearch(
    cloud_id=cloud_id,
    basic_auth=(elastic_user, elastic_password),
    request_timeout = 60,
    max_retries=5,
    retry_on_timeout=True,
)

index = 'semantic-search'
if index not in es.indices.get(index="*"):
    mappings = {
        "properties": {
            "embedding": {
                "type": "dense_vector",
                "dims": corpus_embeddings.shape[1],
                "index": True,
                "similarity": "dot_product"
            },
            "title": {
                "type": "text",
                "index": False
            },
            text_label: {
                "type": "text"
            },
            "organization": {
                "type": "keyword"
            },
            "date": {
                "type": "date"
            },
            "category": {
                "type": "keyword"
            }
        }
    }
    response = es.indices.create(
        index=index,
        mappings=mappings,
        wait_for_active_shards=0
    )

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
query = {
    "query": {"match_all": {}}
}

indexed_ids = set()
for hit in scan(es, index=index, query=query, stored_fields=[], ):
    indexed_ids.add(int(hit["_id"]))

corpus_ids = set(corpus_ids)
to_remove = indexed_ids.difference(corpus_ids)
to_index = list(corpus_ids.difference(indexed_ids))

logging.info(f"{len(to_index)} documents to index")
logging.info(f"{len(indexed_ids) - len(to_remove)} indexed documents kept")
logging.info(f"{len(to_remove)} indexed documents discarded")

# Remove obsolete documents from the index
for idx in to_remove:
    es.delete(index=index, id=idx)

start, batch_size = 0, 1000
while start < len(to_index):
    operations = []
    for i in range(start, min(len(to_index), start + batch_size)):
        idx = to_index[i]
        operation = {"create": {"_index": index, "_id": idx}}
        organization = df.loc[idx].organization
        doc = {
            text_label: df.loc[idx][text_label],
            'title': df.loc[idx].title,
            'organization': [organization] if organization == organization else [],
            'date': str(df.loc[idx].date)[:10],
            'category': df.loc[idx].category.split(", "),
            'embedding': list(corpus_embeddings[idx2rank[idx], :]),
        }
        operations.append(operation)
        operations.append(doc)
    resp = es.bulk(index=index, operations=operations)
    start += batch_size

assert es.count(index=index)["count"] == len(corpus_ids)