Merge pull request #1007 from patriot1burke/master

add saml mapper interfaces
This commit is contained in:
Bill Burke 2015-02-27 21:21:41 -05:00
commit 5a7ea3ba5d
6 changed files with 501 additions and 43 deletions

View file

@ -31,13 +31,14 @@ import static org.picketlink.common.util.StringUtil.isNotNull;
* <p/>
* Configuration Options:
*
* @author Anil.Saldhana@redhat.com
* @author bburke@redhat.com
*/
public class SALM2LoginResponseBuilder extends SAML2BindingBuilder<SALM2LoginResponseBuilder> {
public class SALM2LoginResponseBuilder {
protected static final PicketLinkLogger logger = PicketLinkLoggerFactory.getLogger();
protected List<String> roles = new LinkedList<String>();
protected String destination;
protected String issuer;
protected String nameId;
protected String nameIdFormat;
protected boolean multiValuedRoles;
@ -53,6 +54,16 @@ public class SALM2LoginResponseBuilder extends SAML2BindingBuilder<SALM2LoginRes
return this;
}
public SALM2LoginResponseBuilder destination(String destination) {
this.destination = destination;
return this;
}
public SALM2LoginResponseBuilder issuer(String issuer) {
this.issuer = issuer;
return this;
}
public SALM2LoginResponseBuilder attribute(String name, Object value) {
if (value == null) {
attributes.remove(name);
@ -95,7 +106,7 @@ public class SALM2LoginResponseBuilder extends SAML2BindingBuilder<SALM2LoginRes
return this;
}
public SALM2LoginResponseBuilder multiValuedRoles(boolean multiValuedRoles) {
public SALM2LoginResponseBuilder multiValuedRoles(boolean multiValuedRoles) {
this.multiValuedRoles = multiValuedRoles;
return this;
}
@ -105,21 +116,24 @@ public class SALM2LoginResponseBuilder extends SAML2BindingBuilder<SALM2LoginRes
return this;
}
public RedirectBindingBuilder redirectBinding() throws ConfigurationException, ProcessingException {
Document samlResponseDocument = buildDocument();
return new RedirectBindingBuilder(samlResponseDocument);
}
public PostBindingBuilder postBinding() throws ConfigurationException, ProcessingException {
Document samlResponseDocument = buildDocument();
return new PostBindingBuilder(samlResponseDocument);
}
public Document buildDocument() throws ConfigurationException, ProcessingException {
public Document buildDocument(ResponseType responseType) throws ConfigurationException, ProcessingException {
Document samlResponseDocument = null;
try {
SAML2Response docGen = new SAML2Response();
samlResponseDocument = docGen.convert(responseType);
if (logger.isTraceEnabled()) {
logger.trace("SAML Response Document: " + DocumentUtil.asString(samlResponseDocument));
}
} catch (Exception e) {
throw logger.samlAssertionMarshallError(e);
}
return samlResponseDocument;
}
public ResponseType buildModel() throws ConfigurationException, ProcessingException {
ResponseType responseType = null;
SAML2Response saml2Response = new SAML2Response();
@ -167,19 +181,7 @@ public class SALM2LoginResponseBuilder extends SAML2BindingBuilder<SALM2LoginRes
AttributeStatementType attStatement = StatementUtil.createAttributeStatement(attributes);
assertion.addStatement(attStatement);
}
try {
samlResponseDocument = saml2Response.convert(responseType);
if (logger.isTraceEnabled()) {
logger.trace("SAML Response Document: " + DocumentUtil.asString(samlResponseDocument));
}
} catch (Exception e) {
throw logger.samlAssertionMarshallError(e);
}
if (encrypt) encryptDocument(samlResponseDocument);
return samlResponseDocument;
return responseType;
}
}

View file

@ -143,10 +143,6 @@ public class SAML2BindingBuilder<T extends SAML2BindingBuilder> {
return document;
}
public String htmlResponse() throws ProcessingException, ConfigurationException, IOException {
return buildHtml(encoded(), destination, false);
}
public Response request() throws ConfigurationException, ProcessingException, IOException {
return buildResponse(document, destination, true);
}
@ -188,7 +184,7 @@ public class SAML2BindingBuilder<T extends SAML2BindingBuilder> {
return response(destination, false);
}
public Response response(String redirectUri) throws ProcessingException, ConfigurationException, IOException {
return response(destination, false);
return response(redirectUri, false);
}
public Response request(String redirect) throws ProcessingException, ConfigurationException, IOException {

View file

@ -0,0 +1,364 @@
package org.keycloak.protocol.saml;
import org.jboss.logging.Logger;
import org.picketlink.common.constants.GeneralConstants;
import org.picketlink.common.constants.JBossSAMLConstants;
import org.picketlink.common.constants.JBossSAMLURIConstants;
import org.picketlink.common.exceptions.ConfigurationException;
import org.picketlink.common.exceptions.ProcessingException;
import org.picketlink.common.util.DocumentUtil;
import org.picketlink.identity.federation.api.saml.v2.sig.SAML2Signature;
import org.picketlink.identity.federation.core.util.XMLEncryptionUtil;
import org.picketlink.identity.federation.core.wstrust.WSTrustUtil;
import org.picketlink.identity.federation.web.util.PostBindingUtil;
import org.picketlink.identity.federation.web.util.RedirectBindingUtil;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import javax.ws.rs.core.CacheControl;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriBuilder;
import javax.xml.namespace.QName;
import java.io.IOException;
import java.net.URI;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.Signature;
import java.security.cert.X509Certificate;
import static org.keycloak.util.HtmlUtils.escapeAttribute;
import static org.picketlink.common.util.StringUtil.isNotNull;
/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public class SAML2BindingBuilder2<T extends SAML2BindingBuilder2> {
protected static final Logger logger = Logger.getLogger(SAML2BindingBuilder2.class);
protected KeyPair signingKeyPair;
protected X509Certificate signingCertificate;
protected boolean sign;
protected boolean signAssertions;
protected SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.RSA_SHA1;
protected String relayState;
protected int encryptionKeySize = 128;
protected PublicKey encryptionPublicKey;
protected String encryptionAlgorithm = "AES";
protected boolean encrypt;
public T signDocument() {
this.sign = true;
return (T)this;
}
public T signAssertions() {
this.signAssertions = true;
return (T)this;
}
public T signWith(KeyPair keyPair) {
this.signingKeyPair = keyPair;
return (T)this;
}
public T signWith(PrivateKey privateKey, PublicKey publicKey) {
this.signingKeyPair = new KeyPair(publicKey, privateKey);
return (T)this;
}
public T signWith(KeyPair keyPair, X509Certificate cert) {
this.signingKeyPair = keyPair;
this.signingCertificate = cert;
return (T)this;
}
public T signWith(PrivateKey privateKey, PublicKey publicKey, X509Certificate cert) {
this.signingKeyPair = new KeyPair(publicKey, privateKey);
this.signingCertificate = cert;
return (T)this;
}
public T signatureAlgorithm(SignatureAlgorithm alg) {
this.signatureAlgorithm = alg;
return (T)this;
}
public T encrypt(PublicKey publicKey) {
encrypt = true;
encryptionPublicKey = publicKey;
return (T)this;
}
public T encryptionAlgorithm(String alg) {
this.encryptionAlgorithm = alg;
return (T)this;
}
public T encryptionKeySize(int size) {
this.encryptionKeySize = size;
return (T)this;
}
public T relayState(String relayState) {
this.relayState = relayState;
return (T)this;
}
public class PostBindingBuilder {
protected Document document;
public PostBindingBuilder(Document document) throws ProcessingException {
if (encrypt) encryptDocument(document);
this.document = document;
if (signAssertions) {
signAssertion(document);
}
if (sign) {
signDocument(document);
}
}
public String encoded() throws ProcessingException, ConfigurationException, IOException {
byte[] responseBytes = org.picketlink.identity.federation.core.saml.v2.util.DocumentUtil.getDocumentAsString(document).getBytes("UTF-8");
return PostBindingUtil.base64Encode(new String(responseBytes));
}
public Document getDocument() {
return document;
}
public Response request(String actionUrl) throws ConfigurationException, ProcessingException, IOException {
return buildResponse(document, actionUrl, true);
}
public Response response(String actionUrl) throws ConfigurationException, ProcessingException, IOException {
return buildResponse(document, actionUrl, false);
}
}
public class RedirectBindingBuilder {
protected Document document;
public RedirectBindingBuilder(Document document) throws ProcessingException {
if (encrypt) encryptDocument(document);
this.document = document;
if (signAssertions) {
signAssertion(document);
}
}
public Document getDocument() {
return document;
}
public URI responseUri(String redirectUri, boolean asRequest) throws ConfigurationException, ProcessingException, IOException {
String samlParameterName = GeneralConstants.SAML_RESPONSE_KEY;
if (asRequest) {
samlParameterName = GeneralConstants.SAML_REQUEST_KEY;
}
return generateRedirectUri(samlParameterName, redirectUri, document);
}
public Response response(String redirectUri) throws ProcessingException, ConfigurationException, IOException {
return response(redirectUri, false);
}
public Response request(String redirect) throws ProcessingException, ConfigurationException, IOException {
return response(redirect, true);
}
private Response response(String redirectUri, boolean asRequest) throws ProcessingException, ConfigurationException, IOException {
URI uri = responseUri(redirectUri, asRequest);
if (logger.isDebugEnabled()) logger.trace("redirect-binding uri: " + uri.toString());
CacheControl cacheControl = new CacheControl();
cacheControl.setNoCache(true);
return Response.status(302).location(uri)
.header("Pragma", "no-cache")
.header("Cache-Control", "no-cache, no-store").build();
}
}
private String getSAMLNSPrefix(Document samlResponseDocument) {
Node assertionElement = samlResponseDocument.getDocumentElement()
.getElementsByTagNameNS(JBossSAMLURIConstants.ASSERTION_NSURI.get(), JBossSAMLConstants.ASSERTION.get()).item(0);
if (assertionElement == null) {
throw new IllegalStateException("Unable to find assertion in saml response document");
}
return assertionElement.getPrefix();
}
protected void encryptDocument(Document samlDocument) throws ProcessingException {
String samlNSPrefix = getSAMLNSPrefix(samlDocument);
try {
QName encryptedAssertionElementQName = new QName(JBossSAMLURIConstants.ASSERTION_NSURI.get(),
JBossSAMLConstants.ENCRYPTED_ASSERTION.get(), samlNSPrefix);
byte[] secret = WSTrustUtil.createRandomSecret(encryptionKeySize / 8);
SecretKey secretKey = new SecretKeySpec(secret, encryptionAlgorithm);
// encrypt the Assertion element and replace it with a EncryptedAssertion element.
XMLEncryptionUtil.encryptElement(new QName(JBossSAMLURIConstants.ASSERTION_NSURI.get(),
JBossSAMLConstants.ASSERTION.get(), samlNSPrefix), samlDocument, encryptionPublicKey,
secretKey, encryptionKeySize, encryptedAssertionElementQName, true);
} catch (Exception e) {
throw new ProcessingException("failed to encrypt", e);
}
}
protected void signDocument(Document samlDocument) throws ProcessingException {
String signatureMethod = signatureAlgorithm.getXmlSignatureMethod();
String signatureDigestMethod = signatureAlgorithm.getXmlSignatureDigestMethod();
SAML2Signature samlSignature = new SAML2Signature();
if (signatureMethod != null) {
samlSignature.setSignatureMethod(signatureMethod);
}
if (signatureDigestMethod != null) {
samlSignature.setDigestMethod(signatureDigestMethod);
}
Node nextSibling = samlSignature.getNextSiblingOfIssuer(samlDocument);
samlSignature.setNextSibling(nextSibling);
if (signingCertificate != null) {
samlSignature.setX509Certificate(signingCertificate);
}
samlSignature.signSAMLDocument(samlDocument, signingKeyPair);
}
protected void signAssertion(Document samlDocument) throws ProcessingException {
Element originalAssertionElement = DocumentUtil.getChildElement(samlDocument.getDocumentElement(), new QName(JBossSAMLURIConstants.ASSERTION_NSURI.get(), JBossSAMLConstants.ASSERTION.get()));
if (originalAssertionElement == null) return;
Node clonedAssertionElement = originalAssertionElement.cloneNode(true);
Document temporaryDocument;
try {
temporaryDocument = DocumentUtil.createDocument();
} catch (ConfigurationException e) {
throw new ProcessingException(e);
}
temporaryDocument.adoptNode(clonedAssertionElement);
temporaryDocument.appendChild(clonedAssertionElement);
signDocument(temporaryDocument);
samlDocument.adoptNode(clonedAssertionElement);
Element parentNode = (Element) originalAssertionElement.getParentNode();
parentNode.replaceChild(clonedAssertionElement, originalAssertionElement);
}
protected Response buildResponse(Document responseDoc, String actionUrl, boolean asRequest) throws ProcessingException, ConfigurationException, IOException {
String str = buildHtmlPostResponse(responseDoc, actionUrl, asRequest);
CacheControl cacheControl = new CacheControl();
cacheControl.setNoCache(true);
return Response.ok(str, MediaType.TEXT_HTML_TYPE)
.header("Pragma", "no-cache")
.header("Cache-Control", "no-cache, no-store").build();
}
protected String buildHtmlPostResponse(Document responseDoc, String actionUrl, boolean asRequest) throws ProcessingException, ConfigurationException, IOException {
byte[] responseBytes = DocumentUtil.getDocumentAsString(responseDoc).getBytes("UTF-8");
String samlResponse = PostBindingUtil.base64Encode(new String(responseBytes));
return buildHtml(samlResponse, actionUrl, asRequest);
}
protected String buildHtml(String samlResponse, String actionUrl, boolean asRequest) {
StringBuilder builder = new StringBuilder();
String key = GeneralConstants.SAML_RESPONSE_KEY;
if (asRequest) {
key = GeneralConstants.SAML_REQUEST_KEY;
}
builder.append("<HTML>");
builder.append("<HEAD>");
builder.append("<TITLE>HTTP Post Binding Response (Response)</TITLE>");
builder.append("</HEAD>");
builder.append("<BODY Onload=\"document.forms[0].submit()\">");
builder.append("<FORM METHOD=\"POST\" ACTION=\"" + actionUrl + "\">");
builder.append("<INPUT TYPE=\"HIDDEN\" NAME=\"" + key + "\"" + " VALUE=\"" + samlResponse + "\"/>");
if (isNotNull(relayState)) {
builder.append("<INPUT TYPE=\"HIDDEN\" NAME=\"RelayState\" " + "VALUE=\"" + escapeAttribute(relayState) + "\"/>");
}
builder.append("<NOSCRIPT>");
builder.append("<P>JavaScript is disabled. We strongly recommend to enable it. Click the button below to continue.</P>");
builder.append("<INPUT TYPE=\"SUBMIT\" VALUE=\"CONTINUE\" />");
builder.append("</NOSCRIPT>");
builder.append("</FORM></BODY></HTML>");
return builder.toString();
}
protected String base64Encoded(Document document) throws ConfigurationException, ProcessingException, IOException {
String documentAsString = org.picketlink.identity.federation.core.saml.v2.util.DocumentUtil.getDocumentAsString(document);
logger.debugv("saml docment: {0}", documentAsString);
byte[] responseBytes = documentAsString.getBytes("UTF-8");
return RedirectBindingUtil.deflateBase64URLEncode(responseBytes);
}
protected URI generateRedirectUri(String samlParameterName, String redirectUri, Document document) throws ConfigurationException, ProcessingException, IOException {
UriBuilder builder = UriBuilder.fromUri(redirectUri)
.replaceQuery(null)
.queryParam(samlParameterName, base64Encoded(document));
if (relayState != null) {
builder.queryParam("RelayState", relayState);
}
if (sign) {
builder.queryParam(GeneralConstants.SAML_SIG_ALG_REQUEST_KEY, signatureAlgorithm.getJavaSignatureAlgorithm());
URI uri = builder.build();
String rawQuery = uri.getRawQuery();
Signature signature = signatureAlgorithm.createSignature();
byte[] sig = new byte[0];
try {
signature.initSign(signingKeyPair.getPrivate());
signature.update(rawQuery.getBytes("UTF-8"));
sig = signature.sign();
} catch (Exception e) {
throw new ProcessingException(e);
}
String encodedSig = RedirectBindingUtil.base64URLEncode(sig);
builder.queryParam(GeneralConstants.SAML_SIGNATURE_REQUEST_KEY, encodedSig);
}
return builder.build();
}
public RedirectBindingBuilder redirectBinding(Document document) throws ProcessingException {
return new RedirectBindingBuilder(document);
}
public PostBindingBuilder postBinding(Document document) throws ProcessingException {
return new PostBindingBuilder(document);
}
}

View file

@ -9,11 +9,15 @@ import org.keycloak.models.ClaimMask;
import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.ProtocolMapperModel;
import org.keycloak.models.RealmModel;
import org.keycloak.models.RoleModel;
import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.protocol.LoginProtocol;
import org.keycloak.protocol.ProtocolMapper;
import org.keycloak.protocol.saml.mappers.SAMLLoginResponseMapper;
import org.keycloak.services.managers.ClientSessionCode;
import org.keycloak.services.managers.ResourceAdminManager;
import org.keycloak.services.resources.RealmsResource;
@ -25,13 +29,16 @@ import org.picketlink.common.exceptions.ConfigurationException;
import org.picketlink.common.exceptions.ParsingException;
import org.picketlink.common.exceptions.ProcessingException;
import org.picketlink.identity.federation.core.saml.v2.constants.X500SAMLProfileConstants;
import org.picketlink.identity.federation.saml.v2.protocol.ResponseType;
import org.picketlink.identity.federation.web.handlers.saml2.SAML2LogOutHandler;
import org.w3c.dom.Document;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriInfo;
import java.io.IOException;
import java.security.PublicKey;
import java.util.Set;
import java.util.UUID;
/**
@ -243,7 +250,6 @@ public class SamlProtocol implements LoginProtocol {
SALM2LoginResponseBuilder builder = new SALM2LoginResponseBuilder();
builder.requestID(requestID)
.relayState(relayState)
.destination(redirectUri)
.issuer(responseIssuer)
.requestIssuer(clientSession.getClient().getClientId())
@ -260,19 +266,33 @@ public class SamlProtocol implements LoginProtocol {
builder.roles(roleModel.getName());
}
}
if (!includeAuthnStatement(client)) {
builder.disableAuthnStatement(true);
}
Document samlDocument = null;
try {
ResponseType samlModel = builder.buildModel();
samlModel = transformLoginResponse(session, samlModel, client, userSession, clientSession);
samlDocument = builder.buildDocument(samlModel);
} catch (Exception e) {
logger.error("failed", e);
return Flows.forwardToSecurityFailurePage(session, realm, uriInfo, "Failed to process response");
}
SAML2BindingBuilder2 bindingBuilder = new SAML2BindingBuilder2();
bindingBuilder.relayState(relayState);
if (requiresRealmSignature(client)) {
builder.signatureAlgorithm(getSignatureAlgorithm(client))
bindingBuilder.signatureAlgorithm(getSignatureAlgorithm(client))
.signWith(realm.getPrivateKey(), realm.getPublicKey(), realm.getCertificate())
.signDocument();
}
if (requiresAssertionSignature(client)) {
builder.signatureAlgorithm(getSignatureAlgorithm(client))
bindingBuilder.signatureAlgorithm(getSignatureAlgorithm(client))
.signWith(realm.getPrivateKey(), realm.getPublicKey(), realm.getCertificate())
.signAssertions();
}
if (!includeAuthnStatement(client)) {
builder.disableAuthnStatement(true);
}
if (requiresEncryption(client)) {
PublicKey publicKey = null;
try {
@ -281,13 +301,13 @@ public class SamlProtocol implements LoginProtocol {
logger.error("failed", e);
return Flows.forwardToSecurityFailurePage(session, realm, uriInfo, "Failed to process response");
}
builder.encrypt(publicKey);
bindingBuilder.encrypt(publicKey);
}
try {
if (isPostBinding(clientSession)) {
return builder.postBinding().response();
return bindingBuilder.postBinding(samlDocument).response(redirectUri);
} else {
return builder.redirectBinding().response();
return bindingBuilder.redirectBinding(samlDocument).response(redirectUri);
}
} catch (Exception e) {
logger.error("failed", e);
@ -337,6 +357,24 @@ public class SamlProtocol implements LoginProtocol {
}
}
public ResponseType transformLoginResponse(KeycloakSession session, ResponseType response, ClientModel client,
UserSessionModel userSession, ClientSessionModel clientSession) {
Set<ProtocolMapperModel> mappings = client.getProtocolMappers();
KeycloakSessionFactory sessionFactory = session.getKeycloakSessionFactory();
for (ProtocolMapperModel mapping : mappings) {
if (!mapping.getProtocol().equals(SamlProtocol.LOGIN_PROTOCOL)) continue;
ProtocolMapper mapper = (ProtocolMapper)sessionFactory.getProviderFactory(ProtocolMapper.class, mapping.getProtocolMapper());
if (mapper == null || !(mapper instanceof SAMLLoginResponseMapper)) continue;
response = ((SAMLLoginResponseMapper)mapper).transformLoginResponse(response, mapping, session, userSession, clientSession);
}
return response;
}
@Override

View file

@ -0,0 +1,40 @@
package org.keycloak.protocol.saml.mappers;
import org.keycloak.Config;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.protocol.ProtocolMapper;
import org.keycloak.protocol.oidc.OIDCLoginProtocol;
import org.keycloak.protocol.saml.SamlProtocol;
/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public abstract class AbstractSAMLProtocolMapper implements ProtocolMapper {
@Override
public String getProtocol() {
return SamlProtocol.LOGIN_PROTOCOL;
}
@Override
public void close() {
}
@Override
public final ProtocolMapper create(KeycloakSession session) {
throw new RuntimeException("UNSUPPORTED METHOD");
}
@Override
public void init(Config.Scope config) {
}
@Override
public void postInit(KeycloakSessionFactory factory) {
}
}

View file

@ -0,0 +1,18 @@
package org.keycloak.protocol.saml.mappers;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.ProtocolMapperModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.representations.AccessToken;
import org.picketlink.identity.federation.saml.v2.protocol.ResponseType;
/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public interface SAMLLoginResponseMapper {
ResponseType transformLoginResponse(ResponseType response, ProtocolMapperModel mappingModel, KeycloakSession session,
UserSessionModel userSession, ClientSessionModel clientSession);
}