Implement invitation-only self-registration for realm users

Closes #31643

Signed-off-by: vramik <vramik@redhat.com>
This commit is contained in:
vramik 2024-09-09 13:54:37 +02:00 committed by Marek Posolda
parent 1f573eded0
commit fcb31a5aa6
19 changed files with 176 additions and 68 deletions

View file

@ -103,7 +103,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
OrganizationAdapter adapter = new OrganizationAdapter(session, realm, this); OrganizationAdapter adapter = new OrganizationAdapter(session, realm, this);
try { try {
session.setAttribute(OrganizationModel.class.getName(), adapter); session.getContext().setOrganization(adapter);
GroupModel group = createOrganizationGroup(adapter.getId()); GroupModel group = createOrganizationGroup(adapter.getId());
adapter.setGroupId(group.getId()); adapter.setGroupId(group.getId());
@ -113,7 +113,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
em.persist(adapter.getEntity()); em.persist(adapter.getEntity());
} finally { } finally {
session.removeAttribute(OrganizationModel.class.getName()); session.getContext().setOrganization(null);
} }
return adapter; return adapter;
@ -124,7 +124,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
OrganizationEntity entity = getEntity(organization.getId()); OrganizationEntity entity = getEntity(organization.getId());
try { try {
session.setAttribute(OrganizationModel.class.getName(), organization); session.getContext().setOrganization(organization);
RealmModel realm = session.realms().getRealm(getRealm().getId()); RealmModel realm = session.realms().getRealm(getRealm().getId());
// check if the realm is being removed so that we don't need to remove manually remove any other data but the org // check if the realm is being removed so that we don't need to remove manually remove any other data but the org
@ -143,7 +143,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
em.remove(entity); em.remove(entity);
} finally { } finally {
session.removeAttribute(OrganizationModel.class.getName()); session.getContext().setOrganization(null);
} }
return true; return true;
@ -178,7 +178,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
} }
if (current == null) { if (current == null) {
session.setAttribute(OrganizationModel.class.getName(), organization); session.getContext().setOrganization(organization);
} }
try { try {
@ -191,7 +191,7 @@ public class JpaOrganizationProvider implements OrganizationProvider {
user.joinGroup(group, metadata); user.joinGroup(group, metadata);
} finally { } finally {
if (current == null) { if (current == null) {
session.removeAttribute(OrganizationModel.class.getName()); session.getContext().setOrganization(null);
} }
} }
@ -417,14 +417,14 @@ public class JpaOrganizationProvider implements OrganizationProvider {
OrganizationModel current = Organizations.resolveOrganization(session); OrganizationModel current = Organizations.resolveOrganization(session);
if (current == null) { if (current == null) {
session.setAttribute(OrganizationModel.class.getName(), organization); session.getContext().setOrganization(organization);
} }
try { try {
member.leaveGroup(getOrganizationGroup(organization)); member.leaveGroup(getOrganizationGroup(organization));
} finally { } finally {
if (current == null) { if (current == null) {
session.removeAttribute(OrganizationModel.class.getName()); session.getContext().setOrganization(null);
} }
} }
} }

View file

@ -142,9 +142,9 @@ public final class OrganizationAdapter implements OrganizationModel, JpaModel<Or
} }
// add organization to the session as the following code updates the underlying group // add organization to the session as the following code updates the underlying group
OrganizationModel current = (OrganizationModel) session.getAttribute(OrganizationModel.class.getName()); OrganizationModel current = session.getContext().getOrganization();
if (current == null) { if (current == null) {
session.setAttribute(OrganizationModel.class.getName(), this); session.getContext().setOrganization(this);
} }
try { try {
@ -154,7 +154,7 @@ public final class OrganizationAdapter implements OrganizationModel, JpaModel<Or
attributes.forEach(group::setAttribute); attributes.forEach(group::setAttribute);
} finally { } finally {
if (current == null) { if (current == null) {
session.removeAttribute(OrganizationModel.class.getName()); session.getContext().setOrganization(null);
} }
} }
} }

View file

@ -73,6 +73,10 @@ public interface KeycloakContext {
void setClient(ClientModel client); void setClient(ClientModel client);
OrganizationModel getOrganization();
void setOrganization(OrganizationModel organization);
ClientConnection getConnection(); ClientConnection getConnection();
Locale resolveLocale(UserModel user); Locale resolveLocale(UserModel user);

View file

@ -35,6 +35,16 @@ import jakarta.ws.rs.core.Response;
*/ */
public interface ActionTokenHandler<T extends JsonWebToken> extends Provider { public interface ActionTokenHandler<T extends JsonWebToken> extends Provider {
/**
* This method allows to parse the token and extract information from it after initial verification.
* @param token Token.
* @param tokenContext Token context.
* @return Error response if the initial verification fails, {@code null} otherwise.
*/
default Response preHandleToken(T token, ActionTokenContext<T> tokenContext) {
return null;
}
/** /**
* Performs the action as per the token details. This method is only called if all verifiers * Performs the action as per the token details. This method is only called if all verifiers
* returned in {@link #handleToken} succeed. * returned in {@link #handleToken} succeed.

View file

@ -72,6 +72,26 @@ public class InviteOrgActionTokenHandler extends AbstractActionTokenHandler<Invi
); );
} }
@Override
public Response preHandleToken(InviteOrgActionToken token, ActionTokenContext<InviteOrgActionToken> tokenContext) {
KeycloakSession session = tokenContext.getSession();
OrganizationProvider orgProvider = session.getProvider(OrganizationProvider.class);
AuthenticationSessionModel authSession = tokenContext.getAuthenticationSession();
OrganizationModel organization = orgProvider.getById(token.getOrgId());
if (organization == null) {
return session.getProvider(LoginFormsProvider.class)
.setAuthenticationSession(authSession)
.setInfo(Messages.ORG_NOT_FOUND, token.getOrgId())
.createInfoPage();
}
session.getContext().setOrganization(organization);
return super.preHandleToken(token, tokenContext);
}
@Override @Override
public Response handleToken(InviteOrgActionToken token, ActionTokenContext<InviteOrgActionToken> tokenContext) { public Response handleToken(InviteOrgActionToken token, ActionTokenContext<InviteOrgActionToken> tokenContext) {
UserModel user = tokenContext.getAuthenticationSession().getAuthenticatedUser(); UserModel user = tokenContext.getAuthenticationSession().getAuthenticatedUser();

View file

@ -72,6 +72,7 @@ public class RegistrationPage implements FormAuthenticator, FormAuthenticatorFac
form.setAttribute("messageHeader", Messages.REGISTER_ORGANIZATION_MEMBER); form.setAttribute("messageHeader", Messages.REGISTER_ORGANIZATION_MEMBER);
form.setAttribute(OrganizationModel.ORGANIZATION_NAME_ATTRIBUTE, organization.getName()); form.setAttribute(OrganizationModel.ORGANIZATION_NAME_ATTRIBUTE, organization.getName());
form.setAttribute(FIELD_EMAIL, token.getEmail());
} }
} catch (VerificationException e) { } catch (VerificationException e) {
return form.setError(Messages.EXPIRED_ACTION).createErrorPage(Status.BAD_REQUEST); return form.setError(Messages.EXPIRED_ACTION).createErrorPage(Status.BAD_REQUEST);

View file

@ -317,7 +317,7 @@ public class RegistrationUserCreation implements FormAction, FormActionFactory {
} }
// make sure the organization is set to the session so that UP org-related validators can run // make sure the organization is set to the session so that UP org-related validators can run
session.setAttribute(OrganizationModel.class.getName(), organization); session.getContext().setOrganization(organization);
session.setAttribute(InviteOrgActionToken.class.getName(), token); session.setAttribute(InviteOrgActionToken.class.getName(), token);
if (token.isExpired() || !token.getActionId().equals(InviteOrgActionToken.TOKEN_TYPE)) { if (token.isExpired() || !token.getActionId().equals(InviteOrgActionToken.TOKEN_TYPE)) {

View file

@ -27,6 +27,7 @@ import org.keycloak.authentication.AuthenticationFlowContext;
import org.keycloak.authentication.AuthenticationProcessor; import org.keycloak.authentication.AuthenticationProcessor;
import org.keycloak.authentication.authenticators.browser.AbstractUsernameFormAuthenticator; import org.keycloak.authentication.authenticators.browser.AbstractUsernameFormAuthenticator;
import org.keycloak.authentication.authenticators.browser.OTPFormAuthenticator; import org.keycloak.authentication.authenticators.browser.OTPFormAuthenticator;
import org.keycloak.authentication.forms.RegistrationPage;
import org.keycloak.authentication.requiredactions.util.UpdateProfileContext; import org.keycloak.authentication.requiredactions.util.UpdateProfileContext;
import org.keycloak.authentication.requiredactions.util.UserUpdateProfileContext; import org.keycloak.authentication.requiredactions.util.UserUpdateProfileContext;
import org.keycloak.broker.provider.BrokeredIdentityContext; import org.keycloak.broker.provider.BrokeredIdentityContext;
@ -648,6 +649,17 @@ public class FreeMarkerLoginFormsProvider implements LoginFormsProvider {
} }
} }
if (Profile.isFeatureEnabled(Feature.ORGANIZATION)) {
String email = (String) attributes.get(RegistrationPage.FIELD_EMAIL);
if (this.formData == null) {
this.formData = new MultivaluedHashMap<>();
}
String value = this.formData.getFirst(RegistrationPage.FIELD_EMAIL);
if (value == null || value.trim().isEmpty()) {
this.formData.putSingle(RegistrationPage.FIELD_EMAIL, email);
}
}
return createResponse(LoginFormsPages.REGISTER); return createResponse(LoginFormsPages.REGISTER);
} }

View file

@ -78,10 +78,6 @@ public class OrganizationInvitationResource {
return sendInvitation(user); return sendInvitation(user);
} }
if (!realm.isRegistrationAllowed()) {
throw ErrorResponse.error("Realm does not allow self-registration", Status.BAD_REQUEST);
}
user = new InMemoryUserAdapter(session, realm, null); user = new InMemoryUserAdapter(session, realm, null);
user.setEmail(email); user.setEmail(email);

View file

@ -161,7 +161,7 @@ public class OrganizationsResource {
throw ErrorResponse.error("Organization not found.", Response.Status.NOT_FOUND); throw ErrorResponse.error("Organization not found.", Response.Status.NOT_FOUND);
} }
session.setAttribute(OrganizationModel.class.getName(), organizationModel); session.getContext().setOrganization(organizationModel);
return new OrganizationResource(session, organizationModel, adminEvent); return new OrganizationResource(session, organizationModel, adminEvent);
} }

View file

@ -103,7 +103,7 @@ public class OrganizationAuthenticator extends IdentityProviderAuthenticator {
} }
// make sure the organization is set to the session to make it available to templates // make sure the organization is set to the session to make it available to templates
session.setAttribute(OrganizationModel.class.getName(), organization); session.getContext().setOrganization(organization);
if (tryRedirectBroker(context, organization, user, username, domain)) { if (tryRedirectBroker(context, organization, user, username, domain)) {
return; return;

View file

@ -122,15 +122,11 @@ public class Organizations {
try { try {
OrganizationProvider provider = getProvider(session); OrganizationProvider provider = getProvider(session);
session.setAttribute(OrganizationModel.class.getName(), provider.getById(group.getName())); session.getContext().setOrganization(provider.getById(group.getName()));
realm.removeGroup(group); realm.removeGroup(group);
} finally { } finally {
if (current == null) { session.getContext().setOrganization(current);
session.removeAttribute(OrganizationModel.class.getName());
} else {
session.setAttribute(OrganizationModel.class.getName(), current);
}
} }
}; };
} }
@ -249,7 +245,7 @@ public class Organizations {
} }
public static OrganizationModel resolveOrganization(KeycloakSession session, UserModel user, String domain) { public static OrganizationModel resolveOrganization(KeycloakSession session, UserModel user, String domain) {
Optional<OrganizationModel> organization = Optional.ofNullable((OrganizationModel) session.getAttribute(OrganizationModel.class.getName())); Optional<OrganizationModel> organization = Optional.ofNullable(session.getContext().getOrganization());
if (organization.isPresent()) { if (organization.isPresent()) {
// resolved from current keycloak session // resolved from current keycloak session
@ -297,4 +293,9 @@ public class Organizations {
public static OrganizationProvider getProvider(KeycloakSession session) { public static OrganizationProvider getProvider(KeycloakSession session) {
return session.getProvider(OrganizationProvider.class); return session.getProvider(OrganizationProvider.class);
} }
public static boolean isRegistrationAllowed(KeycloakSession session, RealmModel realm) {
if (session.getContext().getOrganization() != null) return true;
return realm.isRegistrationAllowed();
}
} }

View file

@ -24,6 +24,7 @@ import org.keycloak.common.ClientConnection;
import org.keycloak.events.EventBuilder; import org.keycloak.events.EventBuilder;
import org.keycloak.forms.login.LoginFormsProvider; import org.keycloak.forms.login.LoginFormsProvider;
import org.keycloak.jose.jwk.JSONWebKeySet; import org.keycloak.jose.jwk.JSONWebKeySet;
import org.keycloak.models.Constants;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.protocol.oidc.endpoints.AuthorizationEndpoint; import org.keycloak.protocol.oidc.endpoints.AuthorizationEndpoint;
@ -153,9 +154,9 @@ public class OIDCLoginProtocolService {
* Registration endpoint * Registration endpoint
*/ */
@Path("registrations") @Path("registrations")
public Object registrations() { public Object registrations(@QueryParam(Constants.TOKEN) String tokenString) {
AuthorizationEndpoint endpoint = new AuthorizationEndpoint(session, event); AuthorizationEndpoint endpoint = new AuthorizationEndpoint(session, event);
return endpoint.register(); return endpoint.register(tokenString);
} }
/** /**

View file

@ -31,6 +31,7 @@ import org.keycloak.models.AuthenticationFlowModel;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.Constants; import org.keycloak.models.Constants;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.organization.utils.Organizations;
import org.keycloak.protocol.AuthorizationEndpointBase; import org.keycloak.protocol.AuthorizationEndpointBase;
import org.keycloak.protocol.oidc.OIDCAdvancedConfigWrapper; import org.keycloak.protocol.oidc.OIDCAdvancedConfigWrapper;
import org.keycloak.protocol.oidc.OIDCLoginProtocol; import org.keycloak.protocol.oidc.OIDCLoginProtocol;
@ -206,11 +207,19 @@ public class AuthorizationEndpoint extends AuthorizationEndpointBase {
throw new RuntimeException("Unknown action " + action); throw new RuntimeException("Unknown action " + action);
} }
public AuthorizationEndpoint register() { public AuthorizationEndpoint register(String tokenString) {
event.event(EventType.REGISTER); event.event(EventType.REGISTER);
action = Action.REGISTER; action = Action.REGISTER;
if (!realm.isRegistrationAllowed()) { if (Profile.isFeatureEnabled(Profile.Feature.ORGANIZATION)) {
//this call should extract orgId from token and set the organization to the session context
Response errorResponse = new LoginActionsService(session, event).preHandleActionToken(tokenString);
if (errorResponse != null) {
throw new ErrorPageException(errorResponse);
}
}
if (!Organizations.isRegistrationAllowed(session, realm)) {
throw new ErrorPageException(session, authenticationSession, Response.Status.BAD_REQUEST, Messages.REGISTRATION_NOT_ALLOWED); throw new ErrorPageException(session, authenticationSession, Response.Status.BAD_REQUEST, Messages.REGISTRATION_NOT_ALLOWED);
} }

View file

@ -26,6 +26,7 @@ import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakContext; import org.keycloak.models.KeycloakContext;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakUriInfo; import org.keycloak.models.KeycloakUriInfo;
import org.keycloak.models.OrganizationModel;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.sessions.AuthenticationSessionModel; import org.keycloak.sessions.AuthenticationSessionModel;
@ -45,6 +46,8 @@ public abstract class DefaultKeycloakContext implements KeycloakContext {
private ClientModel client; private ClientModel client;
private OrganizationModel organization;
protected KeycloakSession session; protected KeycloakSession session;
private Map<UrlType, KeycloakUriInfo> uriInfo; private Map<UrlType, KeycloakUriInfo> uriInfo;
@ -117,6 +120,16 @@ public abstract class DefaultKeycloakContext implements KeycloakContext {
this.client = client; this.client = client;
} }
@Override
public OrganizationModel getOrganization() {
return organization;
}
@Override
public void setOrganization(OrganizationModel organization) {
this.organization = organization;
}
@Override @Override
public ClientConnection getConnection() { public ClientConnection getConnection() {
if (clientConnection == null) { if (clientConnection == null) {

View file

@ -29,10 +29,11 @@ import jakarta.ws.rs.core.Response;
public class ErrorPageException extends WebApplicationException { public class ErrorPageException extends WebApplicationException {
private final KeycloakSession session; private final KeycloakSession session;
private Response.Status status; private final Response.Status status;
private final String errorMessage; private final String errorMessage;
private final Object[] parameters; private final Object[] parameters;
private final AuthenticationSessionModel authSession; private final AuthenticationSessionModel authSession;
private final Response response;
public ErrorPageException(KeycloakSession session, Response.Status status, String errorMessage, Object... parameters) { public ErrorPageException(KeycloakSession session, Response.Status status, String errorMessage, Object... parameters) {
@ -42,6 +43,7 @@ public class ErrorPageException extends WebApplicationException {
this.errorMessage = errorMessage; this.errorMessage = errorMessage;
this.parameters = parameters; this.parameters = parameters;
this.authSession = null; this.authSession = null;
this.response = null;
} }
public ErrorPageException(KeycloakSession session, AuthenticationSessionModel authSession, Response.Status status, String errorMessage, Object... parameters) { public ErrorPageException(KeycloakSession session, AuthenticationSessionModel authSession, Response.Status status, String errorMessage, Object... parameters) {
@ -50,13 +52,20 @@ public class ErrorPageException extends WebApplicationException {
this.errorMessage = errorMessage; this.errorMessage = errorMessage;
this.parameters = parameters; this.parameters = parameters;
this.authSession = authSession; this.authSession = authSession;
this.response = null;
} }
public ErrorPageException(Response response) {
this.session = null;
this.status = null;
this.errorMessage = null;
this.parameters = null;
this.authSession = null;
this.response = response;
}
@Override @Override
public Response getResponse() { public Response getResponse() {
return ErrorPage.error(session, authSession, status, errorMessage, parameters); return response != null ? response : ErrorPage.error(session, authSession, status, errorMessage, parameters);
} }
} }

View file

@ -46,6 +46,7 @@ import org.keycloak.broker.provider.BrokeredIdentityContext;
import org.keycloak.common.ClientConnection; import org.keycloak.common.ClientConnection;
import org.keycloak.common.VerificationException; import org.keycloak.common.VerificationException;
import org.keycloak.common.util.Time; import org.keycloak.common.util.Time;
import org.keycloak.common.util.TriFunction;
import org.keycloak.crypto.SignatureProvider; import org.keycloak.crypto.SignatureProvider;
import org.keycloak.crypto.SignatureVerifierContext; import org.keycloak.crypto.SignatureVerifierContext;
import org.keycloak.events.Details; import org.keycloak.events.Details;
@ -71,6 +72,7 @@ import org.keycloak.models.utils.FormMessage;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.models.utils.SystemClientUtil; import org.keycloak.models.utils.SystemClientUtil;
import org.keycloak.organization.OrganizationProvider; import org.keycloak.organization.OrganizationProvider;
import org.keycloak.organization.utils.Organizations;
import org.keycloak.protocol.AuthorizationEndpointBase; import org.keycloak.protocol.AuthorizationEndpointBase;
import org.keycloak.protocol.LoginProtocol; import org.keycloak.protocol.LoginProtocol;
import org.keycloak.protocol.LoginProtocol.Error; import org.keycloak.protocol.LoginProtocol.Error;
@ -415,7 +417,7 @@ public class LoginActionsService {
@QueryParam(Constants.CLIENT_DATA) String clientData, @QueryParam(Constants.CLIENT_DATA) String clientData,
@QueryParam(Constants.KEY) String key) { @QueryParam(Constants.KEY) String key) {
if (key != null) { if (key != null) {
return handleActionToken(key, execution, clientId, tabId, clientData); return handleActionToken(key, execution, clientId, tabId, clientData, null);
} }
event.event(EventType.RESET_PASSWORD); event.event(EventType.RESET_PASSWORD);
@ -544,10 +546,11 @@ public class LoginActionsService {
@QueryParam(Constants.CLIENT_ID) String clientId, @QueryParam(Constants.CLIENT_ID) String clientId,
@QueryParam(Constants.CLIENT_DATA) String clientData, @QueryParam(Constants.CLIENT_DATA) String clientData,
@QueryParam(Constants.TAB_ID) String tabId) { @QueryParam(Constants.TAB_ID) String tabId) {
return handleActionToken(key, execution, clientId, tabId, clientData); return handleActionToken(key, execution, clientId, tabId, clientData, null);
} }
protected <T extends JsonWebToken & SingleUseObjectKeyModel> Response handleActionToken(String tokenString, String execution, String clientId, String tabId, String clientData) { protected <T extends JsonWebToken & SingleUseObjectKeyModel> Response handleActionToken(String tokenString, String execution, String clientId, String tabId, String clientData,
TriFunction<ActionTokenHandler<T>, T, ActionTokenContext<T>, Response> preHandleToken) {
T token; T token;
ActionTokenHandler<T> handler; ActionTokenHandler<T> handler;
ActionTokenContext<T> tokenContext; ActionTokenContext<T> tokenContext;
@ -636,7 +639,11 @@ public class LoginActionsService {
} }
// Now proceed with the verification and handle the token // Now proceed with the verification and handle the token
tokenContext = new ActionTokenContext(session, realm, sessionContext.getUri(), clientConnection, request, event, handler, execution, clientData, this::processFlow, this::brokerLoginFlow); tokenContext = new ActionTokenContext<>(session, realm, sessionContext.getUri(), clientConnection, request, event, handler, execution, clientData, this::processFlow, this::brokerLoginFlow);
if (preHandleToken != null) {
return preHandleToken.apply(handler, token, tokenContext);
}
try { try {
String tokenAuthSessionCompoundId = handler.getAuthenticationSessionIdFromToken(token, tokenContext, authSession); String tokenAuthSessionCompoundId = handler.getAuthenticationSessionIdFromToken(token, tokenContext, authSession);
@ -772,14 +779,20 @@ public class LoginActionsService {
@QueryParam(Constants.EXECUTION) String execution, @QueryParam(Constants.EXECUTION) String execution,
@QueryParam(Constants.CLIENT_ID) String clientId, @QueryParam(Constants.CLIENT_ID) String clientId,
@QueryParam(Constants.CLIENT_DATA) String clientData, @QueryParam(Constants.CLIENT_DATA) String clientData,
@QueryParam(Constants.TAB_ID) String tabId) { @QueryParam(Constants.TAB_ID) String tabId,
@QueryParam(Constants.TOKEN) String tokenString) {
if (Profile.isFeatureEnabled(Profile.Feature.ORGANIZATION) && tokenString != null) {
//this call should extract orgId from token and set the organization to the session context
preHandleActionToken(tokenString);
}
return registerRequest(authSessionId, code, execution, clientId, tabId, clientData); return registerRequest(authSessionId, code, execution, clientId, tabId, clientData);
} }
private Response registerRequest(String authSessionId, String code, String execution, String clientId, String tabId, String clientData) { private Response registerRequest(String authSessionId, String code, String execution, String clientId, String tabId, String clientData) {
event.event(EventType.REGISTER); event.event(EventType.REGISTER);
if (!realm.isRegistrationAllowed()) { if (!Organizations.isRegistrationAllowed(session, realm)) {
event.error(Errors.REGISTRATION_DISABLED); event.error(Errors.REGISTRATION_DISABLED);
return ErrorPage.error(session, null, Response.Status.BAD_REQUEST, Messages.REGISTRATION_NOT_ALLOWED); return ErrorPage.error(session, null, Response.Status.BAD_REQUEST, Messages.REGISTRATION_NOT_ALLOWED);
} }
@ -944,7 +957,7 @@ public class LoginActionsService {
if (organizationId != null) { if (organizationId != null) {
OrganizationProvider provider = session.getProvider(OrganizationProvider.class); OrganizationProvider provider = session.getProvider(OrganizationProvider.class);
session.setAttribute(OrganizationModel.class.getName(), provider.getById(organizationId)); session.getContext().setOrganization(provider.getById(organizationId));
session.setAttribute(BrokeredIdentityContext.class.getName(), brokerContext); session.setAttribute(BrokeredIdentityContext.class.getName(), brokerContext);
} }
} }
@ -1204,4 +1217,7 @@ public class LoginActionsService {
return false; return false;
} }
public Response preHandleActionToken(String tokenString) {
return handleActionToken(tokenString, null, null, null, null, ActionTokenHandler::preHandleToken);
}
} }

View file

@ -202,4 +202,9 @@ public class RealmAttributeUpdater extends ServerResourceUpdater<RealmAttributeU
rep.setOrganizationsEnabled(organizationsEnabled); rep.setOrganizationsEnabled(organizationsEnabled);
return this; return this;
} }
public RealmAttributeUpdater setRegistrationAllowed(Boolean registrationAllowed) {
rep.setRegistrationAllowed(registrationAllowed);
return this;
}
} }

View file

@ -20,7 +20,7 @@ package org.keycloak.testsuite.organization.admin;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.junit.Assert.assertEquals; import static org.hamcrest.Matchers.equalTo;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -32,16 +32,14 @@ import java.util.function.Predicate;
import jakarta.mail.MessagingException; import jakarta.mail.MessagingException;
import jakarta.mail.internet.MimeMessage; import jakarta.mail.internet.MimeMessage;
import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status; import java.time.Duration;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import static org.hamcrest.Matchers.equalTo;
import org.jboss.arquillian.graphene.page.Page; import org.jboss.arquillian.graphene.page.Page;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.keycloak.admin.client.resource.OrganizationResource; import org.keycloak.admin.client.resource.OrganizationResource;
import org.keycloak.common.util.UriUtils; import org.keycloak.common.util.UriUtils;
import org.keycloak.cookie.CookieType; import org.keycloak.cookie.CookieType;
import org.keycloak.representations.idm.ErrorRepresentation;
import org.keycloak.representations.idm.MemberRepresentation; import org.keycloak.representations.idm.MemberRepresentation;
import org.keycloak.representations.idm.MembershipType; import org.keycloak.representations.idm.MembershipType;
import org.keycloak.representations.idm.RealmRepresentation; import org.keycloak.representations.idm.RealmRepresentation;
@ -51,6 +49,7 @@ import org.keycloak.testsuite.AssertEvents;
import org.keycloak.testsuite.admin.ApiUtil; import org.keycloak.testsuite.admin.ApiUtil;
import org.keycloak.testsuite.pages.InfoPage; import org.keycloak.testsuite.pages.InfoPage;
import org.keycloak.testsuite.pages.RegisterPage; import org.keycloak.testsuite.pages.RegisterPage;
import org.keycloak.testsuite.updaters.RealmAttributeUpdater;
import org.keycloak.testsuite.util.GreenMailRule; import org.keycloak.testsuite.util.GreenMailRule;
import org.keycloak.testsuite.util.MailUtils; import org.keycloak.testsuite.util.MailUtils;
import org.keycloak.testsuite.util.MailUtils.EmailBody; import org.keycloak.testsuite.util.MailUtils.EmailBody;
@ -124,19 +123,25 @@ public class OrganizationInvitationLinkTest extends AbstractOrganizationTest {
} }
@Test @Test
public void testFailRegistrationNotEnabledWhenInvitingNewUser() { public void testRegistrationEnabledWhenInvitingNewUser() throws Exception {
String email = "inviteduser@email"; String email = "inviteduser@email";
OrganizationResource organization = testRealm().organizations().get(createOrganization().getId()); OrganizationResource organization = testRealm().organizations().get(createOrganization().getId());
RealmRepresentation realm = testRealm().toRepresentation(); try (
realm.setRegistrationAllowed(false); RealmAttributeUpdater rau = new RealmAttributeUpdater(testRealm()).setRegistrationAllowed(Boolean.TRUE).update();
testRealm().update(realm); Response response = organization.members().inviteUser(email, null, null)
try (Response response = organization.members().inviteUser(email, null, null)) { ) {
assertEquals(Status.BAD_REQUEST.getStatusCode(), response.getStatus()); assertThat(response.getStatus(), equalTo(Response.Status.NO_CONTENT.getStatusCode()));
assertEquals("Realm does not allow self-registration", response.readEntity(ErrorRepresentation.class).getErrorMessage());
} finally { registerUser(organization, email);
realm.setRegistrationAllowed(true);
testRealm().update(realm); // authenticated to the account console
Assert.assertTrue(driver.getPageSource().contains("Account Management"));
Assert.assertNotNull(driver.manage().getCookieNamed(CookieType.IDENTITY.getName()));
List<MemberRepresentation> memberByEmail = organization.members().search(email, Boolean.TRUE, null, null);
assertThat(memberByEmail, Matchers.hasSize(1));
assertThat(memberByEmail.get(0).getMembershipType(), equalTo(MembershipType.MANAGED));
} }
} }
@ -147,7 +152,7 @@ public class OrganizationInvitationLinkTest extends AbstractOrganizationTest {
OrganizationResource organization = testRealm().organizations().get(createOrganization().getId()); OrganizationResource organization = testRealm().organizations().get(createOrganization().getId());
organization.members().inviteUser(email, null, null).close(); organization.members().inviteUser(email, null, null).close();
registerUser(organization, "invalid@email.com"); registerUser(organization, email, "invalid@email.com");
assertThat(driver.getPageSource(), Matchers.containsString("Email does not match the invitation")); assertThat(driver.getPageSource(), Matchers.containsString("Email does not match the invitation"));
assertThat(testRealm().users().searchByEmail(email, true), Matchers.empty()); assertThat(testRealm().users().searchByEmail(email, true), Matchers.empty());
@ -200,9 +205,10 @@ public class OrganizationInvitationLinkTest extends AbstractOrganizationTest {
try { try {
setTimeOffset((int) TimeUnit.DAYS.toSeconds(1)); setTimeOffset((int) TimeUnit.DAYS.toSeconds(1));
registerUser(organization, email); String link = getInvitationLinkFromEmail();
driver.navigate().to(link);
assertThat(driver.getPageSource(), Matchers.containsString("The provided token is not valid or has expired.")); assertThat(driver.getPageSource(), Matchers.containsString("Action expired."));
assertThat(testRealm().users().searchByEmail(email, true), Matchers.empty()); assertThat(testRealm().users().searchByEmail(email, true), Matchers.empty());
} finally { } finally {
resetTimeOffset(); resetTimeOffset();
@ -210,11 +216,16 @@ public class OrganizationInvitationLinkTest extends AbstractOrganizationTest {
} }
private void registerUser(OrganizationResource organization, String email) throws MessagingException, IOException { private void registerUser(OrganizationResource organization, String email) throws MessagingException, IOException {
registerUser(organization, email, email);
}
private void registerUser(OrganizationResource organization, String expectedEmail, String email) throws MessagingException, IOException {
String link = getInvitationLinkFromEmail(); String link = getInvitationLinkFromEmail();
driver.navigate().to(link); driver.navigate().to(link);
Assert.assertFalse(organization.members().getAll().stream().anyMatch(actual -> email.equals(actual.getEmail()))); Assert.assertFalse(organization.members().getAll().stream().anyMatch(actual -> email.equals(actual.getEmail())));
registerPage.assertCurrent(organizationName); registerPage.assertCurrent(organizationName);
driver.manage().timeouts().pageLoadTimeout(1, TimeUnit.DAYS); driver.manage().timeouts().pageLoadTimeout(Duration.ofSeconds(10));
assertThat(registerPage.getEmail(), equalTo(expectedEmail));
registerPage.register("firstName", "lastName", email, registerPage.register("firstName", "lastName", email,
"invitedUser", "password", "password", null, false, null); "invitedUser", "password", "password", null, false, null);
} }