KEYCLOAK-15898 Streamification of Keymanager

This commit is contained in:
Martin Kanis 2020-10-27 20:41:33 +01:00 committed by Hynek Mlnařík
parent 2fd6deaf63
commit 8d6577d66c
11 changed files with 244 additions and 162 deletions

View file

@ -21,17 +21,32 @@ import org.keycloak.crypto.KeyWrapper;
import org.keycloak.provider.Provider; import org.keycloak.provider.Provider;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/ */
public interface KeyProvider extends Provider { public interface KeyProvider extends Provider {
/** /**
* Returns the key * Returns the {@code KeyWrapper} for a {@code KeyProvider}.
* @return *
* @return Returns the {@code KeyWrapper} for a {@code KeyProvider}.
* @deprecated Use {@link #getKeysStream() getKeysStream} instead.
*/ */
List<KeyWrapper> getKeys(); @Deprecated
default List<KeyWrapper> getKeys() {
return getKeysStream().collect(Collectors.toList());
}
/**
* Returns the {@code KeyWrapper} for a {@code KeyProvider}.
*
* @return Returns the {@code KeyWrapper} for a {@code KeyProvider}.
*/
Stream<KeyWrapper> getKeysStream();
default void close() { default void close() {
} }

View file

@ -28,6 +28,8 @@ import java.security.PublicKey;
import java.security.cert.Certificate; import java.security.cert.Certificate;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
@ -38,9 +40,46 @@ public interface KeyManager {
KeyWrapper getKey(RealmModel realm, String kid, KeyUse use, String algorithm); KeyWrapper getKey(RealmModel realm, String kid, KeyUse use, String algorithm);
List<KeyWrapper> getKeys(RealmModel realm); /**
* Returns all {@code KeyWrapper} for the given realm.
*
* @param realm {@code RealmModel}.
* @return List of all {@code KeyWrapper} in the realm.
* @deprecated Use {@link #getKeysStream(RealmModel) getKeysStream} instead.
*/
@Deprecated
default List<KeyWrapper> getKeys(RealmModel realm) {
return getKeysStream(realm).collect(Collectors.toList());
}
List<KeyWrapper> getKeys(RealmModel realm, KeyUse use, String algorithm); /**
* Returns all {@code KeyWrapper} for the given realm.
* @param realm {@code RealmModel}.
* @return Stream of all {@code KeyWrapper} in the realm.
*/
Stream<KeyWrapper> getKeysStream(RealmModel realm);
/**
* Returns all {@code KeyWrapper} for the given realm that match given criteria.
* @param realm {@code RealmModel}.
* @param use {@code KeyUse}.
* @param algorithm {@code String}.
* @return List of all {@code KeyWrapper} in the realm.
* @deprecated Use {@link #getKeysStream(RealmModel, KeyUse, String) getKeysStream} instead.
*/
@Deprecated
default List<KeyWrapper> getKeys(RealmModel realm, KeyUse use, String algorithm) {
return getKeysStream(realm, use, algorithm).collect(Collectors.toList());
}
/**
* Returns all {@code KeyWrapper} for the given realm that match given criteria.
* @param realm {@code RealmModel}.
* @param use {@code KeyUse}.
* @param algorithm {@code String}.
* @return Stream of all {@code KeyWrapper} in the realm.
*/
Stream<KeyWrapper> getKeysStream(RealmModel realm, KeyUse use, String algorithm);
@Deprecated @Deprecated
ActiveRsaKey getActiveRsaKey(RealmModel realm); ActiveRsaKey getActiveRsaKey(RealmModel realm);

View file

@ -17,27 +17,41 @@
package org.keycloak.broker.saml; package org.keycloak.broker.saml;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
import org.keycloak.broker.provider.*; import org.keycloak.broker.provider.AbstractIdentityProvider;
import org.keycloak.broker.provider.AuthenticationRequest;
import org.keycloak.broker.provider.BrokeredIdentityContext;
import org.keycloak.broker.provider.IdentityBrokerException;
import org.keycloak.broker.provider.IdentityProviderDataMarshaller;
import org.keycloak.broker.provider.util.SimpleHttp; import org.keycloak.broker.provider.util.SimpleHttp;
import org.keycloak.common.util.PemUtils; import org.keycloak.common.util.PemUtils;
import org.keycloak.crypto.Algorithm;
import org.keycloak.crypto.KeyStatus; import org.keycloak.crypto.KeyStatus;
import org.keycloak.crypto.KeyUse;
import org.keycloak.dom.saml.v2.assertion.AssertionType; import org.keycloak.dom.saml.v2.assertion.AssertionType;
import org.keycloak.dom.saml.v2.assertion.AuthnStatementType; import org.keycloak.dom.saml.v2.assertion.AuthnStatementType;
import org.keycloak.dom.saml.v2.assertion.NameIDType; import org.keycloak.dom.saml.v2.assertion.NameIDType;
import org.keycloak.dom.saml.v2.assertion.SubjectType; import org.keycloak.dom.saml.v2.assertion.SubjectType;
import org.keycloak.dom.saml.v2.metadata.KeyTypes;
import org.keycloak.dom.saml.v2.protocol.AuthnRequestType; import org.keycloak.dom.saml.v2.protocol.AuthnRequestType;
import org.keycloak.dom.saml.v2.protocol.LogoutRequestType; import org.keycloak.dom.saml.v2.protocol.LogoutRequestType;
import org.keycloak.dom.saml.v2.protocol.ResponseType; import org.keycloak.dom.saml.v2.protocol.ResponseType;
import org.keycloak.events.EventBuilder; import org.keycloak.events.EventBuilder;
import org.keycloak.keys.RsaKeyMetadata; import org.keycloak.models.FederatedIdentityModel;
import org.keycloak.models.*; import org.keycloak.models.KeyManager;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.protocol.oidc.OIDCLoginProtocol; import org.keycloak.protocol.oidc.OIDCLoginProtocol;
import org.keycloak.protocol.saml.JaxrsSAML2BindingBuilder; import org.keycloak.protocol.saml.JaxrsSAML2BindingBuilder;
import org.keycloak.protocol.saml.SamlService;
import org.keycloak.protocol.saml.SamlSessionUtils; import org.keycloak.protocol.saml.SamlSessionUtils;
import org.keycloak.protocol.saml.preprocessor.SamlAuthenticationPreprocessor; import org.keycloak.protocol.saml.preprocessor.SamlAuthenticationPreprocessor;
import org.keycloak.saml.*; import org.keycloak.saml.SAML2AuthnRequestBuilder;
import org.keycloak.saml.SAML2LogoutRequestBuilder;
import org.keycloak.saml.SAML2NameIDPolicyBuilder;
import org.keycloak.saml.SAML2RequestedAuthnContextBuilder;
import org.keycloak.saml.SPMetadataDescriptor;
import org.keycloak.saml.SamlProtocolExtensionsAwareBuilder.NodeGenerator; import org.keycloak.saml.SamlProtocolExtensionsAwareBuilder.NodeGenerator;
import org.keycloak.saml.SignatureAlgorithm;
import org.keycloak.saml.common.constants.GeneralConstants; import org.keycloak.saml.common.constants.GeneralConstants;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants; import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.exceptions.ConfigurationException; import org.keycloak.saml.common.exceptions.ConfigurationException;
@ -58,15 +72,14 @@ import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriBuilder; import javax.ws.rs.core.UriBuilder;
import javax.ws.rs.core.UriInfo; import javax.ws.rs.core.UriInfo;
import javax.xml.crypto.dsig.CanonicalizationMethod; import javax.xml.crypto.dsig.CanonicalizationMethod;
import javax.xml.parsers.ParserConfigurationException;
import java.net.URI; import java.net.URI;
import java.security.KeyPair; import java.security.KeyPair;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Objects;
import java.util.TreeSet;
/** /**
* @author Pedro Igor * @author Pedro Igor
@ -324,21 +337,28 @@ public class SAMLIdentityProvider extends AbstractIdentityProvider<SAMLIdentityP
String entityId = getEntityId(uriInfo, realm); String entityId = getEntityId(uriInfo, realm);
String nameIDPolicyFormat = getConfig().getNameIDPolicyFormat(); String nameIDPolicyFormat = getConfig().getNameIDPolicyFormat();
List<Element> signingKeys = new ArrayList<Element>(); List<Element> signingKeys = new LinkedList<>();
List<Element> encryptionKeys = new ArrayList<Element>(); List<Element> encryptionKeys = new LinkedList<>();
Set<RsaKeyMetadata> keys = new TreeSet<>((o1, o2) -> o1.getStatus() == o2.getStatus() // Status can be only PASSIVE OR ACTIVE, push PASSIVE to end of list session.keys().getKeysStream(realm, KeyUse.SIG, Algorithm.RS256)
? (int) (o2.getProviderPriority() - o1.getProviderPriority()) .filter(Objects::nonNull)
: (o1.getStatus() == KeyStatus.PASSIVE ? 1 : -1)); .filter(key -> key.getCertificate() != null)
keys.addAll(session.keys().getRsaKeys(realm)); .sorted(SamlService::compareKeys)
for (RsaKeyMetadata key : keys) { .forEach(key -> {
if (key == null || key.getCertificate() == null) continue; try {
Element element = SPMetadataDescriptor
.buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate()));
signingKeys.add(element);
signingKeys.add(SPMetadataDescriptor.buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate()))); if (key.getStatus() == KeyStatus.ACTIVE) {
encryptionKeys.add(element);
}
} catch (ParserConfigurationException e) {
logger.warn("Failed to export SAML SP Metadata!", e);
throw new RuntimeException(e);
}
});
if (key.getStatus() == KeyStatus.ACTIVE)
encryptionKeys.add(SPMetadataDescriptor.buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate())));
}
String descriptor = SPMetadataDescriptor.getSPDescriptor(authnBinding, endpoint, endpoint, String descriptor = SPMetadataDescriptor.getSPDescriptor(authnBinding, endpoint, endpoint,
wantAuthnRequestsSigned, wantAssertionsSigned, wantAssertionsEncrypted, wantAuthnRequestsSigned, wantAssertionsSigned, wantAssertionsEncrypted,
entityId, nameIDPolicyFormat, signingKeys, encryptionKeys); entityId, nameIDPolicyFormat, signingKeys, encryptionKeys);

View file

@ -25,8 +25,7 @@ import org.keycloak.crypto.KeyWrapper;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import java.security.KeyPair; import java.security.KeyPair;
import java.util.Collections; import java.util.stream.Stream;
import java.util.List;
public abstract class AbstractEcdsaKeyProvider implements KeyProvider { public abstract class AbstractEcdsaKeyProvider implements KeyProvider {
@ -51,8 +50,8 @@ public abstract class AbstractEcdsaKeyProvider implements KeyProvider {
protected abstract KeyWrapper loadKey(RealmModel realm, ComponentModel model); protected abstract KeyWrapper loadKey(RealmModel realm, ComponentModel model);
@Override @Override
public List<KeyWrapper> getKeys() { public Stream<KeyWrapper> getKeysStream() {
return Collections.singletonList(key); return Stream.of(key);
} }
protected KeyWrapper createKeyWrapper(KeyPair keyPair, String ecInNistRep) { protected KeyWrapper createKeyWrapper(KeyPair keyPair, String ecInNistRep) {

View file

@ -26,8 +26,7 @@ import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper; import org.keycloak.crypto.KeyWrapper;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import java.util.Collections; import java.util.stream.Stream;
import java.util.List;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
@ -59,7 +58,7 @@ public abstract class AbstractGeneratedSecretKeyProvider implements KeyProvider
} }
@Override @Override
public List<KeyWrapper> getKeys() { public Stream<KeyWrapper> getKeysStream() {
KeyWrapper key = new KeyWrapper(); KeyWrapper key = new KeyWrapper();
key.setProviderId(model.getId()); key.setProviderId(model.getId());
@ -72,7 +71,7 @@ public abstract class AbstractGeneratedSecretKeyProvider implements KeyProvider
key.setStatus(status); key.setStatus(status);
key.setSecretKey(secretKey); key.setSecretKey(secretKey);
return Collections.singletonList(key); return Stream.of(key);
} }
@Override @Override

View file

@ -30,6 +30,7 @@ import java.security.KeyPair;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Stream;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
@ -60,8 +61,8 @@ public abstract class AbstractRsaKeyProvider implements KeyProvider {
protected abstract KeyWrapper loadKey(RealmModel realm, ComponentModel model); protected abstract KeyWrapper loadKey(RealmModel realm, ComponentModel model);
@Override @Override
public List<KeyWrapper> getKeys() { public Stream<KeyWrapper> getKeysStream() {
return Collections.singletonList(key); return Stream.of(key);
} }
protected KeyWrapper createKeyWrapper(KeyPair keyPair, X509Certificate certificate) { protected KeyWrapper createKeyWrapper(KeyPair keyPair, X509Certificate certificate) {

View file

@ -31,8 +31,15 @@ import javax.crypto.SecretKey;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.security.PublicKey; import java.security.PublicKey;
import java.security.cert.Certificate; import java.security.cert.Certificate;
import java.util.*; import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
@ -77,15 +84,20 @@ public class DefaultKeyManager implements KeyManager {
} }
private KeyWrapper getActiveKey(List<KeyProvider> providers, RealmModel realm, KeyUse use, String algorithm) { private KeyWrapper getActiveKey(List<KeyProvider> providers, RealmModel realm, KeyUse use, String algorithm) {
for (KeyProvider p : providers) { Consumer<KeyWrapper> loggerConsumer = key -> {
for (KeyWrapper key : p .getKeys()) { if (logger.isTraceEnabled()) {
if (key.getStatus().isActive() && matches(key, use, algorithm)) { logger.tracev("Active key found: realm={0} kid={1} algorithm={2} use={3}",
if (logger.isTraceEnabled()) { realm.getName(), key.getKid(), algorithm, use.name());
logger.tracev("Active key found: realm={0} kid={1} algorithm={2} use={3}", realm.getName(), key.getKid(), algorithm, use.name()); }
} };
return key; for (KeyProvider p : providers) {
} Optional<KeyWrapper> keyWrapper = p.getKeysStream()
.filter(key -> key.getStatus().isActive() && matches(key, use, algorithm))
.peek(loggerConsumer)
.findFirst();
if (keyWrapper.isPresent()) {
return keyWrapper.get();
} }
} }
return null; return null;
@ -98,15 +110,21 @@ public class DefaultKeyManager implements KeyManager {
return null; return null;
} }
for (KeyProvider p : getProviders(realm)) { Consumer<KeyWrapper> loggerConsumer = key -> {
for (KeyWrapper key : p.getKeys()) { if (logger.isTraceEnabled()) {
if (key.getKid().equals(kid) && key.getStatus().isEnabled() && matches(key, use, algorithm)) { logger.tracev("Found key: realm={0} kid={1} algorithm={2} use={3}",
if (logger.isTraceEnabled()) { realm.getName(), key.getKid(), algorithm, use.name());
logger.tracev("Found key: realm={0} kid={1} algorithm={2} use={3}", realm.getName(), key.getKid(), algorithm, use.name()); }
} };
return key; for (KeyProvider p : getProviders(realm)) {
} Optional<KeyWrapper> keyWrapper = p.getKeysStream()
.filter(key -> Objects.equals(key.getKid(), kid) && key.getStatus().isEnabled() && matches(key, use, algorithm))
.peek(loggerConsumer)
.findFirst();
if (keyWrapper.isPresent()) {
return keyWrapper.get();
} }
} }
@ -118,25 +136,15 @@ public class DefaultKeyManager implements KeyManager {
} }
@Override @Override
public List<KeyWrapper> getKeys(RealmModel realm, KeyUse use, String algorithm) { public Stream<KeyWrapper> getKeysStream(RealmModel realm, KeyUse use, String algorithm) {
List<KeyWrapper> keys = new LinkedList<>(); return getProviders(realm).stream()
for (KeyProvider p : getProviders(realm)) { .flatMap(p -> p.getKeysStream()
for (KeyWrapper key : p .getKeys()) { .filter(key -> key.getStatus().isEnabled() && matches(key, use, algorithm)));
if (key.getStatus().isEnabled() && matches(key, use, algorithm)) {
keys.add(key);
}
}
}
return keys;
} }
@Override @Override
public List<KeyWrapper> getKeys(RealmModel realm) { public Stream<KeyWrapper> getKeysStream(RealmModel realm) {
List<KeyWrapper> keys = new LinkedList<>(); return getProviders(realm).stream().flatMap(KeyProvider::getKeysStream);
for (KeyProvider p : getProviders(realm)) {
keys.addAll(p.getKeys());
}
return keys;
} }
@Override @Override
@ -191,49 +199,46 @@ public class DefaultKeyManager implements KeyManager {
@Override @Override
@Deprecated @Deprecated
public List<RsaKeyMetadata> getRsaKeys(RealmModel realm) { public List<RsaKeyMetadata> getRsaKeys(RealmModel realm) {
List<RsaKeyMetadata> keys = new LinkedList<>(); return getKeysStream(realm, KeyUse.SIG, Algorithm.RS256)
for (KeyWrapper key : getKeys(realm, KeyUse.SIG, Algorithm.RS256)) { .map(key -> {
RsaKeyMetadata m = new RsaKeyMetadata(); RsaKeyMetadata m = new RsaKeyMetadata();
m.setCertificate(key.getCertificate()); m.setCertificate(key.getCertificate());
m.setPublicKey((PublicKey) key.getPublicKey()); m.setPublicKey((PublicKey) key.getPublicKey());
m.setKid(key.getKid()); m.setKid(key.getKid());
m.setProviderId(key.getProviderId()); m.setProviderId(key.getProviderId());
m.setProviderPriority(key.getProviderPriority()); m.setProviderPriority(key.getProviderPriority());
m.setStatus(key.getStatus()); m.setStatus(key.getStatus());
return m;
keys.add(m); })
} .collect(Collectors.toList());
return keys;
} }
@Override @Override
public List<SecretKeyMetadata> getHmacKeys(RealmModel realm) { public List<SecretKeyMetadata> getHmacKeys(RealmModel realm) {
List<SecretKeyMetadata> keys = new LinkedList<>(); return getKeysStream(realm, KeyUse.SIG, Algorithm.HS256)
for (KeyWrapper key : getKeys(realm, KeyUse.SIG, Algorithm.HS256)) { .map(key -> {
SecretKeyMetadata m = new SecretKeyMetadata(); SecretKeyMetadata m = new SecretKeyMetadata();
m.setKid(key.getKid()); m.setKid(key.getKid());
m.setProviderId(key.getProviderId()); m.setProviderId(key.getProviderId());
m.setProviderPriority(key.getProviderPriority()); m.setProviderPriority(key.getProviderPriority());
m.setStatus(key.getStatus()); m.setStatus(key.getStatus());
return m;
keys.add(m); })
} .collect(Collectors.toList());
return keys;
} }
@Override @Override
public List<SecretKeyMetadata> getAesKeys(RealmModel realm) { public List<SecretKeyMetadata> getAesKeys(RealmModel realm) {
List<SecretKeyMetadata> keys = new LinkedList<>(); return getKeysStream(realm, KeyUse.ENC, Algorithm.AES)
for (KeyWrapper key : getKeys(realm, KeyUse.ENC, Algorithm.AES)) { .map(key -> {
SecretKeyMetadata m = new SecretKeyMetadata(); SecretKeyMetadata m = new SecretKeyMetadata();
m.setKid(key.getKid()); m.setKid(key.getKid());
m.setProviderId(key.getProviderId()); m.setProviderId(key.getProviderId());
m.setProviderPriority(key.getProviderPriority()); m.setProviderPriority(key.getProviderPriority());
m.setStatus(key.getStatus()); m.setStatus(key.getStatus());
return m;
keys.add(m); })
} .collect(Collectors.toList());
return keys;
} }
private boolean matches(KeyWrapper key, KeyUse use, String algorithm) { private boolean matches(KeyWrapper key, KeyUse use, String algorithm) {

View file

@ -25,7 +25,6 @@ import org.keycloak.OAuthErrorException;
import org.keycloak.common.ClientConnection; import org.keycloak.common.ClientConnection;
import org.keycloak.crypto.KeyType; import org.keycloak.crypto.KeyType;
import org.keycloak.crypto.KeyUse; import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.events.EventBuilder; import org.keycloak.events.EventBuilder;
import org.keycloak.forms.login.LoginFormsProvider; import org.keycloak.forms.login.LoginFormsProvider;
import org.keycloak.jose.jwk.JSONWebKeySet; import org.keycloak.jose.jwk.JSONWebKeySet;
@ -49,8 +48,7 @@ import org.keycloak.services.resources.Cors;
import org.keycloak.services.resources.RealmsResource; import org.keycloak.services.resources.RealmsResource;
import org.keycloak.services.util.CacheControlUtil; import org.keycloak.services.util.CacheControlUtil;
import java.util.LinkedList; import java.util.Objects;
import java.util.List;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.NotFoundException; import javax.ws.rs.NotFoundException;
@ -219,23 +217,22 @@ public class OIDCLoginProtocolService {
public Response certs() { public Response certs() {
checkSsl(); checkSsl();
List<JWK> keys = new LinkedList<>(); JWK[] jwks = session.keys().getKeysStream(realm)
for (KeyWrapper k : session.keys().getKeys(realm)) { .filter(k -> k.getStatus().isEnabled() && Objects.equals(k.getUse(), KeyUse.SIG) && k.getPublicKey() != null)
if (k.getStatus().isEnabled() && k.getUse().equals(KeyUse.SIG) && k.getPublicKey() != null) { .map(k -> {
JWKBuilder b = JWKBuilder.create().kid(k.getKid()).algorithm(k.getAlgorithm()); JWKBuilder b = JWKBuilder.create().kid(k.getKid()).algorithm(k.getAlgorithm());
if (k.getType().equals(KeyType.RSA)) { if (k.getType().equals(KeyType.RSA)) {
keys.add(b.rsa(k.getPublicKey(), k.getCertificate())); return b.rsa(k.getPublicKey(), k.getCertificate());
} else if (k.getType().equals(KeyType.EC)) { } else if (k.getType().equals(KeyType.EC)) {
keys.add(b.ec(k.getPublicKey())); return b.ec(k.getPublicKey());
} }
} return null;
} })
.filter(Objects::nonNull)
.toArray(JWK[]::new);
JSONWebKeySet keySet = new JSONWebKeySet(); JSONWebKeySet keySet = new JSONWebKeySet();
keySet.setKeys(jwks);
JWK[] k = new JWK[keys.size()];
k = keys.toArray(k);
keySet.setKeys(k);
Response.ResponseBuilder responseBuilder = Response.ok(keySet).cacheControl(CacheControlUtil.getDefaultCacheControl()); Response.ResponseBuilder responseBuilder = Response.ok(keySet).cacheControl(CacheControlUtil.getDefaultCacheControl());
return Cors.add(request, responseBuilder).allowedOrigins("*").auth().build(); return Cors.add(request, responseBuilder).allowedOrigins("*").auth().build();

View file

@ -31,8 +31,6 @@ import java.util.List;
import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamWriter; import javax.xml.stream.XMLStreamWriter;
import org.keycloak.saml.common.exceptions.ProcessingException; import org.keycloak.saml.common.exceptions.ProcessingException;
import org.keycloak.saml.processing.core.saml.v2.writers.SAMLMetadataWriter; import org.keycloak.saml.processing.core.saml.v2.writers.SAMLMetadataWriter;
@ -56,8 +54,8 @@ import static org.keycloak.saml.common.constants.JBossSAMLURIConstants.PROTOCOL_
public class IDPMetadataDescriptor { public class IDPMetadataDescriptor {
public static String getIDPDescriptor(URI loginPostEndpoint, URI loginRedirectEndpoint, URI logoutEndpoint, public static String getIDPDescriptor(URI loginPostEndpoint, URI loginRedirectEndpoint, URI logoutEndpoint,
String entityId, boolean wantAuthnRequestsSigned, List<Element> signingCerts, List<Element> encryptionCerts) String entityId, boolean wantAuthnRequestsSigned, List<Element> signingCerts)
throws XMLStreamException, ProcessingException, ParserConfigurationException throws ProcessingException
{ {
StringWriter sw = new StringWriter(); StringWriter sw = new StringWriter();

View file

@ -77,12 +77,11 @@ import javax.ws.rs.core.UriInfo;
import java.io.InputStream; import java.io.InputStream;
import java.net.URI; import java.net.URI;
import java.security.PublicKey; import java.security.PublicKey;
import java.util.ArrayList;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.stream.Collectors;
import java.util.TreeSet;
import org.keycloak.crypto.Algorithm; import org.keycloak.crypto.Algorithm;
import org.keycloak.crypto.KeyUse; import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper; import org.keycloak.crypto.KeyWrapper;
@ -93,6 +92,8 @@ import org.keycloak.saml.validators.DestinationValidator;
import org.keycloak.sessions.AuthenticationSessionModel; import org.keycloak.sessions.AuthenticationSessionModel;
import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.MultivaluedMap;
import javax.xml.crypto.dsig.XMLSignature; import javax.xml.crypto.dsig.XMLSignature;
import javax.xml.parsers.ParserConfigurationException;
import org.w3c.dom.Document; import org.w3c.dom.Document;
import org.w3c.dom.NodeList; import org.w3c.dom.NodeList;
@ -653,16 +654,18 @@ public class SamlService extends AuthorizationEndpointBase {
} }
public static String getIDPMetadataDescriptor(UriInfo uriInfo, KeycloakSession session, RealmModel realm) { public static String getIDPMetadataDescriptor(UriInfo uriInfo, KeycloakSession session, RealmModel realm) {
Set<KeyWrapper> keys = new TreeSet<>((o1, o2) -> o1.getStatus() == o2.getStatus() // Status can be only PASSIVE OR ACTIVE, push PASSIVE to end of list
? (int) (o2.getProviderPriority() - o1.getProviderPriority())
: (o1.getStatus() == KeyStatus.PASSIVE ? 1 : -1));
keys.addAll(session.keys().getKeys(realm, KeyUse.SIG, Algorithm.RS256));
try { try {
List<Element> signingKeys = new ArrayList<Element>(); List<Element> signingKeys = session.keys().getKeysStream(realm, KeyUse.SIG, Algorithm.RS256)
for (KeyWrapper key : keys) { .sorted(SamlService::compareKeys)
signingKeys.add(IDPMetadataDescriptor.buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate()))); .map(key -> {
} try {
return IDPMetadataDescriptor
.buildKeyInfoElement(key.getKid(), PemUtils.encodeCertificate(key.getCertificate()));
} catch (ParserConfigurationException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
return IDPMetadataDescriptor.getIDPDescriptor( return IDPMetadataDescriptor.getIDPDescriptor(
RealmsResource.protocolUrl(uriInfo).build(realm.getName(), SamlProtocol.LOGIN_PROTOCOL), RealmsResource.protocolUrl(uriInfo).build(realm.getName(), SamlProtocol.LOGIN_PROTOCOL),
@ -670,13 +673,19 @@ public class SamlService extends AuthorizationEndpointBase {
RealmsResource.protocolUrl(uriInfo).build(realm.getName(), SamlProtocol.LOGIN_PROTOCOL), RealmsResource.protocolUrl(uriInfo).build(realm.getName(), SamlProtocol.LOGIN_PROTOCOL),
RealmsResource.realmBaseUrl(uriInfo).build(realm.getName()).toString(), RealmsResource.realmBaseUrl(uriInfo).build(realm.getName()).toString(),
true, true,
signingKeys, null); signingKeys);
} catch (Exception ex) { } catch (Exception ex) {
logger.error("Cannot generate IdP metadata", ex); logger.error("Cannot generate IdP metadata", ex);
return ""; return "";
} }
} }
public static int compareKeys(KeyWrapper o1, KeyWrapper o2) {
return o1.getStatus() == o2.getStatus() // Status can be only PASSIVE OR ACTIVE, push PASSIVE to end of list
? (int) (o2.getProviderPriority() - o1.getProviderPriority())
: (o1.getStatus() == KeyStatus.PASSIVE ? 1 : -1);
}
private boolean isClientProtocolCorrect(ClientModel clientModel) { private boolean isClientProtocolCorrect(ClientModel clientModel) {
if (SamlProtocol.LOGIN_PROTOCOL.equals(clientModel.getProtocol())) { if (SamlProtocol.LOGIN_PROTOCOL.equals(clientModel.getProtocol())) {
return true; return true;

View file

@ -20,10 +20,7 @@ package org.keycloak.services.resources.admin;
import org.jboss.resteasy.annotations.cache.NoCache; import org.jboss.resteasy.annotations.cache.NoCache;
import org.keycloak.common.util.PemUtils; import org.keycloak.common.util.PemUtils;
import org.keycloak.crypto.KeyWrapper; import org.keycloak.crypto.KeyWrapper;
import org.keycloak.jose.jws.AlgorithmType;
import org.keycloak.keys.SecretKeyMetadata;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeyManager;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.representations.idm.KeysMetadataRepresentation; import org.keycloak.representations.idm.KeysMetadataRepresentation;
import org.keycloak.services.resources.admin.permissions.AdminPermissionEvaluator; import org.keycloak.services.resources.admin.permissions.AdminPermissionEvaluator;
@ -32,9 +29,8 @@ import javax.ws.rs.GET;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.stream.Collectors;
/** /**
* @resource Key * @resource Key
@ -59,29 +55,33 @@ public class KeyResource {
auth.realm().requireViewRealm(); auth.realm().requireViewRealm();
KeysMetadataRepresentation keys = new KeysMetadataRepresentation(); KeysMetadataRepresentation keys = new KeysMetadataRepresentation();
keys.setKeys(new LinkedList<>());
keys.setActive(new HashMap<>()); keys.setActive(new HashMap<>());
for (KeyWrapper key : session.keys().getKeys(realm)) { List<KeysMetadataRepresentation.KeyMetadataRepresentation> realmKeys = session.keys().getKeysStream(realm)
KeysMetadataRepresentation.KeyMetadataRepresentation r = new KeysMetadataRepresentation.KeyMetadataRepresentation(); .map(key -> {
r.setProviderId(key.getProviderId()); if (key.getStatus().isActive()) {
r.setProviderPriority(key.getProviderPriority()); if (!keys.getActive().containsKey(key.getAlgorithm())) {
r.setKid(key.getKid()); keys.getActive().put(key.getAlgorithm(), key.getKid());
r.setStatus(key.getStatus() != null ? key.getStatus().name() : null); }
r.setType(key.getType()); }
r.setAlgorithm(key.getAlgorithm()); return toKeyMetadataRepresentation(key);
r.setPublicKey(key.getPublicKey() != null ? PemUtils.encodeKey(key.getPublicKey()) : null); })
r.setCertificate(key.getCertificate() != null ? PemUtils.encodeCertificate(key.getCertificate()) : null); .collect(Collectors.toList());
keys.getKeys().add(r); keys.setKeys(realmKeys);
if (key.getStatus().isActive()) {
if (!keys.getActive().containsKey(key.getAlgorithm())) {
keys.getActive().put(key.getAlgorithm(), key.getKid());
}
}
}
return keys; return keys;
} }
private KeysMetadataRepresentation.KeyMetadataRepresentation toKeyMetadataRepresentation(KeyWrapper key) {
KeysMetadataRepresentation.KeyMetadataRepresentation r = new KeysMetadataRepresentation.KeyMetadataRepresentation();
r.setProviderId(key.getProviderId());
r.setProviderPriority(key.getProviderPriority());
r.setKid(key.getKid());
r.setStatus(key.getStatus() != null ? key.getStatus().name() : null);
r.setType(key.getType());
r.setAlgorithm(key.getAlgorithm());
r.setPublicKey(key.getPublicKey() != null ? PemUtils.encodeKey(key.getPublicKey()) : null);
r.setCertificate(key.getCertificate() != null ? PemUtils.encodeCertificate(key.getCertificate()) : null);
return r;
}
} }