Do not persist client sessions of transient user sessions

Closes #12357
This commit is contained in:
vramik 2022-06-06 15:52:00 +02:00 committed by Hynek Mlnařík
parent e856a62fb2
commit 1b3a76d0af
2 changed files with 72 additions and 20 deletions

View file

@ -68,6 +68,10 @@ public class MapUserSessionProvider implements UserSessionProvider {
* Storage for transient user sessions which lifespan is limited to one request.
*/
private final Map<String, MapUserSessionEntity> transientUserSessions = new HashMap<>();
/**
* Storage for client sessions where parent is transient user session. Lifespan is limited to one request.
*/
private final Map<String, MapAuthenticatedClientSessionEntity> transientClientSessions = new HashMap<>();
public MapUserSessionProvider(KeycloakSession session, MapStorage<MapUserSessionEntity, UserSessionModel> userSessionStore,
MapStorage<MapAuthenticatedClientSessionEntity, AuthenticatedClientSessionModel> clientSessionStore) {
@ -82,12 +86,14 @@ public class MapUserSessionProvider implements UserSessionProvider {
private Function<MapUserSessionEntity, UserSessionModel> userEntityToAdapterFunc(RealmModel realm) {
// Clone entity before returning back, to avoid giving away a reference to the live object to the caller
return (origEntity) -> {
if (origEntity == null) return null;
long expiration = origEntity.getExpiration() != null ? origEntity.getExpiration() : 0L;
if (expiration <= Time.currentTimeMillis()) {
if (Objects.equals(origEntity.getPersistenceState(), TRANSIENT)) {
if (TRANSIENT == origEntity.getPersistenceState()) {
transientUserSessions.remove(origEntity.getId());
} else {
userSessionTx.delete(origEntity.getId());
}
userSessionTx.delete(origEntity.getId());
return null;
} else {
return new MapUserSessionAdapter(session, realm, origEntity) {
@ -112,10 +118,14 @@ public class MapUserSessionProvider implements UserSessionProvider {
UserSessionModel userSession) {
// Clone entity before returning back, to avoid giving away a reference to the live object to the caller
return origEntity -> {
if (origEntity == null) return null;
long expiration = origEntity.getExpiration() != null ? origEntity.getExpiration() : 0L;
if (expiration <= Time.currentTimeMillis()) {
userSession.removeAuthenticatedClientSessions(Arrays.asList(origEntity.getClientId()));
clientSessionTx.delete(origEntity.getId());
// if a client session is found among transient ones we can skip call to store
if (transientClientSessions.remove(origEntity.getId()) == null) {
clientSessionTx.delete(origEntity.getId());
}
return null;
} else {
return new MapAuthenticatedClientSessionAdapter(session, realm, client, userSession, origEntity) {
@ -123,7 +133,10 @@ public class MapUserSessionProvider implements UserSessionProvider {
public void detachFromUserSession() {
this.userSession = null;
clientSessionTx.delete(entity.getId());
// if a client session is found among transient ones we can skip call to store
if (transientClientSessions.remove(entity.getId()) == null) {
clientSessionTx.delete(entity.getId());
}
}
@Override
@ -144,20 +157,27 @@ public class MapUserSessionProvider implements UserSessionProvider {
@Override
public AuthenticatedClientSessionModel createClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession) {
LOG.tracef("createClientSession(%s, %s, %s)%s", realm, client, userSession, getShortStackTrace());
MapUserSessionEntity userSessionEntity = getUserSessionById(userSession.getId());
if (userSessionEntity == null) {
throw new IllegalStateException("User session entity does not exist: " + userSession.getId());
}
MapAuthenticatedClientSessionEntity entity = createAuthenticatedClientSessionEntityInstance(null, userSession.getId(),
realm.getId(), client.getId(), false);
String started = entity.getTimestamp() != null ? String.valueOf(TimeAdapter.fromMilliSecondsToSeconds(entity.getTimestamp())) : String.valueOf(0);
entity.setNote(AuthenticatedClientSessionModel.STARTED_AT_NOTE, started);
setClientSessionExpiration(entity, realm, client);
LOG.tracef("createClientSession(%s, %s, %s)%s", realm, client, userSession, getShortStackTrace());
entity = clientSessionTx.create(entity);
MapUserSessionEntity userSessionEntity = getUserSessionById(userSession.getId());
if (userSessionEntity == null) {
throw new IllegalStateException("User session entity does not exist: " + userSession.getId());
if (TRANSIENT == userSessionEntity.getPersistenceState()) {
if (entity.getId() == null) {
entity.setId(UUID.randomUUID().toString());
}
transientClientSessions.put(entity.getId(), entity);
} else {
entity = clientSessionTx.create(entity);
}
userSessionEntity.setAuthenticatedClientSession(client.getId(), entity.getId());
@ -177,6 +197,12 @@ public class MapUserSessionProvider implements UserSessionProvider {
return null;
}
MapAuthenticatedClientSessionEntity entity = transientClientSessions.get(clientSessionId);
if (entity != null) {
return clientEntityToAdapterFunc(client.getRealm(), client, userSession).apply(entity);
}
DefaultModelCriteria<AuthenticatedClientSessionModel> mcb = criteria();
mcb = mcb.compare(AuthenticatedClientSessionModel.SearchableFields.ID, Operator.EQ, clientSessionId)
.compare(AuthenticatedClientSessionModel.SearchableFields.USER_SESSION_ID, Operator.EQ, userSession.getId())
@ -204,20 +230,18 @@ public class MapUserSessionProvider implements UserSessionProvider {
String brokerUserId, UserSessionModel.SessionPersistenceState persistenceState) {
LOG.tracef("createUserSession(%s, %s, %s, %s)%s", id, realm, loginUsername, persistenceState, getShortStackTrace());
MapUserSessionEntity entity;
if (Objects.equals(persistenceState, TRANSIENT)) {
if (id == null) {
id = UUID.randomUUID().toString();
}
entity = createUserSessionEntityInstance(id, realm.getId(), user.getId(), loginUsername, ipAddress, authMethod,
MapUserSessionEntity entity = createUserSessionEntityInstance(id, realm.getId(), user.getId(), loginUsername, ipAddress, authMethod,
rememberMe, brokerSessionId, brokerUserId, false);
if (TRANSIENT == persistenceState) {
if (id == null) {
entity.setId(UUID.randomUUID().toString());
}
transientUserSessions.put(entity.getId(), entity);
} else {
if (id != null && userSessionTx.read(id) != null) {
throw new ModelDuplicateException("User session exists: " + id);
}
entity = createUserSessionEntityInstance(id, realm.getId(), user.getId(), loginUsername, ipAddress, authMethod,
rememberMe, brokerSessionId, brokerUserId, false);
entity = userSessionTx.create(entity);
}

View file

@ -43,6 +43,7 @@ import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.keycloak.testsuite.model.session.UserSessionPersisterProviderTest.createClients;
@ -221,4 +222,31 @@ public class UserSessionProviderModelTest extends KeycloakModelTest {
assertThat(withRealm(realmId, (session, realm) -> session.sessions().getClientSession(origSessions[0], realm.getClientByClientId("test-app"), testAppClientSessionId, false)), nullValue());
}
@Test
public void testClientSessionIsNotPersistedForTransientUserSession() {
Object[] transientUserSessionWithClientSessionId = inComittedTransaction(session -> {
RealmModel realm = session.realms().getRealm(realmId);
UserSessionModel userSession = session.sessions().createUserSession(null, realm, session.users().getUserByUsername(realm, "user1"), "user1", "127.0.0.1", "form", false, null, null, UserSessionModel.SessionPersistenceState.TRANSIENT);
ClientModel testApp = realm.getClientByClientId("test-app");
AuthenticatedClientSessionModel clientSession = session.sessions().createClientSession(realm, testApp, userSession);
// assert the client sessions are present
assertThat(session.sessions().getClientSession(userSession, testApp, clientSession.getId(), false), notNullValue());
Object[] result = new Object[2];
result[0] = userSession;
result[1] = clientSession.getId();
return result;
});
inComittedTransaction(session -> {
RealmModel realm = session.realms().getRealm(realmId);
ClientModel testApp = realm.getClientByClientId("test-app");
UserSessionModel userSession = (UserSessionModel) transientUserSessionWithClientSessionId[0];
String clientSessionId = (String) transientUserSessionWithClientSessionId[1];
// in new transaction transient session should not be present
assertThat(session.sessions().getClientSession(userSession, testApp, clientSessionId, false), nullValue());
});
}
}