KEYCLOAK-16118 Replace MapStorage.entrySet() with search by criteria

* Add model class parameter to MapStorage
* Add shortcut read(id) method to MapKeycloakTransaction
This commit is contained in:
Hynek Mlnarik 2020-12-15 15:45:29 +01:00 committed by Hynek Mlnařík
parent 6c07679446
commit 78c05d2da2
33 changed files with 1687 additions and 608 deletions

View file

@ -31,7 +31,7 @@ public class StackUtil {
return getShortStackTrace("\n "); return getShortStackTrace("\n ");
} }
private static final Pattern IGNORED = Pattern.compile("sun\\.|java\\.(lang|util|stream)\\.|org\\.jboss\\.logging."); private static final Pattern IGNORED = Pattern.compile("sun\\.|java\\.(lang|util|stream)\\.|org\\.jboss\\.(arquillian|logging).|org.apache.maven.surefire");
private static final StringBuilder EMPTY = new StringBuilder(0); private static final StringBuilder EMPTY = new StringBuilder(0);
/** /**

View file

@ -668,4 +668,9 @@ public class ClientAdapter implements ClientModel, CachedObject {
public int hashCode() { public int hashCode() {
return getId().hashCode(); return getId().hashCode();
} }
@Override
public String toString() {
return String.format("%s@%08x", getClientId(), System.identityHashCode(this));
}
} }

View file

@ -28,7 +28,7 @@ import java.util.function.Supplier;
public class DefaultLazyLoader<S, D> implements LazyLoader<S, D> { public class DefaultLazyLoader<S, D> implements LazyLoader<S, D> {
private final Function<S, D> loader; private final Function<S, D> loader;
private Supplier<D> fallback; private final Supplier<D> fallback;
private D data; private D data;
public DefaultLazyLoader(Function<S, D> loader, Supplier<D> fallback) { public DefaultLazyLoader(Function<S, D> loader, Supplier<D> fallback) {

View file

@ -753,8 +753,9 @@ public class ClientAdapter implements ClientModel, JpaModel<ClientEntity> {
return getId().hashCode(); return getId().hashCode();
} }
@Override
public String toString() { public String toString() {
return getClientId(); return String.format("%s@%08x", getClientId(), System.identityHashCode(this));
} }
} }

View file

@ -23,7 +23,6 @@ import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.sessions.AuthenticationSessionModel; import org.keycloak.sessions.AuthenticationSessionModel;
import org.keycloak.sessions.AuthenticationSessionProvider;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;

View file

@ -26,18 +26,19 @@ import org.keycloak.models.RealmModel;
import org.keycloak.models.map.common.Serialization; import org.keycloak.models.map.common.Serialization;
import org.keycloak.models.map.storage.MapKeycloakTransaction; import org.keycloak.models.map.storage.MapKeycloakTransaction;
import org.keycloak.models.map.storage.MapStorage; import org.keycloak.models.map.storage.MapStorage;
import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import org.keycloak.models.utils.RealmInfoUtil; import org.keycloak.models.utils.RealmInfoUtil;
import org.keycloak.sessions.AuthenticationSessionCompoundId; import org.keycloak.sessions.AuthenticationSessionCompoundId;
import org.keycloak.sessions.AuthenticationSessionProvider; import org.keycloak.sessions.AuthenticationSessionProvider;
import org.keycloak.sessions.RootAuthenticationSessionModel; import org.keycloak.sessions.RootAuthenticationSessionModel;
import java.util.List; import org.keycloak.sessions.RootAuthenticationSessionModel.SearchableFields;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.UUID; import java.util.UUID;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
@ -48,16 +49,16 @@ public class MapRootAuthenticationSessionProvider implements AuthenticationSessi
private static final Logger LOG = Logger.getLogger(MapRootAuthenticationSessionProvider.class); private static final Logger LOG = Logger.getLogger(MapRootAuthenticationSessionProvider.class);
private final KeycloakSession session; private final KeycloakSession session;
protected final MapKeycloakTransaction<UUID, MapRootAuthenticationSessionEntity> tx; protected final MapKeycloakTransaction<UUID, MapRootAuthenticationSessionEntity, RootAuthenticationSessionModel> tx;
private final MapStorage<UUID, MapRootAuthenticationSessionEntity> sessionStore; private final MapStorage<UUID, MapRootAuthenticationSessionEntity, RootAuthenticationSessionModel> sessionStore;
private static final Predicate<MapRootAuthenticationSessionEntity> ALWAYS_FALSE = role -> false; private static final Predicate<MapRootAuthenticationSessionEntity> ALWAYS_FALSE = role -> false;
private static final String AUTHENTICATION_SESSION_EVENTS = "AUTHENTICATION_SESSION_EVENTS"; private static final String AUTHENTICATION_SESSION_EVENTS = "AUTHENTICATION_SESSION_EVENTS";
public MapRootAuthenticationSessionProvider(KeycloakSession session, MapStorage<UUID, MapRootAuthenticationSessionEntity> sessionStore) { public MapRootAuthenticationSessionProvider(KeycloakSession session, MapStorage<UUID, MapRootAuthenticationSessionEntity, RootAuthenticationSessionModel> sessionStore) {
this.session = session; this.session = session;
this.sessionStore = sessionStore; this.sessionStore = sessionStore;
this.tx = new MapKeycloakTransaction<>(sessionStore); this.tx = sessionStore.createTransaction();
session.getTransactionManager().enlistAfterCompletion(tx); session.getTransactionManager().enlistAfterCompletion(tx);
} }
@ -100,7 +101,7 @@ public class MapRootAuthenticationSessionProvider implements AuthenticationSessi
entity.setRealmId(realm.getId()); entity.setRealmId(realm.getId());
entity.setTimestamp(Time.currentTime()); entity.setTimestamp(Time.currentTime());
if (tx.read(entity.getId(), sessionStore::read) != null) { if (tx.read(entity.getId()) != null) {
throw new ModelDuplicateException("Root authentication session exists: " + entity.getId()); throw new ModelDuplicateException("Root authentication session exists: " + entity.getId());
} }
@ -118,7 +119,7 @@ public class MapRootAuthenticationSessionProvider implements AuthenticationSessi
LOG.tracef("getRootAuthenticationSession(%s, %s)%s", realm.getName(), authenticationSessionId, getShortStackTrace()); LOG.tracef("getRootAuthenticationSession(%s, %s)%s", realm.getName(), authenticationSessionId, getShortStackTrace());
MapRootAuthenticationSessionEntity entity = tx.read(UUID.fromString(authenticationSessionId), sessionStore::read); MapRootAuthenticationSessionEntity entity = tx.read(UUID.fromString(authenticationSessionId));
return (entity == null || !entityRealmFilter(realm.getId()).test(entity)) return (entity == null || !entityRealmFilter(realm.getId()).test(entity))
? null ? null
: entityToAdapterFunc(realm).apply(entity); : entityToAdapterFunc(realm).apply(entity);
@ -137,25 +138,22 @@ public class MapRootAuthenticationSessionProvider implements AuthenticationSessi
int expired = Time.currentTime() - RealmInfoUtil.getDettachedClientSessionLifespan(realm); int expired = Time.currentTime() - RealmInfoUtil.getDettachedClientSessionLifespan(realm);
List<UUID> sessionIds = sessionStore.entrySet().stream() ModelCriteriaBuilder<RootAuthenticationSessionModel> mcb = sessionStore.createCriteriaBuilder()
.filter(entity -> entityRealmFilter(realm.getId()).test(entity.getValue())) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.filter(entity -> entity.getValue().getTimestamp() < expired) .compare(SearchableFields.TIMESTAMP, Operator.LT, expired);
.map(Map.Entry::getKey)
.collect(Collectors.toList());
LOG.debugf("Removed %d expired authentication sessions for realm '%s'", sessionIds.size(), realm.getName()); long deletedCount = tx.delete(UUID.randomUUID(), mcb);
sessionIds.forEach(tx::delete); LOG.debugf("Removed %d expired authentication sessions for realm '%s'", deletedCount, realm.getName());
} }
@Override @Override
public void onRealmRemoved(RealmModel realm) { public void onRealmRemoved(RealmModel realm) {
Objects.requireNonNull(realm, "The provided realm can't be null!"); Objects.requireNonNull(realm, "The provided realm can't be null!");
sessionStore.entrySet().stream() ModelCriteriaBuilder<RootAuthenticationSessionModel> mcb = sessionStore.createCriteriaBuilder()
.filter(entity -> entityRealmFilter(realm.getId()).test(entity.getValue())) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
.map(Map.Entry::getKey)
.collect(Collectors.toList()) sessionStore.delete(mcb);
.forEach(tx::delete);
} }
@Override @Override

View file

@ -24,6 +24,7 @@ import org.keycloak.models.map.storage.MapStorageProvider;
import org.keycloak.sessions.AuthenticationSessionProvider; import org.keycloak.sessions.AuthenticationSessionProvider;
import org.keycloak.sessions.AuthenticationSessionProviderFactory; import org.keycloak.sessions.AuthenticationSessionProviderFactory;
import org.keycloak.sessions.RootAuthenticationSessionModel;
import java.util.UUID; import java.util.UUID;
/** /**
@ -32,12 +33,12 @@ import java.util.UUID;
public class MapRootAuthenticationSessionProviderFactory extends AbstractMapProviderFactory<AuthenticationSessionProvider> public class MapRootAuthenticationSessionProviderFactory extends AbstractMapProviderFactory<AuthenticationSessionProvider>
implements AuthenticationSessionProviderFactory { implements AuthenticationSessionProviderFactory {
private MapStorage<UUID, MapRootAuthenticationSessionEntity> store; private MapStorage<UUID, MapRootAuthenticationSessionEntity, RootAuthenticationSessionModel> store;
@Override @Override
public void postInit(KeycloakSessionFactory factory) { public void postInit(KeycloakSessionFactory factory) {
MapStorageProvider sp = (MapStorageProvider) factory.getProviderFactory(MapStorageProvider.class); MapStorageProvider sp = (MapStorageProvider) factory.getProviderFactory(MapStorageProvider.class);
this.store = sp.getStorage("sessions", UUID.class, MapRootAuthenticationSessionEntity.class); this.store = sp.getStorage("sessions", UUID.class, MapRootAuthenticationSessionEntity.class, RootAuthenticationSessionModel.class);
} }
@Override @Override

View file

@ -541,4 +541,9 @@ public abstract class MapClientAdapter extends AbstractClientModel<MapClientEnti
.findAny() .findAny()
.orElse(null); .orElse(null);
} }
@Override
public String toString() {
return String.format("%s@%08x", getClientId(), System.identityHashCode(this));
}
} }

View file

@ -20,6 +20,7 @@ package org.keycloak.models.map.client;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientModel.ClientUpdatedEvent; import org.keycloak.models.ClientModel.ClientUpdatedEvent;
import org.keycloak.models.ClientModel.SearchableFields;
import org.keycloak.models.ClientProvider; import org.keycloak.models.ClientProvider;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.ModelDuplicateException; import org.keycloak.models.ModelDuplicateException;
@ -39,6 +40,8 @@ import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.keycloak.models.map.storage.MapStorage; import org.keycloak.models.map.storage.MapStorage;
import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.utils.StreamsUtil.paginatedStream; import static org.keycloak.utils.StreamsUtil.paginatedStream;
@ -47,28 +50,17 @@ public class MapClientProvider implements ClientProvider {
private static final Logger LOG = Logger.getLogger(MapClientProvider.class); private static final Logger LOG = Logger.getLogger(MapClientProvider.class);
private static final Predicate<MapClientEntity> ALWAYS_FALSE = c -> { return false; }; private static final Predicate<MapClientEntity> ALWAYS_FALSE = c -> { return false; };
private final KeycloakSession session; private final KeycloakSession session;
final MapKeycloakTransaction<UUID, MapClientEntity> tx; final MapKeycloakTransaction<UUID, MapClientEntity, ClientModel> tx;
private final MapStorage<UUID, MapClientEntity> clientStore; private final MapStorage<UUID, MapClientEntity, ClientModel> clientStore;
private final ConcurrentMap<UUID, ConcurrentMap<String, Integer>> clientRegisteredNodesStore; private final ConcurrentMap<UUID, ConcurrentMap<String, Integer>> clientRegisteredNodesStore;
private static final Comparator<MapClientEntity> COMPARE_BY_CLIENT_ID = new Comparator<MapClientEntity>() { private static final Comparator<MapClientEntity> COMPARE_BY_CLIENT_ID = Comparator.comparing(MapClientEntity::getClientId);
@Override
public int compare(MapClientEntity o1, MapClientEntity o2) {
String c1 = o1 == null ? null : o1.getClientId();
String c2 = o2 == null ? null : o2.getClientId();
return c1 == c2 ? 0
: c1 == null ? -1
: c2 == null ? 1
: c1.compareTo(c2);
} public MapClientProvider(KeycloakSession session, MapStorage<UUID, MapClientEntity, ClientModel> clientStore, ConcurrentMap<UUID, ConcurrentMap<String, Integer>> clientRegisteredNodesStore) {
};
public MapClientProvider(KeycloakSession session, MapStorage<UUID, MapClientEntity> clientStore, ConcurrentMap<UUID, ConcurrentMap<String, Integer>> clientRegisteredNodesStore) {
this.session = session; this.session = session;
this.clientStore = clientStore; this.clientStore = clientStore;
this.clientRegisteredNodesStore = clientRegisteredNodesStore; this.clientRegisteredNodesStore = clientRegisteredNodesStore;
this.tx = new MapKeycloakTransaction<>(clientStore); this.tx = clientStore.createTransaction();
session.getTransactionManager().enlist(tx); session.getTransactionManager().enlist(tx);
} }
@ -135,17 +127,12 @@ public class MapClientProvider implements ClientProvider {
return paginatedStream(getClientsStream(realm), firstResult, maxResults); return paginatedStream(getClientsStream(realm), firstResult, maxResults);
} }
private Stream<MapClientEntity> getNotRemovedUpdatedClientsStream() {
Stream<MapClientEntity> updatedAndNotRemovedClientsStream = clientStore.entrySet().stream()
.map(tx::getUpdated) // If the client has been removed, tx.get will return null, otherwise it will return me.getValue()
.filter(Objects::nonNull);
return Stream.concat(tx.createdValuesStream(), updatedAndNotRemovedClientsStream);
}
@Override @Override
public Stream<ClientModel> getClientsStream(RealmModel realm) { public Stream<ClientModel> getClientsStream(RealmModel realm) {
return getNotRemovedUpdatedClientsStream() ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder()
.filter(entityRealmFilter(realm)) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
return tx.getUpdatedNotRemoved(mcb)
.sorted(COMPARE_BY_CLIENT_ID) .sorted(COMPARE_BY_CLIENT_ID)
.map(entityToAdapterFunc(realm)) .map(entityToAdapterFunc(realm))
; ;
@ -165,7 +152,7 @@ public class MapClientProvider implements ClientProvider {
entity.setClientId(clientId); entity.setClientId(clientId);
entity.setEnabled(true); entity.setEnabled(true);
entity.setStandardFlowEnabled(true); entity.setStandardFlowEnabled(true);
if (tx.read(entity.getId(), clientStore::read) != null) { if (tx.read(entity.getId()) != null) {
throw new ModelDuplicateException("Client exists: " + id); throw new ModelDuplicateException("Client exists: " + id);
} }
tx.create(entity.getId(), entity); tx.create(entity.getId(), entity);
@ -228,9 +215,10 @@ public class MapClientProvider implements ClientProvider {
@Override @Override
public long getClientsCount(RealmModel realm) { public long getClientsCount(RealmModel realm) {
return this.getNotRemovedUpdatedClientsStream() ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder()
.filter(entityRealmFilter(realm)) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
.count();
return this.clientStore.getCount(mcb);
} }
@Override @Override
@ -241,7 +229,7 @@ public class MapClientProvider implements ClientProvider {
LOG.tracef("getClientById(%s, %s)%s", realm, id, getShortStackTrace()); LOG.tracef("getClientById(%s, %s)%s", realm, id, getShortStackTrace());
MapClientEntity entity = tx.read(UUID.fromString(id), clientStore::read); MapClientEntity entity = tx.read(UUID.fromString(id));
return (entity == null || ! entityRealmFilter(realm).test(entity)) return (entity == null || ! entityRealmFilter(realm).test(entity))
? null ? null
: entityToAdapterFunc(realm).apply(entity); : entityToAdapterFunc(realm).apply(entity);
@ -254,11 +242,11 @@ public class MapClientProvider implements ClientProvider {
} }
LOG.tracef("getClientByClientId(%s, %s)%s", realm, clientId, getShortStackTrace()); LOG.tracef("getClientByClientId(%s, %s)%s", realm, clientId, getShortStackTrace());
String clientIdLower = clientId.toLowerCase(); ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.CLIENT_ID, Operator.ILIKE, clientId);
return getNotRemovedUpdatedClientsStream() return tx.getUpdatedNotRemoved(mcb)
.filter(entityRealmFilter(realm))
.filter(entity -> entity.getClientId() != null && Objects.equals(entity.getClientId().toLowerCase(), clientIdLower))
.map(entityToAdapterFunc(realm)) .map(entityToAdapterFunc(realm))
.findFirst() .findFirst()
.orElse(null) .orElse(null)
@ -270,10 +258,12 @@ public class MapClientProvider implements ClientProvider {
if (clientId == null) { if (clientId == null) {
return Stream.empty(); return Stream.empty();
} }
String clientIdLower = clientId.toLowerCase();
Stream<MapClientEntity> s = getNotRemovedUpdatedClientsStream() ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder()
.filter(entityRealmFilter(realm)) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.filter(entity -> entity.getClientId() != null && entity.getClientId().toLowerCase().contains(clientIdLower)) .compare(SearchableFields.CLIENT_ID, Operator.ILIKE, "%" + clientId + "%");
Stream<MapClientEntity> s = tx.getUpdatedNotRemoved(mcb)
.sorted(COMPARE_BY_CLIENT_ID); .sorted(COMPARE_BY_CLIENT_ID);
return paginatedStream(s, firstResult, maxResults).map(entityToAdapterFunc(realm)); return paginatedStream(s, firstResult, maxResults).map(entityToAdapterFunc(realm));

View file

@ -16,6 +16,7 @@
*/ */
package org.keycloak.models.map.client; package org.keycloak.models.map.client;
import org.keycloak.models.ClientModel;
import org.keycloak.models.map.common.AbstractMapProviderFactory; import org.keycloak.models.map.common.AbstractMapProviderFactory;
import org.keycloak.models.ClientProvider; import org.keycloak.models.ClientProvider;
import org.keycloak.models.ClientProviderFactory; import org.keycloak.models.ClientProviderFactory;
@ -35,12 +36,12 @@ public class MapClientProviderFactory extends AbstractMapProviderFactory<ClientP
private final ConcurrentHashMap<UUID, ConcurrentMap<String, Integer>> REGISTERED_NODES_STORE = new ConcurrentHashMap<>(); private final ConcurrentHashMap<UUID, ConcurrentMap<String, Integer>> REGISTERED_NODES_STORE = new ConcurrentHashMap<>();
private MapStorage<UUID, MapClientEntity> store; private MapStorage<UUID, MapClientEntity, ClientModel> store;
@Override @Override
public void postInit(KeycloakSessionFactory factory) { public void postInit(KeycloakSessionFactory factory) {
MapStorageProvider sp = (MapStorageProvider) factory.getProviderFactory(MapStorageProvider.class); MapStorageProvider sp = (MapStorageProvider) factory.getProviderFactory(MapStorageProvider.class);
this.store = sp.getStorage("clients", UUID.class, MapClientEntity.class); this.store = sp.getStorage("clients", UUID.class, MapClientEntity.class, ClientModel.class);
} }

View file

@ -20,12 +20,17 @@ import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectReader; import com.fasterxml.jackson.databind.ObjectReader;
import com.fasterxml.jackson.databind.ObjectWriter; import com.fasterxml.jackson.databind.ObjectWriter;
import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.databind.type.TypeFactory;
import com.fasterxml.jackson.datatype.jdk8.StreamSerializer;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;
/** /**
* *
@ -33,20 +38,24 @@ import java.util.concurrent.ConcurrentHashMap;
*/ */
public class Serialization { public class Serialization {
public static final ObjectMapper MAPPER = new ObjectMapper(); public static final ObjectMapper MAPPER = new ObjectMapper()
.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false)
.enable(SerializationFeature.INDENT_OUTPUT)
.setSerializationInclusion(JsonInclude.Include.NON_NULL)
.setVisibility(PropertyAccessor.ALL, Visibility.NONE)
.setVisibility(PropertyAccessor.FIELD, Visibility.ANY)
.addMixIn(AbstractEntity.class, IgnoreUpdatedMixIn.class);
public static final ConcurrentHashMap<Class<?>, ObjectReader> READERS = new ConcurrentHashMap<>(); public static final ConcurrentHashMap<Class<?>, ObjectReader> READERS = new ConcurrentHashMap<>();
public static final ConcurrentHashMap<Class<?>, ObjectWriter> WRITERS = new ConcurrentHashMap<>(); public static final ConcurrentHashMap<Class<?>, ObjectWriter> WRITERS = new ConcurrentHashMap<>();
abstract class IgnoreUpdatedMixIn { @JsonIgnore public abstract boolean isUpdated(); } abstract class IgnoreUpdatedMixIn { @JsonIgnore public abstract boolean isUpdated(); }
static { static {
MAPPER.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); JavaType type = TypeFactory.unknownType();
MAPPER.enable(SerializationFeature.INDENT_OUTPUT); JavaType streamType = MAPPER.getTypeFactory().constructParametricType(Stream.class, type);
MAPPER.setSerializationInclusion(JsonInclude.Include.NON_NULL); SimpleModule module = new SimpleModule().addSerializer(new StreamSerializer(streamType, type));
MAPPER.setVisibility(PropertyAccessor.ALL, Visibility.NONE); MAPPER.registerModule(module);
MAPPER.setVisibility(PropertyAccessor.FIELD, Visibility.ANY);
MAPPER.addMixIn(AbstractEntity.class, IgnoreUpdatedMixIn.class);
} }

View file

@ -19,6 +19,7 @@ package org.keycloak.models.map.group;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
import org.keycloak.models.GroupModel; import org.keycloak.models.GroupModel;
import org.keycloak.models.GroupModel.SearchableFields;
import org.keycloak.models.GroupProvider; import org.keycloak.models.GroupProvider;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.ModelDuplicateException; import org.keycloak.models.ModelDuplicateException;
@ -28,11 +29,13 @@ import org.keycloak.models.map.common.Serialization;
import org.keycloak.models.map.storage.MapKeycloakTransaction; import org.keycloak.models.map.storage.MapKeycloakTransaction;
import org.keycloak.models.map.storage.MapStorage; import org.keycloak.models.map.storage.MapStorage;
import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import java.util.Comparator; import java.util.Comparator;
import java.util.Objects; import java.util.Objects;
import java.util.UUID; import java.util.UUID;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.UnaryOperator;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
@ -41,15 +44,14 @@ import static org.keycloak.utils.StreamsUtil.paginatedStream;
public class MapGroupProvider implements GroupProvider { public class MapGroupProvider implements GroupProvider {
private static final Logger LOG = Logger.getLogger(MapGroupProvider.class); private static final Logger LOG = Logger.getLogger(MapGroupProvider.class);
private static final Predicate<MapGroupEntity> ALWAYS_FALSE = c -> { return false; };
private final KeycloakSession session; private final KeycloakSession session;
final MapKeycloakTransaction<UUID, MapGroupEntity> tx; final MapKeycloakTransaction<UUID, MapGroupEntity, GroupModel> tx;
private final MapStorage<UUID, MapGroupEntity> groupStore; private final MapStorage<UUID, MapGroupEntity, GroupModel> groupStore;
public MapGroupProvider(KeycloakSession session, MapStorage<UUID, MapGroupEntity> groupStore) { public MapGroupProvider(KeycloakSession session, MapStorage<UUID, MapGroupEntity, GroupModel> groupStore) {
this.session = session; this.session = session;
this.groupStore = groupStore; this.groupStore = groupStore;
this.tx = new MapKeycloakTransaction<>(groupStore); this.tx = groupStore.createTransaction();
session.getTransactionManager().enlist(tx); session.getTransactionManager().enlist(tx);
} }
@ -64,14 +66,6 @@ public class MapGroupProvider implements GroupProvider {
return origEntity -> new MapGroupAdapter(session, realm, registerEntityForChanges(origEntity)); return origEntity -> new MapGroupAdapter(session, realm, registerEntityForChanges(origEntity));
} }
private Predicate<MapGroupEntity> entityRealmFilter(RealmModel realm) {
if (realm == null || realm.getId() == null) {
return MapGroupProvider.ALWAYS_FALSE;
}
String realmId = realm.getId();
return entity -> Objects.equals(realmId, entity.getRealmId());
}
@Override @Override
public GroupModel getGroupById(RealmModel realm, String id) { public GroupModel getGroupById(RealmModel realm, String id) {
if (id == null) { if (id == null) {
@ -88,28 +82,28 @@ public class MapGroupProvider implements GroupProvider {
return null; return null;
} }
MapGroupEntity entity = tx.read(uid, groupStore::read); MapGroupEntity entity = tx.read(uid);
return (entity == null || ! entityRealmFilter(realm).test(entity)) String realmId = realm.getId();
return (entity == null || ! Objects.equals(realmId, entity.getRealmId()))
? null ? null
: entityToAdapterFunc(realm).apply(entity); : entityToAdapterFunc(realm).apply(entity);
} }
private Stream<MapGroupEntity> getNotRemovedUpdatedGroupsStream() {
Stream<MapGroupEntity> updatedAndNotRemovedGroupsStream = groupStore.entrySet().stream()
.map(tx::getUpdated) // If the group has been removed, tx.get will return null, otherwise it will return me.getValue()
.filter(Objects::nonNull);
return Stream.concat(tx.createdValuesStream(), updatedAndNotRemovedGroupsStream);
}
private Stream<MapGroupEntity> getUnsortedGroupEntitiesStream(RealmModel realm) {
return getNotRemovedUpdatedGroupsStream()
.filter(entityRealmFilter(realm));
}
@Override @Override
public Stream<GroupModel> getGroupsStream(RealmModel realm) { public Stream<GroupModel> getGroupsStream(RealmModel realm) {
return getGroupsStreamInternal(realm, null);
}
private Stream<GroupModel> getGroupsStreamInternal(RealmModel realm, UnaryOperator<ModelCriteriaBuilder<GroupModel>> modifier) {
LOG.tracef("getGroupsStream(%s)%s", realm, getShortStackTrace()); LOG.tracef("getGroupsStream(%s)%s", realm, getShortStackTrace());
return getUnsortedGroupEntitiesStream(realm) ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
if (modifier != null) {
mcb = modifier.apply(mcb);
}
return tx.getUpdatedNotRemoved(mcb)
.map(entityToAdapterFunc(realm)) .map(entityToAdapterFunc(realm))
.sorted(GroupModel.COMPARE_BY_NAME) .sorted(GroupModel.COMPARE_BY_NAME)
; ;
@ -117,38 +111,49 @@ public class MapGroupProvider implements GroupProvider {
@Override @Override
public Stream<GroupModel> getGroupsStream(RealmModel realm, Stream<String> ids, String search, Integer first, Integer max) { public Stream<GroupModel> getGroupsStream(RealmModel realm, Stream<String> ids, String search, Integer first, Integer max) {
Stream<GroupModel> groupModelStream = ids.map(id -> session.groups().getGroupById(realm, id)) ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
.sorted(Comparator.comparing(GroupModel::getName)); .compare(SearchableFields.ID, Operator.IN, ids.map(UUID::fromString))
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
if (search != null) { if (search != null) {
String s = search.toLowerCase(); mcb = mcb.compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%");
groupModelStream = groupModelStream.filter(groupModel -> groupModel.getName().toLowerCase().contains(s));
} }
Stream<GroupModel> groupModelStream = tx.getUpdatedNotRemoved(mcb)
.map(entityToAdapterFunc(realm))
.sorted(Comparator.comparing(GroupModel::getName));
return paginatedStream(groupModelStream, first, max); return paginatedStream(groupModelStream, first, max);
} }
@Override @Override
public Long getGroupsCount(RealmModel realm, Boolean onlyTopGroups) { public Long getGroupsCount(RealmModel realm, Boolean onlyTopGroups) {
LOG.tracef("getGroupsCount(%s, %s)%s", realm, onlyTopGroups, getShortStackTrace()); LOG.tracef("getGroupsCount(%s, %s)%s", realm, onlyTopGroups, getShortStackTrace());
Stream<MapGroupEntity> groupModelStream = getUnsortedGroupEntitiesStream(realm); ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
if (onlyTopGroups) { if (Objects.equals(onlyTopGroups, Boolean.TRUE)) {
groupModelStream = groupModelStream.filter(groupEntity -> Objects.isNull(groupEntity.getParentId())); mcb = mcb.compare(SearchableFields.PARENT_ID, Operator.EQ, (Object) null);
} }
return groupModelStream.count(); return tx.getCount(mcb);
} }
@Override @Override
public Long getGroupsCountByNameContaining(RealmModel realm, String search) { public Long getGroupsCountByNameContaining(RealmModel realm, String search) {
return searchForGroupByNameStream(realm, search, null, null).count(); ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%");
return tx.getCount(mcb);
} }
@Override @Override
public Stream<GroupModel> getGroupsByRoleStream(RealmModel realm, RoleModel role, Integer firstResult, Integer maxResults) { public Stream<GroupModel> getGroupsByRoleStream(RealmModel realm, RoleModel role, Integer firstResult, Integer maxResults) {
LOG.tracef("getGroupsByRole(%s, %s, %d, %d)%s", realm, role, firstResult, maxResults, getShortStackTrace()); LOG.tracef("getGroupsByRole(%s, %s, %d, %d)%s", realm, role, firstResult, maxResults, getShortStackTrace());
Stream<GroupModel> groupModelStream = getGroupsStream(realm).filter(groupModel -> groupModel.hasRole(role)); Stream<GroupModel> groupModelStream = getGroupsStreamInternal(realm,
(ModelCriteriaBuilder<GroupModel> mcb) -> mcb.compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, role.getId())
);
return paginatedStream(groupModelStream, firstResult, maxResults); return paginatedStream(groupModelStream, firstResult, maxResults);
} }
@ -156,8 +161,9 @@ public class MapGroupProvider implements GroupProvider {
@Override @Override
public Stream<GroupModel> getTopLevelGroupsStream(RealmModel realm) { public Stream<GroupModel> getTopLevelGroupsStream(RealmModel realm) {
LOG.tracef("getTopLevelGroupsStream(%s)%s", realm, getShortStackTrace()); LOG.tracef("getTopLevelGroupsStream(%s)%s", realm, getShortStackTrace());
return getGroupsStream(realm) return getGroupsStreamInternal(realm,
.filter(groupModel -> Objects.isNull(groupModel.getParentId())); (ModelCriteriaBuilder<GroupModel> mcb) -> mcb.compare(SearchableFields.PARENT_ID, Operator.EQ, (Object) null)
);
} }
@Override @Override
@ -171,8 +177,9 @@ public class MapGroupProvider implements GroupProvider {
@Override @Override
public Stream<GroupModel> searchForGroupByNameStream(RealmModel realm, String search, Integer firstResult, Integer maxResults) { public Stream<GroupModel> searchForGroupByNameStream(RealmModel realm, String search, Integer firstResult, Integer maxResults) {
LOG.tracef("searchForGroupByNameStream(%s, %s, %d, %d)%s", realm, search, firstResult, maxResults, getShortStackTrace()); LOG.tracef("searchForGroupByNameStream(%s, %s, %d, %d)%s", realm, search, firstResult, maxResults, getShortStackTrace());
Stream<GroupModel> groupModelStream = getGroupsStream(realm) Stream<GroupModel> groupModelStream = getGroupsStreamInternal(realm,
.filter(groupModel -> groupModel.getName().contains(search)); (ModelCriteriaBuilder<GroupModel> mcb) -> mcb.compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%")
);
return paginatedStream(groupModelStream, firstResult, maxResults); return paginatedStream(groupModelStream, firstResult, maxResults);
@ -184,17 +191,20 @@ public class MapGroupProvider implements GroupProvider {
final UUID entityId = id == null ? UUID.randomUUID() : UUID.fromString(id); final UUID entityId = id == null ? UUID.randomUUID() : UUID.fromString(id);
// Check Db constraint: uniqueConstraints = { @UniqueConstraint(columnNames = {"REALM_ID", "PARENT_GROUP", "NAME"})} // Check Db constraint: uniqueConstraints = { @UniqueConstraint(columnNames = {"REALM_ID", "PARENT_GROUP", "NAME"})}
if (getUnsortedGroupEntitiesStream(realm) String parentId = toParent == null ? null : toParent.getId();
.anyMatch(groupEntity -> ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
Objects.equals(groupEntity.getParentId(), toParent == null ? null : toParent.getId()) && .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
Objects.equals(groupEntity.getName(), name))) { .compare(SearchableFields.PARENT_ID, Operator.EQ, parentId)
.compare(SearchableFields.NAME, Operator.EQ, name);
if (tx.getCount(mcb) > 0) {
throw new ModelDuplicateException("Group with name '" + name + "' in realm " + realm.getName() + " already exists for requested parent" ); throw new ModelDuplicateException("Group with name '" + name + "' in realm " + realm.getName() + " already exists for requested parent" );
} }
MapGroupEntity entity = new MapGroupEntity(entityId, realm.getId()); MapGroupEntity entity = new MapGroupEntity(entityId, realm.getId());
entity.setName(name); entity.setName(name);
entity.setParentId(toParent == null ? null : toParent.getId()); entity.setParentId(toParent == null ? null : toParent.getId());
if (tx.read(entity.getId(), groupStore::read) != null) { if (tx.read(entity.getId()) != null) {
throw new ModelDuplicateException("Group exists: " + entityId); throw new ModelDuplicateException("Group exists: " + entityId);
} }
tx.create(entity.getId(), entity); tx.create(entity.getId(), entity);
@ -249,12 +259,16 @@ public class MapGroupProvider implements GroupProvider {
} }
String parentId = toParent == null ? null : toParent.getId(); String parentId = toParent == null ? null : toParent.getId();
Stream<MapGroupEntity> possibleSiblings = getUnsortedGroupEntitiesStream(realm) ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
.filter(mapGroupEntity -> Objects.equals(mapGroupEntity.getParentId(), parentId)); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.PARENT_ID, Operator.EQ, parentId)
.compare(SearchableFields.NAME, Operator.EQ, group.getName());
if (possibleSiblings.map(MapGroupEntity::getName).anyMatch(Predicate.isEqual(group.getName()))) { try (Stream<MapGroupEntity> possibleSiblings = tx.getUpdatedNotRemoved(mcb)) {
if (possibleSiblings.findAny().isPresent()) {
throw new ModelDuplicateException("Parent already contains subgroup named '" + group.getName() + "'"); throw new ModelDuplicateException("Parent already contains subgroup named '" + group.getName() + "'");
} }
}
if (group.getParentId() != null) { if (group.getParentId() != null) {
group.getParent().removeChild(group); group.getParent().removeChild(group);
@ -267,12 +281,16 @@ public class MapGroupProvider implements GroupProvider {
public void addTopLevelGroup(RealmModel realm, GroupModel subGroup) { public void addTopLevelGroup(RealmModel realm, GroupModel subGroup) {
LOG.tracef("addTopLevelGroup(%s, %s)%s", realm, subGroup, getShortStackTrace()); LOG.tracef("addTopLevelGroup(%s, %s)%s", realm, subGroup, getShortStackTrace());
Stream<MapGroupEntity> possibleSiblings = getUnsortedGroupEntitiesStream(realm) ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
.filter(mapGroupEntity -> mapGroupEntity.getParentId() == null); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.PARENT_ID, Operator.EQ, (Object) null)
.compare(SearchableFields.NAME, Operator.EQ, subGroup.getName());
if (possibleSiblings.map(MapGroupEntity::getName).anyMatch(Predicate.isEqual(subGroup.getName()))) { try (Stream<MapGroupEntity> possibleSiblings = tx.getUpdatedNotRemoved(mcb)) {
if (possibleSiblings.findAny().isPresent()) {
throw new ModelDuplicateException("There is already a top level group named '" + subGroup.getName() + "'"); throw new ModelDuplicateException("There is already a top level group named '" + subGroup.getName() + "'");
} }
}
subGroup.setParent(null); subGroup.setParent(null);
} }
@ -280,12 +298,15 @@ public class MapGroupProvider implements GroupProvider {
@Override @Override
public void preRemove(RealmModel realm, RoleModel role) { public void preRemove(RealmModel realm, RoleModel role) {
LOG.tracef("preRemove(%s, %s)%s", realm, role, getShortStackTrace()); LOG.tracef("preRemove(%s, %s)%s", realm, role, getShortStackTrace());
final String roleId = role.getId(); ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
getUnsortedGroupEntitiesStream(realm) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.filter(groupEntity -> groupEntity.getGrantedRoles().contains(roleId)) .compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, role.getId());
try (Stream<MapGroupEntity> toRemove = tx.getUpdatedNotRemoved(mcb)) {
toRemove
.map(groupEntity -> session.groups().getGroupById(realm, groupEntity.getId().toString())) .map(groupEntity -> session.groups().getGroupById(realm, groupEntity.getId().toString()))
.forEach(groupModel -> groupModel.deleteRoleMapping(role)); .forEach(groupModel -> groupModel.deleteRoleMapping(role));
} }
}
@Override @Override
public void close() { public void close() {

View file

@ -17,6 +17,7 @@
package org.keycloak.models.map.group; package org.keycloak.models.map.group;
import org.keycloak.models.GroupModel;
import org.keycloak.models.GroupProvider; import org.keycloak.models.GroupProvider;
import org.keycloak.models.GroupProviderFactory; import org.keycloak.models.GroupProviderFactory;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
@ -33,12 +34,12 @@ import java.util.UUID;
*/ */
public class MapGroupProviderFactory extends AbstractMapProviderFactory<GroupProvider> implements GroupProviderFactory { public class MapGroupProviderFactory extends AbstractMapProviderFactory<GroupProvider> implements GroupProviderFactory {
private MapStorage<UUID, MapGroupEntity> store; private MapStorage<UUID, MapGroupEntity, GroupModel> store;
@Override @Override
public void postInit(KeycloakSessionFactory factory) { public void postInit(KeycloakSessionFactory factory) {
MapStorageProvider sp = (MapStorageProvider) factory.getProviderFactory(MapStorageProvider.class); MapStorageProvider sp = (MapStorageProvider) factory.getProviderFactory(MapStorageProvider.class);
this.store = sp.getStorage("groups", UUID.class, MapGroupEntity.class); this.store = sp.getStorage("groups", UUID.class, MapGroupEntity.class, GroupModel.class);
} }

View file

@ -38,16 +38,18 @@ import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.utils.StreamsUtil.paginatedStream; import static org.keycloak.utils.StreamsUtil.paginatedStream;
import org.keycloak.models.RoleContainerModel; import org.keycloak.models.RoleContainerModel;
import org.keycloak.models.RoleModel.SearchableFields;
import org.keycloak.models.RoleProvider; import org.keycloak.models.RoleProvider;
import org.keycloak.models.map.common.StreamUtils; import org.keycloak.models.map.common.StreamUtils;
import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
public class MapRoleProvider implements RoleProvider { public class MapRoleProvider implements RoleProvider {
private static final Logger LOG = Logger.getLogger(MapRoleProvider.class); private static final Logger LOG = Logger.getLogger(MapRoleProvider.class);
private static final Predicate<MapRoleEntity> ALWAYS_FALSE = role -> { return false; };
private final KeycloakSession session; private final KeycloakSession session;
final MapKeycloakTransaction<UUID, MapRoleEntity> tx; final MapKeycloakTransaction<UUID, MapRoleEntity, RoleModel> tx;
private final MapStorage<UUID, MapRoleEntity> roleStore; private final MapStorage<UUID, MapRoleEntity, RoleModel> roleStore;
private static final Comparator<MapRoleEntity> COMPARE_BY_NAME = new Comparator<MapRoleEntity>() { private static final Comparator<MapRoleEntity> COMPARE_BY_NAME = new Comparator<MapRoleEntity>() {
@Override @Override
@ -62,10 +64,10 @@ public class MapRoleProvider implements RoleProvider {
} }
}; };
public MapRoleProvider(KeycloakSession session, MapStorage<UUID, MapRoleEntity> roleStore) { public MapRoleProvider(KeycloakSession session, MapStorage<UUID, MapRoleEntity, RoleModel> roleStore) {
this.session = session; this.session = session;
this.roleStore = roleStore; this.roleStore = roleStore;
this.tx = new MapKeycloakTransaction<>(roleStore); this.tx = roleStore.createTransaction();
session.getTransactionManager().enlist(tx); session.getTransactionManager().enlist(tx);
} }
@ -81,31 +83,6 @@ public class MapRoleProvider implements RoleProvider {
return res; return res;
} }
private Predicate<MapRoleEntity> entityRealmFilter(RealmModel realm) {
if (realm == null || realm.getId() == null) {
return MapRoleProvider.ALWAYS_FALSE;
}
String realmId = realm.getId();
return entity -> Objects.equals(realmId, entity.getRealmId());
}
private Predicate<MapRoleEntity> entityClientFilter(ClientModel client) {
if (client == null || client.getId() == null) {
return MapRoleProvider.ALWAYS_FALSE;
}
String clientId = client.getId();
return entity -> entity.isClientRole() &&
Objects.equals(clientId, entity.getClientId());
}
private Stream<MapRoleEntity> getNotRemovedUpdatedRolesStream(RealmModel realm) {
Stream<MapRoleEntity> updatedAndNotRemovedRolesStream = roleStore.entrySet().stream()
.map(tx::getUpdated) // If the role has been removed, tx.get will return null, otherwise it will return me.getValue()
.filter(Objects::nonNull);
return Stream.concat(tx.createdValuesStream(), updatedAndNotRemovedRolesStream)
.filter(entityRealmFilter(realm));
}
@Override @Override
public RoleModel addRealmRole(RealmModel realm, String id, String name) { public RoleModel addRealmRole(RealmModel realm, String id, String name) {
if (getRealmRole(realm, name) != null) { if (getRealmRole(realm, name) != null) {
@ -114,12 +91,12 @@ public class MapRoleProvider implements RoleProvider {
final UUID entityId = id == null ? UUID.randomUUID() : UUID.fromString(id); final UUID entityId = id == null ? UUID.randomUUID() : UUID.fromString(id);
LOG.tracef("addRealmRole(%s, %s, %s)%s", realm.getName(), id, name, getShortStackTrace()); LOG.tracef("addRealmRole(%s, %s, %s)%s", realm, id, name, getShortStackTrace());
MapRoleEntity entity = new MapRoleEntity(entityId, realm.getId()); MapRoleEntity entity = new MapRoleEntity(entityId, realm.getId());
entity.setName(name); entity.setName(name);
entity.setRealmId(realm.getId()); entity.setRealmId(realm.getId());
if (tx.read(entity.getId(), roleStore::read) != null) { if (tx.read(entity.getId()) != null) {
throw new ModelDuplicateException("Role exists: " + id); throw new ModelDuplicateException("Role exists: " + id);
} }
tx.create(entity.getId(), entity); tx.create(entity.getId(), entity);
@ -133,16 +110,15 @@ public class MapRoleProvider implements RoleProvider {
@Override @Override
public Stream<RoleModel> getRealmRolesStream(RealmModel realm) { public Stream<RoleModel> getRealmRolesStream(RealmModel realm) {
return getNotRemovedUpdatedRolesStream(realm) ModelCriteriaBuilder<RoleModel> mcb = roleStore.createCriteriaBuilder()
.filter(this::isRealmRole) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.IS_CLIENT_ROLE, Operator.NE, true);
return tx.getUpdatedNotRemoved(mcb)
.sorted(COMPARE_BY_NAME) .sorted(COMPARE_BY_NAME)
.map(entityToAdapterFunc(realm)); .map(entityToAdapterFunc(realm));
} }
private boolean isRealmRole(MapRoleEntity role) {
return ! role.isClientRole();
}
@Override @Override
public RoleModel addClientRole(ClientModel client, String id, String name) { public RoleModel addClientRole(ClientModel client, String id, String name) {
if (getClientRole(client, name) != null) { if (getClientRole(client, name) != null) {
@ -151,13 +127,13 @@ public class MapRoleProvider implements RoleProvider {
final UUID entityId = id == null ? UUID.randomUUID() : UUID.fromString(id); final UUID entityId = id == null ? UUID.randomUUID() : UUID.fromString(id);
LOG.tracef("addClientRole(%s, %s, %s)%s", client.getClientId(), id, name, getShortStackTrace()); LOG.tracef("addClientRole(%s, %s, %s)%s", client, id, name, getShortStackTrace());
MapRoleEntity entity = new MapRoleEntity(entityId, client.getRealm().getId()); MapRoleEntity entity = new MapRoleEntity(entityId, client.getRealm().getId());
entity.setName(name); entity.setName(name);
entity.setClientRole(true); entity.setClientRole(true);
entity.setClientId(client.getId()); entity.setClientId(client.getId());
if (tx.read(entity.getId(), roleStore::read) != null) { if (tx.read(entity.getId()) != null) {
throw new ModelDuplicateException("Role exists: " + id); throw new ModelDuplicateException("Role exists: " + id);
} }
tx.create(entity.getId(), entity); tx.create(entity.getId(), entity);
@ -171,8 +147,11 @@ public class MapRoleProvider implements RoleProvider {
@Override @Override
public Stream<RoleModel> getClientRolesStream(ClientModel client) { public Stream<RoleModel> getClientRolesStream(ClientModel client) {
return getNotRemovedUpdatedRolesStream(client.getRealm()) ModelCriteriaBuilder<RoleModel> mcb = roleStore.createCriteriaBuilder()
.filter(entityClientFilter(client)) .compare(SearchableFields.REALM_ID, Operator.EQ, client.getRealm().getId())
.compare(SearchableFields.CLIENT_ID, Operator.EQ, client.getId());
return tx.getUpdatedNotRemoved(mcb)
.sorted(COMPARE_BY_NAME) .sorted(COMPARE_BY_NAME)
.map(entityToAdapterFunc(client.getRealm())); .map(entityToAdapterFunc(client.getRealm()));
} }
@ -184,16 +163,23 @@ public class MapRoleProvider implements RoleProvider {
session.users().preRemove(realm, role); session.users().preRemove(realm, role);
ModelCriteriaBuilder<RoleModel> mcb = roleStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.IS_CLIENT_ROLE, Operator.EQ, false)
.compare(SearchableFields.IS_COMPOSITE_ROLE, Operator.EQ, false);
//remove role from realm-roles composites //remove role from realm-roles composites
try (Stream<MapRoleEntity> baseStream = getNotRemovedUpdatedRolesStream(realm) try (Stream<MapRoleEntity> baseStream = tx.getUpdatedNotRemoved(mcb)) {
.filter(this::isRealmRole)
.filter(MapRoleEntity::isComposite)) {
StreamUtils.leftInnerJoinIterable(baseStream, MapRoleEntity::getCompositeRoles) StreamUtils.leftInnerJoinIterable(baseStream, MapRoleEntity::getCompositeRoles)
.filter(pair -> role.getId().equals(pair.getV())) .filter(pair -> role.getId().equals(pair.getV()))
.collect(Collectors.toSet()) .collect(Collectors.toSet())
.forEach(pair -> { .forEach(pair -> {
MapRoleEntity origEntity = pair.getK(); MapRoleEntity origEntity = pair.getK();
//
// TODO: Investigate what this is for - the return value is ignored
//
registerEntityForChanges(origEntity); registerEntityForChanges(origEntity);
origEntity.removeCompositeRole(role.getId()); origEntity.removeCompositeRole(role.getId());
}); });
@ -202,15 +188,22 @@ public class MapRoleProvider implements RoleProvider {
//remove role from client-roles composites //remove role from client-roles composites
session.clients().getClientsStream(realm).forEach(client -> { session.clients().getClientsStream(realm).forEach(client -> {
client.deleteScopeMapping(role); client.deleteScopeMapping(role);
try (Stream<MapRoleEntity> baseStream = getNotRemovedUpdatedRolesStream(client.getRealm()) ModelCriteriaBuilder<RoleModel> mcbClient = roleStore.createCriteriaBuilder()
.filter(entityClientFilter(client)) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.filter(MapRoleEntity::isComposite)) { .compare(SearchableFields.CLIENT_ID, Operator.EQ, client.getId())
.compare(SearchableFields.IS_COMPOSITE_ROLE, Operator.EQ, false);
try (Stream<MapRoleEntity> baseStream = tx.getUpdatedNotRemoved(mcbClient)) {
StreamUtils.leftInnerJoinIterable(baseStream, MapRoleEntity::getCompositeRoles) StreamUtils.leftInnerJoinIterable(baseStream, MapRoleEntity::getCompositeRoles)
.filter(pair -> role.getId().equals(pair.getV())) .filter(pair -> role.getId().equals(pair.getV()))
.collect(Collectors.toSet()) .collect(Collectors.toSet())
.forEach(pair -> { .forEach(pair -> {
MapRoleEntity origEntity = pair.getK(); MapRoleEntity origEntity = pair.getK();
//
// TODO: Investigate what this is for - the return value is ignored
//
registerEntityForChanges(origEntity); registerEntityForChanges(origEntity);
origEntity.removeCompositeRole(role.getId()); origEntity.removeCompositeRole(role.getId());
}); });
@ -253,12 +246,13 @@ public class MapRoleProvider implements RoleProvider {
if (name == null) { if (name == null) {
return null; return null;
} }
LOG.tracef("getRealmRole(%s, %s)%s", realm.getName(), name, getShortStackTrace()); LOG.tracef("getRealmRole(%s, %s)%s", realm, name, getShortStackTrace());
String roleNameLower = name.toLowerCase(); ModelCriteriaBuilder<RoleModel> mcb = roleStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.NAME, Operator.ILIKE, name);
String roleId = getNotRemovedUpdatedRolesStream(realm) String roleId = tx.getUpdatedNotRemoved(mcb)
.filter(entity -> entity.getName()!= null && Objects.equals(entity.getName().toLowerCase(), roleNameLower))
.map(entityToAdapterFunc(realm)) .map(entityToAdapterFunc(realm))
.map(RoleModel::getId) .map(RoleModel::getId)
.findFirst() .findFirst()
@ -272,13 +266,14 @@ public class MapRoleProvider implements RoleProvider {
if (name == null) { if (name == null) {
return null; return null;
} }
LOG.tracef("getClientRole(%s, %s)%s", client.getClientId(), name, getShortStackTrace()); LOG.tracef("getClientRole(%s, %s)%s", client, name, getShortStackTrace());
String roleNameLower = name.toLowerCase(); ModelCriteriaBuilder<RoleModel> mcb = roleStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, client.getRealm().getId())
.compare(SearchableFields.CLIENT_ID, Operator.EQ, client.getId())
.compare(SearchableFields.NAME, Operator.ILIKE, name);
String roleId = getNotRemovedUpdatedRolesStream(client.getRealm()) String roleId = tx.getUpdatedNotRemoved(mcb)
.filter(entityClientFilter(client))
.filter(entity -> entity.getName()!= null && Objects.equals(entity.getName().toLowerCase(), roleNameLower))
.map(entityToAdapterFunc(client.getRealm())) .map(entityToAdapterFunc(client.getRealm()))
.map(RoleModel::getId) .map(RoleModel::getId)
.findFirst() .findFirst()
@ -289,14 +284,15 @@ public class MapRoleProvider implements RoleProvider {
@Override @Override
public RoleModel getRoleById(RealmModel realm, String id) { public RoleModel getRoleById(RealmModel realm, String id) {
if (id == null) { if (id == null || realm == null || realm.getId() == null) {
return null; return null;
} }
LOG.tracef("getRoleById(%s, %s)%s", realm.getName(), id, getShortStackTrace()); LOG.tracef("getRoleById(%s, %s)%s", realm, id, getShortStackTrace());
MapRoleEntity entity = tx.read(UUID.fromString(id), roleStore::read); MapRoleEntity entity = tx.read(UUID.fromString(id));
return (entity == null || ! entityRealmFilter(realm).test(entity)) String realmId = realm.getId();
return (entity == null || ! Objects.equals(realmId, entity.getRealmId()))
? null ? null
: entityToAdapterFunc(realm).apply(entity); : entityToAdapterFunc(realm).apply(entity);
} }
@ -306,12 +302,14 @@ public class MapRoleProvider implements RoleProvider {
if (search == null) { if (search == null) {
return Stream.empty(); return Stream.empty();
} }
String searchLower = search.toLowerCase(); ModelCriteriaBuilder<RoleModel> mcb = roleStore.createCriteriaBuilder()
Stream<MapRoleEntity> s = getNotRemovedUpdatedRolesStream(realm) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.filter(entity -> .or(
(entity.getName() != null && entity.getName().toLowerCase().contains(searchLower)) || roleStore.createCriteriaBuilder().compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%"),
(entity.getDescription() != null && entity.getDescription().toLowerCase().contains(searchLower)) roleStore.createCriteriaBuilder().compare(SearchableFields.DESCRIPTION, Operator.ILIKE, "%" + search + "%")
) );
Stream<MapRoleEntity> s = tx.getUpdatedNotRemoved(mcb)
.sorted(COMPARE_BY_NAME); .sorted(COMPARE_BY_NAME);
return paginatedStream(s.map(entityToAdapterFunc(realm)), first, max); return paginatedStream(s.map(entityToAdapterFunc(realm)), first, max);
@ -322,13 +320,14 @@ public class MapRoleProvider implements RoleProvider {
if (search == null) { if (search == null) {
return Stream.empty(); return Stream.empty();
} }
String searchLower = search.toLowerCase(); ModelCriteriaBuilder<RoleModel> mcb = roleStore.createCriteriaBuilder()
Stream<MapRoleEntity> s = getNotRemovedUpdatedRolesStream(client.getRealm()) .compare(SearchableFields.REALM_ID, Operator.EQ, client.getRealm().getId())
.filter(entityClientFilter(client)) .compare(SearchableFields.CLIENT_ID, Operator.EQ, client.getId())
.filter(entity -> .or(
(entity.getName() != null && entity.getName().toLowerCase().contains(searchLower)) || roleStore.createCriteriaBuilder().compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%"),
(entity.getDescription() != null && entity.getDescription().toLowerCase().contains(searchLower)) roleStore.createCriteriaBuilder().compare(SearchableFields.DESCRIPTION, Operator.ILIKE, "%" + search + "%")
) );
Stream<MapRoleEntity> s = tx.getUpdatedNotRemoved(mcb)
.sorted(COMPARE_BY_NAME); .sorted(COMPARE_BY_NAME);
return paginatedStream(s,first, max).map(entityToAdapterFunc(client.getRealm())); return paginatedStream(s,first, max).map(entityToAdapterFunc(client.getRealm()));

View file

@ -20,6 +20,7 @@ import java.util.UUID;
import org.keycloak.models.map.common.AbstractMapProviderFactory; import org.keycloak.models.map.common.AbstractMapProviderFactory;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory; import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.RoleModel;
import org.keycloak.models.RoleProvider; import org.keycloak.models.RoleProvider;
import org.keycloak.models.RoleProviderFactory; import org.keycloak.models.RoleProviderFactory;
import org.keycloak.models.map.storage.MapStorageProvider; import org.keycloak.models.map.storage.MapStorageProvider;
@ -27,12 +28,12 @@ import org.keycloak.models.map.storage.MapStorage;
public class MapRoleProviderFactory extends AbstractMapProviderFactory<RoleProvider> implements RoleProviderFactory { public class MapRoleProviderFactory extends AbstractMapProviderFactory<RoleProvider> implements RoleProviderFactory {
private MapStorage<UUID, MapRoleEntity> store; private MapStorage<UUID, MapRoleEntity, RoleModel> store;
@Override @Override
public void postInit(KeycloakSessionFactory factory) { public void postInit(KeycloakSessionFactory factory) {
MapStorageProvider sp = (MapStorageProvider) factory.getProviderFactory(MapStorageProvider.class); MapStorageProvider sp = (MapStorageProvider) factory.getProviderFactory(MapStorageProvider.class);
this.store = sp.getStorage("roles", UUID.class, MapRoleEntity.class); this.store = sp.getStorage("roles", UUID.class, MapRoleEntity.class, RoleModel.class);
} }

View file

@ -0,0 +1,239 @@
/*
* Copyright 2021 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.map.storage;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumMap;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
*
* @author hmlnarik
*/
class CriteriaOperator {
private static final EnumMap<Operator, Function<Object[], Predicate<Object>>> OPERATORS = new EnumMap<>(Operator.class);
private static final Logger LOG = Logger.getLogger(CriteriaOperator.class.getSimpleName());
private static final Predicate<Object> ALWAYS_FALSE = o -> false;
static {
OPERATORS.put(Operator.EQ, CriteriaOperator::eq);
OPERATORS.put(Operator.NE, CriteriaOperator::ne);
OPERATORS.put(Operator.EXISTS, CriteriaOperator::exists);
OPERATORS.put(Operator.NOT_EXISTS, CriteriaOperator::notExists);
OPERATORS.put(Operator.LT, CriteriaOperator::lt);
OPERATORS.put(Operator.LE, CriteriaOperator::le);
OPERATORS.put(Operator.GT, CriteriaOperator::gt);
OPERATORS.put(Operator.GE, CriteriaOperator::ge);
OPERATORS.put(Operator.IN, CriteriaOperator::in);
OPERATORS.put(Operator.LIKE, CriteriaOperator::like);
OPERATORS.put(Operator.ILIKE, CriteriaOperator::ilike);
// Check that all operators are covered
EnumSet<Operator> s = EnumSet.allOf(Operator.class);
s.removeAll(OPERATORS.keySet());
if (! s.isEmpty()) {
throw new IllegalStateException("Some operators are not implemented: " + s);
}
}
/**
* Returns a predicate {@code P(x)} for comparing {@code value} and {@code x} as {@code x OP value}.
* <b>Implementation note:</b> Note that this may mean reverse logic to e.g. {@link Comparable#compareTo}.
* @param operator
* @param value
* @return
*/
public static Predicate<Object> predicateFor(Operator op, Object[] value) {
final Function<Object[], Predicate<Object>> funcToGetPredicate = OPERATORS.get(op);
if (funcToGetPredicate == null) {
throw new IllegalArgumentException("Unknown operator: " + op);
}
return funcToGetPredicate.apply(value);
}
private static Object getFirstArrayElement(Object[] value) throws IllegalStateException {
if (value == null || value.length != 1) {
throw new IllegalStateException("Invalid argument: " + Arrays.toString(value));
}
return value[0];
}
public static Predicate<Object> eq(Object[] value) {
Object value0 = getFirstArrayElement(value);
return new Predicate<Object>() {
@Override public boolean test(Object v) { return Objects.equals(v, value0); }
};
}
public static Predicate<Object> ne(Object[] value) {
Object value0 = getFirstArrayElement(value);
return new Predicate<Object>() {
@Override public boolean test(Object v) { return ! Objects.equals(v, value0); }
};
}
public static Predicate<Object> exists(Object[] value) {
if (value != null && value.length != 0) {
throw new IllegalStateException("Invalid argument: " + Arrays.toString(value));
}
return Objects::nonNull;
}
public static Predicate<Object> notExists(Object[] value) {
if (value != null && value.length != 0) {
throw new IllegalStateException("Invalid argument: " + Arrays.toString(value));
}
return Objects::isNull;
}
public static Predicate<Object> in(Object[] value) {
if (value == null || value.length == 0) {
return ALWAYS_FALSE;
}
final Collection<?> operand;
if (value.length == 1) {
final Object value0 = value[0];
if (value0 instanceof Collection) {
operand = (Collection) value0;
} else if (value0 instanceof Stream) {
try (Stream valueS = (Stream) value0) {
operand = (Set) valueS.collect(Collectors.toSet());
}
} else {
operand = Collections.singleton(value0);
}
} else {
operand = new HashSet(Arrays.asList(value));
}
return operand.isEmpty() ? ALWAYS_FALSE : new Predicate<Object>() {
@Override public boolean test(Object v) { return operand.contains(v); }
};
}
public static Predicate<Object> lt(Object[] value) {
return getComparisonPredicate(ComparisonPredicateImpl.Op.LT, value);
}
public static Predicate<Object> le(Object[] value) {
return getComparisonPredicate(ComparisonPredicateImpl.Op.LE, value);
}
public static Predicate<Object> gt(Object[] value) {
return getComparisonPredicate(ComparisonPredicateImpl.Op.GT, value);
}
public static Predicate<Object> ge(Object[] value) {
return getComparisonPredicate(ComparisonPredicateImpl.Op.GE, value);
}
private static Predicate<Object> getComparisonPredicate(ComparisonPredicateImpl.Op op, Object[] value) throws IllegalArgumentException {
Object value0 = getFirstArrayElement(value);
if (value0 instanceof Comparable) {
Comparable cValue = (Comparable) value0;
return new ComparisonPredicateImpl(op, cValue);
} else {
throw new IllegalArgumentException("Incomparable argument for comparison operation: " + value0);
}
}
public static Predicate<Object> like(Object[] value) {
Object value0 = getFirstArrayElement(value);
if (value0 instanceof String) {
String sValue = (String) value0;
boolean anyBeginning = sValue.startsWith("%");
boolean anyEnd = sValue.endsWith("%");
Pattern pValue = Pattern.compile(
(anyBeginning ? ".*" : "")
+ Pattern.quote(sValue.substring(anyBeginning ? 1 : 0, sValue.length() - (anyEnd ? 1 : 0)))
+ (anyEnd ? ".*" : ""),
Pattern.DOTALL
);
return o -> {
return o instanceof String && pValue.matcher((String) o).matches();
};
}
return ALWAYS_FALSE;
}
public static Predicate<Object> ilike(Object[] value) {
Object value0 = getFirstArrayElement(value);
if (value0 instanceof String) {
String sValue = (String) value0;
boolean anyBeginning = sValue.startsWith("%");
boolean anyEnd = sValue.endsWith("%");
Pattern pValue = Pattern.compile(
(anyBeginning ? ".*" : "")
+ Pattern.quote(sValue.substring(anyBeginning ? 1 : 0, sValue.length() - (anyEnd ? 1 : 0)))
+ (anyEnd ? ".*" : ""),
Pattern.CASE_INSENSITIVE + Pattern.DOTALL
);
return o -> {
return o instanceof String && pValue.matcher((String) o).matches();
};
}
return ALWAYS_FALSE;
}
private static class ComparisonPredicateImpl implements Predicate<Object> {
private static enum Op {
LT { @Override boolean isComparisonTrue(int compareToValue) { return compareToValue > 0; } },
LE { @Override boolean isComparisonTrue(int compareToValue) { return compareToValue >= 0; } },
GT { @Override boolean isComparisonTrue(int compareToValue) { return compareToValue < 0; } },
GE { @Override boolean isComparisonTrue(int compareToValue) { return compareToValue <= 0; } },
;
abstract boolean isComparisonTrue(int compareToValue);
}
private final Op op;
private final Comparable cValue;
public ComparisonPredicateImpl(Op op, Comparable cValue) {
this.op = op;
this.cValue = cValue;
}
@Override
public boolean test(Object o) {
try {
return o != null && op.isComparisonTrue(cValue.compareTo(o));
} catch (ClassCastException ex) {
LOG.log(Level.WARNING, "Incomparable argument type for comparison operation: {0}", cValue.getClass().getSimpleName());
return false;
}
}
}
}

View file

@ -0,0 +1,55 @@
/*
* Copyright 2021 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.map.storage;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import org.keycloak.storage.SearchableModelField;
/**
*
* @author hmlnarik
*/
public class CriterionNotSupportedException extends RuntimeException {
private final SearchableModelField field;
private final Operator op;
public CriterionNotSupportedException(SearchableModelField field, Operator op) {
super("Criterion not supported: operator: " + op + ", field: " + field);
this.field = field;
this.op = op;
}
public CriterionNotSupportedException(SearchableModelField field, Operator op, String message) {
super(message);
this.field = field;
this.op = op;
}
public CriterionNotSupportedException(SearchableModelField field, Operator op, String message, Throwable cause) {
super(message, cause);
this.field = field;
this.op = op;
}
public SearchableModelField getField() {
return field;
}
public Operator getOp() {
return op;
}
}

View file

@ -0,0 +1,243 @@
/*
* Copyright 2021 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.map.storage;
import org.keycloak.models.ClientModel;
import org.keycloak.models.GroupModel;
import org.keycloak.models.RoleModel;
import org.keycloak.models.UserModel;
import org.keycloak.models.map.authSession.AbstractRootAuthenticationSessionEntity;
import org.keycloak.models.map.client.AbstractClientEntity;
import org.keycloak.models.map.common.AbstractEntity;
import org.keycloak.models.map.group.AbstractGroupEntity;
import org.keycloak.models.map.role.AbstractRoleEntity;
import org.keycloak.storage.SearchableModelField;
import java.util.HashMap;
import java.util.Map;
import org.keycloak.models.map.storage.MapModelCriteriaBuilder.UpdatePredicatesFunc;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import org.keycloak.models.map.user.AbstractUserEntity;
import org.keycloak.models.map.user.UserConsentEntity;
import org.keycloak.sessions.RootAuthenticationSessionModel;
import org.keycloak.storage.StorageId;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Predicate;
/**
*
* @author hmlnarik
*/
public class MapFieldPredicates {
public static final Map<SearchableModelField<ClientModel>, UpdatePredicatesFunc<Object, AbstractClientEntity<Object>, ClientModel>> CLIENT_PREDICATES = basePredicates(ClientModel.SearchableFields.ID);
public static final Map<SearchableModelField<GroupModel>, UpdatePredicatesFunc<Object, AbstractGroupEntity<Object>, GroupModel>> GROUP_PREDICATES = basePredicates(GroupModel.SearchableFields.ID);
public static final Map<SearchableModelField<RoleModel>, UpdatePredicatesFunc<Object, AbstractRoleEntity<Object>, RoleModel>> ROLE_PREDICATES = basePredicates(RoleModel.SearchableFields.ID);
public static final Map<SearchableModelField<UserModel>, UpdatePredicatesFunc<Object, AbstractUserEntity<Object>, UserModel>> USER_PREDICATES = basePredicates(UserModel.SearchableFields.ID);
public static final Map<SearchableModelField<RootAuthenticationSessionModel>, UpdatePredicatesFunc<Object, AbstractRootAuthenticationSessionEntity<Object>, RootAuthenticationSessionModel>> AUTHENTICATION_SESSION_PREDICATES = basePredicates(RootAuthenticationSessionModel.SearchableFields.ID);
@SuppressWarnings("unchecked")
private static final Map<Class<?>, Map> PREDICATES = new HashMap<>();
static {
put(CLIENT_PREDICATES, ClientModel.SearchableFields.REALM_ID, AbstractClientEntity::getRealmId);
put(CLIENT_PREDICATES, ClientModel.SearchableFields.CLIENT_ID, AbstractClientEntity::getClientId);
put(GROUP_PREDICATES, GroupModel.SearchableFields.REALM_ID, AbstractGroupEntity::getRealmId);
put(GROUP_PREDICATES, GroupModel.SearchableFields.NAME, AbstractGroupEntity::getName);
put(GROUP_PREDICATES, GroupModel.SearchableFields.PARENT_ID, AbstractGroupEntity::getParentId);
put(GROUP_PREDICATES, GroupModel.SearchableFields.ASSIGNED_ROLE, MapFieldPredicates::checkGrantedGroupRole);
put(ROLE_PREDICATES, RoleModel.SearchableFields.REALM_ID, AbstractRoleEntity::getRealmId);
put(ROLE_PREDICATES, RoleModel.SearchableFields.CLIENT_ID, AbstractRoleEntity::getClientId);
put(ROLE_PREDICATES, RoleModel.SearchableFields.DESCRIPTION, AbstractRoleEntity::getDescription);
put(ROLE_PREDICATES, RoleModel.SearchableFields.NAME, AbstractRoleEntity::getName);
put(ROLE_PREDICATES, RoleModel.SearchableFields.IS_CLIENT_ROLE, AbstractRoleEntity::isClientRole);
put(ROLE_PREDICATES, RoleModel.SearchableFields.IS_COMPOSITE_ROLE, AbstractRoleEntity::isComposite);
put(USER_PREDICATES, UserModel.SearchableFields.REALM_ID, AbstractUserEntity::getRealmId);
put(USER_PREDICATES, UserModel.SearchableFields.USERNAME, AbstractUserEntity::getUsername);
put(USER_PREDICATES, UserModel.SearchableFields.FIRST_NAME, AbstractUserEntity::getFirstName);
put(USER_PREDICATES, UserModel.SearchableFields.LAST_NAME, AbstractUserEntity::getLastName);
put(USER_PREDICATES, UserModel.SearchableFields.EMAIL, AbstractUserEntity::getEmail);
put(USER_PREDICATES, UserModel.SearchableFields.ENABLED, AbstractUserEntity::isEnabled);
put(USER_PREDICATES, UserModel.SearchableFields.EMAIL_VERIFIED, AbstractUserEntity::isEmailVerified);
put(USER_PREDICATES, UserModel.SearchableFields.FEDERATION_LINK, AbstractUserEntity::getFederationLink);
put(USER_PREDICATES, UserModel.SearchableFields.ATTRIBUTE, MapFieldPredicates::checkUserAttributes);
put(USER_PREDICATES, UserModel.SearchableFields.IDP_AND_USER, MapFieldPredicates::getUserIdpAliasAtIdentityProviderPredicate);
put(USER_PREDICATES, UserModel.SearchableFields.ASSIGNED_ROLE, MapFieldPredicates::checkGrantedUserRole);
put(USER_PREDICATES, UserModel.SearchableFields.ASSIGNED_GROUP, MapFieldPredicates::checkUserGroup);
put(USER_PREDICATES, UserModel.SearchableFields.CONSENT_FOR_CLIENT, MapFieldPredicates::checkUserClientConsent);
put(USER_PREDICATES, UserModel.SearchableFields.CONSENT_WITH_CLIENT_SCOPE, MapFieldPredicates::checkUserConsentsWithClientScope);
put(USER_PREDICATES, UserModel.SearchableFields.CONSENT_CLIENT_FEDERATION_LINK, MapFieldPredicates::getUserConsentClientFederationLink);
put(USER_PREDICATES, UserModel.SearchableFields.SERVICE_ACCOUNT_CLIENT, AbstractUserEntity::getServiceAccountClientLink);
put(AUTHENTICATION_SESSION_PREDICATES, RootAuthenticationSessionModel.SearchableFields.REALM_ID, AbstractRootAuthenticationSessionEntity::getRealmId);
put(AUTHENTICATION_SESSION_PREDICATES, RootAuthenticationSessionModel.SearchableFields.TIMESTAMP, AbstractRootAuthenticationSessionEntity::getTimestamp);
}
static {
PREDICATES.put(ClientModel.class, CLIENT_PREDICATES);
PREDICATES.put(RoleModel.class, ROLE_PREDICATES);
PREDICATES.put(GroupModel.class, GROUP_PREDICATES);
PREDICATES.put(UserModel.class, USER_PREDICATES);
PREDICATES.put(RootAuthenticationSessionModel.class, AUTHENTICATION_SESSION_PREDICATES);
}
private static <K, V extends AbstractEntity<K>, M> void put(
Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> map,
SearchableModelField<M> field, Function<V, Object> extractor) {
map.put(field, (mcb, op, values) -> mcb.fieldCompare(op, extractor, values));
}
private static <K, V extends AbstractEntity<K>, M> void put(
Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> map,
SearchableModelField<M> field, UpdatePredicatesFunc<K, V, M> function) {
map.put(field, function);
}
private static String ensureEqSingleValue(SearchableModelField<?> field, String parameterName, Operator op, Object[] values) throws CriterionNotSupportedException {
if (op != Operator.EQ) {
throw new CriterionNotSupportedException(field, op);
}
if (values == null || values.length != 1) {
throw new CriterionNotSupportedException(field, op, "Invalid arguments, expected (" + parameterName + "), got: " + Arrays.toString(values));
}
final Object ob = values[0];
if (! (ob instanceof String)) {
throw new CriterionNotSupportedException(field, op, "Invalid arguments, expected (String role_id), got: " + Arrays.toString(values));
}
String s = (String) ob;
return s;
}
private static MapModelCriteriaBuilder<Object, AbstractGroupEntity<Object>, GroupModel> checkGrantedGroupRole(MapModelCriteriaBuilder<Object, AbstractGroupEntity<Object>, GroupModel> mcb, Operator op, Object[] values) {
String roleIdS = ensureEqSingleValue(GroupModel.SearchableFields.ASSIGNED_ROLE, "role_id", op, values);
Function<AbstractGroupEntity<Object>, ?> getter;
getter = ge -> ge.getGrantedRoles().contains(roleIdS);
return mcb.fieldCompare(Boolean.TRUE::equals, getter);
}
private static MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> getUserConsentClientFederationLink(MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> mcb, Operator op, Object[] values) {
String providerId = ensureEqSingleValue(UserModel.SearchableFields.CONSENT_CLIENT_FEDERATION_LINK, "provider_id", op, values);
String providerIdS = new StorageId((String) providerId, "").getId();
Function<AbstractUserEntity<Object>, ?> getter;
getter = ue -> ue.getUserConsents().map(UserConsentEntity::getClientId).anyMatch(v -> v != null && v.startsWith(providerIdS));
return mcb.fieldCompare(Boolean.TRUE::equals, getter);
}
private static MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> checkUserAttributes(MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> mcb, Operator op, Object[] values) {
if (values == null || values.length <= 1) {
throw new CriterionNotSupportedException(UserModel.SearchableFields.ATTRIBUTE, op, "Invalid arguments, expected (attribute_name, ...), got: " + Arrays.toString(values));
}
final Object attrName = values[0];
if (! (attrName instanceof String)) {
throw new CriterionNotSupportedException(UserModel.SearchableFields.ATTRIBUTE, op, "Invalid arguments, expected (String attribute_name), got: " + Arrays.toString(values));
}
String attrNameS = (String) attrName;
Function<AbstractUserEntity<Object>, ?> getter;
Object[] realValues = new Object[values.length - 1];
System.arraycopy(values, 1, realValues, 0, values.length - 1);
Predicate<Object> valueComparator = CriteriaOperator.predicateFor(op, realValues);
getter = ue -> {
final List<String> attrs = ue.getAttribute(attrNameS);
return attrs != null && attrs.stream().anyMatch(valueComparator);
};
return mcb.fieldCompare(Boolean.TRUE::equals, getter);
}
private static MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> checkGrantedUserRole(MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> mcb, Operator op, Object[] values) {
String roleIdS = ensureEqSingleValue(UserModel.SearchableFields.ASSIGNED_ROLE, "role_id", op, values);
Function<AbstractUserEntity<Object>, ?> getter;
getter = ue -> ue.getRolesMembership().contains(roleIdS);
return mcb.fieldCompare(Boolean.TRUE::equals, getter);
}
private static MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> checkUserGroup(MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> mcb, Operator op, Object[] values) {
Function<AbstractUserEntity<Object>, ?> getter;
if (op == Operator.IN && values != null && values.length == 1 && (values[0] instanceof Collection)) {
Collection<?> c = (Collection<?>) values[0];
getter = ue -> ue.getGroupsMembership().stream().anyMatch(c::contains);
} else {
String groupIdS = ensureEqSingleValue(UserModel.SearchableFields.ASSIGNED_GROUP, "group_id", op, values);
getter = ue -> ue.getGroupsMembership().contains(groupIdS);
}
return mcb.fieldCompare(Boolean.TRUE::equals, getter);
}
private static MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> checkUserClientConsent(MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> mcb, Operator op, Object[] values) {
String clientIdS = ensureEqSingleValue(UserModel.SearchableFields.CONSENT_FOR_CLIENT, "client_id", op, values);
Function<AbstractUserEntity<Object>, ?> getter;
getter = ue -> ue.getUserConsent(clientIdS);
return mcb.fieldCompare(Operator.EXISTS, getter, null);
}
private static MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> checkUserConsentsWithClientScope(MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> mcb, Operator op, Object[] values) {
String clientScopeIdS = ensureEqSingleValue(UserModel.SearchableFields.CONSENT_FOR_CLIENT, "client_scope_id", op, values);
Function<AbstractUserEntity<Object>, ?> getter;
getter = ue -> ue.getUserConsents().anyMatch(consent -> consent.getGrantedClientScopesIds().contains(clientScopeIdS));
return mcb.fieldCompare(Boolean.TRUE::equals, getter);
}
private static MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> getUserIdpAliasAtIdentityProviderPredicate(MapModelCriteriaBuilder<Object, AbstractUserEntity<Object>, UserModel> mcb, Operator op, Object[] values) {
if (op != Operator.EQ) {
throw new CriterionNotSupportedException(UserModel.SearchableFields.IDP_AND_USER, op);
}
if (values == null || values.length == 0 || values.length > 2) {
throw new CriterionNotSupportedException(UserModel.SearchableFields.IDP_AND_USER, op, "Invalid arguments, expected (idp_alias) or (idp_alias, idp_user), got: " + Arrays.toString(values));
}
final Object idpAlias = values[0];
Function<AbstractUserEntity<Object>, ?> getter;
if (values.length == 1) {
getter = ue -> ue.getFederatedIdentities()
.anyMatch(aue -> Objects.equals(idpAlias, aue.getIdentityProvider()));
} else if (idpAlias == null) {
final Object idpUserId = values[1];
getter = ue -> ue.getFederatedIdentities()
.anyMatch(aue -> Objects.equals(idpUserId, aue.getUserId()));
} else {
final Object idpUserId = values[1];
getter = ue -> ue.getFederatedIdentities()
.anyMatch(aue -> Objects.equals(idpAlias, aue.getIdentityProvider()) && Objects.equals(idpUserId, aue.getUserId()));
}
return mcb.fieldCompare(Boolean.TRUE::equals, getter);
}
protected static <K, V extends AbstractEntity<K>, M> Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> basePredicates(SearchableModelField<M> idField) {
Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates = new HashMap<>();
fieldPredicates.put(idField, (o, op, values) -> o.idCompare(op, values));
return fieldPredicates;
}
@SuppressWarnings("unchecked")
public static <K, V extends AbstractEntity<K>, M> Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> getPredicates(Class<M> clazz) {
return PREDICATES.get(clazz);
}
}

View file

@ -18,58 +18,32 @@ package org.keycloak.models.map.storage;
import org.keycloak.models.KeycloakTransaction; import org.keycloak.models.KeycloakTransaction;
import org.keycloak.models.map.common.AbstractEntity;
import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects; import java.util.Objects;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
public class MapKeycloakTransaction<K, V> implements KeycloakTransaction { public class MapKeycloakTransaction<K, V extends AbstractEntity<K>, M> implements KeycloakTransaction {
private final static Logger log = Logger.getLogger(MapKeycloakTransaction.class); private final static Logger log = Logger.getLogger(MapKeycloakTransaction.class);
private enum MapOperation { private enum MapOperation {
CREATE { CREATE, UPDATE, DELETE,
@Override
protected <K, V> MapTaskWithValue<K, V> taskFor(K key, V value) {
return new MapTaskWithValue<K, V>(value) {
@Override public void execute(MapStorage<K, V> map) { map.create(key, getValue()); }
@Override public MapOperation getOperation() { return CREATE; }
};
}
},
UPDATE {
@Override
protected <K, V> MapTaskWithValue<K, V> taskFor(K key, V value) {
return new MapTaskWithValue<K, V>(value) {
@Override public void execute(MapStorage<K, V> map) { map.update(key, getValue()); }
@Override public MapOperation getOperation() { return UPDATE; }
};
}
},
DELETE {
@Override
protected <K, V> MapTaskWithValue<K, V> taskFor(K key, V value) {
return new MapTaskWithValue<K, V>(null) {
@Override public void execute(MapStorage<K, V> map) { map.delete(key); }
@Override public MapOperation getOperation() { return DELETE; }
};
}
},
;
protected abstract <K, V> MapTaskWithValue<K, V> taskFor(K key, V value);
} }
private boolean active; private boolean active;
private boolean rollback; private boolean rollback;
private final Map<K, MapTaskWithValue<K, V>> tasks = new LinkedHashMap<>(); private final Map<K, MapTaskWithValue> tasks = new LinkedHashMap<>();
private final MapStorage<K, V> map; private final MapStorage<K, V, M> map;
public MapKeycloakTransaction(MapStorage<K, V> map) { public MapKeycloakTransaction(MapStorage<K, V, M> map) {
this.map = map; this.map = map;
} }
@ -86,8 +60,8 @@ public class MapKeycloakTransaction<K, V> implements KeycloakTransaction {
throw new RuntimeException("Rollback only!"); throw new RuntimeException("Rollback only!");
} }
for (MapTaskWithValue<K, V> value : tasks.values()) { for (MapTaskWithValue value : tasks.values()) {
value.execute(map); value.execute();
} }
} }
@ -114,83 +88,175 @@ public class MapKeycloakTransaction<K, V> implements KeycloakTransaction {
/** /**
* Adds a given task if not exists for the given key * Adds a given task if not exists for the given key
*/ */
private void addTask(MapOperation op, K key, V value) { protected void addTask(K key, MapTaskWithValue task) {
log.tracef("Adding operation %s for %s @ %08x", op, key, System.identityHashCode(value)); log.tracef("Adding operation %s for %s @ %08x", task.getOperation(), key, System.identityHashCode(task.getValue()));
K taskKey = key; K taskKey = key;
tasks.merge(taskKey, op.taskFor(key, value), MapTaskCompose::new); tasks.merge(taskKey, task, MapTaskCompose::new);
} }
// This is for possibility to lookup for session by id, which was created in this transaction // This is for possibility to lookup for session by id, which was created in this transaction
public V read(K key) {
return read(key, map::read);
}
public V read(K key, Function<K, V> defaultValueFunc) { public V read(K key, Function<K, V> defaultValueFunc) {
MapTaskWithValue<K, V> current = tasks.get(key); MapTaskWithValue current = tasks.get(key);
if (current != null) { // If the key exists, then it has entered the "tasks" after bulk delete that could have
// removed it, so looking through bulk deletes is irrelevant
if (tasks.containsKey(key)) {
return current.getValue(); return current.getValue();
} }
return defaultValueFunc.apply(key); // If the key does not exist, then it would be read fresh from the storage, but then it
// could have been removed by some bulk delete in the existing tasks. Check it.
final V value = defaultValueFunc.apply(key);
for (MapTaskWithValue val : tasks.values()) {
if (val instanceof MapKeycloakTransaction.BulkDeleteOperation) {
final BulkDeleteOperation delOp = (BulkDeleteOperation) val;
if (! delOp.getFilterForNonDeletedObjects().test(value)) {
return null;
}
}
} }
public V getUpdated(Map.Entry<K, V> keyDefaultValue) { return value;
MapTaskWithValue<K, V> current = tasks.get(keyDefaultValue.getKey());
if (current != null) {
return current.getValue();
} }
return keyDefaultValue.getValue(); /**
* Returns the stream of records that match given criteria and includes changes made in this transaction, i.e.
* the result contains updates and excludes records that have been deleted in this transaction.
*
* Note that returned stream might not reflect on the bulk delete. This is known limitation that can be fixed if necessary.
*
* @param mcb
* @return
*/
public Stream<V> getUpdatedNotRemoved(ModelCriteriaBuilder<M> mcb) {
Predicate<? super V> filterOutAllBulkDeletedObjects = tasks.values().stream()
.filter(BulkDeleteOperation.class::isInstance)
.map(BulkDeleteOperation.class::cast)
.map(BulkDeleteOperation::getFilterForNonDeletedObjects)
.reduce(Predicate::and)
.orElse(v -> true);
Stream<V> updatedAndNotRemovedObjectsStream = this.map.read(mcb)
.filter(filterOutAllBulkDeletedObjects)
.map(this::getUpdated) // If the object has been removed, tx.get will return null, otherwise it will return me.getValue()
.filter(Objects::nonNull);
// In case of created values stored in MapKeycloakTransaction, we need filter those according to the filter
MapModelCriteriaBuilder<K, V, M> mapMcb = mcb.unwrap(MapModelCriteriaBuilder.class);
Stream<V> res = mapMcb == null
? updatedAndNotRemovedObjectsStream
: Stream.concat(
createdValuesStream(mapMcb.getKeyFilter(), mapMcb.getEntityFilter()),
updatedAndNotRemovedObjectsStream
);
return res;
}
/**
* Returns the stream of records that match given criteria and includes changes made in this transaction, i.e.
* the result contains updates and excludes records that have been deleted in this transaction.
*
* @param mcb
* @return
*/
public long getCount(ModelCriteriaBuilder<M> mcb) {
return getUpdatedNotRemoved(mcb).count();
}
/**
* Returns a updated version of the {@code orig} object as updated in this transaction.
* If the underlying store handles transactions on its own, this can return {@code orig} directly.
* @param orig
* @return The {@code orig} object as visible from this transaction, or {@code null} if the object has been removed.
*/
public V getUpdated(V orig) {
MapTaskWithValue current = orig == null ? null : tasks.get(orig.getId());
return current == null ? orig : current.getValue();
} }
public void update(K key, V value) { public void update(K key, V value) {
addTask(MapOperation.UPDATE, key, value); addTask(key, new UpdateOperation(key, value));
} }
public void create(K key, V value) { public void create(K key, V value) {
addTask(MapOperation.CREATE, key, value); addTask(key, new CreateOperation(key, value));
} }
public void updateIfChanged(K key, V value, Predicate<V> shouldPut) { public void updateIfChanged(K key, V value, Predicate<V> shouldPut) {
log.tracef("Adding operation UPDATE_IF_CHANGED for %s @ %08x", key, System.identityHashCode(value)); log.tracef("Adding operation UPDATE_IF_CHANGED for %s @ %08x", key, System.identityHashCode(value));
K taskKey = key; K taskKey = key;
MapTaskWithValue<K, V> op = new MapTaskWithValue<K, V>(value) { MapTaskWithValue op = new MapTaskWithValue(value) {
@Override @Override
public void execute(MapStorage<K, V> map) { public void execute() {
if (shouldPut.test(getValue())) { if (shouldPut.test(getValue())) {
map.update(key, getValue()); map.update(key, getValue());
} }
} }
@Override public MapOperation getOperation() { return MapOperation.UPDATE; } @Override public MapOperation getOperation() { return MapOperation.UPDATE; }
}; };
tasks.merge(taskKey, op, MapKeycloakTransaction::merge); tasks.merge(taskKey, op, this::merge);
} }
public void delete(K key) { public void delete(K key) {
addTask(MapOperation.DELETE, key, null); addTask(key, new DeleteOperation(key));
} }
public Stream<V> valuesStream() { /**
return this.tasks.values().stream() * Bulk removal of items.
.map(MapTaskWithValue<K,V>::getValue) *
.filter(Objects::nonNull); * @param artificialKey Key to record the transaction with, must be a key that does not exist in this transaction to
* prevent collision with other operations in this transaction
* @param mcb
*/
public long delete(K artificialKey, ModelCriteriaBuilder<M> mcb) {
log.tracef("Adding operation DELETE_BULK");
// Remove all tasks that create / update / delete objects deleted by the bulk removal.
final BulkDeleteOperation bdo = new BulkDeleteOperation(mcb);
Predicate<V> filterForNonDeletedObjects = bdo.getFilterForNonDeletedObjects();
long res = 0;
for (Iterator<Entry<K, MapTaskWithValue>> it = tasks.entrySet().iterator(); it.hasNext();) {
Entry<K, MapTaskWithValue> me = it.next();
if (! filterForNonDeletedObjects.test(me.getValue().getValue())) {
log.tracef(" [DELETE_BULK] removing %s", me.getKey());
it.remove();
res++;
}
} }
public Stream<V> createdValuesStream() { tasks.put(artificialKey, bdo);
return this.tasks.values().stream()
return res + bdo.getCount();
}
private Stream<V> createdValuesStream(Predicate<? super K> keyFilter, Predicate<? super V> entityFilter) {
return this.tasks.entrySet().stream()
.filter(me -> keyFilter.test(me.getKey()))
.map(Map.Entry::getValue)
.filter(v -> v.containsCreate() && ! v.isReplace()) .filter(v -> v.containsCreate() && ! v.isReplace())
.map(MapTaskWithValue<K,V>::getValue) .map(MapTaskWithValue::getValue)
.filter(Objects::nonNull); .filter(Objects::nonNull)
.filter(entityFilter)
// make a snapshot
.collect(Collectors.toList()).stream();
} }
private static <K, V> MapTaskWithValue<K, V> merge(MapTaskWithValue<K, V> oldValue, MapTaskWithValue<K, V> newValue) { private MapTaskWithValue merge(MapTaskWithValue oldValue, MapTaskWithValue newValue) {
switch (newValue.getOperation()) { switch (newValue.getOperation()) {
case DELETE: case DELETE:
return oldValue.containsCreate() ? null : newValue; return oldValue.containsCreate() ? null : newValue;
default: default:
return new MapTaskCompose<>(oldValue, newValue); return new MapTaskCompose(oldValue, newValue);
} }
} }
private static abstract class MapTaskWithValue<K, V> { protected abstract class MapTaskWithValue {
protected final V value; protected final V value;
public MapTaskWithValue(V value) { public MapTaskWithValue(V value) {
@ -214,24 +280,24 @@ public class MapKeycloakTransaction<K, V> implements KeycloakTransaction {
} }
public abstract MapOperation getOperation(); public abstract MapOperation getOperation();
public abstract void execute(MapStorage<K,V> map); public abstract void execute();
} }
private static class MapTaskCompose<K, V> extends MapTaskWithValue<K, V> { private class MapTaskCompose extends MapTaskWithValue {
private final MapTaskWithValue<K, V> oldValue; private final MapTaskWithValue oldValue;
private final MapTaskWithValue<K, V> newValue; private final MapTaskWithValue newValue;
public MapTaskCompose(MapTaskWithValue<K, V> oldValue, MapTaskWithValue<K, V> newValue) { public MapTaskCompose(MapTaskWithValue oldValue, MapTaskWithValue newValue) {
super(null); super(null);
this.oldValue = oldValue; this.oldValue = oldValue;
this.newValue = newValue; this.newValue = newValue;
} }
@Override @Override
public void execute(MapStorage<K, V> map) { public void execute() {
oldValue.execute(map); oldValue.execute();
newValue.execute(map); newValue.execute();
} }
@Override @Override
@ -257,7 +323,81 @@ public class MapKeycloakTransaction<K, V> implements KeycloakTransaction {
@Override @Override
public boolean isReplace() { public boolean isReplace() {
return (newValue.getOperation() == MapOperation.CREATE && oldValue.containsRemove()) || return (newValue.getOperation() == MapOperation.CREATE && oldValue.containsRemove()) ||
(oldValue instanceof MapTaskCompose && ((MapTaskCompose) oldValue).isReplace()); (oldValue instanceof MapKeycloakTransaction.MapTaskCompose && ((MapTaskCompose) oldValue).isReplace());
}
}
private class CreateOperation extends MapTaskWithValue {
private final K key;
public CreateOperation(K key, V value) {
super(value);
this.key = key;
}
@Override public void execute() { map.create(key, getValue()); }
@Override public MapOperation getOperation() { return MapOperation.CREATE; }
}
private class UpdateOperation extends MapTaskWithValue {
private final K key;
public UpdateOperation(K key, V value) {
super(value);
this.key = key;
}
@Override public void execute() { map.update(key, getValue()); }
@Override public MapOperation getOperation() { return MapOperation.UPDATE; }
}
private class DeleteOperation extends MapTaskWithValue {
private final K key;
public DeleteOperation(K key) {
super(null);
this.key = key;
}
@Override public void execute() { map.delete(key); }
@Override public MapOperation getOperation() { return MapOperation.DELETE; }
}
private class BulkDeleteOperation extends MapTaskWithValue {
private final ModelCriteriaBuilder<M> mcb;
public BulkDeleteOperation(ModelCriteriaBuilder<M> mcb) {
super(null);
this.mcb = mcb;
}
@Override
@SuppressWarnings("unchecked")
public void execute() {
map.delete(mcb);
}
public Predicate<V> getFilterForNonDeletedObjects() {
if (! (mcb instanceof MapModelCriteriaBuilder)) {
return t -> true;
}
@SuppressWarnings("unchecked")
final MapModelCriteriaBuilder<K, V, M> mmcb = (MapModelCriteriaBuilder<K, V, M>) mcb;
Predicate<? super V> entityFilter = mmcb.getEntityFilter();
Predicate<? super K> keyFilter = ((MapModelCriteriaBuilder) mcb).getKeyFilter();
return v -> v != null && ! (keyFilter.test(v.getId()) && entityFilter.test(v));
}
@Override
public MapOperation getOperation() {
return MapOperation.DELETE;
}
private long getCount() {
return map.getCount(mcb);
} }
} }
} }

View file

@ -0,0 +1,142 @@
/*
* Copyright 2021 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.map.storage;
import org.keycloak.models.map.common.AbstractEntity;
import org.keycloak.storage.SearchableModelField;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;
/**
*
* @author hmlnarik
*/
public class MapModelCriteriaBuilder<K, V extends AbstractEntity<K>, M> implements ModelCriteriaBuilder<M> {
@FunctionalInterface
public static interface UpdatePredicatesFunc<K, V extends AbstractEntity<K>, M> {
MapModelCriteriaBuilder<K, V, M> apply(MapModelCriteriaBuilder<K, V, M> builder, Operator op, Object[] params);
}
private static final Predicate<Object> ALWAYS_TRUE = (e) -> true;
private static final Predicate<Object> ALWAYS_FALSE = (e) -> false;
private final Predicate<? super K> keyFilter;
private final Predicate<? super V> entityFilter;
private final Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates;
public MapModelCriteriaBuilder(Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates) {
this(fieldPredicates, ALWAYS_TRUE, ALWAYS_TRUE);
}
private MapModelCriteriaBuilder(Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates, Predicate<? super K> indexReadFilter, Predicate<? super V> sequentialReadFilter) {
this.fieldPredicates = fieldPredicates;
this.keyFilter = indexReadFilter;
this.entityFilter = sequentialReadFilter;
}
@Override
public MapModelCriteriaBuilder<K, V, M> compare(SearchableModelField<M> modelField, Operator op, Object... values) {
UpdatePredicatesFunc<K, V, M> method = fieldPredicates.get(modelField);
if (method == null) {
throw new IllegalArgumentException("Filter not implemented for field " + modelField);
}
return method.apply(this, op, values);
}
@SafeVarargs
@SuppressWarnings("unchecked")
@Override
public final MapModelCriteriaBuilder<K, V, M> and(ModelCriteriaBuilder<M>... builders) {
Predicate<? super K> resIndexFilter = Stream.of(builders).map(MapModelCriteriaBuilder.class::cast).map(MapModelCriteriaBuilder::getKeyFilter).reduce(keyFilter, Predicate::and);
Predicate<V> resEntityFilter = Stream.of(builders).map(MapModelCriteriaBuilder.class::cast).map(MapModelCriteriaBuilder::getEntityFilter).reduce(entityFilter, Predicate::and);
return new MapModelCriteriaBuilder<>(fieldPredicates, resIndexFilter, resEntityFilter);
}
@SafeVarargs
@SuppressWarnings("unchecked")
@Override
public final MapModelCriteriaBuilder<K, V, M> or(ModelCriteriaBuilder<M>... builders) {
Predicate<? super K> resIndexFilter = Stream.of(builders).map(MapModelCriteriaBuilder.class::cast).map(MapModelCriteriaBuilder::getKeyFilter).reduce(ALWAYS_FALSE, Predicate::or);
Predicate<V> resEntityFilter = Stream.of(builders).map(MapModelCriteriaBuilder.class::cast).map(MapModelCriteriaBuilder::getEntityFilter).reduce(ALWAYS_FALSE, Predicate::or);
return new MapModelCriteriaBuilder<>(
fieldPredicates,
v -> keyFilter.test(v) && resIndexFilter.test(v),
v -> entityFilter.test(v) && resEntityFilter.test(v)
);
}
@SuppressWarnings("unchecked")
@Override
public MapModelCriteriaBuilder<K, V, M> not(ModelCriteriaBuilder<M> builder) {
MapModelCriteriaBuilder<K, V, M> b = builder.unwrap(MapModelCriteriaBuilder.class);
if (b == null) {
throw new ClassCastException("Incompatible class: " + builder.getClass());
}
Predicate<? super K> resIndexFilter = b.getKeyFilter() == ALWAYS_TRUE ? ALWAYS_TRUE : b.getKeyFilter().negate();
Predicate<? super V> resEntityFilter = b.getEntityFilter() == ALWAYS_TRUE ? ALWAYS_TRUE : b.getEntityFilter().negate();
return new MapModelCriteriaBuilder<>(
fieldPredicates,
v -> keyFilter.test(v) && ! resIndexFilter.test(v),
v -> entityFilter.test(v) && ! resEntityFilter.test(v)
);
}
public Predicate<? super K> getKeyFilter() {
return keyFilter;
}
public Predicate<? super V> getEntityFilter() {
return entityFilter;
}
protected MapModelCriteriaBuilder<K, V, M> idCompare(Operator op, Object[] values) {
switch (op) {
case LT:
case LE:
case GT:
case GE:
case EQ:
case NE:
case EXISTS:
case NOT_EXISTS:
case IN:
return new MapModelCriteriaBuilder<>(fieldPredicates, this.keyFilter.and(CriteriaOperator.predicateFor(op, values)), this.entityFilter);
default:
throw new AssertionError("Invalid operator: " + op);
}
}
protected MapModelCriteriaBuilder<K, V, M> fieldCompare(Operator op, Function<V, ?> getter, Object[] values) {
Predicate<Object> valueComparator = CriteriaOperator.predicateFor(op, values);
return fieldCompare(valueComparator, getter);
}
protected MapModelCriteriaBuilder<K, V, M> fieldCompare(Predicate<Object> valueComparator, Function<V, ?> getter) {
final Predicate<? super V> resEntityFilter;
if (entityFilter == ALWAYS_FALSE) {
resEntityFilter = ALWAYS_FALSE;
} else {
final Predicate<V> p = v -> valueComparator.test(getter.apply(v));
resEntityFilter = p.and(entityFilter);
}
return new MapModelCriteriaBuilder<>(fieldPredicates, this.keyFilter, resEntityFilter);
}
}

View file

@ -16,29 +16,41 @@
*/ */
package org.keycloak.models.map.storage; package org.keycloak.models.map.storage;
import java.util.Map; import org.keycloak.models.map.common.AbstractEntity;
import java.util.Set;
import java.util.stream.Stream; import java.util.stream.Stream;
/** /**
* Implementation of this interface interacts with a persistence storage storing various entities, e.g. users, realms.
* It contains basic object CRUD operations as well as bulk {@link #read(org.keycloak.models.map.storage.ModelCriteriaBuilder)}
* and bulk {@link #delete(org.keycloak.models.map.storage.ModelCriteriaBuilder)} operations,
* and operation for determining the number of the objects satisfying given criteria
* ({@link #getCount(org.keycloak.models.map.storage.ModelCriteriaBuilder)}).
* *
* @author hmlnarik * @author hmlnarik
* @param <K> Type of the primary key. Various storages can
* @param <V> Type of the stored values that contains all the data stripped of session state. In other words, in the entities
* there are only IDs and mostly primitive types / {@code String}, never references to {@code *Model} instances.
* See the {@code Abstract*Entity} classes in this module.
* @param <M> Type of the {@code *Model} corresponding to the stored value, e.g. {@code UserModel}. This is used for
* filtering via model fields in {@link ModelCriteriaBuilder} which is necessary to abstract from physical
* layout and thus to support no-downtime upgrade.
*/ */
public interface MapStorage<K, V> { public interface MapStorage<K, V extends AbstractEntity<K>, M> {
/** /**
* Creates an object in the store identified by given {@code key}. * Creates an object in the store identified by given {@code key}.
* @param key Key of the object as seen in the logical level * @param key Key of the object as seen in the logical level
* @param value Entity * @param value Entity
* @return Reference to the entity created in the store * @return Reference to the entity created in the store
* @throws NullPointerException if object or its {@code id} is {@code null} * @throws NullPointerException if object or its {@code key} is {@code null}
*/ */
V create(K key, V value); V create(K key, V value);
/** /**
* Returns object with the given {@code key} from the storage or {@code null} if object does not exist. * Returns object with the given {@code key} from the storage or {@code null} if object does not exist.
* @param key Must not be {@code null}. * @param key Key of the object. Must not be {@code null}.
* @return See description * @return See description
* @throws NullPointerException if the {@code key} is {@code null}
*/ */
V read(K key); V read(K key);
@ -46,14 +58,30 @@ public interface MapStorage<K, V> {
* Returns stream of objects satisfying given {@code criteria} from the storage. * Returns stream of objects satisfying given {@code criteria} from the storage.
* The criteria are specified in the given criteria builder based on model properties. * The criteria are specified in the given criteria builder based on model properties.
* *
* @param criteria * @param criteria Criteria filtering out the object, originally obtained
* from {@link #createCriteriaBuilder()} method of this object.
* If {@code null}, it returns an empty stream.
* @return Stream of objects. Never returns {@code null}. * @return Stream of objects. Never returns {@code null}.
* @throws IllegalStateException If {@code criteria} is not compatible, i.e. has not been originally created
* by the {@link #createCriteriaBuilder()} method of this object.
*/ */
Stream<V> read(ModelCriteriaBuilder criteria); Stream<V> read(ModelCriteriaBuilder<M> criteria);
/**
* Returns the number of objects satisfying given {@code criteria} from the storage.
* The criteria are specified in the given criteria builder based on model properties.
*
* @param criteria
* @return Number of objects. Never returns {@code null}.
* @throws IllegalStateException If {@code criteria} is not compatible, i.e. has not been originally created
* by the {@link #createCriteriaBuilder()} method of this object.
*/
long getCount(ModelCriteriaBuilder<M> criteria);
/** /**
* Updates the object with the given {@code id} in the storage if it already exists. * Updates the object with the given {@code id} in the storage if it already exists.
* @param id * @param key Primary key of the object to update
* @param value Updated value
* @throws NullPointerException if object or its {@code id} is {@code null} * @throws NullPointerException if object or its {@code id} is {@code null}
*/ */
V update(K key, V value); V update(K key, V value);
@ -61,8 +89,19 @@ public interface MapStorage<K, V> {
/** /**
* Deletes object with the given {@code key} from the storage, if exists, no-op otherwise. * Deletes object with the given {@code key} from the storage, if exists, no-op otherwise.
* @param key * @param key
* @return Returns {@code true} if the object has been deleted or result cannot be determined, {@code false} otherwise.
*/ */
V delete(K key); boolean delete(K key);
/**
* Deletes objects that match the given criteria.
* @param criteria
* @return Number of removed objects (might return {@code -1} if not supported)
* @throws IllegalStateException If {@code criteria} is not compatible, i.e. has not been originally created
* by the {@link #createCriteriaBuilder()} method of this object.
*/
long delete(ModelCriteriaBuilder<M> criteria);
/** /**
* Returns criteria builder for the storage engine. * Returns criteria builder for the storage engine.
@ -77,9 +116,17 @@ public interface MapStorage<K, V> {
* *
* @return See description * @return See description
*/ */
ModelCriteriaBuilder createCriteriaBuilder(); ModelCriteriaBuilder<M> createCriteriaBuilder();
@Deprecated
Set<Map.Entry<K,V>> entrySet(); /**
* Creates a {@code MapKeycloakTransaction} object that tracks a new transaction related to this storage.
* In case of JPA or similar, the transaction object might be supplied by the container (via JTA) or
* shared same across storages accessing the same database within the same session; in other cases
* (e.g. plain map) a separate transaction handler might be created per each storage.
*
* @return See description.
*/
public MapKeycloakTransaction<K, V, M> createTransaction();
} }

View file

@ -39,5 +39,5 @@ public interface MapStorageProvider extends Provider, ProviderFactory<MapStorage
* @param flags * @param flags
* @return * @return
*/ */
<K, V extends AbstractEntity<K>> MapStorage<K, V> getStorage(String name, Class<K> keyType, Class<V> valueType, Flag... flags); <K, V extends AbstractEntity<K>, M> MapStorage<K, V, M> getStorage(String name, Class<K> keyType, Class<V> valueType, Class<M> modelType, Flag... flags);
} }

View file

@ -51,7 +51,7 @@ import org.keycloak.storage.SearchableModelField;
* *
* @author hmlnarik * @author hmlnarik
*/ */
public interface ModelCriteriaBuilder { public interface ModelCriteriaBuilder<M> {
/** /**
* The operators are very basic ones for this use case. In the real scenario, * The operators are very basic ones for this use case. In the real scenario,
@ -94,8 +94,15 @@ public interface ModelCriteriaBuilder {
* </ul> * </ul>
*/ */
ILIKE, ILIKE,
/** Operator for belonging into a set of values */ /**
IN * Operator for belonging into a collection of values. Operand in {@code value}
* can be an array (via an implicit conversion of the vararg), a {@link Collection} or a {@link Stream}.
*/
IN,
/** Is not null */
EXISTS,
/** Is null */
NOT_EXISTS,
} }
/** /**
@ -108,8 +115,9 @@ public interface ModelCriteriaBuilder {
* @param op Operator * @param op Operator
* @param value Additional operands of the operator. * @param value Additional operands of the operator.
* @return * @return
* @throws CriterionNotSupported If the operator is not supported for the given field.
*/ */
ModelCriteriaBuilder compare(SearchableModelField modelField, Operator op, Object... value); ModelCriteriaBuilder<M> compare(SearchableModelField<M> modelField, Operator op, Object... value);
/** /**
* Creates and returns a new instance of {@code ModelCriteriaBuilder} that * Creates and returns a new instance of {@code ModelCriteriaBuilder} that
@ -126,8 +134,9 @@ public interface ModelCriteriaBuilder {
* ); * );
* </pre> * </pre>
* *
* @throws CriterionNotSupported If the operator is not supported for the given field.
*/ */
ModelCriteriaBuilder and(ModelCriteriaBuilder... builders); ModelCriteriaBuilder<M> and(ModelCriteriaBuilder<M>... builders);
/** /**
* Creates and returns a new instance of {@code ModelCriteriaBuilder} that * Creates and returns a new instance of {@code ModelCriteriaBuilder} that
@ -143,8 +152,10 @@ public interface ModelCriteriaBuilder {
* cb.compare(FIELD1, EQ, 3).compare(FIELD2, EQ, 4) * cb.compare(FIELD1, EQ, 3).compare(FIELD2, EQ, 4)
* ); * );
* </pre> * </pre>
*
* @throws CriterionNotSupported If the operator is not supported for the given field.
*/ */
ModelCriteriaBuilder or(ModelCriteriaBuilder... builders); ModelCriteriaBuilder<M> or(ModelCriteriaBuilder<M>... builders);
/** /**
* Creates and returns a new instance of {@code ModelCriteriaBuilder} that * Creates and returns a new instance of {@code ModelCriteriaBuilder} that
@ -155,21 +166,21 @@ public interface ModelCriteriaBuilder {
* *
* @param builder * @param builder
* @return * @return
* @throws CriterionNotSupported If the operator is not supported for the given field.
*/ */
ModelCriteriaBuilder not(ModelCriteriaBuilder builder); ModelCriteriaBuilder<M> not(ModelCriteriaBuilder<M> builder);
/** /**
* Returns this object cast to the given class. * Returns this object cast to the given class, or {@code null} if the class cannot be cast to that {@code clazz}.
* @param <T> * @param <T>
* @param clazz * @param clazz
* @return * @return
* @throws ClassCastException When this instance cannot be converted to the given {@code clazz}.
*/ */
default <T extends ModelCriteriaBuilder> T unwrap(Class<T> clazz) { default <T extends ModelCriteriaBuilder> T unwrap(Class<T> clazz) {
if (clazz.isInstance(this)) { if (clazz.isInstance(this)) {
return clazz.cast(this); return clazz.cast(this);
} else { } else {
throw new ClassCastException("Incompatible class: " + clazz); return null;
} }
} }

View file

@ -0,0 +1,132 @@
/*
* Copyright 2020 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.map.storage.chm;
import org.keycloak.models.map.storage.MapModelCriteriaBuilder;
import org.keycloak.models.map.common.AbstractEntity;
import org.keycloak.models.map.storage.MapFieldPredicates;
import org.keycloak.models.map.storage.MapKeycloakTransaction;
import org.keycloak.models.map.storage.MapStorage;
import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.storage.SearchableModelField;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Stream;
import org.keycloak.models.map.storage.MapModelCriteriaBuilder.UpdatePredicatesFunc;
import java.util.Iterator;
import java.util.Objects;
import java.util.function.Predicate;
/**
*
* @author hmlnarik
*/
public class ConcurrentHashMapStorage<K, V extends AbstractEntity<K>, M> implements MapStorage<K, V, M> {
private final ConcurrentMap<K, V> store = new ConcurrentHashMap<>();
private final Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates;
@SuppressWarnings("unchecked")
public ConcurrentHashMapStorage(Class<M> modelClass) {
this.fieldPredicates = MapFieldPredicates.getPredicates(modelClass);
}
@Override
public V create(K key, V value) {
return store.putIfAbsent(key, value);
}
@Override
public V read(K key) {
Objects.requireNonNull(key, "Key must be non-null");
return store.get(key);
}
@Override
public V update(K key, V value) {
return store.replace(key, value);
}
@Override
public boolean delete(K key) {
return store.remove(key) != null;
}
@Override
public long delete(ModelCriteriaBuilder<M> criteria) {
long res;
if (criteria == null) {
res = store.size();
store.clear();
return res;
}
MapModelCriteriaBuilder<K, V, M> b = criteria.unwrap(MapModelCriteriaBuilder.class);
if (b == null) {
throw new IllegalStateException("Incompatible class: " + criteria.getClass());
}
Predicate<? super K> keyFilter = b.getKeyFilter();
Predicate<? super V> entityFilter = b.getEntityFilter();
res = 0;
for (Iterator<Entry<K, V>> iterator = store.entrySet().iterator(); iterator.hasNext();) {
Entry<K, V> next = iterator.next();
if (keyFilter.test(next.getKey()) && entityFilter.test(next.getValue())) {
res++;
iterator.remove();
}
}
return res;
}
@Override
public ModelCriteriaBuilder<M> createCriteriaBuilder() {
return new MapModelCriteriaBuilder<>(fieldPredicates);
}
@Override
public MapKeycloakTransaction<K, V, M> createTransaction() {
return new MapKeycloakTransaction<>(this);
}
@Override
public Stream<V> read(ModelCriteriaBuilder<M> criteria) {
if (criteria == null) {
return Stream.empty();
}
Stream<Entry<K, V>> stream = store.entrySet().stream();
MapModelCriteriaBuilder<K, V, M> b = criteria.unwrap(MapModelCriteriaBuilder.class);
if (b == null) {
throw new IllegalStateException("Incompatible class: " + criteria.getClass());
}
Predicate<? super K> keyFilter = b.getKeyFilter();
Predicate<? super V> entityFilter = b.getEntityFilter();
stream = stream.filter(me -> keyFilter.test(me.getKey()) && entityFilter.test(me.getValue()));
return stream.map(Map.Entry::getValue);
}
@Override
public long getCount(ModelCriteriaBuilder<M> criteria) {
return read(criteria).count();
}
}

View file

@ -14,28 +14,23 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package org.keycloak.models.map.storage; package org.keycloak.models.map.storage.chm;
import org.keycloak.Config.Scope; import org.keycloak.Config.Scope;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory; import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.map.common.AbstractEntity; import org.keycloak.models.map.common.AbstractEntity;
import org.keycloak.models.map.common.Serialization; import org.keycloak.models.map.common.Serialization;
import org.keycloak.storage.SearchableModelField;
import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.JavaType;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Predicate;
import java.util.stream.Stream;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
import org.keycloak.models.map.storage.MapStorageProvider;
import org.keycloak.models.map.storage.ModelCriteriaBuilder;
/** /**
* *
@ -43,123 +38,11 @@ import org.jboss.logging.Logger;
*/ */
public class ConcurrentHashMapStorageProvider implements MapStorageProvider { public class ConcurrentHashMapStorageProvider implements MapStorageProvider {
public static class ConcurrentHashMapStorage<K, V> implements MapStorage<K, V> { public static final String PROVIDER_ID = "concurrenthashmap";
private final ConcurrentMap<K, V> store = new ConcurrentHashMap<>();
@Override
public V create(K key, V value) {
return store.putIfAbsent(key, value);
}
@Override
public V read(K key) {
return store.get(key);
}
@Override
public V update(K key, V value) {
return store.replace(key, value);
}
@Override
public V delete(K key) {
return store.remove(key);
}
@Override
public ModelCriteriaBuilder createCriteriaBuilder() {
return new MapModelCriteriaBuilder(null);
}
@Override
public Set<Entry<K, V>> entrySet() {
return store.entrySet();
}
@Override
public Stream<V> read(ModelCriteriaBuilder criteria) {
throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates.
}
}
private static class MapModelCriteriaBuilder<M> implements ModelCriteriaBuilder {
@FunctionalInterface
public interface TriConsumer<A extends MapModelCriteriaBuilder<?>,B,C> { A apply(A a, B b, C c); }
private static final Predicate<Object> ALWAYS_TRUE = e -> true;
private static final Predicate<Object> ALWAYS_FALSE = e -> false;
private final Predicate<? super String> indexFilter;
private final Predicate<? super M> modelFilter;
private final Map<String, TriConsumer<MapModelCriteriaBuilder<M>, Operator, Object>> fieldPredicates;
public MapModelCriteriaBuilder(Map<String, TriConsumer<MapModelCriteriaBuilder<M>, Operator, Object>> fieldPredicates) {
this(fieldPredicates, ALWAYS_TRUE, ALWAYS_TRUE);
}
private MapModelCriteriaBuilder(Map<String, TriConsumer<MapModelCriteriaBuilder<M>, Operator, Object>> fieldPredicates,
Predicate<? super String> indexReadFilter, Predicate<? super M> sequentialReadFilter) {
this.fieldPredicates = fieldPredicates;
this.indexFilter = indexReadFilter;
this.modelFilter = sequentialReadFilter;
}
@Override
public ModelCriteriaBuilder compare(SearchableModelField modelField, Operator op, Object... value) {
throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates.
}
@Override
public MapModelCriteriaBuilder<M> and(ModelCriteriaBuilder... builders) {
Predicate<? super String> resIndexFilter = Stream.of(builders)
.map(MapModelCriteriaBuilder.class::cast)
.map(MapModelCriteriaBuilder::getIndexFilter)
.reduce(ALWAYS_TRUE, (p1, p2) -> p1.and(p2));
Predicate<? super M> resModelFilter = Stream.of(builders)
.map(MapModelCriteriaBuilder.class::cast)
.map(MapModelCriteriaBuilder::getModelFilter)
.reduce(ALWAYS_TRUE, (p1, p2) -> p1.and(p2));
return new MapModelCriteriaBuilder<>(fieldPredicates, resIndexFilter, resModelFilter);
}
@Override
public MapModelCriteriaBuilder<M> or(ModelCriteriaBuilder... builders) {
Predicate<? super String> resIndexFilter = Stream.of(builders)
.map(MapModelCriteriaBuilder.class::cast)
.map(MapModelCriteriaBuilder::getIndexFilter)
.reduce(ALWAYS_FALSE, (p1, p2) -> p1.or(p2));
Predicate<? super M> resModelFilter = Stream.of(builders)
.map(MapModelCriteriaBuilder.class::cast)
.map(MapModelCriteriaBuilder::getModelFilter)
.reduce(ALWAYS_FALSE, (p1, p2) -> p1.or(p2));
return new MapModelCriteriaBuilder<>(fieldPredicates, resIndexFilter, resModelFilter);
}
@Override
public MapModelCriteriaBuilder<M> not(ModelCriteriaBuilder builder) {
MapModelCriteriaBuilder b = builder.unwrap(MapModelCriteriaBuilder.class);
Predicate<? super String> resIndexFilter = b.getIndexFilter() == ALWAYS_TRUE ? ALWAYS_TRUE : b.getIndexFilter().negate();
Predicate<? super M> resModelFilter = b.getModelFilter() == ALWAYS_TRUE ? ALWAYS_TRUE : b.getModelFilter().negate();
return new MapModelCriteriaBuilder<>(fieldPredicates, resIndexFilter, resModelFilter);
}
public Predicate<? super String> getIndexFilter() {
return indexFilter;
}
public Predicate<? super M> getModelFilter() {
return modelFilter;
}
}
private static final String PROVIDER_ID = "concurrenthashmap";
private static final Logger LOG = Logger.getLogger(ConcurrentHashMapStorageProvider.class); private static final Logger LOG = Logger.getLogger(ConcurrentHashMapStorageProvider.class);
private final ConcurrentHashMap<String, ConcurrentHashMapStorage<?,?>> storages = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String, ConcurrentHashMapStorage<?,?,?>> storages = new ConcurrentHashMap<>();
private File storageDirectory; private File storageDirectory;
@ -189,13 +72,15 @@ public class ConcurrentHashMapStorageProvider implements MapStorageProvider {
storages.forEach(this::storeMap); storages.forEach(this::storeMap);
} }
private void storeMap(String fileName, ConcurrentHashMapStorage<?, ?> store) { private void storeMap(String fileName, ConcurrentHashMapStorage<?, ?, ?> store) {
if (fileName != null) { if (fileName != null) {
File f = getFile(fileName); File f = getFile(fileName);
try { try {
if (storageDirectory != null && storageDirectory.exists()) { if (storageDirectory != null && storageDirectory.exists()) {
LOG.debugf("Storing contents to %s", f.getCanonicalPath()); LOG.debugf("Storing contents to %s", f.getCanonicalPath());
Serialization.MAPPER.writeValue(f, store.entrySet().stream().map(Map.Entry::getValue)); @SuppressWarnings("unchecked")
final ModelCriteriaBuilder readAllCriteria = store.createCriteriaBuilder();
Serialization.MAPPER.writeValue(f, store.read(readAllCriteria));
} else { } else {
LOG.debugf("Not storing contents of %s because directory %s does not exist", fileName, this.storageDirectory); LOG.debugf("Not storing contents of %s because directory %s does not exist", fileName, this.storageDirectory);
} }
@ -205,8 +90,9 @@ public class ConcurrentHashMapStorageProvider implements MapStorageProvider {
} }
} }
private <K, V extends AbstractEntity<K>> ConcurrentHashMapStorage<K, V> loadMap(String fileName, Class<V> valueType, EnumSet<Flag> flags) { private <K, V extends AbstractEntity<K>, M> ConcurrentHashMapStorage<K, V, M> loadMap(String fileName,
ConcurrentHashMapStorage<K, V> store = new ConcurrentHashMapStorage<>(); Class<V> valueType, Class<M> modelType, EnumSet<Flag> flags) {
ConcurrentHashMapStorage<K, V, M> store = new ConcurrentHashMapStorage<>(modelType);
if (! flags.contains(Flag.INITIALIZE_EMPTY)) { if (! flags.contains(Flag.INITIALIZE_EMPTY)) {
final File f = getFile(fileName); final File f = getFile(fileName);
@ -233,9 +119,10 @@ public class ConcurrentHashMapStorageProvider implements MapStorageProvider {
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <K, V extends AbstractEntity<K>> ConcurrentHashMapStorage<K, V> getStorage(String name, Class<K> keyType, Class<V> valueType, Flag... flags) { public <K, V extends AbstractEntity<K>, M> ConcurrentHashMapStorage<K, V, M> getStorage(String name,
Class<K> keyType, Class<V> valueType, Class<M> modelType, Flag... flags) {
EnumSet<Flag> f = flags == null || flags.length == 0 ? EnumSet.noneOf(Flag.class) : EnumSet.of(flags[0], flags); EnumSet<Flag> f = flags == null || flags.length == 0 ? EnumSet.noneOf(Flag.class) : EnumSet.of(flags[0], flags);
return (ConcurrentHashMapStorage<K, V>) storages.computeIfAbsent(name, n -> loadMap(name, valueType, f)); return (ConcurrentHashMapStorage<K, V, M>) storages.computeIfAbsent(name, n -> loadMap(name, valueType, modelType, f));
} }
private File getFile(String fileName) { private File getFile(String fileName) {

View file

@ -17,9 +17,9 @@
package org.keycloak.models.map.user; package org.keycloak.models.map.user;
import org.apache.commons.lang.StringUtils;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
import org.keycloak.authorization.AuthorizationProvider; import org.keycloak.authorization.AuthorizationProvider;
import org.keycloak.authorization.model.Resource;
import org.keycloak.authorization.store.ResourceStore; import org.keycloak.authorization.store.ResourceStore;
import org.keycloak.common.util.Time; import org.keycloak.common.util.Time;
import org.keycloak.component.ComponentModel; import org.keycloak.component.ComponentModel;
@ -39,17 +39,19 @@ import org.keycloak.models.RequiredActionProviderModel;
import org.keycloak.models.RoleModel; import org.keycloak.models.RoleModel;
import org.keycloak.models.UserConsentModel; import org.keycloak.models.UserConsentModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserModel.SearchableFields;
import org.keycloak.models.UserProvider; import org.keycloak.models.UserProvider;
import org.keycloak.models.map.common.Serialization; import org.keycloak.models.map.common.Serialization;
import org.keycloak.models.map.storage.MapKeycloakTransaction; import org.keycloak.models.map.storage.MapKeycloakTransaction;
import org.keycloak.models.map.storage.MapStorage; import org.keycloak.models.map.storage.MapStorage;
import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import org.keycloak.storage.StorageId; import org.keycloak.storage.StorageId;
import org.keycloak.storage.UserStorageProvider; import org.keycloak.storage.UserStorageProvider;
import org.keycloak.storage.client.ClientStorageProvider; import org.keycloak.storage.client.ClientStorageProvider;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -75,13 +77,13 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
private static final Logger LOG = Logger.getLogger(MapUserProvider.class); private static final Logger LOG = Logger.getLogger(MapUserProvider.class);
private static final Predicate<MapUserEntity> ALWAYS_FALSE = c -> { return false; }; private static final Predicate<MapUserEntity> ALWAYS_FALSE = c -> { return false; };
private final KeycloakSession session; private final KeycloakSession session;
final MapKeycloakTransaction<UUID, MapUserEntity> tx; final MapKeycloakTransaction<UUID, MapUserEntity, UserModel> tx;
private final MapStorage<UUID, MapUserEntity> userStore; private final MapStorage<UUID, MapUserEntity, UserModel> userStore;
public MapUserProvider(KeycloakSession session, MapStorage<UUID, MapUserEntity> store) { public MapUserProvider(KeycloakSession session, MapStorage<UUID, MapUserEntity, UserModel> store) {
this.session = session; this.session = session;
this.userStore = store; this.userStore = store;
this.tx = new MapKeycloakTransaction<>(userStore); this.tx = userStore.createTransaction();
session.getTransactionManager().enlist(tx); session.getTransactionManager().enlist(tx);
} }
@ -134,7 +136,7 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
} }
private Optional<MapUserEntity> getEntityById(RealmModel realm, UUID id) { private Optional<MapUserEntity> getEntityById(RealmModel realm, UUID id) {
MapUserEntity mapUserEntity = tx.read(id, userStore::read); MapUserEntity mapUserEntity = tx.read(id);
if (mapUserEntity != null && entityRealmFilter(realm).test(mapUserEntity)) { if (mapUserEntity != null && entityRealmFilter(realm).test(mapUserEntity)) {
return Optional.of(mapUserEntity); return Optional.of(mapUserEntity);
} }
@ -146,18 +148,6 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
return getEntityById(realm, id).map(this::registerEntityForChanges); return getEntityById(realm, id).map(this::registerEntityForChanges);
} }
private Stream<MapUserEntity> getNotRemovedUpdatedUsersStream() {
Stream<MapUserEntity> updatedAndNotRemovedUsersStream = userStore.entrySet().stream()
.map(tx::getUpdated) // If the group has been removed, tx.read will return null, otherwise it will return me.getValue()
.filter(Objects::nonNull);
return Stream.concat(tx.createdValuesStream(), updatedAndNotRemovedUsersStream);
}
private Stream<MapUserEntity> getUnsortedUserEntitiesStream(RealmModel realm) {
return getNotRemovedUpdatedUsersStream()
.filter(entityRealmFilter(realm));
}
@Override @Override
public void addFederatedIdentity(RealmModel realm, UserModel user, FederatedIdentityModel socialLink) { public void addFederatedIdentity(RealmModel realm, UserModel user, FederatedIdentityModel socialLink) {
if (user == null || user.getId() == null) { if (user == null || user.getId() == null) {
@ -182,7 +172,11 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
public void preRemove(RealmModel realm, IdentityProviderModel provider) { public void preRemove(RealmModel realm, IdentityProviderModel provider) {
String socialProvider = provider.getAlias(); String socialProvider = provider.getAlias();
LOG.tracef("preRemove[RealmModel realm, IdentityProviderModel provider](%s, %s)%s", realm, socialProvider, getShortStackTrace()); LOG.tracef("preRemove[RealmModel realm, IdentityProviderModel provider](%s, %s)%s", realm, socialProvider, getShortStackTrace());
getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.IDP_AND_USER, Operator.EQ, socialProvider);
tx.getUpdatedNotRemoved(mcb)
.map(this::registerEntityForChanges) .map(this::registerEntityForChanges)
.forEach(userEntity -> userEntity.removeFederatedIdentity(socialProvider)); .forEach(userEntity -> userEntity.removeFederatedIdentity(socialProvider));
} }
@ -214,9 +208,11 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
@Override @Override
public UserModel getUserByFederatedIdentity(RealmModel realm, FederatedIdentityModel socialLink) { public UserModel getUserByFederatedIdentity(RealmModel realm, FederatedIdentityModel socialLink) {
LOG.tracef("getUserByFederatedIdentity(%s, %s)%s", realm, socialLink, getShortStackTrace()); LOG.tracef("getUserByFederatedIdentity(%s, %s)%s", realm, socialLink, getShortStackTrace());
return getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> Objects.nonNull(userEntity.getFederatedIdentity(socialLink.getIdentityProvider()))) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.filter(userEntity -> Objects.equals(userEntity.getFederatedIdentity(socialLink.getIdentityProvider()).getUserId(), socialLink.getUserId())) .compare(SearchableFields.IDP_AND_USER, Operator.EQ, socialLink.getIdentityProvider(), socialLink.getUserId());
return tx.getUpdatedNotRemoved(mcb)
.collect(Collectors.collectingAndThen( .collect(Collectors.collectingAndThen(
Collectors.toList(), Collectors.toList(),
list -> { list -> {
@ -301,8 +297,11 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
@Override @Override
public UserModel getServiceAccount(ClientModel client) { public UserModel getServiceAccount(ClientModel client) {
LOG.tracef("getServiceAccount(%s)%s", client.getId(), getShortStackTrace()); LOG.tracef("getServiceAccount(%s)%s", client.getId(), getShortStackTrace());
return getUnsortedUserEntitiesStream(client.getRealm()) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> Objects.equals(userEntity.getServiceAccountClientLink(), client.getId())) .compare(SearchableFields.REALM_ID, Operator.EQ, client.getRealm().getId())
.compare(SearchableFields.SERVICE_ACCOUNT_CLIENT, Operator.EQ, client.getId());
return tx.getUpdatedNotRemoved(mcb)
.collect(Collectors.collectingAndThen( .collect(Collectors.collectingAndThen(
Collectors.toList(), Collectors.toList(),
list -> { list -> {
@ -321,14 +320,17 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
@Override @Override
public UserModel addUser(RealmModel realm, String id, String username, boolean addDefaultRoles, boolean addDefaultRequiredActions) { public UserModel addUser(RealmModel realm, String id, String username, boolean addDefaultRoles, boolean addDefaultRequiredActions) {
LOG.tracef("addUser(%s, %s, %s, %s, %s)%s", realm, id, username, addDefaultRoles, addDefaultRequiredActions, getShortStackTrace()); LOG.tracef("addUser(%s, %s, %s, %s, %s)%s", realm, id, username, addDefaultRoles, addDefaultRequiredActions, getShortStackTrace());
if (getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.anyMatch(userEntity -> Objects.equals(userEntity.getUsername(), username))) { .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.USERNAME, Operator.EQ, username);
if (tx.getCount(mcb) > 0) {
throw new ModelDuplicateException("User with username '" + username + "' in realm " + realm.getName() + " already exists" ); throw new ModelDuplicateException("User with username '" + username + "' in realm " + realm.getName() + " already exists" );
} }
final UUID entityId = id == null ? UUID.randomUUID() : UUID.fromString(id); final UUID entityId = id == null ? UUID.randomUUID() : UUID.fromString(id);
if (tx.read(entityId, userStore::read) != null) { if (tx.read(entityId) != null) {
throw new ModelDuplicateException("User exists: " + entityId); throw new ModelDuplicateException("User exists: " + entityId);
} }
@ -360,58 +362,76 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
@Override @Override
public void preRemove(RealmModel realm) { public void preRemove(RealmModel realm) {
LOG.tracef("preRemove[RealmModel](%s)%s", realm, getShortStackTrace()); LOG.tracef("preRemove[RealmModel](%s)%s", realm, getShortStackTrace());
getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.map(MapUserEntity::getId) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
.forEach(tx::delete);
tx.delete(UUID.randomUUID(), mcb);
} }
@Override @Override
public void removeImportedUsers(RealmModel realm, String storageProviderId) { public void removeImportedUsers(RealmModel realm, String storageProviderId) {
LOG.tracef("removeImportedUsers(%s, %s)%s", realm, storageProviderId, getShortStackTrace()); LOG.tracef("removeImportedUsers(%s, %s)%s", realm, storageProviderId, getShortStackTrace());
getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> Objects.equals(userEntity.getFederationLink(), storageProviderId)) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.map(MapUserEntity::getId) .compare(SearchableFields.FEDERATION_LINK, Operator.EQ, storageProviderId);
.forEach(tx::delete);
tx.delete(UUID.randomUUID(), mcb);
} }
@Override @Override
public void unlinkUsers(RealmModel realm, String storageProviderId) { public void unlinkUsers(RealmModel realm, String storageProviderId) {
LOG.tracef("unlinkUsers(%s, %s)%s", realm, storageProviderId, getShortStackTrace()); LOG.tracef("unlinkUsers(%s, %s)%s", realm, storageProviderId, getShortStackTrace());
getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> Objects.equals(userEntity.getFederationLink(), storageProviderId)) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.map(this::registerEntityForChanges) .compare(SearchableFields.FEDERATION_LINK, Operator.EQ, storageProviderId);
try (Stream<MapUserEntity> s = tx.getUpdatedNotRemoved(mcb)) {
s.map(this::registerEntityForChanges)
.forEach(userEntity -> userEntity.setFederationLink(null)); .forEach(userEntity -> userEntity.setFederationLink(null));
} }
}
@Override @Override
public void preRemove(RealmModel realm, RoleModel role) { public void preRemove(RealmModel realm, RoleModel role) {
String roleId = role.getId(); String roleId = role.getId();
LOG.tracef("preRemove[RoleModel](%s, %s)%s", realm, roleId, getShortStackTrace()); LOG.tracef("preRemove[RoleModel](%s, %s)%s", realm, roleId, getShortStackTrace());
getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> userEntity.getRolesMembership().contains(roleId)) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.map(this::registerEntityForChanges) .compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, roleId);
try (Stream<MapUserEntity> s = tx.getUpdatedNotRemoved(mcb)) {
s.map(this::registerEntityForChanges)
.forEach(userEntity -> userEntity.removeRolesMembership(roleId)); .forEach(userEntity -> userEntity.removeRolesMembership(roleId));
} }
}
@Override @Override
public void preRemove(RealmModel realm, GroupModel group) { public void preRemove(RealmModel realm, GroupModel group) {
String groupId = group.getId(); String groupId = group.getId();
LOG.tracef("preRemove[GroupModel](%s, %s)%s", realm, groupId, getShortStackTrace()); LOG.tracef("preRemove[GroupModel](%s, %s)%s", realm, groupId, getShortStackTrace());
getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> userEntity.getGroupsMembership().contains(groupId)) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.map(this::registerEntityForChanges) .compare(SearchableFields.ASSIGNED_GROUP, Operator.EQ, groupId);
try (Stream<MapUserEntity> s = tx.getUpdatedNotRemoved(mcb)) {
s.map(this::registerEntityForChanges)
.forEach(userEntity -> userEntity.removeGroupsMembership(groupId)); .forEach(userEntity -> userEntity.removeGroupsMembership(groupId));
} }
}
@Override @Override
public void preRemove(RealmModel realm, ClientModel client) { public void preRemove(RealmModel realm, ClientModel client) {
String clientId = client.getId(); String clientId = client.getId();
LOG.tracef("preRemove[ClientModel](%s, %s)%s", realm, clientId, getShortStackTrace()); LOG.tracef("preRemove[ClientModel](%s, %s)%s", realm, clientId, getShortStackTrace());
getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> Objects.nonNull(userEntity.getUserConsent(clientId))) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.map(this::registerEntityForChanges) .compare(SearchableFields.CONSENT_FOR_CLIENT, Operator.EQ, clientId);
try (Stream<MapUserEntity> s = tx.getUpdatedNotRemoved(mcb)) {
s.map(this::registerEntityForChanges)
.forEach(userEntity -> userEntity.removeUserConsent(clientId)); .forEach(userEntity -> userEntity.removeUserConsent(clientId));
} }
}
@Override @Override
public void preRemove(ProtocolMapperModel protocolMapper) { public void preRemove(ProtocolMapperModel protocolMapper) {
@ -423,11 +443,15 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
String clientScopeId = clientScope.getId(); String clientScopeId = clientScope.getId();
LOG.tracef("preRemove[ClientScopeModel](%s)%s", clientScopeId, getShortStackTrace()); LOG.tracef("preRemove[ClientScopeModel](%s)%s", clientScopeId, getShortStackTrace());
getUnsortedUserEntitiesStream(clientScope.getRealm()) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.map(this::registerEntityForChanges) .compare(SearchableFields.REALM_ID, Operator.EQ, clientScope.getRealm().getId())
.flatMap(AbstractUserEntity::getUserConsents) .compare(SearchableFields.CONSENT_WITH_CLIENT_SCOPE, Operator.EQ, clientScopeId);
try (Stream<MapUserEntity> s = tx.getUpdatedNotRemoved(mcb)) {
s.flatMap(AbstractUserEntity::getUserConsents)
.forEach(consent -> consent.removeGrantedClientScopesIds(clientScopeId)); .forEach(consent -> consent.removeGrantedClientScopesIds(clientScopeId));
} }
}
@Override @Override
public void preRemove(RealmModel realm, ComponentModel component) { public void preRemove(RealmModel realm, ComponentModel component) {
@ -437,23 +461,27 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
removeImportedUsers(realm, componentId); removeImportedUsers(realm, componentId);
} }
if (component.getProviderType().equals(ClientStorageProvider.class.getName())) { if (component.getProviderType().equals(ClientStorageProvider.class.getName())) {
getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.forEach(removeConsentsForExternalClient(componentId)); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.CONSENT_CLIENT_FEDERATION_LINK, Operator.EQ, componentId);
try (Stream<MapUserEntity> s = tx.getUpdatedNotRemoved(mcb)) {
String providerIdS = new StorageId(componentId, "").getId();
s.forEach(removeConsentsForExternalClient(providerIdS));
}
} }
} }
private Consumer<MapUserEntity> removeConsentsForExternalClient(String componentId) { private Consumer<MapUserEntity> removeConsentsForExternalClient(String idPrefix) {
return userEntity -> { return userEntity -> {
List<UserConsentEntity> consentModels = userEntity.getUserConsents() List<String> consentClientIds = userEntity.getUserConsents()
.filter(consent -> .map(UserConsentEntity::getClientId)
Objects.equals(new StorageId(consent.getClientId()).getProviderId(), componentId)) .filter(clientId -> clientId != null && clientId.startsWith(idPrefix))
.collect(Collectors.toList()); .collect(Collectors.toList());
if (consentModels.size() > 0) { if (! consentClientIds.isEmpty()) {
userEntity = registerEntityForChanges(userEntity); userEntity = registerEntityForChanges(userEntity);
for (UserConsentEntity consentEntity : consentModels) { consentClientIds.forEach(userEntity::removeUserConsent);
userEntity.removeUserConsent(consentEntity.getClientId());
}
} }
}; };
} }
@ -462,10 +490,14 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
public void grantToAllUsers(RealmModel realm, RoleModel role) { public void grantToAllUsers(RealmModel realm, RoleModel role) {
String roleId = role.getId(); String roleId = role.getId();
LOG.tracef("grantToAllUsers(%s, %s)%s", realm, roleId, getShortStackTrace()); LOG.tracef("grantToAllUsers(%s, %s)%s", realm, roleId, getShortStackTrace());
getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.map(this::registerEntityForChanges) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
try (Stream<MapUserEntity> s = tx.getUpdatedNotRemoved(mcb)) {
s.map(this::registerEntityForChanges)
.forEach(entity -> entity.addRolesMembership(roleId)); .forEach(entity -> entity.addRolesMembership(roleId));
} }
}
@Override @Override
public UserModel getUserById(RealmModel realm, String id) { public UserModel getUserById(RealmModel realm, String id) {
@ -476,20 +508,26 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
@Override @Override
public UserModel getUserByUsername(RealmModel realm, String username) { public UserModel getUserByUsername(RealmModel realm, String username) {
if (username == null) return null; if (username == null) return null;
final String usernameLowercase = username.toLowerCase();
LOG.tracef("getUserByUsername(%s, %s)%s", realm, username, getShortStackTrace()); LOG.tracef("getUserByUsername(%s, %s)%s", realm, username, getShortStackTrace());
return getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> Objects.equals(userEntity.getUsername(), usernameLowercase)) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.findFirst() .compare(SearchableFields.USERNAME, Operator.ILIKE, username);
try (Stream<MapUserEntity> s = tx.getUpdatedNotRemoved(mcb)) {
return s.findFirst()
.map(entityToAdapterFunc(realm)).orElse(null); .map(entityToAdapterFunc(realm)).orElse(null);
} }
}
@Override @Override
public UserModel getUserByEmail(RealmModel realm, String email) { public UserModel getUserByEmail(RealmModel realm, String email) {
LOG.tracef("getUserByEmail(%s, %s)%s", realm, email, getShortStackTrace()); LOG.tracef("getUserByEmail(%s, %s)%s", realm, email, getShortStackTrace());
List<MapUserEntity> usersWithEmail = getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> Objects.equals(userEntity.getEmail(), email.toLowerCase())) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.EMAIL, Operator.EQ, email);
List<MapUserEntity> usersWithEmail = tx.getUpdatedNotRemoved(mcb)
.filter(userEntity -> Objects.equals(userEntity.getEmail(), email))
.collect(Collectors.toList()); .collect(Collectors.toList());
if (usersWithEmail.isEmpty()) return null; if (usersWithEmail.isEmpty()) return null;
@ -524,25 +562,28 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
@Override @Override
public int getUsersCount(RealmModel realm, boolean includeServiceAccount) { public int getUsersCount(RealmModel realm, boolean includeServiceAccount) {
LOG.tracef("getUsersCount(%s, %s)%s", realm, includeServiceAccount, getShortStackTrace()); LOG.tracef("getUsersCount(%s, %s)%s", realm, includeServiceAccount, getShortStackTrace());
Stream<MapUserEntity> unsortedUserEntitiesStream = getUnsortedUserEntitiesStream(realm); ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
if (!includeServiceAccount) { if (! includeServiceAccount) {
unsortedUserEntitiesStream = unsortedUserEntitiesStream mcb = mcb.compare(SearchableFields.SERVICE_ACCOUNT_CLIENT, Operator.NOT_EXISTS);
.filter(userEntity -> Objects.isNull(userEntity.getServiceAccountClientLink()));
} }
return (int) unsortedUserEntitiesStream.count(); return (int) tx.getCount(mcb);
} }
@Override @Override
public Stream<UserModel> getUsersStream(RealmModel realm, Integer firstResult, Integer maxResults, boolean includeServiceAccounts) { public Stream<UserModel> getUsersStream(RealmModel realm, Integer firstResult, Integer maxResults, boolean includeServiceAccounts) {
LOG.tracef("getUsersStream(%s, %d, %d, %s)%s", realm, firstResult, maxResults, includeServiceAccounts, getShortStackTrace()); LOG.tracef("getUsersStream(%s, %d, %d, %s)%s", realm, firstResult, maxResults, includeServiceAccounts, getShortStackTrace());
Stream<MapUserEntity> usersStream = getUnsortedUserEntitiesStream(realm); ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
if (!includeServiceAccounts) { .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
usersStream = usersStream.filter(userEntity -> Objects.isNull(userEntity.getServiceAccountClientLink()));
if (! includeServiceAccounts) {
mcb = mcb.compare(SearchableFields.SERVICE_ACCOUNT_CLIENT, Operator.NOT_EXISTS);
} }
return paginatedStream(usersStream.sorted(MapUserEntity.COMPARE_BY_USERNAME), firstResult, maxResults) return paginatedStream(tx.getUpdatedNotRemoved(mcb)
.sorted(MapUserEntity.COMPARE_BY_USERNAME), firstResult, maxResults)
.map(entityToAdapterFunc(realm)); .map(entityToAdapterFunc(realm));
} }
@ -564,11 +605,12 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
@Override @Override
public Stream<UserModel> searchForUserStream(RealmModel realm, Map<String, String> attributes, Integer firstResult, Integer maxResults) { public Stream<UserModel> searchForUserStream(RealmModel realm, Map<String, String> attributes, Integer firstResult, Integer maxResults) {
LOG.tracef("searchForUserStream(%s, %s, %d, %d)%s", realm, attributes, firstResult, maxResults, getShortStackTrace()); LOG.tracef("searchForUserStream(%s, %s, %d, %d)%s", realm, attributes, firstResult, maxResults, getShortStackTrace());
/* Find all predicates based on attributes map */
List<Predicate<MapUserEntity>> predicatesList = new ArrayList<>();
if (!session.getAttributeOrDefault(UserModel.INCLUDE_SERVICE_ACCOUNT, true)) { ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
predicatesList.add(userEntity -> Objects.isNull(userEntity.getServiceAccountClientLink())); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
if (! session.getAttributeOrDefault(UserModel.INCLUDE_SERVICE_ACCOUNT, true)) {
mcb = mcb.compare(SearchableFields.SERVICE_ACCOUNT_CLIENT, Operator.NOT_EXISTS);
} }
final boolean exactSearch = Boolean.parseBoolean(attributes.getOrDefault(UserModel.EXACT, Boolean.FALSE.toString())); final boolean exactSearch = Boolean.parseBoolean(attributes.getOrDefault(UserModel.EXACT, Boolean.FALSE.toString()));
@ -579,81 +621,86 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
if (value == null) { if (value == null) {
continue; continue;
} }
value = value.trim();
final String searchedString = value.toLowerCase(); final String searchedString = exactSearch ? value : ("%" + value + "%");
Function<Function<MapUserEntity, String>, Predicate<MapUserEntity>> containsOrExactPredicate =
func -> {
return userEntity -> testContainsOrExact(func.apply(userEntity), searchedString, exactSearch);
};
switch (key) { switch (key) {
case UserModel.SEARCH: case UserModel.SEARCH:
List<Predicate<MapUserEntity>> orPredicates = new ArrayList<>(); for (String stringToSearch : value.trim().split("\\s+")) {
orPredicates.add(userEntity -> StringUtils.containsIgnoreCase(userEntity.getUsername(), searchedString)); if (value.isEmpty()) {
orPredicates.add(userEntity -> StringUtils.containsIgnoreCase(userEntity.getEmail(), searchedString)); continue;
orPredicates.add(userEntity -> StringUtils.containsIgnoreCase(concatFirstNameLastName(userEntity), searchedString)); }
final String s = exactSearch ? stringToSearch : ("%" + stringToSearch + "%");
predicatesList.add(orPredicates.stream().reduce(Predicate::or).orElse(t -> false)); mcb = mcb.or(
userStore.createCriteriaBuilder().compare(SearchableFields.USERNAME, Operator.ILIKE, s),
userStore.createCriteriaBuilder().compare(SearchableFields.EMAIL, Operator.ILIKE, s),
userStore.createCriteriaBuilder().compare(SearchableFields.FIRST_NAME, Operator.ILIKE, s),
userStore.createCriteriaBuilder().compare(SearchableFields.LAST_NAME, Operator.ILIKE, s)
);
}
break; break;
case USERNAME: case USERNAME:
predicatesList.add(containsOrExactPredicate.apply(MapUserEntity::getUsername)); mcb = mcb.compare(SearchableFields.USERNAME, Operator.ILIKE, searchedString);
break; break;
case FIRST_NAME: case FIRST_NAME:
predicatesList.add(containsOrExactPredicate.apply(MapUserEntity::getFirstName)); mcb = mcb.compare(SearchableFields.FIRST_NAME, Operator.ILIKE, searchedString);
break; break;
case LAST_NAME: case LAST_NAME:
predicatesList.add(containsOrExactPredicate.apply(MapUserEntity::getLastName)); mcb = mcb.compare(SearchableFields.LAST_NAME, Operator.ILIKE, searchedString);
break; break;
case EMAIL: case EMAIL:
predicatesList.add(containsOrExactPredicate.apply(MapUserEntity::getEmail)); mcb = mcb.compare(SearchableFields.EMAIL, Operator.ILIKE, searchedString);
break; break;
case EMAIL_VERIFIED: { case EMAIL_VERIFIED: {
boolean booleanValue = Boolean.parseBoolean(searchedString); boolean booleanValue = Boolean.parseBoolean(value);
predicatesList.add(userEntity -> Objects.equals(userEntity.isEmailVerified(), booleanValue)); mcb = mcb.compare(SearchableFields.EMAIL_VERIFIED, Operator.EQ, booleanValue);
break; break;
} }
case UserModel.ENABLED: { case UserModel.ENABLED: {
boolean booleanValue = Boolean.parseBoolean(searchedString); boolean booleanValue = Boolean.parseBoolean(value);
predicatesList.add(userEntity -> Objects.equals(userEntity.isEnabled(), booleanValue)); mcb = mcb.compare(SearchableFields.ENABLED, Operator.EQ, booleanValue);
break; break;
} }
case UserModel.IDP_ALIAS: { case UserModel.IDP_ALIAS: {
predicatesList.add(mapUserEntity -> Objects.nonNull(mapUserEntity.getFederatedIdentity(value))); if (! attributes.containsKey(UserModel.IDP_USER_ID)) {
mcb = mcb.compare(SearchableFields.IDP_AND_USER, Operator.EQ, value);
}
break; break;
} }
case UserModel.IDP_USER_ID: { case UserModel.IDP_USER_ID: {
predicatesList.add(mapUserEntity -> mapUserEntity.getFederatedIdentities() mcb = mcb.compare(SearchableFields.IDP_AND_USER, Operator.EQ, attributes.get(UserModel.IDP_ALIAS), value);
.anyMatch(idp -> Objects.equals(idp.getUserId(), value)));
break; break;
} }
} }
} }
// Only return those results that the current user is authorized to view,
// i.e. there is an intersection of groups with view permission of the current
// user (passed in via UserModel.GROUPS attribute), the groups for the returned
// users, and the respective group resource available from the authorization provider
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Set<String> userGroups = (Set<String>) session.getAttribute(UserModel.GROUPS); Set<String> userGroups = (Set<String>) session.getAttribute(UserModel.GROUPS);
if (userGroups != null && userGroups.size() > 0) { if (userGroups != null) {
final ResourceStore resourceStore = session.getProvider(AuthorizationProvider.class).getStoreFactory() if (userGroups.isEmpty()) {
.getResourceStore(); return Stream.empty();
final Predicate<String> resourceByGroupIdExists = id -> resourceStore
.findByResourceServer(Collections.singletonMap("name", new String[] { "group.resource." + id }),
null, 0, 1).size() == 1;
predicatesList.add(userEntity -> {
return userEntity.getGroupsMembership()
.stream()
.filter(userGroups::contains)
.anyMatch(resourceByGroupIdExists);
});
} }
// Prepare resulting predicate final ResourceStore resourceStore = session.getProvider(AuthorizationProvider.class).getStoreFactory().getResourceStore();
Predicate<MapUserEntity> resultingPredicate = predicatesList.stream()
.reduce(Predicate::and) // Combine all predicates with and
.orElse(t -> true); // If there is no predicate in predicatesList, return all users
Stream<MapUserEntity> usersStream = getUnsortedUserEntitiesStream(realm) // Get stream of all users in the realm HashSet<String> authorizedGroups = new HashSet<>(userGroups);
.filter(resultingPredicate) // Apply all predicates to userStream authorizedGroups.removeIf(id -> {
Map<String, String[]> values = new HashMap<>();
values.put(Resource.EXACT_NAME, new String[] { "true" });
values.put("name", new String[] { "group.resource." + id });
return resourceStore.findByResourceServer(values, null, 0, 1).isEmpty();
});
mcb = mcb.compare(SearchableFields.ASSIGNED_GROUP, Operator.IN, authorizedGroups);
}
Stream<MapUserEntity> usersStream = tx.getUpdatedNotRemoved(mcb)
.sorted(AbstractUserEntity.COMPARE_BY_USERNAME); // Sort before paginating .sorted(AbstractUserEntity.COMPARE_BY_USERNAME); // Sort before paginating
return paginatedStream(usersStream, firstResult, maxResults) // paginate if necessary return paginatedStream(usersStream, firstResult, maxResults) // paginate if necessary
@ -661,45 +708,27 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
.filter(Objects::nonNull); .filter(Objects::nonNull);
} }
private String concatFirstNameLastName(MapUserEntity entity) {
StringBuilder stringBuilder = new StringBuilder();
if (entity.getFirstName() != null) {
stringBuilder.append(entity.getFirstName());
}
stringBuilder.append(" ");
if (entity.getLastName() != null) {
stringBuilder.append(entity.getLastName());
}
return stringBuilder.toString();
}
private boolean testContainsOrExact(String testedString, String searchedString, boolean exactMatch) {
if (exactMatch) {
return StringUtils.equalsIgnoreCase(testedString, searchedString);
} else {
return StringUtils.containsIgnoreCase(testedString, searchedString);
}
}
@Override @Override
public Stream<UserModel> getGroupMembersStream(RealmModel realm, GroupModel group, Integer firstResult, Integer maxResults) { public Stream<UserModel> getGroupMembersStream(RealmModel realm, GroupModel group, Integer firstResult, Integer maxResults) {
LOG.tracef("getGroupMembersStream(%s, %s, %d, %d)%s", realm, group.getId(), firstResult, maxResults, getShortStackTrace()); LOG.tracef("getGroupMembersStream(%s, %s, %d, %d)%s", realm, group.getId(), firstResult, maxResults, getShortStackTrace());
return paginatedStream(getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> userEntity.getGroupsMembership().contains(group.getId())) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.sorted(MapUserEntity.COMPARE_BY_USERNAME), firstResult, maxResults) .compare(SearchableFields.ASSIGNED_GROUP, Operator.EQ, group.getId());
return paginatedStream(tx.getUpdatedNotRemoved(mcb).sorted(MapUserEntity.COMPARE_BY_USERNAME), firstResult, maxResults)
.map(entityToAdapterFunc(realm)); .map(entityToAdapterFunc(realm));
} }
@Override @Override
public Stream<UserModel> searchForUserByUserAttributeStream(RealmModel realm, String attrName, String attrValue) { public Stream<UserModel> searchForUserByUserAttributeStream(RealmModel realm, String attrName, String attrValue) {
LOG.tracef("searchForUserByUserAttributeStream(%s, %s, %s)%s", realm, attrName, attrValue, getShortStackTrace()); LOG.tracef("searchForUserByUserAttributeStream(%s, %s, %s)%s", realm, attrName, attrValue, getShortStackTrace());
return getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(userEntity -> userEntity.getAttribute(attrName).contains(attrValue)) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.map(entityToAdapterFunc(realm)) .compare(SearchableFields.ATTRIBUTE, Operator.EQ, attrName, attrValue);
.sorted(UserModel.COMPARE_BY_USERNAME);
return tx.getUpdatedNotRemoved(mcb)
.sorted(MapUserEntity.COMPARE_BY_USERNAME)
.map(entityToAdapterFunc(realm));
} }
@Override @Override
@ -722,8 +751,11 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
@Override @Override
public Stream<UserModel> getRoleMembersStream(RealmModel realm, RoleModel role, Integer firstResult, Integer maxResults) { public Stream<UserModel> getRoleMembersStream(RealmModel realm, RoleModel role, Integer firstResult, Integer maxResults) {
LOG.tracef("getRoleMembersStream(%s, %s, %d, %d)%s", realm, role, firstResult, maxResults, getShortStackTrace()); LOG.tracef("getRoleMembersStream(%s, %s, %d, %d)%s", realm, role, firstResult, maxResults, getShortStackTrace());
return paginatedStream(getUnsortedUserEntitiesStream(realm) ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.filter(entity -> entity.getRolesMembership().contains(role.getId())) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, role.getId());
return paginatedStream(tx.getUpdatedNotRemoved(mcb)
.sorted(MapUserEntity.COMPARE_BY_USERNAME), firstResult, maxResults) .sorted(MapUserEntity.COMPARE_BY_USERNAME), firstResult, maxResults)
.map(entityToAdapterFunc(realm)); .map(entityToAdapterFunc(realm));
} }

View file

@ -19,6 +19,7 @@ package org.keycloak.models.map.user;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory; import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.UserModel;
import org.keycloak.models.UserProvider; import org.keycloak.models.UserProvider;
import org.keycloak.models.UserProviderFactory; import org.keycloak.models.UserProviderFactory;
import org.keycloak.models.map.common.AbstractMapProviderFactory; import org.keycloak.models.map.common.AbstractMapProviderFactory;
@ -33,12 +34,12 @@ import java.util.UUID;
*/ */
public class MapUserProviderFactory extends AbstractMapProviderFactory<UserProvider> implements UserProviderFactory { public class MapUserProviderFactory extends AbstractMapProviderFactory<UserProvider> implements UserProviderFactory {
private MapStorage<UUID, MapUserEntity> store; private MapStorage<UUID, MapUserEntity, UserModel> store;
@Override @Override
public void postInit(KeycloakSessionFactory factory) { public void postInit(KeycloakSessionFactory factory) {
MapStorageProvider sp = (MapStorageProvider) factory.getProviderFactory(MapStorageProvider.class); MapStorageProvider sp = (MapStorageProvider) factory.getProviderFactory(MapStorageProvider.class);
this.store = sp.getStorage("users", UUID.class, MapUserEntity.class); this.store = sp.getStorage("users", UUID.class, MapUserEntity.class, UserModel.class);
} }

View file

@ -23,6 +23,7 @@ import java.util.Set;
import org.keycloak.common.util.ObjectUtil; import org.keycloak.common.util.ObjectUtil;
import org.keycloak.provider.ProviderEvent; import org.keycloak.provider.ProviderEvent;
import org.keycloak.provider.ProviderEventManager; import org.keycloak.provider.ProviderEventManager;
import org.keycloak.storage.SearchableModelField;
/** /**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a> * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
@ -36,6 +37,12 @@ public interface ClientModel extends ClientScopeModel, RoleContainerModel, Prot
String PUBLIC_KEY = "publicKey"; String PUBLIC_KEY = "publicKey";
String X509CERTIFICATE = "X509Certificate"; String X509CERTIFICATE = "X509Certificate";
public static class SearchableFields {
public static final SearchableModelField<ClientModel> ID = new SearchableModelField<>("id", String.class);
public static final SearchableModelField<ClientModel> REALM_ID = new SearchableModelField<>("realmId", String.class);
public static final SearchableModelField<ClientModel> CLIENT_ID = new SearchableModelField<>("clientId", String.class);
}
interface ClientCreationEvent extends ProviderEvent { interface ClientCreationEvent extends ProviderEvent {
ClientModel getCreatedClient(); ClientModel getCreatedClient();
} }

View file

@ -19,6 +19,7 @@ package org.keycloak.models;
import org.keycloak.provider.ProviderEvent; import org.keycloak.provider.ProviderEvent;
import org.keycloak.storage.SearchableModelField;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -31,6 +32,20 @@ import java.util.stream.Stream;
* @version $Revision: 1 $ * @version $Revision: 1 $
*/ */
public interface GroupModel extends RoleMapperModel { public interface GroupModel extends RoleMapperModel {
public static class SearchableFields {
public static final SearchableModelField<GroupModel> ID = new SearchableModelField<>("id", String.class);
public static final SearchableModelField<GroupModel> REALM_ID = new SearchableModelField<>("realmId", String.class);
/** Parent group ID */
public static final SearchableModelField<GroupModel> PARENT_ID = new SearchableModelField<>("parentGroupId", String.class);
public static final SearchableModelField<GroupModel> NAME = new SearchableModelField<>("name", String.class);
/**
* Field for comparison with roles granted to this group.
* A role can be checked for belonging only via EQ operator. Role is referred by their ID
*/
public static final SearchableModelField<GroupModel> ASSIGNED_ROLE = new SearchableModelField<>("assignedRole", String.class);
}
interface GroupRemovedEvent extends ProviderEvent { interface GroupRemovedEvent extends ProviderEvent {
RealmModel getRealm(); RealmModel getRealm();
GroupModel getGroup(); GroupModel getGroup();

View file

@ -17,6 +17,7 @@
package org.keycloak.models; package org.keycloak.models;
import org.keycloak.storage.SearchableModelField;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -28,6 +29,18 @@ import java.util.stream.Stream;
* @version $Revision: 1 $ * @version $Revision: 1 $
*/ */
public interface RoleModel { public interface RoleModel {
public static class SearchableFields {
public static final SearchableModelField<RoleModel> ID = new SearchableModelField<>("id", String.class);
public static final SearchableModelField<RoleModel> REALM_ID = new SearchableModelField<>("realmId", String.class);
/** If client role, ID of the client (not the clientId) */
public static final SearchableModelField<RoleModel> CLIENT_ID = new SearchableModelField<>("clientId", String.class);
public static final SearchableModelField<RoleModel> NAME = new SearchableModelField<>("name", String.class);
public static final SearchableModelField<RoleModel> DESCRIPTION = new SearchableModelField<>("description", String.class);
public static final SearchableModelField<RoleModel> IS_CLIENT_ROLE = new SearchableModelField<>("isClientRole", Boolean.class);
public static final SearchableModelField<RoleModel> IS_COMPOSITE_ROLE = new SearchableModelField<>("isCompositeRole", Boolean.class);
}
String getName(); String getName();
String getDescription(); String getDescription();

View file

@ -19,6 +19,7 @@ package org.keycloak.models;
import org.keycloak.provider.ProviderEvent; import org.keycloak.provider.ProviderEvent;
import org.keycloak.storage.SearchableModelField;
import java.util.Comparator; import java.util.Comparator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
@ -48,6 +49,50 @@ public interface UserModel extends RoleMapperModel {
Comparator<UserModel> COMPARE_BY_USERNAME = Comparator.comparing(UserModel::getUsername, String.CASE_INSENSITIVE_ORDER); Comparator<UserModel> COMPARE_BY_USERNAME = Comparator.comparing(UserModel::getUsername, String.CASE_INSENSITIVE_ORDER);
public static class SearchableFields {
public static final SearchableModelField<UserModel> ID = new SearchableModelField<>("id", String.class);
public static final SearchableModelField<UserModel> REALM_ID = new SearchableModelField<>("realmId", String.class);
public static final SearchableModelField<UserModel> USERNAME = new SearchableModelField<>("username", String.class);
public static final SearchableModelField<UserModel> FIRST_NAME = new SearchableModelField<>("firstName", String.class);
public static final SearchableModelField<UserModel> LAST_NAME = new SearchableModelField<>("lastName", String.class);
public static final SearchableModelField<UserModel> EMAIL = new SearchableModelField<>("email", String.class);
public static final SearchableModelField<UserModel> ENABLED = new SearchableModelField<>("enabled", Boolean.class);
public static final SearchableModelField<UserModel> EMAIL_VERIFIED = new SearchableModelField<>("emailVerified", Boolean.class);
public static final SearchableModelField<UserModel> FEDERATION_LINK = new SearchableModelField<>("federationLink", String.class);
/**
* This field can only searched either for users coming from an IDP, then the operand is (idp_alias),
* or as user coming from a particular IDP with given username there, then the operand is a pair (idp_alias, idp_user_id).
* It is also possible to search regardless of {@code idp_alias}, then the pair is {@code (null, idp_user_id)}.
*/
public static final SearchableModelField<UserModel> IDP_AND_USER = new SearchableModelField<>("idpAlias:idpUserId", String.class);
public static final SearchableModelField<UserModel> ASSIGNED_ROLE = new SearchableModelField<>("assignedRole", String.class);
public static final SearchableModelField<UserModel> ASSIGNED_GROUP = new SearchableModelField<>("assignedGroup", String.class);
/**
* Search for users that have consent set for a particular client.
*/
public static final SearchableModelField<UserModel> CONSENT_FOR_CLIENT = new SearchableModelField<>("clientConsent", String.class);
/**
* Search for users that have consent set for a particular client that originates in the given client provider.
*/
public static final SearchableModelField<UserModel> CONSENT_CLIENT_FEDERATION_LINK = new SearchableModelField<>("clientConsentFederationLink", String.class);
/**
* Search for users that have consent that has given client scope.
*/
public static final SearchableModelField<UserModel> CONSENT_WITH_CLIENT_SCOPE = new SearchableModelField<>("consentWithClientScope", String.class);
/**
* ID of the client corresponding to the service account
*/
public static final SearchableModelField<UserModel> SERVICE_ACCOUNT_CLIENT = new SearchableModelField<>("serviceAccountClientId", String.class);
/**
* Search for attribute value. The parameters is a pair {@code (attribute_name, values...)} where {@code attribute_name}
* is always checked for equality, and the value (which can be any numbert of values, none for operators like EXISTS
* or potentially many for e.g. IN) is checked per the operator.
*/
public static final SearchableModelField<UserModel> ATTRIBUTE = new SearchableModelField<>("attribute", String[].class);
}
interface UserRemovedEvent extends ProviderEvent { interface UserRemovedEvent extends ProviderEvent {
RealmModel getRealm(); RealmModel getRealm();
UserModel getUser(); UserModel getUser();

View file

@ -21,6 +21,7 @@ import java.util.Map;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.storage.SearchableModelField;
/** /**
* Represents usually one browser session with potentially many browser tabs. Every browser tab is represented by * Represents usually one browser session with potentially many browser tabs. Every browser tab is represented by
@ -30,6 +31,12 @@ import org.keycloak.models.RealmModel;
*/ */
public interface RootAuthenticationSessionModel { public interface RootAuthenticationSessionModel {
public static class SearchableFields {
public static final SearchableModelField<RootAuthenticationSessionModel> ID = new SearchableModelField<>("id", String.class);
public static final SearchableModelField<RootAuthenticationSessionModel> REALM_ID = new SearchableModelField<>("realmId", String.class);
public static final SearchableModelField<RootAuthenticationSessionModel> TIMESTAMP = new SearchableModelField<>("timestamp", Long.class);
}
/** /**
* Returns id of the root authentication session. * Returns id of the root authentication session.
* @return {@code String} * @return {@code String}

View file

@ -16,6 +16,8 @@
*/ */
package org.keycloak.storage; package org.keycloak.storage;
import java.util.Objects;
/** /**
* *
* @author hmlnarik * @author hmlnarik
@ -37,4 +39,34 @@ public class SearchableModelField<M> {
public Class<?> getFieldType() { public Class<?> getFieldType() {
return fieldClass; return fieldClass;
} }
@Override
public int hashCode() {
int hash = 5;
hash = 83 * hash + Objects.hashCode(this.name);
return hash;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
final SearchableModelField<?> other = (SearchableModelField<?>) obj;
if ( ! Objects.equals(this.name, other.name)) {
return false;
}
return true;
}
@Override
public String toString() {
return "SearchableModelField " + name + " @ " + getClass().getTypeParameters()[0].getTypeName();
}
} }