Implement invitation-only self-registration for realm users

Closes #31643

Signed-off-by: vramik <vramik@redhat.com>
This commit is contained in:
vramik 2024-09-09 13:54:37 +02:00 committed by Marek Posolda
parent 1f573eded0
commit fcb31a5aa6
19 changed files with 176 additions and 68 deletions

View file

@ -103,7 +103,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
OrganizationAdapter adapter = new OrganizationAdapter(session, realm, this);
try {
session.setAttribute(OrganizationModel.class.getName(), adapter);
session.getContext().setOrganization(adapter);
GroupModel group = createOrganizationGroup(adapter.getId());
adapter.setGroupId(group.getId());
@ -113,7 +113,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
em.persist(adapter.getEntity());
} finally {
session.removeAttribute(OrganizationModel.class.getName());
session.getContext().setOrganization(null);
}
return adapter;
@ -124,7 +124,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
OrganizationEntity entity = getEntity(organization.getId());
try {
session.setAttribute(OrganizationModel.class.getName(), organization);
session.getContext().setOrganization(organization);
RealmModel realm = session.realms().getRealm(getRealm().getId());
// check if the realm is being removed so that we don't need to remove manually remove any other data but the org
@ -143,7 +143,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
em.remove(entity);
} finally {
session.removeAttribute(OrganizationModel.class.getName());
session.getContext().setOrganization(null);
}
return true;
@ -178,7 +178,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
}
if (current == null) {
session.setAttribute(OrganizationModel.class.getName(), organization);
session.getContext().setOrganization(organization);
}
try {
@ -191,7 +191,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
user.joinGroup(group, metadata);
} finally {
if (current == null) {
session.removeAttribute(OrganizationModel.class.getName());
session.getContext().setOrganization(null);
}
}
@ -417,14 +417,14 @@ public class JpaOrganizationProvider implements OrganizationProvider {
OrganizationModel current = Organizations.resolveOrganization(session);
if (current == null) {
session.setAttribute(OrganizationModel.class.getName(), organization);
session.getContext().setOrganization(organization);
}
try {
member.leaveGroup(getOrganizationGroup(organization));
} finally {
if (current == null) {
session.removeAttribute(OrganizationModel.class.getName());
session.getContext().setOrganization(null);
}
}
}

View file

@ -142,9 +142,9 @@ public final class OrganizationAdapter implements OrganizationModel, JpaModel<Or
}
// add organization to the session as the following code updates the underlying group
OrganizationModel current = (OrganizationModel) session.getAttribute(OrganizationModel.class.getName());
OrganizationModel current = session.getContext().getOrganization();
if (current == null) {
session.setAttribute(OrganizationModel.class.getName(), this);
session.getContext().setOrganization(this);
}
try {
@ -154,7 +154,7 @@ public final class OrganizationAdapter implements OrganizationModel, JpaModel<Or
attributes.forEach(group::setAttribute);
} finally {
if (current == null) {
session.removeAttribute(OrganizationModel.class.getName());
session.getContext().setOrganization(null);
}
}
}

View file

@ -73,6 +73,10 @@ public interface KeycloakContext {
void setClient(ClientModel client);
OrganizationModel getOrganization();
void setOrganization(OrganizationModel organization);
ClientConnection getConnection();
Locale resolveLocale(UserModel user);

View file

@ -35,6 +35,16 @@ import jakarta.ws.rs.core.Response;
*/
public interface ActionTokenHandler<T extends JsonWebToken> extends Provider {
/**
* This method allows to parse the token and extract information from it after initial verification.
* @param token Token.
* @param tokenContext Token context.
* @return Error response if the initial verification fails, {@code null} otherwise.
*/
default Response preHandleToken(T token, ActionTokenContext<T> tokenContext) {
return null;
}
/**
* Performs the action as per the token details. This method is only called if all verifiers
* returned in {@link #handleToken} succeed.

View file

@ -72,6 +72,26 @@ public class InviteOrgActionTokenHandler extends AbstractActionTokenHandler<Invi
);
}
@Override
public Response preHandleToken(InviteOrgActionToken token, ActionTokenContext<InviteOrgActionToken> tokenContext) {
KeycloakSession session = tokenContext.getSession();
OrganizationProvider orgProvider = session.getProvider(OrganizationProvider.class);
AuthenticationSessionModel authSession = tokenContext.getAuthenticationSession();
OrganizationModel organization = orgProvider.getById(token.getOrgId());
if (organization == null) {
return session.getProvider(LoginFormsProvider.class)
.setAuthenticationSession(authSession)
.setInfo(Messages.ORG_NOT_FOUND, token.getOrgId())
.createInfoPage();
}
session.getContext().setOrganization(organization);
return super.preHandleToken(token, tokenContext);
}
@Override
public Response handleToken(InviteOrgActionToken token, ActionTokenContext<InviteOrgActionToken> tokenContext) {
UserModel user = tokenContext.getAuthenticationSession().getAuthenticatedUser();

View file

@ -72,6 +72,7 @@ public class RegistrationPage implements FormAuthenticator, FormAuthenticatorFac
form.setAttribute("messageHeader", Messages.REGISTER_ORGANIZATION_MEMBER);
form.setAttribute(OrganizationModel.ORGANIZATION_NAME_ATTRIBUTE, organization.getName());
form.setAttribute(FIELD_EMAIL, token.getEmail());
}
} catch (VerificationException e) {
return form.setError(Messages.EXPIRED_ACTION).createErrorPage(Status.BAD_REQUEST);

View file

@ -317,7 +317,7 @@ public class RegistrationUserCreation implements FormAction, FormActionFactory {
}
// make sure the organization is set to the session so that UP org-related validators can run
session.setAttribute(OrganizationModel.class.getName(), organization);
session.getContext().setOrganization(organization);
session.setAttribute(InviteOrgActionToken.class.getName(), token);
if (token.isExpired() || !token.getActionId().equals(InviteOrgActionToken.TOKEN_TYPE)) {

View file

@ -27,6 +27,7 @@ import org.keycloak.authentication.AuthenticationFlowContext;
import org.keycloak.authentication.AuthenticationProcessor;
import org.keycloak.authentication.authenticators.browser.AbstractUsernameFormAuthenticator;
import org.keycloak.authentication.authenticators.browser.OTPFormAuthenticator;
import org.keycloak.authentication.forms.RegistrationPage;
import org.keycloak.authentication.requiredactions.util.UpdateProfileContext;
import org.keycloak.authentication.requiredactions.util.UserUpdateProfileContext;
import org.keycloak.broker.provider.BrokeredIdentityContext;
@ -648,6 +649,17 @@ public class FreeMarkerLoginFormsProvider implements LoginFormsProvider {
}
}
if (Profile.isFeatureEnabled(Feature.ORGANIZATION)) {
String email = (String) attributes.get(RegistrationPage.FIELD_EMAIL);
if (this.formData == null) {
this.formData = new MultivaluedHashMap<>();
}
String value = this.formData.getFirst(RegistrationPage.FIELD_EMAIL);
if (value == null || value.trim().isEmpty()) {
this.formData.putSingle(RegistrationPage.FIELD_EMAIL, email);
}
}
return createResponse(LoginFormsPages.REGISTER);
}

View file

@ -78,10 +78,6 @@ public class OrganizationInvitationResource {
return sendInvitation(user);
}
if (!realm.isRegistrationAllowed()) {
throw ErrorResponse.error("Realm does not allow self-registration", Status.BAD_REQUEST);
}
user = new InMemoryUserAdapter(session, realm, null);
user.setEmail(email);

View file

@ -161,7 +161,7 @@ public class OrganizationsResource {
throw ErrorResponse.error("Organization not found.", Response.Status.NOT_FOUND);
}
session.setAttribute(OrganizationModel.class.getName(), organizationModel);
session.getContext().setOrganization(organizationModel);
return new OrganizationResource(session, organizationModel, adminEvent);
}

View file

@ -103,7 +103,7 @@ public class OrganizationAuthenticator extends IdentityProviderAuthenticator {
}
// make sure the organization is set to the session to make it available to templates
session.setAttribute(OrganizationModel.class.getName(), organization);
session.getContext().setOrganization(organization);
if (tryRedirectBroker(context, organization, user, username, domain)) {
return;

View file

@ -122,15 +122,11 @@ public class Organizations {
try {
OrganizationProvider provider = getProvider(session);
session.setAttribute(OrganizationModel.class.getName(), provider.getById(group.getName()));
session.getContext().setOrganization(provider.getById(group.getName()));
realm.removeGroup(group);
} finally {
if (current == null) {
session.removeAttribute(OrganizationModel.class.getName());
} else {
session.setAttribute(OrganizationModel.class.getName(), current);
}
session.getContext().setOrganization(current);
}
};
}
@ -249,7 +245,7 @@ public class Organizations {
}
public static OrganizationModel resolveOrganization(KeycloakSession session, UserModel user, String domain) {
Optional<OrganizationModel> organization = Optional.ofNullable((OrganizationModel) session.getAttribute(OrganizationModel.class.getName()));
Optional<OrganizationModel> organization = Optional.ofNullable(session.getContext().getOrganization());
if (organization.isPresent()) {
// resolved from current keycloak session
@ -297,4 +293,9 @@ public class Organizations {
public static OrganizationProvider getProvider(KeycloakSession session) {
return session.getProvider(OrganizationProvider.class);
}
public static boolean isRegistrationAllowed(KeycloakSession session, RealmModel realm) {
if (session.getContext().getOrganization() != null) return true;
return realm.isRegistrationAllowed();
}
}

View file

@ -24,6 +24,7 @@ import org.keycloak.common.ClientConnection;
import org.keycloak.events.EventBuilder;
import org.keycloak.forms.login.LoginFormsProvider;
import org.keycloak.jose.jwk.JSONWebKeySet;
import org.keycloak.models.Constants;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.protocol.oidc.endpoints.AuthorizationEndpoint;
@ -153,9 +154,9 @@ public class OIDCLoginProtocolService {
* Registration endpoint
*/
@Path("registrations")
public Object registrations() {
public Object registrations(@QueryParam(Constants.TOKEN) String tokenString) {
AuthorizationEndpoint endpoint = new AuthorizationEndpoint(session, event);
return endpoint.register();
return endpoint.register(tokenString);
}
/**

View file

@ -31,6 +31,7 @@ import org.keycloak.models.AuthenticationFlowModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.Constants;
import org.keycloak.models.KeycloakSession;
import org.keycloak.organization.utils.Organizations;
import org.keycloak.protocol.AuthorizationEndpointBase;
import org.keycloak.protocol.oidc.OIDCAdvancedConfigWrapper;
import org.keycloak.protocol.oidc.OIDCLoginProtocol;
@ -206,11 +207,19 @@ public class AuthorizationEndpoint extends AuthorizationEndpointBase {
throw new RuntimeException("Unknown action " + action);
}
public AuthorizationEndpoint register() {
public AuthorizationEndpoint register(String tokenString) {
event.event(EventType.REGISTER);
action = Action.REGISTER;
if (!realm.isRegistrationAllowed()) {
if (Profile.isFeatureEnabled(Profile.Feature.ORGANIZATION)) {
//this call should extract orgId from token and set the organization to the session context
Response errorResponse = new LoginActionsService(session, event).preHandleActionToken(tokenString);
if (errorResponse != null) {
throw new ErrorPageException(errorResponse);
}
}
if (!Organizations.isRegistrationAllowed(session, realm)) {
throw new ErrorPageException(session, authenticationSession, Response.Status.BAD_REQUEST, Messages.REGISTRATION_NOT_ALLOWED);
}

View file

@ -26,6 +26,7 @@ import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakContext;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakUriInfo;
import org.keycloak.models.OrganizationModel;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel;
import org.keycloak.sessions.AuthenticationSessionModel;
@ -45,6 +46,8 @@ public abstract class DefaultKeycloakContext implements KeycloakContext {
private ClientModel client;
private OrganizationModel organization;
protected KeycloakSession session;
private Map<UrlType, KeycloakUriInfo> uriInfo;
@ -117,6 +120,16 @@ public abstract class DefaultKeycloakContext implements KeycloakContext {
this.client = client;
}
@Override
public OrganizationModel getOrganization() {
return organization;
}
@Override
public void setOrganization(OrganizationModel organization) {
this.organization = organization;
}
@Override
public ClientConnection getConnection() {
if (clientConnection == null) {

View file

@ -29,10 +29,11 @@ import jakarta.ws.rs.core.Response;
public class ErrorPageException extends WebApplicationException {
private final KeycloakSession session;
private Response.Status status;
private final Response.Status status;
private final String errorMessage;
private final Object[] parameters;
private final AuthenticationSessionModel authSession;
private final Response response;
public ErrorPageException(KeycloakSession session, Response.Status status, String errorMessage, Object... parameters) {
@ -42,6 +43,7 @@ public class ErrorPageException extends WebApplicationException {
this.errorMessage = errorMessage;
this.parameters = parameters;
this.authSession = null;
this.response = null;
}
public ErrorPageException(KeycloakSession session, AuthenticationSessionModel authSession, Response.Status status, String errorMessage, Object... parameters) {
@ -50,13 +52,20 @@ public class ErrorPageException extends WebApplicationException {
this.errorMessage = errorMessage;
this.parameters = parameters;
this.authSession = authSession;
this.response = null;
}
public ErrorPageException(Response response) {
this.session = null;
this.status = null;
this.errorMessage = null;
this.parameters = null;
this.authSession = null;
this.response = response;
}
@Override
public Response getResponse() {
return ErrorPage.error(session, authSession, status, errorMessage, parameters);
return response != null ? response : ErrorPage.error(session, authSession, status, errorMessage, parameters);
}
}

View file

@ -46,6 +46,7 @@ import org.keycloak.broker.provider.BrokeredIdentityContext;
import org.keycloak.common.ClientConnection;
import org.keycloak.common.VerificationException;
import org.keycloak.common.util.Time;
import org.keycloak.common.util.TriFunction;
import org.keycloak.crypto.SignatureProvider;
import org.keycloak.crypto.SignatureVerifierContext;
import org.keycloak.events.Details;
@ -71,6 +72,7 @@ import org.keycloak.models.utils.FormMessage;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.models.utils.SystemClientUtil;
import org.keycloak.organization.OrganizationProvider;
import org.keycloak.organization.utils.Organizations;
import org.keycloak.protocol.AuthorizationEndpointBase;
import org.keycloak.protocol.LoginProtocol;
import org.keycloak.protocol.LoginProtocol.Error;
@ -415,7 +417,7 @@ public class LoginActionsService {
@QueryParam(Constants.CLIENT_DATA) String clientData,
@QueryParam(Constants.KEY) String key) {
if (key != null) {
return handleActionToken(key, execution, clientId, tabId, clientData);
return handleActionToken(key, execution, clientId, tabId, clientData, null);
}
event.event(EventType.RESET_PASSWORD);
@ -544,10 +546,11 @@ public class LoginActionsService {
@QueryParam(Constants.CLIENT_ID) String clientId,
@QueryParam(Constants.CLIENT_DATA) String clientData,
@QueryParam(Constants.TAB_ID) String tabId) {
return handleActionToken(key, execution, clientId, tabId, clientData);
return handleActionToken(key, execution, clientId, tabId, clientData, null);
}
protected <T extends JsonWebToken & SingleUseObjectKeyModel> Response handleActionToken(String tokenString, String execution, String clientId, String tabId, String clientData) {
protected <T extends JsonWebToken & SingleUseObjectKeyModel> Response handleActionToken(String tokenString, String execution, String clientId, String tabId, String clientData,
TriFunction<ActionTokenHandler<T>, T, ActionTokenContext<T>, Response> preHandleToken) {
T token;
ActionTokenHandler<T> handler;
ActionTokenContext<T> tokenContext;
@ -636,7 +639,11 @@ public class LoginActionsService {
}
// Now proceed with the verification and handle the token
tokenContext = new ActionTokenContext(session, realm, sessionContext.getUri(), clientConnection, request, event, handler, execution, clientData, this::processFlow, this::brokerLoginFlow);
tokenContext = new ActionTokenContext<>(session, realm, sessionContext.getUri(), clientConnection, request, event, handler, execution, clientData, this::processFlow, this::brokerLoginFlow);
if (preHandleToken != null) {
return preHandleToken.apply(handler, token, tokenContext);
}
try {
String tokenAuthSessionCompoundId = handler.getAuthenticationSessionIdFromToken(token, tokenContext, authSession);
@ -772,14 +779,20 @@ public class LoginActionsService {
@QueryParam(Constants.EXECUTION) String execution,
@QueryParam(Constants.CLIENT_ID) String clientId,
@QueryParam(Constants.CLIENT_DATA) String clientData,
@QueryParam(Constants.TAB_ID) String tabId) {
return registerRequest(authSessionId, code, execution, clientId, tabId,clientData);
@QueryParam(Constants.TAB_ID) String tabId,
@QueryParam(Constants.TOKEN) String tokenString) {
if (Profile.isFeatureEnabled(Profile.Feature.ORGANIZATION) && tokenString != null) {
//this call should extract orgId from token and set the organization to the session context
preHandleActionToken(tokenString);
}
return registerRequest(authSessionId, code, execution, clientId, tabId, clientData);
}
private Response registerRequest(String authSessionId, String code, String execution, String clientId, String tabId, String clientData) {
event.event(EventType.REGISTER);
if (!realm.isRegistrationAllowed()) {
if (!Organizations.isRegistrationAllowed(session, realm)) {
event.error(Errors.REGISTRATION_DISABLED);
return ErrorPage.error(session, null, Response.Status.BAD_REQUEST, Messages.REGISTRATION_NOT_ALLOWED);
}
@ -944,7 +957,7 @@ public class LoginActionsService {
if (organizationId != null) {
OrganizationProvider provider = session.getProvider(OrganizationProvider.class);
session.setAttribute(OrganizationModel.class.getName(), provider.getById(organizationId));
session.getContext().setOrganization(provider.getById(organizationId));
session.setAttribute(BrokeredIdentityContext.class.getName(), brokerContext);
}
}
@ -1204,4 +1217,7 @@ public class LoginActionsService {
return false;
}
public Response preHandleActionToken(String tokenString) {
return handleActionToken(tokenString, null, null, null, null, ActionTokenHandler::preHandleToken);
}
}

View file

@ -202,4 +202,9 @@ public class RealmAttributeUpdater extends ServerResourceUpdater<RealmAttributeU
rep.setOrganizationsEnabled(organizationsEnabled);
return this;
}
public RealmAttributeUpdater setRegistrationAllowed(Boolean registrationAllowed) {
rep.setRegistrationAllowed(registrationAllowed);
return this;
}
}

View file

@ -20,7 +20,7 @@ package org.keycloak.testsuite.organization.admin;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.junit.Assert.assertEquals;
import static org.hamcrest.Matchers.equalTo;
import java.io.IOException;
import java.util.Arrays;
@ -32,16 +32,14 @@ import java.util.function.Predicate;
import jakarta.mail.MessagingException;
import jakarta.mail.internet.MimeMessage;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;
import java.time.Duration;
import org.hamcrest.Matchers;
import static org.hamcrest.Matchers.equalTo;
import org.jboss.arquillian.graphene.page.Page;
import org.junit.Rule;
import org.junit.Test;
import org.keycloak.admin.client.resource.OrganizationResource;
import org.keycloak.common.util.UriUtils;
import org.keycloak.cookie.CookieType;
import org.keycloak.representations.idm.ErrorRepresentation;
import org.keycloak.representations.idm.MemberRepresentation;
import org.keycloak.representations.idm.MembershipType;
import org.keycloak.representations.idm.RealmRepresentation;
@ -51,6 +49,7 @@ import org.keycloak.testsuite.AssertEvents;
import org.keycloak.testsuite.admin.ApiUtil;
import org.keycloak.testsuite.pages.InfoPage;
import org.keycloak.testsuite.pages.RegisterPage;
import org.keycloak.testsuite.updaters.RealmAttributeUpdater;
import org.keycloak.testsuite.util.GreenMailRule;
import org.keycloak.testsuite.util.MailUtils;
import org.keycloak.testsuite.util.MailUtils.EmailBody;
@ -124,19 +123,25 @@ public class OrganizationInvitationLinkTest extends AbstractOrganizationTest {
}
@Test
public void testFailRegistrationNotEnabledWhenInvitingNewUser() {
public void testRegistrationEnabledWhenInvitingNewUser() throws Exception {
String email = "inviteduser@email";
OrganizationResource organization = testRealm().organizations().get(createOrganization().getId());
RealmRepresentation realm = testRealm().toRepresentation();
realm.setRegistrationAllowed(false);
testRealm().update(realm);
try (Response response = organization.members().inviteUser(email, null, null)) {
assertEquals(Status.BAD_REQUEST.getStatusCode(), response.getStatus());
assertEquals("Realm does not allow self-registration", response.readEntity(ErrorRepresentation.class).getErrorMessage());
} finally {
realm.setRegistrationAllowed(true);
testRealm().update(realm);
try (
RealmAttributeUpdater rau = new RealmAttributeUpdater(testRealm()).setRegistrationAllowed(Boolean.TRUE).update();
Response response = organization.members().inviteUser(email, null, null)
) {
assertThat(response.getStatus(), equalTo(Response.Status.NO_CONTENT.getStatusCode()));
registerUser(organization, email);
// authenticated to the account console
Assert.assertTrue(driver.getPageSource().contains("Account Management"));
Assert.assertNotNull(driver.manage().getCookieNamed(CookieType.IDENTITY.getName()));
List<MemberRepresentation> memberByEmail = organization.members().search(email, Boolean.TRUE, null, null);
assertThat(memberByEmail, Matchers.hasSize(1));
assertThat(memberByEmail.get(0).getMembershipType(), equalTo(MembershipType.MANAGED));
}
}
@ -147,7 +152,7 @@ public class OrganizationInvitationLinkTest extends AbstractOrganizationTest {
OrganizationResource organization = testRealm().organizations().get(createOrganization().getId());
organization.members().inviteUser(email, null, null).close();
registerUser(organization, "invalid@email.com");
registerUser(organization, email, "invalid@email.com");
assertThat(driver.getPageSource(), Matchers.containsString("Email does not match the invitation"));
assertThat(testRealm().users().searchByEmail(email, true), Matchers.empty());
@ -200,9 +205,10 @@ public class OrganizationInvitationLinkTest extends AbstractOrganizationTest {
try {
setTimeOffset((int) TimeUnit.DAYS.toSeconds(1));
registerUser(organization, email);
String link = getInvitationLinkFromEmail();
driver.navigate().to(link);
assertThat(driver.getPageSource(), Matchers.containsString("The provided token is not valid or has expired."));
assertThat(driver.getPageSource(), Matchers.containsString("Action expired."));
assertThat(testRealm().users().searchByEmail(email, true), Matchers.empty());
} finally {
resetTimeOffset();
@ -210,11 +216,16 @@ public class OrganizationInvitationLinkTest extends AbstractOrganizationTest {
}
private void registerUser(OrganizationResource organization, String email) throws MessagingException, IOException {
registerUser(organization, email, email);
}
private void registerUser(OrganizationResource organization, String expectedEmail, String email) throws MessagingException, IOException {
String link = getInvitationLinkFromEmail();
driver.navigate().to(link);
Assert.assertFalse(organization.members().getAll().stream().anyMatch(actual -> email.equals(actual.getEmail())));
registerPage.assertCurrent(organizationName);
driver.manage().timeouts().pageLoadTimeout(1, TimeUnit.DAYS);
driver.manage().timeouts().pageLoadTimeout(Duration.ofSeconds(10));
assertThat(registerPage.getEmail(), equalTo(expectedEmail));
registerPage.register("firstName", "lastName", email,
"invitedUser", "password", "password", null, false, null);
}