KEYCLOAK-3189 KEYCLOAK-3190 Add kid and typ to JWT header

This commit is contained in:
Stian Thorgersen 2016-07-05 07:58:13 +02:00
parent 435cdb6180
commit 7cfee80e58
11 changed files with 98 additions and 3 deletions

View file

@ -21,6 +21,7 @@ import com.fasterxml.jackson.core.type.TypeReference;
import org.keycloak.common.util.Base64Url;
import org.keycloak.util.JsonSerialization;
import java.io.InputStream;
import java.math.BigInteger;
import java.security.KeyFactory;
import java.security.PublicKey;

View file

@ -33,6 +33,7 @@ import java.security.PrivateKey;
*/
public class JWSBuilder {
String type;
String kid;
String contentType;
byte[] contentBytes;
@ -41,6 +42,11 @@ public class JWSBuilder {
return this;
}
public JWSBuilder kid(String kid) {
this.kid = kid;
return this;
}
public JWSBuilder contentType(String type) {
this.contentType = type;
return this;
@ -66,6 +72,7 @@ public class JWSBuilder {
builder.append("\"alg\":\"").append(alg.toString()).append("\"");
if (type != null) builder.append(",\"typ\" : \"").append(type).append("\"");
if (kid != null) builder.append(",\"kid\" : \"").append(kid).append("\"");
if (contentType != null) builder.append(",\"cty\":\"").append(contentType).append("\"");
builder.append("}");
try {

View file

@ -402,6 +402,12 @@ public class RealmAdapter implements RealmModel {
updated.setAccessCodeLifespanLogin(seconds);
}
@Override
public String getKeyId() {
if (isUpdated()) return updated.getKeyId();
return cached.getKeyId();
}
@Override
public String getPublicKeyPem() {
if (isUpdated()) return updated.getPublicKeyPem();

View file

@ -93,6 +93,7 @@ public class CachedRealm extends AbstractRevisioned {
protected PasswordPolicy passwordPolicy;
protected OTPPolicy otpPolicy;
protected transient String keyId;
protected transient PublicKey publicKey;
protected String publicKeyPem;
protected transient PrivateKey privateKey;
@ -189,6 +190,7 @@ public class CachedRealm extends AbstractRevisioned {
passwordPolicy = model.getPasswordPolicy();
otpPolicy = model.getOTPPolicy();
keyId = model.getKeyId();
publicKeyPem = model.getPublicKeyPem();
publicKey = model.getPublicKey();
privateKeyPem = model.getPrivateKeyPem();
@ -397,6 +399,10 @@ public class CachedRealm extends AbstractRevisioned {
return accessCodeLifespanLogin;
}
public String getKeyId() {
return keyId;
}
public String getPublicKeyPem() {
return publicKeyPem;
}

View file

@ -20,6 +20,7 @@ package org.keycloak.models.jpa;
import org.jboss.logging.Logger;
import org.keycloak.connections.jpa.util.JpaUtils;
import org.keycloak.common.enums.SslRequired;
import org.keycloak.jose.jwk.JWKBuilder;
import org.keycloak.models.AuthenticationExecutionModel;
import org.keycloak.models.AuthenticationFlowModel;
import org.keycloak.models.AuthenticatorConfigModel;
@ -459,6 +460,12 @@ public class RealmAdapter implements RealmModel, JpaModel<RealmEntity> {
em.flush();
}
@Override
public String getKeyId() {
PublicKey publicKey = getPublicKey();
return publicKey != null ? JWKBuilder.create().rs256(publicKey).getKeyId() : null;
}
@Override
public String getPublicKeyPem() {
return realm.getPublicKeyPem();

View file

@ -22,6 +22,7 @@ import com.mongodb.QueryBuilder;
import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext;
import org.keycloak.common.enums.SslRequired;
import org.keycloak.jose.jwk.JWKBuilder;
import org.keycloak.models.AuthenticationExecutionModel;
import org.keycloak.models.AuthenticationFlowModel;
import org.keycloak.models.AuthenticatorConfigModel;
@ -453,6 +454,12 @@ public class RealmAdapter extends AbstractMongoAdapter<MongoRealmEntity> impleme
return realm.getAccessCodeLifespanLogin();
}
@Override
public String getKeyId() {
PublicKey publicKey = getPublicKey();
return publicKey != null ? JWKBuilder.create().rs256(publicKey).getKeyId() : null;
}
@Override
public String getPublicKeyPem() {
return realm.getPublicKeyPem();

View file

@ -151,6 +151,8 @@ public interface RealmModel extends RoleContainerModel {
void setAccessCodeLifespanLogin(int seconds);
String getKeyId();
String getPublicKeyPem();
void setPublicKeyPem(String publicKeyPem);

View file

@ -99,6 +99,11 @@ public class OIDCLoginProtocolService {
return uriBuilder.path(OIDCLoginProtocolService.class, "token");
}
public static UriBuilder certsUrl(UriBuilder baseUriBuilder) {
UriBuilder uriBuilder = tokenServiceBaseUrl(baseUriBuilder);
return uriBuilder.path(OIDCLoginProtocolService.class, "certs");
}
public static UriBuilder tokenIntrospectionUrl(UriBuilder baseUriBuilder) {
return tokenUrl(baseUriBuilder).path(TokenEndpoint.class, "introspect");
}

View file

@ -78,6 +78,7 @@ import java.util.Set;
*/
public class TokenManager {
protected static final ServicesLogger logger = ServicesLogger.ROOT_LOGGER;
private static final String JWT = "JWT";
public static void applyScope(RoleModel role, RoleModel scope, Set<RoleModel> visited, Set<RoleModel> requested) {
if (visited.contains(scope)) return;
@ -570,6 +571,8 @@ public class TokenManager {
public String encodeToken(RealmModel realm, Object token) {
String encodedToken = new JWSBuilder()
.type(JWT)
.kid(realm.getKeyId())
.jsonContent(token)
.rsa256(realm.getPrivateKey());
return encodedToken;
@ -680,11 +683,11 @@ public class TokenManager {
AccessTokenResponse res = new AccessTokenResponse();
if (idToken != null) {
String encodedToken = new JWSBuilder().jsonContent(idToken).rsa256(realm.getPrivateKey());
String encodedToken = new JWSBuilder().type(JWT).kid(realm.getKeyId()).jsonContent(idToken).rsa256(realm.getPrivateKey());
res.setIdToken(encodedToken);
}
if (accessToken != null) {
String encodedToken = new JWSBuilder().jsonContent(accessToken).rsa256(realm.getPrivateKey());
String encodedToken = new JWSBuilder().type(JWT).kid(realm.getKeyId()).jsonContent(accessToken).rsa256(realm.getPrivateKey());
res.setToken(encodedToken);
res.setTokenType("bearer");
res.setSessionState(accessToken.getSessionState());
@ -693,7 +696,7 @@ public class TokenManager {
}
}
if (refreshToken != null) {
String encodedToken = new JWSBuilder().jsonContent(refreshToken).rsa256(realm.getPrivateKey());
String encodedToken = new JWSBuilder().type(JWT).kid(realm.getKeyId()).jsonContent(refreshToken).rsa256(realm.getPrivateKey());
res.setRefreshToken(encodedToken);
if (refreshToken.getExpiration() != 0) {
res.setRefreshExpiresIn(refreshToken.getExpiration() - Time.currentTime());

View file

@ -22,6 +22,8 @@ import org.apache.commons.io.output.ByteArrayOutputStream;
import org.apache.http.HttpResponse;
import org.apache.http.NameValuePair;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URLEncodedUtils;
import org.apache.http.impl.client.CloseableHttpClient;
@ -34,9 +36,13 @@ import org.keycloak.admin.client.Keycloak;
import org.keycloak.common.VerificationException;
import org.keycloak.common.util.PemUtils;
import org.keycloak.constants.AdapterConstants;
import org.keycloak.jose.jwk.JWK;
import org.keycloak.jose.jwk.JWKBuilder;
import org.keycloak.jose.jwk.JWKParser;
import org.keycloak.jose.jws.JWSInput;
import org.keycloak.jose.jws.crypto.RSAProvider;
import org.keycloak.protocol.oidc.OIDCLoginProtocolService;
import org.keycloak.protocol.oidc.representations.JSONWebKeySet;
import org.keycloak.representations.AccessToken;
import org.keycloak.representations.RefreshToken;
import org.keycloak.testsuite.arquillian.AuthServerTestEnricher;
@ -279,6 +285,17 @@ public class OAuthClient {
}
}
public JSONWebKeySet doCertsRequest(String realm) throws Exception {
CloseableHttpClient client = new DefaultHttpClient();
try {
HttpGet get = new HttpGet(getCertsUrl(realm));
CloseableHttpResponse response = client.execute(get);
return JsonSerialization.readValue(response.getEntity().getContent(), JSONWebKeySet.class);
} finally {
closeClient(client);
}
}
public AccessTokenResponse doClientCredentialsGrantAccessTokenRequest(String clientSecret) throws Exception {
CloseableHttpClient client = new DefaultHttpClient();
try {
@ -503,6 +520,11 @@ public class OAuthClient {
return b.build(realm).toString();
}
public String getCertsUrl(String realm) {
UriBuilder b = OIDCLoginProtocolService.certsUrl(UriBuilder.fromUri(baseUrl));
return b.build(realm).toString();
}
public String getServiceAccountUrl() {
return getResourceOwnerPasswordCredentialGrantUrl();
}
@ -591,6 +613,7 @@ public class OAuthClient {
public static class AccessTokenResponse {
private int statusCode;
private String idToken;
private String accessToken;
private String tokenType;
private int expiresIn;
@ -610,6 +633,7 @@ public class OAuthClient {
Map responseJson = JsonSerialization.readValue(s, Map.class);
if (statusCode == 200) {
idToken = (String)responseJson.get("id_token");
accessToken = (String)responseJson.get("access_token");
tokenType = (String)responseJson.get("token_type");
expiresIn = (Integer)responseJson.get("expires_in");
@ -624,6 +648,10 @@ public class OAuthClient {
}
}
public String getIdToken() {
return idToken;
}
public String getAccessToken() {
return accessToken;
}

View file

@ -32,8 +32,11 @@ import org.keycloak.admin.client.resource.ClientTemplateResource;
import org.keycloak.admin.client.resource.RealmResource;
import org.keycloak.admin.client.resource.UserResource;
import org.keycloak.common.enums.SslRequired;
import org.keycloak.common.util.PemUtils;
import org.keycloak.events.Details;
import org.keycloak.events.Errors;
import org.keycloak.jose.jwk.JWKBuilder;
import org.keycloak.jose.jws.JWSHeader;
import org.keycloak.jose.jws.JWSInput;
import org.keycloak.jose.jws.JWSInputException;
import org.keycloak.models.ProtocolMapperModel;
@ -155,6 +158,26 @@ public class AccessTokenTest extends AbstractKeycloakTest {
assertEquals("bearer", response.getTokenType());
String expectedKid = oauth.doCertsRequest("test").getKeys()[0].getKeyId();
JWSHeader header = new JWSInput(response.getAccessToken()).getHeader();
assertEquals("RS256", header.getAlgorithm().name());
assertEquals("JWT", header.getType());
assertEquals(expectedKid, header.getKeyId());
assertNull(header.getContentType());
header = new JWSInput(response.getIdToken()).getHeader();
assertEquals("RS256", header.getAlgorithm().name());
assertEquals("JWT", header.getType());
assertEquals(expectedKid, header.getKeyId());
assertNull(header.getContentType());
header = new JWSInput(response.getRefreshToken()).getHeader();
assertEquals("RS256", header.getAlgorithm().name());
assertEquals("JWT", header.getType());
assertEquals(expectedKid, header.getKeyId());
assertNull(header.getContentType());
AccessToken token = oauth.verifyToken(response.getAccessToken());
assertEquals(findUserByUsername(adminClient.realm("test"), "test-user@localhost").getId(), token.getSubject());