/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.utils.json_schema;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.tuple.Pair;

public class JSONSchemaCompatibilityEnhancer {
    public static final int MAX_ITERATIONS = 8;
    private final boolean useCache;
    private JsonObject root;
    private final Provider provider;
    private Set<String> emittedWarnings = new HashSet<String>();
    private static final Cache<Pair<JsonObject, Provider>, JsonObject> CACHE = CacheBuilder.newBuilder().maximumSize(200L).expireAfterWrite(1L, TimeUnit.HOURS).build();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.json_schema.compat");

    private JSONSchemaCompatibilityEnhancer(JsonObject root, Provider provider) {
        this.root = root;
        this.provider = provider;
        this.useCache = ApplicationConfigurator.getParams().getBoolParam("dku.llm.json_schema_cache.enabled", true);
    }

    public static JsonObject enhance(JsonObject root, Provider provider) {
        return new JSONSchemaCompatibilityEnhancer(root, provider).enhance();
    }

    private void warn(String message) {
        if (this.emittedWarnings.add(message)) {
            logger.warn((Object)message);
        }
    }

    private JsonObject enhance() {
        int iterations;
        JsonObject cached;
        if (this.provider == Provider.PASSTHROUGH) {
            return this.root;
        }
        Pair cacheKey = Pair.of((Object)this.root, (Object)((Object)this.provider));
        if (this.useCache && (cached = (JsonObject)CACHE.getIfPresent((Object)cacheKey)) != null) {
            return cached;
        }
        this.root = (JsonObject)JSON.deepCopy((Object)this.root);
        logger.info((Object)("Enhancing JSON schema for mode: " + String.valueOf((Object)this.provider)));
        if (logger.isDebugEnabled()) {
            logger.debug((Object)("Original schema: \n" + JSON.pretty((Object)this.root)));
        }
        for (iterations = 0; iterations < 8; ++iterations) {
            JsonObject previous = (JsonObject)JSON.deepCopy((Object)this.root);
            this.fixSchemaRecursive(this.root);
            Map<String, Set<String>> defsGraph = this.collectAllDefsAndTheirDependencies(this.root);
            this.resolveRefsRecursive(this.root, defsGraph);
            this.clearUnusedDefinitions(this.root);
            this.ensureRootIsObject(this.root);
            if (Objects.equals(previous, this.root)) break;
        }
        if (iterations == 8) {
            this.warn("Enhanced schema reached the maximum number of iterations (8). The schema may not have been fully transformed.");
        }
        if (logger.isDebugEnabled()) {
            logger.debug((Object)("Enhanced schema (" + iterations + " iterations): \n" + JSON.pretty((Object)this.root)));
        }
        if (this.useCache) {
            CACHE.put((Object)cacheKey, (Object)this.root);
        }
        return this.root;
    }

    private void clearUnusedDefinitions(JsonObject jsonSchema) {
        HashSet<String> usedRefs = new HashSet<String>();
        JSONSchemaCompatibilityEnhancer.collectAllRefs(jsonSchema, usedRefs);
        for (String key : List.of("$defs", "definitions")) {
            JsonElement definitionsElm = jsonSchema.get(key);
            if (definitionsElm == null || !definitionsElm.isJsonObject()) continue;
            JsonObject definitions = definitionsElm.getAsJsonObject();
            ArrayList<String> definitionsToRemove = new ArrayList<String>();
            for (String defName : definitions.keySet()) {
                String ref = "#/" + key + "/" + defName;
                if (usedRefs.contains(ref)) continue;
                definitionsToRemove.add(defName);
            }
            for (String defToRemove : definitionsToRemove) {
                definitions.remove(defToRemove);
            }
            if (definitions.size() != 0) continue;
            jsonSchema.remove(key);
        }
    }

    private static void collectAllRefs(JsonObject jsonSchema, Set<String> refs) {
        for (String key : jsonSchema.keySet()) {
            JsonElement value = jsonSchema.get(key);
            if (value.isJsonObject()) {
                JSONSchemaCompatibilityEnhancer.collectAllRefs(value.getAsJsonObject(), refs);
                continue;
            }
            if (value.isJsonArray()) {
                for (JsonElement element : value.getAsJsonArray()) {
                    if (!element.isJsonObject()) continue;
                    JSONSchemaCompatibilityEnhancer.collectAllRefs(element.getAsJsonObject(), refs);
                }
                continue;
            }
            if (!key.equals("$ref")) continue;
            refs.add(value.getAsString());
        }
    }

    private Map<String, Set<String>> collectAllDefsAndTheirDependencies(JsonObject jsonSchema) {
        boolean changed;
        HashMap<String, Set<String>> defsGraph = new HashMap<String, Set<String>>();
        for (String key : List.of("$defs", "definitions")) {
            JsonElement definitions = jsonSchema.get(key);
            if (definitions == null || !definitions.isJsonObject()) continue;
            for (String defName : definitions.getAsJsonObject().keySet()) {
                JsonElement def = definitions.getAsJsonObject().get(defName);
                if (def == null || !def.isJsonObject()) continue;
                String fullKey = "#/" + key + "/" + defName;
                HashSet<String> dependencies = new HashSet<String>();
                JSONSchemaCompatibilityEnhancer.collectAllRefs(def.getAsJsonObject(), dependencies);
                defsGraph.put(fullKey, dependencies);
            }
        }
        do {
            changed = false;
            for (String defName : defsGraph.keySet()) {
                for (String depName : new ArrayList((Collection)defsGraph.get(defName))) {
                    if (!defsGraph.containsKey(depName)) continue;
                    changed = changed || ((Set)defsGraph.get(defName)).addAll((Collection)defsGraph.get(depName));
                }
            }
        } while (changed);
        return defsGraph;
    }

    private void fixSchemaRecursive(JsonObject jsonSchema) {
        this.visitDefinitions(jsonSchema);
        this.visitArray(jsonSchema);
        this.visitProperties(jsonSchema);
        this.visitUnionAndIntersection(jsonSchema);
        if (this.provider == Provider.OPENAI) {
            this.setAdditionalPropertiesFalse(jsonSchema);
            this.removePropertyOrdering(jsonSchema);
            this.convertNullableToUnion(jsonSchema);
            this.reorderPropertiesToMoveDiscriminatorToFirstPosition(jsonSchema);
        }
        if (this.provider == Provider.GEMINI) {
            this.removeJsonSchemaMetaAttributes(jsonSchema);
            this.addTypeToEnumOfStrings(jsonSchema);
            this.removeAdditionalProperties(jsonSchema);
            this.convertUnionWithNullToNullable(jsonSchema);
            this.injectPropertyOrdering(jsonSchema);
            this.distributeAllPropertiesToEachAnyOfChild(jsonSchema);
        }
        if (this.provider == Provider.NOVA) {
            this.removeJsonSchemaMetaAttributes(jsonSchema);
        }
        if (this.provider == Provider.MISTRAL) {
            this.setAdditionalPropertiesFalse(jsonSchema);
            this.removeSchemaFormatFields(jsonSchema);
        }
        this.transformNonRequiredPropertiesIntoNullableProperties(jsonSchema);
        this.transformConstIntoEnum(jsonSchema);
        this.mergeNestedAnyOf(jsonSchema);
        this.flattenUnionAndIntersectionOfOne(jsonSchema);
        this.removeDefaultValuesIfNull(jsonSchema);
    }

    private void visitProperties(JsonObject jsonSchema) {
        JsonElement properties = jsonSchema.get("properties");
        if (properties != null && properties.isJsonObject()) {
            for (String key : properties.getAsJsonObject().keySet()) {
                JsonElement value = properties.getAsJsonObject().get(key);
                if (value == null || !value.isJsonObject()) continue;
                this.fixSchemaRecursive(value.getAsJsonObject());
            }
        }
    }

    private void visitDefinitions(JsonObject jsonSchemaElement) {
        JsonObject jsonSchema = jsonSchemaElement.getAsJsonObject();
        for (String key : List.of("$defs", "definitions")) {
            JsonElement defs = jsonSchema.get(key);
            if (defs == null || !defs.isJsonObject()) continue;
            for (String defName : defs.getAsJsonObject().keySet()) {
                JsonElement def = defs.getAsJsonObject().get(defName);
                if (def == null || !def.isJsonObject()) continue;
                this.fixSchemaRecursive(def.getAsJsonObject());
            }
        }
    }

    private void transformNonRequiredPropertiesIntoNullableProperties(JsonObject jsonSchema) {
        JsonElement properties = jsonSchema.get("properties");
        if (properties != null && properties.isJsonObject()) {
            JsonElement requiredProperties = jsonSchema.get("required");
            if (requiredProperties == null) {
                requiredProperties = new JsonArray();
            }
            if (!requiredProperties.isJsonArray()) {
                return;
            }
            HashSet<String> requiredPropertiesSet = new HashSet<String>();
            for (JsonElement requiredProperty : requiredProperties.getAsJsonArray()) {
                requiredPropertiesSet.add(requiredProperty.getAsString());
            }
            for (String key : properties.getAsJsonObject().keySet()) {
                JsonElement value = properties.getAsJsonObject().get(key);
                if (value == null || !value.isJsonObject() || requiredPropertiesSet.contains(key)) continue;
                value.getAsJsonObject().addProperty("nullable", Boolean.valueOf(true));
            }
            jsonSchema.add("required", (JsonElement)properties.getAsJsonObject().keySet().stream().collect(JsonArray::new, JsonArray::add, JsonArray::addAll));
        }
    }

    private void injectPropertyOrdering(JsonObject jsonSchema) {
        JsonElement properties = jsonSchema.get("properties");
        if (properties != null && properties.isJsonObject() && !jsonSchema.has("propertyOrdering")) {
            JsonArray propertyOrder = new JsonArray();
            for (String key : properties.getAsJsonObject().keySet()) {
                propertyOrder.add((JsonElement)new JsonPrimitive(key));
            }
            jsonSchema.add("propertyOrdering", (JsonElement)propertyOrder);
        }
    }

    private void setAdditionalPropertiesFalse(JsonObject jsonSchema) {
        JsonElement type = jsonSchema.get("type");
        if (type != null && type.isJsonPrimitive() && "object".equals(type.getAsString())) {
            jsonSchema.add("additionalProperties", (JsonElement)new JsonPrimitive(Boolean.valueOf(false)));
        }
    }

    private void removeSchemaFormatFields(JsonObject jsonSchema) {
        jsonSchema.remove("format");
    }

    private void transformConstIntoEnum(JsonObject jsonSchema) {
        JsonElement constValue = jsonSchema.get("const");
        if (constValue != null) {
            JsonArray enumArray = new JsonArray();
            enumArray.add(constValue);
            jsonSchema.add("enum", (JsonElement)enumArray);
            jsonSchema.remove("const");
        }
    }

    private void removeFields(JsonObject jsonSchema, String ... fields) {
        JsonElement type = jsonSchema.get("type");
        if (type != null && type.isJsonPrimitive() && "object".equals(type.getAsString())) {
            for (String field : fields) {
                jsonSchema.remove(field);
            }
        }
    }

    private void removeJsonSchemaMetaAttributes(JsonObject jsonSchema) {
        this.removeFields(jsonSchema, "$schema", "$id");
    }

    private void removeAdditionalProperties(JsonObject jsonSchema) {
        this.removeFields(jsonSchema, "additionalProperties");
    }

    private void removePropertyOrdering(JsonObject jsonSchema) {
        this.removeFields(jsonSchema, "propertyOrdering");
    }

    private void visitArray(JsonObject jsonSchema) {
        JsonElement items = jsonSchema.get("items");
        if (items != null && items.isJsonObject()) {
            this.fixSchemaRecursive(items.getAsJsonObject());
        }
    }

    private void convertUnionWithNullToNullable(JsonObject jsonSchema) {
        JsonElement anyOf = jsonSchema.get("anyOf");
        if (anyOf != null && anyOf.isJsonArray()) {
            boolean isNullable = false;
            for (JsonElement element : anyOf.getAsJsonArray()) {
                JsonObject obj;
                JsonElement type;
                if (!element.isJsonObject() || (type = (obj = element.getAsJsonObject()).get("type")) == null || !type.isJsonPrimitive() || !type.getAsString().equals("null")) continue;
                isNullable = true;
                anyOf.getAsJsonArray().remove(element);
                break;
            }
            if (isNullable) {
                jsonSchema.addProperty("nullable", Boolean.valueOf(true));
            }
        }
    }

    public void reorderPropertiesToMoveDiscriminatorToFirstPosition(JsonObject jsonSchema) {
        JsonElement anyOf = jsonSchema.get("anyOf");
        if (anyOf == null || !anyOf.isJsonArray()) {
            return;
        }
        ArrayList<JsonObject> subschemas = new ArrayList<JsonObject>();
        for (Object element : anyOf.getAsJsonArray()) {
            if (!element.isJsonObject()) continue;
            JsonElement ref = element.getAsJsonObject().get("$ref");
            if (ref != null && ref.isJsonPrimitive()) {
                JsonObject resolved = this.resolveRef(ref.getAsString());
                if (resolved == null) continue;
                subschemas.add(resolved);
                continue;
            }
            subschemas.add(element.getAsJsonObject());
        }
        subschemas.removeIf(subschema -> {
            JsonElement properties = subschema.get("properties");
            return properties == null || !properties.isJsonObject();
        });
        if (subschemas.size() <= 1) {
            return;
        }
        ArrayList<Pair> commonProperties = new ArrayList<Pair>();
        for (String key : ((JsonObject)subschemas.get(0)).get("properties").getAsJsonObject().keySet()) {
            boolean isCommon = true;
            ArrayList<JsonObject> commonValues = new ArrayList<JsonObject>();
            for (JsonObject subschema2 : subschemas) {
                JsonObject properties = subschema2.get("properties").getAsJsonObject();
                if (!properties.has(key)) {
                    isCommon = false;
                    break;
                }
                commonValues.add(properties.getAsJsonObject(key));
            }
            if (!isCommon) continue;
            commonProperties.add(Pair.of((Object)key, commonValues));
        }
        if (commonProperties.isEmpty()) {
            return;
        }
        String discriminator = null;
        for (Pair commonProperty : commonProperties) {
            boolean isDiscriminator = true;
            HashSet<JsonElement> seenValues = new HashSet<JsonElement>();
            for (JsonObject commonValue : (List)commonProperty.getRight()) {
                JsonElement enumValue = commonValue.get("enum");
                if (enumValue == null || !enumValue.isJsonArray() || enumValue.getAsJsonArray().size() != 1) {
                    isDiscriminator = false;
                    break;
                }
                if (seenValues.add(enumValue.getAsJsonArray().get(0))) continue;
                isDiscriminator = false;
                break;
            }
            if (!isDiscriminator) continue;
            discriminator = (String)commonProperty.getLeft();
            break;
        }
        if (discriminator == null) {
            return;
        }
        for (JsonObject subschema3 : subschemas) {
            JsonObject properties = subschema3.get("properties").getAsJsonObject();
            if (!properties.has(discriminator) || properties.size() <= 1 || discriminator.equals(properties.keySet().iterator().next())) continue;
            JsonObject newProperties = new JsonObject();
            newProperties.add(discriminator, properties.get(discriminator));
            for (String key : properties.keySet()) {
                if (key.equals(discriminator)) continue;
                newProperties.add(key, properties.get(key));
            }
            subschema3.add("properties", (JsonElement)newProperties);
            subschema3.remove("propertyOrdering");
        }
    }

    private void distributeAllPropertiesToEachAnyOfChild(JsonObject jsonSchema) {
        JsonElement anyOf = jsonSchema.get("anyOf");
        if (anyOf != null && anyOf.isJsonArray()) {
            for (JsonElement element : anyOf.getAsJsonArray()) {
                if (!element.isJsonObject()) continue;
                JsonObject obj = element.getAsJsonObject();
                for (String key : jsonSchema.keySet()) {
                    if (key.equals("anyOf")) continue;
                    obj.add(key, jsonSchema.get(key));
                }
            }
            for (String key : new ArrayList(jsonSchema.keySet())) {
                if (key.equals("anyOf")) continue;
                jsonSchema.remove(key);
            }
        }
    }

    private void convertNullableToUnion(JsonObject jsonSchema) {
        JsonElement nullable = jsonSchema.get("nullable");
        if (nullable != null && nullable.isJsonPrimitive() && nullable.getAsBoolean()) {
            JsonObject subSchema = new JsonObject();
            for (String key : new ArrayList(jsonSchema.keySet())) {
                if (!key.equals("nullable") && !key.equals("title")) {
                    subSchema.add(key, jsonSchema.get(key));
                }
                if (key.equals("title")) continue;
                jsonSchema.remove(key);
            }
            jsonSchema.remove("nullable");
            JsonArray anyOf = new JsonArray();
            JsonObject jsonObjectNull = new JsonObject();
            jsonObjectNull.addProperty("type", "null");
            anyOf.add((JsonElement)subSchema);
            anyOf.add((JsonElement)jsonObjectNull);
            jsonSchema.add("anyOf", (JsonElement)anyOf);
        }
    }

    private void visitUnionAndIntersection(JsonObject jsonSchema) {
        for (String key : List.of("allOf", "anyOf")) {
            JsonElement value = jsonSchema.get(key);
            if (value == null || !value.isJsonArray()) continue;
            for (int i = 0; i < value.getAsJsonArray().size(); ++i) {
                JsonElement item = value.getAsJsonArray().get(i);
                if (item == null || !item.isJsonObject()) continue;
                this.fixSchemaRecursive(item.getAsJsonObject());
            }
        }
    }

    private void flattenUnionAndIntersectionOfOne(JsonObject jsonSchema) {
        for (String key : List.of("allOf", "anyOf")) {
            JsonElement firstEntry;
            JsonElement value = jsonSchema.get(key);
            if (value == null || !value.isJsonArray() || value.getAsJsonArray().size() != 1 || (firstEntry = value.getAsJsonArray().get(0)) == null || !firstEntry.isJsonObject()) continue;
            jsonSchema.remove(key);
            for (String subKey : firstEntry.getAsJsonObject().keySet()) {
                jsonSchema.add(subKey, firstEntry.getAsJsonObject().get(subKey));
            }
        }
    }

    private void mergeNestedAnyOf(JsonObject jsonSchema) {
        JsonElement anyOf = jsonSchema.get("anyOf");
        if (anyOf != null && anyOf.isJsonArray()) {
            JsonArray newArray = new JsonArray();
            for (JsonElement element : anyOf.getAsJsonArray()) {
                if (!element.isJsonObject()) continue;
                JsonObject obj = element.getAsJsonObject();
                JsonElement nestedAnyOf = obj.get("anyOf");
                if (nestedAnyOf != null && nestedAnyOf.isJsonArray()) {
                    for (JsonElement nestedElement : nestedAnyOf.getAsJsonArray()) {
                        newArray.add(nestedElement);
                    }
                    continue;
                }
                newArray.add((JsonElement)obj);
            }
            jsonSchema.add("anyOf", (JsonElement)newArray);
        }
    }

    private void resolveRefsRecursive(JsonObject jsonSchema, Map<String, Set<String>> defsGraph) {
        JsonObject resolved;
        JsonElement value;
        for (String keysObj : new String[]{"properties", "$defs", "definitions"}) {
            value = jsonSchema.get(keysObj);
            if (value == null || !value.isJsonObject()) continue;
            for (String key : value.getAsJsonObject().keySet()) {
                JsonElement element = value.getAsJsonObject().get(key);
                if (!element.isJsonObject()) continue;
                this.resolveRefsRecursive(element.getAsJsonObject(), defsGraph);
            }
        }
        for (String keyObj : new String[]{"items"}) {
            value = jsonSchema.get(keyObj);
            if (value == null || !value.isJsonObject()) continue;
            this.resolveRefsRecursive(value.getAsJsonObject(), defsGraph);
        }
        for (String arrayObj : new String[]{"allOf", "anyOf"}) {
            value = jsonSchema.get(arrayObj);
            if (value == null || !value.isJsonArray()) continue;
            for (JsonElement element : value.getAsJsonArray()) {
                if (!element.isJsonObject()) continue;
                this.resolveRefsRecursive(element.getAsJsonObject(), defsGraph);
            }
        }
        JsonElement ref = jsonSchema.get("$ref");
        if (ref == null || !ref.isJsonPrimitive() || !ref.getAsJsonPrimitive().isString()) {
            return;
        }
        Set<String> defDeps = defsGraph.get(ref.getAsString());
        if (defDeps == null) {
            return;
        }
        boolean isRecursive = defDeps.contains(ref.getAsString());
        if (this.provider == Provider.GEMINI) {
            if (isRecursive) {
                this.warn("Gemini does not support recursive schemas, can't resolve reference: " + ref.getAsString());
                return;
            }
        } else if (this.provider == Provider.OPENAI) {
            if (jsonSchema.size() <= 1) {
                return;
            }
        } else {
            return;
        }
        if ((resolved = this.resolveRef(ref.getAsString())) == null) {
            this.warn("Could not resolve reference: " + ref.getAsString());
            return;
        }
        if (resolved.has("$ref")) {
            return;
        }
        resolved = (JsonObject)JSON.deepCopy((Object)resolved);
        for (String key : resolved.keySet()) {
            if (jsonSchema.has(key)) continue;
            jsonSchema.add(key, resolved.get(key));
        }
        jsonSchema.remove("$ref");
    }

    private void ensureRootIsObject(JsonObject jsonSchema) {
        if (this.provider == Provider.OPENAI && !jsonSchema.has("type")) {
            jsonSchema.addProperty("type", "object");
        }
    }

    private void addTypeToEnumOfStrings(JsonObject jsonSchema) {
        if (jsonSchema.has("type")) {
            return;
        }
        JsonElement enumArray = jsonSchema.get("enum");
        if (enumArray != null && enumArray.isJsonArray() && !enumArray.getAsJsonArray().isEmpty()) {
            for (JsonElement element : enumArray.getAsJsonArray()) {
                if (element.isJsonPrimitive() && element.getAsJsonPrimitive().isString()) continue;
                return;
            }
            jsonSchema.addProperty("type", "string");
        }
    }

    private void removeDefaultValuesIfNull(JsonObject jsonSchema) {
        JsonElement defaultValue = jsonSchema.get("default");
        if (defaultValue != null && defaultValue.isJsonNull()) {
            jsonSchema.remove("default");
        }
    }

    private JsonObject resolveRef(String ref) {
        if (!ref.startsWith("#/")) {
            this.warn("Unsupported reference: " + ref + " (only reference to definitions are supported)");
            return null;
        }
        String[] pathSegments = ref.substring(2).split("/");
        if (pathSegments.length != 2 || !"definitions".equals(pathSegments[0]) && !"$defs".equals(pathSegments[0])) {
            this.warn("Unsupported reference: " + ref + " (only reference to definitions are supported)");
            return null;
        }
        JsonElement defs = this.root.getAsJsonObject().get(pathSegments[0]);
        if (defs == null || !defs.isJsonObject()) {
            return null;
        }
        JsonElement resolvedElement = defs.getAsJsonObject().get(pathSegments[1]);
        if (resolvedElement == null || !resolvedElement.isJsonObject()) {
            return null;
        }
        return resolvedElement.getAsJsonObject();
    }

    public static enum Provider {
        OPENAI,
        GEMINI,
        NOVA,
        MISTRAL,
        PASSTHROUGH;

    }
}

