Add caching when querying brokers by organization

Closes #32574

Signed-off-by: Martin Kanis <mkanis@redhat.com>
This commit is contained in:
Martin Kanis 2024-09-04 14:56:03 +02:00 committed by Pedro Igor
parent 23adc1e04e
commit ccb166d0e9
4 changed files with 200 additions and 0 deletions

View file

@ -0,0 +1,55 @@
/*
* Copyright 2024 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.cache.infinispan.idp;
import org.keycloak.models.RealmModel;
import org.keycloak.models.cache.infinispan.entities.AbstractRevisioned;
import org.keycloak.models.cache.infinispan.entities.InRealm;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
public class IdentityProviderListQuery extends AbstractRevisioned implements InRealm {
private final String realmId;
private final Map<String, Set<String>> searchKeys;
public IdentityProviderListQuery(Long revision, String id, RealmModel realm, String searchKey, Set<String> result) {
super(revision, id);
this.realmId = realm.getId();
this.searchKeys = new HashMap<>();
this.searchKeys.put(searchKey, result);
}
public IdentityProviderListQuery(Long revision, String id, RealmModel realm, String searchKey, Set<String> result, IdentityProviderListQuery previous) {
super(revision, id);
this.realmId = realm.getId();
this.searchKeys = new HashMap<>();
this.searchKeys.putAll(previous.searchKeys);
this.searchKeys.put(searchKey, result);
}
@Override
public String getRealm() {
return realmId;
}
public Set<String> getIDPs(String searchKey) {
return searchKeys.get(searchKey);
}
}

View file

@ -16,7 +16,11 @@
*/
package org.keycloak.models.cache.infinispan.idp;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.keycloak.common.Profile;
import org.keycloak.models.IdentityProviderMapperModel;
@ -28,6 +32,7 @@ import org.keycloak.models.OrganizationModel;
import org.keycloak.models.RealmModel;
import org.keycloak.models.cache.CacheRealmProvider;
import org.keycloak.models.cache.infinispan.CachedCount;
import org.keycloak.models.cache.infinispan.RealmCacheManager;
import org.keycloak.models.cache.infinispan.RealmCacheSession;
import org.keycloak.organization.OrganizationProvider;
@ -35,15 +40,18 @@ public class InfinispanIdentityProviderStorageProvider implements IdentityProvid
private static final String IDP_COUNT_KEY_SUFFIX = ".idp.count";
private static final String IDP_ALIAS_KEY_SUFFIX = ".idp.alias";
private static final String IDP_ORG_ID_KEY_SUFFIX = ".idp.orgId";
private final KeycloakSession session;
private final IdentityProviderStorageProvider idpDelegate;
private final RealmCacheSession realmCache;
private final long startupRevision;
public InfinispanIdentityProviderStorageProvider(KeycloakSession session) {
this.session = session;
this.idpDelegate = session.getProvider(IdentityProviderStorageProvider.class, "jpa");
this.realmCache = (RealmCacheSession) session.getProvider(CacheRealmProvider.class);
this.startupRevision = realmCache.getCache().getCurrentCounter();
}
private static String cacheKeyIdpCount(RealmModel realm) {
@ -58,6 +66,10 @@ public class InfinispanIdentityProviderStorageProvider implements IdentityProvid
return realm.getId() + "." + alias + IDP_ALIAS_KEY_SUFFIX + "." + name;
}
public static String cacheKeyOrgId(RealmModel realm, String orgId) {
return realm.getId() + "." + orgId + IDP_ORG_ID_KEY_SUFFIX;
}
@Override
public IdentityProviderModel create(IdentityProviderModel model) {
registerCountInvalidation();
@ -140,6 +152,52 @@ public class InfinispanIdentityProviderStorageProvider implements IdentityProvid
return createOrganizationAwareIdentityProviderModel(cached.getIdentityProvider());
}
@Override
public Stream<IdentityProviderModel> getByOrganization(String orgId, Integer first, Integer max) {
RealmModel realm = getRealm();
String cacheKey = cacheKeyOrgId(realm, orgId);
// check if there is invalidation for this key or the organization was invalidated
if (isInvalid(cacheKey) || isInvalid(orgId)) {
return idpDelegate.getByOrganization(orgId, first, max).map(this::createOrganizationAwareIdentityProviderModel);
}
RealmCacheManager cache = realmCache.getCache();
IdentityProviderListQuery query = cache.get(cacheKey, IdentityProviderListQuery.class);
String searchKey = Optional.ofNullable(first).orElse(-1) + "." + Optional.ofNullable(max).orElse(-1);
Set<String> cached;
if (query == null) {
// not cached yet
Long loaded = cache.getCurrentRevision(cacheKey);
cached = idpDelegate.getByOrganization(orgId, first, max).map(IdentityProviderModel::getInternalId).collect(Collectors.toSet());
query = new IdentityProviderListQuery(loaded, cacheKey, realm, searchKey, cached);
cache.addRevisioned(query, startupRevision);
} else {
cached = query.getIDPs(searchKey);
if (cached == null) {
// there is a cache entry, but the current search is not yet cached
cache.invalidateObject(cacheKey);
Long loaded = cache.getCurrentRevision(cacheKey);
cached = idpDelegate.getByOrganization(orgId, first, max).map(IdentityProviderModel::getInternalId).collect(Collectors.toSet());
query = new IdentityProviderListQuery(loaded, cacheKey, realm, searchKey, cached, query);
cache.addRevisioned(query, cache.getCurrentCounter());
}
}
Set<IdentityProviderModel> identityProviders = new HashSet<>();
for (String id : cached) {
IdentityProviderModel idp = session.identityProviders().getById(id);
if (idp == null) {
realmCache.registerInvalidation(cacheKey);
return idpDelegate.getByOrganization(orgId, first, max).map(this::createOrganizationAwareIdentityProviderModel);
}
identityProviders.add(idp);
}
return identityProviders.stream();
}
@Override
public Stream<String> getByFlow(String flowId, String search, Integer first, Integer max) {
return idpDelegate.getByFlow(flowId, search, first, max);

View file

@ -30,6 +30,8 @@ import org.keycloak.models.cache.infinispan.CachedCount;
import org.keycloak.models.cache.infinispan.RealmCacheSession;
import org.keycloak.organization.OrganizationProvider;
import static org.keycloak.models.cache.infinispan.idp.InfinispanIdentityProviderStorageProvider.cacheKeyOrgId;
public class InfinispanOrganizationProvider implements OrganizationProvider {
private static final String ORG_COUNT_KEY_SUFFIX = ".org.count";
@ -298,6 +300,7 @@ public class InfinispanOrganizationProvider implements OrganizationProvider {
void registerOrganizationInvalidation(OrganizationModel organization) {
String id = organization.getId();
realmCache.registerInvalidation(cacheKeyOrgId(getRealm(), id));
realmCache.registerInvalidation(id);
organization.getDomains()
.map(OrganizationDomainModel::getName)

View file

@ -20,6 +20,7 @@ package org.keycloak.testsuite.organization.cache;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.keycloak.models.cache.infinispan.idp.InfinispanIdentityProviderStorageProvider.cacheKeyOrgId;
import static org.keycloak.models.cache.infinispan.organization.InfinispanOrganizationProvider.cacheKeyOrgMemberCount;
import java.util.List;
@ -30,6 +31,7 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.keycloak.common.Profile.Feature;
import org.keycloak.models.IdentityProviderStorageProvider;
import org.keycloak.models.OrganizationDomainModel;
import org.keycloak.models.OrganizationModel;
import org.keycloak.models.RealmModel;
@ -37,7 +39,11 @@ import org.keycloak.models.UserModel;
import org.keycloak.models.cache.CacheRealmProvider;
import org.keycloak.models.cache.infinispan.RealmCacheSession;
import org.keycloak.models.cache.infinispan.CachedCount;
import org.keycloak.models.cache.infinispan.idp.IdentityProviderListQuery;
import org.keycloak.organization.OrganizationProvider;
import org.keycloak.representations.idm.IdentityProviderRepresentation;
import org.keycloak.representations.idm.OrganizationDomainRepresentation;
import org.keycloak.representations.idm.OrganizationRepresentation;
import org.keycloak.representations.idm.UserRepresentation;
import org.keycloak.testsuite.arquillian.annotation.EnableFeature;
import org.keycloak.testsuite.organization.admin.AbstractOrganizationTest;
@ -279,4 +285,82 @@ public class OrganizationCacheTest extends AbstractOrganizationTest {
assertEquals(0, cached.getCount());
});
}
@Test
public void testCacheIDPByOrg() {
IdentityProviderRepresentation idpRep = testRealm().identityProviders().get("orga-identity-provider").toRepresentation();
idpRep.setInternalId(null);
idpRep.setOrganizationId(null);
idpRep.getConfig().remove(OrganizationModel.ORGANIZATION_DOMAIN_ATTRIBUTE);
idpRep.getConfig().put(OrganizationModel.BROKER_PUBLIC, Boolean.TRUE.toString());
for (int i = 0; i < 10; i++) {
final String alias = "org-idp-" + i;
idpRep.setAlias(alias);
testRealm().identityProviders().create(idpRep).close();
getCleanup().addCleanup(testRealm().identityProviders().get("alias")::remove);
}
String orgaId = testRealm().organizations().getAll().get(0).getId();
String orgbId = testRealm().organizations().getAll().get(1).getId();
for (int i = 0; i < 5; i++) {
final String aliasA = "org-idp-" + i;
final String aliasB = "org-idp-" + (i + 5);
testRealm().organizations().get(orgaId).identityProviders().addIdentityProvider(aliasA);
testRealm().organizations().get(orgbId).identityProviders().addIdentityProvider(aliasB);
}
getTestingClient().server(TEST_REALM_NAME).run((RunOnServer) session -> {
IdentityProviderStorageProvider idpProvider = session.getProvider(IdentityProviderStorageProvider.class);
RealmModel realm = session.getContext().getRealm();
String cachedKeyA = cacheKeyOrgId(realm, orgaId);
RealmCacheSession realmCache = (RealmCacheSession) session.getProvider(CacheRealmProvider.class);
IdentityProviderListQuery identityProviderListQuery = realmCache.getCache().get(cachedKeyA, IdentityProviderListQuery.class);
assertNull(identityProviderListQuery);
String cachedKeyB = cacheKeyOrgId(realm, orgbId);
identityProviderListQuery = realmCache.getCache().get(cachedKeyB, IdentityProviderListQuery.class);
assertNull(identityProviderListQuery);
idpProvider.getByOrganization(orgaId, null, null);
identityProviderListQuery = realmCache.getCache().get(cachedKeyA, IdentityProviderListQuery.class);
assertNotNull(identityProviderListQuery);
assertEquals(6, identityProviderListQuery.getIDPs("-1.-1").size());
idpProvider.getByOrganization(orgbId, 0, 2);
idpProvider.getByOrganization(orgbId, 2, 6);
identityProviderListQuery = realmCache.getCache().get(cachedKeyB, IdentityProviderListQuery.class);
assertNotNull(identityProviderListQuery);
assertEquals(2, identityProviderListQuery.getIDPs("0.2").size());
assertEquals(4, identityProviderListQuery.getIDPs("2.6").size());
});
// update orga which should invalidate getByOrganization IDP cache
OrganizationRepresentation rep = testRealm().organizations().get(orgaId).toRepresentation();
OrganizationDomainRepresentation orgDomainRep = new OrganizationDomainRepresentation();
orgDomainRep.setName("orgaa.org");
rep.addDomain(orgDomainRep);
testRealm().organizations().get(orgaId).update(rep).close();
// update an IDP that is associated with orgb, that should invalidate getByOrganization IDP cache
idpRep = testRealm().identityProviders().get("org-idp-5").toRepresentation();
idpRep.setDisplayName("something");
testRealm().identityProviders().get("org-idp-5").update(idpRep);
getTestingClient().server(TEST_REALM_NAME).run((RunOnServer) session -> {
IdentityProviderStorageProvider idpProvider = session.getProvider(IdentityProviderStorageProvider.class);
RealmModel realm = session.getContext().getRealm();
String cachedKeyA = cacheKeyOrgId(realm, orgaId);
RealmCacheSession realmCache = (RealmCacheSession) session.getProvider(CacheRealmProvider.class);
IdentityProviderListQuery identityProviderListQuery = realmCache.getCache().get(cachedKeyA, IdentityProviderListQuery.class);
assertNull(identityProviderListQuery);
String cachedKeyB = cacheKeyOrgId(realm, orgbId);
identityProviderListQuery = realmCache.getCache().get(cachedKeyB, IdentityProviderListQuery.class);
assertNull(identityProviderListQuery);
});
}
}