Merge pull request #716 from stianst/master

Fixes to Mongo user session provider
This commit is contained in:
Stian Thorgersen 2014-09-30 14:25:49 +02:00
commit d1bb872aec
7 changed files with 55 additions and 24 deletions

View file

@ -17,20 +17,19 @@ import java.util.Set;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/ */
public class ClientSessionAdapter implements ClientSessionModel { public class ClientSessionAdapter extends AbstractMongoAdapter<MongoClientSessionEntity> implements ClientSessionModel {
private KeycloakSession session; private KeycloakSession session;
private MongoUserSessionProvider provider; private MongoUserSessionProvider provider;
private RealmModel realm; private RealmModel realm;
private MongoClientSessionEntity entity; private MongoClientSessionEntity entity;
private MongoStoreInvocationContext invContext;
public ClientSessionAdapter(KeycloakSession session, MongoUserSessionProvider provider, RealmModel realm, MongoClientSessionEntity entity, MongoStoreInvocationContext invContext) { public ClientSessionAdapter(KeycloakSession session, MongoUserSessionProvider provider, RealmModel realm, MongoClientSessionEntity entity, MongoStoreInvocationContext invContext) {
super(invContext);
this.session = session; this.session = session;
this.provider = provider; this.provider = provider;
this.realm = realm; this.realm = realm;
this.entity = entity; this.entity = entity;
this.invContext = invContext;
} }
@Override @Override
@ -58,13 +57,15 @@ public class ClientSessionAdapter implements ClientSessionModel {
public void setUserSession(UserSessionModel userSession) { public void setUserSession(UserSessionModel userSession) {
MongoUserSessionEntity userSessionEntity = provider.getUserSessionEntity(realm, userSession.getId()); MongoUserSessionEntity userSessionEntity = provider.getUserSessionEntity(realm, userSession.getId());
entity.setSessionId(userSessionEntity.getId()); entity.setSessionId(userSessionEntity.getId());
provider.getMongoStore().pushItemToList(userSessionEntity, "clientSessions", entity.getId(), true, invContext); updateMongoEntity();
provider.getMongoStore().pushItemToList(userSessionEntity, "clientSessions", entity.getId(), true, invocationContext);
} }
@Override @Override
public void setRedirectUri(String uri) { public void setRedirectUri(String uri) {
entity.setRedirectUri(uri); entity.setRedirectUri(uri);
updateMongoEntity();
} }
@Override @Override
@ -72,6 +73,7 @@ public class ClientSessionAdapter implements ClientSessionModel {
List<String> list = new LinkedList<String>(); List<String> list = new LinkedList<String>();
list.addAll(roles); list.addAll(roles);
entity.setRoles(list); entity.setRoles(list);
updateMongoEntity();
} }
@Override @Override
@ -87,6 +89,7 @@ public class ClientSessionAdapter implements ClientSessionModel {
@Override @Override
public void setTimestamp(int timestamp) { public void setTimestamp(int timestamp) {
entity.setTimestamp(timestamp); entity.setTimestamp(timestamp);
updateMongoEntity();
} }
@Override @Override
@ -97,6 +100,7 @@ public class ClientSessionAdapter implements ClientSessionModel {
@Override @Override
public void setAction(Action action) { public void setAction(Action action) {
entity.setAction(action); entity.setAction(action);
updateMongoEntity();
} }
@Override @Override
@ -112,13 +116,13 @@ public class ClientSessionAdapter implements ClientSessionModel {
@Override @Override
public void setNote(String name, String value) { public void setNote(String name, String value) {
entity.getNotes().put(name, value); entity.getNotes().put(name, value);
updateMongoEntity();
} }
@Override @Override
public void removeNote(String name) { public void removeNote(String name) {
entity.getNotes().remove(name); entity.getNotes().remove(name);
updateMongoEntity();
} }
@Override @Override
@ -129,5 +133,11 @@ public class ClientSessionAdapter implements ClientSessionModel {
@Override @Override
public void setAuthMethod(String method) { public void setAuthMethod(String method) {
entity.setAuthMethod(method); entity.setAuthMethod(method);
updateMongoEntity();
}
@Override
protected MongoClientSessionEntity getMongoEntity() {
return entity;
} }
} }

View file

@ -19,6 +19,7 @@ import org.keycloak.models.sessions.mongo.entities.MongoUsernameLoginFailureEnti
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.util.Time; import org.keycloak.util.Time;
import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -49,6 +50,9 @@ public class MongoUserSessionProvider implements UserSessionProvider {
entity.setTimestamp(Time.currentTime()); entity.setTimestamp(Time.currentTime());
entity.setClientId(client.getId()); entity.setClientId(client.getId());
entity.setRealmId(realm.getId()); entity.setRealmId(realm.getId());
mongoStore.insertEntity(entity, invocationContext);
return new ClientSessionAdapter(session, this, realm, entity, invocationContext); return new ClientSessionAdapter(session, this, realm, entity, invocationContext);
} }
@ -125,24 +129,22 @@ public class MongoUserSessionProvider implements UserSessionProvider {
public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client, int firstResult, int maxResults) { public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client, int firstResult, int maxResults) {
DBObject query = new QueryBuilder() DBObject query = new QueryBuilder()
.and("clientSessions.clientId").is(client.getId()) .and("clientId").is(client.getId())
.get(); .get();
DBObject sort = new BasicDBObject("started", 1).append("id", 1); DBObject sort = new BasicDBObject("timestamp", 1).append("id", 1);
List<MongoUserSessionEntity> sessions = mongoStore.loadEntities(MongoUserSessionEntity.class, query, sort, firstResult, maxResults, invocationContext); List<MongoClientSessionEntity> clientSessions = mongoStore.loadEntities(MongoClientSessionEntity.class, query, sort, firstResult, maxResults, invocationContext);
List<UserSessionModel> result = new LinkedList<UserSessionModel>(); List<UserSessionModel> result = new LinkedList<UserSessionModel>();
for (MongoUserSessionEntity session : sessions) { for (MongoClientSessionEntity clientSession : clientSessions) {
result.add(new UserSessionAdapter(this.session, this, session, realm, invocationContext)); MongoUserSessionEntity userSession = mongoStore.loadEntity(MongoUserSessionEntity.class, clientSession.getSessionId(), invocationContext);
result.add(new UserSessionAdapter(session, this, userSession, realm, invocationContext));
} }
return result; return result;
} }
@Override @Override
public int getActiveUserSessions(RealmModel realm, ClientModel client) { public int getActiveUserSessions(RealmModel realm, ClientModel client) {
DBObject query = new QueryBuilder() return getUserSessions(realm, client).size();
.and("clientSessions.clientId").is(client.getId())
.get();
return mongoStore.countEntities(MongoUserSessionEntity.class, query, invocationContext);
} }
@Override @Override
@ -232,7 +234,14 @@ public class MongoUserSessionProvider implements UserSessionProvider {
DBObject query = new QueryBuilder() DBObject query = new QueryBuilder()
.and("clientId").is(client.getId()) .and("clientId").is(client.getId())
.get(); .get();
mongoStore.removeEntities(MongoUserSessionEntity.class, query, invocationContext); DBObject sort = new BasicDBObject("timestamp", 1).append("id", 1);
List<MongoClientSessionEntity> clientSessions = mongoStore.loadEntities(MongoClientSessionEntity.class, query, sort, -1, -1, invocationContext);
for (MongoClientSessionEntity clientSession : clientSessions) {
MongoUserSessionEntity userSession = mongoStore.loadEntity(MongoUserSessionEntity.class, clientSession.getSessionId(), invocationContext);
getMongoStore().pullItemFromList(userSession, "clientSessions", clientSession.getId(), invocationContext);
mongoStore.removeEntity(clientSession, invocationContext);
}
} }
@Override @Override

View file

@ -1,13 +1,11 @@
package org.keycloak.models.sessions.mongo; package org.keycloak.models.sessions.mongo;
import org.jboss.logging.Logger;
import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext; import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext;
import org.keycloak.models.ClientSessionModel; import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.sessions.mongo.entities.MongoClientSessionEntity;
import org.keycloak.models.sessions.mongo.entities.MongoUserSessionEntity; import org.keycloak.models.sessions.mongo.entities.MongoUserSessionEntity;
import java.util.LinkedList; import java.util.LinkedList;

View file

@ -1,5 +1,6 @@
package org.keycloak.models.sessions.mongo.entities; package org.keycloak.models.sessions.mongo.entities;
import org.keycloak.connections.mongo.api.MongoCollection;
import org.keycloak.connections.mongo.api.MongoIdentifiableEntity; import org.keycloak.connections.mongo.api.MongoIdentifiableEntity;
import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext; import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext;
import org.keycloak.models.ClientSessionModel; import org.keycloak.models.ClientSessionModel;
@ -12,6 +13,7 @@ import java.util.Map;
/** /**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a> * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/ */
@MongoCollection(collectionName = "clientSessions")
public class MongoClientSessionEntity extends AbstractIdentifiableEntity implements MongoIdentifiableEntity { public class MongoClientSessionEntity extends AbstractIdentifiableEntity implements MongoIdentifiableEntity {
private String id; private String id;

View file

@ -56,6 +56,7 @@ import javax.ws.rs.core.UriInfo;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -189,10 +190,10 @@ public class UsersResource {
user.setAttribute(attr.getKey(), attr.getValue()); user.setAttribute(attr.getKey(), attr.getValue());
} }
for (String key : user.getAttributes().keySet()) { Set<String> attrToRemove = new HashSet<String>(user.getAttributes().keySet());
if (!rep.getAttributes().containsKey(key)) { attrToRemove.removeAll(rep.getAttributes().keySet());
user.removeAttribute(key); for (String attr : attrToRemove) {
} user.removeAttribute(attr);
} }
} }
} }

View file

@ -406,6 +406,7 @@ public class OAuthClient {
private String refreshToken; private String refreshToken;
private String error; private String error;
private String errorDescription;
public AccessTokenResponse(HttpResponse response) throws Exception { public AccessTokenResponse(HttpResponse response) throws Exception {
statusCode = response.getStatusLine().getStatusCode(); statusCode = response.getStatusLine().getStatusCode();
@ -426,6 +427,7 @@ public class OAuthClient {
} }
} else { } else {
error = responseJson.getString(OAuth2Constants.ERROR); error = responseJson.getString(OAuth2Constants.ERROR);
errorDescription = responseJson.has(OAuth2Constants.ERROR_DESCRIPTION) ? responseJson.getString(OAuth2Constants.ERROR_DESCRIPTION) : null;
} }
} }
@ -437,6 +439,10 @@ public class OAuthClient {
return error; return error;
} }
public String getErrorDescription() {
return errorDescription;
}
public int getExpiresIn() { public int getExpiresIn() {
return expiresIn; return expiresIn;
} }

View file

@ -66,7 +66,7 @@ public class UserSessionProviderTest {
@Test @Test
public void testUpdateSession() { public void testUpdateSession() {
UserSessionModel[] sessions = createSessions(); UserSessionModel[] sessions = createSessions();
sessions[0].setLastSessionRefresh(1000); session.sessions().getUserSession(realm, sessions[0].getId()).setLastSessionRefresh(1000);
resetSession(); resetSession();
@ -137,6 +137,8 @@ public class UserSessionProviderTest {
List<String> clientSessionsRemoved = new LinkedList<String>(); List<String> clientSessionsRemoved = new LinkedList<String>();
List<String> clientSessionsKept = new LinkedList<String>(); List<String> clientSessionsKept = new LinkedList<String>();
for (UserSessionModel s : sessions) { for (UserSessionModel s : sessions) {
s = session.sessions().getUserSession(realm, s.getId());
for (ClientSessionModel c : s.getClientSessions()) { for (ClientSessionModel c : s.getClientSessions()) {
if (c.getUserSession().getUser().getUsername().equals("user1")) { if (c.getUserSession().getUser().getUsername().equals("user1")) {
clientSessionsRemoved.add(c.getId()); clientSessionsRemoved.add(c.getId());
@ -349,8 +351,11 @@ public class UserSessionProviderTest {
resetSession(); resetSession();
failure1 = session.sessions().getUserLoginFailure(realm, "user1");
failure1.clearFailures(); failure1.clearFailures();
resetSession();
failure1 = session.sessions().getUserLoginFailure(realm, "user1"); failure1 = session.sessions().getUserLoginFailure(realm, "user1");
assertEquals(0, failure1.getNumFailures()); assertEquals(0, failure1.getNumFailures());
} }