From 89892d7b4ddd6334e37a60a53a7178b4c5a1e764 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 12 Apr 2022 12:20:50 +0200 Subject: [PATCH] Prevent JVM crash when buffer too small --- pom.xml | 6 + .../airlift/compress/lz4/Lz4Compressor.java | 13 ++ .../airlift/compress/lz4/Lz4Decompressor.java | 13 ++ .../airlift/compress/lzo/LzoCompressor.java | 13 ++ .../airlift/compress/lzo/LzoDecompressor.java | 13 ++ .../compress/snappy/SnappyCompressor.java | 13 ++ .../compress/snappy/SnappyDecompressor.java | 13 ++ .../airlift/compress/zstd/ZstdCompressor.java | 13 ++ .../compress/zstd/ZstdDecompressor.java | 13 ++ .../compress/AbstractTestCompression.java | 172 ++++++++++++++++++ 10 files changed, 282 insertions(+) diff --git a/pom.xml b/pom.xml index 69a9475e..1dfe31f9 100644 --- a/pom.xml +++ b/pom.xml @@ -151,6 +151,12 @@ 0.4 test + + + org.assertj + assertj-core + test + diff --git a/src/main/java/io/airlift/compress/lz4/Lz4Compressor.java b/src/main/java/io/airlift/compress/lz4/Lz4Compressor.java index df918761..75787516 100644 --- a/src/main/java/io/airlift/compress/lz4/Lz4Compressor.java +++ b/src/main/java/io/airlift/compress/lz4/Lz4Compressor.java @@ -20,6 +20,8 @@ import static io.airlift.compress.lz4.Lz4RawCompressor.MAX_TABLE_SIZE; import static io.airlift.compress.lz4.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; /** @@ -39,6 +41,9 @@ public int maxCompressedLength(int uncompressedSize) @Override public int compress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; @@ -109,4 +114,12 @@ else if (output.hasArray()) { } } } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } } diff --git a/src/main/java/io/airlift/compress/lz4/Lz4Decompressor.java b/src/main/java/io/airlift/compress/lz4/Lz4Decompressor.java index 5908da5c..abb55ea9 100644 --- a/src/main/java/io/airlift/compress/lz4/Lz4Decompressor.java +++ b/src/main/java/io/airlift/compress/lz4/Lz4Decompressor.java @@ -20,6 +20,8 @@ import java.nio.ByteBuffer; import static io.airlift.compress.lz4.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; public class Lz4Decompressor @@ -29,6 +31,9 @@ public class Lz4Decompressor public int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) throws MalformedInputException { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; long inputLimit = inputAddress + inputLength; long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; @@ -95,4 +100,12 @@ else if (output.hasArray()) { } } } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } } diff --git a/src/main/java/io/airlift/compress/lzo/LzoCompressor.java b/src/main/java/io/airlift/compress/lzo/LzoCompressor.java index 3b31f690..ae165991 100644 --- a/src/main/java/io/airlift/compress/lzo/LzoCompressor.java +++ b/src/main/java/io/airlift/compress/lzo/LzoCompressor.java @@ -20,6 +20,8 @@ import static io.airlift.compress.lzo.LzoRawCompressor.MAX_TABLE_SIZE; import static io.airlift.compress.lzo.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; /** @@ -39,6 +41,9 @@ public int maxCompressedLength(int uncompressedSize) @Override public int compress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; @@ -109,4 +114,12 @@ else if (output.hasArray()) { } } } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } } diff --git a/src/main/java/io/airlift/compress/lzo/LzoDecompressor.java b/src/main/java/io/airlift/compress/lzo/LzoDecompressor.java index 17141dea..037bcb9a 100644 --- a/src/main/java/io/airlift/compress/lzo/LzoDecompressor.java +++ b/src/main/java/io/airlift/compress/lzo/LzoDecompressor.java @@ -20,6 +20,8 @@ import java.nio.ByteBuffer; import static io.airlift.compress.lzo.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; public class LzoDecompressor @@ -29,6 +31,9 @@ public class LzoDecompressor public int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) throws MalformedInputException { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; long inputLimit = inputAddress + inputLength; long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; @@ -95,4 +100,12 @@ else if (output.hasArray()) { } } } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } } diff --git a/src/main/java/io/airlift/compress/snappy/SnappyCompressor.java b/src/main/java/io/airlift/compress/snappy/SnappyCompressor.java index b3f1ed0c..cffe9ff7 100644 --- a/src/main/java/io/airlift/compress/snappy/SnappyCompressor.java +++ b/src/main/java/io/airlift/compress/snappy/SnappyCompressor.java @@ -19,6 +19,8 @@ import java.nio.ByteBuffer; import static io.airlift.compress.snappy.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; public class SnappyCompressor @@ -35,6 +37,9 @@ public int maxCompressedLength(int uncompressedSize) @Override public int compress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; long inputLimit = inputAddress + inputLength; long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; @@ -107,4 +112,12 @@ else if (output.hasArray()) { } } } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } } diff --git a/src/main/java/io/airlift/compress/snappy/SnappyDecompressor.java b/src/main/java/io/airlift/compress/snappy/SnappyDecompressor.java index 3029953c..981748c9 100644 --- a/src/main/java/io/airlift/compress/snappy/SnappyDecompressor.java +++ b/src/main/java/io/airlift/compress/snappy/SnappyDecompressor.java @@ -20,6 +20,8 @@ import java.nio.ByteBuffer; import static io.airlift.compress.snappy.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; public class SnappyDecompressor @@ -37,6 +39,9 @@ public static int getUncompressedLength(byte[] compressed, int compressedOffset) public int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) throws MalformedInputException { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; long inputLimit = inputAddress + inputLength; long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; @@ -103,4 +108,12 @@ else if (output.hasArray()) { } } } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } } diff --git a/src/main/java/io/airlift/compress/zstd/ZstdCompressor.java b/src/main/java/io/airlift/compress/zstd/ZstdCompressor.java index d7cdf850..91b5605a 100644 --- a/src/main/java/io/airlift/compress/zstd/ZstdCompressor.java +++ b/src/main/java/io/airlift/compress/zstd/ZstdCompressor.java @@ -20,6 +20,8 @@ import static io.airlift.compress.zstd.Constants.MAX_BLOCK_SIZE; import static io.airlift.compress.zstd.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; public class ZstdCompressor @@ -40,6 +42,9 @@ public int maxCompressedLength(int uncompressedSize) @Override public int compress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; @@ -110,4 +115,12 @@ else if (output.hasArray()) { } } } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } } diff --git a/src/main/java/io/airlift/compress/zstd/ZstdDecompressor.java b/src/main/java/io/airlift/compress/zstd/ZstdDecompressor.java index 35744644..93761a1f 100644 --- a/src/main/java/io/airlift/compress/zstd/ZstdDecompressor.java +++ b/src/main/java/io/airlift/compress/zstd/ZstdDecompressor.java @@ -20,6 +20,8 @@ import java.nio.ByteBuffer; import static io.airlift.compress.zstd.UnsafeUtil.getAddress; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET; public class ZstdDecompressor @@ -31,6 +33,9 @@ public class ZstdDecompressor public int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength) throws MalformedInputException { + verifyRange(input, inputOffset, inputLength); + verifyRange(output, outputOffset, maxOutputLength); + long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset; long inputLimit = inputAddress + inputLength; long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset; @@ -103,4 +108,12 @@ public static long getDecompressedSize(byte[] input, int offset, int length) int baseAddress = ARRAY_BYTE_BASE_OFFSET + offset; return ZstdFrameDecompressor.getDecompressedSize(input, baseAddress, baseAddress + length); } + + private static void verifyRange(byte[] data, int offset, int length) + { + requireNonNull(data, "data is null"); + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length)); + } + } } diff --git a/src/test/java/io/airlift/compress/AbstractTestCompression.java b/src/test/java/io/airlift/compress/AbstractTestCompression.java index f1355a85..0d6a7038 100644 --- a/src/test/java/io/airlift/compress/AbstractTestCompression.java +++ b/src/test/java/io/airlift/compress/AbstractTestCompression.java @@ -22,15 +22,21 @@ import javax.inject.Inject; import java.io.IOException; +import java.io.UncheckedIOException; import java.nio.Buffer; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Random; import java.util.concurrent.ThreadLocalRandom; import static com.google.common.base.Preconditions.checkPositionIndexes; +import static java.lang.System.arraycopy; import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.catchThrowable; import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; @@ -142,6 +148,89 @@ public void testDecompressionBufferOverrun(DataSet dataSet) assertByteArraysEqual(padding, 0, padding.length, uncompressed, uncompressed.length - padding.length, padding.length); } + @Test + public void testDecompressInputBoundsChecks() + { + byte[] data = new byte[1024]; + new Random(1234).nextBytes(data); + Compressor compressor = getCompressor(); + byte[] compressed = new byte[compressor.maxCompressedLength(data.length)]; + int compressedLength = compressor.compress(data, 0, data.length, compressed, 0, compressed.length); + + Decompressor decompressor = getDecompressor(); + Throwable throwable; + + // null input buffer + assertThatThrownBy(() -> decompressor.decompress(null, 0, compressedLength, data, 0, data.length)) + .isInstanceOf(NullPointerException.class); + + // mis-declared buffer size + byte[] compressedChoppedOff = Arrays.copyOf(compressed, compressedLength - 1); + throwable = catchThrowable(() -> decompressor.decompress(compressedChoppedOff, 0, compressedLength, data, 0, data.length)); + if (throwable instanceof UncheckedIOException) { + // OK + } + else { + assertThat(throwable) + .hasMessageMatching(".*must not be greater than size.*|Invalid offset or length.*"); + } + + // overrun because of offset + byte[] compressedWithPadding = new byte[10 + compressedLength - 1]; + arraycopy(compressed, 0, compressedWithPadding, 10, compressedLength - 1); + + throwable = catchThrowable(() -> decompressor.decompress(compressedWithPadding, 10, compressedLength, data, 0, data.length)); + if (throwable instanceof UncheckedIOException) { + // OK + } + else { + assertThat(throwable) + .hasMessageMatching(".*must not be greater than size.*|Invalid offset or length.*"); + } + } + + @Test + public void testDecompressOutputBoundsChecks() + { + byte[] data = new byte[1024]; + new Random(1234).nextBytes(data); + Compressor compressor = getCompressor(); + byte[] compressed = new byte[compressor.maxCompressedLength(data.length)]; + int compressedLength = compressor.compress(data, 0, data.length, compressed, 0, compressed.length); + byte[] input = Arrays.copyOf(compressed, compressedLength); + + Decompressor decompressor = getDecompressor(); + Throwable throwable; + + // null output buffer + assertThatThrownBy(() -> decompressor.decompress(input, 0, input.length, null, 0, data.length)) + .isInstanceOf(NullPointerException.class); + + // small buffer + assertThatThrownBy(() -> decompressor.decompress(input, 0, input.length, new byte[1], 0, 1)) + .hasMessageMatching("All input was not consumed|attempt to write.* outside of destination buffer.*|Malformed input.*|Uncompressed length 1024 must be less than 1|Output buffer too small.*"); + + // mis-declared buffer size + throwable = catchThrowable(() -> decompressor.decompress(input, 0, input.length, new byte[1], 0, data.length)); + if (throwable instanceof IndexOutOfBoundsException) { + // OK + } + else { + assertThat(throwable) + .hasMessageMatching(".*must not be greater than size.*|Invalid offset or length.*"); + } + + // mis-declared buffer size with greater buffer + throwable = catchThrowable(() -> decompressor.decompress(input, 0, input.length, new byte[data.length - 1], 0, data.length)); + if (throwable instanceof IndexOutOfBoundsException) { + // OK + } + else { + assertThat(throwable) + .hasMessageMatching(".*must not be greater than size.*|Invalid offset or length.*"); + } + } + @Test(dataProvider = "data") public void testDecompressByteBufferHeapToHeap(DataSet dataSet) throws Exception @@ -245,6 +334,89 @@ public void testCompress(DataSet testCase) verifyCompressedData(originalUncompressed, compressed, compressedLength); } + @Test + public void testCompressInputBoundsChecks() + { + Compressor compressor = getCompressor(); + int declaredInputLength = 1024; + int maxCompressedLength = compressor.maxCompressedLength(1024); + byte[] output = new byte[maxCompressedLength]; + Throwable throwable; + + // null input buffer + assertThatThrownBy(() -> compressor.compress(null, 0, declaredInputLength, output, 0, output.length)) + .isInstanceOf(NullPointerException.class); + + // mis-declared buffer size + throwable = catchThrowable(() -> compressor.compress(new byte[1], 0, declaredInputLength, output, 0, output.length)); + if (throwable instanceof IndexOutOfBoundsException) { + // OK + } + else { + assertThat(throwable) + .hasMessageMatching(".*must not be greater than size.*|Invalid offset or length.*"); + } + + // max too small + throwable = catchThrowable(() -> compressor.compress(new byte[declaredInputLength - 1], 0, declaredInputLength, output, 0, output.length)); + if (throwable instanceof IndexOutOfBoundsException) { + // OK + } + else { + assertThat(throwable) + .hasMessageMatching(".*must not be greater than size.*|Invalid offset or length.*"); + } + + // overrun because of offset + throwable = catchThrowable(() -> compressor.compress(new byte[declaredInputLength + 10], 11, declaredInputLength, output, 0, output.length)); + if (throwable instanceof IndexOutOfBoundsException) { + // OK + } + else { + assertThat(throwable) + .hasMessageMatching(".*must not be greater than size.*|Invalid offset or length.*"); + } + } + + @Test + public void testCompressOutputBoundsChecks() + { + Compressor compressor = getCompressor(); + int minCompressionOverhead = compressor.maxCompressedLength(0); + byte[] input = new byte[minCompressionOverhead * 4 + 1024]; + new Random(1234).nextBytes(input); + int maxCompressedLength = compressor.maxCompressedLength(input.length); + Throwable throwable; + + // null output buffer + assertThatThrownBy(() -> compressor.compress(input, 0, input.length, null, 0, maxCompressedLength)) + .isInstanceOf(NullPointerException.class); + + // small buffer + assertThatThrownBy(() -> compressor.compress(input, 0, input.length, new byte[1], 0, 1)) + .hasMessageMatching(".*must not be greater than size.*|Invalid offset or length.*|Max output length must be larger than .*|Output buffer must be at least.*|Output buffer too small"); + + // mis-declared buffer size + throwable = catchThrowable(() -> compressor.compress(input, 0, input.length, new byte[1], 0, maxCompressedLength)); + if (throwable instanceof ArrayIndexOutOfBoundsException) { + // OK + } + else { + assertThat(throwable) + .hasMessageMatching(".*must not be greater than size.*|Invalid offset or length.*"); + } + + // mis-declared buffer size with buffer large enough to hold compression frame header (if any) + throwable = catchThrowable(() -> compressor.compress(input, 0, input.length, new byte[minCompressionOverhead * 2], 0, maxCompressedLength)); + if (throwable instanceof ArrayIndexOutOfBoundsException) { + // OK + } + else { + assertThat(throwable) + .hasMessageMatching(".*must not be greater than size.*|Invalid offset or length.*"); + } + } + @Test(dataProvider = "data") public void testCompressByteBufferHeapToHeap(DataSet dataSet) throws Exception