diff --git a/core/src/main/java/org/keycloak/jose/jwk/JWKBuilder.java b/core/src/main/java/org/keycloak/jose/jwk/JWKBuilder.java index ecdb3da539..1b84bb26b2 100644 --- a/core/src/main/java/org/keycloak/jose/jwk/JWKBuilder.java +++ b/core/src/main/java/org/keycloak/jose/jwk/JWKBuilder.java @@ -124,8 +124,8 @@ public class JWKBuilder { k.setAlgorithm(algorithm); k.setPublicKeyUse(DEFAULT_PUBLIC_KEY_USE); k.setCrv("P-" + fieldSize); - k.setX(Base64Url.encode(toIntegerBytes(ecKey.getW().getAffineX()))); - k.setY(Base64Url.encode(toIntegerBytes(ecKey.getW().getAffineY()))); + k.setX(Base64Url.encode(toIntegerBytes(ecKey.getW().getAffineX(), fieldSize))); + k.setY(Base64Url.encode(toIntegerBytes(ecKey.getW().getAffineY(), fieldSize))); return k; } diff --git a/core/src/main/java/org/keycloak/jose/jwk/JWKUtil.java b/core/src/main/java/org/keycloak/jose/jwk/JWKUtil.java index 3c929b667c..2b0eaf0081 100644 --- a/core/src/main/java/org/keycloak/jose/jwk/JWKUtil.java +++ b/core/src/main/java/org/keycloak/jose/jwk/JWKUtil.java @@ -18,35 +18,46 @@ package org.keycloak.jose.jwk; import java.math.BigInteger; +import java.util.Arrays; public class JWKUtil { /** - * Convert BigInteger to 64-byte integer array + * Coverts {@code BigInteger} to 64-byte array removing the sign byte if + * necessary. * - * Copied from org.apache.commons.codec.binary.Base64 + * @param bigInt {@code BigInteger} to be converted + * @return Byte array representation of the BigInteger parameter */ public static byte[] toIntegerBytes(final BigInteger bigInt) { - int bitlen = bigInt.bitLength(); - // round bitlen - bitlen = ((bitlen + 7) >> 3) << 3; - final byte[] bigBytes = bigInt.toByteArray(); + return toIntegerBytes(bigInt, bigInt.bitLength()); + } - if (((bigInt.bitLength() % 8) != 0) && (((bigInt.bitLength() / 8) + 1) == (bitlen / 8))) { - return bigBytes; + /** + * Coverts {@code BigInteger} to 64-byte array but maintaining the length + * to bitlen as specified in rfc7518 for certain fields (X and Y parameter + * for EC keys). + * + * @param bigInt {@code BigInteger} to be converted + * @param bitlen The bit length size of the integer (for example 521 for EC P-521) + * @return Byte array representation of the BigInteger parameter with length (bitlen + 7) / 8 + * @throws IllegalStateException if the big integer is longer than bitlen + */ + public static byte[] toIntegerBytes(final BigInteger bigInt, int bitlen) { + assert bigInt.bitLength() <= bitlen : "Incorrect big integer with bit length " + bigInt.bitLength() + " for " + bitlen; + final int bytelen = (bitlen + 7) / 8; + final byte[] array = bigInt.toByteArray(); + if (array.length == bytelen) { + // expected number of bytes, return them + return array; + } else if (bytelen < array.length) { + // if array is greater is because the sign bit (it can be only 1 byte more), remove it + return Arrays.copyOfRange(array, array.length - bytelen, array.length); + } else { + // if array is smaller fill it with zeros + final byte[] resizedBytes = new byte[bytelen]; + System.arraycopy(array, 0, resizedBytes, bytelen - array.length, array.length); + return resizedBytes; } - // set up params for copying everything but sign bit - int startSrc = 0; - int len = bigBytes.length; - - // if bigInt is exactly byte-aligned, just skip signbit in copy - if ((bigInt.bitLength() % 8) == 0) { - startSrc = 1; - len--; - } - final int startDst = bitlen / 8 - len; // to pad w/ nulls as per spec - final byte[] resizedBytes = new byte[bitlen / 8]; - System.arraycopy(bigBytes, startSrc, resizedBytes, startDst, len); - return resizedBytes; } } diff --git a/core/src/test/java/org/keycloak/jose/jwk/JWKTest.java b/core/src/test/java/org/keycloak/jose/jwk/JWKTest.java index 214a072b33..a25960d8ee 100644 --- a/core/src/test/java/org/keycloak/jose/jwk/JWKTest.java +++ b/core/src/test/java/org/keycloak/jose/jwk/JWKTest.java @@ -135,15 +135,14 @@ public abstract class JWKTest { verify(data, sign, JavaAlgorithm.RS256, publicKeyFromJwk); } - @Test - public void publicEs256() throws Exception { + private void testPublicEs256(String algorithm) throws Exception { KeyPairGenerator keyGen = CryptoIntegration.getProvider().getKeyPairGen(KeyType.EC); SecureRandom randomGen = new SecureRandom(); - ECGenParameterSpec ecSpec = new ECGenParameterSpec("secp256r1"); + ECGenParameterSpec ecSpec = new ECGenParameterSpec(algorithm); keyGen.initialize(ecSpec, randomGen); KeyPair keyPair = keyGen.generateKeyPair(); - PublicKey publicKey = keyPair.getPublic(); + ECPublicKey publicKey = (ECPublicKey) keyPair.getPublic(); JWK jwk = JWKBuilder.create().kid(KeyUtils.createKeyId(keyPair.getPublic())).algorithm("ES256").ec(publicKey); @@ -162,28 +161,36 @@ public abstract class JWKTest { byte[] xBytes = Base64Url.decode(ecJwk.getX()); byte[] yBytes = Base64Url.decode(ecJwk.getY()); - assertTrue(publicKey instanceof ECPublicKey); - ECPoint ecPoint = ((ECPublicKey) publicKey).getW(); - assertNotNull(ecPoint); - - int lengthAffineX = JWKUtil.toIntegerBytes(ecPoint.getAffineX()).length; - int lengthAffineY = JWKUtil.toIntegerBytes(ecPoint.getAffineY()).length; - - assertEquals(lengthAffineX, xBytes.length); - assertEquals(lengthAffineY, yBytes.length); + final int expectedSize = (publicKey.getParams().getCurve().getField().getFieldSize() + 7) / 8; + assertEquals(expectedSize, xBytes.length); + assertEquals(expectedSize, yBytes.length); String jwkJson = JsonSerialization.writeValueAsString(jwk); JWKParser parser = JWKParser.create().parse(jwkJson); - PublicKey publicKeyFromJwk = parser.toPublicKey(); - - assertArrayEquals(publicKey.getEncoded(), publicKeyFromJwk.getEncoded()); + ECPublicKey publicKeyFromJwk = (ECPublicKey) parser.toPublicKey(); + assertEquals(publicKey.getW(), publicKeyFromJwk.getW()); byte[] data = "Some test string".getBytes(StandardCharsets.UTF_8); byte[] sign = sign(data, JavaAlgorithm.ES256, keyPair.getPrivate()); verify(data, sign, JavaAlgorithm.ES256, publicKeyFromJwk); } + @Test + public void publicEs256P256() throws Exception { + testPublicEs256("secp256r1"); + } + + @Test + public void publicEs256P521() throws Exception { + testPublicEs256("secp521r1"); + } + + @Test + public void publicEs256P384() throws Exception { + testPublicEs256("secp384r1"); + } + @Test public void parse() { String jwkJson = "{" + diff --git a/core/src/test/java/org/keycloak/jose/jwk/JWKUtilTest.java b/core/src/test/java/org/keycloak/jose/jwk/JWKUtilTest.java new file mode 100644 index 0000000000..927c8c3b8e --- /dev/null +++ b/core/src/test/java/org/keycloak/jose/jwk/JWKUtilTest.java @@ -0,0 +1,103 @@ +/* + * Copyright 2022 Red Hat, Inc. and/or its affiliates + * and other contributors as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.keycloak.jose.jwk; + +import java.math.BigInteger; +import org.junit.Assert; +import org.junit.Test; + +/** + *
Test class for JWKUtil.toIntegerBytes methods.
+ * + * @author rmartinc + */ +public class JWKUtilTest { + + @Test + public void testBigInteger256bit33bytes() { + // big integer that is 256b/32B (P-256) but positive sign adds one more byte + BigInteger bi = new BigInteger("106978455244904118504029146852168092303170743300495577837424194202315290288011"); + Assert.assertEquals(256, bi.bitLength()); + Assert.assertEquals(33, bi.toByteArray().length); + byte[] bytes = JWKUtil.toIntegerBytes(bi, 256); + Assert.assertEquals(32, bytes.length); + Assert.assertEquals(bi, new BigInteger(1, bytes)); + } + + @Test + public void testBigInteger521bit66bytes() { + // big integer that is 521b/66B (P-521) + BigInteger bi = new BigInteger("6734373674814691396115132088653791161514881890352734019594374673014557152383502505390504647094584246525242385854438954847939940255492102589858760446395824148"); + Assert.assertEquals(521, bi.bitLength()); + Assert.assertEquals(66, bi.toByteArray().length); + byte[] bytes = JWKUtil.toIntegerBytes(bi, 521); + Assert.assertEquals(66, bytes.length); + Assert.assertEquals(bi, new BigInteger(1, bytes)); + } + + @Test + public void testBigInteger519bit65bytes() { + // big integer is 519b/65B (P-521) + BigInteger bi = new BigInteger("1056406612537758216307284361941630998827278875643943164504783316640832530092186610655845467862847840003942818620330993843247554843391332954698064457598103921"); + Assert.assertEquals(519, bi.bitLength()); + Assert.assertEquals(65, bi.toByteArray().length); + byte[] bytes = JWKUtil.toIntegerBytes(bi, 521); + Assert.assertEquals(66, bytes.length); + Assert.assertEquals(bi, new BigInteger(1, bytes)); + bytes = JWKUtil.toIntegerBytes(bi); + Assert.assertEquals(65, bytes.length); + Assert.assertEquals(bi, new BigInteger(1, bytes)); + } + + @Test + public void testBigInteger509bit65bytes() { + // big integer is 509b/64B (P-521) + BigInteger bi = new BigInteger("1020105336060806799317581876370378670178920448263046037385822665297838480884942245045412789346716977404456327079571798657084244307627713218035021026706753"); + Assert.assertEquals(509, bi.bitLength()); + Assert.assertEquals(64, bi.toByteArray().length); + byte[] bytes = JWKUtil.toIntegerBytes(bi, 521); + Assert.assertEquals(66, bytes.length); + Assert.assertEquals(bi, new BigInteger(1, bytes)); + bytes = JWKUtil.toIntegerBytes(bi); + Assert.assertEquals(64, bytes.length); + Assert.assertEquals(bi, new BigInteger(1, bytes)); + } + + @Test + public void testBigInteger380bit48bytes() { + // big integer is 380b/48B (P-384) + BigInteger bi = new BigInteger("1318324198847573133767761135109898830134893480775680898178696604234765693579204018161102886445531980641666395659568"); + Assert.assertEquals(380, bi.bitLength()); + Assert.assertEquals(48, bi.toByteArray().length); + byte[] bytes = JWKUtil.toIntegerBytes(bi, 384); + Assert.assertEquals(48, bytes.length); + Assert.assertEquals(bi, new BigInteger(1, bytes)); + bytes = JWKUtil.toIntegerBytes(bi); + Assert.assertEquals(48, bytes.length); + Assert.assertEquals(bi, new BigInteger(1, bytes)); + } + + @Test + public void testBigInteger380bit48bytesErrorFor256() { + // big integer is 380b/48B (P-384) not valid for 256b (P-256) + BigInteger bi = new BigInteger("1318324198847573133767761135109898830134893480775680898178696604234765693579204018161102886445531980641666395659568"); + Assert.assertEquals(380, bi.bitLength()); + Assert.assertEquals(48, bi.toByteArray().length); + AssertionError e = Assert.assertThrows(AssertionError.class, () -> JWKUtil.toIntegerBytes(bi, 256)); + Assert.assertEquals("Incorrect big integer with bit length 380 for 256", e.getMessage()); + } +}