Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
package com.hedera.pbj.integration.jmh.varint.read;

import com.hedera.pbj.runtime.io.DataEncodingException;
import com.hedera.pbj.runtime.io.buffer.BufferedData;
import java.util.Random;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -73,7 +74,7 @@ record NumRange(int min, int max) {}
/// results.
/// So we make every algorithm maintain the sum in this variable instead, so that they all spend the
/// exact same time updating it.
int sum;
long sum;

@Setup(Level.Trial)
public void setup() {
Expand Down Expand Up @@ -102,23 +103,27 @@ public void tearDown() {}
@OperationsPerInvocation(INVOCATIONS)
public void pbj(final BenchState state, final Blackhole blackhole) {
state.sum = 0;
outer:
for (int invocation = 0, pos = 0; invocation < INVOCATIONS; invocation++) {
int value = 0;
final int limit = Math.max(0, Math.min(state.array.length, pos + 10) - pos);
long value = 0;

for (int i = 0; i < 10; i++) {
for (int i = 0; i < limit; i++) {
final byte b = state.array[pos++];
value |= (b & 0x7F) << (i * 7);
value |= (b & 0x7FL) << (i * 7);
if (b >= 0) {
state.sum += state.zigZag ? (value >>> 1) ^ -(value & 1) : value;
break;
continue outer;
}
}

throw new DataEncodingException("Malformed var int");
}
blackhole.consume(state.sum);
}

/// A variation of LEB128 with the zigZag conditional removed.
@Benchmark
// @Benchmark // disabled because it didn't show a significant improvement
@OperationsPerInvocation(INVOCATIONS)
public void pbj_zigZagFalse(final BenchState state, final Blackhole blackhole) {
state.sum = 0;
Expand All @@ -138,7 +143,7 @@ public void pbj_zigZagFalse(final BenchState state, final Blackhole blackhole) {
}

/// A variation of LEB128 that uses `(b & 0x80) == 0` instead of `b >= 0`.
@Benchmark
// @Benchmark // disabled because it didn't show a significant improvement
@OperationsPerInvocation(INVOCATIONS)
public void pbj_BitwiseAndCondition(final BenchState state, final Blackhole blackhole) {
state.sum = 0;
Expand All @@ -158,7 +163,7 @@ public void pbj_BitwiseAndCondition(final BenchState state, final Blackhole blac
}

/// A variation of LEB128 that uses do/while loop instead of a for loop to skip one branch for 1-byte varint.
@Benchmark
// @Benchmark // disabled because it didn't show a significant improvement
@OperationsPerInvocation(INVOCATIONS)
public void pbj_doWhileLoop(final BenchState state, final Blackhole blackhole) {
state.sum = 0;
Expand All @@ -182,7 +187,7 @@ public void pbj_doWhileLoop(final BenchState state, final Blackhole blackhole) {
/// PBJ used to use a very similar algorithm, just before https://github.com/hashgraph/pbj/pull/144
/// where we switched to LEB128.
@SuppressWarnings("lossy-conversions") // the impl is able to support longs, but we ignore that here and use ints.
@Benchmark
// @Benchmark // disabled because it performs worse than the standard pbj implementation
@OperationsPerInvocation(INVOCATIONS)
public void google(final BenchState state, final Blackhole blackhole) {
state.sum = 0;
Expand Down Expand Up @@ -263,7 +268,7 @@ public void google(final BenchState state, final Blackhole blackhole) {
}

/// A LEB128 with fully unrolled loop.
@Benchmark
// @Benchmark // disabled because the algorithm is missing limit checks
@OperationsPerInvocation(INVOCATIONS)
public void loopLess(final BenchState state, final Blackhole blackhole) {
state.sum = 0;
Expand Down Expand Up @@ -305,6 +310,140 @@ public void loopLess(final BenchState state, final Blackhole blackhole) {
blackhole.consume(state.sum);
}

/// A LEB128 with fully unrolled loop and limit checks after each read.
// @Benchmark // disabled because it's replaced with the vector thing below.
@OperationsPerInvocation(INVOCATIONS)
public void loopLess_withLimitChecks(final BenchState state, final Blackhole blackhole) {
state.sum = 0;
for (int invocation = 0, pos = 0; invocation < INVOCATIONS; invocation++) {
final int lim = Math.min(state.array.length, pos + 5);

byte b = state.array[pos++];
if ((b & 0x80) == 0) {
state.sum += b;
continue;
}
if (pos == lim) throw new DataEncodingException("Malformed var int");

int v = b & 0x7F;
b = state.array[pos++];
if ((b & 0x80) == 0) {
state.sum += v | b << 7;
continue;
}
if (pos == lim) throw new DataEncodingException("Malformed var int");

v |= (b & 0x7F) << 7;
b = state.array[pos++];
if ((b & 0x80) == 0) {
state.sum += v | b << 14;
continue;
}
if (pos == lim) throw new DataEncodingException("Malformed var int");

v |= (b & 0x7F) << 14;
b = state.array[pos++];
if ((b & 0x80) == 0) {
state.sum += v | b << 21;
continue;
}
if (pos == lim) throw new DataEncodingException("Malformed var int");

v |= (b & 0x7F) << 21;
b = state.array[pos++];
// if (b > 0) {
// Stop here because this benchmark doesn't support longs, only ints.
state.sum += v | b << 28;
// }
}
blackhole.consume(state.sum);
}

/// A vectorized LEB128, similar to the loopLess above, with some minor tweaks and supporting long varints.
@Benchmark
@OperationsPerInvocation(INVOCATIONS)
public void vector_zigZag(final BenchState state, final Blackhole blackhole) {
state.sum = 0;
for (int invocation = 0, pos = 0; invocation < INVOCATIONS; invocation++) {
final int limit = Math.min(state.array.length, pos + 10);
if (pos >= limit) throw new DataEncodingException("Malformed var int");

byte b;
long v = (b = state.array[pos++]) & 0x7F;
if ((b & 0x80) == 0) {
state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v;
continue;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = state.array[pos++]) & 0x7F) << 7;
if ((b & 0x80) == 0) {
state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v;
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
continue;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = state.array[pos++]) & 0x7F) << 14;
if ((b & 0x80) == 0) {
state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v;
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
continue;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = state.array[pos++]) & 0x7F) << 21;
if ((b & 0x80) == 0) {
state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v;
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
continue;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = state.array[pos++]) & 0x7FL) << 28;
if ((b & 0x80) == 0) {
state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v;
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
continue;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = state.array[pos++]) & 0x7FL) << 35;
if ((b & 0x80) == 0) {
state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v;
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
continue;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = state.array[pos++]) & 0x7FL) << 42;
if ((b & 0x80) == 0) {
state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v;
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
continue;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = state.array[pos++]) & 0x7FL) << 49;
if ((b & 0x80) == 0) {
state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v;
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
continue;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = state.array[pos++]) & 0x7FL) << 56;
if ((b & 0x80) == 0) {
state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v;
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
continue;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

b = state.array[pos++];
if ((b & 0x80) == 0) {
v |= (long) b << 63;
state.sum += state.zigZag ? (v >>> 1) ^ -(v & 1) : v;
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
continue;
}

throw new DataEncodingException("Malformed var int");
}
blackhole.consume(state.sum);
}

public static void main(String[] args) throws Exception {
Options opt = new OptionsBuilder()
.include(VarIntByteArrayReadBench.class.getSimpleName())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// SPDX-License-Identifier: Apache-2.0
package com.hedera.pbj.integration.test;

import static org.junit.jupiter.api.Assertions.assertEquals;

import com.hedera.pbj.runtime.io.DataEncodingException;
import com.hedera.pbj.runtime.io.buffer.BufferedData;
import java.util.Arrays;
import java.util.Random;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

/// A test to verify the correctness of the new varint reading algorithm.
public class VectorVarIntTest {

/// A refactored copy from VarIntByteArrayReadBench.vector_zigZag.
private long readVarInt(byte[] bytes, int pos, boolean zigZag) {
final int limit = Math.min(bytes.length, pos + 10);
if (pos >= limit) throw new DataEncodingException("Malformed var int");

byte b;
long v = (b = bytes[pos++]) & 0x7F;
if ((b & 0x80) == 0) {
return zigZag ? (v >>> 1) ^ -(v & 1) : v;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = bytes[pos++]) & 0x7F) << 7;
if ((b & 0x80) == 0) {
return zigZag ? (v >>> 1) ^ -(v & 1) : v;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = bytes[pos++]) & 0x7F) << 14;
if ((b & 0x80) == 0) {
return zigZag ? (v >>> 1) ^ -(v & 1) : v;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = bytes[pos++]) & 0x7F) << 21;
if ((b & 0x80) == 0) {
return zigZag ? (v >>> 1) ^ -(v & 1) : v;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = bytes[pos++]) & 0x7FL) << 28;
if ((b & 0x80) == 0) {
return zigZag ? (v >>> 1) ^ -(v & 1) : v;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = bytes[pos++]) & 0x7FL) << 35;
if ((b & 0x80) == 0) {
return zigZag ? (v >>> 1) ^ -(v & 1) : v;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = bytes[pos++]) & 0x7FL) << 42;
if ((b & 0x80) == 0) {
return zigZag ? (v >>> 1) ^ -(v & 1) : v;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = bytes[pos++]) & 0x7FL) << 49;
if ((b & 0x80) == 0) {
return zigZag ? (v >>> 1) ^ -(v & 1) : v;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

v |= ((b = bytes[pos++]) & 0x7FL) << 56;
if ((b & 0x80) == 0) {
return zigZag ? (v >>> 1) ^ -(v & 1) : v;
}
if (pos >= limit) throw new DataEncodingException("Malformed var int");

b = bytes[pos++];
if ((b & 0x80) == 0) {
v |= (long) b << 63;
return zigZag ? (v >>> 1) ^ -(v & 1) : v;
}

throw new DataEncodingException("Malformed var int");
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testVectorVarInt(boolean zigZag) {
final byte[] bytes = new byte[64];
final BufferedData bd = BufferedData.wrap(bytes);
final Random random = new Random(457639854);

for (int i = 0; i < 10 * 1024 * 1024; i++) {
final int val = random.nextInt();
Arrays.fill(bytes, (byte) 0);
bd.writeVarInt(val, zigZag);

assertEquals(val, readVarInt(bytes, 0, zigZag));

bd.reset();
}
}
}
Loading