diff --git a/core/src/main/java/org/keycloak/crypto/KeyWrapper.java b/core/src/main/java/org/keycloak/crypto/KeyWrapper.java index 130e8c3677..ace1930ab1 100644 --- a/core/src/main/java/org/keycloak/crypto/KeyWrapper.java +++ b/core/src/main/java/org/keycloak/crypto/KeyWrapper.java @@ -16,6 +16,7 @@ */ package org.keycloak.crypto; +import java.util.List; import javax.crypto.SecretKey; import java.security.Key; import java.security.cert.X509Certificate; @@ -33,6 +34,7 @@ public class KeyWrapper { private Key publicKey; private Key privateKey; private X509Certificate certificate; + private List certificateChain; public String getProviderId() { return providerId; @@ -122,4 +124,12 @@ public class KeyWrapper { this.certificate = certificate; } + public List getCertificateChain() { + return certificateChain; + } + + public void setCertificateChain(List certificateChain) { + this.certificateChain = certificateChain; + } + } diff --git a/core/src/main/java/org/keycloak/jose/jwk/JWKBuilder.java b/core/src/main/java/org/keycloak/jose/jwk/JWKBuilder.java index 5298d932ba..90f5b9f8d9 100644 --- a/core/src/main/java/org/keycloak/jose/jwk/JWKBuilder.java +++ b/core/src/main/java/org/keycloak/jose/jwk/JWKBuilder.java @@ -17,6 +17,8 @@ package org.keycloak.jose.jwk; +import java.util.Collections; +import java.util.List; import org.keycloak.common.util.Base64Url; import org.keycloak.common.util.KeyUtils; import org.keycloak.common.util.PemUtils; @@ -65,10 +67,14 @@ public class JWKBuilder { } public JWK rsa(Key key) { - return rsa(key, (X509Certificate)null); + return rsa(key, (List) null); } public JWK rsa(Key key, X509Certificate certificate) { + return rsa(key, Collections.singletonList(certificate)); + } + + public JWK rsa(Key key, List certificates) { RSAPublicKey rsaKey = (RSAPublicKey) key; RSAPublicJWK k = new RSAPublicJWK(); @@ -80,9 +86,13 @@ public class JWKBuilder { k.setPublicKeyUse(DEFAULT_PUBLIC_KEY_USE); k.setModulus(Base64Url.encode(toIntegerBytes(rsaKey.getModulus()))); k.setPublicExponent(Base64Url.encode(toIntegerBytes(rsaKey.getPublicExponent()))); - - if (certificate != null) { - k.setX509CertificateChain(new String [] {PemUtils.encodeCertificate(certificate)}); + + if (certificates != null && !certificates.isEmpty()) { + String[] certificateChain = new String[certificates.size()]; + for (int i = 0; i < certificates.size(); i++) { + certificateChain[i] = PemUtils.encodeCertificate(certificates.get(i)); + } + k.setX509CertificateChain(certificateChain); } return k; diff --git a/core/src/test/java/org/keycloak/jose/jwk/JWKTest.java b/core/src/test/java/org/keycloak/jose/jwk/JWKTest.java index 9bf7c3a42b..41dde853a8 100644 --- a/core/src/test/java/org/keycloak/jose/jwk/JWKTest.java +++ b/core/src/test/java/org/keycloak/jose/jwk/JWKTest.java @@ -17,9 +17,10 @@ package org.keycloak.jose.jwk; +import java.util.Arrays; +import java.util.List; import org.junit.Test; import org.keycloak.common.util.Base64Url; -import org.keycloak.common.util.CertificateUtils; import org.keycloak.common.util.KeyUtils; import org.keycloak.common.util.PemUtils; import org.keycloak.crypto.JavaAlgorithm; @@ -37,6 +38,7 @@ import java.security.cert.X509Certificate; import java.security.spec.ECGenParameterSpec; import static org.junit.Assert.*; +import static org.keycloak.common.util.CertificateUtils.*; /** * @author Stian Thorgersen @@ -47,7 +49,7 @@ public class JWKTest { public void publicRs256() throws Exception { KeyPair keyPair = KeyPairGenerator.getInstance("RSA").generateKeyPair(); PublicKey publicKey = keyPair.getPublic(); - X509Certificate certificate = CertificateUtils.generateV1SelfSignedCertificate(keyPair, "Test"); + X509Certificate certificate = generateV1SelfSignedCertificate(keyPair, "Test"); JWK jwk = JWKBuilder.create().kid(KeyUtils.createKeyId(publicKey)).algorithm("RS256").rsa(publicKey, certificate); @@ -78,6 +80,47 @@ public class JWKTest { verify(data, sign, JavaAlgorithm.RS256, publicKeyFromJwk); } + @Test + public void publicRs256Chain() throws Exception { + KeyPair keyPair = KeyPairGenerator.getInstance("RSA").generateKeyPair(); + PublicKey publicKey = keyPair.getPublic(); + List certificates = Arrays.asList(generateV1SelfSignedCertificate(keyPair, "Test"), generateV1SelfSignedCertificate(keyPair, "Intermediate")); + + JWK jwk = JWKBuilder.create().kid(KeyUtils.createKeyId(publicKey)).algorithm("RS256").rsa(publicKey, certificates); + + assertNotNull(jwk.getKeyId()); + assertEquals("RSA", jwk.getKeyType()); + assertEquals("RS256", jwk.getAlgorithm()); + assertEquals("sig", jwk.getPublicKeyUse()); + + assertTrue(jwk instanceof RSAPublicJWK); + assertNotNull(((RSAPublicJWK) jwk).getModulus()); + assertNotNull(((RSAPublicJWK) jwk).getPublicExponent()); + assertNotNull(((RSAPublicJWK) jwk).getX509CertificateChain()); + + String[] expectedChain = new String[certificates.size()]; + for (int i = 0; i < certificates.size(); i++) { + expectedChain[i] = PemUtils.encodeCertificate(certificates.get(i)); + } + + assertArrayEquals(expectedChain, ((RSAPublicJWK) jwk).getX509CertificateChain()); + assertNotNull(((RSAPublicJWK) jwk).getSha1x509Thumbprint()); + assertEquals(PemUtils.generateThumbprint(((RSAPublicJWK) jwk).getX509CertificateChain(), "SHA-1"), ((RSAPublicJWK) jwk).getSha1x509Thumbprint()); + assertNotNull(((RSAPublicJWK) jwk).getSha256x509Thumbprint()); + assertEquals(PemUtils.generateThumbprint(((RSAPublicJWK) jwk).getX509CertificateChain(), "SHA-256"), ((RSAPublicJWK) jwk).getSha256x509Thumbprint()); + + String jwkJson = JsonSerialization.writeValueAsString(jwk); + + PublicKey publicKeyFromJwk = JWKParser.create().parse(jwkJson).toPublicKey(); + + // Parse + assertArrayEquals(publicKey.getEncoded(), publicKeyFromJwk.getEncoded()); + + byte[] data = "Some test string".getBytes(StandardCharsets.UTF_8); + byte[] sign = sign(data, JavaAlgorithm.RS256, keyPair.getPrivate()); + verify(data, sign, JavaAlgorithm.RS256, publicKeyFromJwk); + } + @Test public void publicEs256() throws Exception { Security.addProvider(new org.bouncycastle.jce.provider.BouncyCastleProvider()); diff --git a/services/src/main/java/org/keycloak/keys/AbstractRsaKeyProvider.java b/services/src/main/java/org/keycloak/keys/AbstractRsaKeyProvider.java index 94def6a72f..075f20d55f 100644 --- a/services/src/main/java/org/keycloak/keys/AbstractRsaKeyProvider.java +++ b/services/src/main/java/org/keycloak/keys/AbstractRsaKeyProvider.java @@ -19,11 +19,7 @@ package org.keycloak.keys; import org.keycloak.common.util.KeyUtils; import org.keycloak.component.ComponentModel; -import org.keycloak.crypto.Algorithm; -import org.keycloak.crypto.KeyStatus; -import org.keycloak.crypto.KeyType; -import org.keycloak.crypto.KeyUse; -import org.keycloak.crypto.KeyWrapper; +import org.keycloak.crypto.*; import org.keycloak.models.RealmModel; import java.security.KeyPair; @@ -66,6 +62,10 @@ public abstract class AbstractRsaKeyProvider implements KeyProvider { } protected KeyWrapper createKeyWrapper(KeyPair keyPair, X509Certificate certificate) { + return createKeyWrapper(keyPair, certificate, Collections.emptyList()); + } + + protected KeyWrapper createKeyWrapper(KeyPair keyPair, X509Certificate certificate, List certificateChain) { KeyWrapper key = new KeyWrapper(); key.setProviderId(model.getId()); @@ -80,6 +80,14 @@ public abstract class AbstractRsaKeyProvider implements KeyProvider { key.setPublicKey(keyPair.getPublic()); key.setCertificate(certificate); + if (!certificateChain.isEmpty()) { + if (certificate != null && !certificate.equals(certificateChain.get(0))) { + // just in case the chain does not contain the end-user certificate + certificateChain.add(0, certificate); + } + key.setCertificateChain(certificateChain); + } + return key; } diff --git a/services/src/main/java/org/keycloak/keys/JavaKeystoreKeyProvider.java b/services/src/main/java/org/keycloak/keys/JavaKeystoreKeyProvider.java index d085ba55b0..5cf134d332 100644 --- a/services/src/main/java/org/keycloak/keys/JavaKeystoreKeyProvider.java +++ b/services/src/main/java/org/keycloak/keys/JavaKeystoreKeyProvider.java @@ -27,6 +27,7 @@ import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; +import java.security.GeneralSecurityException; import java.security.KeyPair; import java.security.KeyStore; import java.security.KeyStoreException; @@ -34,8 +35,20 @@ import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; import java.security.PublicKey; import java.security.UnrecoverableKeyException; +import java.security.cert.CertPath; +import java.security.cert.CertPathValidator; import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.PKIXParameters; +import java.security.cert.TrustAnchor; import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; /** * @author Stian Thorgersen @@ -52,17 +65,18 @@ public class JavaKeystoreKeyProvider extends AbstractRsaKeyProvider { KeyStore keyStore = KeyStore.getInstance("JKS"); keyStore.load(is, model.get(JavaKeystoreKeyProviderFactory.KEYSTORE_PASSWORD_KEY).toCharArray()); - PrivateKey privateKey = (PrivateKey) keyStore.getKey(model.get(JavaKeystoreKeyProviderFactory.KEY_ALIAS_KEY), model.get(JavaKeystoreKeyProviderFactory.KEY_PASSWORD_KEY).toCharArray()); + String keyAlias = model.get(JavaKeystoreKeyProviderFactory.KEY_ALIAS_KEY); + PrivateKey privateKey = (PrivateKey) keyStore.getKey(keyAlias, model.get(JavaKeystoreKeyProviderFactory.KEY_PASSWORD_KEY).toCharArray()); PublicKey publicKey = KeyUtils.extractPublicKey(privateKey); KeyPair keyPair = new KeyPair(publicKey, privateKey); - X509Certificate certificate = (X509Certificate) keyStore.getCertificate(model.get(JavaKeystoreKeyProviderFactory.KEY_ALIAS_KEY)); + X509Certificate certificate = (X509Certificate) keyStore.getCertificate(keyAlias); if (certificate == null) { certificate = CertificateUtils.generateV1SelfSignedCertificate(keyPair, realm.getName()); } - return createKeyWrapper(keyPair, certificate); + return createKeyWrapper(keyPair, certificate, loadCertificateChain(keyStore, keyAlias)); } catch (KeyStoreException kse) { throw new RuntimeException("KeyStore error on server. " + kse.getMessage(), kse); } catch (FileNotFoundException fnfe) { @@ -75,7 +89,49 @@ public class JavaKeystoreKeyProvider extends AbstractRsaKeyProvider { throw new RuntimeException("Certificate error on server. " + ce.getMessage(), ce); } catch (UnrecoverableKeyException uke) { throw new RuntimeException("Keystore on server can not be recovered. " + uke.getMessage(), uke); + } catch (GeneralSecurityException gse) { + throw new RuntimeException("Invalid certificate chain. Check the order of certificates.", gse); } } + private List loadCertificateChain(KeyStore keyStore, String keyAlias) throws GeneralSecurityException { + List chain = Optional.ofNullable(keyStore.getCertificateChain(keyAlias)) + .map(certificates -> Arrays.stream(certificates) + .map(X509Certificate.class::cast) + .collect(Collectors.toList())) + .orElseGet(Collections::emptyList); + + validateCertificateChain(chain); + + return chain; + } + + /** + *

Validates the giving certificate chain represented by {@code certificates}. If the list of certificates is empty + * or does not have at least 2 certificates (end-user certificate plus intermediary/root CAs) this method does nothing. + * + *

It should not be possible to import to keystores invalid chains though. So this is just an additional check + * that we can reuse later for other purposes when the cert chain is also provided manually, in PEM. + * + * @param certificates + */ + private void validateCertificateChain(List certificates) throws GeneralSecurityException { + if (certificates == null || certificates.isEmpty()) { + return; + } + + Set anchors = new HashSet<>(); + + // consider the last certificate in the chain as the most trusted cert + anchors.add(new TrustAnchor(certificates.get(certificates.size() - 1), null)); + + PKIXParameters params = new PKIXParameters(anchors); + + params.setRevocationEnabled(false); + + CertPath certPath = CertificateFactory.getInstance("X.509").generateCertPath(certificates); + CertPathValidator validator = CertPathValidator.getInstance(CertPathValidator.getDefaultType()); + + validator.validate(certPath, params); + } } diff --git a/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java b/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java index 560f0e3965..e6bb3865b6 100644 --- a/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java +++ b/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java @@ -17,6 +17,10 @@ package org.keycloak.protocol.oidc; +import java.security.cert.X509Certificate; +import java.util.Collections; +import java.util.List; +import java.util.Optional; import org.jboss.logging.Logger; import org.jboss.resteasy.annotations.cache.NoCache; import org.jboss.resteasy.spi.HttpRequest; @@ -226,8 +230,11 @@ public class OIDCLoginProtocolService { .filter(k -> k.getStatus().isEnabled() && Objects.equals(k.getUse(), KeyUse.SIG) && k.getPublicKey() != null) .map(k -> { JWKBuilder b = JWKBuilder.create().kid(k.getKid()).algorithm(k.getAlgorithm()); + List certificates = Optional.ofNullable(k.getCertificateChain()) + .filter(certs -> !certs.isEmpty()) + .orElseGet(() -> Collections.singletonList(k.getCertificate())); if (k.getType().equals(KeyType.RSA)) { - return b.rsa(k.getPublicKey(), k.getCertificate()); + return b.rsa(k.getPublicKey(), certificates); } else if (k.getType().equals(KeyType.EC)) { return b.ec(k.getPublicKey()); } diff --git a/testsuite/integration-arquillian/tests/base/src/test/resources/org/keycloak/testsuite/keys/keystore.jks b/testsuite/integration-arquillian/tests/base/src/test/resources/org/keycloak/testsuite/keys/keystore.jks index ad54a7bcb7..4bdf272220 100644 Binary files a/testsuite/integration-arquillian/tests/base/src/test/resources/org/keycloak/testsuite/keys/keystore.jks and b/testsuite/integration-arquillian/tests/base/src/test/resources/org/keycloak/testsuite/keys/keystore.jks differ