This commit is contained in:
Bill Burke 2014-09-09 12:06:48 -04:00
parent c5662020d8
commit d0a3a04d34
12 changed files with 100 additions and 17 deletions

View file

@ -39,5 +39,7 @@ public interface AccountProvider extends Provider {
AccountProvider setPasswordSet(boolean passwordSet); AccountProvider setPasswordSet(boolean passwordSet);
AccountProvider setStateChecker(String stateChecker);
AccountProvider setFeatures(boolean social, boolean events, boolean passwordUpdateSupported); AccountProvider setFeatures(boolean social, boolean events, boolean passwordUpdateSupported);
} }

View file

@ -47,6 +47,7 @@ public class FreeMarkerAccountProvider implements AccountProvider {
private RealmModel realm; private RealmModel realm;
private String[] referrer; private String[] referrer;
private List<Event> events; private List<Event> events;
private String stateChecker;
private List<UserSessionModel> sessions; private List<UserSessionModel> sessions;
private boolean socialEnabled; private boolean socialEnabled;
private boolean eventsEnabled; private boolean eventsEnabled;
@ -107,6 +108,10 @@ public class FreeMarkerAccountProvider implements AccountProvider {
} }
URI baseQueryUri = baseUriBuilder.build(); URI baseQueryUri = baseUriBuilder.build();
if (stateChecker != null) {
attributes.put("stateChecker", stateChecker);
}
if (message != null) { if (message != null) {
attributes.put("message", new MessageBean(messages.containsKey(message) ? messages.getProperty(message) : message, messageType)); attributes.put("message", new MessageBean(messages.containsKey(message) ? messages.getProperty(message) : message, messageType));
} }
@ -115,7 +120,7 @@ public class FreeMarkerAccountProvider implements AccountProvider {
attributes.put("referrer", new ReferrerBean(referrer)); attributes.put("referrer", new ReferrerBean(referrer));
} }
attributes.put("url", new UrlBean(realm, theme, baseUri, baseQueryUri, uriInfo.getRequestUri())); attributes.put("url", new UrlBean(realm, theme, baseUri, baseQueryUri, uriInfo.getRequestUri(), stateChecker));
attributes.put("features", new FeaturesBean(socialEnabled, eventsEnabled, passwordUpdateSupported)); attributes.put("features", new FeaturesBean(socialEnabled, eventsEnabled, passwordUpdateSupported));
@ -127,7 +132,7 @@ public class FreeMarkerAccountProvider implements AccountProvider {
attributes.put("totp", new TotpBean(user, baseUri)); attributes.put("totp", new TotpBean(user, baseUri));
break; break;
case SOCIAL: case SOCIAL:
attributes.put("social", new AccountSocialBean(session, realm, user, uriInfo.getBaseUri())); attributes.put("social", new AccountSocialBean(session, realm, user, uriInfo.getBaseUri(), stateChecker));
break; break;
case LOG: case LOG:
attributes.put("log", new LogBean(events)); attributes.put("log", new LogBean(events));
@ -212,6 +217,12 @@ public class FreeMarkerAccountProvider implements AccountProvider {
return this; return this;
} }
@Override
public AccountProvider setStateChecker(String stateChecker) {
this.stateChecker = stateChecker;
return this;
}
@Override @Override
public AccountProvider setFeatures(boolean socialEnabled, boolean eventsEnabled, boolean passwordUpdateSupported) { public AccountProvider setFeatures(boolean socialEnabled, boolean eventsEnabled, boolean passwordUpdateSupported) {
this.socialEnabled = socialEnabled; this.socialEnabled = socialEnabled;

View file

@ -25,7 +25,7 @@ public class AccountSocialBean {
private final boolean removeLinkPossible; private final boolean removeLinkPossible;
private final KeycloakSession session; private final KeycloakSession session;
public AccountSocialBean(KeycloakSession session, RealmModel realm, UserModel user, URI baseUri) { public AccountSocialBean(KeycloakSession session, RealmModel realm, UserModel user, URI baseUri, String stateChecker) {
this.session = session; this.session = session;
URI accountSocialUpdateUri = Urls.accountSocialUpdate(baseUri, realm.getName()); URI accountSocialUpdateUri = Urls.accountSocialUpdate(baseUri, realm.getName());
this.socialLinks = new LinkedList<SocialLinkEntry>(); this.socialLinks = new LinkedList<SocialLinkEntry>();
@ -44,7 +44,11 @@ public class AccountSocialBean {
availableLinks++; availableLinks++;
} }
String action = socialLink != null ? "remove" : "add"; String action = socialLink != null ? "remove" : "add";
String actionUrl = UriBuilder.fromUri(accountSocialUpdateUri).queryParam("action", action).queryParam("provider_id", socialProviderId).build().toString(); String actionUrl = UriBuilder.fromUri(accountSocialUpdateUri)
.queryParam("action", action)
.queryParam("provider_id", socialProviderId)
.queryParam("stateChecker", stateChecker)
.build().toString();
SocialLinkEntry entry = new SocialLinkEntry(socialLink, provider.getName(), actionUrl); SocialLinkEntry entry = new SocialLinkEntry(socialLink, provider.getName(), actionUrl);
this.socialLinks.add(entry); this.socialLinks.add(entry);

View file

@ -44,6 +44,8 @@ public class SessionsBean {
this.session = session; this.session = session;
} }
public String getId() {return session.getId(); }
public String getIpAddress() { public String getIpAddress() {
return session.getIpAddress(); return session.getIpAddress();
} }

View file

@ -16,13 +16,15 @@ public class UrlBean {
private URI baseURI; private URI baseURI;
private URI baseQueryURI; private URI baseQueryURI;
private URI currentURI; private URI currentURI;
private String stateChecker;
public UrlBean(RealmModel realm, Theme theme, URI baseURI, URI baseQueryURI, URI currentURI) { public UrlBean(RealmModel realm, Theme theme, URI baseURI, URI baseQueryURI, URI currentURI, String stateChecker) {
this.realm = realm.getName(); this.realm = realm.getName();
this.theme = theme; this.theme = theme;
this.baseURI = baseURI; this.baseURI = baseURI;
this.baseQueryURI = baseQueryURI; this.baseQueryURI = baseQueryURI;
this.currentURI = currentURI; this.currentURI = currentURI;
this.stateChecker = stateChecker;
} }
public String getAccessUrl() { public String getAccessUrl() {
@ -54,11 +56,11 @@ public class UrlBean {
} }
public String getSessionsLogoutUrl() { public String getSessionsLogoutUrl() {
return Urls.accountSessionsLogoutPage(baseQueryURI, realm).toString(); return Urls.accountSessionsLogoutPage(baseQueryURI, realm, stateChecker).toString();
} }
public String getTotpRemoveUrl() { public String getTotpRemoveUrl() {
return Urls.accountTotpRemove(baseQueryURI, realm).toString(); return Urls.accountTotpRemove(baseQueryURI, realm, stateChecker).toString();
} }
public String getLogoutUrl() { public String getLogoutUrl() {

View file

@ -12,6 +12,8 @@
<form action="${url.accountUrl}" class="form-horizontal" method="post"> <form action="${url.accountUrl}" class="form-horizontal" method="post">
<input type="hidden" id="stateChecker" name="stateChecker" value="${stateChecker}">
<div class="form-group"> <div class="form-group">
<div class="col-sm-2 col-md-2"> <div class="col-sm-2 col-md-2">
<label for="username" class="control-label">${rb.username}</label> <label for="username" class="control-label">${rb.username}</label>

View file

@ -23,6 +23,8 @@
</div> </div>
</#if> </#if>
<input type="hidden" id="stateChecker" name="stateChecker" value="${stateChecker}">
<div class="form-group"> <div class="form-group">
<div class="col-sm-2 col-md-2"> <div class="col-sm-2 col-md-2">
<label for="password-new" class="control-label">${rb.passwordNew}</label> <label for="password-new" class="control-label">${rb.passwordNew}</label>

View file

View file

@ -36,6 +36,7 @@
<hr/> <hr/>
<form action="${url.totpUrl}" class="form-horizontal" method="post"> <form action="${url.totpUrl}" class="form-horizontal" method="post">
<input type="hidden" id="stateChecker" name="stateChecker" value="${stateChecker}">
<div class="form-group"> <div class="form-group">
<div class="col-sm-2 col-md-2"> <div class="col-sm-2 col-md-2">
<label for="totp" class="control-label">${rb.authenticatorCode}</label> <label for="totp" class="control-label">${rb.authenticatorCode}</label>

View file

@ -4,6 +4,7 @@ import org.keycloak.models.ApplicationModel;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.representations.AccessToken; import org.keycloak.representations.AccessToken;
/** /**
@ -16,17 +17,19 @@ public class Auth {
private final AccessToken token; private final AccessToken token;
private final UserModel user; private final UserModel user;
private final ClientModel client; private final ClientModel client;
private final UserSessionModel session;
public Auth(RealmModel realm, AccessToken token, UserModel user, ClientModel client, boolean cookie) { public Auth(RealmModel realm, AccessToken token, UserModel user, ClientModel client, UserSessionModel session, boolean cookie) {
this.cookie = cookie; this.cookie = cookie;
this.token = token; this.token = token;
this.realm = realm; this.realm = realm;
this.user = user; this.user = user;
this.client = client; this.client = client;
this.session = session;
} }
public boolean isCookie() { public boolean isCookieAuthenticated() {
return cookie; return cookie;
} }
@ -46,6 +49,10 @@ public class Auth {
return token; return token;
} }
public UserSessionModel getSession() {
return session;
}
public boolean hasRealmRole(String role) { public boolean hasRealmRole(String role) {
if (cookie) { if (cookie) {
return user.hasRole(realm.getRole(role)); return user.hasRole(realm.getRole(role));

View file

@ -145,16 +145,26 @@ public class AccountService {
AuthenticationManager.AuthResult authResult = authManager.authenticateIdentityCookie(session, realm, uriInfo, clientConnection, headers); AuthenticationManager.AuthResult authResult = authManager.authenticateIdentityCookie(session, realm, uriInfo, clientConnection, headers);
if (authResult != null) { if (authResult != null) {
auth = new Auth(realm, authResult.getToken(), authResult.getUser(), application, true); auth = new Auth(realm, authResult.getToken(), authResult.getUser(), application, authResult.getSession(), true);
} else { } else {
authResult = authManager.authenticateBearerToken(session, realm, uriInfo, clientConnection, headers); authResult = authManager.authenticateBearerToken(session, realm, uriInfo, clientConnection, headers);
if (authResult != null) { if (authResult != null) {
auth = new Auth(realm, authResult.getToken(), authResult.getUser(), application, false); auth = new Auth(realm, authResult.getToken(), authResult.getUser(), application, authResult.getSession(), false);
} }
} }
// don't allow cors requests unless they were authenticated by an access token
// This is to prevent CSRF attacks.
if (auth != null && auth.isCookieAuthenticated()) {
if (headers.getRequestHeaders().containsKey("Origin")) {
throw new ForbiddenException();
}
}
if (authResult != null) { if (authResult != null) {
UserSessionModel userSession = authResult.getSession(); UserSessionModel userSession = authResult.getSession();
if (userSession != null) { if (userSession != null) {
account.setStateChecker(authResult.getSession().getId());
boolean associated = false; boolean associated = false;
for (ClientSessionModel c : userSession.getClientSessions()) { for (ClientSessionModel c : userSession.getClientSessions()) {
if (c.getClient().equals(application)) { if (c.getClient().equals(application)) {
@ -312,6 +322,34 @@ public class AccountService {
return forwardToPage("sessions", AccountPages.SESSIONS); return forwardToPage("sessions", AccountPages.SESSIONS);
} }
/**
* Check to see if form post has sessionId hidden field and match it against the session id.
*
* @param formData
*/
protected void csrfCheck(final MultivaluedMap<String, String> formData) {
if (!auth.isCookieAuthenticated()) return;
if (auth.getSession() == null) return;
String stateChecker = formData.getFirst("stateChecker");
if (!auth.getSession().getId().equals(stateChecker)) {
throw new ForbiddenException();
}
}
/**
* Check to see if form post has sessionId hidden field and match it against the session id.
*
*/
protected void csrfCheck(String stateChecker) {
if (!auth.isCookieAuthenticated()) return;
if (auth.getSession() == null) return;
if (!auth.getSession().getId().equals(stateChecker)) {
throw new ForbiddenException();
}
}
/** /**
* Update account information. * Update account information.
* *
@ -340,6 +378,8 @@ public class AccountService {
return account.createResponse(AccountPages.ACCOUNT); return account.createResponse(AccountPages.ACCOUNT);
} }
csrfCheck(formData);
UserModel user = auth.getUser(); UserModel user = auth.getUser();
String error = Validation.validateUpdateProfileForm(formData); String error = Validation.validateUpdateProfileForm(formData);
@ -393,12 +433,13 @@ public class AccountService {
@Path("sessions-logout") @Path("sessions-logout")
@GET @GET
public Response processSessionsLogout() { public Response processSessionsLogout(@QueryParam("stateChecker") String stateChecker) {
if (auth == null) { if (auth == null) {
return login("sessions"); return login("sessions");
} }
require(AccountRoles.MANAGE_ACCOUNT); require(AccountRoles.MANAGE_ACCOUNT);
csrfCheck(stateChecker);
UserModel user = auth.getUser(); UserModel user = auth.getUser();
session.sessions().removeUserSessions(realm, user); session.sessions().removeUserSessions(realm, user);
@ -440,6 +481,8 @@ public class AccountService {
return account.createResponse(AccountPages.TOTP); return account.createResponse(AccountPages.TOTP);
} }
csrfCheck(formData);
UserModel user = auth.getUser(); UserModel user = auth.getUser();
String totp = formData.getFirst("totp"); String totp = formData.getFirst("totp");
@ -494,6 +537,7 @@ public class AccountService {
return account.createResponse(AccountPages.PASSWORD); return account.createResponse(AccountPages.PASSWORD);
} }
csrfCheck(formData);
UserModel user = auth.getUser(); UserModel user = auth.getUser();
boolean requireCurrent = isPasswordSet(user); boolean requireCurrent = isPasswordSet(user);
@ -546,12 +590,14 @@ public class AccountService {
@Path("social-update") @Path("social-update")
@GET @GET
public Response processSocialUpdate(@QueryParam("action") String action, public Response processSocialUpdate(@QueryParam("action") String action,
@QueryParam("provider_id") String providerId) { @QueryParam("provider_id") String providerId,
@QueryParam("stateChecker") String stateChecker) {
if (auth == null) { if (auth == null) {
return login("social"); return login("social");
} }
require(AccountRoles.MANAGE_ACCOUNT); require(AccountRoles.MANAGE_ACCOUNT);
csrfCheck(stateChecker);
UserModel user = auth.getUser(); UserModel user = auth.getUser();
if (Validation.isEmpty(providerId)) { if (Validation.isEmpty(providerId)) {

View file

@ -68,8 +68,10 @@ public class Urls {
return accountBase(baseUri).path(AccountService.class, "totpPage").build(realmId); return accountBase(baseUri).path(AccountService.class, "totpPage").build(realmId);
} }
public static URI accountTotpRemove(URI baseUri, String realmId) { public static URI accountTotpRemove(URI baseUri, String realmId, String stateChecker) {
return accountBase(baseUri).path(AccountService.class, "processTotpRemove").build(realmId); return accountBase(baseUri).path(AccountService.class, "processTotpRemove")
.queryParam("stateChecker", stateChecker)
.build(realmId);
} }
public static URI accountLogPage(URI baseUri, String realmId) { public static URI accountLogPage(URI baseUri, String realmId) {
@ -80,8 +82,10 @@ public class Urls {
return accountBase(baseUri).path(AccountService.class, "sessionsPage").build(realmId); return accountBase(baseUri).path(AccountService.class, "sessionsPage").build(realmId);
} }
public static URI accountSessionsLogoutPage(URI baseUri, String realmId) { public static URI accountSessionsLogoutPage(URI baseUri, String realmId, String stateChecker) {
return accountBase(baseUri).path(AccountService.class, "processSessionsLogout").build(realmId); return accountBase(baseUri).path(AccountService.class, "processSessionsLogout")
.queryParam("stateChecker", stateChecker)
.build(realmId);
} }
public static URI accountLogout(URI baseUri, URI redirectUri, String realmId) { public static URI accountLogout(URI baseUri, URI redirectUri, String realmId) {