diff --git a/core/src/main/java/org/keycloak/representations/CodeJWT.java b/core/src/main/java/org/keycloak/representations/CodeJWT.java deleted file mode 100644 index df43deebf4..0000000000 --- a/core/src/main/java/org/keycloak/representations/CodeJWT.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2017 Red Hat, Inc. and/or its affiliates - * and other contributors as indicated by the @author tags. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.keycloak.representations; - -import com.fasterxml.jackson.annotation.JsonProperty; - -/** - * @author Marek Posolda - */ -public class CodeJWT extends JsonWebToken { - - @JsonProperty("uss") - protected String userSessionId; - - public String getUserSessionId() { - return userSessionId; - } - - public CodeJWT userSessionId(String userSessionId) { - this.userSessionId = userSessionId; - return this; - } - -} diff --git a/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/InfinispanCodeToTokenStoreProvider.java b/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/InfinispanCodeToTokenStoreProvider.java index 463b777bfc..68d94d00c5 100644 --- a/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/InfinispanCodeToTokenStoreProvider.java +++ b/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/InfinispanCodeToTokenStoreProvider.java @@ -17,6 +17,7 @@ package org.keycloak.models.sessions.infinispan; +import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; @@ -44,28 +45,43 @@ public class InfinispanCodeToTokenStoreProvider implements CodeToTokenStoreProvi this.codeCache = actionKeyCache; } - @Override - public boolean putIfAbsent(UUID codeId) { - ActionTokenValueEntity tokenValue = new ActionTokenValueEntity(null); - int lifespanInSeconds = session.getContext().getRealm().getAccessCodeLifespan(); + @Override + public void put(UUID codeId, int lifespanSeconds, Map codeData) { + ActionTokenValueEntity tokenValue = new ActionTokenValueEntity(codeData); try { BasicCache cache = codeCache.get(); - ActionTokenValueEntity existing = cache.putIfAbsent(codeId, tokenValue, lifespanInSeconds, TimeUnit.SECONDS); - return existing == null; + cache.put(codeId, tokenValue, lifespanSeconds, TimeUnit.SECONDS); } catch (HotRodClientException re) { // No need to retry. The hotrod (remoteCache) has some retries in itself in case of some random network error happened. - // In case of lock conflict, we don't want to retry anyway as there was likely an attempt to use the code from different place. if (logger.isDebugEnabled()) { logger.debugf(re, "Failed when adding code %s", codeId); } - return false; + throw re; } - } + + @Override + public Map remove(UUID codeId) { + try { + BasicCache cache = codeCache.get(); + ActionTokenValueEntity existing = cache.remove(codeId); + return existing == null ? null : existing.getNotes(); + } catch (HotRodClientException re) { + // No need to retry. The hotrod (remoteCache) has some retries in itself in case of some random network error happened. + // In case of lock conflict, we don't want to retry anyway as there was likely an attempt to remove the code from different place. + if (logger.isDebugEnabled()) { + logger.debugf(re, "Failed when removing code %s", codeId); + } + + return null; + } + } + + @Override public void close() { diff --git a/model/infinispan/src/test/java/org/keycloak/cluster/infinispan/ConcurrencyDistributedRemoveSessionTest.java b/model/infinispan/src/test/java/org/keycloak/cluster/infinispan/ConcurrencyDistributedRemoveSessionTest.java new file mode 100644 index 0000000000..bac63a0911 --- /dev/null +++ b/model/infinispan/src/test/java/org/keycloak/cluster/infinispan/ConcurrencyDistributedRemoveSessionTest.java @@ -0,0 +1,187 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates + * and other contributors as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.keycloak.cluster.infinispan; + +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import org.infinispan.Cache; +import org.jboss.logging.Logger; +import org.junit.Assert; +import org.keycloak.common.util.Time; +import org.keycloak.connections.infinispan.InfinispanConnectionProvider; +import org.keycloak.models.sessions.infinispan.changes.SessionEntityWrapper; +import org.keycloak.models.sessions.infinispan.entities.AuthenticatedClientSessionEntity; +import org.keycloak.models.sessions.infinispan.entities.UserSessionEntity; +import org.keycloak.models.sessions.infinispan.initializer.DistributedCacheConcurrentWritesTest; + +/** + * @author Marek Posolda + */ +public class ConcurrencyDistributedRemoveSessionTest { + + + protected static final Logger logger = Logger.getLogger(ConcurrencyJDGRemoveSessionTest.class); + + private static final int ITERATIONS = 10000; + + private static final AtomicInteger errorsCounter = new AtomicInteger(0); + + private static final AtomicInteger successfulListenerWrites = new AtomicInteger(0); + private static final AtomicInteger successfulListenerWrites2 = new AtomicInteger(0); + + private static Map removalCounts = new ConcurrentHashMap<>(); + + + private static final UUID CLIENT_1_UUID = UUID.randomUUID(); + + public static void main(String[] args) throws Exception { + Cache> cache1 = DistributedCacheConcurrentWritesTest.createManager("node1").getCache(InfinispanConnectionProvider.USER_SESSION_CACHE_NAME); + Cache> cache2 = DistributedCacheConcurrentWritesTest.createManager("node2").getCache(InfinispanConnectionProvider.USER_SESSION_CACHE_NAME); + + // Create caches, listeners and finally worker threads + Thread worker1 = createWorker(cache1, 1); + Thread worker2 = createWorker(cache2, 2); + Thread worker3 = createWorker(cache1, 1); + Thread worker4 = createWorker(cache2, 2); + + // Create 100 initial sessions + for (int i=0 ; i wrappedSession = createSessionEntity(sessionId); + cache1.put(sessionId, wrappedSession); + + removalCounts.put(sessionId, new AtomicInteger(0)); + } + + logger.info("SESSIONS CREATED"); + + // Create 100 initial sessions + for (int i=0 ; i histogram = new HashMap<>(); + for (Map.Entry entry : removalCounts.entrySet()) { + int count = entry.getValue().get(); + + int current = histogram.get(count) == null ? 0 : histogram.get(count); + current++; + histogram.put(count, current); + } + + logger.infof("Histogram: %s", histogram.toString()); + logger.infof("Errors: %d", errorsCounter.get()); + + long took = System.currentTimeMillis() - start; + logger.infof("took %d ms", took); + + + } finally { + Thread.sleep(2000); + + // Finish JVM + cache1.getCacheManager().stop(); + cache2.getCacheManager().stop(); + } + } + + + private static SessionEntityWrapper createSessionEntity(String sessionId) { + // Create 100 initial sessions + UserSessionEntity session = new UserSessionEntity(); + session.setId(sessionId); + session.setRealmId("foo"); + session.setBrokerSessionId("!23123123"); + session.setBrokerUserId(null); + session.setUser("foo"); + session.setLoginUsername("foo"); + session.setIpAddress("123.44.143.178"); + session.setStarted(Time.currentTime()); + session.setLastSessionRefresh(Time.currentTime()); + + AuthenticatedClientSessionEntity clientSession = new AuthenticatedClientSessionEntity(UUID.randomUUID()); + clientSession.setAuthMethod("saml"); + clientSession.setAction("something"); + clientSession.setTimestamp(1234); + session.getAuthenticatedClientSessions().put(CLIENT_1_UUID.toString(), clientSession.getId()); + + SessionEntityWrapper wrappedSession = new SessionEntityWrapper<>(session); + return wrappedSession; + } + + + private static Thread createWorker(Cache> cache, int threadId) { + System.out.println("Retrieved cache: " + threadId); + return new CacheWorker(cache, threadId); + } + + + private static class CacheWorker extends Thread { + + private final Cache cache; + + private final int myThreadId; + + private CacheWorker(Cache cache, int myThreadId) { + this.cache = cache; + this.myThreadId = myThreadId; + } + + + @Override + public void run() { + + for (int i=0 ; i state = new HashMap<>(); + private static Map removalCounts = new ConcurrentHashMap<>(); + private static final UUID CLIENT_1_UUID = UUID.randomUUID(); @@ -78,12 +82,16 @@ public class ConcurrencyJDGRemoveSessionTest { // Create caches, listeners and finally worker threads Thread worker1 = createWorker(cache1, 1); Thread worker2 = createWorker(cache2, 2); + Thread worker3 = createWorker(cache1, 1); + Thread worker4 = createWorker(cache2, 2); // Create 100 initial sessions for (int i=0 ; i wrappedSession = createSessionEntity(sessionId); cache1.put(sessionId, wrappedSession); + + removalCounts.put(sessionId, new AtomicInteger(0)); } logger.info("SESSIONS CREATED"); @@ -101,25 +109,44 @@ public class ConcurrencyJDGRemoveSessionTest { long start = System.currentTimeMillis(); try { - // Just running in current thread - worker1.run(); + worker1.start(); + worker2.start(); + worker3.start(); + worker4.start(); + + worker1.join(); + worker2.join(); + worker3.join(); + worker4.join(); logger.info("SESSIONS REMOVED"); + Map histogram = new HashMap<>(); + for (Map.Entry entry : removalCounts.entrySet()) { + int count = entry.getValue().get(); + + int current = histogram.get(count) == null ? 0 : histogram.get(count); + current++; + histogram.put(count, current); + } + + logger.infof("Histogram: %s", histogram.toString()); + logger.infof("Errors: %d", errorsCounter.get()); + //Thread.sleep(5000); // Doing it in opposite direction to ensure that newer are checked first. // This us currently FAILING (expected) as listeners are executed asynchronously. - for (int i=ITERATIONS-1 ; i>=0 ; i--) { - String sessionId = String.valueOf(i); - - logger.infof("Before call cache2.get: %s", sessionId); - - SessionEntityWrapper loadedWrapper = cache2.get(sessionId); - Assert.assertNull("Loaded wrapper not null for key " + sessionId, loadedWrapper); - } - - logger.info("SESSIONS NOT AVAILABLE ON DC2"); +// for (int i=ITERATIONS-1 ; i>=0 ; i--) { +// String sessionId = String.valueOf(i); +// +// logger.infof("Before call cache2.get: %s", sessionId); +// +// SessionEntityWrapper loadedWrapper = cache2.get(sessionId); +// Assert.assertNull("Loaded wrapper not null for key " + sessionId, loadedWrapper); +// } +// +// logger.info("SESSIONS NOT AVAILABLE ON DC2"); long took = System.currentTimeMillis() - start; logger.infof("took %d ms", took); @@ -271,19 +298,30 @@ public class ConcurrencyJDGRemoveSessionTest { for (int i=0 ; i codeData); + + + /** + * This method returns data just if removal was successful. Implementation should guarantee that "remove" is single-use. So if + * 2 threads (even on different cluster nodes or on different cross-dc nodes) calls "remove(123)" concurrently, then just one of them + * is allowed to succeed and return data back. It can't happen that both will succeed. + * + * @param codeId + * @return context data related to OAuth2 code. It returns null if there are not context data available. + */ + Map remove(UUID codeId); } diff --git a/server-spi/src/main/java/org/keycloak/models/ClientSessionContext.java b/server-spi/src/main/java/org/keycloak/models/ClientSessionContext.java index 332079701c..34cec958c6 100644 --- a/server-spi/src/main/java/org/keycloak/models/ClientSessionContext.java +++ b/server-spi/src/main/java/org/keycloak/models/ClientSessionContext.java @@ -40,4 +40,11 @@ public interface ClientSessionContext { Set getProtocolMappers(); String getScopeString(); + + void setAttribute(String name, Object value); + + T getAttribute(String attribute, Class clazz); + + + String AUTHENTICATION_SESSION_ATTR = "AUTH_SESSION_ATTR"; } diff --git a/services/src/main/java/org/keycloak/protocol/AuthorizationEndpointBase.java b/services/src/main/java/org/keycloak/protocol/AuthorizationEndpointBase.java index 3b3a769220..c024076bb8 100755 --- a/services/src/main/java/org/keycloak/protocol/AuthorizationEndpointBase.java +++ b/services/src/main/java/org/keycloak/protocol/AuthorizationEndpointBase.java @@ -126,8 +126,6 @@ public abstract class AuthorizationEndpointBase { return response; } - // Attach session once no requiredActions or other things are required - processor.attachSession(); } catch (Exception e) { return processor.handleBrowserException(e); } diff --git a/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocol.java b/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocol.java index 93e3f7df07..0d14c0bbbc 100755 --- a/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocol.java +++ b/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocol.java @@ -29,7 +29,6 @@ import org.keycloak.events.EventType; import org.keycloak.models.AuthenticatedClientSessionModel; import org.keycloak.models.ClientModel; import org.keycloak.models.ClientSessionContext; -import org.keycloak.models.TokenManager; import org.keycloak.models.KeycloakSession; import org.keycloak.models.RealmModel; import org.keycloak.models.UserSessionModel; @@ -39,16 +38,21 @@ import org.keycloak.protocol.oidc.utils.OIDCResponseMode; import org.keycloak.protocol.oidc.utils.OIDCResponseType; import org.keycloak.representations.AccessTokenResponse; import org.keycloak.representations.adapters.action.PushNotBeforeAction; +import org.keycloak.services.ErrorResponseException; import org.keycloak.services.ServicesLogger; import org.keycloak.services.managers.AuthenticationManager; import org.keycloak.services.managers.AuthenticationSessionManager; -import org.keycloak.services.managers.ClientSessionCode; +import org.keycloak.protocol.oidc.utils.OAuth2Code; +import org.keycloak.protocol.oidc.utils.OAuth2CodeParser; import org.keycloak.services.managers.ResourceAdminManager; import org.keycloak.sessions.AuthenticationSessionModel; +import org.keycloak.sessions.CommonClientSessionModel; import org.keycloak.util.TokenUtil; import java.io.IOException; import java.net.URI; +import java.util.UUID; + import javax.ws.rs.core.HttpHeaders; import javax.ws.rs.core.Response; import javax.ws.rs.core.UriBuilder; @@ -179,7 +183,6 @@ public class OIDCLoginProtocol implements LoginProtocol { @Override public Response authenticated(UserSessionModel userSession, ClientSessionContext clientSessionCtx) { AuthenticatedClientSessionModel clientSession= clientSessionCtx.getClientSession(); - ClientSessionCode accessCode = new ClientSessionCode<>(session, realm, clientSession); String responseTypeParam = clientSession.getNote(OIDCLoginProtocol.RESPONSE_TYPE_PARAM); String responseModeParam = clientSession.getNote(OIDCLoginProtocol.RESPONSE_MODE_PARAM); @@ -197,10 +200,27 @@ public class OIDCLoginProtocol implements LoginProtocol { redirectUri.addParam(OAuth2Constants.SESSION_STATE, userSession.getId()); } + AuthenticationSessionModel authSession = clientSessionCtx.getAttribute(ClientSessionContext.AUTHENTICATION_SESSION_ATTR, AuthenticationSessionModel.class); + if (authSession == null) { + // Shouldn't happen if correctly used + throw new IllegalStateException("AuthenticationSession attachement not set in the ClientSessionContext"); + } + + String nonce = authSession.getClientNote(OIDCLoginProtocol.NONCE_PARAM); + clientSessionCtx.setAttribute(OIDCLoginProtocol.NONCE_PARAM, nonce); + // Standard or hybrid flow String code = null; if (responseType.hasResponseType(OIDCResponseType.CODE)) { - code = accessCode.getOrGenerateCode(); + OAuth2Code codeData = new OAuth2Code(UUID.randomUUID(), + Time.currentTime() + userSession.getRealm().getAccessCodeLifespan(), + nonce, + authSession.getClientNote(OAuth2Constants.SCOPE), + authSession.getClientNote(OIDCLoginProtocol.REDIRECT_URI_PARAM), + authSession.getClientNote(OIDCLoginProtocol.CODE_CHALLENGE_PARAM), + authSession.getClientNote(OIDCLoginProtocol.CODE_CHALLENGE_METHOD_PARAM)); + + code = OAuth2CodeParser.persistCode(session, clientSession, codeData); redirectUri.addParam(OAuth2Constants.CODE, code); } diff --git a/services/src/main/java/org/keycloak/protocol/oidc/TokenManager.java b/services/src/main/java/org/keycloak/protocol/oidc/TokenManager.java index 1311635471..08a86e244d 100755 --- a/services/src/main/java/org/keycloak/protocol/oidc/TokenManager.java +++ b/services/src/main/java/org/keycloak/protocol/oidc/TokenManager.java @@ -180,6 +180,8 @@ public class TokenManager { throw new OAuthErrorException(OAuthErrorException.INVALID_SCOPE, "Client no longer has requested consent from user"); } + clientSessionCtx.setAttribute(OIDCLoginProtocol.NONCE_PARAM, oldToken.getNonce()); + // recreate token. AccessToken newToken = createClientAccessToken(session, realm, client, user, userSession, clientSessionCtx); verifyAccess(oldToken, newToken); @@ -433,7 +435,10 @@ public class TokenManager { // Remove authentication session now new AuthenticationSessionManager(session).removeAuthenticationSession(userSession.getRealm(), authSession, true); - return DefaultClientSessionContext.fromClientSessionAndClientScopeIds(clientSession, clientScopeIds); + ClientSessionContext clientSessionCtx = DefaultClientSessionContext.fromClientSessionAndClientScopeIds(clientSession, clientScopeIds); + clientSessionCtx.setAttribute(ClientSessionContext.AUTHENTICATION_SESSION_ATTR, authSession); + + return clientSessionCtx; } @@ -614,7 +619,7 @@ public class TokenManager { AuthenticatedClientSessionModel clientSession = clientSessionCtx.getClientSession(); token.issuer(clientSession.getNote(OIDCLoginProtocol.ISSUER)); - token.setNonce(clientSession.getNote(OIDCLoginProtocol.NONCE_PARAM)); + token.setNonce(clientSessionCtx.getAttribute(OIDCLoginProtocol.NONCE_PARAM, String.class)); token.setScope(clientSessionCtx.getScopeString()); // Best effort for "acr" value. Use 0 if clientSession was authenticated through cookie ( SSO ) diff --git a/services/src/main/java/org/keycloak/protocol/oidc/endpoints/TokenEndpoint.java b/services/src/main/java/org/keycloak/protocol/oidc/endpoints/TokenEndpoint.java index ed4817ad68..5279c4bb91 100644 --- a/services/src/main/java/org/keycloak/protocol/oidc/endpoints/TokenEndpoint.java +++ b/services/src/main/java/org/keycloak/protocol/oidc/endpoints/TokenEndpoint.java @@ -75,7 +75,8 @@ import org.keycloak.services.managers.AuthenticationManager; import org.keycloak.services.managers.AuthenticationSessionManager; import org.keycloak.services.managers.BruteForceProtector; import org.keycloak.services.managers.ClientManager; -import org.keycloak.services.managers.ClientSessionCode; +import org.keycloak.protocol.oidc.utils.OAuth2Code; +import org.keycloak.protocol.oidc.utils.OAuth2CodeParser; import org.keycloak.services.managers.RealmManager; import org.keycloak.services.resources.Cors; import org.keycloak.services.resources.IdentityBrokerService; @@ -275,8 +276,8 @@ public class TokenEndpoint { throw new CorsErrorResponseException(cors, OAuthErrorException.INVALID_REQUEST, "Missing parameter: " + OAuth2Constants.CODE, Response.Status.BAD_REQUEST); } - ClientSessionCode.ParseResult parseResult = ClientSessionCode.parseResult(code, null, session, realm, client, event, AuthenticatedClientSessionModel.class); - if (parseResult.isAuthSessionNotFound() || parseResult.isIllegalHash()) { + OAuth2CodeParser.ParseResult parseResult = OAuth2CodeParser.parseCode(session, code, realm, event); + if (parseResult.isIllegalCode()) { AuthenticatedClientSessionModel clientSession = parseResult.getClientSession(); // Attempt to use same code twice should invalidate existing clientSession @@ -291,7 +292,7 @@ public class TokenEndpoint { AuthenticatedClientSessionModel clientSession = parseResult.getClientSession(); - if (parseResult.isExpiredToken()) { + if (parseResult.isExpiredCode()) { event.error(Errors.EXPIRED_CODE); throw new CorsErrorResponseException(cors, OAuthErrorException.INVALID_GRANT, "Code is expired", Response.Status.BAD_REQUEST); } @@ -317,7 +318,8 @@ public class TokenEndpoint { throw new CorsErrorResponseException(cors, OAuthErrorException.INVALID_GRANT, "User disabled", Response.Status.BAD_REQUEST); } - String redirectUri = clientSession.getNote(OIDCLoginProtocol.REDIRECT_URI_PARAM); + OAuth2Code codeData = parseResult.getCodeData(); + String redirectUri = codeData.getRedirectUriParam(); String redirectUriParam = formParams.getFirst(OAuth2Constants.REDIRECT_URI); // KEYCLOAK-4478 Backwards compatibility with the adapters earlier than KC 3.4.2 @@ -349,8 +351,8 @@ public class TokenEndpoint { // https://tools.ietf.org/html/rfc7636#section-4.6 String codeVerifier = formParams.getFirst(OAuth2Constants.CODE_VERIFIER); - String codeChallenge = clientSession.getNote(OIDCLoginProtocol.CODE_CHALLENGE_PARAM); - String codeChallengeMethod = clientSession.getNote(OIDCLoginProtocol.CODE_CHALLENGE_METHOD_PARAM); + String codeChallenge = codeData.getCodeChallenge(); + String codeChallengeMethod = codeData.getCodeChallengeMethod(); String authUserId = user.getId(); String authUsername = user.getUsername(); if (authUserId == null) { @@ -406,7 +408,7 @@ public class TokenEndpoint { // Compute client scopes again from scope parameter. Check if user still has them granted // (but in code-to-token request, it could just theoretically happen that they are not available) - String scopeParam = clientSession.getNote(OAuth2Constants.SCOPE); + String scopeParam = codeData.getScope(); Set clientScopes = TokenManager.getRequestedClientScopes(scopeParam, client); if (!TokenManager.verifyConsentStillAvailable(session, user, client, clientScopes)) { event.error(Errors.NOT_ALLOWED); @@ -415,6 +417,9 @@ public class TokenEndpoint { ClientSessionContext clientSessionCtx = DefaultClientSessionContext.fromClientSessionAndClientScopes(clientSession, clientScopes); + // Set nonce as an attribute in the ClientSessionContext. Will be used for the token generation + clientSessionCtx.setAttribute(OIDCLoginProtocol.NONCE_PARAM, codeData.getNonce()); + AccessToken token = tokenManager.createClientAccessToken(session, realm, client, user, userSession, clientSessionCtx); TokenManager.AccessTokenResponseBuilder responseBuilder = tokenManager.responseBuilder(realm, client, event, session, userSession, clientSessionCtx) diff --git a/services/src/main/java/org/keycloak/protocol/oidc/utils/OAuth2Code.java b/services/src/main/java/org/keycloak/protocol/oidc/utils/OAuth2Code.java new file mode 100644 index 0000000000..fd9159922d --- /dev/null +++ b/services/src/main/java/org/keycloak/protocol/oidc/utils/OAuth2Code.java @@ -0,0 +1,128 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates + * and other contributors as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.keycloak.protocol.oidc.utils; + +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +/** + * Data associated with the oauth2 code. + * + * Those data are typically valid just for the very short time - they're created at the point before we redirect to the application + * after successful and they're removed when application sends requests to the token endpoint (code-to-token endpoint) to exchange the + * single-use OAuth2 code parameter for those data. + * + * @author Marek Posolda + */ +public class OAuth2Code { + + private static final String ID_NOTE = "id"; + private static final String EXPIRATION_NOTE = "exp"; + private static final String NONCE_NOTE = "nonce"; + private static final String SCOPE_NOTE = "scope"; + private static final String REDIRECT_URI_PARAM_NOTE = "redirectUri"; + private static final String CODE_CHALLENGE_NOTE = "code_challenge"; + private static final String CODE_CHALLENGE_METHOD_NOTE = "code_challenge_method"; + + private final UUID id; + + private final int expiration; + + private final String nonce; + + private final String scope; + + private final String redirectUriParam; + + private final String codeChallenge; + + private final String codeChallengeMethod; + + + public OAuth2Code(UUID id, int expiration, String nonce, String scope, String redirectUriParam, + String codeChallenge, String codeChallengeMethod) { + this.id = id; + this.expiration = expiration; + this.nonce = nonce; + this.scope = scope; + this.redirectUriParam = redirectUriParam; + this.codeChallenge = codeChallenge; + this.codeChallengeMethod = codeChallengeMethod; + } + + + private OAuth2Code(Map data) { + id = UUID.fromString(data.get(ID_NOTE)); + expiration = Integer.parseInt(data.get(EXPIRATION_NOTE)); + nonce = data.get(NONCE_NOTE); + scope = data.get(SCOPE_NOTE); + redirectUriParam = data.get(REDIRECT_URI_PARAM_NOTE); + codeChallenge = data.get(CODE_CHALLENGE_NOTE); + codeChallengeMethod = data.get(CODE_CHALLENGE_METHOD_NOTE); + } + + + public static final OAuth2Code deserializeCode(Map data) { + return new OAuth2Code(data); + } + + + public Map serializeCode() { + Map result = new HashMap<>(); + + result.put(ID_NOTE, id.toString()); + result.put(EXPIRATION_NOTE, String.valueOf(expiration)); + result.put(NONCE_NOTE, nonce); + result.put(SCOPE_NOTE, scope); + result.put(REDIRECT_URI_PARAM_NOTE, redirectUriParam); + result.put(CODE_CHALLENGE_NOTE, codeChallenge); + result.put(CODE_CHALLENGE_METHOD_NOTE, codeChallengeMethod); + + return result; + } + + + public UUID getId() { + return id; + } + + public int getExpiration() { + return expiration; + } + + public String getNonce() { + return nonce; + } + + public String getScope() { + return scope; + } + + public String getRedirectUriParam() { + return redirectUriParam; + } + + public String getCodeChallenge() { + return codeChallenge; + } + + public String getCodeChallengeMethod() { + return codeChallengeMethod; + } +} diff --git a/services/src/main/java/org/keycloak/protocol/oidc/utils/OAuth2CodeParser.java b/services/src/main/java/org/keycloak/protocol/oidc/utils/OAuth2CodeParser.java new file mode 100644 index 0000000000..d565789782 --- /dev/null +++ b/services/src/main/java/org/keycloak/protocol/oidc/utils/OAuth2CodeParser.java @@ -0,0 +1,195 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates + * and other contributors as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.keycloak.protocol.oidc.utils; + +import java.util.Map; +import java.util.UUID; +import java.util.regex.Pattern; + +import org.jboss.logging.Logger; +import org.keycloak.common.util.Time; +import org.keycloak.events.Details; +import org.keycloak.events.EventBuilder; +import org.keycloak.models.AuthenticatedClientSessionModel; +import org.keycloak.models.CodeToTokenStoreProvider; +import org.keycloak.models.KeycloakSession; +import org.keycloak.models.RealmModel; +import org.keycloak.models.UserSessionModel; +import org.keycloak.services.managers.UserSessionCrossDCManager; + +/** + * @author Marek Posolda + */ +public class OAuth2CodeParser { + + private static final Logger logger = Logger.getLogger(OAuth2CodeParser.class); + + private static final Pattern DOT = Pattern.compile("\\."); + + + /** + * Will persist the code to the cache and return the object with the codeData and code correctly set + * + * @param session + * @param clientSession + * @param codeData + * @return code parameter to be used in OAuth2 handshake + */ + public static String persistCode(KeycloakSession session, AuthenticatedClientSessionModel clientSession, OAuth2Code codeData) { + CodeToTokenStoreProvider codeStore = session.getProvider(CodeToTokenStoreProvider.class); + + UUID key = codeData.getId(); + if (key == null) { + throw new IllegalStateException("ID not present in the data"); + } + + Map serialized = codeData.serializeCode(); + codeStore.put(key, clientSession.getUserSession().getRealm().getAccessCodeLifespan(), serialized); + return key.toString() + "." + clientSession.getUserSession().getId() + "." + clientSession.getClient().getId(); + } + + + /** + * Will parse the code and retrieve the corresponding OAuth2Code and AuthenticatedClientSessionModel. Will also check if code wasn't already + * used and if it wasn't expired. If it was already used (or other error happened during parsing), then returned parser will have "isIllegalHash" + * set to true. If it was expired, the parser will have "isExpired" set to true + * + * @param session + * @param code + * @param realm + * @param event + * @return + */ + public static ParseResult parseCode(KeycloakSession session, String code, RealmModel realm, EventBuilder event) { + ParseResult result = new ParseResult(code); + + String[] parsed = DOT.split(code, 3); + if (parsed.length < 3) { + logger.warn("Invalid format of the code"); + return result.illegalCode(); + } + + String userSessionId = parsed[1]; + String clientUUID = parsed[2]; + + event.detail(Details.CODE_ID, userSessionId); + event.session(userSessionId); + + // Parse UUID + UUID codeUUID; + try { + codeUUID = UUID.fromString(parsed[0]); + } catch (IllegalArgumentException re) { + logger.warn("Invalid format of the UUID in the code"); + return result.illegalCode(); + } + + // Retrieve UserSession + UserSessionModel userSession = new UserSessionCrossDCManager(session).getUserSessionWithClient(realm, userSessionId, clientUUID); + if (userSession == null) { + // Needed to track if code is invalid or was already used. + userSession = session.sessions().getUserSession(realm, userSessionId); + if (userSession == null) { + return result.illegalCode(); + } + } + + result.clientSession = userSession.getAuthenticatedClientSessionByClient(clientUUID); + + CodeToTokenStoreProvider codeStore = session.getProvider(CodeToTokenStoreProvider.class); + Map codeData = codeStore.remove(codeUUID); + + // Either code not available or was already used + if (codeData == null) { + logger.warnf("Code '%s' already used for userSession '%s' and client '%s'.", codeUUID, userSessionId, clientUUID); + return result.illegalCode(); + } + + logger.tracef("Successfully verified code '%s'. User session: '%s', client: '%s'", codeUUID, userSessionId, clientUUID); + + result.codeData = OAuth2Code.deserializeCode(codeData); + + // Finally doublecheck if code is not expired + int currentTime = Time.currentTime(); + if (currentTime > result.codeData.getExpiration()) { + return result.expiredCode(); + } + + return result; + } + + + public static class ParseResult { + + private final String code; + private OAuth2Code codeData; + private AuthenticatedClientSessionModel clientSession; + + private boolean isIllegalCode = false; + private boolean isExpiredCode = false; + + + private ParseResult(String code, OAuth2Code codeData, AuthenticatedClientSessionModel clientSession) { + this.code = code; + this.codeData = codeData; + this.clientSession = clientSession; + + this.isIllegalCode = false; + this.isExpiredCode = false; + } + + + private ParseResult(String code) { + this.code = code; + } + + + public String getCode() { + return code; + } + + public OAuth2Code getCodeData() { + return codeData; + } + + public AuthenticatedClientSessionModel getClientSession() { + return clientSession; + } + + public boolean isIllegalCode() { + return isIllegalCode; + } + + public boolean isExpiredCode() { + return isExpiredCode; + } + + + private ParseResult illegalCode() { + this.isIllegalCode = true; + return this; + } + + + private ParseResult expiredCode() { + this.isExpiredCode = true; + return this; + } + } + +} diff --git a/services/src/main/java/org/keycloak/services/managers/CodeGenerateUtil.java b/services/src/main/java/org/keycloak/services/managers/CodeGenerateUtil.java index 488e21d4d5..686524d0d6 100644 --- a/services/src/main/java/org/keycloak/services/managers/CodeGenerateUtil.java +++ b/services/src/main/java/org/keycloak/services/managers/CodeGenerateUtil.java @@ -38,7 +38,6 @@ import org.keycloak.models.KeycloakSession; import org.keycloak.models.RealmModel; import org.keycloak.models.UserSessionModel; import org.keycloak.models.utils.KeycloakModelUtils; -import org.keycloak.representations.CodeJWT; import org.keycloak.sessions.CommonClientSessionModel; import org.keycloak.sessions.AuthenticationSessionModel; import org.keycloak.util.TokenUtil; @@ -60,10 +59,6 @@ class CodeGenerateUtil { PARSERS.put(AuthenticationSessionModel.class, () -> { return new AuthenticationSessionModelParser(); }); - - PARSERS.put(AuthenticatedClientSessionModel.class, () -> { - return new AuthenticatedClientSessionModelParser(); - }); } @@ -166,119 +161,4 @@ class CodeGenerateUtil { } - private static class AuthenticatedClientSessionModelParser implements ClientSessionParser { - - private CodeJWT codeJWT; - - @Override - public AuthenticatedClientSessionModel parseSession(String code, String tabId, KeycloakSession session, RealmModel realm, ClientModel client, EventBuilder event) { - SecretKey aesKey = session.keys().getActiveAesKey(realm).getSecretKey(); - SecretKey hmacKey = session.keys().getActiveHmacKey(realm).getSecretKey(); - - try { - codeJWT = TokenUtil.jweDirectVerifyAndDecode(aesKey, hmacKey, code, CodeJWT.class); - } catch (JWEException jweException) { - logger.error("Exception during JWE Verification or decode", jweException); - return null; - } - - event.detail(Details.CODE_ID, codeJWT.getUserSessionId()); - event.session(codeJWT.getUserSessionId()); - - UserSessionModel userSession = new UserSessionCrossDCManager(session).getUserSessionWithClient(realm, codeJWT.getUserSessionId(), codeJWT.getIssuedFor()); - if (userSession == null) { - // TODO:mposolda Temporary workaround needed to track if code is invalid or was already used. Will be good to remove once used OAuth codes are tracked through one-time cache - userSession = session.sessions().getUserSession(realm, codeJWT.getUserSessionId()); - if (userSession == null) { - return null; - } - } - - return userSession.getAuthenticatedClientSessionByClient(codeJWT.getIssuedFor()); - - } - - - @Override - public String retrieveCode(KeycloakSession session, AuthenticatedClientSessionModel clientSession) { - String actionId = KeycloakModelUtils.generateId(); - - CodeJWT codeJWT = new CodeJWT(); - codeJWT.id(actionId); - codeJWT.issuedFor(clientSession.getClient().getId()); - codeJWT.userSessionId(clientSession.getUserSession().getId()); - - RealmModel realm = clientSession.getRealm(); - - int issuedAt = Time.currentTime(); - codeJWT.issuedAt(issuedAt); - codeJWT.expiration(issuedAt + realm.getAccessCodeLifespan()); - - SecretKey aesKey = session.keys().getActiveAesKey(realm).getSecretKey(); - SecretKey hmacKey = session.keys().getActiveHmacKey(realm).getSecretKey(); - - if (logger.isTraceEnabled()) { - logger.tracef("Using AES key of length '%d' bytes and HMAC key of length '%d' bytes . Client: '%s', User Session: '%s'", aesKey.getEncoded().length, - hmacKey.getEncoded().length, clientSession.getClient().getClientId(), clientSession.getUserSession().getId()); - } - - try { - return TokenUtil.jweDirectEncode(aesKey, hmacKey, codeJWT); - } catch (JWEException jweEx) { - throw new RuntimeException(jweEx); - } - } - - - @Override - public boolean verifyCode(KeycloakSession session, String code, AuthenticatedClientSessionModel clientSession) { - if (codeJWT == null) { - throw new IllegalStateException("Illegal use. codeJWT not yet set"); - } - - UUID codeId = UUID.fromString(codeJWT.getId()); - CodeToTokenStoreProvider singleUseCache = session.getProvider(CodeToTokenStoreProvider.class); - - if (singleUseCache.putIfAbsent(codeId)) { - - if (logger.isTraceEnabled()) { - logger.tracef("Added code '%s' to single-use cache. User session: %s, client: %s", codeJWT.getId(), codeJWT.getUserSessionId(), codeJWT.getIssuedFor()); - } - - return true; - } else { - logger.warnf("Code '%s' already used for userSession '%s' and client '%s'.", codeJWT.getId(), codeJWT.getUserSessionId(), codeJWT.getIssuedFor()); - return false; - } - } - - - @Override - public void removeExpiredSession(KeycloakSession session, AuthenticatedClientSessionModel clientSession) { - throw new IllegalStateException("Not yet implemented"); - } - - - @Override - public boolean isExpired(KeycloakSession session, String code, AuthenticatedClientSessionModel clientSession) { - return !codeJWT.isActive(); - } - - @Override - public int getTimestamp(AuthenticatedClientSessionModel clientSession) { - return clientSession.getTimestamp(); - } - - @Override - public void setTimestamp(AuthenticatedClientSessionModel clientSession, int timestamp) { - clientSession.setTimestamp(timestamp); - } - - @Override - public String getClientNote(AuthenticatedClientSessionModel clientSession, String noteKey) { - return clientSession.getNote(noteKey); - } - } - - } diff --git a/services/src/main/java/org/keycloak/services/util/DefaultClientSessionContext.java b/services/src/main/java/org/keycloak/services/util/DefaultClientSessionContext.java index 2c127ed2fc..74d14a6b9d 100644 --- a/services/src/main/java/org/keycloak/services/util/DefaultClientSessionContext.java +++ b/services/src/main/java/org/keycloak/services/util/DefaultClientSessionContext.java @@ -17,7 +17,9 @@ package org.keycloak.services.util; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; import org.jboss.logging.Logger; @@ -56,6 +58,8 @@ public class DefaultClientSessionContext implements ClientSessionContext { // All roles of user expanded. It doesn't yet take into account permitted clientScopes private Set userRoles; + private Map attributes = new HashMap<>(); + private DefaultClientSessionContext(AuthenticatedClientSessionModel clientSession, Set clientScopeIds) { this.clientSession = clientSession; this.clientScopeIds = clientScopeIds; @@ -177,6 +181,19 @@ public class DefaultClientSessionContext implements ClientSessionContext { } + @Override + public void setAttribute(String name, Object value) { + attributes.put(name, value); + } + + + @Override + public T getAttribute(String name, Class clazz) { + Object value = attributes.get(name); + return clazz.cast(value); + } + + // Loading data private Set loadClientScopes() { diff --git a/testsuite/integration-arquillian/tests/base/src/main/java/org/keycloak/testsuite/util/OAuthClient.java b/testsuite/integration-arquillian/tests/base/src/main/java/org/keycloak/testsuite/util/OAuthClient.java index f7d8c5538f..a5c7fa5c0f 100644 --- a/testsuite/integration-arquillian/tests/base/src/main/java/org/keycloak/testsuite/util/OAuthClient.java +++ b/testsuite/integration-arquillian/tests/base/src/main/java/org/keycloak/testsuite/util/OAuthClient.java @@ -760,6 +760,10 @@ public class OAuthClient { return redirectUri; } + public String getNonce() { + return nonce; + } + public String getLoginFormUrl() { UriBuilder b = OIDCLoginProtocolService.authUrl(UriBuilder.fromUri(baseUrl)); if (responseType != null) { diff --git a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/admin/concurrency/ConcurrentLoginTest.java b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/admin/concurrency/ConcurrentLoginTest.java index 40a81afbe7..28d78156f7 100644 --- a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/admin/concurrency/ConcurrentLoginTest.java +++ b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/admin/concurrency/ConcurrentLoginTest.java @@ -48,6 +48,8 @@ import org.keycloak.OAuth2Constants; import org.keycloak.admin.client.Keycloak; import org.keycloak.admin.client.resource.ClientsResource; import org.keycloak.admin.client.resource.RealmResource; +import org.keycloak.jose.jws.JWSInput; +import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.representations.AccessToken; import org.keycloak.representations.idm.ClientRepresentation; import org.keycloak.common.util.Retry; @@ -56,12 +58,14 @@ import org.keycloak.testsuite.util.ClientBuilder; import org.keycloak.testsuite.util.OAuthClient; import java.util.Arrays; import java.util.LinkedHashMap; +import java.util.Random; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import org.apache.http.client.CookieStore; import org.apache.http.impl.client.BasicCookieStore; import org.hamcrest.Matchers; +import org.keycloak.util.JsonSerialization; import static org.hamcrest.Matchers.containsString; @@ -321,6 +325,10 @@ public class ConcurrentLoginTest extends AbstractConcurrencyTest { protected OAuthClient initialValue() { OAuthClient oauth1 = new OAuthClient(); oauth1.init(driver); + + // Add some randomness to nonce and redirectUri. Verify that login is successful and nonce will match + oauth1.nonce(KeycloakModelUtils.generateId()); + oauth1.redirectUri(oauth.getRedirectUri() + "?some=" + new Random().nextInt(1024)); return oauth1; } }; @@ -375,16 +383,25 @@ public class ConcurrentLoginTest extends AbstractConcurrencyTest { accessResRef.set(accessRes); // Refresh access + refresh token using refresh token + AtomicReference refreshResRef = new AtomicReference<>(); + int invocationIndex = Retry.execute(() -> { OAuthClient.AccessTokenResponse refreshRes = oauth1.doRefreshTokenRequest(accessResRef.get().getRefreshToken(), "password"); Assert.assertEquals("AccessTokenResponse: client: " + oauth1.getClientId() + ", error: '" + refreshRes.getError() + "' desc: '" + refreshRes.getErrorDescription() + "'", 200, refreshRes.getStatusCode()); + + refreshResRef.set(refreshRes); }, retryCount, retryDelayMs); retryHistogram[invocationIndex].incrementAndGet(); + AccessToken token = JsonSerialization.readValue(new JWSInput(accessResRef.get().getAccessToken()).getContent(), AccessToken.class); + Assert.assertEquals("Invalid nonce.", token.getNonce(), oauth1.getNonce()); + + AccessToken refreshedToken = JsonSerialization.readValue(new JWSInput(refreshResRef.get().getAccessToken()).getContent(), AccessToken.class); + Assert.assertEquals("Invalid nonce.", refreshedToken.getNonce(), oauth1.getNonce()); + if (userSessionId.get() == null) { - AccessToken token = oauth1.verifyToken(accessResRef.get().getAccessToken()); userSessionId.set(token.getSessionState()); } } diff --git a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/keys/FallbackKeyProviderTest.java b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/keys/FallbackKeyProviderTest.java index 233c8ca326..3273fcaade 100644 --- a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/keys/FallbackKeyProviderTest.java +++ b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/keys/FallbackKeyProviderTest.java @@ -86,7 +86,7 @@ public class FallbackKeyProviderTest extends AbstractKeycloakTest { Assert.assertEquals(AppPage.RequestType.AUTH_RESPONSE, appPage.getRequestType()); providers = realmsResouce().realm("test").components().query(realmId, "org.keycloak.keys.KeyProvider"); - assertProviders(providers, "fallback-RS256", "fallback-HS256", "fallback-AES"); + assertProviders(providers, "fallback-RS256", "fallback-HS256"); } @Test diff --git a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/oauth/RefreshTokenTest.java b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/oauth/RefreshTokenTest.java index d74c9858a6..9aecb0e945 100755 --- a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/oauth/RefreshTokenTest.java +++ b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/oauth/RefreshTokenTest.java @@ -136,6 +136,7 @@ public class RefreshTokenTest extends AbstractKeycloakTest { @Test public void refreshTokenRequest() throws Exception { + oauth.nonce("123456"); oauth.doLogin("test-user@localhost", "password"); EventRepresentation loginEvent = events.expectLogin().assertEvent(); @@ -147,6 +148,8 @@ public class RefreshTokenTest extends AbstractKeycloakTest { OAuthClient.AccessTokenResponse tokenResponse = oauth.doAccessTokenRequest(code, "password"); AccessToken token = oauth.verifyToken(tokenResponse.getAccessToken()); + assertEquals("123456", token.getNonce()); + String refreshTokenString = tokenResponse.getRefreshToken(); RefreshToken refreshToken = oauth.parseRefreshToken(refreshTokenString); @@ -200,6 +203,8 @@ public class RefreshTokenTest extends AbstractKeycloakTest { Assert.assertNotEquals(tokenEvent.getDetails().get(Details.TOKEN_ID), refreshEvent.getDetails().get(Details.TOKEN_ID)); Assert.assertNotEquals(tokenEvent.getDetails().get(Details.REFRESH_TOKEN_ID), refreshEvent.getDetails().get(Details.UPDATED_REFRESH_TOKEN_ID)); + assertEquals("123456", refreshedToken.getNonce()); + setTimeOffset(0); } @Test