[KEYCLOAK-16508] Complement methods for accessing user sessions with Stream variants

This commit is contained in:
Stefan Guilhen 2020-12-01 15:11:31 -03:00 committed by Hynek Mlnařík
parent edabbc9449
commit d6422e415c
20 changed files with 406 additions and 378 deletions

View file

@ -67,8 +67,6 @@ import java.io.Serializable;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
@ -80,6 +78,7 @@ import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import java.util.stream.StreamSupport;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
@ -270,25 +269,16 @@ public class InfinispanUserSessionProvider implements UserSessionProvider {
} }
protected List<UserSessionModel> getUserSessions(RealmModel realm, Predicate<Map.Entry<String, SessionEntityWrapper<UserSessionEntity>>> predicate, boolean offline) { protected Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, Predicate<Map.Entry<String, SessionEntityWrapper<UserSessionEntity>>> predicate, boolean offline) {
Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = getCache(offline); Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = getCache(offline);
cache = CacheDecorators.skipCacheLoaders(cache); cache = CacheDecorators.skipCacheLoaders(cache);
Stream<Map.Entry<String, SessionEntityWrapper<UserSessionEntity>>> cacheStream = cache.entrySet().stream(); // return a stream that 'wraps' the infinispan cache stream so that the cache stream's elements are read one by one
// and then filtered/mapped locally to avoid serialization issues when trying to manipulate the cache stream directly.
List<UserSessionModel> resultSessions = new LinkedList<>(); return StreamSupport.stream(cache.entrySet().stream().spliterator(), true)
.filter(predicate)
Iterator<UserSessionEntity> itr = cacheStream.filter(predicate)
.map(Mappers.userSessionEntity()) .map(Mappers.userSessionEntity())
.iterator(); .map(entity -> this.wrap(realm, entity, offline));
while (itr.hasNext()) {
UserSessionEntity userSessionEntity = itr.next();
resultSessions.add(wrap(realm, userSessionEntity, offline));
}
return resultSessions;
} }
@Override @Override
@ -305,46 +295,45 @@ public class InfinispanUserSessionProvider implements UserSessionProvider {
@Override @Override
public List<UserSessionModel> getUserSessions(final RealmModel realm, UserModel user) { public Stream<UserSessionModel> getUserSessionsStream(final RealmModel realm, UserModel user) {
return getUserSessions(realm, UserSessionPredicate.create(realm.getId()).user(user.getId()), false); return getUserSessionsStream(realm, UserSessionPredicate.create(realm.getId()).user(user.getId()), false);
} }
@Override @Override
public List<UserSessionModel> getUserSessionByBrokerUserId(RealmModel realm, String brokerUserId) { public Stream<UserSessionModel> getUserSessionByBrokerUserIdStream(RealmModel realm, String brokerUserId) {
return getUserSessions(realm, UserSessionPredicate.create(realm.getId()).brokerUserId(brokerUserId), false); return getUserSessionsStream(realm, UserSessionPredicate.create(realm.getId()).brokerUserId(brokerUserId), false);
} }
@Override @Override
public UserSessionModel getUserSessionByBrokerSessionId(RealmModel realm, String brokerSessionId) { public UserSessionModel getUserSessionByBrokerSessionId(RealmModel realm, String brokerSessionId) {
List<UserSessionModel> userSessions = getUserSessions(realm, UserSessionPredicate.create(realm.getId()).brokerSessionId(brokerSessionId), false); return this.getUserSessionsStream(realm, UserSessionPredicate.create(realm.getId()).brokerSessionId(brokerSessionId), false)
return userSessions.isEmpty() ? null : userSessions.get(0); .findFirst().orElse(null);
} }
@Override @Override
public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client) { public Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, ClientModel client) {
return getUserSessions(realm, client, -1, -1); return getUserSessionsStream(realm, client, -1, -1);
} }
@Override @Override
public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client, int firstResult, int maxResults) { public Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, ClientModel client, int firstResult, int maxResults) {
return getUserSessions(realm, client, firstResult, maxResults, false); return getUserSessionsStream(realm, client, firstResult, maxResults, false);
} }
protected List<UserSessionModel> getUserSessions(final RealmModel realm, ClientModel client, int firstResult, int maxResults, final boolean offline) { protected Stream<UserSessionModel> getUserSessionsStream(final RealmModel realm, ClientModel client, int firstResult, int maxResults, final boolean offline) {
final String clientUuid = client.getId(); final String clientUuid = client.getId();
UserSessionPredicate predicate = UserSessionPredicate.create(realm.getId()).client(clientUuid); UserSessionPredicate predicate = UserSessionPredicate.create(realm.getId()).client(clientUuid);
return getUserSessionModels(realm, firstResult, maxResults, offline, predicate); return getUserSessionModels(realm, firstResult, maxResults, offline, predicate);
} }
protected List<UserSessionModel> getUserSessionModels(RealmModel realm, int firstResult, int maxResults, boolean offline, UserSessionPredicate predicate) { protected Stream<UserSessionModel> getUserSessionModels(RealmModel realm, int firstResult, int maxResults, boolean offline, UserSessionPredicate predicate) {
Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = getCache(offline); Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = getCache(offline);
cache = CacheDecorators.skipCacheLoaders(cache); cache = CacheDecorators.skipCacheLoaders(cache);
Cache<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> clientSessionCache = getClientSessionCache(offline); // return a stream that 'wraps' the infinispan cache stream so that the cache stream's elements are read one by one
Cache<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> clientSessionCacheDecorated = CacheDecorators.skipCacheLoaders(clientSessionCache); // and then filtered/mapped locally to avoid serialization issues when trying to manipulate the cache stream directly.
Stream<UserSessionEntity> stream = StreamSupport.stream(cache.entrySet().stream().spliterator(), true)
Stream<UserSessionEntity> stream = cache.entrySet().stream()
.filter(predicate) .filter(predicate)
.map(Mappers.userSessionEntity()) .map(Mappers.userSessionEntity())
.sorted(Comparators.userSessionLastSessionRefresh()); .sorted(Comparators.userSessionLastSessionRefresh());
@ -357,16 +346,7 @@ public class InfinispanUserSessionProvider implements UserSessionProvider {
stream = stream.limit(maxResults); stream = stream.limit(maxResults);
} }
final List<UserSessionModel> sessions = new LinkedList<>(); return stream.map(entity -> this.wrap(realm, entity, offline));
Iterator<UserSessionEntity> itr = stream.iterator();
while (itr.hasNext()) {
UserSessionEntity userSessionEntity = itr.next();
sessions.add(wrap(realm, userSessionEntity, offline));
}
return sessions;
} }
@Override @Override
@ -839,13 +819,13 @@ public class InfinispanUserSessionProvider implements UserSessionProvider {
@Override @Override
public UserSessionModel getOfflineUserSessionByBrokerSessionId(RealmModel realm, String brokerSessionId) { public UserSessionModel getOfflineUserSessionByBrokerSessionId(RealmModel realm, String brokerSessionId) {
List<UserSessionModel> userSessions = getUserSessions(realm, UserSessionPredicate.create(realm.getId()).brokerSessionId(brokerSessionId), true); return this.getUserSessionsStream(realm, UserSessionPredicate.create(realm.getId()).brokerSessionId(brokerSessionId), true)
return userSessions.isEmpty() ? null : userSessions.get(0); .findFirst().orElse(null);
} }
@Override @Override
public List<UserSessionModel> getOfflineUserSessionByBrokerUserId(RealmModel realm, String brokerUserId) { public Stream<UserSessionModel> getOfflineUserSessionByBrokerUserIdStream(RealmModel realm, String brokerUserId) {
return getUserSessions(realm, UserSessionPredicate.create(realm.getId()).brokerUserId(brokerUserId), true); return getUserSessionsStream(realm, UserSessionPredicate.create(realm.getId()).brokerUserId(brokerUserId), true);
} }
@Override @Override
@ -856,8 +836,6 @@ public class InfinispanUserSessionProvider implements UserSessionProvider {
} }
} }
@Override @Override
public AuthenticatedClientSessionModel createOfflineClientSession(AuthenticatedClientSessionModel clientSession, UserSessionModel offlineUserSession) { public AuthenticatedClientSessionModel createOfflineClientSession(AuthenticatedClientSessionModel clientSession, UserSessionModel offlineUserSession) {
UserSessionAdapter userSessionAdapter = (offlineUserSession instanceof UserSessionAdapter) ? (UserSessionAdapter) offlineUserSession : UserSessionAdapter userSessionAdapter = (offlineUserSession instanceof UserSessionAdapter) ? (UserSessionAdapter) offlineUserSession :
@ -874,23 +852,8 @@ public class InfinispanUserSessionProvider implements UserSessionProvider {
} }
@Override @Override
public List<UserSessionModel> getOfflineUserSessions(RealmModel realm, UserModel user) { public Stream<UserSessionModel> getOfflineUserSessionsStream(RealmModel realm, UserModel user) {
List<UserSessionModel> userSessions = new LinkedList<>(); return this.getUserSessionsStream(realm, UserSessionPredicate.create(realm.getId()).user(user.getId()), true);
Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = CacheDecorators.skipCacheLoaders(offlineSessionCache);
Iterator<UserSessionEntity> itr = cache.entrySet().stream()
.filter(UserSessionPredicate.create(realm.getId()).user(user.getId()))
.map(Mappers.userSessionEntity())
.iterator();
while (itr.hasNext()) {
UserSessionEntity userSessionEntity = itr.next();
UserSessionModel userSession = wrap(realm, userSessionEntity, true);
userSessions.add(userSession);
}
return userSessions;
} }
@Override @Override
@ -899,8 +862,8 @@ public class InfinispanUserSessionProvider implements UserSessionProvider {
} }
@Override @Override
public List<UserSessionModel> getOfflineUserSessions(RealmModel realm, ClientModel client, int first, int max) { public Stream<UserSessionModel> getOfflineUserSessionsStream(RealmModel realm, ClientModel client, int first, int max) {
return getUserSessions(realm, client, first, max, true); return getUserSessionsStream(realm, client, first, max, true);
} }

View file

@ -24,6 +24,8 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/** /**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a> * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
@ -40,10 +42,79 @@ public interface UserSessionProvider extends Provider {
String authMethod, boolean rememberMe, String brokerSessionId, String brokerUserId, UserSessionModel.SessionPersistenceState persistenceState); String authMethod, boolean rememberMe, String brokerSessionId, String brokerUserId, UserSessionModel.SessionPersistenceState persistenceState);
UserSessionModel getUserSession(RealmModel realm, String id); UserSessionModel getUserSession(RealmModel realm, String id);
List<UserSessionModel> getUserSessions(RealmModel realm, UserModel user);
List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client); /**
List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client, int firstResult, int maxResults); * @deprecated Use {@link #getUserSessionsStream(RealmModel, ClientModel) getUserSessionsStream} instead.
List<UserSessionModel> getUserSessionByBrokerUserId(RealmModel realm, String brokerUserId); */
@Deprecated
default List<UserSessionModel> getUserSessions(RealmModel realm, UserModel user) {
return this.getUserSessionsStream(realm, user).collect(Collectors.toList());
}
/**
* Obtains the user sessions associated with the specified user.
*
* @param realm a reference to the realm.
* @param user the user whose sessions are being searched.
* @return a non-null {@link Stream} of user sessions.
*/
Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, UserModel user);
/**
* @deprecated Use {@link #getUserSessionsStream(RealmModel, ClientModel) getUserSessionsStream} instead.
*/
@Deprecated
default List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client) {
return this.getUserSessionsStream(realm, client).collect(Collectors.toList());
}
/**
* Obtains the user sessions associated with the specified client.
*
* @param realm a reference to the realm.
* @param client the client whose user sessions are being searched.
* @return a non-null {@link Stream} of user sessions.
*/
Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, ClientModel client);
/**
* @deprecated Use {@link #getUserSessionsStream(RealmModel, ClientModel, int, int) getUserSessionsStream} instead.
*/
@Deprecated
default List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client, int firstResult, int maxResults) {
return this.getUserSessionsStream(realm, client, firstResult, maxResults).collect(Collectors.toList());
}
/**
* Obtains the user sessions associated with the specified client, starting from the {@code firstResult} and containing
* at most {@code maxResults}.
*
* @param realm a reference tot he realm.
* @param client the client whose user sessions are being searched.
* @param firstResult first result to return. Ignored if negative.
* @param maxResults maximum number of results to return. Ignored if negative.
* @return a non-null {@link Stream} of user sessions.
*/
Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, ClientModel client, int firstResult, int maxResults);
/**
* @deprecated Use {@link #getUserSessionByBrokerUserIdStream(RealmModel, String) getUserSessionByBrokerUserIdStream}
* instead.
*/
@Deprecated
default List<UserSessionModel> getUserSessionByBrokerUserId(RealmModel realm, String brokerUserId) {
return this.getUserSessionByBrokerUserIdStream(realm, brokerUserId).collect(Collectors.toList());
}
/**
* Obtains the user sessions associated with the user that matches the specified {@code brokerUserId}.
*
* @param realm a reference to the realm.
* @param brokerUserId the id of the broker user whose sessions are being searched.
* @return a non-null {@link Stream} of user sessions.
*/
Stream<UserSessionModel> getUserSessionByBrokerUserIdStream(RealmModel realm, String brokerUserId);
UserSessionModel getUserSessionByBrokerSessionId(RealmModel realm, String brokerSessionId); UserSessionModel getUserSessionByBrokerSessionId(RealmModel realm, String brokerSessionId);
/** /**
@ -88,12 +159,66 @@ public interface UserSessionProvider extends Provider {
/** Will automatically attach newly created offline client session to the offlineUserSession **/ /** Will automatically attach newly created offline client session to the offlineUserSession **/
AuthenticatedClientSessionModel createOfflineClientSession(AuthenticatedClientSessionModel clientSession, UserSessionModel offlineUserSession); AuthenticatedClientSessionModel createOfflineClientSession(AuthenticatedClientSessionModel clientSession, UserSessionModel offlineUserSession);
List<UserSessionModel> getOfflineUserSessions(RealmModel realm, UserModel user);
/**
* @deprecated Use {@link #getOfflineUserSessionsStream(RealmModel, UserModel) getOfflineUserSessionsStream} instead.
*/
@Deprecated
default List<UserSessionModel> getOfflineUserSessions(RealmModel realm, UserModel user) {
return this.getOfflineUserSessionsStream(realm, user).collect(Collectors.toList());
}
/**
* Obtains the offline user sessions associated with the specified user.
*
* @param realm a reference to the realm.
* @param user the user whose offline sessions are being searched.
* @return a non-null {@link Stream} of offline user sessions.
*/
Stream<UserSessionModel> getOfflineUserSessionsStream(RealmModel realm, UserModel user);
UserSessionModel getOfflineUserSessionByBrokerSessionId(RealmModel realm, String brokerSessionId); UserSessionModel getOfflineUserSessionByBrokerSessionId(RealmModel realm, String brokerSessionId);
List<UserSessionModel> getOfflineUserSessionByBrokerUserId(RealmModel realm, String brokerUserId);
/**
* @deprecated Use {@link #getOfflineUserSessionByBrokerUserIdStream(RealmModel, String) getOfflineUserSessionByBrokerUserIdStream}
* instead.
*/
@Deprecated
default List<UserSessionModel> getOfflineUserSessionByBrokerUserId(RealmModel realm, String brokerUserId) {
return this.getOfflineUserSessionByBrokerUserIdStream(realm, brokerUserId).collect(Collectors.toList());
}
/**
* Obtains the offline user sessions associated with the user that matches the specified {@code brokerUserId}.
*
* @param realm a reference to the realm.
* @param brokerUserId the id of the broker user whose sessions are being searched.
* @return a non-null {@link Stream} of offline user sessions.
*/
Stream<UserSessionModel> getOfflineUserSessionByBrokerUserIdStream(RealmModel realm, String brokerUserId);
long getOfflineSessionsCount(RealmModel realm, ClientModel client); long getOfflineSessionsCount(RealmModel realm, ClientModel client);
List<UserSessionModel> getOfflineUserSessions(RealmModel realm, ClientModel client, int first, int max);
/**
* @deprecated use {@link #getOfflineUserSessionsStream(RealmModel, ClientModel, int, int) getOfflineUserSessionsStream}
* instead.
*/
@Deprecated
default List<UserSessionModel> getOfflineUserSessions(RealmModel realm, ClientModel client, int first, int max) {
return this.getOfflineUserSessionsStream(realm, client, first, max).collect(Collectors.toList());
}
/**
* Obtains the offline user sessions associated with the specified client, starting from the {@code firstResult} and
* containing at most {@code maxResults}.
*
* @param realm a reference tot he realm.
* @param client the client whose user sessions are being searched.
* @param firstResult first result to return. Ignored if negative.
* @param maxResults maximum number of results to return. Ignored if negative.
* @return a non-null {@link Stream} of offline user sessions.
*/
Stream<UserSessionModel> getOfflineUserSessionsStream(RealmModel realm, ClientModel client, int firstResult, int maxResults);
/** Triggered by persister during pre-load. It imports authenticatedClientSessions too **/ /** Triggered by persister during pre-load. It imports authenticatedClientSessions too **/
void importUserSessions(Collection<UserSessionModel> persistentUserSessions, boolean offline); void importUserSessions(Collection<UserSessionModel> persistentUserSessions, boolean offline);

View file

@ -47,7 +47,9 @@ import org.keycloak.sessions.AuthenticationSessionModel;
import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/** /**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a> * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
@ -129,12 +131,11 @@ public class UpdatePassword implements RequiredActionProvider, RequiredActionFac
if (getId().equals(authSession.getClientNote(Constants.KC_ACTION_EXECUTING)) if (getId().equals(authSession.getClientNote(Constants.KC_ACTION_EXECUTING))
&& "on".equals(formData.getFirst("logout-sessions"))) && "on".equals(formData.getFirst("logout-sessions")))
{ {
List<UserSessionModel> sessions = session.sessions().getUserSessions(realm, user); session.sessions().getUserSessionsStream(realm, user)
for (UserSessionModel s : sessions) { .filter(s -> !Objects.equals(s.getId(), authSession.getParentSession().getId()))
if (!s.getId().equals(authSession.getParentSession().getId())) { .collect(Collectors.toList()) // collect to avoid concurrent modification as backchannelLogout removes the user sessions.
AuthenticationManager.backchannelLogout(session, realm, s, session.getContext().getUri(), context.getConnection(), context.getHttpRequest().getHttpHeaders(), true); .forEach(s -> AuthenticationManager.backchannelLogout(session, realm, s, session.getContext().getUri(),
} context.getConnection(), context.getHttpRequest().getHttpHeaders(), true));
}
} }
try { try {

View file

@ -83,9 +83,15 @@ import javax.xml.namespace.QName;
import java.io.IOException; import java.io.IOException;
import java.security.Key; import java.security.Key;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.keycloak.protocol.saml.SamlPrincipalType; import org.keycloak.protocol.saml.SamlPrincipalType;
import org.keycloak.rotation.HardcodedKeyLocator; import org.keycloak.rotation.HardcodedKeyLocator;
@ -93,14 +99,14 @@ import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator; import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
import org.keycloak.saml.validators.ConditionsValidator; import org.keycloak.saml.validators.ConditionsValidator;
import org.keycloak.saml.validators.DestinationValidator; import org.keycloak.saml.validators.DestinationValidator;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;
import java.net.URI; import java.net.URI;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import org.w3c.dom.Element;
import java.util.*;
import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.MultivaluedMap;
import javax.xml.crypto.dsig.XMLSignature; import javax.xml.crypto.dsig.XMLSignature;
import org.w3c.dom.NodeList;
/** /**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a> * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
@ -296,22 +302,13 @@ public class SAMLEndpoint {
protected Response logoutRequest(LogoutRequestType request, String relayState) { protected Response logoutRequest(LogoutRequestType request, String relayState) {
String brokerUserId = config.getAlias() + "." + request.getNameID().getValue(); String brokerUserId = config.getAlias() + "." + request.getNameID().getValue();
if (request.getSessionIndex() == null || request.getSessionIndex().isEmpty()) { if (request.getSessionIndex() == null || request.getSessionIndex().isEmpty()) {
List<UserSessionModel> userSessions = session.sessions().getUserSessionByBrokerUserId(realm, brokerUserId); AtomicReference<LogoutRequestType> ref = new AtomicReference<>(request);
for (UserSessionModel userSession : userSessions) { session.sessions().getUserSessionByBrokerUserIdStream(realm, brokerUserId)
if (userSession.getState() == UserSessionModel.State.LOGGING_OUT || userSession.getState() == UserSessionModel.State.LOGGED_OUT) { .filter(userSession -> userSession.getState() != UserSessionModel.State.LOGGING_OUT &&
continue; userSession.getState() != UserSessionModel.State.LOGGED_OUT)
} .collect(Collectors.toList()) // collect to avoid concurrent modification as backchannelLogout removes the user sessions.
.forEach(processLogout(ref));
for(Iterator<SamlAuthenticationPreprocessor> it = SamlSessionUtils.getSamlAuthenticationPreprocessorIterator(session); it.hasNext();) { request = ref.get();
request = it.next().beforeProcessingLogoutRequest(request, userSession, null);
}
try {
AuthenticationManager.backchannelLogout(session, realm, userSession, session.getContext().getUri(), clientConnection, headers, false);
} catch (Exception e) {
logger.warn("failed to do backchannel logout for userSession", e);
}
}
} else { } else {
for (String sessionIndex : request.getSessionIndex()) { for (String sessionIndex : request.getSessionIndex()) {
@ -369,6 +366,19 @@ public class SAMLEndpoint {
} }
private Consumer<UserSessionModel> processLogout(AtomicReference<LogoutRequestType> ref) {
return userSession -> {
for(Iterator<SamlAuthenticationPreprocessor> it = SamlSessionUtils.getSamlAuthenticationPreprocessorIterator(session); it.hasNext();) {
ref.set(it.next().beforeProcessingLogoutRequest(ref.get(), userSession, null));
}
try {
AuthenticationManager.backchannelLogout(session, realm, userSession, session.getContext().getUri(), clientConnection, headers, false);
} catch (Exception e) {
logger.warn("failed to do backchannel logout for userSession", e);
}
};
}
private String getEntityId(UriInfo uriInfo, RealmModel realm) { private String getEntityId(UriInfo uriInfo, RealmModel realm) {
String configEntityId = config.getEntityId(); String configEntityId = config.getEntityId();
@ -578,11 +588,6 @@ public class SAMLEndpoint {
} }
return AuthenticationManager.finishBrowserLogout(session, realm, userSession, session.getContext().getUri(), clientConnection, headers); return AuthenticationManager.finishBrowserLogout(session, realm, userSession, session.getContext().getUri(), clientConnection, headers);
} }
} }
protected class PostBinding extends Binding { protected class PostBinding extends Binding {

View file

@ -65,8 +65,8 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriBuilder; import javax.ws.rs.core.UriBuilder;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.keycloak.models.UserSessionModel.State.LOGGED_OUT; import static org.keycloak.models.UserSessionModel.State.LOGGED_OUT;
@ -356,34 +356,29 @@ public class LogoutEndpoint {
BackchannelLogoutResponse backchannelLogoutResponse = new BackchannelLogoutResponse(); BackchannelLogoutResponse backchannelLogoutResponse = new BackchannelLogoutResponse();
backchannelLogoutResponse.setLocalLogoutSucceeded(true); backchannelLogoutResponse.setLocalLogoutSucceeded(true);
identityProviderAliases.forEach(identityProviderAlias -> { identityProviderAliases.forEach(identityProviderAlias -> {
List<UserSessionModel> userSessions = session.sessions().getUserSessionByBrokerUserId(realm,
identityProviderAlias + "." + federatedUserId);
if (logoutOfflineSessions) { if (logoutOfflineSessions) {
logoutOfflineUserSessions(identityProviderAlias + "." + federatedUserId); logoutOfflineUserSessions(identityProviderAlias + "." + federatedUserId);
} }
for (UserSessionModel userSession : userSessions) { session.sessions().getUserSessionByBrokerUserIdStream(realm, identityProviderAlias + "." + federatedUserId)
BackchannelLogoutResponse userBackchannelLogoutResponse; .collect(Collectors.toList()) // collect to avoid concurrent modification as backchannelLogout removes the user sessions.
userBackchannelLogoutResponse = logoutUserSession(userSession); .forEach(userSession -> {
BackchannelLogoutResponse userBackchannelLogoutResponse = this.logoutUserSession(userSession);
backchannelLogoutResponse.setLocalLogoutSucceeded(backchannelLogoutResponse.getLocalLogoutSucceeded() backchannelLogoutResponse.setLocalLogoutSucceeded(backchannelLogoutResponse.getLocalLogoutSucceeded()
&& userBackchannelLogoutResponse.getLocalLogoutSucceeded()); && userBackchannelLogoutResponse.getLocalLogoutSucceeded());
userBackchannelLogoutResponse.getClientResponses() userBackchannelLogoutResponse.getClientResponses()
.forEach(backchannelLogoutResponse::addClientResponses); .forEach(backchannelLogoutResponse::addClientResponses);
} });
}); });
return backchannelLogoutResponse; return backchannelLogoutResponse;
} }
private void logoutOfflineUserSessions(String brokerUserId) { private void logoutOfflineUserSessions(String brokerUserId) {
List<UserSessionModel> offlineUserSessions =
session.sessions().getOfflineUserSessionByBrokerUserId(realm, brokerUserId);
UserSessionManager userSessionManager = new UserSessionManager(session); UserSessionManager userSessionManager = new UserSessionManager(session);
for (UserSessionModel offlineUserSession : offlineUserSessions) { session.sessions().getOfflineUserSessionByBrokerUserIdStream(realm, brokerUserId).collect(Collectors.toList())
userSessionManager.revokeOfflineUserSession(offlineUserSession); .forEach(userSessionManager::revokeOfflineUserSession);
}
} }
private BackchannelLogoutResponse logoutUserSession(UserSessionModel userSession) { private BackchannelLogoutResponse logoutUserSession(UserSessionModel userSession) {

View file

@ -17,7 +17,8 @@
package org.keycloak.protocol.oidc.endpoints; package org.keycloak.protocol.oidc.endpoints;
import java.util.List; import java.util.Objects;
import java.util.stream.Collectors;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.POST; import javax.ws.rs.POST;
@ -36,7 +37,6 @@ import org.keycloak.events.Errors;
import org.keycloak.events.EventBuilder; import org.keycloak.events.EventBuilder;
import org.keycloak.events.EventType; import org.keycloak.events.EventType;
import org.keycloak.headers.SecurityHeadersProvider; import org.keycloak.headers.SecurityHeadersProvider;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
@ -223,14 +223,11 @@ public class TokenRevocationEndpoint {
if (TokenUtil.TOKEN_TYPE_OFFLINE.equals(token.getType())) { if (TokenUtil.TOKEN_TYPE_OFFLINE.equals(token.getType())) {
new UserSessionManager(session).revokeOfflineToken(user, client); new UserSessionManager(session).revokeOfflineToken(user, client);
} }
session.sessions().getUserSessionsStream(realm, user)
List<UserSessionModel> userSessions = session.sessions().getUserSessions(realm, user); .map(userSession -> userSession.getAuthenticatedClientSessionByClient(client.getId()))
for (UserSessionModel userSession : userSessions) { .filter(Objects::nonNull)
AuthenticatedClientSessionModel clientSession = userSession.getAuthenticatedClientSessionByClient(client.getId()); .collect(Collectors.toList()) // collect to avoid concurrent modification as dettachClientSession removes the user sessions.
if (clientSession != null) { .forEach(clientSession -> TokenManager.dettachClientSession(session.sessions(), realm, clientSession));
org.keycloak.protocol.oidc.TokenManager.dettachClientSession(session.sessions(), realm, clientSession);
}
}
} }
private void revokeAccessToken() { private void revokeAccessToken() {

View file

@ -548,15 +548,15 @@ public class AuthenticationManager {
* @param headers * @param headers
*/ */
public static void backchannelLogoutUserFromClient(KeycloakSession session, RealmModel realm, UserModel user, ClientModel client, UriInfo uriInfo, HttpHeaders headers) { public static void backchannelLogoutUserFromClient(KeycloakSession session, RealmModel realm, UserModel user, ClientModel client, UriInfo uriInfo, HttpHeaders headers) {
List<UserSessionModel> userSessions = session.sessions().getUserSessions(realm, user); session.sessions().getUserSessionsStream(realm, user)
for (UserSessionModel userSession : userSessions) { .map(userSession -> userSession.getAuthenticatedClientSessionByClient(client.getId()))
AuthenticatedClientSessionModel clientSession = userSession.getAuthenticatedClientSessionByClient(client.getId()); .filter(Objects::nonNull)
if (clientSession != null) { .collect(Collectors.toList()) // collect to avoid concurrent modification.
.forEach(clientSession -> {
backchannelLogoutClientSession(session, realm, clientSession, null, uriInfo, headers); backchannelLogoutClientSession(session, realm, clientSession, null, uriInfo, headers);
clientSession.setAction(AuthenticationSessionModel.Action.LOGGED_OUT.name()); clientSession.setAction(AuthenticationSessionModel.Action.LOGGED_OUT.name());
org.keycloak.protocol.oidc.TokenManager.dettachClientSession(session.sessions(), realm, clientSession); TokenManager.dettachClientSession(session.sessions(), realm, clientSession);
} });
}
} }
public static Response browserLogout(KeycloakSession session, public static Response browserLogout(KeycloakSession session,

View file

@ -37,7 +37,6 @@ import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.protocol.LoginProtocol; import org.keycloak.protocol.LoginProtocol;
import org.keycloak.protocol.oidc.OIDCAdvancedConfigWrapper; import org.keycloak.protocol.oidc.OIDCAdvancedConfigWrapper;
import org.keycloak.protocol.oidc.OIDCLoginProtocol; import org.keycloak.protocol.oidc.OIDCLoginProtocol;
@ -119,38 +118,6 @@ public class ResourceAdminManager {
return result; return result;
} }
public void logoutUser(RealmModel realm, UserModel user, KeycloakSession keycloakSession) {
keycloakSession.users().setNotBeforeForUser(realm, user, Time.currentTime());
List<UserSessionModel> userSessions = keycloakSession.sessions().getUserSessions(realm, user);
logoutUserSessions(realm, userSessions);
}
protected void logoutUserSessions(RealmModel realm, List<UserSessionModel> userSessions) {
// Map from "app" to clientSessions for this app
MultivaluedHashMap<String, AuthenticatedClientSessionModel> clientSessions = new MultivaluedHashMap<>();
for (UserSessionModel userSession : userSessions) {
putClientSessions(clientSessions, userSession);
}
logger.debugv("logging out {0} resources ", clientSessions.size());
//logger.infov("logging out resources: {0}", clientSessions);
for (Map.Entry<String, List<AuthenticatedClientSessionModel>> entry : clientSessions.entrySet()) {
if (entry.getValue().size() == 0) {
continue;
}
logoutClientSessions(realm, entry.getValue().get(0).getClient(), entry.getValue());
}
}
private void putClientSessions(MultivaluedHashMap<String, AuthenticatedClientSessionModel> clientSessions, UserSessionModel userSession) {
for (Map.Entry<String, AuthenticatedClientSessionModel> entry : userSession.getAuthenticatedClientSessions().entrySet()) {
clientSessions.add(entry.getKey(), entry.getValue());
}
}
public Response logoutClientSession(RealmModel realm, ClientModel resource, AuthenticatedClientSessionModel clientSession) { public Response logoutClientSession(RealmModel realm, ClientModel resource, AuthenticatedClientSessionModel clientSession) {
return logoutClientSessions(realm, resource, Arrays.asList(clientSession)); return logoutClientSessions(realm, resource, Arrays.asList(clientSession));
} }

View file

@ -30,10 +30,11 @@ import org.keycloak.models.UserSessionModel;
import org.keycloak.models.session.UserSessionPersisterProvider; import org.keycloak.models.session.UserSessionPersisterProvider;
import org.keycloak.services.ServicesLogger; import org.keycloak.services.ServicesLogger;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
/** /**
@ -77,28 +78,27 @@ public class UserSessionManager {
} }
public Set<ClientModel> findClientsWithOfflineToken(RealmModel realm, UserModel user) { public Set<ClientModel> findClientsWithOfflineToken(RealmModel realm, UserModel user) {
List<UserSessionModel> userSessions = kcSession.sessions().getOfflineUserSessions(realm, user); return kcSession.sessions().getOfflineUserSessionsStream(realm, user)
Set<ClientModel> clients = new HashSet<>(); .flatMap(userSession -> userSession.getAuthenticatedClientSessions().keySet().stream())
for (UserSessionModel userSession : userSessions) { .map(clientUUID -> realm.getClientById(clientUUID))
Set<String> clientIds = userSession.getAuthenticatedClientSessions().keySet(); .collect(Collectors.toSet());
for (String clientUUID : clientIds) {
ClientModel client = realm.getClientById(clientUUID);
clients.add(client);
}
}
return clients;
} }
@Deprecated
public List<UserSessionModel> findOfflineSessions(RealmModel realm, UserModel user) { public List<UserSessionModel> findOfflineSessions(RealmModel realm, UserModel user) {
return kcSession.sessions().getOfflineUserSessions(realm, user); return this.findOfflineSessionsStream(realm, user).collect(Collectors.toList());
}
public Stream<UserSessionModel> findOfflineSessionsStream(RealmModel realm, UserModel user) {
return kcSession.sessions().getOfflineUserSessionsStream(realm, user);
} }
public boolean revokeOfflineToken(UserModel user, ClientModel client) { public boolean revokeOfflineToken(UserModel user, ClientModel client) {
RealmModel realm = client.getRealm(); RealmModel realm = client.getRealm();
List<UserSessionModel> userSessions = kcSession.sessions().getOfflineUserSessions(realm, user); AtomicBoolean anyRemoved = new AtomicBoolean(false);
boolean anyRemoved = false; kcSession.sessions().getOfflineUserSessionsStream(realm, user).collect(Collectors.toList())
for (UserSessionModel userSession : userSessions) { .forEach(userSession -> {
AuthenticatedClientSessionModel clientSession = userSession.getAuthenticatedClientSessionByClient(client.getId()); AuthenticatedClientSessionModel clientSession = userSession.getAuthenticatedClientSessionByClient(client.getId());
if (clientSession != null) { if (clientSession != null) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
@ -109,11 +109,11 @@ public class UserSessionManager {
clientSession.detachFromUserSession(); clientSession.detachFromUserSession();
persister.removeClientSession(userSession.getId(), client.getId(), true); persister.removeClientSession(userSession.getId(), client.getId(), true);
checkOfflineUserSessionHasClientSessions(realm, user, userSession); checkOfflineUserSessionHasClientSessions(realm, user, userSession);
anyRemoved = true; anyRemoved.set(true);
}
} }
});
return anyRemoved; return anyRemoved.get();
} }
public void revokeOfflineUserSession(UserSessionModel userSession) { public void revokeOfflineUserSession(UserSessionModel userSession) {

View file

@ -322,7 +322,7 @@ public class AccountFormService extends AbstractSecuredLocalService {
@GET @GET
public Response sessionsPage() { public Response sessionsPage() {
if (auth != null) { if (auth != null) {
account.setSessions(session.sessions().getUserSessions(realm, auth.getUser())); account.setSessions(session.sessions().getUserSessionsStream(realm, auth.getUser()).collect(Collectors.toList()));
} }
return forwardToPage("sessions", AccountPages.SESSIONS); return forwardToPage("sessions", AccountPages.SESSIONS);
} }
@ -342,7 +342,6 @@ public class AccountFormService extends AbstractSecuredLocalService {
* lastName * lastName
* email * email
* *
* @param formData
* @return * @return
*/ */
@Path("/") @Path("/")
@ -427,10 +426,10 @@ public class AccountFormService extends AbstractSecuredLocalService {
// as time on the token will be same like notBefore // as time on the token will be same like notBefore
session.users().setNotBeforeForUser(realm, user, Time.currentTime() - 1); session.users().setNotBeforeForUser(realm, user, Time.currentTime() - 1);
List<UserSessionModel> userSessions = session.sessions().getUserSessions(realm, user); session.sessions().getUserSessionsStream(realm, user)
for (UserSessionModel userSession : userSessions) { .collect(Collectors.toList()) // collect to avoid concurrent modification as backchannelLogout removes the user sessions.
AuthenticationManager.backchannelLogout(session, realm, userSession, session.getContext().getUri(), clientConnection, headers, true); .forEach(userSession -> AuthenticationManager.backchannelLogout(session, realm, userSession,
} session.getContext().getUri(), clientConnection, headers, true));
UriBuilder builder = Urls.accountBase(session.getContext().getUri().getBaseUri()).path(AccountFormService.class, "sessionsPage"); UriBuilder builder = Urls.accountBase(session.getContext().getUri().getBaseUri()).path(AccountFormService.class, "sessionsPage");
String referrer = session.getContext().getUri().getQueryParameters().getFirst("referrer"); String referrer = session.getContext().getUri().getQueryParameters().getFirst("referrer");
@ -495,7 +494,6 @@ public class AccountFormService extends AbstractSecuredLocalService {
* totp - otp generated by authenticator * totp - otp generated by authenticator
* totpSecret - totp secret to register * totpSecret - totp secret to register
* *
* @param formData
* @return * @return
*/ */
@Path("totp") @Path("totp")
@ -567,7 +565,6 @@ public class AccountFormService extends AbstractSecuredLocalService {
* password-new * password-new
* pasword-confirm * pasword-confirm
* *
* @param formData
* @return * @return
*/ */
@Path("password") @Path("password")
@ -641,12 +638,9 @@ public class AccountFormService extends AbstractSecuredLocalService {
return account.setError(Response.Status.INTERNAL_SERVER_ERROR, ape.getMessage()).createResponse(AccountPages.PASSWORD); return account.setError(Response.Status.INTERNAL_SERVER_ERROR, ape.getMessage()).createResponse(AccountPages.PASSWORD);
} }
List<UserSessionModel> sessions = session.sessions().getUserSessions(realm, user); session.sessions().getUserSessionsStream(realm, user).filter(s -> !Objects.equals(s.getId(), auth.getSession().getId()))
for (UserSessionModel s : sessions) { .collect(Collectors.toList()) // collect to avoid concurrent modification as backchannelLogout removes the user sessions.
if (!s.getId().equals(auth.getSession().getId())) { .forEach(s -> AuthenticationManager.backchannelLogout(session, realm, s, session.getContext().getUri(), clientConnection, headers, true));
AuthenticationManager.backchannelLogout(session, realm, s, session.getContext().getUri(), clientConnection, headers, true);
}
}
event.event(EventType.UPDATE_PASSWORD).client(auth.getClient()).user(auth.getUser()).success(); event.event(EventType.UPDATE_PASSWORD).client(auth.getClient()).user(auth.getUser()).success();

View file

@ -33,7 +33,6 @@ import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserConsentModel; import org.keycloak.models.UserConsentModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.representations.account.ClientRepresentation; import org.keycloak.representations.account.ClientRepresentation;
import org.keycloak.representations.account.ConsentRepresentation; import org.keycloak.representations.account.ConsentRepresentation;
import org.keycloak.representations.account.ConsentScopeRepresentation; import org.keycloak.representations.account.ConsentScopeRepresentation;
@ -79,6 +78,7 @@ import java.util.Properties;
import java.util.Set; import java.util.Set;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
@ -401,51 +401,36 @@ public class AccountRestService {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@NoCache @NoCache
public List<ClientRepresentation> applications(@QueryParam("name") String name) { public Stream<ClientRepresentation> applications(@QueryParam("name") String name) {
checkAccountApiEnabled(); checkAccountApiEnabled();
auth.requireOneOf(AccountRoles.MANAGE_ACCOUNT, AccountRoles.VIEW_APPLICATIONS); auth.requireOneOf(AccountRoles.MANAGE_ACCOUNT, AccountRoles.VIEW_APPLICATIONS);
Set<ClientModel> clients = new HashSet<>(); Set<ClientModel> clients = new HashSet<>();
List<String> inUseClients = new LinkedList<>(); List<String> inUseClients = new LinkedList<>();
List<UserSessionModel> sessions = session.sessions().getUserSessions(realm, user); clients.addAll(session.sessions().getUserSessionsStream(realm, user)
for(UserSessionModel s : sessions) { .flatMap(s -> s.getAuthenticatedClientSessions().values().stream())
for (AuthenticatedClientSessionModel a : s.getAuthenticatedClientSessions().values()) { .map(AuthenticatedClientSessionModel::getClient)
ClientModel client = a.getClient(); .peek(client -> inUseClients.add(client.getClientId()))
clients.add(client); .collect(Collectors.toSet()));
inUseClients.add(client.getClientId());
}
}
List<String> offlineClients = new LinkedList<>(); List<String> offlineClients = new LinkedList<>();
List<UserSessionModel> offlineSessions = session.sessions().getOfflineUserSessions(realm, user); clients.addAll(session.sessions().getOfflineUserSessionsStream(realm, user)
for(UserSessionModel s : offlineSessions) { .flatMap(s -> s.getAuthenticatedClientSessions().values().stream())
for(AuthenticatedClientSessionModel a : s.getAuthenticatedClientSessions().values()) { .map(AuthenticatedClientSessionModel::getClient)
ClientModel client = a.getClient(); .peek(client -> offlineClients.add(client.getClientId()))
clients.add(client); .collect(Collectors.toSet()));
offlineClients.add(client.getClientId());
}
}
Map<String, UserConsentModel> consentModels = new HashMap<>(); Map<String, UserConsentModel> consentModels = new HashMap<>();
session.users().getConsentsStream(realm, user.getId()).forEach(consent -> { clients.addAll(session.users().getConsentsStream(realm, user.getId())
ClientModel client = consent.getClient(); .peek(consent -> consentModels.put(consent.getClient().getClientId(), consent))
clients.add(client); .map(UserConsentModel::getClient)
consentModels.put(client.getClientId(), consent); .collect(Collectors.toSet()));
});
realm.getAlwaysDisplayInConsoleClientsStream().forEach(clients::add); realm.getAlwaysDisplayInConsoleClientsStream().forEach(clients::add);
List<ClientRepresentation> apps = new LinkedList<>(); return clients.stream().filter(client -> !client.isBearerOnly() && client.getBaseUrl() != null && !client.getClientId().isEmpty())
for (ClientModel client : clients) { .filter(client -> matches(client, name))
if (client.isBearerOnly() || client.getBaseUrl() == null || client.getBaseUrl().isEmpty()) { .map(client -> modelToRepresentation(client, inUseClients, offlineClients, consentModels));
continue;
}
else if (matches(client, name)) {
apps.add(modelToRepresentation(client, inUseClients, offlineClients, consentModels));
}
}
return apps;
} }
private boolean matches(ClientModel client, String name) { private boolean matches(ClientModel client, String name) {

View file

@ -30,6 +30,7 @@ import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.jboss.resteasy.annotations.cache.NoCache; import org.jboss.resteasy.annotations.cache.NoCache;
import org.jboss.resteasy.spi.HttpRequest; import org.jboss.resteasy.spi.HttpRequest;
@ -73,8 +74,8 @@ public class SessionResource {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@NoCache @NoCache
public List<SessionRepresentation> toRepresentation() { public Stream<SessionRepresentation> toRepresentation() {
return session.sessions().getUserSessions(realm, user).stream().map(this::toRepresentation).collect(Collectors.toList()); return session.sessions().getUserSessionsStream(realm, user).map(this::toRepresentation);
} }
/** /**
@ -88,9 +89,7 @@ public class SessionResource {
@NoCache @NoCache
public Collection<DeviceRepresentation> devices() { public Collection<DeviceRepresentation> devices() {
Map<String, DeviceRepresentation> reps = new HashMap<>(); Map<String, DeviceRepresentation> reps = new HashMap<>();
List<UserSessionModel> sessions = session.sessions().getUserSessions(realm, user); session.sessions().getUserSessionsStream(realm, user).forEach(s -> {
for (UserSessionModel s : sessions) {
DeviceRepresentation device = getAttachedDevice(s); DeviceRepresentation device = getAttachedDevice(s);
DeviceRepresentation rep = reps DeviceRepresentation rep = reps
.computeIfAbsent(device.getOs() + device.getOsVersion(), key -> { .computeIfAbsent(device.getOs() + device.getOsVersion(), key -> {
@ -114,7 +113,7 @@ public class SessionResource {
} }
rep.addSession(createSessionRepresentation(s, device)); rep.addSession(createSessionRepresentation(s, device));
} });
return reps.values(); return reps.values();
} }
@ -130,13 +129,9 @@ public class SessionResource {
@NoCache @NoCache
public Response logout(@QueryParam("current") boolean removeCurrent) { public Response logout(@QueryParam("current") boolean removeCurrent) {
auth.require(AccountRoles.MANAGE_ACCOUNT); auth.require(AccountRoles.MANAGE_ACCOUNT);
List<UserSessionModel> userSessions = session.sessions().getUserSessions(realm, user); session.sessions().getUserSessionsStream(realm, user).filter(s -> removeCurrent || !isCurrentSession(s))
.collect(Collectors.toList()) // collect to avoid concurrent modification as backchannelLogout removes the user sessions.
for (UserSessionModel s : userSessions) { .forEach(s -> AuthenticationManager.backchannelLogout(session, s, true));
if (removeCurrent || !isCurrentSession(s)) {
AuthenticationManager.backchannelLogout(session, s, true);
}
}
return Response.noContent().build(); return Response.noContent().build();
} }

View file

@ -50,10 +50,8 @@ import org.keycloak.representations.idm.UserRepresentation;
import org.keycloak.representations.idm.UserSessionRepresentation; import org.keycloak.representations.idm.UserSessionRepresentation;
import org.keycloak.services.ErrorResponse; import org.keycloak.services.ErrorResponse;
import org.keycloak.services.ErrorResponseException; import org.keycloak.services.ErrorResponseException;
import org.keycloak.services.clientpolicy.AdminClientRegisterContext;
import org.keycloak.services.clientpolicy.AdminClientUpdateContext; import org.keycloak.services.clientpolicy.AdminClientUpdateContext;
import org.keycloak.services.clientpolicy.ClientPolicyException; import org.keycloak.services.clientpolicy.ClientPolicyException;
import org.keycloak.services.clientpolicy.DefaultClientPolicyManager;
import org.keycloak.services.clientregistration.ClientRegistrationTokenUtils; import org.keycloak.services.clientregistration.ClientRegistrationTokenUtils;
import org.keycloak.services.clientregistration.policy.RegistrationAuth; import org.keycloak.services.clientregistration.policy.RegistrationAuth;
import org.keycloak.services.managers.ClientManager; import org.keycloak.services.managers.ClientManager;
@ -78,11 +76,12 @@ import javax.ws.rs.QueryParam;
import javax.ws.rs.core.Context; import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import static java.lang.Boolean.TRUE; import static java.lang.Boolean.TRUE;
@ -462,17 +461,13 @@ public class ClientResource {
@GET @GET
@NoCache @NoCache
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public List<UserSessionRepresentation> getUserSessions(@QueryParam("first") Integer firstResult, @QueryParam("max") Integer maxResults) { public Stream<UserSessionRepresentation> getUserSessions(@QueryParam("first") Integer firstResult, @QueryParam("max") Integer maxResults) {
auth.clients().requireView(client); auth.clients().requireView(client);
firstResult = firstResult != null ? firstResult : -1; firstResult = firstResult != null ? firstResult : -1;
maxResults = maxResults != null ? maxResults : Constants.DEFAULT_MAX_RESULTS; maxResults = maxResults != null ? maxResults : Constants.DEFAULT_MAX_RESULTS;
List<UserSessionRepresentation> sessions = new ArrayList<UserSessionRepresentation>(); return session.sessions().getUserSessionsStream(client.getRealm(), client, firstResult, maxResults)
for (UserSessionModel userSession : session.sessions().getUserSessions(client.getRealm(), client, firstResult, maxResults)) { .map(ModelToRepresentation::toRepresentation);
UserSessionRepresentation rep = ModelToRepresentation.toRepresentation(userSession);
sessions.add(rep);
}
return sessions;
} }
/** /**
@ -511,30 +506,14 @@ public class ClientResource {
@GET @GET
@NoCache @NoCache
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public List<UserSessionRepresentation> getOfflineUserSessions(@QueryParam("first") Integer firstResult, @QueryParam("max") Integer maxResults) { public Stream<UserSessionRepresentation> getOfflineUserSessions(@QueryParam("first") Integer firstResult, @QueryParam("max") Integer maxResults) {
auth.clients().requireView(client); auth.clients().requireView(client);
firstResult = firstResult != null ? firstResult : -1; firstResult = firstResult != null ? firstResult : -1;
maxResults = maxResults != null ? maxResults : Constants.DEFAULT_MAX_RESULTS; maxResults = maxResults != null ? maxResults : Constants.DEFAULT_MAX_RESULTS;
List<UserSessionRepresentation> sessions = new ArrayList<UserSessionRepresentation>();
List<UserSessionModel> userSessions = session.sessions().getOfflineUserSessions(client.getRealm(), client, firstResult, maxResults);
for (UserSessionModel userSession : userSessions) {
UserSessionRepresentation rep = ModelToRepresentation.toRepresentation(userSession);
// Update lastSessionRefresh with the timestamp from clientSession return session.sessions().getOfflineUserSessionsStream(client.getRealm(), client, firstResult, maxResults)
for (Map.Entry<String, AuthenticatedClientSessionModel> csEntry : userSession.getAuthenticatedClientSessions().entrySet()) { .map(this::toUserSessionRepresentation);
String clientUuid = csEntry.getKey();
AuthenticatedClientSessionModel clientSession = csEntry.getValue();
if (client.getId().equals(clientUuid)) {
rep.setLastAccess(Time.toMillis(clientSession.getTimestamp()));
break;
}
}
sessions.add(rep);
}
return sessions;
} }
/** /**
@ -701,4 +680,23 @@ public class ClientResource {
authorization().disable(); authorization().disable();
} }
} }
/**
* Converts the specified {@link UserSessionModel} into a {@link UserSessionRepresentation}.
*
* @param userSession the model to be converted.
* @return a reference to the constructed representation.
*/
private UserSessionRepresentation toUserSessionRepresentation(final UserSessionModel userSession) {
UserSessionRepresentation rep = ModelToRepresentation.toRepresentation(userSession);
// Update lastSessionRefresh with the timestamp from clientSession
Map.Entry<String, AuthenticatedClientSessionModel> result = userSession.getAuthenticatedClientSessions().entrySet().stream()
.filter(entry -> Objects.equals(client.getId(), entry.getKey()))
.findFirst().orElse(null);
if (result != null) {
rep.setLastAccess(Time.toMillis(result.getValue().getTimestamp()));
}
return rep;
}
} }

View file

@ -309,15 +309,9 @@ public class UserResource {
@GET @GET
@NoCache @NoCache
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public List<UserSessionRepresentation> getSessions() { public Stream<UserSessionRepresentation> getSessions() {
auth.users().requireView(user); auth.users().requireView(user);
List<UserSessionModel> sessions = session.sessions().getUserSessions(realm, user); return session.sessions().getUserSessionsStream(realm, user).map(ModelToRepresentation::toRepresentation);
List<UserSessionRepresentation> reps = new ArrayList<UserSessionRepresentation>();
for (UserSessionModel session : sessions) {
UserSessionRepresentation rep = ModelToRepresentation.toRepresentation(session);
reps.add(rep);
}
return reps;
} }
/** /**
@ -329,30 +323,15 @@ public class UserResource {
@GET @GET
@NoCache @NoCache
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public List<UserSessionRepresentation> getOfflineSessions(final @PathParam("clientUuid") String clientUuid) { public Stream<UserSessionRepresentation> getOfflineSessions(final @PathParam("clientUuid") String clientUuid) {
auth.users().requireView(user); auth.users().requireView(user);
ClientModel client = realm.getClientById(clientUuid); ClientModel client = realm.getClientById(clientUuid);
if (client == null) { if (client == null) {
throw new NotFoundException("Client not found"); throw new NotFoundException("Client not found");
} }
List<UserSessionModel> sessions = new UserSessionManager(session).findOfflineSessions(realm, user); return new UserSessionManager(session).findOfflineSessionsStream(realm, user)
List<UserSessionRepresentation> reps = new ArrayList<UserSessionRepresentation>(); .map(session -> toUserSessionRepresentation(session, clientUuid))
for (UserSessionModel session : sessions) { .filter(Objects::nonNull);
UserSessionRepresentation rep = ModelToRepresentation.toRepresentation(session);
// Update lastSessionRefresh with the timestamp from clientSession
AuthenticatedClientSessionModel clientSession = session.getAuthenticatedClientSessionByClient(clientUuid);
// Skip if userSession is not for this client
if (clientSession == null) {
continue;
}
rep.setLastAccess(clientSession.getTimestamp());
reps.add(rep);
}
return reps;
} }
/** /**
@ -503,10 +482,10 @@ public class UserResource {
session.users().setNotBeforeForUser(realm, user, Time.currentTime()); session.users().setNotBeforeForUser(realm, user, Time.currentTime());
List<UserSessionModel> userSessions = session.sessions().getUserSessions(realm, user); session.sessions().getUserSessionsStream(realm, user)
for (UserSessionModel userSession : userSessions) { .collect(Collectors.toList()) // collect to avoid concurrent modification as backchannelLogout removes the user sessions.
AuthenticationManager.backchannelLogout(session, realm, userSession, session.getContext().getUri(), clientConnection, headers, true); .forEach(userSession -> AuthenticationManager.backchannelLogout(session, realm, userSession,
} session.getContext().getUri(), clientConnection, headers, true));
adminEvent.operation(OperationType.ACTION).resourcePath(session.getContext().getUri()).success(); adminEvent.operation(OperationType.ACTION).resourcePath(session.getContext().getUri()).success();
} }
@ -900,4 +879,22 @@ public class UserResource {
} }
} }
/**
* Converts the specified {@link UserSessionModel} into a {@link UserSessionRepresentation}.
*
* @param userSession the model to be converted.
* @param clientUuid the client's UUID.
* @return a reference to the constructed representation or {@code null} if the session is not associated with the specified
* client.
*/
private UserSessionRepresentation toUserSessionRepresentation(final UserSessionModel userSession, final String clientUuid) {
UserSessionRepresentation rep = ModelToRepresentation.toRepresentation(userSession);
// Update lastSessionRefresh with the timestamp from clientSession
AuthenticatedClientSessionModel clientSession = userSession.getAuthenticatedClientSessionByClient(clientUuid);
if (clientSession == null) {
return null;
}
rep.setLastAccess(clientSession.getTimestamp());
return rep;
}
} }

View file

@ -282,7 +282,7 @@ public class ImpersonationTest extends AbstractKeycloakTest {
final UserSessionNotesHolder notesHolder = testingClient.server("test").fetch(session -> { final UserSessionNotesHolder notesHolder = testingClient.server("test").fetch(session -> {
final RealmModel realm = session.realms().getRealmByName("test"); final RealmModel realm = session.realms().getRealmByName("test");
final UserModel user = session.users().getUserById(userId, realm); final UserModel user = session.users().getUserById(userId, realm);
final UserSessionModel userSession = session.sessions().getUserSessions(realm, user).get(0); final UserSessionModel userSession = session.sessions().getUserSessionsStream(realm, user).findFirst().get();
return new UserSessionNotesHolder(userSession.getNotes()); return new UserSessionNotesHolder(userSession.getNotes());
}, UserSessionNotesHolder.class); }, UserSessionNotesHolder.class);

View file

@ -135,8 +135,7 @@ final class BrokerRunOnServerUtil {
return (session) -> { return (session) -> {
RealmModel realm = session.realms().getRealmByName("consumer"); RealmModel realm = session.realms().getRealmByName("consumer");
UserModel user = session.users().getUserByUsername("testuser", realm); UserModel user = session.users().getUserByUsername("testuser", realm);
List<UserSessionModel> userSessions = session.sessions().getUserSessions(realm, user); UserSessionModel sessions = session.sessions().getUserSessionsStream(realm, user).findFirst().get();
UserSessionModel sessions = userSessions.get(0);
assertEquals("sessionvalue", sessions.getNote("user-session-attr")); assertEquals("sessionvalue", sessions.getNote("user-session-attr"));
}; };
} }

View file

@ -42,6 +42,7 @@ import org.keycloak.testsuite.arquillian.annotation.ModelTest;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import static org.hamcrest.core.Is.is; import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
@ -116,7 +117,8 @@ public class UserSessionInitializerTest extends AbstractTestRealmKeycloakTest {
assertThat("Count of offline sesions for client 'test-app'", currentSession.sessions().getOfflineSessionsCount(realm, testApp), is((long) 3)); assertThat("Count of offline sesions for client 'test-app'", currentSession.sessions().getOfflineSessionsCount(realm, testApp), is((long) 3));
assertThat("Count of offline sesions for client 'third-party'", currentSession.sessions().getOfflineSessionsCount(realm, thirdparty), is((long) 1)); assertThat("Count of offline sesions for client 'third-party'", currentSession.sessions().getOfflineSessionsCount(realm, thirdparty), is((long) 1));
List<UserSessionModel> loadedSessions = currentSession.sessions().getOfflineUserSessions(realm, testApp, 0, 10); List<UserSessionModel> loadedSessions = currentSession.sessions().getOfflineUserSessionsStream(realm, testApp, 0, 10)
.collect(Collectors.toList());
UserSessionProviderTest.assertSessions(loadedSessions, origSessions); UserSessionProviderTest.assertSessions(loadedSessions, origSessions);
assertSessionLoaded(loadedSessions, origSessions[0].getId(), currentSession.users().getUserByUsername("user1", realm), "127.0.0.1", started, started, "test-app", "third-party"); assertSessionLoaded(loadedSessions, origSessions[0].getId(), currentSession.users().getUserByUsername("user1", realm), "127.0.0.1", started, started, "test-app", "third-party");
@ -166,7 +168,8 @@ public class UserSessionInitializerTest extends AbstractTestRealmKeycloakTest {
ClientModel thirdparty = realm.getClientByClientId("third-party"); ClientModel thirdparty = realm.getClientByClientId("third-party");
assertThat("Count of offline sesions for client 'third-party'", currentSession.sessions().getOfflineSessionsCount(realm, thirdparty), is((long) 1)); assertThat("Count of offline sesions for client 'third-party'", currentSession.sessions().getOfflineSessionsCount(realm, thirdparty), is((long) 1));
List<UserSessionModel> loadedSessions = currentSession.sessions().getOfflineUserSessions(realm, thirdparty, 0, 10); List<UserSessionModel> loadedSessions = currentSession.sessions().getOfflineUserSessionsStream(realm, thirdparty, 0, 10)
.collect(Collectors.toList());
assertThat("Size of loaded Sessions", loadedSessions.size(), is(1)); assertThat("Size of loaded Sessions", loadedSessions.size(), is(1));
assertSessionLoaded(loadedSessions, origSessions[0].getId(), currentSession.users().getUserByUsername("user1", realm), "127.0.0.1", started, started, "third-party"); assertSessionLoaded(loadedSessions, origSessions[0].getId(), currentSession.users().getUserByUsername("user1", realm), "127.0.0.1", started, started, "third-party");

View file

@ -44,6 +44,7 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@ -102,10 +103,8 @@ public class UserSessionPersisterProviderTest extends AbstractTestRealmKeycloakT
// Persist 3 created userSessions and clientSessions as offline // Persist 3 created userSessions and clientSessions as offline
RealmModel realm = sessionWL22.realms().getRealm("test"); RealmModel realm = sessionWL22.realms().getRealm("test");
ClientModel testApp = realm.getClientByClientId("test-app"); ClientModel testApp = realm.getClientByClientId("test-app");
List<UserSessionModel> userSessions = sessionWL22.sessions().getUserSessions(realm, testApp); sessionWL22.sessions().getUserSessionsStream(realm, testApp).collect(Collectors.toList())
for (UserSessionModel userSessionLooper : userSessions) { .forEach(userSessionLooper -> persistUserSession(sessionWL22, userSessionLooper, true));
persistUserSession(sessionWL22, userSessionLooper, true);
}
}); });
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionWL2) -> { KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionWL2) -> {

View file

@ -49,6 +49,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@ -112,10 +113,8 @@ public class UserSessionProviderOfflineTest extends AbstractTestRealmKeycloakTes
// Key is userSession ID, values are client UUIDS // Key is userSession ID, values are client UUIDS
// Persist 3 created userSessions and clientSessions as offline // Persist 3 created userSessions and clientSessions as offline
ClientModel testApp = realm.getClientByClientId("test-app"); ClientModel testApp = realm.getClientByClientId("test-app");
List<UserSessionModel> userSessions = currentSession.sessions().getUserSessions(realm, testApp); currentSession.sessions().getUserSessionsStream(realm, testApp).collect(Collectors.toList())
for (UserSessionModel userSession : userSessions) { .forEach(userSession -> offlineSessions.put(userSession.getId(), createOfflineSessionIncludeClientSessions(currentSession, userSession)));
offlineSessions.put(userSession.getId(), createOfflineSessionIncludeClientSessions(currentSession, userSession));
}
}); });
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionCrud3) -> { KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionCrud3) -> {
@ -170,7 +169,8 @@ public class UserSessionProviderOfflineTest extends AbstractTestRealmKeycloakTes
Assert.assertEquals(2, currentSession.sessions().getOfflineSessionsCount(realm, testApp)); Assert.assertEquals(2, currentSession.sessions().getOfflineSessionsCount(realm, testApp));
Assert.assertEquals(1, currentSession.sessions().getOfflineSessionsCount(realm, thirdparty)); Assert.assertEquals(1, currentSession.sessions().getOfflineSessionsCount(realm, thirdparty));
List<UserSessionModel> thirdpartySessions = currentSession.sessions().getOfflineUserSessions(realm, thirdparty, 0, 10); List<UserSessionModel> thirdpartySessions = currentSession.sessions().getOfflineUserSessionsStream(realm, thirdparty, 0, 10)
.collect(Collectors.toList());
Assert.assertEquals(1, thirdpartySessions.size()); Assert.assertEquals(1, thirdpartySessions.size());
Assert.assertEquals("127.0.0.1", thirdpartySessions.get(0).getIpAddress()); Assert.assertEquals("127.0.0.1", thirdpartySessions.get(0).getIpAddress());
Assert.assertEquals("user1", thirdpartySessions.get(0).getUser().getUsername()); Assert.assertEquals("user1", thirdpartySessions.get(0).getUser().getUsername());
@ -203,7 +203,8 @@ public class UserSessionProviderOfflineTest extends AbstractTestRealmKeycloakTes
Assert.assertEquals(1, currentSession.sessions().getOfflineSessionsCount(realm, testApp)); Assert.assertEquals(1, currentSession.sessions().getOfflineSessionsCount(realm, testApp));
Assert.assertEquals(0, currentSession.sessions().getOfflineSessionsCount(realm, thirdparty)); Assert.assertEquals(0, currentSession.sessions().getOfflineSessionsCount(realm, thirdparty));
List<UserSessionModel> testAppSessions = currentSession.sessions().getOfflineUserSessions(realm, testApp, 0, 10); List<UserSessionModel> testAppSessions = currentSession.sessions().getOfflineUserSessionsStream(realm, testApp, 0, 10)
.collect(Collectors.toList());
Assert.assertEquals(1, testAppSessions.size()); Assert.assertEquals(1, testAppSessions.size());
Assert.assertEquals("127.0.0.3", testAppSessions.get(0).getIpAddress()); Assert.assertEquals("127.0.0.3", testAppSessions.get(0).getIpAddress());
@ -462,10 +463,8 @@ public class UserSessionProviderOfflineTest extends AbstractTestRealmKeycloakTes
// Persist 3 created userSessions and clientSessions as offline // Persist 3 created userSessions and clientSessions as offline
testApp[0] = realm.getClientByClientId("test-app"); testApp[0] = realm.getClientByClientId("test-app");
List<UserSessionModel> userSessions = currentSession.sessions().getUserSessions(realm, testApp[0]); currentSession.sessions().getUserSessionsStream(realm, testApp[0]).collect(Collectors.toList())
for (UserSessionModel userSession : userSessions) { .forEach(userSession -> offlineSessions.put(userSession.getId(), createOfflineSessionIncludeClientSessions(currentSession, userSession)));
offlineSessions.put(userSession.getId(), createOfflineSessionIncludeClientSessions(currentSession, userSession));
}
// Assert all previously saved offline sessions found // Assert all previously saved offline sessions found
for (Map.Entry<String, Set<String>> entry : offlineSessions.entrySet()) { for (Map.Entry<String, Set<String>> entry : offlineSessions.entrySet()) {

View file

@ -46,10 +46,12 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
@ -249,8 +251,10 @@ public class UserSessionProviderTest extends AbstractTestRealmKeycloakTest {
} }
assertSessions(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user1", realm)), sessions[0], sessions[1]); assertSessions(session.sessions().getUserSessionsStream(realm, session.users().getUserByUsername("user1", realm))
assertSessions(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user2", realm)), sessions[2]); .collect(Collectors.toList()), sessions[0], sessions[1]);
assertSessions(session.sessions().getUserSessionsStream(realm, session.users().getUserByUsername("user2", realm))
.collect(Collectors.toList()), sessions[2]);
} }
@Test @Test
@ -262,20 +266,18 @@ public class UserSessionProviderTest extends AbstractTestRealmKeycloakTest {
inheritClientConnection(session, kcSession); inheritClientConnection(session, kcSession);
createSessions(kcSession); createSessions(kcSession);
}); });
Map<String, Integer> clientSessionsKept = new HashMap<>(); Map<String, Integer> clientSessionsKept = session.sessions().getUserSessionsStream(realm,
for (UserSessionModel s : session.sessions().getUserSessions(realm, session.users().getUserByUsername("user2", realm))
session.users().getUserByUsername("user2", realm))) { .collect(Collectors.toMap(model -> model.getId(), model -> model.getAuthenticatedClientSessions().keySet().size()));
clientSessionsKept.put(s.getId(), s.getAuthenticatedClientSessions().keySet().size());
}
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession kcSession) -> { KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession kcSession) -> {
kcSession.sessions().removeUserSessions(realm, kcSession.users().getUserByUsername("user1", realm)); kcSession.sessions().removeUserSessions(realm, kcSession.users().getUserByUsername("user1", realm));
}); });
assertTrue(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user1", realm)).isEmpty()); assertEquals(0, session.sessions().getUserSessionsStream(realm, session.users().getUserByUsername("user1", realm))
List<UserSessionModel> userSessions = session.sessions().getUserSessions(realm, .count());
session.users().getUserByUsername("user2", realm)); List<UserSessionModel> userSessions = session.sessions().getUserSessionsStream(realm,
session.users().getUserByUsername("user2", realm)).collect(Collectors.toList());
assertSame(userSessions.size(), 1); assertSame(userSessions.size(), 1);
@ -309,8 +311,10 @@ public class UserSessionProviderTest extends AbstractTestRealmKeycloakTest {
kcSession.sessions().removeUserSessions(realm); kcSession.sessions().removeUserSessions(realm);
}); });
assertTrue(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user1", realm)).isEmpty()); assertEquals(0, session.sessions().getUserSessionsStream(realm, session.users().getUserByUsername("user1", realm))
assertTrue(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user2", realm)).isEmpty()); .count());
assertEquals(0, session.sessions().getUserSessionsStream(realm, session.users().getUserByUsername("user2", realm))
.count());
} }
@Test @Test
@ -544,8 +548,10 @@ public class UserSessionProviderTest extends AbstractTestRealmKeycloakTest {
} }
assertSessions(session.sessions().getUserSessions(realm, realm.getClientByClientId("test-app")), sessions[0], sessions[1], sessions[2]); assertSessions(session.sessions().getUserSessionsStream(realm, realm.getClientByClientId("test-app"))
assertSessions(session.sessions().getUserSessions(realm, realm.getClientByClientId("third-party")), sessions[0]); .collect(Collectors.toList()), sessions[0], sessions[1], sessions[2]);
assertSessions(session.sessions().getUserSessionsStream(realm, realm.getClientByClientId("third-party"))
.collect(Collectors.toList()), sessions[0]);
} }
@Test @Test
@ -663,7 +669,7 @@ public class UserSessionProviderTest extends AbstractTestRealmKeycloakTest {
} }
private static void assertPaginatedSession(KeycloakSession session, RealmModel realm, ClientModel client, int start, int max, int expectedSize) { private static void assertPaginatedSession(KeycloakSession session, RealmModel realm, ClientModel client, int start, int max, int expectedSize) {
List<UserSessionModel> sessions = session.sessions().getUserSessions(realm, client, start, max); List<UserSessionModel> sessions = session.sessions().getUserSessionsStream(realm, client, start, max).collect(Collectors.toList());
String[] actualIps = new String[sessions.size()]; String[] actualIps = new String[sessions.size()];
for (int i = 0; i < actualIps.length; i++) { for (int i = 0; i < actualIps.length; i++) {
@ -773,11 +779,11 @@ public class UserSessionProviderTest extends AbstractTestRealmKeycloakTest {
session.userStorageManager().removeUser(realm, user1); session.userStorageManager().removeUser(realm, user1);
assertTrue(session.sessions().getUserSessions(realm, user1).isEmpty()); assertEquals(0, session.sessions().getUserSessionsStream(realm, user1).count());
session.getTransactionManager().commit(); session.getTransactionManager().commit();
assertFalse(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user2", realm)).isEmpty()); assertNotEquals(0, session.sessions().getUserSessionsStream(realm, session.users().getUserByUsername("user2", realm)).count());
user1 = session.users().getUserByUsername("user1", realm); user1 = session.users().getUserByUsername("user1", realm);
user2 = session.users().getUserByUsername("user2", realm); user2 = session.users().getUserByUsername("user2", realm);