KEYCLOAK-17348 Add manual pagination into UserStorageManager#query

This commit is contained in:
mhajas 2021-05-12 13:09:05 +02:00 committed by Hynek Mlnařík
parent 8feefe94ac
commit f37a24dd91
7 changed files with 418 additions and 20 deletions

View file

@ -52,6 +52,8 @@ import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import static org.keycloak.models.utils.KeycloakModelUtils.runJobInTransaction;
@ -157,10 +159,19 @@ public class UserStorageManager extends AbstractStorageManager<UserStorageProvid
@FunctionalInterface
interface PaginatedQuery {
Stream<UserModel> query(Object provider);
Stream<UserModel> query(Object provider, Integer firstResult, Integer maxResults);
}
@FunctionalInterface
interface CountQuery {
int query(Object provider, Integer firstResult, Integer maxResult);
}
protected Stream<UserModel> query(PaginatedQuery pagedQuery, RealmModel realm, Integer firstResult, Integer maxResults) {
return query(pagedQuery, ((provider, first, max) -> (int) pagedQuery.query(provider, first, max).count()), realm, firstResult, maxResults);
}
protected Stream<UserModel> query(PaginatedQuery pagedQuery, CountQuery countQuery, RealmModel realm, Integer firstResult, Integer maxResults) {
if (maxResults != null && maxResults == 0) return Stream.empty();
Stream<Object> providersStream = Stream.concat(Stream.of((Object) localStorage()), getEnabledStorageProviders(realm, UserQueryProvider.class));
@ -170,7 +181,54 @@ public class UserStorageManager extends AbstractStorageManager<UserStorageProvid
providersStream = Stream.concat(providersStream, Stream.of(federatedStorageProvider));
}
return paginatedStream(providersStream.flatMap(pagedQuery::query), firstResult, maxResults);
final AtomicInteger currentFirst;
if (firstResult == null || firstResult <= 0) { // We don't want to skip any users so we don't need to do firstResult filtering
currentFirst = new AtomicInteger(0);
} else {
AtomicBoolean droppingProviders = new AtomicBoolean(true);
currentFirst = new AtomicInteger(firstResult);
providersStream = providersStream
.filter(provider -> { // This is basically dropWhile
if (!droppingProviders.get()) return true; // We have already gathered enough users to pass firstResult number in previous providers, we can take all following providers
long expectedNumberOfUsersForProvider = countQuery.query(provider, 0, currentFirst.get() + 1); // check how many users we can obtain from this provider
if (expectedNumberOfUsersForProvider == currentFirst.get()) { // This provider provides exactly the amount of users we need for passing firstResult, we can set currentFirst to 0 and drop this provider
currentFirst.set(0);
droppingProviders.set(false);
return false;
}
if (expectedNumberOfUsersForProvider > currentFirst.get()) { // If we can obtain enough enough users from this provider to fulfill our need we can stop dropping providers
droppingProviders.set(false);
return true; // don't filter out this provider because we are going to return some users from it
}
// This provider cannot provide enough users to pass firstResult so we are going to filter it out and change firstResult for next provider
currentFirst.set((int) (currentFirst.get() - expectedNumberOfUsersForProvider));
return false;
});
}
// Actual user querying
if (maxResults == null || maxResults < 0) {
// No maxResult set, we want all users
return providersStream
.flatMap(provider -> pagedQuery.query(provider, currentFirst.getAndSet(0), null));
} else {
final AtomicInteger currentMax = new AtomicInteger(maxResults);
// Query users with currentMax variable counting how many users we return
return providersStream
.filter(provider -> currentMax.get() != 0) // If we reach currentMax == 0, we can skip querying all following providers
.flatMap(provider -> pagedQuery.query(provider, currentFirst.getAndSet(0), currentMax.get()))
.peek(userModel -> {
currentMax.updateAndGet(i -> i > 0 ? i - 1 : i);
});
}
}
// removeDuplicates method may cause concurrent issues, it should not be used on parallel streams
@ -260,12 +318,12 @@ public class UserStorageManager extends AbstractStorageManager<UserStorageProvid
@Override
public Stream<UserModel> getGroupMembersStream(final RealmModel realm, final GroupModel group, Integer firstResult, Integer maxResults) {
Stream<UserModel> results = query((provider) -> {
Stream<UserModel> results = query((provider, firstResultInQuery, maxResultsInQuery) -> {
if (provider instanceof UserQueryProvider) {
return ((UserQueryProvider)provider).getGroupMembersStream(realm, group);
return ((UserQueryProvider)provider).getGroupMembersStream(realm, group, firstResultInQuery, maxResultsInQuery);
} else if (provider instanceof UserFederatedStorageProvider) {
return ((UserFederatedStorageProvider)provider).getMembershipStream(realm, group, -1, -1).
return ((UserFederatedStorageProvider)provider).getMembershipStream(realm, group, firstResultInQuery, maxResultsInQuery).
map(id -> getUserById(realm, id));
}
return Stream.empty();
@ -276,9 +334,9 @@ public class UserStorageManager extends AbstractStorageManager<UserStorageProvid
@Override
public Stream<UserModel> getRoleMembersStream(final RealmModel realm, final RoleModel role, Integer firstResult, Integer maxResults) {
Stream<UserModel> results = query((provider) -> {
Stream<UserModel> results = query((provider, firstResultInQuery, maxResultsInQuery) -> {
if (provider instanceof UserQueryProvider) {
return ((UserQueryProvider)provider).getRoleMembersStream(realm, role);
return ((UserQueryProvider)provider).getRoleMembersStream(realm, role, firstResultInQuery, maxResultsInQuery);
}
return Stream.empty();
}, realm, firstResult, maxResults);
@ -298,9 +356,9 @@ public class UserStorageManager extends AbstractStorageManager<UserStorageProvid
@Override
public Stream<UserModel> getUsersStream(final RealmModel realm, Integer firstResult, Integer maxResults, final boolean includeServiceAccounts) {
Stream<UserModel> results = query((provider) -> {
Stream<UserModel> results = query((provider, firstResultInQuery, maxResultsInQuery) -> {
if (provider instanceof UserProvider) { // it is local storage
return ((UserProvider) provider).getUsersStream(realm, includeServiceAccounts);
return ((UserProvider) provider).getUsersStream(realm, firstResultInQuery, maxResultsInQuery, includeServiceAccounts);
} else if (provider instanceof UserQueryProvider) {
return ((UserQueryProvider)provider).getUsersStream(realm);
}
@ -352,26 +410,41 @@ public class UserStorageManager extends AbstractStorageManager<UserStorageProvid
@Override
public Stream<UserModel> searchForUserStream(RealmModel realm, String search, Integer firstResult, Integer maxResults) {
Stream<UserModel> results = query((provider) -> {
Stream<UserModel> results = query((provider, firstResultInQuery, maxResultsInQuery) -> {
if (provider instanceof UserQueryProvider) {
return ((UserQueryProvider)provider).searchForUserStream(realm, search);
return ((UserQueryProvider)provider).searchForUserStream(realm, search, firstResultInQuery, maxResultsInQuery);
}
return Stream.empty();
}, (provider, firstResultInQuery, maxResultsInQuery) -> {
if (provider instanceof UserQueryProvider) {
return ((UserQueryProvider)provider).getUsersCount(realm, search);
}
return 0;
}, realm, firstResult, maxResults);
return importValidation(realm, results);
}
@Override
public Stream<UserModel> searchForUserStream(RealmModel realm, Map<String, String> attributes, Integer firstResult, Integer maxResults) {
Stream<UserModel> results = query((provider) -> {
Stream<UserModel> results = query((provider, firstResultInQuery, maxResultsInQuery) -> {
if (provider instanceof UserQueryProvider) {
if (attributes.containsKey(UserModel.SEARCH)) {
return ((UserQueryProvider)provider).searchForUserStream(realm, attributes.get(UserModel.SEARCH));
return ((UserQueryProvider)provider).searchForUserStream(realm, attributes.get(UserModel.SEARCH), firstResultInQuery, maxResultsInQuery);
} else {
return ((UserQueryProvider)provider).searchForUserStream(realm, attributes);
return ((UserQueryProvider)provider).searchForUserStream(realm, attributes, firstResultInQuery, maxResultsInQuery);
}
}
return Stream.empty();
},
(provider, firstResultInQuery, maxResultsInQuery) -> {
if (provider instanceof UserQueryProvider) {
if (attributes.containsKey(UserModel.SEARCH)) {
return ((UserQueryProvider)provider).getUsersCount(realm, attributes.get(UserModel.SEARCH));
} else {
return ((UserQueryProvider)provider).getUsersCount(realm, attributes);
}
}
return 0;
}
, realm, firstResult, maxResults);
return importValidation(realm, results);
@ -379,17 +452,17 @@ public class UserStorageManager extends AbstractStorageManager<UserStorageProvid
@Override
public Stream<UserModel> searchForUserByUserAttributeStream(RealmModel realm, String attrName, String attrValue) {
Stream<UserModel> results = query((provider) -> {
Stream<UserModel> results = query((provider, firstResultInQuery, maxResultsInQuery) -> {
if (provider instanceof UserQueryProvider) {
return ((UserQueryProvider)provider).searchForUserByUserAttributeStream(realm, attrName, attrValue);
return paginatedStream(((UserQueryProvider)provider).searchForUserByUserAttributeStream(realm, attrName, attrValue), firstResultInQuery, maxResultsInQuery);
} else if (provider instanceof UserFederatedStorageProvider) {
return ((UserFederatedStorageProvider)provider).getUsersByUserAttributeStream(realm, attrName, attrValue)
return paginatedStream(((UserFederatedStorageProvider)provider).getUsersByUserAttributeStream(realm, attrName, attrValue)
.map(id -> getUserById(realm, id))
.filter(Objects::nonNull);
.filter(Objects::nonNull), firstResultInQuery, maxResultsInQuery);
}
return Stream.empty();
}, realm,null, null);
}, realm, null, null);
// removeDuplicates method may cause concurrent issues, it should not be used on parallel streams
results = removeDuplicates(results);

View file

@ -33,6 +33,11 @@ import org.keycloak.storage.adapter.AbstractUserAdapterFederatedStorage;
import org.keycloak.storage.user.UserLookupProvider;
import org.keycloak.storage.user.UserQueryProvider;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
@ -47,11 +52,40 @@ import static org.keycloak.utils.StreamsUtil.paginatedStream;
*/
public class UserPropertyFileStorage implements UserLookupProvider.Streams, UserStorageProvider, UserQueryProvider.Streams, CredentialInputValidator {
public static final String SEARCH_METHOD = "searchForUserStream(RealmMode, String, Integer, Integer)";
public static final String COUNT_SEARCH_METHOD = "getUsersCount(RealmModel, String)";
protected Properties userPasswords;
protected ComponentModel model;
protected KeycloakSession session;
protected boolean federatedStorageEnabled;
public static Map<String, List<UserPropertyFileStorageCall>> storageCalls = new HashMap<>();
public static class UserPropertyFileStorageCall implements Serializable {
private final String method;
private final Integer first;
private final Integer max;
public UserPropertyFileStorageCall(String method, Integer first, Integer max) {
this.method = method;
this.first = first;
this.max = max;
}
public String getMethod() {
return method;
}
public Integer getFirst() {
return first;
}
public Integer getMax() {
return max;
}
}
public UserPropertyFileStorage(KeycloakSession session, ComponentModel model, Properties userPasswords) {
this.session = session;
this.model = model;
@ -59,6 +93,23 @@ public class UserPropertyFileStorage implements UserLookupProvider.Streams, User
this.federatedStorageEnabled = model.getConfig().containsKey("federatedStorage") && Boolean.valueOf(model.getConfig().getFirst("federatedStorage")).booleanValue();
}
private void addCall(String method, Integer first, Integer max) {
storageCalls.merge(model.getId(), new LinkedList<>(Collections.singletonList(new UserPropertyFileStorageCall(method, first, max))), (a, b) -> {
a.addAll(b);
return a;
});
}
private void addCall(String method) {
addCall(method, null, null);
}
@Override
public int getUsersCount(RealmModel realm, String search) {
addCall(COUNT_SEARCH_METHOD);
return (int) searchForUser(realm, search, null, null, username -> username.contains(search)).count();
}
@Override
public UserModel getUserById(RealmModel realm, String id) {
@ -159,6 +210,7 @@ public class UserPropertyFileStorage implements UserLookupProvider.Streams, User
@Override
public Stream<UserModel> searchForUserStream(RealmModel realm, String search, Integer firstResult, Integer maxResults) {
addCall(SEARCH_METHOD, firstResult, maxResults);
return searchForUser(realm, search, firstResult, maxResults, username -> username.contains(search));
}

View file

@ -206,6 +206,20 @@
</properties>
</profile>
<profile>
<id>jpa-federation-file-storage</id>
<properties>
<keycloak.model.parameters>JpaFederation,TestsuiteUserFileStorage</keycloak.model.parameters>
</properties>
</profile>
<profile>
<id>jpa-federation-file-storage+infinispan</id>
<properties>
<keycloak.model.parameters>JpaFederation,TestsuiteUserFileStorage,Infinispan</keycloak.model.parameters>
</properties>
</profile>
<profile>
<id>jpa-federation+ldap</id>
<properties>

View file

@ -0,0 +1,161 @@
package org.keycloak.testsuite.model;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.keycloak.component.ComponentModel;
import org.keycloak.models.Constants;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.RealmProvider;
import org.keycloak.models.UserModel;
import org.keycloak.models.UserProvider;
import org.keycloak.storage.UserStorageProvider;
import org.keycloak.storage.UserStorageProviderFactory;
import org.keycloak.storage.UserStorageProviderModel;
import org.keycloak.testsuite.federation.UserPropertyFileStorage;
import org.keycloak.testsuite.federation.UserPropertyFileStorage.UserPropertyFileStorageCall;
import org.keycloak.testsuite.federation.UserPropertyFileStorageFactory;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.junit.Assume.assumeThat;
/**
* @author mhajas
*/
@RequireProvider(UserProvider.class)
@RequireProvider(RealmProvider.class)
@RequireProvider(value = UserStorageProvider.class, only = UserPropertyFileStorageFactory.PROVIDER_ID)
public class UserPaginationTest extends KeycloakModelTest {
private String realmId;
private String userFederationId1;
private String userFederationId2;
@Override
public void createEnvironment(KeycloakSession s) {
RealmModel realm = s.realms().createRealm("realm");
realm.setDefaultRole(s.roles().addRealmRole(realm, Constants.DEFAULT_ROLES_ROLE_PREFIX + "-" + realm.getName()));
this.realmId = realm.getId();
getParameters(UserStorageProviderModel.class).forEach(fs -> inComittedTransaction(session -> {
assumeThat("Cannot handle more than 2 user federation provider", userFederationId2, Matchers.nullValue());
fs.setParentId(realmId);
ComponentModel res = realm.addComponentModel(fs);
if (userFederationId1 == null) {
userFederationId1 = res.getId();
} else {
userFederationId2 = res.getId();
}
log.infof("Added %s user federation provider: %s", fs.getName(), res.getId());
}));
}
@Override
public void cleanEnvironment(KeycloakSession s) {
s.realms().removeRealm(realmId);
}
@Test
public void testNoPaginationCalls() {
List<UserModel> list = withRealm(realmId, (session, realm) ->
session.users().searchForUserStream(realm,"", 0, Constants.DEFAULT_MAX_RESULTS) // Default values used in UsersResource
.collect(Collectors.toList()));
assertThat(list, hasSize(8));
expectedStorageCalls(
Collections.singletonList(new UserPropertyFileStorageCall(UserPropertyFileStorage.SEARCH_METHOD, 0, Constants.DEFAULT_MAX_RESULTS)),
Collections.singletonList(new UserPropertyFileStorageCall(UserPropertyFileStorage.SEARCH_METHOD, 0, Constants.DEFAULT_MAX_RESULTS - 4))
);
}
@Test
public void testPaginationStarting0() {
List<UserModel> list = withRealm(realmId, (session, realm) ->
session.users().searchForUserStream(realm,"", 0, 6)
.collect(Collectors.toList()));
assertThat(list, hasSize(6));
expectedStorageCalls(
Collections.singletonList(new UserPropertyFileStorageCall(UserPropertyFileStorage.SEARCH_METHOD, 0, 6)),
Collections.singletonList(new UserPropertyFileStorageCall(UserPropertyFileStorage.SEARCH_METHOD, 0, 2))
);
}
@Test
public void testPaginationFirstResultInFirstProvider() {
List<UserModel> list = withRealm(realmId, (session, realm) ->
session.users().searchForUserStream(realm,"", 1, 6)
.collect(Collectors.toList()));
assertThat(list, hasSize(6));
expectedStorageCalls(
Arrays.asList(new UserPropertyFileStorageCall(UserPropertyFileStorage.COUNT_SEARCH_METHOD, null, null), new UserPropertyFileStorageCall(UserPropertyFileStorage.SEARCH_METHOD, 1, 6)),
Collections.singletonList(new UserPropertyFileStorageCall(UserPropertyFileStorage.SEARCH_METHOD, 0, 3))
);
}
@Test
public void testPaginationFirstResultIsExactlyTheAmountOfUsersInTheFirstProvider() {
List<UserModel> list = withRealm(realmId, (session, realm) ->
session.users().searchForUserStream(realm,"", 4, 6)
.collect(Collectors.toList()));
assertThat(list, hasSize(4));
expectedStorageCalls(
Collections.singletonList(new UserPropertyFileStorageCall(UserPropertyFileStorage.COUNT_SEARCH_METHOD, null, null)),
Collections.singletonList(new UserPropertyFileStorageCall(UserPropertyFileStorage.SEARCH_METHOD, 0, 6))
);
}
@Test
public void testPaginationFirstResultIsInSecondProvider() {
List<UserModel> list = withRealm(realmId, (session, realm) ->
session.users().searchForUserStream(realm,"", 5, 6)
.collect(Collectors.toList()));
assertThat(list, hasSize(3));
expectedStorageCalls(
Collections.singletonList(new UserPropertyFileStorageCall(UserPropertyFileStorage.COUNT_SEARCH_METHOD, null, null)),
Arrays.asList(new UserPropertyFileStorageCall(UserPropertyFileStorage.COUNT_SEARCH_METHOD, null, null), new UserPropertyFileStorageCall(UserPropertyFileStorage.SEARCH_METHOD, 1, 6))
);
}
private void expectedStorageCalls(final List<UserPropertyFileStorageCall> roCalls, final List<UserPropertyFileStorageCall> rwCalls) {
assertThat(UserPropertyFileStorage.storageCalls.get(userFederationId1), hasSize(roCalls.size()));
int i = 0;
for (UserPropertyFileStorageCall call : roCalls) {
assertThat(UserPropertyFileStorage.storageCalls.get(userFederationId1).get(i).getMethod(), equalTo(call.getMethod()));
assertThat(UserPropertyFileStorage.storageCalls.get(userFederationId1).get(i).getFirst(), equalTo(call.getFirst()));
assertThat(UserPropertyFileStorage.storageCalls.get(userFederationId1).get(i).getMax(), equalTo(call.getMax()));
i++;
}
assertThat(UserPropertyFileStorage.storageCalls.get(userFederationId2), hasSize(rwCalls.size()));
i = 0;
for (UserPropertyFileStorageCall call : rwCalls) {
assertThat(UserPropertyFileStorage.storageCalls.get(userFederationId2).get(i).getMethod(), equalTo(call.getMethod()));
assertThat(UserPropertyFileStorage.storageCalls.get(userFederationId2).get(i).getFirst(), equalTo(call.getFirst()));
assertThat(UserPropertyFileStorage.storageCalls.get(userFederationId2).get(i).getMax(), equalTo(call.getMax()));
i++;
}
UserPropertyFileStorage.storageCalls.clear();
}
}

View file

@ -0,0 +1,90 @@
/*
* Copyright 2020 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.testsuite.model.parameters;
import com.google.common.collect.ImmutableSet;
import org.keycloak.common.util.MultivaluedHashMap;
import org.keycloak.provider.ProviderFactory;
import org.keycloak.provider.Spi;
import org.keycloak.representations.idm.ComponentRepresentation;
import org.keycloak.storage.UserStorageProvider;
import org.keycloak.storage.UserStorageProviderModel;
import org.keycloak.testsuite.federation.UserMapStorageFactory;
import org.keycloak.testsuite.federation.UserPropertyFileStorageFactory;
import org.keycloak.testsuite.model.KeycloakModelParameters;
import java.io.File;
import java.net.URISyntaxException;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
/**
*
* @author hmlnarik
*/
public class TestsuiteUserFileStorage extends KeycloakModelParameters {
static final Set<Class<? extends Spi>> ALLOWED_SPIS = ImmutableSet.<Class<? extends Spi>>builder()
.build();
static final Set<Class<? extends ProviderFactory>> ALLOWED_FACTORIES = ImmutableSet.<Class<? extends ProviderFactory>>builder()
.add(UserPropertyFileStorageFactory.class)
.build();
private static final File CONFIG_DIR;
static {
try {
CONFIG_DIR = new File(TestsuiteUserFileStorage.class.getClassLoader().getResource("file-storage-provider").toURI());
} catch (URISyntaxException e) {
throw new RuntimeException("Cannot get resource directory");
}
}
public TestsuiteUserFileStorage() {
super(ALLOWED_SPIS, ALLOWED_FACTORIES);
}
@Override
public <T> Stream<T> getParameters(Class<T> clazz) {
if (UserStorageProviderModel.class.isAssignableFrom(clazz)) {
UserStorageProviderModel propProviderRO = new UserStorageProviderModel();
propProviderRO.setName("read-only-user-props");
propProviderRO.setProviderId(UserPropertyFileStorageFactory.PROVIDER_ID);
propProviderRO.setProviderType(UserStorageProvider.class.getName());
propProviderRO.setConfig(new MultivaluedHashMap<>());
propProviderRO.getConfig().putSingle("priority", Integer.toString(1));
propProviderRO.getConfig().putSingle("propertyFile",
CONFIG_DIR.getAbsolutePath() + File.separator + "read-only-user-password.properties");
UserStorageProviderModel propProviderRW = new UserStorageProviderModel();
propProviderRW.setName("user-props");
propProviderRW.setProviderId(UserPropertyFileStorageFactory.PROVIDER_ID);
propProviderRW.setProviderType(UserStorageProvider.class.getName());
propProviderRW.setConfig(new MultivaluedHashMap<>());
propProviderRW.getConfig().putSingle("priority", Integer.toString(2));
propProviderRW.getConfig().putSingle("propertyFile", CONFIG_DIR.getAbsolutePath() + File.separator + "user-password.properties");
propProviderRW.getConfig().putSingle("federatedStorage", "true");
return Stream.of((T) propProviderRO, (T) propProviderRW);
} else {
return super.getParameters(clazz);
}
}
}

View file

@ -0,0 +1,4 @@
tbrady=goat
rob=pw
jules=pw
danny=pw

View file

@ -0,0 +1,4 @@
thor=hammer
zeus=pw
apollo=pw
perseus=pw