Ensure that the EncryptedKey is passed to the DecryptionKeyLocator for SAML

Closes https://github.com/keycloak/keycloak/issues/22974
This commit is contained in:
rmartinc 2023-09-12 11:34:18 +02:00 committed by Marek Posolda
parent 48e4e973a4
commit f8a9e0134a
4 changed files with 129 additions and 54 deletions

View file

@ -49,6 +49,11 @@ const Fields = ({ readOnly }: DescriptorSettingsProps) => {
name: "config.wantAuthnRequestsSigned", name: "config.wantAuthnRequestsSigned",
}); });
const wantAssertionsEncrypted = useWatch({
control,
name: "config.wantAssertionsEncrypted",
});
const validateSignature = useWatch({ const validateSignature = useWatch({
control, control,
name: "config.validateSignature", name: "config.validateSignature",
@ -376,41 +381,6 @@ const Fields = ({ readOnly }: DescriptorSettingsProps) => {
)} )}
></Controller> ></Controller>
</FormGroup> </FormGroup>
<FormGroup
label={t("encryptionAlgorithm")}
labelIcon={
<HelpItem
helpText={t("encryptionAlgorithmHelp")}
fieldLabelId="identity-provider:encryptionAlgorithm"
/>
}
fieldId="kc-encryptionAlgorithm"
>
<Controller
name="config.encryptionAlgorithm"
defaultValue="RSA-OAEP"
control={control}
render={({ field }) => (
<Select
toggleId="kc-encryptionAlgorithm"
onToggle={(isExpanded) =>
setEncryptionAlgorithmDropdownOpen(isExpanded)
}
isOpen={encryptionAlgorithmDropdownOpen}
onSelect={(_, value) => {
field.onChange(value.toString());
setEncryptionAlgorithmDropdownOpen(false);
}}
selections={field.value}
variant={SelectVariant.single}
isDisabled={readOnly}
>
<SelectOption value="RSA-OAEP" />
<SelectOption value="RSA1_5" />
</Select>
)}
></Controller>
</FormGroup>
<FormGroup <FormGroup
label={t("samlSignatureKeyName")} label={t("samlSignatureKeyName")}
labelIcon={ labelIcon={
@ -461,6 +431,45 @@ const Fields = ({ readOnly }: DescriptorSettingsProps) => {
label="wantAssertionsEncrypted" label="wantAssertionsEncrypted"
isReadOnly={readOnly} isReadOnly={readOnly}
/> />
{wantAssertionsEncrypted === "true" && (
<FormGroup
label={t("encryptionAlgorithm")}
labelIcon={
<HelpItem
helpText={t("encryptionAlgorithmHelp")}
fieldLabelId="encryptionAlgorithm"
/>
}
fieldId="kc-encryptionAlgorithm"
>
<Controller
name="config.encryptionAlgorithm"
defaultValue="RSA-OAEP"
control={control}
render={({ field }) => (
<Select
toggleId="kc-encryptionAlgorithm"
onToggle={(isExpanded) =>
setEncryptionAlgorithmDropdownOpen(isExpanded)
}
isOpen={encryptionAlgorithmDropdownOpen}
onSelect={(_, value) => {
field.onChange(value.toString());
setEncryptionAlgorithmDropdownOpen(false);
}}
selections={field.value}
variant={SelectVariant.single}
isDisabled={readOnly}
>
<SelectOption value="RSA-OAEP" />
<SelectOption value="RSA1_5" />
</Select>
)}
></Controller>
</FormGroup>
)}
<SwitchField <SwitchField
field="config.forceAuthn" field="config.forceAuthn"
label="forceAuthentication" label="forceAuthentication"

View file

@ -20,6 +20,7 @@ import org.apache.xml.security.encryption.EncryptedData;
import org.apache.xml.security.encryption.EncryptedKey; import org.apache.xml.security.encryption.EncryptedKey;
import org.apache.xml.security.encryption.XMLCipher; import org.apache.xml.security.encryption.XMLCipher;
import org.apache.xml.security.encryption.XMLEncryptionException; import org.apache.xml.security.encryption.XMLEncryptionException;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.apache.xml.security.utils.EncryptionConstants; import org.apache.xml.security.utils.EncryptionConstants;
import org.keycloak.saml.common.PicketLinkLogger; import org.keycloak.saml.common.PicketLinkLogger;
@ -252,18 +253,6 @@ public class XMLEncryptionUtil {
if (encDataElement == null) if (encDataElement == null)
throw logger.domMissingElementError("No element representing the encrypted data found"); throw logger.domMissingElementError("No element representing the encrypted data found");
// Look at siblings for the key
Element encKeyElement = getNextElementNode(encDataElement.getNextSibling());
if (encKeyElement == null) {
// Search the enc data element for enc key
NodeList nodeList = encDataElement.getElementsByTagNameNS(EncryptionConstants.EncryptionSpecNS, EncryptionConstants._TAG_ENCRYPTEDKEY);
if (nodeList == null || nodeList.getLength() == 0)
throw logger.nullValueError("Encrypted Key not found in the enc data");
encKeyElement = (Element) nodeList.item(0);
}
XMLCipher cipher; XMLCipher cipher;
EncryptedData encryptedData; EncryptedData encryptedData;
EncryptedKey encryptedKey; EncryptedKey encryptedKey;
@ -271,8 +260,18 @@ public class XMLEncryptionUtil {
cipher = XMLCipher.getInstance(); cipher = XMLCipher.getInstance();
cipher.init(XMLCipher.DECRYPT_MODE, null); cipher.init(XMLCipher.DECRYPT_MODE, null);
encryptedData = cipher.loadEncryptedData(documentWithEncryptedElement, encDataElement); encryptedData = cipher.loadEncryptedData(documentWithEncryptedElement, encDataElement);
encryptedKey = cipher.loadEncryptedKey(documentWithEncryptedElement, encKeyElement); if (encryptedData.getKeyInfo() == null) {
} catch (XMLEncryptionException e1) { throw logger.domMissingElementError("No element representing KeyInfo found in the EncryptedData");
}
encryptedKey = encryptedData.getKeyInfo().itemEncryptedKey(0);
if (encryptedKey == null) {
// the encrypted key is not inside the encrypted data, locate it
Element encKeyElement = locateEncryptedKeyElement(encDataElement);
encryptedKey = cipher.loadEncryptedKey(documentWithEncryptedElement, encKeyElement);
encryptedData.getKeyInfo().add(encryptedKey);
}
} catch (XMLSecurityException e1) {
throw logger.processingError(e1); throw logger.processingError(e1);
} }
@ -325,6 +324,28 @@ public class XMLEncryptionUtil {
return decryptedDoc.getDocumentElement(); return decryptedDoc.getDocumentElement();
} }
/**
* Locates the EncryptedKey element once the EncryptedData element is found.
* A exception is thrown if not found.
*
* @param encDataElement The EncryptedData element found
* @return The EncryptedKey element
*/
private static Element locateEncryptedKeyElement(Element encDataElement) {
// Look at siblings for the key
Element encKeyElement = getNextElementNode(encDataElement.getNextSibling());
if (encKeyElement == null) {
// Search the enc data element for enc key
NodeList nodeList = encDataElement.getElementsByTagNameNS(EncryptionConstants.EncryptionSpecNS, EncryptionConstants._TAG_ENCRYPTEDKEY);
if (nodeList == null || nodeList.getLength() == 0)
throw logger.nullValueError("Encrypted Key not found in the enc data");
encKeyElement = (Element) nodeList.item(0);
}
return encKeyElement;
}
/** /**
* From the secret key, get the W3C XML Encryption URL * From the secret key, get the W3C XML Encryption URL
* *

View file

@ -153,6 +153,7 @@ public class SAMLDecryptionKeysLocator implements XMLEncryptionUtil.DecryptionKe
// Map keys to PrivateKey // Map keys to PrivateKey
return keysStream return keysStream
.map(KeyWrapper::getPrivateKey) .map(KeyWrapper::getPrivateKey)
.filter(Objects::nonNull)
.map(Key::getEncoded) .map(Key::getEncoded)
.map(encoded -> { .map(encoded -> {
try { try {

View file

@ -22,9 +22,12 @@ import java.security.KeyPair;
import java.security.KeyPairGenerator; import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.Collections;
import java.util.function.Function;
import javax.crypto.Cipher; import javax.crypto.Cipher;
import javax.crypto.NoSuchPaddingException; import javax.crypto.NoSuchPaddingException;
import org.apache.xml.security.encryption.XMLCipher; import org.apache.xml.security.encryption.XMLCipher;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.apache.xml.security.utils.EncryptionConstants; import org.apache.xml.security.utils.EncryptionConstants;
import org.hamcrest.MatcherAssert; import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
@ -36,17 +39,19 @@ import org.keycloak.dom.saml.v2.assertion.AssertionType;
import org.keycloak.dom.saml.v2.assertion.NameIDType; import org.keycloak.dom.saml.v2.assertion.NameIDType;
import org.keycloak.dom.saml.v2.protocol.ResponseType; import org.keycloak.dom.saml.v2.protocol.ResponseType;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.protocol.saml.JaxrsSAML2BindingBuilder;
import org.keycloak.protocol.saml.SAMLEncryptionAlgorithms;
import org.keycloak.saml.SAML2LoginResponseBuilder; import org.keycloak.saml.SAML2LoginResponseBuilder;
import org.keycloak.saml.SAMLRequestParser; import org.keycloak.saml.SAMLRequestParser;
import org.keycloak.saml.common.constants.JBossSAMLConstants;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants; import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.util.DocumentUtil; import org.keycloak.saml.common.util.DocumentUtil;
import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder; import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder;
import org.keycloak.saml.processing.core.saml.v2.util.AssertionUtil; import org.keycloak.saml.processing.core.saml.v2.util.AssertionUtil;
import org.keycloak.saml.processing.core.util.XMLEncryptionUtil;
import org.keycloak.services.DefaultKeycloakSession; import org.keycloak.services.DefaultKeycloakSession;
import org.keycloak.services.DefaultKeycloakSessionFactory; import org.keycloak.services.DefaultKeycloakSessionFactory;
import org.w3c.dom.Document; import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;
/** /**
* <p>Simple test class that checks SAML encryption with different algorithms. * <p>Simple test class that checks SAML encryption with different algorithms.
@ -56,8 +61,6 @@ import org.w3c.dom.Document;
*/ */
public class SamlEncryptionTest { public class SamlEncryptionTest {
private static final KeyPair rsaKeyPair;
static { static {
try { try {
KeyPairGenerator rsa = KeyPairGenerator.getInstance("RSA"); KeyPairGenerator rsa = KeyPairGenerator.getInstance("RSA");
@ -68,6 +71,17 @@ public class SamlEncryptionTest {
} }
} }
private static final KeyPair rsaKeyPair;
private static final XMLEncryptionUtil.DecryptionKeyLocator keyLocator = data -> {
try {
Assert.assertNotNull("EncryptedData does not contain KeyInfo", data.getKeyInfo());
Assert.assertNotNull("EncryptedData does not contain EncryptedKey", data.getKeyInfo().itemEncryptedKey(0));
return Collections.singletonList(rsaKeyPair.getPrivate());
} catch (XMLSecurityException e) {
throw new IllegalArgumentException("EncryptedData does not contain KeyInfo ", e);
}
};
@BeforeClass @BeforeClass
public static void beforeClass() { public static void beforeClass() {
Cipher cipher = null; Cipher cipher = null;
@ -86,6 +100,11 @@ public class SamlEncryptionTest {
} }
private void testEncryption(KeyPair pair, String alg, int keySize, String keyWrapAlg, String keyWrapHashMethod, String keyWrapMgf) throws Exception { private void testEncryption(KeyPair pair, String alg, int keySize, String keyWrapAlg, String keyWrapHashMethod, String keyWrapMgf) throws Exception {
testEncryption(pair, alg, keySize, keyWrapAlg, keyWrapHashMethod, keyWrapMgf, Function.identity());
}
private void testEncryption(KeyPair pair, String alg, int keySize, String keyWrapAlg,
String keyWrapHashMethod, String keyWrapMgf, Function<Document,Document> transformer) throws Exception {
SAML2LoginResponseBuilder builder = new SAML2LoginResponseBuilder(); SAML2LoginResponseBuilder builder = new SAML2LoginResponseBuilder();
builder.requestID("requestId") builder.requestID("requestId")
.destination("http://localhost") .destination("http://localhost")
@ -120,18 +139,38 @@ public class SamlEncryptionTest {
Document samlDocument = builder.buildDocument(samlModel); Document samlDocument = builder.buildDocument(samlModel);
bindingBuilder.postBinding(samlDocument); bindingBuilder.postBinding(samlDocument);
samlDocument = transformer.apply(samlDocument);
String samlResponse = DocumentUtil.getDocumentAsString(samlDocument); String samlResponse = DocumentUtil.getDocumentAsString(samlDocument);
SAMLDocumentHolder holder = SAMLRequestParser.parseResponseDocument(samlResponse.getBytes(StandardCharsets.UTF_8)); SAMLDocumentHolder holder = SAMLRequestParser.parseResponseDocument(samlResponse.getBytes(StandardCharsets.UTF_8));
ResponseType responseType = (ResponseType) holder.getSamlObject(); ResponseType responseType = (ResponseType) holder.getSamlObject();
Assert.assertTrue("Assertion is not encrypted", AssertionUtil.isAssertionEncrypted(responseType)); Assert.assertTrue("Assertion is not encrypted", AssertionUtil.isAssertionEncrypted(responseType));
AssertionType assertion = AssertionUtil.getAssertion(holder, responseType, pair.getPrivate()); AssertionUtil.decryptAssertion(responseType, keyLocator);
AssertionType assertion = responseType.getAssertions().get(0).getAssertion();
Assert.assertEquals("issuer", assertion.getIssuer().getValue()); Assert.assertEquals("issuer", assertion.getIssuer().getValue());
MatcherAssert.assertThat(assertion.getSubject().getSubType().getBaseID(), Matchers.instanceOf(NameIDType.class)); MatcherAssert.assertThat(assertion.getSubject().getSubType().getBaseID(), Matchers.instanceOf(NameIDType.class));
NameIDType nameId = (NameIDType) assertion.getSubject().getSubType().getBaseID(); NameIDType nameId = (NameIDType) assertion.getSubject().getSubType().getBaseID();
Assert.assertEquals("nameId", nameId.getValue()); Assert.assertEquals("nameId", nameId.getValue());
} }
private Document moveEncryptedKeyToRetrievalMethod(Document doc) {
NodeList nodes = doc.getElementsByTagNameNS(JBossSAMLURIConstants.XMLENC_NSURI.get(), JBossSAMLConstants.ENCRYPTED_KEY.get());
Element encKey = (Element) nodes.item(0);
Element keyInfo = (Element) encKey.getParentNode();
// remove the encKey, insert into EncryptedAssertion and substitute it with a RetrievalMethod
keyInfo.removeChild(encKey);
encKey.setAttribute("Id", "encryption-key-123");
keyInfo.getParentNode().getParentNode().appendChild(encKey);
Element retrievalMethod = doc.createElementNS(JBossSAMLURIConstants.XMLENC_NSURI.get(), "xenc:RetrievalMethod");
retrievalMethod.setAttribute("Type", "http://www.w3.org/2001/04/xmlenc#EncryptedKey");
retrievalMethod.setAttribute("URI", "encryption-key-123");
keyInfo.appendChild(retrievalMethod);
return doc;
}
@Test @Test
public void testDefault() throws Exception { public void testDefault() throws Exception {
testEncryption(rsaKeyPair, null, -1, null, null, null); testEncryption(rsaKeyPair, null, -1, null, null, null);
@ -164,4 +203,9 @@ public class SamlEncryptionTest {
public void testRsaOaep11WithSha512AndMgfSha512() throws Exception { public void testRsaOaep11WithSha512AndMgfSha512() throws Exception {
testEncryption(rsaKeyPair, "AES", 256, XMLCipher.RSA_OAEP_11, XMLCipher.SHA512, EncryptionConstants.MGF1_SHA512); testEncryption(rsaKeyPair, "AES", 256, XMLCipher.RSA_OAEP_11, XMLCipher.SHA512, EncryptionConstants.MGF1_SHA512);
} }
@Test
public void testEncryptionWithRetrievalMethod() throws Exception {
testEncryption(rsaKeyPair, null, -1, null, null, null, this::moveEncryptedKeyToRetrievalMethod);
}
} }