Remove realm model storage from OAuth2DeviceConfig class to avoid persisting old session and entity manager in infinispan fixes keycloak/keycloak#23943

This commit is contained in:
Alice Wood 2023-10-12 12:14:28 -04:00 committed by Alexander Schwartz
parent f9386bd62b
commit 5a76ddfc2e
3 changed files with 19 additions and 33 deletions

View file

@ -260,8 +260,8 @@ public class LegacyExportImportManager implements ExportImportManager {
// OAuth 2.0 Device Authorization Grant
OAuth2DeviceConfig deviceConfig = newRealm.getOAuth2DeviceConfig();
deviceConfig.setOAuth2DeviceCodeLifespan(rep.getOAuth2DeviceCodeLifespan());
deviceConfig.setOAuth2DevicePollingInterval(rep.getOAuth2DevicePollingInterval());
deviceConfig.setOAuth2DeviceCodeLifespan(newRealm, rep.getOAuth2DeviceCodeLifespan());
deviceConfig.setOAuth2DevicePollingInterval(newRealm, rep.getOAuth2DevicePollingInterval());
if (rep.getSslRequired() != null)
newRealm.setSslRequired(SslRequired.valueOf(rep.getSslRequired().toUpperCase()));
@ -764,8 +764,8 @@ public class LegacyExportImportManager implements ExportImportManager {
OAuth2DeviceConfig deviceConfig = realm.getOAuth2DeviceConfig();
deviceConfig.setOAuth2DeviceCodeLifespan(rep.getOAuth2DeviceCodeLifespan());
deviceConfig.setOAuth2DevicePollingInterval(rep.getOAuth2DevicePollingInterval());
deviceConfig.setOAuth2DeviceCodeLifespan(realm, rep.getOAuth2DeviceCodeLifespan());
deviceConfig.setOAuth2DevicePollingInterval(realm, rep.getOAuth2DevicePollingInterval());
if (rep.getNotBefore() != null) realm.setNotBefore(rep.getNotBefore());
if (rep.getDefaultSignatureAlgorithm() != null) realm.setDefaultSignatureAlgorithm(rep.getDefaultSignatureAlgorithm());

View file

@ -266,8 +266,8 @@ public class MapExportImportManager implements ExportImportManager {
// OAuth 2.0 Device Authorization Grant
OAuth2DeviceConfig deviceConfig = newRealm.getOAuth2DeviceConfig();
deviceConfig.setOAuth2DeviceCodeLifespan(rep.getOAuth2DeviceCodeLifespan());
deviceConfig.setOAuth2DevicePollingInterval(rep.getOAuth2DevicePollingInterval());
deviceConfig.setOAuth2DeviceCodeLifespan(newRealm, rep.getOAuth2DeviceCodeLifespan());
deviceConfig.setOAuth2DevicePollingInterval(newRealm, rep.getOAuth2DevicePollingInterval());
if (rep.getSslRequired() != null)
newRealm.setSslRequired(SslRequired.valueOf(rep.getSslRequired().toUpperCase()));
@ -1053,8 +1053,8 @@ public class MapExportImportManager implements ExportImportManager {
OAuth2DeviceConfig deviceConfig = realm.getOAuth2DeviceConfig();
deviceConfig.setOAuth2DeviceCodeLifespan(rep.getOAuth2DeviceCodeLifespan());
deviceConfig.setOAuth2DevicePollingInterval(rep.getOAuth2DevicePollingInterval());
deviceConfig.setOAuth2DeviceCodeLifespan(realm, rep.getOAuth2DeviceCodeLifespan());
deviceConfig.setOAuth2DevicePollingInterval(realm, rep.getOAuth2DevicePollingInterval());
if (rep.getNotBefore() != null) realm.setNotBefore(rep.getNotBefore());
if (rep.getDefaultSignatureAlgorithm() != null) realm.setDefaultSignatureAlgorithm(rep.getDefaultSignatureAlgorithm());

View file

@ -41,17 +41,10 @@ public final class OAuth2DeviceConfig implements Serializable {
public static String OAUTH2_DEVICE_POLLING_INTERVAL_PER_CLIENT = "oauth2.device.polling.interval";
public static final String OAUTH2_DEVICE_AUTHORIZATION_GRANT_ENABLED = "oauth2.device.authorization.grant.enabled";
private transient Supplier<RealmModel> realm;
// Make sure setters are not called when calling this from constructor to avoid DB updates
private transient Supplier<RealmModel> realmForWrite;
private int lifespan = DEFAULT_OAUTH2_DEVICE_CODE_LIFESPAN;
private int poolingInterval = DEFAULT_OAUTH2_DEVICE_POLLING_INTERVAL;
public OAuth2DeviceConfig(RealmModel realm) {
this.realm = () -> realm;
String lifespan = realm.getAttribute(OAUTH2_DEVICE_CODE_LIFESPAN);
if (lifespan != null && !lifespan.trim().isEmpty()) {
@ -63,8 +56,6 @@ public final class OAuth2DeviceConfig implements Serializable {
if (pooling != null && !pooling.trim().isEmpty()) {
setOAuth2DevicePollingInterval(Integer.parseInt(pooling));
}
this.realmForWrite = () -> realm;
}
public int getLifespan() {
@ -72,11 +63,15 @@ public final class OAuth2DeviceConfig implements Serializable {
}
public void setOAuth2DeviceCodeLifespan(Integer seconds) {
setOAuth2DeviceCodeLifespan(null, seconds);
}
public void setOAuth2DeviceCodeLifespan(RealmModel realm, Integer seconds) {
if (seconds == null) {
seconds = DEFAULT_OAUTH2_DEVICE_CODE_LIFESPAN;
}
this.lifespan = seconds;
persistRealmAttribute(OAUTH2_DEVICE_CODE_LIFESPAN, lifespan);
persistRealmAttribute(realm, OAUTH2_DEVICE_CODE_LIFESPAN, lifespan);
}
public int getPoolingInterval() {
@ -84,14 +79,16 @@ public final class OAuth2DeviceConfig implements Serializable {
}
public void setOAuth2DevicePollingInterval(Integer seconds) {
setOAuth2DevicePollingInterval(null, seconds);
}
public void setOAuth2DevicePollingInterval(RealmModel realm, Integer seconds) {
if (seconds == null) {
seconds = DEFAULT_OAUTH2_DEVICE_POLLING_INTERVAL;
}
this.poolingInterval = seconds;
RealmModel model = getRealm();
persistRealmAttribute(OAUTH2_DEVICE_POLLING_INTERVAL, poolingInterval);
persistRealmAttribute(realm, OAUTH2_DEVICE_POLLING_INTERVAL, poolingInterval);
}
public int getLifespan(ClientModel client) {
@ -119,18 +116,7 @@ public final class OAuth2DeviceConfig implements Serializable {
return Boolean.parseBoolean(enabled);
}
private RealmModel getRealm() {
RealmModel model = realm.get();
if (model == null) {
throw new RuntimeException("Can only update after invalidating the realm");
}
return model;
}
private void persistRealmAttribute(String name, Integer value) {
RealmModel realm = realmForWrite == null ? null : this.realmForWrite.get();
private void persistRealmAttribute(RealmModel realm, String name, Integer value) {
if (realm != null) {
realm.setAttribute(name, value);
}