Filter first, then sort, and avoid atomics

Closes #20394
This commit is contained in:
Alexander Schwartz 2023-05-09 13:27:49 +02:00 committed by Marek Posolda
parent b63fccb062
commit cd9e0be9f0
6 changed files with 96 additions and 46 deletions

View file

@ -560,14 +560,6 @@ public final class KeycloakModelUtils {
}); });
} }
public static String resolveFirstAttribute(GroupModel group, String name) {
String value = group.getFirstAttribute(name);
if (value != null) return value;
if (group.getParentId() == null) return null;
return resolveFirstAttribute(group.getParent(), name);
}
public static Collection<String> resolveAttribute(GroupModel group, String name, boolean aggregateAttrs) { public static Collection<String> resolveAttribute(GroupModel group, String name, boolean aggregateAttrs) {
Set<String> values = group.getAttributeStream(name).collect(Collectors.toSet()); Set<String> values = group.getAttributeStream(name).collect(Collectors.toSet());
if ((values.isEmpty() || aggregateAttrs) && group.getParentId() != null) { if ((values.isEmpty() || aggregateAttrs) && group.getParentId() != null) {
@ -587,7 +579,6 @@ public final class KeycloakModelUtils {
} }
Stream<Collection<String>> attributes = user.getGroupsStream() Stream<Collection<String>> attributes = user.getGroupsStream()
.map(group -> resolveAttribute(group, name, aggregateAttrs)) .map(group -> resolveAttribute(group, name, aggregateAttrs))
.filter(Objects::nonNull)
.filter(attr -> !attr.isEmpty()); .filter(attr -> !attr.isEmpty());
if (!aggregateAttrs) { if (!aggregateAttrs) {

View file

@ -30,6 +30,7 @@ import java.util.AbstractMap;
import java.util.Comparator; import java.util.Comparator;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Objects; import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Stream; import java.util.stream.Stream;
/** /**
@ -134,6 +135,21 @@ public class ProtocolMapperUtils {
.sorted(Comparator.comparing(ProtocolMapperUtils::compare)); .sorted(Comparator.comparing(ProtocolMapperUtils::compare));
} }
public static Stream<Entry<ProtocolMapperModel, ProtocolMapper>> getSortedProtocolMappers(KeycloakSession session, ClientSessionContext ctx, Predicate<Entry<ProtocolMapperModel, ProtocolMapper>> filter) {
KeycloakSessionFactory sessionFactory = session.getKeycloakSessionFactory();
return ctx.getProtocolMappersStream()
.<Entry<ProtocolMapperModel, ProtocolMapper>>map(mapperModel -> {
ProtocolMapper mapper = (ProtocolMapper) sessionFactory.getProviderFactory(ProtocolMapper.class, mapperModel.getProtocolMapper());
if (mapper == null) {
return null;
}
return new AbstractMap.SimpleEntry<>(mapperModel, mapper);
})
.filter(Objects::nonNull)
.filter(filter)
.sorted(Comparator.comparing(ProtocolMapperUtils::compare));
}
public static int compare(Entry<ProtocolMapperModel, ProtocolMapper> entry) { public static int compare(Entry<ProtocolMapperModel, ProtocolMapper> entry) {
int priority = entry.getValue().getPriority(); int priority = entry.getValue().getPriority();
return priority; return priority;

View file

@ -110,9 +110,8 @@ public class DockerAuthV2Protocol implements LoginProtocol {
// Next, allow mappers to decorate the token to add/remove scopes as appropriate // Next, allow mappers to decorate the token to add/remove scopes as appropriate
AtomicReference<DockerResponseToken> finalResponseToken = new AtomicReference<>(responseToken); AtomicReference<DockerResponseToken> finalResponseToken = new AtomicReference<>(responseToken);
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx) ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx, mapper ->
.filter(mapper -> mapper.getValue() instanceof DockerAuthV2AttributeMapper) mapper.getValue() instanceof DockerAuthV2AttributeMapper && ((DockerAuthV2AttributeMapper) mapper.getValue()).appliesTo(finalResponseToken.get()))
.filter(mapper -> ((DockerAuthV2AttributeMapper) mapper.getValue()).appliesTo(finalResponseToken.get()))
.forEach(mapper -> finalResponseToken.set(((DockerAuthV2AttributeMapper) mapper.getValue()) .forEach(mapper -> finalResponseToken.set(((DockerAuthV2AttributeMapper) mapper.getValue())
.transformDockerResponseToken(finalResponseToken.get(), mapper.getKey(), session, userSession, clientSession))); .transformDockerResponseToken(finalResponseToken.get(), mapper.getKey(), session, userSession, clientSession)));
responseToken = finalResponseToken.get(); responseToken = finalResponseToken.get();

View file

@ -17,6 +17,7 @@
package org.keycloak.protocol.oidc; package org.keycloak.protocol.oidc;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
import org.keycloak.http.HttpRequest; import org.keycloak.http.HttpRequest;
@ -47,6 +48,7 @@ import org.keycloak.models.ClientSessionContext;
import org.keycloak.models.Constants; import org.keycloak.models.Constants;
import org.keycloak.models.ImpersonationSessionNote; import org.keycloak.models.ImpersonationSessionNote;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.ProtocolMapperModel;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.RoleModel; import org.keycloak.models.RoleModel;
import org.keycloak.models.SingleUseObjectProvider; import org.keycloak.models.SingleUseObjectProvider;
@ -56,6 +58,7 @@ import org.keycloak.models.UserSessionModel;
import org.keycloak.models.UserSessionProvider; import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.models.utils.RoleUtils; import org.keycloak.models.utils.RoleUtils;
import org.keycloak.protocol.ProtocolMapper;
import org.keycloak.protocol.ProtocolMapperUtils; import org.keycloak.protocol.ProtocolMapperUtils;
import org.keycloak.protocol.oidc.mappers.OIDCAccessTokenMapper; import org.keycloak.protocol.oidc.mappers.OIDCAccessTokenMapper;
import org.keycloak.protocol.oidc.mappers.OIDCAccessTokenResponseMapper; import org.keycloak.protocol.oidc.mappers.OIDCAccessTokenResponseMapper;
@ -90,8 +93,12 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -740,36 +747,36 @@ public class TokenManager {
public AccessToken transformAccessToken(KeycloakSession session, AccessToken token, public AccessToken transformAccessToken(KeycloakSession session, AccessToken token,
UserSessionModel userSession, ClientSessionContext clientSessionCtx) { UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
return ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx, mapper -> mapper.getValue() instanceof OIDCAccessTokenMapper)
AtomicReference<AccessToken> finalToken = new AtomicReference<>(token); .collect(new TokenCollector<AccessToken>(token) {
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx) @Override
.filter(mapper -> mapper.getValue() instanceof OIDCAccessTokenMapper) protected AccessToken applyMapper(AccessToken token, Map.Entry<ProtocolMapperModel, ProtocolMapper> mapper) {
.forEach(mapper -> finalToken.set(((OIDCAccessTokenMapper) mapper.getValue()) return ((OIDCAccessTokenMapper) mapper.getValue()).transformAccessToken(token, mapper.getKey(), session, userSession, clientSessionCtx);
.transformAccessToken(finalToken.get(), mapper.getKey(), session, userSession, clientSessionCtx))); }
return finalToken.get(); });
} }
public AccessTokenResponse transformAccessTokenResponse(KeycloakSession session, AccessTokenResponse accessTokenResponse, public AccessTokenResponse transformAccessTokenResponse(KeycloakSession session, AccessTokenResponse accessTokenResponse,
UserSessionModel userSession, ClientSessionContext clientSessionCtx) { UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
AtomicReference<AccessTokenResponse> finalResponseToken = new AtomicReference<>(accessTokenResponse); return ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx, mapper -> mapper.getValue() instanceof OIDCAccessTokenResponseMapper)
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx) .collect(new TokenCollector<AccessTokenResponse>(accessTokenResponse) {
.filter(mapper -> mapper.getValue() instanceof OIDCAccessTokenResponseMapper) @Override
.forEach(mapper -> finalResponseToken.set(((OIDCAccessTokenResponseMapper) mapper.getValue()) protected AccessTokenResponse applyMapper(AccessTokenResponse token, Map.Entry<ProtocolMapperModel, ProtocolMapper> mapper) {
.transformAccessTokenResponse(finalResponseToken.get(), mapper.getKey(), session, userSession, clientSessionCtx))); return ((OIDCAccessTokenResponseMapper) mapper.getValue()).transformAccessTokenResponse(token, mapper.getKey(), session, userSession, clientSessionCtx);
}
return finalResponseToken.get(); });
} }
public AccessToken transformUserInfoAccessToken(KeycloakSession session, AccessToken token, public AccessToken transformUserInfoAccessToken(KeycloakSession session, AccessToken token,
UserSessionModel userSession, ClientSessionContext clientSessionCtx) { UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
return ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx, mapper -> mapper.getValue() instanceof UserInfoTokenMapper)
AtomicReference<AccessToken> finalToken = new AtomicReference<>(token); .collect(new TokenCollector<AccessToken>(token) {
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx) @Override
.filter(mapper -> mapper.getValue() instanceof UserInfoTokenMapper) protected AccessToken applyMapper(AccessToken token, Map.Entry<ProtocolMapperModel, ProtocolMapper> mapper) {
.forEach(mapper -> finalToken.set(((UserInfoTokenMapper) mapper.getValue()) return ((UserInfoTokenMapper) mapper.getValue()).transformUserInfoToken(token, mapper.getKey(), session, userSession, clientSessionCtx);
.transformUserInfoToken(finalToken.get(), mapper.getKey(), session, userSession, clientSessionCtx))); }
return finalToken.get(); });
} }
public Map<String, Object> generateUserInfoClaims(AccessToken userInfo, UserModel userModel) { public Map<String, Object> generateUserInfoClaims(AccessToken userInfo, UserModel userModel) {
@ -863,14 +870,51 @@ public class TokenManager {
return claims; return claims;
} }
public void transformIDToken(KeycloakSession session, IDToken token, private abstract static class TokenCollector<T> implements Collector<Map.Entry<ProtocolMapperModel, ProtocolMapper>, TokenCollector<T>, T> {
UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
AtomicReference<IDToken> finalToken = new AtomicReference<>(token); private T token;
ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx)
.filter(mapper -> mapper.getValue() instanceof OIDCIDTokenMapper) public TokenCollector(T token) {
.forEach(mapper -> finalToken.set(((OIDCIDTokenMapper) mapper.getValue()) this.token = token;
.transformIDToken(finalToken.get(), mapper.getKey(), session, userSession, clientSessionCtx))); }
@Override
public Supplier<TokenCollector<T>> supplier() {
return () -> this;
}
@Override
public Function<TokenCollector<T>, T> finisher() {
return idTokenWrapper -> idTokenWrapper.token;
}
@Override
public Set<Collector.Characteristics> characteristics() {
return Collections.emptySet();
}
@Override
public BinaryOperator<TokenCollector<T>> combiner() {
return (tMutableWrapper, tMutableWrapper2) -> { throw new IllegalStateException("can't combine"); };
}
@Override
public BiConsumer<TokenCollector<T>, Map.Entry<ProtocolMapperModel, ProtocolMapper>> accumulator() {
return (idToken, mapper) -> idToken.token = applyMapper(idToken.token, mapper);
}
protected abstract T applyMapper(T token, Map.Entry<ProtocolMapperModel, ProtocolMapper> mapper);
}
public IDToken transformIDToken(KeycloakSession session, IDToken token,
UserSessionModel userSession, ClientSessionContext clientSessionCtx) {
return ProtocolMapperUtils.getSortedProtocolMappers(session, clientSessionCtx, mapper -> mapper.getValue() instanceof OIDCIDTokenMapper)
.collect(new TokenCollector<IDToken>(token) {
protected IDToken applyMapper(IDToken token, Map.Entry<ProtocolMapperModel, ProtocolMapper> mapper) {
return ((OIDCIDTokenMapper) mapper.getValue()).transformIDToken(token, mapper.getKey(), session, userSession, clientSessionCtx);
}
});
} }
protected AccessToken initToken(RealmModel realm, ClientModel client, UserModel user, UserSessionModel session, protected AccessToken initToken(RealmModel realm, ClientModel client, UserModel user, UserSessionModel session,
@ -1175,7 +1219,7 @@ public class TokenManager {
} }
if (isIdTokenAsDetachedSignature == false) { if (isIdTokenAsDetachedSignature == false) {
transformIDToken(session, idToken, userSession, clientSessionCtx); idToken = transformIDToken(session, idToken, userSession, clientSessionCtx);
} }
return this; return this;
} }
@ -1257,7 +1301,7 @@ public class TokenManager {
if (userNotBefore > notBefore) notBefore = userNotBefore; if (userNotBefore > notBefore) notBefore = userNotBefore;
res.setNotBeforePolicy(notBefore); res.setNotBeforePolicy(notBefore);
transformAccessTokenResponse(session, res, userSession, clientSessionCtx); res = transformAccessTokenResponse(session, res, userSession, clientSessionCtx);
// OIDC Financial API Read Only Profile : scope MUST be returned in the response from Token Endpoint // OIDC Financial API Read Only Profile : scope MUST be returned in the response from Token Endpoint
String responseScope = clientSessionCtx.getScopeString(); String responseScope = clientSessionCtx.getScopeString();

View file

@ -255,7 +255,7 @@ public class UserInfoEndpoint {
AccessToken userInfo = new AccessToken(); AccessToken userInfo = new AccessToken();
tokenManager.transformUserInfoAccessToken(session, userInfo, userSession, clientSessionCtx); userInfo = tokenManager.transformUserInfoAccessToken(session, userInfo, userSession, clientSessionCtx);
Map<String, Object> claims = tokenManager.generateUserInfoClaims(userInfo, userModel); Map<String, Object> claims = tokenManager.generateUserInfoClaims(userInfo, userModel);
Response.ResponseBuilder responseBuilder; Response.ResponseBuilder responseBuilder;

View file

@ -168,7 +168,7 @@ public class ClientScopeEvaluateResource {
AccessToken userInfo = new AccessToken(); AccessToken userInfo = new AccessToken();
TokenManager tokenManager = new TokenManager(); TokenManager tokenManager = new TokenManager();
tokenManager.transformUserInfoAccessToken(session, userInfo, userSession, clientSessionCtx); userInfo = tokenManager.transformUserInfoAccessToken(session, userInfo, userSession, clientSessionCtx);
return tokenManager.generateUserInfoClaims(userInfo, user); return tokenManager.generateUserInfoClaims(userInfo, user);
}); });
} }