From 1b3a76d0afd3cafee6dbf47516d6d5a40498bf97 Mon Sep 17 00:00:00 2001 From: vramik Date: Mon, 6 Jun 2022 15:52:00 +0200 Subject: [PATCH] Do not persist client sessions of transient user sessions Closes #12357 --- .../userSession/MapUserSessionProvider.java | 64 +++++++++++++------ .../session/UserSessionProviderModelTest.java | 28 ++++++++ 2 files changed, 72 insertions(+), 20 deletions(-) diff --git a/model/map/src/main/java/org/keycloak/models/map/userSession/MapUserSessionProvider.java b/model/map/src/main/java/org/keycloak/models/map/userSession/MapUserSessionProvider.java index c254e604bd..e707fe040c 100644 --- a/model/map/src/main/java/org/keycloak/models/map/userSession/MapUserSessionProvider.java +++ b/model/map/src/main/java/org/keycloak/models/map/userSession/MapUserSessionProvider.java @@ -68,6 +68,10 @@ public class MapUserSessionProvider implements UserSessionProvider { * Storage for transient user sessions which lifespan is limited to one request. */ private final Map transientUserSessions = new HashMap<>(); + /** + * Storage for client sessions where parent is transient user session. Lifespan is limited to one request. + */ + private final Map transientClientSessions = new HashMap<>(); public MapUserSessionProvider(KeycloakSession session, MapStorage userSessionStore, MapStorage clientSessionStore) { @@ -82,12 +86,14 @@ public class MapUserSessionProvider implements UserSessionProvider { private Function userEntityToAdapterFunc(RealmModel realm) { // Clone entity before returning back, to avoid giving away a reference to the live object to the caller return (origEntity) -> { + if (origEntity == null) return null; long expiration = origEntity.getExpiration() != null ? origEntity.getExpiration() : 0L; if (expiration <= Time.currentTimeMillis()) { - if (Objects.equals(origEntity.getPersistenceState(), TRANSIENT)) { + if (TRANSIENT == origEntity.getPersistenceState()) { transientUserSessions.remove(origEntity.getId()); + } else { + userSessionTx.delete(origEntity.getId()); } - userSessionTx.delete(origEntity.getId()); return null; } else { return new MapUserSessionAdapter(session, realm, origEntity) { @@ -112,10 +118,14 @@ public class MapUserSessionProvider implements UserSessionProvider { UserSessionModel userSession) { // Clone entity before returning back, to avoid giving away a reference to the live object to the caller return origEntity -> { + if (origEntity == null) return null; long expiration = origEntity.getExpiration() != null ? origEntity.getExpiration() : 0L; if (expiration <= Time.currentTimeMillis()) { userSession.removeAuthenticatedClientSessions(Arrays.asList(origEntity.getClientId())); - clientSessionTx.delete(origEntity.getId()); + // if a client session is found among transient ones we can skip call to store + if (transientClientSessions.remove(origEntity.getId()) == null) { + clientSessionTx.delete(origEntity.getId()); + } return null; } else { return new MapAuthenticatedClientSessionAdapter(session, realm, client, userSession, origEntity) { @@ -123,7 +133,10 @@ public class MapUserSessionProvider implements UserSessionProvider { public void detachFromUserSession() { this.userSession = null; - clientSessionTx.delete(entity.getId()); + // if a client session is found among transient ones we can skip call to store + if (transientClientSessions.remove(entity.getId()) == null) { + clientSessionTx.delete(entity.getId()); + } } @Override @@ -144,20 +157,27 @@ public class MapUserSessionProvider implements UserSessionProvider { @Override public AuthenticatedClientSessionModel createClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession) { + LOG.tracef("createClientSession(%s, %s, %s)%s", realm, client, userSession, getShortStackTrace()); + + MapUserSessionEntity userSessionEntity = getUserSessionById(userSession.getId()); + + if (userSessionEntity == null) { + throw new IllegalStateException("User session entity does not exist: " + userSession.getId()); + } + MapAuthenticatedClientSessionEntity entity = createAuthenticatedClientSessionEntityInstance(null, userSession.getId(), realm.getId(), client.getId(), false); String started = entity.getTimestamp() != null ? String.valueOf(TimeAdapter.fromMilliSecondsToSeconds(entity.getTimestamp())) : String.valueOf(0); entity.setNote(AuthenticatedClientSessionModel.STARTED_AT_NOTE, started); setClientSessionExpiration(entity, realm, client); - LOG.tracef("createClientSession(%s, %s, %s)%s", realm, client, userSession, getShortStackTrace()); - - entity = clientSessionTx.create(entity); - - MapUserSessionEntity userSessionEntity = getUserSessionById(userSession.getId()); - - if (userSessionEntity == null) { - throw new IllegalStateException("User session entity does not exist: " + userSession.getId()); + if (TRANSIENT == userSessionEntity.getPersistenceState()) { + if (entity.getId() == null) { + entity.setId(UUID.randomUUID().toString()); + } + transientClientSessions.put(entity.getId(), entity); + } else { + entity = clientSessionTx.create(entity); } userSessionEntity.setAuthenticatedClientSession(client.getId(), entity.getId()); @@ -177,6 +197,12 @@ public class MapUserSessionProvider implements UserSessionProvider { return null; } + MapAuthenticatedClientSessionEntity entity = transientClientSessions.get(clientSessionId); + + if (entity != null) { + return clientEntityToAdapterFunc(client.getRealm(), client, userSession).apply(entity); + } + DefaultModelCriteria mcb = criteria(); mcb = mcb.compare(AuthenticatedClientSessionModel.SearchableFields.ID, Operator.EQ, clientSessionId) .compare(AuthenticatedClientSessionModel.SearchableFields.USER_SESSION_ID, Operator.EQ, userSession.getId()) @@ -204,20 +230,18 @@ public class MapUserSessionProvider implements UserSessionProvider { String brokerUserId, UserSessionModel.SessionPersistenceState persistenceState) { LOG.tracef("createUserSession(%s, %s, %s, %s)%s", id, realm, loginUsername, persistenceState, getShortStackTrace()); - MapUserSessionEntity entity; - if (Objects.equals(persistenceState, TRANSIENT)) { - if (id == null) { - id = UUID.randomUUID().toString(); - } - entity = createUserSessionEntityInstance(id, realm.getId(), user.getId(), loginUsername, ipAddress, authMethod, + MapUserSessionEntity entity = createUserSessionEntityInstance(id, realm.getId(), user.getId(), loginUsername, ipAddress, authMethod, rememberMe, brokerSessionId, brokerUserId, false); + + if (TRANSIENT == persistenceState) { + if (id == null) { + entity.setId(UUID.randomUUID().toString()); + } transientUserSessions.put(entity.getId(), entity); } else { if (id != null && userSessionTx.read(id) != null) { throw new ModelDuplicateException("User session exists: " + id); } - entity = createUserSessionEntityInstance(id, realm.getId(), user.getId(), loginUsername, ipAddress, authMethod, - rememberMe, brokerSessionId, brokerUserId, false); entity = userSessionTx.create(entity); } diff --git a/testsuite/model/src/test/java/org/keycloak/testsuite/model/session/UserSessionProviderModelTest.java b/testsuite/model/src/test/java/org/keycloak/testsuite/model/session/UserSessionProviderModelTest.java index dc2db81c71..e522ddc9dd 100644 --- a/testsuite/model/src/test/java/org/keycloak/testsuite/model/session/UserSessionProviderModelTest.java +++ b/testsuite/model/src/test/java/org/keycloak/testsuite/model/session/UserSessionProviderModelTest.java @@ -43,6 +43,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; +import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.MatcherAssert.assertThat; import static org.keycloak.testsuite.model.session.UserSessionPersisterProviderTest.createClients; @@ -221,4 +222,31 @@ public class UserSessionProviderModelTest extends KeycloakModelTest { assertThat(withRealm(realmId, (session, realm) -> session.sessions().getClientSession(origSessions[0], realm.getClientByClientId("test-app"), testAppClientSessionId, false)), nullValue()); } + + @Test + public void testClientSessionIsNotPersistedForTransientUserSession() { + Object[] transientUserSessionWithClientSessionId = inComittedTransaction(session -> { + RealmModel realm = session.realms().getRealm(realmId); + UserSessionModel userSession = session.sessions().createUserSession(null, realm, session.users().getUserByUsername(realm, "user1"), "user1", "127.0.0.1", "form", false, null, null, UserSessionModel.SessionPersistenceState.TRANSIENT); + + ClientModel testApp = realm.getClientByClientId("test-app"); + AuthenticatedClientSessionModel clientSession = session.sessions().createClientSession(realm, testApp, userSession); + + // assert the client sessions are present + assertThat(session.sessions().getClientSession(userSession, testApp, clientSession.getId(), false), notNullValue()); + Object[] result = new Object[2]; + result[0] = userSession; + result[1] = clientSession.getId(); + return result; + }); + + inComittedTransaction(session -> { + RealmModel realm = session.realms().getRealm(realmId); + ClientModel testApp = realm.getClientByClientId("test-app"); + UserSessionModel userSession = (UserSessionModel) transientUserSessionWithClientSessionId[0]; + String clientSessionId = (String) transientUserSessionWithClientSessionId[1]; + // in new transaction transient session should not be present + assertThat(session.sessions().getClientSession(userSession, testApp, clientSessionId, false), nullValue()); + }); + } }