KEYCLOAK-15522 Use AbstractStorageManager in UserStorageManager

This commit is contained in:
mhajas 2020-09-18 17:50:58 +02:00 committed by Hynek Mlnařík
parent a2efb84e00
commit 4556e858ad
13 changed files with 731 additions and 643 deletions

View file

@ -482,8 +482,14 @@ public class JpaUserFederatedStorageProvider implements
TypedQuery<String> query = em.createNamedQuery("fedgroupMembership", String.class)
.setParameter("realmId", realm.getId())
.setParameter("groupId", group.getId());
query.setFirstResult(firstResult);
query.setMaxResults(max);
if (firstResult != -1) {
query.setFirstResult(firstResult);
}
if (max != -1) {
query.setMaxResults(max);
}
return query.getResultList();
}

View file

@ -25,14 +25,16 @@ import org.keycloak.models.UserModel;
import org.keycloak.models.cache.CachedUserModel;
import org.keycloak.models.cache.OnUserCache;
import org.keycloak.models.cache.UserCache;
import org.keycloak.provider.ProviderFactory;
import org.keycloak.storage.AbstractStorageManager;
import org.keycloak.storage.StorageId;
import org.keycloak.storage.UserStorageManager;
import org.keycloak.storage.UserStorageProvider;
import org.keycloak.storage.UserStorageProviderFactory;
import org.keycloak.storage.UserStorageProviderModel;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
@ -44,11 +46,10 @@ import java.util.stream.Stream;
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public class UserCredentialStoreManager implements UserCredentialManager, OnUserCache {
protected KeycloakSession session;
public class UserCredentialStoreManager extends AbstractStorageManager<UserStorageProvider, UserStorageProviderModel> implements UserCredentialManager, OnUserCache {
public UserCredentialStoreManager(KeycloakSession session) {
this.session = session;
super(session, UserStorageProviderFactory.class, UserStorageProvider.class, UserStorageProviderModel::new, "user");
}
protected UserCredentialStore getStoreForUser(UserModel user) {
@ -145,29 +146,16 @@ public class UserCredentialStoreManager implements UserCredentialManager, OnUser
if (!isValid(user)) {
return false;
}
List<CredentialInput> toValidate = new LinkedList<>();
toValidate.addAll(inputs);
if (!StorageId.isLocalStorage(user)) {
String providerId = StorageId.resolveProviderId(user);
UserStorageProvider provider = UserStorageManager.getStorageProvider(session, realm, providerId);
if (provider instanceof CredentialInputValidator) {
if (!UserStorageManager.isStorageProviderEnabled(realm, providerId)) return false;
Iterator<CredentialInput> it = toValidate.iterator();
while (it.hasNext()) {
CredentialInput input = it.next();
CredentialInputValidator validator = (CredentialInputValidator) provider;
if (validator.supportsCredentialType(input.getType()) && validator.isValid(realm, user, input)) {
it.remove();
}
}
}
} else {
if (user.getFederationLink() != null) {
UserStorageProvider provider = UserStorageManager.getStorageProvider(session, realm, user.getFederationLink());
if (provider instanceof CredentialInputValidator) {
if (!UserStorageManager.isStorageProviderEnabled(realm, user.getFederationLink())) return false;
validate(realm, user, toValidate, ((CredentialInputValidator)provider));
}
List<CredentialInput> toValidate = new LinkedList<>(inputs);
String providerId = StorageId.isLocalStorage(user) ? user.getFederationLink() : StorageId.resolveProviderId(user);
if (providerId != null) {
UserStorageProviderModel model = getStorageProviderModel(realm, providerId);
if (model == null || !model.isEnabled()) return false;
CredentialInputValidator validator = getStorageProviderInstance(model, CredentialInputValidator.class);
if (validator != null) {
validate(realm, user, toValidate, validator);
}
}
@ -180,13 +168,7 @@ public class UserCredentialStoreManager implements UserCredentialManager, OnUser
}
private void validate(RealmModel realm, UserModel user, List<CredentialInput> toValidate, CredentialInputValidator validator) {
Iterator<CredentialInput> it = toValidate.iterator();
while (it.hasNext()) {
CredentialInput input = it.next();
if (validator.supportsCredentialType(input.getType()) && validator.isValid(realm, user, input)) {
it.remove();
}
}
toValidate.removeIf(input -> validator.supportsCredentialType(input.getType()) && validator.isValid(realm, user, input));
}
public static <T> Stream<T> getCredentialProviders(KeycloakSession session, Class<T> type) {
@ -198,25 +180,16 @@ public class UserCredentialStoreManager implements UserCredentialManager, OnUser
@Override
public boolean updateCredential(RealmModel realm, UserModel user, CredentialInput input) {
if (!StorageId.isLocalStorage(user)) {
String providerId = StorageId.resolveProviderId(user);
UserStorageProvider provider = UserStorageManager.getStorageProvider(session, realm, providerId);
if (provider instanceof CredentialInputUpdater) {
if (!UserStorageManager.isStorageProviderEnabled(realm, providerId)) return false;
CredentialInputUpdater updater = (CredentialInputUpdater) provider;
if (updater.supportsCredentialType(input.getType())) {
if (updater.updateCredential(realm, user, input)) return true;
}
String providerId = StorageId.isLocalStorage(user) ? user.getFederationLink() : StorageId.resolveProviderId(user);
if (!StorageId.isLocalStorage(user)) throwExceptionIfInvalidUser(user);
}
} else {
throwExceptionIfInvalidUser(user);
if (user.getFederationLink() != null) {
UserStorageProvider provider = UserStorageManager.getStorageProvider(session, realm, user.getFederationLink());
if (provider instanceof CredentialInputUpdater) {
if (!UserStorageManager.isStorageProviderEnabled(realm, user.getFederationLink())) return false;
if (((CredentialInputUpdater) provider).updateCredential(realm, user, input)) return true;
}
if (providerId != null) {
UserStorageProviderModel model = getStorageProviderModel(realm, providerId);
if (model == null || !model.isEnabled()) return false;
CredentialInputUpdater updater = getStorageProviderInstance(model, CredentialInputUpdater.class);
if (updater != null && updater.supportsCredentialType(input.getType())) {
if (updater.updateCredential(realm, user, input)) return true;
}
}
@ -227,27 +200,16 @@ public class UserCredentialStoreManager implements UserCredentialManager, OnUser
@Override
public void disableCredentialType(RealmModel realm, UserModel user, String credentialType) {
if (!StorageId.isLocalStorage(user)) {
String providerId = StorageId.resolveProviderId(user);
UserStorageProvider provider = UserStorageManager.getStorageProvider(session, realm, providerId);
if (provider instanceof CredentialInputUpdater) {
if (!UserStorageManager.isStorageProviderEnabled(realm, providerId)) return;
CredentialInputUpdater updater = (CredentialInputUpdater) provider;
if (updater.supportsCredentialType(credentialType)) {
updater.disableCredentialType(realm, user, credentialType);
}
String providerId = StorageId.isLocalStorage(user) ? user.getFederationLink() : StorageId.resolveProviderId(user);
if (!StorageId.isLocalStorage(user)) throwExceptionIfInvalidUser(user);
if (providerId != null) {
UserStorageProviderModel model = getStorageProviderModel(realm, providerId);
if (model == null || !model.isEnabled()) return;
CredentialInputUpdater updater = getStorageProviderInstance(model, CredentialInputUpdater.class);
if (updater.supportsCredentialType(credentialType)) {
updater.disableCredentialType(realm, user, credentialType);
}
} else {
throwExceptionIfInvalidUser(user);
if (user.getFederationLink() != null) {
UserStorageProvider provider = UserStorageManager.getStorageProvider(session, realm, user.getFederationLink());
if (provider instanceof CredentialInputUpdater) {
if (!UserStorageManager.isStorageProviderEnabled(realm, user.getFederationLink())) return;
((CredentialInputUpdater) provider).disableCredentialType(realm, user, credentialType);
}
}
}
getCredentialProviders(session, CredentialInputUpdater.class)
@ -258,23 +220,13 @@ public class UserCredentialStoreManager implements UserCredentialManager, OnUser
@Override
public Set<String> getDisableableCredentialTypes(RealmModel realm, UserModel user) {
Set<String> types = new HashSet<>();
if (!StorageId.isLocalStorage(user)) {
String providerId = StorageId.resolveProviderId(user);
UserStorageProvider provider = UserStorageManager.getStorageProvider(session, realm, providerId);
if (provider instanceof CredentialInputUpdater) {
if (!UserStorageManager.isStorageProviderEnabled(realm, providerId)) return Collections.EMPTY_SET;
CredentialInputUpdater updater = (CredentialInputUpdater) provider;
types.addAll(updater.getDisableableCredentialTypes(realm, user));
}
} else {
if (user.getFederationLink() != null) {
UserStorageProvider provider = UserStorageManager.getStorageProvider(session, realm, user.getFederationLink());
if (provider instanceof CredentialInputUpdater) {
if (!UserStorageManager.isStorageProviderEnabled(realm, user.getFederationLink())) return Collections.EMPTY_SET;
types.addAll(((CredentialInputUpdater) provider).getDisableableCredentialTypes(realm, user));
}
}
String providerId = StorageId.isLocalStorage(user) ? user.getFederationLink() : StorageId.resolveProviderId(user);
if (providerId != null) {
UserStorageProviderModel model = getStorageProviderModel(realm, providerId);
if (model == null || !model.isEnabled()) return Collections.EMPTY_SET;
CredentialInputUpdater updater = getStorageProviderInstance(model, CredentialInputUpdater.class);
if (updater != null) types.addAll(updater.getDisableableCredentialTypes(realm, user));
}
types.addAll(getCredentialProviders(session, CredentialInputUpdater.class)
@ -308,25 +260,15 @@ public class UserCredentialStoreManager implements UserCredentialManager, OnUser
private UserStorageCredentialConfigured isConfiguredThroughUserStorage(RealmModel realm, UserModel user, String type) {
if (!StorageId.isLocalStorage(user)) {
String providerId = StorageId.resolveProviderId(user);
UserStorageProvider provider = UserStorageManager.getStorageProvider(session, realm, providerId);
if (provider instanceof CredentialInputValidator) {
if (!UserStorageManager.isStorageProviderEnabled(realm, providerId)) return UserStorageCredentialConfigured.USER_STORAGE_DISABLED;
CredentialInputValidator validator = (CredentialInputValidator) provider;
if (validator.supportsCredentialType(type) && validator.isConfiguredFor(realm, user, type)) {
return UserStorageCredentialConfigured.CONFIGURED;
}
}
} else {
if (user.getFederationLink() != null) {
UserStorageProvider provider = UserStorageManager.getStorageProvider(session, realm, user.getFederationLink());
if (provider instanceof CredentialInputValidator) {
if (!UserStorageManager.isStorageProviderEnabled(realm, user.getFederationLink())) return UserStorageCredentialConfigured.USER_STORAGE_DISABLED;
if (((CredentialInputValidator) provider).isConfiguredFor(realm, user, type)) return UserStorageCredentialConfigured.CONFIGURED;
}
}
String providerId = StorageId.isLocalStorage(user) ? user.getFederationLink() : StorageId.resolveProviderId(user);
if (providerId != null) {
UserStorageProviderModel model = getStorageProviderModel(realm, providerId);
if (model == null || !model.isEnabled()) return UserStorageCredentialConfigured.USER_STORAGE_DISABLED;
CredentialInputValidator validator = getStorageProviderInstance(model, CredentialInputValidator.class);
if (validator.supportsCredentialType(type) && validator.isConfiguredFor(realm, user, type)) {
return UserStorageCredentialConfigured.CONFIGURED;
}
}
return UserStorageCredentialConfigured.NOT_CONFIGURED;
@ -340,22 +282,14 @@ public class UserCredentialStoreManager implements UserCredentialManager, OnUser
@Override
public CredentialValidationOutput authenticate(KeycloakSession session, RealmModel realm, CredentialInput input) {
CredentialValidationOutput output = authenticate(
UserStorageManager.getEnabledStorageProviders(session, realm, CredentialAuthentication.class),
realm, input);
Stream<CredentialAuthentication> credentialAuthenticationStream = getEnabledStorageProviders(realm, CredentialAuthentication.class);
credentialAuthenticationStream = Stream.concat(credentialAuthenticationStream,
getCredentialProviders(session, CredentialAuthentication.class));
return (output != null) ? output : authenticate(getCredentialProviders(session, CredentialAuthentication.class),
realm, input);
}
public CredentialValidationOutput authenticate(Stream<CredentialAuthentication> storageProviders,
RealmModel realm, CredentialInput input) {
return storageProviders
.filter(auth -> auth.supportsCredentialAuthenticationFor(input.getType()))
.map(auth -> auth.authenticate(realm, input))
.filter(Objects::nonNull)
.findFirst()
.orElse(null);
return credentialAuthenticationStream
.filter(credentialAuthentication -> credentialAuthentication.supportsCredentialAuthenticationFor(input.getType()))
.map(credentialAuthentication -> credentialAuthentication.authenticate(realm, input))
.findFirst().orElse(null);
}
@Override

View file

@ -16,6 +16,7 @@
*/
package org.keycloak.storage;
import org.jboss.logging.Logger;
import org.keycloak.Config;
import org.keycloak.common.util.reflections.Types;
import org.keycloak.component.ComponentFactory;
@ -27,6 +28,7 @@ import org.keycloak.provider.ProviderFactory;
import org.keycloak.utils.ServicesUtils;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
@ -41,6 +43,8 @@ import java.util.stream.Stream;
public abstract class AbstractStorageManager<ProviderType extends Provider,
StorageProviderModelType extends CacheableStorageProviderModel> {
private static final Logger LOG = Logger.getLogger(AbstractStorageManager.class);
/**
* Timeouts are used as time boundary for obtaining models from an external storage. Default value is set
* to 3000 milliseconds and it's configurable.
@ -91,7 +95,7 @@ public abstract class AbstractStorageManager<ProviderType extends Provider,
.map(toStorageProviderModelTypeFunction)
.filter(StorageProviderModelType::isEnabled)
.sorted(StorageProviderModelType.comparator)
.map(storageProviderModelType -> getStorageProviderInstance(storageProviderModelType, capabilityInterface))
.map(storageProviderModelType -> getStorageProviderInstance(storageProviderModelType, capabilityInterface, false))
.filter(Objects::nonNull);
}
@ -108,9 +112,48 @@ public abstract class AbstractStorageManager<ProviderType extends Provider,
* @param <R> result of applyFunction
* @return a stream with all results from all StorageProviders
*/
protected <R, T> Stream<R> applyOnEnabledStorageProvidersWithTimeout(RealmModel realm, Class<T> capabilityInterface, Function<T, ? extends Stream<R>> applyFunction) {
return getEnabledStorageProviders(realm, capabilityInterface).flatMap(ServicesUtils.timeBound(session,
getStorageProviderTimeout(), applyFunction));
protected <R, T> Stream<R> flatMapEnabledStorageProvidersWithTimeout(RealmModel realm, Class<T> capabilityInterface, Function<T, ? extends Stream<R>> applyFunction) {
return getEnabledStorageProviders(realm, capabilityInterface)
.flatMap(ServicesUtils.timeBound(session, getStorageProviderTimeout(), applyFunction));
}
/**
* Gets all enabled StorageProviders that implements the capabilityInterface, applies applyFunction on each of
* them and returns the stream.
*
* !! Each StorageProvider has a limited time to respond, if it fails to do it, null is returned !!
*
* @param realm realm
* @param capabilityInterface class of desired capabilityInterface.
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
* @param applyFunction function that is applied on StorageProviders
* @param <R> Result of applyFunction
* @return First result from StorageProviders
*/
protected <R, T> Stream<R> mapEnabledStorageProvidersWithTimeout(RealmModel realm, Class<T> capabilityInterface, Function<T, R> applyFunction) {
return getEnabledStorageProviders(realm, capabilityInterface)
.map(ServicesUtils.timeBoundOne(session, getStorageProviderTimeout(), applyFunction))
.filter(Objects::nonNull);
}
/**
* Gets all enabled StorageProviders that implements the capabilityInterface and call applyFunction on each
*
* !! Each StorageProvider has a limited time for consuming !!
*
* @param realm realm
* @param capabilityInterface class of desired capabilityInterface.
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
* @param consumer function that is applied on StorageProviders
*/
protected <T> void consumeEnabledStorageProvidersWithTimeout(RealmModel realm, Class<T> capabilityInterface, Consumer<T> consumer) {
getEnabledStorageProviders(realm, capabilityInterface)
.forEachOrdered(ServicesUtils.consumeWithTimeBound(session, getStorageProviderTimeout(), consumer));
}
protected <T> T getStorageProviderInstance(RealmModel realm, String providerId, Class<T> capabilityInterface) {
return getStorageProviderInstance(realm, providerId, capabilityInterface, false);
}
/**
@ -123,24 +166,50 @@ public abstract class AbstractStorageManager<ProviderType extends Provider,
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
* @return an instance of type CreatedProviderType or null if storage provider with providerId doesn't implement capabilityInterface
*/
protected <T> T getStorageProviderInstance(RealmModel realm, String providerId, Class<T> capabilityInterface) {
protected <T> T getStorageProviderInstance(RealmModel realm, String providerId, Class<T> capabilityInterface, boolean includeDisabled) {
if (providerId == null || capabilityInterface == null) return null;
return getStorageProviderInstance(getStorageProviderModel(realm, providerId), capabilityInterface, includeDisabled);
}
/**
* Returns an instance of StorageProvider model corresponding realm and providerId
* @param realm Realm.
* @param providerId Id of desired provider.
* @return An instance of type StorageProviderModelType
*/
protected StorageProviderModelType getStorageProviderModel(RealmModel realm, String providerId) {
ComponentModel componentModel = realm.getComponent(providerId);
if (componentModel == null) {
return null;
}
return getStorageProviderInstance(toStorageProviderModelTypeFunction.apply(componentModel), capabilityInterface);
return toStorageProviderModelTypeFunction.apply(componentModel);
}
/**
* Returns an instance of provider for the model or null if storage provider based on the model doesn't implement capabilityInterface.
*
* @param model StorageProviderModel obtained from database/storage
* @param capabilityInterface class of desired capabilityInterface.
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
* @return an instance of type CreatedProviderType or null if storage provider based on the model doesn't implement capabilityInterface.
* @param <T> Required capability interface type
* @return an instance of type T or null if storage provider based on the model doesn't exist or doesn't implement the capabilityInterface.
*/
protected <T> T getStorageProviderInstance(StorageProviderModelType model, Class<T> capabilityInterface) {
if (model == null || !model.isEnabled() || capabilityInterface == null) {
return getStorageProviderInstance(model, capabilityInterface, false);
}
/**
* Returns an instance of provider for the model or null if storage provider based on the model doesn't implement capabilityInterface.
*
* @param model StorageProviderModel obtained from database/storage
* @param capabilityInterface class of desired capabilityInterface.
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
* @param includeDisabled If set to true, the method will return also disabled providers.
* @return an instance of type T or null if storage provider based on the model doesn't exist or doesn't implement the capabilityInterface.
*/
protected <T> T getStorageProviderInstance(StorageProviderModelType model, Class<T> capabilityInterface, boolean includeDisabled) {
if (model == null || (!model.isEnabled() && !includeDisabled) || capabilityInterface == null) {
return null;
}
@ -149,7 +218,10 @@ public abstract class AbstractStorageManager<ProviderType extends Provider,
if (instance != null && capabilityInterface.isAssignableFrom(instance.getClass())) return capabilityInterface.cast(instance);
ComponentFactory<? extends ProviderType, ProviderType> factory = getStorageProviderFactory(model.getProviderId());
if (factory == null) {
LOG.warnv("Configured StorageProvider {0} of provider id {1} does not exist", model.getName(), model.getProviderId());
return null;
}
if (!Types.supports(capabilityInterface, factory, factoryTypeClass)) {
return null;
}

View file

@ -61,7 +61,7 @@ public class GroupStorageManager extends AbstractStorageManager<GroupStorageProv
@Override
public Stream<GroupModel> searchForGroupByNameStream(RealmModel realm, String search, Integer firstResult, Integer maxResults) {
Stream<GroupModel> local = session.groupLocalStorage().searchForGroupByNameStream(realm, search, firstResult, maxResults);
Stream<GroupModel> ext = applyOnEnabledStorageProvidersWithTimeout(realm, GroupLookupProvider.class,
Stream<GroupModel> ext = flatMapEnabledStorageProvidersWithTimeout(realm, GroupLookupProvider.class,
p -> p.searchForGroupByNameStream(realm, search, firstResult, maxResults));
return Stream.concat(local, ext);

View file

@ -24,15 +24,14 @@ import org.keycloak.models.KeycloakSession;
import org.keycloak.models.utils.ModelToRepresentation;
import org.keycloak.representations.idm.GroupRepresentation;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
import static org.keycloak.common.util.StackUtil.getShortStackTrace;
/**
* Utility class for general helper methods used across the keycloak-services.
*/
@ -45,14 +44,90 @@ public class ServicesUtils {
Function<T, ? extends Stream<R>> func) {
ExecutorService executor = session.getProvider(ExecutorsProvider.class).getExecutor("storage-provider-threads");
return p -> {
Callable<? extends Stream<R>> c = () -> func.apply(p);
Future<? extends Stream<R>> future = executor.submit(c);
// We are running another thread here, which serves as a time checking thread. When timeout is hit, the time
// checking thread will send interrupted flag to main thread, which can cause interruption of func execution.
// To support interruption func implementation should react to interrupt flag.
// If func doesn't check the interrupted flag, the execution won't be interrupted and can take more time
// than the threshold given by timeout variable
Future<?> timeCheckingThread = executor.submit(timeWarningRunnable(timeout, Thread.currentThread()));
try {
return future.get(timeout, TimeUnit.MILLISECONDS);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
future.cancel(true);
logger.debug("Function failed to return on time.", e);
return Stream.empty();
// We cannot run func in different than main thread, because main thread have, for example, EntityManager
// transaction context. If we run any operation on EntityManager in a different thread, it will fail
// with a transaction doesn't exist error
return func.apply(p);
} finally {
timeCheckingThread.cancel(true);
if (Thread.interrupted()) {
logger.warnf("Execution with object [%s] exceeded specified time limit %d. %s", p, timeout, getShortStackTrace());
}
}
};
}
public static <T, R> Function<? super T, R> timeBoundOne(KeycloakSession session,
long timeout,
Function<T, R> func) {
ExecutorService executor = session.getProvider(ExecutorsProvider.class).getExecutor("storage-provider-threads");
return p -> {
// We are running another thread here, which serves as a time checking thread. When timeout is hit, the time
// checking thread will send interrupted flag to main thread, which can cause interruption of func execution.
// To support interruption func implementation should react to interrupt flag.
// If func doesn't check the interrupted flag, the execution won't be interrupted and can take more time
// than the threshold given by timeout variable
Future<?> warningThreadFuture = executor.submit(timeWarningRunnable(timeout, Thread.currentThread()));
try {
// We cannot run func in different than main thread, because main thread have, for example, EntityManager
// transaction context. If we run any operation on EntityManager in a different thread, it will fail
// with a transaction doesn't exist error
return func.apply(p);
} finally {
warningThreadFuture.cancel(true);
if (Thread.interrupted()) {
logger.warnf("Execution with object [%s] exceeded specified time limit %d. %s", p, timeout, getShortStackTrace());
}
}
};
}
public static <T> Consumer<? super T> consumeWithTimeBound(KeycloakSession session,
long timeout,
Consumer<T> func) {
ExecutorService executor = session.getProvider(ExecutorsProvider.class).getExecutor("storage-provider-threads");
return p -> {
// We are running another thread here, which serves as a time checking thread. When timeout is hit, the time
// checking thread will send interrupted flag to main thread, which can cause interruption of func execution.
// To support interruption func implementation should react to interrupt flag.
// If func doesn't check the interrupted flag, the execution won't be interrupted and can take more time
// than the threshold given by timeout variable
Future<?> warningThreadFuture = executor.submit(timeWarningRunnable(timeout, Thread.currentThread()));
try {
// We cannot run func in different than main thread, because main thread have, for example, EntityManager
// transaction context. If we run any operation on EntityManager in a different thread, it will fail
// with a transaction doesn't exist error
func.accept(p);
} finally {
warningThreadFuture.cancel(true);
if (Thread.interrupted()) {
logger.warnf("Execution with object [%s] exceeded specified time limit %d. %s", p, timeout, getShortStackTrace());
}
}
};
}
private static Runnable timeWarningRunnable(long timeout, Thread mainThread) {
return new Runnable() {
@Override
public void run() {
try {
Thread.sleep(timeout);
} catch (InterruptedException exception) {
return; // Do not interrupt if warning thread was interrupted (== main thread finished execution in time)
}
mainThread.interrupt();
}
};
}

View file

@ -84,6 +84,7 @@ public class HardcodedClientStorageProvider implements ClientStorageProvider, Cl
Thread.sleep(5000l);
} catch (InterruptedException ex) {
Logger.getLogger(HardcodedClientStorageProvider.class).warn(ex.getCause());
return Stream.empty();
}
if (clientId != null && this.clientId.toLowerCase().contains(clientId.toLowerCase())) {
return Stream.of(new ClientAdapter(realm));

View file

@ -58,6 +58,7 @@ public class HardcodedGroupStorageProvider implements GroupStorageProvider {
Thread.sleep(5000l);
} catch (InterruptedException ex) {
Logger.getLogger(HardcodedGroupStorageProvider.class).warn(ex.getCause());
return Stream.empty();
}
if (search != null && this.groupName.toLowerCase().contains(search.toLowerCase())) {
return Stream.of(new HardcodedGroupAdapter(realm));

View file

@ -64,6 +64,7 @@ public class HardcodedRoleStorageProvider implements RoleStorageProvider {
Thread.sleep(5000l);
} catch (InterruptedException ex) {
Logger.getLogger(HardcodedClientStorageProvider.class).warn(ex.getCause());
return Stream.empty();
}
if (search != null && this.roleName.toLowerCase().contains(search.toLowerCase())) {
return Stream.of(new HardcodedRoleAdapter(realm));

View file

@ -156,7 +156,7 @@ public class UserPropertyFileStorage implements UserLookupProvider, UserStorageP
@Override
public List<UserModel> searchForUser(Map<String, String> attributes, RealmModel realm) {
return Collections.EMPTY_LIST;
return searchForUser(attributes, realm, 0, Integer.MAX_VALUE - 1);
}
@Override
@ -201,7 +201,7 @@ public class UserPropertyFileStorage implements UserLookupProvider, UserStorageP
@Override
public List<UserModel> searchForUser(String search, RealmModel realm) {
return getUsers(realm, 0, Integer.MAX_VALUE - 1);
return searchForUser(search, realm, 0, Integer.MAX_VALUE - 1);
}
@Override

View file

@ -10,7 +10,6 @@ import org.keycloak.models.KeycloakSession;
import org.keycloak.models.LDAPConstants;
import org.keycloak.models.RealmModel;
import org.keycloak.representations.idm.ComponentRepresentation;
import org.keycloak.storage.UserStorageManager;
import org.keycloak.storage.UserStorageProvider;
import org.keycloak.testsuite.admin.concurrency.AbstractConcurrencyTest;
import org.keycloak.testsuite.federation.UserMapStorage;
@ -45,7 +44,7 @@ public abstract class AbstractUserStorageDirtyDeletionTest extends AbstractConcu
public static void remove20UsersFromStorageProvider(KeycloakSession session) {
assertThat(REMOVED_USERS_COUNT, Matchers.lessThan(NUM_USERS));
final RealmModel realm = session.realms().getRealm(TEST_REALM_NAME);
UserStorageManager.getEnabledStorageProviders(session, realm, UserMapStorage.class)
UserStorageProvidersTestUtils.getEnabledStorageProviders(session, realm, UserMapStorage.class)
.forEachOrdered((UserMapStorage userMapStorage) -> {
Set<String> users = new HashSet<>(userMapStorage.getUsernames());
users.stream()

View file

@ -0,0 +1,88 @@
package org.keycloak.testsuite.federation.storage;
import org.jboss.logging.Logger;
import org.keycloak.common.util.reflections.Types;
import org.keycloak.component.ComponentModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.ModelException;
import org.keycloak.models.RealmModel;
import org.keycloak.storage.UserStorageProvider;
import org.keycloak.storage.UserStorageProviderFactory;
import org.keycloak.storage.UserStorageProviderModel;
import java.util.stream.Stream;
public class UserStorageProvidersTestUtils {
private static final Logger logger = Logger.getLogger(UserStorageProvidersTestUtils.class);
public static boolean isStorageProviderEnabled(RealmModel realm, String providerId) {
UserStorageProviderModel model = getStorageProviderModel(realm, providerId);
return model.isEnabled();
}
public static Stream<UserStorageProviderModel> getStorageProviders(RealmModel realm) {
return realm.getUserStorageProvidersStream();
}
private static UserStorageProviderFactory getUserStorageProviderFactory(UserStorageProviderModel model, KeycloakSession session) {
return (UserStorageProviderFactory) session.getKeycloakSessionFactory()
.getProviderFactory(UserStorageProvider.class, model.getProviderId());
}
public static <T> Stream<T> getEnabledStorageProviders(KeycloakSession session, RealmModel realm, Class<T> type) {
return getStorageProviders(realm, session, type)
.filter(UserStorageProviderModel::isEnabled)
.map(model -> type.cast(getStorageProviderInstance(session, model, getUserStorageProviderFactory(model, session))));
}
public static UserStorageProvider getStorageProviderInstance(KeycloakSession session, UserStorageProviderModel model, UserStorageProviderFactory factory) {
UserStorageProvider instance = (UserStorageProvider)session.getAttribute(model.getId());
if (instance != null) return instance;
instance = factory.create(session, model);
if (instance == null) {
throw new IllegalStateException("UserStorageProvideFactory (of type " + factory.getClass().getName() + ") produced a null instance");
}
session.enlistForClose(instance);
session.setAttribute(model.getId(), instance);
return instance;
}
public static <T> Stream<UserStorageProviderModel> getStorageProviders(RealmModel realm, KeycloakSession session, Class<T> type) {
return realm.getUserStorageProvidersStream()
.filter(model -> {
UserStorageProviderFactory factory = getUserStorageProviderFactory(model, session);
if (factory == null) {
logger.warnv("Configured UserStorageProvider {0} of provider id {1} does not exist in realm {2}",
model.getName(), model.getProviderId(), realm.getName());
return false;
} else {
return Types.supports(type, factory, UserStorageProviderFactory.class);
}
});
}
public static UserStorageProvider getStorageProvider(KeycloakSession session, RealmModel realm, String componentId) {
ComponentModel model = realm.getComponent(componentId);
if (model == null) return null;
UserStorageProviderModel storageModel = new UserStorageProviderModel(model);
UserStorageProviderFactory factory = (UserStorageProviderFactory)session.getKeycloakSessionFactory().getProviderFactory(UserStorageProvider.class, model.getProviderId());
if (factory == null) {
throw new ModelException("Could not find UserStorageProviderFactory for: " + model.getProviderId());
}
return getStorageProviderInstance(session, storageModel, factory);
}
public static <T> Stream<T> getStorageProviders(KeycloakSession session, RealmModel realm, Class<T> type) {
return getStorageProviders(realm, session, type)
.map(model -> type.cast(getStorageProviderInstance(session, model, getUserStorageProviderFactory(model, session))));
}
public static UserStorageProviderModel getStorageProviderModel(RealmModel realm, String componentId) {
ComponentModel model = realm.getComponent(componentId);
if (model == null) return null;
return new UserStorageProviderModel(model);
}
}

View file

@ -66,8 +66,11 @@ import java.util.stream.Collectors;
import static java.util.Calendar.DAY_OF_WEEK;
import static java.util.Calendar.HOUR_OF_DAY;
import static java.util.Calendar.MINUTE;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.keycloak.models.UserModel.RequiredAction.UPDATE_PROFILE;
@ -422,8 +425,8 @@ public class UserStorageTest extends AbstractAuthTest {
// test searchForUser
List<UserRepresentation> users = testRealmResource().users().search("tbrady", 0, Integer.MAX_VALUE);
Assert.assertTrue(users.size() == 1);
Assert.assertTrue(users.get(0).getUsername().equals("tbrady"));
assertThat(users, hasSize(1));
assertThat(users.get(0).getUsername(), equalTo("tbrady"));
// test getGroupMembers()
GroupRepresentation g = new GroupRepresentation();
@ -478,9 +481,9 @@ public class UserStorageTest extends AbstractAuthTest {
@Test
public void testQueryExactMatch() {
Assert.assertThat(testRealmResource().users().search("a", true), Matchers.hasSize(0));
Assert.assertThat(testRealmResource().users().search("apollo", true), Matchers.hasSize(1));
Assert.assertThat(testRealmResource().users().search("tbrady", true), Matchers.hasSize(1));
Assert.assertThat(testRealmResource().users().search("a", true), hasSize(0));
Assert.assertThat(testRealmResource().users().search("apollo", true), hasSize(1));
Assert.assertThat(testRealmResource().users().search("tbrady", true), hasSize(1));
}
private void setDailyEvictionTime(int hour, int minutes) {