Avoid race condition when using initial-access-token
Closes #27294 Signed-off-by: Giuseppe Graziano <g.graziano94@gmail.com>
This commit is contained in:
parent
9300903674
commit
1df60461a9
2 changed files with 61 additions and 8 deletions
|
@ -2376,7 +2376,7 @@ public class RealmAdapter implements StorageProviderRealmModel, JpaModel<RealmEn
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ClientInitialAccessModel getClientInitialAccessModel(String id) {
|
public ClientInitialAccessModel getClientInitialAccessModel(String id) {
|
||||||
ClientInitialAccessEntity entity = em.find(ClientInitialAccessEntity.class, id);
|
ClientInitialAccessEntity entity = em.find(ClientInitialAccessEntity.class, id, LockModeType.PESSIMISTIC_WRITE);
|
||||||
if (entity == null) return null;
|
if (entity == null) return null;
|
||||||
if (!entity.getRealm().getId().equals(realm.getId())) return null;
|
if (!entity.getRealm().getId().equals(realm.getId())) return null;
|
||||||
return entityToModel(entity);
|
return entityToModel(entity);
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.http.impl.client.HttpClientBuilder;
|
||||||
import org.apache.http.util.EntityUtils;
|
import org.apache.http.util.EntityUtils;
|
||||||
import org.hamcrest.CoreMatchers;
|
import org.hamcrest.CoreMatchers;
|
||||||
import org.hamcrest.Matchers;
|
import org.hamcrest.Matchers;
|
||||||
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.keycloak.client.registration.Auth;
|
import org.keycloak.client.registration.Auth;
|
||||||
import org.keycloak.client.registration.ClientRegistration;
|
import org.keycloak.client.registration.ClientRegistration;
|
||||||
|
@ -38,6 +39,8 @@ import org.keycloak.models.utils.KeycloakModelUtils;
|
||||||
import org.keycloak.protocol.oidc.OIDCAdvancedConfigWrapper;
|
import org.keycloak.protocol.oidc.OIDCAdvancedConfigWrapper;
|
||||||
import org.keycloak.protocol.oidc.OIDCLoginProtocol;
|
import org.keycloak.protocol.oidc.OIDCLoginProtocol;
|
||||||
import org.keycloak.protocol.saml.SamlProtocol;
|
import org.keycloak.protocol.saml.SamlProtocol;
|
||||||
|
import org.keycloak.representations.idm.ClientInitialAccessCreatePresentation;
|
||||||
|
import org.keycloak.representations.idm.ClientInitialAccessPresentation;
|
||||||
import org.keycloak.representations.idm.ClientRepresentation;
|
import org.keycloak.representations.idm.ClientRepresentation;
|
||||||
import org.keycloak.representations.idm.OAuth2ErrorRepresentation;
|
import org.keycloak.representations.idm.OAuth2ErrorRepresentation;
|
||||||
import org.keycloak.representations.idm.ProtocolMapperRepresentation;
|
import org.keycloak.representations.idm.ProtocolMapperRepresentation;
|
||||||
|
@ -50,16 +53,24 @@ import org.keycloak.util.JsonSerialization;
|
||||||
|
|
||||||
import jakarta.ws.rs.NotFoundException;
|
import jakarta.ws.rs.NotFoundException;
|
||||||
import jakarta.ws.rs.core.Response;
|
import jakarta.ws.rs.core.Response;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
|
@ -798,4 +809,46 @@ public class ClientRegistrationTest extends AbstractClientRegistrationTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void registerMultipleClients() {
|
||||||
|
|
||||||
|
int concurrentThreads = 5;
|
||||||
|
int iterations = 10;
|
||||||
|
int initialTokenCounts = 2;
|
||||||
|
|
||||||
|
ClientInitialAccessCreatePresentation clientInitialAccessCreatePresentation = new ClientInitialAccessCreatePresentation();
|
||||||
|
clientInitialAccessCreatePresentation.setCount(initialTokenCounts);
|
||||||
|
clientInitialAccessCreatePresentation.setExpiration(10000);
|
||||||
|
ClientInitialAccessPresentation response = adminClient.realm(REALM_NAME).clientInitialAccess().create(clientInitialAccessCreatePresentation);
|
||||||
|
|
||||||
|
ExecutorService threadPool = Executors.newFixedThreadPool(concurrentThreads);
|
||||||
|
AtomicInteger createdCount = new AtomicInteger();
|
||||||
|
try {
|
||||||
|
Collection<Callable<Void>> futures = new LinkedList<>();
|
||||||
|
for (int i = 0; i < iterations; i ++) {
|
||||||
|
final int j = i;
|
||||||
|
|
||||||
|
Callable<Void> f = () -> {
|
||||||
|
ClientRegistration client = ClientRegistration.create().url(suiteContext.getAuthServerInfo().getContextRoot() + "/auth", "test").build();
|
||||||
|
client.auth(Auth.token(response));
|
||||||
|
ClientRepresentation rep = new ClientRepresentation();
|
||||||
|
rep.setClientId("test-" + j);
|
||||||
|
rep = client.create(rep);
|
||||||
|
if(rep.getId() != null && rep.getClientId().equals("test-" + j)) {
|
||||||
|
createdCount.getAndIncrement();
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
futures.add(f);
|
||||||
|
}
|
||||||
|
threadPool.invokeAll(futures);
|
||||||
|
|
||||||
|
} catch (Exception ex) {
|
||||||
|
fail(ex.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
//controls the number of uses of the initial access token
|
||||||
|
assertEquals(initialTokenCounts, createdCount.get());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue