import dataiku
import dataikuapi
import openai
import os

import re
from urllib.parse import urlparse, urlunparse
import tempfile

from dataikuapi.utils import DataikuException
from text2sql import get_answer
from langchain.tools import Tool
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from urllib.parse import urlparse
from langchain import PromptTemplate, LLMChain

NEW_LINE = "\n"


def create_tool_from_python_endpoint(
    url, description, name=None, arguments_names=None, api_key="", **kwargs
):
    """
    Create a tool that calls an arbitrary Python function endpoint
    """

    parsed_url = urlparse(url)
    path_parts = parsed_url.path.split("/")
    host = f"{parsed_url.scheme}://{parsed_url.hostname}:{parsed_url.port}"
    service = path_parts[4]
    if name is None:
        name = path_parts[5]
    client = dataikuapi.APINodeClient(host, service, api_key)
    if (
        arguments_names is None
    ):  # If the list of arguments is not provided, the function tries to get it by calling the endpoint (could cause issues if calling the endpoint without any argument causes problems)
        arguments_names = []
        try:
            return client.run_function(name).get("reponse")
        except DataikuException as e:
            error = str(e)
            pattern = r"'(.*?)'"
            inputs = re.findall(pattern, error)
            for i in range(len(inputs)):
                if i > 0:
                    arguments_names.append(inputs[i].strip("'"))

    def run(endpoint_input: str):
        args = endpoint_input.split(
            ","
        )  # Parse the input given by the LLM, does not work if there is a comma inside an argument
        try:
            for index, value in enumerate(arguments_names):
                kwargs[value] = args[index]
            return client.run_function(name, **kwargs).get("response")
        except DataikuException as e:
            return f"Error: {e}"

    tool = Tool.from_function(func=run, name=name, description=description)

    return tool


def create_tool_from_model_endpoint(
    url,
    description,
    name=None,
    post_process=lambda x: x,
    api_key="",
    key="id",
    value=None,
):
    """
    Create a tool that calls a model endpoint.
    """
    parsed_url = urlparse(url)
    path_parts = parsed_url.path.split("/")
    host = f"{parsed_url.scheme}://{parsed_url.hostname}:{parsed_url.port}"
    service = path_parts[4]
    if name is None:
        name = path_parts[5]
    client = dataikuapi.APINodeClient(host, service, api_key)

    if value is None:

        def run(endpoint_input):
            try:
                features = {}
                features[key] = int(endpoint_input.strip("'").strip('"'))
                return post_process(
                    client.predict_record(name, features)["result"]["prediction"]
                )
            except DataikuException as e:
                return f"Error: {e}"

    else:

        def run(ignored_input):
            try:
                features = {}
                features[key] = str(value)
                return post_process(
                    client.predict_record(name, features)["result"]["prediction"]
                )
            except DataikuException as e:
                return f"Error: {e}"

    tool = Tool.from_function(func=run, name=name, description=description)

    return tool


DATASETS_TOOL_DESCRIPTION = """This is a tool you can use to get information from datasets. The input is the question you want to ask the datasets and the output will be the answer based on the datasets.
The available datasets are:
{}
Include in the question all relevant contextual information.
The input is the question you want to ask the answer of, not the actual query.
If a SQL syntax error is returned, try to rephrase the question or provide additional contextual information.
"""


def create_tool_from_datasets(
    tag,
    name="query_datasets",
    description=None,
    datasets_restrictions=[],
    llm=None
):
    """
    Create a tool from datasets identified by a tag.
    The tools answers questions on the basis of the datasets with a text2SQL approach.
    """
    project = dataiku.api_client().get_default_project()
    datasets = [
        dataset for dataset in project.list_datasets() if tag in dataset["tags"]
    ]
    datasets_string = "\n".join(
        [
            f"{ds['name']}: {project.get_dataset(ds['name']).get_settings().short_description}"
            for ds in datasets
        ]
    )
        
    if description is None:
        description = DATASETS_TOOL_DESCRIPTION.format(datasets_string)

    tool = Tool.from_function(
        func=lambda question: get_answer(
            project,
            tag,
            question,
            llm,
            datasets_restrictions=datasets_restrictions
        ),
        name=name,
        description=description,
    )
    return tool


RAG_PROMPT_TEMPLATE = """
Given the following extracts of documents and a question, create a final answer with references. 
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
Scope of the documents: {documents_descriptions}
{additional_instructions}

QUESTION: {question}
=========
EXTRACTED PARTS:
{summaries}
=========
FINAL ANSWER:
"""

MANAGED_FOLDER_TOOL_DESCRIPTION = """This is a tool that uses documents to answer a question. 
The input is the question.
Scope of the documents: {}
If the tool is not helpful, try to continue without it.
"""


def create_tool_from_managed_folder(
    folder_id,
    additional_instructions,
    name="query_documents",
    num_chunks=5,
    description=None,
    llm_chain=None,
    filters={},
    embeddings=None
):
    """
    Create a tool from a managed_folder.
    The managed folder should contain a LangChain FAISS object.
    The tool answers questions on the basis of the documents indexed in the FAISS object, with a retrieval augmented generation approach.
    """

    if embeddings == None:
        embeddings = OpenAIEmbeddings()

    if llm_chain == None:
        llm_chain = LLMChain(
            llm=ChatOpenAI(temperature=0),
            prompt=PromptTemplate.from_template(RAG_PROMPT_TEMPLATE),
        )

    project = dataiku.api_client().get_default_project()
    managed_folder = project.get_managed_folder(folder_id)
    folder = dataiku.Folder(folder_id)

    if "description" in managed_folder.get_definition():
        documents_descriptions = managed_folder.get_definition()["description"]
    else:
        documents_descriptions = "Documents that can be used to answer questions"    

    if description is None:
        description = MANAGED_FOLDER_TOOL_DESCRIPTION.format(documents_descriptions)
    
    with tempfile.TemporaryDirectory() as temp_dir:
        for f in folder.list_paths_in_partition():
            with folder.get_download_stream(f) as stream:
                with open(os.path.join(temp_dir, os.path.basename(f)), "wb") as f2:
                    f2.write(stream.read())
        index = FAISS.load_local(temp_dir, embeddings)

    def run(question: str):
        result_search = index.similarity_search_with_score(
            question, k=num_chunks, filter=filters
        )
        chunks = [r[0] for r in result_search]
        proba = [r[1] for r in result_search]
        standard_proba = proba[0]
        extracts = "\n\n\n".join([chunk.page_content for chunk in chunks])

        result = llm_chain.predict(
            documents_descriptions=documents_descriptions,
            additional_instructions=additional_instructions,
            question=question,
            summaries=extracts,
        )

        return result

    tool = Tool.from_function(func=run, name=name, description=description)
    return tool
