Load client sessions in chunks from the database (#32185)

Closes #32180

Signed-off-by: Alexander Schwartz <aschwart@redhat.com>
This commit is contained in:
Alexander Schwartz 2024-08-16 17:00:57 +02:00 committed by GitHub
parent b0dfef0c60
commit 74fec50ac5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -50,6 +50,7 @@ import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import jakarta.persistence.LockModeType;
import org.keycloak.utils.StreamsUtil;
import static org.keycloak.models.jpa.PaginationUtils.paginateQuery;
import static org.keycloak.utils.StreamsUtil.closing;
@ -467,25 +468,48 @@ public class JpaUserSessionPersisterProvider implements UserSessionPersisterProv
* @return
*/
private Stream<UserSessionModel> loadUserSessionsWithClientSessions(TypedQuery<PersistentUserSessionEntity> query, String offlineStr, boolean useExact) {
List<OfflineUserSessionModel> userSessionAdapters = closing(query.getResultStream()
.map(this::toAdapter)
.filter(Objects::nonNull))
.collect(Collectors.toList());
Map<String, OfflineUserSessionModel> sessionsById = userSessionAdapters.stream()
.collect(Collectors.toMap(UserSessionModel::getId, Function.identity()));
if (useExact) {
Set<String> userSessionIds = sessionsById.keySet();
// Take the results returned by the database in chunks and enrich them.
// The chunking avoids loading all the entries, as the caller usually adds pagination for the frontend.
return closing(StreamsUtil.chunkedStream(closing(query.getResultStream()).map(this::toAdapter).filter(Objects::nonNull), 100)
.flatMap(batchedUserSessions -> {
Set<String> removedClientUUIDs = new HashSet<>();
Set<String> removedClientUUIDs = new HashSet<>();
Map<String, OfflineUserSessionModel> sessionsById = batchedUserSessions.stream()
.collect(Collectors.toMap(UserSessionModel::getId, Function.identity()));
if (!sessionsById.isEmpty()) {
TypedQuery<PersistentClientSessionEntity> queryClientSessions;
if (useExact) {
queryClientSessions = em.createNamedQuery("findClientSessionsOrderedByIdExact", PersistentClientSessionEntity.class);
queryClientSessions.setParameter("offline", offlineStr);
queryClientSessions.setParameter("userSessionIds", userSessionIds);
} else {
Set<String> userSessionIds = sessionsById.keySet();
TypedQuery<PersistentClientSessionEntity> queryClientSessions;
queryClientSessions = em.createNamedQuery("findClientSessionsOrderedByIdExact", PersistentClientSessionEntity.class);
queryClientSessions.setParameter("offline", offlineStr);
queryClientSessions.setParameter("userSessionIds", userSessionIds);
processClientSessions(sessionsById, removedClientUUIDs, queryClientSessions);
for (String clientUUID : removedClientUUIDs) {
onClientRemoved(clientUUID);
}
logger.tracef("Loaded %d batch of user sessions (offline=%s, sessionIds=%s)", batchedUserSessions.size(), offlineStr, sessionsById.keySet());
return batchedUserSessions.stream();
}).map(UserSessionModel.class::cast));
} else {
List<OfflineUserSessionModel> userSessionAdapters = closing(query.getResultStream()
.map(this::toAdapter)
.filter(Objects::nonNull))
.toList();
Map<String, OfflineUserSessionModel> sessionsById = userSessionAdapters.stream()
.collect(Collectors.toMap(UserSessionModel::getId, Function.identity()));
Set<String> removedClientUUIDs = new HashSet<>();
if (!sessionsById.isEmpty()) {
TypedQuery<PersistentClientSessionEntity> queryClientSessions;
String fromUserSessionId = userSessionAdapters.get(0).getId();
String toUserSessionId = userSessionAdapters.get(userSessionAdapters.size() - 1).getId();
@ -493,28 +517,33 @@ public class JpaUserSessionPersisterProvider implements UserSessionPersisterProv
queryClientSessions.setParameter("offline", offlineStr);
queryClientSessions.setParameter("fromSessionId", fromUserSessionId);
queryClientSessions.setParameter("toSessionId", toUserSessionId);
processClientSessions(sessionsById, removedClientUUIDs, queryClientSessions);
}
closing(queryClientSessions.getResultStream()).forEach(clientSession -> {
OfflineUserSessionModel userSession = sessionsById.get(clientSession.getUserSessionId());
// check if we have a user session for the client session
if (userSession != null) {
boolean added = addClientSessionToAuthenticatedClientSessionsIfPresent(userSession, clientSession);
if (!added) {
// client was removed in the meantime
removedClientUUIDs.add(clientSession.getClientId());
}
for (String clientUUID : removedClientUUIDs) {
onClientRemoved(clientUUID);
}
logger.tracef("Loaded %d user sessions (offline=%s, sessionIds=%s)", userSessionAdapters.size(), offlineStr, sessionsById.keySet());
return userSessionAdapters.stream().map(UserSessionModel.class::cast);
}
}
private void processClientSessions(Map<String, OfflineUserSessionModel> sessionsById, Set<String> removedClientUUIDs, TypedQuery<PersistentClientSessionEntity> queryClientSessions) {
closing(queryClientSessions.getResultStream()).forEach(clientSession -> {
OfflineUserSessionModel userSession = sessionsById.get(clientSession.getUserSessionId());
// check if we have a user session for the client session
if (userSession != null) {
boolean added = addClientSessionToAuthenticatedClientSessionsIfPresent(userSession, clientSession);
if (!added) {
// client was removed in the meantime
removedClientUUIDs.add(clientSession.getClientId());
}
});
}
for (String clientUUID : removedClientUUIDs) {
onClientRemoved(clientUUID);
}
logger.tracef("Loaded %d user sessions (offline=%s, sessionIds=%s)", userSessionAdapters.size(), offlineStr, sessionsById.keySet());
return userSessionAdapters.stream().map(UserSessionModel.class::cast);
}
});
}
private boolean addClientSessionToAuthenticatedClientSessionsIfPresent(OfflineUserSessionModel userSession, PersistentClientSessionEntity clientSessionEntity) {