KEYCLOAK-15514 Update AbstractStorageManager to check capability interface types

This commit is contained in:
mhajas 2020-09-09 13:45:16 +02:00 committed by Hynek Mlnařík
parent cb5c893d87
commit 3186f1b5a9
2 changed files with 45 additions and 20 deletions

View file

@ -17,13 +17,16 @@
package org.keycloak.storage;
import org.keycloak.Config;
import org.keycloak.common.util.reflections.Types;
import org.keycloak.component.ComponentFactory;
import org.keycloak.component.ComponentModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.provider.Provider;
import org.keycloak.provider.ProviderFactory;
import org.keycloak.utils.ServicesUtils;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Stream;
@ -45,13 +48,15 @@ public abstract class AbstractStorageManager<ProviderType extends Provider,
private static final Long STORAGE_PROVIDER_DEFAULT_TIMEOUT = 3000L;
protected final KeycloakSession session;
private final Class<ProviderType> providerTypeClass;
private final Class<? extends ProviderFactory> factoryTypeClass;
private final Function<ComponentModel, StorageProviderModelType> toStorageProviderModelTypeFunction;
private final String configScope;
private Long storageProviderTimeout;
public AbstractStorageManager(KeycloakSession session, Class<ProviderType> providerTypeClass, Function<ComponentModel, StorageProviderModelType> toStorageProviderModelTypeFunction, String configScope) {
public AbstractStorageManager(KeycloakSession session, Class<? extends ProviderFactory> factoryTypeClass, Class<ProviderType> providerTypeClass, Function<ComponentModel, StorageProviderModelType> toStorageProviderModelTypeFunction, String configScope) {
this.session = session;
this.providerTypeClass = providerTypeClass;
this.factoryTypeClass = factoryTypeClass;
this.toStorageProviderModelTypeFunction = toStorageProviderModelTypeFunction;
this.configScope = configScope;
}
@ -74,70 +79,88 @@ public abstract class AbstractStorageManager<ProviderType extends Provider,
}
/**
* Returns stream of all storageProviders within the realm that implements the capabilityInterface.
*
* @param realm realm
* @param capabilityInterface class of desired capabilityInterface.
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
* @return enabled storage providers for realm and @{code getProviderTypeClass()}
*/
protected Stream<ProviderType> getEnabledStorageProviders(RealmModel realm) {
protected <T> Stream<T> getEnabledStorageProviders(RealmModel realm, Class<T> capabilityInterface) {
return getStorageProviderModels(realm, providerTypeClass)
.map(toStorageProviderModelTypeFunction)
.filter(StorageProviderModelType::isEnabled)
.sorted(StorageProviderModelType.comparator)
.map(this::getStorageProviderInstance);
.map(storageProviderModelType -> getStorageProviderInstance(storageProviderModelType, capabilityInterface))
.filter(Objects::nonNull);
}
/**
* Gets all enabled StorageProviders, applies applyFunction on each of them and then join the results together.
* Gets all enabled StorageProviders that implements the capabilityInterface, applies applyFunction on each of
* them and then join the results together.
*
* !! Each StorageProvider has a limited time to respond, if it fails to do it, empty stream is returned !!
*
* @param realm realm
* @param capabilityInterface class of desired capabilityInterface.
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
* @param applyFunction function that is applied on StorageProviders
* @param <R> result of applyFunction
* @return a stream with all results from all StorageProviders
*/
protected <R> Stream<R> applyOnEnabledStorageProvidersWithTimeout(RealmModel realm, Function<ProviderType, ? extends Stream<R>> applyFunction) {
return getEnabledStorageProviders(realm).flatMap(ServicesUtils.timeBound(session,
protected <R, T> Stream<R> applyOnEnabledStorageProvidersWithTimeout(RealmModel realm, Class<T> capabilityInterface, Function<T, ? extends Stream<R>> applyFunction) {
return getEnabledStorageProviders(realm, capabilityInterface).flatMap(ServicesUtils.timeBound(session,
getStorageProviderTimeout(), applyFunction));
}
/**
* Returns an instance of provider with the providerId within the realm.
* Returns an instance of provider with the providerId within the realm or null if storage provider with providerId
* doesn't implement capabilityInterface.
*
* @param realm realm
* @param providerId id of ComponentModel within database/storage
* @return an instance of type CreatedProviderType
* @param capabilityInterface class of desired capabilityInterface.
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
* @return an instance of type CreatedProviderType or null if storage provider with providerId doesn't implement capabilityInterface
*/
protected ProviderType getStorageProviderInstance(RealmModel realm, String providerId) {
protected <T> T getStorageProviderInstance(RealmModel realm, String providerId, Class<T> capabilityInterface) {
ComponentModel componentModel = realm.getComponent(providerId);
if (componentModel == null) {
return null;
}
return getStorageProviderInstance(toStorageProviderModelTypeFunction.apply(componentModel));
return getStorageProviderInstance(toStorageProviderModelTypeFunction.apply(componentModel), capabilityInterface);
}
/**
* Returns an instance of provider for the model
* Returns an instance of provider for the model or null if storage provider based on the model doesn't implement capabilityInterface.
* @param model StorageProviderModel obtained from database/storage
* @return an instance of type CreatedProviderType
* @param capabilityInterface class of desired capabilityInterface.
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
* @return an instance of type CreatedProviderType or null if storage provider based on the model doesn't implement capabilityInterface.
*/
protected ProviderType getStorageProviderInstance(StorageProviderModelType model) {
if (model == null || !model.isEnabled()) {
protected <T> T getStorageProviderInstance(StorageProviderModelType model, Class<T> capabilityInterface) {
if (model == null || !model.isEnabled() || capabilityInterface == null) {
return null;
}
@SuppressWarnings("unchecked")
ProviderType instance = (ProviderType) session.getAttribute(model.getId());
if (instance != null) return instance;
if (instance != null && capabilityInterface.isAssignableFrom(instance.getClass())) return capabilityInterface.cast(instance);
ComponentFactory<? extends ProviderType, ProviderType> factory = getStorageProviderFactory(model.getProviderId());
if (!Types.supports(capabilityInterface, factory, factoryTypeClass)) {
return null;
}
instance = factory.create(session, model);
if (instance == null) {
throw new IllegalStateException("StorageProvideFactory (of type " + factory.getClass().getName() + ") produced a null instance");
}
session.enlistForClose(instance);
session.setAttribute(model.getId(), instance);
return instance;
return capabilityInterface.cast(instance);
}
/**

View file

@ -23,6 +23,7 @@ import org.keycloak.models.RealmModel;
import org.keycloak.models.RoleModel;
import org.keycloak.storage.group.GroupLookupProvider;
import org.keycloak.storage.group.GroupStorageProvider;
import org.keycloak.storage.group.GroupStorageProviderFactory;
import org.keycloak.storage.group.GroupStorageProviderModel;
import java.util.stream.Stream;
@ -30,7 +31,8 @@ import java.util.stream.Stream;
public class GroupStorageManager extends AbstractStorageManager<GroupStorageProvider, GroupStorageProviderModel> implements GroupProvider {
public GroupStorageManager(KeycloakSession session) {
super(session, GroupStorageProvider.class, GroupStorageProviderModel::new, "group");
super(session, GroupStorageProviderFactory.class, GroupStorageProvider.class,
GroupStorageProviderModel::new, "group");
}
/* GROUP PROVIDER LOOKUP METHODS - implemented by group storage providers */
@ -42,7 +44,7 @@ public class GroupStorageManager extends AbstractStorageManager<GroupStorageProv
return session.groupLocalStorage().getGroupById(realm, id);
}
GroupLookupProvider provider = getStorageProviderInstance(realm, storageId.getProviderId());
GroupLookupProvider provider = getStorageProviderInstance(realm, storageId.getProviderId(), GroupLookupProvider.class);
if (provider == null) return null;
return provider.getGroupById(realm, id);
@ -59,8 +61,8 @@ public class GroupStorageManager extends AbstractStorageManager<GroupStorageProv
@Override
public Stream<GroupModel> searchForGroupByNameStream(RealmModel realm, String search, Integer firstResult, Integer maxResults) {
Stream<GroupModel> local = session.groupLocalStorage().searchForGroupByNameStream(realm, search, firstResult, maxResults);
Stream<GroupModel> ext = applyOnEnabledStorageProvidersWithTimeout(realm,
p -> ((GroupLookupProvider) p).searchForGroupByNameStream(realm, search, firstResult, maxResults));
Stream<GroupModel> ext = applyOnEnabledStorageProvidersWithTimeout(realm, GroupLookupProvider.class,
p -> p.searchForGroupByNameStream(realm, search, firstResult, maxResults));
return Stream.concat(local, ext);
}