From 326d63ce7462cbcd94f2efde2eb9070c1a724220 Mon Sep 17 00:00:00 2001 From: Pedro Igor Date: Wed, 28 Feb 2024 10:39:51 -0300 Subject: [PATCH] Make sure group searches are cached and entries invalidate accordingly Closes #26983 Signed-off-by: Pedro Igor --- .../cache/infinispan/RealmCacheSession.java | 68 +++++++++++++------ .../infinispan/entities/GroupListQuery.java | 47 +++++++++++-- .../org/keycloak/models/GroupProvider.java | 2 - .../admin/concurrency/ConcurrencyTest.java | 16 ++--- 4 files changed, 94 insertions(+), 39 deletions(-) diff --git a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/RealmCacheSession.java b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/RealmCacheSession.java index 8bed4ba203..1c65253214 100755 --- a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/RealmCacheSession.java +++ b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/RealmCacheSession.java @@ -32,6 +32,7 @@ import org.keycloak.storage.StorageId; import org.keycloak.storage.client.ClientStorageProviderModel; import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -1016,40 +1017,63 @@ public class RealmCacheSession implements CacheRealmProvider { @Override public Stream getTopLevelGroupsStream(RealmModel realm, String search, Boolean exact, Integer first, Integer max) { - String cacheKey = getTopGroupsQueryCacheKey(realm.getId() + search + first + max); - boolean queryDB = invalidations.contains(cacheKey) || listInvalidations.contains(cacheKey) - || listInvalidations.contains(realm.getId()); - if (queryDB) { + String cacheKey = getTopGroupsQueryCacheKey(realm.getId()); + + if (hasInvalidation(realm, cacheKey)) { return getGroupDelegate().getTopLevelGroupsStream(realm, search, exact, first, max); } GroupListQuery query = cache.get(cacheKey, GroupListQuery.class); - if (Objects.nonNull(query)) { - logger.tracev("getTopLevelGroups cache hit: {0}", realm.getName()); - } + String searchKey = Optional.ofNullable(search).orElse("") + "." + Optional.ofNullable(first).orElse(-1) + "." + Optional.ofNullable(max).orElse(-1); + Set cached; if (Objects.isNull(query)) { + // not cached yet Long loaded = cache.getCurrentRevision(cacheKey); - List model = getGroupDelegate().getTopLevelGroupsStream(realm, search, exact, first, max).collect(Collectors.toList()); - if (model.isEmpty()) return Stream.empty(); - Set ids = new HashSet<>(); - for (GroupModel client : model) ids.add(client.getId()); - query = new GroupListQuery(loaded, cacheKey, realm, ids); + cached = getGroupDelegate().getTopLevelGroupsStream(realm, search, exact, first, max).map(GroupModel::getId).collect(Collectors.toSet()); + query = new GroupListQuery(loaded, cacheKey, realm, searchKey, cached); logger.tracev("adding realm getTopLevelGroups cache miss: realm {0} key {1}", realm.getName(), cacheKey); cache.addRevisioned(query, startupRevision); - return model.stream(); - } - List list = new LinkedList<>(); - for (String id : query.getGroups()) { - GroupModel group = session.groups().getGroupById(realm, id); - if (Objects.isNull(group)) { - invalidations.add(cacheKey); - return getGroupDelegate().getTopLevelGroupsStream(realm); + } else { + logger.tracev("getTopLevelGroups cache hit: {0}", realm.getName()); + + cached = query.getGroups(searchKey); + + if (hasInvalidation(realm, cacheKey) || cached == null) { + // there is a cache entry, but the current search is not yet cached + cache.invalidateObject(cacheKey); + Long loaded = cache.getCurrentRevision(cacheKey); + cached = getGroupDelegate().getTopLevelGroupsStream(realm, search, exact, first, max).map(GroupModel::getId).collect(Collectors.toSet()); + query = new GroupListQuery(loaded, cacheKey, realm, searchKey, cached, query); + logger.tracev("adding realm getTopLevelGroups search cache miss: realm {0} key {1}", realm.getName(), searchKey); + cache.addRevisioned(query, cache.getCurrentCounter()); } - list.add(group); } - return list.stream().sorted(GroupModel.COMPARE_BY_NAME); + AtomicBoolean invalidate = new AtomicBoolean(false); + Stream groups = cached.stream() + .map((id) -> session.groups().getGroupById(realm, id)) + .takeWhile(group -> { + if (Objects.isNull(group)) { + invalidate.set(true); + return false; + } + return true; + }) + .sorted(GroupModel.COMPARE_BY_NAME); + + if (!invalidate.get()) { + return groups; + } + + invalidations.add(cacheKey); + + return getGroupDelegate().getTopLevelGroupsStream(realm, search, exact, first, max); + } + + private boolean hasInvalidation(RealmModel realm, String cacheKey) { + return invalidations.contains(cacheKey) || listInvalidations.contains(cacheKey) + || listInvalidations.contains(realm.getId()); } @Override diff --git a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/entities/GroupListQuery.java b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/entities/GroupListQuery.java index 1e0c664e1e..4c9b607f07 100755 --- a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/entities/GroupListQuery.java +++ b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/entities/GroupListQuery.java @@ -2,27 +2,59 @@ package org.keycloak.models.cache.infinispan.entities; import org.keycloak.models.RealmModel; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * @author Bill Burke * @version $Revision: 1 $ */ public class GroupListQuery extends AbstractRevisioned implements GroupQuery { - private final Set groups; private final String realm; private final String realmName; + private Map> searchKeys; - public GroupListQuery(Long revisioned, String id, RealmModel realm, Set groups) { + public GroupListQuery(Long revisioned, String id, RealmModel realm, String searchKey, Set result) { super(revisioned, id); this.realm = realm.getId(); this.realmName = realm.getName(); - this.groups = groups; + this.searchKeys = new HashMap<>(); + this.searchKeys.put(searchKey, result); + } + + public GroupListQuery(Long revisioned, String id, RealmModel realm, String searchKey, Set result, GroupListQuery previous) { + super(revisioned, id); + this.realm = realm.getId(); + this.realmName = realm.getName(); + this.searchKeys = new HashMap<>(); + this.searchKeys.putAll(previous.searchKeys); + this.searchKeys.put(searchKey, result); + } + + public GroupListQuery(Long revisioned, String id, RealmModel realm, Set ids) { + super(revisioned, id); + this.realm = realm.getId(); + this.realmName = realm.getName(); + this.searchKeys = new HashMap<>(); + this.searchKeys.put(id, ids); } @Override public Set getGroups() { - return groups; + Collection> values = searchKeys.values(); + + if (values.isEmpty()) { + return Set.of(); + } + + return values.stream().flatMap(Set::stream).collect(Collectors.toSet()); + } + + public Set getGroups(String searchKey) { + return searchKeys.get(searchKey); } @Override @@ -30,6 +62,13 @@ public class GroupListQuery extends AbstractRevisioned implements GroupQuery { return realm; } + public Map> getSearchKeys() { + if (searchKeys == null) { + searchKeys = new HashMap<>(); + } + return searchKeys; + } + @Override public String toString() { return "GroupListQuery{" + diff --git a/server-spi/src/main/java/org/keycloak/models/GroupProvider.java b/server-spi/src/main/java/org/keycloak/models/GroupProvider.java index 7951fd9a37..a3880a9ce6 100644 --- a/server-spi/src/main/java/org/keycloak/models/GroupProvider.java +++ b/server-spi/src/main/java/org/keycloak/models/GroupProvider.java @@ -20,8 +20,6 @@ package org.keycloak.models; import org.keycloak.provider.Provider; import org.keycloak.storage.group.GroupLookupProvider; -import java.util.List; -import java.util.stream.Collectors; import java.util.stream.Stream; /** diff --git a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/admin/concurrency/ConcurrencyTest.java b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/admin/concurrency/ConcurrencyTest.java index 2a846f37bd..b3f0aaf9ef 100755 --- a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/admin/concurrency/ConcurrencyTest.java +++ b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/admin/concurrency/ConcurrencyTest.java @@ -235,17 +235,11 @@ public class ConcurrencyTest extends AbstractConcurrencyTest { c = realm.groups().group(id).toRepresentation(); assertNotNull(c); - boolean retry = true; - int i = 0; - do { - List groups = realm.groups().groups().stream() - .map(GroupRepresentation::getName) - .filter(Objects::nonNull) - .collect(Collectors.toList()); - retry = !groups.contains(name); - i++; - } while(retry && i < 3); - assertFalse("Group " + name + " [" + id + "] " + " not found in group list", retry); + assertTrue("Group " + name + " [" + id + "] " + " not found in group list", + realm.groups().groups().stream() + .map(GroupRepresentation::getName) + .filter(Objects::nonNull) + .anyMatch(name::equals)); } }