KEYCLOAK-1881 KeyLocator implementation for SAML descriptor
This commit is contained in:
parent
057cc37b60
commit
10deac0b06
4 changed files with 361 additions and 0 deletions
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright 2016 Red Hat, Inc. and/or its affiliates
|
||||
* and other contributors as indicated by the @author tags.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.keycloak.adapters.saml.descriptor.parsers;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import javax.xml.crypto.MarshalException;
|
||||
import javax.xml.crypto.dom.DOMStructure;
|
||||
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
|
||||
import javax.xml.crypto.dsig.keyinfo.KeyInfoFactory;
|
||||
import javax.xml.parsers.DocumentBuilder;
|
||||
import javax.xml.parsers.DocumentBuilderFactory;
|
||||
import javax.xml.parsers.ParserConfigurationException;
|
||||
import javax.xml.xpath.XPath;
|
||||
import javax.xml.xpath.XPathConstants;
|
||||
import javax.xml.xpath.XPathExpression;
|
||||
import javax.xml.xpath.XPathExpressionException;
|
||||
import javax.xml.xpath.XPathFactory;
|
||||
import org.keycloak.common.util.MultivaluedHashMap;
|
||||
import org.keycloak.saml.common.constants.JBossSAMLConstants;
|
||||
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
|
||||
import org.keycloak.saml.common.exceptions.ParsingException;
|
||||
import org.keycloak.saml.processing.core.util.NamespaceContext;
|
||||
import org.w3c.dom.Document;
|
||||
import org.w3c.dom.Element;
|
||||
import org.w3c.dom.Node;
|
||||
import org.w3c.dom.NodeList;
|
||||
import org.xml.sax.SAXException;
|
||||
|
||||
/**
|
||||
* Goes through the given XML file and extracts names, certificates and keys from the KeyInfo elements.
|
||||
* @author hmlnarik
|
||||
*/
|
||||
public class SamlDescriptorIDPKeysExtractor {
|
||||
|
||||
private static final NamespaceContext NS_CONTEXT = new NamespaceContext();
|
||||
static {
|
||||
NS_CONTEXT.addNsUriPair("m", JBossSAMLURIConstants.METADATA_NSURI.get());
|
||||
NS_CONTEXT.addNsUriPair("dsig", JBossSAMLURIConstants.XMLDSIG_NSURI.get());
|
||||
}
|
||||
|
||||
private final KeyInfoFactory kif = KeyInfoFactory.getInstance();
|
||||
|
||||
private final XPathFactory xPathfactory = XPathFactory.newInstance();
|
||||
private final XPath xpath = xPathfactory.newXPath();
|
||||
{
|
||||
xpath.setNamespaceContext(NS_CONTEXT);
|
||||
}
|
||||
|
||||
public MultivaluedHashMap<String, KeyInfo> parse(InputStream stream) throws ParsingException {
|
||||
MultivaluedHashMap<String, KeyInfo> res = new MultivaluedHashMap<>();
|
||||
|
||||
try {
|
||||
DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
|
||||
factory.setNamespaceAware(true);
|
||||
DocumentBuilder builder = factory.newDocumentBuilder();
|
||||
Document doc = builder.parse(stream);
|
||||
|
||||
XPathExpression expr = xpath.compile("/m:EntitiesDescriptor/m:EntityDescriptor/m:IDPSSODescriptor/m:KeyDescriptor");
|
||||
NodeList keyDescriptors = (NodeList) expr.evaluate(doc, XPathConstants.NODESET);
|
||||
for (int i = 0; i < keyDescriptors.getLength(); i ++) {
|
||||
Node keyDescriptor = keyDescriptors.item(i);
|
||||
Element keyDescriptorEl = (Element) keyDescriptor;
|
||||
KeyInfo ki = processKeyDescriptor(keyDescriptorEl);
|
||||
if (ki != null) {
|
||||
String use = keyDescriptorEl.getAttribute(JBossSAMLConstants.USE.get());
|
||||
res.add(use, ki);
|
||||
}
|
||||
}
|
||||
} catch (SAXException | IOException | ParserConfigurationException | MarshalException | XPathExpressionException e) {
|
||||
throw new ParsingException("Error parsing SAML descriptor", e);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
private KeyInfo processKeyDescriptor(Element keyDescriptor) throws MarshalException {
|
||||
NodeList childNodes = keyDescriptor.getElementsByTagNameNS(JBossSAMLURIConstants.XMLDSIG_NSURI.get(), JBossSAMLConstants.KEY_INFO.get());
|
||||
|
||||
if (childNodes.getLength() == 0) {
|
||||
return null;
|
||||
}
|
||||
Node keyInfoNode = childNodes.item(0);
|
||||
return (keyInfoNode == null) ? null : kif.unmarshalKeyInfo(new DOMStructure(keyInfoNode));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
/*
|
||||
* Copyright 2016 Red Hat, Inc. and/or its affiliates
|
||||
* and other contributors as indicated by the @author tags.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.keycloak.adapters.saml.rotation;
|
||||
|
||||
import java.security.Key;
|
||||
import java.security.KeyManagementException;
|
||||
import java.security.PublicKey;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
|
||||
import javax.xml.crypto.dsig.keyinfo.KeyName;
|
||||
import org.apache.http.client.HttpClient;
|
||||
import org.jboss.logging.Logger;
|
||||
import org.keycloak.adapters.cloned.HttpAdapterUtils;
|
||||
import org.keycloak.adapters.cloned.HttpClientAdapterException;
|
||||
import org.keycloak.common.util.MultivaluedHashMap;
|
||||
import org.keycloak.common.util.Time;
|
||||
import org.keycloak.dom.saml.v2.metadata.KeyTypes;
|
||||
import org.keycloak.rotation.KeyLocator;
|
||||
import org.keycloak.saml.processing.api.util.KeyInfoTools;
|
||||
|
||||
/**
|
||||
* This class defines a {@link KeyLocator} that looks up public keys and certificates in IdP's
|
||||
* SAML descriptor (i.e. http://{host}/auth/realms/{realm}/protocol/saml/descriptor).
|
||||
*
|
||||
* Based on {@code JWKPublicKeyLocator}.
|
||||
*
|
||||
* @author hmlnarik
|
||||
*/
|
||||
public class SamlDescriptorPublicKeyLocator implements KeyLocator, Iterable<PublicKey> {
|
||||
|
||||
private static final Logger LOG = Logger.getLogger(SamlDescriptorPublicKeyLocator.class);
|
||||
|
||||
/**
|
||||
* Time between two subsequent requests (in seconds).
|
||||
*/
|
||||
private final int minTimeBetweenDescriptorRequests;
|
||||
|
||||
/**
|
||||
* Time to live for cache entries (in seconds).
|
||||
*/
|
||||
private final int cacheEntryTtl;
|
||||
|
||||
/**
|
||||
* Target descriptor URL.
|
||||
*/
|
||||
private final String descriptorUrl;
|
||||
|
||||
private final Map<String, PublicKey> publicKeyCache = new ConcurrentHashMap<>();
|
||||
|
||||
private final HttpClient client;
|
||||
|
||||
private volatile int lastRequestTime = 0;
|
||||
|
||||
public SamlDescriptorPublicKeyLocator(String descriptorUrl, int minTimeBetweenDescriptorRequests, int cacheEntryTtl, HttpClient httpClient) {
|
||||
this.minTimeBetweenDescriptorRequests = minTimeBetweenDescriptorRequests <= 0
|
||||
? 20
|
||||
: minTimeBetweenDescriptorRequests;
|
||||
|
||||
this.descriptorUrl = descriptorUrl;
|
||||
this.cacheEntryTtl = cacheEntryTtl;
|
||||
|
||||
this.client = httpClient;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Key getKey(String kid) throws KeyManagementException {
|
||||
if (kid == null) {
|
||||
LOG.debugf("Invalid key id: %s", kid);
|
||||
return null;
|
||||
}
|
||||
|
||||
LOG.tracef("Requested key id: %s", kid);
|
||||
|
||||
int currentTime = Time.currentTime();
|
||||
|
||||
PublicKey res;
|
||||
if (currentTime > this.lastRequestTime + this.cacheEntryTtl) {
|
||||
LOG.debugf("Performing regular cache cleanup.");
|
||||
res = refreshCertificateCacheAndGet(kid);
|
||||
} else {
|
||||
res = publicKeyCache.get(kid);
|
||||
|
||||
if (res == null) {
|
||||
if (currentTime > this.lastRequestTime + this.minTimeBetweenDescriptorRequests) {
|
||||
res = refreshCertificateCacheAndGet(kid);
|
||||
} else {
|
||||
LOG.debugf("Won't send request to realm SAML descriptor url, timeout not expired. Last request time was %d", lastRequestTime);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void refreshKeyCache() {
|
||||
LOG.info("Forcing key cache cleanup and refresh.");
|
||||
this.publicKeyCache.clear();
|
||||
refreshCertificateCacheAndGet(null);
|
||||
}
|
||||
|
||||
private synchronized PublicKey refreshCertificateCacheAndGet(String kid) {
|
||||
if (this.descriptorUrl == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
this.lastRequestTime = Time.currentTime();
|
||||
|
||||
LOG.debugf("Refreshing public key cache from %s", this.descriptorUrl);
|
||||
List<KeyInfo> signingCerts;
|
||||
try {
|
||||
MultivaluedHashMap<String, KeyInfo> certs = HttpAdapterUtils.downloadKeysFromSamlDescriptor(client, this.descriptorUrl);
|
||||
signingCerts = certs.get(KeyTypes.SIGNING.value());
|
||||
} catch (HttpClientAdapterException ex) {
|
||||
LOG.error("Could not refresh certificates from the server", ex);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (signingCerts == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
LOG.debugf("Certificates retrieved from server, filling public key cache");
|
||||
|
||||
// Only clear cache after it is certain that the SAML descriptor has been read successfully
|
||||
this.publicKeyCache.clear();
|
||||
|
||||
for (KeyInfo ki : signingCerts) {
|
||||
KeyName keyName = KeyInfoTools.getKeyName(ki);
|
||||
X509Certificate x509certificate = KeyInfoTools.getX509Certificate(ki);
|
||||
if (x509certificate != null && keyName != null) {
|
||||
LOG.tracef("Registering signing certificate %s", keyName.getName());
|
||||
this.publicKeyCache.put(keyName.getName(), x509certificate.getPublicKey());
|
||||
} else {
|
||||
LOG.tracef("Ignoring certificate %s: %s", keyName, x509certificate);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return (kid == null ? null : this.publicKeyCache.get(kid));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Keys retrieved from SAML descriptor at " + descriptorUrl;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<PublicKey> iterator() {
|
||||
if (this.publicKeyCache.isEmpty()) {
|
||||
refreshCertificateCacheAndGet(null);
|
||||
}
|
||||
|
||||
return this.publicKeyCache.values().iterator();
|
||||
}
|
||||
}
|
|
@ -26,26 +26,51 @@ public class Time {
|
|||
|
||||
private static int offset;
|
||||
|
||||
/**
|
||||
* Returns current time in seconds adjusted by adding {@link #offset) seconds.
|
||||
* @return see description
|
||||
*/
|
||||
public static int currentTime() {
|
||||
return ((int) (System.currentTimeMillis() / 1000)) + offset;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns current time in milliseconds adjusted by adding {@link #offset) seconds.
|
||||
* @return see description
|
||||
*/
|
||||
public static long currentTimeMillis() {
|
||||
return System.currentTimeMillis() + (offset * 1000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns {@link Date} object, its value set to time
|
||||
* @param time Time in milliseconds since the epoch
|
||||
* @return see description
|
||||
*/
|
||||
public static Date toDate(int time) {
|
||||
return new Date(((long) time ) * 1000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns time in milliseconds for a time in seconds. No adjustment is made to the parameter.
|
||||
* @param time Time in seconds since the epoch
|
||||
* @return Time in milliseconds
|
||||
*/
|
||||
public static long toMillis(int time) {
|
||||
return ((long) time) * 1000;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Time offset in seconds that will be added to {@link #currentTime()} and {@link #currentTimeMillis()}.
|
||||
*/
|
||||
public static int getOffset() {
|
||||
return offset;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets time offset in seconds that will be added to {@link #currentTime()} and {@link #currentTimeMillis()}.
|
||||
* @param offset Offset (in seconds)
|
||||
*/
|
||||
public static void setOffset(int offset) {
|
||||
Time.offset = offset;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright 2016 Red Hat, Inc. and/or its affiliates
|
||||
* and other contributors as indicated by the @author tags.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.keycloak.saml.processing.api.util;
|
||||
|
||||
import java.security.cert.X509Certificate;
|
||||
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
|
||||
import javax.xml.crypto.dsig.keyinfo.KeyName;
|
||||
import javax.xml.crypto.dsig.keyinfo.X509Data;
|
||||
|
||||
/**
|
||||
* Tools for {@link KeyInfo} object manipulation.
|
||||
* @author hmlnarik
|
||||
*/
|
||||
public class KeyInfoTools {
|
||||
|
||||
/**
|
||||
* Returns the first object of the given class from the given Iterable.
|
||||
* @param <T>
|
||||
* @param objects
|
||||
* @param clazz
|
||||
* @return The object or {@code null} if not found.
|
||||
*/
|
||||
public static <T> T getContent(Iterable<Object> objects, Class<T> clazz) {
|
||||
for (Object o : objects) {
|
||||
if (clazz.isInstance(o)) {
|
||||
return (T) o;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
public static KeyName getKeyName(KeyInfo keyInfo) {
|
||||
return getContent(keyInfo.getContent(), KeyName.class);
|
||||
}
|
||||
|
||||
public static X509Data getX509Data(KeyInfo keyInfo) {
|
||||
return getContent(keyInfo.getContent(), X509Data.class);
|
||||
}
|
||||
|
||||
public static X509Certificate getX509Certificate(KeyInfo keyInfo) {
|
||||
X509Data d = getX509Data(keyInfo);
|
||||
return d == null ? null : getContent(d.getContent(), X509Certificate.class);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue