Do not persist client sessions of transient user sessions
Closes #12357
This commit is contained in:
parent
e856a62fb2
commit
1b3a76d0af
2 changed files with 72 additions and 20 deletions
|
@ -68,6 +68,10 @@ public class MapUserSessionProvider implements UserSessionProvider {
|
|||
* Storage for transient user sessions which lifespan is limited to one request.
|
||||
*/
|
||||
private final Map<String, MapUserSessionEntity> transientUserSessions = new HashMap<>();
|
||||
/**
|
||||
* Storage for client sessions where parent is transient user session. Lifespan is limited to one request.
|
||||
*/
|
||||
private final Map<String, MapAuthenticatedClientSessionEntity> transientClientSessions = new HashMap<>();
|
||||
|
||||
public MapUserSessionProvider(KeycloakSession session, MapStorage<MapUserSessionEntity, UserSessionModel> userSessionStore,
|
||||
MapStorage<MapAuthenticatedClientSessionEntity, AuthenticatedClientSessionModel> clientSessionStore) {
|
||||
|
@ -82,12 +86,14 @@ public class MapUserSessionProvider implements UserSessionProvider {
|
|||
private Function<MapUserSessionEntity, UserSessionModel> userEntityToAdapterFunc(RealmModel realm) {
|
||||
// Clone entity before returning back, to avoid giving away a reference to the live object to the caller
|
||||
return (origEntity) -> {
|
||||
if (origEntity == null) return null;
|
||||
long expiration = origEntity.getExpiration() != null ? origEntity.getExpiration() : 0L;
|
||||
if (expiration <= Time.currentTimeMillis()) {
|
||||
if (Objects.equals(origEntity.getPersistenceState(), TRANSIENT)) {
|
||||
if (TRANSIENT == origEntity.getPersistenceState()) {
|
||||
transientUserSessions.remove(origEntity.getId());
|
||||
} else {
|
||||
userSessionTx.delete(origEntity.getId());
|
||||
}
|
||||
userSessionTx.delete(origEntity.getId());
|
||||
return null;
|
||||
} else {
|
||||
return new MapUserSessionAdapter(session, realm, origEntity) {
|
||||
|
@ -112,10 +118,14 @@ public class MapUserSessionProvider implements UserSessionProvider {
|
|||
UserSessionModel userSession) {
|
||||
// Clone entity before returning back, to avoid giving away a reference to the live object to the caller
|
||||
return origEntity -> {
|
||||
if (origEntity == null) return null;
|
||||
long expiration = origEntity.getExpiration() != null ? origEntity.getExpiration() : 0L;
|
||||
if (expiration <= Time.currentTimeMillis()) {
|
||||
userSession.removeAuthenticatedClientSessions(Arrays.asList(origEntity.getClientId()));
|
||||
clientSessionTx.delete(origEntity.getId());
|
||||
// if a client session is found among transient ones we can skip call to store
|
||||
if (transientClientSessions.remove(origEntity.getId()) == null) {
|
||||
clientSessionTx.delete(origEntity.getId());
|
||||
}
|
||||
return null;
|
||||
} else {
|
||||
return new MapAuthenticatedClientSessionAdapter(session, realm, client, userSession, origEntity) {
|
||||
|
@ -123,7 +133,10 @@ public class MapUserSessionProvider implements UserSessionProvider {
|
|||
public void detachFromUserSession() {
|
||||
this.userSession = null;
|
||||
|
||||
clientSessionTx.delete(entity.getId());
|
||||
// if a client session is found among transient ones we can skip call to store
|
||||
if (transientClientSessions.remove(entity.getId()) == null) {
|
||||
clientSessionTx.delete(entity.getId());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -144,20 +157,27 @@ public class MapUserSessionProvider implements UserSessionProvider {
|
|||
|
||||
@Override
|
||||
public AuthenticatedClientSessionModel createClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession) {
|
||||
LOG.tracef("createClientSession(%s, %s, %s)%s", realm, client, userSession, getShortStackTrace());
|
||||
|
||||
MapUserSessionEntity userSessionEntity = getUserSessionById(userSession.getId());
|
||||
|
||||
if (userSessionEntity == null) {
|
||||
throw new IllegalStateException("User session entity does not exist: " + userSession.getId());
|
||||
}
|
||||
|
||||
MapAuthenticatedClientSessionEntity entity = createAuthenticatedClientSessionEntityInstance(null, userSession.getId(),
|
||||
realm.getId(), client.getId(), false);
|
||||
String started = entity.getTimestamp() != null ? String.valueOf(TimeAdapter.fromMilliSecondsToSeconds(entity.getTimestamp())) : String.valueOf(0);
|
||||
entity.setNote(AuthenticatedClientSessionModel.STARTED_AT_NOTE, started);
|
||||
setClientSessionExpiration(entity, realm, client);
|
||||
|
||||
LOG.tracef("createClientSession(%s, %s, %s)%s", realm, client, userSession, getShortStackTrace());
|
||||
|
||||
entity = clientSessionTx.create(entity);
|
||||
|
||||
MapUserSessionEntity userSessionEntity = getUserSessionById(userSession.getId());
|
||||
|
||||
if (userSessionEntity == null) {
|
||||
throw new IllegalStateException("User session entity does not exist: " + userSession.getId());
|
||||
if (TRANSIENT == userSessionEntity.getPersistenceState()) {
|
||||
if (entity.getId() == null) {
|
||||
entity.setId(UUID.randomUUID().toString());
|
||||
}
|
||||
transientClientSessions.put(entity.getId(), entity);
|
||||
} else {
|
||||
entity = clientSessionTx.create(entity);
|
||||
}
|
||||
|
||||
userSessionEntity.setAuthenticatedClientSession(client.getId(), entity.getId());
|
||||
|
@ -177,6 +197,12 @@ public class MapUserSessionProvider implements UserSessionProvider {
|
|||
return null;
|
||||
}
|
||||
|
||||
MapAuthenticatedClientSessionEntity entity = transientClientSessions.get(clientSessionId);
|
||||
|
||||
if (entity != null) {
|
||||
return clientEntityToAdapterFunc(client.getRealm(), client, userSession).apply(entity);
|
||||
}
|
||||
|
||||
DefaultModelCriteria<AuthenticatedClientSessionModel> mcb = criteria();
|
||||
mcb = mcb.compare(AuthenticatedClientSessionModel.SearchableFields.ID, Operator.EQ, clientSessionId)
|
||||
.compare(AuthenticatedClientSessionModel.SearchableFields.USER_SESSION_ID, Operator.EQ, userSession.getId())
|
||||
|
@ -204,20 +230,18 @@ public class MapUserSessionProvider implements UserSessionProvider {
|
|||
String brokerUserId, UserSessionModel.SessionPersistenceState persistenceState) {
|
||||
LOG.tracef("createUserSession(%s, %s, %s, %s)%s", id, realm, loginUsername, persistenceState, getShortStackTrace());
|
||||
|
||||
MapUserSessionEntity entity;
|
||||
if (Objects.equals(persistenceState, TRANSIENT)) {
|
||||
if (id == null) {
|
||||
id = UUID.randomUUID().toString();
|
||||
}
|
||||
entity = createUserSessionEntityInstance(id, realm.getId(), user.getId(), loginUsername, ipAddress, authMethod,
|
||||
MapUserSessionEntity entity = createUserSessionEntityInstance(id, realm.getId(), user.getId(), loginUsername, ipAddress, authMethod,
|
||||
rememberMe, brokerSessionId, brokerUserId, false);
|
||||
|
||||
if (TRANSIENT == persistenceState) {
|
||||
if (id == null) {
|
||||
entity.setId(UUID.randomUUID().toString());
|
||||
}
|
||||
transientUserSessions.put(entity.getId(), entity);
|
||||
} else {
|
||||
if (id != null && userSessionTx.read(id) != null) {
|
||||
throw new ModelDuplicateException("User session exists: " + id);
|
||||
}
|
||||
entity = createUserSessionEntityInstance(id, realm.getId(), user.getId(), loginUsername, ipAddress, authMethod,
|
||||
rememberMe, brokerSessionId, brokerUserId, false);
|
||||
entity = userSessionTx.create(entity);
|
||||
}
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ import java.util.List;
|
|||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.notNullValue;
|
||||
import static org.hamcrest.CoreMatchers.nullValue;
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.keycloak.testsuite.model.session.UserSessionPersisterProviderTest.createClients;
|
||||
|
@ -221,4 +222,31 @@ public class UserSessionProviderModelTest extends KeycloakModelTest {
|
|||
|
||||
assertThat(withRealm(realmId, (session, realm) -> session.sessions().getClientSession(origSessions[0], realm.getClientByClientId("test-app"), testAppClientSessionId, false)), nullValue());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testClientSessionIsNotPersistedForTransientUserSession() {
|
||||
Object[] transientUserSessionWithClientSessionId = inComittedTransaction(session -> {
|
||||
RealmModel realm = session.realms().getRealm(realmId);
|
||||
UserSessionModel userSession = session.sessions().createUserSession(null, realm, session.users().getUserByUsername(realm, "user1"), "user1", "127.0.0.1", "form", false, null, null, UserSessionModel.SessionPersistenceState.TRANSIENT);
|
||||
|
||||
ClientModel testApp = realm.getClientByClientId("test-app");
|
||||
AuthenticatedClientSessionModel clientSession = session.sessions().createClientSession(realm, testApp, userSession);
|
||||
|
||||
// assert the client sessions are present
|
||||
assertThat(session.sessions().getClientSession(userSession, testApp, clientSession.getId(), false), notNullValue());
|
||||
Object[] result = new Object[2];
|
||||
result[0] = userSession;
|
||||
result[1] = clientSession.getId();
|
||||
return result;
|
||||
});
|
||||
|
||||
inComittedTransaction(session -> {
|
||||
RealmModel realm = session.realms().getRealm(realmId);
|
||||
ClientModel testApp = realm.getClientByClientId("test-app");
|
||||
UserSessionModel userSession = (UserSessionModel) transientUserSessionWithClientSessionId[0];
|
||||
String clientSessionId = (String) transientUserSessionWithClientSessionId[1];
|
||||
// in new transaction transient session should not be present
|
||||
assertThat(session.sessions().getClientSession(userSession, testApp, clientSessionId, false), nullValue());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue