diff --git a/model/jpa/src/main/java/org/keycloak/models/jpa/session/JpaUserSessionPersisterProvider.java b/model/jpa/src/main/java/org/keycloak/models/jpa/session/JpaUserSessionPersisterProvider.java index 3e253302dd..6ee15f956a 100644 --- a/model/jpa/src/main/java/org/keycloak/models/jpa/session/JpaUserSessionPersisterProvider.java +++ b/model/jpa/src/main/java/org/keycloak/models/jpa/session/JpaUserSessionPersisterProvider.java @@ -50,6 +50,7 @@ import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; import jakarta.persistence.LockModeType; +import org.keycloak.utils.StreamsUtil; import static org.keycloak.models.jpa.PaginationUtils.paginateQuery; import static org.keycloak.utils.StreamsUtil.closing; @@ -467,25 +468,48 @@ public class JpaUserSessionPersisterProvider implements UserSessionPersisterProv * @return */ private Stream loadUserSessionsWithClientSessions(TypedQuery query, String offlineStr, boolean useExact) { - List userSessionAdapters = closing(query.getResultStream() - .map(this::toAdapter) - .filter(Objects::nonNull)) - .collect(Collectors.toList()); - Map sessionsById = userSessionAdapters.stream() - .collect(Collectors.toMap(UserSessionModel::getId, Function.identity())); + if (useExact) { - Set userSessionIds = sessionsById.keySet(); + // Take the results returned by the database in chunks and enrich them. + // The chunking avoids loading all the entries, as the caller usually adds pagination for the frontend. + return closing(StreamsUtil.chunkedStream(closing(query.getResultStream()).map(this::toAdapter).filter(Objects::nonNull), 100) + .flatMap(batchedUserSessions -> { + Set removedClientUUIDs = new HashSet<>(); - Set removedClientUUIDs = new HashSet<>(); + Map sessionsById = batchedUserSessions.stream() + .collect(Collectors.toMap(UserSessionModel::getId, Function.identity())); - if (!sessionsById.isEmpty()) { - TypedQuery queryClientSessions; - if (useExact) { - queryClientSessions = em.createNamedQuery("findClientSessionsOrderedByIdExact", PersistentClientSessionEntity.class); - queryClientSessions.setParameter("offline", offlineStr); - queryClientSessions.setParameter("userSessionIds", userSessionIds); - } else { + Set userSessionIds = sessionsById.keySet(); + + TypedQuery queryClientSessions; + queryClientSessions = em.createNamedQuery("findClientSessionsOrderedByIdExact", PersistentClientSessionEntity.class); + queryClientSessions.setParameter("offline", offlineStr); + queryClientSessions.setParameter("userSessionIds", userSessionIds); + + processClientSessions(sessionsById, removedClientUUIDs, queryClientSessions); + + for (String clientUUID : removedClientUUIDs) { + onClientRemoved(clientUUID); + } + + logger.tracef("Loaded %d batch of user sessions (offline=%s, sessionIds=%s)", batchedUserSessions.size(), offlineStr, sessionsById.keySet()); + + return batchedUserSessions.stream(); + }).map(UserSessionModel.class::cast)); + } else { + List userSessionAdapters = closing(query.getResultStream() + .map(this::toAdapter) + .filter(Objects::nonNull)) + .toList(); + + Map sessionsById = userSessionAdapters.stream() + .collect(Collectors.toMap(UserSessionModel::getId, Function.identity())); + + Set removedClientUUIDs = new HashSet<>(); + + if (!sessionsById.isEmpty()) { + TypedQuery queryClientSessions; String fromUserSessionId = userSessionAdapters.get(0).getId(); String toUserSessionId = userSessionAdapters.get(userSessionAdapters.size() - 1).getId(); @@ -493,28 +517,33 @@ public class JpaUserSessionPersisterProvider implements UserSessionPersisterProv queryClientSessions.setParameter("offline", offlineStr); queryClientSessions.setParameter("fromSessionId", fromUserSessionId); queryClientSessions.setParameter("toSessionId", toUserSessionId); + + processClientSessions(sessionsById, removedClientUUIDs, queryClientSessions); } - closing(queryClientSessions.getResultStream()).forEach(clientSession -> { - OfflineUserSessionModel userSession = sessionsById.get(clientSession.getUserSessionId()); - // check if we have a user session for the client session - if (userSession != null) { - boolean added = addClientSessionToAuthenticatedClientSessionsIfPresent(userSession, clientSession); - if (!added) { - // client was removed in the meantime - removedClientUUIDs.add(clientSession.getClientId()); - } + for (String clientUUID : removedClientUUIDs) { + onClientRemoved(clientUUID); + } + + logger.tracef("Loaded %d user sessions (offline=%s, sessionIds=%s)", userSessionAdapters.size(), offlineStr, sessionsById.keySet()); + + return userSessionAdapters.stream().map(UserSessionModel.class::cast); + + } + } + + private void processClientSessions(Map sessionsById, Set removedClientUUIDs, TypedQuery queryClientSessions) { + closing(queryClientSessions.getResultStream()).forEach(clientSession -> { + OfflineUserSessionModel userSession = sessionsById.get(clientSession.getUserSessionId()); + // check if we have a user session for the client session + if (userSession != null) { + boolean added = addClientSessionToAuthenticatedClientSessionsIfPresent(userSession, clientSession); + if (!added) { + // client was removed in the meantime + removedClientUUIDs.add(clientSession.getClientId()); } - }); - } - - for (String clientUUID : removedClientUUIDs) { - onClientRemoved(clientUUID); - } - - logger.tracef("Loaded %d user sessions (offline=%s, sessionIds=%s)", userSessionAdapters.size(), offlineStr, sessionsById.keySet()); - - return userSessionAdapters.stream().map(UserSessionModel.class::cast); + } + }); } private boolean addClientSessionToAuthenticatedClientSessionsIfPresent(OfflineUserSessionModel userSession, PersistentClientSessionEntity clientSessionEntity) {