diff --git a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/AbstractSamlAuthenticationHandler.java b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/AbstractSamlAuthenticationHandler.java index be7de79443..da55815777 100644 --- a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/AbstractSamlAuthenticationHandler.java +++ b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/AbstractSamlAuthenticationHandler.java @@ -84,6 +84,7 @@ import org.keycloak.dom.saml.v2.protocol.ExtensionsType; import org.keycloak.rotation.KeyLocator; import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator; import org.keycloak.saml.processing.core.util.XMLEncryptionUtil; +import org.keycloak.saml.validators.DestinationValidator; /** * @@ -97,6 +98,7 @@ public abstract class AbstractSamlAuthenticationHandler implements SamlAuthentic protected final SamlSessionStore sessionStore; protected final SamlDeployment deployment; protected AuthChallenge challenge; + private final DestinationValidator destinationValidator = DestinationValidator.forProtocolMap(null); public AbstractSamlAuthenticationHandler(HttpFacade facade, SamlDeployment deployment, SamlSessionStore sessionStore) { this.facade = facade; @@ -145,7 +147,7 @@ public abstract class AbstractSamlAuthenticationHandler implements SamlAuthentic holder = SAMLRequestParser.parseRequestPostBinding(samlRequest); } RequestAbstractType requestAbstractType = (RequestAbstractType) holder.getSamlObject(); - if (!requestUri.equals(requestAbstractType.getDestination().toString())) { + if (! destinationValidator.validate(requestUri, requestAbstractType.getDestination())) { log.error("expected destination '" + requestUri + "' got '" + requestAbstractType.getDestination() + "'"); return AuthOutcome.FAILED; } @@ -186,7 +188,7 @@ public abstract class AbstractSamlAuthenticationHandler implements SamlAuthentic } final StatusResponseType statusResponse = (StatusResponseType) holder.getSamlObject(); // validate destination - if (!requestUri.equals(statusResponse.getDestination())) { + if (! destinationValidator.validate(requestUri, statusResponse.getDestination())) { log.error("Request URI '" + requestUri + "' does not match SAML request destination '" + statusResponse.getDestination() + "'"); return AuthOutcome.FAILED; } diff --git a/saml-core/src/main/java/org/keycloak/saml/validators/DestinationValidator.java b/saml-core/src/main/java/org/keycloak/saml/validators/DestinationValidator.java new file mode 100644 index 0000000000..a160790335 --- /dev/null +++ b/saml-core/src/main/java/org/keycloak/saml/validators/DestinationValidator.java @@ -0,0 +1,136 @@ +/* + * Copyright 2018 Red Hat, Inc. and/or its affiliates + * and other contributors as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.keycloak.saml.validators; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Check that Destination field in SAML request/response is either unset or matches the expected one. + * @author hmlnarik + */ +public class DestinationValidator { + + private static final Pattern PROTOCOL_MAP_PATTERN = Pattern.compile("\\s*([a-zA-Z][a-zA-Z\\d+-.]*)\\s*=\\s*(\\d+)\\s*"); + private static final String[] DEFAULT_PROTOCOL_TO_PORT_MAP = new String[] { "http=80", "https=443" }; + + private final Map knownPorts; + private final Map knownProtocols; + + private DestinationValidator(Map knownPorts, Map knownProtocols) { + this.knownPorts = knownPorts; + this.knownProtocols = knownProtocols; + } + + public static DestinationValidator forProtocolMap(String[] protocolMappings) { + if (protocolMappings == null) { + protocolMappings = DEFAULT_PROTOCOL_TO_PORT_MAP; + } + + Map knownPorts = new HashMap<>(); + Map knownProtocols = new HashMap<>(); + + for (String protocolMapping : protocolMappings) { + Matcher m = PROTOCOL_MAP_PATTERN.matcher(protocolMapping); + if (m.matches()) { + Integer port = Integer.valueOf(m.group(2)); + String proto = m.group(1); + + knownPorts.put(proto, port); + knownProtocols.put(port, proto); + } + } + + return new DestinationValidator(knownPorts, knownProtocols); + } + + public boolean validate(String expectedDestination, String actualDestination) { + try { + return validate(expectedDestination == null ? null : URI.create(expectedDestination), actualDestination); + } catch (IllegalArgumentException ex) { + return false; + } + } + + public boolean validate(String expectedDestination, URI actualDestination) { + try { + return validate(expectedDestination == null ? null : URI.create(expectedDestination), actualDestination); + } catch (IllegalArgumentException ex) { + return false; + } + } + + public boolean validate(URI expectedDestination, String actualDestination) { + try { + return validate(expectedDestination, actualDestination == null ? null : URI.create(actualDestination)); + } catch (IllegalArgumentException ex) { + return false; + } + } + + public boolean validate(URI expectedDestination, URI actualDestination) { + if (actualDestination == null) { + return true; // destination is optional + } + + if (expectedDestination == null) { + return false; // expected destination is mandatory + } + + if (Objects.equals(expectedDestination, actualDestination)) { + return true; + } + + Integer portByScheme = knownPorts.get(expectedDestination.getScheme()); + String protocolByPort = knownProtocols.get(expectedDestination.getPort()); + + URI updatedUri = null; + try { + if (expectedDestination.getPort() < 0 && portByScheme != null) { + updatedUri = new URI( + expectedDestination.getScheme(), + expectedDestination.getUserInfo(), + expectedDestination.getHost(), + portByScheme, + expectedDestination.getPath(), + expectedDestination.getQuery(), + expectedDestination.getFragment() + ); + } else if (expectedDestination.getPort() >= 0 && Objects.equals(protocolByPort, expectedDestination.getScheme())) { + updatedUri = new URI( + expectedDestination.getScheme(), + expectedDestination.getUserInfo(), + expectedDestination.getHost(), + -1, + expectedDestination.getPath(), + expectedDestination.getQuery(), + expectedDestination.getFragment() + ); + } + } catch (URISyntaxException ex) { + return false; + } + + return Objects.equals(updatedUri, actualDestination); + } + +} diff --git a/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java b/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java index f6cb9f1340..a65f86aac5 100755 --- a/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java +++ b/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java @@ -87,6 +87,7 @@ import java.util.List; import org.keycloak.rotation.HardcodedKeyLocator; import org.keycloak.rotation.KeyLocator; import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator; +import org.keycloak.saml.validators.DestinationValidator; import org.w3c.dom.Element; import java.util.*; @@ -111,6 +112,7 @@ public class SAMLEndpoint { protected SAMLIdentityProviderConfig config; protected IdentityProvider.AuthenticationCallback callback; protected SAMLIdentityProvider provider; + private final DestinationValidator destinationValidator; @Context private KeycloakSession session; @@ -122,11 +124,12 @@ public class SAMLEndpoint { private HttpHeaders headers; - public SAMLEndpoint(RealmModel realm, SAMLIdentityProvider provider, SAMLIdentityProviderConfig config, IdentityProvider.AuthenticationCallback callback) { + public SAMLEndpoint(RealmModel realm, SAMLIdentityProvider provider, SAMLIdentityProviderConfig config, IdentityProvider.AuthenticationCallback callback, DestinationValidator destinationValidator) { this.realm = realm; this.config = config; this.callback = callback; this.provider = provider; + this.destinationValidator = destinationValidator; } @GET @@ -238,7 +241,7 @@ public class SAMLEndpoint { SAMLDocumentHolder holder = extractRequestDocument(samlRequest); RequestAbstractType requestAbstractType = (RequestAbstractType) holder.getSamlObject(); // validate destination - if (requestAbstractType.getDestination() != null && !session.getContext().getUri().getAbsolutePath().equals(requestAbstractType.getDestination())) { + if (! destinationValidator.validate(session.getContext().getUri().getAbsolutePath(), requestAbstractType.getDestination())) { event.event(EventType.IDENTITY_PROVIDER_RESPONSE); event.detail(Details.REASON, "invalid_destination"); event.error(Errors.INVALID_SAML_RESPONSE); @@ -456,7 +459,7 @@ public class SAMLEndpoint { SAMLDocumentHolder holder = extractResponseDocument(samlResponse); StatusResponseType statusResponse = (StatusResponseType)holder.getSamlObject(); // validate destination - if (statusResponse.getDestination() != null && !session.getContext().getUri().getAbsolutePath().toString().equals(statusResponse.getDestination())) { + if (! destinationValidator.validate(session.getContext().getUri().getAbsolutePath(), statusResponse.getDestination())) { event.event(EventType.IDENTITY_PROVIDER_RESPONSE); event.detail(Details.REASON, "invalid_destination"); event.error(Errors.INVALID_SAML_RESPONSE); diff --git a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java index 5a9a4c7390..4069ead98b 100755 --- a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java +++ b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java @@ -35,6 +35,7 @@ import org.keycloak.saml.*; import org.keycloak.saml.common.constants.GeneralConstants; import org.keycloak.saml.common.constants.JBossSAMLURIConstants; import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator; +import org.keycloak.saml.validators.DestinationValidator; import org.keycloak.sessions.AuthenticationSessionModel; import javax.ws.rs.core.MediaType; @@ -50,13 +51,15 @@ import java.util.TreeSet; */ public class SAMLIdentityProvider extends AbstractIdentityProvider { protected static final Logger logger = Logger.getLogger(SAMLIdentityProvider.class); - public SAMLIdentityProvider(KeycloakSession session, SAMLIdentityProviderConfig config) { + private final DestinationValidator destinationValidator; + public SAMLIdentityProvider(KeycloakSession session, SAMLIdentityProviderConfig config, DestinationValidator destinationValidator) { super(session, config); + this.destinationValidator = destinationValidator; } @Override public Object callback(RealmModel realm, AuthenticationCallback callback, EventBuilder event) { - return new SAMLEndpoint(realm, this, getConfig(), callback); + return new SAMLEndpoint(realm, this, getConfig(), callback, destinationValidator); } @Override diff --git a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProviderFactory.java b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProviderFactory.java index 0170424921..a0eb47150f 100755 --- a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProviderFactory.java +++ b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProviderFactory.java @@ -16,6 +16,7 @@ */ package org.keycloak.broker.saml; +import org.keycloak.Config.Scope; import org.keycloak.broker.provider.AbstractIdentityProviderFactory; import org.keycloak.dom.saml.v2.metadata.EndpointType; import org.keycloak.dom.saml.v2.metadata.EntitiesDescriptorType; @@ -29,6 +30,7 @@ import org.keycloak.saml.common.constants.JBossSAMLURIConstants; import org.keycloak.saml.common.exceptions.ParsingException; import org.keycloak.saml.common.util.DocumentUtil; import org.keycloak.saml.processing.core.parsers.saml.SAMLParser; +import org.keycloak.saml.validators.DestinationValidator; import org.w3c.dom.Element; import javax.xml.namespace.QName; @@ -44,6 +46,8 @@ public class SAMLIdentityProviderFactory extends AbstractIdentityProviderFactory public static final String PROVIDER_ID = "saml"; + private DestinationValidator destinationValidator; + @Override public String getName() { return "SAML v2.0"; @@ -51,7 +55,7 @@ public class SAMLIdentityProviderFactory extends AbstractIdentityProviderFactory @Override public SAMLIdentityProvider create(KeycloakSession session, IdentityProviderModel model) { - return new SAMLIdentityProvider(session, new SAMLIdentityProviderConfig(model)); + return new SAMLIdentityProvider(session, new SAMLIdentityProviderConfig(model), destinationValidator); } @Override @@ -159,4 +163,10 @@ public class SAMLIdentityProviderFactory extends AbstractIdentityProviderFactory return PROVIDER_ID; } + @Override + public void init(Scope config) { + super.init(config); + + this.destinationValidator = DestinationValidator.forProtocolMap(config.getArray("knownProtocols")); + } } diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolFactory.java b/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolFactory.java index 32212a3edf..19514042f9 100755 --- a/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolFactory.java +++ b/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolFactory.java @@ -18,7 +18,6 @@ package org.keycloak.protocol.saml; import org.keycloak.Config; -import org.keycloak.OAuth2Constants; import org.keycloak.events.EventBuilder; import org.keycloak.models.ClientModel; import org.keycloak.models.ClientScopeModel; @@ -33,18 +32,16 @@ import org.keycloak.protocol.saml.mappers.RoleListMapper; import org.keycloak.protocol.saml.mappers.UserPropertyAttributeStatementMapper; import org.keycloak.representations.idm.CertificateRepresentation; import org.keycloak.representations.idm.ClientRepresentation; -import org.keycloak.representations.idm.ClientScopeRepresentation; import org.keycloak.saml.SignatureAlgorithm; import org.keycloak.saml.common.constants.JBossSAMLURIConstants; import org.keycloak.saml.processing.core.saml.v2.constants.X500SAMLProfileConstants; +import org.keycloak.saml.validators.DestinationValidator; import javax.xml.crypto.dsig.CanonicalizationMethod; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.regex.Matcher; -import java.util.regex.Pattern; /** * @author Bill Burke @@ -52,29 +49,14 @@ import java.util.regex.Pattern; */ public class SamlProtocolFactory extends AbstractLoginProtocolFactory { - private static final Pattern PROTOCOL_MAP_PATTERN = Pattern.compile("\\s*([a-zA-Z][a-zA-Z\\d+-.]*)\\s*=\\s*(\\d+)\\s*"); - private static final String[] DEFAULT_PROTOCOL_TO_PORT_MAP = new String[] { "http=80", "https=443" }; - public static final String SCOPE_ROLE_LIST = "role_list"; private static final String ROLE_LIST_CONSENT_TEXT = "${samlRoleListScopeConsentText}"; - private final Map knownPorts = new HashMap<>(); - private final Map knownProtocols = new HashMap<>(); - - private void addToProtocolPortMaps(String protocolMapping) { - Matcher m = PROTOCOL_MAP_PATTERN.matcher(protocolMapping); - if (m.matches()) { - Integer port = Integer.valueOf(m.group(2)); - String proto = m.group(1); - - knownPorts.put(port, proto); - knownProtocols.put(proto, port); - } - } + private DestinationValidator destinationValidator; @Override public Object createProtocolEndpoint(RealmModel realm, EventBuilder event) { - return new SamlService(realm, event, knownProtocols, knownPorts); + return new SamlService(realm, event, destinationValidator); } @Override @@ -87,14 +69,7 @@ public class SamlProtocolFactory extends AbstractLoginProtocolFactory { //PicketLinkCoreSTS sts = PicketLinkCoreSTS.instance(); //sts.installDefaultConfiguration(); - String[] protocolMappings = config.getArray("knownProtocols"); - if (protocolMappings == null) { - protocolMappings = DEFAULT_PROTOCOL_TO_PORT_MAP; - } - - for (String protocolMapping : protocolMappings) { - addToProtocolPortMaps(protocolMapping); - } + this.destinationValidator = DestinationValidator.forProtocolMap(config.getArray("knownProtocols")); } @Override diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlService.java b/services/src/main/java/org/keycloak/protocol/saml/SamlService.java index d3da17676a..b8a32f5c43 100755 --- a/services/src/main/java/org/keycloak/protocol/saml/SamlService.java +++ b/services/src/main/java/org/keycloak/protocol/saml/SamlService.java @@ -85,8 +85,8 @@ import org.keycloak.rotation.HardcodedKeyLocator; import org.keycloak.rotation.KeyLocator; import org.keycloak.saml.SPMetadataDescriptor; import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator; +import org.keycloak.saml.validators.DestinationValidator; import org.keycloak.sessions.AuthenticationSessionModel; -import java.util.Map; /** * Resource class for the saml connect token service @@ -98,13 +98,11 @@ public class SamlService extends AuthorizationEndpointBase { protected static final Logger logger = Logger.getLogger(SamlService.class); - private final Map knownPorts; - private final Map knownProtocols; + private final DestinationValidator destinationValidator; - public SamlService(RealmModel realm, EventBuilder event, Map knownPorts, Map knownProtocols) { + public SamlService(RealmModel realm, EventBuilder event, DestinationValidator destinationValidator) { super(realm, event); - this.knownPorts = knownPorts; - this.knownProtocols = knownProtocols; + this.destinationValidator = destinationValidator; } public abstract class BindingProtocol { @@ -147,7 +145,7 @@ public class SamlService extends AuthorizationEndpointBase { StatusResponseType statusResponse = (StatusResponseType) holder.getSamlObject(); // validate destination - if (statusResponse.getDestination() != null && !session.getContext().getUri().getAbsolutePath().toString().equals(statusResponse.getDestination())) { + if (! destinationValidator.validate(session.getContext().getUri().getAbsolutePath(), statusResponse.getDestination())) { event.detail(Details.REASON, "invalid_destination"); event.error(Errors.INVALID_SAML_LOGOUT_RESPONSE); return ErrorPage.error(session, null, Response.Status.BAD_REQUEST, Messages.INVALID_REQUEST); @@ -272,7 +270,7 @@ public class SamlService extends AuthorizationEndpointBase { event.error(Errors.INVALID_SAML_AUTHN_REQUEST); return ErrorPage.error(session, null, Response.Status.BAD_REQUEST, Messages.INVALID_REQUEST); } - if (! isValidDestination(requestAbstractType.getDestination())) { + if (! destinationValidator.validate(session.getContext().getUri().getAbsolutePath(), requestAbstractType.getDestination())) { event.detail(Details.REASON, "invalid_destination"); event.error(Errors.INVALID_SAML_AUTHN_REQUEST); return ErrorPage.error(session, null, Response.Status.BAD_REQUEST, Messages.INVALID_REQUEST); @@ -376,7 +374,7 @@ public class SamlService extends AuthorizationEndpointBase { event.error(Errors.INVALID_SAML_LOGOUT_REQUEST); return ErrorPage.error(session, null, Response.Status.BAD_REQUEST, Messages.INVALID_REQUEST); } - if (! isValidDestination(logoutRequest.getDestination())) { + if (! destinationValidator.validate(logoutRequest.getDestination(), session.getContext().getUri().getAbsolutePath())) { event.detail(Details.REASON, "invalid_destination"); event.error(Errors.INVALID_SAML_LOGOUT_REQUEST); return ErrorPage.error(session, null, Response.Status.BAD_REQUEST, Messages.INVALID_REQUEST); @@ -696,35 +694,10 @@ public class SamlService extends AuthorizationEndpointBase { @NoCache @Consumes({"application/soap+xml",MediaType.TEXT_XML}) public Response soapBinding(InputStream inputStream) { - SamlEcpProfileService bindingService = new SamlEcpProfileService(realm, event, knownPorts, knownProtocols); + SamlEcpProfileService bindingService = new SamlEcpProfileService(realm, event, destinationValidator); ResteasyProviderFactory.getInstance().injectProperties(bindingService); return bindingService.authenticate(inputStream); } - - private boolean isValidDestination(URI destination) { - if (destination == null) { - return true; // destination is optional - } - - URI expected = session.getContext().getUri().getAbsolutePath(); - - if (Objects.equals(expected, destination)) { - return true; - } - - Integer portByScheme = knownPorts.get(expected.getScheme()); - if (expected.getPort() < 0 && portByScheme != null) { - return Objects.equals(session.getContext().getUri().getRequestUriBuilder().port(portByScheme).build(), destination); - } - - String protocolByPort = knownProtocols.get(expected.getPort()); - if (expected.getPort() >= 0 && Objects.equals(protocolByPort, expected.getScheme())) { - return Objects.equals(session.getContext().getUri().getRequestUriBuilder().port(-1).build(), destination); - } - - return false; - } - } diff --git a/services/src/main/java/org/keycloak/protocol/saml/profile/ecp/SamlEcpProfileService.java b/services/src/main/java/org/keycloak/protocol/saml/profile/ecp/SamlEcpProfileService.java index 3a1ae98f1b..eac5d9809d 100755 --- a/services/src/main/java/org/keycloak/protocol/saml/profile/ecp/SamlEcpProfileService.java +++ b/services/src/main/java/org/keycloak/protocol/saml/profile/ecp/SamlEcpProfileService.java @@ -35,6 +35,7 @@ import org.keycloak.saml.common.constants.JBossSAMLConstants; import org.keycloak.saml.common.constants.JBossSAMLURIConstants; import org.keycloak.saml.common.exceptions.ConfigurationException; import org.keycloak.saml.common.exceptions.ProcessingException; +import org.keycloak.saml.validators.DestinationValidator; import org.keycloak.sessions.AuthenticationSessionModel; import org.w3c.dom.Document; @@ -54,8 +55,8 @@ public class SamlEcpProfileService extends SamlService { private static final String NS_PREFIX_SAML_PROTOCOL = "samlp"; private static final String NS_PREFIX_SAML_ASSERTION = "saml"; - public SamlEcpProfileService(RealmModel realm, EventBuilder event, Map knownPorts, Map knownProtocols) { - super(realm, event, knownPorts, knownProtocols); + public SamlEcpProfileService(RealmModel realm, EventBuilder event, DestinationValidator destinationValidator) { + super(realm, event, destinationValidator); } public Response authenticate(InputStream inputStream) {