# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import numpy as np
import pandas as pd
import torch
from transformers import pipeline

HYPOTHESIS_TEMPLATE = "This product review is about {}."
BATCH_SIZE = 8

device = 0 if torch.cuda.is_available() else -1
pipe = pipeline(model="facebook/bart-large-mnli", device=device)

df = dataiku.Dataset("test").get_dataframe()
labels = sorted(list(np.unique(df.label_text)))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
scores = np.zeros((0, len(labels)))
i = 0
while i < len(df):
    j = min(len(df), i + BATCH_SIZE)
    batch = pipe(
        [df.loc[idx, "text"] for idx in df.index[i:j]],
        candidate_labels=labels,
        hypothesis_template=HYPOTHESIS_TEMPLATE
    )
    for example in batch:
        scores = np.vstack(
            (
                scores,
                np.array(example["scores"])[np.argsort(example["labels"])]
            )
        )
    i += BATCH_SIZE

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
df["prediction"] = [labels[i] for i in np.argmax(scores, axis=1)]
scores_df = pd.DataFrame(scores, columns=[f"proba_{i}" for i in range(len(labels))])
df = pd.concat((df, scores_df), axis=1)
dataiku.Dataset("test_scored_zero-shot_nli").write_with_schema(df)