# See examples on https://developer.dataiku.com/latest/concepts-and-examples/llm-mesh.html#fine-tuning

## Variables

base_model_name = ""
assert base_model_name, "please specify a base LLM, it must be available on HuggingFace hub"

connection_name = ""
assert connection_name, "please specify a connection name, the fine-tuned LLM will be available from this connection"

# these columns must be in the input dataset
prompt_column = "prompt"
completion_column = "completion"

## Code

import datasets
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer

from dataiku import recipe
from dataiku.llm.finetuning import formatters

# turn Dataiku datasets into transformers datasets
columns = [prompt_column, completion_column]
training_dataset = recipe.get_inputs()[0]
df = training_dataset.get_dataframe(columns=columns)
train_dataset = datasets.Dataset.from_pandas(df)

validation_dataset = None
eval_dataset = None
if len(recipe.get_inputs()) > 1:
    validation_dataset = recipe.get_inputs()[1]
    df = validation_dataset.get_dataframe(columns=columns)
    eval_dataset = datasets.Dataset.from_pandas(df)

# load the base model and tokenizer
model = AutoModelForCausalLM.from_pretrained(base_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token

formatting_func = formatters.InstructPromptFormatter(*columns)

# fine-tune using SFTTrainer
saved_model = recipe.get_outputs()[0]
with saved_model.create_finetuned_llm_version(connection_name) as finetuned_llm_version:
    # Customize here. Requirement: put a transformers model in safetensors format into finetuned_llm_version.working_directory.
    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        formatting_func=formatting_func,
        args=SFTConfig(
            output_dir=finetuned_llm_version.working_directory,
            save_safetensors=True,
            num_train_epochs=1,
            logging_steps=1,
            max_length=min(tokenizer.model_max_length, 1024),
            eval_strategy="steps" if eval_dataset else "no",
        ),
    )
    trainer.train()
    trainer.save_model()

    config = finetuned_llm_version.config
    config["trainingDataset"] = training_dataset.short_name
    if validation_dataset:
        config["validationDataset"] = validation_dataset.short_name
    config["promptColumn"] = prompt_column
    config["completionColumn"] = completion_column
    config["batchSize"] = trainer.state.train_batch_size
    config["eventLog"] = trainer.state.log_history
