Store user nested entities in Set instead of Map

This commit is contained in:
Michal Hajas 2022-01-07 10:04:49 +01:00 committed by Hynek Mlnařík
parent 9849df3757
commit ab9413b48c
6 changed files with 250 additions and 128 deletions

View file

@ -59,7 +59,7 @@ import javax.lang.model.type.TypeKind;
@SupportedSourceVersion(SourceVersion.RELEASE_8) @SupportedSourceVersion(SourceVersion.RELEASE_8)
public class GenerateEntityImplementationsProcessor extends AbstractGenerateEntityImplementationsProcessor { public class GenerateEntityImplementationsProcessor extends AbstractGenerateEntityImplementationsProcessor {
private final Collection<String> autogenerated = new TreeSet<>(); private static final Collection<String> autogenerated = new TreeSet<>();
private final Generator[] generators = new Generator[] { private final Generator[] generators = new Generator[] {
new ClonerGenerator(), new ClonerGenerator(),

View file

@ -279,7 +279,7 @@ public class MapFieldPredicates {
String providerId = ensureEqSingleValue(UserModel.SearchableFields.CONSENT_CLIENT_FEDERATION_LINK, "provider_id", op, values); String providerId = ensureEqSingleValue(UserModel.SearchableFields.CONSENT_CLIENT_FEDERATION_LINK, "provider_id", op, values);
String providerIdS = new StorageId((String) providerId, "").getId(); String providerIdS = new StorageId((String) providerId, "").getId();
Function<MapUserEntity, ?> getter; Function<MapUserEntity, ?> getter;
getter = ue -> Optional.ofNullable(ue.getUserConsents()).orElseGet(Collections::emptyMap).values().stream().map(MapUserConsentEntity::getClientId).anyMatch(v -> v != null && v.startsWith(providerIdS)); getter = ue -> Optional.ofNullable(ue.getUserConsents()).orElseGet(Collections::emptySet).stream().map(MapUserConsentEntity::getClientId).anyMatch(v -> v != null && v.startsWith(providerIdS));
return mcb.fieldCompare(Boolean.TRUE::equals, getter); return mcb.fieldCompare(Boolean.TRUE::equals, getter);
} }
@ -445,7 +445,7 @@ public class MapFieldPredicates {
private static MapModelCriteriaBuilder<Object, MapUserEntity, UserModel> checkUserClientConsent(MapModelCriteriaBuilder<Object, MapUserEntity, UserModel> mcb, Operator op, Object[] values) { private static MapModelCriteriaBuilder<Object, MapUserEntity, UserModel> checkUserClientConsent(MapModelCriteriaBuilder<Object, MapUserEntity, UserModel> mcb, Operator op, Object[] values) {
String clientIdS = ensureEqSingleValue(UserModel.SearchableFields.CONSENT_FOR_CLIENT, "client_id", op, values); String clientIdS = ensureEqSingleValue(UserModel.SearchableFields.CONSENT_FOR_CLIENT, "client_id", op, values);
Function<MapUserEntity, ?> getter; Function<MapUserEntity, ?> getter;
getter = ue -> ue.getUserConsent(clientIdS); getter = ue -> ue.getUserConsent(clientIdS).orElse(null);
return mcb.fieldCompare(Operator.EXISTS, getter, null); return mcb.fieldCompare(Operator.EXISTS, getter, null);
} }
@ -453,7 +453,7 @@ public class MapFieldPredicates {
private static MapModelCriteriaBuilder<Object, MapUserEntity, UserModel> checkUserConsentsWithClientScope(MapModelCriteriaBuilder<Object, MapUserEntity, UserModel> mcb, Operator op, Object[] values) { private static MapModelCriteriaBuilder<Object, MapUserEntity, UserModel> checkUserConsentsWithClientScope(MapModelCriteriaBuilder<Object, MapUserEntity, UserModel> mcb, Operator op, Object[] values) {
String clientScopeIdS = ensureEqSingleValue(UserModel.SearchableFields.CONSENT_FOR_CLIENT, "client_scope_id", op, values); String clientScopeIdS = ensureEqSingleValue(UserModel.SearchableFields.CONSENT_FOR_CLIENT, "client_scope_id", op, values);
Function<MapUserEntity, ?> getter; Function<MapUserEntity, ?> getter;
getter = ue -> Optional.ofNullable(ue.getUserConsents()).orElseGet(Collections::emptyMap).values().stream().anyMatch(consent -> Optional.ofNullable(consent.getGrantedClientScopesIds()).orElseGet(Collections::emptySet).contains(clientScopeIdS)); getter = ue -> Optional.ofNullable(ue.getUserConsents()).orElseGet(Collections::emptySet).stream().anyMatch(consent -> Optional.ofNullable(consent.getGrantedClientScopesIds()).orElseGet(Collections::emptySet).contains(clientScopeIdS));
return mcb.fieldCompare(Boolean.TRUE::equals, getter); return mcb.fieldCompare(Boolean.TRUE::equals, getter);
} }
@ -469,15 +469,15 @@ public class MapFieldPredicates {
final Object idpAlias = values[0]; final Object idpAlias = values[0];
Function<MapUserEntity, ?> getter; Function<MapUserEntity, ?> getter;
if (values.length == 1) { if (values.length == 1) {
getter = ue -> Optional.ofNullable(ue.getFederatedIdentities()).orElseGet(Collections::emptyMap).values().stream() getter = ue -> Optional.ofNullable(ue.getFederatedIdentities()).orElseGet(Collections::emptySet).stream()
.anyMatch(aue -> Objects.equals(idpAlias, aue.getIdentityProvider())); .anyMatch(aue -> Objects.equals(idpAlias, aue.getIdentityProvider()));
} else if (idpAlias == null) { } else if (idpAlias == null) {
final Object idpUserId = values[1]; final Object idpUserId = values[1];
getter = ue -> Optional.ofNullable(ue.getFederatedIdentities()).orElseGet(Collections::emptyMap).values().stream() getter = ue -> Optional.ofNullable(ue.getFederatedIdentities()).orElseGet(Collections::emptySet).stream()
.anyMatch(aue -> Objects.equals(idpUserId, aue.getUserId())); .anyMatch(aue -> Objects.equals(idpUserId, aue.getUserId()));
} else { } else {
final Object idpUserId = values[1]; final Object idpUserId = values[1];
getter = ue -> Optional.ofNullable(ue.getFederatedIdentities()).orElseGet(Collections::emptyMap).values().stream() getter = ue -> Optional.ofNullable(ue.getFederatedIdentities()).orElseGet(Collections::emptySet).stream()
.anyMatch(aue -> Objects.equals(idpAlias, aue.getIdentityProvider()) && Objects.equals(idpUserId, aue.getUserId())); .anyMatch(aue -> Objects.equals(idpAlias, aue.getIdentityProvider()) && Objects.equals(idpUserId, aue.getUserId()));
} }

View file

@ -29,8 +29,6 @@ import java.util.Comparator;
@DeepCloner.Root @DeepCloner.Root
public interface MapUserCredentialEntity extends UpdatableEntity { public interface MapUserCredentialEntity extends UpdatableEntity {
Comparator<MapUserCredentialEntity> ORDER_BY_PRIORITY = Comparator.comparing(MapUserCredentialEntity::getPriority);
public static MapUserCredentialEntity fromModel(CredentialModel model) { public static MapUserCredentialEntity fromModel(CredentialModel model) {
MapUserCredentialEntity credentialEntity = new MapUserCredentialEntityImpl(); MapUserCredentialEntity credentialEntity = new MapUserCredentialEntityImpl();
String id = model.getId() == null ? KeycloakModelUtils.generateId() : model.getId(); String id = model.getId() == null ? KeycloakModelUtils.generateId() : model.getId();
@ -72,7 +70,4 @@ public interface MapUserCredentialEntity extends UpdatableEntity {
String getCredentialData(); String getCredentialData();
void setCredentialData(String credentialData); void setCredentialData(String credentialData);
Integer getPriority();
void setPriority(Integer priority);
} }

View file

@ -17,6 +17,7 @@
package org.keycloak.models.map.user; package org.keycloak.models.map.user;
import org.jboss.logging.Logger;
import org.keycloak.models.map.annotations.GenerateEntityImplementations; import org.keycloak.models.map.annotations.GenerateEntityImplementations;
import org.keycloak.models.map.annotations.IgnoreForEntityImplementationGenerator; import org.keycloak.models.map.annotations.IgnoreForEntityImplementationGenerator;
import org.keycloak.models.map.common.AbstractEntity; import org.keycloak.models.map.common.AbstractEntity;
@ -26,12 +27,11 @@ import org.keycloak.models.map.common.UpdatableEntity;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Stream;
@GenerateEntityImplementations( @GenerateEntityImplementations(
inherits = "org.keycloak.models.map.user.MapUserEntity.AbstractUserEntity" inherits = "org.keycloak.models.map.user.MapUserEntity.AbstractUserEntity"
@ -41,22 +41,23 @@ public interface MapUserEntity extends UpdatableEntity, AbstractEntity, EntityWi
public abstract class AbstractUserEntity extends UpdatableEntity.Impl implements MapUserEntity { public abstract class AbstractUserEntity extends UpdatableEntity.Impl implements MapUserEntity {
private static final Logger LOG = Logger.getLogger(MapUserProvider.class);
private String id; private String id;
@Override @Override
public boolean isUpdated() { public boolean isUpdated() {
return this.updated return this.updated
|| Optional.ofNullable(getUserConsents()).orElseGet(Collections::emptyMap).values().stream().anyMatch(MapUserConsentEntity::isUpdated) || Optional.ofNullable(getUserConsents()).orElseGet(Collections::emptySet).stream().anyMatch(MapUserConsentEntity::isUpdated)
|| Optional.ofNullable(getCredentials()).orElseGet(Collections::emptyMap).values().stream().anyMatch(MapUserCredentialEntity::isUpdated) || Optional.ofNullable(getCredentials()).orElseGet(Collections::emptyList).stream().anyMatch(MapUserCredentialEntity::isUpdated)
|| Optional.ofNullable(getFederatedIdentities()).orElseGet(Collections::emptyMap).values().stream().anyMatch(MapUserFederatedIdentityEntity::isUpdated); || Optional.ofNullable(getFederatedIdentities()).orElseGet(Collections::emptySet).stream().anyMatch(MapUserFederatedIdentityEntity::isUpdated);
} }
@Override @Override
public void clearUpdatedFlag() { public void clearUpdatedFlag() {
this.updated = false; this.updated = false;
Optional.ofNullable(getUserConsents()).orElseGet(Collections::emptyMap).values().forEach(UpdatableEntity::clearUpdatedFlag); Optional.ofNullable(getUserConsents()).orElseGet(Collections::emptySet).forEach(UpdatableEntity::clearUpdatedFlag);
Optional.ofNullable(getCredentials()).orElseGet(Collections::emptyMap).values().forEach(UpdatableEntity::clearUpdatedFlag); Optional.ofNullable(getCredentials()).orElseGet(Collections::emptyList).forEach(UpdatableEntity::clearUpdatedFlag);
Optional.ofNullable(getFederatedIdentities()).orElseGet(Collections::emptyMap).values().forEach(UpdatableEntity::clearUpdatedFlag); Optional.ofNullable(getFederatedIdentities()).orElseGet(Collections::emptySet).forEach(UpdatableEntity::clearUpdatedFlag);
} }
@ -77,6 +78,97 @@ public interface MapUserEntity extends UpdatableEntity, AbstractEntity, EntityWi
this.setEmail(email); this.setEmail(email);
this.setEmailConstraint(email == null || duplicateEmailsAllowed ? KeycloakModelUtils.generateId() : email); this.setEmailConstraint(email == null || duplicateEmailsAllowed ? KeycloakModelUtils.generateId() : email);
} }
@Override
public Optional<MapUserConsentEntity> getUserConsent(String clientId) {
Set<MapUserConsentEntity> ucs = getUserConsents();
if (ucs == null || ucs.isEmpty()) return Optional.empty();
return ucs.stream().filter(uc -> Objects.equals(uc.getClientId(), clientId)).findFirst();
}
@Override
public Boolean removeUserConsent(String clientId) {
Set<MapUserConsentEntity> consents = getUserConsents();
boolean removed = consents != null && consents.removeIf(uc -> Objects.equals(uc.getClientId(), clientId));
this.updated |= removed;
return removed;
}
@Override
public Optional<MapUserCredentialEntity> getCredential(String id) {
List<MapUserCredentialEntity> uce = getCredentials();
if (uce == null || uce.isEmpty()) return Optional.empty();
return uce.stream().filter(uc -> Objects.equals(uc.getId(), id)).findFirst();
}
@Override
public Boolean removeCredential(String id) {
List<MapUserCredentialEntity> credentials = getCredentials();
boolean removed = credentials != null && credentials.removeIf(c -> Objects.equals(c.getId(), id));
this.updated |= removed;
return removed;
}
@Override
public Boolean moveCredential(String credentialId, String newPreviousCredentialId) {
// 1 - Get all credentials from the entity.
List<MapUserCredentialEntity> credentialsList = getCredentials();
// 2 - Find indexes of our and newPrevious credential
int ourCredentialIndex = -1;
int newPreviousCredentialIndex = -1;
MapUserCredentialEntity ourCredential = null;
int i = 0;
for (MapUserCredentialEntity credential : credentialsList) {
if (credentialId.equals(credential.getId())) {
ourCredentialIndex = i;
ourCredential = credential;
} else if(newPreviousCredentialId != null && newPreviousCredentialId.equals(credential.getId())) {
newPreviousCredentialIndex = i;
}
i++;
}
if (ourCredentialIndex == -1) {
LOG.warnf("Not found credential with id [%s] of user [%s]", credentialId, getUsername());
return false;
}
if (newPreviousCredentialId != null && newPreviousCredentialIndex == -1) {
LOG.warnf("Can't move up credential with id [%s] of user [%s]", credentialId, getUsername());
return false;
}
// 3 - Compute index where we move our credential
int toMoveIndex = newPreviousCredentialId==null ? 0 : newPreviousCredentialIndex + 1;
// 4 - Insert our credential to new position, remove it from the old position
if (toMoveIndex == ourCredentialIndex) return true;
credentialsList.add(toMoveIndex, ourCredential);
int indexToRemove = toMoveIndex < ourCredentialIndex ? ourCredentialIndex + 1 : ourCredentialIndex;
credentialsList.remove(indexToRemove);
this.updated = true;
return true;
}
@Override
public Optional<MapUserFederatedIdentityEntity> getFederatedIdentity(String identityProviderId) {
Set<MapUserFederatedIdentityEntity> fes = getFederatedIdentities();
if (fes == null || fes.isEmpty()) return Optional.empty();
return fes.stream().filter(fi -> Objects.equals(fi.getIdentityProvider(), identityProviderId)).findFirst();
}
@Override
public Boolean removeFederatedIdentity(String identityProviderId) {
Set<MapUserFederatedIdentityEntity> federatedIdentities = getFederatedIdentities();
boolean removed = federatedIdentities != null && federatedIdentities.removeIf(fi -> Objects.equals(fi.getIdentityProvider(), identityProviderId));
this.updated |= removed;
return removed;
}
} }
String getRealmId(); String getRealmId();
@ -119,20 +211,27 @@ public interface MapUserEntity extends UpdatableEntity, AbstractEntity, EntityWi
void addRequiredAction(String requiredAction); void addRequiredAction(String requiredAction);
void removeRequiredAction(String requiredAction); void removeRequiredAction(String requiredAction);
Map<String, MapUserCredentialEntity> getCredentials(); List<MapUserCredentialEntity> getCredentials();
void setCredential(String id, MapUserCredentialEntity credentialEntity); Optional<MapUserCredentialEntity> getCredential(String id);
Boolean removeCredential(String credentialId); void setCredentials(List<MapUserCredentialEntity> credentials);
MapUserCredentialEntity getCredential(String id); void addCredential(MapUserCredentialEntity credentialEntity);
Boolean removeCredential(MapUserCredentialEntity credentialEntity);
Boolean removeCredential(String id);
@IgnoreForEntityImplementationGenerator
Boolean moveCredential(String credentialId, String newPreviousCredentialId);
Map<String, MapUserFederatedIdentityEntity> getFederatedIdentities(); Set<MapUserFederatedIdentityEntity> getFederatedIdentities();
void setFederatedIdentities(Map<String, MapUserFederatedIdentityEntity> federatedIdentities); Optional<MapUserFederatedIdentityEntity> getFederatedIdentity(String identityProviderId);
void setFederatedIdentity(String id, MapUserFederatedIdentityEntity federatedIdentity); void setFederatedIdentities(Set<MapUserFederatedIdentityEntity> federatedIdentities);
MapUserFederatedIdentityEntity getFederatedIdentity(String federatedIdentity); void addFederatedIdentity(MapUserFederatedIdentityEntity federatedIdentity);
Boolean removeFederatedIdentity(String providerId); Boolean removeFederatedIdentity(MapUserFederatedIdentityEntity providerId);
Boolean removeFederatedIdentity(String identityProviderId);
Map<String, MapUserConsentEntity> getUserConsents(); Set<MapUserConsentEntity> getUserConsents();
MapUserConsentEntity getUserConsent(String clientId); Optional<MapUserConsentEntity> getUserConsent(String clientId);
void setUserConsent(String id, MapUserConsentEntity userConsentEntity); void setUserConsents(Set<MapUserConsentEntity> userConsentEntity);
void addUserConsent(MapUserConsentEntity userConsentEntity);
Boolean removeUserConsent(MapUserConsentEntity userConsentEntity);
Boolean removeUserConsent(String clientId); Boolean removeUserConsent(String clientId);
Set<String> getGroupsMembership(); Set<String> getGroupsMembership();

View file

@ -51,11 +51,9 @@ import org.keycloak.storage.UserStorageProvider;
import org.keycloak.storage.client.ClientStorageProvider; import org.keycloak.storage.client.ClientStorageProvider;
import java.util.Collection; import java.util.Collection;
import java.util.Comparator;
import java.util.EnumMap; import java.util.EnumMap;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -79,8 +77,6 @@ import static org.keycloak.models.map.storage.criteria.DefaultModelCriteria.crit
public class MapUserProvider implements UserProvider.Streams, UserCredentialStore.Streams { public class MapUserProvider implements UserProvider.Streams, UserCredentialStore.Streams {
// Typical priority difference between 2 credentials
public static final int PRIORITY_DIFFERENCE = 10;
private static final Logger LOG = Logger.getLogger(MapUserProvider.class); private static final Logger LOG = Logger.getLogger(MapUserProvider.class);
private final KeycloakSession session; private final KeycloakSession session;
final MapKeycloakTransaction<MapUserEntity, UserModel> tx; final MapKeycloakTransaction<MapUserEntity, UserModel> tx;
@ -145,7 +141,7 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
getEntityById(realm, user.getId()) getEntityById(realm, user.getId())
.ifPresent(userEntity -> .ifPresent(userEntity ->
userEntity.setFederatedIdentity(socialLink.getIdentityProvider(), MapUserFederatedIdentityEntity.fromModel(socialLink))); userEntity.addFederatedIdentity(MapUserFederatedIdentityEntity.fromModel(socialLink)));
} }
@Override @Override
@ -175,7 +171,12 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
public void updateFederatedIdentity(RealmModel realm, UserModel federatedUser, FederatedIdentityModel federatedIdentityModel) { public void updateFederatedIdentity(RealmModel realm, UserModel federatedUser, FederatedIdentityModel federatedIdentityModel) {
LOG.tracef("updateFederatedIdentity(%s, %s, %s)%s", realm, federatedUser.getId(), federatedIdentityModel.getIdentityProvider(), getShortStackTrace()); LOG.tracef("updateFederatedIdentity(%s, %s, %s)%s", realm, federatedUser.getId(), federatedIdentityModel.getIdentityProvider(), getShortStackTrace());
getEntityById(realm, federatedUser.getId()) getEntityById(realm, federatedUser.getId())
.ifPresent(entity -> entity.setFederatedIdentity(federatedIdentityModel.getIdentityProvider(), MapUserFederatedIdentityEntity.fromModel(federatedIdentityModel))); .flatMap(u -> u.getFederatedIdentity(federatedIdentityModel.getIdentityProvider()))
.ifPresent(fi -> {
fi.setUserId(federatedIdentityModel.getUserId());
fi.setUserName(federatedIdentityModel.getUserName());
fi.setToken(federatedIdentityModel.getToken());
});
} }
@Override @Override
@ -183,7 +184,6 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
LOG.tracef("getFederatedIdentitiesStream(%s, %s)%s", realm, user.getId(), getShortStackTrace()); LOG.tracef("getFederatedIdentitiesStream(%s, %s)%s", realm, user.getId(), getShortStackTrace());
return getEntityById(realm, user.getId()) return getEntityById(realm, user.getId())
.map(MapUserEntity::getFederatedIdentities) .map(MapUserEntity::getFederatedIdentities)
.map(Map::values)
.map(Collection::stream) .map(Collection::stream)
.orElseGet(Stream::empty) .orElseGet(Stream::empty)
.map(MapUserFederatedIdentityEntity::toModel); .map(MapUserFederatedIdentityEntity::toModel);
@ -193,7 +193,7 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
public FederatedIdentityModel getFederatedIdentity(RealmModel realm, UserModel user, String socialProvider) { public FederatedIdentityModel getFederatedIdentity(RealmModel realm, UserModel user, String socialProvider) {
LOG.tracef("getFederatedIdentity(%s, %s, %s)%s", realm, user.getId(), socialProvider, getShortStackTrace()); LOG.tracef("getFederatedIdentity(%s, %s, %s)%s", realm, user.getId(), socialProvider, getShortStackTrace());
return getEntityById(realm, user.getId()) return getEntityById(realm, user.getId())
.map(userEntity -> userEntity.getFederatedIdentity(socialProvider)) .flatMap(userEntity -> userEntity.getFederatedIdentity(socialProvider))
.map(MapUserFederatedIdentityEntity::toModel) .map(MapUserFederatedIdentityEntity::toModel)
.orElse(null); .orElse(null);
} }
@ -225,14 +225,14 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
LOG.tracef("addConsent(%s, %s, %s)%s", realm, userId, consent, getShortStackTrace()); LOG.tracef("addConsent(%s, %s, %s)%s", realm, userId, consent, getShortStackTrace());
getEntityByIdOrThrow(realm, userId) getEntityByIdOrThrow(realm, userId)
.setUserConsent(consent.getClient().getId(), MapUserConsentEntity.fromModel(consent)); .addUserConsent(MapUserConsentEntity.fromModel(consent));
} }
@Override @Override
public UserConsentModel getConsentByClient(RealmModel realm, String userId, String clientInternalId) { public UserConsentModel getConsentByClient(RealmModel realm, String userId, String clientInternalId) {
LOG.tracef("getConsentByClient(%s, %s, %s)%s", realm, userId, clientInternalId, getShortStackTrace()); LOG.tracef("getConsentByClient(%s, %s, %s)%s", realm, userId, clientInternalId, getShortStackTrace());
return getEntityById(realm, userId) return getEntityById(realm, userId)
.map(userEntity -> userEntity.getUserConsent(clientInternalId)) .flatMap(userEntity -> userEntity.getUserConsent(clientInternalId))
.map(consent -> MapUserConsentEntity.toModel(realm, consent)) .map(consent -> MapUserConsentEntity.toModel(realm, consent))
.orElse(null); .orElse(null);
} }
@ -242,9 +242,8 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
LOG.tracef("getConsentByClientStream(%s, %s)%s", realm, userId, getShortStackTrace()); LOG.tracef("getConsentByClientStream(%s, %s)%s", realm, userId, getShortStackTrace());
return getEntityById(realm, userId) return getEntityById(realm, userId)
.map(MapUserEntity::getUserConsents) .map(MapUserEntity::getUserConsents)
.map(Map::values)
.map(Collection::stream) .map(Collection::stream)
.orElse(Stream.empty()) .orElseGet(Stream::empty)
.map(consent -> MapUserConsentEntity.toModel(realm, consent)); .map(consent -> MapUserConsentEntity.toModel(realm, consent));
} }
@ -253,10 +252,8 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
LOG.tracef("updateConsent(%s, %s, %s)%s", realm, userId, consent, getShortStackTrace()); LOG.tracef("updateConsent(%s, %s, %s)%s", realm, userId, consent, getShortStackTrace());
MapUserEntity user = getEntityByIdOrThrow(realm, userId); MapUserEntity user = getEntityByIdOrThrow(realm, userId);
MapUserConsentEntity userConsentEntity = user.getUserConsent(consent.getClient().getId()); MapUserConsentEntity userConsentEntity = user.getUserConsent(consent.getClient().getId())
if (userConsentEntity == null) { .orElseThrow(() -> new ModelException("Consent not found for client [" + consent.getClient().getId() + "] and user [" + userId + "]"));
throw new ModelException("Consent not found for client [" + consent.getClient().getId() + "] and user [" + userId + "]");
}
userConsentEntity.setGrantedClientScopesIds( userConsentEntity.setGrantedClientScopesIds(
consent.getGrantedClientScopes().stream() consent.getGrantedClientScopes().stream()
@ -447,7 +444,6 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
try (Stream<MapUserEntity> s = tx.read(withCriteria(mcb))) { try (Stream<MapUserEntity> s = tx.read(withCriteria(mcb))) {
s.map(MapUserEntity::getUserConsents) s.map(MapUserEntity::getUserConsents)
.filter(Objects::nonNull) .filter(Objects::nonNull)
.map(Map::values)
.flatMap(Collection::stream) .flatMap(Collection::stream)
.forEach(consent -> consent.removeGrantedClientScopesId(clientScopeId)); .forEach(consent -> consent.removeGrantedClientScopesId(clientScopeId));
} }
@ -474,9 +470,9 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
private Consumer<MapUserEntity> removeConsentsForExternalClient(String idPrefix) { private Consumer<MapUserEntity> removeConsentsForExternalClient(String idPrefix) {
return userEntity -> { return userEntity -> {
Map<String, MapUserConsentEntity> userConsents = userEntity.getUserConsents(); Set<MapUserConsentEntity> userConsents = userEntity.getUserConsents();
if (userConsents == null || userConsents.isEmpty()) return; if (userConsents == null || userConsents.isEmpty()) return;
List<String> consentClientIds = userConsents.values().stream() List<String> consentClientIds = userConsents.stream()
.map(MapUserConsentEntity::getClientId) .map(MapUserConsentEntity::getClientId)
.filter(clientId -> clientId != null && clientId.startsWith(idPrefix)) .filter(clientId -> clientId != null && clientId.startsWith(idPrefix))
.collect(Collectors.toList()); .collect(Collectors.toList());
@ -756,16 +752,13 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
} }
private Consumer<MapUserEntity> updateCredential(CredentialModel credentialModel) { private Consumer<MapUserEntity> updateCredential(CredentialModel credentialModel) {
return user -> { return user -> user.getCredential(credentialModel.getId()).ifPresent(c -> {
MapUserCredentialEntity credentialEntity = user.getCredential(credentialModel.getId()); c.setCreatedDate(credentialModel.getCreatedDate());
if (credentialEntity == null) return; c.setUserLabel(credentialModel.getUserLabel());
c.setType(credentialModel.getType());
credentialEntity.setCreatedDate(credentialModel.getCreatedDate()); c.setSecretData(credentialModel.getSecretData());
credentialEntity.setUserLabel(credentialModel.getUserLabel()); c.setCredentialData(credentialModel.getCredentialData());
credentialEntity.setType(credentialModel.getType()); });
credentialEntity.setSecretData(credentialModel.getSecretData());
credentialEntity.setCredentialData(credentialModel.getCredentialData());
};
} }
@Override @Override
@ -774,19 +767,11 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
MapUserEntity userEntity = getEntityByIdOrThrow(realm, user.getId()); MapUserEntity userEntity = getEntityByIdOrThrow(realm, user.getId());
MapUserCredentialEntity credentialEntity = MapUserCredentialEntity.fromModel(cred); MapUserCredentialEntity credentialEntity = MapUserCredentialEntity.fromModel(cred);
if (userEntity.getCredential(cred.getId()) != null) { if (userEntity.getCredential(cred.getId()).isPresent()) {
throw new ModelDuplicateException("A CredentialModel with given id already exists"); throw new ModelDuplicateException("A CredentialModel with given id already exists");
} }
Map<String, MapUserCredentialEntity> credentials = userEntity.getCredentials(); userEntity.addCredential(credentialEntity);
int priority = PRIORITY_DIFFERENCE;
if (credentials != null && !credentials.isEmpty()) {
priority = credentials.values().stream().max(MapUserCredentialEntity.ORDER_BY_PRIORITY).map(MapUserCredentialEntity::getPriority).orElse(0) + PRIORITY_DIFFERENCE;
}
credentialEntity.setPriority(priority);
userEntity.setCredential(credentialEntity.getId(), credentialEntity);
return MapUserCredentialEntity.toModel(credentialEntity); return MapUserCredentialEntity.toModel(credentialEntity);
} }
@ -806,7 +791,7 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
public CredentialModel getStoredCredentialById(RealmModel realm, UserModel user, String id) { public CredentialModel getStoredCredentialById(RealmModel realm, UserModel user, String id) {
LOG.tracef("getStoredCredentialById(%s, %s, %s)%s", realm, user.getId(), id, getShortStackTrace()); LOG.tracef("getStoredCredentialById(%s, %s, %s)%s", realm, user.getId(), id, getShortStackTrace());
return getEntityById(realm, user.getId()) return getEntityById(realm, user.getId())
.map(mapUserEntity -> mapUserEntity.getCredential(id)) .flatMap(mapUserEntity -> mapUserEntity.getCredential(id))
.map(MapUserCredentialEntity::toModel) .map(MapUserCredentialEntity::toModel)
.orElse(null); .orElse(null);
} }
@ -817,10 +802,8 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
return getEntityById(realm, user.getId()) return getEntityById(realm, user.getId())
.map(MapUserEntity::getCredentials) .map(MapUserEntity::getCredentials)
.map(Map::values)
.map(Collection::stream) .map(Collection::stream)
.orElseGet(Stream::empty) .orElseGet(Stream::empty)
.sorted(MapUserCredentialEntity.ORDER_BY_PRIORITY)
.map(MapUserCredentialEntity::toModel); .map(MapUserCredentialEntity::toModel);
} }
@ -841,60 +824,8 @@ public class MapUserProvider implements UserProvider.Streams, UserCredentialStor
@Override @Override
public boolean moveCredentialTo(RealmModel realm, UserModel user, String id, String newPreviousCredentialId) { public boolean moveCredentialTo(RealmModel realm, UserModel user, String id, String newPreviousCredentialId) {
LOG.tracef("moveCredentialTo(%s, %s, %s, %s)%s", realm, user, id, newPreviousCredentialId, getShortStackTrace());
MapUserEntity userEntity = getEntityByIdOrThrow(realm, user.getId()); return getEntityByIdOrThrow(realm, user.getId()).moveCredential(id, newPreviousCredentialId);
// 1 - Create new list and move everything to it.
Map<String, MapUserCredentialEntity> credentialEntityMap = userEntity.getCredentials();
List<MapUserCredentialEntity> newList = credentialEntityMap == null ? new LinkedList<>()
: credentialEntityMap.values().stream()
.sorted(MapUserCredentialEntity.ORDER_BY_PRIORITY)
.collect(Collectors.toList());
// 2 - Find indexes of our and newPrevious credential
int ourCredentialIndex = -1;
int newPreviousCredentialIndex = -1;
MapUserCredentialEntity ourCredential = null;
int i = 0;
for (MapUserCredentialEntity credential : newList) {
if (id.equals(credential.getId())) {
ourCredentialIndex = i;
ourCredential = credential;
} else if(newPreviousCredentialId != null && newPreviousCredentialId.equals(credential.getId())) {
newPreviousCredentialIndex = i;
}
i++;
}
if (ourCredentialIndex == -1) {
LOG.warnf("Not found credential with id [%s] of user [%s]", id, user.getUsername());
return false;
}
if (newPreviousCredentialId != null && newPreviousCredentialIndex == -1) {
LOG.warnf("Can't move up credential with id [%s] of user [%s]", id, user.getUsername());
return false;
}
// 3 - Compute index where we move our credential
int toMoveIndex = newPreviousCredentialId==null ? 0 : newPreviousCredentialIndex + 1;
// 4 - Insert our credential to new position, remove it from the old position
newList.add(toMoveIndex, ourCredential);
int indexToRemove = toMoveIndex < ourCredentialIndex ? ourCredentialIndex + 1 : ourCredentialIndex;
newList.remove(indexToRemove);
// 5 - newList contains credentials in requested order now. Iterate through whole list and change priorities accordingly.
int expectedPriority = 0;
for (MapUserCredentialEntity credential : newList) {
expectedPriority += PRIORITY_DIFFERENCE;
if (credential.getPriority() != expectedPriority) {
credential.setPriority(expectedPriority);
LOG.tracef("Priority of credential [%s] of user [%s] changed to [%d]", credential.getId(), user.getUsername(), expectedPriority);
}
}
return true;
} }
@Override @Override

View file

@ -0,0 +1,97 @@
/*
* Copyright 2022 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.map.user;
import org.junit.Before;
import org.junit.Test;
import org.hamcrest.Matchers;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import org.keycloak.models.map.common.DeepCloner;
import java.util.List;
import java.util.stream.Collectors;
public class MapUserEntityImplCredentialsOrderTest {
private MapUserEntity user;
private final static DeepCloner CLONER = new DeepCloner.Builder()
.constructor(MapUserCredentialEntityImpl.class, MapUserCredentialEntityImpl::new)
.build();
@Before
public void init() {
user = new MapUserEntityImpl(CLONER);
for (int i = 1; i <= 5; i++) {
MapUserCredentialEntity credentialModel = new MapUserCredentialEntityImpl();
credentialModel.setId(Integer.toString(i));
user.addCredential(credentialModel);
}
user.clearUpdatedFlag();
}
private void assertOrder(Integer... ids) {
List<Integer> currentList = user.getCredentials().stream().map(entity -> Integer.valueOf(entity.getId())).collect(Collectors.toList());
assertThat(currentList, Matchers.contains(ids));
}
@Test
public void testCorrectOrder() {
assertOrder(1, 2, 3, 4, 5);
}
@Test
public void testMoveToZero() {
user.moveCredential("3", null);
assertOrder(3, 1, 2, 4, 5);
assertThat(user.isUpdated(), is(true));
}
@Test
public void testMoveBack() {
user.moveCredential("4", "1");
assertOrder(1, 4, 2, 3, 5);
assertThat(user.isUpdated(), is(true));
}
@Test
public void testMoveForward() {
user.moveCredential("2", "4");
assertOrder(1, 3, 4, 2, 5);
assertThat(user.isUpdated(), is(true));
}
@Test
public void testSamePosition() {
user.moveCredential("2", "1");
assertOrder(1, 2, 3, 4, 5);
assertThat(user.isUpdated(), is(false));
}
@Test
public void testSamePositionZero() {
user.moveCredential("1", null);
assertOrder(1, 2, 3, 4, 5);
assertThat(user.isUpdated(), is(false));
}
}