Merge pull request #561 from stianst/access-code

Added ClientSessionModel to UserSessionProvider
This commit is contained in:
Stian Thorgersen 2014-07-29 16:18:28 +01:00
commit 1bd219ed40
41 changed files with 1529 additions and 694 deletions

View file

@ -20,7 +20,8 @@
<class>org.keycloak.models.jpa.entities.ScopeMappingEntity</class> <class>org.keycloak.models.jpa.entities.ScopeMappingEntity</class>
<!-- JpaUserSessionProvider --> <!-- JpaUserSessionProvider -->
<class>org.keycloak.models.sessions.jpa.entities.ClientUserSessionAssociationEntity</class> <class>org.keycloak.models.sessions.jpa.entities.ClientSessionEntity</class>
<class>org.keycloak.models.sessions.jpa.entities.ClientSessionRoleEntity</class>
<class>org.keycloak.models.sessions.jpa.entities.UserSessionEntity</class> <class>org.keycloak.models.sessions.jpa.entities.UserSessionEntity</class>
<class>org.keycloak.models.sessions.jpa.entities.UsernameLoginFailureEntity</class> <class>org.keycloak.models.sessions.jpa.entities.UsernameLoginFailureEntity</class>

View file

@ -32,6 +32,7 @@ public class DefaultMongoConnectionFactoryProvider implements MongoConnectionPro
"org.keycloak.models.mongo.keycloak.entities.MongoOAuthClientEntity", "org.keycloak.models.mongo.keycloak.entities.MongoOAuthClientEntity",
"org.keycloak.models.sessions.mongo.entities.MongoUsernameLoginFailureEntity", "org.keycloak.models.sessions.mongo.entities.MongoUsernameLoginFailureEntity",
"org.keycloak.models.sessions.mongo.entities.MongoUserSessionEntity", "org.keycloak.models.sessions.mongo.entities.MongoUserSessionEntity",
"org.keycloak.models.sessions.mongo.entities.MongoClientSessionEntity",
"org.keycloak.models.entities.FederationProviderEntity" "org.keycloak.models.entities.FederationProviderEntity"
}; };

View file

@ -1,101 +0,0 @@
package org.keycloak.representations;
import java.util.Set;
/**
*
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public class AccessCode {
protected String id;
protected String clientId;
protected String userId;
protected String state;
protected String sessionState;
protected String redirectUri;
protected int timestamp;
protected Action action;
protected Set<String> requestedRoles;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getClientId() {
return clientId;
}
public void setClientId(String clientId) {
this.clientId = clientId;
}
public String getUserId() {
return userId;
}
public void setUserId(String userId) {
this.userId = userId;
}
public String getState() {
return state;
}
public void setState(String state) {
this.state = state;
}
public String getSessionState() {
return sessionState;
}
public void setSessionState(String sessionState) {
this.sessionState = sessionState;
}
public String getRedirectUri() {
return redirectUri;
}
public void setRedirectUri(String redirectUri) {
this.redirectUri = redirectUri;
}
public int getTimestamp() {
return timestamp;
}
public void setTimestamp(int timestamp) {
this.timestamp = timestamp;
}
public Action getAction() {
return action;
}
public void setAction(Action action) {
this.action = action;
}
public Set<String> getRequestedRoles() {
return requestedRoles;
}
public void setRequestedRoles(Set<String> requestedRoles) {
this.requestedRoles = requestedRoles;
}
public static enum Action {
OAUTH_GRANT,
VERIFY_EMAIL,
UPDATE_PROFILE,
CONFIGURE_TOTP,
UPDATE_PASSWORD
}
}

View file

@ -2,6 +2,7 @@ package org.keycloak.account.freemarker.model;
import org.keycloak.models.ApplicationModel; import org.keycloak.models.ApplicationModel;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.OAuthClientModel; import org.keycloak.models.OAuthClientModel;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
@ -9,8 +10,10 @@ import org.keycloak.util.Time;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Date; import java.util.Date;
import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Set;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
@ -58,16 +61,18 @@ public class SessionsBean {
return Time.toDate(max); return Time.toDate(max);
} }
public List<String> getApplications() { public Set<String> getApplications() {
List<String> apps = new ArrayList<String>(); Set<String> apps = new HashSet<String>();
for (ClientModel client : session.getClientAssociations()) { for (ClientSessionModel clientSession : session.getClientSessions()) {
ClientModel client = clientSession.getClient();
if (client instanceof ApplicationModel) apps.add(client.getClientId()); if (client instanceof ApplicationModel) apps.add(client.getClientId());
} }
return apps; return apps;
} }
public List<String> getClients() { public List<String> getClients() {
List<String> apps = new ArrayList<String>(); List<String> apps = new ArrayList<String>();
for (ClientModel client : session.getClientAssociations()) { for (ClientSessionModel clientSession : session.getClientSessions()) {
ClientModel client = clientSession.getClient();
if (client instanceof OAuthClientModel) apps.add(client.getClientId()); if (client instanceof OAuthClientModel) apps.add(client.getClientId());
} }
return apps; return apps;

View file

@ -0,0 +1,39 @@
package org.keycloak.models;
import java.util.Set;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public interface ClientSessionModel {
public String getId();
public ClientModel getClient();
public String getState();
public UserSessionModel getUserSession();
public String getRedirectUri();
public int getTimestamp();
public void setTimestamp(int timestamp);
public Action getAction();
public void setAction(Action action);
public Set<String> getRoles();
public static enum Action {
OAUTH_GRANT,
CODE_TO_TOKEN,
VERIFY_EMAIL,
UPDATE_PROFILE,
CONFIGURE_TOTP,
UPDATE_PASSWORD
}
}

View file

@ -39,10 +39,6 @@ public interface UserSessionModel {
void setLastSessionRefresh(int seconds); void setLastSessionRefresh(int seconds);
void associateClient(ClientModel client); List<ClientSessionModel> getClientSessions();
List<ClientModel> getClientAssociations();
void removeAssociatedClient(ClientModel client);
} }

View file

@ -3,6 +3,7 @@ package org.keycloak.models;
import org.keycloak.provider.Provider; import org.keycloak.provider.Provider;
import java.util.List; import java.util.List;
import java.util.Set;
/** /**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a> * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
@ -10,6 +11,9 @@ import java.util.List;
*/ */
public interface UserSessionProvider extends Provider { public interface UserSessionProvider extends Provider {
ClientSessionModel createClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession, String redirectUri, String state, Set<String> roles);
ClientSessionModel getClientSession(RealmModel realm, String id);
UserSessionModel createUserSession(RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe); UserSessionModel createUserSession(RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe);
UserSessionModel getUserSession(RealmModel realm, String id); UserSessionModel getUserSession(RealmModel realm, String id);
List<UserSessionModel> getUserSessions(RealmModel realm, UserModel user); List<UserSessionModel> getUserSessions(RealmModel realm, UserModel user);

View file

@ -4,6 +4,7 @@ import org.keycloak.models.ApplicationModel;
import org.keycloak.models.AuthenticationProviderModel; import org.keycloak.models.AuthenticationProviderModel;
import org.keycloak.models.ClaimMask; import org.keycloak.models.ClaimMask;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.Constants; import org.keycloak.models.Constants;
import org.keycloak.models.UserFederationProviderModel; import org.keycloak.models.UserFederationProviderModel;
import org.keycloak.models.OAuthClientModel; import org.keycloak.models.OAuthClientModel;
@ -209,7 +210,8 @@ public class ModelToRepresentation {
rep.setLastAccess(((long)session.getLastSessionRefresh())* 1000L); rep.setLastAccess(((long)session.getLastSessionRefresh())* 1000L);
rep.setUser(session.getUser().getUsername()); rep.setUser(session.getUser().getUsername());
rep.setIpAddress(session.getIpAddress()); rep.setIpAddress(session.getIpAddress());
for (ClientModel client : session.getClientAssociations()) { for (ClientSessionModel clientSession : session.getClientSessions()) {
ClientModel client = clientSession.getClient();
if (client instanceof ApplicationModel) { if (client instanceof ApplicationModel) {
rep.getApplications().add(client.getClientId()); rep.getApplications().add(client.getClientId());
} else if (client instanceof OAuthClientModel) { } else if (client instanceof OAuthClientModel) {

View file

@ -0,0 +1,87 @@
package org.keycloak.models.sessions.jpa;
import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.sessions.jpa.entities.ClientSessionEntity;
import org.keycloak.models.sessions.jpa.entities.ClientSessionRoleEntity;
import javax.persistence.EntityManager;
import java.util.HashSet;
import java.util.Set;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public class ClientSessionAdapter implements ClientSessionModel {
private KeycloakSession session;
private ClientSessionEntity entity;
private EntityManager em;
private RealmModel realm;
public ClientSessionAdapter(KeycloakSession session, EntityManager em, RealmModel realm, ClientSessionEntity entity) {
this.session = session;
this.em = em;
this.realm = realm;
this.entity = entity;
}
@Override
public String getId() {
return entity.getId();
}
@Override
public ClientModel getClient() {
return realm.findClientById(entity.getClientId());
}
@Override
public String getState() {
return entity.getState();
}
@Override
public UserSessionModel getUserSession() {
return new UserSessionAdapter(session, em, realm, entity.getSession());
}
@Override
public String getRedirectUri() {
return entity.getRedirectUri();
}
@Override
public int getTimestamp() {
return entity.getTimestamp();
}
@Override
public void setTimestamp(int timestamp) {
entity.setTimestamp(timestamp);
}
@Override
public Action getAction() {
return entity.getAction();
}
@Override
public void setAction(Action action) {
entity.setAction(action);
}
@Override
public Set<String> getRoles() {
Set<String> roles = new HashSet<String>();
if (entity.getRoles() != null) {
for (ClientSessionRoleEntity e : entity.getRoles()) {
roles.add(e.getRoleId());
}
}
return roles;
}
}

View file

@ -1,12 +1,15 @@
package org.keycloak.models.sessions.jpa; package org.keycloak.models.sessions.jpa;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.UserSessionProvider; import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.UsernameLoginFailureModel; import org.keycloak.models.UsernameLoginFailureModel;
import org.keycloak.models.sessions.jpa.entities.ClientSessionEntity;
import org.keycloak.models.sessions.jpa.entities.ClientSessionRoleEntity;
import org.keycloak.models.sessions.jpa.entities.UserSessionEntity; import org.keycloak.models.sessions.jpa.entities.UserSessionEntity;
import org.keycloak.models.sessions.jpa.entities.UsernameLoginFailureEntity; import org.keycloak.models.sessions.jpa.entities.UsernameLoginFailureEntity;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
@ -17,6 +20,7 @@ import javax.persistence.TypedQuery;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Set;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
@ -32,6 +36,46 @@ public class JpaUserSessionProvider implements UserSessionProvider {
this.em = em; this.em = em;
} }
@Override
public ClientSessionModel createClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession, String redirectUri, String state, Set<String> roles) {
UserSessionEntity userSessionEntity = em.find(UserSessionEntity.class, userSession.getId());
ClientSessionEntity entity = new ClientSessionEntity();
entity.setId(KeycloakModelUtils.generateId());
entity.setTimestamp(Time.currentTime());
entity.setClientId(client.getId());
entity.setSession(userSessionEntity);
entity.setRedirectUri(redirectUri);
entity.setState(state);
em.persist(entity);
if (roles != null) {
List<ClientSessionRoleEntity> roleEntities = new LinkedList<ClientSessionRoleEntity>();
for (String r : roles) {
ClientSessionRoleEntity roleEntity = new ClientSessionRoleEntity();
roleEntity.setClientSession(entity);
roleEntity.setRoleId(r);
em.persist(roleEntity);
roleEntities.add(roleEntity);
}
entity.setRoles(roleEntities);
}
userSessionEntity.getClientSessions().add(entity);
return new ClientSessionAdapter(session, em, realm, entity);
}
@Override
public ClientSessionModel getClientSession(RealmModel realm, String id) {
ClientSessionEntity clientSession = em.find(ClientSessionEntity.class, id);
if (clientSession != null && clientSession.getSession().getRealmId().equals(realm.getId())) {
return new ClientSessionAdapter(session, em, realm, clientSession);
}
return null;
}
@Override @Override
public UsernameLoginFailureModel getUserLoginFailure(RealmModel realm, String username) { public UsernameLoginFailureModel getUserLoginFailure(RealmModel realm, String username) {
String id = username + "-" + realm; String id = username + "-" + realm;
@ -109,7 +153,7 @@ public class JpaUserSessionProvider implements UserSessionProvider {
List<UserSessionModel> list = new LinkedList<UserSessionModel>(); List<UserSessionModel> list = new LinkedList<UserSessionModel>();
TypedQuery<UserSessionEntity> query = em.createNamedQuery("getUserSessionByClient", UserSessionEntity.class) TypedQuery<UserSessionEntity> query = em.createNamedQuery("getUserSessionByClient", UserSessionEntity.class)
.setParameter("realmId", realm.getId()) .setParameter("realmId", realm.getId())
.setParameter("clientId", client.getClientId()); .setParameter("clientId", client.getId());
if (firstResult != -1) { if (firstResult != -1) {
query.setFirstResult(firstResult); query.setFirstResult(firstResult);
} }
@ -126,7 +170,7 @@ public class JpaUserSessionProvider implements UserSessionProvider {
public int getActiveUserSessions(RealmModel realm, ClientModel client) { public int getActiveUserSessions(RealmModel realm, ClientModel client) {
Object count = em.createNamedQuery("getActiveUserSessionByClient") Object count = em.createNamedQuery("getActiveUserSessionByClient")
.setParameter("realmId", realm.getId()) .setParameter("realmId", realm.getId())
.setParameter("clientId", client.getClientId()) .setParameter("clientId", client.getId())
.getSingleResult(); .getSingleResult();
return ((Number)count).intValue(); return ((Number)count).intValue();
} }
@ -141,7 +185,11 @@ public class JpaUserSessionProvider implements UserSessionProvider {
@Override @Override
public void removeUserSessions(RealmModel realm, UserModel user) { public void removeUserSessions(RealmModel realm, UserModel user) {
em.createNamedQuery("removeClientUserSessionByUser") em.createNamedQuery("removeClientSessionRoleByUser")
.setParameter("realmId", realm.getId())
.setParameter("userId", user.getId())
.executeUpdate();
em.createNamedQuery("removeClientSessionByUser")
.setParameter("realmId", realm.getId()) .setParameter("realmId", realm.getId())
.setParameter("userId", user.getId()) .setParameter("userId", user.getId())
.executeUpdate(); .executeUpdate();
@ -156,7 +204,12 @@ public class JpaUserSessionProvider implements UserSessionProvider {
int maxTime = Time.currentTime() - realm.getSsoSessionMaxLifespan(); int maxTime = Time.currentTime() - realm.getSsoSessionMaxLifespan();
int idleTime = Time.currentTime() - realm.getSsoSessionIdleTimeout(); int idleTime = Time.currentTime() - realm.getSsoSessionIdleTimeout();
em.createNamedQuery("removeClientUserSessionByExpired") em.createNamedQuery("removeClientSessionRoleByExpired")
.setParameter("realmId", realm.getId())
.setParameter("maxTime", maxTime)
.setParameter("idleTime", idleTime)
.executeUpdate();
em.createNamedQuery("removeClientSessionByExpired")
.setParameter("realmId", realm.getId()) .setParameter("realmId", realm.getId())
.setParameter("maxTime", maxTime) .setParameter("maxTime", maxTime)
.setParameter("idleTime", idleTime) .setParameter("idleTime", idleTime)
@ -170,7 +223,8 @@ public class JpaUserSessionProvider implements UserSessionProvider {
@Override @Override
public void removeUserSessions(RealmModel realm) { public void removeUserSessions(RealmModel realm) {
em.createNamedQuery("removeClientUserSessionByRealm").setParameter("realmId", realm.getId()).executeUpdate(); em.createNamedQuery("removeClientSessionRoleByRealm").setParameter("realmId", realm.getId()).executeUpdate();
em.createNamedQuery("removeClientSessionByRealm").setParameter("realmId", realm.getId()).executeUpdate();
em.createNamedQuery("removeUserSessionByRealm").setParameter("realmId", realm.getId()).executeUpdate(); em.createNamedQuery("removeUserSessionByRealm").setParameter("realmId", realm.getId()).executeUpdate();
} }
@ -182,7 +236,8 @@ public class JpaUserSessionProvider implements UserSessionProvider {
@Override @Override
public void onClientRemoved(RealmModel realm, ClientModel client) { public void onClientRemoved(RealmModel realm, ClientModel client) {
em.createNamedQuery("removeClientUserSessionByClient").setParameter("realmId", realm.getId()).setParameter("clientId", client.getClientId()).executeUpdate(); em.createNamedQuery("removeClientSessionRoleByClient").setParameter("realmId", realm.getId()).setParameter("clientId", client.getId()).executeUpdate();
em.createNamedQuery("removeClientSessionByClient").setParameter("realmId", realm.getId()).setParameter("clientId", client.getId()).executeUpdate();
} }
@Override @Override

View file

@ -1,15 +1,15 @@
package org.keycloak.models.sessions.jpa; package org.keycloak.models.sessions.jpa;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.sessions.jpa.entities.ClientUserSessionAssociationEntity; import org.keycloak.models.sessions.jpa.entities.ClientSessionEntity;
import org.keycloak.models.sessions.jpa.entities.UserSessionEntity; import org.keycloak.models.sessions.jpa.entities.UserSessionEntity;
import javax.persistence.EntityManager; import javax.persistence.EntityManager;
import java.util.ArrayList; import java.util.LinkedList;
import java.util.List; import java.util.List;
/** /**
@ -114,30 +114,12 @@ public class UserSessionAdapter implements UserSessionModel {
} }
@Override @Override
public void associateClient(ClientModel client) { public List<ClientSessionModel> getClientSessions() {
for (ClientUserSessionAssociationEntity ass : entity.getClients()) { List<ClientSessionModel> clientSessions = new LinkedList<ClientSessionModel>();
if (ass.getClientId().equals(client.getClientId())) return; for (ClientSessionEntity e : entity.getClientSessions()) {
clientSessions.add(new ClientSessionAdapter(session, em, realm, e));
} }
return clientSessions;
ClientUserSessionAssociationEntity association = new ClientUserSessionAssociationEntity();
association.setClientId(client.getClientId());
association.setSession(entity);
em.persist(association);
entity.getClients().add(association);
}
@Override
public void removeAssociatedClient(ClientModel client) {
em.createNamedQuery("removeClientUserSessionByClient").setParameter("clientId", client.getClientId()).executeUpdate();
}
@Override
public List<ClientModel> getClientAssociations() {
List<ClientModel> clients = new ArrayList<ClientModel>();
for (ClientUserSessionAssociationEntity association : entity.getClients()) {
clients.add(realm.findClient(association.getClientId()));
}
return clients;
} }
@Override @Override

View file

@ -0,0 +1,127 @@
package org.keycloak.models.sessions.jpa.entities;
import org.keycloak.models.ClientSessionModel;
import javax.persistence.CascadeType;
import javax.persistence.Column;
import javax.persistence.ElementCollection;
import javax.persistence.Entity;
import javax.persistence.FetchType;
import javax.persistence.Id;
import javax.persistence.IdClass;
import javax.persistence.JoinColumn;
import javax.persistence.ManyToOne;
import javax.persistence.NamedQueries;
import javax.persistence.NamedQuery;
import javax.persistence.OneToMany;
import javax.persistence.Table;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
@Entity
@Table(name = "CLIENT_SESSION")
@NamedQueries({
@NamedQuery(name = "removeClientSessionByRealm", query = "delete from ClientSessionEntity a where a.session IN (select s from UserSessionEntity s where s.realmId = :realmId)"),
@NamedQuery(name = "removeClientSessionByUser", query = "delete from ClientSessionEntity a where a.session IN (select s from UserSessionEntity s where s.realmId = :realmId and s.userId = :userId)"),
@NamedQuery(name = "removeClientSessionByClient", query = "delete from ClientSessionEntity a where a.clientId = :clientId and a.session IN (select s from UserSessionEntity s where s.realmId = :realmId)"),
@NamedQuery(name = "removeClientSessionByExpired", query = "delete from ClientSessionEntity a where a.session IN (select s from UserSessionEntity s where s.realmId = :realmId and (s.started < :maxTime or s.lastSessionRefresh < :idleTime))")
})
public class ClientSessionEntity {
@Id
@Column(name = "ID", length = 36)
protected String id;
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "SESSION_ID")
protected UserSessionEntity session;
@Column(name="CLIENT_ID",length = 36)
protected String clientId;
@Column(name="TIMESTAMP")
protected int timestamp;
@Column(name="REDIRECT_URI")
protected String redirectUri;
@Column(name="STATE")
protected String state;
@Column(name="ACTION")
protected ClientSessionModel.Action action;
@OneToMany(cascade = CascadeType.REMOVE, orphanRemoval = true, mappedBy="clientSession")
protected Collection<ClientSessionRoleEntity> roles = new ArrayList<ClientSessionRoleEntity>();
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public UserSessionEntity getSession() {
return session;
}
public void setSession(UserSessionEntity session) {
this.session = session;
}
public String getClientId() {
return clientId;
}
public void setClientId(String clientId) {
this.clientId = clientId;
}
public int getTimestamp() {
return timestamp;
}
public void setTimestamp(int timestamp) {
this.timestamp = timestamp;
}
public String getRedirectUri() {
return redirectUri;
}
public void setRedirectUri(String redirectUri) {
this.redirectUri = redirectUri;
}
public String getState() {
return state;
}
public void setState(String state) {
this.state = state;
}
public ClientSessionModel.Action getAction() {
return action;
}
public void setAction(ClientSessionModel.Action action) {
this.action = action;
}
public Collection<ClientSessionRoleEntity> getRoles() {
return roles;
}
public void setRoles(Collection<ClientSessionRoleEntity> roles) {
this.roles = roles;
}
}

View file

@ -0,0 +1,97 @@
package org.keycloak.models.sessions.jpa.entities;
import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.FetchType;
import javax.persistence.Id;
import javax.persistence.IdClass;
import javax.persistence.JoinColumn;
import javax.persistence.ManyToOne;
import javax.persistence.NamedQueries;
import javax.persistence.NamedQuery;
import javax.persistence.Table;
import java.io.Serializable;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
@NamedQueries({
@NamedQuery(name = "removeClientSessionRoleByUser", query="delete from ClientSessionRoleEntity r where r.clientSession IN (select c from ClientSessionEntity c where c.session IN (select s from UserSessionEntity s where s.realmId = :realmId and s.userId = :userId))"),
@NamedQuery(name = "removeClientSessionRoleByClient", query="delete from ClientSessionRoleEntity r where r.clientSession IN (select c from ClientSessionEntity c where c.clientId = :clientId and c.session IN (select s from UserSessionEntity s where s.realmId = :realmId))"),
@NamedQuery(name = "removeClientSessionRoleByRealm", query="delete from ClientSessionRoleEntity r where r.clientSession IN (select c from ClientSessionEntity c where c.session IN (select s from UserSessionEntity s where s.realmId = :realmId))"),
@NamedQuery(name = "removeClientSessionRoleByExpired", query = "delete from ClientSessionRoleEntity r where r.clientSession IN (select c from ClientSessionEntity c where c.session IN (select s from UserSessionEntity s where s.realmId = :realmId and (s.started < :maxTime or s.lastSessionRefresh < :idleTime)))")
})
@Table(name="CLIENT_SESSION_ROLE")
@Entity
@IdClass(ClientSessionRoleEntity.Key.class)
public class ClientSessionRoleEntity {
@Id
@ManyToOne(fetch= FetchType.LAZY)
@JoinColumn(name="CLIENT_SESSION")
protected ClientSessionEntity clientSession;
@Id
@Column(name = "ROLE_ID")
protected String roleId;
public ClientSessionEntity getClientSession() {
return clientSession;
}
public void setClientSession(ClientSessionEntity clientSession) {
this.clientSession = clientSession;
}
public String getRoleId() {
return roleId;
}
public void setRoleId(String roleId) {
this.roleId = roleId;
}
public static class Key implements Serializable {
protected ClientSessionEntity clientSession;
protected String roleId;
public Key() {
}
public Key(ClientSessionEntity clientSession, String roleId) {
this.clientSession = clientSession;
this.roleId = roleId;
}
public ClientSessionEntity getClientSession() {
return clientSession;
}
public String getRoleId() {
return roleId;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Key key = (Key) o;
if (!roleId.equals(key.roleId)) return false;
if (!clientSession.getId().equals(key.clientSession.getId())) return false;
return true;
}
@Override
public int hashCode() {
int result = clientSession.getId().hashCode();
result = 31 * result + roleId.hashCode();
return result;
}
}
}

View file

@ -1,97 +0,0 @@
package org.keycloak.models.sessions.jpa.entities;
import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.FetchType;
import javax.persistence.Id;
import javax.persistence.IdClass;
import javax.persistence.JoinColumn;
import javax.persistence.ManyToOne;
import javax.persistence.NamedQueries;
import javax.persistence.NamedQuery;
import javax.persistence.Table;
import java.io.Serializable;
/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
@Entity
@Table(name = "CLIENT_USERSESSION")
@NamedQueries({
@NamedQuery(name = "removeClientUserSessionByRealm", query = "delete from ClientUserSessionAssociationEntity a where a.session IN (select s from UserSessionEntity s where s.realmId = :realmId)"),
@NamedQuery(name = "removeClientUserSessionByUser", query = "delete from ClientUserSessionAssociationEntity a where a.session IN (select s from UserSessionEntity s where s.realmId = :realmId and s.userId = :userId)"),
@NamedQuery(name = "removeClientUserSessionByClient", query = "delete from ClientUserSessionAssociationEntity a where a.clientId = :clientId and a.session IN (select s from UserSessionEntity s where s.realmId = :realmId)"),
@NamedQuery(name = "removeClientUserSessionByExpired", query = "delete from ClientUserSessionAssociationEntity a where a.session IN (select s from UserSessionEntity s where s.realmId = :realmId and (s.started < :maxTime or s.lastSessionRefresh < :idleTime))")
})
@IdClass(ClientUserSessionAssociationEntity.Key.class)
public class ClientUserSessionAssociationEntity {
@Id
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "SESSION_ID")
protected UserSessionEntity session;
@Id
@Column(name="CLIENT_ID",length = 36)
protected String clientId;
public UserSessionEntity getSession() {
return session;
}
public void setSession(UserSessionEntity session) {
this.session = session;
}
public String getClientId() {
return clientId;
}
public void setClientId(String clientId) {
this.clientId = clientId;
}
public static class Key implements Serializable {
private String clientId;
private UserSessionEntity session;
public Key() {
}
public Key(String clientId, UserSessionEntity session) {
this.clientId = clientId;
this.session = session;
}
public String getClientId() {
return clientId;
}
public UserSessionEntity getSession() {
return session;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Key key = (Key) o;
if (clientId != null ? !clientId.equals(key.clientId) : key.clientId != null) return false;
if (session != null ? !session.getId().equals(key.session != null ? key.session.getId() : null) : key.session != null) return false;
return true;
}
@Override
public int hashCode() {
int result = clientId != null ? clientId.hashCode() : 0;
result = 31 * result + (session != null ? session.getId().hashCode() : 0);
return result;
}
}
}

View file

@ -20,8 +20,8 @@ import java.util.Collection;
@Table(name = "USER_SESSION") @Table(name = "USER_SESSION")
@NamedQueries({ @NamedQueries({
@NamedQuery(name = "getUserSessionByUser", query = "select s from UserSessionEntity s where s.realmId = :realmId and s.userId = :userId order by s.started, s.id"), @NamedQuery(name = "getUserSessionByUser", query = "select s from UserSessionEntity s where s.realmId = :realmId and s.userId = :userId order by s.started, s.id"),
@NamedQuery(name = "getUserSessionByClient", query = "select s from UserSessionEntity s join s.clients c where s.realmId = :realmId and c.clientId = :clientId order by s.started, s.id"), @NamedQuery(name = "getUserSessionByClient", query = "select s from UserSessionEntity s join s.clientSessions c where s.realmId = :realmId and c.clientId = :clientId order by s.started, s.id"),
@NamedQuery(name = "getActiveUserSessionByClient", query = "select count(s) from UserSessionEntity s join s.clients c where s.realmId = :realmId and c.clientId = :clientId"), @NamedQuery(name = "getActiveUserSessionByClient", query = "select count(s) from UserSessionEntity s join s.clientSessions c where s.realmId = :realmId and c.clientId = :clientId"),
@NamedQuery(name = "removeUserSessionByRealm", query = "delete from UserSessionEntity s where s.realmId = :realmId"), @NamedQuery(name = "removeUserSessionByRealm", query = "delete from UserSessionEntity s where s.realmId = :realmId"),
@NamedQuery(name = "removeUserSessionByUser", query = "delete from UserSessionEntity s where s.realmId = :realmId and s.userId = :userId"), @NamedQuery(name = "removeUserSessionByUser", query = "delete from UserSessionEntity s where s.realmId = :realmId and s.userId = :userId"),
@NamedQuery(name = "removeUserSessionByExpired", query = "delete from UserSessionEntity s where s.realmId = :realmId and (s.started < :maxTime or s.lastSessionRefresh < :idleTime)") @NamedQuery(name = "removeUserSessionByExpired", query = "delete from UserSessionEntity s where s.realmId = :realmId and (s.started < :maxTime or s.lastSessionRefresh < :idleTime)")
@ -56,8 +56,8 @@ public class UserSessionEntity {
@Column(name="LAST_SESSION_REFRESH") @Column(name="LAST_SESSION_REFRESH")
protected int lastSessionRefresh; protected int lastSessionRefresh;
@OneToMany(fetch = FetchType.LAZY, cascade = CascadeType.REMOVE, orphanRemoval = true, mappedBy="session") @OneToMany(cascade = CascadeType.REMOVE, orphanRemoval = true, mappedBy="session")
protected Collection<ClientUserSessionAssociationEntity> clients = new ArrayList<ClientUserSessionAssociationEntity>(); protected Collection<ClientSessionEntity> clientSessions = new ArrayList<ClientSessionEntity>();
public String getId() { public String getId() {
return id; return id;
@ -131,8 +131,8 @@ public class UserSessionEntity {
this.lastSessionRefresh = lastSessionRefresh; this.lastSessionRefresh = lastSessionRefresh;
} }
public Collection<ClientUserSessionAssociationEntity> getClients() { public Collection<ClientSessionEntity> getClientSessions() {
return clients; return clientSessions;
} }
} }

View file

@ -0,0 +1,79 @@
package org.keycloak.models.sessions.mem;
import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.sessions.mem.entities.ClientSessionEntity;
import java.util.Set;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public class ClientSessionAdapter implements ClientSessionModel {
private KeycloakSession session;
private MemUserSessionProvider provider;
private RealmModel realm;
private ClientSessionEntity entity;
public ClientSessionAdapter(KeycloakSession session, MemUserSessionProvider provider, RealmModel realm, ClientSessionEntity entity) {
this.session = session;
this.provider = provider;
this.realm = realm;
this.entity = entity;
}
@Override
public String getId() {
return entity.getId();
}
@Override
public ClientModel getClient() {
return realm.findClientById(entity.getClientId());
}
@Override
public String getState() {
return entity.getState();
}
@Override
public UserSessionModel getUserSession() {
return new UserSessionAdapter(session, provider, realm, entity.getSession());
}
@Override
public String getRedirectUri() {
return entity.getRedirectUri();
}
@Override
public int getTimestamp() {
return entity.getTimestamp();
}
@Override
public void setTimestamp(int timestamp) {
entity.setTimestamp(timestamp);
}
@Override
public ClientSessionModel.Action getAction() {
return entity.getAction();
}
@Override
public void setAction(ClientSessionModel.Action action) {
entity.setAction(action);
}
@Override
public Set<String> getRoles() {
return entity.getRoles();
}
}

View file

@ -1,14 +1,15 @@
package org.keycloak.models.sessions.mem; package org.keycloak.models.sessions.mem;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.UserSessionProvider; import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.UsernameLoginFailureModel; import org.keycloak.models.UsernameLoginFailureModel;
import org.keycloak.models.sessions.mem.entities.ClientSessionEntity;
import org.keycloak.models.sessions.mem.entities.UserSessionEntity; import org.keycloak.models.sessions.mem.entities.UserSessionEntity;
import org.keycloak.models.sessions.mem.entities.UserSessionKey;
import org.keycloak.models.sessions.mem.entities.UsernameLoginFailureEntity; import org.keycloak.models.sessions.mem.entities.UsernameLoginFailureEntity;
import org.keycloak.models.sessions.mem.entities.UsernameLoginFailureKey; import org.keycloak.models.sessions.mem.entities.UsernameLoginFailureKey;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
@ -19,6 +20,7 @@ import java.util.Comparator;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
/** /**
@ -27,15 +29,42 @@ import java.util.concurrent.ConcurrentHashMap;
public class MemUserSessionProvider implements UserSessionProvider { public class MemUserSessionProvider implements UserSessionProvider {
private final KeycloakSession session; private final KeycloakSession session;
private final ConcurrentHashMap<UserSessionKey, UserSessionEntity> sessions; private final ConcurrentHashMap<String, UserSessionEntity> userSessions;
private final ConcurrentHashMap<String, ClientSessionEntity> clientSessions;
private final ConcurrentHashMap<UsernameLoginFailureKey, UsernameLoginFailureEntity> loginFailures; private final ConcurrentHashMap<UsernameLoginFailureKey, UsernameLoginFailureEntity> loginFailures;
public MemUserSessionProvider(KeycloakSession session, ConcurrentHashMap<UserSessionKey, UserSessionEntity> sessions, ConcurrentHashMap<UsernameLoginFailureKey, UsernameLoginFailureEntity> loginFailures) { public MemUserSessionProvider(KeycloakSession session, ConcurrentHashMap<String, UserSessionEntity> userSessions, ConcurrentHashMap<String, ClientSessionEntity> clientSessions, ConcurrentHashMap<UsernameLoginFailureKey, UsernameLoginFailureEntity> loginFailures) {
this.session = session; this.session = session;
this.sessions = sessions; this.userSessions = userSessions;
this.clientSessions = clientSessions;
this.loginFailures = loginFailures; this.loginFailures = loginFailures;
} }
@Override
public ClientSessionModel createClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession, String redirectUri, String state, Set<String> roles) {
UserSessionEntity userSessionEntity = getUserSessionEntity(realm, userSession.getId());
ClientSessionEntity entity = new ClientSessionEntity();
entity.setId(KeycloakModelUtils.generateId());
entity.setTimestamp(Time.currentTime());
entity.setClientId(client.getId());
entity.setSession(userSessionEntity);
entity.setRedirectUri(redirectUri);
entity.setState(state);
entity.setRoles(roles);
userSessionEntity.addClientSession(entity);
clientSessions.put(entity.getId(), entity);
return new ClientSessionAdapter(session, this, realm, entity);
}
@Override
public ClientSessionModel getClientSession(RealmModel realm, String id) {
ClientSessionEntity entity = clientSessions.get(id);
return entity != null ? new ClientSessionAdapter(session, this, realm, entity) : null;
}
@Override @Override
public UserSessionModel createUserSession(RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe) { public UserSessionModel createUserSession(RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe) {
String id = KeycloakModelUtils.generateId(); String id = KeycloakModelUtils.generateId();
@ -54,23 +83,31 @@ public class MemUserSessionProvider implements UserSessionProvider {
entity.setStarted(currentTime); entity.setStarted(currentTime);
entity.setLastSessionRefresh(currentTime); entity.setLastSessionRefresh(currentTime);
sessions.put(new UserSessionKey(realm.getId(), id), entity); userSessions.put(id, entity);
return new UserSessionAdapter(session, realm, entity); return new UserSessionAdapter(session, this, realm, entity);
} }
@Override @Override
public UserSessionModel getUserSession(RealmModel realm, String id) { public UserSessionModel getUserSession(RealmModel realm, String id) {
UserSessionEntity entity = sessions.get(new UserSessionKey(realm.getId(), id)); UserSessionEntity entity = getUserSessionEntity(realm, id);
return entity != null ? new UserSessionAdapter(session, realm, entity) : null; return entity != null ? new UserSessionAdapter(session, this, realm, entity) : null;
}
UserSessionEntity getUserSessionEntity(RealmModel realm, String id) {
UserSessionEntity entity = userSessions.get(id);
if (entity != null && entity.getRealm().equals(realm.getId())) {
return entity;
}
return null;
} }
@Override @Override
public List<UserSessionModel> getUserSessions(RealmModel realm, UserModel user) { public List<UserSessionModel> getUserSessions(RealmModel realm, UserModel user) {
List<UserSessionModel> userSessions = new LinkedList<UserSessionModel>(); List<UserSessionModel> userSessions = new LinkedList<UserSessionModel>();
for (UserSessionEntity s : sessions.values()) { for (UserSessionEntity s : this.userSessions.values()) {
if (s.getRealm().equals(realm.getId()) && s.getUser().equals(user.getId())) { if (s.getRealm().equals(realm.getId()) && s.getUser().equals(user.getId())) {
userSessions.add(new UserSessionAdapter(session, realm, s)); userSessions.add(new UserSessionAdapter(session, this, realm, s));
} }
} }
return userSessions; return userSessions;
@ -78,14 +115,21 @@ public class MemUserSessionProvider implements UserSessionProvider {
@Override @Override
public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client) { public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client) {
List<UserSessionModel> clientSessions = new LinkedList<UserSessionModel>(); List<UserSessionEntity> userSessionEntities = new LinkedList<UserSessionEntity>();
for (UserSessionEntity s : sessions.values()) { for (ClientSessionEntity s : clientSessions.values()) {
if (s.getRealm().equals(realm.getId()) && s.getClients().contains(client.getClientId())) { if (s.getSession().getRealm().equals(realm.getId()) && s.getClientId().equals(client.getId())) {
clientSessions.add(new UserSessionAdapter(session, realm, s)); if (!userSessionEntities.contains(s.getSession())) {
userSessionEntities.add(s.getSession());
}
} }
} }
Collections.sort(clientSessions, new UserSessionSort());
return clientSessions; List<UserSessionModel> userSessions = new LinkedList<UserSessionModel>();
for (UserSessionEntity e : userSessionEntities) {
userSessions.add(new UserSessionAdapter(session, this, realm, e));
}
Collections.sort(userSessions, new UserSessionSort());
return userSessions;
} }
@Override @Override
@ -101,49 +145,61 @@ public class MemUserSessionProvider implements UserSessionProvider {
@Override @Override
public int getActiveUserSessions(RealmModel realm, ClientModel client) { public int getActiveUserSessions(RealmModel realm, ClientModel client) {
int count = 0; return getUserSessions(realm, client).size();
for (UserSessionEntity s : sessions.values()) {
if (s.getRealm().equals(realm.getId()) && s.getClients().contains(client.getClientId())) {
count++;
}
}
return count;
} }
@Override @Override
public void removeUserSession(RealmModel realm, UserSessionModel session) { public void removeUserSession(RealmModel realm, UserSessionModel session) {
sessions.remove(new UserSessionKey(realm.getId(), session.getId())); UserSessionEntity entity = getUserSessionEntity(realm, session.getId());
if (entity != null) {
userSessions.remove(entity.getId());
for (ClientSessionEntity clientSession : entity.getClientSessions()) {
clientSessions.remove(clientSession.getId());
}
}
} }
@Override @Override
public void removeUserSessions(RealmModel realm, UserModel user) { public void removeUserSessions(RealmModel realm, UserModel user) {
Iterator<UserSessionEntity> itr = sessions.values().iterator(); Iterator<UserSessionEntity> itr = userSessions.values().iterator();
while (itr.hasNext()) { while (itr.hasNext()) {
UserSessionEntity s = itr.next(); UserSessionEntity s = itr.next();
if (s.getRealm().equals(realm.getId()) && s.getUser().equals(user.getId())) { if (s.getRealm().equals(realm.getId()) && s.getUser().equals(user.getId())) {
itr.remove(); itr.remove();
for (ClientSessionEntity clientSession : s.getClientSessions()) {
clientSessions.remove(clientSession.getId());
}
} }
} }
} }
@Override @Override
public void removeExpiredUserSessions(RealmModel realm) { public void removeExpiredUserSessions(RealmModel realm) {
Iterator<UserSessionEntity> itr = sessions.values().iterator(); Iterator<UserSessionEntity> itr = userSessions.values().iterator();
while (itr.hasNext()) { while (itr.hasNext()) {
UserSessionEntity s = itr.next(); UserSessionEntity s = itr.next();
if (s.getRealm().equals(realm.getId()) && (s.getLastSessionRefresh() < Time.currentTime() - realm.getSsoSessionIdleTimeout() || s.getStarted() < Time.currentTime() - realm.getSsoSessionMaxLifespan())) { if (s.getRealm().equals(realm.getId()) && (s.getLastSessionRefresh() < Time.currentTime() - realm.getSsoSessionIdleTimeout() || s.getStarted() < Time.currentTime() - realm.getSsoSessionMaxLifespan())) {
itr.remove(); itr.remove();
for (ClientSessionEntity clientSession : s.getClientSessions()) {
clientSessions.remove(clientSession.getId());
}
} }
} }
} }
@Override @Override
public void removeUserSessions(RealmModel realm) { public void removeUserSessions(RealmModel realm) {
Iterator<UserSessionEntity> itr = sessions.values().iterator(); Iterator<UserSessionEntity> itr = userSessions.values().iterator();
while (itr.hasNext()) { while (itr.hasNext()) {
UserSessionEntity s = itr.next(); UserSessionEntity s = itr.next();
if (s.getRealm().equals(realm.getId())) { if (s.getRealm().equals(realm.getId())) {
itr.remove(); itr.remove();
for (ClientSessionEntity clientSession : s.getClientSessions()) {
clientSessions.remove(clientSession.getId());
}
} }
} }
} }
@ -185,17 +241,10 @@ public class MemUserSessionProvider implements UserSessionProvider {
@Override @Override
public void onClientRemoved(RealmModel realm, ClientModel client) { public void onClientRemoved(RealmModel realm, ClientModel client) {
Iterator<UserSessionEntity> itr = sessions.values().iterator(); for (ClientSessionEntity e : clientSessions.values()) {
while (itr.hasNext()) { if (e.getSession().getRealm().equals(realm.getId()) && e.getClientId().equals(client.getId())) {
UserSessionEntity s = itr.next(); clientSessions.remove(e.getId());
if (s.getRealm().equals(realm.getId())) { e.getSession().removeClientSession(e);
itr.remove();
}
}
for (UserSessionEntity s : sessions.values()) {
if (s.getRealm().equals(realm.getId())) {
s.getClients().remove(client.getClientId());
} }
} }
} }

View file

@ -4,8 +4,8 @@ import org.keycloak.Config;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.UserSessionProvider; import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.UserSessionProviderFactory; import org.keycloak.models.UserSessionProviderFactory;
import org.keycloak.models.sessions.mem.entities.ClientSessionEntity;
import org.keycloak.models.sessions.mem.entities.UserSessionEntity; import org.keycloak.models.sessions.mem.entities.UserSessionEntity;
import org.keycloak.models.sessions.mem.entities.UserSessionKey;
import org.keycloak.models.sessions.mem.entities.UsernameLoginFailureEntity; import org.keycloak.models.sessions.mem.entities.UsernameLoginFailureEntity;
import org.keycloak.models.sessions.mem.entities.UsernameLoginFailureKey; import org.keycloak.models.sessions.mem.entities.UsernameLoginFailureKey;
@ -18,13 +18,15 @@ public class MemUserSessionProviderFactory implements UserSessionProviderFactory
public static final String ID = "mem"; public static final String ID = "mem";
private ConcurrentHashMap<UserSessionKey, UserSessionEntity> sessions = new ConcurrentHashMap<UserSessionKey, UserSessionEntity>(); private ConcurrentHashMap<String, UserSessionEntity> userSessions = new ConcurrentHashMap<String, UserSessionEntity>();
private ConcurrentHashMap<String, ClientSessionEntity> clientSessions = new ConcurrentHashMap<String, ClientSessionEntity>();
private ConcurrentHashMap<UsernameLoginFailureKey, UsernameLoginFailureEntity> loginFailures = new ConcurrentHashMap<UsernameLoginFailureKey, UsernameLoginFailureEntity>(); private ConcurrentHashMap<UsernameLoginFailureKey, UsernameLoginFailureEntity> loginFailures = new ConcurrentHashMap<UsernameLoginFailureKey, UsernameLoginFailureEntity>();
@Override @Override
public UserSessionProvider create(KeycloakSession session) { public UserSessionProvider create(KeycloakSession session) {
return new MemUserSessionProvider(session, sessions, loginFailures); return new MemUserSessionProvider(session, userSessions, clientSessions, loginFailures);
} }
@Override @Override
@ -33,7 +35,7 @@ public class MemUserSessionProviderFactory implements UserSessionProviderFactory
@Override @Override
public void close() { public void close() {
sessions.clear(); userSessions.clear();
loginFailures.clear(); loginFailures.clear();
} }

View file

@ -1,10 +1,12 @@
package org.keycloak.models.sessions.mem; package org.keycloak.models.sessions.mem;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.sessions.mem.entities.ClientSessionEntity;
import org.keycloak.models.sessions.mem.entities.UserSessionEntity; import org.keycloak.models.sessions.mem.entities.UserSessionEntity;
import java.util.LinkedList; import java.util.LinkedList;
@ -17,12 +19,14 @@ public class UserSessionAdapter implements UserSessionModel {
private final KeycloakSession session; private final KeycloakSession session;
private MemUserSessionProvider provider;
private final RealmModel realm; private final RealmModel realm;
private final UserSessionEntity entity; private final UserSessionEntity entity;
public UserSessionAdapter(KeycloakSession session, RealmModel realm, UserSessionEntity entity) { public UserSessionAdapter(KeycloakSession session, MemUserSessionProvider provider, RealmModel realm, UserSessionEntity entity) {
this.session = session; this.session = session;
this.provider = provider;
this.realm = realm; this.realm = realm;
this.entity = entity; this.entity = entity;
} }
@ -98,24 +102,14 @@ public class UserSessionAdapter implements UserSessionModel {
} }
@Override @Override
public void associateClient(ClientModel client) { public List<ClientSessionModel> getClientSessions() {
if (!entity.getClients().contains(client.getClientId())) { List<ClientSessionModel> clientSessionModels = new LinkedList<ClientSessionModel>();
entity.getClients().add(client.getClientId()); if (entity.getClientSessions() != null) {
for (ClientSessionEntity e : entity.getClientSessions()) {
clientSessionModels.add(new ClientSessionAdapter(session, provider, realm, e));
}
} }
} return clientSessionModels;
@Override
public List<ClientModel> getClientAssociations() {
List<ClientModel> models = new LinkedList<ClientModel>();
for (String clientId : entity.getClients()) {
models.add(realm.findClient(clientId));
}
return models;
}
@Override
public void removeAssociatedClient(ClientModel client) {
entity.getClients().remove(client.getClientId());
} }
@Override @Override

View file

@ -0,0 +1,87 @@
package org.keycloak.models.sessions.mem.entities;
import org.keycloak.models.ClientSessionModel;
import java.util.Set;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public class ClientSessionEntity {
private String id;
private String clientId;
private UserSessionEntity session;
private String redirectUri;
private String state;
private int timestamp;
private ClientSessionModel.Action action;
private Set<String> roles;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getClientId() {
return clientId;
}
public void setClientId(String clientId) {
this.clientId = clientId;
}
public UserSessionEntity getSession() {
return session;
}
public void setSession(UserSessionEntity session) {
this.session = session;
}
public String getRedirectUri() {
return redirectUri;
}
public void setRedirectUri(String redirectUri) {
this.redirectUri = redirectUri;
}
public String getState() {
return state;
}
public void setState(String state) {
this.state = state;
}
public int getTimestamp() {
return timestamp;
}
public void setTimestamp(int timestamp) {
this.timestamp = timestamp;
}
public ClientSessionModel.Action getAction() {
return action;
}
public void setAction(ClientSessionModel.Action action) {
this.action = action;
}
public Set<String> getRoles() {
return roles;
}
public void setRoles(Set<String> roles) {
this.roles = roles;
}
}

View file

@ -1,5 +1,6 @@
package org.keycloak.models.sessions.mem.entities; package org.keycloak.models.sessions.mem.entities;
import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@ -17,7 +18,7 @@ public class UserSessionEntity {
private boolean rememberMe; private boolean rememberMe;
private int started; private int started;
private int lastSessionRefresh; private int lastSessionRefresh;
private List<String> clients = new LinkedList<String>(); private List<ClientSessionEntity> clientSessions = Collections.synchronizedList(new LinkedList<ClientSessionEntity>());
public String getId() { public String getId() {
return id; return id;
@ -91,12 +92,21 @@ public class UserSessionEntity {
this.lastSessionRefresh = lastSessionRefresh; this.lastSessionRefresh = lastSessionRefresh;
} }
public List<String> getClients() { public void addClientSession(ClientSessionEntity clientSession) {
return clients; if (clientSessions == null) {
clientSessions = new LinkedList<ClientSessionEntity>();
}
clientSessions.add(clientSession);
} }
public void setClients(List<String> clients) { public void removeClientSession(ClientSessionEntity clientSession) {
this.clients = clients; if (clientSessions != null) {
clientSessions.remove(clientSession);
}
}
public List<ClientSessionEntity> getClientSessions() {
return clientSessions;
} }
} }

View file

@ -1,36 +0,0 @@
package org.keycloak.models.sessions.mem.entities;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public class UserSessionKey {
private final String realm;
private final String id;
public UserSessionKey(String realm, String id) {
this.realm = realm;
this.id = id;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
UserSessionKey key = (UserSessionKey) o;
if (realm != null ? !realm.equals(key.realm) : key.realm != null) return false;
if (id != null ? !id.equals(key.id) : key.id != null) return false;
return true;
}
@Override
public int hashCode() {
int result = realm != null ? realm.hashCode() : 0;
result = 31 * result + (id != null ? id.hashCode() : 0);
return result;
}
}

View file

@ -0,0 +1,88 @@
package org.keycloak.models.sessions.mongo;
import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext;
import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.sessions.mongo.entities.MongoClientSessionEntity;
import org.keycloak.models.sessions.mongo.entities.MongoUserSessionEntity;
import java.util.HashSet;
import java.util.Set;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public class ClientSessionAdapter implements ClientSessionModel {
private KeycloakSession session;
private MongoUserSessionProvider provider;
private RealmModel realm;
private MongoClientSessionEntity entity;
private MongoUserSessionEntity userSessionEntity;
private MongoStoreInvocationContext invContext;
public ClientSessionAdapter(KeycloakSession session, MongoUserSessionProvider provider, RealmModel realm, MongoClientSessionEntity entity, MongoUserSessionEntity userSessionEntity, MongoStoreInvocationContext invContext) {
this.session = session;
this.provider = provider;
this.realm = realm;
this.entity = entity;
this.userSessionEntity = userSessionEntity;
this.invContext = invContext;
}
@Override
public String getId() {
return entity.getId();
}
@Override
public ClientModel getClient() {
return realm.findClientById(entity.getClientId());
}
@Override
public String getState() {
return entity.getState();
}
@Override
public UserSessionModel getUserSession() {
return new UserSessionAdapter(session, provider, userSessionEntity, realm, invContext);
}
@Override
public String getRedirectUri() {
return entity.getRedirectUri();
}
@Override
public int getTimestamp() {
return entity.getTimestamp();
}
@Override
public void setTimestamp(int timestamp) {
entity.setTimestamp(timestamp);
invContext.getMongoStore().updateEntity(userSessionEntity, invContext);
}
@Override
public Action getAction() {
return entity.getAction();
}
@Override
public void setAction(Action action) {
entity.setAction(action);
invContext.getMongoStore().updateEntity(userSessionEntity, invContext);
}
@Override
public Set<String> getRoles() {
return entity.getRoles() != null ? new HashSet<String>(entity.getRoles()) : null;
}
}

View file

@ -6,18 +6,22 @@ import com.mongodb.QueryBuilder;
import org.keycloak.connections.mongo.api.MongoStore; import org.keycloak.connections.mongo.api.MongoStore;
import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext; import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.UserSessionProvider; import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.UsernameLoginFailureModel; import org.keycloak.models.UsernameLoginFailureModel;
import org.keycloak.models.sessions.mongo.entities.MongoClientSessionEntity;
import org.keycloak.models.sessions.mongo.entities.MongoUserSessionEntity; import org.keycloak.models.sessions.mongo.entities.MongoUserSessionEntity;
import org.keycloak.models.sessions.mongo.entities.MongoUsernameLoginFailureEntity; import org.keycloak.models.sessions.mongo.entities.MongoUsernameLoginFailureEntity;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.util.Time; import org.keycloak.util.Time;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Set;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
@ -34,6 +38,47 @@ public class MongoUserSessionProvider implements UserSessionProvider {
this.invocationContext = invocationContext; this.invocationContext = invocationContext;
} }
@Override
public ClientSessionModel createClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession, String redirectUri, String state, Set<String> roles) {
MongoUserSessionEntity userSessionEntity = getUserSessionEntity(realm, userSession.getId());
MongoClientSessionEntity entity = new MongoClientSessionEntity();
entity.setId(KeycloakModelUtils.generateId());
entity.setTimestamp(Time.currentTime());
entity.setClientId(client.getId());
entity.setRedirectUri(redirectUri);
entity.setState(state);
if (roles != null) {
entity.setRoles(new LinkedList<String>(roles));
}
mongoStore.pushItemToList(userSessionEntity, "clientSessions", entity, false, invocationContext);
return new ClientSessionAdapter(session, this, realm, entity, userSessionEntity, invocationContext);
}
@Override
public ClientSessionModel getClientSession(RealmModel realm, String id) {
DBObject query = new QueryBuilder()
.and("realmId").is(realm.getId())
.and("clientSessions.id").is(id).get();
List<MongoUserSessionEntity> entities = mongoStore.loadEntities(MongoUserSessionEntity.class, query, invocationContext);
if (entities.isEmpty()) {
return null;
}
MongoUserSessionEntity userSessionEntity = entities.get(0);
List<MongoClientSessionEntity> sessions = userSessionEntity.getClientSessions();
for (MongoClientSessionEntity s : sessions) {
if (s.getId().equals(id)) {
return new ClientSessionAdapter(session, this, realm, s, userSessionEntity, invocationContext);
}
}
return null;
}
@Override @Override
public UserSessionModel createUserSession(RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe) { public UserSessionModel createUserSession(RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe) {
MongoUserSessionEntity entity = new MongoUserSessionEntity(); MongoUserSessionEntity entity = new MongoUserSessionEntity();
@ -50,25 +95,29 @@ public class MongoUserSessionProvider implements UserSessionProvider {
entity.setLastSessionRefresh(currentTime); entity.setLastSessionRefresh(currentTime);
mongoStore.insertEntity(entity, invocationContext); mongoStore.insertEntity(entity, invocationContext);
return new UserSessionAdapter(session, entity, realm, invocationContext); return new UserSessionAdapter(session, this, entity, realm, invocationContext);
} }
@Override @Override
public UserSessionModel getUserSession(RealmModel realm, String id) { public UserSessionModel getUserSession(RealmModel realm, String id) {
MongoUserSessionEntity entity = mongoStore.loadEntity(MongoUserSessionEntity.class, id, invocationContext); MongoUserSessionEntity entity = getUserSessionEntity(realm, id);
if (entity == null) { if (entity == null) {
return null; return null;
} else { } else {
return new UserSessionAdapter(session, entity, realm, invocationContext); return new UserSessionAdapter(session, this, entity, realm, invocationContext);
} }
} }
MongoUserSessionEntity getUserSessionEntity(RealmModel realm, String id) {
return mongoStore.loadEntity(MongoUserSessionEntity.class, id, invocationContext);
}
@Override @Override
public List<UserSessionModel> getUserSessions(RealmModel realm, UserModel user) { public List<UserSessionModel> getUserSessions(RealmModel realm, UserModel user) {
DBObject query = new BasicDBObject("user", user.getId()); DBObject query = new BasicDBObject("user", user.getId());
List<UserSessionModel> sessions = new LinkedList<UserSessionModel>(); List<UserSessionModel> sessions = new LinkedList<UserSessionModel>();
for (MongoUserSessionEntity e : mongoStore.loadEntities(MongoUserSessionEntity.class, query, invocationContext)) { for (MongoUserSessionEntity e : mongoStore.loadEntities(MongoUserSessionEntity.class, query, invocationContext)) {
sessions.add(new UserSessionAdapter(session, e, realm, invocationContext)); sessions.add(new UserSessionAdapter(session, this, e, realm, invocationContext));
} }
return sessions; return sessions;
} }
@ -80,14 +129,14 @@ public class MongoUserSessionProvider implements UserSessionProvider {
public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client, int firstResult, int maxResults) { public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client, int firstResult, int maxResults) {
DBObject query = new QueryBuilder() DBObject query = new QueryBuilder()
.and("associatedClientIds").is(client.getId()) .and("clientSessions.clientId").is(client.getId())
.get(); .get();
DBObject sort = new BasicDBObject("started", 1).append("id", 1); DBObject sort = new BasicDBObject("started", 1).append("id", 1);
List<MongoUserSessionEntity> sessions = mongoStore.loadEntities(MongoUserSessionEntity.class, query, sort, firstResult, maxResults, invocationContext); List<MongoUserSessionEntity> sessions = mongoStore.loadEntities(MongoUserSessionEntity.class, query, sort, firstResult, maxResults, invocationContext);
List<UserSessionModel> result = new LinkedList<UserSessionModel>(); List<UserSessionModel> result = new LinkedList<UserSessionModel>();
for (MongoUserSessionEntity session : sessions) { for (MongoUserSessionEntity session : sessions) {
result.add(new UserSessionAdapter(this.session, session, realm, invocationContext)); result.add(new UserSessionAdapter(this.session, this, session, realm, invocationContext));
} }
return result; return result;
} }
@ -95,7 +144,7 @@ public class MongoUserSessionProvider implements UserSessionProvider {
@Override @Override
public int getActiveUserSessions(RealmModel realm, ClientModel client) { public int getActiveUserSessions(RealmModel realm, ClientModel client) {
DBObject query = new QueryBuilder() DBObject query = new QueryBuilder()
.and("associatedClientIds").is(client.getId()) .and("clientSessions.clientId").is(client.getId())
.get(); .get();
return mongoStore.countEntities(MongoUserSessionEntity.class, query, invocationContext); return mongoStore.countEntities(MongoUserSessionEntity.class, query, invocationContext);
} }
@ -184,13 +233,22 @@ public class MongoUserSessionProvider implements UserSessionProvider {
} }
@Override @Override
// TODO Not very efficient, should use Mongo $pull to remove directly
public void onClientRemoved(RealmModel realm, ClientModel client) { public void onClientRemoved(RealmModel realm, ClientModel client) {
DBObject query = new QueryBuilder() DBObject query = new QueryBuilder()
.and("realmId").is(realm.getId()) .and("clientSessions.clientId").is(client.getId())
.get(); .get();
List<MongoUserSessionEntity> sessions = invocationContext.getMongoStore().loadEntities(MongoUserSessionEntity.class, query, invocationContext); List<MongoUserSessionEntity> userSessionEntities = mongoStore.loadEntities(MongoUserSessionEntity.class, query, invocationContext);
for (MongoUserSessionEntity session : sessions) { for (MongoUserSessionEntity e : userSessionEntities) {
invocationContext.getMongoStore().pullItemFromList(session, "associatedClientIds", client.getClientId(), invocationContext); List<MongoClientSessionEntity> remove = new LinkedList<MongoClientSessionEntity>();
for (MongoClientSessionEntity c : e.getClientSessions()) {
if (c.getClientId().equals(client.getId())) {
remove.add(c);
}
}
for (MongoClientSessionEntity c : remove) {
mongoStore.pullItemFromList(e, "clientSessions", c, invocationContext);
}
} }
} }

View file

@ -2,14 +2,15 @@ package org.keycloak.models.sessions.mongo;
import org.jboss.logging.Logger; import org.jboss.logging.Logger;
import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext; import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.sessions.mongo.entities.MongoClientSessionEntity;
import org.keycloak.models.sessions.mongo.entities.MongoUserSessionEntity; import org.keycloak.models.sessions.mongo.entities.MongoUserSessionEntity;
import java.util.ArrayList; import java.util.LinkedList;
import java.util.List; import java.util.List;
/** /**
@ -19,16 +20,19 @@ public class UserSessionAdapter extends AbstractMongoAdapter<MongoUserSessionEnt
private static final Logger logger = Logger.getLogger(UserSessionAdapter.class); private static final Logger logger = Logger.getLogger(UserSessionAdapter.class);
private final MongoUserSessionProvider provider;
private MongoUserSessionEntity entity; private MongoUserSessionEntity entity;
private RealmModel realm; private RealmModel realm;
private KeycloakSession keycloakSession; private KeycloakSession keycloakSession;
private final MongoStoreInvocationContext invContext;
public UserSessionAdapter(KeycloakSession keycloakSession, MongoUserSessionEntity entity, RealmModel realm, MongoStoreInvocationContext invContext) public UserSessionAdapter(KeycloakSession keycloakSession, MongoUserSessionProvider provider, MongoUserSessionEntity entity, RealmModel realm, MongoStoreInvocationContext invContext) {
{
super(invContext); super(invContext);
this.provider = provider;
this.entity = entity; this.entity = entity;
this.realm = realm; this.realm = realm;
this.keycloakSession = keycloakSession; this.keycloakSession = keycloakSession;
this.invContext = invContext;
} }
@Override @Override
@ -125,36 +129,12 @@ public class UserSessionAdapter extends AbstractMongoAdapter<MongoUserSessionEnt
} }
@Override @Override
public void associateClient(ClientModel client) { public List<ClientSessionModel> getClientSessions() {
getMongoStore().pushItemToList(entity, "associatedClientIds", client.getId(), true, invocationContext); List<ClientSessionModel> sessions = new LinkedList<ClientSessionModel>();
} for (MongoClientSessionEntity e : entity.getClientSessions()) {
sessions.add(new ClientSessionAdapter(keycloakSession, provider, realm, e, entity, invocationContext));
@Override
public List<ClientModel> getClientAssociations() {
List<String> associatedClientIds = getMongoEntity().getAssociatedClientIds();
List<ClientModel> clients = new ArrayList<ClientModel>();
for (String clientId : associatedClientIds) {
// Try application first
ClientModel client = realm.getApplicationById(clientId);
// And then OAuthClient
if (client == null) {
client = realm.getOAuthClientById(clientId);
}
if (client != null) {
clients.add(client);
} else {
logger.warnf("Not found associated client with Id: %s", clientId);
}
} }
return clients; return sessions;
}
@Override
public void removeAssociatedClient(ClientModel client) {
getMongoStore().pullItemFromList(entity, "associatedClientIds", client.getId(), invocationContext);
} }
@Override @Override

View file

@ -0,0 +1,80 @@
package org.keycloak.models.sessions.mongo.entities;
import org.keycloak.models.ClientSessionModel;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public class MongoClientSessionEntity {
private String id;
private String clientId;
private String redirectUri;
private String state;
private int timestamp;
private ClientSessionModel.Action action;
private List<String> roles;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getClientId() {
return clientId;
}
public void setClientId(String clientId) {
this.clientId = clientId;
}
public String getRedirectUri() {
return redirectUri;
}
public void setRedirectUri(String redirectUri) {
this.redirectUri = redirectUri;
}
public String getState() {
return state;
}
public void setState(String state) {
this.state = state;
}
public int getTimestamp() {
return timestamp;
}
public void setTimestamp(int timestamp) {
this.timestamp = timestamp;
}
public ClientSessionModel.Action getAction() {
return action;
}
public void setAction(ClientSessionModel.Action action) {
this.action = action;
}
public List<String> getRoles() {
return roles;
}
public void setRoles(List<String> roles) {
this.roles = roles;
}
}

View file

@ -6,6 +6,7 @@ import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext;
import org.keycloak.models.entities.AbstractIdentifiableEntity; import org.keycloak.models.entities.AbstractIdentifiableEntity;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List; import java.util.List;
/** /**
@ -30,7 +31,7 @@ public class MongoUserSessionEntity extends AbstractIdentifiableEntity implement
private int lastSessionRefresh; private int lastSessionRefresh;
private List<String> associatedClientIds = new ArrayList<String>(); private List<MongoClientSessionEntity> clientSessions;
public String getRealmId() { public String getRealmId() {
return realmId; return realmId;
@ -96,12 +97,12 @@ public class MongoUserSessionEntity extends AbstractIdentifiableEntity implement
this.lastSessionRefresh = lastSessionRefresh; this.lastSessionRefresh = lastSessionRefresh;
} }
public List<String> getAssociatedClientIds() { public List<MongoClientSessionEntity> getClientSessions() {
return associatedClientIds; return clientSessions;
} }
public void setAssociatedClientIds(List<String> associatedClientIds) { public void setClientSessions(List<MongoClientSessionEntity> clientSessions) {
this.associatedClientIds = associatedClientIds; this.clientSessions = clientSessions;
} }
@Override @Override

View file

@ -0,0 +1,171 @@
package org.keycloak.services.managers;
import org.keycloak.OAuthErrorException;
import org.keycloak.jose.jws.Algorithm;
import org.keycloak.jose.jws.crypto.RSAProvider;
import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.RoleModel;
import org.keycloak.models.UserModel;
import org.keycloak.models.UserModel.RequiredAction;
import org.keycloak.util.Base64Url;
import org.keycloak.util.Time;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.security.Signature;
import java.util.HashSet;
import java.util.Set;
/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public class AccessCode {
private final RealmModel realm;
private final ClientSessionModel clientSession;
public AccessCode(RealmModel realm, ClientSessionModel clientSession) {
this.realm = realm;
this.clientSession = clientSession;
}
public static AccessCode parse(String code, KeycloakSession session, RealmModel realm) {
try {
String[] parts = code.split("\\.");
String id = new String(Base64Url.decode(parts[1]));
ClientSessionModel clientSession = session.sessions().getClientSession(realm, id);
if (clientSession == null) {
return null;
}
String hash = createSignatureHash(realm, clientSession);
if (!hash.equals(parts[0])) {
return null;
}
return new AccessCode(realm, clientSession);
} catch (RuntimeException e) {
return null;
}
}
public String getCodeId() {
return clientSession.getId();
}
public UserModel getUser() {
return clientSession.getUserSession().getUser();
}
public String getSessionState() {
return clientSession.getUserSession().getId();
}
public boolean isValid(RequiredAction requiredAction) {
return isValid(convertToAction(requiredAction));
}
public boolean isValid(ClientSessionModel.Action requestedAction) {
ClientSessionModel.Action action = clientSession.getAction();
if (action == null) {
return false;
}
int timestamp = clientSession.getTimestamp();
if (!action.equals(requestedAction)) {
return false;
}
int lifespan = action.equals(ClientSessionModel.Action.CODE_TO_TOKEN) ? realm.getAccessCodeLifespan() : realm.getAccessCodeLifespanUserAction();
return timestamp + lifespan > Time.currentTime();
}
public Set<RoleModel> getRequestedRoles() {
Set<RoleModel> requestedRoles = new HashSet<RoleModel>();
for (String roleId : clientSession.getRoles()) {
RoleModel role = realm.getRoleById(roleId);
if (role == null) {
new OAuthErrorException(OAuthErrorException.INVALID_GRANT, "Invalid role " + roleId);
}
requestedRoles.add(realm.getRoleById(roleId));
}
return requestedRoles;
}
public ClientModel getClient() {
return clientSession.getClient();
}
public String getState() {
return clientSession.getState();
}
public String getRedirectUri() {
return clientSession.getRedirectUri();
}
public ClientSessionModel.Action getAction() {
return clientSession.getAction();
}
public void setAction(ClientSessionModel.Action action) {
clientSession.setAction(action);
clientSession.setTimestamp(Time.currentTime());
}
public void setRequiredAction(RequiredAction requiredAction) {
setAction(convertToAction(requiredAction));
}
private ClientSessionModel.Action convertToAction(RequiredAction requiredAction) {
switch (requiredAction) {
case CONFIGURE_TOTP:
return ClientSessionModel.Action.CONFIGURE_TOTP;
case UPDATE_PASSWORD:
return ClientSessionModel.Action.UPDATE_PASSWORD;
case UPDATE_PROFILE:
return ClientSessionModel.Action.UPDATE_PROFILE;
case VERIFY_EMAIL:
return ClientSessionModel.Action.VERIFY_EMAIL;
default:
throw new IllegalArgumentException("Unknown required action " + requiredAction);
}
}
public String getCode() {
String hash = createSignatureHash(realm, clientSession);
StringBuilder sb = new StringBuilder();
sb.append(hash);
sb.append(".");
sb.append(Base64Url.encode(clientSession.getId().getBytes()));
return sb.toString();
}
private static String createSignatureHash(RealmModel realm, ClientSessionModel clientSession) {
try {
Signature signature = Signature.getInstance(RSAProvider.getJavaAlgorithm(Algorithm.RS256));
signature.initSign(realm.getPrivateKey());
signature.update(clientSession.getId().getBytes());
signature.update(ByteBuffer.allocate(4).putInt(clientSession.getTimestamp()));
if (clientSession.getAction() != null) {
signature.update(clientSession.getAction().toString().getBytes());
}
byte[] sign = signature.sign();
MessageDigest digest = MessageDigest.getInstance("sha-1");
digest.update(sign);
return Base64Url.encode(digest.digest());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View file

@ -1,125 +0,0 @@
package org.keycloak.services.managers;
import org.keycloak.OAuthErrorException;
import org.keycloak.jose.jws.JWSBuilder;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.RoleModel;
import org.keycloak.models.UserModel;
import org.keycloak.models.UserModel.RequiredAction;
import org.keycloak.representations.AccessCode;
import org.keycloak.util.Time;
import java.util.HashSet;
import java.util.Set;
/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public class AccessCodeEntry {
protected AccessCode accessCode;
protected RealmModel realm;
KeycloakSession keycloakSession;
public AccessCodeEntry(KeycloakSession keycloakSession, RealmModel realm, AccessCode accessCode) {
this.realm = realm;
this.accessCode = accessCode;
this.keycloakSession = keycloakSession;
}
public String getCodeId() {
return this.accessCode.getId();
}
public UserModel getUser() {
return keycloakSession.users().getUserById(accessCode.getUserId(), realm);
}
public String getSessionState() {
return accessCode.getSessionState();
}
public boolean isExpired() {
int lifespan = accessCode.getAction() == null ? realm.getAccessCodeLifespan() : realm.getAccessCodeLifespanUserAction();
return accessCode.getTimestamp() + lifespan < Time.currentTime();
}
public Set<RoleModel> getRequestedRoles() {
Set<RoleModel> requestedRoles = new HashSet<RoleModel>();
for (String roleId : accessCode.getRequestedRoles()) {
RoleModel role = realm.getRoleById(roleId);
if (role == null) {
new OAuthErrorException(OAuthErrorException.INVALID_GRANT, "Invalid role " + roleId);
}
requestedRoles.add(realm.getRoleById(roleId));
}
return requestedRoles;
}
public ClientModel getClient() {
return realm.findClient(accessCode.getClientId());
}
public String getState() {
return accessCode.getState();
}
public void setState(String state) {
accessCode.setState(state);
}
public String getRedirectUri() {
return accessCode.getRedirectUri();
}
public AccessCode.Action getAction() {
return accessCode.getAction();
}
public void setAction(AccessCode.Action action) {
accessCode.setAction(action);
accessCode.setTimestamp(Time.currentTime());
}
public RequiredAction getRequiredAction() {
AccessCode.Action action = accessCode.getAction();
if (action != null) {
switch (action) {
case CONFIGURE_TOTP:
return RequiredAction.CONFIGURE_TOTP;
case UPDATE_PASSWORD:
return RequiredAction.UPDATE_PASSWORD;
case UPDATE_PROFILE:
return RequiredAction.UPDATE_PROFILE;
case VERIFY_EMAIL:
return RequiredAction.VERIFY_EMAIL;
}
}
return null;
}
public void setRequiredAction(RequiredAction requiredAction) {
switch (requiredAction) {
case CONFIGURE_TOTP:
setAction(AccessCode.Action.CONFIGURE_TOTP);
break;
case UPDATE_PASSWORD:
setAction(AccessCode.Action.UPDATE_PASSWORD);
break;
case UPDATE_PROFILE:
setAction(AccessCode.Action.UPDATE_PROFILE);
break;
case VERIFY_EMAIL:
setAction(AccessCode.Action.VERIFY_EMAIL);
break;
default:
throw new IllegalArgumentException("Unknown required action " + requiredAction);
}
}
public String getCode() {
return new JWSBuilder().jsonContent(accessCode).rsa256(realm.getPrivateKey());
}
}

View file

@ -10,13 +10,13 @@ import org.keycloak.jose.jws.crypto.RSAProvider;
import org.keycloak.models.ApplicationModel; import org.keycloak.models.ApplicationModel;
import org.keycloak.models.ClaimMask; import org.keycloak.models.ClaimMask;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.RoleModel; import org.keycloak.models.RoleModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.representations.AccessCode;
import org.keycloak.representations.AccessToken; import org.keycloak.representations.AccessToken;
import org.keycloak.representations.AccessTokenResponse; import org.keycloak.representations.AccessTokenResponse;
import org.keycloak.representations.IDToken; import org.keycloak.representations.IDToken;
@ -28,7 +28,6 @@ import java.io.IOException;
import java.util.HashSet; import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.UUID;
/** /**
* Stateful object that creates tokens and manages oauth access codes * Stateful object that creates tokens and manages oauth access codes
@ -39,22 +38,6 @@ import java.util.UUID;
public class TokenManager { public class TokenManager {
protected static final Logger logger = Logger.getLogger(TokenManager.class); protected static final Logger logger = Logger.getLogger(TokenManager.class);
public AccessCodeEntry parseCode(String code, KeycloakSession session, RealmModel realm) {
try {
JWSInput input = new JWSInput(code);
if (!RSAProvider.verify(input, realm.getPublicKey())) {
logger.error("Could not verify access code");
return null;
}
AccessCode accessCode = input.readJsonContent(AccessCode.class);
return new AccessCodeEntry(session, realm, accessCode);
} catch (Exception e) {
logger.error("error parsing access code", e);
return null;
}
}
public static void applyScope(RoleModel role, RoleModel scope, Set<RoleModel> visited, Set<RoleModel> requested) { public static void applyScope(RoleModel role, RoleModel scope, Set<RoleModel> visited, Set<RoleModel> requested) {
if (visited.contains(scope)) return; if (visited.contains(scope)) return;
visited.add(scope); visited.add(scope);
@ -69,30 +52,14 @@ public class TokenManager {
} }
} }
public AccessCode createAccessCode(String scopeParam, String state, String redirect, KeycloakSession session, RealmModel realm, ClientModel client, UserModel user, UserSessionModel userSession) {
public AccessCodeEntry createAccessCode(String scopeParam, String state, String redirect, KeycloakSession keycloakSession, RealmModel realm, ClientModel client, UserModel user, UserSessionModel session) {
return createAccessCodeEntry(scopeParam, state, redirect, keycloakSession, realm, client, user, session);
}
private AccessCodeEntry createAccessCodeEntry(String scopeParam, String state, String redirect, KeycloakSession keycloakSession, RealmModel realm, ClientModel client, UserModel user, UserSessionModel session) {
AccessCode code = new AccessCode();
code.setId(UUID.randomUUID().toString() + System.currentTimeMillis());
code.setClientId(client.getClientId());
code.setUserId(user.getId());
code.setTimestamp(Time.currentTime());
code.setSessionState(session != null ? session.getId() : null);
code.setRedirectUri(redirect);
code.setState(state);
Set<String> requestedRoles = new HashSet<String>(); Set<String> requestedRoles = new HashSet<String>();
for (RoleModel r : getAccess(scopeParam, client, user)) { for (RoleModel r : getAccess(scopeParam, client, user)) {
requestedRoles.add(r.getId()); requestedRoles.add(r.getId());
} }
code.setRequestedRoles(requestedRoles);
AccessCodeEntry entry = new AccessCodeEntry(keycloakSession, realm, code); ClientSessionModel clientSession = session.sessions().createClientSession(realm, client, userSession, redirect, state, requestedRoles);
return entry; return new AccessCode(realm, clientSession);
} }
public AccessToken refreshAccessToken(KeycloakSession session, UriInfo uriInfo, RealmModel realm, ClientModel client, String encodedRefreshToken, Audit audit) throws OAuthErrorException { public AccessToken refreshAccessToken(KeycloakSession session, UriInfo uriInfo, RealmModel realm, ClientModel client, String encodedRefreshToken, Audit audit) throws OAuthErrorException {

View file

@ -40,12 +40,14 @@ import org.keycloak.models.ApplicationModel;
import org.keycloak.models.AuthenticationLinkModel; import org.keycloak.models.AuthenticationLinkModel;
import org.keycloak.models.AuthenticationProviderModel; import org.keycloak.models.AuthenticationProviderModel;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.Constants; import org.keycloak.models.Constants;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.SocialLinkModel; import org.keycloak.models.SocialLinkModel;
import org.keycloak.models.UserCredentialModel; import org.keycloak.models.UserCredentialModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.utils.ModelToRepresentation; import org.keycloak.models.utils.ModelToRepresentation;
import org.keycloak.models.utils.TimeBasedOTP; import org.keycloak.models.utils.TimeBasedOTP;
import org.keycloak.representations.idm.CredentialRepresentation; import org.keycloak.representations.idm.CredentialRepresentation;
@ -151,9 +153,20 @@ public class AccountService {
} }
} }
if (authResult != null) { if (authResult != null) {
if (authResult.getSession() != null) { UserSessionModel userSession = authResult.getSession();
authResult.getSession().associateClient(application); if (userSession != null) {
boolean associated = false;
for (ClientSessionModel c : userSession.getClientSessions()) {
if (c.getClient().equals(application)) {
associated = true;
break;
}
}
if (!associated) {
session.sessions().createClientSession(realm, application, userSession, null, null, null);
}
} }
account.setUser(auth.getUser()); account.setUser(auth.getUser());
AuthenticationLinkModel authLinkModel = auth.getUser().getAuthenticationLink(); AuthenticationLinkModel authLinkModel = auth.getUser().getAuthenticationLink();

View file

@ -34,6 +34,7 @@ import org.keycloak.email.EmailException;
import org.keycloak.email.EmailProvider; import org.keycloak.email.EmailProvider;
import org.keycloak.login.LoginFormsProvider; import org.keycloak.login.LoginFormsProvider;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserCredentialModel; import org.keycloak.models.UserCredentialModel;
@ -43,7 +44,7 @@ import org.keycloak.models.UserSessionModel;
import org.keycloak.models.utils.TimeBasedOTP; import org.keycloak.models.utils.TimeBasedOTP;
import org.keycloak.representations.idm.CredentialRepresentation; import org.keycloak.representations.idm.CredentialRepresentation;
import org.keycloak.services.ClientConnection; import org.keycloak.services.ClientConnection;
import org.keycloak.services.managers.AccessCodeEntry; import org.keycloak.services.managers.AccessCode;
import org.keycloak.services.managers.AuthenticationManager; import org.keycloak.services.managers.AuthenticationManager;
import org.keycloak.services.managers.TokenManager; import org.keycloak.services.managers.TokenManager;
import org.keycloak.services.messages.Messages; import org.keycloak.services.messages.Messages;
@ -106,7 +107,7 @@ public class RequiredActionsService {
@POST @POST
@Consumes(MediaType.APPLICATION_FORM_URLENCODED) @Consumes(MediaType.APPLICATION_FORM_URLENCODED)
public Response updateProfile(final MultivaluedMap<String, String> formData) { public Response updateProfile(final MultivaluedMap<String, String> formData) {
AccessCodeEntry accessCode = getAccessCodeEntry(RequiredAction.UPDATE_PROFILE); AccessCode accessCode = getAccessCodeEntry(RequiredAction.UPDATE_PROFILE);
if (accessCode == null) { if (accessCode == null) {
return unauthorized(); return unauthorized();
} }
@ -144,7 +145,7 @@ public class RequiredActionsService {
@POST @POST
@Consumes(MediaType.APPLICATION_FORM_URLENCODED) @Consumes(MediaType.APPLICATION_FORM_URLENCODED)
public Response updateTotp(final MultivaluedMap<String, String> formData) { public Response updateTotp(final MultivaluedMap<String, String> formData) {
AccessCodeEntry accessCode = getAccessCodeEntry(RequiredAction.CONFIGURE_TOTP); AccessCode accessCode = getAccessCodeEntry(RequiredAction.CONFIGURE_TOTP);
if (accessCode == null) { if (accessCode == null) {
return unauthorized(); return unauthorized();
} }
@ -182,7 +183,7 @@ public class RequiredActionsService {
@Consumes(MediaType.APPLICATION_FORM_URLENCODED) @Consumes(MediaType.APPLICATION_FORM_URLENCODED)
public Response updatePassword(final MultivaluedMap<String, String> formData) { public Response updatePassword(final MultivaluedMap<String, String> formData) {
logger.debug("updatePassword"); logger.debug("updatePassword");
AccessCodeEntry accessCode = getAccessCodeEntry(RequiredAction.UPDATE_PASSWORD); AccessCode accessCode = getAccessCodeEntry(RequiredAction.UPDATE_PASSWORD);
if (accessCode == null) { if (accessCode == null) {
logger.debug("updatePassword access code is null"); logger.debug("updatePassword access code is null");
return unauthorized(); return unauthorized();
@ -231,8 +232,8 @@ public class RequiredActionsService {
@GET @GET
public Response emailVerification() { public Response emailVerification() {
if (uriInfo.getQueryParameters().containsKey("key")) { if (uriInfo.getQueryParameters().containsKey("key")) {
AccessCodeEntry accessCode = tokenManager.parseCode(uriInfo.getQueryParameters().getFirst("key"), session, realm); AccessCode accessCode = AccessCode.parse(uriInfo.getQueryParameters().getFirst("key"), session, realm);
if (accessCode == null || accessCode.isExpired() || !RequiredAction.VERIFY_EMAIL.equals(accessCode.getRequiredAction())) { if (accessCode == null || !accessCode.isValid(RequiredAction.VERIFY_EMAIL)) {
return unauthorized(); return unauthorized();
} }
@ -248,13 +249,12 @@ public class RequiredActionsService {
return redirectOauth(user, accessCode); return redirectOauth(user, accessCode);
} else { } else {
AccessCodeEntry accessCode = getAccessCodeEntry(RequiredAction.VERIFY_EMAIL); AccessCode accessCode = getAccessCodeEntry(RequiredAction.VERIFY_EMAIL);
if (accessCode == null) { if (accessCode == null) {
return unauthorized(); return unauthorized();
} }
initAudit(accessCode); initAudit(accessCode);
//audit.clone().event(EventType.SEND_VERIFY_EMAIL).detail(Details.EMAIL, accessCode.getUser().getEmail()).success();
return Flows.forms(session, realm, uriInfo).setAccessCode(accessCode.getCode()).setUser(accessCode.getUser()) return Flows.forms(session, realm, uriInfo).setAccessCode(accessCode.getCode()).setUser(accessCode.getUser())
.createResponse(RequiredAction.VERIFY_EMAIL); .createResponse(RequiredAction.VERIFY_EMAIL);
@ -265,8 +265,8 @@ public class RequiredActionsService {
@GET @GET
public Response passwordReset() { public Response passwordReset() {
if (uriInfo.getQueryParameters().containsKey("key")) { if (uriInfo.getQueryParameters().containsKey("key")) {
AccessCodeEntry accessCode = tokenManager.parseCode(uriInfo.getQueryParameters().getFirst("key"), session, realm); AccessCode accessCode = AccessCode.parse(uriInfo.getQueryParameters().getFirst("key"), session, realm);
if (accessCode == null || accessCode.isExpired() || !RequiredAction.UPDATE_PASSWORD.equals(accessCode.getRequiredAction())) { if (accessCode == null || !accessCode.isValid(RequiredAction.UPDATE_PASSWORD)) {
return unauthorized(); return unauthorized();
} }
@ -317,7 +317,7 @@ public class RequiredActionsService {
UserSessionModel userSession = session.sessions().createUserSession(realm, user, username, clientConnection.getRemoteAddr(), "form", false); UserSessionModel userSession = session.sessions().createUserSession(realm, user, username, clientConnection.getRemoteAddr(), "form", false);
audit.session(userSession); audit.session(userSession);
AccessCodeEntry accessCode = tokenManager.createAccessCode(scopeParam, state, redirect, session, realm, client, user, userSession); AccessCode accessCode = tokenManager.createAccessCode(scopeParam, state, redirect, session, realm, client, user, userSession);
accessCode.setRequiredAction(RequiredAction.UPDATE_PASSWORD); accessCode.setRequiredAction(RequiredAction.UPDATE_PASSWORD);
try { try {
@ -339,38 +339,33 @@ public class RequiredActionsService {
return Flows.forms(session, realm, uriInfo).setSuccess("emailSent").createPasswordReset(); return Flows.forms(session, realm, uriInfo).setSuccess("emailSent").createPasswordReset();
} }
private AccessCodeEntry getAccessCodeEntry(RequiredAction requiredAction) { private AccessCode getAccessCodeEntry(RequiredAction requiredAction) {
String code = uriInfo.getQueryParameters().getFirst(OAuth2Constants.CODE); String code = uriInfo.getQueryParameters().getFirst(OAuth2Constants.CODE);
if (code == null) { if (code == null) {
logger.debug("getAccessCodeEntry code as not in query param"); logger.debug("getAccessCodeEntry code as not in query param");
return null; return null;
} }
AccessCodeEntry accessCodeEntry = tokenManager.parseCode(code, session, realm); AccessCode accessCode = AccessCode.parse(code, session, realm);
if (accessCodeEntry == null) { if (accessCode == null) {
logger.debug("getAccessCodeEntry access code entry null"); logger.debug("getAccessCodeEntry access code entry null");
return null; return null;
} }
if (accessCodeEntry.isExpired()) { if (!accessCode.isValid(requiredAction)) {
logger.debugv("getAccessCodeEntry: access code id: {0}", accessCodeEntry.getCodeId()); logger.debugv("getAccessCodeEntry: access code id: {0}", accessCode.getCodeId());
logger.debugv("getAccessCodeEntry access code entry expired"); logger.debugv("getAccessCodeEntry access code not valid");
return null; return null;
} }
if (!requiredAction.equals(accessCodeEntry.getRequiredAction())) { return accessCode;
logger.debugv("Invalid access code action: {0}", requiredAction);
return null;
}
return accessCodeEntry;
} }
private UserModel getUser(AccessCodeEntry accessCode) { private UserModel getUser(AccessCode accessCode) {
return session.users().getUserByUsername(accessCode.getUser().getUsername(), realm); return session.users().getUserByUsername(accessCode.getUser().getUsername(), realm);
} }
private Response redirectOauth(UserModel user, AccessCodeEntry accessCode) { private Response redirectOauth(UserModel user, AccessCode accessCode) {
if (accessCode == null) { if (accessCode == null) {
return null; return null;
} }
@ -382,7 +377,7 @@ public class RequiredActionsService {
.createResponse(requiredActions.iterator().next()); .createResponse(requiredActions.iterator().next());
} else { } else {
logger.debugv("redirectOauth: redirecting to: {0}", accessCode.getRedirectUri()); logger.debugv("redirectOauth: redirecting to: {0}", accessCode.getRedirectUri());
accessCode.setAction(null); accessCode.setAction(ClientSessionModel.Action.CODE_TO_TOKEN);
AuthenticationManager authManager = new AuthenticationManager(); AuthenticationManager authManager = new AuthenticationManager();
@ -400,7 +395,7 @@ public class RequiredActionsService {
} }
} }
private void initAudit(AccessCodeEntry accessCode) { private void initAudit(AccessCode accessCode) {
audit.event(EventType.LOGIN).client(accessCode.getClient()) audit.event(EventType.LOGIN).client(accessCode.getClient())
.user(accessCode.getUser()) .user(accessCode.getUser())
.session(accessCode.getSessionState()) .session(accessCode.getSessionState())

View file

@ -19,21 +19,20 @@ import org.keycloak.authentication.AuthenticationProviderManager;
import org.keycloak.login.LoginFormsProvider; import org.keycloak.login.LoginFormsProvider;
import org.keycloak.models.ApplicationModel; import org.keycloak.models.ApplicationModel;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.Constants; import org.keycloak.models.Constants;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.RequiredCredentialModel; import org.keycloak.models.RequiredCredentialModel;
import org.keycloak.models.RoleModel;
import org.keycloak.models.UserCredentialModel; import org.keycloak.models.UserCredentialModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.representations.AccessCode;
import org.keycloak.representations.AccessToken; import org.keycloak.representations.AccessToken;
import org.keycloak.representations.AccessTokenResponse; import org.keycloak.representations.AccessTokenResponse;
import org.keycloak.representations.idm.CredentialRepresentation; import org.keycloak.representations.idm.CredentialRepresentation;
import org.keycloak.services.ClientConnection; import org.keycloak.services.ClientConnection;
import org.keycloak.services.managers.AccessCodeEntry; import org.keycloak.services.managers.AccessCode;
import org.keycloak.services.managers.AuthenticationManager; import org.keycloak.services.managers.AuthenticationManager;
import org.keycloak.services.managers.AuthenticationManager.AuthenticationStatus; import org.keycloak.services.managers.AuthenticationManager.AuthenticationStatus;
import org.keycloak.services.managers.ResourceAdminManager; import org.keycloak.services.managers.ResourceAdminManager;
@ -43,6 +42,7 @@ import org.keycloak.services.resources.flows.Flows;
import org.keycloak.services.resources.flows.OAuthFlows; import org.keycloak.services.resources.flows.OAuthFlows;
import org.keycloak.services.resources.flows.Urls; import org.keycloak.services.resources.flows.Urls;
import org.keycloak.services.validation.Validation; import org.keycloak.services.validation.Validation;
import org.keycloak.util.Base64Url;
import org.keycloak.util.BasicAuthHelper; import org.keycloak.util.BasicAuthHelper;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@ -281,7 +281,6 @@ public class TokenService {
String scope = form.getFirst(OAuth2Constants.SCOPE); String scope = form.getFirst(OAuth2Constants.SCOPE);
UserSessionModel userSession = session.sessions().createUserSession(realm, user, username, clientConnection.getRemoteAddr(), "oauth_credentials", false); UserSessionModel userSession = session.sessions().createUserSession(realm, user, username, clientConnection.getRemoteAddr(), "oauth_credentials", false);
userSession.associateClient(client);
audit.session(userSession); audit.session(userSession);
AccessTokenResponse res = tokenManager.responseBuilder(realm, client, audit) AccessTokenResponse res = tokenManager.responseBuilder(realm, client, audit)
@ -623,10 +622,15 @@ public class TokenService {
throw new BadRequestException("Code not specified", Response.status(Response.Status.BAD_REQUEST).entity(error).type("application/json").build()); throw new BadRequestException("Code not specified", Response.status(Response.Status.BAD_REQUEST).entity(error).type("application/json").build());
} }
AccessCode accessCode = AccessCode.parse(code, session, realm);
AccessCodeEntry accessCode = tokenManager.parseCode(code, session, realm);
if (accessCode == null) { if (accessCode == null) {
String[] parts = code.split("\\.");
if (parts.length == 2) {
try {
audit.detail(Details.CODE_ID, new String(Base64Url.decode(parts[1])));
} catch (Throwable t) {
}
}
Map<String, String> res = new HashMap<String, String>(); Map<String, String> res = new HashMap<String, String>();
res.put(OAuth2Constants.ERROR, "invalid_grant"); res.put(OAuth2Constants.ERROR, "invalid_grant");
res.put(OAuth2Constants.ERROR_DESCRIPTION, "Code not found"); res.put(OAuth2Constants.ERROR_DESCRIPTION, "Code not found");
@ -635,7 +639,7 @@ public class TokenService {
.build(); .build();
} }
audit.detail(Details.CODE_ID, accessCode.getCodeId()); audit.detail(Details.CODE_ID, accessCode.getCodeId());
if (accessCode.isExpired()) { if (!accessCode.isValid(ClientSessionModel.Action.CODE_TO_TOKEN)) {
Map<String, String> res = new HashMap<String, String>(); Map<String, String> res = new HashMap<String, String>();
res.put(OAuth2Constants.ERROR, "invalid_grant"); res.put(OAuth2Constants.ERROR, "invalid_grant");
res.put(OAuth2Constants.ERROR_DESCRIPTION, "Code is expired"); res.put(OAuth2Constants.ERROR_DESCRIPTION, "Code is expired");
@ -643,14 +647,8 @@ public class TokenService {
return Response.status(Response.Status.BAD_REQUEST).type(MediaType.APPLICATION_JSON_TYPE).entity(res) return Response.status(Response.Status.BAD_REQUEST).type(MediaType.APPLICATION_JSON_TYPE).entity(res)
.build(); .build();
} }
if (accessCode.getAction() != null) {
Map<String, String> res = new HashMap<String, String>(); accessCode.setAction(null);
res.put(OAuth2Constants.ERROR, "invalid_grant");
res.put(OAuth2Constants.ERROR_DESCRIPTION, "Code is not active");
audit.error(Errors.INVALID_CODE);
return Response.status(Response.Status.BAD_REQUEST).type(MediaType.APPLICATION_JSON_TYPE).entity(res)
.build();
}
audit.user(accessCode.getUser()); audit.user(accessCode.getUser());
audit.session(accessCode.getSessionState()); audit.session(accessCode.getSessionState());
@ -698,8 +696,6 @@ public class TokenService {
logger.debug("accessRequest SUCCESS"); logger.debug("accessRequest SUCCESS");
userSession.associateClient(client);
AccessToken token = tokenManager.createClientAccessToken(accessCode.getRequestedRoles(), realm, client, user, userSession); AccessToken token = tokenManager.createClientAccessToken(accessCode.getRequestedRoles(), realm, client, user, userSession);
try { try {
@ -982,22 +978,22 @@ public class TokenService {
String code = formData.getFirst(OAuth2Constants.CODE); String code = formData.getFirst(OAuth2Constants.CODE);
AccessCodeEntry accessCodeEntry = tokenManager.parseCode(code, session, realm); AccessCode accessCode = AccessCode.parse(code, session, realm);
if (accessCodeEntry == null || !AccessCode.Action.OAUTH_GRANT.equals(accessCodeEntry.getAction())) { if (accessCode == null || !accessCode.isValid(ClientSessionModel.Action.OAUTH_GRANT)) {
audit.error(Errors.INVALID_CODE); audit.error(Errors.INVALID_CODE);
return oauth.forwardToSecurityFailure("Unknown access code."); return oauth.forwardToSecurityFailure("Invalid access code.");
} }
audit.detail(Details.CODE_ID, accessCodeEntry.getCodeId()); audit.detail(Details.CODE_ID, accessCode.getCodeId());
String redirect = accessCodeEntry.getRedirectUri(); String redirect = accessCode.getRedirectUri();
String state = accessCodeEntry.getState(); String state = accessCode.getState();
audit.client(accessCodeEntry.getClient()) audit.client(accessCode.getClient())
.user(accessCodeEntry.getUser()) .user(accessCode.getUser())
.detail(Details.RESPONSE_TYPE, "code") .detail(Details.RESPONSE_TYPE, "code")
.detail(Details.REDIRECT_URI, redirect); .detail(Details.REDIRECT_URI, redirect);
UserSessionModel userSession = session.sessions().getUserSession(realm, accessCodeEntry.getSessionState()); UserSessionModel userSession = session.sessions().getUserSession(realm, accessCode.getSessionState());
if (userSession != null) { if (userSession != null) {
audit.detail(Details.AUTH_METHOD, userSession.getAuthMethod()); audit.detail(Details.AUTH_METHOD, userSession.getAuthMethod());
audit.detail(Details.USERNAME, userSession.getLoginUsername()); audit.detail(Details.USERNAME, userSession.getLoginUsername());
@ -1020,8 +1016,8 @@ public class TokenService {
audit.success(); audit.success();
accessCodeEntry.setAction(null); accessCode.setAction(ClientSessionModel.Action.CODE_TO_TOKEN);
return oauth.redirectAccessCode(accessCodeEntry, userSession, state, redirect); return oauth.redirectAccessCode(accessCode, userSession, state, redirect);
} }
@Path("oauth/oob") @Path("oauth/oob")

View file

@ -27,7 +27,7 @@ import org.keycloak.representations.idm.RoleRepresentation;
import org.keycloak.representations.idm.SocialLinkRepresentation; import org.keycloak.representations.idm.SocialLinkRepresentation;
import org.keycloak.representations.idm.UserRepresentation; import org.keycloak.representations.idm.UserRepresentation;
import org.keycloak.representations.idm.UserSessionRepresentation; import org.keycloak.representations.idm.UserSessionRepresentation;
import org.keycloak.services.managers.AccessCodeEntry; import org.keycloak.services.managers.AccessCode;
import org.keycloak.services.managers.RealmManager; import org.keycloak.services.managers.RealmManager;
import org.keycloak.services.managers.ResourceAdminManager; import org.keycloak.services.managers.ResourceAdminManager;
import org.keycloak.services.managers.TokenManager; import org.keycloak.services.managers.TokenManager;
@ -820,7 +820,7 @@ public class UsersResource {
return Flows.errors().error("AccountProvider management not enabled", Response.Status.INTERNAL_SERVER_ERROR); return Flows.errors().error("AccountProvider management not enabled", Response.Status.INTERNAL_SERVER_ERROR);
} }
AccessCodeEntry accessCode = tokenManager.createAccessCode(scope, state, redirect, session, realm, client, user, null); AccessCode accessCode = tokenManager.createAccessCode(scope, state, redirect, session, realm, client, user, null);
accessCode.setRequiredAction(UserModel.RequiredAction.UPDATE_PASSWORD); accessCode.setRequiredAction(UserModel.RequiredAction.UPDATE_PASSWORD);
try { try {

View file

@ -30,6 +30,7 @@ import org.keycloak.audit.Details;
import org.keycloak.audit.EventType; import org.keycloak.audit.EventType;
import org.keycloak.models.ApplicationModel; import org.keycloak.models.ApplicationModel;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.RequiredCredentialModel; import org.keycloak.models.RequiredCredentialModel;
@ -37,9 +38,8 @@ import org.keycloak.models.RoleModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserModel.RequiredAction; import org.keycloak.models.UserModel.RequiredAction;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.representations.AccessCode;
import org.keycloak.representations.idm.CredentialRepresentation; import org.keycloak.representations.idm.CredentialRepresentation;
import org.keycloak.services.managers.AccessCodeEntry; import org.keycloak.services.managers.AccessCode;
import org.keycloak.services.managers.AuthenticationManager; import org.keycloak.services.managers.AuthenticationManager;
import org.keycloak.services.managers.TokenManager; import org.keycloak.services.managers.TokenManager;
@ -82,7 +82,7 @@ public class OAuthFlows {
this.tokenManager = tokenManager; this.tokenManager = tokenManager;
} }
public Response redirectAccessCode(AccessCodeEntry accessCode, UserSessionModel userSession, String state, String redirect) { public Response redirectAccessCode(AccessCode accessCode, UserSessionModel userSession, String state, String redirect) {
String code = accessCode.getCode(); String code = accessCode.getCode();
UriBuilder redirectUri = UriBuilder.fromUri(redirect).queryParam(OAuth2Constants.CODE, code); UriBuilder redirectUri = UriBuilder.fromUri(redirect).queryParam(OAuth2Constants.CODE, code);
log.debugv("redirectAccessCode: state: {0}", state); log.debugv("redirectAccessCode: state: {0}", state);
@ -122,7 +122,7 @@ public class OAuthFlows {
isEmailVerificationRequired(user); isEmailVerificationRequired(user);
boolean isResource = client instanceof ApplicationModel; boolean isResource = client instanceof ApplicationModel;
AccessCodeEntry accessCode = tokenManager.createAccessCode(scopeParam, state, redirect, this.session, realm, client, user, session); AccessCode accessCode = tokenManager.createAccessCode(scopeParam, state, redirect, this.session, realm, client, user, session);
log.debugv("processAccessCode: isResource: {0}", isResource); log.debugv("processAccessCode: isResource: {0}", isResource);
log.debugv("processAccessCode: go to oauth page?: {0}", log.debugv("processAccessCode: go to oauth page?: {0}",
@ -144,7 +144,7 @@ public class OAuthFlows {
} }
if (!isResource) { if (!isResource) {
accessCode.setAction(AccessCode.Action.OAUTH_GRANT); accessCode.setAction(ClientSessionModel.Action.OAUTH_GRANT);
List<RoleModel> realmRoles = new LinkedList<RoleModel>(); List<RoleModel> realmRoles = new LinkedList<RoleModel>();
MultivaluedMap<String, RoleModel> resourceRoles = new MultivaluedMapImpl<String, RoleModel>(); MultivaluedMap<String, RoleModel> resourceRoles = new MultivaluedMapImpl<String, RoleModel>();
@ -165,6 +165,8 @@ public class OAuthFlows {
if (redirect != null) { if (redirect != null) {
audit.success(); audit.success();
accessCode.setAction(ClientSessionModel.Action.CODE_TO_TOKEN);
return redirectAccessCode(accessCode, session, state, redirect); return redirectAccessCode(accessCode, session, state, redirect);
} else { } else {
return null; return null;

View file

@ -22,6 +22,7 @@ import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.representations.idm.UserRepresentation; import org.keycloak.representations.idm.UserRepresentation;
import org.keycloak.services.managers.RealmManager; import org.keycloak.services.managers.RealmManager;
import org.keycloak.testsuite.rule.KeycloakRule; import org.keycloak.testsuite.rule.KeycloakRule;
import org.keycloak.util.Time;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
@ -328,17 +329,7 @@ public class AssertEvents implements TestRule, AuditListenerFactory {
} }
public static Matcher<String> isCodeId() { public static Matcher<String> isCodeId() {
return new TypeSafeMatcher<String>() { return isUUID();
@Override
protected boolean matchesSafely(String item) {
return (UUID.randomUUID().toString() + System.currentTimeMillis()).length() == item.length();
}
@Override
public void describeTo(Description description) {
description.appendText("Not an Code ID");
}
};
} }
public static Matcher<String> isUUID() { public static Matcher<String> isUUID() {

View file

@ -38,8 +38,11 @@ import org.keycloak.RSATokenVerifier;
import org.keycloak.VerificationException; import org.keycloak.VerificationException;
import org.keycloak.jose.jws.JWSInput; import org.keycloak.jose.jws.JWSInput;
import org.keycloak.jose.jws.crypto.RSAProvider; import org.keycloak.jose.jws.crypto.RSAProvider;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.representations.AccessToken; import org.keycloak.representations.AccessToken;
import org.keycloak.representations.RefreshToken; import org.keycloak.representations.RefreshToken;
import org.keycloak.services.managers.AccessCode;
import org.keycloak.services.resources.TokenService; import org.keycloak.services.resources.TokenService;
import org.keycloak.util.BasicAuthHelper; import org.keycloak.util.BasicAuthHelper;
import org.keycloak.util.PemUtils; import org.keycloak.util.PemUtils;
@ -217,12 +220,6 @@ public class OAuthClient {
} }
} }
public void verifyCode(String code) {
if (!RSAProvider.verify(new JWSInput(code), realmPublicKey)) {
throw new RuntimeException("Failed to verify code");
}
}
public RefreshToken verifyRefreshToken(String refreshToken) { public RefreshToken verifyRefreshToken(String refreshToken) {
try { try {
JWSInput jws = new JWSInput(refreshToken); JWSInput jws = new JWSInput(refreshToken);

View file

@ -5,6 +5,7 @@ import org.junit.Before;
import org.junit.ClassRule; import org.junit.ClassRule;
import org.junit.Test; import org.junit.Test;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
@ -13,7 +14,11 @@ import org.keycloak.testsuite.rule.KeycloakRule;
import org.keycloak.util.Time; import org.keycloak.util.Time;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Set;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@ -59,7 +64,47 @@ public class UserSessionProviderTest {
assertSession(session.sessions().getUserSession(realm, sessions[0].getId()), session.users().getUserByUsername("user1", realm), "127.0.0.1", started, started, "test-app", "third-party"); assertSession(session.sessions().getUserSession(realm, sessions[0].getId()), session.users().getUserByUsername("user1", realm), "127.0.0.1", started, started, "test-app", "third-party");
assertSession(session.sessions().getUserSession(realm, sessions[1].getId()), session.users().getUserByUsername("user1", realm), "127.0.0.2", started, started, "test-app"); assertSession(session.sessions().getUserSession(realm, sessions[1].getId()), session.users().getUserByUsername("user1", realm), "127.0.0.2", started, started, "test-app");
assertSession(session.sessions().getUserSession(realm, sessions[2].getId()), session.users().getUserByUsername("user2", realm), "127.0.0.3", started, started); assertSession(session.sessions().getUserSession(realm, sessions[2].getId()), session.users().getUserByUsername("user2", realm), "127.0.0.3", started, started, "test-app");
}
@Test
public void testCreateClientSession() {
UserSessionModel[] sessions = createSessions();
List<ClientSessionModel> clientSessions = sessions[0].getClientSessions();
assertEquals(2, clientSessions.size());
ClientSessionModel session = clientSessions.get(0);
assertEquals(null, session.getAction());
assertEquals(realm.findClient("test-app").getClientId(), session.getClient().getClientId());
assertEquals(sessions[0].getId(), session.getUserSession().getId());
assertEquals("http://redirect", session.getRedirectUri());
assertEquals("state", session.getState());
assertEquals(2, session.getRoles().size());
assertTrue(session.getRoles().contains("one"));
assertTrue(session.getRoles().contains("two"));
}
@Test
public void testUpdateClientSession() {
UserSessionModel[] sessions = createSessions();
String id = sessions[0].getClientSessions().get(0).getId();
ClientSessionModel clientSession = session.sessions().getClientSession(realm, id);
int time = clientSession.getTimestamp();
assertEquals(null, clientSession.getAction());
clientSession.setAction(ClientSessionModel.Action.CODE_TO_TOKEN);
clientSession.setTimestamp(time + 10);
kc.stopSession(session, true);
session = kc.startSession();
ClientSessionModel updated = session.sessions().getClientSession(realm, id);
assertEquals(ClientSessionModel.Action.CODE_TO_TOKEN, updated.getAction());
assertEquals(time + 10, updated.getTimestamp());
} }
@Test @Test
@ -72,22 +117,108 @@ public class UserSessionProviderTest {
@Test @Test
public void testRemoveUserSessionsByUser() { public void testRemoveUserSessionsByUser() {
createSessions(); UserSessionModel[] sessions = createSessions();
List<String> clientSessionsRemoved = new LinkedList<String>();
List<String> clientSessionsKept = new LinkedList<String>();
for (UserSessionModel s : sessions) {
for (ClientSessionModel c : s.getClientSessions()) {
if (c.getUserSession().getUser().getUsername().equals("user1")) {
clientSessionsRemoved.add(c.getId());
} else {
clientSessionsKept.add(c.getId());
}
}
}
session.sessions().removeUserSessions(realm, session.users().getUserByUsername("user1", realm)); session.sessions().removeUserSessions(realm, session.users().getUserByUsername("user1", realm));
resetSession(); resetSession();
assertTrue(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user1", realm)).isEmpty()); assertTrue(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user1", realm)).isEmpty());
assertFalse(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user2", realm)).isEmpty()); assertFalse(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user2", realm)).isEmpty());
for (String c : clientSessionsRemoved) {
assertNull(session.sessions().getClientSession(realm, c));
}
for (String c : clientSessionsKept) {
assertNotNull(session.sessions().getClientSession(realm, c));
}
}
@Test
public void testRemoveUserSession() {
UserSessionModel userSession = createSessions()[0];
List<String> clientSessionsRemoved = new LinkedList<String>();
for (ClientSessionModel c : userSession.getClientSessions()) {
clientSessionsRemoved.add(c.getId());
}
session.sessions().removeUserSession(realm, userSession);
resetSession();
assertNull(session.sessions().getUserSession(realm, userSession.getId()));
for (String c : clientSessionsRemoved) {
assertNull(session.sessions().getClientSession(realm, c));
}
} }
@Test @Test
public void testRemoveUserSessionsByRealm() { public void testRemoveUserSessionsByRealm() {
createSessions(); UserSessionModel[] sessions = createSessions();
List<ClientSessionModel> clientSessions = new LinkedList<ClientSessionModel>();
for (UserSessionModel s : sessions) {
clientSessions.addAll(s.getClientSessions());
}
session.sessions().removeUserSessions(realm); session.sessions().removeUserSessions(realm);
resetSession(); resetSession();
assertTrue(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user1", realm)).isEmpty()); assertTrue(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user1", realm)).isEmpty());
assertTrue(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user2", realm)).isEmpty()); assertTrue(session.sessions().getUserSessions(realm, session.users().getUserByUsername("user2", realm)).isEmpty());
for (ClientSessionModel c : clientSessions) {
assertNull(session.sessions().getClientSession(realm, c.getId()));
}
}
@Test
public void testOnClientRemoved() {
UserSessionModel[] sessions = createSessions();
List<String> clientSessionsRemoved = new LinkedList<String>();
List<String> clientSessionsKept = new LinkedList<String>();
for (UserSessionModel s : sessions) {
s = session.sessions().getUserSession(realm, s.getId());
for (ClientSessionModel c : s.getClientSessions()) {
if (c.getClient().getClientId().equals("third-party")) {
clientSessionsRemoved.add(c.getId());
} else {
clientSessionsKept.add(c.getId());
}
}
}
session.sessions().onClientRemoved(realm, realm.findClient("third-party"));
resetSession();
for (String c : clientSessionsRemoved) {
assertNull(session.sessions().getClientSession(realm, c));
}
for (String c : clientSessionsKept) {
assertNotNull(session.sessions().getClientSession(realm, c));
}
session.sessions().onClientRemoved(realm, realm.findClient("test-app"));
resetSession();
for (String c : clientSessionsRemoved) {
assertNull(session.sessions().getClientSession(realm, c));
}
for (String c : clientSessionsKept) {
assertNull(session.sessions().getClientSession(realm, c));
}
} }
@Test @Test
@ -111,7 +242,7 @@ public class UserSessionProviderTest {
public void testGetByClient() { public void testGetByClient() {
UserSessionModel[] sessions = createSessions(); UserSessionModel[] sessions = createSessions();
assertSessions(session.sessions().getUserSessions(realm, realm.findClient("test-app")), sessions[0], sessions[1]); assertSessions(session.sessions().getUserSessions(realm, realm.findClient("test-app")), sessions[0], sessions[1], sessions[2]);
assertSessions(session.sessions().getUserSessions(realm, realm.findClient("third-party")), sessions[0]); assertSessions(session.sessions().getUserSessions(realm, realm.findClient("third-party")), sessions[0]);
} }
@ -120,7 +251,7 @@ public class UserSessionProviderTest {
for (int i = 0; i < 25; i++) { for (int i = 0; i < 25; i++) {
UserSessionModel userSession = session.sessions().createUserSession(realm, session.users().getUserByUsername("user1", realm), "user1", "127.0.0." + i, "form", false); UserSessionModel userSession = session.sessions().createUserSession(realm, session.users().getUserByUsername("user1", realm), "user1", "127.0.0." + i, "form", false);
userSession.setStarted(Time.currentTime() + i); userSession.setStarted(Time.currentTime() + i);
userSession.associateClient(realm.findClient("test-app")); session.sessions().createClientSession(realm, realm.findClient("test-app"), userSession, "http://redirect", "state", new HashSet<String>());
} }
resetSession(); resetSession();
@ -147,26 +278,30 @@ public class UserSessionProviderTest {
assertArrayEquals(expectedIps, actualIps); assertArrayEquals(expectedIps, actualIps);
} }
@Test @Test
public void testGetCountByClient() { public void testGetCountByClient() {
createSessions(); createSessions();
assertEquals(2, session.sessions().getActiveUserSessions(realm, realm.findClient("test-app"))); assertEquals(3, session.sessions().getActiveUserSessions(realm, realm.findClient("test-app")));
assertEquals(1, session.sessions().getActiveUserSessions(realm, realm.findClient("third-party"))); assertEquals(1, session.sessions().getActiveUserSessions(realm, realm.findClient("third-party")));
} }
private UserSessionModel[] createSessions() { private UserSessionModel[] createSessions() {
UserSessionModel[] sessions = new UserSessionModel[4]; UserSessionModel[] sessions = new UserSessionModel[3];
sessions[0] = session.sessions().createUserSession(realm, session.users().getUserByUsername("user1", realm), "user1", "127.0.0.1", "form", true); sessions[0] = session.sessions().createUserSession(realm, session.users().getUserByUsername("user1", realm), "user1", "127.0.0.1", "form", true);
sessions[0].associateClient(realm.findClient("test-app"));
sessions[0].associateClient(realm.findClient("third-party")); Set<String> roles = new HashSet<String>();
roles.add("one");
roles.add("two");
session.sessions().createClientSession(realm, realm.findClient("test-app"), sessions[0], "http://redirect", "state", roles);
session.sessions().createClientSession(realm, realm.findClient("third-party"), sessions[0], "http://redirect", "state", new HashSet<String>());
sessions[1] = session.sessions().createUserSession(realm, session.users().getUserByUsername("user1", realm), "user1", "127.0.0.2", "form", true); sessions[1] = session.sessions().createUserSession(realm, session.users().getUserByUsername("user1", realm), "user1", "127.0.0.2", "form", true);
sessions[1].associateClient(realm.findClient("test-app")); session.sessions().createClientSession(realm, realm.findClient("test-app"), sessions[1], "http://redirect", "state", new HashSet<String>());
sessions[2] = session.sessions().createUserSession(realm, session.users().getUserByUsername("user2", realm), "user2", "127.0.0.3", "form", true); sessions[2] = session.sessions().createUserSession(realm, session.users().getUserByUsername("user2", realm), "user2", "127.0.0.3", "form", true);
session.sessions().createClientSession(realm, realm.findClient("test-app"), sessions[2], "http://redirect", "state", new HashSet<String>());
resetSession(); resetSession();
@ -205,9 +340,9 @@ public class UserSessionProviderTest {
assertTrue(session.getStarted() >= started - 1 && session.getStarted() <= started + 1); assertTrue(session.getStarted() >= started - 1 && session.getStarted() <= started + 1);
assertTrue(session.getLastSessionRefresh() >= lastRefresh - 1 && session.getLastSessionRefresh() <= lastRefresh + 1); assertTrue(session.getLastSessionRefresh() >= lastRefresh - 1 && session.getLastSessionRefresh() <= lastRefresh + 1);
String[] actualClients = new String[session.getClientAssociations().size()]; String[] actualClients = new String[session.getClientSessions().size()];
for (int i = 0; i < actualClients.length; i++) { for (int i = 0; i < actualClients.length; i++) {
actualClients[i] = session.getClientAssociations().get(i).getClientId(); actualClients[i] = session.getClientSessions().get(i).getClient().getClientId();
} }
Arrays.sort(clients); Arrays.sort(clients);

View file

@ -29,7 +29,9 @@ import org.keycloak.OAuth2Constants;
import org.keycloak.audit.Details; import org.keycloak.audit.Details;
import org.keycloak.audit.Errors; import org.keycloak.audit.Errors;
import org.keycloak.audit.Event; import org.keycloak.audit.Event;
import org.keycloak.audit.EventType;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel;
import org.keycloak.representations.AccessToken; import org.keycloak.representations.AccessToken;
import org.keycloak.services.managers.RealmManager; import org.keycloak.services.managers.RealmManager;
import org.keycloak.testsuite.AssertEvents; import org.keycloak.testsuite.AssertEvents;
@ -45,6 +47,7 @@ import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
/** /**
@ -72,13 +75,6 @@ public class AccessTokenTest {
@Test @Test
public void accessTokenRequest() throws Exception { public void accessTokenRequest() throws Exception {
keycloakRule.update(new KeycloakRule.KeycloakSetup() {
@Override
public void config(RealmManager manager, RealmModel adminstrationRealm, RealmModel appRealm) {
appRealm.setAccessCodeLifespan(1);
}
});
oauth.doLogin("test-user@localhost", "password"); oauth.doLogin("test-user@localhost", "password");
Event loginEvent = events.expectLogin().assertEvent(); Event loginEvent = events.expectLogin().assertEvent();
@ -112,22 +108,6 @@ public class AccessTokenTest {
Assert.assertEquals(token.getId(), event.getDetails().get(Details.TOKEN_ID)); Assert.assertEquals(token.getId(), event.getDetails().get(Details.TOKEN_ID));
Assert.assertEquals(oauth.verifyRefreshToken(response.getRefreshToken()).getId(), event.getDetails().get(Details.REFRESH_TOKEN_ID)); Assert.assertEquals(oauth.verifyRefreshToken(response.getRefreshToken()).getId(), event.getDetails().get(Details.REFRESH_TOKEN_ID));
Assert.assertEquals(sessionId, token.getSessionState()); Assert.assertEquals(sessionId, token.getSessionState());
Thread.sleep(2000);
response = oauth.doAccessTokenRequest(code, "password");
Assert.assertEquals(400, response.getStatusCode());
AssertEvents.ExpectedEvent expectedEvent = events.expectCodeToToken(codeId, null);
expectedEvent.error("invalid_code").removeDetail(Details.TOKEN_ID).removeDetail(Details.REFRESH_TOKEN_ID).client((String) null).user((String) null);
expectedEvent.assertEvent();
keycloakRule.update(new KeycloakRule.KeycloakSetup() {
@Override
public void config(RealmManager manager, RealmModel adminstrationRealm, RealmModel appRealm) {
appRealm.setAccessCodeLifespan(60);
}
});
} }
@Test @Test
@ -162,10 +142,109 @@ public class AccessTokenTest {
assertNull(tokenResponse.getAccessToken()); assertNull(tokenResponse.getAccessToken());
assertNull(tokenResponse.getRefreshToken()); assertNull(tokenResponse.getRefreshToken());
events.expectCodeToToken(codeId, sessionId).removeDetail(Details.TOKEN_ID).removeDetail(Details.REFRESH_TOKEN_ID).error(Errors.INVALID_CODE).assertEvent(); events.expectCodeToToken(codeId, sessionId).removeDetail(Details.TOKEN_ID).client((String) null).user((String) null).session((String) null).removeDetail(Details.REFRESH_TOKEN_ID).error(Errors.INVALID_CODE).assertEvent();
events.clear(); events.clear();
} }
@Test
public void accessTokenCodeExpired() {
keycloakRule.update(new KeycloakRule.KeycloakSetup() {
@Override
public void config(RealmManager manager, RealmModel adminstrationRealm, RealmModel appRealm) {
appRealm.setAccessCodeLifespan(1);
}
});
oauth.doLogin("test-user@localhost", "password");
Event loginEvent = events.expectLogin().assertEvent();
String codeId = loginEvent.getDetails().get(Details.CODE_ID);
loginEvent.getSessionId();
String code = oauth.getCurrentQuery().get(OAuth2Constants.CODE);
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
}
OAuthClient.AccessTokenResponse response = oauth.doAccessTokenRequest(code, "password");
Assert.assertEquals(400, response.getStatusCode());
AssertEvents.ExpectedEvent expectedEvent = events.expectCodeToToken(codeId, null);
expectedEvent.error("invalid_code").removeDetail(Details.TOKEN_ID).removeDetail(Details.REFRESH_TOKEN_ID).client((String) null).user((String) null);
expectedEvent.assertEvent();
events.clear();
keycloakRule.update(new KeycloakRule.KeycloakSetup() {
@Override
public void config(RealmManager manager, RealmModel adminstrationRealm, RealmModel appRealm) {
appRealm.setAccessCodeLifespan(60);
}
});
}
@Test
public void accessTokenCodeUsed() {
oauth.doLogin("test-user@localhost", "password");
Event loginEvent = events.expectLogin().assertEvent();
String codeId = loginEvent.getDetails().get(Details.CODE_ID);
loginEvent.getSessionId();
String code = oauth.getCurrentQuery().get(OAuth2Constants.CODE);
OAuthClient.AccessTokenResponse response = oauth.doAccessTokenRequest(code, "password");
Assert.assertEquals(200, response.getStatusCode());
events.clear();
response = oauth.doAccessTokenRequest(code, "password");
Assert.assertEquals(400, response.getStatusCode());
AssertEvents.ExpectedEvent expectedEvent = events.expectCodeToToken(codeId, null);
expectedEvent.error("invalid_code").removeDetail(Details.TOKEN_ID).removeDetail(Details.REFRESH_TOKEN_ID).client((String) null).user((String) null);
expectedEvent.assertEvent();
events.clear();
keycloakRule.update(new KeycloakRule.KeycloakSetup() {
@Override
public void config(RealmManager manager, RealmModel adminstrationRealm, RealmModel appRealm) {
appRealm.setAccessCodeLifespan(60);
}
});
}
@Test
public void accessTokenCodeHasRequiredAction() {
keycloakRule.configure(new KeycloakRule.KeycloakSetup() {
@Override
public void config(RealmManager manager, RealmModel defaultRealm, RealmModel appRealm) {
UserModel user = manager.getSession().users().getUserByUsername("test-user@localhost", appRealm);
user.addRequiredAction(UserModel.RequiredAction.UPDATE_PROFILE);
}
});
oauth.doLogin("test-user@localhost", "password");
String code = driver.getPageSource().split("code=")[1].split("&")[0].split("\"")[0];
OAuthClient.AccessTokenResponse response = oauth.doAccessTokenRequest(code, "password");
Assert.assertEquals(400, response.getStatusCode());
Event event = events.poll();
assertNotNull(event.getDetails().get(Details.CODE_ID));
keycloakRule.update(new KeycloakRule.KeycloakSetup() {
@Override
public void config(RealmManager manager, RealmModel adminstrationRealm, RealmModel appRealm) {
manager.getSession().users().getUserByUsername("test-user@localhost", appRealm).removeRequiredAction(UserModel.RequiredAction.UPDATE_PROFILE);
}
});
}
} }

View file

@ -23,6 +23,7 @@ package org.keycloak.testsuite.oauth;
import org.junit.Assert; import org.junit.Assert;
import org.junit.ClassRule; import org.junit.ClassRule;
import org.junit.Ignore;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.keycloak.OAuth2Constants; import org.keycloak.OAuth2Constants;
@ -30,7 +31,7 @@ import org.keycloak.audit.Details;
import org.keycloak.jose.jws.JWSInput; import org.keycloak.jose.jws.JWSInput;
import org.keycloak.models.Constants; import org.keycloak.models.Constants;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.representations.AccessCode; import org.keycloak.services.managers.AccessCode;
import org.keycloak.services.managers.RealmManager; import org.keycloak.services.managers.RealmManager;
import org.keycloak.testsuite.AssertEvents; import org.keycloak.testsuite.AssertEvents;
import org.keycloak.testsuite.OAuthClient; import org.keycloak.testsuite.OAuthClient;
@ -44,6 +45,8 @@ import org.openqa.selenium.WebDriver;
import java.io.IOException; import java.io.IOException;
import static org.junit.Assert.assertEquals;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/ */
@ -75,14 +78,13 @@ public class AuthorizationCodeTest {
Assert.assertTrue(response.isRedirected()); Assert.assertTrue(response.isRedirected());
Assert.assertNotNull(response.getCode()); Assert.assertNotNull(response.getCode());
Assert.assertEquals("mystate", response.getState()); assertEquals("mystate", response.getState());
Assert.assertNull(response.getError()); Assert.assertNull(response.getError());
oauth.verifyCode(response.getCode()); keycloakRule.verifyCode(response.getCode());
String codeId = events.expectLogin().assertEvent().getDetails().get(Details.CODE_ID); String codeId = events.expectLogin().assertEvent().getDetails().get(Details.CODE_ID);
AccessCode accessCode = new JWSInput(response.getCode()).readJsonContent(AccessCode.class); assertCode(codeId, response.getCode());
Assert.assertEquals(codeId,accessCode.getId());
} }
@Test @Test
@ -101,11 +103,10 @@ public class AuthorizationCodeTest {
Assert.assertTrue(title.startsWith("Success code=")); Assert.assertTrue(title.startsWith("Success code="));
String code = driver.findElement(By.id(OAuth2Constants.CODE)).getText(); String code = driver.findElement(By.id(OAuth2Constants.CODE)).getText();
oauth.verifyCode(code); keycloakRule.verifyCode(code);
String codeId = events.expectLogin().detail(Details.REDIRECT_URI, Constants.INSTALLED_APP_URN).assertEvent().getDetails().get(Details.CODE_ID); String codeId = events.expectLogin().detail(Details.REDIRECT_URI, Constants.INSTALLED_APP_URN).assertEvent().getDetails().get(Details.CODE_ID);
AccessCode accessCode = new JWSInput(code).readJsonContent(AccessCode.class); assertCode(codeId, code);
Assert.assertEquals(codeId,accessCode.getId());
keycloakRule.update(new KeycloakRule.KeycloakSetup() { keycloakRule.update(new KeycloakRule.KeycloakSetup() {
@Override @Override
@ -132,7 +133,7 @@ public class AuthorizationCodeTest {
Assert.assertTrue(title.equals("Error error=access_denied")); Assert.assertTrue(title.equals("Error error=access_denied"));
String error = driver.findElement(By.id(OAuth2Constants.ERROR)).getText(); String error = driver.findElement(By.id(OAuth2Constants.ERROR)).getText();
Assert.assertEquals("access_denied", error); assertEquals("access_denied", error);
events.expectLogin().error("rejected_by_user").user((String) null).session((String) null).removeDetail(Details.USERNAME).removeDetail(Details.CODE_ID).detail(Details.REDIRECT_URI, Constants.INSTALLED_APP_URN).assertEvent().getDetails().get(Details.CODE_ID); events.expectLogin().error("rejected_by_user").user((String) null).session((String) null).removeDetail(Details.USERNAME).removeDetail(Details.CODE_ID).detail(Details.REDIRECT_URI, Constants.INSTALLED_APP_URN).assertEvent().getDetails().get(Details.CODE_ID);
@ -160,11 +161,10 @@ public class AuthorizationCodeTest {
Assert.assertTrue(response.isRedirected()); Assert.assertTrue(response.isRedirected());
Assert.assertNotNull(response.getCode()); Assert.assertNotNull(response.getCode());
oauth.verifyCode(response.getCode()); keycloakRule.verifyCode(response.getCode());
String codeId = events.expectLogin().assertEvent().getDetails().get(Details.CODE_ID); String codeId = events.expectLogin().assertEvent().getDetails().get(Details.CODE_ID);
AccessCode accessCode = new JWSInput(response.getCode()).readJsonContent(AccessCode.class); assertCode(codeId, response.getCode());
Assert.assertEquals(codeId,accessCode.getId());
} }
@Test @Test
@ -176,11 +176,15 @@ public class AuthorizationCodeTest {
Assert.assertNull(response.getState()); Assert.assertNull(response.getState());
Assert.assertNull(response.getError()); Assert.assertNull(response.getError());
oauth.verifyCode(response.getCode()); keycloakRule.verifyCode(response.getCode());
String codeId = events.expectLogin().assertEvent().getDetails().get(Details.CODE_ID); String codeId = events.expectLogin().assertEvent().getDetails().get(Details.CODE_ID);
AccessCode accessCode = new JWSInput(response.getCode()).readJsonContent(AccessCode.class); assertCode(codeId, response.getCode());
Assert.assertEquals(codeId,accessCode.getId()); }
private void assertCode(String expectedCodeId, String actualCode) {
AccessCode code = keycloakRule.verifyCode(actualCode);
assertEquals(expectedCodeId, code.getCodeId());
} }
} }

View file

@ -21,10 +21,12 @@
*/ */
package org.keycloak.testsuite.rule; package org.keycloak.testsuite.rule;
import org.junit.Assert;
import org.keycloak.Config; import org.keycloak.Config;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.services.managers.AccessCode;
import org.keycloak.services.managers.RealmManager; import org.keycloak.services.managers.RealmManager;
import org.keycloak.testsuite.ApplicationServlet; import org.keycloak.testsuite.ApplicationServlet;
@ -107,6 +109,24 @@ public class KeycloakRule extends AbstractKeycloakRule {
stopSession(session, true); stopSession(session, true);
} }
public AccessCode verifyCode(String code) {
KeycloakSession session = startSession();
try {
RealmModel realm = session.realms().getRealm("test");
try {
AccessCode accessCode = AccessCode.parse(code, session, realm);
if (accessCode == null) {
Assert.fail("Invalid code");
}
return accessCode;
} catch (Throwable t) {
throw new AssertionError("Failed to parse code", t);
}
} finally {
stopSession(session, false);
}
}
public abstract static class KeycloakSetup { public abstract static class KeycloakSetup {
protected KeycloakSession session; protected KeycloakSession session;