# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import dataiku
import os
import io
import pandas as pd
import numpy as np
import torch
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
from datasets import Dataset

df = dataiku.Dataset("train_small").get_dataframe()
labels = sorted(list(np.unique(df["label_text"])))
label2label_id = {labels[i]: i for i in range(len(labels))}
df["label"] = df["label_text"].apply(lambda s: label2label_id[s])
labels_df = pd.DataFrame.from_dict({
    "label": list(range(len(labels))),
    "label_text": labels
})
dataiku.Dataset("labels").write_with_schema(labels_df)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
train_ds = Dataset.from_pandas(df)

folder = dataiku.Folder("Ey4Ge7PK")

model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_ds,
    loss_class=CosineSimilarityLoss,
    batch_size=32,
    num_iterations=20,
    num_epochs=20
)

trainer.train()

buffer = io.BytesIO()
torch.save(model, buffer)
folder.upload_data("model.pt", buffer.getvalue())
buffer.close()