Clear local caches on split-brain heal

Closes #25837

Signed-off-by: Ryan Emerson <remerson@redhat.com>
This commit is contained in:
Ryan Emerson 2024-07-25 14:54:00 +01:00 committed by Alexander Schwartz
parent 17e30e9ec1
commit 8d7e18ec29
6 changed files with 411 additions and 0 deletions

View file

@ -17,6 +17,7 @@
package org.keycloak.cluster.infinispan;
import java.util.Arrays;
import java.util.Collection;
import java.util.Set;
import java.util.concurrent.ExecutorService;
@ -27,9 +28,12 @@ import java.util.stream.Collectors;
import org.infinispan.Cache;
import org.infinispan.client.hotrod.exceptions.HotRodClientException;
import org.infinispan.configuration.cache.CacheMode;
import org.infinispan.lifecycle.ComponentStatus;
import org.infinispan.notifications.Listener;
import org.infinispan.notifications.cachemanagerlistener.annotation.Merged;
import org.infinispan.notifications.cachemanagerlistener.annotation.ViewChanged;
import org.infinispan.notifications.cachemanagerlistener.event.MergeEvent;
import org.infinispan.notifications.cachemanagerlistener.event.ViewChangedEvent;
import org.infinispan.persistence.remote.RemoteStore;
import org.infinispan.remoting.transport.Address;
@ -195,6 +199,19 @@ public class InfinispanClusterProviderFactory implements ClusterProviderFactory,
@Listener
public class ViewChangeListener {
@Merged
public void mergeEvent(MergeEvent event) {
// During split-brain only Keycloak instances contained within the same partition will receive updates via
// the work cache. On split-brain heal it's necessary for us to clear all local caches so that potentially
// stale values are invalidated and subsequent requests are forced to read from the DB.
localExecutor.execute(() ->
Arrays.stream(InfinispanConnectionProvider.LOCAL_CACHE_NAMES)
.map(name -> workCache.getCacheManager().getCache(name))
.filter(cache -> cache.getCacheConfiguration().clustering().cacheMode() == CacheMode.LOCAL)
.forEach(Cache::clear)
);
}
@ViewChanged
public void viewChanged(ViewChangedEvent event) {
Set<String> removedNodesAddresses = convertAddresses(event.getOldMembers());

View file

@ -1414,6 +1414,12 @@
<version>${io.setl.rdf-urdna.version}</version>
</dependency>
<dependency>
<groupId>org.infinispan</groupId>
<artifactId>infinispan-core</artifactId>
<type>test-jar</type>
<version>${infinispan.version}</version>
</dependency>
<dependency>
<groupId>org.infinispan.protostream</groupId>
<artifactId>protostream</artifactId>

View file

@ -102,6 +102,11 @@
<groupId>org.infinispan</groupId>
<artifactId>infinispan-core</artifactId>
</dependency>
<dependency>
<groupId>org.infinispan</groupId>
<artifactId>infinispan-core</artifactId>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>

View file

@ -601,6 +601,13 @@ public abstract class KeycloakModelTest {
});
}
protected void withRealmConsumer(String realmId, BiConsumer<KeycloakSession, RealmModel> what) {
withRealm(realmId, (session, realm) -> {
what.accept(session, realm);
return null;
});
}
protected boolean isUseSameKeycloakSessionFactoryForAllThreads() {
return false;
}

View file

@ -0,0 +1,115 @@
package org.keycloak.testsuite.model.infinispan;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.keycloak.connections.infinispan.InfinispanConnectionProvider.WORK_CACHE_NAME;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Assume;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TestRule;
import org.keycloak.common.Profile;
import org.keycloak.connections.infinispan.InfinispanConnectionProvider;
import org.keycloak.models.Constants;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.cache.CacheRealmProvider;
import org.keycloak.testsuite.model.KeycloakModelTest;
import org.keycloak.testsuite.model.RequireProvider;
/**
Tests to ensure that Keycloak correctly handles various split-brain scenarios when an Embedded Infinispan instance
is used for clustering.
*/
@RequireProvider(CacheRealmProvider.class)
@RequireProvider(InfinispanConnectionProvider.class)
public class EmbeddedInfinispanSplitBrainTest extends KeycloakModelTest {
private String realmId;
@ClassRule
public static final TestRule SKIPPED_PROFILES = (base, description) -> {
// We skip split-brain tests for the REMOTE_CACHE and MULTI_SITE features as neither of these architectures
// utilise embedded distributed/replicated caches
Assume.assumeFalse(Profile.isFeatureEnabled(Profile.Feature.REMOTE_CACHE));
Assume.assumeFalse(Profile.isFeatureEnabled(Profile.Feature.MULTI_SITE));
return base;
};
@Override
public void createEnvironment(KeycloakSession s) {
RealmModel realm = createRealm(s, "test");
realm.setDefaultRole(s.roles().addRealmRole(realm, Constants.DEFAULT_ROLES_ROLE_PREFIX + "-" + realm.getName()));
this.realmId = realm.getId();
s.users().addUser(realm, "user1").setEmail("user1@localhost");
}
/**
* A Test to ensure that when Infinispan recovers from a split-brain event, all Keycloak local caches are cleared
* and subsequent user requests read from the DB.
* <p>
* <a href="https://github.com/keycloak/keycloak/issues/25837" />
*/
@Test
public void testLocalCacheClearedOnMergeEvent() throws InterruptedException {
var numFactories = 2;
var partitionManager = new PartitionManager(numFactories, Set.of(WORK_CACHE_NAME));
var factoryIndex = new AtomicInteger(0);
var addManagerLatch = new CountDownLatch(numFactories);
var splitLatch = new CountDownLatch(1);
var mergeLatch = new CountDownLatch(1);
closeKeycloakSessionFactory();
inIndependentFactories(numFactories, 60, () -> {
var customDisplayName = "custom-value";
var index = factoryIndex.getAndIncrement();
// Init PartitionManager
withRealmConsumer(realmId, (session, realm) -> {
var cm = session.getProvider(InfinispanConnectionProvider.class)
.getCache(InfinispanConnectionProvider.USER_CACHE_NAME)
.getCacheManager();
partitionManager.addManager(index, cm);
addManagerLatch.countDown();
});
awaitLatch(addManagerLatch);
// Split the cluster and update the realm on the 1st partition
if (index == 0) {
partitionManager.splitCluster(new int[]{0}, new int[]{1});
withRealmConsumer(realmId, (session, realm) -> realm.setDisplayNameHtml(customDisplayName));
splitLatch.countDown();
}
awaitLatch(splitLatch);
// Assert that the display name has not been updated on the 2nd partition
if (index == 1) {
withRealmConsumer(realmId, (session, realm) -> assertNotEquals(customDisplayName, realm.getDisplayNameHtml()));
}
// Heal the cluster by merging the two partitions
if (index == 0) {
partitionManager.merge(0, 1);
mergeLatch.countDown();
}
awaitLatch(mergeLatch);
// Ensure that both nodes see the updated realm entity after merge
withRealmConsumer(realmId, (session, realm) -> assertEquals(customDisplayName, realm.getDisplayNameHtml()));
});
}
private void awaitLatch(CountDownLatch latch) {
try {
assertTrue(latch.await(10, TimeUnit.SECONDS));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
}
}

View file

@ -0,0 +1,261 @@
package org.keycloak.testsuite.model.infinispan;
import static org.infinispan.test.TestingUtil.blockUntilViewsReceived;
import static org.infinispan.test.TestingUtil.waitForNoRebalance;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.infinispan.Cache;
import org.infinispan.configuration.global.TransportConfiguration;
import org.infinispan.manager.EmbeddedCacheManager;
import org.infinispan.test.TestingUtil;
import org.jboss.logging.Logger;
import org.jgroups.Address;
import org.jgroups.JChannel;
import org.jgroups.MergeView;
import org.jgroups.View;
import org.jgroups.protocols.DISCARD;
import org.jgroups.protocols.Discovery;
import org.jgroups.protocols.TP;
import org.jgroups.protocols.pbcast.GMS;
import org.jgroups.protocols.pbcast.STABLE;
import org.jgroups.stack.Protocol;
import org.jgroups.stack.ProtocolStack;
import org.jgroups.util.MutableDigest;
public class PartitionManager {
private static final Logger log = Logger.getLogger(PartitionManager.class);
final int numberOfCacheManagers;
final Set<String> cacheNames;
final EmbeddedCacheManager[] cacheManagers;
final AtomicInteger viewId;
volatile Partition[] partitions;
public PartitionManager(int numberOfCacheManagers, Set<String> cacheNames) {
this.numberOfCacheManagers = numberOfCacheManagers;
this.cacheNames = cacheNames;
this.cacheManagers = new EmbeddedCacheManager[2];
this.viewId = new AtomicInteger(5);
}
public void addManager(int index, EmbeddedCacheManager cacheManager) {
this.cacheManagers[index] = cacheManager;
}
public void splitCluster(int[]... parts) {
List<Address> allMembers = channel(0).getView().getMembers();
partitions = new Partition[parts.length];
for (int i = 0; i < parts.length; i++) {
Partition p = new Partition(viewId, allMembers, getCaches());
for (int j : parts[i]) {
p.addNode(channel(j));
}
partitions[i] = p;
p.discardOtherMembers();
}
// Only install the new views after installing DISCARD
// Otherwise broadcasts from the first partition would be visible in the other partitions
for (Partition p : partitions) {
p.partition();
}
}
public void merge(int p1, int p2) {
var partition = partitions[p1];
partition.merge(partitions[p2]);
List<Partition> tmp = new ArrayList<>(Arrays.asList(this.partitions));
if (!tmp.remove(partition)) throw new AssertionError();
this.partitions = tmp.toArray(new Partition[0]);
}
private List<Cache<?, ?>> getCaches() {
return cacheNames.stream()
.flatMap(
name -> Arrays.stream(cacheManagers).map(m -> m.getCache(name))
)
.collect(Collectors.toList());
}
private JChannel channel(int index) {
return TestingUtil.extractJChannel(cacheManagers[index]);
}
private static class Partition {
final AtomicInteger viewId;
final List<Address> allMembers;
final List<Cache<?, ?>> caches;
final List<JChannel> channels = new ArrayList<>();
public Partition(AtomicInteger viewId, List<Address> allMembers, List<Cache<?, ?>> caches) {
this.viewId = viewId;
this.allMembers = allMembers;
this.caches = caches;
}
public void addNode(JChannel c) {
channels.add(c);
}
public void partition() {
log.trace("Partition forming");
disableDiscovery();
installNewView();
assertPartitionFormed();
log.trace("New views installed");
}
private void disableDiscovery() {
channels.forEach(c ->
((Discovery) c.getProtocolStack().findProtocol(Discovery.class)).setClusterName(c.getAddressAsString())
);
}
private void assertPartitionFormed() {
final List<Address> viewMembers = new ArrayList<>();
for (JChannel ac : channels) viewMembers.add(ac.getAddress());
for (JChannel c : channels) {
List<Address> members = c.getView().getMembers();
if (!members.equals(viewMembers)) throw new AssertionError();
}
}
private List<Address> installNewView() {
final List<Address> viewMembers = new ArrayList<>();
for (JChannel c : channels) viewMembers.add(c.getAddress());
View view = View.create(channels.get(0).getAddress(), viewId.incrementAndGet(),
viewMembers.toArray(new Address[0]));
log.trace("Before installing new view...");
for (JChannel c : channels) {
getGms(c).installView(view);
}
return viewMembers;
}
private List<Address> installMergeView(ArrayList<JChannel> view1, ArrayList<JChannel> view2) {
List<Address> allAddresses =
Stream.concat(view1.stream(), view2.stream()).map(JChannel::getAddress).distinct()
.collect(Collectors.toList());
View v1 = toView(view1);
View v2 = toView(view2);
List<View> allViews = new ArrayList<>();
allViews.add(v1);
allViews.add(v2);
// Remove all sent NAKACK2 messages to reproduce ISPN-9291
for (JChannel c : channels) {
STABLE stable = c.getProtocolStack().findProtocol(STABLE.class);
stable.gc();
}
try {
Thread.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
}
MergeView mv = new MergeView(view1.get(0).getAddress(), viewId.incrementAndGet(), allAddresses, allViews);
// Compute the merge digest, without it nodes would request the retransmission of all messages
// Including those that were removed by STABLE earlier
MutableDigest digest = new MutableDigest(allAddresses.toArray(new Address[0]));
for (JChannel c : channels) {
digest.merge(getGms(c).getDigest());
}
for (JChannel c : channels) {
getGms(c).installView(mv, digest);
}
return allMembers;
}
private View toView(ArrayList<JChannel> channels) {
final List<Address> viewMembers = new ArrayList<>();
for (JChannel c : channels) viewMembers.add(c.getAddress());
return View.create(channels.get(0).getAddress(), viewId.incrementAndGet(),
viewMembers.toArray(new Address[0]));
}
private void discardOtherMembers() {
List<Address> outsideMembers = new ArrayList<>();
for (Address a : allMembers) {
boolean inThisPartition = false;
for (JChannel c : channels) {
if (c.getAddress().equals(a)) inThisPartition = true;
}
if (!inThisPartition) outsideMembers.add(a);
}
for (JChannel c : channels) {
DISCARD discard = new DISCARD();
log.tracef("%s discarding messages from %s", c.getAddress(), outsideMembers);
for (Address a : outsideMembers) discard.addIgnoreMember(a);
try {
c.getProtocolStack().insertProtocol(discard, ProtocolStack.Position.ABOVE, TP.class);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
private GMS getGms(JChannel c) {
return c.getProtocolStack().findProtocol(GMS.class);
}
public void merge(Partition partition) {
observeMembers(partition);
partition.observeMembers(this);
ArrayList<JChannel> view1 = new ArrayList<>(channels);
ArrayList<JChannel> view2 = new ArrayList<>(partition.channels);
partition.channels.stream().filter(c -> !channels.contains(c)).forEach(c -> channels.add(c));
installMergeView(view1, view2);
enableDiscovery();
waitForPartitionToForm();
}
private void waitForPartitionToForm() {
var caches = new ArrayList<>(this.caches);
caches.removeIf(c -> !channels.contains(TestingUtil.extractJChannel(c.getCacheManager())));
blockUntilViewsReceived(10000, caches);
waitForNoRebalance(caches);
}
public void enableDiscovery() {
channels.forEach(c -> {
try {
String defaultClusterName = TransportConfiguration.CLUSTER_NAME.getDefaultValue();
((Discovery) c.getProtocolStack().findProtocol(Discovery.class)).setClusterName(defaultClusterName);
} catch (Exception e) {
throw new RuntimeException(e);
}
});
log.trace("Discovery started.");
}
private void observeMembers(Partition partition) {
for (JChannel c : channels) {
List<Protocol> protocols = c.getProtocolStack().getProtocols();
for (Protocol p : protocols) {
if (p instanceof DISCARD) {
for (JChannel oc : partition.channels) {
((DISCARD) p).removeIgnoredMember(oc.getAddress());
}
}
}
}
}
@Override
public String toString() {
StringBuilder addresses = new StringBuilder();
for (JChannel c : channels) addresses.append(c.getAddress()).append(" ");
return "Partition{" + addresses + '}';
}
}
}