session mgmt as7

This commit is contained in:
Bill Burke 2014-03-05 15:00:21 -05:00
parent d6bd02ea7d
commit 633ddb70e7
3 changed files with 199 additions and 87 deletions

View file

@ -20,7 +20,12 @@ import org.keycloak.adapters.RefreshableKeycloakSession;
import org.keycloak.adapters.ResourceMetadata; import org.keycloak.adapters.ResourceMetadata;
import org.keycloak.adapters.as7.config.CatalinaAdapterConfigLoader; import org.keycloak.adapters.as7.config.CatalinaAdapterConfigLoader;
import org.keycloak.representations.AccessToken; import org.keycloak.representations.AccessToken;
import org.keycloak.representations.adapters.action.AdminAction;
import org.keycloak.representations.adapters.action.PushNotBeforeAction; import org.keycloak.representations.adapters.action.PushNotBeforeAction;
import org.keycloak.representations.adapters.action.SessionStats;
import org.keycloak.representations.adapters.action.SessionStatsAction;
import org.keycloak.representations.adapters.action.UserStats;
import org.keycloak.representations.adapters.action.UserStatsAction;
import org.keycloak.representations.adapters.config.AdapterConfig; import org.keycloak.representations.adapters.config.AdapterConfig;
import org.keycloak.adapters.config.RealmConfiguration; import org.keycloak.adapters.config.RealmConfiguration;
import org.keycloak.adapters.config.RealmConfigurationLoader; import org.keycloak.adapters.config.RealmConfigurationLoader;
@ -35,7 +40,9 @@ import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Map;
import java.util.Set; import java.util.Set;
/** /**
@ -99,6 +106,21 @@ public class KeycloakAuthenticatorValve extends FormAuthenticator implements Lif
return; // we failed to verify the request return; // we failed to verify the request
} }
pushNotBefore(input, response); pushNotBefore(input, response);
return;
} else if (requestURI.endsWith(AdapterConstants.K_GET_SESSION_STATS)) {
JWSInput input = verifyAdminRequest(request, response);
if (input == null) {
return; // we failed to verify the request
}
getSessionStats(input, response);
return;
} else if (requestURI.endsWith(AdapterConstants.K_GET_USER_STATS)) {
JWSInput input = verifyAdminRequest(request, response);
if (input == null) {
return; // we failed to verify the request
}
getUserStats(input, response);
return;
} }
checkKeycloakSession(request); checkKeycloakSession(request);
super.invoke(request, response); super.invoke(request, response);
@ -136,63 +158,112 @@ public class KeycloakAuthenticatorValve extends FormAuthenticator implements Lif
String token = StreamUtil.readString(request.getInputStream()); String token = StreamUtil.readString(request.getInputStream());
if (token == null) { if (token == null) {
log.warn("admin request failed, no token"); log.warn("admin request failed, no token");
response.sendError(HttpServletResponse.SC_FORBIDDEN, "no token"); response.sendError(403, "no token");
return null; return null;
} }
JWSInput input = new JWSInput(token); JWSInput input = new JWSInput(token);
boolean verified = false; boolean verified = false;
try { try {
verified = RSAProvider.verify(input, resourceMetadata.getRealmKey()); verified = RSAProvider.verify(input, realmConfiguration.getMetadata().getRealmKey());
} catch (Exception ignore) { } catch (Exception ignore) {
} }
if (!verified) { if (!verified) {
log.warn("admin request failed, unable to verify token"); log.warn("admin request failed, unable to verify token");
response.sendError(HttpServletResponse.SC_FORBIDDEN, "verification failed"); response.sendError(403, "verification failed");
return null; return null;
} }
return input; return input;
} }
protected void pushNotBefore(JWSInput token, HttpServletResponse response) throws IOException {
try {
log.info("->> pushNotBefore: ");
PushNotBeforeAction action = JsonSerialization.readValue(token.getContent(), PushNotBeforeAction.class);
if (action.isExpired()) {
log.warn("admin request failed, expired token");
response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Expired token");
return;
}
if (!resourceMetadata.getResourceName().equals(action.getResource())) {
log.warn("Resource name does not match");
response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Resource name does not match");
return;
} protected boolean validateAction(HttpServletResponse response, AdminAction action) throws IOException {
realmConfiguration.setNotBefore(action.getNotBefore()); if (!action.validate()) {
} catch (Exception e) { log.warn("admin request failed, not validated" + action.getAction());
log.warn("failed to logout", e); response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Not validated");
response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Failed to logout"); return false;
} }
if (action.isExpired()) {
log.warn("admin request failed, expired token");
response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Expired token");
return false;
}
if (!resourceMetadata.getResourceName().equals(action.getResource())) {
log.warn("Resource name does not match");
response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Resource name does not match");
return false;
}
return true;
}
protected void pushNotBefore(JWSInput token, HttpServletResponse response) throws IOException {
log.info("->> pushNotBefore: ");
PushNotBeforeAction action = JsonSerialization.readValue(token.getContent(), PushNotBeforeAction.class);
if (!validateAction(response, action)) {
return;
}
realmConfiguration.setNotBefore(action.getNotBefore());
response.setStatus(HttpServletResponse.SC_NO_CONTENT); response.setStatus(HttpServletResponse.SC_NO_CONTENT);
} }
protected UserStats getUserStats(String user) {
UserStats stats = new UserStats();
Long loginTime = userSessionManagement.getUserLoginTime(user);
if (loginTime != null) {
stats.setLoggedIn(true);
stats.setWhenLoggedIn(loginTime);
} else {
stats.setLoggedIn(false);
}
return stats;
}
protected void getSessionStats(JWSInput token, HttpServletResponse response) throws IOException {
log.info("->> getSessionStats: ");
SessionStatsAction action = JsonSerialization.readValue(token.getContent(), SessionStatsAction.class);
if (!validateAction(response, action)) {
return;
}
SessionStats stats = new SessionStats();
stats.setActiveSessions(userSessionManagement.getActiveSessions());
stats.setActiveUsers(userSessionManagement.getActiveUsers().size());
if (action.isListUsers() && userSessionManagement.getActiveSessions() > 0) {
Map<String, UserStats> list = new HashMap<String, UserStats>();
for (String user : userSessionManagement.getActiveUsers()) {
list.put(user, getUserStats(user));
}
stats.setUsers(list);
}
response.setStatus(200);
response.setContentType("application/json");
JsonSerialization.writeValueToStream(response.getOutputStream(), stats);
}
protected void getUserStats(JWSInput token, HttpServletResponse response) throws IOException {
log.info("->> getUserStats: ");
UserStatsAction action = JsonSerialization.readValue(token.getContent(), UserStatsAction.class);
if (!validateAction(response, action)) {
return;
}
String user = action.getUser();
UserStats stats = getUserStats(user);
response.setStatus(200);
response.setContentType("application/json");
JsonSerialization.writeValueToStream(response.getOutputStream(), stats);
}
protected void remoteLogout(JWSInput token, HttpServletResponse response) throws IOException { protected void remoteLogout(JWSInput token, HttpServletResponse response) throws IOException {
try { try {
log.debug("->> remoteLogout: "); log.debug("->> remoteLogout: ");
LogoutAction action = JsonSerialization.readValue(token.getContent(), LogoutAction.class); LogoutAction action = JsonSerialization.readValue(token.getContent(), LogoutAction.class);
if (action.isExpired()) { if (!validateAction(response, action)) {
log.warn("admin request failed, expired token");
response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Expired token");
return; return;
} }
if (!resourceMetadata.getResourceName().equals(action.getResource())) {
log.warn("Resource name does not match");
response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Resource name does not match");
return;
}
String user = action.getUser(); String user = action.getUser();
if (user != null) { if (user != null) {
log.debug("logout of session for: " + user); log.debug("logout of session for: " + user);

View file

@ -8,8 +8,10 @@ import org.jboss.logging.Logger;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
/** /**
@ -20,19 +22,63 @@ import java.util.concurrent.ConcurrentHashMap;
*/ */
public class UserSessionManagement implements SessionListener { public class UserSessionManagement implements SessionListener {
private static final Logger log = Logger.getLogger(UserSessionManagement.class); private static final Logger log = Logger.getLogger(UserSessionManagement.class);
protected ConcurrentHashMap<String, Map<String, Session>> userSessionMap = new ConcurrentHashMap<String, Map<String, Session>>(); protected ConcurrentHashMap<String, UserSessions> userSessionMap = new ConcurrentHashMap<String, UserSessions>();
public static class UserSessions {
protected Map<String, Session> sessions = new ConcurrentHashMap<String, Session>();
protected long loggedIn = System.currentTimeMillis();
public Map<String, Session> getSessions() {
return sessions;
}
public long getLoggedIn() {
return loggedIn;
}
}
public int getNumUserLogins() {
return userSessionMap.size();
}
public int getActiveSessions() {
int active = 0;
synchronized (userSessionMap) {
for (UserSessions sessions : userSessionMap.values()) {
active += sessions.getSessions().size();
}
}
return active;
}
/**
*
* @param username
* @return null if user not logged in
*/
public Long getUserLoginTime(String username) {
UserSessions sessions = userSessionMap.get(username);
if (sessions == null) return null;
return sessions.getLoggedIn();
}
public Set<String> getActiveUsers() {
HashSet<String> set = new HashSet<String>();
set.addAll(userSessionMap.keySet());
return set;
}
protected void login(Session session, String username) { protected void login(Session session, String username) {
Map<String, Session> map = userSessionMap.get(username); synchronized (userSessionMap) {
if (map == null) { UserSessions userSessions = userSessionMap.get(username);
final Map<String, Session> value = new HashMap<String, Session>(); if (userSessions == null) {
map = userSessionMap.putIfAbsent(username, value); userSessions = new UserSessions();
if (map == null) { userSessionMap.put(username, userSessions);
map = value;
} }
} userSessions.getSessions().put(session.getId(), session);
synchronized (map) {
map.put(session.getId(), session);
} }
session.addSessionListener(this); session.addSessionListener(this);
} }
@ -43,32 +89,24 @@ public class UserSessionManagement implements SessionListener {
for (String user : users) logout(user); for (String user : users) logout(user);
} }
public void logoutAllBut(String but) {
List<String> users = new ArrayList<String>();
users.addAll(userSessionMap.keySet());
for (String user : users) {
if (!but.equals(user)) logout(user);
}
}
public void logout(String user) { public void logout(String user) {
log.debug("logoutUser: " + user); log.debug("logoutUser: " + user);
Map<String, Session> map = userSessionMap.remove(user); UserSessions sessions = null;
if (map == null) { synchronized (userSessionMap) {
sessions = userSessionMap.remove(user);
}
if (sessions == null) {
log.debug("no session for user: " + user); log.debug("no session for user: " + user);
return; return;
} }
log.debug("found session for user"); log.debug("found session for user");
synchronized (map) { for (Session session : sessions.getSessions().values()) {
for (Session session : map.values()) { session.setPrincipal(null);
log.debug("invalidating session for user: " + user); session.setAuthType(null);
session.setPrincipal(null); session.getSession().invalidate();
session.setAuthType(null);
session.getSession().invalidate();
}
} }
} }
public void sessionEvent(SessionEvent event) { public void sessionEvent(SessionEvent event) {
@ -85,13 +123,14 @@ public class UserSessionManagement implements SessionListener {
session.setAuthType(null); session.setAuthType(null);
String username = principal.getUserPrincipal().getName(); String username = principal.getUserPrincipal().getName();
Map<String, Session> map = userSessionMap.get(username); synchronized (userSessionMap) {
if (map == null) return; UserSessions sessions = userSessionMap.get(username);
synchronized (map) { if (sessions != null) {
map.remove(session.getId()); sessions.getSessions().remove(session.getId());
if (map.isEmpty()) userSessionMap.remove(username); if (sessions.getSessions().isEmpty()) {
userSessionMap.remove(username);
}
}
} }
} }
} }

View file

@ -69,27 +69,6 @@ public class ServletAdminActionsHandler implements HttpHandler {
this.realmConfig = realmConfig; this.realmConfig = realmConfig;
} }
protected JWSInput verifyAdminRequest(HttpServletRequest request, HttpServletResponse response) throws Exception {
String token = StreamUtil.readString(request.getInputStream());
if (token == null) {
log.warn("admin request failed, no token");
response.sendError(StatusCodes.FORBIDDEN, "no token");
return null;
}
JWSInput input = new JWSInput(token);
boolean verified = false;
try {
verified = RSAProvider.verify(input, realmConfig.getMetadata().getRealmKey());
} catch (Exception ignore) {
}
if (!verified) {
log.warn("admin request failed, unable to verify token");
response.sendError(StatusCodes.FORBIDDEN, "verification failed");
return null;
}
return input;
}
@ -135,6 +114,29 @@ public class ServletAdminActionsHandler implements HttpHandler {
return; return;
} }
protected JWSInput verifyAdminRequest(HttpServletRequest request, HttpServletResponse response) throws Exception {
String token = StreamUtil.readString(request.getInputStream());
if (token == null) {
log.warn("admin request failed, no token");
response.sendError(StatusCodes.FORBIDDEN, "no token");
return null;
}
JWSInput input = new JWSInput(token);
boolean verified = false;
try {
verified = RSAProvider.verify(input, realmConfig.getMetadata().getRealmKey());
} catch (Exception ignore) {
}
if (!verified) {
log.warn("admin request failed, unable to verify token");
response.sendError(StatusCodes.FORBIDDEN, "verification failed");
return null;
}
return input;
}
protected boolean validateAction(HttpServletResponse response, AdminAction action) throws IOException { protected boolean validateAction(HttpServletResponse response, AdminAction action) throws IOException {
if (!action.validate()) { if (!action.validate()) {
log.warn("admin request failed, not validated" + action.getAction()); log.warn("admin request failed, not validated" + action.getAction());