# -*- coding: utf-8 -*-
import os
import logging
from tempfile import NamedTemporaryFile
from io import BytesIO
from time import time

import pandas as pd
import numpy as np
from deepspeech import Model
from scipy.io import wavfile
from scipy.signal import resample
from tqdm import tqdm

import dataiku
from dataiku.customrecipe import get_recipe_config, get_input_names_for_role, get_output_names_for_role

from dku_io_utils import generate_path_list

config = get_recipe_config()

# Read recipe inputs
audio_folder = dataiku.Folder(get_input_names_for_role("audios")[0])
model_folder = dataiku.Folder(get_input_names_for_role("weights")[0])
model_folder_type = model_folder.get_info().get("type")

# Model configuration
logging.info("Loading DeepSpeech model from {}...".format(model_folder_type))
start = time()
model_file = "model.pbmm"
if model_folder_type == "Filesystem":
    model = Model(os.path.join(model_folder.get_path(), model_file))
else:
    with model_folder.get_download_stream(model_file) as stream:
        with NamedTemporaryFile(suffix=model_file) as temp:
            temp.write(stream.read())
            model = Model(temp.name)
if "BEAM_WIDTH" in config:
    model.setBeamWidth(config["BEAM_WIDTH"])
if "WORD_COUNT_WEIGHT" in config and "LM_WEIGHT" in config:
    model.setScorerAlphaBeta(config["LM_WEIGHT"], config["WORD_COUNT_WEIGHT"])
deepspeech_samplerate = model.sampleRate()
logging.info("Loading DeepSpeech model: Done in {:.2f} seconds.".format(time() - start))

# Apply model to transcribe audio files
input_path_list = [p for p in generate_path_list(audio_folder) if os.path.splitext(p)[1].lower() == ".wav"]
if len(input_path_list) == 0:
    raise ValueError("No .WAV audio files in the input folder")
output_df = pd.DataFrame(columns=["path", "text", "comment"])
logging.info("Transcribing {:d} audio files...".format(len(input_path_list)))
start = time()
for i, path in enumerate(tqdm(input_path_list)):
    (text, comment) = ("", "")
    with audio_folder.get_download_stream(path) as stream:
        samplerate, audio = wavfile.read(BytesIO(stream.read()))
    if len(audio.shape) == 2:
        comment = comment + "Found {} channels, kept only the first one. ".format(audio.shape[1])
        audio = audio[:, 0]
    if samplerate != deepspeech_samplerate:
        comment = comment + "Resampled from {} Hz to {} Hz. ".format(samplerate, deepspeech_samplerate)
        audio = resample(audio, round(audio.shape[0] * float(deepspeech_samplerate) / samplerate))
    text = model.stt(audio.astype(np.int16))
    output_df.loc[i] = [path, text, comment]
logging.info("Transcribing {:d} audio files: Done in {:.2f} seconds.".format(i + 1, time() - start))

# Write recipe outputs
detected_texts = dataiku.Dataset(get_output_names_for_role("text")[0])
detected_texts.write_with_schema(output_df)
