From 3cdd3081565d890801a27a1f9ac8bc53e5711ce6 Mon Sep 17 00:00:00 2001 From: Martin Balao Date: Mon, 17 Mar 2025 17:54:00 +0000 Subject: [PATCH] 8337692: Better TLS connection support Reviewed-by: abakhtin, andrew Backport-of: 2adb01e8c5fbcc3dfa9f82df3deccb3a1705bf13 --- .../com/sun/crypto/provider/RSACipher.java | 62 +++++++--- .../classes/sun/security/rsa/RSAPadding.java | 109 +++++++++++++----- .../classes/sun/security/util/KeyUtil.java | 32 +++-- 3 files changed, 147 insertions(+), 56 deletions(-) diff --git a/jdk/src/share/classes/com/sun/crypto/provider/RSACipher.java b/jdk/src/share/classes/com/sun/crypto/provider/RSACipher.java index c90f7efcbd..d489036fa9 100644 --- a/jdk/src/share/classes/com/sun/crypto/provider/RSACipher.java +++ b/jdk/src/share/classes/com/sun/crypto/provider/RSACipher.java @@ -383,7 +383,7 @@ public final class RSACipher extends CipherSpi { byte[] decryptBuffer = RSACore.convert(buffer, 0, bufOfs); paddingCopy = RSACore.rsa(decryptBuffer, privateKey, false); result = padding.unpad(paddingCopy); - if (result == null && !forTlsPremasterSecret) { + if (!forTlsPremasterSecret && result == null) { throw new BadPaddingException ("Padding error in decryption"); } @@ -403,6 +403,34 @@ public final class RSACipher extends CipherSpi { } } + // TLS master secret decode version of the doFinal() method. + private byte[] doFinalForTls(int clientVersion, int serverVersion) + throws BadPaddingException, IllegalBlockSizeException { + if (bufOfs > buffer.length) { + throw new IllegalBlockSizeException("Data must not be longer " + + "than " + buffer.length + " bytes"); + } + byte[] paddingCopy = null; + byte[] result = null; + try { + byte[] decryptBuffer = RSACore.convert(buffer, 0, bufOfs); + + paddingCopy = RSACore.rsa(decryptBuffer, privateKey, false); + result = padding.unpadForTls(paddingCopy, clientVersion, + serverVersion); + + return result; + } finally { + Arrays.fill(buffer, 0, bufOfs, (byte)0); + bufOfs = 0; + if (paddingCopy != null + && paddingCopy != buffer // already cleaned + && paddingCopy != result) { // DO NOT CLEAN, THIS IS RESULT + Arrays.fill(paddingCopy, (byte)0); + } + } + } + // see JCE spec protected byte[] engineUpdate(byte[] in, int inOfs, int inLen) { update(in, inOfs, inLen); @@ -474,38 +502,34 @@ public final class RSACipher extends CipherSpi { byte[] encoded = null; update(wrappedKey, 0, wrappedKey.length); - try { - encoded = doFinal(); - } catch (BadPaddingException | IllegalBlockSizeException e) { - // BadPaddingException cannot happen for TLS RSA unwrap. - // In that case, padding error is indicated by returning null. - // IllegalBlockSizeException cannot happen in any case, - // because of the length check above. - throw new InvalidKeyException("Unwrapping failed", e); - } - try { if (isTlsRsaPremasterSecret) { if (!forTlsPremasterSecret) { throw new IllegalStateException( "No TlsRsaPremasterSecretParameterSpec specified"); } - - // polish the TLS premaster secret - encoded = KeyUtil.checkTlsPreMasterSecretKey( - ((TlsRsaPremasterSecretParameterSpec) spec).getClientVersion(), - ((TlsRsaPremasterSecretParameterSpec) spec).getServerVersion(), - random, encoded, encoded == null); + TlsRsaPremasterSecretParameterSpec parameterSpec = + (TlsRsaPremasterSecretParameterSpec) spec; + encoded = doFinalForTls(parameterSpec.getClientVersion(), + parameterSpec.getServerVersion()); + } else { + encoded = doFinal(); } - return ConstructKeys.constructKey(encoded, algorithm, type); + + } catch (BadPaddingException | IllegalBlockSizeException e) { + // BadPaddingException cannot happen for TLS RSA unwrap. + // Neither padding error nor server version error is indicated + // for TLS, but a fake unwrapped value is returned. + // IllegalBlockSizeException cannot happen in any case, + // because of the length check above. + throw new InvalidKeyException("Unwrapping failed", e); } finally { if (encoded != null) { Arrays.fill(encoded, (byte) 0); } } } - // see JCE spec protected int engineGetKeySize(Key key) throws InvalidKeyException { RSAKey rsaKey = RSAKeyFactory.toRSAKey(key); diff --git a/jdk/src/share/classes/sun/security/rsa/RSAPadding.java b/jdk/src/share/classes/sun/security/rsa/RSAPadding.java index d0a00c5a90..aab5a944c8 100644 --- a/jdk/src/share/classes/sun/security/rsa/RSAPadding.java +++ b/jdk/src/share/classes/sun/security/rsa/RSAPadding.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2003, 2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2003, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -328,48 +328,103 @@ public final class RSAPadding { * Note that we want to make it a constant-time operation */ private byte[] unpadV15(byte[] padded) { - int k = 0; - boolean bp = false; + int paddedLength = padded.length; - if (padded[k++] != 0) { - bp = true; + if (paddedLength < 2) { + return null; } - if (padded[k++] != type) { - bp = true; - } - int p = 0; - while (k < padded.length) { + + // The following check ensures that the lead byte is zero and + // the second byte is equivalent to the padding type. The + // bp (bad padding) variable throughout this unpadding process will + // be updated and remain 0 if good padding, 1 if bad. + int p0 = padded[0]; + int p1 = padded[1]; + int bp = (-(p0 & 0xff) | ((p1 - type) | (type - p1))) >>> 31; + + int padLen = 0; + int k = 2; + // Walk through the random, nonzero padding bytes. For each padding + // byte bp and padLen will remain zero. When the end-of-padding + // byte (0x00) is reached then padLen will be set to the index of the + // first byte of the message content. + while (k < paddedLength) { int b = padded[k++] & 0xff; - if ((b == 0) && (p == 0)) { - p = k; - } - if ((k == padded.length) && (p == 0)) { - bp = true; - } - if ((type == PAD_BLOCKTYPE_1) && (b != 0xff) && - (p == 0)) { - bp = true; + padLen += (k * (1 - ((-(b | padLen)) >>> 31))); + if (k == paddedLength) { + bp = bp | (1 - ((-padLen) >>> 31)); } + bp = bp | (1 - (-(((type - PAD_BLOCKTYPE_1) & 0xff) | + padLen | (1 - ((b - 0xff) >>> 31))) >>> 31)); } - int n = padded.length - p; - if (n > maxDataSize) { - bp = true; - } + int n = paddedLength - padLen; + // So long as n <= maxDataSize, bp will remain zero + bp = bp | ((maxDataSize - n) >>> 31); // copy useless padding array for a constant-time method - byte[] padding = new byte[p]; - System.arraycopy(padded, 0, padding, 0, p); + byte[] padding = new byte[padLen + 2]; + for (int i = 0; i < padLen; i++) { + padding[i] = padded[i]; + } byte[] data = new byte[n]; - System.arraycopy(padded, p, data, 0, n); + for (int i = 0; i < n; i++) { + data[i] = padded[padLen + i]; + } - if (bp) { + if ((bp | padding[bp]) != 0) { + // using the array padding here hoping that this way + // the compiler does not eliminate the above useless copy return null; } else { return data; } } + public byte[] unpadForTls(byte[] padded, int clientVersion, + int serverVersion) { + int paddedLength = padded.length; + + // bp is positive if the padding is bad and 0 if it is good so far + int bp = (((int) padded[0] | ((int)padded[1] - PAD_BLOCKTYPE_2)) & + 0xFFF); + + int k = 2; + while (k < paddedLength - 49) { + int b = padded[k++] & 0xFF; + bp = bp | (1 - (-b >>> 31)); // if (padded[k] == 0) bp |= 1; + } + bp |= ((int)padded[k++] & 0xFF); + int encodedVersion = ((padded[k] & 0xFF) << 8) | (padded[k + 1] & 0xFF); + + int bv1 = clientVersion - encodedVersion; + bv1 |= -bv1; + int bv3 = serverVersion - encodedVersion; + bv3 |= -bv3; + int bv2 = (0x301 - clientVersion); + + bp |= ((bv1 & (bv2 | bv3)) >>> 28); + + byte[] data = Arrays.copyOfRange(padded, paddedLength - 48, + paddedLength); + if (random == null) { + random = JCAUtil.getSecureRandom(); + } + + byte[] fake = new byte[48]; + random.nextBytes(fake); + + bp = (-bp >> 24); + + // Now bp is 0 if the padding and version number were good and + // -1 otherwise. + for (int i = 0; i < 48; i++) { + data[i] = (byte)((~bp & data[i]) | (bp & fake[i])); + } + + return data; + } + /** * PKCS#1 v2.0 OAEP padding (MGF1). * Paragraph references refer to PKCS#1 v2.1 (June 14, 2002) diff --git a/jdk/src/share/classes/sun/security/util/KeyUtil.java b/jdk/src/share/classes/sun/security/util/KeyUtil.java index e4099b301f..f2b56d5fcd 100644 --- a/jdk/src/share/classes/sun/security/util/KeyUtil.java +++ b/jdk/src/share/classes/sun/security/util/KeyUtil.java @@ -287,19 +287,31 @@ public final class KeyUtil { tmp = encoded; } + // At this point tmp.length is 48 int encodedVersion = ((tmp[0] & 0xFF) << 8) | (tmp[1] & 0xFF); - int check1 = 0; - int check2 = 0; - int check3 = 0; - if (clientVersion != encodedVersion) check1 = 1; - if (clientVersion > 0x0301) check2 = 1; - if (serverVersion != encodedVersion) check3 = 1; - if ((check1 & (check2 | check3)) == 1) { - return replacer; - } else { - return tmp; + + // The following code is a time-constant version of + // if ((clientVersion != encodedVersion) || + // ((clientVersion > 0x301) && (serverVersion != encodedVersion))) { + // return replacer; + // } else { return tmp; } + int check1 = (clientVersion - encodedVersion) | + (encodedVersion - clientVersion); + int check2 = 0x0301 - clientVersion; + int check3 = (serverVersion - encodedVersion) | + (encodedVersion - serverVersion); + + check1 = (check1 & (check2 | check3)) >> 24; + + // Now check1 is either 0 or -1 + check2 = ~check1; + + for (int i = 0; i < 48; i++) { + tmp[i] = (byte) ((tmp[i] & check2) | (replacer[i] & check1)); } + + return tmp; } /**