diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/varint/read/VarIntByteArrayReadBench.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/varint/read/VarIntByteArrayReadBench.java index 38a90c867..98f7453ca 100644 --- a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/varint/read/VarIntByteArrayReadBench.java +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/integration/jmh/varint/read/VarIntByteArrayReadBench.java @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 package com.hedera.pbj.integration.jmh.varint.read; +import com.hedera.pbj.runtime.io.DataEncodingException; import com.hedera.pbj.runtime.io.buffer.BufferedData; import java.util.Random; import java.util.concurrent.TimeUnit; @@ -73,7 +74,7 @@ record NumRange(int min, int max) {} /// results. /// So we make every algorithm maintain the sum in this variable instead, so that they all spend the /// exact same time updating it. - int sum; + long sum; @Setup(Level.Trial) public void setup() { @@ -102,23 +103,27 @@ public void tearDown() {} @OperationsPerInvocation(INVOCATIONS) public void pbj(final BenchState state, final Blackhole blackhole) { state.sum = 0; + outer: for (int invocation = 0, pos = 0; invocation < INVOCATIONS; invocation++) { - int value = 0; + final int limit = Math.max(0, Math.min(state.array.length, pos + 10) - pos); + long value = 0; - for (int i = 0; i < 10; i++) { + for (int i = 0; i < limit; i++) { final byte b = state.array[pos++]; - value |= (b & 0x7F) << (i * 7); + value |= (b & 0x7FL) << (i * 7); if (b >= 0) { state.sum += state.zigZag ? (value >>> 1) ^ -(value & 1) : value; - break; + continue outer; } } + + throw new DataEncodingException("Malformed var int"); } blackhole.consume(state.sum); } /// A variation of LEB128 with the zigZag conditional removed. - @Benchmark + // @Benchmark // disabled because it didn't show a significant improvement @OperationsPerInvocation(INVOCATIONS) public void pbj_zigZagFalse(final BenchState state, final Blackhole blackhole) { state.sum = 0; @@ -138,7 +143,7 @@ public void pbj_zigZagFalse(final BenchState state, final Blackhole blackhole) { } /// A variation of LEB128 that uses `(b & 0x80) == 0` instead of `b >= 0`. - @Benchmark + // @Benchmark // disabled because it didn't show a significant improvement @OperationsPerInvocation(INVOCATIONS) public void pbj_BitwiseAndCondition(final BenchState state, final Blackhole blackhole) { state.sum = 0; @@ -158,7 +163,7 @@ public void pbj_BitwiseAndCondition(final BenchState state, final Blackhole blac } /// A variation of LEB128 that uses do/while loop instead of a for loop to skip one branch for 1-byte varint. - @Benchmark + // @Benchmark // disabled because it didn't show a significant improvement @OperationsPerInvocation(INVOCATIONS) public void pbj_doWhileLoop(final BenchState state, final Blackhole blackhole) { state.sum = 0; @@ -182,7 +187,7 @@ public void pbj_doWhileLoop(final BenchState state, final Blackhole blackhole) { /// PBJ used to use a very similar algorithm, just before https://github.com/hashgraph/pbj/pull/144 /// where we switched to LEB128. @SuppressWarnings("lossy-conversions") // the impl is able to support longs, but we ignore that here and use ints. - @Benchmark + // @Benchmark // disabled because it performs worse than the standard pbj implementation @OperationsPerInvocation(INVOCATIONS) public void google(final BenchState state, final Blackhole blackhole) { state.sum = 0; @@ -263,7 +268,7 @@ public void google(final BenchState state, final Blackhole blackhole) { } /// A LEB128 with fully unrolled loop. - @Benchmark + // @Benchmark // disabled because the algorithm is missing limit checks @OperationsPerInvocation(INVOCATIONS) public void loopLess(final BenchState state, final Blackhole blackhole) { state.sum = 0; @@ -305,6 +310,140 @@ public void loopLess(final BenchState state, final Blackhole blackhole) { blackhole.consume(state.sum); } + /// A LEB128 with fully unrolled loop and limit checks after each read. + // @Benchmark // disabled because it's replaced with the vector thing below. + @OperationsPerInvocation(INVOCATIONS) + public void loopLess_withLimitChecks(final BenchState state, final Blackhole blackhole) { + state.sum = 0; + for (int invocation = 0, pos = 0; invocation < INVOCATIONS; invocation++) { + final int lim = Math.min(state.array.length, pos + 5); + + byte b = state.array[pos++]; + if ((b & 0x80) == 0) { + state.sum += b; + continue; + } + if (pos == lim) throw new DataEncodingException("Malformed var int"); + + int v = b & 0x7F; + b = state.array[pos++]; + if ((b & 0x80) == 0) { + state.sum += v | b << 7; + continue; + } + if (pos == lim) throw new DataEncodingException("Malformed var int"); + + v |= (b & 0x7F) << 7; + b = state.array[pos++]; + if ((b & 0x80) == 0) { + state.sum += v | b << 14; + continue; + } + if (pos == lim) throw new DataEncodingException("Malformed var int"); + + v |= (b & 0x7F) << 14; + b = state.array[pos++]; + if ((b & 0x80) == 0) { + state.sum += v | b << 21; + continue; + } + if (pos == lim) throw new DataEncodingException("Malformed var int"); + + v |= (b & 0x7F) << 21; + b = state.array[pos++]; + // if (b > 0) { + // Stop here because this benchmark doesn't support longs, only ints. + state.sum += v | b << 28; + // } + } + blackhole.consume(state.sum); + } + + /// A vectorized LEB128, similar to the loopLess above, with some minor tweaks and supporting long varints. + @Benchmark + @OperationsPerInvocation(INVOCATIONS) + public void vector_zigZag(final BenchState state, final Blackhole blackhole) { + state.sum = 0; + for (int invocation = 0, pos = 0; invocation < INVOCATIONS; invocation++) { + final int limit = Math.min(state.array.length, pos + 10); + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + byte b; + long v = (b = state.array[pos++]) & 0x7F; + if ((b & 0x80) == 0) { + state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v; + continue; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = state.array[pos++]) & 0x7F) << 7; + if ((b & 0x80) == 0) { + state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v; + continue; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = state.array[pos++]) & 0x7F) << 14; + if ((b & 0x80) == 0) { + state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v; + continue; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = state.array[pos++]) & 0x7F) << 21; + if ((b & 0x80) == 0) { + state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v; + continue; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = state.array[pos++]) & 0x7FL) << 28; + if ((b & 0x80) == 0) { + state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v; + continue; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = state.array[pos++]) & 0x7FL) << 35; + if ((b & 0x80) == 0) { + state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v; + continue; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = state.array[pos++]) & 0x7FL) << 42; + if ((b & 0x80) == 0) { + state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v; + continue; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = state.array[pos++]) & 0x7FL) << 49; + if ((b & 0x80) == 0) { + state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v; + continue; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = state.array[pos++]) & 0x7FL) << 56; + if ((b & 0x80) == 0) { + state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v; + continue; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + b = state.array[pos++]; + if ((b & 0x80) == 0) { + v |= (long) b << 63; + state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v; + continue; + } + + throw new DataEncodingException("Malformed var int"); + } + blackhole.consume(state.sum); + } + public static void main(String[] args) throws Exception { Options opt = new OptionsBuilder() .include(VarIntByteArrayReadBench.class.getSimpleName()) diff --git a/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/VectorVarIntTest.java b/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/VectorVarIntTest.java new file mode 100644 index 000000000..830bc0132 --- /dev/null +++ b/pbj-integration-tests/src/test/java/com/hedera/pbj/integration/test/VectorVarIntTest.java @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: Apache-2.0 +package com.hedera.pbj.integration.test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.hedera.pbj.runtime.io.DataEncodingException; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import java.util.Arrays; +import java.util.Random; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/// A test to verify the correctness of the new varint reading algorithm. +public class VectorVarIntTest { + + /// A refactored copy from VarIntByteArrayReadBench.vector_zigZag. + private long readVarInt(byte[] bytes, int pos, boolean zigZag) { + final int limit = Math.min(bytes.length, pos + 10); + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + byte b; + long v = (b = bytes[pos++]) & 0x7F; + if ((b & 0x80) == 0) { + return zigZag ? (v >>> 1) ^ -(v & 1) : v; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = bytes[pos++]) & 0x7F) << 7; + if ((b & 0x80) == 0) { + return zigZag ? (v >>> 1) ^ -(v & 1) : v; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = bytes[pos++]) & 0x7F) << 14; + if ((b & 0x80) == 0) { + return zigZag ? (v >>> 1) ^ -(v & 1) : v; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = bytes[pos++]) & 0x7F) << 21; + if ((b & 0x80) == 0) { + return zigZag ? (v >>> 1) ^ -(v & 1) : v; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = bytes[pos++]) & 0x7FL) << 28; + if ((b & 0x80) == 0) { + return zigZag ? (v >>> 1) ^ -(v & 1) : v; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = bytes[pos++]) & 0x7FL) << 35; + if ((b & 0x80) == 0) { + return zigZag ? (v >>> 1) ^ -(v & 1) : v; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = bytes[pos++]) & 0x7FL) << 42; + if ((b & 0x80) == 0) { + return zigZag ? (v >>> 1) ^ -(v & 1) : v; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = bytes[pos++]) & 0x7FL) << 49; + if ((b & 0x80) == 0) { + return zigZag ? (v >>> 1) ^ -(v & 1) : v; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + v |= ((b = bytes[pos++]) & 0x7FL) << 56; + if ((b & 0x80) == 0) { + return zigZag ? (v >>> 1) ^ -(v & 1) : v; + } + if (pos >= limit) throw new DataEncodingException("Malformed var int"); + + b = bytes[pos++]; + if ((b & 0x80) == 0) { + v |= (long) b << 63; + return zigZag ? (v >>> 1) ^ -(v & 1) : v; + } + + throw new DataEncodingException("Malformed var int"); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testVectorVarInt(boolean zigZag) { + final byte[] bytes = new byte[64]; + final BufferedData bd = BufferedData.wrap(bytes); + final Random random = new Random(457639854); + + for (int i = 0; i < 10 * 1024 * 1024; i++) { + final int val = random.nextInt(); + Arrays.fill(bytes, (byte) 0); + bd.writeVarInt(val, zigZag); + + assertEquals(val, readVarInt(bytes, 0, zigZag)); + + bd.reset(); + } + } +}