diff --git a/server-spi-private/src/main/java/org/keycloak/keys/KeyProvider.java b/server-spi-private/src/main/java/org/keycloak/keys/KeyProvider.java index 9870ff33b1..0379bdb09c 100644 --- a/server-spi-private/src/main/java/org/keycloak/keys/KeyProvider.java +++ b/server-spi-private/src/main/java/org/keycloak/keys/KeyProvider.java @@ -21,17 +21,32 @@ import org.keycloak.crypto.KeyWrapper; import org.keycloak.provider.Provider; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * @author Stian Thorgersen */ public interface KeyProvider extends Provider { + /** - * Returns the key - * @return + * Returns the {@code KeyWrapper} for a {@code KeyProvider}. + * + * @return Returns the {@code KeyWrapper} for a {@code KeyProvider}. + * @deprecated Use {@link #getKeysStream() getKeysStream} instead. */ - List getKeys(); + @Deprecated + default List getKeys() { + return getKeysStream().collect(Collectors.toList()); + } + + /** + * Returns the {@code KeyWrapper} for a {@code KeyProvider}. + * + * @return Returns the {@code KeyWrapper} for a {@code KeyProvider}. + */ + Stream getKeysStream(); default void close() { } diff --git a/server-spi/src/main/java/org/keycloak/models/KeyManager.java b/server-spi/src/main/java/org/keycloak/models/KeyManager.java index f7a9b407e7..216d37d9fb 100644 --- a/server-spi/src/main/java/org/keycloak/models/KeyManager.java +++ b/server-spi/src/main/java/org/keycloak/models/KeyManager.java @@ -28,6 +28,8 @@ import java.security.PublicKey; import java.security.cert.Certificate; import java.security.cert.X509Certificate; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * @author Stian Thorgersen @@ -38,9 +40,46 @@ public interface KeyManager { KeyWrapper getKey(RealmModel realm, String kid, KeyUse use, String algorithm); - List getKeys(RealmModel realm); + /** + * Returns all {@code KeyWrapper} for the given realm. + * + * @param realm {@code RealmModel}. + * @return List of all {@code KeyWrapper} in the realm. + * @deprecated Use {@link #getKeysStream(RealmModel) getKeysStream} instead. + */ + @Deprecated + default List getKeys(RealmModel realm) { + return getKeysStream(realm).collect(Collectors.toList()); + } - List getKeys(RealmModel realm, KeyUse use, String algorithm); + /** + * Returns all {@code KeyWrapper} for the given realm. + * @param realm {@code RealmModel}. + * @return Stream of all {@code KeyWrapper} in the realm. + */ + Stream getKeysStream(RealmModel realm); + + /** + * Returns all {@code KeyWrapper} for the given realm that match given criteria. + * @param realm {@code RealmModel}. + * @param use {@code KeyUse}. + * @param algorithm {@code String}. + * @return List of all {@code KeyWrapper} in the realm. + * @deprecated Use {@link #getKeysStream(RealmModel, KeyUse, String) getKeysStream} instead. + */ + @Deprecated + default List getKeys(RealmModel realm, KeyUse use, String algorithm) { + return getKeysStream(realm, use, algorithm).collect(Collectors.toList()); + } + + /** + * Returns all {@code KeyWrapper} for the given realm that match given criteria. + * @param realm {@code RealmModel}. + * @param use {@code KeyUse}. + * @param algorithm {@code String}. + * @return Stream of all {@code KeyWrapper} in the realm. + */ + Stream getKeysStream(RealmModel realm, KeyUse use, String algorithm); @Deprecated ActiveRsaKey getActiveRsaKey(RealmModel realm); diff --git a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java index 012d2e63e2..b312397094 100755 --- a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java +++ b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java @@ -17,27 +17,41 @@ package org.keycloak.broker.saml; import org.jboss.logging.Logger; -import org.keycloak.broker.provider.*; +import org.keycloak.broker.provider.AbstractIdentityProvider; +import org.keycloak.broker.provider.AuthenticationRequest; +import org.keycloak.broker.provider.BrokeredIdentityContext; +import org.keycloak.broker.provider.IdentityBrokerException; +import org.keycloak.broker.provider.IdentityProviderDataMarshaller; import org.keycloak.broker.provider.util.SimpleHttp; import org.keycloak.common.util.PemUtils; +import org.keycloak.crypto.Algorithm; import org.keycloak.crypto.KeyStatus; +import org.keycloak.crypto.KeyUse; import org.keycloak.dom.saml.v2.assertion.AssertionType; import org.keycloak.dom.saml.v2.assertion.AuthnStatementType; import org.keycloak.dom.saml.v2.assertion.NameIDType; import org.keycloak.dom.saml.v2.assertion.SubjectType; -import org.keycloak.dom.saml.v2.metadata.KeyTypes; import org.keycloak.dom.saml.v2.protocol.AuthnRequestType; import org.keycloak.dom.saml.v2.protocol.LogoutRequestType; import org.keycloak.dom.saml.v2.protocol.ResponseType; import org.keycloak.events.EventBuilder; -import org.keycloak.keys.RsaKeyMetadata; -import org.keycloak.models.*; +import org.keycloak.models.FederatedIdentityModel; +import org.keycloak.models.KeyManager; +import org.keycloak.models.KeycloakSession; +import org.keycloak.models.RealmModel; +import org.keycloak.models.UserSessionModel; import org.keycloak.protocol.oidc.OIDCLoginProtocol; import org.keycloak.protocol.saml.JaxrsSAML2BindingBuilder; +import org.keycloak.protocol.saml.SamlService; import org.keycloak.protocol.saml.SamlSessionUtils; import org.keycloak.protocol.saml.preprocessor.SamlAuthenticationPreprocessor; -import org.keycloak.saml.*; +import org.keycloak.saml.SAML2AuthnRequestBuilder; +import org.keycloak.saml.SAML2LogoutRequestBuilder; +import org.keycloak.saml.SAML2NameIDPolicyBuilder; +import org.keycloak.saml.SAML2RequestedAuthnContextBuilder; +import org.keycloak.saml.SPMetadataDescriptor; import org.keycloak.saml.SamlProtocolExtensionsAwareBuilder.NodeGenerator; +import org.keycloak.saml.SignatureAlgorithm; import org.keycloak.saml.common.constants.GeneralConstants; import org.keycloak.saml.common.constants.JBossSAMLURIConstants; import org.keycloak.saml.common.exceptions.ConfigurationException; @@ -58,15 +72,14 @@ import javax.ws.rs.core.Response; import javax.ws.rs.core.UriBuilder; import javax.ws.rs.core.UriInfo; import javax.xml.crypto.dsig.CanonicalizationMethod; +import javax.xml.parsers.ParserConfigurationException; import java.net.URI; import java.security.KeyPair; -import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedList; import java.util.Iterator; import java.util.List; -import java.util.Set; -import java.util.TreeSet; +import java.util.Objects; /** * @author Pedro Igor @@ -324,21 +337,28 @@ public class SAMLIdentityProvider extends AbstractIdentityProvider signingKeys = new ArrayList(); - List encryptionKeys = new ArrayList(); + List signingKeys = new LinkedList<>(); + List encryptionKeys = new LinkedList<>(); - Set keys = new TreeSet<>((o1, o2) -> o1.getStatus() == o2.getStatus() // Status can be only PASSIVE OR ACTIVE, push PASSIVE to end of list - ? (int) (o2.getProviderPriority() - o1.getProviderPriority()) - : (o1.getStatus() == KeyStatus.PASSIVE ? 1 : -1)); - keys.addAll(session.keys().getRsaKeys(realm)); - for (RsaKeyMetadata key : keys) { - if (key == null || key.getCertificate() == null) continue; + session.keys().getKeysStream(realm, KeyUse.SIG, Algorithm.RS256) + .filter(Objects::nonNull) + .filter(key -> key.getCertificate() != null) + .sorted(SamlService::compareKeys) + .forEach(key -> { + try { + Element element = SPMetadataDescriptor + .buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate())); + signingKeys.add(element); - signingKeys.add(SPMetadataDescriptor.buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate()))); + if (key.getStatus() == KeyStatus.ACTIVE) { + encryptionKeys.add(element); + } + } catch (ParserConfigurationException e) { + logger.warn("Failed to export SAML SP Metadata!", e); + throw new RuntimeException(e); + } + }); - if (key.getStatus() == KeyStatus.ACTIVE) - encryptionKeys.add(SPMetadataDescriptor.buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate()))); - } String descriptor = SPMetadataDescriptor.getSPDescriptor(authnBinding, endpoint, endpoint, wantAuthnRequestsSigned, wantAssertionsSigned, wantAssertionsEncrypted, entityId, nameIDPolicyFormat, signingKeys, encryptionKeys); diff --git a/services/src/main/java/org/keycloak/keys/AbstractEcdsaKeyProvider.java b/services/src/main/java/org/keycloak/keys/AbstractEcdsaKeyProvider.java index e178ab16ba..647adb3c30 100644 --- a/services/src/main/java/org/keycloak/keys/AbstractEcdsaKeyProvider.java +++ b/services/src/main/java/org/keycloak/keys/AbstractEcdsaKeyProvider.java @@ -25,8 +25,7 @@ import org.keycloak.crypto.KeyWrapper; import org.keycloak.models.RealmModel; import java.security.KeyPair; -import java.util.Collections; -import java.util.List; +import java.util.stream.Stream; public abstract class AbstractEcdsaKeyProvider implements KeyProvider { @@ -51,8 +50,8 @@ public abstract class AbstractEcdsaKeyProvider implements KeyProvider { protected abstract KeyWrapper loadKey(RealmModel realm, ComponentModel model); @Override - public List getKeys() { - return Collections.singletonList(key); + public Stream getKeysStream() { + return Stream.of(key); } protected KeyWrapper createKeyWrapper(KeyPair keyPair, String ecInNistRep) { diff --git a/services/src/main/java/org/keycloak/keys/AbstractGeneratedSecretKeyProvider.java b/services/src/main/java/org/keycloak/keys/AbstractGeneratedSecretKeyProvider.java index 3f34bf405a..51ed75b8b7 100644 --- a/services/src/main/java/org/keycloak/keys/AbstractGeneratedSecretKeyProvider.java +++ b/services/src/main/java/org/keycloak/keys/AbstractGeneratedSecretKeyProvider.java @@ -26,8 +26,7 @@ import org.keycloak.crypto.KeyUse; import org.keycloak.crypto.KeyWrapper; import javax.crypto.SecretKey; -import java.util.Collections; -import java.util.List; +import java.util.stream.Stream; /** * @author Stian Thorgersen @@ -59,7 +58,7 @@ public abstract class AbstractGeneratedSecretKeyProvider implements KeyProvider } @Override - public List getKeys() { + public Stream getKeysStream() { KeyWrapper key = new KeyWrapper(); key.setProviderId(model.getId()); @@ -72,7 +71,7 @@ public abstract class AbstractGeneratedSecretKeyProvider implements KeyProvider key.setStatus(status); key.setSecretKey(secretKey); - return Collections.singletonList(key); + return Stream.of(key); } @Override diff --git a/services/src/main/java/org/keycloak/keys/AbstractRsaKeyProvider.java b/services/src/main/java/org/keycloak/keys/AbstractRsaKeyProvider.java index 940d47d2a5..94def6a72f 100644 --- a/services/src/main/java/org/keycloak/keys/AbstractRsaKeyProvider.java +++ b/services/src/main/java/org/keycloak/keys/AbstractRsaKeyProvider.java @@ -30,6 +30,7 @@ import java.security.KeyPair; import java.security.cert.X509Certificate; import java.util.Collections; import java.util.List; +import java.util.stream.Stream; /** * @author Stian Thorgersen @@ -60,8 +61,8 @@ public abstract class AbstractRsaKeyProvider implements KeyProvider { protected abstract KeyWrapper loadKey(RealmModel realm, ComponentModel model); @Override - public List getKeys() { - return Collections.singletonList(key); + public Stream getKeysStream() { + return Stream.of(key); } protected KeyWrapper createKeyWrapper(KeyPair keyPair, X509Certificate certificate) { diff --git a/services/src/main/java/org/keycloak/keys/DefaultKeyManager.java b/services/src/main/java/org/keycloak/keys/DefaultKeyManager.java index efc6639340..411888fab1 100644 --- a/services/src/main/java/org/keycloak/keys/DefaultKeyManager.java +++ b/services/src/main/java/org/keycloak/keys/DefaultKeyManager.java @@ -31,8 +31,15 @@ import javax.crypto.SecretKey; import java.security.PrivateKey; import java.security.PublicKey; import java.security.cert.Certificate; -import java.util.*; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.function.Consumer; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * @author Stian Thorgersen @@ -77,15 +84,20 @@ public class DefaultKeyManager implements KeyManager { } private KeyWrapper getActiveKey(List providers, RealmModel realm, KeyUse use, String algorithm) { - for (KeyProvider p : providers) { - for (KeyWrapper key : p .getKeys()) { - if (key.getStatus().isActive() && matches(key, use, algorithm)) { - if (logger.isTraceEnabled()) { - logger.tracev("Active key found: realm={0} kid={1} algorithm={2} use={3}", realm.getName(), key.getKid(), algorithm, use.name()); - } + Consumer loggerConsumer = key -> { + if (logger.isTraceEnabled()) { + logger.tracev("Active key found: realm={0} kid={1} algorithm={2} use={3}", + realm.getName(), key.getKid(), algorithm, use.name()); + } + }; - return key; - } + for (KeyProvider p : providers) { + Optional keyWrapper = p.getKeysStream() + .filter(key -> key.getStatus().isActive() && matches(key, use, algorithm)) + .peek(loggerConsumer) + .findFirst(); + if (keyWrapper.isPresent()) { + return keyWrapper.get(); } } return null; @@ -98,15 +110,21 @@ public class DefaultKeyManager implements KeyManager { return null; } - for (KeyProvider p : getProviders(realm)) { - for (KeyWrapper key : p.getKeys()) { - if (key.getKid().equals(kid) && key.getStatus().isEnabled() && matches(key, use, algorithm)) { - if (logger.isTraceEnabled()) { - logger.tracev("Found key: realm={0} kid={1} algorithm={2} use={3}", realm.getName(), key.getKid(), algorithm, use.name()); - } + Consumer loggerConsumer = key -> { + if (logger.isTraceEnabled()) { + logger.tracev("Found key: realm={0} kid={1} algorithm={2} use={3}", + realm.getName(), key.getKid(), algorithm, use.name()); + } + }; - return key; - } + for (KeyProvider p : getProviders(realm)) { + Optional keyWrapper = p.getKeysStream() + .filter(key -> Objects.equals(key.getKid(), kid) && key.getStatus().isEnabled() && matches(key, use, algorithm)) + .peek(loggerConsumer) + .findFirst(); + + if (keyWrapper.isPresent()) { + return keyWrapper.get(); } } @@ -118,25 +136,15 @@ public class DefaultKeyManager implements KeyManager { } @Override - public List getKeys(RealmModel realm, KeyUse use, String algorithm) { - List keys = new LinkedList<>(); - for (KeyProvider p : getProviders(realm)) { - for (KeyWrapper key : p .getKeys()) { - if (key.getStatus().isEnabled() && matches(key, use, algorithm)) { - keys.add(key); - } - } - } - return keys; + public Stream getKeysStream(RealmModel realm, KeyUse use, String algorithm) { + return getProviders(realm).stream() + .flatMap(p -> p.getKeysStream() + .filter(key -> key.getStatus().isEnabled() && matches(key, use, algorithm))); } @Override - public List getKeys(RealmModel realm) { - List keys = new LinkedList<>(); - for (KeyProvider p : getProviders(realm)) { - keys.addAll(p.getKeys()); - } - return keys; + public Stream getKeysStream(RealmModel realm) { + return getProviders(realm).stream().flatMap(KeyProvider::getKeysStream); } @Override @@ -191,49 +199,46 @@ public class DefaultKeyManager implements KeyManager { @Override @Deprecated public List getRsaKeys(RealmModel realm) { - List keys = new LinkedList<>(); - for (KeyWrapper key : getKeys(realm, KeyUse.SIG, Algorithm.RS256)) { - RsaKeyMetadata m = new RsaKeyMetadata(); - m.setCertificate(key.getCertificate()); - m.setPublicKey((PublicKey) key.getPublicKey()); - m.setKid(key.getKid()); - m.setProviderId(key.getProviderId()); - m.setProviderPriority(key.getProviderPriority()); - m.setStatus(key.getStatus()); - - keys.add(m); - } - return keys; + return getKeysStream(realm, KeyUse.SIG, Algorithm.RS256) + .map(key -> { + RsaKeyMetadata m = new RsaKeyMetadata(); + m.setCertificate(key.getCertificate()); + m.setPublicKey((PublicKey) key.getPublicKey()); + m.setKid(key.getKid()); + m.setProviderId(key.getProviderId()); + m.setProviderPriority(key.getProviderPriority()); + m.setStatus(key.getStatus()); + return m; + }) + .collect(Collectors.toList()); } @Override public List getHmacKeys(RealmModel realm) { - List keys = new LinkedList<>(); - for (KeyWrapper key : getKeys(realm, KeyUse.SIG, Algorithm.HS256)) { - SecretKeyMetadata m = new SecretKeyMetadata(); - m.setKid(key.getKid()); - m.setProviderId(key.getProviderId()); - m.setProviderPriority(key.getProviderPriority()); - m.setStatus(key.getStatus()); - - keys.add(m); - } - return keys; + return getKeysStream(realm, KeyUse.SIG, Algorithm.HS256) + .map(key -> { + SecretKeyMetadata m = new SecretKeyMetadata(); + m.setKid(key.getKid()); + m.setProviderId(key.getProviderId()); + m.setProviderPriority(key.getProviderPriority()); + m.setStatus(key.getStatus()); + return m; + }) + .collect(Collectors.toList()); } @Override public List getAesKeys(RealmModel realm) { - List keys = new LinkedList<>(); - for (KeyWrapper key : getKeys(realm, KeyUse.ENC, Algorithm.AES)) { - SecretKeyMetadata m = new SecretKeyMetadata(); - m.setKid(key.getKid()); - m.setProviderId(key.getProviderId()); - m.setProviderPriority(key.getProviderPriority()); - m.setStatus(key.getStatus()); - - keys.add(m); - } - return keys; + return getKeysStream(realm, KeyUse.ENC, Algorithm.AES) + .map(key -> { + SecretKeyMetadata m = new SecretKeyMetadata(); + m.setKid(key.getKid()); + m.setProviderId(key.getProviderId()); + m.setProviderPriority(key.getProviderPriority()); + m.setStatus(key.getStatus()); + return m; + }) + .collect(Collectors.toList()); } private boolean matches(KeyWrapper key, KeyUse use, String algorithm) { diff --git a/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java b/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java index 579985410f..323e9baf3f 100644 --- a/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java +++ b/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java @@ -25,7 +25,6 @@ import org.keycloak.OAuthErrorException; import org.keycloak.common.ClientConnection; import org.keycloak.crypto.KeyType; import org.keycloak.crypto.KeyUse; -import org.keycloak.crypto.KeyWrapper; import org.keycloak.events.EventBuilder; import org.keycloak.forms.login.LoginFormsProvider; import org.keycloak.jose.jwk.JSONWebKeySet; @@ -49,8 +48,7 @@ import org.keycloak.services.resources.Cors; import org.keycloak.services.resources.RealmsResource; import org.keycloak.services.util.CacheControlUtil; -import java.util.LinkedList; -import java.util.List; +import java.util.Objects; import javax.ws.rs.GET; import javax.ws.rs.NotFoundException; @@ -219,23 +217,22 @@ public class OIDCLoginProtocolService { public Response certs() { checkSsl(); - List keys = new LinkedList<>(); - for (KeyWrapper k : session.keys().getKeys(realm)) { - if (k.getStatus().isEnabled() && k.getUse().equals(KeyUse.SIG) && k.getPublicKey() != null) { - JWKBuilder b = JWKBuilder.create().kid(k.getKid()).algorithm(k.getAlgorithm()); - if (k.getType().equals(KeyType.RSA)) { - keys.add(b.rsa(k.getPublicKey(), k.getCertificate())); - } else if (k.getType().equals(KeyType.EC)) { - keys.add(b.ec(k.getPublicKey())); - } - } - } + JWK[] jwks = session.keys().getKeysStream(realm) + .filter(k -> k.getStatus().isEnabled() && Objects.equals(k.getUse(), KeyUse.SIG) && k.getPublicKey() != null) + .map(k -> { + JWKBuilder b = JWKBuilder.create().kid(k.getKid()).algorithm(k.getAlgorithm()); + if (k.getType().equals(KeyType.RSA)) { + return b.rsa(k.getPublicKey(), k.getCertificate()); + } else if (k.getType().equals(KeyType.EC)) { + return b.ec(k.getPublicKey()); + } + return null; + }) + .filter(Objects::nonNull) + .toArray(JWK[]::new); JSONWebKeySet keySet = new JSONWebKeySet(); - - JWK[] k = new JWK[keys.size()]; - k = keys.toArray(k); - keySet.setKeys(k); + keySet.setKeys(jwks); Response.ResponseBuilder responseBuilder = Response.ok(keySet).cacheControl(CacheControlUtil.getDefaultCacheControl()); return Cors.add(request, responseBuilder).allowedOrigins("*").auth().build(); diff --git a/services/src/main/java/org/keycloak/protocol/saml/IDPMetadataDescriptor.java b/services/src/main/java/org/keycloak/protocol/saml/IDPMetadataDescriptor.java index 1f95c81c69..0774b1d459 100644 --- a/services/src/main/java/org/keycloak/protocol/saml/IDPMetadataDescriptor.java +++ b/services/src/main/java/org/keycloak/protocol/saml/IDPMetadataDescriptor.java @@ -31,8 +31,6 @@ import java.util.List; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; -import javax.xml.parsers.ParserConfigurationException; -import javax.xml.stream.XMLStreamException; import javax.xml.stream.XMLStreamWriter; import org.keycloak.saml.common.exceptions.ProcessingException; import org.keycloak.saml.processing.core.saml.v2.writers.SAMLMetadataWriter; @@ -56,8 +54,8 @@ import static org.keycloak.saml.common.constants.JBossSAMLURIConstants.PROTOCOL_ public class IDPMetadataDescriptor { public static String getIDPDescriptor(URI loginPostEndpoint, URI loginRedirectEndpoint, URI logoutEndpoint, - String entityId, boolean wantAuthnRequestsSigned, List signingCerts, List encryptionCerts) - throws XMLStreamException, ProcessingException, ParserConfigurationException + String entityId, boolean wantAuthnRequestsSigned, List signingCerts) + throws ProcessingException { StringWriter sw = new StringWriter(); diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlService.java b/services/src/main/java/org/keycloak/protocol/saml/SamlService.java index f535acee71..088b5e1830 100755 --- a/services/src/main/java/org/keycloak/protocol/saml/SamlService.java +++ b/services/src/main/java/org/keycloak/protocol/saml/SamlService.java @@ -77,12 +77,11 @@ import javax.ws.rs.core.UriInfo; import java.io.InputStream; import java.net.URI; import java.security.PublicKey; -import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Objects; -import java.util.Set; -import java.util.TreeSet; +import java.util.stream.Collectors; + import org.keycloak.crypto.Algorithm; import org.keycloak.crypto.KeyUse; import org.keycloak.crypto.KeyWrapper; @@ -93,6 +92,8 @@ import org.keycloak.saml.validators.DestinationValidator; import org.keycloak.sessions.AuthenticationSessionModel; import javax.ws.rs.core.MultivaluedMap; import javax.xml.crypto.dsig.XMLSignature; +import javax.xml.parsers.ParserConfigurationException; + import org.w3c.dom.Document; import org.w3c.dom.NodeList; @@ -653,16 +654,18 @@ public class SamlService extends AuthorizationEndpointBase { } public static String getIDPMetadataDescriptor(UriInfo uriInfo, KeycloakSession session, RealmModel realm) { - Set keys = new TreeSet<>((o1, o2) -> o1.getStatus() == o2.getStatus() // Status can be only PASSIVE OR ACTIVE, push PASSIVE to end of list - ? (int) (o2.getProviderPriority() - o1.getProviderPriority()) - : (o1.getStatus() == KeyStatus.PASSIVE ? 1 : -1)); - keys.addAll(session.keys().getKeys(realm, KeyUse.SIG, Algorithm.RS256)); - try { - List signingKeys = new ArrayList(); - for (KeyWrapper key : keys) { - signingKeys.add(IDPMetadataDescriptor.buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate()))); - } + List signingKeys = session.keys().getKeysStream(realm, KeyUse.SIG, Algorithm.RS256) + .sorted(SamlService::compareKeys) + .map(key -> { + try { + return IDPMetadataDescriptor + .buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate())); + } catch (ParserConfigurationException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); return IDPMetadataDescriptor.getIDPDescriptor( RealmsResource.protocolUrl(uriInfo).build(realm.getName(), SamlProtocol.LOGIN_PROTOCOL), @@ -670,13 +673,19 @@ public class SamlService extends AuthorizationEndpointBase { RealmsResource.protocolUrl(uriInfo).build(realm.getName(), SamlProtocol.LOGIN_PROTOCOL), RealmsResource.realmBaseUrl(uriInfo).build(realm.getName()).toString(), true, - signingKeys, null); + signingKeys); } catch (Exception ex) { logger.error("Cannot generate IdP metadata", ex); return ""; } } + public static int compareKeys(KeyWrapper o1, KeyWrapper o2) { + return o1.getStatus() == o2.getStatus() // Status can be only PASSIVE OR ACTIVE, push PASSIVE to end of list + ? (int) (o2.getProviderPriority() - o1.getProviderPriority()) + : (o1.getStatus() == KeyStatus.PASSIVE ? 1 : -1); + } + private boolean isClientProtocolCorrect(ClientModel clientModel) { if (SamlProtocol.LOGIN_PROTOCOL.equals(clientModel.getProtocol())) { return true; diff --git a/services/src/main/java/org/keycloak/services/resources/admin/KeyResource.java b/services/src/main/java/org/keycloak/services/resources/admin/KeyResource.java index b47de30bb0..123570c081 100644 --- a/services/src/main/java/org/keycloak/services/resources/admin/KeyResource.java +++ b/services/src/main/java/org/keycloak/services/resources/admin/KeyResource.java @@ -20,10 +20,7 @@ package org.keycloak.services.resources.admin; import org.jboss.resteasy.annotations.cache.NoCache; import org.keycloak.common.util.PemUtils; import org.keycloak.crypto.KeyWrapper; -import org.keycloak.jose.jws.AlgorithmType; -import org.keycloak.keys.SecretKeyMetadata; import org.keycloak.models.KeycloakSession; -import org.keycloak.models.KeyManager; import org.keycloak.models.RealmModel; import org.keycloak.representations.idm.KeysMetadataRepresentation; import org.keycloak.services.resources.admin.permissions.AdminPermissionEvaluator; @@ -32,9 +29,8 @@ import javax.ws.rs.GET; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; import java.util.HashMap; -import java.util.LinkedList; import java.util.List; -import java.util.Map; +import java.util.stream.Collectors; /** * @resource Key @@ -59,29 +55,33 @@ public class KeyResource { auth.realm().requireViewRealm(); KeysMetadataRepresentation keys = new KeysMetadataRepresentation(); - keys.setKeys(new LinkedList<>()); keys.setActive(new HashMap<>()); - for (KeyWrapper key : session.keys().getKeys(realm)) { - KeysMetadataRepresentation.KeyMetadataRepresentation r = new KeysMetadataRepresentation.KeyMetadataRepresentation(); - r.setProviderId(key.getProviderId()); - r.setProviderPriority(key.getProviderPriority()); - r.setKid(key.getKid()); - r.setStatus(key.getStatus() != null ? key.getStatus().name() : null); - r.setType(key.getType()); - r.setAlgorithm(key.getAlgorithm()); - r.setPublicKey(key.getPublicKey() != null ? PemUtils.encodeKey(key.getPublicKey()) : null); - r.setCertificate(key.getCertificate() != null ? PemUtils.encodeCertificate(key.getCertificate()) : null); - keys.getKeys().add(r); - - if (key.getStatus().isActive()) { - if (!keys.getActive().containsKey(key.getAlgorithm())) { - keys.getActive().put(key.getAlgorithm(), key.getKid()); - } - } - } + List realmKeys = session.keys().getKeysStream(realm) + .map(key -> { + if (key.getStatus().isActive()) { + if (!keys.getActive().containsKey(key.getAlgorithm())) { + keys.getActive().put(key.getAlgorithm(), key.getKid()); + } + } + return toKeyMetadataRepresentation(key); + }) + .collect(Collectors.toList()); + keys.setKeys(realmKeys); return keys; } + private KeysMetadataRepresentation.KeyMetadataRepresentation toKeyMetadataRepresentation(KeyWrapper key) { + KeysMetadataRepresentation.KeyMetadataRepresentation r = new KeysMetadataRepresentation.KeyMetadataRepresentation(); + r.setProviderId(key.getProviderId()); + r.setProviderPriority(key.getProviderPriority()); + r.setKid(key.getKid()); + r.setStatus(key.getStatus() != null ? key.getStatus().name() : null); + r.setType(key.getType()); + r.setAlgorithm(key.getAlgorithm()); + r.setPublicKey(key.getPublicKey() != null ? PemUtils.encodeKey(key.getPublicKey()) : null); + r.setCertificate(key.getCertificate() != null ? PemUtils.encodeCertificate(key.getCertificate()) : null); + return r; + } }