Avoid race condition when using initial-access-token

Closes #27294

Signed-off-by: Giuseppe Graziano <g.graziano94@gmail.com>
This commit is contained in:
Giuseppe Graziano 2024-06-18 14:18:09 +02:00 committed by Marek Posolda
parent 9300903674
commit 1df60461a9
2 changed files with 61 additions and 8 deletions

View file

@ -2376,7 +2376,7 @@ public class RealmAdapter implements StorageProviderRealmModel, JpaModel<RealmEn
@Override
public ClientInitialAccessModel getClientInitialAccessModel(String id) {
ClientInitialAccessEntity entity = em.find(ClientInitialAccessEntity.class, id);
ClientInitialAccessEntity entity = em.find(ClientInitialAccessEntity.class, id, LockModeType.PESSIMISTIC_WRITE);
if (entity == null) return null;
if (!entity.getRealm().getId().equals(realm.getId())) return null;
return entityToModel(entity);

View file

@ -27,6 +27,7 @@ import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.util.EntityUtils;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;
import org.keycloak.client.registration.Auth;
import org.keycloak.client.registration.ClientRegistration;
@ -38,6 +39,8 @@ import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.protocol.oidc.OIDCAdvancedConfigWrapper;
import org.keycloak.protocol.oidc.OIDCLoginProtocol;
import org.keycloak.protocol.saml.SamlProtocol;
import org.keycloak.representations.idm.ClientInitialAccessCreatePresentation;
import org.keycloak.representations.idm.ClientInitialAccessPresentation;
import org.keycloak.representations.idm.ClientRepresentation;
import org.keycloak.representations.idm.OAuth2ErrorRepresentation;
import org.keycloak.representations.idm.ProtocolMapperRepresentation;
@ -50,16 +53,24 @@ import org.keycloak.util.JsonSerialization;
import jakarta.ws.rs.NotFoundException;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import static java.util.Arrays.asList;
@ -798,4 +809,46 @@ public class ClientRegistrationTest extends AbstractClientRegistrationTest {
}
}
}
@Test
public void registerMultipleClients() {
int concurrentThreads = 5;
int iterations = 10;
int initialTokenCounts = 2;
ClientInitialAccessCreatePresentation clientInitialAccessCreatePresentation = new ClientInitialAccessCreatePresentation();
clientInitialAccessCreatePresentation.setCount(initialTokenCounts);
clientInitialAccessCreatePresentation.setExpiration(10000);
ClientInitialAccessPresentation response = adminClient.realm(REALM_NAME).clientInitialAccess().create(clientInitialAccessCreatePresentation);
ExecutorService threadPool = Executors.newFixedThreadPool(concurrentThreads);
AtomicInteger createdCount = new AtomicInteger();
try {
Collection<Callable<Void>> futures = new LinkedList<>();
for (int i = 0; i < iterations; i ++) {
final int j = i;
Callable<Void> f = () -> {
ClientRegistration client = ClientRegistration.create().url(suiteContext.getAuthServerInfo().getContextRoot() + "/auth", "test").build();
client.auth(Auth.token(response));
ClientRepresentation rep = new ClientRepresentation();
rep.setClientId("test-" + j);
rep = client.create(rep);
if(rep.getId() != null && rep.getClientId().equals("test-" + j)) {
createdCount.getAndIncrement();
}
return null;
};
futures.add(f);
}
threadPool.invokeAll(futures);
} catch (Exception ex) {
fail(ex.getMessage());
}
//controls the number of uses of the initial access token
assertEquals(initialTokenCounts, createdCount.get());
}
}