Add limit for authSessions per rootAuthSession in map storage

This commit is contained in:
Martin Kanis 2022-08-08 13:01:31 +02:00 committed by Hynek Mlnařík
parent 89795cfd7d
commit 57f2f4654a
5 changed files with 75 additions and 10 deletions

View file

@ -104,7 +104,7 @@ public class InfinispanAuthenticationSessionProviderFactory implements Authentic
.property() .property()
.name("authSessionsLimit") .name("authSessionsLimit")
.type("int") .type("int")
.helpText("The maximum number of concurrent authentication sessions.") .helpText("The maximum number of concurrent authentication sessions per RootAuthenticationSession.")
.defaultValue(DEFAULT_AUTH_SESSIONS_LIMIT) .defaultValue(DEFAULT_AUTH_SESSIONS_LIMIT)
.add() .add()
.build(); .build();

View file

@ -16,6 +16,7 @@
*/ */
package org.keycloak.models.map.authSession; package org.keycloak.models.map.authSession;
import org.jboss.logging.Logger;
import org.keycloak.common.util.Base64Url; import org.keycloak.common.util.Base64Url;
import org.keycloak.common.util.SecretGenerator; import org.keycloak.common.util.SecretGenerator;
import org.keycloak.common.util.Time; import org.keycloak.common.util.Time;
@ -27,9 +28,11 @@ import org.keycloak.models.utils.SessionExpiration;
import org.keycloak.sessions.AuthenticationSessionModel; import org.keycloak.sessions.AuthenticationSessionModel;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.keycloak.models.utils.SessionExpiration.getAuthSessionLifespan; import static org.keycloak.models.utils.SessionExpiration.getAuthSessionLifespan;
@ -39,8 +42,15 @@ import static org.keycloak.models.utils.SessionExpiration.getAuthSessionLifespan
*/ */
public class MapRootAuthenticationSessionAdapter extends AbstractRootAuthenticationSessionModel<MapRootAuthenticationSessionEntity> { public class MapRootAuthenticationSessionAdapter extends AbstractRootAuthenticationSessionModel<MapRootAuthenticationSessionEntity> {
public MapRootAuthenticationSessionAdapter(KeycloakSession session, RealmModel realm, MapRootAuthenticationSessionEntity entity) { private static final Logger LOG = Logger.getLogger(MapRootAuthenticationSessionAdapter.class);
private int authSessionsLimit;
private static Comparator<MapAuthenticationSessionEntity> TIMESTAMP_COMPARATOR = Comparator.comparingLong(MapAuthenticationSessionEntity::getTimestamp);
public MapRootAuthenticationSessionAdapter(KeycloakSession session, RealmModel realm, MapRootAuthenticationSessionEntity entity, int authSessionsLimit) {
super(session, realm, entity); super(session, realm, entity);
this.authSessionsLimit = authSessionsLimit;
} }
@Override @Override
@ -83,6 +93,18 @@ public class MapRootAuthenticationSessionAdapter extends AbstractRootAuthenticat
public AuthenticationSessionModel createAuthenticationSession(ClientModel client) { public AuthenticationSessionModel createAuthenticationSession(ClientModel client) {
Objects.requireNonNull(client, "The provided client can't be null!"); Objects.requireNonNull(client, "The provided client can't be null!");
Set<MapAuthenticationSessionEntity> authenticationSessions = entity.getAuthenticationSessions();
if (authenticationSessions != null && authenticationSessions.size() >= authSessionsLimit) {
String tabId = authenticationSessions.stream().min(TIMESTAMP_COMPARATOR).map(MapAuthenticationSessionEntity::getTabId).orElse(null);
if (tabId != null) {
LOG.debugf("Reached limit (%s) of active authentication sessions per a root authentication session. Removing oldest authentication session with TabId %s.", authSessionsLimit, tabId);
// remove the oldest authentication session
entity.removeAuthenticationSession(tabId);
}
}
MapAuthenticationSessionEntity authSessionEntity = new MapAuthenticationSessionEntityImpl(); MapAuthenticationSessionEntity authSessionEntity = new MapAuthenticationSessionEntityImpl();
authSessionEntity.setClientUUID(client.getId()); authSessionEntity.setClientUUID(client.getId());

View file

@ -53,10 +53,14 @@ public class MapRootAuthenticationSessionProvider implements AuthenticationSessi
private static final Logger LOG = Logger.getLogger(MapRootAuthenticationSessionProvider.class); private static final Logger LOG = Logger.getLogger(MapRootAuthenticationSessionProvider.class);
private final KeycloakSession session; private final KeycloakSession session;
protected final MapKeycloakTransaction<MapRootAuthenticationSessionEntity, RootAuthenticationSessionModel> tx; protected final MapKeycloakTransaction<MapRootAuthenticationSessionEntity, RootAuthenticationSessionModel> tx;
private int authSessionsLimit;
public MapRootAuthenticationSessionProvider(KeycloakSession session, MapStorage<MapRootAuthenticationSessionEntity, RootAuthenticationSessionModel> sessionStore) { public MapRootAuthenticationSessionProvider(KeycloakSession session,
MapStorage<MapRootAuthenticationSessionEntity, RootAuthenticationSessionModel> sessionStore,
int authSessionsLimit) {
this.session = session; this.session = session;
this.tx = sessionStore.createTransaction(session); this.tx = sessionStore.createTransaction(session);
this.authSessionsLimit = authSessionsLimit;
session.getTransactionManager().enlistAfterCompletion(tx); session.getTransactionManager().enlistAfterCompletion(tx);
} }
@ -67,7 +71,7 @@ public class MapRootAuthenticationSessionProvider implements AuthenticationSessi
tx.delete(origEntity.getId()); tx.delete(origEntity.getId());
return null; return null;
} else { } else {
return new MapRootAuthenticationSessionAdapter(session, realm, origEntity); return new MapRootAuthenticationSessionAdapter(session, realm, origEntity, authSessionsLimit);
} }
}; };
} }

View file

@ -16,25 +16,58 @@
*/ */
package org.keycloak.models.map.authSession; package org.keycloak.models.map.authSession;
import org.keycloak.Config;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.map.common.AbstractMapProviderFactory; import org.keycloak.models.map.common.AbstractMapProviderFactory;
import org.keycloak.provider.ProviderConfigProperty;
import org.keycloak.provider.ProviderConfigurationBuilder;
import org.keycloak.sessions.AuthenticationSessionProviderFactory; import org.keycloak.sessions.AuthenticationSessionProviderFactory;
import org.keycloak.sessions.RootAuthenticationSessionModel; import org.keycloak.sessions.RootAuthenticationSessionModel;
import java.util.List;
/** /**
* @author <a href="mailto:mkanis@redhat.com">Martin Kanis</a> * @author <a href="mailto:mkanis@redhat.com">Martin Kanis</a>
*/ */
public class MapRootAuthenticationSessionProviderFactory extends AbstractMapProviderFactory<MapRootAuthenticationSessionProvider, MapRootAuthenticationSessionEntity, RootAuthenticationSessionModel> public class MapRootAuthenticationSessionProviderFactory extends AbstractMapProviderFactory<MapRootAuthenticationSessionProvider, MapRootAuthenticationSessionEntity, RootAuthenticationSessionModel>
implements AuthenticationSessionProviderFactory<MapRootAuthenticationSessionProvider> { implements AuthenticationSessionProviderFactory<MapRootAuthenticationSessionProvider> {
public static final String AUTH_SESSIONS_LIMIT = "authSessionsLimit";
public static final int DEFAULT_AUTH_SESSIONS_LIMIT = 300;
private int authSessionsLimit;
public MapRootAuthenticationSessionProviderFactory() { public MapRootAuthenticationSessionProviderFactory() {
super(RootAuthenticationSessionModel.class, MapRootAuthenticationSessionProvider.class); super(RootAuthenticationSessionModel.class, MapRootAuthenticationSessionProvider.class);
} }
@Override
public void init(Config.Scope config) {
super.init(config);
// get auth sessions limit from config or use default if not provided
int configInt = config.getInt(AUTH_SESSIONS_LIMIT, DEFAULT_AUTH_SESSIONS_LIMIT);
// use default if provided value is not a positive number
authSessionsLimit = (configInt <= 0) ? DEFAULT_AUTH_SESSIONS_LIMIT : configInt;
}
@Override
public List<ProviderConfigProperty> getConfigMetadata() {
return ProviderConfigurationBuilder.create()
.property()
.name("authSessionsLimit")
.type("int")
.helpText("The maximum number of concurrent authentication sessions per RootAuthenticationSession.")
.defaultValue(DEFAULT_AUTH_SESSIONS_LIMIT)
.add()
.build();
}
@Override @Override
public MapRootAuthenticationSessionProvider createNew(KeycloakSession session) { public MapRootAuthenticationSessionProvider createNew(KeycloakSession session) {
return new MapRootAuthenticationSessionProvider(session, getStorage(session)); return new MapRootAuthenticationSessionProvider(session, getStorage(session), authSessionsLimit);
} }
@Override @Override

View file

@ -25,7 +25,6 @@ import org.keycloak.models.ClientModel;
import org.keycloak.models.Constants; import org.keycloak.models.Constants;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.sessions.infinispan.InfinispanAuthenticationSessionProviderFactory;
import org.keycloak.sessions.AuthenticationSessionModel; import org.keycloak.sessions.AuthenticationSessionModel;
import org.keycloak.sessions.AuthenticationSessionProvider; import org.keycloak.sessions.AuthenticationSessionProvider;
import org.keycloak.sessions.RootAuthenticationSessionModel; import org.keycloak.sessions.RootAuthenticationSessionModel;
@ -65,7 +64,6 @@ public class AuthenticationSessionTest extends KeycloakModelTest {
} }
@Test @Test
@RequireProvider(value = AuthenticationSessionProvider.class, only = InfinispanAuthenticationSessionProviderFactory.PROVIDER_ID)
public void testLimitAuthSessions() { public void testLimitAuthSessions() {
AtomicReference<String> rootAuthSessionId = new AtomicReference<>(); AtomicReference<String> rootAuthSessionId = new AtomicReference<>();
List<String> tabIds = withRealm(realmId, (session, realm) -> { List<String> tabIds = withRealm(realmId, (session, realm) -> {
@ -81,13 +79,21 @@ public class AuthenticationSessionTest extends KeycloakModelTest {
.collect(Collectors.toList()); .collect(Collectors.toList());
}); });
withRealm(realmId, (session, realm) -> { String tabId = withRealm(realmId, (session, realm) -> {
RootAuthenticationSessionModel ras = session.authenticationSessions().getRootAuthenticationSession(realm, rootAuthSessionId.get()); RootAuthenticationSessionModel ras = session.authenticationSessions().getRootAuthenticationSession(realm, rootAuthSessionId.get());
ClientModel client = realm.getClientByClientId("test-app"); ClientModel client = realm.getClientByClientId("test-app");
// create 301st auth session // create 301st auth session
AuthenticationSessionModel as = ras.createAuthenticationSession(client); return ras.createAuthenticationSession(client).getTabId();
Assert.assertEquals(as, ras.getAuthenticationSession(client, as.getTabId())); });
withRealm(realmId, (session, realm) -> {
RootAuthenticationSessionModel ras = session.authenticationSessions().getRootAuthenticationSession(realm, rootAuthSessionId.get());
ClientModel client = realm.getClientByClientId("test-app");
assertThat(ras.getAuthenticationSessions(), Matchers.aMapWithSize(300));
Assert.assertEquals(tabId, ras.getAuthenticationSession(client, tabId).getTabId());
// assert the first authentication session was deleted // assert the first authentication session was deleted
Assert.assertNull(ras.getAuthenticationSession(client, tabIds.get(0))); Assert.assertNull(ras.getAuthenticationSession(client, tabIds.get(0)));