diff --git a/server-spi/src/main/java/org/keycloak/models/OAuth2DeviceCodeModel.java b/server-spi/src/main/java/org/keycloak/models/OAuth2DeviceCodeModel.java index 0d532a1f74..979220e61f 100755 --- a/server-spi/src/main/java/org/keycloak/models/OAuth2DeviceCodeModel.java +++ b/server-spi/src/main/java/org/keycloak/models/OAuth2DeviceCodeModel.java @@ -20,6 +20,7 @@ import org.keycloak.common.util.Time; import javax.ws.rs.core.MultivaluedHashMap; import javax.ws.rs.core.MultivaluedMap; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -35,6 +36,7 @@ public class OAuth2DeviceCodeModel { private static final String SCOPE_NOTE = "scope"; private static final String USER_SESSION_ID_NOTE = "uid"; private static final String DENIED_NOTE = "denied"; + private static final String ADDITIONAL_PARAM_PREFIX = "additional_param_"; private final RealmModel realm; private final String clientId; @@ -45,26 +47,27 @@ public class OAuth2DeviceCodeModel { private final String nonce; private final String userSessionId; private final boolean denied; + private final Map additionalParams; public static OAuth2DeviceCodeModel create(RealmModel realm, ClientModel client, - String deviceCode, String scope, String nonce) { + String deviceCode, String scope, String nonce, Map additionalParams) { int expiresIn = realm.getOAuth2DeviceCodeLifespan(); int expiration = Time.currentTime() + expiresIn; int pollingInterval = realm.getOAuth2DevicePollingInterval(); - return new OAuth2DeviceCodeModel(realm, client.getClientId(), deviceCode, scope, nonce, expiration, pollingInterval, null, false); + return new OAuth2DeviceCodeModel(realm, client.getClientId(), deviceCode, scope, nonce, expiration, pollingInterval, null, false, additionalParams); } public OAuth2DeviceCodeModel approve(String userSessionId) { - return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, userSessionId, false); + return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, userSessionId, false, additionalParams); } public OAuth2DeviceCodeModel deny() { - return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, null, true); + return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, null, true, additionalParams); } private OAuth2DeviceCodeModel(RealmModel realm, String clientId, String deviceCode, String scope, String nonce, int expiration, int pollingInterval, - String userSessionId, boolean denied) { + String userSessionId, boolean denied, Map additionalParams) { this.realm = realm; this.clientId = clientId; this.deviceCode = deviceCode; @@ -74,6 +77,7 @@ public class OAuth2DeviceCodeModel { this.pollingInterval = pollingInterval; this.userSessionId = userSessionId; this.denied = denied; + this.additionalParams = additionalParams; } public static OAuth2DeviceCodeModel fromCache(RealmModel realm, String deviceCode, Map data) { @@ -81,15 +85,19 @@ public class OAuth2DeviceCodeModel { } private OAuth2DeviceCodeModel(RealmModel realm, String deviceCode, Map data) { - this.realm = realm; - this.clientId = data.get(CLIENT_ID); - this.deviceCode = deviceCode; - this.scope = data.get(SCOPE_NOTE); - this.nonce = data.get(NONCE_NOTE); - this.expiration = Integer.parseInt(data.get(EXPIRATION_NOTE)); - this.pollingInterval = Integer.parseInt(data.get(POLLING_INTERVAL_NOTE)); - this.userSessionId = data.get(USER_SESSION_ID_NOTE); - this.denied = Boolean.parseBoolean(data.get(DENIED_NOTE)); + this(realm, data.get(CLIENT_ID), deviceCode, data.get(SCOPE_NOTE), data.get(NONCE_NOTE), + Integer.parseInt(data.get(EXPIRATION_NOTE)), Integer.parseInt(data.get(POLLING_INTERVAL_NOTE)), data.get(USER_SESSION_ID_NOTE), + Boolean.parseBoolean(data.get(DENIED_NOTE)), extractAdditionalParams(data)); + } + + private static Map extractAdditionalParams(Map data) { + Map additionalParams = new HashMap<>(); + for (Map.Entry entry : data.entrySet()) { + if (entry.getKey().startsWith(ADDITIONAL_PARAM_PREFIX)) { + additionalParams.put(entry.getKey().substring(ADDITIONAL_PARAM_PREFIX.length()), entry.getValue()); + } + } + return additionalParams; } public String getDeviceCode() { @@ -151,6 +159,7 @@ public class OAuth2DeviceCodeModel { result.put(POLLING_INTERVAL_NOTE, String.valueOf(pollingInterval)); result.put(SCOPE_NOTE, scope); result.put(NONCE_NOTE, nonce); + additionalParams.forEach((key, value) -> result.put(ADDITIONAL_PARAM_PREFIX + key, value)); return result; } @@ -161,6 +170,7 @@ public class OAuth2DeviceCodeModel { result.put(SCOPE_NOTE, scope); result.put(NONCE_NOTE, nonce); result.put(USER_SESSION_ID_NOTE, userSessionId); + additionalParams.forEach((key, value) -> result.put(ADDITIONAL_PARAM_PREFIX + key, value)); return result; } @@ -169,13 +179,17 @@ public class OAuth2DeviceCodeModel { result.put(EXPIRATION_NOTE, String.valueOf(expiration)); result.put(POLLING_INTERVAL_NOTE, String.valueOf(pollingInterval)); result.put(DENIED_NOTE, String.valueOf(denied)); + additionalParams.forEach((key, value) -> result.put(ADDITIONAL_PARAM_PREFIX + key, value)); return result; } public MultivaluedMap getParams() { MultivaluedHashMap params = new MultivaluedHashMap<>(); params.putSingle(SCOPE_NOTE, scope); - params.putSingle(NONCE_NOTE, nonce); + if (nonce != null) { + params.putSingle(NONCE_NOTE, nonce); + } + this.additionalParams.forEach(params::putSingle); return params; } } diff --git a/services/src/main/java/org/keycloak/protocol/oidc/endpoints/OAuth2DeviceAuthorizationEndpoint.java b/services/src/main/java/org/keycloak/protocol/oidc/endpoints/OAuth2DeviceAuthorizationEndpoint.java index 02666913b2..055cc17e02 100644 --- a/services/src/main/java/org/keycloak/protocol/oidc/endpoints/OAuth2DeviceAuthorizationEndpoint.java +++ b/services/src/main/java/org/keycloak/protocol/oidc/endpoints/OAuth2DeviceAuthorizationEndpoint.java @@ -133,13 +133,14 @@ public class OAuth2DeviceAuthorizationEndpoint extends AuthorizationEndpointBase int interval = realm.getOAuth2DevicePollingInterval(); OAuth2DeviceCodeModel deviceCode = OAuth2DeviceCodeModel.create(realm, client, - Base64Url.encode(KeycloakModelUtils.generateSecret()), request.getScope(), request.getNonce()); + Base64Url.encode(KeycloakModelUtils.generateSecret()), request.getScope(), request.getNonce(), request.getAdditionalReqParams()); OAuth2DeviceUserCodeProvider userCodeProvider = session.getProvider(OAuth2DeviceUserCodeProvider.class); String secret = userCodeProvider.generate(); OAuth2DeviceUserCodeModel userCode = new OAuth2DeviceUserCodeModel(realm, deviceCode.getDeviceCode(), - secret); + secret + ); // To inform "expired_token" to the client, the lifespan of the cache provider is longer than device code int lifespanSeconds = expiresIn + interval + 10;