KEYCLOAK-12749 fix "invalid state" error due to IE requesting favicon

Internet Explorer occasionally requests a favicon before doing the
actual redirect to localhost. This commit adds Undertow to properly
handle those unwanted requests.
This commit is contained in:
Thomas Kuestermann 2020-02-05 15:06:20 +01:00 committed by Pedro Igor
parent 7b1b1cd35f
commit 22555371d8
2 changed files with 136 additions and 145 deletions

View file

@ -71,6 +71,10 @@
<groupId>org.jboss.spec.javax.ws.rs</groupId>
<artifactId>jboss-jaxrs-api_2.1_spec</artifactId>
</dependency>
<dependency>
<groupId>io.undertow</groupId>
<artifactId>undertow-core</artifactId>
</dependency>
</dependencies>

View file

@ -1,5 +1,5 @@
/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -17,6 +17,35 @@
package org.keycloak.adapters.installed;
import java.awt.Desktop;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Reader;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Deque;
import java.util.Locale;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.Form;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.Response;
import org.jboss.resteasy.client.jaxrs.ResteasyClient;
import org.jboss.resteasy.client.jaxrs.ResteasyClientBuilder;
import org.keycloak.OAuth2Constants;
@ -33,36 +62,21 @@ import org.keycloak.representations.AccessToken;
import org.keycloak.representations.AccessTokenResponse;
import org.keycloak.representations.IDToken;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.Form;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.Response;
import java.awt.*;
import java.io.*;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Locale;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import io.undertow.Handlers;
import io.undertow.Undertow;
import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.handlers.AllowedMethodsHandler;
import io.undertow.server.handlers.GracefulShutdownHandler;
import io.undertow.server.handlers.PathHandler;
import io.undertow.util.Headers;
import io.undertow.util.Methods;
import io.undertow.util.StatusCodes;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public class KeycloakInstalled {
public interface HttpResponseWriter {
void success(PrintWriter pw, KeycloakInstalled ki);
void failure(PrintWriter pw, KeycloakInstalled ki);
}
private static final String KEYCLOAK_JSON = "META-INF/keycloak.json";
private KeycloakDeployment deployment;
@ -79,8 +93,6 @@ public class KeycloakInstalled {
private String refreshToken;
private Status status;
private Locale locale;
private HttpResponseWriter loginResponseWriter;
private HttpResponseWriter logoutResponseWriter;
private ResteasyClient resteasyClient;
Pattern callbackPattern = Pattern.compile("callback\\s*=\\s*\"([^\"]+)\"");
Pattern paramPattern = Pattern.compile("param=\"([^\"]+)\"\\s+label=\"([^\"]+)\"\\s+mask=(\\S+)");
@ -100,22 +112,6 @@ public class KeycloakInstalled {
this.deployment = deployment;
}
public HttpResponseWriter getLoginResponseWriter() {
return null;
}
public HttpResponseWriter getLogoutResponseWriter() {
return null;
}
public void setLoginResponseWriter(HttpResponseWriter loginResponseWriter) {
this.loginResponseWriter = loginResponseWriter;
}
public void setLogoutResponseWriter(HttpResponseWriter logoutResponseWriter) {
this.logoutResponseWriter = logoutResponseWriter;
}
public void setResteasyClient(ResteasyClient resteasyClient) {
this.resteasyClient = resteasyClient;
}
@ -161,10 +157,10 @@ public class KeycloakInstalled {
}
public void loginDesktop() throws IOException, VerificationException, OAuthErrorException, URISyntaxException, ServerRequest.HttpFailure, InterruptedException {
CallbackListener callback = new CallbackListener(getLoginResponseWriter());
CallbackListener callback = new CallbackListener();
callback.start();
String redirectUri = "http://localhost:" + callback.server.getLocalPort();
String redirectUri = "http://localhost:" + callback.getLocalPort();
String state = UUID.randomUUID().toString();
Pkce pkce = deployment.isPkce() ? generatePkce() : null;
@ -172,18 +168,19 @@ public class KeycloakInstalled {
Desktop.getDesktop().browse(new URI(authUrl));
callback.join();
if (!state.equals(callback.state)) {
throw new VerificationException("Invalid state");
try {
callback.await();
} catch (InterruptedException e) {
callback.stop();
throw e;
}
if (callback.error != null) {
throw new OAuthErrorException(callback.error, callback.errorDescription);
}
if (callback.errorException != null) {
throw callback.errorException;
if (!state.equals(callback.state)) {
throw new VerificationException("Invalid state");
}
processCode(callback.code, redirectUri, pkce);
@ -220,10 +217,10 @@ public class KeycloakInstalled {
}
private void logoutDesktop() throws IOException, URISyntaxException, InterruptedException {
CallbackListener callback = new CallbackListener(getLogoutResponseWriter());
CallbackListener callback = new CallbackListener();
callback.start();
String redirectUri = "http://localhost:" + callback.server.getLocalPort();
String redirectUri = "http://localhost:" + callback.getLocalPort();
String logoutUrl = deployment.getLogoutUrl()
.queryParam(OAuth2Constants.REDIRECT_URI, redirectUri)
@ -231,10 +228,11 @@ public class KeycloakInstalled {
Desktop.getDesktop().browse(new URI(logoutUrl));
callback.join();
if (callback.errorException != null) {
throw callback.errorException;
try {
callback.await();
} catch (InterruptedException e) {
callback.stop();
throw e;
}
}
@ -590,7 +588,6 @@ public class KeycloakInstalled {
return deployment;
}
private void processCode(String code, String redirectUri, Pkce pkce) throws IOException, ServerRequest.HttpFailure, VerificationException {
AccessTokenResponse tokenResponse = ServerRequest.invokeAccessCodeToToken(deployment, code, redirectUri, null, pkce == null ? null : pkce.getCodeVerifier());
@ -613,96 +610,86 @@ public class KeycloakInstalled {
return sb.toString();
}
public class CallbackListener extends Thread {
private ServerSocket server;
private String code;
private String error;
private String errorDescription;
private IOException errorException;
private String state;
private Socket socket;
private HttpResponseWriter writer;
public CallbackListener(HttpResponseWriter writer) throws IOException {
this.writer = writer;
server = new ServerSocket(0);
}
@Override
public void run() {
try {
socket = server.accept();
BufferedReader br = new BufferedReader(new InputStreamReader(socket.getInputStream()));
String request = br.readLine();
String url = request.split(" ")[1];
if (url.indexOf('?') >= 0) {
url = url.split("\\?")[1];
String[] params = url.split("&");
for (String param : params) {
String[] p = param.split("=");
if (p[0].equals(OAuth2Constants.CODE)) {
code = p[1];
} else if (p[0].equals(OAuth2Constants.ERROR)) {
error = p[1];
} else if (p[0].equals("error-description")) {
errorDescription = p[1];
} else if (p[0].equals(OAuth2Constants.STATE)) {
state = p[1];
}
}
}
OutputStreamWriter out = new OutputStreamWriter(socket.getOutputStream());
PrintWriter pw = new PrintWriter(out);
if (writer != null) {
System.err.println("Using a writer is deprecated. Please remove its usage. This is now handled by endpoint on server");
}
if (error == null) {
if (writer != null) {
writer.success(pw, KeycloakInstalled.this);
} else {
pw.println("HTTP/1.1 302 Found");
pw.println("Location: " + deployment.getTokenUrl().replace("/token", "/delegated"));
}
} else {
if (writer != null) {
writer.failure(pw, KeycloakInstalled.this);
} else {
pw.println("HTTP/1.1 302 Found");
pw.println("Location: " + deployment.getTokenUrl().replace("/token", "/delegated?error=true"));
}
}
pw.flush();
socket.close();
} catch (IOException e) {
errorException = e;
}
try {
server.close();
} catch (IOException e) {
}
}
KeycloakInstalled(int i) {
}
public static class Pkce {
class CallbackListener implements HttpHandler {
private final CountDownLatch shutdownSignal = new CountDownLatch(1);
private String code;
private String error;
private String errorDescription;
private String state;
private Undertow server;
private GracefulShutdownHandler gracefulShutdownHandler;
public void start() {
PathHandler pathHandler = Handlers.path().addExactPath("/", this);
AllowedMethodsHandler allowedMethodsHandler = new AllowedMethodsHandler(pathHandler, Methods.GET);
gracefulShutdownHandler = Handlers.gracefulShutdown(allowedMethodsHandler);
server = Undertow.builder()
.addHttpListener(0, "localhost")
.setHandler(gracefulShutdownHandler).build();
server.start();
}
public void stop() {
server.stop();
}
public int getLocalPort() {
return ((InetSocketAddress) server.getListenerInfo().get(0).getAddress()).getPort();
}
public void await() throws InterruptedException {
shutdownSignal.await();
}
@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
gracefulShutdownHandler.shutdown();
if (!exchange.getQueryParameters().isEmpty()) {
readQueryParameters(exchange);
}
exchange.setStatusCode(StatusCodes.FOUND);
exchange.getResponseHeaders().add(Headers.LOCATION, getRedirectUrl());
exchange.endExchange();
shutdownSignal.countDown();
ForkJoinPool.commonPool().execute(this::stop);
}
private void readQueryParameters(HttpServerExchange exchange) {
code = getQueryParameterIfPresent(exchange, OAuth2Constants.CODE);
error = getQueryParameterIfPresent(exchange, OAuth2Constants.ERROR);
errorDescription = getQueryParameterIfPresent(exchange, "error-description");
state = getQueryParameterIfPresent(exchange, OAuth2Constants.STATE);
}
private String getQueryParameterIfPresent(HttpServerExchange exchange, String name) {
Map<String, Deque<String>> queryParameters = exchange.getQueryParameters();
return queryParameters.containsKey(name) ? queryParameters.get(name).getFirst() : null;
}
private String getRedirectUrl() {
String redirectUrl = deployment.getTokenUrl().replace("/token", "/delegated");
if (error != null) {
redirectUrl += "?error=true";
}
return redirectUrl;
}
}
public static class Pkce {
// https://tools.ietf.org/html/rfc7636#section-4.1
public static final int PKCE_CODE_VERIFIER_MAX_LENGTH = 128;