KEYCLOAK-10112: Issues in loading offline session in a cluster environment during startup

This commit is contained in:
rmartinc 2019-04-23 18:08:41 +02:00 committed by Marek Posolda
parent 53d0db80c3
commit bd5dec1830
19 changed files with 496 additions and 196 deletions

View file

@ -66,7 +66,7 @@ public class DBLockBasedCacheInitializer extends CacheInitializer {
DBLockManager dbLockManager = new DBLockManager(session);
dbLockManager.checkForcedUnlock();
DBLockProvider dbLock = dbLockManager.getDBLock();
dbLock.waitForLock();
dbLock.waitForLock(DBLockProvider.Namespace.OFFLINE_SESSIONS);
try {
if (isFinished()) {

View file

@ -119,17 +119,24 @@ public class InfinispanCacheInitializer extends BaseCacheInitializer {
ExecutorService executorService = distributed ? new DefaultExecutorService(workCache, localExecutor) : localExecutor;
int errors = 0;
int segmentToLoad = 0;
try {
List<SessionLoader.WorkerResult> previousResults = new LinkedList<>();
SessionLoader.WorkerResult previousResult = null;
SessionLoader.WorkerResult nextResult = null;
int distributedWorkersCount = 0;
boolean firstTryForSegment = true;
while (!state.isFinished()) {
while (segmentToLoad < state.getSegmentsCount()) {
if (firstTryForSegment) {
// do not change the node count if it's not the first try
int nodesCount = transport==null ? 1 : transport.getMembers().size();
int distributedWorkersCount = processors * nodesCount;
distributedWorkersCount = processors * nodesCount;
}
log.debugf("Starting next iteration with %d workers", distributedWorkersCount);
List<Integer> segments = state.getUnfinishedSegments(distributedWorkersCount);
List<Integer> segments = state.getSegmentsToLoad(segmentToLoad, distributedWorkersCount);
if (log.isTraceEnabled()) {
log.trace("unfinished segments for this iteration: " + segments);
@ -137,9 +144,8 @@ public class InfinispanCacheInitializer extends BaseCacheInitializer {
List<Future<SessionLoader.WorkerResult>> futures = new LinkedList<>();
int workerId = 0;
for (Integer segment : segments) {
SessionLoader.WorkerContext workerCtx = sessionLoader.computeWorkerContext(loaderCtx, segment, workerId, previousResults);
SessionLoader.WorkerContext workerCtx = sessionLoader.computeWorkerContext(loaderCtx, segment, segment - segmentToLoad, previousResult);
SessionInitializerWorker worker = new SessionInitializerWorker();
worker.setWorkerEnvironment(loaderCtx, workerCtx, sessionLoader);
@ -150,17 +156,19 @@ public class InfinispanCacheInitializer extends BaseCacheInitializer {
Future<SessionLoader.WorkerResult> future = executorService.submit(worker);
futures.add(future);
workerId++;
}
boolean anyFailure = false;
for (Future<SessionLoader.WorkerResult> future : futures) {
try {
SessionLoader.WorkerResult result = future.get();
previousResults.add(result);
if (!result.isSuccess()) {
if (result.isSuccess()) {
state.markSegmentFinished(result.getSegment());
if (result.getSegment() == segmentToLoad + distributedWorkersCount - 1) {
// last result for next iteration when complete
nextResult = result;
}
} else {
if (log.isTraceEnabled()) {
log.tracef("Segment %d failed to compute", result.getSegment());
}
@ -181,14 +189,19 @@ public class InfinispanCacheInitializer extends BaseCacheInitializer {
throw new RuntimeException("Maximum count of worker errors occured. Limit was " + maxErrors + ". See server.log for details");
}
// Save just if no error happened. Otherwise re-compute
if (!anyFailure) {
for (SessionLoader.WorkerResult result : previousResults) {
state.markSegmentFinished(result.getSegment());
}
// everything is OK, prepare the new row
segmentToLoad += distributedWorkersCount;
firstTryForSegment = true;
previousResult = nextResult;
nextResult = null;
if (log.isTraceEnabled()) {
log.debugf("New initializer state is: %s", state);
}
} else {
// some segments failed, try to load unloaded segments
firstTryForSegment = false;
}
}
// Push the state after computation is finished

View file

@ -44,15 +44,12 @@ public class InitializerState extends SessionEntity {
private final int segmentsCount;
private final BitSet segments;
private int lowestUnfinishedSegment = 0;
public InitializerState(int segmentsCount) {
this.segmentsCount = segmentsCount;
this.segments = new BitSet(segmentsCount);
log.debugf("segmentsCount: %d", segmentsCount);
updateLowestUnfinishedSegment();
}
private InitializerState(String realmId, int segmentsCount, BitSet segments) {
@ -61,8 +58,14 @@ public class InitializerState extends SessionEntity {
this.segments = segments;
log.debugf("segmentsCount: %d", segmentsCount);
}
updateLowestUnfinishedSegment();
/**
* Getter for the segments count.
* @return The number of segments of the state
*/
public int getSegmentsCount() {
return segmentsCount;
}
/** Return true just if computation is entirely finished (all segments are true) */
@ -70,39 +73,23 @@ public class InitializerState extends SessionEntity {
return segments.cardinality() == segmentsCount;
}
/** Return next un-finished segments. It returns at most {@code maxSegmentCount} segments. */
public List<Integer> getUnfinishedSegments(int maxSegmentCount) {
/** Return next un-finished segments in the next row of segments.
* @param segmentToLoad The segment we are loading
* @param maxSegmentCount The max segment to load
* @return The list of segments to work on this step
*/
public List<Integer> getSegmentsToLoad(int segmentToLoad, int maxSegmentCount) {
List<Integer> result = new LinkedList<>();
int next = lowestUnfinishedSegment;
boolean remaining = lowestUnfinishedSegment != -1;
while (remaining && result.size() < maxSegmentCount) {
next = getNextUnfinishedSegmentFromIndex(next);
if (next == -1) {
remaining = false;
} else {
result.add(next);
next++;
for (int i = segmentToLoad; i < (segmentToLoad + maxSegmentCount) && i < segmentsCount; i++) {
if (!segments.get(i)) {
result.add(i);
}
}
return result;
}
public void markSegmentFinished(int index) {
segments.set(index);
updateLowestUnfinishedSegment();
}
private void updateLowestUnfinishedSegment() {
this.lowestUnfinishedSegment = getNextUnfinishedSegmentFromIndex(lowestUnfinishedSegment);
}
private int getNextUnfinishedSegmentFromIndex(int index) {
final int nextFreeSegment = this.segments.nextClearBit(index);
return (nextFreeSegment < this.segmentsCount)
? nextFreeSegment
: -1;
}
@Override
@ -119,7 +106,6 @@ public class InitializerState extends SessionEntity {
int hash = 3;
hash = 97 * hash + this.segmentsCount;
hash = 97 * hash + Objects.hashCode(this.segments);
hash = 97 * hash + this.lowestUnfinishedSegment;
return hash;
}
@ -138,9 +124,6 @@ public class InitializerState extends SessionEntity {
if (this.segmentsCount != other.segmentsCount) {
return false;
}
if (this.lowestUnfinishedSegment != other.lowestUnfinishedSegment) {
return false;
}
if ( ! Objects.equals(this.segments, other.segments)) {
return false;
}

View file

@ -70,16 +70,15 @@ public class OfflinePersistentUserSessionLoader implements SessionLoader<Offline
@Override
public OfflinePersistentWorkerContext computeWorkerContext(OfflinePersistentLoaderContext loaderCtx, int segment, int workerId, List<OfflinePersistentWorkerResult> previousResults) {
public OfflinePersistentWorkerContext computeWorkerContext(OfflinePersistentLoaderContext loaderCtx, int segment, int workerId, OfflinePersistentWorkerResult previousResult) {
int lastCreatedOn;
String lastSessionId;
if (previousResults.isEmpty()) {
if (previousResult == null) {
lastCreatedOn = 0;
lastSessionId = FIRST_SESSION_ID;
} else {
OfflinePersistentWorkerResult lastResult = previousResults.get(previousResults.size() - 1);
lastCreatedOn = lastResult.getLastCreatedOn();
lastSessionId = lastResult.getLastSessionId();
lastCreatedOn = previousResult.getLastCreatedOn();
lastSessionId = previousResult.getLastSessionId();
}
// We know the last loaded session. New workers iteration will start from this place
@ -97,12 +96,12 @@ public class OfflinePersistentUserSessionLoader implements SessionLoader<Offline
public OfflinePersistentWorkerResult loadSessions(KeycloakSession session, OfflinePersistentLoaderContext loaderContext, OfflinePersistentWorkerContext ctx) {
int first = ctx.getWorkerId() * sessionsPerSegment;
log.tracef("Loading sessions for segment: %d", ctx.getSegment());
log.tracef("Loading sessions for segment=%d createdOn=%d lastSessionId=%s", ctx.getSegment(), ctx.getLastCreatedOn(), ctx.getLastSessionId());
UserSessionPersisterProvider persister = session.getProvider(UserSessionPersisterProvider.class);
List<UserSessionModel> sessions = persister.loadUserSessions(first, sessionsPerSegment, true, ctx.getLastCreatedOn(), ctx.getLastSessionId());
log.tracef("Sessions loaded from DB - segment: %d", ctx.getSegment());
log.tracef("Sessions loaded from DB - segment=%d createdOn=%d lastSessionId=%s", ctx.getSegment(), ctx.getLastCreatedOn(), ctx.getLastSessionId());
UserSessionModel lastSession = null;
if (!sessions.isEmpty()) {

View file

@ -58,10 +58,10 @@ public interface SessionLoader<LOADER_CONTEXT extends SessionLoader.LoaderContex
* @param loaderCtx global loader context
* @param segment the current segment (page) to compute
* @param workerId ID of worker for current worker iteration. Usually the number 0-8 (with single cluster node)
* @param previousResults workerResults from previous computation. Can be empty list in case of the operation is triggered for the 1st time
* @param previousResult last workerResult from previous computation. Can be empty list in case of the operation is triggered for the 1st time
* @return
*/
WORKER_CONTEXT computeWorkerContext(LOADER_CONTEXT loaderCtx, int segment, int workerId, List<WORKER_RESULT> previousResults);
WORKER_CONTEXT computeWorkerContext(LOADER_CONTEXT loaderCtx, int segment, int workerId, WORKER_RESULT previousResult);
/**

View file

@ -94,7 +94,7 @@ public class RemoteCacheSessionsLoader implements SessionLoader<RemoteCacheSessi
@Override
public WorkerContext computeWorkerContext(RemoteCacheSessionsLoaderContext loaderCtx, int segment, int workerId, List<WorkerResult> previousResults) {
public WorkerContext computeWorkerContext(RemoteCacheSessionsLoaderContext loaderCtx, int segment, int workerId, WorkerResult previousResult) {
return new WorkerContext(segment, workerId);
}

View file

@ -83,22 +83,22 @@ public class InitializerStateTest {
InitializerState state = new InitializerState(ctx.getSegmentsCount());
Assert.assertFalse(state.isFinished());
List<Integer> segments = state.getUnfinishedSegments(3);
List<Integer> segments = state.getSegmentsToLoad(0, 3);
assertContains(segments, 3, 0, 1, 2);
state.markSegmentFinished(1);
state.markSegmentFinished(2);
segments = state.getUnfinishedSegments(4);
assertContains(segments, 4, 0, 3, 4, 5);
segments = state.getSegmentsToLoad(0, 3);
assertContains(segments, 1, 0);
state.markSegmentFinished(0);
state.markSegmentFinished(3);
segments = state.getUnfinishedSegments(4);
segments = state.getSegmentsToLoad(4, 4);
assertContains(segments, 2, 4, 5);
state.markSegmentFinished(4);
state.markSegmentFinished(5);
segments = state.getUnfinishedSegments(4);
segments = state.getSegmentsToLoad(4, 4);
Assert.assertTrue(segments.isEmpty());
Assert.assertTrue(state.isFinished());
}

View file

@ -326,16 +326,12 @@ public class DefaultJpaConnectionProviderFactory implements JpaConnectionProvide
}
protected void update(Connection connection, String schema, KeycloakSession session, JpaUpdaterProvider updater) {
DBLockProvider dbLock = new DBLockManager(session).getDBLock();
if (dbLock.hasLock()) {
updater.update(connection, schema);
} else {
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), new KeycloakSessionTask() {
@Override
public void run(KeycloakSession lockSession) {
DBLockManager dbLockManager = new DBLockManager(lockSession);
DBLockProvider dbLock2 = dbLockManager.getDBLock();
dbLock2.waitForLock();
dbLock2.waitForLock(DBLockProvider.Namespace.DATABASE);
try {
updater.update(connection, schema);
} finally {
@ -344,19 +340,14 @@ public class DefaultJpaConnectionProviderFactory implements JpaConnectionProvide
}
});
}
}
protected void export(Connection connection, String schema, File databaseUpdateFile, KeycloakSession session, JpaUpdaterProvider updater) {
DBLockProvider dbLock = new DBLockManager(session).getDBLock();
if (dbLock.hasLock()) {
updater.export(connection, schema, databaseUpdateFile);
} else {
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), new KeycloakSessionTask() {
@Override
public void run(KeycloakSession lockSession) {
DBLockManager dbLockManager = new DBLockManager(lockSession);
DBLockProvider dbLock2 = dbLockManager.getDBLock();
dbLock2.waitForLock();
dbLock2.waitForLock(DBLockProvider.Namespace.DATABASE);
try {
updater.export(connection, schema, databaseUpdateFile);
} finally {
@ -365,7 +356,6 @@ public class DefaultJpaConnectionProviderFactory implements JpaConnectionProvide
}
});
}
}
@Override
public Connection getConnection() {

View file

@ -0,0 +1,38 @@
/*
* Copyright 2019 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.connections.jpa.updater.liquibase.lock;
import java.util.Set;
import liquibase.statement.core.InitializeDatabaseChangeLogLockTableStatement;
/**
*
* @author rmartinc
*/
public class CustomInitializeDatabaseChangeLogLockTableStatement extends InitializeDatabaseChangeLogLockTableStatement {
private final Set<Integer> currentIds;
public CustomInitializeDatabaseChangeLogLockTableStatement(Set<Integer> currentIds) {
this.currentIds = currentIds;
}
public Set<Integer> getCurrentIds() {
return currentIds;
}
}

View file

@ -25,7 +25,13 @@ import liquibase.sqlgenerator.core.AbstractSqlGenerator;
import liquibase.statement.core.InitializeDatabaseChangeLogLockTableStatement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import liquibase.sqlgenerator.SqlGeneratorFactory;
import liquibase.statement.core.InsertStatement;
import org.keycloak.models.dblock.DBLockProvider;
/**
* We need to remove DELETE SQL command, which liquibase adds by default when inserting record to table lock. This is causing buggy behaviour
@ -46,15 +52,20 @@ public class CustomInsertLockRecordGenerator extends AbstractSqlGenerator<Initia
@Override
public Sql[] generateSql(InitializeDatabaseChangeLogLockTableStatement statement, Database database, SqlGeneratorChain sqlGeneratorChain) {
// Generated by InitializeDatabaseChangeLogLockTableGenerator
Sql[] sqls = sqlGeneratorChain.generateSql(statement, database);
// get the IDs that are already in the database if migration
Set<Integer> currentIds = new HashSet<>();
if (statement instanceof CustomInitializeDatabaseChangeLogLockTableStatement) {
currentIds = ((CustomInitializeDatabaseChangeLogLockTableStatement) statement).getCurrentIds();
}
// Removing delete statement
// generate all the IDs that are currently missing in the lock table
List<Sql> result = new ArrayList<>();
for (Sql sql : sqls) {
String sqlCommand = sql.toSql();
if (!sqlCommand.toUpperCase().contains("DELETE")) {
result.add(sql);
for (DBLockProvider.Namespace lock : DBLockProvider.Namespace.values()) {
if (!currentIds.contains(lock.getId())) {
InsertStatement insertStatement = new InsertStatement(database.getLiquibaseCatalogName(), database.getLiquibaseSchemaName(), database.getDatabaseChangeLogLockTableName())
.addColumnValue("ID", lock.getId())
.addColumnValue("LOCKED", Boolean.FALSE);
result.addAll(Arrays.asList(SqlGeneratorFactory.getInstance().generateSql(insertStatement, database)));
}
}

View file

@ -48,13 +48,15 @@ public class CustomLockDatabaseChangeLogGenerator extends LockDatabaseChangeLogG
@Override
public Sql[] generateSql(LockDatabaseChangeLogStatement statement, Database database, SqlGeneratorChain sqlGeneratorChain) {
Sql selectForUpdateSql = generateSelectForUpdate(database);
Sql selectForUpdateSql = generateSelectForUpdate(database,
(statement instanceof CustomLockDatabaseChangeLogStatement)?
((CustomLockDatabaseChangeLogStatement) statement).getId() : 1);
return new Sql[] { selectForUpdateSql };
}
private Sql generateSelectForUpdate(Database database) {
private Sql generateSelectForUpdate(Database database, int id) {
String catalog = database.getLiquibaseCatalogName();
String schema = database.getLiquibaseSchemaName();
String rawLockTableName = database.getDatabaseChangeLogLockTableName();
@ -63,7 +65,7 @@ public class CustomLockDatabaseChangeLogGenerator extends LockDatabaseChangeLogG
String idColumnName = database.escapeColumnName(catalog, schema, rawLockTableName, "ID");
String sqlBase = "SELECT " + idColumnName + " FROM " + lockTableName;
String sqlWhere = " WHERE " + idColumnName + "=1";
String sqlWhere = " WHERE " + idColumnName + "=" + id;
String sql;
if (database instanceof MySQLDatabase || database instanceof PostgresDatabase || database instanceof H2Database ||

View file

@ -0,0 +1,38 @@
/*
* Copyright 2019 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.connections.jpa.updater.liquibase.lock;
import liquibase.statement.core.LockDatabaseChangeLogStatement;
/**
*
* @author rmartinc
*/
public class CustomLockDatabaseChangeLogStatement extends LockDatabaseChangeLogStatement {
final private int id;
public CustomLockDatabaseChangeLogStatement(int id) {
this.id = id;
}
public int getId() {
return id;
}
}

View file

@ -31,8 +31,15 @@ import liquibase.statement.core.RawSqlStatement;
import org.jboss.logging.Logger;
import org.keycloak.common.util.Time;
import org.keycloak.common.util.reflections.Reflections;
import org.keycloak.models.dblock.DBLockProvider;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import liquibase.statement.SqlStatement;
/**
* Liquibase lock service, which has some bugfixes and assumes timeouts to be configured in milliseconds
@ -45,7 +52,6 @@ public class CustomLockService extends StandardLockService {
@Override
public void init() throws DatabaseException {
boolean createdTable = false;
Executor executor = ExecutorService.getInstance().getExecutor(database);
if (!hasDatabaseChangeLogLockTable()) {
@ -74,17 +80,15 @@ public class CustomLockService extends StandardLockService {
} catch (IllegalAccessException iae) {
throw new RuntimeException(iae);
}
createdTable = true;
}
try {
if (!isDatabaseChangeLogLockTableInitialized(createdTable)) {
Set<Integer> currentIds = currentIdsInDatabaseChangeLogLockTable();
if (!currentIds.containsAll(Arrays.asList(DBLockProvider.Namespace.values()))) {
if (log.isTraceEnabled()) {
log.trace("Initialize Database Lock Table");
log.tracef("Initialize Database Lock Table, current locks %s", currentIds);
}
executor.execute(new InitializeDatabaseChangeLogLockTableStatement());
executor.execute(new CustomInitializeDatabaseChangeLogLockTableStatement(currentIds));
database.commit();
log.debug("Initialized record in the database lock table");
@ -113,6 +117,32 @@ public class CustomLockService extends StandardLockService {
}
private Set<Integer> currentIdsInDatabaseChangeLogLockTable() throws DatabaseException {
try {
Executor executor = ExecutorService.getInstance().getExecutor(database);
String idColumnName = database.escapeColumnName(database.getLiquibaseCatalogName(),
database.getLiquibaseSchemaName(),
database.getDatabaseChangeLogLockTableName(),
"ID");
String lockTableName = database.escapeTableName(database.getLiquibaseCatalogName(),
database.getLiquibaseSchemaName(),
database.getDatabaseChangeLogLockTableName());
SqlStatement sqlStatement = new RawSqlStatement("SELECT " + idColumnName + " FROM " + lockTableName);
List<Map<String, ?>> rows = executor.queryForList(sqlStatement);
Set<Integer> ids = rows.stream().map(columnMap -> ((Number) columnMap.get("ID")).intValue()).collect(Collectors.toSet());
database.commit();
return ids;
} catch (UnexpectedLiquibaseException ulie) {
// It can happen with MariaDB Galera 10.1 that UnexpectedLiquibaseException is rethrown due the DB lock.
// It is sufficient to just rollback transaction and retry in that case.
if (ulie.getCause() != null && ulie.getCause() instanceof DatabaseException) {
throw (DatabaseException) ulie.getCause();
} else {
throw ulie;
}
}
}
@Override
public boolean isDatabaseChangeLogLockTableInitialized(boolean tableJustCreated) throws DatabaseException {
try {
@ -129,13 +159,21 @@ public class CustomLockService extends StandardLockService {
@Override
public void waitForLock() {
waitForLock(new LockDatabaseChangeLogStatement());
}
public void waitForLock(DBLockProvider.Namespace lock) {
waitForLock(new CustomLockDatabaseChangeLogStatement(lock.getId()));
}
private void waitForLock(LockDatabaseChangeLogStatement lockStmt) {
boolean locked = false;
long startTime = Time.toMillis(Time.currentTime());
long timeToGiveUp = startTime + (getChangeLogLockWaitTime());
boolean nextAttempt = true;
while (nextAttempt) {
locked = acquireLock();
locked = acquireLock(lockStmt);
if (!locked) {
int remainingTime = ((int)(timeToGiveUp / 1000)) - Time.currentTime();
if (remainingTime > 0) {
@ -156,6 +194,10 @@ public class CustomLockService extends StandardLockService {
@Override
public boolean acquireLock() {
return acquireLock(new LockDatabaseChangeLogStatement());
}
private boolean acquireLock(LockDatabaseChangeLogStatement lockStmt) {
if (hasChangeLogLock) {
// We already have a lock
return true;
@ -174,7 +216,7 @@ public class CustomLockService extends StandardLockService {
try {
log.debug("Trying to lock database");
executor.execute(new LockDatabaseChangeLogStatement());
executor.execute(lockStmt);
log.debug("Successfully acquired database lock");
hasChangeLogLock = true;

View file

@ -49,6 +49,7 @@ public class LiquibaseDBLockProvider implements DBLockProvider {
private CustomLockService lockService;
private Connection dbConnection;
private boolean initialized = false;
private Namespace namespaceLocked = null;
public LiquibaseDBLockProvider(LiquibaseDBLockProviderFactory factory, KeycloakSession session) {
this.factory = factory;
@ -88,17 +89,26 @@ public class LiquibaseDBLockProvider implements DBLockProvider {
lazyInit();
}
@Override
public void waitForLock() {
public void waitForLock(Namespace lock) {
KeycloakModelUtils.suspendJtaTransaction(session.getKeycloakSessionFactory(), () -> {
lazyInit();
if (this.lockService.hasChangeLogLock()) {
if (lock.equals(this.namespaceLocked)) {
logger.warnf("Locking namespace %s which was already locked in this provider", lock);
return;
} else {
throw new RuntimeException(String.format("Trying to get a lock when one was already taken by the provider"));
}
}
logger.debugf("Going to lock namespace=%s", lock);
Retry.executeWithBackoff((int iteration) -> {
lockService.waitForLock();
factory.setHasLock(true);
lockService.waitForLock(lock);
namespaceLocked = lock;
}, (int iteration, Throwable e) -> {
@ -116,21 +126,21 @@ public class LiquibaseDBLockProvider implements DBLockProvider {
}
@Override
public void releaseLock() {
KeycloakModelUtils.suspendJtaTransaction(session.getKeycloakSessionFactory(), () -> {
lazyInit();
logger.debugf("Going to release database lock namespace=%s", namespaceLocked);
namespaceLocked = null;
lockService.releaseLock();
lockService.reset();
factory.setHasLock(false);
});
}
@Override
public boolean hasLock() {
return factory.hasLock();
public Namespace getCurrentLock() {
return this.namespaceLocked;
}
@Override

View file

@ -24,8 +24,6 @@ import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.dblock.DBLockProviderFactory;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
@ -35,9 +33,6 @@ public class LiquibaseDBLockProviderFactory implements DBLockProviderFactory {
private long lockWaitTimeoutMillis;
// True if this node has a lock acquired
private AtomicBoolean hasLock = new AtomicBoolean(false);
protected long getLockWaitTimeoutMillis() {
return lockWaitTimeoutMillis;
}
@ -73,12 +68,4 @@ public class LiquibaseDBLockProviderFactory implements DBLockProviderFactory {
public String getId() {
return "jpa";
}
public boolean hasLock() {
return hasLock.get();
}
public void setHasLock(boolean hasLock) {
this.hasLock.set(hasLock);
}
}

View file

@ -20,39 +20,69 @@ package org.keycloak.models.dblock;
import org.keycloak.provider.Provider;
/**
* Global database lock to ensure that some actions in DB can be done just be one cluster node at a time.
* <p>Global database lock to ensure that some actions in DB can be done just be
* one cluster node at a time.</p>
*
* <p>There are different namespaces that can be locked. The same <em>DBLockProvider</em>
* (same session in keycloack) can only be used to lock one namespace, a second
* attempt will throw a <em>RuntimeException</em>. The <em>hasLock</em> method
* returns the local namespace locked by this provider.</p>
*
* <p>Different <em>DBLockProvider</em> instances can be used to lock in
* different threads. Note that the <em>DBLockProvider</em> is associated to
* the session (so in order to have different lock providers different sessions
* are needed).</p>
*
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public interface DBLockProvider extends Provider {
/**
* Try to retrieve DB lock or wait if retrieve was unsuccessful. Throw exception if lock can't be retrieved within specified timeout (900 seconds by default)
* Lock namespace to have different lock types or contexts.
*/
void waitForLock();
public enum Namespace {
DATABASE(1),
KEYCLOAK_BOOT(1000),
OFFLINE_SESSIONS(1001);
private final int id;
private Namespace(int id) {
this.id = id;
}
public int getId() {
return id;
}
};
/**
* Release previously acquired lock
* Try to retrieve DB lock or wait if retrieve was unsuccessful.
* Throw exception if lock can't be retrieved within specified timeout (900 seconds by default)
* Throw exception if a different namespace has already been locked by this provider.
*
* @param lock The namespace to lock
*/
void waitForLock(Namespace lock);
/**
* Release previously acquired lock by this provider.
*/
void releaseLock();
/**
* Check if I have lock
* Returns the current provider namespace locked or null
*
* @return
* @return The namespace locked or null if there is no lock
*/
boolean hasLock();
Namespace getCurrentLock();
/**
* @return true if provider supports forced unlock at startup
*/
boolean supportsForcedUnlock();
/**
* Will destroy whole state of DB lock (drop table/collection to track locking).
* */

View file

@ -141,7 +141,7 @@ public class KeycloakApplication extends Application {
DBLockManager dbLockManager = new DBLockManager(lockSession);
dbLockManager.checkForcedUnlock();
DBLockProvider dbLock = dbLockManager.getDBLock();
dbLock.waitForLock();
dbLock.waitForLock(DBLockProvider.Namespace.KEYCLOAK_BOOT);
try {
exportImportManager[0] = migrateAndBootstrap();
} finally {

View file

@ -27,7 +27,6 @@ import org.junit.Test;
import org.keycloak.admin.client.resource.UserResource;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.KeycloakSessionTask;
import org.keycloak.models.dblock.DBLockManager;
import org.keycloak.models.dblock.DBLockProvider;
import org.keycloak.models.dblock.DBLockProviderFactory;
@ -61,7 +60,10 @@ public class DBLockTest extends AbstractTestRealmKeycloakTest {
private static final int SLEEP_TIME_MILLIS = 10;
private static final int THREADS_COUNT = 20;
private static final int THREADS_COUNT_MEDIUM = 12;
private static final int ITERATIONS_PER_THREAD = 2;
private static final int ITERATIONS_PER_THREAD_MEDIUM = 4;
private static final int ITERATIONS_PER_THREAD_LONG = 20;
private static final int LOCK_TIMEOUT_MILLIS = 240000; // Rather bigger to handle slow DB connections in testing env
private static final int LOCK_RECHECK_MILLIS = 10;
@ -83,9 +85,152 @@ public class DBLockTest extends AbstractTestRealmKeycloakTest {
@Test
@ModelTest
public void testLockConcurrently(KeycloakSession session) throws Exception {
public void simpleLockTest(KeycloakSession session) throws Exception {
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionLC) -> {
DBLockProvider dbLock = new DBLockManager(sessionLC).getDBLock();
dbLock.waitForLock(DBLockProvider.Namespace.DATABASE);
try {
Assert.assertEquals(DBLockProvider.Namespace.DATABASE, dbLock.getCurrentLock());
} finally {
dbLock.releaseLock();
}
Assert.assertNull(dbLock.getCurrentLock());
});
}
@Test
@ModelTest
public void simpleNestedLockTest(KeycloakSession session) throws Exception {
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionLC) -> {
// first session lock DATABASE
DBLockProvider dbLock1 = new DBLockManager(sessionLC).getDBLock();
dbLock1.waitForLock(DBLockProvider.Namespace.DATABASE);
try {
Assert.assertEquals(DBLockProvider.Namespace.DATABASE, dbLock1.getCurrentLock());
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionLC2) -> {
// a second session/dblock-provider can lock another namespace OFFLINE_SESSIONS
DBLockProvider dbLock2 = new DBLockManager(sessionLC2).getDBLock();
dbLock2.waitForLock(DBLockProvider.Namespace.OFFLINE_SESSIONS);
try {
// getCurrentLock is local, each provider instance has one
Assert.assertEquals(DBLockProvider.Namespace.OFFLINE_SESSIONS, dbLock2.getCurrentLock());
} finally {
dbLock2.releaseLock();
}
Assert.assertNull(dbLock2.getCurrentLock());
});
} finally {
dbLock1.releaseLock();
}
Assert.assertNull(dbLock1.getCurrentLock());
});
}
@Test
@ModelTest
public void testLockConcurrentlyGeneral(KeycloakSession session) throws Exception {
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionLC) -> {
testLockConcurrentlyInternal(sessionLC, DBLockProvider.Namespace.DATABASE);
});
}
@Test
@ModelTest
public void testLockConcurrentlyOffline(KeycloakSession session) throws Exception {
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionLC) -> {
testLockConcurrentlyInternal(sessionLC, DBLockProvider.Namespace.OFFLINE_SESSIONS);
});
}
@Test
@ModelTest
public void testTwoLocksCurrently(KeycloakSession session) throws Exception {
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionLC) -> {
testTwoLocksCurrentlyInternal(sessionLC, DBLockProvider.Namespace.DATABASE, DBLockProvider.Namespace.OFFLINE_SESSIONS);
});
}
@Test
@ModelTest
public void testTwoNestedLocksCurrently(KeycloakSession session) throws Exception {
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(), (KeycloakSession sessionLC) -> {
testTwoNestedLocksCurrentlyInternal(sessionLC, DBLockProvider.Namespace.KEYCLOAK_BOOT, DBLockProvider.Namespace.DATABASE);
});
}
private void testTwoLocksCurrentlyInternal(KeycloakSession sessionLC, DBLockProvider.Namespace lock1, DBLockProvider.Namespace lock2) {
final Semaphore semaphore = new Semaphore();
final KeycloakSessionFactory sessionFactory = sessionLC.getKeycloakSessionFactory();
List<Thread> threads = new LinkedList<>();
// launch two threads and expect an error because the locks are different
for (int i = 0; i < 2; i++) {
final DBLockProvider.Namespace lock = (i % 2 == 0)? lock1 : lock2;
Thread thread = new Thread(() -> {
for (int j = 0; j < ITERATIONS_PER_THREAD_LONG; j++) {
try {
KeycloakModelUtils.runJobInTransaction(sessionFactory, session1 -> lock(session1, lock, semaphore));
} catch (RuntimeException e) {
semaphore.setException(e);
}
}
});
threads.add(thread);
}
for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
try {
thread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
// interference is needed because different namespaces can interfere
Assert.assertNotNull(semaphore.getException());
}
private void testTwoNestedLocksCurrentlyInternal(KeycloakSession sessionLC, DBLockProvider.Namespace lockTop, DBLockProvider.Namespace lockInner) {
final Semaphore semaphore = new Semaphore();
final KeycloakSessionFactory sessionFactory = sessionLC.getKeycloakSessionFactory();
List<Thread> threads = new LinkedList<>();
// launch two threads and expect an error because the locks are different
for (int i = 0; i < THREADS_COUNT_MEDIUM; i++) {
final boolean nested = i % 2 == 0;
Thread thread = new Thread(() -> {
for (int j = 0; j < ITERATIONS_PER_THREAD_MEDIUM; j++) {
try {
if (nested) {
// half the threads run two level lock top-inner
KeycloakModelUtils.runJobInTransaction(sessionFactory,
session1 -> nestedTwoLevelLock(session1, lockTop, lockInner, semaphore));
} else {
// the other half only run a lock in the top namespace
KeycloakModelUtils.runJobInTransaction(sessionFactory,
session1 -> lock(session1, lockTop, semaphore));
}
} catch (RuntimeException e) {
semaphore.setException(e);
}
}
});
threads.add(thread);
}
for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
try {
thread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
Assert.assertEquals(THREADS_COUNT_MEDIUM * ITERATIONS_PER_THREAD_MEDIUM, semaphore.getTotal());
Assert.assertNull(semaphore.getException());
}
private void testLockConcurrentlyInternal(KeycloakSession sessionLC, DBLockProvider.Namespace lock) {
long startupTime = System.currentTimeMillis();
final Semaphore semaphore = new Semaphore();
@ -98,7 +243,7 @@ public class DBLockTest extends AbstractTestRealmKeycloakTest {
for (int j = 0; j < ITERATIONS_PER_THREAD; j++) {
try {
KeycloakModelUtils.runJobInTransaction(sessionFactory, session1 ->
lock(session1, semaphore));
lock(session1, lock, semaphore));
} catch (RuntimeException e) {
semaphore.setException(e);
throw e;
@ -123,14 +268,13 @@ public class DBLockTest extends AbstractTestRealmKeycloakTest {
long took = (System.currentTimeMillis() - startupTime);
log.infof("DBLockTest executed in %d ms with total counter %d. THREADS_COUNT=%d, ITERATIONS_PER_THREAD=%d", took, semaphore.getTotal(), THREADS_COUNT, ITERATIONS_PER_THREAD);
Assert.assertEquals(semaphore.getTotal(), THREADS_COUNT * ITERATIONS_PER_THREAD);
Assert.assertEquals(THREADS_COUNT * ITERATIONS_PER_THREAD, semaphore.getTotal());
Assert.assertNull(semaphore.getException());
});
}
private void lock(KeycloakSession session, Semaphore semaphore) {
private void lock(KeycloakSession session, DBLockProvider.Namespace lock, Semaphore semaphore) {
DBLockProvider dbLock = new DBLockManager(session).getDBLock();
dbLock.waitForLock();
dbLock.waitForLock(lock);
try {
semaphore.increase();
Thread.sleep(SLEEP_TIME_MILLIS);
@ -142,6 +286,19 @@ public class DBLockTest extends AbstractTestRealmKeycloakTest {
}
}
private void nestedTwoLevelLock(KeycloakSession session, DBLockProvider.Namespace lockTop,
DBLockProvider.Namespace lockInner, Semaphore semaphore) {
DBLockProvider dbLock = new DBLockManager(session).getDBLock();
dbLock.waitForLock(lockTop);
try {
// create a new session to call the lock method with the inner namespace
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(),
sessionInner -> lock(sessionInner, lockInner, semaphore));
} finally {
dbLock.releaseLock();
}
}
@Override
public void configureTestRealm(RealmRepresentation testRealm) {
}

View file

@ -643,7 +643,7 @@
<auth.server.undertow>false</auth.server.undertow>
<auth.server.config.property.value>standalone.xml</auth.server.config.property.value>
<auth.server.config.dir>${auth.server.home}/standalone/configuration</auth.server.config.dir>
<h2.version>1.3.173</h2.version>
<h2.version>1.4.193</h2.version>
<surefire.memory.Xmx>1024m</surefire.memory.Xmx>
</properties>
<dependencies>
@ -668,7 +668,7 @@
<auth.server.undertow>false</auth.server.undertow>
<auth.server.config.property.value>standalone.xml</auth.server.config.property.value>
<auth.server.config.dir>${auth.server.home}/standalone/configuration</auth.server.config.dir>
<h2.version>1.3.173</h2.version>
<h2.version>1.4.193</h2.version>
<surefire.memory.Xmx>1024m</surefire.memory.Xmx>
</properties>
<dependencies>