KEYCLOAK-18370 Introduce QueryParameters

This commit is contained in:
mhajas 2021-06-30 16:14:16 +02:00 committed by Hynek Mlnařík
parent 7d26b245de
commit dc1c9b944f
26 changed files with 520 additions and 352 deletions

View file

@ -40,6 +40,7 @@ import java.util.function.Predicate;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
/** /**
* @author <a href="mailto:mkanis@redhat.com">Martin Kanis</a> * @author <a href="mailto:mkanis@redhat.com">Martin Kanis</a>
@ -144,7 +145,7 @@ public class MapRootAuthenticationSessionProvider<K> implements AuthenticationSe
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.TIMESTAMP, Operator.LT, expired); .compare(SearchableFields.TIMESTAMP, Operator.LT, expired);
long deletedCount = tx.delete(sessionStore.getKeyConvertor().yieldNewUniqueKey(), mcb); long deletedCount = tx.delete(sessionStore.getKeyConvertor().yieldNewUniqueKey(), withCriteria(mcb));
LOG.debugf("Removed %d expired authentication sessions for realm '%s'", deletedCount, realm.getName()); LOG.debugf("Removed %d expired authentication sessions for realm '%s'", deletedCount, realm.getName());
} }
@ -155,7 +156,7 @@ public class MapRootAuthenticationSessionProvider<K> implements AuthenticationSe
ModelCriteriaBuilder<RootAuthenticationSessionModel> mcb = sessionStore.createCriteriaBuilder() ModelCriteriaBuilder<RootAuthenticationSessionModel> mcb = sessionStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
sessionStore.delete(mcb); sessionStore.delete(withCriteria(mcb));
} }
@Override @Override

View file

@ -35,7 +35,6 @@ import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator; import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator;
import java.util.EnumMap; import java.util.EnumMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -45,6 +44,8 @@ import java.util.stream.Collectors;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.models.map.storage.QueryParameters.Order.ASCENDING;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
import static org.keycloak.utils.StreamsUtil.distinctByKey; import static org.keycloak.utils.StreamsUtil.distinctByKey;
import static org.keycloak.utils.StreamsUtil.paginatedStream; import static org.keycloak.utils.StreamsUtil.paginatedStream;
@ -90,7 +91,7 @@ public class MapPermissionTicketStore<K extends Comparable<K>> implements Permis
.toArray(ModelCriteriaBuilder[]::new) .toArray(ModelCriteriaBuilder[]::new)
); );
return tx.getCount(mcb); return tx.getCount(withCriteria(mcb));
} }
@Override @Override
@ -108,8 +109,8 @@ public class MapPermissionTicketStore<K extends Comparable<K>> implements Permis
if (scopeId != null) { if (scopeId != null) {
mcb = mcb.compare(SearchableFields.SCOPE_ID, Operator.EQ, scopeId); mcb = mcb.compare(SearchableFields.SCOPE_ID, Operator.EQ, scopeId);
} }
if (tx.getCount(mcb) > 0) { if (tx.getCount(withCriteria(mcb)) > 0) {
throw new ModelDuplicateException("Permission ticket for resource server: '" + resourceServer.getId() throw new ModelDuplicateException("Permission ticket for resource server: '" + resourceServer.getId()
+ ", Resource: " + resourceId + ", owner: " + owner + ", scopeId: " + scopeId + " already exists."); + ", Resource: " + resourceId + ", owner: " + owner + ", scopeId: " + scopeId + " already exists.");
} }
@ -142,8 +143,8 @@ public class MapPermissionTicketStore<K extends Comparable<K>> implements Permis
public PermissionTicket findById(String id, String resourceServerId) { public PermissionTicket findById(String id, String resourceServerId) {
LOG.tracef("findById(%s, %s)%s", id, resourceServerId, getShortStackTrace()); LOG.tracef("findById(%s, %s)%s", id, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(PermissionTicket.SearchableFields.ID, Operator.EQ, id)) .compare(SearchableFields.ID, Operator.EQ, id)))
.findFirst() .findFirst()
.map(this::entityToAdapter) .map(this::entityToAdapter)
.orElse(null); .orElse(null);
@ -153,7 +154,7 @@ public class MapPermissionTicketStore<K extends Comparable<K>> implements Permis
public List<PermissionTicket> findByResourceServer(String resourceServerId) { public List<PermissionTicket> findByResourceServer(String resourceServerId) {
LOG.tracef("findByResourceServer(%s)%s", resourceServerId, getShortStackTrace()); LOG.tracef("findByResourceServer(%s)%s", resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId)) return tx.read(withCriteria(forResourceServer(resourceServerId)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -162,8 +163,8 @@ public class MapPermissionTicketStore<K extends Comparable<K>> implements Permis
public List<PermissionTicket> findByOwner(String owner, String resourceServerId) { public List<PermissionTicket> findByOwner(String owner, String resourceServerId) {
LOG.tracef("findByOwner(%s, %s)%s", owner, resourceServerId, getShortStackTrace()); LOG.tracef("findByOwner(%s, %s)%s", owner, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.OWNER, Operator.EQ, owner)) .compare(SearchableFields.OWNER, Operator.EQ, owner)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -172,8 +173,8 @@ public class MapPermissionTicketStore<K extends Comparable<K>> implements Permis
public List<PermissionTicket> findByResource(String resourceId, String resourceServerId) { public List<PermissionTicket> findByResource(String resourceId, String resourceServerId) {
LOG.tracef("findByResource(%s, %s)%s", resourceId, resourceServerId, getShortStackTrace()); LOG.tracef("findByResource(%s, %s)%s", resourceId, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.RESOURCE_ID, Operator.EQ, resourceId)) .compare(SearchableFields.RESOURCE_ID, Operator.EQ, resourceId)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -182,8 +183,8 @@ public class MapPermissionTicketStore<K extends Comparable<K>> implements Permis
public List<PermissionTicket> findByScope(String scopeId, String resourceServerId) { public List<PermissionTicket> findByScope(String scopeId, String resourceServerId) {
LOG.tracef("findByScope(%s, %s)%s", scopeId, resourceServerId, getShortStackTrace()); LOG.tracef("findByScope(%s, %s)%s", scopeId, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.SCOPE_ID, Operator.EQ, scopeId)) .compare(SearchableFields.SCOPE_ID, Operator.EQ, scopeId)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -212,11 +213,9 @@ public class MapPermissionTicketStore<K extends Comparable<K>> implements Permis
.toArray(ModelCriteriaBuilder[]::new) .toArray(ModelCriteriaBuilder[]::new)
); );
Comparator<? super MapPermissionTicketEntity<K>> c = Comparator.comparing(MapPermissionTicketEntity::getId); return tx.read(withCriteria(mcb).pagination(firstResult, maxResult, SearchableFields.ID))
return paginatedStream(tx.read(mcb)
.sorted(c), firstResult, maxResult)
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private ModelCriteriaBuilder<PermissionTicket> filterEntryToModelCriteriaBuilder(Map.Entry<PermissionTicket.FilterOption, String> entry) { private ModelCriteriaBuilder<PermissionTicket> filterEntryToModelCriteriaBuilder(Map.Entry<PermissionTicket.FilterOption, String> entry) {
@ -297,12 +296,11 @@ public class MapPermissionTicketStore<K extends Comparable<K>> implements Permis
.findById(ticket.getResourceId(), ticket.getResourceServerId()); .findById(ticket.getResourceId(), ticket.getResourceServerId());
} }
return paginatedStream(tx.read(mcb) return paginatedStream(tx.read(withCriteria(mcb).orderBy(SearchableFields.RESOURCE_ID, ASCENDING))
.filter(distinctByKey(MapPermissionTicketEntity::getResourceId)) .filter(distinctByKey(MapPermissionTicketEntity::getResourceId))
.sorted(MapPermissionTicketEntity.COMPARE_BY_RESOURCE_ID) .map(ticketResourceMapper)
.map(ticketResourceMapper) .filter(Objects::nonNull), first, max)
.filter(Objects::nonNull), first, max) .collect(Collectors.toList());
.collect(Collectors.toList());
} }
@Override @Override
@ -310,11 +308,10 @@ public class MapPermissionTicketStore<K extends Comparable<K>> implements Permis
ModelCriteriaBuilder<PermissionTicket> mcb = permissionTicketStore.createCriteriaBuilder() ModelCriteriaBuilder<PermissionTicket> mcb = permissionTicketStore.createCriteriaBuilder()
.compare(SearchableFields.OWNER, Operator.EQ, owner); .compare(SearchableFields.OWNER, Operator.EQ, owner);
return paginatedStream(tx.read(mcb) return paginatedStream(tx.read(withCriteria(mcb).orderBy(SearchableFields.RESOURCE_ID, ASCENDING))
.filter(distinctByKey(MapPermissionTicketEntity::getResourceId)) .filter(distinctByKey(MapPermissionTicketEntity::getResourceId)), first, max)
.sorted(MapPermissionTicketEntity.COMPARE_BY_RESOURCE_ID), first, max) .map(ticket -> authorizationProvider.getStoreFactory().getResourceStore()
.map(ticket -> authorizationProvider.getStoreFactory().getResourceStore() .findById(ticket.getResourceId(), ticket.getResourceServerId()))
.findById(ticket.getResourceId(), ticket.getResourceServerId())) .collect(Collectors.toList());
.collect(Collectors.toList());
} }
} }

View file

@ -42,7 +42,7 @@ import java.util.stream.Collectors;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.utils.StreamsUtil.paginatedStream; import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
public class MapPolicyStore<K> implements PolicyStore { public class MapPolicyStore<K> implements PolicyStore {
@ -86,7 +86,7 @@ public class MapPolicyStore<K> implements PolicyStore {
ModelCriteriaBuilder<Policy> mcb = forResourceServer(resourceServer.getId()) ModelCriteriaBuilder<Policy> mcb = forResourceServer(resourceServer.getId())
.compare(SearchableFields.NAME, Operator.EQ, representation.getName()); .compare(SearchableFields.NAME, Operator.EQ, representation.getName());
if (tx.getCount(mcb) > 0) { if (tx.getCount(withCriteria(mcb)) > 0) {
throw new ModelDuplicateException("Policy with name '" + representation.getName() + "' for " + resourceServer.getId() + " already exists"); throw new ModelDuplicateException("Policy with name '" + representation.getName() + "' for " + resourceServer.getId() + " already exists");
} }
@ -111,8 +111,8 @@ public class MapPolicyStore<K> implements PolicyStore {
public Policy findById(String id, String resourceServerId) { public Policy findById(String id, String resourceServerId) {
LOG.tracef("findById(%s, %s)%s", id, resourceServerId, getShortStackTrace()); LOG.tracef("findById(%s, %s)%s", id, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.ID, Operator.EQ, id)) .compare(SearchableFields.ID, Operator.EQ, id)))
.findFirst() .findFirst()
.map(this::entityToAdapter) .map(this::entityToAdapter)
.orElse(null); .orElse(null);
@ -122,8 +122,8 @@ public class MapPolicyStore<K> implements PolicyStore {
public Policy findByName(String name, String resourceServerId) { public Policy findByName(String name, String resourceServerId) {
LOG.tracef("findByName(%s, %s)%s", name, resourceServerId, getShortStackTrace()); LOG.tracef("findByName(%s, %s)%s", name, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.NAME, Operator.EQ, name)) .compare(SearchableFields.NAME, Operator.EQ, name)))
.findFirst() .findFirst()
.map(this::entityToAdapter) .map(this::entityToAdapter)
.orElse(null); .orElse(null);
@ -133,7 +133,7 @@ public class MapPolicyStore<K> implements PolicyStore {
public List<Policy> findByResourceServer(String id) { public List<Policy> findByResourceServer(String id) {
LOG.tracef("findByResourceServer(%s)%s", id, getShortStackTrace()); LOG.tracef("findByResourceServer(%s)%s", id, getShortStackTrace());
return tx.read(forResourceServer(id)) return tx.read(withCriteria(forResourceServer(id)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -153,12 +153,12 @@ public class MapPolicyStore<K> implements PolicyStore {
mcb = mcb.compare(SearchableFields.OWNER, Operator.NOT_EXISTS); mcb = mcb.compare(SearchableFields.OWNER, Operator.NOT_EXISTS);
} }
return paginatedStream(tx.read(mcb) return tx.read(withCriteria(mcb).pagination(firstResult, maxResult, SearchableFields.NAME))
.sorted(MapPolicyEntity.COMPARE_BY_NAME), firstResult, maxResult) .map(MapPolicyEntity<K>::getId)
.map(MapPolicyEntity<K>::getId) .map(policyStore.getKeyConvertor()::keyToString)
.map(K::toString) // We need to go through cache
.map(id -> authorizationProvider.getStoreFactory().getPolicyStore().findById(id, resourceServerId)) // We need to go through cache .map(id -> authorizationProvider.getStoreFactory().getPolicyStore().findById(id, resourceServerId))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private ModelCriteriaBuilder<Policy> filterEntryToModelCriteriaBuilder(Map.Entry<Policy.FilterOption, String[]> entry) { private ModelCriteriaBuilder<Policy> filterEntryToModelCriteriaBuilder(Map.Entry<Policy.FilterOption, String[]> entry) {
@ -205,24 +205,24 @@ public class MapPolicyStore<K> implements PolicyStore {
public void findByResource(String resourceId, String resourceServerId, Consumer<Policy> consumer) { public void findByResource(String resourceId, String resourceServerId, Consumer<Policy> consumer) {
LOG.tracef("findByResource(%s, %s, %s)%s", resourceId, resourceServerId, consumer, getShortStackTrace()); LOG.tracef("findByResource(%s, %s, %s)%s", resourceId, resourceServerId, consumer, getShortStackTrace());
tx.read(forResourceServer(resourceServerId) tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(Policy.SearchableFields.RESOURCE_ID, Operator.EQ, resourceId)) .compare(SearchableFields.RESOURCE_ID, Operator.EQ, resourceId)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.forEach(consumer); .forEach(consumer);
} }
@Override @Override
public void findByResourceType(String type, String resourceServerId, Consumer<Policy> policyConsumer) { public void findByResourceType(String type, String resourceServerId, Consumer<Policy> policyConsumer) {
tx.read(forResourceServer(resourceServerId) tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.CONFIG, Operator.LIKE, (Object[]) new String[] {"defaultResourceType", type})) .compare(SearchableFields.CONFIG, Operator.LIKE, (Object[]) new String[]{"defaultResourceType", type})))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.forEach(policyConsumer); .forEach(policyConsumer);
} }
@Override @Override
public List<Policy> findByScopeIds(List<String> scopeIds, String resourceServerId) { public List<Policy> findByScopeIds(List<String> scopeIds, String resourceServerId) {
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.SCOPE_ID, Operator.IN, scopeIds)) .compare(SearchableFields.SCOPE_ID, Operator.IN, scopeIds)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -240,22 +240,22 @@ public class MapPolicyStore<K> implements PolicyStore {
mcb = mcb.compare(SearchableFields.RESOURCE_ID, Operator.NOT_EXISTS) mcb = mcb.compare(SearchableFields.RESOURCE_ID, Operator.NOT_EXISTS)
.compare(SearchableFields.CONFIG, Operator.NOT_EXISTS, (Object[]) new String[] {"defaultResourceType"}); .compare(SearchableFields.CONFIG, Operator.NOT_EXISTS, (Object[]) new String[] {"defaultResourceType"});
} }
tx.read(mcb).map(this::entityToAdapter).forEach(consumer); tx.read(withCriteria(mcb)).map(this::entityToAdapter).forEach(consumer);
} }
@Override @Override
public List<Policy> findByType(String type, String resourceServerId) { public List<Policy> findByType(String type, String resourceServerId) {
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.TYPE, Operator.EQ, type)) .compare(SearchableFields.TYPE, Operator.EQ, type)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@Override @Override
public List<Policy> findDependentPolicies(String id, String resourceServerId) { public List<Policy> findDependentPolicies(String id, String resourceServerId) {
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.ASSOCIATED_POLICY_ID, Operator.EQ, id)) .compare(SearchableFields.ASSOCIATED_POLICY_ID, Operator.EQ, id)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }

View file

@ -33,7 +33,6 @@ import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator; import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -42,7 +41,7 @@ import java.util.stream.Collectors;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.utils.StreamsUtil.paginatedStream; import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
public class MapResourceStore<K extends Comparable<K>> implements ResourceStore { public class MapResourceStore<K extends Comparable<K>> implements ResourceStore {
@ -86,7 +85,7 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
.compare(SearchableFields.NAME, Operator.EQ, name) .compare(SearchableFields.NAME, Operator.EQ, name)
.compare(SearchableFields.OWNER, Operator.EQ, owner); .compare(SearchableFields.OWNER, Operator.EQ, owner);
if (tx.getCount(mcb) > 0) { if (tx.getCount(withCriteria(mcb)) > 0) {
throw new ModelDuplicateException("Resource with name '" + name + "' for " + resourceServer.getId() + " already exists for request owner " + owner); throw new ModelDuplicateException("Resource with name '" + name + "' for " + resourceServer.getId() + " already exists for request owner " + owner);
} }
@ -113,8 +112,8 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
public Resource findById(String id, String resourceServerId) { public Resource findById(String id, String resourceServerId) {
LOG.tracef("findById(%s, %s)%s", id, resourceServerId, getShortStackTrace()); LOG.tracef("findById(%s, %s)%s", id, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.ID, Operator.EQ, id)) .compare(SearchableFields.ID, Operator.EQ, id)))
.findFirst() .findFirst()
.map(this::entityToAdapter) .map(this::entityToAdapter)
.orElse(null); .orElse(null);
@ -127,12 +126,11 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
private void findByOwnerFilter(String ownerId, String resourceServerId, Consumer<Resource> consumer, int firstResult, int maxResult) { private void findByOwnerFilter(String ownerId, String resourceServerId, Consumer<Resource> consumer, int firstResult, int maxResult) {
LOG.tracef("findByOwnerFilter(%s, %s, %s, %d, %d)%s", ownerId, resourceServerId, consumer, firstResult, maxResult, getShortStackTrace()); LOG.tracef("findByOwnerFilter(%s, %s, %s, %d, %d)%s", ownerId, resourceServerId, consumer, firstResult, maxResult, getShortStackTrace());
Comparator<? super MapResourceEntity<K>> c = Comparator.comparing(MapResourceEntity::getId);
paginatedStream(tx.read(forResourceServer(resourceServerId) tx.read(withCriteria(forResourceServer(resourceServerId).compare(SearchableFields.OWNER, Operator.EQ, ownerId))
.compare(SearchableFields.OWNER, Operator.EQ, ownerId)) .pagination(firstResult, maxResult, SearchableFields.ID)
.sorted(c), firstResult, maxResult) ).map(this::entityToAdapter)
.map(this::entityToAdapter) .forEach(consumer);
.forEach(consumer);
} }
@Override @Override
@ -147,9 +145,9 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
@Override @Override
public List<Resource> findByUri(String uri, String resourceServerId) { public List<Resource> findByUri(String uri, String resourceServerId) {
LOG.tracef("findByUri(%s, %s)%s", uri, resourceServerId, getShortStackTrace()); LOG.tracef("findByUri(%s, %s)%s", uri, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.URI, Operator.EQ, uri)) .compare(SearchableFields.URI, Operator.EQ, uri)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -158,7 +156,7 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
public List<Resource> findByResourceServer(String resourceServerId) { public List<Resource> findByResourceServer(String resourceServerId) {
LOG.tracef("findByResourceServer(%s)%s", resourceServerId, getShortStackTrace()); LOG.tracef("findByResourceServer(%s)%s", resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId)) return tx.read(withCriteria(forResourceServer(resourceServerId)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -172,8 +170,7 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
.toArray(ModelCriteriaBuilder[]::new) .toArray(ModelCriteriaBuilder[]::new)
); );
return paginatedStream(tx.read(mcb) return tx.read(withCriteria(mcb).pagination(firstResult, maxResult, SearchableFields.NAME))
.sorted(MapResourceEntity.COMPARE_BY_NAME), firstResult, maxResult)
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -210,8 +207,8 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
public void findByScope(List<String> scopes, String resourceServerId, Consumer<Resource> consumer) { public void findByScope(List<String> scopes, String resourceServerId, Consumer<Resource> consumer) {
LOG.tracef("findByScope(%s, %s, %s)%s", scopes, resourceServerId, consumer, getShortStackTrace()); LOG.tracef("findByScope(%s, %s, %s)%s", scopes, resourceServerId, consumer, getShortStackTrace());
tx.read(forResourceServer(resourceServerId) tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.SCOPE_ID, Operator.IN, scopes)) .compare(SearchableFields.SCOPE_ID, Operator.IN, scopes)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.forEach(consumer); .forEach(consumer);
} }
@ -224,9 +221,9 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
@Override @Override
public Resource findByName(String name, String ownerId, String resourceServerId) { public Resource findByName(String name, String ownerId, String resourceServerId) {
LOG.tracef("findByName(%s, %s, %s)%s", name, ownerId, resourceServerId, getShortStackTrace()); LOG.tracef("findByName(%s, %s, %s)%s", name, ownerId, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.OWNER, Operator.EQ, ownerId) .compare(SearchableFields.OWNER, Operator.EQ, ownerId)
.compare(SearchableFields.NAME, Operator.EQ, name)) .compare(SearchableFields.NAME, Operator.EQ, name)))
.findFirst() .findFirst()
.map(this::entityToAdapter) .map(this::entityToAdapter)
.orElse(null); .orElse(null);
@ -235,8 +232,8 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
@Override @Override
public void findByType(String type, String resourceServerId, Consumer<Resource> consumer) { public void findByType(String type, String resourceServerId, Consumer<Resource> consumer) {
LOG.tracef("findByType(%s, %s, %s)%s", type, resourceServerId, consumer, getShortStackTrace()); LOG.tracef("findByType(%s, %s, %s)%s", type, resourceServerId, consumer, getShortStackTrace());
tx.read(forResourceServer(resourceServerId) tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.TYPE, Operator.EQ, type)) .compare(SearchableFields.TYPE, Operator.EQ, type)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.forEach(consumer); .forEach(consumer);
} }
@ -252,7 +249,7 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
mcb = mcb.compare(SearchableFields.OWNER, Operator.EQ, owner); mcb = mcb.compare(SearchableFields.OWNER, Operator.EQ, owner);
} }
tx.read(mcb) tx.read(withCriteria(mcb))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.forEach(consumer); .forEach(consumer);
} }
@ -260,9 +257,9 @@ public class MapResourceStore<K extends Comparable<K>> implements ResourceStore
@Override @Override
public void findByTypeInstance(String type, String resourceServerId, Consumer<Resource> consumer) { public void findByTypeInstance(String type, String resourceServerId, Consumer<Resource> consumer) {
LOG.tracef("findByTypeInstance(%s, %s, %s)%s", type, resourceServerId, consumer, getShortStackTrace()); LOG.tracef("findByTypeInstance(%s, %s, %s)%s", type, resourceServerId, consumer, getShortStackTrace());
tx.read(forResourceServer(resourceServerId) tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(SearchableFields.OWNER, Operator.NE, resourceServerId) .compare(SearchableFields.OWNER, Operator.NE, resourceServerId)
.compare(SearchableFields.TYPE, Operator.EQ, type)) .compare(SearchableFields.TYPE, Operator.EQ, type)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.forEach(consumer); .forEach(consumer);
} }

View file

@ -39,7 +39,7 @@ import java.util.stream.Collectors;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.utils.StreamsUtil.paginatedStream; import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
public class MapScopeStore<K> implements ScopeStore { public class MapScopeStore<K> implements ScopeStore {
@ -84,7 +84,7 @@ public class MapScopeStore<K> implements ScopeStore {
ModelCriteriaBuilder<Scope> mcb = forResourceServer(resourceServer.getId()) ModelCriteriaBuilder<Scope> mcb = forResourceServer(resourceServer.getId())
.compare(SearchableFields.NAME, Operator.EQ, name); .compare(SearchableFields.NAME, Operator.EQ, name);
if (tx.getCount(mcb) > 0) { if (tx.getCount(withCriteria(mcb)) > 0) {
throw new ModelDuplicateException("Scope with name '" + name + "' for " + resourceServer.getId() + " already exists"); throw new ModelDuplicateException("Scope with name '" + name + "' for " + resourceServer.getId() + " already exists");
} }
@ -109,8 +109,8 @@ public class MapScopeStore<K> implements ScopeStore {
public Scope findById(String id, String resourceServerId) { public Scope findById(String id, String resourceServerId) {
LOG.tracef("findById(%s, %s)%s", id, resourceServerId, getShortStackTrace()); LOG.tracef("findById(%s, %s)%s", id, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId) return tx.read(withCriteria(forResourceServer(resourceServerId)
.compare(Scope.SearchableFields.ID, Operator.EQ, id)) .compare(SearchableFields.ID, Operator.EQ, id)))
.findFirst() .findFirst()
.map(this::entityToAdapter) .map(this::entityToAdapter)
.orElse(null); .orElse(null);
@ -120,8 +120,8 @@ public class MapScopeStore<K> implements ScopeStore {
public Scope findByName(String name, String resourceServerId) { public Scope findByName(String name, String resourceServerId) {
LOG.tracef("findByName(%s, %s)%s", name, resourceServerId, getShortStackTrace()); LOG.tracef("findByName(%s, %s)%s", name, resourceServerId, getShortStackTrace());
return tx.read(forResourceServer(resourceServerId).compare(Scope.SearchableFields.NAME, return tx.read(withCriteria(forResourceServer(resourceServerId).compare(SearchableFields.NAME,
Operator.EQ, name)) Operator.EQ, name)))
.findFirst() .findFirst()
.map(this::entityToAdapter) .map(this::entityToAdapter)
.orElse(null); .orElse(null);
@ -131,7 +131,7 @@ public class MapScopeStore<K> implements ScopeStore {
public List<Scope> findByResourceServer(String id) { public List<Scope> findByResourceServer(String id) {
LOG.tracef("findByResourceServer(%s)%s", id, getShortStackTrace()); LOG.tracef("findByResourceServer(%s)%s", id, getShortStackTrace());
return tx.read(forResourceServer(id)) return tx.read(withCriteria(forResourceServer(id)))
.map(this::entityToAdapter) .map(this::entityToAdapter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -155,7 +155,8 @@ public class MapScopeStore<K> implements ScopeStore {
} }
} }
return paginatedStream(tx.read(mcb).map(this::entityToAdapter), firstResult, maxResult) return tx.read(withCriteria(mcb).pagination(firstResult, maxResult, SearchableFields.NAME))
.collect(Collectors.toList()); .map(this::entityToAdapter)
.collect(Collectors.toList());
} }
} }

View file

@ -24,8 +24,6 @@ import java.util.Objects;
public class MapPermissionTicketEntity<K> implements AbstractEntity<K> { public class MapPermissionTicketEntity<K> implements AbstractEntity<K> {
public static final Comparator<MapPermissionTicketEntity<?>> COMPARE_BY_RESOURCE_ID = Comparator.comparing(MapPermissionTicketEntity::getResourceId);
private final K id; private final K id;
private String owner; private String owner;
private String requester; private String requester;

View file

@ -30,8 +30,6 @@ import java.util.Objects;
public class MapPolicyEntity<K> implements AbstractEntity<K> { public class MapPolicyEntity<K> implements AbstractEntity<K> {
public static final Comparator<MapPolicyEntity<?>> COMPARE_BY_NAME = Comparator.comparing(MapPolicyEntity::getName);
private final K id; private final K id;
private String name; private String name;
private String description; private String description;

View file

@ -29,8 +29,6 @@ import java.util.Set;
public class MapResourceEntity<K> implements AbstractEntity<K> { public class MapResourceEntity<K> implements AbstractEntity<K> {
public static final Comparator<MapResourceEntity<?>> COMPARE_BY_NAME = Comparator.comparing(MapResourceEntity::getName);
private final K id; private final K id;
private String name; private String name;
private String displayName; private String displayName;

View file

@ -28,7 +28,7 @@ import org.keycloak.models.RealmModel;
import org.keycloak.models.RoleModel; import org.keycloak.models.RoleModel;
import org.keycloak.models.map.storage.MapKeycloakTransaction; import org.keycloak.models.map.storage.MapKeycloakTransaction;
import java.util.Comparator;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
@ -45,9 +45,11 @@ import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import org.keycloak.models.ClientScopeModel; import org.keycloak.models.ClientScopeModel;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.models.map.storage.QueryParameters.Order.ASCENDING;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
import org.keycloak.protocol.oidc.OIDCLoginProtocol; import org.keycloak.protocol.oidc.OIDCLoginProtocol;
import java.util.HashSet; import java.util.HashSet;
import static org.keycloak.utils.StreamsUtil.paginatedStream;
public class MapClientProvider<K> implements ClientProvider { public class MapClientProvider<K> implements ClientProvider {
@ -57,8 +59,6 @@ public class MapClientProvider<K> implements ClientProvider {
private final MapStorage<K, MapClientEntity<K>, ClientModel> clientStore; private final MapStorage<K, MapClientEntity<K>, ClientModel> clientStore;
private final ConcurrentMap<K, ConcurrentMap<String, Integer>> clientRegisteredNodesStore; private final ConcurrentMap<K, ConcurrentMap<String, Integer>> clientRegisteredNodesStore;
private static final Comparator<MapClientEntity> COMPARE_BY_CLIENT_ID = Comparator.comparing(MapClientEntity::getClientId);
public MapClientProvider(KeycloakSession session, MapStorage<K, MapClientEntity<K>, ClientModel> clientStore, ConcurrentMap<K, ConcurrentMap<String, Integer>> clientRegisteredNodesStore) { public MapClientProvider(KeycloakSession session, MapStorage<K, MapClientEntity<K>, ClientModel> clientStore, ConcurrentMap<K, ConcurrentMap<String, Integer>> clientRegisteredNodesStore) {
this.session = session; this.session = session;
this.clientStore = clientStore; this.clientStore = clientStore;
@ -126,7 +126,11 @@ public class MapClientProvider<K> implements ClientProvider {
@Override @Override
public Stream<ClientModel> getClientsStream(RealmModel realm, Integer firstResult, Integer maxResults) { public Stream<ClientModel> getClientsStream(RealmModel realm, Integer firstResult, Integer maxResults) {
return paginatedStream(getClientsStream(realm), firstResult, maxResults); ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
return tx.read(withCriteria(mcb).pagination(firstResult, maxResults, SearchableFields.CLIENT_ID))
.map(entityToAdapterFunc(realm));
} }
@Override @Override
@ -134,10 +138,8 @@ public class MapClientProvider<K> implements ClientProvider {
ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder() ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
return tx.read(mcb) return tx.read(withCriteria(mcb).orderBy(SearchableFields.CLIENT_ID, ASCENDING))
.sorted(COMPARE_BY_CLIENT_ID) .map(entityToAdapterFunc(realm));
.map(entityToAdapterFunc(realm))
;
} }
@Override @Override
@ -220,7 +222,7 @@ public class MapClientProvider<K> implements ClientProvider {
ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder() ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
return this.clientStore.getCount(mcb); return this.clientStore.getCount(withCriteria(mcb));
} }
@Override @Override
@ -248,7 +250,7 @@ public class MapClientProvider<K> implements ClientProvider {
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.CLIENT_ID, Operator.ILIKE, clientId); .compare(SearchableFields.CLIENT_ID, Operator.ILIKE, clientId);
return tx.read(mcb) return tx.read(withCriteria(mcb))
.map(entityToAdapterFunc(realm)) .map(entityToAdapterFunc(realm))
.findFirst() .findFirst()
.orElse(null) .orElse(null)
@ -265,10 +267,8 @@ public class MapClientProvider<K> implements ClientProvider {
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.CLIENT_ID, Operator.ILIKE, "%" + clientId + "%"); .compare(SearchableFields.CLIENT_ID, Operator.ILIKE, "%" + clientId + "%");
Stream<MapClientEntity<K>> s = tx.read(mcb) return tx.read(withCriteria(mcb).pagination(firstResult, maxResults, SearchableFields.CLIENT_ID))
.sorted(COMPARE_BY_CLIENT_ID); .map(entityToAdapterFunc(realm));
return paginatedStream(s, firstResult, maxResults).map(entityToAdapterFunc(realm));
} }
@Override @Override
@ -280,10 +280,8 @@ public class MapClientProvider<K> implements ClientProvider {
mcb = mcb.compare(SearchableFields.ATTRIBUTE, Operator.EQ, entry.getKey(), entry.getValue()); mcb = mcb.compare(SearchableFields.ATTRIBUTE, Operator.EQ, entry.getKey(), entry.getValue());
} }
Stream<MapClientEntity<K>> s = tx.read(mcb) return tx.read(withCriteria(mcb).pagination(firstResult, maxResults, SearchableFields.CLIENT_ID))
.sorted(COMPARE_BY_CLIENT_ID); .map(entityToAdapterFunc(realm));
return paginatedStream(s, firstResult, maxResults).map(entityToAdapterFunc(realm));
} }
@Override @Override
@ -344,7 +342,7 @@ public class MapClientProvider<K> implements ClientProvider {
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.ENABLED, Operator.EQ, Boolean.TRUE); .compare(SearchableFields.ENABLED, Operator.EQ, Boolean.TRUE);
try (Stream<MapClientEntity<K>> st = tx.read(mcb)) { try (Stream<MapClientEntity<K>> st = tx.read(withCriteria(mcb))) {
return st return st
.filter(mce -> mce.getRedirectUris() != null && ! mce.getRedirectUris().isEmpty()) .filter(mce -> mce.getRedirectUris() != null && ! mce.getRedirectUris().isEmpty())
.collect(Collectors.toMap( .collect(Collectors.toMap(
@ -358,7 +356,7 @@ public class MapClientProvider<K> implements ClientProvider {
ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder() ModelCriteriaBuilder<ClientModel> mcb = clientStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.SCOPE_MAPPING_ROLE, Operator.EQ, role.getId()); .compare(SearchableFields.SCOPE_MAPPING_ROLE, Operator.EQ, role.getId());
try (Stream<MapClientEntity<K>> toRemove = tx.read(mcb)) { try (Stream<MapClientEntity<K>> toRemove = tx.read(withCriteria(mcb))) {
toRemove toRemove
.map(clientEntity -> session.clients().getClientById(realm, clientEntity.getId().toString())) .map(clientEntity -> session.clients().getClientById(realm, clientEntity.getId().toString()))
.filter(Objects::nonNull) .filter(Objects::nonNull)

View file

@ -17,7 +17,6 @@
package org.keycloak.models.map.clientscope; package org.keycloak.models.map.clientscope;
import java.util.Comparator;
import java.util.Objects; import java.util.Objects;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
@ -38,6 +37,8 @@ import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.models.map.storage.QueryParameters.Order.ASCENDING;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
public class MapClientScopeProvider<K> implements ClientScopeProvider { public class MapClientScopeProvider<K> implements ClientScopeProvider {
@ -46,8 +47,6 @@ public class MapClientScopeProvider<K> implements ClientScopeProvider {
private final MapKeycloakTransaction<K, MapClientScopeEntity<K>, ClientScopeModel> tx; private final MapKeycloakTransaction<K, MapClientScopeEntity<K>, ClientScopeModel> tx;
private final MapStorage<K, MapClientScopeEntity<K>, ClientScopeModel> clientScopeStore; private final MapStorage<K, MapClientScopeEntity<K>, ClientScopeModel> clientScopeStore;
private static final Comparator<MapClientScopeEntity> COMPARE_BY_NAME = Comparator.comparing(MapClientScopeEntity::getName);
public MapClientScopeProvider(KeycloakSession session, MapStorage<K, MapClientScopeEntity<K>, ClientScopeModel> clientScopeStore) { public MapClientScopeProvider(KeycloakSession session, MapStorage<K, MapClientScopeEntity<K>, ClientScopeModel> clientScopeStore) {
this.session = session; this.session = session;
this.clientScopeStore = clientScopeStore; this.clientScopeStore = clientScopeStore;
@ -79,8 +78,7 @@ public class MapClientScopeProvider<K> implements ClientScopeProvider {
ModelCriteriaBuilder<ClientScopeModel> mcb = clientScopeStore.createCriteriaBuilder() ModelCriteriaBuilder<ClientScopeModel> mcb = clientScopeStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
return tx.read(mcb) return tx.read(withCriteria(mcb).orderBy(SearchableFields.NAME, ASCENDING))
.sorted(COMPARE_BY_NAME)
.map(entityToAdapterFunc(realm)); .map(entityToAdapterFunc(realm));
} }
@ -91,7 +89,7 @@ public class MapClientScopeProvider<K> implements ClientScopeProvider {
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.NAME, Operator.EQ, name); .compare(SearchableFields.NAME, Operator.EQ, name);
if (tx.getCount(mcb) > 0) { if (tx.getCount(withCriteria(mcb)) > 0) {
throw new ModelDuplicateException("Client scope with name '" + name + "' in realm " + realm.getName()); throw new ModelDuplicateException("Client scope with name '" + name + "' in realm " + realm.getName());
} }

View file

@ -30,15 +30,18 @@ import org.keycloak.models.map.storage.MapStorage;
import org.keycloak.models.map.storage.ModelCriteriaBuilder; import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator; import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import java.util.Comparator; import org.keycloak.models.map.storage.QueryParameters;
import java.util.Objects; import java.util.Objects;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.UnaryOperator; import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.utils.StreamsUtil.paginatedStream; import static org.keycloak.models.map.storage.QueryParameters.Order.ASCENDING;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
public class MapGroupProvider<K> implements GroupProvider { public class MapGroupProvider<K> implements GroupProvider {
@ -88,10 +91,10 @@ public class MapGroupProvider<K> implements GroupProvider {
@Override @Override
public Stream<GroupModel> getGroupsStream(RealmModel realm) { public Stream<GroupModel> getGroupsStream(RealmModel realm) {
return getGroupsStreamInternal(realm, null); return getGroupsStreamInternal(realm, null, null);
} }
private Stream<GroupModel> getGroupsStreamInternal(RealmModel realm, UnaryOperator<ModelCriteriaBuilder<GroupModel>> modifier) { private Stream<GroupModel> getGroupsStreamInternal(RealmModel realm, UnaryOperator<ModelCriteriaBuilder<GroupModel>> modifier, UnaryOperator<QueryParameters<GroupModel>> queryParametersModifier) {
LOG.tracef("getGroupsStream(%s)%s", realm, getShortStackTrace()); LOG.tracef("getGroupsStream(%s)%s", realm, getShortStackTrace());
ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder() ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
@ -100,9 +103,13 @@ public class MapGroupProvider<K> implements GroupProvider {
mcb = modifier.apply(mcb); mcb = modifier.apply(mcb);
} }
return tx.read(mcb) QueryParameters<GroupModel> queryParameters = withCriteria(mcb).orderBy(SearchableFields.NAME, ASCENDING);
if (queryParametersModifier != null) {
queryParameters = queryParametersModifier.apply(queryParameters);
}
return tx.read(queryParameters)
.map(entityToAdapterFunc(realm)) .map(entityToAdapterFunc(realm))
.sorted(GroupModel.COMPARE_BY_NAME)
; ;
} }
@ -116,11 +123,8 @@ public class MapGroupProvider<K> implements GroupProvider {
mcb = mcb.compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%"); mcb = mcb.compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%");
} }
Stream<GroupModel> groupModelStream = tx.read(mcb) return tx.read(withCriteria(mcb).pagination(first, max, SearchableFields.NAME))
.map(entityToAdapterFunc(realm)) .map(entityToAdapterFunc(realm));
.sorted(Comparator.comparing(GroupModel::getName));
return paginatedStream(groupModelStream, first, max);
} }
@Override @Override
@ -133,7 +137,7 @@ public class MapGroupProvider<K> implements GroupProvider {
mcb = mcb.compare(SearchableFields.PARENT_ID, Operator.EQ, (Object) null); mcb = mcb.compare(SearchableFields.PARENT_ID, Operator.EQ, (Object) null);
} }
return tx.getCount(mcb); return tx.getCount(withCriteria(mcb));
} }
@Override @Override
@ -142,51 +146,56 @@ public class MapGroupProvider<K> implements GroupProvider {
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%"); .compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%");
return tx.getCount(mcb); return tx.getCount(withCriteria(mcb));
} }
@Override @Override
public Stream<GroupModel> getGroupsByRoleStream(RealmModel realm, RoleModel role, Integer firstResult, Integer maxResults) { public Stream<GroupModel> getGroupsByRoleStream(RealmModel realm, RoleModel role, Integer firstResult, Integer maxResults) {
LOG.tracef("getGroupsByRole(%s, %s, %d, %d)%s", realm, role, firstResult, maxResults, getShortStackTrace()); LOG.tracef("getGroupsByRole(%s, %s, %d, %d)%s", realm, role, firstResult, maxResults, getShortStackTrace());
Stream<GroupModel> groupModelStream = getGroupsStreamInternal(realm, return getGroupsStreamInternal(realm,
(ModelCriteriaBuilder<GroupModel> mcb) -> mcb.compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, role.getId()) (ModelCriteriaBuilder<GroupModel> mcb) -> mcb.compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, role.getId()),
qp -> qp.offset(firstResult).limit(maxResults)
); );
return paginatedStream(groupModelStream, firstResult, maxResults);
} }
@Override @Override
public Stream<GroupModel> getTopLevelGroupsStream(RealmModel realm) { public Stream<GroupModel> getTopLevelGroupsStream(RealmModel realm) {
LOG.tracef("getTopLevelGroupsStream(%s)%s", realm, getShortStackTrace()); LOG.tracef("getTopLevelGroupsStream(%s)%s", realm, getShortStackTrace());
return getGroupsStreamInternal(realm, return getGroupsStreamInternal(realm,
(ModelCriteriaBuilder<GroupModel> mcb) -> mcb.compare(SearchableFields.PARENT_ID, Operator.EQ, (Object) null) (ModelCriteriaBuilder<GroupModel> mcb) -> mcb.compare(SearchableFields.PARENT_ID, Operator.NOT_EXISTS),
null
); );
} }
@Override @Override
public Stream<GroupModel> getTopLevelGroupsStream(RealmModel realm, Integer firstResult, Integer maxResults) { public Stream<GroupModel> getTopLevelGroupsStream(RealmModel realm, Integer firstResult, Integer maxResults) {
Stream<GroupModel> groupModelStream = getTopLevelGroupsStream(realm); LOG.tracef("getTopLevelGroupsStream(%s, %s, %s)%s", realm, firstResult, maxResults, getShortStackTrace());
return getGroupsStreamInternal(realm,
return paginatedStream(groupModelStream, firstResult, maxResults); (ModelCriteriaBuilder<GroupModel> mcb) -> mcb.compare(SearchableFields.PARENT_ID, Operator.NOT_EXISTS),
qp -> qp.offset(firstResult).limit(maxResults)
);
} }
@Override @Override
public Stream<GroupModel> searchForGroupByNameStream(RealmModel realm, String search, Integer firstResult, Integer maxResults) { public Stream<GroupModel> searchForGroupByNameStream(RealmModel realm, String search, Integer firstResult, Integer maxResults) {
LOG.tracef("searchForGroupByNameStream(%s, %s, %d, %d)%s", realm, search, firstResult, maxResults, getShortStackTrace()); LOG.tracef("searchForGroupByNameStream(%s, %s, %d, %d)%s", realm, search, firstResult, maxResults, getShortStackTrace());
Stream<GroupModel> groupModelStream = getGroupsStreamInternal(realm,
(ModelCriteriaBuilder<GroupModel> mcb) -> mcb.compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%")
);
final Stream<String> groups = paginatedStream(groupModelStream.map(GroupModel::getId), firstResult, maxResults);
return groups.map(id -> { ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
GroupModel groupById = session.groups().getGroupById(realm,id); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
while (Objects.nonNull(groupById.getParentId())) { .compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%");
groupById = session.groups().getGroupById(realm, groupById.getParentId());
}
return groupById; return tx.read(withCriteria(mcb).pagination(firstResult, maxResults, SearchableFields.NAME))
}).sorted(GroupModel.COMPARE_BY_NAME).distinct(); .map(MapGroupEntity::getId)
.map(groupStore.getKeyConvertor()::keyToString)
.map(id -> {
GroupModel groupById = session.groups().getGroupById(realm, id);
while (Objects.nonNull(groupById.getParentId())) {
groupById = session.groups().getGroupById(realm, groupById.getParentId());
}
return groupById;
}).sorted(GroupModel.COMPARE_BY_NAME).distinct();
} }
@Override @Override
@ -201,7 +210,7 @@ public class MapGroupProvider<K> implements GroupProvider {
.compare(SearchableFields.PARENT_ID, Operator.EQ, parentId) .compare(SearchableFields.PARENT_ID, Operator.EQ, parentId)
.compare(SearchableFields.NAME, Operator.EQ, name); .compare(SearchableFields.NAME, Operator.EQ, name);
if (tx.getCount(mcb) > 0) { if (tx.getCount(withCriteria(mcb)) > 0) {
throw new ModelDuplicateException("Group with name '" + name + "' in realm " + realm.getName() + " already exists for requested parent" ); throw new ModelDuplicateException("Group with name '" + name + "' in realm " + realm.getName() + " already exists for requested parent" );
} }
@ -243,7 +252,7 @@ public class MapGroupProvider<K> implements GroupProvider {
session.users().preRemove(realm, group); session.users().preRemove(realm, group);
realm.removeDefaultGroup(group); realm.removeDefaultGroup(group);
group.getSubGroupsStream().forEach(subGroup -> session.groups().removeGroup(realm, subGroup)); group.getSubGroupsStream().collect(Collectors.toSet()).forEach(subGroup -> session.groups().removeGroup(realm, subGroup));
// TODO: ^^^^^^^ Up to here // TODO: ^^^^^^^ Up to here
@ -268,7 +277,7 @@ public class MapGroupProvider<K> implements GroupProvider {
.compare(SearchableFields.PARENT_ID, Operator.EQ, parentId) .compare(SearchableFields.PARENT_ID, Operator.EQ, parentId)
.compare(SearchableFields.NAME, Operator.EQ, group.getName()); .compare(SearchableFields.NAME, Operator.EQ, group.getName());
try (Stream<MapGroupEntity<K>> possibleSiblings = tx.read(mcb)) { try (Stream<MapGroupEntity<K>> possibleSiblings = tx.read(withCriteria(mcb))) {
if (possibleSiblings.findAny().isPresent()) { if (possibleSiblings.findAny().isPresent()) {
throw new ModelDuplicateException("Parent already contains subgroup named '" + group.getName() + "'"); throw new ModelDuplicateException("Parent already contains subgroup named '" + group.getName() + "'");
} }
@ -290,7 +299,7 @@ public class MapGroupProvider<K> implements GroupProvider {
.compare(SearchableFields.PARENT_ID, Operator.EQ, (Object) null) .compare(SearchableFields.PARENT_ID, Operator.EQ, (Object) null)
.compare(SearchableFields.NAME, Operator.EQ, subGroup.getName()); .compare(SearchableFields.NAME, Operator.EQ, subGroup.getName());
try (Stream<MapGroupEntity<K>> possibleSiblings = tx.read(mcb)) { try (Stream<MapGroupEntity<K>> possibleSiblings = tx.read(withCriteria(mcb))) {
if (possibleSiblings.findAny().isPresent()) { if (possibleSiblings.findAny().isPresent()) {
throw new ModelDuplicateException("There is already a top level group named '" + subGroup.getName() + "'"); throw new ModelDuplicateException("There is already a top level group named '" + subGroup.getName() + "'");
} }
@ -304,7 +313,7 @@ public class MapGroupProvider<K> implements GroupProvider {
ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder() ModelCriteriaBuilder<GroupModel> mcb = groupStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, role.getId()); .compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, role.getId());
try (Stream<MapGroupEntity<K>> toRemove = tx.read(mcb)) { try (Stream<MapGroupEntity<K>> toRemove = tx.read(withCriteria(mcb))) {
toRemove toRemove
.map(groupEntity -> session.groups().getGroupById(realm, groupEntity.getId().toString())) .map(groupEntity -> session.groups().getGroupById(realm, groupEntity.getId().toString()))
.forEach(groupModel -> groupModel.deleteRoleMapping(role)); .forEach(groupModel -> groupModel.deleteRoleMapping(role));

View file

@ -29,6 +29,7 @@ import java.util.function.Function;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
/** /**
* @author <a href="mailto:mkanis@redhat.com">Martin Kanis</a> * @author <a href="mailto:mkanis@redhat.com">Martin Kanis</a>
@ -66,7 +67,7 @@ public class MapUserLoginFailureProvider<K> implements UserLoginFailureProvider
LOG.tracef("getUserLoginFailure(%s, %s)%s", realm, userId, getShortStackTrace()); LOG.tracef("getUserLoginFailure(%s, %s)%s", realm, userId, getShortStackTrace());
return userLoginFailureTx.read(mcb) return userLoginFailureTx.read(withCriteria(mcb))
.findFirst() .findFirst()
.map(userLoginFailureEntityToAdapterFunc(realm)) .map(userLoginFailureEntityToAdapterFunc(realm))
.orElse(null); .orElse(null);
@ -80,7 +81,7 @@ public class MapUserLoginFailureProvider<K> implements UserLoginFailureProvider
LOG.tracef("addUserLoginFailure(%s, %s)%s", realm, userId, getShortStackTrace()); LOG.tracef("addUserLoginFailure(%s, %s)%s", realm, userId, getShortStackTrace());
MapUserLoginFailureEntity<K> userLoginFailureEntity = userLoginFailureTx.read(mcb).findFirst().orElse(null); MapUserLoginFailureEntity<K> userLoginFailureEntity = userLoginFailureTx.read(withCriteria(mcb)).findFirst().orElse(null);
if (userLoginFailureEntity == null) { if (userLoginFailureEntity == null) {
userLoginFailureEntity = new MapUserLoginFailureEntity<>(userLoginFailureStore.getKeyConvertor().yieldNewUniqueKey(), realm.getId(), userId); userLoginFailureEntity = new MapUserLoginFailureEntity<>(userLoginFailureStore.getKeyConvertor().yieldNewUniqueKey(), realm.getId(), userId);
@ -99,7 +100,7 @@ public class MapUserLoginFailureProvider<K> implements UserLoginFailureProvider
LOG.tracef("removeUserLoginFailure(%s, %s)%s", realm, userId, getShortStackTrace()); LOG.tracef("removeUserLoginFailure(%s, %s)%s", realm, userId, getShortStackTrace());
userLoginFailureTx.delete(userLoginFailureStore.getKeyConvertor().yieldNewUniqueKey(), mcb); userLoginFailureTx.delete(userLoginFailureStore.getKeyConvertor().yieldNewUniqueKey(), withCriteria(mcb));
} }
@Override @Override
@ -109,7 +110,7 @@ public class MapUserLoginFailureProvider<K> implements UserLoginFailureProvider
LOG.tracef("removeAllUserLoginFailures(%s)%s", realm, getShortStackTrace()); LOG.tracef("removeAllUserLoginFailures(%s)%s", realm, getShortStackTrace());
userLoginFailureTx.delete(userLoginFailureStore.getKeyConvertor().yieldNewUniqueKey(), mcb); userLoginFailureTx.delete(userLoginFailureStore.getKeyConvertor().yieldNewUniqueKey(), withCriteria(mcb));
} }
@Override @Override

View file

@ -38,6 +38,8 @@ import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator; import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.models.map.storage.QueryParameters.Order.ASCENDING;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
public class MapRealmProvider<K> implements RealmProvider { public class MapRealmProvider<K> implements RealmProvider {
@ -110,7 +112,7 @@ public class MapRealmProvider<K> implements RealmProvider {
ModelCriteriaBuilder<RealmModel> mcb = realmStore.createCriteriaBuilder() ModelCriteriaBuilder<RealmModel> mcb = realmStore.createCriteriaBuilder()
.compare(SearchableFields.NAME, Operator.EQ, name); .compare(SearchableFields.NAME, Operator.EQ, name);
K realmId = tx.read(mcb) K realmId = tx.read(withCriteria(mcb))
.findFirst() .findFirst()
.map(MapRealmEntity<K>::getId) .map(MapRealmEntity<K>::getId)
.orElse(null); .orElse(null);
@ -132,9 +134,8 @@ public class MapRealmProvider<K> implements RealmProvider {
} }
private Stream<RealmModel> getRealmsStream(ModelCriteriaBuilder<RealmModel> mcb) { private Stream<RealmModel> getRealmsStream(ModelCriteriaBuilder<RealmModel> mcb) {
return tx.read(mcb) return tx.read(withCriteria(mcb).orderBy(SearchableFields.NAME, ASCENDING))
.map(this::entityToAdapter) .map(this::entityToAdapter);
.sorted(RealmModel.COMPARE_BY_NAME);
} }
@Override @Override
@ -174,7 +175,7 @@ public class MapRealmProvider<K> implements RealmProvider {
ModelCriteriaBuilder<RealmModel> mcb = realmStore.createCriteriaBuilder() ModelCriteriaBuilder<RealmModel> mcb = realmStore.createCriteriaBuilder()
.compare(SearchableFields.CLIENT_INITIAL_ACCESS, Operator.EXISTS); .compare(SearchableFields.CLIENT_INITIAL_ACCESS, Operator.EXISTS);
tx.read(mcb) tx.read(withCriteria(mcb))
.map(e -> registerEntityForChanges(tx, e)) .map(e -> registerEntityForChanges(tx, e))
.forEach(MapRealmEntity<K>::removeExpiredClientInitialAccesses); .forEach(MapRealmEntity<K>::removeExpiredClientInitialAccesses);
} }

View file

@ -25,14 +25,15 @@ import org.keycloak.models.RealmModel;
import org.keycloak.models.RoleModel; import org.keycloak.models.RoleModel;
import org.keycloak.models.map.storage.MapKeycloakTransaction; import org.keycloak.models.map.storage.MapKeycloakTransaction;
import java.util.Comparator;
import java.util.Objects; import java.util.Objects;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.keycloak.models.map.storage.MapStorage; import org.keycloak.models.map.storage.MapStorage;
import static org.keycloak.common.util.StackUtil.getShortStackTrace; import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.utils.StreamsUtil.paginatedStream; import static org.keycloak.models.map.storage.QueryParameters.Order.ASCENDING;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
import org.keycloak.models.RoleContainerModel; import org.keycloak.models.RoleContainerModel;
import org.keycloak.models.RoleModel.SearchableFields; import org.keycloak.models.RoleModel.SearchableFields;
@ -47,19 +48,6 @@ public class MapRoleProvider<K> implements RoleProvider {
final MapKeycloakTransaction<K, MapRoleEntity<K>, RoleModel> tx; final MapKeycloakTransaction<K, MapRoleEntity<K>, RoleModel> tx;
private final MapStorage<K, MapRoleEntity<K>, RoleModel> roleStore; private final MapStorage<K, MapRoleEntity<K>, RoleModel> roleStore;
private static final Comparator<MapRoleEntity<?>> COMPARE_BY_NAME = new Comparator<MapRoleEntity<?>>() {
@Override
public int compare(MapRoleEntity<?> o1, MapRoleEntity<?> o2) {
String r1 = o1 == null ? null : o1.getName();
String r2 = o2 == null ? null : o2.getName();
return r1 == r2 ? 0
: r1 == null ? -1
: r2 == null ? 1
: r1.compareTo(r2);
}
};
public MapRoleProvider(KeycloakSession session, MapStorage<K, MapRoleEntity<K>, RoleModel> roleStore) { public MapRoleProvider(KeycloakSession session, MapStorage<K, MapRoleEntity<K>, RoleModel> roleStore) {
this.session = session; this.session = session;
this.roleStore = roleStore; this.roleStore = roleStore;
@ -99,7 +87,12 @@ public class MapRoleProvider<K> implements RoleProvider {
@Override @Override
public Stream<RoleModel> getRealmRolesStream(RealmModel realm, Integer first, Integer max) { public Stream<RoleModel> getRealmRolesStream(RealmModel realm, Integer first, Integer max) {
return paginatedStream(getRealmRolesStream(realm), first, max); ModelCriteriaBuilder<RoleModel> mcb = roleStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.IS_CLIENT_ROLE, Operator.NE, true);
return tx.read(withCriteria(mcb).pagination(first, max, SearchableFields.NAME))
.map(entityToAdapterFunc(realm));
} }
@Override @Override
@ -108,8 +101,7 @@ public class MapRoleProvider<K> implements RoleProvider {
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.IS_CLIENT_ROLE, Operator.NE, true); .compare(SearchableFields.IS_CLIENT_ROLE, Operator.NE, true);
return tx.read(mcb) return tx.read(withCriteria(mcb).orderBy(SearchableFields.NAME, ASCENDING))
.sorted(COMPARE_BY_NAME)
.map(entityToAdapterFunc(realm)); .map(entityToAdapterFunc(realm));
} }
@ -136,7 +128,12 @@ public class MapRoleProvider<K> implements RoleProvider {
@Override @Override
public Stream<RoleModel> getClientRolesStream(ClientModel client, Integer first, Integer max) { public Stream<RoleModel> getClientRolesStream(ClientModel client, Integer first, Integer max) {
return paginatedStream(getClientRolesStream(client), first, max); ModelCriteriaBuilder<RoleModel> mcb = roleStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, client.getRealm().getId())
.compare(SearchableFields.CLIENT_ID, Operator.EQ, client.getId());
return tx.read(withCriteria(mcb).pagination(first, max, SearchableFields.NAME))
.map(entityToAdapterFunc(client.getRealm()));
} }
@Override @Override
@ -145,8 +142,7 @@ public class MapRoleProvider<K> implements RoleProvider {
.compare(SearchableFields.REALM_ID, Operator.EQ, client.getRealm().getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, client.getRealm().getId())
.compare(SearchableFields.CLIENT_ID, Operator.EQ, client.getId()); .compare(SearchableFields.CLIENT_ID, Operator.EQ, client.getId());
return tx.read(mcb) return tx.read(withCriteria(mcb).orderBy(SearchableFields.NAME, ASCENDING))
.sorted(COMPARE_BY_NAME)
.map(entityToAdapterFunc(client.getRealm())); .map(entityToAdapterFunc(client.getRealm()));
} }
@Override @Override
@ -197,7 +193,7 @@ public class MapRoleProvider<K> implements RoleProvider {
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.NAME, Operator.ILIKE, name); .compare(SearchableFields.NAME, Operator.ILIKE, name);
String roleId = tx.read(mcb) String roleId = tx.read(withCriteria(mcb))
.map(entityToAdapterFunc(realm)) .map(entityToAdapterFunc(realm))
.map(RoleModel::getId) .map(RoleModel::getId)
.findFirst() .findFirst()
@ -218,7 +214,7 @@ public class MapRoleProvider<K> implements RoleProvider {
.compare(SearchableFields.CLIENT_ID, Operator.EQ, client.getId()) .compare(SearchableFields.CLIENT_ID, Operator.EQ, client.getId())
.compare(SearchableFields.NAME, Operator.ILIKE, name); .compare(SearchableFields.NAME, Operator.ILIKE, name);
String roleId = tx.read(mcb) String roleId = tx.read(withCriteria(mcb))
.map(entityToAdapterFunc(client.getRealm())) .map(entityToAdapterFunc(client.getRealm()))
.map(RoleModel::getId) .map(RoleModel::getId)
.findFirst() .findFirst()
@ -254,10 +250,8 @@ public class MapRoleProvider<K> implements RoleProvider {
roleStore.createCriteriaBuilder().compare(SearchableFields.DESCRIPTION, Operator.ILIKE, "%" + search + "%") roleStore.createCriteriaBuilder().compare(SearchableFields.DESCRIPTION, Operator.ILIKE, "%" + search + "%")
); );
Stream<MapRoleEntity<K>> s = tx.read(mcb) return tx.read(withCriteria(mcb).pagination(first, max, SearchableFields.NAME))
.sorted(COMPARE_BY_NAME); .map(entityToAdapterFunc(realm));
return paginatedStream(s.map(entityToAdapterFunc(realm)), first, max);
} }
@Override @Override
@ -272,10 +266,8 @@ public class MapRoleProvider<K> implements RoleProvider {
roleStore.createCriteriaBuilder().compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%"), roleStore.createCriteriaBuilder().compare(SearchableFields.NAME, Operator.ILIKE, "%" + search + "%"),
roleStore.createCriteriaBuilder().compare(SearchableFields.DESCRIPTION, Operator.ILIKE, "%" + search + "%") roleStore.createCriteriaBuilder().compare(SearchableFields.DESCRIPTION, Operator.ILIKE, "%" + search + "%")
); );
Stream<MapRoleEntity<K>> s = tx.read(mcb) return tx.read(withCriteria(mcb).pagination(first, max, SearchableFields.NAME))
.sorted(COMPARE_BY_NAME); .map(entityToAdapterFunc(client.getRealm()));
return paginatedStream(s,first, max).map(entityToAdapterFunc(client.getRealm()));
} }
@Override @Override

View file

@ -58,20 +58,19 @@ public interface MapKeycloakTransaction<K, V extends AbstractEntity<K>, M> exten
* transaction by methods {@link MapKeycloakTransaction#create}, {@link MapKeycloakTransaction#update}, * transaction by methods {@link MapKeycloakTransaction#create}, {@link MapKeycloakTransaction#update},
* {@link MapKeycloakTransaction#delete}, etc. * {@link MapKeycloakTransaction#delete}, etc.
* *
* @param mcb criteria to filter values * @param queryParameters parameters for the query like firstResult, maxResult, requested ordering, etc.
* @return values that fulfill the given criteria, that are updated based on changes in the current transaction * @return values that fulfill the given criteria, that are updated based on changes in the current transaction
*/ */
Stream<V> read(ModelCriteriaBuilder<M> mcb); Stream<V> read(QueryParameters<M> queryParameters);
/** /**
* Returns a number of values present in the underlying storage that fulfill the given criteria with respect to * Returns a number of values present in the underlying storage that fulfill the given criteria with respect to
* changes done in the current transaction. * changes done in the current transaction.
* *
* @param mcb criteria to filter values * @param queryParameters parameters for the query like firstResult, maxResult, requested ordering, etc.
* @return number of values present in the storage that fulfill the given criteria * @return number of values present in the storage that fulfill the given criteria
*/ */
long getCount(ModelCriteriaBuilder<M> mcb); long getCount(QueryParameters<M> queryParameters);
/** /**
* Instructs this transaction to force-update the {@code value} associated with the identifier {@code value.getId()} in the * Instructs this transaction to force-update the {@code value} associated with the identifier {@code value.getId()} in the
@ -116,8 +115,9 @@ public interface MapKeycloakTransaction<K, V extends AbstractEntity<K>, M> exten
* *
* @param artificialKey key to record the transaction with, must be a key that does not exist in this transaction to * @param artificialKey key to record the transaction with, must be a key that does not exist in this transaction to
* prevent collisions with other operations in this transaction * prevent collisions with other operations in this transaction
* @param mcb criteria to delete values * @param queryParameters parameters for the query like firstResult, maxResult, requested ordering, etc.
* @return number of removed objects (might return {@code -1} if not supported)
*/ */
long delete(K artificialKey, ModelCriteriaBuilder<M> mcb); long delete(K artificialKey, QueryParameters<M> queryParameters);
} }

View file

@ -20,12 +20,14 @@ import org.keycloak.models.KeycloakSession;
import org.keycloak.models.map.common.AbstractEntity; import org.keycloak.models.map.common.AbstractEntity;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
/** /**
* Implementation of this interface interacts with a persistence storage storing various entities, e.g. users, realms. * Implementation of this interface interacts with a persistence storage storing various entities, e.g. users, realms.
* It contains basic object CRUD operations as well as bulk {@link #read(org.keycloak.models.map.storage.ModelCriteriaBuilder)} * It contains basic object CRUD operations as well as bulk {@link #read(org.keycloak.models.map.storage.QueryParameters)}
* and bulk {@link #delete(org.keycloak.models.map.storage.ModelCriteriaBuilder)} operations, * and bulk {@link #delete(org.keycloak.models.map.storage.QueryParameters)} operations,
* and operation for determining the number of the objects satisfying given criteria * and operation for determining the number of the objects satisfying given criteria
* ({@link #getCount(org.keycloak.models.map.storage.ModelCriteriaBuilder)}). * ({@link #getCount(org.keycloak.models.map.storage.QueryParameters)}).
* *
* @author hmlnarik * @author hmlnarik
* @param <K> Type of the primary key. Various storages can * @param <K> Type of the primary key. Various storages can
@ -62,29 +64,27 @@ public interface MapStorage<K, V extends AbstractEntity<K>, M> {
* Returns stream of objects satisfying given {@code criteria} from the storage. * Returns stream of objects satisfying given {@code criteria} from the storage.
* The criteria are specified in the given criteria builder based on model properties. * The criteria are specified in the given criteria builder based on model properties.
* *
* @param criteria Criteria filtering out the object, originally obtained * @param queryParameters parameters for the query like firstResult, maxResult, requested ordering, etc.
* from {@link #createCriteriaBuilder()} method of this object.
* If {@code null}, it returns an empty stream.
* @return Stream of objects. Never returns {@code null}. * @return Stream of objects. Never returns {@code null}.
* @throws IllegalStateException If {@code criteria} is not compatible, i.e. has not been originally created * @throws IllegalStateException If {@code criteria} is not compatible, i.e. has not been originally created
* by the {@link #createCriteriaBuilder()} method of this object. * by the {@link #createCriteriaBuilder()} method of this object.
*/ */
Stream<V> read(ModelCriteriaBuilder<M> criteria); Stream<V> read(QueryParameters<M> queryParameters);
/** /**
* Returns the number of objects satisfying given {@code criteria} from the storage. * Returns the number of objects satisfying given {@code criteria} from the storage.
* The criteria are specified in the given criteria builder based on model properties. * The criteria are specified in the given criteria builder based on model properties.
* *
* @param criteria * @param queryParameters parameters for the query like firstResult, maxResult, requested ordering, etc.
* @return Number of objects. Never returns {@code null}. * @return Number of objects. Never returns {@code null}.
* @throws IllegalStateException If {@code criteria} is not compatible, i.e. has not been originally created * @throws IllegalStateException If {@code criteria} is not compatible, i.e. has not been originally created
* by the {@link #createCriteriaBuilder()} method of this object. * by the {@link #createCriteriaBuilder()} method of this object.
*/ */
long getCount(ModelCriteriaBuilder<M> criteria); long getCount(QueryParameters<M> queryParameters);
/** /**
* Updates the object with the key of the {@code value}'s ID in the storage if it already exists. * Updates the object with the key of the {@code value}'s ID in the storage if it already exists.
* @param key Primary key of the object to update *
* @param value Updated value * @param value Updated value
* @throws NullPointerException if the object or its {@code id} is {@code null} * @throws NullPointerException if the object or its {@code id} is {@code null}
* @see AbstractEntity#getId() * @see AbstractEntity#getId()
@ -100,12 +100,12 @@ public interface MapStorage<K, V extends AbstractEntity<K>, M> {
/** /**
* Deletes objects that match the given criteria. * Deletes objects that match the given criteria.
* @param criteria * @param queryParameters parameters for the query like firstResult, maxResult, requested ordering, etc.
* @return Number of removed objects (might return {@code -1} if not supported) * @return Number of removed objects (might return {@code -1} if not supported)
* @throws IllegalStateException If {@code criteria} is not compatible, i.e. has not been originally created * @throws IllegalStateException If {@code criteria} is not compatible, i.e. has not been originally created
* by the {@link #createCriteriaBuilder()} method of this object. * by the {@link #createCriteriaBuilder()} method of this object.
*/ */
long delete(ModelCriteriaBuilder<M> criteria); long delete(QueryParameters<M> queryParameters);
/** /**
@ -122,7 +122,6 @@ public interface MapStorage<K, V extends AbstractEntity<K>, M> {
* @return See description. Never returns {@code null} * @return See description. Never returns {@code null}
*/ */
ModelCriteriaBuilder<M> createCriteriaBuilder(); ModelCriteriaBuilder<M> createCriteriaBuilder();
/** /**
* Creates a {@code MapKeycloakTransaction} object that tracks a new transaction related to this storage. * Creates a {@code MapKeycloakTransaction} object that tracks a new transaction related to this storage.

View file

@ -0,0 +1,139 @@
package org.keycloak.models.map.storage;
import org.keycloak.storage.SearchableModelField;
import java.util.LinkedList;
import java.util.List;
import static org.keycloak.models.map.storage.QueryParameters.Order.ASCENDING;
/**
* Wraps together parameters for querying storage e.g. number of results to return, requested order or filtering criteria
*
* @param <M> Provide entity specific type checking, for example, when we create {@code QueryParameters}
* instance for Users, M is equal to UserModel, hence we are not able, for example, to order result by a
* {@link SearchableModelField} defined for clients in {@link org.keycloak.models.ClientModel}.
*/
public class QueryParameters<M> {
private Integer offset;
private Integer limit;
private final List<OrderBy<M>> orderBy = new LinkedList<>();
private ModelCriteriaBuilder<M> mcb;
public QueryParameters() {
}
public QueryParameters(ModelCriteriaBuilder<M> mcb) {
this.mcb = mcb;
}
/**
* Creates a new {@code QueryParameters} instance initialized with {@link ModelCriteriaBuilder}
*
* @param mcb filtering criteria
* @param <M> model type
* @return a new {@code QueryParameters} instance
*/
public static <M> QueryParameters<M> withCriteria(ModelCriteriaBuilder<M> mcb) {
return new QueryParameters<>(mcb);
}
/**
* Sets pagination (offset, limit and orderBy) parameters to {@code QueryParameters}
*
* @param offset
* @param limit
* @param orderByAscField
* @return this object
*/
public QueryParameters<M> pagination(Integer offset, Integer limit, SearchableModelField<M> orderByAscField) {
this.offset = offset;
this.limit = limit;
this.orderBy.add(new OrderBy<>(orderByAscField, ASCENDING));
return this;
}
/**
* Sets orderBy parameter; can be called repeatedly; fields are stored in a list where the first field has highest
* priority when determining order; e.g. the second field is compared only when values for the first field are equal
*
* @param searchableModelField
* @return this object
*/
public QueryParameters<M> orderBy(SearchableModelField<M> searchableModelField, Order order) {
orderBy.add(new OrderBy<>(searchableModelField, order));
return this;
}
/**
* Sets offset parameter
*
* @param offset
* @return
*/
public QueryParameters<M> offset(Integer offset) {
this.offset = offset;
return this;
}
/**
* Sets limit parameter
*
* @param limit
* @return
*/
public QueryParameters<M> limit(Integer limit) {
this.limit = limit;
return this;
}
public Integer getOffset() {
return offset;
}
public Integer getLimit() {
return limit;
}
public ModelCriteriaBuilder<M> getModelCriteriaBuilder() {
return mcb;
}
public List<OrderBy<M>> getOrderBy() {
return orderBy;
}
/**
* Enum for ascending or descending ordering
*/
public enum Order {
ASCENDING,
DESCENDING
}
/**
* Wrapper class for a field with its {@code Order}, ascending or descending
*
* @param <M>
*/
public static class OrderBy<M> {
private final SearchableModelField<M> modelField;
private final Order order;
public OrderBy(SearchableModelField<M> modelField, Order order) {
this.modelField = modelField;
this.order = order;
}
public SearchableModelField<M> getModelField() {
return modelField;
}
public Order getOrder() {
return order;
}
}
}

View file

@ -30,6 +30,8 @@ import org.jboss.logging.Logger;
import org.keycloak.models.map.storage.MapKeycloakTransaction; import org.keycloak.models.map.storage.MapKeycloakTransaction;
import org.keycloak.models.map.storage.MapStorage; import org.keycloak.models.map.storage.MapStorage;
import org.keycloak.models.map.storage.ModelCriteriaBuilder; import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.QueryParameters;
import org.keycloak.utils.StreamsUtil;
public class ConcurrentHashMapKeycloakTransaction<K, V extends AbstractEntity<K>, M> implements MapKeycloakTransaction<K, V, M> { public class ConcurrentHashMapKeycloakTransaction<K, V extends AbstractEntity<K>, M> implements MapKeycloakTransaction<K, V, M> {
@ -133,11 +135,11 @@ public class ConcurrentHashMapKeycloakTransaction<K, V extends AbstractEntity<K>
* Returns the stream of records that match given criteria and includes changes made in this transaction, i.e. * Returns the stream of records that match given criteria and includes changes made in this transaction, i.e.
* the result contains updates and excludes records that have been deleted in this transaction. * the result contains updates and excludes records that have been deleted in this transaction.
* *
* @param mcb * @param queryParameters
* @return * @return
*/ */
@Override @Override
public Stream<V> read(ModelCriteriaBuilder<M> mcb) { public Stream<V> read(QueryParameters<M> queryParameters) {
Predicate<? super V> filterOutAllBulkDeletedObjects = tasks.values().stream() Predicate<? super V> filterOutAllBulkDeletedObjects = tasks.values().stream()
.filter(BulkDeleteOperation.class::isInstance) .filter(BulkDeleteOperation.class::isInstance)
.map(BulkDeleteOperation.class::cast) .map(BulkDeleteOperation.class::cast)
@ -145,7 +147,9 @@ public class ConcurrentHashMapKeycloakTransaction<K, V extends AbstractEntity<K>
.reduce(Predicate::and) .reduce(Predicate::and)
.orElse(v -> true); .orElse(v -> true);
Stream<V> updatedAndNotRemovedObjectsStream = this.map.read(mcb) ModelCriteriaBuilder<M> mcb = queryParameters.getModelCriteriaBuilder();
Stream<V> updatedAndNotRemovedObjectsStream = this.map.read(queryParameters)
.filter(filterOutAllBulkDeletedObjects) .filter(filterOutAllBulkDeletedObjects)
.map(this::getUpdated) // If the object has been removed, tx.get will return null, otherwise it will return me.getValue() .map(this::getUpdated) // If the object has been removed, tx.get will return null, otherwise it will return me.getValue()
.filter(Objects::nonNull); .filter(Objects::nonNull);
@ -159,12 +163,17 @@ public class ConcurrentHashMapKeycloakTransaction<K, V extends AbstractEntity<K>
updatedAndNotRemovedObjectsStream updatedAndNotRemovedObjectsStream
); );
return res; if (!queryParameters.getOrderBy().isEmpty()) {
res = res.sorted(MapFieldPredicates.getComparator(queryParameters.getOrderBy().stream()));
}
return StreamsUtil.paginatedStream(res, queryParameters.getOffset(), queryParameters.getLimit());
} }
@Override @Override
public long getCount(ModelCriteriaBuilder<M> mcb) { public long getCount(QueryParameters<M> queryParameters) {
return read(mcb).count(); return read(queryParameters).count();
} }
@Override @Override
@ -210,11 +219,11 @@ public class ConcurrentHashMapKeycloakTransaction<K, V extends AbstractEntity<K>
@Override @Override
public long delete(K artificialKey, ModelCriteriaBuilder<M> mcb) { public long delete(K artificialKey, QueryParameters<M> queryParameters) {
log.tracef("Adding operation DELETE_BULK"); log.tracef("Adding operation DELETE_BULK");
// Remove all tasks that create / update / delete objects deleted by the bulk removal. // Remove all tasks that create / update / delete objects deleted by the bulk removal.
final BulkDeleteOperation bdo = new BulkDeleteOperation(mcb); final BulkDeleteOperation bdo = new BulkDeleteOperation(queryParameters);
Predicate<V> filterForNonDeletedObjects = bdo.getFilterForNonDeletedObjects(); Predicate<V> filterForNonDeletedObjects = bdo.getFilterForNonDeletedObjects();
long res = 0; long res = 0;
for (Iterator<Entry<K, MapTaskWithValue>> it = tasks.entrySet().iterator(); it.hasNext();) { for (Iterator<Entry<K, MapTaskWithValue>> it = tasks.entrySet().iterator(); it.hasNext();) {
@ -355,29 +364,29 @@ public class ConcurrentHashMapKeycloakTransaction<K, V extends AbstractEntity<K>
private class BulkDeleteOperation extends MapTaskWithValue { private class BulkDeleteOperation extends MapTaskWithValue {
private final ModelCriteriaBuilder<M> mcb; private final QueryParameters<M> queryParameters;
public BulkDeleteOperation(ModelCriteriaBuilder<M> mcb) { public BulkDeleteOperation(QueryParameters<M> queryParameters) {
super(null); super(null);
this.mcb = mcb; this.queryParameters = queryParameters;
} }
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void execute() { public void execute() {
map.delete(mcb); map.delete(queryParameters);
} }
public Predicate<V> getFilterForNonDeletedObjects() { public Predicate<V> getFilterForNonDeletedObjects() {
if (! (mcb instanceof MapModelCriteriaBuilder)) { if (! (queryParameters.getModelCriteriaBuilder() instanceof MapModelCriteriaBuilder)) {
return t -> true; return t -> true;
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final MapModelCriteriaBuilder<K, V, M> mmcb = (MapModelCriteriaBuilder<K, V, M>) mcb; final MapModelCriteriaBuilder<K, V, M> mmcb = (MapModelCriteriaBuilder<K, V, M>) queryParameters.getModelCriteriaBuilder();
Predicate<? super V> entityFilter = mmcb.getEntityFilter(); Predicate<? super V> entityFilter = mmcb.getEntityFilter();
Predicate<? super K> keyFilter = ((MapModelCriteriaBuilder) mcb).getKeyFilter(); Predicate<? super K> keyFilter = mmcb.getKeyFilter();
return v -> v == null || ! (keyFilter.test(v.getId()) && entityFilter.test(v)); return v -> v == null || ! (keyFilter.test(v.getId()) && entityFilter.test(v));
} }
@ -387,7 +396,7 @@ public class ConcurrentHashMapKeycloakTransaction<K, V extends AbstractEntity<K>
} }
private long getCount() { private long getCount() {
return map.getCount(mcb); return map.getCount(queryParameters);
} }
} }
} }

View file

@ -21,18 +21,23 @@ import org.keycloak.models.map.storage.MapKeycloakTransaction;
import org.keycloak.models.map.common.AbstractEntity; import org.keycloak.models.map.common.AbstractEntity;
import org.keycloak.models.map.storage.MapStorage; import org.keycloak.models.map.storage.MapStorage;
import org.keycloak.models.map.storage.ModelCriteriaBuilder; import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.QueryParameters;
import org.keycloak.storage.SearchableModelField; import org.keycloak.storage.SearchableModelField;
import java.util.Comparator;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.keycloak.models.map.storage.chm.MapModelCriteriaBuilder.UpdatePredicatesFunc;
import org.keycloak.models.map.storage.StringKeyConvertor; import org.keycloak.models.map.storage.StringKeyConvertor;
import java.util.Iterator; import org.keycloak.models.map.storage.chm.MapModelCriteriaBuilder.UpdatePredicatesFunc;
import java.util.Objects; import java.util.Objects;
import java.util.function.Predicate; import java.util.function.Predicate;
import static org.keycloak.utils.StreamsUtil.paginatedStream;
/** /**
* *
* @author hmlnarik * @author hmlnarik
@ -74,10 +79,11 @@ public class ConcurrentHashMapStorage<K, V extends AbstractEntity<K>, M> impleme
} }
@Override @Override
public long delete(ModelCriteriaBuilder<M> criteria) { public long delete(QueryParameters<M> queryParameters) {
long res; ModelCriteriaBuilder<M> criteria = queryParameters.getModelCriteriaBuilder();
if (criteria == null) { if (criteria == null) {
res = store.size(); long res = store.size();
store.clear(); store.clear();
return res; return res;
} }
@ -88,15 +94,21 @@ public class ConcurrentHashMapStorage<K, V extends AbstractEntity<K>, M> impleme
} }
Predicate<? super K> keyFilter = b.getKeyFilter(); Predicate<? super K> keyFilter = b.getKeyFilter();
Predicate<? super V> entityFilter = b.getEntityFilter(); Predicate<? super V> entityFilter = b.getEntityFilter();
res = 0; Stream<Entry<K, V>> storeStream = store.entrySet().stream();
for (Iterator<Entry<K, V>> iterator = store.entrySet().iterator(); iterator.hasNext();) { final AtomicLong res = new AtomicLong(0);
Entry<K, V> next = iterator.next();
if (keyFilter.test(next.getKey()) && entityFilter.test(next.getValue())) { if (!queryParameters.getOrderBy().isEmpty()) {
res++; Comparator<V> comparator = MapFieldPredicates.getComparator(queryParameters.getOrderBy().stream());
iterator.remove(); storeStream = storeStream.sorted((entry1, entry2) -> comparator.compare(entry1.getValue(), entry2.getValue()));
}
} }
return res;
paginatedStream(storeStream.filter(next -> keyFilter.test(next.getKey()) && entityFilter.test(next.getValue()))
, queryParameters.getOffset(), queryParameters.getLimit())
.peek(item -> {res.incrementAndGet();})
.map(Entry::getKey)
.forEach(store::remove);
return res.get();
} }
@Override @Override
@ -117,7 +129,9 @@ public class ConcurrentHashMapStorage<K, V extends AbstractEntity<K>, M> impleme
} }
@Override @Override
public Stream<V> read(ModelCriteriaBuilder<M> criteria) { public Stream<V> read(QueryParameters<M> queryParameters) {
ModelCriteriaBuilder<M> criteria = queryParameters.getModelCriteriaBuilder();
if (criteria == null) { if (criteria == null) {
return Stream.empty(); return Stream.empty();
} }
@ -135,8 +149,8 @@ public class ConcurrentHashMapStorage<K, V extends AbstractEntity<K>, M> impleme
} }
@Override @Override
public long getCount(ModelCriteriaBuilder<M> criteria) { public long getCount(QueryParameters<M> queryParameters) {
return read(criteria).count(); return read(queryParameters).count();
} }
} }

View file

@ -57,6 +57,8 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
/** /**
* *
* @author hmlnarik * @author hmlnarik
@ -171,7 +173,7 @@ public class ConcurrentHashMapStorageProviderFactory implements AmphibianProvide
LOG.debugf("Storing contents to %s", f.getCanonicalPath()); LOG.debugf("Storing contents to %s", f.getCanonicalPath());
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final ModelCriteriaBuilder readAllCriteria = store.createCriteriaBuilder(); final ModelCriteriaBuilder readAllCriteria = store.createCriteriaBuilder();
Serialization.MAPPER.writeValue(f, store.read(readAllCriteria)); Serialization.MAPPER.writeValue(f, store.read(withCriteria(readAllCriteria)));
} else { } else {
LOG.debugf("Not storing contents of %s because directory not set", mapName); LOG.debugf("Not storing contents of %s because directory not set", mapName);
} }

View file

@ -43,7 +43,10 @@ import org.keycloak.models.map.group.MapGroupEntity;
import org.keycloak.models.map.loginFailure.MapUserLoginFailureEntity; import org.keycloak.models.map.loginFailure.MapUserLoginFailureEntity;
import org.keycloak.models.map.realm.MapRealmEntity; import org.keycloak.models.map.realm.MapRealmEntity;
import org.keycloak.models.map.role.MapRoleEntity; import org.keycloak.models.map.role.MapRoleEntity;
import org.keycloak.models.map.storage.QueryParameters;
import org.keycloak.storage.SearchableModelField; import org.keycloak.storage.SearchableModelField;
import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.keycloak.models.map.storage.chm.MapModelCriteriaBuilder.UpdatePredicatesFunc; import org.keycloak.models.map.storage.chm.MapModelCriteriaBuilder.UpdatePredicatesFunc;
@ -56,11 +59,11 @@ import org.keycloak.sessions.RootAuthenticationSessionModel;
import org.keycloak.storage.StorageId; import org.keycloak.storage.StorageId;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Stream;
import org.keycloak.models.map.storage.CriterionNotSupportedException; import org.keycloak.models.map.storage.CriterionNotSupportedException;
import static org.keycloak.models.UserSessionModel.CORRESPONDING_SESSION_ID; import static org.keycloak.models.UserSessionModel.CORRESPONDING_SESSION_ID;
@ -89,10 +92,11 @@ public class MapFieldPredicates {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final Map<Class<?>, Map> PREDICATES = new HashMap<>(); private static final Map<Class<?>, Map> PREDICATES = new HashMap<>();
private static final Map<SearchableModelField<?>, Comparator<?>> COMPARATORS = new HashMap<>();
static { static {
put(REALM_PREDICATES, RealmModel.SearchableFields.NAME, MapRealmEntity::getName); put(REALM_PREDICATES, RealmModel.SearchableFields.NAME, MapRealmEntity::getName);
put(REALM_PREDICATES, RealmModel.SearchableFields.CLIENT_INITIAL_ACCESS, MapRealmEntity::getClientInitialAccesses); putIncomparable(REALM_PREDICATES, RealmModel.SearchableFields.CLIENT_INITIAL_ACCESS, MapRealmEntity::getClientInitialAccesses);
put(REALM_PREDICATES, RealmModel.SearchableFields.COMPONENT_PROVIDER_TYPE, MapFieldPredicates::checkRealmsWithComponentType); put(REALM_PREDICATES, RealmModel.SearchableFields.COMPONENT_PROVIDER_TYPE, MapFieldPredicates::checkRealmsWithComponentType);
put(CLIENT_PREDICATES, ClientModel.SearchableFields.REALM_ID, MapClientEntity::getRealmId); put(CLIENT_PREDICATES, ClientModel.SearchableFields.REALM_ID, MapClientEntity::getRealmId);
@ -136,7 +140,7 @@ public class MapFieldPredicates {
put(AUTHENTICATION_SESSION_PREDICATES, RootAuthenticationSessionModel.SearchableFields.REALM_ID, MapRootAuthenticationSessionEntity::getRealmId); put(AUTHENTICATION_SESSION_PREDICATES, RootAuthenticationSessionModel.SearchableFields.REALM_ID, MapRootAuthenticationSessionEntity::getRealmId);
put(AUTHENTICATION_SESSION_PREDICATES, RootAuthenticationSessionModel.SearchableFields.TIMESTAMP, MapRootAuthenticationSessionEntity::getTimestamp); put(AUTHENTICATION_SESSION_PREDICATES, RootAuthenticationSessionModel.SearchableFields.TIMESTAMP, MapRootAuthenticationSessionEntity::getTimestamp);
put(AUTHZ_RESOURCE_SERVER_PREDICATES, ResourceServer.SearchableFields.ID, MapResourceServerEntity::getId); put(AUTHZ_RESOURCE_SERVER_PREDICATES, ResourceServer.SearchableFields.ID, predicateForKeyField(MapResourceServerEntity::getId));
put(AUTHZ_RESOURCE_PREDICATES, Resource.SearchableFields.ID, predicateForKeyField(MapResourceEntity::getId)); put(AUTHZ_RESOURCE_PREDICATES, Resource.SearchableFields.ID, predicateForKeyField(MapResourceEntity::getId));
put(AUTHZ_RESOURCE_PREDICATES, Resource.SearchableFields.NAME, MapResourceEntity::getName); put(AUTHZ_RESOURCE_PREDICATES, Resource.SearchableFields.NAME, MapResourceEntity::getName);
@ -207,9 +211,16 @@ public class MapFieldPredicates {
PREDICATES.put(UserLoginFailureModel.class, USER_LOGIN_FAILURE_PREDICATES); PREDICATES.put(UserLoginFailureModel.class, USER_LOGIN_FAILURE_PREDICATES);
} }
private static <K, V extends AbstractEntity<K>, M> void put( private static <K, V extends AbstractEntity<K>, M, L extends Comparable<L>> void put(
Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> map, Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> map,
SearchableModelField<M> field, Function<V, Object> extractor) { SearchableModelField<M> field, Function<V, L> extractor) {
COMPARATORS.put(field, Comparator.comparing(extractor));
map.put(field, (mcb, op, values) -> mcb.fieldCompare(op, extractor, values));
}
private static <K, V extends AbstractEntity<K>, M> void putIncomparable(
Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> map,
SearchableModelField<M> field, Function<V, Object> extractor) {
map.put(field, (mcb, op, values) -> mcb.fieldCompare(op, extractor, values)); map.put(field, (mcb, op, values) -> mcb.fieldCompare(op, extractor, values));
} }
@ -219,7 +230,7 @@ public class MapFieldPredicates {
map.put(field, function); map.put(field, function);
} }
private static <V extends AbstractEntity<?>> Function<V, Object> predicateForKeyField(Function<V, Object> extractor) { private static <V extends AbstractEntity<?>> Function<V, String> predicateForKeyField(Function<V, Object> extractor) {
return entity -> { return entity -> {
Object o = extractor.apply(entity); Object o = extractor.apply(entity);
return o == null ? null : o.toString(); return o == null ? null : o.toString();
@ -483,10 +494,35 @@ public class MapFieldPredicates {
protected static <K, V extends AbstractEntity<K>, M> Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> basePredicates(SearchableModelField<M> idField) { protected static <K, V extends AbstractEntity<K>, M> Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> basePredicates(SearchableModelField<M> idField) {
Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates = new HashMap<>(); Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates = new HashMap<>();
fieldPredicates.put(idField, (o, op, values) -> o.idCompare(op, values)); fieldPredicates.put(idField, MapModelCriteriaBuilder::idCompare);
return fieldPredicates; return fieldPredicates;
} }
public static <K, V extends AbstractEntity<K>, M> Comparator<V> getComparator(QueryParameters.OrderBy<M> orderBy) {
SearchableModelField<M> searchableModelField = orderBy.getModelField();
QueryParameters.Order order = orderBy.getOrder();
@SuppressWarnings("unchecked")
Comparator<V> comparator = (Comparator<V>) COMPARATORS.get(searchableModelField);
if (comparator == null) {
throw new IllegalArgumentException("Comparator for field " + searchableModelField.getName() + " is not configured.");
}
if (order == QueryParameters.Order.DESCENDING) {
return comparator.reversed();
}
return comparator;
}
@SuppressWarnings("unchecked")
public static <K, V extends AbstractEntity<K>, M> Comparator<V> getComparator(Stream<QueryParameters.OrderBy<M>> ordering) {
return (Comparator<V>) ordering.map(MapFieldPredicates::getComparator)
.reduce(Comparator::thenComparing)
.orElseThrow(() -> new IllegalArgumentException("Cannot create comparator for " + ordering));
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public static <K, V extends AbstractEntity<K>, M> Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> getPredicates(Class<M> clazz) { public static <K, V extends AbstractEntity<K>, M> Map<SearchableModelField<M>, UpdatePredicatesFunc<K, V, M>> getPredicates(Class<M> clazz) {
return PREDICATES.get(clazz); return PREDICATES.get(clazz);

View file

@ -23,12 +23,15 @@ import org.keycloak.models.map.common.AbstractEntity;
import org.keycloak.models.map.storage.MapKeycloakTransaction; import org.keycloak.models.map.storage.MapKeycloakTransaction;
import org.keycloak.models.map.storage.ModelCriteriaBuilder; import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator; import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import org.keycloak.models.map.storage.QueryParameters;
import org.keycloak.models.map.storage.StringKeyConvertor; import org.keycloak.models.map.storage.StringKeyConvertor;
import org.keycloak.models.map.userSession.MapAuthenticatedClientSessionEntity; import org.keycloak.models.map.userSession.MapAuthenticatedClientSessionEntity;
import org.keycloak.models.map.userSession.MapUserSessionEntity; import org.keycloak.models.map.userSession.MapUserSessionEntity;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
/** /**
* User session storage with a naive implementation of referential integrity in client to user session relation, restricted to * User session storage with a naive implementation of referential integrity in client to user session relation, restricted to
* ON DELETE CASCADE functionality. * ON DELETE CASCADE functionality.
@ -49,17 +52,19 @@ public class UserSessionConcurrentHashMapStorage<K> extends ConcurrentHashMapSto
} }
@Override @Override
public long delete(K artificialKey, ModelCriteriaBuilder<UserSessionModel> mcb) { public long delete(K artificialKey, QueryParameters<UserSessionModel> queryParameters) {
Set<K> ids = read(mcb).map(AbstractEntity::getId).collect(Collectors.toSet()); ModelCriteriaBuilder<UserSessionModel> mcb = queryParameters.getModelCriteriaBuilder();
Set<K> ids = read(queryParameters).map(AbstractEntity::getId).collect(Collectors.toSet());
ModelCriteriaBuilder<AuthenticatedClientSessionModel> csMcb = clientSessionStore.createCriteriaBuilder().compare(AuthenticatedClientSessionModel.SearchableFields.USER_SESSION_ID, Operator.IN, ids); ModelCriteriaBuilder<AuthenticatedClientSessionModel> csMcb = clientSessionStore.createCriteriaBuilder().compare(AuthenticatedClientSessionModel.SearchableFields.USER_SESSION_ID, Operator.IN, ids);
clientSessionTr.delete(artificialKey, csMcb); clientSessionTr.delete(artificialKey, withCriteria(csMcb));
return super.delete(artificialKey, mcb); return super.delete(artificialKey, queryParameters);
} }
@Override @Override
public void delete(K key) { public void delete(K key) {
ModelCriteriaBuilder<AuthenticatedClientSessionModel> csMcb = clientSessionStore.createCriteriaBuilder().compare(AuthenticatedClientSessionModel.SearchableFields.USER_SESSION_ID, Operator.EQ, key); ModelCriteriaBuilder<AuthenticatedClientSessionModel> csMcb = clientSessionStore.createCriteriaBuilder().compare(AuthenticatedClientSessionModel.SearchableFields.USER_SESSION_ID, Operator.EQ, key);
clientSessionTr.delete(key, csMcb); clientSessionTr.delete(key, withCriteria(csMcb));
super.delete(key); super.delete(key);
} }

View file

@ -65,8 +65,6 @@ public class MapUserEntity<K> implements AbstractEntity<K> {
private String serviceAccountClientLink; private String serviceAccountClientLink;
private int notBefore; private int notBefore;
static Comparator<MapUserEntity<?>> COMPARE_BY_USERNAME = Comparator.comparing(MapUserEntity::getUsername);
/** /**
* Flag signalizing that any of the setters has been meaningfully used. * Flag signalizing that any of the setters has been meaningfully used.
*/ */

View file

@ -70,7 +70,8 @@ import static org.keycloak.models.UserModel.FIRST_NAME;
import static org.keycloak.models.UserModel.LAST_NAME; import static org.keycloak.models.UserModel.LAST_NAME;
import static org.keycloak.models.UserModel.USERNAME; import static org.keycloak.models.UserModel.USERNAME;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.utils.StreamsUtil.paginatedStream; import static org.keycloak.models.map.storage.QueryParameters.Order.ASCENDING;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialStore.Streams { public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialStore.Streams {
@ -173,7 +174,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.IDP_AND_USER, Operator.EQ, socialProvider); .compare(SearchableFields.IDP_AND_USER, Operator.EQ, socialProvider);
tx.read(mcb) tx.read(withCriteria(mcb))
.map(e -> registerEntityForChanges(tx, e)) .map(e -> registerEntityForChanges(tx, e))
.forEach(userEntity -> userEntity.removeFederatedIdentity(socialProvider)); .forEach(userEntity -> userEntity.removeFederatedIdentity(socialProvider));
} }
@ -208,8 +209,8 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder() ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.IDP_AND_USER, Operator.EQ, socialLink.getIdentityProvider(), socialLink.getUserId()); .compare(SearchableFields.IDP_AND_USER, Operator.EQ, socialLink.getIdentityProvider(), socialLink.getUserId());
return tx.read(mcb) return tx.read(withCriteria(mcb))
.collect(Collectors.collectingAndThen( .collect(Collectors.collectingAndThen(
Collectors.toList(), Collectors.toList(),
list -> { list -> {
@ -298,7 +299,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, client.getRealm().getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, client.getRealm().getId())
.compare(SearchableFields.SERVICE_ACCOUNT_CLIENT, Operator.EQ, client.getId()); .compare(SearchableFields.SERVICE_ACCOUNT_CLIENT, Operator.EQ, client.getId());
return tx.read(mcb) return tx.read(withCriteria(mcb))
.collect(Collectors.collectingAndThen( .collect(Collectors.collectingAndThen(
Collectors.toList(), Collectors.toList(),
list -> { list -> {
@ -321,7 +322,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.USERNAME, Operator.EQ, username); .compare(SearchableFields.USERNAME, Operator.EQ, username);
if (tx.getCount(mcb) > 0) { if (tx.getCount(withCriteria(mcb)) > 0) {
throw new ModelDuplicateException("User with username '" + username + "' in realm " + realm.getName() + " already exists" ); throw new ModelDuplicateException("User with username '" + username + "' in realm " + realm.getName() + " already exists" );
} }
@ -362,7 +363,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder() ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
tx.delete(userStore.getKeyConvertor().yieldNewUniqueKey(), mcb); tx.delete(userStore.getKeyConvertor().yieldNewUniqueKey(), withCriteria(mcb));
} }
@Override @Override
@ -372,7 +373,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.FEDERATION_LINK, Operator.EQ, storageProviderId); .compare(SearchableFields.FEDERATION_LINK, Operator.EQ, storageProviderId);
tx.delete(userStore.getKeyConvertor().yieldNewUniqueKey(), mcb); tx.delete(userStore.getKeyConvertor().yieldNewUniqueKey(), withCriteria(mcb));
} }
@Override @Override
@ -382,7 +383,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.FEDERATION_LINK, Operator.EQ, storageProviderId); .compare(SearchableFields.FEDERATION_LINK, Operator.EQ, storageProviderId);
try (Stream<MapUserEntity<K>> s = tx.read(mcb)) { try (Stream<MapUserEntity<K>> s = tx.read(withCriteria(mcb))) {
s.map(e -> registerEntityForChanges(tx, e)) s.map(e -> registerEntityForChanges(tx, e))
.forEach(userEntity -> userEntity.setFederationLink(null)); .forEach(userEntity -> userEntity.setFederationLink(null));
} }
@ -396,7 +397,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, roleId); .compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, roleId);
try (Stream<MapUserEntity<K>> s = tx.read(mcb)) { try (Stream<MapUserEntity<K>> s = tx.read(withCriteria(mcb))) {
s.map(e -> registerEntityForChanges(tx, e)) s.map(e -> registerEntityForChanges(tx, e))
.forEach(userEntity -> userEntity.removeRolesMembership(roleId)); .forEach(userEntity -> userEntity.removeRolesMembership(roleId));
} }
@ -410,7 +411,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.ASSIGNED_GROUP, Operator.EQ, groupId); .compare(SearchableFields.ASSIGNED_GROUP, Operator.EQ, groupId);
try (Stream<MapUserEntity<K>> s = tx.read(mcb)) { try (Stream<MapUserEntity<K>> s = tx.read(withCriteria(mcb))) {
s.map(e -> registerEntityForChanges(tx, e)) s.map(e -> registerEntityForChanges(tx, e))
.forEach(userEntity -> userEntity.removeGroupsMembership(groupId)); .forEach(userEntity -> userEntity.removeGroupsMembership(groupId));
} }
@ -424,7 +425,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.CONSENT_FOR_CLIENT, Operator.EQ, clientId); .compare(SearchableFields.CONSENT_FOR_CLIENT, Operator.EQ, clientId);
try (Stream<MapUserEntity<K>> s = tx.read(mcb)) { try (Stream<MapUserEntity<K>> s = tx.read(withCriteria(mcb))) {
s.map(e -> registerEntityForChanges(tx, e)) s.map(e -> registerEntityForChanges(tx, e))
.forEach(userEntity -> userEntity.removeUserConsent(clientId)); .forEach(userEntity -> userEntity.removeUserConsent(clientId));
} }
@ -444,7 +445,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, clientScope.getRealm().getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, clientScope.getRealm().getId())
.compare(SearchableFields.CONSENT_WITH_CLIENT_SCOPE, Operator.EQ, clientScopeId); .compare(SearchableFields.CONSENT_WITH_CLIENT_SCOPE, Operator.EQ, clientScopeId);
try (Stream<MapUserEntity<K>> s = tx.read(mcb)) { try (Stream<MapUserEntity<K>> s = tx.read(withCriteria(mcb))) {
s.flatMap(MapUserEntity::getUserConsents) s.flatMap(MapUserEntity::getUserConsents)
.forEach(consent -> consent.removeGrantedClientScopesIds(clientScopeId)); .forEach(consent -> consent.removeGrantedClientScopesIds(clientScopeId));
} }
@ -462,7 +463,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.CONSENT_CLIENT_FEDERATION_LINK, Operator.EQ, componentId); .compare(SearchableFields.CONSENT_CLIENT_FEDERATION_LINK, Operator.EQ, componentId);
try (Stream<MapUserEntity<K>> s = tx.read(mcb)) { try (Stream<MapUserEntity<K>> s = tx.read(withCriteria(mcb))) {
String providerIdS = new StorageId(componentId, "").getId(); String providerIdS = new StorageId(componentId, "").getId();
s.forEach(removeConsentsForExternalClient(providerIdS)); s.forEach(removeConsentsForExternalClient(providerIdS));
} }
@ -490,7 +491,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder() ModelCriteriaBuilder<UserModel> mcb = userStore.createCriteriaBuilder()
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()); .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId());
try (Stream<MapUserEntity<K>> s = tx.read(mcb)) { try (Stream<MapUserEntity<K>> s = tx.read(withCriteria(mcb))) {
s.map(e -> registerEntityForChanges(tx, e)) s.map(e -> registerEntityForChanges(tx, e))
.forEach(entity -> entity.addRolesMembership(roleId)); .forEach(entity -> entity.addRolesMembership(roleId));
} }
@ -510,7 +511,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.USERNAME, Operator.ILIKE, username); .compare(SearchableFields.USERNAME, Operator.ILIKE, username);
try (Stream<MapUserEntity<K>> s = tx.read(mcb)) { try (Stream<MapUserEntity<K>> s = tx.read(withCriteria(mcb))) {
return s.findFirst() return s.findFirst()
.map(entityToAdapterFunc(realm)).orElse(null); .map(entityToAdapterFunc(realm)).orElse(null);
} }
@ -523,7 +524,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.EMAIL, Operator.EQ, email); .compare(SearchableFields.EMAIL, Operator.EQ, email);
List<MapUserEntity<K>> usersWithEmail = tx.read(mcb) List<MapUserEntity<K>> usersWithEmail = tx.read(withCriteria(mcb))
.filter(userEntity -> Objects.equals(userEntity.getEmail(), email)) .filter(userEntity -> Objects.equals(userEntity.getEmail(), email))
.collect(Collectors.toList()); .collect(Collectors.toList());
if (usersWithEmail.isEmpty()) return null; if (usersWithEmail.isEmpty()) return null;
@ -571,7 +572,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
mcb = mcb.compare(SearchableFields.SERVICE_ACCOUNT_CLIENT, Operator.NOT_EXISTS); mcb = mcb.compare(SearchableFields.SERVICE_ACCOUNT_CLIENT, Operator.NOT_EXISTS);
} }
return (int) tx.getCount(mcb); return (int) tx.getCount(withCriteria(mcb));
} }
@Override @Override
@ -584,9 +585,8 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
mcb = mcb.compare(SearchableFields.SERVICE_ACCOUNT_CLIENT, Operator.NOT_EXISTS); mcb = mcb.compare(SearchableFields.SERVICE_ACCOUNT_CLIENT, Operator.NOT_EXISTS);
} }
return paginatedStream(tx.read(mcb) return tx.read(withCriteria(mcb).pagination(firstResult, maxResults, SearchableFields.USERNAME))
.sorted(MapUserEntity.COMPARE_BY_USERNAME), firstResult, maxResults) .map(entityToAdapterFunc(realm));
.map(entityToAdapterFunc(realm));
} }
@Override @Override
@ -701,10 +701,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
mcb = mcb.compare(SearchableFields.ASSIGNED_GROUP, Operator.IN, authorizedGroups); mcb = mcb.compare(SearchableFields.ASSIGNED_GROUP, Operator.IN, authorizedGroups);
} }
Stream<MapUserEntity<K>> usersStream = tx.read(mcb) return tx.read(withCriteria(mcb).pagination(firstResult, maxResults, SearchableFields.USERNAME))
.sorted(MapUserEntity.COMPARE_BY_USERNAME); // Sort before paginating
return paginatedStream(usersStream, firstResult, maxResults) // paginate if necessary
.map(entityToAdapterFunc(realm)) .map(entityToAdapterFunc(realm))
.filter(Objects::nonNull); .filter(Objects::nonNull);
} }
@ -716,7 +713,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.ASSIGNED_GROUP, Operator.EQ, group.getId()); .compare(SearchableFields.ASSIGNED_GROUP, Operator.EQ, group.getId());
return paginatedStream(tx.read(mcb).sorted(MapUserEntity.COMPARE_BY_USERNAME), firstResult, maxResults) return tx.read(withCriteria(mcb).pagination(firstResult, maxResults, SearchableFields.USERNAME))
.map(entityToAdapterFunc(realm)); .map(entityToAdapterFunc(realm));
} }
@ -727,8 +724,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.ATTRIBUTE, Operator.EQ, attrName, attrValue); .compare(SearchableFields.ATTRIBUTE, Operator.EQ, attrName, attrValue);
return tx.read(mcb) return tx.read(withCriteria(mcb).orderBy(SearchableFields.USERNAME, ASCENDING))
.sorted(MapUserEntity.COMPARE_BY_USERNAME)
.map(entityToAdapterFunc(realm)); .map(entityToAdapterFunc(realm));
} }
@ -756,8 +752,7 @@ public class MapUserProvider<K> implements UserProvider.Streams, UserCredentialS
.compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId()) .compare(SearchableFields.REALM_ID, Operator.EQ, realm.getId())
.compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, role.getId()); .compare(SearchableFields.ASSIGNED_ROLE, Operator.EQ, role.getId());
return paginatedStream(tx.read(mcb) return tx.read(withCriteria(mcb).pagination(firstResult, maxResults, SearchableFields.USERNAME))
.sorted(MapUserEntity.COMPARE_BY_USERNAME), firstResult, maxResults)
.map(entityToAdapterFunc(realm)); .map(entityToAdapterFunc(realm));
} }

View file

@ -33,7 +33,6 @@ import org.keycloak.models.map.storage.ModelCriteriaBuilder;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -48,9 +47,9 @@ import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import static org.keycloak.models.UserSessionModel.CORRESPONDING_SESSION_ID; import static org.keycloak.models.UserSessionModel.CORRESPONDING_SESSION_ID;
import static org.keycloak.models.UserSessionModel.SessionPersistenceState.TRANSIENT; import static org.keycloak.models.UserSessionModel.SessionPersistenceState.TRANSIENT;
import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges; import static org.keycloak.models.map.common.MapStorageUtils.registerEntityForChanges;
import static org.keycloak.models.map.storage.QueryParameters.withCriteria;
import static org.keycloak.models.map.userSession.SessionExpiration.setClientSessionExpiration; import static org.keycloak.models.map.userSession.SessionExpiration.setClientSessionExpiration;
import static org.keycloak.models.map.userSession.SessionExpiration.setUserSessionExpiration; import static org.keycloak.models.map.userSession.SessionExpiration.setUserSessionExpiration;
import static org.keycloak.utils.StreamsUtil.paginatedStream;
/** /**
* @author <a href="mailto:mkanis@redhat.com">Martin Kanis</a> * @author <a href="mailto:mkanis@redhat.com">Martin Kanis</a>
@ -195,7 +194,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
.compare(AuthenticatedClientSessionModel.SearchableFields.CLIENT_ID, ModelCriteriaBuilder.Operator.EQ, client.getId()) .compare(AuthenticatedClientSessionModel.SearchableFields.CLIENT_ID, ModelCriteriaBuilder.Operator.EQ, client.getId())
.compare(AuthenticatedClientSessionModel.SearchableFields.IS_OFFLINE, ModelCriteriaBuilder.Operator.EQ, offline); .compare(AuthenticatedClientSessionModel.SearchableFields.IS_OFFLINE, ModelCriteriaBuilder.Operator.EQ, offline);
return clientSessionTx.read(mcb) return clientSessionTx.read(withCriteria(mcb))
.findFirst() .findFirst()
.map(clientEntityToAdapterFunc(client.getRealm(), client, userSession)) .map(clientEntityToAdapterFunc(client.getRealm(), client, userSession))
.orElse(null); .orElse(null);
@ -258,7 +257,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
ModelCriteriaBuilder<UserSessionModel> mcb = realmAndOfflineCriteriaBuilder(realm, false) ModelCriteriaBuilder<UserSessionModel> mcb = realmAndOfflineCriteriaBuilder(realm, false)
.compare(UserSessionModel.SearchableFields.ID, ModelCriteriaBuilder.Operator.EQ, uuid); .compare(UserSessionModel.SearchableFields.ID, ModelCriteriaBuilder.Operator.EQ, uuid);
return userSessionTx.read(mcb) return userSessionTx.read(withCriteria(mcb))
.findFirst() .findFirst()
.map(userEntityToAdapterFunc(realm)) .map(userEntityToAdapterFunc(realm))
.orElse(null); .orElse(null);
@ -271,7 +270,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getUserSessionsStream(%s, %s)%s", realm, user, getShortStackTrace()); LOG.tracef("getUserSessionsStream(%s, %s)%s", realm, user, getShortStackTrace());
return userSessionTx.read(mcb) return userSessionTx.read(withCriteria(mcb))
.map(userEntityToAdapterFunc(realm)) .map(userEntityToAdapterFunc(realm))
.filter(Objects::nonNull); .filter(Objects::nonNull);
} }
@ -283,7 +282,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getUserSessionsStream(%s, %s)%s", realm, client, getShortStackTrace()); LOG.tracef("getUserSessionsStream(%s, %s)%s", realm, client, getShortStackTrace());
return userSessionTx.read(mcb) return userSessionTx.read(withCriteria(mcb))
.map(userEntityToAdapterFunc(realm)) .map(userEntityToAdapterFunc(realm))
.filter(Objects::nonNull); .filter(Objects::nonNull);
} }
@ -291,8 +290,16 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
@Override @Override
public Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, ClientModel client, public Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, ClientModel client,
Integer firstResult, Integer maxResults) { Integer firstResult, Integer maxResults) {
return paginatedStream(getUserSessionsStream(realm, client) LOG.tracef("getUserSessionsStream(%s, %s, %s, %s)%s", realm, client, firstResult, maxResults, getShortStackTrace());
.sorted(Comparator.comparing(UserSessionModel::getLastSessionRefresh)), firstResult, maxResults);
ModelCriteriaBuilder<UserSessionModel> mcb = realmAndOfflineCriteriaBuilder(realm, false)
.compare(UserSessionModel.SearchableFields.CLIENT_ID, ModelCriteriaBuilder.Operator.EQ, client.getId());
return userSessionTx.read(withCriteria(mcb).pagination(firstResult, maxResults,
UserSessionModel.SearchableFields.LAST_SESSION_REFRESH))
.map(userEntityToAdapterFunc(realm))
.filter(Objects::nonNull);
} }
@Override @Override
@ -302,7 +309,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getUserSessionByBrokerUserIdStream(%s, %s)%s", realm, brokerUserId, getShortStackTrace()); LOG.tracef("getUserSessionByBrokerUserIdStream(%s, %s)%s", realm, brokerUserId, getShortStackTrace());
return userSessionTx.read(mcb) return userSessionTx.read(withCriteria(mcb))
.map(userEntityToAdapterFunc(realm)) .map(userEntityToAdapterFunc(realm))
.filter(Objects::nonNull); .filter(Objects::nonNull);
} }
@ -314,7 +321,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getUserSessionByBrokerSessionId(%s, %s)%s", realm, brokerSessionId, getShortStackTrace()); LOG.tracef("getUserSessionByBrokerSessionId(%s, %s)%s", realm, brokerSessionId, getShortStackTrace());
return userSessionTx.read(mcb) return userSessionTx.read(withCriteria(mcb))
.findFirst() .findFirst()
.map(userEntityToAdapterFunc(realm)) .map(userEntityToAdapterFunc(realm))
.orElse(null); .orElse(null);
@ -347,7 +354,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getActiveUserSessions(%s, %s)%s", realm, client, getShortStackTrace()); LOG.tracef("getActiveUserSessions(%s, %s)%s", realm, client, getShortStackTrace());
return userSessionTx.getCount(mcb); return userSessionTx.getCount(withCriteria(mcb));
} }
@Override @Override
@ -356,7 +363,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getActiveClientSessionStats(%s, %s)%s", realm, offline, getShortStackTrace()); LOG.tracef("getActiveClientSessionStats(%s, %s)%s", realm, offline, getShortStackTrace());
return userSessionTx.read(mcb) return userSessionTx.read(withCriteria(mcb))
.map(userEntityToAdapterFunc(realm)) .map(userEntityToAdapterFunc(realm))
.filter(Objects::nonNull) .filter(Objects::nonNull)
.map(UserSessionModel::getAuthenticatedClientSessions) .map(UserSessionModel::getAuthenticatedClientSessions)
@ -375,7 +382,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("removeUserSession(%s, %s)%s", realm, session, getShortStackTrace()); LOG.tracef("removeUserSession(%s, %s)%s", realm, session, getShortStackTrace());
userSessionTx.delete(userSessionStore.getKeyConvertor().yieldNewUniqueKey(), mcb); userSessionTx.delete(userSessionStore.getKeyConvertor().yieldNewUniqueKey(), withCriteria(mcb));
} }
@Override @Override
@ -386,7 +393,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("removeUserSessions(%s, %s)%s", realm, user, getShortStackTrace()); LOG.tracef("removeUserSessions(%s, %s)%s", realm, user, getShortStackTrace());
userSessionTx.delete(userSessionStore.getKeyConvertor().yieldNewUniqueKey(), mcb); userSessionTx.delete(userSessionStore.getKeyConvertor().yieldNewUniqueKey(), withCriteria(mcb));
} }
@Override @Override
@ -405,7 +412,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("removeUserSessions(%s)%s", realm, getShortStackTrace()); LOG.tracef("removeUserSessions(%s)%s", realm, getShortStackTrace());
userSessionTx.delete(userSessionStore.getKeyConvertor().yieldNewUniqueKey(), mcb); userSessionTx.delete(userSessionStore.getKeyConvertor().yieldNewUniqueKey(), withCriteria(mcb));
} }
@Override @Override
@ -462,7 +469,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
UK uk = userSessionStore.getKeyConvertor().fromString(userSession.getNote(CORRESPONDING_SESSION_ID)); UK uk = userSessionStore.getKeyConvertor().fromString(userSession.getNote(CORRESPONDING_SESSION_ID));
mcb = realmAndOfflineCriteriaBuilder(realm, true) mcb = realmAndOfflineCriteriaBuilder(realm, true)
.compare(UserSessionModel.SearchableFields.ID, ModelCriteriaBuilder.Operator.EQ, uk); .compare(UserSessionModel.SearchableFields.ID, ModelCriteriaBuilder.Operator.EQ, uk);
userSessionTx.delete(userSessionStore.getKeyConvertor().yieldNewUniqueKey(), mcb); userSessionTx.delete(userSessionStore.getKeyConvertor().yieldNewUniqueKey(), withCriteria(mcb));
userSession.removeNote(CORRESPONDING_SESSION_ID); userSession.removeNote(CORRESPONDING_SESSION_ID);
} }
} }
@ -496,7 +503,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getOfflineUserSessionsStream(%s, %s)%s", realm, user, getShortStackTrace()); LOG.tracef("getOfflineUserSessionsStream(%s, %s)%s", realm, user, getShortStackTrace());
return userSessionTx.read(mcb) return userSessionTx.read(withCriteria(mcb))
.map(userEntityToAdapterFunc(realm)) .map(userEntityToAdapterFunc(realm))
.filter(Objects::nonNull); .filter(Objects::nonNull);
} }
@ -508,7 +515,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getOfflineUserSessionByBrokerSessionId(%s, %s)%s", realm, brokerSessionId, getShortStackTrace()); LOG.tracef("getOfflineUserSessionByBrokerSessionId(%s, %s)%s", realm, brokerSessionId, getShortStackTrace());
return userSessionTx.read(mcb) return userSessionTx.read(withCriteria(mcb))
.findFirst() .findFirst()
.map(userEntityToAdapterFunc(realm)) .map(userEntityToAdapterFunc(realm))
.orElse(null); .orElse(null);
@ -521,7 +528,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getOfflineUserSessionByBrokerUserIdStream(%s, %s)%s", realm, brokerUserId, getShortStackTrace()); LOG.tracef("getOfflineUserSessionByBrokerUserIdStream(%s, %s)%s", realm, brokerUserId, getShortStackTrace());
return userSessionTx.read(mcb) return userSessionTx.read(withCriteria(mcb))
.map(userEntityToAdapterFunc(realm)) .map(userEntityToAdapterFunc(realm))
.filter(Objects::nonNull); .filter(Objects::nonNull);
} }
@ -533,7 +540,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getOfflineSessionsCount(%s, %s)%s", realm, client, getShortStackTrace()); LOG.tracef("getOfflineSessionsCount(%s, %s)%s", realm, client, getShortStackTrace());
return userSessionTx.getCount(mcb); return userSessionTx.getCount(withCriteria(mcb));
} }
@Override @Override
@ -544,10 +551,10 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
LOG.tracef("getOfflineUserSessionsStream(%s, %s, %s, %s)%s", realm, client, firstResult, maxResults, getShortStackTrace()); LOG.tracef("getOfflineUserSessionsStream(%s, %s, %s, %s)%s", realm, client, firstResult, maxResults, getShortStackTrace());
return paginatedStream(userSessionTx.read(mcb) return userSessionTx.read(withCriteria(mcb).pagination(firstResult, maxResults,
UserSessionModel.SearchableFields.LAST_SESSION_REFRESH))
.map(userEntityToAdapterFunc(realm)) .map(userEntityToAdapterFunc(realm))
.filter(Objects::nonNull) .filter(Objects::nonNull);
.sorted(Comparator.comparing(UserSessionModel::getLastSessionRefresh)), firstResult, maxResults);
} }
@Override @Override
@ -595,7 +602,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
.compare(UserSessionModel.SearchableFields.ID, ModelCriteriaBuilder.Operator.EQ, uuid); .compare(UserSessionModel.SearchableFields.ID, ModelCriteriaBuilder.Operator.EQ, uuid);
// check if it's an offline user session // check if it's an offline user session
MapUserSessionEntity<UK> userSessionEntity = userSessionTx.read(mcb).findFirst().orElse(null); MapUserSessionEntity<UK> userSessionEntity = userSessionTx.read(withCriteria(mcb)).findFirst().orElse(null);
if (userSessionEntity != null) { if (userSessionEntity != null) {
if (userSessionEntity.isOffline()) { if (userSessionEntity.isOffline()) {
return Stream.of(userSessionEntity); return Stream.of(userSessionEntity);
@ -604,7 +611,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
// no session found by the given ID, try to find by corresponding session ID // no session found by the given ID, try to find by corresponding session ID
mcb = realmAndOfflineCriteriaBuilder(realm, true) mcb = realmAndOfflineCriteriaBuilder(realm, true)
.compare(UserSessionModel.SearchableFields.CORRESPONDING_SESSION_ID, ModelCriteriaBuilder.Operator.EQ, userSessionId); .compare(UserSessionModel.SearchableFields.CORRESPONDING_SESSION_ID, ModelCriteriaBuilder.Operator.EQ, userSessionId);
return userSessionTx.read(mcb); return userSessionTx.read(withCriteria(mcb));
} }
// it's online user session so lookup offline user session by corresponding session id reference // it's online user session so lookup offline user session by corresponding session id reference
@ -613,7 +620,7 @@ public class MapUserSessionProvider<UK, CK> implements UserSessionProvider {
UK uk = userSessionStore.getKeyConvertor().fromStringSafe(offlineUserSessionId); UK uk = userSessionStore.getKeyConvertor().fromStringSafe(offlineUserSessionId);
mcb = realmAndOfflineCriteriaBuilder(realm, true) mcb = realmAndOfflineCriteriaBuilder(realm, true)
.compare(UserSessionModel.SearchableFields.ID, ModelCriteriaBuilder.Operator.EQ, uk); .compare(UserSessionModel.SearchableFields.ID, ModelCriteriaBuilder.Operator.EQ, uk);
return userSessionTx.read(mcb); return userSessionTx.read(withCriteria(mcb));
} }
return Stream.empty(); return Stream.empty();

View file

@ -40,31 +40,6 @@ public class SearchableModelField<M> {
return fieldClass; return fieldClass;
} }
@Override
public int hashCode() {
int hash = 5;
hash = 83 * hash + Objects.hashCode(this.name);
return hash;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
final SearchableModelField<?> other = (SearchableModelField<?>) obj;
if ( ! Objects.equals(this.name, other.name)) {
return false;
}
return true;
}
@Override @Override
public String toString() { public String toString() {
return "SearchableModelField " + name + " @ " + getClass().getTypeParameters()[0].getTypeName(); return "SearchableModelField " + name + " @ " + getClass().getTypeParameters()[0].getTypeName();