Prevent client cache stampede after invalidation of a client or on startup (#25217)

Closes #24202

Signed-off-by: Alexander Schwartz <aschwart@redhat.com>
This commit is contained in:
Alexander Schwartz 2023-12-05 16:01:37 +01:00 committed by GitHub
parent e69031d411
commit e4be3ed244
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1124,30 +1124,46 @@ public class RealmCacheSession implements CacheRealmProvider {
@Override
public ClientModel getClientById(RealmModel realm, String id) {
if (invalidations.contains(id) || listInvalidations.contains(realm.getId())) {
return getClientDelegate().getClientById(realm, id);
} else if (managedApplications.containsKey(id)) {
return managedApplications.get(id);
}
CachedClient cached = cache.get(id, CachedClient.class);
if (cached != null && !cached.getRealm().equals(realm.getId())) {
cached = null;
}
boolean queryDB = invalidations.contains(id) || listInvalidations.contains(realm.getId());
if (queryDB) { // short-circuit if the client has been potentially invalidated
return getClientDelegate().getClientById(realm, id);
}
ClientModel adapter;
if (cached != null) {
logger.tracev("client by id cache hit: {0}", cached.getClientId());
adapter = validateCache(realm, cached);
} else {
adapter = cache.computeSerialized(session, id, (key, keycloakSession) -> prepareCachedClientById(realm, id));
if (adapter == null) {
return adapter;
}
}
managedApplications.put(id, adapter);
return adapter;
}
private ClientModel prepareCachedClientById(RealmModel realm, String id) {
CachedClient cached = cache.get(id, CachedClient.class);
ClientModel adapter;
if (cached != null && !cached.getRealm().equals(realm.getId())) {
cached = null;
}
if (cached == null) {
Long loaded = cache.getCurrentRevision(id);
ClientModel model = getClientDelegate().getClientById(realm, id);
if (model == null) return null;
ClientModel adapter = cacheClient(realm, model, loaded);
managedApplications.put(id, adapter);
return adapter;
} else if (managedApplications.containsKey(id)) {
return managedApplications.get(id);
if (model == null) {
return null;
}
adapter = cacheClient(realm, model, loaded);
} else {
logger.tracev("client by id cache hit after locking: {0}", cached.getClientId());
adapter = validateCache(realm, cached);
}
ClientModel adapter = validateCache(realm, cached);
managedApplications.put(id, adapter);
return adapter;
}
@ -1230,31 +1246,37 @@ public class RealmCacheSession implements CacheRealmProvider {
@Override
public ClientModel getClientByClientId(RealmModel realm, String clientId) {
String cacheKey = getClientByClientIdCacheKey(clientId, realm.getId());
ClientListQuery query = cache.get(cacheKey, ClientListQuery.class);
String id = null;
boolean queryDB = invalidations.contains(cacheKey) || listInvalidations.contains(realm.getId());
if (queryDB) { // short-circuit if the client has been potentially invalidated
if (invalidations.contains(cacheKey) || listInvalidations.contains(realm.getId())) {
return getClientDelegate().getClientByClientId(realm, clientId);
}
ClientListQuery query = cache.get(cacheKey, ClientListQuery.class);
if (query != null) {
logger.tracev("client by name cache hit: {0}", clientId);
String id = query.getClients().iterator().next();
return getClientById(realm, id);
} else {
return cache.computeSerialized(session, cacheKey, (key, keycloakSession) -> prepareCachedClientByClientId(realm, clientId, key));
}
}
private ClientModel prepareCachedClientByClientId(RealmModel realm, String clientId, String cacheKey) {
ClientListQuery query = cache.get(cacheKey, ClientListQuery.class);
String id;
if (query == null) {
Long loaded = cache.getCurrentRevision(cacheKey);
ClientModel model = getClientDelegate().getClientByClientId(realm, clientId);
if (model == null) return null;
if (invalidations.contains(model.getId())) return model;
if (model == null) {
return null;
}
id = model.getId();
query = new ClientListQuery(loaded, cacheKey, realm, id);
logger.tracev("adding client by name cache miss: {0}", clientId);
cache.addRevisioned(query, startupRevision);
if (invalidations.contains(model.getId())) {
return model;
}
} else {
id = query.getClients().iterator().next();
if (invalidations.contains(id)) {
return getClientDelegate().getClientByClientId(realm, clientId);
}
}
return getClientById(realm, id);
}