/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.security.jwt;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.security.jwt.JWTError;
import com.dataiku.dip.security.jwt.RemoteJWKException;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelib.com.nimbusds.jose.JOSEException;
import com.dataiku.dss.shadelib.com.nimbusds.jose.JWSAlgorithm;
import com.dataiku.dss.shadelib.com.nimbusds.jose.JWSHeader;
import com.dataiku.dss.shadelib.com.nimbusds.jose.KeySourceException;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.ECKey;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.JWKMatcher;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.JWKSelector;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.JWKSet;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.RSAKey;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.source.CachingJWKSetSource;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.source.JWKSource;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.source.RefreshAheadCachingJWKSetSource;
import com.dataiku.dss.shadelib.com.nimbusds.jose.proc.JWSKeySelector;
import com.dataiku.dss.shadelib.com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.dataiku.dss.shadelib.com.nimbusds.jose.proc.SecurityContext;
import com.dataiku.dss.shadelib.com.nimbusds.jose.util.events.EventListener;
import java.security.Key;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.springframework.stereotype.Service;

@Service
public class JwtVerificationService {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.jwt");
    private static final String JWK_URL_CACHE_PROPERTY_CACHE_TIME_TO_LIVE_IN_MINUTES = "dku.oauth2.jwkurl.cache.timeToLiveInMinutes";
    private static final long DEFAULT_CACHE_TIME_TO_LIVE_IN_MINUTES = 60L;
    private static final String JWK_URL_CACHE_PROPERTY_CACHE_REFRESH_TIMEOUT_IN_SECONDS = "dku.oauth2.jwkurl.cache.refreshTimeoutInSeconds";
    private static final long DEFAULT_CACHE_REFRESH_TIMEOUT_IN_SECONDS = 15L;
    private static final String JWK_URL_CACHE_PROPERTY_CACHE_REFRESH_AHEAD_TIME_IN_MINUTES = "dku.oauth2.jwkurl.cache.refreshAHeadTimeInMinutes";
    private static final long DEFAULT_REFRESH_AHEAD_TIME_IN_MINUTES = 10L;
    private static final String JWK_URL_CACHE_PROPERTY_REFRESH_AHEAD_SCHEDULED = "dku.oauth2.jwkurl.cache.scheduled";
    private static final boolean DEFAULT_REFRESH_AHEAD_SCHEDULED = true;
    private static final String JWK_URL_CACHE_PROPERTY_REFRESH_AHEAD_EVENT_LOGGING_ENABLED = "dku.oauth2.jwkurl.cache.event.logging.enabled";
    private static final boolean DEFAULT_REFRESH_AHEAD_EVENT_LOGGING = true;
    private static final String JWK_URL_CACHE_PROPERTY_RATE_LIMIT = "dku.oauth2.jwkurl.cache.rateLimit";
    private static final boolean DEFAULT_RATE_LIMIT = false;

    public JWSKeySelector<SecurityContext> selectKeysFromKidOrAlg(JWKSource<SecurityContext> keySource) {
        return (jwsHeader, securityContext) -> this.selectJWSKeys(keySource, jwsHeader, securityContext);
    }

    private List<Key> selectJWSKeys(JWKSource<SecurityContext> keySource, JWSHeader jwsHeader, SecurityContext securityContext) throws KeySourceException {
        List keysFound;
        if (jwsHeader.getAlgorithm() == null || !JWSAlgorithm.Family.EC.contains((Object)jwsHeader.getAlgorithm()) && !JWSAlgorithm.Family.RSA.contains((Object)jwsHeader.getAlgorithm())) {
            throw new RemoteJWKException(JWTError.ERR_JWT_ALGO_NOT_SUPPORTED, "The JWT algorithm '" + String.valueOf(jwsHeader.getAlgorithm()) + "' is not supported. Only EC or RSA are supported currently");
        }
        if (!StringUtils.isBlank((String)jwsHeader.getKeyID())) {
            keysFound = keySource.get(new JWKSelector(new JWKMatcher.Builder().keyID(jwsHeader.getKeyID()).build()), null).stream().map(jwk -> {
                try {
                    if (jwk instanceof ECKey) {
                        return ((ECKey)jwk).toECPublicKey();
                    }
                    if (jwk instanceof RSAKey) {
                        return ((RSAKey)jwk).toRSAPublicKey();
                    }
                    logger.warnV("The JWK found '%s' is neither a RSA or EC key, JWK instance of '%s'. We skip this key", new Object[]{jwk.toJSONString(), jwk.getClass()});
                    return null;
                }
                catch (JOSEException e) {
                    logger.warnV("Could not get public key from key '%s' of type '%s'. We skip this key", new Object[]{jwk.toJSONString(), jwk.getClass()});
                    return null;
                }
            }).filter(Objects::nonNull).collect(Collectors.toList());
            if (keysFound.size() == 0) {
                logger.warnV("Couldn't find the JWK corresponding to kid %s in remote jwk set", new Object[]{jwsHeader.getKeyID()});
                throw new RemoteJWKException(JWTError.ERR_JWK_NOT_FOUND, "Failed to find JWK for the following kid='" + jwsHeader.getKeyID() + "'");
            }
        } else {
            keysFound = new JWSVerificationKeySelector(jwsHeader.getAlgorithm(), keySource).selectJWSKeys(jwsHeader, securityContext);
            if (keysFound.size() == 0) {
                logger.warnV("JWT has no KID and we couldn't find the JWK corresponding to the algorithm %s in remote jwk set", new Object[]{jwsHeader.getAlgorithm()});
                throw new RemoteJWKException(JWTError.ERR_JWK_NOT_FOUND, "Failed to find JWK for the following alg='" + String.valueOf(jwsHeader.getAlgorithm()) + "'");
            }
        }
        return keysFound;
    }

    public void dkuDefaultJwkSrcConfiguration(JWKSourceBuilder<SecurityContext> jwkSourceBuilder, String jwksUri) {
        long cacheTimeToLiveInMs = 60000L * DKUApp.getProperty(JWK_URL_CACHE_PROPERTY_CACHE_TIME_TO_LIVE_IN_MINUTES, 60L);
        long cacheRefreshTimeoutInMs = 1000L * DKUApp.getProperty(JWK_URL_CACHE_PROPERTY_CACHE_REFRESH_TIMEOUT_IN_SECONDS, 15L);
        long cacheRefreshAHeadTimeInMs = 60000L * DKUApp.getProperty(JWK_URL_CACHE_PROPERTY_CACHE_REFRESH_AHEAD_TIME_IN_MINUTES, 10L);
        boolean cacheRefreshAHeadScheduled = DKUApp.getProperty(JWK_URL_CACHE_PROPERTY_REFRESH_AHEAD_SCHEDULED, true);
        boolean rateLimit = DKUApp.getProperty(JWK_URL_CACHE_PROPERTY_RATE_LIMIT, false);
        boolean loggingEnabled = DKUApp.getProperty(JWK_URL_CACHE_PROPERTY_REFRESH_AHEAD_EVENT_LOGGING_ENABLED, true);
        jwkSourceBuilder.rateLimited(rateLimit);
        jwkSourceBuilder.cache(cacheTimeToLiveInMs, cacheRefreshTimeoutInMs);
        jwkSourceBuilder.refreshAheadCache(cacheRefreshAHeadTimeInMs, cacheRefreshAHeadScheduled, loggingEnabled ? JwtVerificationService.getEventListener(jwksUri) : event -> {});
    }

    private static EventListener<CachingJWKSetSource<SecurityContext>, SecurityContext> getEventListener(String jwksUri) {
        return event -> {
            logger.infoV("JWKSet Event '%s' for source URI: %s", new Object[]{event.getClass().getSimpleName(), jwksUri});
            if (event instanceof CachingJWKSetSource.RefreshCompletedEvent || event instanceof RefreshAheadCachingJWKSetSource.ScheduledRefreshCompletedEvent) {
                JWKSet jwkSet = null;
                if (event instanceof CachingJWKSetSource.RefreshCompletedEvent) {
                    CachingJWKSetSource.RefreshCompletedEvent refreshCompletedEvent = (CachingJWKSetSource.RefreshCompletedEvent)event;
                    jwkSet = refreshCompletedEvent.getJWKSet();
                } else if (event instanceof RefreshAheadCachingJWKSetSource.ScheduledRefreshCompletedEvent) {
                    RefreshAheadCachingJWKSetSource.ScheduledRefreshCompletedEvent refreshCompletedEvent = (RefreshAheadCachingJWKSetSource.ScheduledRefreshCompletedEvent)event;
                    jwkSet = refreshCompletedEvent.getJWKSet();
                }
                if (jwkSet != null) {
                    logger.traceV("JWK sets: %s", new Object[]{jwkSet});
                    logger.infoV("JWK Set contains %d keys: [%s]", new Object[]{jwkSet.getKeys().size(), jwkSet.getKeys().stream().map(jwk -> String.format("{kid: '%s', type: '%s', use: '%s'}", jwk.getKeyID(), jwk.getKeyType() != null ? jwk.getKeyType().getValue() : "unknown", jwk.getKeyUse() != null ? jwk.getKeyUse().identifier() : "none")).collect(Collectors.joining(", "))});
                } else {
                    logger.infoV("No JWKs Set attached to the event", new Object[0]);
                }
            } else if (event instanceof CachingJWKSetSource.RefreshInitiatedEvent || event instanceof CachingJWKSetSource.WaitingForRefreshEvent) {
                logger.traceV("JWK set cache refresh event: %s", new Object[]{event.getClass().getSimpleName()});
            } else if (event instanceof CachingJWKSetSource.RefreshTimedOutEvent) {
                logger.error((Object)"JWK set cache refresh timed out");
            } else if (event instanceof RefreshAheadCachingJWKSetSource.ScheduledRefreshFailed) {
                RefreshAheadCachingJWKSetSource.ScheduledRefreshFailed scheduledRefreshFailed = (RefreshAheadCachingJWKSetSource.ScheduledRefreshFailed)event;
                logger.error((Object)"JWK set scheduled refresh failed", (Throwable)scheduledRefreshFailed.getException());
            } else if (event instanceof CachingJWKSetSource.UnableToRefreshEvent || event instanceof RefreshAheadCachingJWKSetSource.UnableToRefreshAheadOfExpirationEvent) {
                logger.error((Object)"Unable to refresh the JWK set cache");
            } else {
                logger.traceV("JWK source event: %s", new Object[]{event.getClass().getSimpleName()});
            }
        };
    }
}

