# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import dataiku
import io
import numpy as np
import torch
from datasets import Dataset

folder = dataiku.Folder("Ey4Ge7PK")
df = dataiku.Dataset("test").get_dataframe()
labels_df = dataiku.Dataset("labels").get_dataframe()

labels = sorted(list(labels_df.label_text))

test_ds = Dataset.from_pandas(df)

with folder.get_download_stream("model.pt") as stream:
    model = torch.load(
        io.BytesIO(stream.read()),
        map_location=torch.device('cpu')
    )
model.model_body._target_device = torch.device('cpu')

result = model.predict_proba(test_ds["text"]).numpy()

for i in range(len(df)):
    idx = df.index[i]
    for j in range(result.shape[1]):
        score = result[i, j]
        df.loc[idx, f"proba_{j}"] = score
    df.loc[idx, "prediction"] = labels[np.argmax(result[i, :])]

dataiku.Dataset("test_scored_few-shot_setfit").write_with_schema(df)