Use references to obtain the signed elements in a signature (#188) (#33190)

Closes keycloak/keycloak-private#191
Closes #33116

Signed-off-by: rmartinc <rmartinc@redhat.com>
Co-authored-by: Ricardo Martin <rmartinc@redhat.com>
This commit is contained in:
Stian Thorgersen 2024-09-23 13:51:46 +02:00 committed by GitHub
parent af5eef57bf
commit d778a8551a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 130 additions and 27 deletions

View file

@ -59,7 +59,6 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import org.w3c.dom.Node; import org.w3c.dom.Node;
import javax.xml.crypto.dsig.XMLSignature;
import javax.xml.datatype.XMLGregorianCalendar; import javax.xml.datatype.XMLGregorianCalendar;
import javax.xml.stream.XMLEventReader; import javax.xml.stream.XMLEventReader;
@ -315,7 +314,7 @@ public class AssertionUtil {
} }
protected static Element getSignature(Element element) { protected static Element getSignature(Element element) {
return DocumentUtil.getDirectChildElement(element, XMLSignature.XMLNS, "Signature"); return XMLSignatureUtil.getSignature(element);
} }
/** /**

View file

@ -81,15 +81,22 @@ import java.security.interfaces.DSAPublicKey;
import java.security.interfaces.RSAPublicKey; import java.security.interfaces.RSAPublicKey;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Set;
import javax.xml.crypto.AlgorithmMethod; import javax.xml.crypto.AlgorithmMethod;
import javax.xml.crypto.Data;
import javax.xml.crypto.KeySelector; import javax.xml.crypto.KeySelector;
import javax.xml.crypto.KeySelectorException; import javax.xml.crypto.KeySelectorException;
import javax.xml.crypto.KeySelectorResult; import javax.xml.crypto.KeySelectorResult;
import javax.xml.crypto.NodeSetData;
import javax.xml.crypto.URIReferenceException;
import javax.xml.crypto.XMLCryptoContext; import javax.xml.crypto.XMLCryptoContext;
import javax.xml.crypto.dom.DOMStructure; import javax.xml.crypto.dom.DOMStructure;
import org.keycloak.rotation.KeyLocator; import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.util.SecurityActions; import org.keycloak.saml.common.util.SecurityActions;
/** /**
@ -170,6 +177,52 @@ public class XMLSignatureUtil {
return xsf; return xsf;
} }
/**
* Returns the element that contains the signature for the passed element.
*
* @param element The element to search for the signature
* @return The signature element or null
*/
public static Element getSignature(Element element) {
Document doc = element.getOwnerDocument();
NodeList nl = doc.getElementsByTagNameNS(XMLSignature.XMLNS, "Signature");
if (element.getAttributeNode(JBossSAMLConstants.ID.get()) != null) {
// set the saml ID to be found
element.setIdAttribute(JBossSAMLConstants.ID.get(), true);
}
KeySelector nullSelector = new KeySelector() {
@Override
public KeySelectorResult select(KeyInfo ki, KeySelector.Purpose prps, AlgorithmMethod am, XMLCryptoContext xmlcc) throws KeySelectorException {
return () -> null;
}
};
try {
for (int i = 0; i < nl.getLength(); i++) {
Element signatureElement = (Element) nl.item(i);
DOMValidateContext valContext = new DOMValidateContext(nullSelector, signatureElement);
DOMStructure structure = new DOMStructure(signatureElement);
XMLSignature signature = fac.unmarshalXMLSignature(structure);
for (Reference ref : (List<Reference>) signature.getSignedInfo().getReferences()) {
try {
Data data = fac.getURIDereferencer().dereference(ref, valContext);
if (data instanceof NodeSetData) {
Iterator<Node> it = ((NodeSetData) data).iterator();
if (it.hasNext() && element.equals(it.next())) {
return signatureElement;
}
}
} catch (URIReferenceException e) {
logger.trace("Invalid URI reference in signature " + ref.getURI());
}
}
}
} catch (MarshalException e) {
logger.trace("Error unmarshalling signature", e);
}
return null;
}
/** /**
* Use this method to not include the KeyInfo in the signature * Use this method to not include the KeyInfo in the signature
* *
@ -404,7 +457,7 @@ public class XMLSignatureUtil {
* this way both assertions and the containing document are verified when signed. * this way both assertions and the containing document are verified when signed.
* *
* @param signedDoc * @param signedDoc
* @param publicKey * @param locator
* *
* @return * @return
* *
@ -428,39 +481,46 @@ public class XMLSignatureUtil {
if (locator == null) if (locator == null)
throw logger.nullValueError("Public Key"); throw logger.nullValueError("Public Key");
int signedAssertions = 0; HashSet<Node> signedNodes = new HashSet<>();
String assertionNameSpaceUri = null;
for (int i = 0; i < nl.getLength(); i++) { for (int i = 0; i < nl.getLength(); i++) {
Node signatureNode = nl.item(i); Node signatureNode = nl.item(i);
Node parent = signatureNode.getParentNode(); if (!validateSingleNode(signatureNode, locator, signedNodes)) {
if (parent != null && JBossSAMLConstants.ASSERTION.get().equals(parent.getLocalName())) { return false;
++signedAssertions; }
if (assertionNameSpaceUri == null) { }
assertionNameSpaceUri = parent.getNamespaceURI();
if (signedNodes.contains(signedDoc.getDocumentElement())) {
logger.trace("All signatures are OK and root document is signed");
return true;
}
NodeList assertions = signedDoc.getElementsByTagNameNS(JBossSAMLURIConstants.ASSERTION_NSURI.get(), JBossSAMLConstants.ASSERTION.get());
if (assertions.getLength() > 0) {
// if document is not fully signed check if all the assertions are signed
for (int i = 0; i < assertions.getLength(); i++) {
if (!signedNodes.contains(assertions.item(i))) {
logger.debug("SAML Response document may contain malicious assertions. Signature validation will fail.");
// there are unsigned assertions mixed with signed ones
return false;
} }
} }
logger.trace("Document not signed but all assertions are signed OK");
if (! validateSingleNode(signatureNode, locator)) return false; return true;
} }
NodeList assertions = signedDoc.getElementsByTagNameNS(assertionNameSpaceUri, JBossSAMLConstants.ASSERTION.get()); return false;
if (signedAssertions > 0 && assertions != null && assertions.getLength() != signedAssertions) {
if (logger.isDebugEnabled()) {
logger.debug("SAML Response document may contain malicious assertions. Signature validation will fail.");
}
// there are unsigned assertions mixed with signed ones
return false;
}
return true;
} }
public static boolean validateSingleNode(Node signatureNode, final KeyLocator locator) throws MarshalException, XMLSignatureException { public static boolean validateSingleNode(Node signatureNode, final KeyLocator locator) throws MarshalException, XMLSignatureException {
return validateSingleNode(signatureNode, locator, new HashSet<>());
}
public static boolean validateSingleNode(Node signatureNode, final KeyLocator locator, Set<Node> signedNodes) throws MarshalException, XMLSignatureException {
KeySelectorUtilizingKeyNameHint sel = new KeySelectorUtilizingKeyNameHint(locator); KeySelectorUtilizingKeyNameHint sel = new KeySelectorUtilizingKeyNameHint(locator);
try { try {
if (validateUsingKeySelector(signatureNode, sel)) { if (validateUsingKeySelector(signatureNode, sel, signedNodes)) {
return true; return true;
} }
if (sel.wasKeyLocated()) { if (sel.wasKeyLocated()) {
@ -477,7 +537,7 @@ public class XMLSignatureUtil {
for (Key key : locator) { for (Key key : locator) {
try { try {
if (validateUsingKeySelector(signatureNode, KeySelector.singletonKeySelector(key))) { if (validateUsingKeySelector(signatureNode, KeySelector.singletonKeySelector(key), signedNodes)) {
return true; return true;
} }
} catch (XMLSignatureException ex) { // pass through MarshalException } catch (XMLSignatureException ex) { // pass through MarshalException
@ -489,12 +549,26 @@ public class XMLSignatureUtil {
return false; return false;
} }
private static boolean validateUsingKeySelector(Node signatureNode, KeySelector validationKeySelector) throws XMLSignatureException, MarshalException { private static boolean validateUsingKeySelector(Node signatureNode, KeySelector validationKeySelector, Set<Node> signedNodes) throws XMLSignatureException, MarshalException {
DOMValidateContext valContext = new DOMValidateContext(validationKeySelector, signatureNode); DOMValidateContext valContext = new DOMValidateContext(validationKeySelector, signatureNode);
XMLSignature signature = fac.unmarshalXMLSignature(valContext); XMLSignature signature = fac.unmarshalXMLSignature(valContext);
boolean coreValidity = signature.validate(valContext); boolean coreValidity = signature.validate(valContext);
if (! coreValidity) { if (coreValidity) {
for (Reference ref : (List<Reference>) signature.getSignedInfo().getReferences()) {
try {
Data data = fac.getURIDereferencer().dereference(ref, valContext);
if (data instanceof NodeSetData) {
Iterator<Node> it = ((NodeSetData) data).iterator();
if (it.hasNext()) {
signedNodes.add(it.next()); // add the first referenced object as signed element
}
}
} catch (URIReferenceException e) {
// ignored as signature was ok so reference can be obtained
}
}
} else {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
boolean sv = signature.getSignatureValue().validate(valContext); boolean sv = signature.getSignatureValue().validate(valContext);
logger.trace("Signature validation status: " + sv); logger.trace("Signature validation status: " + sv);

View file

@ -34,6 +34,7 @@ import org.keycloak.testsuite.util.Matchers;
import org.keycloak.testsuite.util.RealmBuilder; import org.keycloak.testsuite.util.RealmBuilder;
import org.keycloak.testsuite.util.RoleBuilder; import org.keycloak.testsuite.util.RoleBuilder;
import org.keycloak.testsuite.util.RolesBuilder; import org.keycloak.testsuite.util.RolesBuilder;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.keycloak.testsuite.util.SamlClient.Binding; import org.keycloak.testsuite.util.SamlClient.Binding;
import org.keycloak.testsuite.util.SamlClientBuilder; import org.keycloak.testsuite.util.SamlClientBuilder;
@ -66,6 +67,7 @@ import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.keycloak.saml.common.constants.JBossSAMLURIConstants.ASSERTION_NSURI; import static org.keycloak.saml.common.constants.JBossSAMLURIConstants.ASSERTION_NSURI;
import static org.keycloak.saml.common.constants.JBossSAMLURIConstants.PROTOCOL_NSURI; import static org.keycloak.saml.common.constants.JBossSAMLURIConstants.PROTOCOL_NSURI;
import org.keycloak.saml.common.util.DocumentUtil;
import static org.keycloak.testsuite.adapter.AbstractServletsAdapterTest.samlServletDeployment; import static org.keycloak.testsuite.adapter.AbstractServletsAdapterTest.samlServletDeployment;
import static org.keycloak.testsuite.saml.AbstractSamlTest.REALM_NAME; import static org.keycloak.testsuite.saml.AbstractSamlTest.REALM_NAME;
import static org.keycloak.testsuite.saml.AbstractSamlTest.REALM_PRIVATE_KEY; import static org.keycloak.testsuite.saml.AbstractSamlTest.REALM_PRIVATE_KEY;
@ -181,6 +183,20 @@ public class SamlSignatureTest extends AbstractAdapterTest {
originalSignature.appendChild(object); originalSignature.appendChild(object);
object.appendChild(assertion); object.appendChild(assertion);
} }
public static void noDocumentSignatureOnlyOneAssertionSignedBelowResponse(Document document){
// remove the signature for the whole response
removeDocumentSignature(document);
// move the signature from the assertion to the response level
Element assertion = (Element) document.getElementsByTagNameNS(ASSERTION_NSURI.get(), "Assertion").item(0);
Element signature = (Element) assertion.getElementsByTagNameNS(XMLSignature.XMLNS, "Signature").item(0);
assertion.removeChild(signature);
document.getDocumentElement().appendChild(signature);
// create a second assertion without signature
Element evilAssertion = (Element) assertion.cloneNode(true);
evilAssertion.setAttribute("ID", "_evil_assertion_ID");
document.getDocumentElement().insertBefore(evilAssertion, assertion);
}
} }
@Page @Page
@ -321,11 +337,21 @@ public class SamlSignatureTest extends AbstractAdapterTest {
} }
} }
private static void removeDocumentSignature(Document doc) throws DOMException {
Element responseSignature = (Element) doc.getElementsByTagNameNS(XMLSignature.XMLNS, "Signature").item(0);
Assert.assertNotNull(doc.getDocumentElement().removeChild(responseSignature));
}
@Test @Test
public void testNoChange() throws Exception { public void testNoChange() throws Exception {
testSamlResponseModifications(r -> {}, true); testSamlResponseModifications(r -> {}, true);
} }
@Test
public void testOnlyAssertionSignature() throws Exception {
testSamlResponseModifications(SamlSignatureTest::removeDocumentSignature, true);
}
@Test @Test
public void testRemoveSignatures() throws Exception { public void testRemoveSignatures() throws Exception {
testSamlResponseModifications(SamlSignatureTest::removeAllSignatures, false); testSamlResponseModifications(SamlSignatureTest::removeAllSignatures, false);
@ -371,4 +397,8 @@ public class SamlSignatureTest extends AbstractAdapterTest {
testSamlResponseModifications(XSWHelpers::applyXSW8, false); testSamlResponseModifications(XSWHelpers::applyXSW8, false);
} }
@Test
public void testNoDocumentSignatureOnlyOneAssertionSignedBelowResponse() throws Exception {
testSamlResponseModifications(XSWHelpers::noDocumentSignatureOnlyOneAssertionSignedBelowResponse, false);
}
} }