Prevent client cache stampede after invalidation of a client or on startup (#25217)

Closes #24202

Signed-off-by: Alexander Schwartz <aschwart@redhat.com>
This commit is contained in:
Alexander Schwartz 2023-12-05 16:01:37 +01:00 committed by GitHub
parent e69031d411
commit e4be3ed244
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1124,30 +1124,46 @@ public class RealmCacheSession implements CacheRealmProvider {
@Override @Override
public ClientModel getClientById(RealmModel realm, String id) { public ClientModel getClientById(RealmModel realm, String id) {
if (invalidations.contains(id) || listInvalidations.contains(realm.getId())) {
return getClientDelegate().getClientById(realm, id);
} else if (managedApplications.containsKey(id)) {
return managedApplications.get(id);
}
CachedClient cached = cache.get(id, CachedClient.class); CachedClient cached = cache.get(id, CachedClient.class);
if (cached != null && !cached.getRealm().equals(realm.getId())) { if (cached != null && !cached.getRealm().equals(realm.getId())) {
cached = null; cached = null;
} }
boolean queryDB = invalidations.contains(id) || listInvalidations.contains(realm.getId()); ClientModel adapter;
if (queryDB) { // short-circuit if the client has been potentially invalidated
return getClientDelegate().getClientById(realm, id);
}
if (cached != null) { if (cached != null) {
logger.tracev("client by id cache hit: {0}", cached.getClientId()); logger.tracev("client by id cache hit: {0}", cached.getClientId());
adapter = validateCache(realm, cached);
} else {
adapter = cache.computeSerialized(session, id, (key, keycloakSession) -> prepareCachedClientById(realm, id));
if (adapter == null) {
return adapter;
}
} }
managedApplications.put(id, adapter);
return adapter;
}
private ClientModel prepareCachedClientById(RealmModel realm, String id) {
CachedClient cached = cache.get(id, CachedClient.class);
ClientModel adapter;
if (cached != null && !cached.getRealm().equals(realm.getId())) {
cached = null;
}
if (cached == null) { if (cached == null) {
Long loaded = cache.getCurrentRevision(id); Long loaded = cache.getCurrentRevision(id);
ClientModel model = getClientDelegate().getClientById(realm, id); ClientModel model = getClientDelegate().getClientById(realm, id);
if (model == null) return null; if (model == null) {
ClientModel adapter = cacheClient(realm, model, loaded); return null;
managedApplications.put(id, adapter); }
return adapter; adapter = cacheClient(realm, model, loaded);
} else if (managedApplications.containsKey(id)) { } else {
return managedApplications.get(id); logger.tracev("client by id cache hit after locking: {0}", cached.getClientId());
adapter = validateCache(realm, cached);
} }
ClientModel adapter = validateCache(realm, cached);
managedApplications.put(id, adapter);
return adapter; return adapter;
} }
@ -1230,31 +1246,37 @@ public class RealmCacheSession implements CacheRealmProvider {
@Override @Override
public ClientModel getClientByClientId(RealmModel realm, String clientId) { public ClientModel getClientByClientId(RealmModel realm, String clientId) {
String cacheKey = getClientByClientIdCacheKey(clientId, realm.getId()); String cacheKey = getClientByClientIdCacheKey(clientId, realm.getId());
ClientListQuery query = cache.get(cacheKey, ClientListQuery.class); if (invalidations.contains(cacheKey) || listInvalidations.contains(realm.getId())) {
String id = null;
boolean queryDB = invalidations.contains(cacheKey) || listInvalidations.contains(realm.getId());
if (queryDB) { // short-circuit if the client has been potentially invalidated
return getClientDelegate().getClientByClientId(realm, clientId); return getClientDelegate().getClientByClientId(realm, clientId);
} }
ClientListQuery query = cache.get(cacheKey, ClientListQuery.class);
if (query != null) { if (query != null) {
logger.tracev("client by name cache hit: {0}", clientId); logger.tracev("client by name cache hit: {0}", clientId);
String id = query.getClients().iterator().next();
return getClientById(realm, id);
} else {
return cache.computeSerialized(session, cacheKey, (key, keycloakSession) -> prepareCachedClientByClientId(realm, clientId, key));
} }
}
private ClientModel prepareCachedClientByClientId(RealmModel realm, String clientId, String cacheKey) {
ClientListQuery query = cache.get(cacheKey, ClientListQuery.class);
String id;
if (query == null) { if (query == null) {
Long loaded = cache.getCurrentRevision(cacheKey); Long loaded = cache.getCurrentRevision(cacheKey);
ClientModel model = getClientDelegate().getClientByClientId(realm, clientId); ClientModel model = getClientDelegate().getClientByClientId(realm, clientId);
if (model == null) return null; if (model == null) {
if (invalidations.contains(model.getId())) return model; return null;
}
id = model.getId(); id = model.getId();
query = new ClientListQuery(loaded, cacheKey, realm, id); query = new ClientListQuery(loaded, cacheKey, realm, id);
logger.tracev("adding client by name cache miss: {0}", clientId); logger.tracev("adding client by name cache miss: {0}", clientId);
cache.addRevisioned(query, startupRevision); cache.addRevisioned(query, startupRevision);
if (invalidations.contains(model.getId())) {
return model;
}
} else { } else {
id = query.getClients().iterator().next(); id = query.getClients().iterator().next();
if (invalidations.contains(id)) {
return getClientDelegate().getClientByClientId(realm, clientId);
}
} }
return getClientById(realm, id); return getClientById(realm, id);
} }