KEYCLOAK-7774 KEYCLOAK-8438 Errors when SSO authenticating to same client multiple times concurrently in more browser tabs

This commit is contained in:
mposolda 2018-11-14 11:52:46 +01:00 committed by Marek Posolda
parent 8af1ca8fc3
commit 6db1f60e27
18 changed files with 715 additions and 213 deletions

View file

@ -1,39 +0,0 @@
/*
* Copyright 2017 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.representations;
import com.fasterxml.jackson.annotation.JsonProperty;
/**
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public class CodeJWT extends JsonWebToken {
@JsonProperty("uss")
protected String userSessionId;
public String getUserSessionId() {
return userSessionId;
}
public CodeJWT userSessionId(String userSessionId) {
this.userSessionId = userSessionId;
return this;
}
}

View file

@ -17,6 +17,7 @@
package org.keycloak.models.sessions.infinispan;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
@ -44,28 +45,43 @@ public class InfinispanCodeToTokenStoreProvider implements CodeToTokenStoreProvi
this.codeCache = actionKeyCache;
}
@Override
public boolean putIfAbsent(UUID codeId) {
ActionTokenValueEntity tokenValue = new ActionTokenValueEntity(null);
int lifespanInSeconds = session.getContext().getRealm().getAccessCodeLifespan();
@Override
public void put(UUID codeId, int lifespanSeconds, Map<String, String> codeData) {
ActionTokenValueEntity tokenValue = new ActionTokenValueEntity(codeData);
try {
BasicCache<UUID, ActionTokenValueEntity> cache = codeCache.get();
ActionTokenValueEntity existing = cache.putIfAbsent(codeId, tokenValue, lifespanInSeconds, TimeUnit.SECONDS);
return existing == null;
cache.put(codeId, tokenValue, lifespanSeconds, TimeUnit.SECONDS);
} catch (HotRodClientException re) {
// No need to retry. The hotrod (remoteCache) has some retries in itself in case of some random network error happened.
// In case of lock conflict, we don't want to retry anyway as there was likely an attempt to use the code from different place.
if (logger.isDebugEnabled()) {
logger.debugf(re, "Failed when adding code %s", codeId);
}
return false;
throw re;
}
}
@Override
public Map<String, String> remove(UUID codeId) {
try {
BasicCache<UUID, ActionTokenValueEntity> cache = codeCache.get();
ActionTokenValueEntity existing = cache.remove(codeId);
return existing == null ? null : existing.getNotes();
} catch (HotRodClientException re) {
// No need to retry. The hotrod (remoteCache) has some retries in itself in case of some random network error happened.
// In case of lock conflict, we don't want to retry anyway as there was likely an attempt to remove the code from different place.
if (logger.isDebugEnabled()) {
logger.debugf(re, "Failed when removing code %s", codeId);
}
return null;
}
}
@Override
public void close() {

View file

@ -0,0 +1,187 @@
/*
* Copyright 2017 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.cluster.infinispan;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.infinispan.Cache;
import org.jboss.logging.Logger;
import org.junit.Assert;
import org.keycloak.common.util.Time;
import org.keycloak.connections.infinispan.InfinispanConnectionProvider;
import org.keycloak.models.sessions.infinispan.changes.SessionEntityWrapper;
import org.keycloak.models.sessions.infinispan.entities.AuthenticatedClientSessionEntity;
import org.keycloak.models.sessions.infinispan.entities.UserSessionEntity;
import org.keycloak.models.sessions.infinispan.initializer.DistributedCacheConcurrentWritesTest;
/**
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public class ConcurrencyDistributedRemoveSessionTest {
protected static final Logger logger = Logger.getLogger(ConcurrencyJDGRemoveSessionTest.class);
private static final int ITERATIONS = 10000;
private static final AtomicInteger errorsCounter = new AtomicInteger(0);
private static final AtomicInteger successfulListenerWrites = new AtomicInteger(0);
private static final AtomicInteger successfulListenerWrites2 = new AtomicInteger(0);
private static Map<String, AtomicInteger> removalCounts = new ConcurrentHashMap<>();
private static final UUID CLIENT_1_UUID = UUID.randomUUID();
public static void main(String[] args) throws Exception {
Cache<String, SessionEntityWrapper<UserSessionEntity>> cache1 = DistributedCacheConcurrentWritesTest.createManager("node1").getCache(InfinispanConnectionProvider.USER_SESSION_CACHE_NAME);
Cache<String, SessionEntityWrapper<UserSessionEntity>> cache2 = DistributedCacheConcurrentWritesTest.createManager("node2").getCache(InfinispanConnectionProvider.USER_SESSION_CACHE_NAME);
// Create caches, listeners and finally worker threads
Thread worker1 = createWorker(cache1, 1);
Thread worker2 = createWorker(cache2, 2);
Thread worker3 = createWorker(cache1, 1);
Thread worker4 = createWorker(cache2, 2);
// Create 100 initial sessions
for (int i=0 ; i<ITERATIONS ; i++) {
String sessionId = String.valueOf(i);
SessionEntityWrapper<UserSessionEntity> wrappedSession = createSessionEntity(sessionId);
cache1.put(sessionId, wrappedSession);
removalCounts.put(sessionId, new AtomicInteger(0));
}
logger.info("SESSIONS CREATED");
// Create 100 initial sessions
for (int i=0 ; i<ITERATIONS ; i++) {
String sessionId = String.valueOf(i);
SessionEntityWrapper loadedWrapper = cache2.get(sessionId);
Assert.assertNotNull("Loaded wrapper for key " + sessionId, loadedWrapper);
}
logger.info("SESSIONS AVAILABLE ON DC2");
long start = System.currentTimeMillis();
try {
worker1.start();
worker2.start();
worker3.start();
worker4.start();
worker1.join();
worker2.join();
worker3.join();
worker4.join();
logger.info("SESSIONS REMOVED");
Map<Integer, Integer> histogram = new HashMap<>();
for (Map.Entry<String, AtomicInteger> entry : removalCounts.entrySet()) {
int count = entry.getValue().get();
int current = histogram.get(count) == null ? 0 : histogram.get(count);
current++;
histogram.put(count, current);
}
logger.infof("Histogram: %s", histogram.toString());
logger.infof("Errors: %d", errorsCounter.get());
long took = System.currentTimeMillis() - start;
logger.infof("took %d ms", took);
} finally {
Thread.sleep(2000);
// Finish JVM
cache1.getCacheManager().stop();
cache2.getCacheManager().stop();
}
}
private static SessionEntityWrapper<UserSessionEntity> createSessionEntity(String sessionId) {
// Create 100 initial sessions
UserSessionEntity session = new UserSessionEntity();
session.setId(sessionId);
session.setRealmId("foo");
session.setBrokerSessionId("!23123123");
session.setBrokerUserId(null);
session.setUser("foo");
session.setLoginUsername("foo");
session.setIpAddress("123.44.143.178");
session.setStarted(Time.currentTime());
session.setLastSessionRefresh(Time.currentTime());
AuthenticatedClientSessionEntity clientSession = new AuthenticatedClientSessionEntity(UUID.randomUUID());
clientSession.setAuthMethod("saml");
clientSession.setAction("something");
clientSession.setTimestamp(1234);
session.getAuthenticatedClientSessions().put(CLIENT_1_UUID.toString(), clientSession.getId());
SessionEntityWrapper<UserSessionEntity> wrappedSession = new SessionEntityWrapper<>(session);
return wrappedSession;
}
private static Thread createWorker(Cache<String, SessionEntityWrapper<UserSessionEntity>> cache, int threadId) {
System.out.println("Retrieved cache: " + threadId);
return new CacheWorker(cache, threadId);
}
private static class CacheWorker extends Thread {
private final Cache<String, Object> cache;
private final int myThreadId;
private CacheWorker(Cache cache, int myThreadId) {
this.cache = cache;
this.myThreadId = myThreadId;
}
@Override
public void run() {
for (int i=0 ; i<ITERATIONS ; i++) {
String sessionId = String.valueOf(i);
Object o = cache.remove(sessionId);
if (o != null) {
removalCounts.get(sessionId).incrementAndGet();
}
}
}
}
}

View file

@ -18,7 +18,10 @@
package org.keycloak.cluster.infinispan;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.infinispan.Cache;
@ -30,6 +33,7 @@ import org.infinispan.client.hotrod.annotation.ClientListener;
import org.infinispan.client.hotrod.event.ClientCacheEntryCreatedEvent;
import org.infinispan.client.hotrod.event.ClientCacheEntryModifiedEvent;
import org.infinispan.client.hotrod.event.ClientCacheEntryRemovedEvent;
import org.infinispan.client.hotrod.exceptions.HotRodClientException;
import org.infinispan.context.Flag;
import org.infinispan.manager.EmbeddedCacheManager;
import org.infinispan.persistence.remote.configuration.RemoteStoreConfigurationBuilder;
@ -61,13 +65,13 @@ public class ConcurrencyJDGRemoveSessionTest {
private static RemoteCache remoteCache1;
private static RemoteCache remoteCache2;
private static final AtomicInteger failedReplaceCounter = new AtomicInteger(0);
private static final AtomicInteger failedReplaceCounter2 = new AtomicInteger(0);
private static final AtomicInteger errorsCounter = new AtomicInteger(0);
private static final AtomicInteger successfulListenerWrites = new AtomicInteger(0);
private static final AtomicInteger successfulListenerWrites2 = new AtomicInteger(0);
//private static Map<String, EntryInfo> state = new HashMap<>();
private static Map<String, AtomicInteger> removalCounts = new ConcurrentHashMap<>();
private static final UUID CLIENT_1_UUID = UUID.randomUUID();
@ -78,12 +82,16 @@ public class ConcurrencyJDGRemoveSessionTest {
// Create caches, listeners and finally worker threads
Thread worker1 = createWorker(cache1, 1);
Thread worker2 = createWorker(cache2, 2);
Thread worker3 = createWorker(cache1, 1);
Thread worker4 = createWorker(cache2, 2);
// Create 100 initial sessions
for (int i=0 ; i<ITERATIONS ; i++) {
String sessionId = String.valueOf(i);
SessionEntityWrapper<UserSessionEntity> wrappedSession = createSessionEntity(sessionId);
cache1.put(sessionId, wrappedSession);
removalCounts.put(sessionId, new AtomicInteger(0));
}
logger.info("SESSIONS CREATED");
@ -101,25 +109,44 @@ public class ConcurrencyJDGRemoveSessionTest {
long start = System.currentTimeMillis();
try {
// Just running in current thread
worker1.run();
worker1.start();
worker2.start();
worker3.start();
worker4.start();
worker1.join();
worker2.join();
worker3.join();
worker4.join();
logger.info("SESSIONS REMOVED");
Map<Integer, Integer> histogram = new HashMap<>();
for (Map.Entry<String, AtomicInteger> entry : removalCounts.entrySet()) {
int count = entry.getValue().get();
int current = histogram.get(count) == null ? 0 : histogram.get(count);
current++;
histogram.put(count, current);
}
logger.infof("Histogram: %s", histogram.toString());
logger.infof("Errors: %d", errorsCounter.get());
//Thread.sleep(5000);
// Doing it in opposite direction to ensure that newer are checked first.
// This us currently FAILING (expected) as listeners are executed asynchronously.
for (int i=ITERATIONS-1 ; i>=0 ; i--) {
String sessionId = String.valueOf(i);
logger.infof("Before call cache2.get: %s", sessionId);
SessionEntityWrapper loadedWrapper = cache2.get(sessionId);
Assert.assertNull("Loaded wrapper not null for key " + sessionId, loadedWrapper);
}
logger.info("SESSIONS NOT AVAILABLE ON DC2");
// for (int i=ITERATIONS-1 ; i>=0 ; i--) {
// String sessionId = String.valueOf(i);
//
// logger.infof("Before call cache2.get: %s", sessionId);
//
// SessionEntityWrapper loadedWrapper = cache2.get(sessionId);
// Assert.assertNull("Loaded wrapper not null for key " + sessionId, loadedWrapper);
// }
//
// logger.info("SESSIONS NOT AVAILABLE ON DC2");
long took = System.currentTimeMillis() - start;
logger.infof("took %d ms", took);
@ -271,19 +298,30 @@ public class ConcurrencyJDGRemoveSessionTest {
for (int i=0 ; i<ITERATIONS ; i++) {
String sessionId = String.valueOf(i);
remoteCache.remove(sessionId);
try {
Object o = remoteCache
.withFlags(org.infinispan.client.hotrod.Flag.FORCE_RETURN_VALUE)
.remove(sessionId);
logger.infof("Session %s removed on DC1", sessionId);
// Check if it's immediately seen that session is removed on 2nd DC
RemoteCache secondDCRemoteCache = myThreadId == 1 ? remoteCache2 : remoteCache1;
SessionEntityWrapper thatSession = (SessionEntityWrapper) secondDCRemoteCache.get(sessionId);
Assert.assertNull("Session with ID " + sessionId + " not removed on the other DC. ThreadID: " + myThreadId, thatSession);
// Also check that it's immediatelly removed on my DC
SessionEntityWrapper mySession = (SessionEntityWrapper) remoteCache.get(sessionId);
Assert.assertNull("Session with ID " + sessionId + " not removed on the other DC. ThreadID: " + myThreadId, mySession);
if (o != null) {
removalCounts.get(sessionId).incrementAndGet();
}
} catch (HotRodClientException hrce) {
errorsCounter.incrementAndGet();
}
//
//
// logger.infof("Session %s removed on DC1", sessionId);
//
// // Check if it's immediately seen that session is removed on 2nd DC
// RemoteCache secondDCRemoteCache = myThreadId == 1 ? remoteCache2 : remoteCache1;
// SessionEntityWrapper thatSession = (SessionEntityWrapper) secondDCRemoteCache.get(sessionId);
// Assert.assertNull("Session with ID " + sessionId + " not removed on the other DC. ThreadID: " + myThreadId, thatSession);
//
// // Also check that it's immediatelly removed on my DC
// SessionEntityWrapper mySession = (SessionEntityWrapper) remoteCache.get(sessionId);
// Assert.assertNull("Session with ID " + sessionId + " not removed on the other DC. ThreadID: " + myThreadId, mySession);
}
}

View file

@ -17,6 +17,7 @@
package org.keycloak.models;
import java.util.Map;
import java.util.UUID;
import org.keycloak.provider.Provider;
@ -30,5 +31,23 @@ import org.keycloak.provider.Provider;
*/
public interface CodeToTokenStoreProvider extends Provider {
boolean putIfAbsent(UUID codeId);
/**
* Stores the given data and guarantees that data should be available in the store for at least the time specified by {@param lifespanSeconds} parameter
* @param codeId
* @param lifespanSeconds
* @param codeData
* @return true if data were successfully put
*/
void put(UUID codeId, int lifespanSeconds, Map<String, String> codeData);
/**
* This method returns data just if removal was successful. Implementation should guarantee that "remove" is single-use. So if
* 2 threads (even on different cluster nodes or on different cross-dc nodes) calls "remove(123)" concurrently, then just one of them
* is allowed to succeed and return data back. It can't happen that both will succeed.
*
* @param codeId
* @return context data related to OAuth2 code. It returns null if there are not context data available.
*/
Map<String, String> remove(UUID codeId);
}

View file

@ -40,4 +40,11 @@ public interface ClientSessionContext {
Set<ProtocolMapperModel> getProtocolMappers();
String getScopeString();
void setAttribute(String name, Object value);
<T> T getAttribute(String attribute, Class<T> clazz);
String AUTHENTICATION_SESSION_ATTR = "AUTH_SESSION_ATTR";
}

View file

@ -126,8 +126,6 @@ public abstract class AuthorizationEndpointBase {
return response;
}
// Attach session once no requiredActions or other things are required
processor.attachSession();
} catch (Exception e) {
return processor.handleBrowserException(e);
}

View file

@ -29,7 +29,6 @@ import org.keycloak.events.EventType;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionContext;
import org.keycloak.models.TokenManager;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
@ -39,16 +38,21 @@ import org.keycloak.protocol.oidc.utils.OIDCResponseMode;
import org.keycloak.protocol.oidc.utils.OIDCResponseType;
import org.keycloak.representations.AccessTokenResponse;
import org.keycloak.representations.adapters.action.PushNotBeforeAction;
import org.keycloak.services.ErrorResponseException;
import org.keycloak.services.ServicesLogger;
import org.keycloak.services.managers.AuthenticationManager;
import org.keycloak.services.managers.AuthenticationSessionManager;
import org.keycloak.services.managers.ClientSessionCode;
import org.keycloak.protocol.oidc.utils.OAuth2Code;
import org.keycloak.protocol.oidc.utils.OAuth2CodeParser;
import org.keycloak.services.managers.ResourceAdminManager;
import org.keycloak.sessions.AuthenticationSessionModel;
import org.keycloak.sessions.CommonClientSessionModel;
import org.keycloak.util.TokenUtil;
import java.io.IOException;
import java.net.URI;
import java.util.UUID;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriBuilder;
@ -179,7 +183,6 @@ public class OIDCLoginProtocol implements LoginProtocol {
@Override
public Response authenticated(UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
AuthenticatedClientSessionModel clientSession= clientSessionCtx.getClientSession();
ClientSessionCode<AuthenticatedClientSessionModel> accessCode = new ClientSessionCode<>(session, realm, clientSession);
String responseTypeParam = clientSession.getNote(OIDCLoginProtocol.RESPONSE_TYPE_PARAM);
String responseModeParam = clientSession.getNote(OIDCLoginProtocol.RESPONSE_MODE_PARAM);
@ -197,10 +200,27 @@ public class OIDCLoginProtocol implements LoginProtocol {
redirectUri.addParam(OAuth2Constants.SESSION_STATE, userSession.getId());
}
AuthenticationSessionModel authSession = clientSessionCtx.getAttribute(ClientSessionContext.AUTHENTICATION_SESSION_ATTR, AuthenticationSessionModel.class);
if (authSession == null) {
// Shouldn't happen if correctly used
throw new IllegalStateException("AuthenticationSession attachement not set in the ClientSessionContext");
}
String nonce = authSession.getClientNote(OIDCLoginProtocol.NONCE_PARAM);
clientSessionCtx.setAttribute(OIDCLoginProtocol.NONCE_PARAM, nonce);
// Standard or hybrid flow
String code = null;
if (responseType.hasResponseType(OIDCResponseType.CODE)) {
code = accessCode.getOrGenerateCode();
OAuth2Code codeData = new OAuth2Code(UUID.randomUUID(),
Time.currentTime() + userSession.getRealm().getAccessCodeLifespan(),
nonce,
authSession.getClientNote(OAuth2Constants.SCOPE),
authSession.getClientNote(OIDCLoginProtocol.REDIRECT_URI_PARAM),
authSession.getClientNote(OIDCLoginProtocol.CODE_CHALLENGE_PARAM),
authSession.getClientNote(OIDCLoginProtocol.CODE_CHALLENGE_METHOD_PARAM));
code = OAuth2CodeParser.persistCode(session, clientSession, codeData);
redirectUri.addParam(OAuth2Constants.CODE, code);
}

View file

@ -180,6 +180,8 @@ public class TokenManager {
throw new OAuthErrorException(OAuthErrorException.INVALID_SCOPE, "Client no longer has requested consent from user");
}
clientSessionCtx.setAttribute(OIDCLoginProtocol.NONCE_PARAM, oldToken.getNonce());
// recreate token.
AccessToken newToken = createClientAccessToken(session, realm, client, user, userSession, clientSessionCtx);
verifyAccess(oldToken, newToken);
@ -433,7 +435,10 @@ public class TokenManager {
// Remove authentication session now
new AuthenticationSessionManager(session).removeAuthenticationSession(userSession.getRealm(), authSession, true);
return DefaultClientSessionContext.fromClientSessionAndClientScopeIds(clientSession, clientScopeIds);
ClientSessionContext clientSessionCtx = DefaultClientSessionContext.fromClientSessionAndClientScopeIds(clientSession, clientScopeIds);
clientSessionCtx.setAttribute(ClientSessionContext.AUTHENTICATION_SESSION_ATTR, authSession);
return clientSessionCtx;
}
@ -614,7 +619,7 @@ public class TokenManager {
AuthenticatedClientSessionModel clientSession = clientSessionCtx.getClientSession();
token.issuer(clientSession.getNote(OIDCLoginProtocol.ISSUER));
token.setNonce(clientSession.getNote(OIDCLoginProtocol.NONCE_PARAM));
token.setNonce(clientSessionCtx.getAttribute(OIDCLoginProtocol.NONCE_PARAM, String.class));
token.setScope(clientSessionCtx.getScopeString());
// Best effort for "acr" value. Use 0 if clientSession was authenticated through cookie ( SSO )

View file

@ -75,7 +75,8 @@ import org.keycloak.services.managers.AuthenticationManager;
import org.keycloak.services.managers.AuthenticationSessionManager;
import org.keycloak.services.managers.BruteForceProtector;
import org.keycloak.services.managers.ClientManager;
import org.keycloak.services.managers.ClientSessionCode;
import org.keycloak.protocol.oidc.utils.OAuth2Code;
import org.keycloak.protocol.oidc.utils.OAuth2CodeParser;
import org.keycloak.services.managers.RealmManager;
import org.keycloak.services.resources.Cors;
import org.keycloak.services.resources.IdentityBrokerService;
@ -275,8 +276,8 @@ public class TokenEndpoint {
throw new CorsErrorResponseException(cors, OAuthErrorException.INVALID_REQUEST, "Missing parameter: " + OAuth2Constants.CODE, Response.Status.BAD_REQUEST);
}
ClientSessionCode.ParseResult<AuthenticatedClientSessionModel> parseResult = ClientSessionCode.parseResult(code, null, session, realm, client, event, AuthenticatedClientSessionModel.class);
if (parseResult.isAuthSessionNotFound() || parseResult.isIllegalHash()) {
OAuth2CodeParser.ParseResult parseResult = OAuth2CodeParser.parseCode(session, code, realm, event);
if (parseResult.isIllegalCode()) {
AuthenticatedClientSessionModel clientSession = parseResult.getClientSession();
// Attempt to use same code twice should invalidate existing clientSession
@ -291,7 +292,7 @@ public class TokenEndpoint {
AuthenticatedClientSessionModel clientSession = parseResult.getClientSession();
if (parseResult.isExpiredToken()) {
if (parseResult.isExpiredCode()) {
event.error(Errors.EXPIRED_CODE);
throw new CorsErrorResponseException(cors, OAuthErrorException.INVALID_GRANT, "Code is expired", Response.Status.BAD_REQUEST);
}
@ -317,7 +318,8 @@ public class TokenEndpoint {
throw new CorsErrorResponseException(cors, OAuthErrorException.INVALID_GRANT, "User disabled", Response.Status.BAD_REQUEST);
}
String redirectUri = clientSession.getNote(OIDCLoginProtocol.REDIRECT_URI_PARAM);
OAuth2Code codeData = parseResult.getCodeData();
String redirectUri = codeData.getRedirectUriParam();
String redirectUriParam = formParams.getFirst(OAuth2Constants.REDIRECT_URI);
// KEYCLOAK-4478 Backwards compatibility with the adapters earlier than KC 3.4.2
@ -349,8 +351,8 @@ public class TokenEndpoint {
// https://tools.ietf.org/html/rfc7636#section-4.6
String codeVerifier = formParams.getFirst(OAuth2Constants.CODE_VERIFIER);
String codeChallenge = clientSession.getNote(OIDCLoginProtocol.CODE_CHALLENGE_PARAM);
String codeChallengeMethod = clientSession.getNote(OIDCLoginProtocol.CODE_CHALLENGE_METHOD_PARAM);
String codeChallenge = codeData.getCodeChallenge();
String codeChallengeMethod = codeData.getCodeChallengeMethod();
String authUserId = user.getId();
String authUsername = user.getUsername();
if (authUserId == null) {
@ -406,7 +408,7 @@ public class TokenEndpoint {
// Compute client scopes again from scope parameter. Check if user still has them granted
// (but in code-to-token request, it could just theoretically happen that they are not available)
String scopeParam = clientSession.getNote(OAuth2Constants.SCOPE);
String scopeParam = codeData.getScope();
Set<ClientScopeModel> clientScopes = TokenManager.getRequestedClientScopes(scopeParam, client);
if (!TokenManager.verifyConsentStillAvailable(session, user, client, clientScopes)) {
event.error(Errors.NOT_ALLOWED);
@ -415,6 +417,9 @@ public class TokenEndpoint {
ClientSessionContext clientSessionCtx = DefaultClientSessionContext.fromClientSessionAndClientScopes(clientSession, clientScopes);
// Set nonce as an attribute in the ClientSessionContext. Will be used for the token generation
clientSessionCtx.setAttribute(OIDCLoginProtocol.NONCE_PARAM, codeData.getNonce());
AccessToken token = tokenManager.createClientAccessToken(session, realm, client, user, userSession, clientSessionCtx);
TokenManager.AccessTokenResponseBuilder responseBuilder = tokenManager.responseBuilder(realm, client, event, session, userSession, clientSessionCtx)

View file

@ -0,0 +1,128 @@
/*
* Copyright 2017 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.protocol.oidc.utils;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
/**
* Data associated with the oauth2 code.
*
* Those data are typically valid just for the very short time - they're created at the point before we redirect to the application
* after successful and they're removed when application sends requests to the token endpoint (code-to-token endpoint) to exchange the
* single-use OAuth2 code parameter for those data.
*
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public class OAuth2Code {
private static final String ID_NOTE = "id";
private static final String EXPIRATION_NOTE = "exp";
private static final String NONCE_NOTE = "nonce";
private static final String SCOPE_NOTE = "scope";
private static final String REDIRECT_URI_PARAM_NOTE = "redirectUri";
private static final String CODE_CHALLENGE_NOTE = "code_challenge";
private static final String CODE_CHALLENGE_METHOD_NOTE = "code_challenge_method";
private final UUID id;
private final int expiration;
private final String nonce;
private final String scope;
private final String redirectUriParam;
private final String codeChallenge;
private final String codeChallengeMethod;
public OAuth2Code(UUID id, int expiration, String nonce, String scope, String redirectUriParam,
String codeChallenge, String codeChallengeMethod) {
this.id = id;
this.expiration = expiration;
this.nonce = nonce;
this.scope = scope;
this.redirectUriParam = redirectUriParam;
this.codeChallenge = codeChallenge;
this.codeChallengeMethod = codeChallengeMethod;
}
private OAuth2Code(Map<String, String> data) {
id = UUID.fromString(data.get(ID_NOTE));
expiration = Integer.parseInt(data.get(EXPIRATION_NOTE));
nonce = data.get(NONCE_NOTE);
scope = data.get(SCOPE_NOTE);
redirectUriParam = data.get(REDIRECT_URI_PARAM_NOTE);
codeChallenge = data.get(CODE_CHALLENGE_NOTE);
codeChallengeMethod = data.get(CODE_CHALLENGE_METHOD_NOTE);
}
public static final OAuth2Code deserializeCode(Map<String, String> data) {
return new OAuth2Code(data);
}
public Map<String, String> serializeCode() {
Map<String, String> result = new HashMap<>();
result.put(ID_NOTE, id.toString());
result.put(EXPIRATION_NOTE, String.valueOf(expiration));
result.put(NONCE_NOTE, nonce);
result.put(SCOPE_NOTE, scope);
result.put(REDIRECT_URI_PARAM_NOTE, redirectUriParam);
result.put(CODE_CHALLENGE_NOTE, codeChallenge);
result.put(CODE_CHALLENGE_METHOD_NOTE, codeChallengeMethod);
return result;
}
public UUID getId() {
return id;
}
public int getExpiration() {
return expiration;
}
public String getNonce() {
return nonce;
}
public String getScope() {
return scope;
}
public String getRedirectUriParam() {
return redirectUriParam;
}
public String getCodeChallenge() {
return codeChallenge;
}
public String getCodeChallengeMethod() {
return codeChallengeMethod;
}
}

View file

@ -0,0 +1,195 @@
/*
* Copyright 2017 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.protocol.oidc.utils;
import java.util.Map;
import java.util.UUID;
import java.util.regex.Pattern;
import org.jboss.logging.Logger;
import org.keycloak.common.util.Time;
import org.keycloak.events.Details;
import org.keycloak.events.EventBuilder;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.CodeToTokenStoreProvider;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.services.managers.UserSessionCrossDCManager;
/**
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public class OAuth2CodeParser {
private static final Logger logger = Logger.getLogger(OAuth2CodeParser.class);
private static final Pattern DOT = Pattern.compile("\\.");
/**
* Will persist the code to the cache and return the object with the codeData and code correctly set
*
* @param session
* @param clientSession
* @param codeData
* @return code parameter to be used in OAuth2 handshake
*/
public static String persistCode(KeycloakSession session, AuthenticatedClientSessionModel clientSession, OAuth2Code codeData) {
CodeToTokenStoreProvider codeStore = session.getProvider(CodeToTokenStoreProvider.class);
UUID key = codeData.getId();
if (key == null) {
throw new IllegalStateException("ID not present in the data");
}
Map<String, String> serialized = codeData.serializeCode();
codeStore.put(key, clientSession.getUserSession().getRealm().getAccessCodeLifespan(), serialized);
return key.toString() + "." + clientSession.getUserSession().getId() + "." + clientSession.getClient().getId();
}
/**
* Will parse the code and retrieve the corresponding OAuth2Code and AuthenticatedClientSessionModel. Will also check if code wasn't already
* used and if it wasn't expired. If it was already used (or other error happened during parsing), then returned parser will have "isIllegalHash"
* set to true. If it was expired, the parser will have "isExpired" set to true
*
* @param session
* @param code
* @param realm
* @param event
* @return
*/
public static ParseResult parseCode(KeycloakSession session, String code, RealmModel realm, EventBuilder event) {
ParseResult result = new ParseResult(code);
String[] parsed = DOT.split(code, 3);
if (parsed.length < 3) {
logger.warn("Invalid format of the code");
return result.illegalCode();
}
String userSessionId = parsed[1];
String clientUUID = parsed[2];
event.detail(Details.CODE_ID, userSessionId);
event.session(userSessionId);
// Parse UUID
UUID codeUUID;
try {
codeUUID = UUID.fromString(parsed[0]);
} catch (IllegalArgumentException re) {
logger.warn("Invalid format of the UUID in the code");
return result.illegalCode();
}
// Retrieve UserSession
UserSessionModel userSession = new UserSessionCrossDCManager(session).getUserSessionWithClient(realm, userSessionId, clientUUID);
if (userSession == null) {
// Needed to track if code is invalid or was already used.
userSession = session.sessions().getUserSession(realm, userSessionId);
if (userSession == null) {
return result.illegalCode();
}
}
result.clientSession = userSession.getAuthenticatedClientSessionByClient(clientUUID);
CodeToTokenStoreProvider codeStore = session.getProvider(CodeToTokenStoreProvider.class);
Map<String, String> codeData = codeStore.remove(codeUUID);
// Either code not available or was already used
if (codeData == null) {
logger.warnf("Code '%s' already used for userSession '%s' and client '%s'.", codeUUID, userSessionId, clientUUID);
return result.illegalCode();
}
logger.tracef("Successfully verified code '%s'. User session: '%s', client: '%s'", codeUUID, userSessionId, clientUUID);
result.codeData = OAuth2Code.deserializeCode(codeData);
// Finally doublecheck if code is not expired
int currentTime = Time.currentTime();
if (currentTime > result.codeData.getExpiration()) {
return result.expiredCode();
}
return result;
}
public static class ParseResult {
private final String code;
private OAuth2Code codeData;
private AuthenticatedClientSessionModel clientSession;
private boolean isIllegalCode = false;
private boolean isExpiredCode = false;
private ParseResult(String code, OAuth2Code codeData, AuthenticatedClientSessionModel clientSession) {
this.code = code;
this.codeData = codeData;
this.clientSession = clientSession;
this.isIllegalCode = false;
this.isExpiredCode = false;
}
private ParseResult(String code) {
this.code = code;
}
public String getCode() {
return code;
}
public OAuth2Code getCodeData() {
return codeData;
}
public AuthenticatedClientSessionModel getClientSession() {
return clientSession;
}
public boolean isIllegalCode() {
return isIllegalCode;
}
public boolean isExpiredCode() {
return isExpiredCode;
}
private ParseResult illegalCode() {
this.isIllegalCode = true;
return this;
}
private ParseResult expiredCode() {
this.isExpiredCode = true;
return this;
}
}
}

View file

@ -38,7 +38,6 @@ import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.representations.CodeJWT;
import org.keycloak.sessions.CommonClientSessionModel;
import org.keycloak.sessions.AuthenticationSessionModel;
import org.keycloak.util.TokenUtil;
@ -60,10 +59,6 @@ class CodeGenerateUtil {
PARSERS.put(AuthenticationSessionModel.class, () -> {
return new AuthenticationSessionModelParser();
});
PARSERS.put(AuthenticatedClientSessionModel.class, () -> {
return new AuthenticatedClientSessionModelParser();
});
}
@ -166,119 +161,4 @@ class CodeGenerateUtil {
}
private static class AuthenticatedClientSessionModelParser implements ClientSessionParser<AuthenticatedClientSessionModel> {
private CodeJWT codeJWT;
@Override
public AuthenticatedClientSessionModel parseSession(String code, String tabId, KeycloakSession session, RealmModel realm, ClientModel client, EventBuilder event) {
SecretKey aesKey = session.keys().getActiveAesKey(realm).getSecretKey();
SecretKey hmacKey = session.keys().getActiveHmacKey(realm).getSecretKey();
try {
codeJWT = TokenUtil.jweDirectVerifyAndDecode(aesKey, hmacKey, code, CodeJWT.class);
} catch (JWEException jweException) {
logger.error("Exception during JWE Verification or decode", jweException);
return null;
}
event.detail(Details.CODE_ID, codeJWT.getUserSessionId());
event.session(codeJWT.getUserSessionId());
UserSessionModel userSession = new UserSessionCrossDCManager(session).getUserSessionWithClient(realm, codeJWT.getUserSessionId(), codeJWT.getIssuedFor());
if (userSession == null) {
// TODO:mposolda Temporary workaround needed to track if code is invalid or was already used. Will be good to remove once used OAuth codes are tracked through one-time cache
userSession = session.sessions().getUserSession(realm, codeJWT.getUserSessionId());
if (userSession == null) {
return null;
}
}
return userSession.getAuthenticatedClientSessionByClient(codeJWT.getIssuedFor());
}
@Override
public String retrieveCode(KeycloakSession session, AuthenticatedClientSessionModel clientSession) {
String actionId = KeycloakModelUtils.generateId();
CodeJWT codeJWT = new CodeJWT();
codeJWT.id(actionId);
codeJWT.issuedFor(clientSession.getClient().getId());
codeJWT.userSessionId(clientSession.getUserSession().getId());
RealmModel realm = clientSession.getRealm();
int issuedAt = Time.currentTime();
codeJWT.issuedAt(issuedAt);
codeJWT.expiration(issuedAt + realm.getAccessCodeLifespan());
SecretKey aesKey = session.keys().getActiveAesKey(realm).getSecretKey();
SecretKey hmacKey = session.keys().getActiveHmacKey(realm).getSecretKey();
if (logger.isTraceEnabled()) {
logger.tracef("Using AES key of length '%d' bytes and HMAC key of length '%d' bytes . Client: '%s', User Session: '%s'", aesKey.getEncoded().length,
hmacKey.getEncoded().length, clientSession.getClient().getClientId(), clientSession.getUserSession().getId());
}
try {
return TokenUtil.jweDirectEncode(aesKey, hmacKey, codeJWT);
} catch (JWEException jweEx) {
throw new RuntimeException(jweEx);
}
}
@Override
public boolean verifyCode(KeycloakSession session, String code, AuthenticatedClientSessionModel clientSession) {
if (codeJWT == null) {
throw new IllegalStateException("Illegal use. codeJWT not yet set");
}
UUID codeId = UUID.fromString(codeJWT.getId());
CodeToTokenStoreProvider singleUseCache = session.getProvider(CodeToTokenStoreProvider.class);
if (singleUseCache.putIfAbsent(codeId)) {
if (logger.isTraceEnabled()) {
logger.tracef("Added code '%s' to single-use cache. User session: %s, client: %s", codeJWT.getId(), codeJWT.getUserSessionId(), codeJWT.getIssuedFor());
}
return true;
} else {
logger.warnf("Code '%s' already used for userSession '%s' and client '%s'.", codeJWT.getId(), codeJWT.getUserSessionId(), codeJWT.getIssuedFor());
return false;
}
}
@Override
public void removeExpiredSession(KeycloakSession session, AuthenticatedClientSessionModel clientSession) {
throw new IllegalStateException("Not yet implemented");
}
@Override
public boolean isExpired(KeycloakSession session, String code, AuthenticatedClientSessionModel clientSession) {
return !codeJWT.isActive();
}
@Override
public int getTimestamp(AuthenticatedClientSessionModel clientSession) {
return clientSession.getTimestamp();
}
@Override
public void setTimestamp(AuthenticatedClientSessionModel clientSession, int timestamp) {
clientSession.setTimestamp(timestamp);
}
@Override
public String getClientNote(AuthenticatedClientSessionModel clientSession, String noteKey) {
return clientSession.getNote(noteKey);
}
}
}

View file

@ -17,7 +17,9 @@
package org.keycloak.services.util;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.jboss.logging.Logger;
@ -56,6 +58,8 @@ public class DefaultClientSessionContext implements ClientSessionContext {
// All roles of user expanded. It doesn't yet take into account permitted clientScopes
private Set<RoleModel> userRoles;
private Map<String, Object> attributes = new HashMap<>();
private DefaultClientSessionContext(AuthenticatedClientSessionModel clientSession, Set<String> clientScopeIds) {
this.clientSession = clientSession;
this.clientScopeIds = clientScopeIds;
@ -177,6 +181,19 @@ public class DefaultClientSessionContext implements ClientSessionContext {
}
@Override
public void setAttribute(String name, Object value) {
attributes.put(name, value);
}
@Override
public <T> T getAttribute(String name, Class<T> clazz) {
Object value = attributes.get(name);
return clazz.cast(value);
}
// Loading data
private Set<ClientScopeModel> loadClientScopes() {

View file

@ -760,6 +760,10 @@ public class OAuthClient {
return redirectUri;
}
public String getNonce() {
return nonce;
}
public String getLoginFormUrl() {
UriBuilder b = OIDCLoginProtocolService.authUrl(UriBuilder.fromUri(baseUrl));
if (responseType != null) {

View file

@ -48,6 +48,8 @@ import org.keycloak.OAuth2Constants;
import org.keycloak.admin.client.Keycloak;
import org.keycloak.admin.client.resource.ClientsResource;
import org.keycloak.admin.client.resource.RealmResource;
import org.keycloak.jose.jws.JWSInput;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.representations.AccessToken;
import org.keycloak.representations.idm.ClientRepresentation;
import org.keycloak.common.util.Retry;
@ -56,12 +58,14 @@ import org.keycloak.testsuite.util.ClientBuilder;
import org.keycloak.testsuite.util.OAuthClient;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.apache.http.client.CookieStore;
import org.apache.http.impl.client.BasicCookieStore;
import org.hamcrest.Matchers;
import org.keycloak.util.JsonSerialization;
import static org.hamcrest.Matchers.containsString;
@ -321,6 +325,10 @@ public class ConcurrentLoginTest extends AbstractConcurrencyTest {
protected OAuthClient initialValue() {
OAuthClient oauth1 = new OAuthClient();
oauth1.init(driver);
// Add some randomness to nonce and redirectUri. Verify that login is successful and nonce will match
oauth1.nonce(KeycloakModelUtils.generateId());
oauth1.redirectUri(oauth.getRedirectUri() + "?some=" + new Random().nextInt(1024));
return oauth1;
}
};
@ -375,16 +383,25 @@ public class ConcurrentLoginTest extends AbstractConcurrencyTest {
accessResRef.set(accessRes);
// Refresh access + refresh token using refresh token
AtomicReference<OAuthClient.AccessTokenResponse> refreshResRef = new AtomicReference<>();
int invocationIndex = Retry.execute(() -> {
OAuthClient.AccessTokenResponse refreshRes = oauth1.doRefreshTokenRequest(accessResRef.get().getRefreshToken(), "password");
Assert.assertEquals("AccessTokenResponse: client: " + oauth1.getClientId() + ", error: '" + refreshRes.getError() + "' desc: '" + refreshRes.getErrorDescription() + "'",
200, refreshRes.getStatusCode());
refreshResRef.set(refreshRes);
}, retryCount, retryDelayMs);
retryHistogram[invocationIndex].incrementAndGet();
AccessToken token = JsonSerialization.readValue(new JWSInput(accessResRef.get().getAccessToken()).getContent(), AccessToken.class);
Assert.assertEquals("Invalid nonce.", token.getNonce(), oauth1.getNonce());
AccessToken refreshedToken = JsonSerialization.readValue(new JWSInput(refreshResRef.get().getAccessToken()).getContent(), AccessToken.class);
Assert.assertEquals("Invalid nonce.", refreshedToken.getNonce(), oauth1.getNonce());
if (userSessionId.get() == null) {
AccessToken token = oauth1.verifyToken(accessResRef.get().getAccessToken());
userSessionId.set(token.getSessionState());
}
}

View file

@ -86,7 +86,7 @@ public class FallbackKeyProviderTest extends AbstractKeycloakTest {
Assert.assertEquals(AppPage.RequestType.AUTH_RESPONSE, appPage.getRequestType());
providers = realmsResouce().realm("test").components().query(realmId, "org.keycloak.keys.KeyProvider");
assertProviders(providers, "fallback-RS256", "fallback-HS256", "fallback-AES");
assertProviders(providers, "fallback-RS256", "fallback-HS256");
}
@Test

View file

@ -136,6 +136,7 @@ public class RefreshTokenTest extends AbstractKeycloakTest {
@Test
public void refreshTokenRequest() throws Exception {
oauth.nonce("123456");
oauth.doLogin("test-user@localhost", "password");
EventRepresentation loginEvent = events.expectLogin().assertEvent();
@ -147,6 +148,8 @@ public class RefreshTokenTest extends AbstractKeycloakTest {
OAuthClient.AccessTokenResponse tokenResponse = oauth.doAccessTokenRequest(code, "password");
AccessToken token = oauth.verifyToken(tokenResponse.getAccessToken());
assertEquals("123456", token.getNonce());
String refreshTokenString = tokenResponse.getRefreshToken();
RefreshToken refreshToken = oauth.parseRefreshToken(refreshTokenString);
@ -200,6 +203,8 @@ public class RefreshTokenTest extends AbstractKeycloakTest {
Assert.assertNotEquals(tokenEvent.getDetails().get(Details.TOKEN_ID), refreshEvent.getDetails().get(Details.TOKEN_ID));
Assert.assertNotEquals(tokenEvent.getDetails().get(Details.REFRESH_TOKEN_ID), refreshEvent.getDetails().get(Details.UPDATED_REFRESH_TOKEN_ID));
assertEquals("123456", refreshedToken.getNonce());
setTimeOffset(0);
}
@Test