DefaultBruteForceProtector leverages a single thread to write success/failed events

Closes #14084

Signed-off-by: Douglas Palmer <dpalmer@redhat.com>
This commit is contained in:
Douglas Palmer 2024-04-02 07:42:05 -07:00 committed by Marek Posolda
parent ca00395877
commit 69ba92808d
3 changed files with 48 additions and 229 deletions

View file

@ -64,7 +64,9 @@ public class LoginFailureEntity extends SessionEntity {
} }
public void setFailedLoginNotBefore(int failedLoginNotBefore) { public void setFailedLoginNotBefore(int failedLoginNotBefore) {
this.failedLoginNotBefore = failedLoginNotBefore; if(failedLoginNotBefore>this.failedLoginNotBefore) {
this.failedLoginNotBefore = failedLoginNotBefore;
}
} }
public int getNumFailures() { public int getNumFailures() {
@ -88,7 +90,9 @@ public class LoginFailureEntity extends SessionEntity {
} }
public void setLastFailure(long lastFailure) { public void setLastFailure(long lastFailure) {
this.lastFailure = lastFailure; if(lastFailure>this.lastFailure) {
this.lastFailure = lastFailure;
}
} }
public String getLastIPFailure() { public String getLastIPFailure() {

View file

@ -23,23 +23,19 @@ import org.keycloak.common.util.Time;
import org.keycloak.events.Details; import org.keycloak.events.Details;
import org.keycloak.events.EventBuilder; import org.keycloak.events.EventBuilder;
import org.keycloak.events.EventType; import org.keycloak.events.EventType;
import org.keycloak.executors.ExecutorsProvider;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory; import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserLoginFailureModel; import org.keycloak.models.UserLoginFailureModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.services.ServicesLogger; import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.storage.ReadOnlyException; import org.keycloak.storage.ReadOnlyException;
import java.time.Instant; import java.time.Instant;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.ZoneId; import java.time.ZoneId;
import java.time.ZoneOffset; import java.util.concurrent.ExecutorService;
import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import static org.keycloak.models.UserModel.DISABLED_REASON; import static org.keycloak.models.UserModel.DISABLED_REASON;
@ -49,125 +45,30 @@ import static org.keycloak.models.UserModel.DISABLED_REASON;
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a> * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $ * @version $Revision: 1 $
*/ */
public class DefaultBruteForceProtector implements Runnable, BruteForceProtector { public class DefaultBruteForceProtector implements BruteForceProtector {
private static final Logger logger = Logger.getLogger(DefaultBruteForceProtector.class); private static final Logger logger = Logger.getLogger(DefaultBruteForceProtector.class);
protected volatile boolean run = true;
protected int maxDeltaTimeSeconds = 60 * 60 * 12; // 12 hours protected int maxDeltaTimeSeconds = 60 * 60 * 12; // 12 hours
protected KeycloakSessionFactory factory; protected KeycloakSessionFactory factory;
protected CountDownLatch shutdownLatch = new CountDownLatch(1);
protected volatile long failures;
protected volatile long lastFailure;
protected volatile long totalTime;
protected LinkedBlockingQueue<LoginEvent> queue = new LinkedBlockingQueue<LoginEvent>();
public static final int TRANSACTION_SIZE = 20;
protected abstract class LoginEvent implements Comparable<LoginEvent> {
protected final String realmId;
protected final String userId;
protected final ClientConnection clientConnection;
protected LoginEvent(String realmId, String userId, ClientConnection clientConnection) {
this.realmId = realmId;
this.userId = userId;
this.clientConnection = new AdaptedClientConnection(clientConnection);
}
@Override
public int compareTo(LoginEvent o) {
return userId.compareTo(o.userId);
}
}
protected class ShutdownEvent extends LoginEvent {
public ShutdownEvent() {
super(null, null, null);
}
}
protected class FailedLogin extends LoginEvent {
protected final CountDownLatch latch = new CountDownLatch(1);
public FailedLogin(String realmId, String userId, ClientConnection clientConnection) {
super(realmId, userId, clientConnection);
}
}
protected class SuccessfulLogin extends LoginEvent {
protected final CountDownLatch latch = new CountDownLatch(1);
public SuccessfulLogin(String realmId, String userId, ClientConnection clientConnection) {
super(realmId, userId, clientConnection);
}
}
protected static class AdaptedClientConnection implements ClientConnection {
private final String remoteAddr;
private final String remoteHost;
private final int remotePort;
private final String localAddr;
private final int localPort;
public AdaptedClientConnection(ClientConnection c) {
this.remoteAddr = c == null ? null : c.getRemoteAddr();
this.remoteHost = c == null ? null : c.getRemoteHost();
this.remotePort = c == null ? 0 : c.getRemotePort();
this.localAddr = c == null ? null : c.getLocalAddr();
this.localPort = c == null ? 0 : c.getLocalPort();
}
@Override
public String getRemoteAddr() {
return this.remoteAddr;
}
@Override
public String getRemoteHost() {
return this.remoteHost;
}
@Override
public int getRemotePort() {
return this.remotePort;
}
@Override
public String getLocalAddr() {
return this.localAddr;
}
@Override
public int getLocalPort() {
return this.localPort;
}
}
public DefaultBruteForceProtector(KeycloakSessionFactory factory) { public DefaultBruteForceProtector(KeycloakSessionFactory factory) {
this.factory = factory; this.factory = factory;
} }
protected void failure(KeycloakSession session, LoginEvent event) { protected void failure(KeycloakSession session, RealmModel realm, String userId, String remoteAddr, long failureTime) {
logger.debug("failure"); logger.debug("failure");
RealmModel realm = getRealmModel(session, event);
logFailure(event);
String userId = event.userId; UserLoginFailureModel userLoginFailure = getUserFailureModel(session, realm, userId);
UserLoginFailureModel userLoginFailure = getUserModel(session, event);
if (userLoginFailure == null) { if (userLoginFailure == null) {
userLoginFailure = session.loginFailures().addUserLoginFailure(realm, userId); userLoginFailure = session.loginFailures().addUserLoginFailure(realm, userId);
} }
userLoginFailure.setLastIPFailure(event.clientConnection.getRemoteAddr()); userLoginFailure.setLastIPFailure(remoteAddr);
long currentTime = Time.currentTimeMillis();
long last = userLoginFailure.getLastFailure(); long last = userLoginFailure.getLastFailure();
long deltaTime = 0; long deltaTime = 0;
if (last > 0) { if (last > 0) {
deltaTime = currentTime - last; deltaTime = failureTime - last;
} }
userLoginFailure.setLastFailure(currentTime); userLoginFailure.setLastFailure(failureTime);
if (deltaTime > 0) { if (deltaTime > 0) {
// if last failure was more than MAX_DELTA clear failures // if last failure was more than MAX_DELTA clear failures
@ -198,7 +99,7 @@ public class DefaultBruteForceProtector implements Runnable, BruteForceProtector
userLoginFailure.incrementTemporaryLockouts(); userLoginFailure.incrementTemporaryLockouts();
} }
if(quickLoginFailure || !realm.isPermanentLockout() || userLoginFailure.getNumTemporaryLockouts() <= realm.getMaxTemporaryLockouts()) { if(quickLoginFailure || !realm.isPermanentLockout() || userLoginFailure.getNumTemporaryLockouts() <= realm.getMaxTemporaryLockouts()) {
int notBefore = (int) (currentTime / 1000) + waitSeconds; int notBefore = (int) (failureTime / 1000) + waitSeconds;
logger.debugv("set notBefore: {0}", notBefore); logger.debugv("set notBefore: {0}", notBefore);
userLoginFailure.setFailedLoginNotBefore(notBefore); userLoginFailure.setFailedLoginNotBefore(notBefore);
sendEvent(session, realm, userLoginFailure, EventType.USER_DISABLED_BY_TEMPORARY_LOCKOUT); sendEvent(session, realm, userLoginFailure, EventType.USER_DISABLED_BY_TEMPORARY_LOCKOUT);
@ -226,19 +127,9 @@ public class DefaultBruteForceProtector implements Runnable, BruteForceProtector
} }
} }
protected UserLoginFailureModel getUserFailureModel(KeycloakSession session, RealmModel realm, String userId) {
protected UserLoginFailureModel getUserModel(KeycloakSession session, LoginEvent event) {
RealmModel realm = getRealmModel(session, event);
if (realm == null) return null; if (realm == null) return null;
UserLoginFailureModel user = session.loginFailures().getUserLoginFailure(realm, event.userId); return session.loginFailures().getUserLoginFailure(realm, userId);
if (user == null) return null;
return user;
}
protected RealmModel getRealmModel(KeycloakSession session, LoginEvent event) {
RealmModel realm = session.realms().getRealm(event.realmId);
if (realm == null) return null;
return realm;
} }
protected void sendEvent(KeycloakSession session, RealmModel realm, UserLoginFailureModel userLoginFailure, EventType type) { protected void sendEvent(KeycloakSession session, RealmModel realm, UserLoginFailureModel userLoginFailure, EventType type) {
@ -261,125 +152,54 @@ public class DefaultBruteForceProtector implements Runnable, BruteForceProtector
builder.success(); builder.success();
} }
public void start() { public void shutdown() {}
new Thread(this, "Brute Force Protector").start();
}
public void shutdown() { protected void success(KeycloakSession session, RealmModel realm, String userId) {
run = false; UserLoginFailureModel userLoginFailure = getUserFailureModel(session, realm, userId);
try { if(userLoginFailure == null) return;
queue.offer(new ShutdownEvent());
shutdownLatch.await(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
public void run() {
final ArrayList<LoginEvent> events = new ArrayList<LoginEvent>(TRANSACTION_SIZE + 1);
try {
while (run) {
try {
LoginEvent take = queue.poll(2, TimeUnit.SECONDS);
if (take == null) {
continue;
}
try {
events.add(take);
queue.drainTo(events, TRANSACTION_SIZE);
Collections.sort(events); // we sort to avoid deadlock due to ordered updates. Maybe I'm overthinking this.
try (KeycloakSession session = factory.create()) {
session.getTransactionManager().begin();
try {
for (LoginEvent event : events) {
if (event instanceof FailedLogin) {
failure(session, event);
} else if (event instanceof SuccessfulLogin) {
success(session, event);
} else if (event instanceof ShutdownEvent) {
run = false;
}
}
} catch (Exception e) {
session.getTransactionManager().setRollbackOnly();
throw e;
}
} finally {
for (LoginEvent event : events) {
if (event instanceof FailedLogin) {
((FailedLogin) event).latch.countDown();
} else if (event instanceof SuccessfulLogin) {
((SuccessfulLogin) event).latch.countDown();
}
}
events.clear();
}
} catch (Exception e) {
ServicesLogger.LOGGER.failedProcessingType(e);
}
} catch (InterruptedException e) {
break;
}
}
} finally {
shutdownLatch.countDown();
}
}
protected void success(KeycloakSession session, LoginEvent event) {
String userId = event.userId;
UserLoginFailureModel user = getUserModel(session, event);
if(user == null) return;
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
UserModel model = session.users().getUserById(getRealmModel(session, event), userId); UserModel model = session.users().getUserById(realm, userId);
logger.debugv("user {0} successfully logged in, clearing all failures", model.getUsername()); logger.debugv("user {0} successfully logged in, clearing all failures", model.getUsername());
} }
user.clearFailures(); userLoginFailure.clearFailures();
}
protected void logFailure(LoginEvent event) {
failures++;
long delta = 0;
if (lastFailure > 0) {
delta = Time.currentTimeMillis() - lastFailure;
if (delta > (long)maxDeltaTimeSeconds * 1000L) {
totalTime = 0;
} else {
totalTime += delta;
}
}
} }
@Override @Override
public void failedLogin(RealmModel realm, UserModel user, ClientConnection clientConnection) { public void failedLogin(RealmModel realm, UserModel user, ClientConnection clientConnection) {
try { processLogin(realm, user, clientConnection, false);
FailedLogin event = new FailedLogin(realm.getId(), user.getId(), clientConnection); // wait a minimum of seconds for type to process so that a hacker
queue.offer(event); // cannot flood with failed logins and overwhelm the queue and not have notBefore updated to block next requests
// wait a minimum of seconds for type to process so that a hacker // todo failure HTTP responses should be queued via async HTTP
// cannot flood with failed logins and overwhelm the queue and not have notBefore updated to block next requests //event.latch.await(5, TimeUnit.SECONDS);
// todo failure HTTP responses should be queued via async HTTP
event.latch.await(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
}
logger.trace("sent failure event"); logger.trace("sent failure event");
} }
@Override @Override
public void successfulLogin(final RealmModel realm, final UserModel user, final ClientConnection clientConnection) { public void successfulLogin(RealmModel realm, UserModel user, ClientConnection clientConnection) {
SuccessfulLogin event = new SuccessfulLogin(realm.getId(), user.getId(), clientConnection); processLogin(realm, user, clientConnection, true);
queue.offer(event);
logger.trace("sent success event"); logger.trace("sent success event");
} }
private void processLogin(RealmModel realm, UserModel user, ClientConnection clientConnection, boolean success) {
KeycloakSession session = factory.create();
ExecutorsProvider provider = session.getProvider(ExecutorsProvider.class);
ExecutorService executor = provider.getExecutor("bruteforce");
executor.execute(() -> KeycloakModelUtils.runJobInTransaction(factory, s -> {
if (success) {
success(s, realm, user.getId());
} else {
failure(s, realm, user.getId(), clientConnection.getRemoteAddr(), Time.currentTimeMillis());
}
}));
}
@Override @Override
public boolean isTemporarilyDisabled(KeycloakSession session, RealmModel realm, UserModel user) { public boolean isTemporarilyDisabled(KeycloakSession session, RealmModel realm, UserModel user) {
UserLoginFailureModel failure = session.loginFailures().getUserLoginFailure(realm, user.getId()); UserLoginFailureModel userLoginFailure = getUserFailureModel(session, realm, user.getId());
if (failure != null) { if (userLoginFailure != null) {
int currTime = (int) (Time.currentTimeMillis() / 1000); int currTime = (int) (Time.currentTimeMillis() / 1000);
int failedLoginNotBefore = failure.getFailedLoginNotBefore(); int failedLoginNotBefore = userLoginFailure.getFailedLoginNotBefore();
if (currTime < failedLoginNotBefore) { if (currTime < failedLoginNotBefore) {
logger.debugv("Current: {0} notBefore: {1}", currTime, failedLoginNotBefore); logger.debugv("Current: {0} notBefore: {1}", currTime, failedLoginNotBefore);
return true; return true;
@ -403,7 +223,5 @@ public class DefaultBruteForceProtector implements Runnable, BruteForceProtector
} }
@Override @Override
public void close() { public void close() {}
}
} }

View file

@ -41,14 +41,11 @@ public class DefaultBruteForceProtectorFactory implements BruteForceProtectorFac
@Override @Override
public void postInit(KeycloakSessionFactory factory) { public void postInit(KeycloakSessionFactory factory) {
protector = new DefaultBruteForceProtector(factory); protector = new DefaultBruteForceProtector(factory);
protector.start();
} }
@Override @Override
public void close() { public void close() {
protector.shutdown(); protector.shutdown();
} }
@Override @Override