Make sure additional params are passed between device request and user authnetication.

This commit is contained in:
Łukasz Dywicki 2020-05-20 00:04:24 +02:00 committed by Pedro Igor
parent 319195236b
commit f58bf0deeb
2 changed files with 32 additions and 17 deletions

View file

@ -20,6 +20,7 @@ import org.keycloak.common.util.Time;
import javax.ws.rs.core.MultivaluedHashMap; import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.MultivaluedMap;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -35,6 +36,7 @@ public class OAuth2DeviceCodeModel {
private static final String SCOPE_NOTE = "scope"; private static final String SCOPE_NOTE = "scope";
private static final String USER_SESSION_ID_NOTE = "uid"; private static final String USER_SESSION_ID_NOTE = "uid";
private static final String DENIED_NOTE = "denied"; private static final String DENIED_NOTE = "denied";
private static final String ADDITIONAL_PARAM_PREFIX = "additional_param_";
private final RealmModel realm; private final RealmModel realm;
private final String clientId; private final String clientId;
@ -45,26 +47,27 @@ public class OAuth2DeviceCodeModel {
private final String nonce; private final String nonce;
private final String userSessionId; private final String userSessionId;
private final boolean denied; private final boolean denied;
private final Map<String, String> additionalParams;
public static OAuth2DeviceCodeModel create(RealmModel realm, ClientModel client, public static OAuth2DeviceCodeModel create(RealmModel realm, ClientModel client,
String deviceCode, String scope, String nonce) { String deviceCode, String scope, String nonce, Map<String, String> additionalParams) {
int expiresIn = realm.getOAuth2DeviceCodeLifespan(); int expiresIn = realm.getOAuth2DeviceCodeLifespan();
int expiration = Time.currentTime() + expiresIn; int expiration = Time.currentTime() + expiresIn;
int pollingInterval = realm.getOAuth2DevicePollingInterval(); int pollingInterval = realm.getOAuth2DevicePollingInterval();
return new OAuth2DeviceCodeModel(realm, client.getClientId(), deviceCode, scope, nonce, expiration, pollingInterval, null, false); return new OAuth2DeviceCodeModel(realm, client.getClientId(), deviceCode, scope, nonce, expiration, pollingInterval, null, false, additionalParams);
} }
public OAuth2DeviceCodeModel approve(String userSessionId) { public OAuth2DeviceCodeModel approve(String userSessionId) {
return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, userSessionId, false); return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, userSessionId, false, additionalParams);
} }
public OAuth2DeviceCodeModel deny() { public OAuth2DeviceCodeModel deny() {
return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, null, true); return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, null, true, additionalParams);
} }
private OAuth2DeviceCodeModel(RealmModel realm, String clientId, private OAuth2DeviceCodeModel(RealmModel realm, String clientId,
String deviceCode, String scope, String nonce, int expiration, int pollingInterval, String deviceCode, String scope, String nonce, int expiration, int pollingInterval,
String userSessionId, boolean denied) { String userSessionId, boolean denied, Map<String, String> additionalParams) {
this.realm = realm; this.realm = realm;
this.clientId = clientId; this.clientId = clientId;
this.deviceCode = deviceCode; this.deviceCode = deviceCode;
@ -74,6 +77,7 @@ public class OAuth2DeviceCodeModel {
this.pollingInterval = pollingInterval; this.pollingInterval = pollingInterval;
this.userSessionId = userSessionId; this.userSessionId = userSessionId;
this.denied = denied; this.denied = denied;
this.additionalParams = additionalParams;
} }
public static OAuth2DeviceCodeModel fromCache(RealmModel realm, String deviceCode, Map<String, String> data) { public static OAuth2DeviceCodeModel fromCache(RealmModel realm, String deviceCode, Map<String, String> data) {
@ -81,15 +85,19 @@ public class OAuth2DeviceCodeModel {
} }
private OAuth2DeviceCodeModel(RealmModel realm, String deviceCode, Map<String, String> data) { private OAuth2DeviceCodeModel(RealmModel realm, String deviceCode, Map<String, String> data) {
this.realm = realm; this(realm, data.get(CLIENT_ID), deviceCode, data.get(SCOPE_NOTE), data.get(NONCE_NOTE),
this.clientId = data.get(CLIENT_ID); Integer.parseInt(data.get(EXPIRATION_NOTE)), Integer.parseInt(data.get(POLLING_INTERVAL_NOTE)), data.get(USER_SESSION_ID_NOTE),
this.deviceCode = deviceCode; Boolean.parseBoolean(data.get(DENIED_NOTE)), extractAdditionalParams(data));
this.scope = data.get(SCOPE_NOTE); }
this.nonce = data.get(NONCE_NOTE);
this.expiration = Integer.parseInt(data.get(EXPIRATION_NOTE)); private static Map<String, String> extractAdditionalParams(Map<String, String> data) {
this.pollingInterval = Integer.parseInt(data.get(POLLING_INTERVAL_NOTE)); Map<String, String> additionalParams = new HashMap<>();
this.userSessionId = data.get(USER_SESSION_ID_NOTE); for (Map.Entry<String, String> entry : data.entrySet()) {
this.denied = Boolean.parseBoolean(data.get(DENIED_NOTE)); if (entry.getKey().startsWith(ADDITIONAL_PARAM_PREFIX)) {
additionalParams.put(entry.getKey().substring(ADDITIONAL_PARAM_PREFIX.length()), entry.getValue());
}
}
return additionalParams;
} }
public String getDeviceCode() { public String getDeviceCode() {
@ -151,6 +159,7 @@ public class OAuth2DeviceCodeModel {
result.put(POLLING_INTERVAL_NOTE, String.valueOf(pollingInterval)); result.put(POLLING_INTERVAL_NOTE, String.valueOf(pollingInterval));
result.put(SCOPE_NOTE, scope); result.put(SCOPE_NOTE, scope);
result.put(NONCE_NOTE, nonce); result.put(NONCE_NOTE, nonce);
additionalParams.forEach((key, value) -> result.put(ADDITIONAL_PARAM_PREFIX + key, value));
return result; return result;
} }
@ -161,6 +170,7 @@ public class OAuth2DeviceCodeModel {
result.put(SCOPE_NOTE, scope); result.put(SCOPE_NOTE, scope);
result.put(NONCE_NOTE, nonce); result.put(NONCE_NOTE, nonce);
result.put(USER_SESSION_ID_NOTE, userSessionId); result.put(USER_SESSION_ID_NOTE, userSessionId);
additionalParams.forEach((key, value) -> result.put(ADDITIONAL_PARAM_PREFIX + key, value));
return result; return result;
} }
@ -169,13 +179,17 @@ public class OAuth2DeviceCodeModel {
result.put(EXPIRATION_NOTE, String.valueOf(expiration)); result.put(EXPIRATION_NOTE, String.valueOf(expiration));
result.put(POLLING_INTERVAL_NOTE, String.valueOf(pollingInterval)); result.put(POLLING_INTERVAL_NOTE, String.valueOf(pollingInterval));
result.put(DENIED_NOTE, String.valueOf(denied)); result.put(DENIED_NOTE, String.valueOf(denied));
additionalParams.forEach((key, value) -> result.put(ADDITIONAL_PARAM_PREFIX + key, value));
return result; return result;
} }
public MultivaluedMap<String, String> getParams() { public MultivaluedMap<String, String> getParams() {
MultivaluedHashMap<String, String> params = new MultivaluedHashMap<>(); MultivaluedHashMap<String, String> params = new MultivaluedHashMap<>();
params.putSingle(SCOPE_NOTE, scope); params.putSingle(SCOPE_NOTE, scope);
params.putSingle(NONCE_NOTE, nonce); if (nonce != null) {
params.putSingle(NONCE_NOTE, nonce);
}
this.additionalParams.forEach(params::putSingle);
return params; return params;
} }
} }

View file

@ -133,13 +133,14 @@ public class OAuth2DeviceAuthorizationEndpoint extends AuthorizationEndpointBase
int interval = realm.getOAuth2DevicePollingInterval(); int interval = realm.getOAuth2DevicePollingInterval();
OAuth2DeviceCodeModel deviceCode = OAuth2DeviceCodeModel.create(realm, client, OAuth2DeviceCodeModel deviceCode = OAuth2DeviceCodeModel.create(realm, client,
Base64Url.encode(KeycloakModelUtils.generateSecret()), request.getScope(), request.getNonce()); Base64Url.encode(KeycloakModelUtils.generateSecret()), request.getScope(), request.getNonce(), request.getAdditionalReqParams());
OAuth2DeviceUserCodeProvider userCodeProvider = session.getProvider(OAuth2DeviceUserCodeProvider.class); OAuth2DeviceUserCodeProvider userCodeProvider = session.getProvider(OAuth2DeviceUserCodeProvider.class);
String secret = userCodeProvider.generate(); String secret = userCodeProvider.generate();
OAuth2DeviceUserCodeModel userCode = new OAuth2DeviceUserCodeModel(realm, OAuth2DeviceUserCodeModel userCode = new OAuth2DeviceUserCodeModel(realm,
deviceCode.getDeviceCode(), deviceCode.getDeviceCode(),
secret); secret
);
// To inform "expired_token" to the client, the lifespan of the cache provider is longer than device code // To inform "expired_token" to the client, the lifespan of the cache provider is longer than device code
int lifespanSeconds = expiresIn + interval + 10; int lifespanSeconds = expiresIn + interval + 10;