From 9dd9c9b37ff91971af385f44eb80a304ddb3990b Mon Sep 17 00:00:00 2001 From: Bill Burke Date: Fri, 27 Feb 2015 20:16:34 -0500 Subject: [PATCH] add saml mapper interfaces --- .../saml/SALM2LoginResponseBuilder.java | 60 +-- .../protocol/saml/SAML2BindingBuilder.java | 6 +- .../protocol/saml/SAML2BindingBuilder2.java | 364 ++++++++++++++++++ .../keycloak/protocol/saml/SamlProtocol.java | 56 ++- .../mappers/AbstractSAMLProtocolMapper.java | 40 ++ .../saml/mappers/SAMLLoginResponseMapper.java | 18 + 6 files changed, 501 insertions(+), 43 deletions(-) create mode 100755 saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SAML2BindingBuilder2.java create mode 100755 saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/mappers/AbstractSAMLProtocolMapper.java create mode 100755 saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/mappers/SAMLLoginResponseMapper.java diff --git a/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SALM2LoginResponseBuilder.java b/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SALM2LoginResponseBuilder.java index 288e2bdb07..782adbc8a7 100755 --- a/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SALM2LoginResponseBuilder.java +++ b/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SALM2LoginResponseBuilder.java @@ -31,13 +31,14 @@ import static org.picketlink.common.util.StringUtil.isNotNull; *

* Configuration Options: * - * @author Anil.Saldhana@redhat.com * @author bburke@redhat.com */ -public class SALM2LoginResponseBuilder extends SAML2BindingBuilder { +public class SALM2LoginResponseBuilder { protected static final PicketLinkLogger logger = PicketLinkLoggerFactory.getLogger(); protected List roles = new LinkedList(); + protected String destination; + protected String issuer; protected String nameId; protected String nameIdFormat; protected boolean multiValuedRoles; @@ -53,6 +54,16 @@ public class SALM2LoginResponseBuilder extends SAML2BindingBuilder { return document; } - public String htmlResponse() throws ProcessingException, ConfigurationException, IOException { - return buildHtml(encoded(), destination, false); - - } public Response request() throws ConfigurationException, ProcessingException, IOException { return buildResponse(document, destination, true); } @@ -188,7 +184,7 @@ public class SAML2BindingBuilder { return response(destination, false); } public Response response(String redirectUri) throws ProcessingException, ConfigurationException, IOException { - return response(destination, false); + return response(redirectUri, false); } public Response request(String redirect) throws ProcessingException, ConfigurationException, IOException { diff --git a/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SAML2BindingBuilder2.java b/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SAML2BindingBuilder2.java new file mode 100755 index 0000000000..d5cc00d735 --- /dev/null +++ b/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SAML2BindingBuilder2.java @@ -0,0 +1,364 @@ +package org.keycloak.protocol.saml; + +import org.jboss.logging.Logger; +import org.picketlink.common.constants.GeneralConstants; +import org.picketlink.common.constants.JBossSAMLConstants; +import org.picketlink.common.constants.JBossSAMLURIConstants; +import org.picketlink.common.exceptions.ConfigurationException; +import org.picketlink.common.exceptions.ProcessingException; +import org.picketlink.common.util.DocumentUtil; +import org.picketlink.identity.federation.api.saml.v2.sig.SAML2Signature; +import org.picketlink.identity.federation.core.util.XMLEncryptionUtil; +import org.picketlink.identity.federation.core.wstrust.WSTrustUtil; +import org.picketlink.identity.federation.web.util.PostBindingUtil; +import org.picketlink.identity.federation.web.util.RedirectBindingUtil; +import org.w3c.dom.Document; +import org.w3c.dom.Element; +import org.w3c.dom.Node; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import javax.ws.rs.core.CacheControl; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.UriBuilder; +import javax.xml.namespace.QName; +import java.io.IOException; +import java.net.URI; +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.Signature; +import java.security.cert.X509Certificate; + +import static org.keycloak.util.HtmlUtils.escapeAttribute; +import static org.picketlink.common.util.StringUtil.isNotNull; + +/** + * @author Bill Burke + * @version $Revision: 1 $ + */ +public class SAML2BindingBuilder2 { + protected static final Logger logger = Logger.getLogger(SAML2BindingBuilder2.class); + + protected KeyPair signingKeyPair; + protected X509Certificate signingCertificate; + protected boolean sign; + protected boolean signAssertions; + protected SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.RSA_SHA1; + protected String relayState; + protected int encryptionKeySize = 128; + protected PublicKey encryptionPublicKey; + protected String encryptionAlgorithm = "AES"; + protected boolean encrypt; + + public T signDocument() { + this.sign = true; + return (T)this; + } + + public T signAssertions() { + this.signAssertions = true; + return (T)this; + } + + public T signWith(KeyPair keyPair) { + this.signingKeyPair = keyPair; + return (T)this; + } + + public T signWith(PrivateKey privateKey, PublicKey publicKey) { + this.signingKeyPair = new KeyPair(publicKey, privateKey); + return (T)this; + } + + public T signWith(KeyPair keyPair, X509Certificate cert) { + this.signingKeyPair = keyPair; + this.signingCertificate = cert; + return (T)this; + } + + public T signWith(PrivateKey privateKey, PublicKey publicKey, X509Certificate cert) { + this.signingKeyPair = new KeyPair(publicKey, privateKey); + this.signingCertificate = cert; + return (T)this; + } + + public T signatureAlgorithm(SignatureAlgorithm alg) { + this.signatureAlgorithm = alg; + return (T)this; + } + + public T encrypt(PublicKey publicKey) { + encrypt = true; + encryptionPublicKey = publicKey; + return (T)this; + } + + public T encryptionAlgorithm(String alg) { + this.encryptionAlgorithm = alg; + return (T)this; + } + + public T encryptionKeySize(int size) { + this.encryptionKeySize = size; + return (T)this; + } + + public T relayState(String relayState) { + this.relayState = relayState; + return (T)this; + } + + public class PostBindingBuilder { + protected Document document; + + public PostBindingBuilder(Document document) throws ProcessingException { + if (encrypt) encryptDocument(document); + this.document = document; + if (signAssertions) { + signAssertion(document); + } + if (sign) { + signDocument(document); + } + } + + public String encoded() throws ProcessingException, ConfigurationException, IOException { + byte[] responseBytes = org.picketlink.identity.federation.core.saml.v2.util.DocumentUtil.getDocumentAsString(document).getBytes("UTF-8"); + return PostBindingUtil.base64Encode(new String(responseBytes)); + } + public Document getDocument() { + return document; + } + + public Response request(String actionUrl) throws ConfigurationException, ProcessingException, IOException { + return buildResponse(document, actionUrl, true); + } + public Response response(String actionUrl) throws ConfigurationException, ProcessingException, IOException { + return buildResponse(document, actionUrl, false); + } + } + + + public class RedirectBindingBuilder { + protected Document document; + + public RedirectBindingBuilder(Document document) throws ProcessingException { + if (encrypt) encryptDocument(document); + this.document = document; + if (signAssertions) { + signAssertion(document); + } + } + + public Document getDocument() { + return document; + } + public URI responseUri(String redirectUri, boolean asRequest) throws ConfigurationException, ProcessingException, IOException { + String samlParameterName = GeneralConstants.SAML_RESPONSE_KEY; + + if (asRequest) { + samlParameterName = GeneralConstants.SAML_REQUEST_KEY; + } + + return generateRedirectUri(samlParameterName, redirectUri, document); + } + public Response response(String redirectUri) throws ProcessingException, ConfigurationException, IOException { + return response(redirectUri, false); + } + + public Response request(String redirect) throws ProcessingException, ConfigurationException, IOException { + return response(redirect, true); + } + + private Response response(String redirectUri, boolean asRequest) throws ProcessingException, ConfigurationException, IOException { + URI uri = responseUri(redirectUri, asRequest); + if (logger.isDebugEnabled()) logger.trace("redirect-binding uri: " + uri.toString()); + CacheControl cacheControl = new CacheControl(); + cacheControl.setNoCache(true); + return Response.status(302).location(uri) + .header("Pragma", "no-cache") + .header("Cache-Control", "no-cache, no-store").build(); + } + + } + + + + private String getSAMLNSPrefix(Document samlResponseDocument) { + Node assertionElement = samlResponseDocument.getDocumentElement() + .getElementsByTagNameNS(JBossSAMLURIConstants.ASSERTION_NSURI.get(), JBossSAMLConstants.ASSERTION.get()).item(0); + + if (assertionElement == null) { + throw new IllegalStateException("Unable to find assertion in saml response document"); + } + + return assertionElement.getPrefix(); + } + + protected void encryptDocument(Document samlDocument) throws ProcessingException { + String samlNSPrefix = getSAMLNSPrefix(samlDocument); + + try { + QName encryptedAssertionElementQName = new QName(JBossSAMLURIConstants.ASSERTION_NSURI.get(), + JBossSAMLConstants.ENCRYPTED_ASSERTION.get(), samlNSPrefix); + + byte[] secret = WSTrustUtil.createRandomSecret(encryptionKeySize / 8); + SecretKey secretKey = new SecretKeySpec(secret, encryptionAlgorithm); + + // encrypt the Assertion element and replace it with a EncryptedAssertion element. + XMLEncryptionUtil.encryptElement(new QName(JBossSAMLURIConstants.ASSERTION_NSURI.get(), + JBossSAMLConstants.ASSERTION.get(), samlNSPrefix), samlDocument, encryptionPublicKey, + secretKey, encryptionKeySize, encryptedAssertionElementQName, true); + } catch (Exception e) { + throw new ProcessingException("failed to encrypt", e); + } + + } + + protected void signDocument(Document samlDocument) throws ProcessingException { + String signatureMethod = signatureAlgorithm.getXmlSignatureMethod(); + String signatureDigestMethod = signatureAlgorithm.getXmlSignatureDigestMethod(); + SAML2Signature samlSignature = new SAML2Signature(); + + if (signatureMethod != null) { + samlSignature.setSignatureMethod(signatureMethod); + } + + if (signatureDigestMethod != null) { + samlSignature.setDigestMethod(signatureDigestMethod); + } + + Node nextSibling = samlSignature.getNextSiblingOfIssuer(samlDocument); + + samlSignature.setNextSibling(nextSibling); + + if (signingCertificate != null) { + samlSignature.setX509Certificate(signingCertificate); + } + + samlSignature.signSAMLDocument(samlDocument, signingKeyPair); + } + + protected void signAssertion(Document samlDocument) throws ProcessingException { + Element originalAssertionElement = DocumentUtil.getChildElement(samlDocument.getDocumentElement(), new QName(JBossSAMLURIConstants.ASSERTION_NSURI.get(), JBossSAMLConstants.ASSERTION.get())); + if (originalAssertionElement == null) return; + Node clonedAssertionElement = originalAssertionElement.cloneNode(true); + Document temporaryDocument; + + try { + temporaryDocument = DocumentUtil.createDocument(); + } catch (ConfigurationException e) { + throw new ProcessingException(e); + } + + temporaryDocument.adoptNode(clonedAssertionElement); + temporaryDocument.appendChild(clonedAssertionElement); + + signDocument(temporaryDocument); + + samlDocument.adoptNode(clonedAssertionElement); + + Element parentNode = (Element) originalAssertionElement.getParentNode(); + + parentNode.replaceChild(clonedAssertionElement, originalAssertionElement); + } + + + protected Response buildResponse(Document responseDoc, String actionUrl, boolean asRequest) throws ProcessingException, ConfigurationException, IOException { + String str = buildHtmlPostResponse(responseDoc, actionUrl, asRequest); + + CacheControl cacheControl = new CacheControl(); + cacheControl.setNoCache(true); + return Response.ok(str, MediaType.TEXT_HTML_TYPE) + .header("Pragma", "no-cache") + .header("Cache-Control", "no-cache, no-store").build(); + } + + protected String buildHtmlPostResponse(Document responseDoc, String actionUrl, boolean asRequest) throws ProcessingException, ConfigurationException, IOException { + byte[] responseBytes = DocumentUtil.getDocumentAsString(responseDoc).getBytes("UTF-8"); + String samlResponse = PostBindingUtil.base64Encode(new String(responseBytes)); + + return buildHtml(samlResponse, actionUrl, asRequest); + } + + protected String buildHtml(String samlResponse, String actionUrl, boolean asRequest) { + StringBuilder builder = new StringBuilder(); + + String key = GeneralConstants.SAML_RESPONSE_KEY; + + if (asRequest) { + key = GeneralConstants.SAML_REQUEST_KEY; + } + + builder.append(""); + builder.append(""); + + builder.append("HTTP Post Binding Response (Response)"); + builder.append(""); + builder.append(""); + + builder.append("

"); + builder.append(""); + + if (isNotNull(relayState)) { + builder.append(""); + } + + builder.append(""); + + builder.append("
"); + + return builder.toString(); + } + + protected String base64Encoded(Document document) throws ConfigurationException, ProcessingException, IOException { + String documentAsString = org.picketlink.identity.federation.core.saml.v2.util.DocumentUtil.getDocumentAsString(document); + logger.debugv("saml docment: {0}", documentAsString); + byte[] responseBytes = documentAsString.getBytes("UTF-8"); + + return RedirectBindingUtil.deflateBase64URLEncode(responseBytes); + } + + + protected URI generateRedirectUri(String samlParameterName, String redirectUri, Document document) throws ConfigurationException, ProcessingException, IOException { + UriBuilder builder = UriBuilder.fromUri(redirectUri) + .replaceQuery(null) + .queryParam(samlParameterName, base64Encoded(document)); + if (relayState != null) { + builder.queryParam("RelayState", relayState); + } + + if (sign) { + builder.queryParam(GeneralConstants.SAML_SIG_ALG_REQUEST_KEY, signatureAlgorithm.getJavaSignatureAlgorithm()); + URI uri = builder.build(); + String rawQuery = uri.getRawQuery(); + Signature signature = signatureAlgorithm.createSignature(); + byte[] sig = new byte[0]; + try { + signature.initSign(signingKeyPair.getPrivate()); + signature.update(rawQuery.getBytes("UTF-8")); + sig = signature.sign(); + } catch (Exception e) { + throw new ProcessingException(e); + } + String encodedSig = RedirectBindingUtil.base64URLEncode(sig); + builder.queryParam(GeneralConstants.SAML_SIGNATURE_REQUEST_KEY, encodedSig); + } + return builder.build(); + } + + public RedirectBindingBuilder redirectBinding(Document document) throws ProcessingException { + return new RedirectBindingBuilder(document); + } + + public PostBindingBuilder postBinding(Document document) throws ProcessingException { + return new PostBindingBuilder(document); + } + + +} diff --git a/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java b/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java index ffdb39d83b..dad37338fc 100755 --- a/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java +++ b/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java @@ -9,11 +9,15 @@ import org.keycloak.models.ClaimMask; import org.keycloak.models.ClientModel; import org.keycloak.models.ClientSessionModel; import org.keycloak.models.KeycloakSession; +import org.keycloak.models.KeycloakSessionFactory; +import org.keycloak.models.ProtocolMapperModel; import org.keycloak.models.RealmModel; import org.keycloak.models.RoleModel; import org.keycloak.models.UserModel; import org.keycloak.models.UserSessionModel; import org.keycloak.protocol.LoginProtocol; +import org.keycloak.protocol.ProtocolMapper; +import org.keycloak.protocol.saml.mappers.SAMLLoginResponseMapper; import org.keycloak.services.managers.ClientSessionCode; import org.keycloak.services.managers.ResourceAdminManager; import org.keycloak.services.resources.RealmsResource; @@ -25,13 +29,16 @@ import org.picketlink.common.exceptions.ConfigurationException; import org.picketlink.common.exceptions.ParsingException; import org.picketlink.common.exceptions.ProcessingException; import org.picketlink.identity.federation.core.saml.v2.constants.X500SAMLProfileConstants; +import org.picketlink.identity.federation.saml.v2.protocol.ResponseType; import org.picketlink.identity.federation.web.handlers.saml2.SAML2LogOutHandler; +import org.w3c.dom.Document; import javax.ws.rs.core.HttpHeaders; import javax.ws.rs.core.Response; import javax.ws.rs.core.UriInfo; import java.io.IOException; import java.security.PublicKey; +import java.util.Set; import java.util.UUID; /** @@ -243,7 +250,6 @@ public class SamlProtocol implements LoginProtocol { SALM2LoginResponseBuilder builder = new SALM2LoginResponseBuilder(); builder.requestID(requestID) - .relayState(relayState) .destination(redirectUri) .issuer(responseIssuer) .requestIssuer(clientSession.getClient().getClientId()) @@ -260,19 +266,33 @@ public class SamlProtocol implements LoginProtocol { builder.roles(roleModel.getName()); } } + if (!includeAuthnStatement(client)) { + builder.disableAuthnStatement(true); + } + + Document samlDocument = null; + try { + ResponseType samlModel = builder.buildModel(); + samlModel = transformLoginResponse(session, samlModel, client, userSession, clientSession); + samlDocument = builder.buildDocument(samlModel); + } catch (Exception e) { + logger.error("failed", e); + return Flows.forwardToSecurityFailurePage(session, realm, uriInfo, "Failed to process response"); + } + + SAML2BindingBuilder2 bindingBuilder = new SAML2BindingBuilder2(); + bindingBuilder.relayState(relayState); + if (requiresRealmSignature(client)) { - builder.signatureAlgorithm(getSignatureAlgorithm(client)) + bindingBuilder.signatureAlgorithm(getSignatureAlgorithm(client)) .signWith(realm.getPrivateKey(), realm.getPublicKey(), realm.getCertificate()) .signDocument(); } if (requiresAssertionSignature(client)) { - builder.signatureAlgorithm(getSignatureAlgorithm(client)) + bindingBuilder.signatureAlgorithm(getSignatureAlgorithm(client)) .signWith(realm.getPrivateKey(), realm.getPublicKey(), realm.getCertificate()) .signAssertions(); } - if (!includeAuthnStatement(client)) { - builder.disableAuthnStatement(true); - } if (requiresEncryption(client)) { PublicKey publicKey = null; try { @@ -281,13 +301,13 @@ public class SamlProtocol implements LoginProtocol { logger.error("failed", e); return Flows.forwardToSecurityFailurePage(session, realm, uriInfo, "Failed to process response"); } - builder.encrypt(publicKey); + bindingBuilder.encrypt(publicKey); } try { if (isPostBinding(clientSession)) { - return builder.postBinding().response(); + return bindingBuilder.postBinding(samlDocument).response(redirectUri); } else { - return builder.redirectBinding().response(); + return bindingBuilder.redirectBinding(samlDocument).response(redirectUri); } } catch (Exception e) { logger.error("failed", e); @@ -337,6 +357,24 @@ public class SamlProtocol implements LoginProtocol { } } + public ResponseType transformLoginResponse(KeycloakSession session, ResponseType response, ClientModel client, + UserSessionModel userSession, ClientSessionModel clientSession) { + Set mappings = client.getProtocolMappers(); + KeycloakSessionFactory sessionFactory = session.getKeycloakSessionFactory(); + for (ProtocolMapperModel mapping : mappings) { + if (!mapping.getProtocol().equals(SamlProtocol.LOGIN_PROTOCOL)) continue; + + ProtocolMapper mapper = (ProtocolMapper)sessionFactory.getProviderFactory(ProtocolMapper.class, mapping.getProtocolMapper()); + if (mapper == null || !(mapper instanceof SAMLLoginResponseMapper)) continue; + response = ((SAMLLoginResponseMapper)mapper).transformLoginResponse(response, mapping, session, userSession, clientSession); + + + + } + return response; + } + + @Override diff --git a/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/mappers/AbstractSAMLProtocolMapper.java b/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/mappers/AbstractSAMLProtocolMapper.java new file mode 100755 index 0000000000..3d42ef799b --- /dev/null +++ b/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/mappers/AbstractSAMLProtocolMapper.java @@ -0,0 +1,40 @@ +package org.keycloak.protocol.saml.mappers; + +import org.keycloak.Config; +import org.keycloak.models.KeycloakSession; +import org.keycloak.models.KeycloakSessionFactory; +import org.keycloak.protocol.ProtocolMapper; +import org.keycloak.protocol.oidc.OIDCLoginProtocol; +import org.keycloak.protocol.saml.SamlProtocol; + +/** + * @author Bill Burke + * @version $Revision: 1 $ + */ +public abstract class AbstractSAMLProtocolMapper implements ProtocolMapper { + + + @Override + public String getProtocol() { + return SamlProtocol.LOGIN_PROTOCOL; + } + + @Override + public void close() { + + } + + @Override + public final ProtocolMapper create(KeycloakSession session) { + throw new RuntimeException("UNSUPPORTED METHOD"); + } + + @Override + public void init(Config.Scope config) { + } + + @Override + public void postInit(KeycloakSessionFactory factory) { + + } +} diff --git a/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/mappers/SAMLLoginResponseMapper.java b/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/mappers/SAMLLoginResponseMapper.java new file mode 100755 index 0000000000..306877b058 --- /dev/null +++ b/saml/saml-protocol/src/main/java/org/keycloak/protocol/saml/mappers/SAMLLoginResponseMapper.java @@ -0,0 +1,18 @@ +package org.keycloak.protocol.saml.mappers; + +import org.keycloak.models.ClientSessionModel; +import org.keycloak.models.KeycloakSession; +import org.keycloak.models.ProtocolMapperModel; +import org.keycloak.models.UserSessionModel; +import org.keycloak.representations.AccessToken; +import org.picketlink.identity.federation.saml.v2.protocol.ResponseType; + +/** + * @author Bill Burke + * @version $Revision: 1 $ + */ +public interface SAMLLoginResponseMapper { + + ResponseType transformLoginResponse(ResponseType response, ProtocolMapperModel mappingModel, KeycloakSession session, + UserSessionModel userSession, ClientSessionModel clientSession); +}