Fixes collection comparison ignoring order

Use of containsAll() does not permit to compare if 2 lists are equals
(ignoring order)
Previous implementation of CollectionUtil.collectionEquals(...) was not taking care of specific cases where you can have [ A, A, B ] and [ A, B, B ] and complexity was O(n²)
Using Map, complexity is now O(n)

Closes #9920
This commit is contained in:
Francis PEROT 2022-02-02 12:33:10 +01:00 committed by Hynek Mlnařík
parent 340d8da197
commit 623aaf1e8b
5 changed files with 57 additions and 34 deletions

View file

@ -18,7 +18,9 @@
package org.keycloak.common.util;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
/**
* @author <a href="mailto:jeroen.rosenberg@gmail.com">Jeroen Rosenberg</a>
@ -43,16 +45,26 @@ public class CollectionUtil {
// Return true if all items from col1 are in col2 and viceversa. Order is not taken into account
public static <T> boolean collectionEquals(Collection<T> col1, Collection<T> col2) {
if (col1.size() != col2.size()) {
if (col1.size()!=col2.size()) {
return false;
}
for (T item : col1) {
if (!col2.contains(item)) {
Map<T, Integer> countMap = new HashMap<>();
for(T o : col1) {
Integer v = countMap.get(o);
countMap.put(o, v==null ? 1 : v+1);
}
for(T o : col2) {
Integer v = countMap.get(o);
if (v==null) {
return false;
}
countMap.put(o, v-1);
}
for(Integer count : countMap.values()) {
if (count!=0) {
return false;
}
}
return true;
}

View file

@ -123,7 +123,7 @@ public class MultivaluedHashMap<K, V> extends HashMap<K, List<V>>
for (Map.Entry<K, List<V>> e : entrySet()) {
List<V> list = e.getValue();
List<V> olist = omap.get(e.getKey());
if (!(list.size() == olist.size() && list.containsAll(olist) && olist.containsAll(list))) {
if (!CollectionUtil.collectionEquals(list, olist)) {
return false;
}
}

View file

@ -4,6 +4,7 @@ import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -13,29 +14,29 @@ import static org.hamcrest.MatcherAssert.assertThat;
public class CollectionUtilTest {
@Test
public void joinInputNoneOutputEmpty() {
final ArrayList<String> strings = new ArrayList<>();
final String retval = CollectionUtil.join(strings, ",");
Assert.assertEquals("", retval);
}
@Test
public void joinInputNoneOutputEmpty() {
final ArrayList<String> strings = new ArrayList<>();
final String retval = CollectionUtil.join(strings, ",");
Assert.assertEquals("", retval);
}
@Test
public void joinInput2SeparatorNull() {
final ArrayList<String> strings = new ArrayList<>();
strings.add("foo");
strings.add("bar");
final String retval = CollectionUtil.join(strings, null);
Assert.assertEquals("foonullbar", retval);
}
@Test
public void joinInput2SeparatorNull() {
final ArrayList<String> strings = new ArrayList<>();
strings.add("foo");
strings.add("bar");
final String retval = CollectionUtil.join(strings, null);
Assert.assertEquals("foonullbar", retval);
}
@Test
public void joinInput1SeparatorNotNull() {
final ArrayList<String> strings = new ArrayList<>();
strings.add("foo");
final String retval = CollectionUtil.join(strings, ",");
Assert.assertEquals("foo", retval);
}
@Test
public void joinInput1SeparatorNotNull() {
final ArrayList<String> strings = new ArrayList<>();
strings.add("foo");
final String retval = CollectionUtil.join(strings, ",");
Assert.assertEquals("foo", retval);
}
@Test
public void joinInput2SeparatorNotNull() {
@ -68,4 +69,12 @@ public class CollectionUtilTest {
assertThat(CollectionUtil.isEmpty(set), is(false));
assertThat(CollectionUtil.isNotEmpty(set), is(true));
}
@Test
public void equalsCollectionTest() {
Assert.assertFalse(CollectionUtil.collectionEquals(Arrays.asList(1, 3, 2), Arrays.asList(1, 3)));
Assert.assertFalse(CollectionUtil.collectionEquals(Arrays.asList("A", "C"), Arrays.asList("A", "C", "B")));
Assert.assertFalse(CollectionUtil.collectionEquals(Arrays.asList(1, 3, 2, 3), Arrays.asList(1, 2, 3, 2)));
Assert.assertTrue(CollectionUtil.collectionEquals(Arrays.asList(1, 3, 3), Arrays.asList(3, 1, 3)));
}
}

View file

@ -27,6 +27,7 @@ import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.keycloak.common.util.CollectionUtil;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.ModelException;
import org.keycloak.models.UserModel;
@ -113,13 +114,13 @@ public final class DefaultUserProfile implements UserProfile {
List<String> currentValue = user.getAttributeStream(name).filter(Objects::nonNull).collect(Collectors.toList());
List<String> updatedValue = attribute.getValue().stream().filter(Objects::nonNull).collect(Collectors.toList());
if (currentValue.size() != updatedValue.size() || !currentValue.containsAll(updatedValue)) {
if (!CollectionUtil.collectionEquals(currentValue, updatedValue)) {
user.setAttribute(name, updatedValue);
if(UserModel.EMAIL.equals(name) && metadata.getContext().isResetEmailVerified()) {
if (UserModel.EMAIL.equals(name) && metadata.getContext().isResetEmailVerified()) {
user.setEmailVerified(false);
}
for (AttributeChangeListener listener : changeListener) {
listener.onChange(name, user, currentValue);
}
@ -138,10 +139,10 @@ public final class DefaultUserProfile implements UserProfile {
if (this.attributes.isReadOnly(attr)) {
continue;
}
List<String> currentValue = user.getAttributeStream(attr).filter(Objects::nonNull).collect(Collectors.toList());
user.removeAttribute(attr);
for (AttributeChangeListener listener : changeListener) {
listener.onChange(attr, user, currentValue);
}

View file

@ -21,6 +21,7 @@ import static org.keycloak.validate.Validators.notBlankValidator;
import java.util.List;
import java.util.stream.Collectors;
import org.keycloak.common.util.CollectionUtil;
import org.keycloak.models.UserModel;
import org.keycloak.userprofile.AttributeContext;
import org.keycloak.userprofile.UserProfileAttributeValidationContext;
@ -64,7 +65,7 @@ public class ImmutableAttributeValidator implements SimpleValidator {
List<String> currentValue = user.getAttributeStream(inputHint).collect(Collectors.toList());
List<String> values = (List<String>) input;
if (!(currentValue.containsAll(values) && currentValue.size() == values.size())) {
if (!CollectionUtil.collectionEquals(currentValue, values)) {
if (currentValue.isEmpty() && !notBlankValidator().validate(values).isValid()) {
return context;
}