Prevent cache stampede on realms

Closes #21521
This commit is contained in:
Alexander Schwartz 2023-07-07 22:26:21 +02:00
parent 07c27336aa
commit 9b3effb4b8
2 changed files with 56 additions and 19 deletions

View file

@ -19,6 +19,7 @@ package org.keycloak.models.cache.infinispan;
import org.infinispan.Cache; import org.infinispan.Cache;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.cache.infinispan.events.InvalidationEvent; import org.keycloak.models.cache.infinispan.events.InvalidationEvent;
import org.keycloak.models.cache.infinispan.entities.Revisioned; import org.keycloak.models.cache.infinispan.entities.Revisioned;
import org.keycloak.models.cache.infinispan.events.RealmCacheInvalidationEvent; import org.keycloak.models.cache.infinispan.events.RealmCacheInvalidationEvent;
@ -28,6 +29,8 @@ import org.keycloak.models.cache.infinispan.stream.InClientPredicate;
import org.keycloak.models.cache.infinispan.stream.InRealmPredicate; import org.keycloak.models.cache.infinispan.stream.InRealmPredicate;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
@ -36,6 +39,8 @@ public class RealmCacheManager extends CacheManager {
private static final Logger logger = Logger.getLogger(RealmCacheManager.class); private static final Logger logger = Logger.getLogger(RealmCacheManager.class);
private final ConcurrentHashMap<String, String> cacheInteractions = new ConcurrentHashMap<>();
@Override @Override
protected Logger getLogger() { protected Logger getLogger() {
return logger; return logger;
@ -109,7 +114,6 @@ public class RealmCacheManager extends CacheManager {
addInvalidations(InClientPredicate.create().client(clientUUID), invalidations); addInvalidations(InClientPredicate.create().client(clientUUID), invalidations);
} }
@Override @Override
protected void addInvalidationsFromEvent(InvalidationEvent event, Set<String> invalidations) { protected void addInvalidationsFromEvent(InvalidationEvent event, Set<String> invalidations) {
invalidations.add(event.getId()); invalidations.add(event.getId());
@ -117,4 +121,23 @@ public class RealmCacheManager extends CacheManager {
((RealmCacheInvalidationEvent) event).addInvalidations(this, invalidations); ((RealmCacheInvalidationEvent) event).addInvalidations(this, invalidations);
} }
/**
* Compute a cached realm and ensure that this happens only once with the current Keycloak instance.
* Use this to avoid concurrent preparations of a realm in parallel threads. This helps to break the load on
* a stampede after a server has started, were a lot of requests come in for the same realm that hasn't been cached yet.
* Instead of each request loading the realm in parallel, this lets the first request load the realm, and all
* other requests will use the cached realm, which is much more efficient.
*/
public RealmAdapter computeSerialized(KeycloakSession session, String id, BiFunction<String, KeycloakSession, RealmAdapter> compute) {
// this locking is only to ensure that if there is a computation for the same id in the "synchronized" block below,
// it will have the same object instance to lock the current execution until the other is finished.
Object lock = cacheInteractions.computeIfAbsent(id, s -> id);
try {
synchronized (lock) {
return compute.apply(id, session);
}
} finally {
cacheInteractions.remove(lock);
}
}
} }

View file

@ -405,26 +405,38 @@ public class RealmCacheSession implements CacheRealmProvider {
@Override @Override
public RealmModel getRealm(String id) { public RealmModel getRealm(String id) {
CachedRealm cached = cache.get(id, CachedRealm.class); if (invalidations.contains(id)) {
if (cached != null) {
logger.tracev("by id cache hit: {0}", cached.getName());
}
boolean wasCached = false;
if (cached == null) {
Long loaded = cache.getCurrentRevision(id);
RealmModel model = getRealmDelegate().getRealm(id);
if (model == null) return null;
if (invalidations.contains(id)) return model;
cached = new CachedRealm(loaded, model);
cache.addRevisioned(cached, startupRevision);
wasCached =true;
} else if (invalidations.contains(id)) {
return getRealmDelegate().getRealm(id); return getRealmDelegate().getRealm(id);
} else if (managedRealms.containsKey(id)) { } else if (managedRealms.containsKey(id)) {
return managedRealms.get(id); return managedRealms.get(id);
} }
RealmAdapter adapter = new RealmAdapter(session, cached, this); CachedRealm cached = cache.get(id, CachedRealm.class);
if (wasCached) { RealmAdapter adapter;
if (cached != null) {
logger.tracev("by id cache hit: {0}", cached.getName());
adapter = new RealmAdapter(session, cached, this);
} else {
adapter = cache.computeSerialized(session, id, this::prepareCachedRealm);
if (adapter == null) {
return null;
}
}
managedRealms.put(id, adapter);
return adapter;
}
private RealmAdapter prepareCachedRealm(String id, KeycloakSession session) {
CachedRealm cached = cache.get(id, CachedRealm.class);
RealmAdapter adapter;
if (cached == null) {
Long loaded = cache.getCurrentRevision(id);
RealmModel model = getRealmDelegate().getRealm(id);
if (model == null) {
return null;
}
cached = new CachedRealm(loaded, model);
cache.addRevisioned(cached, startupRevision);
adapter = new RealmAdapter(session, cached, this);
CachedRealmModel.RealmCachedEvent event = new CachedRealmModel.RealmCachedEvent() { CachedRealmModel.RealmCachedEvent event = new CachedRealmModel.RealmCachedEvent() {
@Override @Override
public CachedRealmModel getRealm() { public CachedRealmModel getRealm() {
@ -437,8 +449,10 @@ public class RealmCacheSession implements CacheRealmProvider {
} }
}; };
session.getKeycloakSessionFactory().publish(event); session.getKeycloakSessionFactory().publish(event);
} else {
adapter = new RealmAdapter(session, cached, this);
logger.tracev("by id cache hit after locking: {0}", cached.getName());
} }
managedRealms.put(id, adapter);
return adapter; return adapter;
} }
@ -597,7 +611,7 @@ public class RealmCacheSession implements CacheRealmProvider {
client.getRolesStream().forEach(role -> { client.getRolesStream().forEach(role -> {
roleRemovalInvalidations(role.getId(), role.getName(), client.getId()); roleRemovalInvalidations(role.getId(), role.getName(), client.getId());
}); });
if (client.isServiceAccountsEnabled()) { if (client.isServiceAccountsEnabled()) {
UserModel serviceAccount = session.users().getServiceAccount(client); UserModel serviceAccount = session.users().getServiceAccount(client);