Skip to content

Move save-time state off TrieBuilder.Node#15970

Merged
romseygeek merged 2 commits intoapache:mainfrom
romseygeek:merge/improve-triebuilder-memory-usage
Apr 22, 2026
Merged

Move save-time state off TrieBuilder.Node#15970
romseygeek merged 2 commits intoapache:mainfrom
romseygeek:merge/improve-triebuilder-memory-usage

Conversation

@romseygeek
Copy link
Copy Markdown
Contributor

TrieBuilder.Node objects are constructed during field merges, and last as long as the merge for a given field. They can use significant heap memory, particularly for fields with many long terms (eg fields that store uuids or URLs). Two of the member fields on Node are only used at the end of a merge when the trie is being written, and so are effectively wasted memory during trie construction.

This commit moves these two member fields out of Node onto a separate SaveFrame class, which is only built when the trie is saved. This should reduce the amount of memory held by Nodes by around a quarter.

TrieBuilder.Node objects are constructed during field merges, and last as
long as the merge for a given field.  They can use significant heap memory,
particularly for fields with many long terms (eg fields that store uuids
or URLs). Two of the member fields on Node are only used at the end of a
merge when the trie is being written, and so are effectively wasted memory
during trie construction.

This commit moves these two member fields out of Node onto a separate
SaveFrame class, which is only built when the trie is saved. This should
reduce the amount of memory held by Nodes by around a quarter.
@romseygeek
Copy link
Copy Markdown
Contributor Author

We've seen a few out-of-memory crashes on elasticsearch nodes running with small JVM heaps, with heap dumps consisting almost entirely of PendingBlock/TrieBuilder/Node hierarchies. This should help reduce memory pressure in these cases.

Claude also suggests reworking how Nodes are stored for cases where we have long shared prefixes (URL fields are particularly bad for this), but that's a much bigger change and I'm not quite ready to dive into how to do that yet :)

@github-actions github-actions Bot added this to the 10.5.0 milestone Apr 21, 2026
@gf2121
Copy link
Copy Markdown
Contributor

gf2121 commented Apr 22, 2026

Thanks for opening this PR and looking into this! I apologize for not having optimized the TrieBuilder for better memory efficiency earlier. As the TODO in the JavaDoc points out, we definitely need to make it a much more memory-efficient structure.

Overall, I am fine with the idea of introducing a SaveFrame to shave off those 16 bytes (an 8-byte object reference + an 8-byte long fp) from each Node. However, I am a bit skeptical about whether this will resolve the OOM issues you are seeing with long fields like URLs. Even with this reduction, the sheer volume of Node object headers and structural references will still consume a massive amount of heap memory in those worst-case scenarios.

A couple of relatively straightforward optimization ideas come to mind that might yield more substantial memory savings:

  1. Path Compression (Radix Tree approach): Instead of storing a single int label for every single character, we could store a byte[] labels per Node. This would compress long, unbranched chains of single-child nodes into a single edge, which would drastically reduce the total number of allocated Node objects—especially for long, repetitive strings like URLs or UUIDs.

  2. Similar to how the FSTCompiler worked, we could potentially keep the structures serialized. When merging, we could append the already-serialized data term-by-term into a brand new TrieBuilder (or compiler). This would trade off some CPU overhead during merges for a significant reduction in peak heap memory usage, as we wouldn't need to keep the entire object graph resident in memory at once.

FWIW the original code with FST:

final ByteSequenceOutputs outputs = ByteSequenceOutputs.getSingleton();
final FSTCompiler<BytesRef> fstCompiler =
    new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE1, outputs)
        .suffixRAMLimitMB(0d)
        .bytesPageBits(pageBits)
        .build();

final byte[] bytes = scratchBytes.toArrayCopy();
    assert bytes.length > 0;
    fstCompiler.add(Util.toIntsRef(prefix, scratchIntsRef), new BytesRef(bytes, 0, bytes.length));
    scratchBytes.reset();

    for (PendingBlock block : blocks) {
  if (block.subIndices != null) {
    for (FST<BytesRef> subIndex : block.subIndices) {
      append(fstCompiler, subIndex, scratchIntsRef);
    }
    block.subIndices = null;
  }
}

index = fstCompiler.compile();

@romseygeek
Copy link
Copy Markdown
Contributor Author

Yes, this is very much a stopgap and to be honest I'm not even sure if it will help our case as there's no real way of knowing if we crashed out at the beginning or the end of the process.

Path compression was the next step suggested by Claude so I can try and take a look at that, although this is not an area of the code that I know very well so it will need a lot more eyes on it than mine!

@gf2121
Copy link
Copy Markdown
Contributor

gf2121 commented Apr 22, 2026

AI can easily generate the 'Path Compression' code, i just got a test-passed version, posting here for reference :)

TrieBuilder
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.codecs.lucene103.blocktree;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;
import java.util.function.BiConsumer;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;

/**
 * A builder to build prefix tree (trie) as the index of block tree, and can be saved to disk.
 *
 * <p>TODO make this trie builder a more memory efficient structure.
 */
class TrieBuilder {

  static final int SIGN_NO_CHILDREN = 0x00;
  static final int SIGN_SINGLE_CHILD_WITH_OUTPUT = 0x01;
  static final int SIGN_SINGLE_CHILD_WITHOUT_OUTPUT = 0x02;
  static final int SIGN_MULTI_CHILDREN = 0x03;

  static final int LEAF_NODE_HAS_TERMS = 1 << 5;
  static final int LEAF_NODE_HAS_FLOOR = 1 << 6;
  static final long NON_LEAF_NODE_HAS_TERMS = 1L << 1;
  static final long NON_LEAF_NODE_HAS_FLOOR = 1L << 0;

  /**
   * The output describing the term block the prefix point to.
   *
   * @param fp the file pointer to the on-disk terms block which a trie node points to.
   * @param hasTerms false if this on-disk block consists entirely of pointers to child blocks.
   * @param floorData will be non-null when a large block of terms sharing a single trie prefix is
   *     split into multiple on-disk blocks.
   */
  record Output(long fp, boolean hasTerms, BytesRef floorData) {}

  private enum Status {
    BUILDING,
    SAVED,
    DESTROYED
  }

  private static class Node {

    // The utf8 digit that leads to this Node, 0 for root node
    private byte[] labels;
    // The output of this node.
    private Output output;
    // The number of children of this node.
    private int childrenNum;
    // Pointers to relative nodes
    private Node next;
    private Node firstChild;
    private Node lastChild;

    // Vars used during saving:

    // The file pointer point to where the node saved. -1 means the node has not been saved.
    private long fp = -1;
    // The latest child that have been saved. null means no child has been saved.
    private Node savedTo;

    Node(byte[] labels, Output output) {
      this.labels = labels;
      this.output = output;
    }
  }

  private final Node root = new Node(new byte[0], null);
  private final BytesRef minKey;
  private BytesRef maxKey;
  private Status status = Status.BUILDING;

  static TrieBuilder bytesRefToTrie(BytesRef k, Output v) {
    return new TrieBuilder(k, v);
  }

  private TrieBuilder(BytesRef k, Output v) {
    minKey = maxKey = BytesRef.deepCopyOf(k);
    if (k.length == 0) {
      root.output = v;
      return;
    }
    byte[] path = ArrayUtil.copyOfSubArray(k.bytes, k.offset, k.offset + k.length);
    Node node = new Node(path, v);
    root.firstChild = root.lastChild = node;
    root.childrenNum = 1;
  }

  private void splitNode(Node node, int prefixLen) {
    assert prefixLen > 0 && prefixLen < node.labels.length;
    byte[] prefix = ArrayUtil.copyOfSubArray(node.labels, 0, prefixLen);
    byte[] suffix = ArrayUtil.copyOfSubArray(node.labels, prefixLen, node.labels.length);

    Node child = new Node(suffix, node.output);
    child.childrenNum = node.childrenNum;
    child.firstChild = node.firstChild;
    child.lastChild = node.lastChild;

    node.labels = prefix;
    node.output = null;
    node.childrenNum = 1;
    node.firstChild = node.lastChild = child;
  }

  /**
   * Append all (K, V) pairs from the given trie into this one. The given trie builder need to
   * ensure its keys greater or equals than max key of this one.
   *
   * <p>Note: the given trie will be destroyed after appending.
   */
  void append(TrieBuilder trieBuilder) {
    if (status != Status.BUILDING || trieBuilder.status != Status.BUILDING) {
      throw new IllegalStateException(
          "tries have wrong status, got this: " + status + ", append: " + trieBuilder.status);
    }
    assert this.maxKey.compareTo(trieBuilder.minKey) < 0;

    int mismatch =
        Arrays.mismatch(
            this.maxKey.bytes,
            this.maxKey.offset,
            this.maxKey.offset + this.maxKey.length,
            trieBuilder.minKey.bytes,
            trieBuilder.minKey.offset,
            trieBuilder.minKey.offset + trieBuilder.minKey.length);
    Node a = this.root;
    Node b = trieBuilder.root;

    int matched = 0;
    while (matched < mismatch) {
      final Node aLast = a.lastChild;
      final Node bFirst = b.firstChild;
      assert aLast.labels[0] == bFirst.labels[0];

      int edgeLen = Math.min(aLast.labels.length, bFirst.labels.length);
      int matchLen = Math.min(edgeLen, mismatch - matched);

      if (matchLen < aLast.labels.length) {
        splitNode(aLast, matchLen);
      }
      if (matchLen < bFirst.labels.length) {
        splitNode(bFirst, matchLen);
      }

      if (b.childrenNum > 1) {
        aLast.next = bFirst.next;
        a.childrenNum += b.childrenNum - 1;
        a.lastChild = b.lastChild;
        assert assertChildrenLabelInOrder(a);
      }

      a = aLast;
      b = bFirst;
      matched += matchLen;
    }

    assert b.childrenNum > 0;
    if (a.childrenNum == 0) {
      a.firstChild = b.firstChild;
      a.lastChild = b.lastChild;
      a.childrenNum = b.childrenNum;
    } else {
      assert (a.lastChild.labels[0] & 0xFF) < (b.firstChild.labels[0] & 0xFF);
      a.lastChild.next = b.firstChild;
      a.lastChild = b.lastChild;
      a.childrenNum += b.childrenNum;
    }
    assert assertChildrenLabelInOrder(a);

    this.maxKey = trieBuilder.maxKey;
    trieBuilder.status = Status.DESTROYED;
  }

  Output getEmptyOutput() {
    return root.output;
  }

  /**
   * Used for tests only. The recursive impl need to be avoided if someone plans to use for
   * production one day.
   */
  void visit(BiConsumer<BytesRef, Output> consumer) {
    assert status == Status.BUILDING;
    if (root.output != null) {
      consumer.accept(new BytesRef(), root.output);
    }
    visit(root.firstChild, new BytesRefBuilder(), consumer);
  }

  private void visit(Node first, BytesRefBuilder key, BiConsumer<BytesRef, Output> consumer) {
    while (first != null) {
      int len = first.labels.length;
      key.append(first.labels, 0, len);
      if (first.output != null) {
        consumer.accept(key.toBytesRef(), first.output);
      }
      visit(first.firstChild, key, consumer);
      key.setLength(key.length() - len);
      first = first.next;
    }
  }

  void save(DataOutput meta, IndexOutput index) throws IOException {
    if (status != Status.BUILDING) {
      throw new IllegalStateException("only unsaved trie can be saved, got: " + status);
    }
    meta.writeVLong(index.getFilePointer());
    saveNodes(index);
    meta.writeVLong(root.fp);
    index.writeLong(0L); // additional 8 bytes for over-reading
    meta.writeVLong(index.getFilePointer());
    status = Status.SAVED;
  }

  void saveNodes(IndexOutput index) throws IOException {
    final long startFP = index.getFilePointer();
    Deque<Node> stack = new ArrayDeque<>();
    stack.push(root);

    // Visit and save nodes of this trie in a post-order depth-first traversal.
    while (stack.isEmpty() == false) {
      Node node = stack.peek();
      assert node.fp == -1;
      assert assertChildrenLabelInOrder(node);

      final int childrenNum = node.childrenNum;

      if (childrenNum == 0) { // leaf node
        assert node.output != null : "leaf nodes should have output.";

        long bottomFp = index.getFilePointer() - startFP;

        // [n bytes] floor data
        // [n bytes] output fp
        // [1bit] x | [1bit] has floor | [1bit] has terms | [3bit] output fp bytes | [2bit] sign

        Output output = node.output;
        int outputFpBytes = bytesRequiredVLong(output.fp);
        int header =
            SIGN_NO_CHILDREN
                | ((outputFpBytes - 1) << 2)
                | (output.hasTerms ? LEAF_NODE_HAS_TERMS : 0)
                | (output.floorData != null ? LEAF_NODE_HAS_FLOOR : 0);
        index.writeByte(((byte) header));
        writeLongNBytes(output.fp, outputFpBytes, index);
        if (output.floorData != null) {
          index.writeBytes(
              output.floorData.bytes, output.floorData.offset, output.floorData.length);
        }

        node.fp = writeCompressedPathChain(node.labels, bottomFp, startFP, index);
        stack.pop();
        continue;
      }

      // If there are any children have not been saved, push the first one into stack and continue.
      // We want to ensure saving children before parent.

      if (node.savedTo == null) {
        node.savedTo = node.firstChild;
        stack.push(node.savedTo);
        continue;
      }
      if (node.savedTo.next != null) {
        assert node.savedTo.fp >= 0;
        node.savedTo = node.savedTo.next;
        stack.push(node.savedTo);
        continue;
      }

      // All children have been written, now it's time to write the parent!

      assert assertNonLeafNodePreparingSaving(node);
      long bottomFp = index.getFilePointer() - startFP;

      if (childrenNum == 1) {

        // [n bytes] floor data
        // [n bytes] encoded output fp | [n bytes] child fp | [1 byte] label
        // [3bit] encoded output fp bytes | [3bit] child fp bytes | [2bit] sign

        long childDeltaFp = bottomFp - node.firstChild.fp;
        assert childDeltaFp > 0 : "parent node is always written after children: " + childDeltaFp;
        int childFpBytes = bytesRequiredVLong(childDeltaFp);
        int encodedOutputFpBytes =
            node.output == null ? 0 : bytesRequiredVLong(node.output.fp << 2);

        // TODO if we have only one child and no output, we can store child labels in this node.
        // E.g. for a single term trie [foobar], we can save only two nodes [fooba] and [r]

        int sign =
            node.output != null ? SIGN_SINGLE_CHILD_WITH_OUTPUT : SIGN_SINGLE_CHILD_WITHOUT_OUTPUT;
        int header = sign | ((childFpBytes - 1) << 2) | ((encodedOutputFpBytes - 1) << 5);
        index.writeByte((byte) header);
        index.writeByte(node.firstChild.labels[0]);
        writeLongNBytes(childDeltaFp, childFpBytes, index);

        if (node.output != null) {
          Output output = node.output;
          long encodedFp = encodeFP(output);
          writeLongNBytes(encodedFp, encodedOutputFpBytes, index);
          if (output.floorData != null) {
            index.writeBytes(
                output.floorData.bytes, output.floorData.offset, output.floorData.length);
          }
        }
      } else {

        // [n bytes] floor data
        // [n bytes] children fps | [n bytes] strategy data
        // [1 byte] children count (if floor data) | [n bytes] encoded output fp | [1 byte] label
        // [5bit] strategy bytes | 2bit children strategy | [3bit] encoded output fp bytes
        // [1bit] has output | [3bit] children fp bytes | [2bit] sign

        final int minLabel = node.firstChild.labels[0] & 0xFF;
        final int maxLabel = node.lastChild.labels[0] & 0xFF;
        assert maxLabel > minLabel;
        ChildSaveStrategy childSaveStrategy =
            ChildSaveStrategy.choose(minLabel, maxLabel, childrenNum);
        int strategyBytes = childSaveStrategy.needBytes(minLabel, maxLabel, childrenNum);
        assert strategyBytes > 0 && strategyBytes <= 32;

        // children fps are in order, so the first child's fp is min, then delta is max.
        long maxChildDeltaFp = bottomFp - node.firstChild.fp;
        assert maxChildDeltaFp > 0 : "parent always written after all children";

        int childrenFpBytes = bytesRequiredVLong(maxChildDeltaFp);
        int encodedOutputFpBytes =
            node.output == null ? 1 : bytesRequiredVLong(node.output.fp << 2);
        int header =
            SIGN_MULTI_CHILDREN
                | ((childrenFpBytes - 1) << 2)
                | ((node.output != null ? 1 : 0) << 5)
                | ((encodedOutputFpBytes - 1) << 6)
                | (childSaveStrategy.code << 9)
                | ((strategyBytes - 1) << 11)
                | (minLabel << 16);

        writeLongNBytes(header, 3, index);

        if (node.output != null) {
          Output output = node.output;
          long encodedFp = encodeFP(output);
          writeLongNBytes(encodedFp, encodedOutputFpBytes, index);
          if (output.floorData != null) {
            // We need this childrenNum to compute where the floor data start.
            index.writeByte((byte) (childrenNum - 1));
          }
        }

        long strategyStartFp = index.getFilePointer();
        childSaveStrategy.save(node, childrenNum, strategyBytes, index);
        assert index.getFilePointer() == strategyStartFp + strategyBytes
            : childSaveStrategy.name()
                + " strategy bytes compute error, computed: "
                + strategyBytes
                + " actual: "
                + (index.getFilePointer() - strategyStartFp);

        for (Node child = node.firstChild; child != null; child = child.next) {
          assert bottomFp > child.fp : "parent always written after all children";
          writeLongNBytes(bottomFp - child.fp, childrenFpBytes, index);
        }

        if (node.output != null && node.output.floorData != null) {
          BytesRef floorData = node.output.floorData;
          index.writeBytes(floorData.bytes, floorData.offset, floorData.length);
        }
      }

      node.fp = writeCompressedPathChain(node.labels, bottomFp, startFP, index);
      stack.pop();
    }
  }

  private long writeCompressedPathChain(byte[] labels, long bottomFp, long startFP, IndexOutput index) throws IOException {
    long currentChildFp = bottomFp;
    for (int i = labels.length - 2; i >= 0; i--) {
      long currentFp = index.getFilePointer() - startFP;
      long childDeltaFp = currentFp - currentChildFp;
      int childFpBytes = bytesRequiredVLong(childDeltaFp);

      int sign = SIGN_SINGLE_CHILD_WITHOUT_OUTPUT;
      int header = sign | ((childFpBytes - 1) << 2);
      index.writeByte((byte) header);
      index.writeByte(labels[i + 1]);
      writeLongNBytes(childDeltaFp, childFpBytes, index);

      currentChildFp = currentFp;
    }
    return currentChildFp;
  }

  private long encodeFP(Output output) {
    assert output.fp < 1L << 62;
    return (output.floorData != null ? NON_LEAF_NODE_HAS_FLOOR : 0)
        | (output.hasTerms ? NON_LEAF_NODE_HAS_TERMS : 0)
        | (output.fp << 2);
  }

  private static int bytesRequiredVLong(long v) {
    return Long.BYTES - (Long.numberOfLeadingZeros(v | 1) >>> 3);
  }

  /**
   * Write the first (LSB order) n bytes of the given long v into the DataOutput.
   *
   * <p>This differs from writeVLong because it can write more bytes than would be needed for vLong
   * when the incoming int n is larger.
   */
  private static void writeLongNBytes(long v, int n, DataOutput out) throws IOException {
    for (int i = 0; i < n; i++) {
      // Note that we sometimes write trailing 0 bytes here, when the incoming int n is bigger than
      // would be required for a "normal" vLong
      out.writeByte((byte) v);
      v >>>= 8;
    }
    assert v == 0;
  }

  private static boolean assertChildrenLabelInOrder(Node node) {
    if (node.childrenNum == 0) {
      assert node.firstChild == null;
      assert node.lastChild == null;
    } else if (node.childrenNum == 1) {
      assert node.firstChild == node.lastChild;
      assert node.firstChild.next == null;
    } else if (node.childrenNum > 1) {
      int n = 0;
      for (Node child = node.firstChild; child != null; child = child.next) {
        n++;
        assert child.next == null || (child.labels[0] & 0xFF) < (child.next.labels[0] & 0xFF)
            : " the label of children nodes should always be in strictly increasing order.";
      }
      assert node.childrenNum == n;
    }
    return true;
  }

  private static boolean assertNonLeafNodePreparingSaving(Node node) {
    assert assertChildrenLabelInOrder(node);
    assert node.childrenNum != 0;
    if (node.childrenNum == 1) {
      assert node.firstChild == node.lastChild;
      assert node.firstChild.next == null;
      assert node.savedTo == node.firstChild;
      assert node.firstChild.fp >= 0;
    } else {
      int n = 0;
      for (Node child = node.firstChild; child != null; child = child.next) {
        n++;
        assert child.fp >= 0;
        assert child.next == null || child.fp < child.next.fp
            : " the fp or children nodes should always be in order.";
      }
      assert node.childrenNum == n;
      assert node.lastChild == node.savedTo;
      assert node.savedTo.next == null;
    }
    return true;
  }

  enum ChildSaveStrategy {

    /**
     * Store children labels in a bitset, this is likely the most efficient storage as we can
     * compute position with bitCount instruction, so we give it the highest priority.
     */
    BITS(2) {
      @Override
      int needBytes(int minLabel, int maxLabel, int labelCnt) {
        int byteDistance = maxLabel - minLabel + 1;
        return (byteDistance + 7) >>> 3;
      }

      @Override
      void save(Node parent, int labelCnt, int strategyBytes, IndexOutput output)
          throws IOException {
        byte presenceBits = 1; // The first arc is always present.
        int presenceIndex = 0;
        int previousLabel = parent.firstChild.labels[0] & 0xFF;
        for (Node child = parent.firstChild.next; child != null; child = child.next) {
          int label = child.labels[0] & 0xFF;
          assert label > previousLabel;
          presenceIndex += label - previousLabel;
          while (presenceIndex >= Byte.SIZE) {
            output.writeByte(presenceBits);
            presenceBits = 0;
            presenceIndex -= Byte.SIZE;
          }
          // Set the bit at presenceIndex to flag that the corresponding arc is present.
          presenceBits |= 1 << presenceIndex;
          previousLabel = label;
        }
        assert presenceIndex == ((parent.lastChild.labels[0] & 0xFF) - (parent.firstChild.labels[0] & 0xFF)) % 8;
        assert presenceBits != 0; // The last byte is not 0.
        assert (presenceBits & (1 << presenceIndex)) != 0; // The last arc is always present.
        output.writeByte(presenceBits);
      }

      @Override
      int lookup(
          int targetLabel, RandomAccessInput in, long offset, int strategyBytes, int minLabel)
          throws IOException {
        int bitIndex = targetLabel - minLabel;
        if (bitIndex >= (strategyBytes << 3)) {
          return -1;
        }
        int wordIndex = bitIndex >>> 6;
        long wordFp = offset + (wordIndex << 3);
        long word = in.readLong(wordFp);
        long mask = 1L << bitIndex;
        if ((word & mask) == 0) {
          return -1;
        }
        int pos = 0;
        for (long fp = offset; fp < wordFp; fp += 8L) {
          pos += Long.bitCount(in.readLong(fp));
        }
        pos += Long.bitCount(word & (mask - 1));
        return pos;
      }
    },

    /**
     * Store labels in an array and lookup with binary search.
     *
     * <p>TODO: Can we use VectorAPI to speed up the lookup? we can check 64 labels once on AVX512!
     */
    ARRAY(1) {
      @Override
      int needBytes(int minLabel, int maxLabel, int labelCnt) {
        return labelCnt - 1; // min label saved
      }

      @Override
      void save(Node parent, int labelCnt, int strategyBytes, IndexOutput output)
          throws IOException {
        for (Node child = parent.firstChild.next; child != null; child = child.next) {
          output.writeByte(child.labels[0]);
        }
      }

      @Override
      int lookup(
          int targetLabel, RandomAccessInput in, long offset, int strategyBytes, int minLabel)
          throws IOException {
        int low = 0;
        int high = strategyBytes - 1;
        while (low <= high) {
          int mid = (low + high) >>> 1;
          int midLabel = in.readByte(offset + mid) & 0xFF;
          if (midLabel < targetLabel) {
            low = mid + 1;
          } else if (midLabel > targetLabel) {
            high = mid - 1;
          } else {
            return mid + 1; // min label not included, plus 1
          }
        }
        return -1;
      }
    },

    /**
     * Store labels that not existing within the range. E.g. store 10(max label) and 3, 5(absent
     * label) for [1, 2, 4, 6, 7, 8, 9, 10].
     *
     * <p>TODO: Can we use VectorAPI to speed up the lookup? we can check 64 labels once on AVX512!
     */
    REVERSE_ARRAY(0) {

      @Override
      int needBytes(int minLabel, int maxLabel, int labelCnt) {
        int byteDistance = maxLabel - minLabel + 1;
        return byteDistance - labelCnt + 1;
      }

      @Override
      void save(Node parent, int labelCnt, int strategyBytes, IndexOutput output)
          throws IOException {
        output.writeByte(parent.lastChild.labels[0]);
        int lastLabel = parent.firstChild.labels[0] & 0xFF;
        for (Node child = parent.firstChild.next; child != null; child = child.next) {
          while (++lastLabel < (child.labels[0] & 0xFF)) {
            output.writeByte((byte) lastLabel);
          }
        }
      }

      @Override
      int lookup(
          int targetLabel, RandomAccessInput in, long offset, int strategyBytes, int minLabel)
          throws IOException {
        int maxLabel = in.readByte(offset++) & 0xFF;
        if (targetLabel >= maxLabel) {
          return targetLabel == maxLabel ? maxLabel - minLabel - strategyBytes + 1 : -1;
        }
        if (strategyBytes == 1) {
          return targetLabel - minLabel;
        }

        int low = 0;
        int high = strategyBytes - 2;
        while (low <= high) {
          int mid = (low + high) >>> 1;
          int midLabel = in.readByte(offset + mid) & 0xFF;
          if (midLabel < targetLabel) {
            low = mid + 1;
          } else if (midLabel > targetLabel) {
            high = mid - 1;
          } else {
            return -1;
          }
        }
        return targetLabel - minLabel - low;
      }
    };

    private static final ChildSaveStrategy[] STRATEGIES_IN_PRIORITY_ORDER =
        new ChildSaveStrategy[] {BITS, ARRAY, REVERSE_ARRAY};
    private static final ChildSaveStrategy[] STRATEGIES_BY_CODE;

    static {
      STRATEGIES_BY_CODE = new ChildSaveStrategy[ChildSaveStrategy.values().length];
      for (ChildSaveStrategy strategy : ChildSaveStrategy.values()) {
        assert STRATEGIES_BY_CODE[strategy.code] == null;
        STRATEGIES_BY_CODE[strategy.code] = strategy;
      }
    }

    final int code;

    ChildSaveStrategy(int code) {
      this.code = code;
    }

    abstract int needBytes(int minLabel, int maxLabel, int labelCnt);

    abstract void save(Node parent, int labelCnt, int strategyBytes, IndexOutput output)
        throws IOException;

    abstract int lookup(
        int targetLabel, RandomAccessInput in, long offset, int strategyBytes, int minLabel)
        throws IOException;

    static ChildSaveStrategy byCode(int code) {
      return STRATEGIES_BY_CODE[code];
    }

    static ChildSaveStrategy choose(int minLabel, int maxLabel, int labelCnt) {
      ChildSaveStrategy childSaveStrategy = null;
      int strategyBytes = Integer.MAX_VALUE;
      for (ChildSaveStrategy strategy : ChildSaveStrategy.STRATEGIES_IN_PRIORITY_ORDER) {
        int strategyCost = strategy.needBytes(minLabel, maxLabel, labelCnt);
        if (strategyCost < strategyBytes) {
          childSaveStrategy = strategy;
          strategyBytes = strategyCost;
        }
      }
      assert childSaveStrategy != null;
      assert strategyBytes > 0 && strategyBytes <= 32;
      return childSaveStrategy;
    }
  }
}

Copy link
Copy Markdown
Contributor

@gf2121 gf2121 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@romseygeek
Copy link
Copy Markdown
Contributor Author

Nice! I'll merge this, do you want to open a PR with the path compression changes and I'll take a look?

@romseygeek romseygeek merged commit 198cc93 into apache:main Apr 22, 2026
13 checks passed
romseygeek added a commit that referenced this pull request Apr 22, 2026
TrieBuilder.Node objects are constructed during field merges, and last as
long as the merge for a given field.  They can use significant heap memory,
particularly for fields with many long terms (eg fields that store uuids
or URLs). Two of the member fields on Node are only used at the end of a
merge when the trie is being written, and so are effectively wasted memory
during trie construction.

This commit moves these two member fields out of Node onto a separate
SaveFrame class, which is only built when the trie is saved. This should
reduce the amount of memory held by Nodes by around a quarter.
@gf2121
Copy link
Copy Markdown
Contributor

gf2121 commented Apr 22, 2026

Nice! I'll merge this, do you want to open a PR with the path compression changes and I'll take a look?

I thought about this a bit more, and I realized that we do not really need the Tree structure on the heap at all. We can just use an efficient prefix encoding and delay the building of the trie until the save phase. I opened #15977.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants