Better handling for SAML signatures in POST and REDIRECT bindings

Closes https://github.com/keycloak/keycloak/issues/17456
This commit is contained in:
rmartinc 2023-03-13 11:22:39 +01:00 committed by Pedro Igor
parent 5e7793b64d
commit cab7e50410
12 changed files with 755 additions and 199 deletions

View file

@ -19,6 +19,17 @@ package org.keycloak.adapters.saml.profile;
import static org.keycloak.adapters.saml.SamlPrincipal.DEFAULT_ROLE_ATTRIBUTE_NAME;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.xml.crypto.dsig.XMLSignature;
import javax.xml.datatype.XMLGregorianCalendar;
import javax.xml.namespace.QName;
import org.jboss.logging.Logger;
import org.keycloak.adapters.saml.AbstractInitiateLogin;
import org.keycloak.adapters.saml.OnSessionCreated;
@ -36,6 +47,7 @@ import org.keycloak.common.VerificationException;
import org.keycloak.common.util.Base64;
import org.keycloak.common.util.KeycloakUriBuilder;
import org.keycloak.common.util.MultivaluedHashMap;
import org.keycloak.dom.saml.v2.SAML2Object;
import org.keycloak.dom.saml.v2.assertion.AssertionType;
import org.keycloak.dom.saml.v2.assertion.AttributeStatementType;
import org.keycloak.dom.saml.v2.assertion.AttributeType;
@ -43,12 +55,14 @@ import org.keycloak.dom.saml.v2.assertion.AuthnStatementType;
import org.keycloak.dom.saml.v2.assertion.NameIDType;
import org.keycloak.dom.saml.v2.assertion.StatementAbstractType;
import org.keycloak.dom.saml.v2.assertion.SubjectType;
import org.keycloak.dom.saml.v2.protocol.ExtensionsType;
import org.keycloak.dom.saml.v2.protocol.LogoutRequestType;
import org.keycloak.dom.saml.v2.protocol.RequestAbstractType;
import org.keycloak.dom.saml.v2.protocol.ResponseType;
import org.keycloak.dom.saml.v2.protocol.StatusCodeType;
import org.keycloak.dom.saml.v2.protocol.StatusResponseType;
import org.keycloak.dom.saml.v2.protocol.StatusType;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.BaseSAML2BindingBuilder;
import org.keycloak.saml.SAML2AuthnRequestBuilder;
import org.keycloak.saml.SAMLRequestParser;
@ -62,32 +76,15 @@ import org.keycloak.saml.common.util.DocumentUtil;
import org.keycloak.saml.processing.api.saml.v2.sig.SAML2Signature;
import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder;
import org.keycloak.saml.processing.core.saml.v2.util.AssertionUtil;
import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
import org.keycloak.saml.processing.core.util.RedirectBindingSignatureUtil;
import org.keycloak.saml.processing.core.util.XMLEncryptionUtil;
import org.keycloak.saml.processing.web.util.PostBindingUtil;
import org.keycloak.saml.validators.ConditionsValidator;
import org.keycloak.saml.validators.DestinationValidator;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import java.io.IOException;
import java.net.URI;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyManagementException;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.util.*;
import javax.xml.datatype.XMLGregorianCalendar;
import javax.xml.namespace.QName;
import org.keycloak.dom.saml.v2.SAML2Object;
import org.keycloak.dom.saml.v2.protocol.ExtensionsType;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
import org.keycloak.saml.processing.core.util.XMLEncryptionUtil;
import org.keycloak.saml.validators.ConditionsValidator;
import org.keycloak.saml.validators.DestinationValidator;
import javax.xml.crypto.dsig.XMLSignature;
import org.w3c.dom.NodeList;
/**
@ -676,7 +673,7 @@ public abstract class AbstractSamlAuthenticationHandler implements SamlAuthentic
SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.getFromXmlMethod(decodedAlgorithm);
if (! validateRedirectBindingSignature(signatureAlgorithm, rawQueryBytes, decodedSignature, keyLocator, keyId)) {
if (!RedirectBindingSignatureUtil.validateRedirectBindingSignature(signatureAlgorithm, rawQueryBytes, decodedSignature, keyLocator, keyId)) {
throw new VerificationException("Invalid query param signature");
}
} catch (Exception e) {
@ -684,67 +681,6 @@ public abstract class AbstractSamlAuthenticationHandler implements SamlAuthentic
}
}
private boolean validateRedirectBindingSignature(SignatureAlgorithm sigAlg, byte[] rawQueryBytes, byte[] decodedSignature, KeyLocator locator, String keyId)
throws KeyManagementException, VerificationException {
try {
Key key;
try {
key = locator.getKey(keyId);
boolean keyLocated = key != null;
if (keyLocated) {
return validateRedirectBindingSignatureForKey(sigAlg, rawQueryBytes, decodedSignature, key);
}
} catch (KeyManagementException ex) {
}
} catch (SignatureException ex) {
log.debug("Verification failed for key %s: %s", keyId, ex);
log.trace(ex);
}
if (locator instanceof Iterable) {
Iterable<Key> availableKeys = (Iterable<Key>) locator;
log.trace("Trying hard to validate XML signature using all available keys.");
for (Key key : availableKeys) {
try {
if (validateRedirectBindingSignatureForKey(sigAlg, rawQueryBytes, decodedSignature, key)) {
return true;
}
} catch (SignatureException ex) {
log.debug("Verification failed: %s", ex);
}
}
}
return false;
}
private boolean validateRedirectBindingSignatureForKey(SignatureAlgorithm sigAlg, byte[] rawQueryBytes, byte[] decodedSignature, Key key)
throws SignatureException {
if (key == null) {
return false;
}
if (! (key instanceof PublicKey)) {
log.warnf("Unusable key for signature validation: %s", key);
return false;
}
Signature signature = sigAlg.createSignature(); // todo plugin signature alg
try {
signature.initVerify((PublicKey) key);
} catch (InvalidKeyException ex) {
log.warnf(ex, "Unusable key for signature validation: %s", key);
return false;
}
signature.update(rawQueryBytes);
return signature.verify(decodedSignature);
}
protected boolean isAutodetectedBearerOnly(HttpFacade.Request request) {
if (!deployment.isAutodetectBearerOnly()) return false;

View file

@ -19,12 +19,14 @@ package org.keycloak.adapters.saml.rotation;
import java.security.Key;
import java.security.KeyManagementException;
import java.security.PublicKey;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import javax.security.auth.x500.X500Principal;
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
import javax.xml.crypto.dsig.keyinfo.KeyName;
import org.apache.http.client.HttpClient;
@ -36,9 +38,6 @@ import org.keycloak.common.util.Time;
import org.keycloak.dom.saml.v2.metadata.KeyTypes;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.processing.api.util.KeyInfoTools;
import java.security.cert.CertificateException;
import java.util.UUID;
import javax.security.auth.x500.X500Principal;
/**
* This class defines a {@link KeyLocator} that looks up public keys and certificates in IdP's
@ -48,7 +47,7 @@ import javax.security.auth.x500.X500Principal;
*
* @author hmlnarik
*/
public class SamlDescriptorPublicKeyLocator implements KeyLocator, Iterable<PublicKey> {
public class SamlDescriptorPublicKeyLocator implements KeyLocator {
private static final Logger LOG = Logger.getLogger(SamlDescriptorPublicKeyLocator.class);
@ -67,7 +66,8 @@ public class SamlDescriptorPublicKeyLocator implements KeyLocator, Iterable<Publ
*/
private final String descriptorUrl;
private final Map<String, PublicKey> publicKeyCache = new ConcurrentHashMap<>();
private final Map<String, Key> publicKeyCacheByName = new ConcurrentHashMap<>();
private final Map<KeyHash, Key> publicKeyCacheByKey = new ConcurrentHashMap<>();
private final HttpClient client;
@ -90,21 +90,32 @@ public class SamlDescriptorPublicKeyLocator implements KeyLocator, Iterable<Publ
LOG.debugf("Invalid key id: %s", kid);
return null;
}
return getKey(kid, publicKeyCacheByName);
}
LOG.tracef("Requested key id: %s", kid);
@Override
public Key getKey(Key key) throws KeyManagementException {
if (key == null) {
return null;
}
return getKey(new KeyHash(key), publicKeyCacheByKey);
}
private <T> Key getKey(T key, Map<T, Key> cache) throws KeyManagementException {
LOG.tracef("Requested key: %s", key);
int currentTime = Time.currentTime();
PublicKey res;
Key res;
if (currentTime > this.lastRequestTime + this.cacheEntryTtl) {
LOG.debugf("Performing regular cache cleanup.");
res = refreshCertificateCacheAndGet(kid);
res = refreshCertificateCacheAndGet(key, cache, currentTime);
} else {
res = publicKeyCache.get(kid);
res = cache.get(key);
if (res == null) {
if (currentTime > this.lastRequestTime + this.minTimeBetweenDescriptorRequests) {
res = refreshCertificateCacheAndGet(kid);
res = refreshCertificateCacheAndGet(key, cache, currentTime);
} else {
LOG.debugf("Won't send request to realm SAML descriptor url, timeout not expired. Last request time was %d", lastRequestTime);
}
@ -117,13 +128,15 @@ public class SamlDescriptorPublicKeyLocator implements KeyLocator, Iterable<Publ
@Override
public synchronized void refreshKeyCache() {
LOG.info("Forcing key cache cleanup and refresh.");
this.publicKeyCache.clear();
refreshCertificateCacheAndGet(null);
this.publicKeyCacheByName.clear();
this.publicKeyCacheByKey.clear();
refreshCertificateCacheAndGet(null, this.publicKeyCacheByKey, Time.currentTime());
}
private synchronized PublicKey refreshCertificateCacheAndGet(String kid) {
if (this.descriptorUrl == null) {
return null;
private synchronized <T> Key refreshCertificateCacheAndGet(T key, Map<T, Key> cache, int currentTime) {
if (this.descriptorUrl == null || currentTime <= this.lastRequestTime + this.minTimeBetweenDescriptorRequests) {
// no descriptor or updated time too short
return key == null ? null : cache.get(key);
}
this.lastRequestTime = Time.currentTime();
@ -145,7 +158,8 @@ public class SamlDescriptorPublicKeyLocator implements KeyLocator, Iterable<Publ
LOG.debugf("Certificates retrieved from server, filling public key cache");
// Only clear cache after it is certain that the SAML descriptor has been read successfully
this.publicKeyCache.clear();
this.publicKeyCacheByName.clear();
this.publicKeyCacheByKey.clear();
for (KeyInfo ki : signingCerts) {
KeyName keyName = KeyInfoTools.getKeyName(ki);
@ -161,17 +175,18 @@ public class SamlDescriptorPublicKeyLocator implements KeyLocator, Iterable<Publ
if (keyName != null) {
LOG.tracef("Registering signing certificate %s", keyName.getName());
this.publicKeyCache.put(keyName.getName(), x509certificate.getPublicKey());
this.publicKeyCacheByName.put(keyName.getName(), x509certificate.getPublicKey());
this.publicKeyCacheByKey.put(new KeyHash(x509certificate.getPublicKey()), x509certificate.getPublicKey());
} else {
final X500Principal principal = x509certificate.getSubjectX500Principal();
String name = (principal == null ? "unnamed" : principal.getName())
+ "@" + x509certificate.getSerialNumber() + "$" + UUID.randomUUID();
this.publicKeyCache.put(name, x509certificate.getPublicKey());
String name = (principal == null ? "unnamed" : principal.getName()) + "@" + x509certificate.getSerialNumber() + "$" + UUID.randomUUID();
this.publicKeyCacheByName.put(name, x509certificate.getPublicKey());
this.publicKeyCacheByKey.put(new KeyHash(x509certificate.getPublicKey()), x509certificate.getPublicKey());
LOG.tracef("Adding certificate %s without a specific key name: %s", name, x509certificate);
}
}
return (kid == null ? null : this.publicKeyCache.get(kid));
return key == null ? null : cache.get(key);
}
@Override
@ -180,11 +195,13 @@ public class SamlDescriptorPublicKeyLocator implements KeyLocator, Iterable<Publ
}
@Override
public Iterator<PublicKey> iterator() {
if (this.publicKeyCache.isEmpty()) {
refreshCertificateCacheAndGet(null);
public Iterator<Key> iterator() {
int currentTime = Time.currentTime();
if (currentTime > this.lastRequestTime + this.cacheEntryTtl) {
LOG.debugf("Performing regular cache cleanup.");
refreshCertificateCacheAndGet(null, publicKeyCacheByName, currentTime);
}
return this.publicKeyCache.values().iterator();
return this.publicKeyCacheByKey.values().iterator();
}
}

View file

@ -0,0 +1,182 @@
/*
* Copyright 2023 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.adapters.saml.rotation;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpHandler;
import com.sun.net.httpserver.HttpServer;
import java.io.IOException;
import java.io.OutputStream;
import java.io.StringWriter;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.cert.X509Certificate;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamWriter;
import org.apache.http.impl.client.HttpClients;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.keycloak.common.util.Time;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.SPMetadataDescriptor;
import org.keycloak.saml.common.exceptions.ProcessingException;
import org.keycloak.saml.common.util.StaxUtil;
import org.keycloak.saml.processing.core.util.XMLSignatureUtil;
import org.w3c.dom.Element;
/**
*
* @author rmartinc
*/
public class SamlDescriptorPublicKeyLocatorTest {
private static final String DESCRIPTOR_PREFIX =
"<EntityDescriptor ID=\"_46a4ff39-ad96-499d-91d9-040588865218\" entityID=\"http://adfs.server.url/adfs/services/trust\" xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" xmlns:ds=\"http://www.w3.org/2000/09/xmldsig#\">" +
"<IDPSSODescriptor protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\" WantAuthnRequestsSigned=\"true\">" +
"<KeyDescriptor use=\"signing\">";
private static final String DESCRIPTOR_SUFFIX =
"</KeyDescriptor>" +
"<SingleLogoutService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"https://adfs.server.url/adfs/ls/\"/>" +
"<SingleLogoutService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"https://adfs.server.url/adfs/ls/\"/>" +
"<NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</NameIDFormat>" +
"<NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:persistent</NameIDFormat>" +
"<NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>" +
"<SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"https://adfs.server.url/adfs/ls/\"/>" +
"<SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"https://adfs.server.url/adfs/ls/\"/>" +
"</IDPSSODescriptor>" +
"<ContactPerson contactType=\"support\"/>" +
"</EntityDescriptor>";
private static final String SAMPLE_CERTIFICATE_RSA_1 = "MIICmzCCAYMCBgGGc0gbfzANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjMwMjIxMDkyMDUwWhcNMzMwMjIxMDkyMjMwWjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCfc15pW/NOT3CM92q7BUB3pyTdA1h0WFG+JM2JjrNyZEbxsycYXS84QlaaEl/qT0wshIFQnv6bD1jy604V9W+7luK6Q/cOoQyRCiI70CVy4kB73sqT8Lgrfux6zWJeZ0lMO14sPq6eJLhWNBGxbGvJtUgBAdv5TIjf8yaHCV+yo4rc83T6Pd1sfTlRrURnokPD+hy+BbCEVj9350vYiyTRSvUD+e1wG1BIyZ/IA572p15rS69PP+qAuBBE8QF42bI56ZTsU+tXxwSX2nPqVbLD61tb1BFXfrHkArRiLe/Dte7xAmArynWs62ZI1q52REVWik1dzzy+VpJ7lef7vgtJAgMBAAEwDQYJKoZIhvcNAQELBQADggEBADB5DXugTWEYrw/ic/Jqz+aKXlz+QJvP5JEOVMnfKQLfHx+6760ubCwqJstA8HL6z8DWQUWWylwhfFv15nW/tgawbYLGHiq0NfB3/T6u3hswAPff9ZNvviL0L8CtPXpgPE5MzUEyPRIl/ExW/a7oNlo3rOPE6vA2xEG5h24f9xVdT5hGT5wRTm/e64ZT+umpWs2HnGjRcvdEKZhQPGfKrfdzNn1DVobbGSuy7P64lPWRJ/DxrhMwVkOyfZ+XoIGavS/yLQt01KjIrqtmUZOwHE5FRM/B58doGZn/zNpxq0tb7t9sxWIcW6wyZyieTAO7D9D84Qw8EBwKlbtsfS0oSZw=";
private static final String SAMPLE_CERTIFICATE_RSA_2 = "MIICmzCCAYMCBgGGc0gb+jANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjMwMjIxMDkyMDUxWhcNMzMwMjIxMDkyMjMxWjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDOgZKCSPgYFaBCLrhaX4jBjgTqdYemJPLyR3gAq3GhO/KVVj5i3lOJYLPE3TdyxowxpvnqJK5zIgLv954y7cbah5wbyfdcFf/qa/RvEDAVb1c3gs+7e5uEoiAWgvARQbcduuO8U/rerlgF3eN0WLjIjcz8yncLmMvd+AhjOAqs3AmKrlEADeABTRq454gXjrD8x3bZwRvC67ZOdK32WpfIG9u58WABDYHWavQ8aetcs1uuwbNl7Tmi0heEtgBd8q2y3BJmn31NXmRobLwNuILEN8sujMKf6iaISA50gh0TCUYSbzeeQ6DrqHBlOA8azpuwka4pQyr+R22MDdrItTc3AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAKuc82PlWzQbevzd/FvbutsEX5Tdf4Nojd+jOvcP6NiDtImWojzgN+SSAKTtmCz3ToBxjJbI4UjhovjWN4e4ygEWksBw6YYYR9ZGCJ7Z3EZzyREojvZeF/H0lQqB3BgnjI38HBpRgCpZm3H6+1UoJtMOW2sU8jorG/k1qvXrx2Y3bZvj/6wixVnzjiFzagb3cIUzv9c7ZWlexaR2Bg0k4kQ5TFwyzYCE136nl8xPqoDd8Nc4fQEPI7wLYMGglmbLFlGvdz3IJ7XRparYJRm4wlznQ43GL2x2KGBu8JipgbA7+u6F84oqf3vOC/PozWXzVCn08e6gqBY3YdZcs6sA3qY=";
private static HttpServer server;
private static final Map<String, String> signingCertificates = new HashMap<>();
private static final X509Certificate cert1;
private static final X509Certificate cert2;
static {
try {
cert1 = XMLSignatureUtil.getX509CertificateFromKeyInfoString(SAMPLE_CERTIFICATE_RSA_1);
cert2 = XMLSignatureUtil.getX509CertificateFromKeyInfoString(SAMPLE_CERTIFICATE_RSA_2);
} catch (ProcessingException e) {
throw new IllegalStateException(e);
}
}
@BeforeClass
public static void startHttpServer() throws IOException {
server = HttpServer.create(new InetSocketAddress(8280), 0);
server.createContext("/", new MyHandler());
server.setExecutor(null); // creates a default executor
server.start();
}
@AfterClass
public static void stopHttpServer() {
server.stop(0);
}
@After
public void resetTest() {
signingCertificates.clear();
Time.setOffset(0);
}
@Test
public void testKeyName() throws Exception {
signingCertificates.put("cert1", SAMPLE_CERTIFICATE_RSA_1);
KeyLocator locator = new SamlDescriptorPublicKeyLocator("http://localhost:8280", 10, 30, HttpClients.createDefault());
Assert.assertEquals(cert1.getPublicKey(), locator.getKey("cert1"));
Assert.assertNull(locator.getKey("cert2"));
signingCertificates.put("cert2", SAMPLE_CERTIFICATE_RSA_2);
Assert.assertNull(locator.getKey("cert2")); // not allowed to refresh
Time.setOffset(11);
signingCertificates.put("cert2", SAMPLE_CERTIFICATE_RSA_2);
Assert.assertEquals(cert2.getPublicKey(), locator.getKey("cert2"));
}
@Test
public void testCertificateKey() throws Exception {
signingCertificates.put("cert1", SAMPLE_CERTIFICATE_RSA_1);
KeyLocator locator = new SamlDescriptorPublicKeyLocator("http://localhost:8280", 10, 30, HttpClients.createDefault());
Assert.assertEquals(cert1.getPublicKey(), locator.getKey(cert1.getPublicKey()));
Assert.assertNull(locator.getKey("cert2"));
signingCertificates.put("cert2", SAMPLE_CERTIFICATE_RSA_2);
Assert.assertNull(locator.getKey(cert2.getPublicKey())); // not allowed to refresh
Time.setOffset(11);
signingCertificates.put("cert2", SAMPLE_CERTIFICATE_RSA_2);
Assert.assertEquals(cert2.getPublicKey(), locator.getKey(cert2.getPublicKey()));
}
@Test
public void testIteration() throws Exception {
signingCertificates.put("cert1", SAMPLE_CERTIFICATE_RSA_1);
KeyLocator locator = new SamlDescriptorPublicKeyLocator("http://localhost:8280", 10, 30, HttpClients.createDefault());
Set<Key> keys = StreamSupport.stream(locator.spliterator(), false).collect(Collectors.toSet());
Assert.assertTrue(keys.contains(cert1.getPublicKey()));
signingCertificates.put("cert2", SAMPLE_CERTIFICATE_RSA_2);
// not refreshed
keys = StreamSupport.stream(locator.spliterator(), false).collect(Collectors.toSet());
Assert.assertFalse(keys.contains(cert2.getPublicKey()));
Time.setOffset(11);
// still not refreshed, iterator waits for ttl
keys = StreamSupport.stream(locator.spliterator(), false).collect(Collectors.toSet());
Assert.assertFalse(keys.contains(cert2.getPublicKey()));
Time.setOffset(31);
// now should be refreshed
keys = StreamSupport.stream(locator.spliterator(), false).collect(Collectors.toSet());
Assert.assertTrue(keys.contains(cert2.getPublicKey()));
}
private static class MyHandler implements HttpHandler {
@Override
public void handle(HttpExchange t) throws IOException {
try {
StringBuilder sb = new StringBuilder(DESCRIPTOR_PREFIX);
for (Map.Entry<String, String> entry : signingCertificates.entrySet()) {
StringWriter sw = new StringWriter();
XMLStreamWriter writer = StaxUtil.getXMLStreamWriter(sw);
Element e = SPMetadataDescriptor.buildKeyInfoElement(entry.getKey(), entry.getValue());
StaxUtil.writeDOMElement(writer, e);
writer.close();
sb.append(sw.toString());
}
sb.append(DESCRIPTOR_SUFFIX);
byte[] bytes = sb.toString().getBytes(StandardCharsets.UTF_8);
t.getResponseHeaders().add("Content-Type", "application/xml;charset=UTF-8");
t.sendResponseHeaders(200, bytes.length);
try ( OutputStream os = t.getResponseBody()) {
os.write(bytes);
}
} catch (ParserConfigurationException | XMLStreamException | ProcessingException e) {
throw new IOException(e);
}
}
}
}

View file

@ -47,6 +47,18 @@ public class CompositeKeyLocator implements KeyLocator, Iterable<Key> {
return null;
}
@Override
public Key getKey(Key key) throws KeyManagementException {
for (KeyLocator keyLocator : keyLocators) {
Key k = keyLocator.getKey(key);
if (k != null) {
return k;
}
}
return null;
}
@Override
public void refreshKeyCache() {
for (KeyLocator keyLocator : keyLocators) {

View file

@ -21,34 +21,60 @@ import java.security.Key;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Key locator that always returns a specified key.
* Key locator for a bunch of keys. It can be initializaed with or without
* key names.
*
* @author <a href="mailto:hmlnarik@redhat.com">Hynek Mlnařík</a>
*/
public class HardcodedKeyLocator implements KeyLocator, Iterable<Key> {
private final Collection<? extends Key> keys;
private final Map<String, ? extends Key> byName;
private final Map<KeyHash, ? extends Key> byKey;
public HardcodedKeyLocator(Key key) {
this.keys = Collections.singleton(key);
Objects.requireNonNull(key, "Key must not be null");
this.byName = Collections.emptyMap();
this.byKey = Collections.singletonMap(new KeyHash(key), key);
}
public HardcodedKeyLocator(Collection<? extends Key> keys) {
if (keys == null) {
throw new NullPointerException("keys");
}
this.keys = new LinkedList<>(keys);
Objects.requireNonNull(keys, "Keys must not be null");
this.byName = Collections.emptyMap();
this.byKey = Collections.unmodifiableMap(keys.stream().collect(
Collectors.toMap(k -> new KeyHash(k), k -> k)));
}
public HardcodedKeyLocator(Map<String, ? extends Key> keys) {
Objects.requireNonNull(keys, "Keys must not be null");
this.byName = Collections.unmodifiableMap(keys);
this.byKey = Collections.unmodifiableMap(keys.values().stream().collect(
Collectors.toMap(k -> new KeyHash(k), k -> k)));
}
@Override
public Key getKey(String kid) {
if (this.keys.size() == 1) {
return this.keys.iterator().next();
} else {
if (this.byKey.size() == 1) {
return this.byKey.values().iterator().next();
} else if (kid == null) {
return null;
} else {
return this.byName.get(kid);
}
}
@Override
public Key getKey(Key key) {
if (this.byKey.size() == 1) {
return this.byKey.values().iterator().next();
} else if (key == null) {
return null;
} else {
return this.byKey.get(new KeyHash(key));
}
}
@ -59,11 +85,11 @@ public class HardcodedKeyLocator implements KeyLocator, Iterable<Key> {
@Override
public String toString() {
return "hardcoded keys, count: " + this.keys.size();
return "hardcoded keys, count: " + this.byKey.size();
}
@Override
public Iterator<Key> iterator() {
return (Iterator<Key>) Collections.unmodifiableCollection(keys).iterator();
return (Iterator<Key>) byKey.values().iterator();
}
}

View file

@ -18,7 +18,17 @@
package org.keycloak.rotation;
import java.security.Key;
import java.security.KeyException;
import java.security.KeyManagementException;
import java.security.MessageDigest;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.List;
import javax.xml.crypto.XMLStructure;
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
import javax.xml.crypto.dsig.keyinfo.KeyName;
import javax.xml.crypto.dsig.keyinfo.KeyValue;
import javax.xml.crypto.dsig.keyinfo.X509Data;
/**
* This interface defines a method for obtaining a security key by ID.
@ -30,21 +40,108 @@ import java.security.KeyManagementException;
*
* @author <a href="mailto:hmlnarik@redhat.com">Hynek Mlnařík</a>
*/
public interface KeyLocator {
public interface KeyLocator extends Iterable<Key> {
/**
* Returns a key with a particular ID.
* @param kid Key ID
* @param configuration Configuration
* @return key, which should be used for verify signature on given "input"
* @throws KeyManagementException
*/
Key getKey(String kid) throws KeyManagementException;
/**
* Method that checks if the key passed is inside the locator.
* @param key The key to search
* @return The same key or null if it's not in the locator
* @throws KeyManagementException
*/
default Key getKey(Key key) throws KeyManagementException {
if (key == null) {
return null;
}
for (Key k : this) {
if (k.getAlgorithm().equals(key.getAlgorithm()) && MessageDigest.isEqual(k.getEncoded(), key.getEncoded())) {
return key;
}
}
return null;
}
/**
* Returns the key in the locator that is represented by the KeyInfo
* dsig structure. The default implementation just iterates and returns
* the first KeyName, X509Data or PublicKey that is in the locator.
* @param info The KeyInfo to search
* @return The key found or null
* @throws KeyManagementException
*/
default Key getKey(KeyInfo info) throws KeyManagementException {
if (info == null) {
return null;
}
Key key = null;
for (XMLStructure xs : (List<XMLStructure>) info.getContent()) {
if (xs instanceof KeyName) {
key = getKey(((KeyName) xs).getName());
} else if (xs instanceof X509Data) {
for (Object content : ((X509Data) xs).getContent()) {
if (content instanceof X509Certificate) {
key = getKey(((X509Certificate) content).getPublicKey());
if (key != null) {
return key;
}
// only the first X509Certificate is the signer
// the rest are just part of the chain
break;
}
}
} else if (xs instanceof KeyValue) {
try {
key = getKey(((KeyValue) xs).getPublicKey());
} catch (KeyException e) {
throw new KeyManagementException(e);
}
}
if (key != null) {
return key;
}
}
return null;
}
/**
* If this key locator caches keys in any way, forces this cache cleanup
* and refreshing the keys.
*/
void refreshKeyCache();
/**
* Helper class that facilitates the hash of a Key to be located easier.
*/
public static class KeyHash {
private final Key key;
private final int keyHash;
public KeyHash(Key key) {
this.key = key;
this.keyHash = Arrays.hashCode(key.getEncoded());
}
@Override
public int hashCode() {
return keyHash;
}
@Override
public boolean equals(Object o) {
if (o instanceof KeyHash) {
KeyHash other = (KeyHash) o;
return keyHash == other.keyHash &&
key.getAlgorithm().equals(other.key.getAlgorithm()) &&
MessageDigest.isEqual(key.getEncoded(), other.key.getEncoded());
}
return false;
}
}
}

View file

@ -0,0 +1,95 @@
/*
* Copyright 2023 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.saml.processing.core.util;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyManagementException;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import org.jboss.logging.Logger;
import org.keycloak.common.VerificationException;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.SignatureAlgorithm;
/**
*
* @author rmartinc
*/
public class RedirectBindingSignatureUtil {
private static final Logger log = Logger.getLogger(RedirectBindingSignatureUtil.class);
private RedirectBindingSignatureUtil (){
// utility class
}
public static boolean validateRedirectBindingSignature(SignatureAlgorithm sigAlg, byte[] rawQueryBytes, byte[] decodedSignature,
KeyLocator locator, String keyId) throws KeyManagementException, VerificationException {
try {
try {
Key key = locator.getKey(keyId);
if (key != null) {
return validateRedirectBindingSignatureForKey(sigAlg, rawQueryBytes, decodedSignature, key);
}
} catch (KeyManagementException ex) {
}
} catch (SignatureException ex) {
log.debug("Verification failed for key %s: %s", keyId, ex);
log.trace(ex);
}
log.trace("Trying hard to validate XML signature using all available keys.");
for (Key key : locator) {
try {
if (validateRedirectBindingSignatureForKey(sigAlg, rawQueryBytes, decodedSignature, key)) {
return true;
}
} catch (SignatureException ex) {
log.debug("Verification failed: %s", ex);
}
}
return false;
}
public static boolean validateRedirectBindingSignatureForKey(SignatureAlgorithm sigAlg, byte[] rawQueryBytes, byte[] decodedSignature, Key key)
throws SignatureException {
if (key == null) {
return false;
}
if (!(key instanceof PublicKey)) {
log.warnf("Unusable key for signature validation: %s", key);
return false;
}
Signature signature = sigAlg.createSignature(); // todo plugin signature alg
try {
signature.initVerify((PublicKey) key);
} catch (InvalidKeyException ex) {
log.warnf(ex, "Unusable key for signature validation: %s", key);
return false;
}
signature.update(rawQueryBytes);
return signature.verify(decodedSignature);
}
}

View file

@ -88,10 +88,8 @@ import javax.xml.crypto.KeySelector;
import javax.xml.crypto.KeySelectorException;
import javax.xml.crypto.KeySelectorResult;
import javax.xml.crypto.XMLCryptoContext;
import javax.xml.crypto.dsig.keyinfo.KeyName;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.common.util.SecurityActions;
import org.keycloak.saml.processing.api.util.KeyInfoTools;
/**
* Utility for XML Signature <b>Note:</b> You can change the canonicalization method type by using the system property
@ -138,9 +136,7 @@ public class XMLSignatureUtil {
@Override
public KeySelectorResult select(KeyInfo keyInfo, KeySelector.Purpose purpose, AlgorithmMethod method, XMLCryptoContext context) throws KeySelectorException {
try {
KeyName keyNameEl = KeyInfoTools.getKeyName(keyInfo);
this.keyName = keyNameEl == null ? null : keyNameEl.getName();
final Key key = locator.getKey(keyName);
final Key key = locator.getKey(keyInfo);
this.keyLocated = key != null;
return new KeySelectorResult() {
@Override public Key getKey() {
@ -158,24 +154,6 @@ public class XMLSignatureUtil {
}
}
private static class KeySelectorPresetKey extends KeySelector {
private final Key key;
public KeySelectorPresetKey(Key key) {
this.key = key;
}
@Override
public KeySelectorResult select(KeyInfo keyInfo, KeySelector.Purpose purpose, AlgorithmMethod method, XMLCryptoContext context) {
return new KeySelectorResult() {
@Override public Key getKey() {
return key;
}
};
}
}
private static XMLSignatureFactory getXMLSignatureFactory() {
XMLSignatureFactory xsf = null;
@ -494,20 +472,16 @@ public class XMLSignatureUtil {
logger.trace("Could not validate signature using ds:KeyInfo/ds:KeyName hint.");
if (locator instanceof Iterable) {
Iterable<Key> availableKeys = (Iterable<Key>) locator;
logger.trace("Trying hard to validate XML signature using all available keys.");
logger.trace("Trying hard to validate XML signature using all available keys.");
for (Key key : availableKeys) {
try {
if (validateUsingKeySelector(signatureNode, new KeySelectorPresetKey(key))) {
return true;
}
} catch (XMLSignatureException ex) { // pass through MarshalException
logger.debug("Verification failed: " + ex);
logger.trace(ex);
for (Key key : locator) {
try {
if (validateUsingKeySelector(signatureNode, KeySelector.singletonKeySelector(key))) {
return true;
}
} catch (XMLSignatureException ex) { // pass through MarshalException
logger.debug("Verification failed: " + ex);
logger.trace(ex);
}
}
@ -738,7 +712,7 @@ public class XMLSignatureUtil {
signature.sign(dsc);
}
private static KeyInfo createKeyInfo(String keyName, PublicKey publicKey, X509Certificate x509Certificate) throws KeyException {
public static KeyInfo createKeyInfo(String keyName, PublicKey publicKey, X509Certificate x509Certificate) throws KeyException {
KeyInfoFactory keyInfoFactory = fac.getKeyInfoFactory();
List<XMLStructure> items = new LinkedList<>();

View file

@ -0,0 +1,142 @@
/*
* Copyright 2023 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.rotation;
import java.security.Key;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
import org.junit.Assert;
import org.junit.Test;
import org.keycloak.saml.common.exceptions.ProcessingException;
import org.keycloak.saml.processing.core.util.XMLSignatureUtil;
/**
*
* @author rmartinc
*/
public class HardcodedKeyLocatorTest {
private static final String SAMPLE_CERTIFICATE_RSA_1 = "MIICmzCCAYMCBgGGc0gbfzANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjMwMjIxMDkyMDUwWhcNMzMwMjIxMDkyMjMwWjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCfc15pW/NOT3CM92q7BUB3pyTdA1h0WFG+JM2JjrNyZEbxsycYXS84QlaaEl/qT0wshIFQnv6bD1jy604V9W+7luK6Q/cOoQyRCiI70CVy4kB73sqT8Lgrfux6zWJeZ0lMO14sPq6eJLhWNBGxbGvJtUgBAdv5TIjf8yaHCV+yo4rc83T6Pd1sfTlRrURnokPD+hy+BbCEVj9350vYiyTRSvUD+e1wG1BIyZ/IA572p15rS69PP+qAuBBE8QF42bI56ZTsU+tXxwSX2nPqVbLD61tb1BFXfrHkArRiLe/Dte7xAmArynWs62ZI1q52REVWik1dzzy+VpJ7lef7vgtJAgMBAAEwDQYJKoZIhvcNAQELBQADggEBADB5DXugTWEYrw/ic/Jqz+aKXlz+QJvP5JEOVMnfKQLfHx+6760ubCwqJstA8HL6z8DWQUWWylwhfFv15nW/tgawbYLGHiq0NfB3/T6u3hswAPff9ZNvviL0L8CtPXpgPE5MzUEyPRIl/ExW/a7oNlo3rOPE6vA2xEG5h24f9xVdT5hGT5wRTm/e64ZT+umpWs2HnGjRcvdEKZhQPGfKrfdzNn1DVobbGSuy7P64lPWRJ/DxrhMwVkOyfZ+XoIGavS/yLQt01KjIrqtmUZOwHE5FRM/B58doGZn/zNpxq0tb7t9sxWIcW6wyZyieTAO7D9D84Qw8EBwKlbtsfS0oSZw=";
private static final String SAMPLE_CERTIFICATE_RSA_2 = "MIICmzCCAYMCBgGGc0gb+jANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjMwMjIxMDkyMDUxWhcNMzMwMjIxMDkyMjMxWjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDOgZKCSPgYFaBCLrhaX4jBjgTqdYemJPLyR3gAq3GhO/KVVj5i3lOJYLPE3TdyxowxpvnqJK5zIgLv954y7cbah5wbyfdcFf/qa/RvEDAVb1c3gs+7e5uEoiAWgvARQbcduuO8U/rerlgF3eN0WLjIjcz8yncLmMvd+AhjOAqs3AmKrlEADeABTRq454gXjrD8x3bZwRvC67ZOdK32WpfIG9u58WABDYHWavQ8aetcs1uuwbNl7Tmi0heEtgBd8q2y3BJmn31NXmRobLwNuILEN8sujMKf6iaISA50gh0TCUYSbzeeQ6DrqHBlOA8azpuwka4pQyr+R22MDdrItTc3AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAKuc82PlWzQbevzd/FvbutsEX5Tdf4Nojd+jOvcP6NiDtImWojzgN+SSAKTtmCz3ToBxjJbI4UjhovjWN4e4ygEWksBw6YYYR9ZGCJ7Z3EZzyREojvZeF/H0lQqB3BgnjI38HBpRgCpZm3H6+1UoJtMOW2sU8jorG/k1qvXrx2Y3bZvj/6wixVnzjiFzagb3cIUzv9c7ZWlexaR2Bg0k4kQ5TFwyzYCE136nl8xPqoDd8Nc4fQEPI7wLYMGglmbLFlGvdz3IJ7XRparYJRm4wlznQ43GL2x2KGBu8JipgbA7+u6F84oqf3vOC/PozWXzVCn08e6gqBY3YdZcs6sA3qY=";
private static final X509Certificate cert1;
private static final X509Certificate cert2;
static {
try {
cert1 = XMLSignatureUtil.getX509CertificateFromKeyInfoString(SAMPLE_CERTIFICATE_RSA_1);
cert2 = XMLSignatureUtil.getX509CertificateFromKeyInfoString(SAMPLE_CERTIFICATE_RSA_2);
} catch (ProcessingException e) {
throw new IllegalStateException(e);
}
}
private static KeyLocator createLocatorWithName(X509Certificate... cert) {
Map<String, Key> tmp = new HashMap<>();
for (int i = 1; i <= cert.length; i++) {
tmp.put("cert" + i, cert[i - 1].getPublicKey());
}
return new HardcodedKeyLocator(tmp);
}
private static KeyLocator createLocatorWithoutName(X509Certificate... cert) {
return new HardcodedKeyLocator(Arrays.stream(cert).map(X509Certificate::getPublicKey).collect(Collectors.toList()));
}
@Test
public void testCertificateWithTwoCertificatesWithName() throws Exception {
KeyLocator locator = createLocatorWithName(cert1, cert2);
KeyInfo info = XMLSignatureUtil.createKeyInfo(null, null, cert1);
Key found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert1.getPublicKey(), found);
info = XMLSignatureUtil.createKeyInfo(null, null, cert2);
found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert2.getPublicKey(), found);
}
@Test
public void testKeyWithTwoCertificatesWithName() throws Exception {
KeyLocator locator = createLocatorWithName(cert1, cert2);
KeyInfo info = XMLSignatureUtil.createKeyInfo(null, cert1.getPublicKey(), null);
Key found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert1.getPublicKey(), found);
info = XMLSignatureUtil.createKeyInfo(null, cert2.getPublicKey(), null);
found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert2.getPublicKey(), found);
}
@Test
public void testKeyNameWithTwoCertificatesWithName() throws Exception {
KeyLocator locator = createLocatorWithName(cert1, cert2);
KeyInfo info = XMLSignatureUtil.createKeyInfo("cert1", null, null);
Key found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert1.getPublicKey(), found);
info = XMLSignatureUtil.createKeyInfo("cert2", null, null);
found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert2.getPublicKey(), found);
}
@Test
public void testKeyNameWithTwoCertificatesWithoutName() throws Exception {
KeyLocator locator = createLocatorWithoutName(cert1, cert2);
KeyInfo info = XMLSignatureUtil.createKeyInfo("cert1", null, null);
Key found = locator.getKey(info);
Assert.assertNull(found);
info = XMLSignatureUtil.createKeyInfo("cert1", cert1.getPublicKey(), null);
found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert1.getPublicKey(), found);
info = XMLSignatureUtil.createKeyInfo("cert2", null, cert2);
found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert2.getPublicKey(), found);
}
@Test
public void testKeyNameWithOneCertificatesWithoutName() throws Exception {
//hardcoded locator with one cert is always returned
KeyLocator locator = createLocatorWithoutName(cert1);
KeyInfo info = XMLSignatureUtil.createKeyInfo("cert1", null, null);
Key found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert1.getPublicKey(), found);
info = XMLSignatureUtil.createKeyInfo("cert1", cert1.getPublicKey(), null);
found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert1.getPublicKey(), found);
info = XMLSignatureUtil.createKeyInfo("cert2", null, cert2);
found = locator.getKey(info);
Assert.assertNotNull(found);
Assert.assertEquals(cert1.getPublicKey(), found);
}
}

View file

@ -20,16 +20,25 @@ package org.keycloak.protocol.saml;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.net.URI;
import java.security.Key;
import java.security.PublicKey;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.UriInfo;
import org.jboss.logging.Logger;
import org.keycloak.common.VerificationException;
import org.keycloak.common.util.PemUtils;
import org.keycloak.dom.saml.v2.SAML2Object;
import org.keycloak.dom.saml.v2.assertion.NameIDType;
import org.keycloak.dom.saml.v2.protocol.ArtifactResponseType;
import org.keycloak.dom.saml.v2.protocol.ExtensionsType;
import org.keycloak.dom.saml.v2.protocol.RequestAbstractType;
import org.keycloak.dom.saml.v2.protocol.StatusCodeType;
import org.keycloak.dom.saml.v2.protocol.StatusResponseType;
import org.keycloak.dom.saml.v2.protocol.StatusType;
import org.keycloak.models.ClientModel;
import org.keycloak.rotation.HardcodedKeyLocator;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.SignatureAlgorithm;
import org.keycloak.saml.common.constants.GeneralConstants;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
@ -41,28 +50,13 @@ import org.keycloak.saml.common.util.StaxUtil;
import org.keycloak.saml.processing.api.saml.v2.request.SAML2Request;
import org.keycloak.saml.processing.api.saml.v2.sig.SAML2Signature;
import org.keycloak.saml.processing.core.saml.v2.common.IDGenerator;
import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder;
import org.keycloak.saml.processing.core.saml.v2.util.XMLTimeUtil;
import org.keycloak.saml.processing.core.saml.v2.writers.SAMLResponseWriter;
import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
import org.keycloak.saml.processing.core.util.RedirectBindingSignatureUtil;
import org.keycloak.saml.processing.web.util.RedirectBindingUtil;
import org.w3c.dom.Document;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.UriInfo;
import org.keycloak.dom.saml.v2.SAML2Object;
import org.keycloak.dom.saml.v2.protocol.ExtensionsType;
import org.keycloak.dom.saml.v2.protocol.RequestAbstractType;
import org.keycloak.dom.saml.v2.protocol.StatusResponseType;
import org.keycloak.rotation.HardcodedKeyLocator;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder;
import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
import java.security.PublicKey;
import java.security.Signature;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import org.w3c.dom.Element;
/**
@ -177,15 +171,8 @@ public class SamlProtocolUtils {
String decodedAlgorithm = RedirectBindingUtil.urlDecode(encodedParams.getFirst(GeneralConstants.SAML_SIG_ALG_REQUEST_KEY));
SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.getFromXmlMethod(decodedAlgorithm);
Signature validator = signatureAlgorithm.createSignature(); // todo plugin signature alg
Key key = locator.getKey(keyId);
if (key instanceof PublicKey) {
validator.initVerify((PublicKey) key);
validator.update(rawQuery.getBytes("UTF-8"));
} else {
throw new VerificationException("Invalid key locator for signature verification");
}
if (!validator.verify(decodedSignature)) {
if (!RedirectBindingSignatureUtil.validateRedirectBindingSignature(signatureAlgorithm,
rawQuery.getBytes("UTF-8"), decodedSignature, locator, keyId)) {
throw new VerificationException("Invalid query param signature");
}
} catch (Exception e) {

View file

@ -87,6 +87,8 @@ import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.UUID;
@ -705,14 +707,27 @@ public class SamlClient {
// if the public key is passed verify the signature of the redirect URI
try {
KeyLocator locator = new KeyLocator() {
private final Key key = org.keycloak.testsuite.util.KeyUtils.publicKeyFromString(realmPublicKey);
@Override
public Key getKey(String kid) throws KeyManagementException {
return org.keycloak.testsuite.util.KeyUtils.publicKeyFromString(realmPublicKey);
return this.key;
}
@Override
public Key getKey(Key key) throws KeyManagementException {
return this.key;
}
@Override
public void refreshKeyCache() {
}
@Override
public Iterator<Key> iterator() {
return Collections.singleton(this.key).iterator();
}
};
SamlProtocolUtils.verifyRedirectSignature(documentHolder, locator, encodedParams,
samlResponse != null? GeneralConstants.SAML_RESPONSE_KEY : GeneralConstants.SAML_REQUEST_KEY);

View file

@ -1,22 +1,14 @@
package org.keycloak.testsuite.broker;
import org.keycloak.broker.saml.SAMLIdentityProviderConfig;
import org.keycloak.common.util.MultivaluedHashMap;
import org.keycloak.crypto.Algorithm;
import org.keycloak.crypto.KeyUse;
import org.keycloak.dom.saml.v2.protocol.AuthnRequestType;
import org.keycloak.jose.jwe.JWEConstants;
import org.keycloak.keys.Attributes;
import org.keycloak.keys.GeneratedRsaEncKeyProviderFactory;
import org.keycloak.keys.KeyProvider;
import org.keycloak.models.IdentityProviderSyncMode;
import org.keycloak.models.utils.DefaultKeyProviders;
import org.keycloak.protocol.saml.SamlConfigAttributes;
import org.keycloak.representations.idm.ClientRepresentation;
import org.keycloak.representations.idm.ComponentExportRepresentation;
import org.keycloak.representations.idm.IdentityProviderRepresentation;
import org.keycloak.representations.idm.KeysMetadataRepresentation;
import org.keycloak.representations.idm.RealmRepresentation;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.util.DocumentUtil;
import org.keycloak.saml.processing.api.saml.v2.request.SAML2Request;
@ -58,6 +50,7 @@ import org.w3c.dom.Element;
import org.w3c.dom.NamedNodeMap;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertThat;
@ -500,4 +493,84 @@ public class KcSamlSignedBrokerTest extends AbstractBrokerTest {
.execute();
}
}
@Test
public void testSignatureDataTwoCertificatesPostBinding() throws Exception {
// Check two certifcates work with POST binding
String badCert = KeyUtils.findActiveSigningKey(adminClient.realm(bc.consumerRealmName()), Algorithm.RS256).getCertificate();
String goodCert = KeyUtils.findActiveSigningKey(adminClient.realm(bc.providerRealmName()), Algorithm.RS256).getCertificate();
try (Closeable clientUpdater = ClientAttributeUpdater.forClient(adminClient, bc.providerRealmName(), bc.getIDPClientIdInProviderRealm())
.setAttribute(SamlConfigAttributes.SAML_ENCRYPT, "false")
.setAttribute(SamlConfigAttributes.SAML_SERVER_SIGNATURE, "true")
.setAttribute(SamlConfigAttributes.SAML_ASSERTION_SIGNATURE, "false")
.setAttribute(SamlConfigAttributes.SAML_CLIENT_SIGNATURE_ATTRIBUTE, "false")
.update();
Closeable idpUpdater = new IdentityProviderAttributeUpdater(identityProviderResource)
.setAttribute(SAMLIdentityProviderConfig.VALIDATE_SIGNATURE, "true")
.setAttribute(SAMLIdentityProviderConfig.WANT_ASSERTIONS_SIGNED, "false")
.setAttribute(SAMLIdentityProviderConfig.WANT_ASSERTIONS_ENCRYPTED, "false")
.setAttribute(SAMLIdentityProviderConfig.WANT_AUTHN_REQUESTS_SIGNED, "true")
.setAttribute(SAMLIdentityProviderConfig.SIGNING_CERTIFICATE_KEY, badCert + "," + goodCert)
.update();
)
{
// Build the login request document
AuthnRequestType loginRep = SamlClient.createLoginRequestDocument(AbstractSamlTest.SAML_CLIENT_ID_SALES_POST, getConsumerRoot() + "/sales-post/saml", null);
Document doc = SAML2Request.convert(loginRep);
new SamlClientBuilder()
.authnRequest(getConsumerSamlEndpoint(bc.consumerRealmName()), doc, Binding.POST)
.build() // Request to consumer IdP
.login().idp(bc.getIDPAlias()).build()
.processSamlResponse(Binding.POST).build() // AuthnRequest to producer IdP
.login().user(bc.getUserLogin(), bc.getUserPassword()).build()
.processSamlResponse(Binding.POST) // Response from producer IdP
.build()
// first-broker flow: if valid request, it displays an update profile page on consumer realm
.execute(currentResponse -> assertThat(currentResponse, bodyHC(containsString("Update Account Information"))));
}
}
@Test
public void testSignatureDataTwoCertificatesRedirectBinding() throws Exception {
// Check two certifcates work with REDIRECT binding
String badCert = KeyUtils.findActiveSigningKey(adminClient.realm(bc.consumerRealmName()), Algorithm.RS256).getCertificate();
String goodCert = KeyUtils.findActiveSigningKey(adminClient.realm(bc.providerRealmName()), Algorithm.RS256).getCertificate();
try (Closeable clientProviderUpdater = ClientAttributeUpdater.forClient(adminClient, bc.providerRealmName(), bc.getIDPClientIdInProviderRealm())
.setAttribute(SamlConfigAttributes.SAML_ENCRYPT, "false")
.setAttribute(SamlConfigAttributes.SAML_SERVER_SIGNATURE, "true")
.setAttribute(SamlConfigAttributes.SAML_ASSERTION_SIGNATURE, "false")
.setAttribute(SamlConfigAttributes.SAML_CLIENT_SIGNATURE_ATTRIBUTE, "false")
.setAttribute(SamlConfigAttributes.SAML_FORCE_POST_BINDING, "false")
.update();
Closeable clientConsumerUpdater = ClientAttributeUpdater.forClient(adminClient, bc.providerRealmName(), bc.getIDPClientIdInProviderRealm())
.setAttribute(SamlConfigAttributes.SAML_FORCE_POST_BINDING, "false")
.update();
Closeable idpUpdater = new IdentityProviderAttributeUpdater(identityProviderResource)
.setAttribute(SAMLIdentityProviderConfig.VALIDATE_SIGNATURE, "true")
.setAttribute(SAMLIdentityProviderConfig.WANT_ASSERTIONS_SIGNED, "false")
.setAttribute(SAMLIdentityProviderConfig.WANT_ASSERTIONS_ENCRYPTED, "false")
.setAttribute(SAMLIdentityProviderConfig.WANT_AUTHN_REQUESTS_SIGNED, "true")
.setAttribute(SAMLIdentityProviderConfig.POST_BINDING_AUTHN_REQUEST, "false")
.setAttribute(SAMLIdentityProviderConfig.POST_BINDING_RESPONSE, "false")
.setAttribute(SAMLIdentityProviderConfig.SIGNING_CERTIFICATE_KEY, badCert + "," + goodCert)
.update();
)
{
// Build the login request document
AuthnRequestType loginRep = SamlClient.createLoginRequestDocument(AbstractSamlTest.SAML_CLIENT_ID_SALES_POST, getConsumerRoot() + "/sales-post/saml", null);
Document doc = SAML2Request.convert(loginRep);
new SamlClientBuilder()
.authnRequest(getConsumerSamlEndpoint(bc.consumerRealmName()), doc, Binding.REDIRECT)
.build() // Request to consumer IdP
.login().idp(bc.getIDPAlias()).build()
.processSamlResponse(Binding.REDIRECT).build() // AuthnRequest to producer IdP
.login().user(bc.getUserLogin(), bc.getUserPassword()).build()
.processSamlResponse(Binding.REDIRECT) // Response from producer IdP
.build()
// first-broker flow: if valid request, it displays an update profile page on consumer realm
.execute(currentResponse -> assertThat(currentResponse, bodyHC(containsString("Update Account Information"))));
}
}
}