Merge pull request #1405 from patriot1burke/master

cleanup client sessions
This commit is contained in:
Bill Burke 2015-06-19 16:23:31 -04:00
commit 94e2983310
17 changed files with 136 additions and 16 deletions

View file

@ -29,6 +29,7 @@ public interface UserSessionProvider extends Provider {
void removeUserSessions(RealmModel realm, UserModel user); void removeUserSessions(RealmModel realm, UserModel user);
void removeExpiredUserSessions(RealmModel realm); void removeExpiredUserSessions(RealmModel realm);
void removeUserSessions(RealmModel realm); void removeUserSessions(RealmModel realm);
void removeClientSession(RealmModel realm, ClientSessionModel clientSession);
UsernameLoginFailureModel getUserLoginFailure(RealmModel realm, String username); UsernameLoginFailureModel getUserLoginFailure(RealmModel realm, String username);
UsernameLoginFailureModel addUserLoginFailure(RealmModel realm, String username); UsernameLoginFailureModel addUserLoginFailure(RealmModel realm, String username);

View file

@ -334,6 +334,21 @@ public class InfinispanUserSessionProvider implements UserSessionProvider {
} }
} }
@Override
public void removeClientSession(RealmModel realm, ClientSessionModel clientSession) {
UserSessionModel userSession = clientSession.getUserSession();
if (userSession != null) {
UserSessionEntity entity = ((UserSessionAdapter) userSession).getEntity();
if (entity.getClientSessions() != null) {
entity.getClientSessions().remove(clientSession.getId());
}
tx.replace(sessionCache, entity.getId(), entity);
}
tx.remove(sessionCache, clientSession.getId());
}
void dettachSession(UserSessionModel userSession, ClientSessionModel clientSession) { void dettachSession(UserSessionModel userSession, ClientSessionModel clientSession) {
UserSessionEntity entity = ((UserSessionAdapter) userSession).getEntity(); UserSessionEntity entity = ((UserSessionAdapter) userSession).getEntity();
String clientSessionId = clientSession.getId(); String clientSessionId = clientSession.getId();
@ -359,6 +374,7 @@ public class InfinispanUserSessionProvider implements UserSessionProvider {
} }
} }
InfinispanKeycloakTransaction getTx() { InfinispanKeycloakTransaction getTx() {
return tx; return tx;
} }

View file

@ -135,4 +135,6 @@ public class ClientSessionEntity extends SessionEntity {
public void setUserSessionNotes(Map<String, String> userSessionNotes) { public void setUserSessionNotes(Map<String, String> userSessionNotes) {
this.userSessionNotes = userSessionNotes; this.userSessionNotes = userSessionNotes;
} }
} }

View file

@ -26,4 +26,21 @@ public class SessionEntity implements Serializable {
public void setRealm(String realm) { public void setRealm(String realm) {
this.realm = realm; this.realm = realm;
} }
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof SessionEntity)) return false;
SessionEntity that = (SessionEntity) o;
if (id != null ? !id.equals(that.id) : that.id != null) return false;
return true;
}
@Override
public int hashCode() {
return id != null ? id.hashCode() : 0;
}
} }

View file

@ -128,6 +128,10 @@ public class ClientSessionAdapter implements ClientSessionModel {
return realm.getClientById(entity.getClientId()); return realm.getClientById(entity.getClientId());
} }
public ClientSessionEntity getEntity() {
return entity;
}
@Override @Override
public void setUserSession(UserSessionModel userSession) { public void setUserSession(UserSessionModel userSession) {
if (userSession == null) { if (userSession == null) {

View file

@ -18,6 +18,7 @@ import org.keycloak.util.Time;
import javax.persistence.EntityManager; import javax.persistence.EntityManager;
import javax.persistence.TypedQuery; import javax.persistence.TypedQuery;
import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@ -47,6 +48,13 @@ public class JpaUserSessionProvider implements UserSessionProvider {
return new ClientSessionAdapter(session, em, realm, entity); return new ClientSessionAdapter(session, em, realm, entity);
} }
@Override
public void removeClientSession(RealmModel realm, ClientSessionModel clientSession) {
ClientSessionEntity clientSessionEntity = ((ClientSessionAdapter)clientSession).getEntity();
em.remove(clientSessionEntity);
em.flush();
}
@Override @Override
public ClientSessionModel getClientSession(RealmModel realm, String id) { public ClientSessionModel getClientSession(RealmModel realm, String id) {
ClientSessionEntity clientSession = em.find(ClientSessionEntity.class, id); ClientSessionEntity clientSession = em.find(ClientSessionEntity.class, id);

View file

@ -186,4 +186,21 @@ public class ClientSessionEntity {
public void setUserSessionNotes(Collection<ClientUserSessionNoteEntity> userSessionNotes) { public void setUserSessionNotes(Collection<ClientUserSessionNoteEntity> userSessionNotes) {
this.userSessionNotes = userSessionNotes; this.userSessionNotes = userSessionNotes;
} }
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof ClientSessionEntity)) return false;
ClientSessionEntity that = (ClientSessionEntity) o;
if (id != null ? !id.equals(that.id) : that.id != null) return false;
return true;
}
@Override
public int hashCode() {
return id != null ? id.hashCode() : 0;
}
} }

View file

@ -39,7 +39,9 @@ public class ClientSessionAdapter implements ClientSessionModel {
return session.realms().getRealm(entity.getRealmId()); return session.realms().getRealm(entity.getRealmId());
} }
public ClientSessionEntity getEntity() {
return entity;
}
@Override @Override
public ClientModel getClient() { public ClientModel getClient() {

View file

@ -58,6 +58,17 @@ public class MemUserSessionProvider implements UserSessionProvider {
return new ClientSessionAdapter(session, this, realm, entity); return new ClientSessionAdapter(session, this, realm, entity);
} }
@Override
public void removeClientSession(RealmModel realm, ClientSessionModel clientSession) {
ClientSessionEntity entity = ((ClientSessionAdapter)clientSession).getEntity();
UserSessionModel userSession = clientSession.getUserSession();
if (userSession != null) {
UserSessionEntity userSessionEntity = ((UserSessionAdapter)userSession).getEntity();
userSessionEntity.getClientSessions().remove(entity);
}
clientSessions.remove(clientSession.getId());
}
@Override @Override
public ClientSessionModel getClientSession(RealmModel realm, String id) { public ClientSessionModel getClientSession(RealmModel realm, String id) {
ClientSessionEntity entity = clientSessions.get(id); ClientSessionEntity entity = clientSessions.get(id);

View file

@ -132,4 +132,21 @@ public class ClientSessionEntity {
public Map<String, String> getUserSessionNotes() { public Map<String, String> getUserSessionNotes() {
return userSessionNotes; return userSessionNotes;
} }
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof ClientSessionEntity)) return false;
ClientSessionEntity that = (ClientSessionEntity) o;
if (id != null ? !id.equals(that.id) : that.id != null) return false;
return true;
}
@Override
public int hashCode() {
return id != null ? id.hashCode() : 0;
}
} }

View file

@ -55,6 +55,19 @@ public class MongoUserSessionProvider implements UserSessionProvider {
return new ClientSessionAdapter(session, this, realm, entity, invocationContext); return new ClientSessionAdapter(session, this, realm, entity, invocationContext);
} }
@Override
public void removeClientSession(RealmModel realm, ClientSessionModel clientSession) {
MongoClientSessionEntity entity = ((ClientSessionAdapter)clientSession).getMongoEntity();
if (entity.getSessionId() != null) {
MongoUserSessionEntity userSessionEntity = getUserSessionEntity(realm, entity.getSessionId());
getMongoStore().pullItemFromList(userSessionEntity, "clientSessions", entity.getSessionId(), invocationContext);
}
mongoStore.removeEntity(entity, invocationContext);
}
@Override @Override
public ClientSessionModel getClientSession(RealmModel realm, String id) { public ClientSessionModel getClientSession(RealmModel realm, String id) {
MongoClientSessionEntity entity = getClientSessionEntity(id); MongoClientSessionEntity entity = getClientSessionEntity(id);

View file

@ -134,12 +134,9 @@ public class SamlProtocol implements LoginProtocol {
@Override @Override
public Response cancelLogin(ClientSessionModel clientSession) { public Response cancelLogin(ClientSessionModel clientSession) {
return getErrorResponse(clientSession, JBossSAMLURIConstants.STATUS_REQUEST_DENIED.get()); Response error = getErrorResponse(clientSession, JBossSAMLURIConstants.STATUS_REQUEST_DENIED.get());
} session.sessions().removeClientSession(realm, clientSession);
return error;
@Override
public Response invalidSessionError(ClientSessionModel clientSession) {
return getErrorResponse(clientSession, JBossSAMLURIConstants.STATUS_AUTHNFAILED.get());
} }
protected String getResponseIssuer(RealmModel realm) { protected String getResponseIssuer(RealmModel realm) {

View file

@ -459,7 +459,8 @@ public class AuthenticationProcessor {
return authenticationComplete(); return authenticationComplete();
} }
protected void resetFlow() { public static void resetFlow(ClientSessionModel clientSession) {
clientSession.setAuthenticatedUser(null);
clientSession.clearExecutionStatus(); clientSession.clearExecutionStatus();
clientSession.clearUserSessionNotes(); clientSession.clearUserSessionNotes();
clientSession.removeNote(CURRENT_AUTHENTICATION_EXECUTION); clientSession.removeNote(CURRENT_AUTHENTICATION_EXECUTION);
@ -471,14 +472,14 @@ public class AuthenticationProcessor {
if (!execution.equals(current)) { if (!execution.equals(current)) {
logger.debug("Current execution does not equal executed execution. Might be a page refresh"); logger.debug("Current execution does not equal executed execution. Might be a page refresh");
logFailure(); logFailure();
resetFlow(); resetFlow(clientSession);
return authenticate(); return authenticate();
} }
AuthenticationExecutionModel model = realm.getAuthenticationExecutionById(execution); AuthenticationExecutionModel model = realm.getAuthenticationExecutionById(execution);
if (model == null) { if (model == null) {
logger.debug("Cannot find execution, reseting flow"); logger.debug("Cannot find execution, reseting flow");
logFailure(); logFailure();
resetFlow(); resetFlow(clientSession);
return authenticate(); return authenticate();
} }
event.event(EventType.LOGIN); event.event(EventType.LOGIN);

View file

@ -28,7 +28,6 @@ public interface LoginProtocol extends Provider {
LoginProtocol setEventBuilder(EventBuilder event); LoginProtocol setEventBuilder(EventBuilder event);
Response cancelLogin(ClientSessionModel clientSession); Response cancelLogin(ClientSessionModel clientSession);
Response invalidSessionError(ClientSessionModel clientSession);
Response authenticated(UserSessionModel userSession, ClientSessionCode accessCode); Response authenticated(UserSessionModel userSession, ClientSessionCode accessCode);
Response consentDenied(ClientSessionModel clientSession); Response consentDenied(ClientSessionModel clientSession);

View file

@ -122,6 +122,7 @@ public class OIDCLoginProtocol implements LoginProtocol {
if (state != null) { if (state != null) {
redirectUri.queryParam(OAuth2Constants.STATE, state); redirectUri.queryParam(OAuth2Constants.STATE, state);
} }
session.sessions().removeClientSession(realm, clientSession);
return Response.status(302).location(redirectUri.build()).build(); return Response.status(302).location(redirectUri.build()).build();
} }

View file

@ -182,7 +182,15 @@ public class LoginActionsService {
} else if (!(clientCode.isActionActive(requiredAction) || clientCode.isActionActive(alternativeRequiredAction))) { } else if (!(clientCode.isActionActive(requiredAction) || clientCode.isActionActive(alternativeRequiredAction))) {
event.client(clientCode.getClientSession().getClient()); event.client(clientCode.getClientSession().getClient());
event.error(Errors.EXPIRED_CODE); event.error(Errors.EXPIRED_CODE);
if (clientCode.getClientSession().getAction().equals(ClientSessionModel.Action.AUTHENTICATE.name())) {
AuthenticationProcessor.resetFlow(clientCode.getClientSession());
response = processAuthentication(null, clientCode.getClientSession());
} else {
if (clientCode.getClientSession().getUserSession() == null) {
session.sessions().removeClientSession(realm, clientCode.getClientSession());
}
response = ErrorPage.error(session, Messages.EXPIRED_CODE); response = ErrorPage.error(session, Messages.EXPIRED_CODE);
}
return false; return false;
} else { } else {
return true; return true;
@ -207,21 +215,26 @@ public class LoginActionsService {
return false; return false;
} }
ClientSessionModel clientSession = clientCode.getClientSession(); ClientSessionModel clientSession = clientCode.getClientSession();
if (clientSession == null) {
event.error(Errors.INVALID_CODE);
response = ErrorPage.error(session, Messages.INVALID_CODE);
return false;
}
event.detail(Details.CODE_ID, clientSession.getId()); event.detail(Details.CODE_ID, clientSession.getId());
ClientModel client = clientSession.getClient(); ClientModel client = clientSession.getClient();
if (client == null) { if (client == null) {
event.error(Errors.CLIENT_NOT_FOUND); event.error(Errors.CLIENT_NOT_FOUND);
response = ErrorPage.error(session, Messages.UNKNOWN_LOGIN_REQUESTER); response = ErrorPage.error(session, Messages.UNKNOWN_LOGIN_REQUESTER);
session.sessions().removeClientSession(realm, clientSession);
return false; return false;
} }
session.getContext().setClient(client);
if (!client.isEnabled()) { if (!client.isEnabled()) {
event.error(Errors.CLIENT_NOT_FOUND); event.error(Errors.CLIENT_NOT_FOUND);
response = ErrorPage.error(session, Messages.LOGIN_REQUESTER_NOT_ENABLED); response = ErrorPage.error(session, Messages.LOGIN_REQUESTER_NOT_ENABLED);
session.sessions().removeClientSession(realm, clientSession);
return false; return false;
} }
session.getContext().setClient(clientCode.getClientSession().getClient()); session.getContext().setClient(client);
return true; return true;
} }
} }
@ -239,7 +252,7 @@ public class LoginActionsService {
@QueryParam("execution") String execution) { @QueryParam("execution") String execution) {
event.event(EventType.LOGIN); event.event(EventType.LOGIN);
Checks checks = new Checks(); Checks checks = new Checks();
if (!checks.check(code)) { if (!checks.check(code, ClientSessionModel.Action.AUTHENTICATE.name(), ClientSessionModel.Action.RECOVER_PASSWORD.name())) {
return checks.response; return checks.response;
} }
event.detail(Details.CODE_ID, code); event.detail(Details.CODE_ID, code);

View file

@ -137,6 +137,7 @@ public class ResetPasswordTest {
events.expectRequiredAction(EventType.SEND_RESET_PASSWORD).user(userId).detail(Details.USERNAME, "login-test").detail(Details.EMAIL, "login@test.com").assertEvent().getSessionId(); events.expectRequiredAction(EventType.SEND_RESET_PASSWORD).user(userId).detail(Details.USERNAME, "login-test").detail(Details.EMAIL, "login@test.com").assertEvent().getSessionId();
String src = driver.getPageSource();
resetPasswordPage.backToLogin(); resetPasswordPage.backToLogin();
assertTrue(loginPage.isCurrent()); assertTrue(loginPage.isCurrent());