Use references to obtain the signed elements in a signature (#188) (#33190)

Closes keycloak/keycloak-private#191
Closes #33116

Signed-off-by: rmartinc <rmartinc@redhat.com>
Co-authored-by: Ricardo Martin <rmartinc@redhat.com>
This commit is contained in:
Stian Thorgersen 2024-09-23 13:51:46 +02:00 committed by GitHub
parent af5eef57bf
commit d778a8551a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 130 additions and 27 deletions

View file

@ -59,7 +59,6 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import javax.xml.crypto.dsig.XMLSignature;
import javax.xml.datatype.XMLGregorianCalendar;
import javax.xml.stream.XMLEventReader;
@ -315,7 +314,7 @@ public class AssertionUtil {
}
protected static Element getSignature(Element element) {
return DocumentUtil.getDirectChildElement(element, XMLSignature.XMLNS, "Signature");
return XMLSignatureUtil.getSignature(element);
}
/**

View file

@ -81,15 +81,22 @@ import java.security.interfaces.DSAPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import javax.xml.crypto.AlgorithmMethod;
import javax.xml.crypto.Data;
import javax.xml.crypto.KeySelector;
import javax.xml.crypto.KeySelectorException;
import javax.xml.crypto.KeySelectorResult;
import javax.xml.crypto.NodeSetData;
import javax.xml.crypto.URIReferenceException;
import javax.xml.crypto.XMLCryptoContext;
import javax.xml.crypto.dom.DOMStructure;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.util.SecurityActions;
/**
@ -170,6 +177,52 @@ public class XMLSignatureUtil {
return xsf;
}
/**
* Returns the element that contains the signature for the passed element.
*
* @param element The element to search for the signature
* @return The signature element or null
*/
public static Element getSignature(Element element) {
Document doc = element.getOwnerDocument();
NodeList nl = doc.getElementsByTagNameNS(XMLSignature.XMLNS, "Signature");
if (element.getAttributeNode(JBossSAMLConstants.ID.get()) != null) {
// set the saml ID to be found
element.setIdAttribute(JBossSAMLConstants.ID.get(), true);
}
KeySelector nullSelector = new KeySelector() {
@Override
public KeySelectorResult select(KeyInfo ki, KeySelector.Purpose prps, AlgorithmMethod am, XMLCryptoContext xmlcc) throws KeySelectorException {
return () -> null;
}
};
try {
for (int i = 0; i < nl.getLength(); i++) {
Element signatureElement = (Element) nl.item(i);
DOMValidateContext valContext = new DOMValidateContext(nullSelector, signatureElement);
DOMStructure structure = new DOMStructure(signatureElement);
XMLSignature signature = fac.unmarshalXMLSignature(structure);
for (Reference ref : (List<Reference>) signature.getSignedInfo().getReferences()) {
try {
Data data = fac.getURIDereferencer().dereference(ref, valContext);
if (data instanceof NodeSetData) {
Iterator<Node> it = ((NodeSetData) data).iterator();
if (it.hasNext() && element.equals(it.next())) {
return signatureElement;
}
}
} catch (URIReferenceException e) {
logger.trace("Invalid URI reference in signature " + ref.getURI());
}
}
}
} catch (MarshalException e) {
logger.trace("Error unmarshalling signature", e);
}
return null;
}
/**
* Use this method to not include the KeyInfo in the signature
*
@ -404,7 +457,7 @@ public class XMLSignatureUtil {
* this way both assertions and the containing document are verified when signed.
*
* @param signedDoc
* @param publicKey
* @param locator
*
* @return
*
@ -428,39 +481,46 @@ public class XMLSignatureUtil {
if (locator == null)
throw logger.nullValueError("Public Key");
int signedAssertions = 0;
String assertionNameSpaceUri = null;
HashSet<Node> signedNodes = new HashSet<>();
for (int i = 0; i < nl.getLength(); i++) {
Node signatureNode = nl.item(i);
Node parent = signatureNode.getParentNode();
if (parent != null && JBossSAMLConstants.ASSERTION.get().equals(parent.getLocalName())) {
++signedAssertions;
if (assertionNameSpaceUri == null) {
assertionNameSpaceUri = parent.getNamespaceURI();
if (!validateSingleNode(signatureNode, locator, signedNodes)) {
return false;
}
}
if (signedNodes.contains(signedDoc.getDocumentElement())) {
logger.trace("All signatures are OK and root document is signed");
return true;
}
NodeList assertions = signedDoc.getElementsByTagNameNS(JBossSAMLURIConstants.ASSERTION_NSURI.get(), JBossSAMLConstants.ASSERTION.get());
if (assertions.getLength() > 0) {
// if document is not fully signed check if all the assertions are signed
for (int i = 0; i < assertions.getLength(); i++) {
if (!signedNodes.contains(assertions.item(i))) {
logger.debug("SAML Response document may contain malicious assertions. Signature validation will fail.");
// there are unsigned assertions mixed with signed ones
return false;
}
}
if (! validateSingleNode(signatureNode, locator)) return false;
logger.trace("Document not signed but all assertions are signed OK");
return true;
}
NodeList assertions = signedDoc.getElementsByTagNameNS(assertionNameSpaceUri, JBossSAMLConstants.ASSERTION.get());
if (signedAssertions > 0 && assertions != null && assertions.getLength() != signedAssertions) {
if (logger.isDebugEnabled()) {
logger.debug("SAML Response document may contain malicious assertions. Signature validation will fail.");
}
// there are unsigned assertions mixed with signed ones
return false;
}
return true;
return false;
}
public static boolean validateSingleNode(Node signatureNode, final KeyLocator locator) throws MarshalException, XMLSignatureException {
return validateSingleNode(signatureNode, locator, new HashSet<>());
}
public static boolean validateSingleNode(Node signatureNode, final KeyLocator locator, Set<Node> signedNodes) throws MarshalException, XMLSignatureException {
KeySelectorUtilizingKeyNameHint sel = new KeySelectorUtilizingKeyNameHint(locator);
try {
if (validateUsingKeySelector(signatureNode, sel)) {
if (validateUsingKeySelector(signatureNode, sel, signedNodes)) {
return true;
}
if (sel.wasKeyLocated()) {
@ -477,7 +537,7 @@ public class XMLSignatureUtil {
for (Key key : locator) {
try {
if (validateUsingKeySelector(signatureNode, KeySelector.singletonKeySelector(key))) {
if (validateUsingKeySelector(signatureNode, KeySelector.singletonKeySelector(key), signedNodes)) {
return true;
}
} catch (XMLSignatureException ex) { // pass through MarshalException
@ -489,12 +549,26 @@ public class XMLSignatureUtil {
return false;
}
private static boolean validateUsingKeySelector(Node signatureNode, KeySelector validationKeySelector) throws XMLSignatureException, MarshalException {
private static boolean validateUsingKeySelector(Node signatureNode, KeySelector validationKeySelector, Set<Node> signedNodes) throws XMLSignatureException, MarshalException {
DOMValidateContext valContext = new DOMValidateContext(validationKeySelector, signatureNode);
XMLSignature signature = fac.unmarshalXMLSignature(valContext);
boolean coreValidity = signature.validate(valContext);
if (! coreValidity) {
if (coreValidity) {
for (Reference ref : (List<Reference>) signature.getSignedInfo().getReferences()) {
try {
Data data = fac.getURIDereferencer().dereference(ref, valContext);
if (data instanceof NodeSetData) {
Iterator<Node> it = ((NodeSetData) data).iterator();
if (it.hasNext()) {
signedNodes.add(it.next()); // add the first referenced object as signed element
}
}
} catch (URIReferenceException e) {
// ignored as signature was ok so reference can be obtained
}
}
} else {
if (logger.isTraceEnabled()) {
boolean sv = signature.getSignatureValue().validate(valContext);
logger.trace("Signature validation status: " + sv);

View file

@ -34,6 +34,7 @@ import org.keycloak.testsuite.util.Matchers;
import org.keycloak.testsuite.util.RealmBuilder;
import org.keycloak.testsuite.util.RoleBuilder;
import org.keycloak.testsuite.util.RolesBuilder;
import org.junit.Assert;
import org.junit.Test;
import org.keycloak.testsuite.util.SamlClient.Binding;
import org.keycloak.testsuite.util.SamlClientBuilder;
@ -66,6 +67,7 @@ import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.keycloak.saml.common.constants.JBossSAMLURIConstants.ASSERTION_NSURI;
import static org.keycloak.saml.common.constants.JBossSAMLURIConstants.PROTOCOL_NSURI;
import org.keycloak.saml.common.util.DocumentUtil;
import static org.keycloak.testsuite.adapter.AbstractServletsAdapterTest.samlServletDeployment;
import static org.keycloak.testsuite.saml.AbstractSamlTest.REALM_NAME;
import static org.keycloak.testsuite.saml.AbstractSamlTest.REALM_PRIVATE_KEY;
@ -181,6 +183,20 @@ public class SamlSignatureTest extends AbstractAdapterTest {
originalSignature.appendChild(object);
object.appendChild(assertion);
}
public static void noDocumentSignatureOnlyOneAssertionSignedBelowResponse(Document document){
// remove the signature for the whole response
removeDocumentSignature(document);
// move the signature from the assertion to the response level
Element assertion = (Element) document.getElementsByTagNameNS(ASSERTION_NSURI.get(), "Assertion").item(0);
Element signature = (Element) assertion.getElementsByTagNameNS(XMLSignature.XMLNS, "Signature").item(0);
assertion.removeChild(signature);
document.getDocumentElement().appendChild(signature);
// create a second assertion without signature
Element evilAssertion = (Element) assertion.cloneNode(true);
evilAssertion.setAttribute("ID", "_evil_assertion_ID");
document.getDocumentElement().insertBefore(evilAssertion, assertion);
}
}
@Page
@ -321,11 +337,21 @@ public class SamlSignatureTest extends AbstractAdapterTest {
}
}
private static void removeDocumentSignature(Document doc) throws DOMException {
Element responseSignature = (Element) doc.getElementsByTagNameNS(XMLSignature.XMLNS, "Signature").item(0);
Assert.assertNotNull(doc.getDocumentElement().removeChild(responseSignature));
}
@Test
public void testNoChange() throws Exception {
testSamlResponseModifications(r -> {}, true);
}
@Test
public void testOnlyAssertionSignature() throws Exception {
testSamlResponseModifications(SamlSignatureTest::removeDocumentSignature, true);
}
@Test
public void testRemoveSignatures() throws Exception {
testSamlResponseModifications(SamlSignatureTest::removeAllSignatures, false);
@ -371,4 +397,8 @@ public class SamlSignatureTest extends AbstractAdapterTest {
testSamlResponseModifications(XSWHelpers::applyXSW8, false);
}
@Test
public void testNoDocumentSignatureOnlyOneAssertionSignedBelowResponse() throws Exception {
testSamlResponseModifications(XSWHelpers::noDocumentSignatureOnlyOneAssertionSignedBelowResponse, false);
}
}