Make sure additional params are passed between device request and user authnetication.

This commit is contained in:
Łukasz Dywicki 2020-05-20 00:04:24 +02:00 committed by Pedro Igor
parent 319195236b
commit f58bf0deeb
2 changed files with 32 additions and 17 deletions

View file

@ -20,6 +20,7 @@ import org.keycloak.common.util.Time;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.MultivaluedMap;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@ -35,6 +36,7 @@ public class OAuth2DeviceCodeModel {
private static final String SCOPE_NOTE = "scope";
private static final String USER_SESSION_ID_NOTE = "uid";
private static final String DENIED_NOTE = "denied";
private static final String ADDITIONAL_PARAM_PREFIX = "additional_param_";
private final RealmModel realm;
private final String clientId;
@ -45,26 +47,27 @@ public class OAuth2DeviceCodeModel {
private final String nonce;
private final String userSessionId;
private final boolean denied;
private final Map<String, String> additionalParams;
public static OAuth2DeviceCodeModel create(RealmModel realm, ClientModel client,
String deviceCode, String scope, String nonce) {
String deviceCode, String scope, String nonce, Map<String, String> additionalParams) {
int expiresIn = realm.getOAuth2DeviceCodeLifespan();
int expiration = Time.currentTime() + expiresIn;
int pollingInterval = realm.getOAuth2DevicePollingInterval();
return new OAuth2DeviceCodeModel(realm, client.getClientId(), deviceCode, scope, nonce, expiration, pollingInterval, null, false);
return new OAuth2DeviceCodeModel(realm, client.getClientId(), deviceCode, scope, nonce, expiration, pollingInterval, null, false, additionalParams);
}
public OAuth2DeviceCodeModel approve(String userSessionId) {
return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, userSessionId, false);
return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, userSessionId, false, additionalParams);
}
public OAuth2DeviceCodeModel deny() {
return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, null, true);
return new OAuth2DeviceCodeModel(realm, clientId, deviceCode, scope, nonce, expiration, pollingInterval, null, true, additionalParams);
}
private OAuth2DeviceCodeModel(RealmModel realm, String clientId,
String deviceCode, String scope, String nonce, int expiration, int pollingInterval,
String userSessionId, boolean denied) {
String userSessionId, boolean denied, Map<String, String> additionalParams) {
this.realm = realm;
this.clientId = clientId;
this.deviceCode = deviceCode;
@ -74,6 +77,7 @@ public class OAuth2DeviceCodeModel {
this.pollingInterval = pollingInterval;
this.userSessionId = userSessionId;
this.denied = denied;
this.additionalParams = additionalParams;
}
public static OAuth2DeviceCodeModel fromCache(RealmModel realm, String deviceCode, Map<String, String> data) {
@ -81,15 +85,19 @@ public class OAuth2DeviceCodeModel {
}
private OAuth2DeviceCodeModel(RealmModel realm, String deviceCode, Map<String, String> data) {
this.realm = realm;
this.clientId = data.get(CLIENT_ID);
this.deviceCode = deviceCode;
this.scope = data.get(SCOPE_NOTE);
this.nonce = data.get(NONCE_NOTE);
this.expiration = Integer.parseInt(data.get(EXPIRATION_NOTE));
this.pollingInterval = Integer.parseInt(data.get(POLLING_INTERVAL_NOTE));
this.userSessionId = data.get(USER_SESSION_ID_NOTE);
this.denied = Boolean.parseBoolean(data.get(DENIED_NOTE));
this(realm, data.get(CLIENT_ID), deviceCode, data.get(SCOPE_NOTE), data.get(NONCE_NOTE),
Integer.parseInt(data.get(EXPIRATION_NOTE)), Integer.parseInt(data.get(POLLING_INTERVAL_NOTE)), data.get(USER_SESSION_ID_NOTE),
Boolean.parseBoolean(data.get(DENIED_NOTE)), extractAdditionalParams(data));
}
private static Map<String, String> extractAdditionalParams(Map<String, String> data) {
Map<String, String> additionalParams = new HashMap<>();
for (Map.Entry<String, String> entry : data.entrySet()) {
if (entry.getKey().startsWith(ADDITIONAL_PARAM_PREFIX)) {
additionalParams.put(entry.getKey().substring(ADDITIONAL_PARAM_PREFIX.length()), entry.getValue());
}
}
return additionalParams;
}
public String getDeviceCode() {
@ -151,6 +159,7 @@ public class OAuth2DeviceCodeModel {
result.put(POLLING_INTERVAL_NOTE, String.valueOf(pollingInterval));
result.put(SCOPE_NOTE, scope);
result.put(NONCE_NOTE, nonce);
additionalParams.forEach((key, value) -> result.put(ADDITIONAL_PARAM_PREFIX + key, value));
return result;
}
@ -161,6 +170,7 @@ public class OAuth2DeviceCodeModel {
result.put(SCOPE_NOTE, scope);
result.put(NONCE_NOTE, nonce);
result.put(USER_SESSION_ID_NOTE, userSessionId);
additionalParams.forEach((key, value) -> result.put(ADDITIONAL_PARAM_PREFIX + key, value));
return result;
}
@ -169,13 +179,17 @@ public class OAuth2DeviceCodeModel {
result.put(EXPIRATION_NOTE, String.valueOf(expiration));
result.put(POLLING_INTERVAL_NOTE, String.valueOf(pollingInterval));
result.put(DENIED_NOTE, String.valueOf(denied));
additionalParams.forEach((key, value) -> result.put(ADDITIONAL_PARAM_PREFIX + key, value));
return result;
}
public MultivaluedMap<String, String> getParams() {
MultivaluedHashMap<String, String> params = new MultivaluedHashMap<>();
params.putSingle(SCOPE_NOTE, scope);
if (nonce != null) {
params.putSingle(NONCE_NOTE, nonce);
}
this.additionalParams.forEach(params::putSingle);
return params;
}
}

View file

@ -133,13 +133,14 @@ public class OAuth2DeviceAuthorizationEndpoint extends AuthorizationEndpointBase
int interval = realm.getOAuth2DevicePollingInterval();
OAuth2DeviceCodeModel deviceCode = OAuth2DeviceCodeModel.create(realm, client,
Base64Url.encode(KeycloakModelUtils.generateSecret()), request.getScope(), request.getNonce());
Base64Url.encode(KeycloakModelUtils.generateSecret()), request.getScope(), request.getNonce(), request.getAdditionalReqParams());
OAuth2DeviceUserCodeProvider userCodeProvider = session.getProvider(OAuth2DeviceUserCodeProvider.class);
String secret = userCodeProvider.generate();
OAuth2DeviceUserCodeModel userCode = new OAuth2DeviceUserCodeModel(realm,
deviceCode.getDeviceCode(),
secret);
secret
);
// To inform "expired_token" to the client, the lifespan of the cache provider is longer than device code
int lifespanSeconds = expiresIn + interval + 10;