Change returning type to Set in MapClientEntity when obtaining protocol mappers

Closes #11136
This commit is contained in:
vramik 2022-06-24 08:10:21 +02:00 committed by Hynek Mlnařík
parent e3ece8244f
commit 91335ebaad
6 changed files with 63 additions and 38 deletions

View file

@ -24,12 +24,14 @@ import org.keycloak.models.map.storage.hotRod.common.AbstractHotRodEntity;
import org.keycloak.models.map.storage.hotRod.common.HotRodAttributeEntity; import org.keycloak.models.map.storage.hotRod.common.HotRodAttributeEntity;
import org.keycloak.models.map.storage.hotRod.common.HotRodPair; import org.keycloak.models.map.storage.hotRod.common.HotRodPair;
import org.keycloak.models.map.client.MapClientEntity; import org.keycloak.models.map.client.MapClientEntity;
import org.keycloak.models.map.client.MapProtocolMapperEntity;
import org.keycloak.models.map.storage.hotRod.common.UpdatableHotRodEntityDelegateImpl; import org.keycloak.models.map.storage.hotRod.common.UpdatableHotRodEntityDelegateImpl;
import java.util.Collection; import java.util.Collection;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -190,6 +192,20 @@ public class HotRodClientEntity extends AbstractHotRodEntity {
.filter(me -> Objects.equals(me.getValue(), defaultScope)) .filter(me -> Objects.equals(me.getValue(), defaultScope))
.map(Map.Entry::getKey); .map(Map.Entry::getKey);
} }
@Override
public Optional<MapProtocolMapperEntity> getProtocolMapper(String id) {
Set<MapProtocolMapperEntity> mappers = getProtocolMappers();
if (mappers == null || mappers.isEmpty()) return Optional.empty();
return mappers.stream().filter(m -> Objects.equals(m.getId(), id)).findFirst();
}
@Override
public void removeProtocolMapper(String id) {
HotRodClientEntity entity = getHotRodEntity();
entity.updated |= entity.protocolMappers != null && entity.protocolMappers.removeIf(m -> Objects.equals(m.id, id));
}
} }
@Override @Override

View file

@ -211,23 +211,13 @@ public class JpaClientEntity extends AbstractClientEntity implements JpaRootVers
} }
@Override @Override
public MapProtocolMapperEntity getProtocolMapper(String id) { public Set<MapProtocolMapperEntity> getProtocolMappers() {
return metadata.getProtocolMapper(id);
}
@Override
public Map<String, MapProtocolMapperEntity> getProtocolMappers() {
return metadata.getProtocolMappers(); return metadata.getProtocolMappers();
} }
@Override @Override
public void removeProtocolMapper(String id) { public void addProtocolMapper(MapProtocolMapperEntity mapping) {
metadata.removeProtocolMapper(id); metadata.addProtocolMapper(mapping);
}
@Override
public void setProtocolMapper(String id, MapProtocolMapperEntity mapping) {
metadata.setProtocolMapper(id, mapping);
} }
@Override @Override

View file

@ -524,9 +524,8 @@ public abstract class MapClientAdapter extends AbstractClientModel<MapClientEnti
@Override @Override
public Stream<ProtocolMapperModel> getProtocolMappersStream() { public Stream<ProtocolMapperModel> getProtocolMappersStream() {
final Map<String, MapProtocolMapperEntity> protocolMappers = entity.getProtocolMappers(); final Set<MapProtocolMapperEntity> protocolMappers = entity.getProtocolMappers();
return protocolMappers == null ? Stream.empty() : protocolMappers.values().stream().distinct() return protocolMappers == null ? Stream.empty() : protocolMappers.stream().distinct().map(pmUtils::toModel);
.map(pmUtils::toModel);
} }
@Override @Override
@ -544,7 +543,7 @@ public abstract class MapClientAdapter extends AbstractClientModel<MapClientEnti
pm.setConfig(new HashMap<>()); pm.setConfig(new HashMap<>());
} }
entity.setProtocolMapper(pm.getId(), pm); entity.addProtocolMapper(pm);
return pmUtils.toModel(pm); return pmUtils.toModel(pm);
} }
@ -560,23 +559,25 @@ public abstract class MapClientAdapter extends AbstractClientModel<MapClientEnti
public void updateProtocolMapper(ProtocolMapperModel mapping) { public void updateProtocolMapper(ProtocolMapperModel mapping) {
final String id = mapping == null ? null : mapping.getId(); final String id = mapping == null ? null : mapping.getId();
if (id != null) { if (id != null) {
entity.setProtocolMapper(id, MapProtocolMapperUtils.fromModel(mapping)); entity.getProtocolMapper(id).ifPresent((pmEntity) -> {
entity.removeProtocolMapper(id);
addProtocolMapper(mapping);
});
} }
} }
@Override @Override
public ProtocolMapperModel getProtocolMapperById(String id) { public ProtocolMapperModel getProtocolMapperById(String id) {
MapProtocolMapperEntity protocolMapper = entity.getProtocolMapper(id); return entity.getProtocolMapper(id).map(pmUtils::toModel).orElse(null);
return protocolMapper == null ? null : pmUtils.toModel(protocolMapper);
} }
@Override @Override
public ProtocolMapperModel getProtocolMapperByName(String protocol, String name) { public ProtocolMapperModel getProtocolMapperByName(String protocol, String name) {
final Map<String, MapProtocolMapperEntity> protocolMappers = entity.getProtocolMappers(); final Set<MapProtocolMapperEntity> protocolMappers = entity.getProtocolMappers();
if (! Objects.equals(protocol, safeGetProtocol())) { if (! Objects.equals(protocol, safeGetProtocol())) {
return null; return null;
} }
return protocolMappers == null ? null : protocolMappers.values().stream() return protocolMappers == null ? null : protocolMappers.stream()
.filter(pm -> Objects.equals(pm.getName(), name)) .filter(pm -> Objects.equals(pm.getName(), name))
.map(pmUtils::toModel) .map(pmUtils::toModel)
.findAny() .findAny()

View file

@ -59,13 +59,13 @@ public interface MapClientEntity extends AbstractEntity, UpdatableEntity, Entity
@Override @Override
public boolean isUpdated() { public boolean isUpdated() {
return this.updated return this.updated
|| Optional.ofNullable(getProtocolMappers()).orElseGet(Collections::emptyMap).values().stream().anyMatch(MapProtocolMapperEntity::isUpdated); || Optional.ofNullable(getProtocolMappers()).orElseGet(Collections::emptySet).stream().anyMatch(MapProtocolMapperEntity::isUpdated);
} }
@Override @Override
public void clearUpdatedFlag() { public void clearUpdatedFlag() {
this.updated = false; this.updated = false;
Optional.ofNullable(getProtocolMappers()).orElseGet(Collections::emptyMap).values().forEach(UpdatableEntity::clearUpdatedFlag); Optional.ofNullable(getProtocolMappers()).orElseGet(Collections::emptySet).forEach(UpdatableEntity::clearUpdatedFlag);
} }
@Override @Override
@ -75,6 +75,20 @@ public interface MapClientEntity extends AbstractEntity, UpdatableEntity, Entity
.filter(me -> Objects.equals(me.getValue(), defaultScope)) .filter(me -> Objects.equals(me.getValue(), defaultScope))
.map(Entry::getKey); .map(Entry::getKey);
} }
@Override
public Optional<MapProtocolMapperEntity> getProtocolMapper(String id) {
Set<MapProtocolMapperEntity> mappers = getProtocolMappers();
if (mappers == null || mappers.isEmpty()) return Optional.empty();
return mappers.stream().filter(mapper -> Objects.equals(mapper.getId(), id)).findFirst();
}
@Override
public void removeProtocolMapper(String id) {
Set<MapProtocolMapperEntity> mappers = getProtocolMappers();
this.updated |= mappers != null && mappers.removeIf(mapper -> Objects.equals(mapper.getId(), id));
}
} }
Map<String, Boolean> getClientScopes(); Map<String, Boolean> getClientScopes();
@ -82,9 +96,9 @@ public interface MapClientEntity extends AbstractEntity, UpdatableEntity, Entity
void setClientScope(String id, Boolean defaultScope); void setClientScope(String id, Boolean defaultScope);
void removeClientScope(String id); void removeClientScope(String id);
MapProtocolMapperEntity getProtocolMapper(String id); Optional<MapProtocolMapperEntity> getProtocolMapper(String id);
Map<String, MapProtocolMapperEntity> getProtocolMappers(); Set<MapProtocolMapperEntity> getProtocolMappers();
void setProtocolMapper(String id, MapProtocolMapperEntity mapping); void addProtocolMapper(MapProtocolMapperEntity mapping);
void removeProtocolMapper(String id); void removeProtocolMapper(String id);
void addRedirectUri(String redirectUri); void addRedirectUri(String redirectUri);

View file

@ -92,7 +92,7 @@ public class MapClientEntityClonerTest {
config.put("key1", "value1"); config.put("key1", "value1");
config.put("key2", "value2"); config.put("key2", "value2");
pmm.setConfig(config); pmm.setConfig(config);
newInstance.setProtocolMapper("pmm-id", pmm); newInstance.addProtocolMapper(pmm);
newInstance.setAttribute("attr", Arrays.asList("aa", "bb", "cc")); newInstance.setAttribute("attr", Arrays.asList("aa", "bb", "cc"));
MapClientEntity clonedInstance = CLONER.newInstance(MapClientEntity.class); MapClientEntity clonedInstance = CLONER.newInstance(MapClientEntity.class);
@ -108,10 +108,12 @@ public class MapClientEntityClonerTest {
assertThat(clonedInstance.getAttributes().get("attr"), not(sameInstance(newInstance.getAttributes().get("attr")))); assertThat(clonedInstance.getAttributes().get("attr"), not(sameInstance(newInstance.getAttributes().get("attr"))));
assertThat(clonedInstance.getProtocolMappers(), not(sameInstance(newInstance.getProtocolMappers()))); assertThat(clonedInstance.getProtocolMappers(), not(sameInstance(newInstance.getProtocolMappers())));
assertThat(clonedInstance.getProtocolMapper("pmm-id"), not(sameInstance(newInstance.getProtocolMapper("pmm-id")))); assertThat(clonedInstance.getProtocolMapper("pmm-id").isPresent(), is(true));
assertThat(clonedInstance.getProtocolMapper("pmm-id"), equalTo(newInstance.getProtocolMapper("pmm-id"))); assertThat(newInstance.getProtocolMapper("pmm-id").isPresent(), is(true));
assertThat(clonedInstance.getProtocolMapper("pmm-id").getConfig(), not(sameInstance(newInstance.getProtocolMapper("pmm-id").getConfig()))); assertThat(clonedInstance.getProtocolMapper("pmm-id").get(), not(sameInstance(newInstance.getProtocolMapper("pmm-id").get())));
assertThat(clonedInstance.getProtocolMapper("pmm-id").getConfig(), equalTo(newInstance.getProtocolMapper("pmm-id").getConfig())); assertThat(clonedInstance.getProtocolMapper("pmm-id").get(), equalTo(newInstance.getProtocolMapper("pmm-id").get()));
assertThat(clonedInstance.getProtocolMapper("pmm-id").get().getConfig(), not(sameInstance(newInstance.getProtocolMapper("pmm-id").get().getConfig())));
assertThat(clonedInstance.getProtocolMapper("pmm-id").get().getConfig(), equalTo(newInstance.getProtocolMapper("pmm-id").get().getConfig()));
assertThat(clonedInstance.getAuthenticationFlowBindingOverrides(), nullValue()); assertThat(clonedInstance.getAuthenticationFlowBindingOverrides(), nullValue());
assertThat(clonedInstance.getRegistrationToken(), nullValue()); assertThat(clonedInstance.getRegistrationToken(), nullValue());
@ -130,7 +132,7 @@ public class MapClientEntityClonerTest {
config.put("key2", "value2"); config.put("key2", "value2");
pmm.setConfig(config); pmm.setConfig(config);
newInstance.setProtocolMapper("pmm-id", pmm); newInstance.addProtocolMapper(pmm);
newInstance.setAttribute("attr", Arrays.asList("aa", "bb", "cc")); newInstance.setAttribute("attr", Arrays.asList("aa", "bb", "cc"));
MapClientEntity clonedInstance = CLONER.newInstance(MapClientEntity.class); MapClientEntity clonedInstance = CLONER.newInstance(MapClientEntity.class);
@ -146,10 +148,12 @@ public class MapClientEntityClonerTest {
assertThat(clonedInstance.getAttributes().get("attr"), not(sameInstance(newInstance.getAttributes().get("attr")))); assertThat(clonedInstance.getAttributes().get("attr"), not(sameInstance(newInstance.getAttributes().get("attr"))));
assertThat(clonedInstance.getProtocolMappers(), not(sameInstance(newInstance.getProtocolMappers()))); assertThat(clonedInstance.getProtocolMappers(), not(sameInstance(newInstance.getProtocolMappers())));
assertThat(clonedInstance.getProtocolMapper("pmm-id"), not(sameInstance(newInstance.getProtocolMapper("pmm-id")))); assertThat(clonedInstance.getProtocolMapper("pmm-id").isPresent(), is(true));
assertThat(clonedInstance.getProtocolMapper("pmm-id"), equalTo(newInstance.getProtocolMapper("pmm-id"))); assertThat(newInstance.getProtocolMapper("pmm-id").isPresent(), is(true));
assertThat(clonedInstance.getProtocolMapper("pmm-id").getConfig(), not(sameInstance(newInstance.getProtocolMapper("pmm-id").getConfig()))); assertThat(clonedInstance.getProtocolMapper("pmm-id").get(), not(sameInstance(newInstance.getProtocolMapper("pmm-id").get())));
assertThat(clonedInstance.getProtocolMapper("pmm-id").getConfig(), equalTo(newInstance.getProtocolMapper("pmm-id").getConfig())); assertThat(clonedInstance.getProtocolMapper("pmm-id").get(), equalTo(newInstance.getProtocolMapper("pmm-id").get()));
assertThat(clonedInstance.getProtocolMapper("pmm-id").get().getConfig(), not(sameInstance(newInstance.getProtocolMapper("pmm-id").get().getConfig())));
assertThat(clonedInstance.getProtocolMapper("pmm-id").get().getConfig(), equalTo(newInstance.getProtocolMapper("pmm-id").get().getConfig()));
assertThat(clonedInstance.getAuthenticationFlowBindingOverrides(), nullValue()); assertThat(clonedInstance.getAuthenticationFlowBindingOverrides(), nullValue());
assertThat(clonedInstance.getRegistrationToken(), nullValue()); assertThat(clonedInstance.getRegistrationToken(), nullValue());

View file

@ -164,7 +164,7 @@ public class OIDCClientRegistrationProvider extends AbstractClientRegistrationPr
} else { } else {
return false; return false;
} }
}).forEach((ProtocolMapperModel mapping) -> { }).collect(Collectors.toList()).forEach((ProtocolMapperModel mapping) -> {
PairwiseSubMapperHelper.setSectorIdentifierUri(mapping, sectorIdentifierUri); PairwiseSubMapperHelper.setSectorIdentifierUri(mapping, sectorIdentifierUri);
clientModel.updateProtocolMapper(mapping); clientModel.updateProtocolMapper(mapping);
}); });