Merge pull request #2320 from mposolda/master

KEYCLOAK-2523 Fix concurrency tests with all databases by track trans…
This commit is contained in:
Marek Posolda 2016-03-03 12:34:29 +01:00
commit 002074bb30
6 changed files with 115 additions and 81 deletions

View file

@ -305,7 +305,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
if (model == null) return null; if (model == null) return null;
if (invalidations.contains(id)) return model; if (invalidations.contains(id)) return model;
cached = new CachedRealm(loaded, model); cached = new CachedRealm(loaded, model);
cache.addRevisioned(cached); cache.addRevisioned(cached, session);
} else if (invalidations.contains(id)) { } else if (invalidations.contains(id)) {
return getDelegate().getRealm(id); return getDelegate().getRealm(id);
} else if (managedRealms.containsKey(id)) { } else if (managedRealms.containsKey(id)) {
@ -329,7 +329,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
if (model == null) return null; if (model == null) return null;
if (invalidations.contains(model.getId())) return model; if (invalidations.contains(model.getId())) return model;
query = new RealmListQuery(loaded, cacheKey, model.getId()); query = new RealmListQuery(loaded, cacheKey, model.getId());
cache.addRevisioned(query); cache.addRevisioned(query, session);
return model; return model;
} else if (invalidations.contains(cacheKey)) { } else if (invalidations.contains(cacheKey)) {
return getDelegate().getRealmByName(name); return getDelegate().getRealmByName(name);
@ -435,7 +435,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
for (ClientModel client : model) ids.add(client.getId()); for (ClientModel client : model) ids.add(client.getId());
query = new ClientListQuery(loaded, cacheKey, realm, ids); query = new ClientListQuery(loaded, cacheKey, realm, ids);
logger.tracev("adding realm clients cache miss: realm {0} key {1}", realm.getName(), cacheKey); logger.tracev("adding realm clients cache miss: realm {0} key {1}", realm.getName(), cacheKey);
cache.addRevisioned(query); cache.addRevisioned(query, session);
return model; return model;
} }
List<ClientModel> list = new LinkedList<>(); List<ClientModel> list = new LinkedList<>();
@ -508,7 +508,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
for (RoleModel role : model) ids.add(role.getId()); for (RoleModel role : model) ids.add(role.getId());
query = new RoleListQuery(loaded, cacheKey, realm, ids); query = new RoleListQuery(loaded, cacheKey, realm, ids);
logger.tracev("adding realm roles cache miss: realm {0} key {1}", realm.getName(), cacheKey); logger.tracev("adding realm roles cache miss: realm {0} key {1}", realm.getName(), cacheKey);
cache.addRevisioned(query); cache.addRevisioned(query, session);
return model; return model;
} }
Set<RoleModel> list = new HashSet<>(); Set<RoleModel> list = new HashSet<>();
@ -544,7 +544,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
for (RoleModel role : model) ids.add(role.getId()); for (RoleModel role : model) ids.add(role.getId());
query = new RoleListQuery(loaded, cacheKey, realm, ids, client.getClientId()); query = new RoleListQuery(loaded, cacheKey, realm, ids, client.getClientId());
logger.tracev("adding client roles cache miss: client {0} key {1}", client.getClientId(), cacheKey); logger.tracev("adding client roles cache miss: client {0} key {1}", client.getClientId(), cacheKey);
cache.addRevisioned(query); cache.addRevisioned(query, session);
return model; return model;
} }
Set<RoleModel> list = new HashSet<>(); Set<RoleModel> list = new HashSet<>();
@ -593,7 +593,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
if (model == null) return null; if (model == null) return null;
query = new RoleListQuery(loaded, cacheKey, realm, model.getId()); query = new RoleListQuery(loaded, cacheKey, realm, model.getId());
logger.tracev("adding realm role cache miss: client {0} key {1}", realm.getName(), cacheKey); logger.tracev("adding realm role cache miss: client {0} key {1}", realm.getName(), cacheKey);
cache.addRevisioned(query); cache.addRevisioned(query, session);
return model; return model;
} }
RoleModel role = getRoleById(query.getRoles().iterator().next(), realm); RoleModel role = getRoleById(query.getRoles().iterator().next(), realm);
@ -623,7 +623,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
if (model == null) return null; if (model == null) return null;
query = new RoleListQuery(loaded, cacheKey, realm, model.getId(), client.getClientId()); query = new RoleListQuery(loaded, cacheKey, realm, model.getId(), client.getClientId());
logger.tracev("adding client role cache miss: client {0} key {1}", client.getClientId(), cacheKey); logger.tracev("adding client role cache miss: client {0} key {1}", client.getClientId(), cacheKey);
cache.addRevisioned(query); cache.addRevisioned(query, session);
return model; return model;
} }
RoleModel role = getRoleById(query.getRoles().iterator().next(), realm); RoleModel role = getRoleById(query.getRoles().iterator().next(), realm);
@ -660,7 +660,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
} else { } else {
cached = new CachedRealmRole(loaded, model, realm); cached = new CachedRealmRole(loaded, model, realm);
} }
cache.addRevisioned(cached); cache.addRevisioned(cached, session);
} else if (invalidations.contains(id)) { } else if (invalidations.contains(id)) {
return getDelegate().getRoleById(id, realm); return getDelegate().getRoleById(id, realm);
@ -685,7 +685,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
if (model == null) return null; if (model == null) return null;
if (invalidations.contains(id)) return model; if (invalidations.contains(id)) return model;
cached = new CachedGroup(loaded, realm, model); cached = new CachedGroup(loaded, realm, model);
cache.addRevisioned(cached); cache.addRevisioned(cached, session);
} else if (invalidations.contains(id)) { } else if (invalidations.contains(id)) {
return getDelegate().getGroupById(id, realm); return getDelegate().getGroupById(id, realm);
@ -725,7 +725,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
for (GroupModel client : model) ids.add(client.getId()); for (GroupModel client : model) ids.add(client.getId());
query = new GroupListQuery(loaded, cacheKey, realm, ids); query = new GroupListQuery(loaded, cacheKey, realm, ids);
logger.tracev("adding realm getGroups cache miss: realm {0} key {1}", realm.getName(), cacheKey); logger.tracev("adding realm getGroups cache miss: realm {0} key {1}", realm.getName(), cacheKey);
cache.addRevisioned(query); cache.addRevisioned(query, session);
return model; return model;
} }
List<GroupModel> list = new LinkedList<>(); List<GroupModel> list = new LinkedList<>();
@ -761,7 +761,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
for (GroupModel client : model) ids.add(client.getId()); for (GroupModel client : model) ids.add(client.getId());
query = new GroupListQuery(loaded, cacheKey, realm, ids); query = new GroupListQuery(loaded, cacheKey, realm, ids);
logger.tracev("adding realm getTopLevelGroups cache miss: realm {0} key {1}", realm.getName(), cacheKey); logger.tracev("adding realm getTopLevelGroups cache miss: realm {0} key {1}", realm.getName(), cacheKey);
cache.addRevisioned(query); cache.addRevisioned(query, session);
return model; return model;
} }
List<GroupModel> list = new LinkedList<>(); List<GroupModel> list = new LinkedList<>();
@ -837,7 +837,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
if (invalidations.contains(id)) return model; if (invalidations.contains(id)) return model;
cached = new CachedClient(loaded, realm, model); cached = new CachedClient(loaded, realm, model);
logger.tracev("adding client by id cache miss: {0}", cached.getClientId()); logger.tracev("adding client by id cache miss: {0}", cached.getClientId());
cache.addRevisioned(cached); cache.addRevisioned(cached, session);
} else if (invalidations.contains(id)) { } else if (invalidations.contains(id)) {
return getDelegate().getClientById(id, realm); return getDelegate().getClientById(id, realm);
} else if (managedApplications.containsKey(id)) { } else if (managedApplications.containsKey(id)) {
@ -866,7 +866,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
id = model.getId(); id = model.getId();
query = new ClientListQuery(loaded, cacheKey, realm, id); query = new ClientListQuery(loaded, cacheKey, realm, id);
logger.tracev("adding client by name cache miss: {0}", clientId); logger.tracev("adding client by name cache miss: {0}", clientId);
cache.addRevisioned(query); cache.addRevisioned(query, session);
} else if (invalidations.contains(cacheKey)) { } else if (invalidations.contains(cacheKey)) {
return getDelegate().getClientByClientId(clientId, realm); return getDelegate().getClientByClientId(clientId, realm);
} else { } else {
@ -895,7 +895,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
if (model == null) return null; if (model == null) return null;
if (invalidations.contains(id)) return model; if (invalidations.contains(id)) return model;
cached = new CachedClientTemplate(loaded, realm, model); cached = new CachedClientTemplate(loaded, realm, model);
cache.addRevisioned(cached); cache.addRevisioned(cached, session);
} else if (invalidations.contains(id)) { } else if (invalidations.contains(id)) {
return getDelegate().getClientTemplateById(id, realm); return getDelegate().getClientTemplateById(id, realm);
} else if (managedClientTemplates.containsKey(id)) { } else if (managedClientTemplates.containsKey(id)) {

View file

@ -24,6 +24,7 @@ import org.infinispan.notifications.cachelistener.annotation.CacheEntryInvalidat
import org.infinispan.notifications.cachelistener.event.CacheEntriesEvictedEvent; import org.infinispan.notifications.cachelistener.event.CacheEntriesEvictedEvent;
import org.infinispan.notifications.cachelistener.event.CacheEntryInvalidatedEvent; import org.infinispan.notifications.cachelistener.event.CacheEntryInvalidatedEvent;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.cache.infinispan.entities.AbstractRevisioned; import org.keycloak.models.cache.infinispan.entities.AbstractRevisioned;
import org.keycloak.models.cache.infinispan.entities.CachedClient; import org.keycloak.models.cache.infinispan.entities.CachedClient;
import org.keycloak.models.cache.infinispan.entities.CachedClientTemplate; import org.keycloak.models.cache.infinispan.entities.CachedClientTemplate;
@ -38,7 +39,7 @@ import org.keycloak.models.cache.infinispan.stream.HasRolePredicate;
import org.keycloak.models.cache.infinispan.stream.InClientPredicate; import org.keycloak.models.cache.infinispan.stream.InClientPredicate;
import org.keycloak.models.cache.infinispan.stream.InRealmPredicate; import org.keycloak.models.cache.infinispan.stream.InRealmPredicate;
import org.keycloak.models.cache.infinispan.stream.RealmQueryPredicate; import org.keycloak.models.cache.infinispan.stream.RealmQueryPredicate;
import org.keycloak.models.cache.infinispan.stream.RoleQueryPredicate; import org.keycloak.models.utils.UpdateCounter;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
@ -73,7 +74,9 @@ public class StreamRealmCache {
public Long getCurrentRevision(String id) { public Long getCurrentRevision(String id) {
Long revision = revisions.get(id); Long revision = revisions.get(id);
if (revision == null) revision = UpdateCounter.current(); if (revision == null) {
revision = UpdateCounter.current();
}
// if you do cache.remove() on node 1 and the entry doesn't exist on node 2, node 2 never receives a invalidation event // if you do cache.remove() on node 1 and the entry doesn't exist on node 2, node 2 never receives a invalidation event
// so, we do this to force this. // so, we do this to force this.
String invalidationKey = "invalidation.key" + id; String invalidationKey = "invalidation.key" + id;
@ -121,7 +124,7 @@ public class StreamRealmCache {
Object rev = revisions.put(id, next); Object rev = revisions.put(id, next);
} }
public void addRevisioned(Revisioned object) { public void addRevisioned(Revisioned object, KeycloakSession session) {
//startRevisionBatch(); //startRevisionBatch();
String id = object.getId(); String id = object.getId();
try { try {
@ -135,12 +138,19 @@ public class StreamRealmCache {
revisions.startBatch(); revisions.startBatch();
if (!revisions.getAdvancedCache().lock(id)) { if (!revisions.getAdvancedCache().lock(id)) {
logger.trace("Could not obtain version lock"); logger.trace("Could not obtain version lock");
return;
} }
rev = revisions.get(id); rev = revisions.get(id);
if (rev == null) { if (rev == null) {
if (id.endsWith("realm.clients")) logger.trace("addRevisioned rev2 == null realm.clients"); if (id.endsWith("realm.clients")) logger.trace("addRevisioned rev2 == null realm.clients");
return; return;
} }
if (rev > session.getTransaction().getStartupRevision()) { // revision is ahead transaction start. Other transaction updated in the meantime. Don't cache
if (logger.isTraceEnabled()) {
logger.tracev("Skipped cache. Current revision {0}, Transaction start revision {1}", object.getRevision(), session.getTransaction().getStartupRevision());
}
return;
}
if (rev.equals(object.getRevision())) { if (rev.equals(object.getRevision())) {
if (id.endsWith("realm.clients")) logger.tracev("adding Object.revision {0} rev {1}", object.getRevision(), rev); if (id.endsWith("realm.clients")) logger.tracev("adding Object.revision {0} rev {1}", object.getRevision(), rev);
cache.putForExternalRead(id, object); cache.putForExternalRead(id, object);

View file

@ -23,6 +23,8 @@ package org.keycloak.models;
*/ */
public interface KeycloakTransactionManager extends KeycloakTransaction { public interface KeycloakTransactionManager extends KeycloakTransaction {
long getStartupRevision();
void enlist(KeycloakTransaction transaction); void enlist(KeycloakTransaction transaction);
void enlistAfterCompletion(KeycloakTransaction transaction); void enlistAfterCompletion(KeycloakTransaction transaction);

View file

@ -1,8 +1,10 @@
package org.keycloak.models.cache.infinispan; package org.keycloak.models.utils;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
/** /**
* Used to track cache revisions
*
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/ */
public class UpdateCounter { public class UpdateCounter {

View file

@ -18,6 +18,7 @@ package org.keycloak.services;
import org.keycloak.models.KeycloakTransaction; import org.keycloak.models.KeycloakTransaction;
import org.keycloak.models.KeycloakTransactionManager; import org.keycloak.models.KeycloakTransactionManager;
import org.keycloak.models.utils.UpdateCounter;
import org.keycloak.services.ServicesLogger; import org.keycloak.services.ServicesLogger;
import java.util.LinkedList; import java.util.LinkedList;
@ -35,6 +36,12 @@ public class DefaultKeycloakTransactionManager implements KeycloakTransactionMan
private List<KeycloakTransaction> afterCompletion = new LinkedList<KeycloakTransaction>(); private List<KeycloakTransaction> afterCompletion = new LinkedList<KeycloakTransaction>();
private boolean active; private boolean active;
private boolean rollback; private boolean rollback;
private long startupRevision;
@Override
public long getStartupRevision() {
return startupRevision;
}
@Override @Override
public void enlist(KeycloakTransaction transaction) { public void enlist(KeycloakTransaction transaction) {
@ -69,6 +76,8 @@ public class DefaultKeycloakTransactionManager implements KeycloakTransactionMan
throw new IllegalStateException("Transaction already active"); throw new IllegalStateException("Transaction already active");
} }
startupRevision = UpdateCounter.current();
for (KeycloakTransaction tx : transactions) { for (KeycloakTransaction tx : transactions) {
tx.begin(); tx.begin();
} }

View file

@ -464,72 +464,83 @@ public abstract class AbstractKeycloakIdentityProviderTest extends AbstractIdent
setUpdateProfileFirstLogin(IdentityProviderRepresentation.UPFLM_ON); setUpdateProfileFirstLogin(IdentityProviderRepresentation.UPFLM_ON);
IdentityProviderModel identityProviderModel = getIdentityProviderModel(); IdentityProviderModel identityProviderModel = getIdentityProviderModel();
identityProviderModel.setStoreToken(true); setStoreToken(identityProviderModel, true);
try {
authenticateWithIdentityProvider(identityProviderModel, "test-user", true);
authenticateWithIdentityProvider(identityProviderModel, "test-user", true); brokerServerRule.stopSession(session, true);
session = brokerServerRule.startSession();
brokerServerRule.stopSession(session, true); UserModel federatedUser = getFederatedUser();
RealmModel realm = getRealm();
Set<FederatedIdentityModel> federatedIdentities = this.session.users().getFederatedIdentities(federatedUser, realm);
assertFalse(federatedIdentities.isEmpty());
assertEquals(1, federatedIdentities.size());
FederatedIdentityModel identityModel = federatedIdentities.iterator().next();
assertNotNull(identityModel.getToken());
UserSessionStatusServlet.UserSessionStatus userSessionStatus = retrieveSessionStatus();
String accessToken = userSessionStatus.getAccessTokenString();
URI tokenEndpointUrl = Urls.identityProviderRetrieveToken(BASE_URI, getProviderId(), realm.getName());
final String authHeader = "Bearer " + accessToken;
ClientRequestFilter authFilter = new ClientRequestFilter() {
@Override
public void filter(ClientRequestContext requestContext) throws IOException {
requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader);
}
};
Client client = ClientBuilder.newBuilder().register(authFilter).build();
WebTarget tokenEndpoint = client.target(tokenEndpointUrl);
Response response = tokenEndpoint.request().get();
assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
assertNotNull(response.readEntity(String.class));
revokeGrant();
driver.navigate().to("http://localhost:8081/test-app/logout");
String currentUrl = this.driver.getCurrentUrl();
System.out.println("after logout currentUrl: " + currentUrl);
assertTrue(currentUrl.startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
unconfigureUserRetrieveToken("test-user");
loginIDP("test-user");
//authenticateWithIdentityProvider(identityProviderModel, "test-user");
assertEquals("http://localhost:8081/test-app", driver.getCurrentUrl());
userSessionStatus = retrieveSessionStatus();
accessToken = userSessionStatus.getAccessTokenString();
final String authHeader2 = "Bearer " + accessToken;
ClientRequestFilter authFilter2 = new ClientRequestFilter() {
@Override
public void filter(ClientRequestContext requestContext) throws IOException {
requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader2);
}
};
client = ClientBuilder.newBuilder().register(authFilter2).build();
tokenEndpoint = client.target(tokenEndpointUrl);
response = tokenEndpoint.request().get();
assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus());
revokeGrant();
driver.navigate().to("http://localhost:8081/test-app/logout");
driver.navigate().to("http://localhost:8081/test-app");
assertTrue(this.driver.getCurrentUrl().startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
} finally {
setStoreToken(identityProviderModel, false);
}
}
private void setStoreToken(IdentityProviderModel identityProviderModel, boolean storeToken) {
identityProviderModel.setStoreToken(storeToken);
getRealm().updateIdentityProvider(identityProviderModel);
brokerServerRule.stopSession(session, storeToken);
session = brokerServerRule.startSession(); session = brokerServerRule.startSession();
UserModel federatedUser = getFederatedUser();
RealmModel realm = getRealm();
Set<FederatedIdentityModel> federatedIdentities = this.session.users().getFederatedIdentities(federatedUser, realm);
assertFalse(federatedIdentities.isEmpty());
assertEquals(1, federatedIdentities.size());
FederatedIdentityModel identityModel = federatedIdentities.iterator().next();
assertNotNull(identityModel.getToken());
UserSessionStatusServlet.UserSessionStatus userSessionStatus = retrieveSessionStatus();
String accessToken = userSessionStatus.getAccessTokenString();
URI tokenEndpointUrl = Urls.identityProviderRetrieveToken(BASE_URI, getProviderId(), realm.getName());
final String authHeader = "Bearer " + accessToken;
ClientRequestFilter authFilter = new ClientRequestFilter() {
@Override
public void filter(ClientRequestContext requestContext) throws IOException {
requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader);
}
};
Client client = ClientBuilder.newBuilder().register(authFilter).build();
WebTarget tokenEndpoint = client.target(tokenEndpointUrl);
Response response = tokenEndpoint.request().get();
assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
assertNotNull(response.readEntity(String.class));
revokeGrant();
driver.navigate().to("http://localhost:8081/test-app/logout");
String currentUrl = this.driver.getCurrentUrl();
System.out.println("after logout currentUrl: " + currentUrl);
assertTrue(currentUrl.startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
unconfigureUserRetrieveToken("test-user");
loginIDP("test-user");
//authenticateWithIdentityProvider(identityProviderModel, "test-user");
assertEquals("http://localhost:8081/test-app", driver.getCurrentUrl());
userSessionStatus = retrieveSessionStatus();
accessToken = userSessionStatus.getAccessTokenString();
final String authHeader2 = "Bearer " + accessToken;
ClientRequestFilter authFilter2 = new ClientRequestFilter() {
@Override
public void filter(ClientRequestContext requestContext) throws IOException {
requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader2);
}
};
client = ClientBuilder.newBuilder().register(authFilter2).build();
tokenEndpoint = client.target(tokenEndpointUrl);
response = tokenEndpoint.request().get();
assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus());
revokeGrant();
driver.navigate().to("http://localhost:8081/test-app/logout");
driver.navigate().to("http://localhost:8081/test-app");
assertTrue(this.driver.getCurrentUrl().startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
} }
protected abstract void doAssertTokenRetrieval(String pageSource); protected abstract void doAssertTokenRetrieval(String pageSource);