Unexpected invalid_grant error on offline session refresh when client session is not in the cache

Closes #9959

Co-authored-by: Martin Kanis <mkanis@redhat.com>
Co-authored-by: Lex Cao <lexcao@foxmail.com>
This commit is contained in:
Martin Kanis 2023-03-06 17:03:38 +01:00 committed by Michal Hajas
parent ce1e0a65e7
commit 5e7793b64d
7 changed files with 116 additions and 9 deletions

View file

@ -387,13 +387,33 @@ public class InfinispanUserSessionProvider implements UserSessionProvider {
@Override @Override
public AuthenticatedClientSessionAdapter getClientSession(UserSessionModel userSession, ClientModel client, String clientSessionId, boolean offline) { public AuthenticatedClientSessionAdapter getClientSession(UserSessionModel userSession, ClientModel client, String clientSessionId, boolean offline) {
return getClientSession(userSession, client, clientSessionId == null ? null : UUID.fromString(clientSessionId), offline); if (clientSessionId == null) {
return null;
}
AuthenticatedClientSessionEntity clientSessionEntityFromCache = getClientSessionEntity(UUID.fromString(clientSessionId), offline);
if (clientSessionEntityFromCache != null) {
return wrap(userSession, client, clientSessionEntityFromCache, offline);
}
// offline client session lookup in the persister
if (offline) {
log.debugf("Offline client session is not found in cache, try to load from db, userSession [%s] clientSessionId [%s] clientId [%s]", userSession.getId(), clientSessionId, client.getClientId());
return getClientSessionEntityFromPersistenceProvider(userSession, client, true);
}
return null;
} }
@Override private AuthenticatedClientSessionAdapter getClientSessionEntityFromPersistenceProvider(UserSessionModel userSession, ClientModel client, boolean offline) {
public AuthenticatedClientSessionAdapter getClientSession(UserSessionModel userSession, ClientModel client, UUID clientSessionId, boolean offline) { UserSessionPersisterProvider persister = session.getProvider(UserSessionPersisterProvider.class);
AuthenticatedClientSessionEntity entity = getClientSessionEntity(clientSessionId, offline); AuthenticatedClientSessionModel clientSession = persister.loadClientSession(session.getContext().getRealm(), client, userSession, offline);
return wrap(userSession, client, entity, offline);
if (clientSession == null) {
return null;
}
return importClientSession((UserSessionAdapter) userSession, clientSession, getTransaction(offline), getClientSessionTransaction(offline), offline);
} }
private AuthenticatedClientSessionEntity getClientSessionEntity(UUID id, boolean offline) { private AuthenticatedClientSessionEntity getClientSessionEntity(UUID id, boolean offline) {

View file

@ -91,7 +91,7 @@ public class UserSessionAdapter implements UserSessionModel {
// Check if client still exists // Check if client still exists
ClientModel client = realm.getClientById(key); ClientModel client = realm.getClientById(key);
if (client != null) { if (client != null) {
final AuthenticatedClientSessionAdapter clientSession = provider.getClientSession(this, client, value, offline); final AuthenticatedClientSessionAdapter clientSession = provider.getClientSession(this, client, value.toString(), offline);
if (clientSession != null) { if (clientSession != null) {
result.put(key, clientSession); result.put(key, clientSession);
} }

View file

@ -382,6 +382,30 @@ public class JpaUserSessionPersisterProvider implements UserSessionPersisterProv
return loadUserSessionsWithClientSessions(query, offlineStr, false); return loadUserSessionsWithClientSessions(query, offlineStr, false);
} }
@Override
public AuthenticatedClientSessionModel loadClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession, boolean offline) {
TypedQuery<PersistentClientSessionEntity> query;
StorageId clientStorageId = new StorageId(client.getId());
if (clientStorageId.isLocal()) {
query = em.createNamedQuery("findClientSessionsByUserSessionAndClient", PersistentClientSessionEntity.class);
query.setParameter("clientId", client.getId());
} else {
query = em.createNamedQuery("findClientSessionsByUserSessionAndExternalClient", PersistentClientSessionEntity.class);
query.setParameter("clientStorageProvider", clientStorageId.getProviderId());
query.setParameter("externalClientId", clientStorageId.getExternalId());
}
String offlineStr = offlineToString(offline);
query.setParameter("userSessionId", userSession.getId());
query.setParameter("offline", offlineStr);
query.setMaxResults(1);
return closing(query.getResultStream())
.map(entity -> toAdapter(realm, userSession, entity))
.findFirst()
.orElse(null);
}
/** /**
* *
* @param query * @param query
@ -479,17 +503,25 @@ public class JpaUserSessionPersisterProvider implements UserSessionPersisterProv
return new PersistentUserSessionAdapter(session, model, realm, entity.getUserId(), clientSessions); return new PersistentUserSessionAdapter(session, model, realm, entity.getUserId(), clientSessions);
} }
private PersistentAuthenticatedClientSessionAdapter toAdapter(RealmModel realm, PersistentUserSessionAdapter userSession, PersistentClientSessionEntity entity) { private PersistentAuthenticatedClientSessionAdapter toAdapter(RealmModel realm, UserSessionModel userSession, PersistentClientSessionEntity entity) {
String clientId = entity.getClientId(); String clientId = entity.getClientId();
if (isExternalClient(entity)) { if (isExternalClient(entity)) {
clientId = getExternalClientId(entity); clientId = getExternalClientId(entity);
} }
// can be null if client is not found anymore
ClientModel client = realm.getClientById(clientId); ClientModel client = realm.getClientById(clientId);
PersistentClientSessionModel model = new PersistentClientSessionModel(); PersistentClientSessionModel model = new PersistentClientSessionModel();
model.setClientId(clientId); model.setClientId(clientId);
model.setUserSessionId(userSession.getId()); model.setUserSessionId(userSession.getId());
model.setUserId(userSession.getUserId());
UserModel user = userSession.getUser();
if (user != null) {
model.setUserId(user.getId());
}
else if (userSession instanceof PersistentUserSessionAdapter) {
model.setUserId(((PersistentUserSessionAdapter) userSession).getUserId());
}
model.setTimestamp(entity.getTimestamp()); model.setTimestamp(entity.getTimestamp());
model.setData(entity.getData()); model.setData(entity.getData());
return new PersistentAuthenticatedClientSessionAdapter(session, model, realm, client, userSession); return new PersistentAuthenticatedClientSessionAdapter(session, model, realm, client, userSession);

View file

@ -47,7 +47,9 @@ import java.io.Serializable;
@NamedQuery(name="findClientSessionsOrderedByIdInterval", query="select sess from PersistentClientSessionEntity sess where sess.offline = :offline and sess.userSessionId >= :fromSessionId and sess.userSessionId <= :toSessionId order by sess.userSessionId"), @NamedQuery(name="findClientSessionsOrderedByIdInterval", query="select sess from PersistentClientSessionEntity sess where sess.offline = :offline and sess.userSessionId >= :fromSessionId and sess.userSessionId <= :toSessionId order by sess.userSessionId"),
@NamedQuery(name="findClientSessionsOrderedByIdExact", query="select sess from PersistentClientSessionEntity sess where sess.offline = :offline and sess.userSessionId IN (:userSessionIds)"), @NamedQuery(name="findClientSessionsOrderedByIdExact", query="select sess from PersistentClientSessionEntity sess where sess.offline = :offline and sess.userSessionId IN (:userSessionIds)"),
@NamedQuery(name="findClientSessionsCountByClient", query="select count(sess) from PersistentClientSessionEntity sess where sess.offline = :offline and sess.clientId = :clientId"), @NamedQuery(name="findClientSessionsCountByClient", query="select count(sess) from PersistentClientSessionEntity sess where sess.offline = :offline and sess.clientId = :clientId"),
@NamedQuery(name="findClientSessionsCountByExternalClient", query="select count(sess) from PersistentClientSessionEntity sess where sess.offline = :offline and sess.clientStorageProvider = :clientStorageProvider and sess.externalClientId = :externalClientId") @NamedQuery(name="findClientSessionsCountByExternalClient", query="select count(sess) from PersistentClientSessionEntity sess where sess.offline = :offline and sess.clientStorageProvider = :clientStorageProvider and sess.externalClientId = :externalClientId"),
@NamedQuery(name="findClientSessionsByUserSessionAndClient", query="select sess from PersistentClientSessionEntity sess where sess.userSessionId=:userSessionId and sess.offline = :offline and sess.clientId=:clientId"),
@NamedQuery(name="findClientSessionsByUserSessionAndExternalClient", query="select sess from PersistentClientSessionEntity sess where sess.userSessionId=:userSessionId and sess.offline = :offline and sess.clientStorageProvider = :clientStorageProvider and sess.externalClientId = :externalClientId")
}) })
@Table(name="OFFLINE_CLIENT_SESSION") @Table(name="OFFLINE_CLIENT_SESSION")
@Entity @Entity

View file

@ -131,6 +131,11 @@ public class DisabledUserSessionPersisterProvider implements UserSessionPersiste
return Stream.empty(); return Stream.empty();
} }
@Override
public AuthenticatedClientSessionModel loadClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession, boolean offline) {
return null;
}
@Override @Override
public int getUserSessionsCount(boolean offline) { public int getUserSessionsCount(boolean offline) {
return 0; return 0;

View file

@ -109,6 +109,16 @@ public interface UserSessionPersisterProvider extends Provider {
Stream<UserSessionModel> loadUserSessionsStream(Integer firstResult, Integer maxResults, boolean offline, Stream<UserSessionModel> loadUserSessionsStream(Integer firstResult, Integer maxResults, boolean offline,
String lastUserSessionId); String lastUserSessionId);
/**
* Loads client session from the db by provided user session and client.
* @param realm RealmModel Realm for the associated client session.
* @param client ClientModel Client used for the creation of client session.
* @param userSession UserSessionModel User session for the associated client session.
* @param offline boolean Flag that indicates the client session should be online/offline.
* @return Client session according the provided criteria or {@code null} if not found.
*/
AuthenticatedClientSessionModel loadClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession, boolean offline);
/** /**
* Retrieves the count of user sessions for all realms. * Retrieves the count of user sessions for all realms.
* *

View file

@ -17,6 +17,7 @@
package org.keycloak.testsuite.model.session; package org.keycloak.testsuite.model.session;
import org.hamcrest.Matchers;
import org.infinispan.Cache; import org.infinispan.Cache;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Assume; import org.junit.Assume;
@ -51,6 +52,7 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
@ -373,6 +375,42 @@ public class UserSessionProviderOfflineModelTest extends KeycloakModelTest {
} }
@Test
public void testOfflineClientSessionLoading() {
// create online user and client sessions
inComittedTransaction((Consumer<KeycloakSession>) session -> UserSessionPersisterProviderTest.createSessions(session, realmId));
// create offline user and client sessions
withRealm(realmId, (session, realm) -> {
session.sessions().getUserSessionsStream(realm, realm.getClientByClientId("test-app")).collect(Collectors.toList())
.forEach(userSession -> createOfflineSessionIncludeClientSessions(session, userSession));
return null;
});
List<String> offlineUserSessionIds = withRealm(realmId, (session, realm) -> {
UserModel user = session.users().getUserByUsername(realm, "user1");
List<String> ids = session.sessions().getOfflineUserSessionsStream(realm, user).map(UserSessionModel::getId).collect(Collectors.toList());
Assert.assertThat(ids, Matchers.hasSize(2));
return ids;
});
withRealm(realmId, (session, realm) -> {
// remove offline client sessions from the cache
// this simulates the cases when offline client sessions are lost from the cache due to various reasons (a cache limit/expiration/preloading issue)
session.getProvider(InfinispanConnectionProvider.class).getCache(InfinispanConnectionProvider.OFFLINE_CLIENT_SESSION_CACHE_NAME).clear();
String clientUUID = realm.getClientByClientId("test-app").getId();
offlineUserSessionIds.forEach(id -> {
UserSessionModel offlineUserSession = session.sessions().getOfflineUserSession(realm, id);
// each associated offline client session should be found by looking into persister
Assert.assertNotNull(offlineUserSession.getAuthenticatedClientSessionByClient(clientUUID));
});
return null;
});
}
private static Set<String> createOfflineSessionIncludeClientSessions(KeycloakSession session, UserSessionModel private static Set<String> createOfflineSessionIncludeClientSessions(KeycloakSession session, UserSessionModel
userSession) { userSession) {
Set<String> offlineSessions = new HashSet<>(); Set<String> offlineSessions = new HashSet<>();