Filter first, then sort, and avoid atomics

Closes #20394
This commit is contained in:
Alexander Schwartz 2023-05-09 13:27:49 +02:00 committed by Marek Posolda
parent b63fccb062
commit cd9e0be9f0
6 changed files with 96 additions and 46 deletions

View file

@ -560,14 +560,6 @@ public final class KeycloakModelUtils {
});
}
public static String resolveFirstAttribute(GroupModel group, String name) {
String value = group.getFirstAttribute(name);
if (value != null) return value;
if (group.getParentId() == null) return null;
return resolveFirstAttribute(group.getParent(), name);
}
public static Collection<String> resolveAttribute(GroupModel group, String name, boolean aggregateAttrs) {
Set<String> values = group.getAttributeStream(name).collect(Collectors.toSet());
if ((values.isEmpty() || aggregateAttrs) && group.getParentId() != null) {
@ -587,7 +579,6 @@ public final class KeycloakModelUtils {
}
Stream<Collection<String>> attributes = user.getGroupsStream()
.map(group -> resolveAttribute(group, name, aggregateAttrs))
.filter(Objects::nonNull)
.filter(attr -> !attr.isEmpty());
if (!aggregateAttrs) {

View file

@ -30,6 +30,7 @@ import java.util.AbstractMap;
import java.util.Comparator;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Stream;
/**
@ -134,6 +135,21 @@ public class ProtocolMapperUtils {
.sorted(Comparator.comparing(ProtocolMapperUtils::compare));
}
public static Stream<Entry<ProtocolMapperModel, ProtocolMapper>> getSortedProtocolMappers(KeycloakSession session, ClientSessionContext ctx, Predicate<Entry<ProtocolMapperModel, ProtocolMapper>> filter) {
KeycloakSessionFactory sessionFactory = session.getKeycloakSessionFactory();
return ctx.getProtocolMappersStream()
.<Entry<ProtocolMapperModel, ProtocolMapper>>map(mapperModel -> {
ProtocolMapper mapper = (ProtocolMapper) sessionFactory.getProviderFactory(ProtocolMapper.class, mapperModel.getProtocolMapper());
if (mapper == null) {
return null;
}
return new AbstractMap.SimpleEntry<>(mapperModel, mapper);
})
.filter(Objects::nonNull)
.filter(filter)
.sorted(Comparator.comparing(ProtocolMapperUtils::compare));
}
public static int compare(Entry<ProtocolMapperModel, ProtocolMapper> entry) {
int priority = entry.getValue().getPriority();
return priority;

View file

@ -110,9 +110,8 @@ public class DockerAuthV2Protocol implements LoginProtocol {
// Next, allow mappers to decorate the token to add/remove scopes as appropriate
AtomicReference<DockerResponseToken> finalResponseToken = new AtomicReference<>(responseToken);
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx)
.filter(mapper -> mapper.getValue() instanceof DockerAuthV2AttributeMapper)
.filter(mapper -> ((DockerAuthV2AttributeMapper) mapper.getValue()).appliesTo(finalResponseToken.get()))
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx, mapper ->
mapper.getValue() instanceof DockerAuthV2AttributeMapper && ((DockerAuthV2AttributeMapper) mapper.getValue()).appliesTo(finalResponseToken.get()))
.forEach(mapper -> finalResponseToken.set(((DockerAuthV2AttributeMapper) mapper.getValue())
.transformDockerResponseToken(finalResponseToken.get(), mapper.getKey(), session, userSession, clientSession)));
responseToken = finalResponseToken.get();

View file

@ -17,6 +17,7 @@
package org.keycloak.protocol.oidc;
import java.util.Collections;
import java.util.HashMap;
import org.jboss.logging.Logger;
import org.keycloak.http.HttpRequest;
@ -47,6 +48,7 @@ import org.keycloak.models.ClientSessionContext;
import org.keycloak.models.Constants;
import org.keycloak.models.ImpersonationSessionNote;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.ProtocolMapperModel;
import org.keycloak.models.RealmModel;
import org.keycloak.models.RoleModel;
import org.keycloak.models.SingleUseObjectProvider;
@ -56,6 +58,7 @@ import org.keycloak.models.UserSessionModel;
import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.models.utils.RoleUtils;
import org.keycloak.protocol.ProtocolMapper;
import org.keycloak.protocol.ProtocolMapperUtils;
import org.keycloak.protocol.oidc.mappers.OIDCAccessTokenMapper;
import org.keycloak.protocol.oidc.mappers.OIDCAccessTokenResponseMapper;
@ -90,8 +93,12 @@ import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -740,36 +747,36 @@ public class TokenManager {
public AccessToken transformAccessToken(KeycloakSession session, AccessToken token,
UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
AtomicReference<AccessToken> finalToken = new AtomicReference<>(token);
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx)
.filter(mapper -> mapper.getValue() instanceof OIDCAccessTokenMapper)
.forEach(mapper -> finalToken.set(((OIDCAccessTokenMapper) mapper.getValue())
.transformAccessToken(finalToken.get(), mapper.getKey(), session, userSession, clientSessionCtx)));
return finalToken.get();
return ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx, mapper -> mapper.getValue() instanceof OIDCAccessTokenMapper)
.collect(new TokenCollector<AccessToken>(token) {
@Override
protected AccessToken applyMapper(AccessToken token, Map.Entry<ProtocolMapperModel, ProtocolMapper> mapper) {
return ((OIDCAccessTokenMapper) mapper.getValue()).transformAccessToken(token, mapper.getKey(), session, userSession, clientSessionCtx);
}
});
}
public AccessTokenResponse transformAccessTokenResponse(KeycloakSession session, AccessTokenResponse accessTokenResponse,
UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
AtomicReference<AccessTokenResponse> finalResponseToken = new AtomicReference<>(accessTokenResponse);
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx)
.filter(mapper -> mapper.getValue() instanceof OIDCAccessTokenResponseMapper)
.forEach(mapper -> finalResponseToken.set(((OIDCAccessTokenResponseMapper) mapper.getValue())
.transformAccessTokenResponse(finalResponseToken.get(), mapper.getKey(), session, userSession, clientSessionCtx)));
return finalResponseToken.get();
return ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx, mapper -> mapper.getValue() instanceof OIDCAccessTokenResponseMapper)
.collect(new TokenCollector<AccessTokenResponse>(accessTokenResponse) {
@Override
protected AccessTokenResponse applyMapper(AccessTokenResponse token, Map.Entry<ProtocolMapperModel, ProtocolMapper> mapper) {
return ((OIDCAccessTokenResponseMapper) mapper.getValue()).transformAccessTokenResponse(token, mapper.getKey(), session, userSession, clientSessionCtx);
}
});
}
public AccessToken transformUserInfoAccessToken(KeycloakSession session, AccessToken token,
UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
AtomicReference<AccessToken> finalToken = new AtomicReference<>(token);
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx)
.filter(mapper -> mapper.getValue() instanceof UserInfoTokenMapper)
.forEach(mapper -> finalToken.set(((UserInfoTokenMapper) mapper.getValue())
.transformUserInfoToken(finalToken.get(), mapper.getKey(), session, userSession, clientSessionCtx)));
return finalToken.get();
UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
return ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx, mapper -> mapper.getValue() instanceof UserInfoTokenMapper)
.collect(new TokenCollector<AccessToken>(token) {
@Override
protected AccessToken applyMapper(AccessToken token, Map.Entry<ProtocolMapperModel, ProtocolMapper> mapper) {
return ((UserInfoTokenMapper) mapper.getValue()).transformUserInfoToken(token, mapper.getKey(), session, userSession, clientSessionCtx);
}
});
}
public Map<String, Object> generateUserInfoClaims(AccessToken userInfo, UserModel userModel) {
@ -863,14 +870,51 @@ public class TokenManager {
return claims;
}
public void transformIDToken(KeycloakSession session, IDToken token,
UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
private abstract static class TokenCollector<T> implements Collector<Map.Entry<ProtocolMapperModel, ProtocolMapper>, TokenCollector<T>, T> {
AtomicReference<IDToken> finalToken = new AtomicReference<>(token);
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx)
.filter(mapper -> mapper.getValue() instanceof OIDCIDTokenMapper)
.forEach(mapper -> finalToken.set(((OIDCIDTokenMapper) mapper.getValue())
.transformIDToken(finalToken.get(), mapper.getKey(), session, userSession, clientSessionCtx)));
private T token;
public TokenCollector(T token) {
this.token = token;
}
@Override
public Supplier<TokenCollector<T>> supplier() {
return () -> this;
}
@Override
public Function<TokenCollector<T>, T> finisher() {
return idTokenWrapper -> idTokenWrapper.token;
}
@Override
public Set<Collector.Characteristics> characteristics() {
return Collections.emptySet();
}
@Override
public BinaryOperator<TokenCollector<T>> combiner() {
return (tMutableWrapper, tMutableWrapper2) -> { throw new IllegalStateException("can't combine"); };
}
@Override
public BiConsumer<TokenCollector<T>, Map.Entry<ProtocolMapperModel, ProtocolMapper>> accumulator() {
return (idToken, mapper) -> idToken.token = applyMapper(idToken.token, mapper);
}
protected abstract T applyMapper(T token, Map.Entry<ProtocolMapperModel, ProtocolMapper> mapper);
}
public IDToken transformIDToken(KeycloakSession session, IDToken token,
UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
return ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx, mapper -> mapper.getValue() instanceof OIDCIDTokenMapper)
.collect(new TokenCollector<IDToken>(token) {
protected IDToken applyMapper(IDToken token, Map.Entry<ProtocolMapperModel, ProtocolMapper> mapper) {
return ((OIDCIDTokenMapper) mapper.getValue()).transformIDToken(token, mapper.getKey(), session, userSession, clientSessionCtx);
}
});
}
protected AccessToken initToken(RealmModel realm, ClientModel client, UserModel user, UserSessionModel session,
@ -1175,7 +1219,7 @@ public class TokenManager {
}
if (isIdTokenAsDetachedSignature == false) {
transformIDToken(session, idToken, userSession, clientSessionCtx);
idToken = transformIDToken(session, idToken, userSession, clientSessionCtx);
}
return this;
}
@ -1257,7 +1301,7 @@ public class TokenManager {
if (userNotBefore > notBefore) notBefore = userNotBefore;
res.setNotBeforePolicy(notBefore);
transformAccessTokenResponse(session, res, userSession, clientSessionCtx);
res = transformAccessTokenResponse(session, res, userSession, clientSessionCtx);
// OIDC Financial API Read Only Profile : scope MUST be returned in the response from Token Endpoint
String responseScope = clientSessionCtx.getScopeString();

View file

@ -255,7 +255,7 @@ public class UserInfoEndpoint {
AccessToken userInfo = new AccessToken();
tokenManager.transformUserInfoAccessToken(session, userInfo, userSession, clientSessionCtx);
userInfo = tokenManager.transformUserInfoAccessToken(session, userInfo, userSession, clientSessionCtx);
Map<String, Object> claims = tokenManager.generateUserInfoClaims(userInfo, userModel);
Response.ResponseBuilder responseBuilder;

View file

@ -168,7 +168,7 @@ public class ClientScopeEvaluateResource {
AccessToken userInfo = new AccessToken();
TokenManager tokenManager = new TokenManager();
tokenManager.transformUserInfoAccessToken(session, userInfo, userSession, clientSessionCtx);
userInfo = tokenManager.transformUserInfoAccessToken(session, userInfo, userSession, clientSessionCtx);
return tokenManager.generateUserInfoClaims(userInfo, user);
});
}