diff --git a/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java b/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java index 6421d7bcd1..df2a7664f0 100755 --- a/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java +++ b/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java @@ -66,6 +66,7 @@ import org.keycloak.saml.common.util.DocumentUtil; import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder; import org.keycloak.saml.processing.core.saml.v2.constants.X500SAMLProfileConstants; import org.keycloak.saml.processing.core.saml.v2.util.AssertionUtil; +import org.keycloak.saml.processing.core.util.XMLEncryptionUtil; import org.keycloak.saml.processing.core.util.XMLSignatureUtil; import org.keycloak.saml.processing.web.util.PostBindingUtil; import org.keycloak.services.ErrorPage; @@ -89,7 +90,9 @@ import jakarta.ws.rs.core.UriInfo; import javax.xml.namespace.QName; import java.io.IOException; import java.security.Key; +import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.LinkedList; @@ -458,14 +461,19 @@ public class SAMLEndpoint { if (assertionIsEncrypted) { try { + XMLEncryptionUtil.DecryptionKeyLocator decryptionKeyLocator = new SAMLDecryptionKeysLocator(session, realm, config.getEncryptionAlgorithm()); /* This code is deprecated and will be removed in Keycloak 24 */ if (DEPRECATED_ENCRYPTION) { KeyManager.ActiveRsaKey keys = session.keys().getActiveRsaKey(realm); - assertionElement = AssertionUtil.decryptAssertion(responseType, keys.getPrivateKey()); - } else { - /* End of deprecated code */ - assertionElement = AssertionUtil.decryptAssertion(responseType, new SAMLDecryptionKeysLocator(session, realm, config.getEncryptionAlgorithm())); + final XMLEncryptionUtil.DecryptionKeyLocator tmp = decryptionKeyLocator; + decryptionKeyLocator = data -> { + List result = new ArrayList<>(tmp.getKeys(data)); + result.add(keys.getPrivateKey()); + return result; + }; } + /* End of deprecated code */ + assertionElement = AssertionUtil.decryptAssertion(responseType, decryptionKeyLocator); } catch (ProcessingException ex) { logger.warnf(ex, "Not possible to decrypt SAML assertion. Please check realm keys of usage ENC in the realm '%s' and make sure there is a key able to decrypt the assertion encrypted by identity provider '%s'", realm.getName(), config.getAlias()); throw new WebApplicationException(ex, Response.Status.BAD_REQUEST); @@ -511,14 +519,19 @@ public class SAMLEndpoint { if (AssertionUtil.isIdEncrypted(responseType)) { try { + XMLEncryptionUtil.DecryptionKeyLocator decryptionKeyLocator = new SAMLDecryptionKeysLocator(session, realm, config.getEncryptionAlgorithm()); /* This code is deprecated and will be removed in Keycloak 24 */ if (DEPRECATED_ENCRYPTION) { KeyManager.ActiveRsaKey keys = session.keys().getActiveRsaKey(realm); - AssertionUtil.decryptId(responseType, data -> Collections.singletonList(keys.getPrivateKey())); - } else { - /* End of deprecated code */ - AssertionUtil.decryptId(responseType, new SAMLDecryptionKeysLocator(session, realm, config.getEncryptionAlgorithm())); + final XMLEncryptionUtil.DecryptionKeyLocator tmp = decryptionKeyLocator; + decryptionKeyLocator = data -> { + List result = new ArrayList<>(tmp.getKeys(data)); + result.add(keys.getPrivateKey()); + return result; + }; } + /* End of deprecated code */ + AssertionUtil.decryptId(responseType, decryptionKeyLocator); } catch (ProcessingException ex) { logger.warnf(ex, "Not possible to decrypt SAML encryptedId. Please check realm keys of usage ENC in the realm '%s' and make sure there is a key able to decrypt the encryptedId encrypted by identity provider '%s'", realm.getName(), config.getAlias()); throw new WebApplicationException(ex, Response.Status.BAD_REQUEST);