/*
 * Decompiled with CFR 0.152.
 */
package io.gravitee.policy.jws;

import io.gravitee.gateway.api.ExecutionContext;
import io.gravitee.gateway.api.Request;
import io.gravitee.gateway.api.buffer.Buffer;
import io.gravitee.gateway.api.http.stream.TransformableRequestStreamBuilder;
import io.gravitee.gateway.api.stream.ReadWriteStream;
import io.gravitee.gateway.api.stream.exception.TransformationException;
import io.gravitee.policy.api.PolicyChain;
import io.gravitee.policy.api.PolicyResult;
import io.gravitee.policy.api.annotations.OnRequestContent;
import io.gravitee.policy.jws.configuration.JWSPolicyConfiguration;
import io.gravitee.policy.jws.utils.JsonUtils;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import io.jsonwebtoken.SignatureException;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.SigningKeyResolverAdapter;
import io.jsonwebtoken.UnsupportedJwtException;
import io.jsonwebtoken.impl.DefaultClaims;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.net.URI;
import java.net.URL;
import java.net.URLConnection;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.Key;
import java.security.KeyFactory;
import java.security.PublicKey;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509CRL;
import java.security.cert.X509CRLEntry;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.RSAPublicKeySpec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.xml.bind.DatatypeConverter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.env.Environment;
import sun.security.x509.CRLDistributionPointsExtension;
import sun.security.x509.DistributionPoint;
import sun.security.x509.GeneralNames;
import sun.security.x509.URIName;
import sun.security.x509.X509CertImpl;

public class JWSPolicy {
    private static final Logger LOGGER = LoggerFactory.getLogger(JWSPolicy.class);
    private static final String DEFAULT_KID = "default";
    private static final String PUBLIC_KEY_PROPERTY = "policy.jws.kid.%s";
    private static final Pattern SSH_PUB_KEY = Pattern.compile("ssh-(rsa|dsa) ([A-Za-z0-9/+]+=*) (.*)");
    private static final String PEM_EXTENSION = ".pem";
    private static final String BEGIN_PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----";
    private static final String END_PRIVATE_KEY = "-----END PRIVATE KEY-----";
    private static final String BEGIN_RSA_PRIVATE_KEY = "-----BEGIN RSA PRIVATE KEY-----";
    private static final String END_RSA_PRIVATE_KEY = "-----END RSA PRIVATE KEY-----";
    private static final String JOSE_JSON_TYP = "JOSE+JSON";
    private static final String JSON_TYP = "JSON";
    private static final String[] AUTHORIZED_TYPES = new String[]{"JSON", "JOSE+JSON"};
    private static final String JSON_CTY = "json";
    private static final String APPLICATION_PREFIX = "application/";
    private JWSPolicyConfiguration jwsPolicyConfiguration;

    public JWSPolicy(JWSPolicyConfiguration jwsPolicyConfiguration) {
        this.jwsPolicyConfiguration = jwsPolicyConfiguration;
    }

    @OnRequestContent
    public ReadWriteStream onRequestContent(Request request, ExecutionContext executionContext, PolicyChain policyChain) {
        return TransformableRequestStreamBuilder.on((Request)request).chain(policyChain).contentType("application/json").transform(this.map(executionContext, policyChain)).build();
    }

    Function<Buffer, Buffer> map(ExecutionContext executionContext, PolicyChain policyChain) {
        return input -> {
            try {
                DefaultClaims jwtClaims = this.validateJsonWebToken(input.toString(), executionContext);
                return Buffer.buffer((String)JsonUtils.writeValueAsString(jwtClaims));
            }
            catch (ExpiredJwtException | MalformedJwtException | SignatureException | UnsupportedJwtException | IllegalArgumentException | CertificateException ex) {
                LOGGER.error("Failed to decoding JWS token", ex);
                policyChain.streamFailWith(PolicyResult.failure((int)401, (String)"Unauthorized"));
                return null;
            }
            catch (Exception ex) {
                LOGGER.error("Error occurs while decoding JWS token", (Throwable)ex);
                throw new TransformationException("Unable to apply JWS decode: " + ex.getMessage(), (Throwable)ex);
            }
        };
    }

    private DefaultClaims validateJsonWebToken(String jwt, ExecutionContext executionContext) throws CertificateException {
        JwtParser jwtParser = Jwts.parser();
        SigningKeyResolver signingKeyResolver = this.getSigningKeyResolverByGatewaySettings(executionContext);
        jwtParser.setSigningKeyResolver(signingKeyResolver);
        Jws token = jwtParser.parseClaimsJws(jwt);
        String type = (String)((JwsHeader)token.getHeader()).get((Object)"typ");
        if (type != null && !type.isEmpty() && !Arrays.asList(AUTHORIZED_TYPES).contains(type.toUpperCase())) {
            throw new MalformedJwtException("Only " + AUTHORIZED_TYPES + " JWS typ header are authorized but was " + type);
        }
        String cty = (String)((JwsHeader)token.getHeader()).get((Object)"cty");
        if (cty != null && !cty.isEmpty() && !JSON_CTY.equals(cty = cty.toLowerCase().replaceAll(APPLICATION_PREFIX, ""))) {
            throw new MalformedJwtException("Only json JWS cty header is authorized but was " + cty);
        }
        List x5cList = (List)((JwsHeader)token.getHeader()).get((Object)"x5c");
        String[] x5c = x5cList.toArray(new String[x5cList.size()]);
        if (x5c == null || x5c.length == 0) {
            throw new MalformedJwtException("X5C JWS Header is missing");
        }
        X509Certificate cert = JWSPolicy.extractCertificateFromX5CHeader(x5c);
        RSAPublicKey givenPublicKey = (RSAPublicKey)signingKeyResolver.resolveSigningKey((JwsHeader)token.getHeader(), (Claims)token.getBody());
        RSAPublicKey certificatePublicKey = (RSAPublicKey)cert.getPublicKey();
        if (certificatePublicKey.getPublicExponent().compareTo(givenPublicKey.getPublicExponent()) != 0) {
            throw new SignatureException("Certificate public key exponent is different compare to the given public key exponent");
        }
        if (certificatePublicKey.getModulus().compareTo(givenPublicKey.getModulus()) != 0) {
            throw new SignatureException("Certificate public key modulus is different compare to the given public key modulus");
        }
        if (this.jwsPolicyConfiguration.isCheckCertificateValidity()) {
            cert.checkValidity();
        }
        if (this.jwsPolicyConfiguration.isCheckCertificateRevocation()) {
            this.validateCRLSFromCertificate(cert);
        }
        return (DefaultClaims)token.getBody();
    }

    private SigningKeyResolver getSigningKeyResolverByGatewaySettings(final ExecutionContext executionContext) {
        return new SigningKeyResolverAdapter(){

            public Key resolveSigningKey(JwsHeader header, Claims claims) {
                Environment env;
                String publicKey;
                String keyId = header.getKeyId();
                if (keyId == null || keyId.isEmpty()) {
                    keyId = JWSPolicy.DEFAULT_KID;
                }
                if ((publicKey = (env = (Environment)executionContext.getComponent(Environment.class)).getProperty(String.format(JWSPolicy.PUBLIC_KEY_PROPERTY, keyId))) == null || publicKey.trim().isEmpty()) {
                    return null;
                }
                Matcher m = SSH_PUB_KEY.matcher(publicKey);
                if (m.matches()) {
                    return JWSPolicy.parsePublicKey(publicKey);
                }
                if (publicKey.endsWith(JWSPolicy.PEM_EXTENSION)) {
                    try {
                        return JWSPolicy.extractPublicKeyFromPEMFile(publicKey);
                    }
                    catch (Exception e) {
                        LOGGER.error("Failed to load PEM file", (Throwable)e);
                        return null;
                    }
                }
                return null;
            }
        };
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void validateCRLSFromCertificate(X509Certificate certificate, BigInteger serialNumber) throws CertificateException {
        block12: {
            X509CRLEntry revokedCertificate = null;
            X509CertImpl x509Cert = (X509CertImpl)certificate;
            CRLDistributionPointsExtension crlDistroExtension = x509Cert.getCRLDistributionPointsExtension();
            if (crlDistroExtension != null) {
                try {
                    ArrayList distributionPoints = (ArrayList)crlDistroExtension.get("points");
                    Iterator iterator = distributionPoints.iterator();
                    boolean hasError = false;
                    while (iterator.hasNext() && revokedCertificate == null) {
                        GeneralNames distroName = ((DistributionPoint)iterator.next()).getFullName();
                        for (int i = 0; i < distroName.size(); ++i) {
                            hasError = false;
                            if (revokedCertificate != null) break;
                            try (FilterInputStream inStream = null;){
                                URI uri = ((URIName)distroName.get(i).getName()).getURI();
                                URL url = new URL(uri.toString());
                                URLConnection connection = url.openConnection();
                                inStream = new DataInputStream(connection.getInputStream());
                                X509CRL crl = (X509CRL)JWSPolicy.certificateFactory().generateCRL(inStream);
                                revokedCertificate = crl.getRevokedCertificate(serialNumber != null ? serialNumber : certificate.getSerialNumber());
                                continue;
                            }
                        }
                        if (!hasError || iterator.hasNext()) continue;
                        throw new CertificateException("An error has occurred while checking if certificate was revoked");
                    }
                    if (revokedCertificate != null) {
                        throw new CertificateException("Certificate has been revoked");
                    }
                    break block12;
                }
                catch (IOException ex) {
                    throw new CertificateException("Failed to get CRL distribution points");
                }
            }
            throw new CertificateException("Failed to find CRL distribution points for the given certificate");
        }
    }

    private void validateCRLSFromCertificate(X509Certificate certificate) throws CertificateException {
        this.validateCRLSFromCertificate(certificate, null);
    }

    static RSAPublicKey parsePublicKey(String key) {
        Matcher m = SSH_PUB_KEY.matcher(key);
        if (m.matches()) {
            String alg = m.group(1);
            String encKey = m.group(2);
            if (!"rsa".equalsIgnoreCase(alg)) {
                throw new IllegalArgumentException("Only RSA is currently supported, but algorithm was " + alg);
            }
            return JWSPolicy.parseSSHPublicKey(encKey);
        }
        return null;
    }

    private static RSAPublicKey parseSSHPublicKey(String encKey) {
        byte[] PREFIX = new byte[]{0, 0, 0, 7, 115, 115, 104, 45, 114, 115, 97};
        ByteArrayInputStream in = new ByteArrayInputStream(Base64.getDecoder().decode(StandardCharsets.UTF_8.encode(encKey)).array());
        byte[] prefix = new byte[11];
        try {
            if (in.read(prefix) != 11 || !Arrays.equals(PREFIX, prefix)) {
                throw new IllegalArgumentException("SSH key prefix not found");
            }
            BigInteger e = new BigInteger(JWSPolicy.readBigInteger(in));
            BigInteger n = new BigInteger(JWSPolicy.readBigInteger(in));
            return JWSPolicy.createPublicKey(n, e);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static RSAPublicKey createPublicKey(BigInteger n, BigInteger e) {
        try {
            return (RSAPublicKey)KeyFactory.getInstance("RSA").generatePublic(new RSAPublicKeySpec(n, e));
        }
        catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    static X509Certificate extractCertificateFromX5CHeader(String[] x5c) throws CertificateException {
        return JWSPolicy.extractCertificate(new ByteArrayInputStream(DatatypeConverter.parseBase64Binary((String)x5c[0])));
    }

    static PublicKey extractPublicKeyFromPEMFile(String pemFile) throws IOException, CertificateException {
        String fileContent = new String(Files.readAllBytes(Paths.get(pemFile, new String[0])), Charset.forName(StandardCharsets.UTF_8.name()));
        fileContent = fileContent.replaceAll(BEGIN_PRIVATE_KEY, "");
        fileContent = fileContent.replaceAll(END_PRIVATE_KEY, "");
        fileContent = fileContent.replaceAll(BEGIN_RSA_PRIVATE_KEY, "");
        fileContent = fileContent.replaceAll(END_RSA_PRIVATE_KEY, "");
        X509Certificate cert = JWSPolicy.extractCertificate(new ByteArrayInputStream(fileContent.getBytes(StandardCharsets.UTF_8)));
        return cert.getPublicKey();
    }

    static X509Certificate extractCertificate(InputStream inputStream) throws CertificateException {
        X509Certificate cert = (X509Certificate)JWSPolicy.certificateFactory().generateCertificate(inputStream);
        return cert;
    }

    static CertificateFactory certificateFactory() throws CertificateException {
        return CertificateFactory.getInstance("X.509");
    }

    private static byte[] readBigInteger(ByteArrayInputStream in) throws IOException {
        byte[] b = new byte[4];
        if (in.read(b) != 4) {
            throw new IOException("Expected length data as 4 bytes");
        }
        int l = b[0] << 24 | b[1] << 16 | b[2] << 8 | b[3];
        if (in.read(b = new byte[l]) != l) {
            throw new IOException("Expected " + l + " key bytes");
        }
        return b;
    }
}

