KEYCLOAK-15514 Update AbstractStorageManager to check capability interface types
This commit is contained in:
parent
cb5c893d87
commit
3186f1b5a9
2 changed files with 45 additions and 20 deletions
|
@ -17,13 +17,16 @@
|
|||
package org.keycloak.storage;
|
||||
|
||||
import org.keycloak.Config;
|
||||
import org.keycloak.common.util.reflections.Types;
|
||||
import org.keycloak.component.ComponentFactory;
|
||||
import org.keycloak.component.ComponentModel;
|
||||
import org.keycloak.models.KeycloakSession;
|
||||
import org.keycloak.models.RealmModel;
|
||||
import org.keycloak.provider.Provider;
|
||||
import org.keycloak.provider.ProviderFactory;
|
||||
import org.keycloak.utils.ServicesUtils;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
@ -45,13 +48,15 @@ public abstract class AbstractStorageManager<ProviderType extends Provider,
|
|||
private static final Long STORAGE_PROVIDER_DEFAULT_TIMEOUT = 3000L;
|
||||
protected final KeycloakSession session;
|
||||
private final Class<ProviderType> providerTypeClass;
|
||||
private final Class<? extends ProviderFactory> factoryTypeClass;
|
||||
private final Function<ComponentModel, StorageProviderModelType> toStorageProviderModelTypeFunction;
|
||||
private final String configScope;
|
||||
private Long storageProviderTimeout;
|
||||
|
||||
public AbstractStorageManager(KeycloakSession session, Class<ProviderType> providerTypeClass, Function<ComponentModel, StorageProviderModelType> toStorageProviderModelTypeFunction, String configScope) {
|
||||
public AbstractStorageManager(KeycloakSession session, Class<? extends ProviderFactory> factoryTypeClass, Class<ProviderType> providerTypeClass, Function<ComponentModel, StorageProviderModelType> toStorageProviderModelTypeFunction, String configScope) {
|
||||
this.session = session;
|
||||
this.providerTypeClass = providerTypeClass;
|
||||
this.factoryTypeClass = factoryTypeClass;
|
||||
this.toStorageProviderModelTypeFunction = toStorageProviderModelTypeFunction;
|
||||
this.configScope = configScope;
|
||||
}
|
||||
|
@ -74,70 +79,88 @@ public abstract class AbstractStorageManager<ProviderType extends Provider,
|
|||
}
|
||||
|
||||
/**
|
||||
* Returns stream of all storageProviders within the realm that implements the capabilityInterface.
|
||||
*
|
||||
* @param realm realm
|
||||
* @param capabilityInterface class of desired capabilityInterface.
|
||||
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
|
||||
* @return enabled storage providers for realm and @{code getProviderTypeClass()}
|
||||
*/
|
||||
protected Stream<ProviderType> getEnabledStorageProviders(RealmModel realm) {
|
||||
protected <T> Stream<T> getEnabledStorageProviders(RealmModel realm, Class<T> capabilityInterface) {
|
||||
return getStorageProviderModels(realm, providerTypeClass)
|
||||
.map(toStorageProviderModelTypeFunction)
|
||||
.filter(StorageProviderModelType::isEnabled)
|
||||
.sorted(StorageProviderModelType.comparator)
|
||||
.map(this::getStorageProviderInstance);
|
||||
.map(storageProviderModelType -> getStorageProviderInstance(storageProviderModelType, capabilityInterface))
|
||||
.filter(Objects::nonNull);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all enabled StorageProviders, applies applyFunction on each of them and then join the results together.
|
||||
* Gets all enabled StorageProviders that implements the capabilityInterface, applies applyFunction on each of
|
||||
* them and then join the results together.
|
||||
*
|
||||
* !! Each StorageProvider has a limited time to respond, if it fails to do it, empty stream is returned !!
|
||||
*
|
||||
* @param realm realm
|
||||
* @param capabilityInterface class of desired capabilityInterface.
|
||||
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
|
||||
* @param applyFunction function that is applied on StorageProviders
|
||||
* @param <R> result of applyFunction
|
||||
* @return a stream with all results from all StorageProviders
|
||||
*/
|
||||
protected <R> Stream<R> applyOnEnabledStorageProvidersWithTimeout(RealmModel realm, Function<ProviderType, ? extends Stream<R>> applyFunction) {
|
||||
return getEnabledStorageProviders(realm).flatMap(ServicesUtils.timeBound(session,
|
||||
protected <R, T> Stream<R> applyOnEnabledStorageProvidersWithTimeout(RealmModel realm, Class<T> capabilityInterface, Function<T, ? extends Stream<R>> applyFunction) {
|
||||
return getEnabledStorageProviders(realm, capabilityInterface).flatMap(ServicesUtils.timeBound(session,
|
||||
getStorageProviderTimeout(), applyFunction));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an instance of provider with the providerId within the realm.
|
||||
* Returns an instance of provider with the providerId within the realm or null if storage provider with providerId
|
||||
* doesn't implement capabilityInterface.
|
||||
*
|
||||
* @param realm realm
|
||||
* @param providerId id of ComponentModel within database/storage
|
||||
* @return an instance of type CreatedProviderType
|
||||
* @param capabilityInterface class of desired capabilityInterface.
|
||||
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
|
||||
* @return an instance of type CreatedProviderType or null if storage provider with providerId doesn't implement capabilityInterface
|
||||
*/
|
||||
protected ProviderType getStorageProviderInstance(RealmModel realm, String providerId) {
|
||||
protected <T> T getStorageProviderInstance(RealmModel realm, String providerId, Class<T> capabilityInterface) {
|
||||
ComponentModel componentModel = realm.getComponent(providerId);
|
||||
if (componentModel == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return getStorageProviderInstance(toStorageProviderModelTypeFunction.apply(componentModel));
|
||||
return getStorageProviderInstance(toStorageProviderModelTypeFunction.apply(componentModel), capabilityInterface);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an instance of provider for the model
|
||||
* Returns an instance of provider for the model or null if storage provider based on the model doesn't implement capabilityInterface.
|
||||
* @param model StorageProviderModel obtained from database/storage
|
||||
* @return an instance of type CreatedProviderType
|
||||
* @param capabilityInterface class of desired capabilityInterface.
|
||||
* For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
|
||||
* @return an instance of type CreatedProviderType or null if storage provider based on the model doesn't implement capabilityInterface.
|
||||
*/
|
||||
protected ProviderType getStorageProviderInstance(StorageProviderModelType model) {
|
||||
if (model == null || !model.isEnabled()) {
|
||||
protected <T> T getStorageProviderInstance(StorageProviderModelType model, Class<T> capabilityInterface) {
|
||||
if (model == null || !model.isEnabled() || capabilityInterface == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
ProviderType instance = (ProviderType) session.getAttribute(model.getId());
|
||||
if (instance != null) return instance;
|
||||
if (instance != null && capabilityInterface.isAssignableFrom(instance.getClass())) return capabilityInterface.cast(instance);
|
||||
|
||||
ComponentFactory<? extends ProviderType, ProviderType> factory = getStorageProviderFactory(model.getProviderId());
|
||||
|
||||
if (!Types.supports(capabilityInterface, factory, factoryTypeClass)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
instance = factory.create(session, model);
|
||||
if (instance == null) {
|
||||
throw new IllegalStateException("StorageProvideFactory (of type " + factory.getClass().getName() + ") produced a null instance");
|
||||
}
|
||||
session.enlistForClose(instance);
|
||||
session.setAttribute(model.getId(), instance);
|
||||
return instance;
|
||||
return capabilityInterface.cast(instance);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.keycloak.models.RealmModel;
|
|||
import org.keycloak.models.RoleModel;
|
||||
import org.keycloak.storage.group.GroupLookupProvider;
|
||||
import org.keycloak.storage.group.GroupStorageProvider;
|
||||
import org.keycloak.storage.group.GroupStorageProviderFactory;
|
||||
import org.keycloak.storage.group.GroupStorageProviderModel;
|
||||
|
||||
import java.util.stream.Stream;
|
||||
|
@ -30,7 +31,8 @@ import java.util.stream.Stream;
|
|||
public class GroupStorageManager extends AbstractStorageManager<GroupStorageProvider, GroupStorageProviderModel> implements GroupProvider {
|
||||
|
||||
public GroupStorageManager(KeycloakSession session) {
|
||||
super(session, GroupStorageProvider.class, GroupStorageProviderModel::new, "group");
|
||||
super(session, GroupStorageProviderFactory.class, GroupStorageProvider.class,
|
||||
GroupStorageProviderModel::new, "group");
|
||||
}
|
||||
|
||||
/* GROUP PROVIDER LOOKUP METHODS - implemented by group storage providers */
|
||||
|
@ -42,7 +44,7 @@ public class GroupStorageManager extends AbstractStorageManager<GroupStorageProv
|
|||
return session.groupLocalStorage().getGroupById(realm, id);
|
||||
}
|
||||
|
||||
GroupLookupProvider provider = getStorageProviderInstance(realm, storageId.getProviderId());
|
||||
GroupLookupProvider provider = getStorageProviderInstance(realm, storageId.getProviderId(), GroupLookupProvider.class);
|
||||
if (provider == null) return null;
|
||||
|
||||
return provider.getGroupById(realm, id);
|
||||
|
@ -59,8 +61,8 @@ public class GroupStorageManager extends AbstractStorageManager<GroupStorageProv
|
|||
@Override
|
||||
public Stream<GroupModel> searchForGroupByNameStream(RealmModel realm, String search, Integer firstResult, Integer maxResults) {
|
||||
Stream<GroupModel> local = session.groupLocalStorage().searchForGroupByNameStream(realm, search, firstResult, maxResults);
|
||||
Stream<GroupModel> ext = applyOnEnabledStorageProvidersWithTimeout(realm,
|
||||
p -> ((GroupLookupProvider) p).searchForGroupByNameStream(realm, search, firstResult, maxResults));
|
||||
Stream<GroupModel> ext = applyOnEnabledStorageProvidersWithTimeout(realm, GroupLookupProvider.class,
|
||||
p -> p.searchForGroupByNameStream(realm, search, firstResult, maxResults));
|
||||
|
||||
return Stream.concat(local, ext);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue