KEYCLOAK-18727 Improve user search query

This commit is contained in:
bal1imb 2021-07-22 06:09:28 -07:00 committed by Hynek Mlnařík
parent 80072b30cd
commit 9621d513b5
7 changed files with 311 additions and 84 deletions

View file

@ -55,6 +55,7 @@ import javax.persistence.TypedQuery;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Expression;
import javax.persistence.criteria.From;
import javax.persistence.criteria.Join;
import javax.persistence.criteria.JoinType;
import javax.persistence.criteria.Predicate;
@ -607,12 +608,20 @@ public class JpaUserProvider implements UserProvider.Streams, UserCredentialStor
@Override
public int getUsersCount(RealmModel realm, String search) {
TypedQuery<Long> query = em.createNamedQuery("searchForUserCount", Long.class);
query.setParameter("realmId", realm.getId());
query.setParameter("search", "%" + search.toLowerCase() + "%");
Long count = query.getSingleResult();
CriteriaBuilder builder = em.getCriteriaBuilder();
CriteriaQuery<Long> queryBuilder = builder.createQuery(Long.class);
Root<UserEntity> root = queryBuilder.from(UserEntity.class);
return count.intValue();
queryBuilder.select(builder.count(root));
List<Predicate> predicates = new ArrayList<>();
predicates.add(builder.equal(root.get("realmId"), realm.getId()));
predicates.add(builder.or(getSearchOptionPredicateArray(search, builder, root)));
queryBuilder.where(predicates.toArray(new Predicate[0]));
return em.createQuery(queryBuilder).getSingleResult().intValue();
}
@Override
@ -621,13 +630,23 @@ public class JpaUserProvider implements UserProvider.Streams, UserCredentialStor
return 0;
}
TypedQuery<Long> query = em.createNamedQuery("searchForUserCountInGroups", Long.class);
query.setParameter("realmId", realm.getId());
query.setParameter("search", "%" + search.toLowerCase() + "%");
query.setParameter("groupIds", groupIds);
Long count = query.getSingleResult();
CriteriaBuilder builder = em.getCriteriaBuilder();
CriteriaQuery<Long> queryBuilder = builder.createQuery(Long.class);
return count.intValue();
Root<UserGroupMembershipEntity> groupMembership = queryBuilder.from(UserGroupMembershipEntity.class);
Join<UserGroupMembershipEntity, UserEntity> userJoin = groupMembership.join("user");
queryBuilder.select(builder.count(userJoin));
List<Predicate> predicates = new ArrayList<>();
predicates.add(builder.equal(userJoin.get("realmId"), realm.getId()));
predicates.add(builder.or(getSearchOptionPredicateArray(search, builder, userJoin)));
predicates.add(groupMembership.get("groupId").in(groupIds));
queryBuilder.where(predicates.toArray(new Predicate[0]));
return em.createQuery(queryBuilder).getSingleResult().intValue();
}
@Override
@ -789,21 +808,10 @@ public class JpaUserProvider implements UserProvider.Streams, UserCredentialStor
switch (key) {
case UserModel.SEARCH:
List<Predicate> orPredicates = new ArrayList<>();
orPredicates
.add(builder.like(builder.lower(root.get(USERNAME)), "%" + value.toLowerCase() + "%"));
orPredicates.add(builder.like(builder.lower(root.get(EMAIL)), "%" + value.toLowerCase() + "%"));
orPredicates.add(builder.like(
builder.lower(builder.concat(builder.concat(
builder.coalesce(root.get(FIRST_NAME), builder.literal("")), " "),
builder.coalesce(root.get(LAST_NAME), builder.literal("")))),
"%" + value.toLowerCase() + "%"));
predicates.add(builder.or(orPredicates.toArray(new Predicate[0])));
for (String stringToSearch : value.trim().split("\\s+")) {
predicates.add(builder.or(getSearchOptionPredicateArray(stringToSearch, builder, root)));
}
break;
case USERNAME:
case FIRST_NAME:
case LAST_NAME:
@ -1033,6 +1041,40 @@ public class JpaUserProvider implements UserProvider.Streams, UserCredentialStor
}
}
private Predicate[] getSearchOptionPredicateArray(String value, CriteriaBuilder builder, From<?, UserEntity> from) {
value = value.toLowerCase();
List<Predicate> orPredicates = new ArrayList<>();
if (value.length() >= 2 && value.charAt(0) == '"' && value.charAt(value.length() - 1) == '"') {
// exact search
value = value.substring(1, value.length() - 1);
orPredicates.add(builder.equal(builder.lower(from.get(USERNAME)), value));
orPredicates.add(builder.equal(builder.lower(from.get(EMAIL)), value));
orPredicates.add(builder.equal(builder.lower(from.get(FIRST_NAME)), value));
orPredicates.add(builder.equal(builder.lower(from.get(LAST_NAME)), value));
} else {
if (value.length() >= 2 && value.charAt(0) == '*' && value.charAt(value.length() - 1) == '*') {
// infix search
value = "%" + value.substring(1, value.length() - 1) + "%";
} else {
// default to prefix search
if (value.length() > 0 && value.charAt(value.length() - 1) == '*') {
value = value.substring(0, value.length() - 1);
}
value += "%";
}
orPredicates.add(builder.like(builder.lower(from.get(USERNAME)), value));
orPredicates.add(builder.like(builder.lower(from.get(EMAIL)), value));
orPredicates.add(builder.like(builder.lower(from.get(FIRST_NAME)), value));
orPredicates.add(builder.like(builder.lower(from.get(LAST_NAME)), value));
}
return orPredicates.toArray(new Predicate[0]);
}
private UserEntity userInEntityManagerContext(String id) {
UserEntity user = em.getReference(UserEntity.class, id);
boolean isLoaded = em.getEntityManagerFactory().getPersistenceUnitUtil().isLoaded(user);

View file

@ -44,8 +44,6 @@ import java.util.LinkedList;
@NamedQueries({
@NamedQuery(name="getAllUsersByRealm", query="select u from UserEntity u where u.realmId = :realmId order by u.username"),
@NamedQuery(name="getAllUsersByRealmExcludeServiceAccount", query="select u from UserEntity u where u.realmId = :realmId and (u.serviceAccountClientLink is null) order by u.username"),
@NamedQuery(name="searchForUserCount", query="select count(u) from UserEntity u where u.realmId = :realmId and (u.serviceAccountClientLink is null) and " +
"( lower(u.username) like :search or lower(concat(coalesce(u.firstName, ''), ' ', coalesce(u.lastName, ''))) like :search or u.email like :search )"),
@NamedQuery(name="getRealmUserByUsername", query="select u from UserEntity u where u.username = :username and u.realmId = :realmId"),
@NamedQuery(name="getRealmUserByEmail", query="select u from UserEntity u where u.email = :email and u.realmId = :realmId"),
@NamedQuery(name="getRealmUserByLastName", query="select u from UserEntity u where u.lastName = :lastName and u.realmId = :realmId"),

View file

@ -41,8 +41,6 @@ import java.io.Serializable;
@NamedQuery(name="deleteUserGroupMembershipsByRealmAndLink", query="delete from UserGroupMembershipEntity mapping where mapping.user IN (select u from UserEntity u where u.realmId=:realmId and u.federationLink=:link)"),
@NamedQuery(name="deleteUserGroupMembershipsByGroup", query="delete from UserGroupMembershipEntity m where m.groupId = :groupId"),
@NamedQuery(name="deleteUserGroupMembershipsByUser", query="delete from UserGroupMembershipEntity m where m.user = :user"),
@NamedQuery(name="searchForUserCountInGroups", query="select count(m.user) from UserGroupMembershipEntity m where m.user.realmId = :realmId and (m.user.serviceAccountClientLink is null) and " +
"( lower(m.user.username) like :search or lower(concat(m.user.firstName, ' ', m.user.lastName)) like :search or m.user.email like :search ) and m.groupId in :groupIds"),
@NamedQuery(name="userCountInGroups", query="select count(m.user) from UserGroupMembershipEntity m where m.user.realmId = :realmId and m.groupId in :groupIds")
})
@Table(name="USER_GROUP_MEMBERSHIP")

View file

@ -44,6 +44,7 @@ class CriteriaOperator {
private static final Logger LOG = Logger.getLogger(CriteriaOperator.class.getSimpleName());
private static final Predicate<Object> ALWAYS_FALSE = o -> false;
private static final Predicate<Object> ALWAYS_TRUE = o -> true;
static {
OPERATORS.put(Operator.EQ, CriteriaOperator::eq);
@ -190,6 +191,11 @@ class CriteriaOperator {
Object value0 = getFirstArrayElement(value);
if (value0 instanceof String) {
String sValue = (String) value0;
if(Pattern.matches("^%+$", sValue)) {
return ALWAYS_TRUE;
}
boolean anyBeginning = sValue.startsWith("%");
boolean anyEnd = sValue.endsWith("%");
@ -210,6 +216,11 @@ class CriteriaOperator {
Object value0 = getFirstArrayElement(value);
if (value0 instanceof String) {
String sValue = (String) value0;
if(Pattern.matches("^%+$", sValue)) {
return ALWAYS_TRUE;
}
boolean anyBeginning = sValue.startsWith("%");
boolean anyEnd = sValue.endsWith("%");

View file

@ -47,6 +47,7 @@ import org.keycloak.models.map.storage.ModelCriteriaBuilder.Operator;
import org.keycloak.models.map.storage.criteria.DefaultModelCriteria;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.storage.StorageId;
import org.keycloak.storage.UserStorageManager;
import org.keycloak.storage.UserStorageProvider;
import org.keycloak.storage.client.ClientStorageProvider;
@ -613,20 +614,10 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
switch (key) {
case UserModel.SEARCH:
for (String stringToSearch : value.trim().split("\\s+")) {
if (value.isEmpty()) {
continue;
}
final String s = exactSearch ? stringToSearch : ("%" + stringToSearch + "%");
mcb = mcb.or(
mcb.compare(SearchableFields.USERNAME, Operator.ILIKE, s),
mcb.compare(SearchableFields.EMAIL, Operator.ILIKE, s),
mcb.compare(SearchableFields.FIRST_NAME, Operator.ILIKE, s),
mcb.compare(SearchableFields.LAST_NAME, Operator.ILIKE, s)
);
for (String stringToSearch : value.split("\\s+")) {
mcb = addSearchToModelCriteria(stringToSearch, mcb);
}
break;
case USERNAME:
mcb = mcb.compare(SearchableFields.USERNAME, Operator.ILIKE, searchedString);
break;
@ -650,13 +641,14 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
break;
}
case UserModel.IDP_ALIAS: {
if (! attributes.containsKey(UserModel.IDP_USER_ID)) {
if (!attributes.containsKey(UserModel.IDP_USER_ID)) {
mcb = mcb.compare(SearchableFields.IDP_AND_USER, Operator.EQ, value);
}
break;
}
case UserModel.IDP_USER_ID: {
mcb = mcb.compare(SearchableFields.IDP_AND_USER, Operator.EQ, attributes.get(UserModel.IDP_ALIAS), value);
mcb = mcb.compare(SearchableFields.IDP_AND_USER, Operator.EQ, attributes.get(UserModel.IDP_ALIAS),
value);
break;
}
case UserModel.EXACT:
@ -678,12 +670,13 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
return Stream.empty();
}
final ResourceStore resourceStore = session.getProvider(AuthorizationProvider.class).getStoreFactory().getResourceStore();
final ResourceStore resourceStore =
session.getProvider(AuthorizationProvider.class).getStoreFactory().getResourceStore();
HashSet<String> authorizedGroups = new HashSet<>(userGroups);
authorizedGroups.removeIf(id -> {
Map<Resource.FilterOption, String[]> values = new EnumMap<>(Resource.FilterOption.class);
values.put(Resource.FilterOption.EXACT_NAME, new String[] { "group.resource." + id });
values.put(Resource.FilterOption.EXACT_NAME, new String[] {"group.resource." + id});
return resourceStore.findByResourceServer(values, null, 0, 1).isEmpty();
});
@ -832,4 +825,30 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
public void close() {
}
private DefaultModelCriteria<UserModel> addSearchToModelCriteria(String value,
DefaultModelCriteria<UserModel> mcb) {
if (value.length() >= 2 && value.charAt(0) == '"' && value.charAt(value.length() - 1) == '"') {
// exact search
value = value.substring(1, value.length() - 1);
} else {
if (value.length() >= 2 && value.charAt(0) == '*' && value.charAt(value.length() - 1) == '*') {
// infix search
value = "%" + value.substring(1, value.length() - 1) + "%";
} else {
// default to prefix search
if (value.length() > 0 && value.charAt(value.length() - 1) == '*') {
value = value.substring(0, value.length() - 1);
}
value += "%";
}
}
return mcb.or(
mcb.compare(SearchableFields.USERNAME, Operator.ILIKE, value),
mcb.compare(SearchableFields.EMAIL, Operator.ILIKE, value),
mcb.compare(SearchableFields.FIRST_NAME, Operator.ILIKE, value),
mcb.compare(SearchableFields.LAST_NAME, Operator.ILIKE, value));
}
}

View file

@ -40,7 +40,6 @@ import org.keycloak.common.util.ObjectUtil;
import org.keycloak.credential.CredentialModel;
import org.keycloak.events.admin.OperationType;
import org.keycloak.events.admin.ResourceType;
import org.keycloak.jose.jws.JWSInput;
import org.keycloak.models.Constants;
import org.keycloak.models.LDAPConstants;
import org.keycloak.models.PasswordPolicy;
@ -93,7 +92,6 @@ import javax.mail.internet.MimeMessage;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.ClientErrorException;
import javax.ws.rs.NotFoundException;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriBuilder;
import java.io.IOException;
@ -723,7 +721,7 @@ public class UserTest extends AbstractAdminTest {
createUser(user);
List<UserRepresentation> users = realm.users().search("wit", null, null);
List<UserRepresentation> users = realm.users().search("*wit*", null, null);
assertEquals(1, users.size());
}
@ -925,25 +923,131 @@ public class UserTest extends AbstractAdminTest {
}
@Test
public void search() {
createUsers();
public void infixSearch() {
List<String> userIds = createUsers();
List<UserRepresentation> users = realm.users().search("username1", null, null);
assertEquals(1, users.size());
// Username search
List<UserRepresentation> users = realm.users().search("*1*", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
users = realm.users().search("first1", null, null);
assertEquals(1, users.size());
users = realm.users().search("*y*", null, null);
assertThat(users.size(), is(0));
users = realm.users().search("last", null, null);
assertEquals(9, users.size());
users = realm.users().search("*name*", null, null);
assertThat(users, hasSize(9));
users = realm.users().search("**", null, null);
assertThat(users, hasSize(9));
// First/Last name search
users = realm.users().search("*first1*", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
users = realm.users().search("*last*", null, null);
assertThat(users, hasSize(9));
// Email search
users = realm.users().search("*@localhost*", null, null);
assertThat(users, hasSize(9));
users = realm.users().search("*1@local*", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
}
@Test
public void count() {
public void prefixSearch() {
List<String> userIds = createUsers();
// Username search
List<UserRepresentation> users = realm.users().search("user", null, null);
assertThat(users, hasSize(9));
users = realm.users().search("user*", null, null);
assertThat(users, hasSize(9));
users = realm.users().search("name", null, null);
assertThat(users, hasSize(0));
users = realm.users().search("name*", null, null);
assertThat(users, hasSize(0));
users = realm.users().search("username1", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
users = realm.users().search("username1*", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
users = realm.users().search(null, null, null);
assertThat(users, hasSize(9));
users = realm.users().search("", null, null);
assertThat(users, hasSize(9));
users = realm.users().search("*", null, null);
assertThat(users, hasSize(9));
// First/Last name search
users = realm.users().search("first1", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
users = realm.users().search("first1*", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
users = realm.users().search("last", null, null);
assertThat(users, hasSize(9));
users = realm.users().search("last*", null, null);
assertThat(users, hasSize(9));
// Email search
users = realm.users().search("user1@local", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
users = realm.users().search("user1@local*", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
}
@Test
public void circumfixSearchNotSupported() {
createUsers();
Integer count = realm.users().count();
assertEquals(9, count.intValue());
List<UserRepresentation> users = realm.users().search("u*name", null, null);
assertThat(users, hasSize(0));
}
@Test
public void exactSearch() {
List<String> userIds = createUsers();
// Username search
List<UserRepresentation> users = realm.users().search("\"username1\"", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
users = realm.users().search("\"user\"", null, null);
assertThat(users, hasSize(0));
users = realm.users().search("\"\"", null, null);
assertThat(users, hasSize(0));
// First/Last name search
users = realm.users().search("\"first1\"", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
// Email search
users = realm.users().search("\"user1@localhost\"", null, null);
assertThat(users, hasSize(1));
assertThat(userIds.get(0), equalTo(users.get(0).getId()));
}
@Test

View file

@ -18,7 +18,6 @@
package org.keycloak.testsuite.admin;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.keycloak.admin.client.Keycloak;
import org.keycloak.admin.client.resource.AuthorizationResource;
@ -49,7 +48,7 @@ import java.util.Optional;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertThat;
import static org.hamcrest.MatcherAssert.assertThat;
public class UsersTest extends AbstractAdminTest {
@ -111,25 +110,81 @@ public class UsersTest extends AbstractAdminTest {
@Test
public void countUsersBySearchWithViewPermission() {
createUser(realmId, "user1", "password", "user1FirstName", "user1LastName", "user1@example.com");
createUser(realmId, "user2", "password", "user2FirstName", "user2LastName", "user2@example.com");
//search all
assertThat(realm.users().count("user"), is(2));
//search first name
assertThat(realm.users().count("FirstName"), is(2));
assertThat(realm.users().count("user2FirstName"), is(1));
//search last name
assertThat(realm.users().count("LastName"), is(2));
assertThat(realm.users().count("user2LastName"), is(1));
//search in email
assertThat(realm.users().count("@example.com"), is(2));
assertThat(realm.users().count("user1@example.com"), is(1));
//search for something not existing
assertThat(realm.users().count("notExisting"), is(0));
//search for empty string
assertThat(realm.users().count(""), is(2));
//search not specified (defaults to simply /count)
assertThat(realm.users().count(null), is(2));
createUser(realmId, "user1", "password", "user1FirstName", "user1LastName", "user1@example.com", rep -> rep.setEmailVerified(true));
createUser(realmId, "user2", "password", "user2FirstName", "user2LastName", "user2@example.com", rep -> rep.setEmailVerified(false));
createUser(realmId, "user3", "password", "user3FirstName", "user3LastName", "user3@example.com", rep -> rep.setEmailVerified(true));
// Prefix search count
Integer count = realm.users().count("user");
assertThat(count, is(3));
count = realm.users().count("user*");
assertThat(count, is(3));
count = realm.users().count("er");
assertThat(count, is(0));
count = realm.users().count("");
assertThat(count, is(3));
count = realm.users().count("*");
assertThat(count, is(3));
count = realm.users().count("user2FirstName");
assertThat(count, is(1));
count = realm.users().count("user2First");
assertThat(count, is(1));
count = realm.users().count("user2First*");
assertThat(count, is(1));
count = realm.users().count("user1@example");
assertThat(count, is(1));
count = realm.users().count("user1@example*");
assertThat(count, is(1));
count = realm.users().count(null);
assertThat(count, is(3));
// Infix search count
count = realm.users().count("*user*");
assertThat(count, is(3));
count = realm.users().count("**");
assertThat(count, is(3));
count = realm.users().count("*foobar*");
assertThat(count, is(0));
count = realm.users().count("*LastName*");
assertThat(count, is(3));
count = realm.users().count("*FirstName*");
assertThat(count, is(3));
count = realm.users().count("*@example.com*");
assertThat(count, is(3));
// Exact search count
count = realm.users().count("\"user1\"");
assertThat(count, is(1));
count = realm.users().count("\"1\"");
assertThat(count, is(0));
count = realm.users().count("\"\"");
assertThat(count, is(0));
count = realm.users().count("\"user1FirstName\"");
assertThat(count, is(1));
count = realm.users().count("\"user1LastName\"");
assertThat(count, is(1));
count = realm.users().count("\"user1@example.com\"");
assertThat(count, is(1));
}
@Test
@ -184,13 +239,13 @@ public class UsersTest extends AbstractAdminTest {
//search all
assertThat(testRealmResource.users().count("user"), is(3));
//search first name
assertThat(testRealmResource.users().count("FirstName"), is(3));
assertThat(testRealmResource.users().count("*FirstName*"), is(3));
assertThat(testRealmResource.users().count("user2FirstName"), is(1));
//search last name
assertThat(testRealmResource.users().count("LastName"), is(3));
assertThat(testRealmResource.users().count("*LastName*"), is(3));
assertThat(testRealmResource.users().count("user2LastName"), is(1));
//search in email
assertThat(testRealmResource.users().count("@example.com"), is(3));
assertThat(testRealmResource.users().count("*@example.com*"), is(3));
assertThat(testRealmResource.users().count("user1@example.com"), is(1));
//search for something not existing
assertThat(testRealmResource.users().count("notExisting"), is(0));