import io
import numpy as np
import torch

def mean_pooling(model_output, attention_mask):
    """
    Pool the outputs of the model to get sentence embeddings
    """
    token_embeddings = model_output[0] # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def compute_embeddings(model, tokenizer, sentences):
    """
    Compute embeddings for a list of strings
    """
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**encoded_input)
    return np.array(mean_pooling(model_output, encoded_input['attention_mask']))

def save(folder, filename, arr):
    """
    Save a numpy array in a Dataiku folder
    """
    buf = io.BytesIO()
    np.save(buf, arr)
    folder.upload_data(filename, buf.getvalue())
    buf.close()

def load(folder, filename):
    """
    Load a numpy array from a Dataiku folder
    """
    with folder.get_download_stream(filename) as stream:
        arr = np.load(io.BytesIO(stream.read()))
    return arr

def normalize(arr):
    """
    L2-normalize each row vector of a matrix
    """
    return arr/(1e-6 + np.linalg.norm(arr, ord=2, axis=1, keepdims=True))