diff --git a/Directory.Packages.props b/Directory.Packages.props index a33ea0a1fec..c7f8278c249 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -32,5 +32,6 @@ + \ No newline at end of file diff --git a/libs/client/ClientSession/GarnetClientSessionMigrationExtensions.cs b/libs/client/ClientSession/GarnetClientSessionMigrationExtensions.cs index 7662b533f83..9ac7428ef40 100644 --- a/libs/client/ClientSession/GarnetClientSessionMigrationExtensions.cs +++ b/libs/client/ClientSession/GarnetClientSessionMigrationExtensions.cs @@ -25,6 +25,7 @@ public sealed unsafe partial class GarnetClientSession : IServerHook, IMessageCo static ReadOnlySpan MAIN_STORE => "SSTORE"u8; static ReadOnlySpan OBJECT_STORE => "OSTORE"u8; + static ReadOnlySpan VECTOR_STORE => "VSTORE"u8; static ReadOnlySpan T => "T"u8; static ReadOnlySpan F => "F"u8; @@ -170,14 +171,30 @@ public Task SetSlotRange(Memory state, string nodeid, List<(int, i /// /// /// - public void SetClusterMigrateHeader(string sourceNodeId, bool replace, bool isMainStore) + public void SetClusterMigrateHeader(string sourceNodeId, bool replace, bool isMainStore, bool isVectorSets) { currTcsIterationTask = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); tcsQueue.Enqueue(currTcsIterationTask); curr = offset; this.isMainStore = isMainStore; this.ist = IncrementalSendType.MIGRATE; - var storeType = isMainStore ? MAIN_STORE : OBJECT_STORE; + ReadOnlySpan storeType; + if (isMainStore) + { + if (isVectorSets) + { + storeType = VECTOR_STORE; + } + else + { + storeType = MAIN_STORE; + } + } + else + { + storeType = OBJECT_STORE; + } + var replaceOption = replace ? T : F; var arraySize = 6; @@ -249,7 +266,7 @@ public void SetClusterMigrateHeader(string sourceNodeId, bool replace, bool isMa /// public Task CompleteMigrate(string sourceNodeId, bool replace, bool isMainStore) { - SetClusterMigrateHeader(sourceNodeId, replace, isMainStore); + SetClusterMigrateHeader(sourceNodeId, replace, isMainStore, isVectorSets: false); Debug.Assert(end - curr >= 2); *curr++ = (byte)'\r'; diff --git a/libs/cluster/Server/ClusterManager.cs b/libs/cluster/Server/ClusterManager.cs index 68f86b9171f..a6e1f026773 100644 --- a/libs/cluster/Server/ClusterManager.cs +++ b/libs/cluster/Server/ClusterManager.cs @@ -240,22 +240,27 @@ public string GetInfo() public static string GetRange(int[] slots) { var range = "> "; - var start = slots[0]; - var end = slots[0]; - for (var i = 1; i < slots.Length + 1; i++) + if (slots.Length >= 1) { - if (i < slots.Length && slots[i] == end + 1) - end = slots[i]; - else + + var start = slots[0]; + var end = slots[0]; + for (var i = 1; i < slots.Length + 1; i++) { - range += $"{start}-{end} "; - if (i < slots.Length) - { - start = slots[i]; + if (i < slots.Length && slots[i] == end + 1) end = slots[i]; + else + { + range += $"{start}-{end} "; + if (i < slots.Length) + { + start = slots[i]; + end = slots[i]; + } } } } + return range; } diff --git a/libs/cluster/Server/ClusterManagerSlotState.cs b/libs/cluster/Server/ClusterManagerSlotState.cs index a35e474a263..0ef36402b84 100644 --- a/libs/cluster/Server/ClusterManagerSlotState.cs +++ b/libs/cluster/Server/ClusterManagerSlotState.cs @@ -17,7 +17,10 @@ namespace Garnet.cluster SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; /// /// Cluster manager diff --git a/libs/cluster/Server/ClusterManagerWorkerState.cs b/libs/cluster/Server/ClusterManagerWorkerState.cs index 7f9b23599ba..0775debb7b2 100644 --- a/libs/cluster/Server/ClusterManagerWorkerState.cs +++ b/libs/cluster/Server/ClusterManagerWorkerState.cs @@ -222,6 +222,10 @@ public bool TryAddReplica(string nodeid, bool force, bool upgradeLock, out ReadO clusterProvider.replicationManager.EndRecovery(RecoveryStatus.NoRecovery, downgradeLock: false); } } + + clusterProvider.storeWrapper.SuspendPrimaryOnlyTasks().Wait(); + clusterProvider.storeWrapper.StartReplicaTasks(); + FlushConfig(); return true; } diff --git a/libs/cluster/Server/ClusterProvider.cs b/libs/cluster/Server/ClusterProvider.cs index 51ac87401f0..9af5cf6a02e 100644 --- a/libs/cluster/Server/ClusterProvider.cs +++ b/libs/cluster/Server/ClusterProvider.cs @@ -15,12 +15,21 @@ namespace Garnet.cluster { + using BasicContext = BasicContext, + SpanByteAllocator>>; + using BasicGarnetApi = GarnetApi, SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; + + using VectorContext = BasicContext, SpanByteAllocator>>; /// /// Cluster provider @@ -100,8 +109,8 @@ public void Start() } /// - public IClusterSession CreateClusterSession(TransactionManager txnManager, IGarnetAuthenticator authenticator, UserHandle userHandle, GarnetSessionMetrics garnetSessionMetrics, BasicGarnetApi basicGarnetApi, INetworkSender networkSender, ILogger logger = null) - => new ClusterSession(this, txnManager, authenticator, userHandle, garnetSessionMetrics, basicGarnetApi, networkSender, logger); + public IClusterSession CreateClusterSession(TransactionManager txnManager, IGarnetAuthenticator authenticator, UserHandle userHandle, GarnetSessionMetrics garnetSessionMetrics, BasicGarnetApi basicGarnetApi, BasicContext basicContext, VectorContext vectorContext, INetworkSender networkSender, ILogger logger = null) + => new ClusterSession(this, txnManager, authenticator, userHandle, garnetSessionMetrics, basicGarnetApi, basicContext, vectorContext, networkSender, logger); /// public void UpdateClusterAuth(string clusterUsername, string clusterPassword) diff --git a/libs/cluster/Server/Failover/ReplicaFailoverSession.cs b/libs/cluster/Server/Failover/ReplicaFailoverSession.cs index 62ffea2f996..8feebf8d958 100644 --- a/libs/cluster/Server/Failover/ReplicaFailoverSession.cs +++ b/libs/cluster/Server/Failover/ReplicaFailoverSession.cs @@ -307,6 +307,9 @@ public async Task BeginAsyncReplicaFailover() // Attach to old replicas, and old primary if DEFAULT option await IssueAttachReplicas(); + await clusterProvider.storeWrapper.SuspendReplicaOnlyTasks(); + clusterProvider.storeWrapper.StartPrimaryTasks(); + return true; } catch (Exception ex) diff --git a/libs/cluster/Server/Migration/MigrateOperation.cs b/libs/cluster/Server/Migration/MigrateOperation.cs index d4f069a8189..3f677c959ee 100644 --- a/libs/cluster/Server/Migration/MigrateOperation.cs +++ b/libs/cluster/Server/Migration/MigrateOperation.cs @@ -2,9 +2,11 @@ // Licensed under the MIT license. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using Garnet.client; using Garnet.server; +using Microsoft.Extensions.Logging; using Tsavorite.core; namespace Garnet.cluster @@ -18,16 +20,25 @@ internal sealed partial class MigrateOperation public MainStoreScan mss; public ObjectStoreScan oss; + private readonly ConcurrentDictionary vectorSetsIndexKeysToMigrate; + readonly MigrateSession session; readonly GarnetClientSession gcs; readonly LocalServerSession localServerSession; public GarnetClientSession Client => gcs; + public IEnumerable> VectorSets => vectorSetsIndexKeysToMigrate; + public void ThrowIfCancelled() => session._cts.Token.ThrowIfCancellationRequested(); public bool Contains(int slot) => session._sslots.Contains(slot); + public bool ContainsNamespace(ulong ns) => session._namespaces?.Contains(ns) ?? false; + + public void EncounteredVectorSet(byte[] key, byte[] value) + => vectorSetsIndexKeysToMigrate.TryAdd(key, value); + public MigrateOperation(MigrateSession session, Sketch sketch = null, int batchSize = 1 << 18) { this.session = session; @@ -37,6 +48,7 @@ public MigrateOperation(MigrateSession session, Sketch sketch = null, int batchS mss = new MainStoreScan(this); oss = new ObjectStoreScan(this); keysToDelete = []; + vectorSetsIndexKeysToMigrate = new(ByteArrayComparer.Instance); } public bool Initialize() @@ -72,7 +84,7 @@ public void Scan(StoreType storeType, ref long currentAddress, long endAddress) /// /// /// - public bool TrasmitSlots(StoreType storeType) + public bool TransmitSlots(StoreType storeType) { var bufferSize = 1 << 10; SectorAlignedMemory buffer = new(bufferSize, 1); @@ -87,7 +99,7 @@ public bool TrasmitSlots(StoreType storeType) { foreach (var key in sketch.argSliceVector) { - var spanByte = key.SpanByte; + var spanByte = key; if (!session.WriteOrSendMainStoreKeyValuePair(gcs, localServerSession, ref spanByte, ref input, ref o, out _)) return false; @@ -117,7 +129,10 @@ public bool TrasmitSlots(StoreType storeType) return true; } - public bool TransmitKeys(StoreType storeType) + /// + /// Move keys in sketch out of the given store, UNLESS they are also in . + /// + public bool TransmitKeys(StoreType storeType, Dictionary vectorSetKeysToIgnore) { var bufferSize = 1 << 10; SectorAlignedMemory buffer = new(bufferSize, 1); @@ -131,12 +146,30 @@ public bool TransmitKeys(StoreType storeType) var keys = sketch.Keys; if (storeType == StoreType.Main) { +#if NET9_0_OR_GREATER + var ignoreLookup = vectorSetKeysToIgnore.GetAlternateLookup>(); +#endif + for (var i = 0; i < keys.Count; i++) { if (keys[i].Item2) continue; var spanByte = keys[i].Item1.SpanByte; + + // Don't transmit if a Vector Set + var isVectorSet = + vectorSetKeysToIgnore.Count > 0 && +#if NET9_0_OR_GREATER + ignoreLookup.ContainsKey(spanByte.AsReadOnlySpan()); +#else + vectorSetKeysToIgnore.ContainsKey(spanByte.ToByteArray()); +#endif + if (isVectorSet) + { + continue; + } + if (!session.WriteOrSendMainStoreKeyValuePair(gcs, localServerSession, ref spanByte, ref input, ref o, out var status)) return false; @@ -158,8 +191,8 @@ public bool TransmitKeys(StoreType storeType) if (keys[i].Item2) continue; - var argSlice = keys[i].Item1; - if (!session.WriteOrSendObjectStoreKeyValuePair(gcs, localServerSession, ref argSlice, out var status)) + var spanByte = keys[i].Item1.SpanByte; + if (!session.WriteOrSendObjectStoreKeyValuePair(gcs, localServerSession, ref spanByte, out var status)) return false; // Skip if key NOTFOUND @@ -182,6 +215,54 @@ public bool TransmitKeys(StoreType storeType) return true; } + /// + /// Transmit data in namespaces during a MIGRATE ... KEYS operation. + /// + /// Doesn't delete anything, just scans and transmits. + /// + public bool TransmitKeysNamespaces(ILogger logger) + { + var migrateOperation = this; + + if (!migrateOperation.Initialize()) + return false; + + var workerStartAddress = migrateOperation.session.clusterProvider.storeWrapper.store.Log.BeginAddress; + var workerEndAddress = migrateOperation.session.clusterProvider.storeWrapper.store.Log.TailAddress; + + var cursor = workerStartAddress; + logger?.LogWarning(" migrate keys (namespaces) scan range [{workerStartAddress}, {workerEndAddress}]", workerStartAddress, workerEndAddress); + while (true) + { + var current = cursor; + // Build Sketch + migrateOperation.sketch.SetStatus(SketchStatus.INITIALIZING); + migrateOperation.Scan(StoreType.Main, ref current, workerEndAddress); + + // Stop if no keys have been found + if (migrateOperation.sketch.argSliceVector.IsEmpty) break; + + logger?.LogWarning("Scan from {cursor} to {current} and discovered {count} keys", cursor, current, migrateOperation.sketch.argSliceVector.Count); + + // Transition EPSM to MIGRATING + migrateOperation.sketch.SetStatus(SketchStatus.TRANSMITTING); + migrateOperation.session.WaitForConfigPropagation(); + + // Transmit all keys gathered + migrateOperation.TransmitSlots(StoreType.Main); + + // Transition EPSM to DELETING + migrateOperation.sketch.SetStatus(SketchStatus.DELETING); + migrateOperation.session.WaitForConfigPropagation(); + + // Clear keys from buffer + migrateOperation.sketch.Clear(); + cursor = current; + } + + return true; + } + /// /// Delete keys after migration if copyOption is not set /// @@ -193,7 +274,13 @@ public void DeleteKeys() { foreach (var key in sketch.argSliceVector) { - var spanByte = key.SpanByte; + if (key.MetadataSize == 1) + { + // Namespace'd keys are not deleted here, but when migration finishes + continue; + } + + var spanByte = key; _ = localServerSession.BasicGarnetApi.DELETE(ref spanByte); } } @@ -209,6 +296,19 @@ public void DeleteKeys() } } } + + /// + /// Delete a Vector Set after migration if _copyOption is not set. + /// + public void DeleteVectorSet(ref SpanByte key) + { + if (session._copyOption) + return; + + var delRes = localServerSession.BasicGarnetApi.DELETE(ref key); + + session.logger?.LogDebug("Deleting Vector Set {key} after migration: {delRes}", System.Text.Encoding.UTF8.GetString(key.AsReadOnlySpan()), delRes); + } } } } \ No newline at end of file diff --git a/libs/cluster/Server/Migration/MigrateScanFunctions.cs b/libs/cluster/Server/Migration/MigrateScanFunctions.cs index 03cb23d1af8..25d9f5da3d3 100644 --- a/libs/cluster/Server/Migration/MigrateScanFunctions.cs +++ b/libs/cluster/Server/Migration/MigrateScanFunctions.cs @@ -36,10 +36,34 @@ public unsafe bool SingleReader(ref SpanByte key, ref SpanByte value, RecordMeta if (ClusterSession.Expired(ref value)) return true; - var s = HashSlotUtils.HashSlot(ref key); - // Check if key belongs to slot that is being migrated and if it can be added to our buffer - if (mss.Contains(s) && !mss.sketch.TryHashAndStore(key.AsSpan())) - return false; + // TODO: Some other way to detect namespaces + if (key.MetadataSize == 1) + { + var ns = key.GetNamespaceInPayload(); + + if (mss.ContainsNamespace(ns) && !mss.sketch.TryHashAndStore(ns, key.AsSpan())) + return false; + } + else + { + var s = HashSlotUtils.HashSlot(ref key); + + // Check if key belongs to slot that is being migrated... + if (mss.Contains(s)) + { + if (recordMetadata.RecordInfo.VectorSet) + { + // We can't delete the vector set _yet_ nor can we migrate it, + // we just need to remember it to migrate once the associated namespaces are all moved over + mss.EncounteredVectorSet(key.ToByteArray(), value.ToByteArray()); + } + else if (!mss.sketch.TryHashAndStore(key.AsSpan())) + { + // Out of space, end scan for now + return false; + } + } + } return true; } diff --git a/libs/cluster/Server/Migration/MigrateSession.cs b/libs/cluster/Server/Migration/MigrateSession.cs index 16c4cb481dd..cd59a66d347 100644 --- a/libs/cluster/Server/Migration/MigrateSession.cs +++ b/libs/cluster/Server/Migration/MigrateSession.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Collections.Frozen; using System.Collections.Generic; using System.Linq; using System.Net; @@ -48,6 +49,9 @@ internal sealed unsafe partial class MigrateSession : IDisposable readonly HashSet _sslots; readonly CancellationTokenSource _cts = new(); + HashSet _namespaces; + FrozenDictionary _namespaceMap; + /// /// Get endpoint of target node /// @@ -276,9 +280,10 @@ public bool TrySetSlotRanges(string nodeid, MigrateState state) Status = MigrateState.FAIL; return false; } - logger?.LogTrace("[Completed] SETSLOT {slots} {state} {nodeid}", ClusterManager.GetRange([.. _sslots]), state, nodeid == null ? "" : nodeid); + logger?.LogTrace("[Completed] SETSLOT {slots} {state} {nodeid}", ClusterManager.GetRange([.. _sslots]), state, nodeid ?? ""); return true; - }, TaskContinuationOptions.OnlyOnRanToCompletion).WaitAsync(_timeout, _cts.Token).Result; + }, TaskContinuationOptions.OnlyOnRanToCompletion) + .WaitAsync(_timeout, _cts.Token).Result; } catch (Exception ex) { @@ -338,6 +343,8 @@ public bool TryRecoverFromFailure() // This will execute the equivalent of SETSLOTRANGE STABLE for the slots of the failed migration task ResetLocalSlot(); + // TODO: Need to relinquish any migrating Vector Set contexts from target node + // Log explicit migration failure. Status = MigrateState.FAIL; return true; diff --git a/libs/cluster/Server/Migration/MigrateSessionCommonUtils.cs b/libs/cluster/Server/Migration/MigrateSessionCommonUtils.cs index 835f755a4b8..a11059bfe49 100644 --- a/libs/cluster/Server/Migration/MigrateSessionCommonUtils.cs +++ b/libs/cluster/Server/Migration/MigrateSessionCommonUtils.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Diagnostics; using System.Threading.Tasks; using Garnet.client; using Garnet.server; @@ -29,6 +30,18 @@ private bool WriteOrSendMainStoreKeyValuePair(GarnetClientSession gcs, LocalServ value = ref SpanByte.ReinterpretWithoutLength(o.Memory.Memory.Span); } + // Map up any namespaces as needed + // TODO: Better way to do "has namespace" + if (key.MetadataSize == 1) + { + var oldNs = key.GetNamespaceInPayload(); + if (_namespaceMap.TryGetValue(oldNs, out var newNs)) + { + Debug.Assert(newNs <= byte.MaxValue, "Namespace too large"); + key.SetNamespaceInPayload((byte)newNs); + } + } + // Write key to network buffer if it has not expired if (!ClusterSession.Expired(ref value) && !WriteOrSendMainStoreKeyValuePair(gcs, ref key, ref value)) return false; @@ -39,7 +52,7 @@ bool WriteOrSendMainStoreKeyValuePair(GarnetClientSession gcs, ref SpanByte key, { // Check if we need to initialize cluster migrate command arguments if (gcs.NeedsInitialization) - gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true); + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: false); // Try write serialized key value to client buffer while (!gcs.TryWriteKeyValueSpanByte(ref key, ref value, out var task)) @@ -49,15 +62,15 @@ bool WriteOrSendMainStoreKeyValuePair(GarnetClientSession gcs, ref SpanByte key, return false; // re-initialize cluster migrate command parameters - gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true); + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: false); } return true; } } - private bool WriteOrSendObjectStoreKeyValuePair(GarnetClientSession gcs, LocalServerSession localServerSession, ref ArgSlice key, out GarnetStatus status) + private bool WriteOrSendObjectStoreKeyValuePair(GarnetClientSession gcs, LocalServerSession localServerSession, ref SpanByte key, out GarnetStatus status) { - var keyByteArray = key.ToArray(); + var keyByteArray = key.AsReadOnlySpan().ToArray(); ObjectInput input = default; GarnetObjectStoreOutput value = default; @@ -81,14 +94,14 @@ bool WriteOrSendObjectStoreKeyValuePair(GarnetClientSession gcs, byte[] key, byt { // Check if we need to initialize cluster migrate command arguments if (gcs.NeedsInitialization) - gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: false); + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: false, isVectorSets: false); while (!gcs.TryWriteKeyValueByteArray(key, value, expiration, out var task)) { // Flush key value pairs in the buffer if (!HandleMigrateTaskResponse(task)) return false; - gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: false); + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: false, isVectorSets: false); } return true; } diff --git a/libs/cluster/Server/Migration/MigrateSessionKeys.cs b/libs/cluster/Server/Migration/MigrateSessionKeys.cs index 294b4ae3172..0d0de71bd2b 100644 --- a/libs/cluster/Server/Migration/MigrateSessionKeys.cs +++ b/libs/cluster/Server/Migration/MigrateSessionKeys.cs @@ -2,6 +2,8 @@ // Licensed under the MIT license. using System; +using System.Collections.Generic; +using System.Linq; using Garnet.server; using Microsoft.Extensions.Logging; using Tsavorite.core; @@ -33,13 +35,78 @@ private bool MigrateKeysFromMainStore() migrateTask.sketch.SetStatus(SketchStatus.TRANSMITTING); WaitForConfigPropagation(); + // Discover Vector Sets linked namespaces + var indexesToMigrate = new Dictionary(ByteArrayComparer.Instance); + _namespaces = clusterProvider.storeWrapper.DefaultDatabase.VectorManager.GetNamespacesForKeys(clusterProvider.storeWrapper, migrateTask.sketch.Keys.Select(t => t.Item1.ToArray()), indexesToMigrate); + + // If we have any namespaces, that implies Vector Sets, and if we have any of THOSE + // we need to reserve destination sets on the other side + if ((_namespaces?.Count ?? 0) > 0 && !ReserveDestinationVectorSetsAsync().GetAwaiter().GetResult()) + { + logger?.LogError("Failed to reserve destination vector sets, migration failed"); + return false; + } + // Transmit keys from main store - if (!migrateTask.TransmitKeys(StoreType.Main)) + if (!migrateTask.TransmitKeys(StoreType.Main, indexesToMigrate)) { logger?.LogError("Failed transmitting keys from main store"); return false; } + if ((_namespaces?.Count ?? 0) > 0) + { + // Actually move element data over + if (!migrateTask.TransmitKeysNamespaces(logger)) + { + logger?.LogError("Failed to transmit vector set (namespaced) element data, migration failed"); + return false; + } + + // Move the indexes over + var gcs = migrateTask.Client; + + foreach (var (key, value) in indexesToMigrate) + { + // Update the index context as we move it, so it arrives on the destination node pointed at the appropriate + // namespaces for element data + VectorManager.ReadIndex(value, out var oldContext, out _, out _, out _, out _, out _, out _, out _, out _); + + var newContext = _namespaceMap[oldContext]; + VectorManager.SetContextForMigration(value, newContext); + + unsafe + { + fixed (byte* keyPtr = key, valuePtr = value) + { + var keySpan = SpanByte.FromPinnedPointer(keyPtr, key.Length); + var valSpan = SpanByte.FromPinnedPointer(valuePtr, value.Length); + + if (gcs.NeedsInitialization) + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: true); + + while (!gcs.TryWriteKeyValueSpanByte(ref keySpan, ref valSpan, out var task)) + { + if (!HandleMigrateTaskResponse(task)) + { + logger?.LogCritical("Failed to migrate Vector Set key {key} during migration", keySpan); + return false; + } + + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: true); + } + } + } + } + + if (!HandleMigrateTaskResponse(gcs.SendAndResetIterationBuffer())) + { + logger?.LogCritical("Final flush after Vector Set migration failed"); + return false; + } + } + + // Final cleanup, which will also delete Vector Sets DeleteKeys(); } finally @@ -68,7 +135,7 @@ private bool MigrateKeysFromObjectStore() WaitForConfigPropagation(); // Transmit keys from object store - if (!migrateTask.TransmitKeys(StoreType.Object)) + if (!migrateTask.TransmitKeys(StoreType.Object, new(ByteArrayComparer.Instance))) { logger?.LogError("Failed transmitting keys from object store"); return false; diff --git a/libs/cluster/Server/Migration/MigrateSessionSlots.cs b/libs/cluster/Server/Migration/MigrateSessionSlots.cs index 0d153cc4aa0..5c54d00e37c 100644 --- a/libs/cluster/Server/Migration/MigrateSessionSlots.cs +++ b/libs/cluster/Server/Migration/MigrateSessionSlots.cs @@ -2,17 +2,68 @@ // Licensed under the MIT license. using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; using System.Threading.Tasks; #if DEBUG using Garnet.common; #endif using Garnet.server; using Microsoft.Extensions.Logging; +using Tsavorite.core; namespace Garnet.cluster { internal sealed partial class MigrateSession : IDisposable { + /// + /// Attempts to reserve contexts on the destination node for migrating vector sets. + /// + /// This maps roughly to "for each namespaces, reserve one context, record the mapping". + /// + public async Task ReserveDestinationVectorSetsAsync() + { + Debug.Assert((_namespaces.Count % (int)VectorManager.ContextStep) == 0, "Expected to be migrating Vector Sets, and thus to have an even number of namespaces"); + + var neededContexts = _namespaces.Count / (int)VectorManager.ContextStep; + + try + { + var reservedCtxs = await migrateOperation[0].Client.ExecuteForArrayAsync("CLUSTER", "RESERVE", "VECTOR_SET_CONTEXTS", neededContexts.ToString()); + + var rootNamespacesMigrating = _namespaces.Where(static x => (x % VectorManager.ContextStep) == 0); + + var nextReservedIx = 0; + + var namespaceMap = new Dictionary(); + + foreach (var migratingContext in rootNamespacesMigrating) + { + var toMapTo = ulong.Parse(reservedCtxs[nextReservedIx]); + for (var i = 0U; i < VectorManager.ContextStep; i++) + { + var fromCtx = migratingContext + i; + var toCtx = toMapTo + i; + + namespaceMap[fromCtx] = toCtx; + } + + nextReservedIx++; + } + + _namespaceMap = namespaceMap.ToFrozenDictionary(); + + return true; + } + catch (Exception ex) + { + logger?.LogError(ex, "Failed to reserve {count} Vector Set contexts on destination node {node}", neededContexts, _targetNodeId); + return false; + } + } + /// /// Migrate Slots inline driver /// @@ -61,6 +112,60 @@ async Task CreateAndRunMigrateTasks(StoreType storeType, long beginAddress try { await Task.WhenAll(migrateOperationRunners).WaitAsync(_timeout, _cts.Token).ConfigureAwait(false); + + // Handle migration of discovered Vector Set keys now that they're namespaces have been moved + if (storeType == StoreType.Main) + { + var vectorSets = migrateOperation.SelectMany(static mo => mo.VectorSets).GroupBy(static g => g.Key, ByteArrayComparer.Instance).ToDictionary(static g => g.Key, g => g.First().Value, ByteArrayComparer.Instance); + + if (vectorSets.Count > 0) + { + var gcs = migrateOperation[0].Client; + + foreach (var (key, value) in vectorSets) + { + // Update the index context as we move it, so it arrives on the destination node pointed at the appropriate + // namespaces for element data + VectorManager.ReadIndex(value, out var oldContext, out _, out _, out _, out _, out _, out _, out _, out _); + + var newContext = _namespaceMap[oldContext]; + VectorManager.SetContextForMigration(value, newContext); + + unsafe + { + fixed (byte* keyPtr = key, valuePtr = value) + { + var keySpan = SpanByte.FromPinnedPointer(keyPtr, key.Length); + var valSpan = SpanByte.FromPinnedPointer(valuePtr, value.Length); + + if (gcs.NeedsInitialization) + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: true); + + while (!gcs.TryWriteKeyValueSpanByte(ref keySpan, ref valSpan, out var task)) + { + if (!HandleMigrateTaskResponse(task)) + { + logger?.LogCritical("Failed to migrate Vector Set key {key} during migration", keySpan); + return false; + } + + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: true); + } + + // Force a flush before doing the delete, in case that fails + if (!HandleMigrateTaskResponse(gcs.SendAndResetIterationBuffer())) + { + logger?.LogCritical("Flush failed before deletion of Vector Set {key} duration migration", keySpan); + return false; + } + + // Delete the index on this node now that it's moved over to the destination node + migrateOperation[0].DeleteVectorSet(ref keySpan); + } + } + } + } + } } catch (Exception ex) { @@ -68,6 +173,7 @@ async Task CreateAndRunMigrateTasks(StoreType storeType, long beginAddress _cts.Cancel(); return false; } + return true; } @@ -103,7 +209,7 @@ Task ScanStoreTask(int taskId, StoreType storeType, long beginAddress, lon WaitForConfigPropagation(); // Transmit all keys gathered - migrateOperation.TrasmitSlots(storeType); + migrateOperation.TransmitSlots(storeType); // Transition EPSM to DELETING migrateOperation.sketch.SetStatus(SketchStatus.DELETING); diff --git a/libs/cluster/Server/Migration/MigrationDriver.cs b/libs/cluster/Server/Migration/MigrationDriver.cs index d2e6af5c1c2..eeda6d6d7e2 100644 --- a/libs/cluster/Server/Migration/MigrationDriver.cs +++ b/libs/cluster/Server/Migration/MigrationDriver.cs @@ -78,6 +78,19 @@ private async Task BeginAsyncMigrationTask() if (!clusterProvider.BumpAndWaitForEpochTransition()) return; #endregion + // Acquire namespaces at this point, after slots have been switch to migration + _namespaces = clusterProvider.storeWrapper.DefaultDatabase.VectorManager.GetNamespacesForHashSlots(_sslots); + + // If we have any namespaces, that implies Vector Sets, and if we have any of THOSE + // we need to reserve destination sets on the other side + if ((_namespaces?.Count ?? 0) > 0 && !await ReserveDestinationVectorSetsAsync()) + { + logger?.LogError("Failed to reserve destination vector sets, migration failed"); + TryRecoverFromFailure(); + Status = MigrateState.FAIL; + return; + } + #region migrateData // Migrate actual data if (!await MigrateSlotsDriverInline()) @@ -87,6 +100,7 @@ private async Task BeginAsyncMigrationTask() Status = MigrateState.FAIL; return; } + #endregion #region transferSlotOwnnershipToTargetNode diff --git a/libs/cluster/Server/Migration/Sketch.cs b/libs/cluster/Server/Migration/Sketch.cs index 4c1ff3e376e..59f3d0bc4a5 100644 --- a/libs/cluster/Server/Migration/Sketch.cs +++ b/libs/cluster/Server/Migration/Sketch.cs @@ -44,6 +44,19 @@ public bool TryHashAndStore(Span key) return true; } + public bool TryHashAndStore(ulong ns, Span key) + { + if (!argSliceVector.TryAddItem(ns, key)) + return false; + + var slot = (int)HashUtils.MurmurHash2x64A(key, seed: (uint)ns) & (size - 1); + var byteOffset = slot >> 3; + var bitOffset = slot & 7; + bitmap[byteOffset] = (byte)(bitmap[byteOffset] | (1UL << bitOffset)); + + return true; + } + /// /// Hash key to bloomfilter and store it for future use (NOTE: Use only with KEYS option) /// @@ -65,7 +78,19 @@ public unsafe void HashAndStore(ref ArgSlice key) /// public unsafe bool Probe(SpanByte key, out SketchStatus status) { - var slot = (int)HashUtils.MurmurHash2x64A(key.ToPointer(), key.Length) & (size - 1); + int slot; + + // TODO: better way to detect namespace + if (key.MetadataSize == 1) + { + var ns = key.GetNamespaceInPayload(); + slot = (int)HashUtils.MurmurHash2x64A(key.ToPointer(), key.Length, seed: (uint)ns) & (size - 1); + } + else + { + slot = (int)HashUtils.MurmurHash2x64A(key.ToPointer(), key.Length) & (size - 1); + } + var byteOffset = slot >> 3; var bitOffset = slot & 7; diff --git a/libs/cluster/Server/Replication/ReplicaOps/ReplicaReplayTask.cs b/libs/cluster/Server/Replication/ReplicaOps/ReplicaReplayTask.cs index f5bedf6d469..b275574c676 100644 --- a/libs/cluster/Server/Replication/ReplicaOps/ReplicaReplayTask.cs +++ b/libs/cluster/Server/Replication/ReplicaOps/ReplicaReplayTask.cs @@ -36,6 +36,7 @@ void ResetReplayCts() try { activeReplay.WriteLock(); + replicaReplayTaskCts.Dispose(); replicaReplayTaskCts = CancellationTokenSource.CreateLinkedTokenSource(ctsRepManager.Token); } diff --git a/libs/cluster/Session/ClusterCommands.cs b/libs/cluster/Session/ClusterCommands.cs index 104e05144b7..d938b710340 100644 --- a/libs/cluster/Session/ClusterCommands.cs +++ b/libs/cluster/Session/ClusterCommands.cs @@ -135,7 +135,7 @@ private bool TryParseSlots(int startIdx, out HashSet slots, out ReadOnlySpa /// Subcommand to execute. /// True if number of parameters is invalid /// True if command is fully processed, false if more processing is needed. - private void ProcessClusterCommands(RespCommand command, out bool invalidParameters) + private void ProcessClusterCommands(RespCommand command, VectorManager vectorManager, out bool invalidParameters) { _ = command switch { @@ -173,6 +173,7 @@ private void ProcessClusterCommands(RespCommand command, out bool invalidParamet RespCommand.CLUSTER_PUBLISH or RespCommand.CLUSTER_SPUBLISH => NetworkClusterPublish(out invalidParameters), RespCommand.CLUSTER_REPLICAS => NetworkClusterReplicas(out invalidParameters), RespCommand.CLUSTER_REPLICATE => NetworkClusterReplicate(out invalidParameters), + RespCommand.CLUSTER_RESERVE => NetworkClusterReserve(vectorManager, out invalidParameters), RespCommand.CLUSTER_RESET => NetworkClusterReset(out invalidParameters), RespCommand.CLUSTER_SEND_CKPT_FILE_SEGMENT => NetworkClusterSendCheckpointFileSegment(out invalidParameters), RespCommand.CLUSTER_SEND_CKPT_METADATA => NetworkClusterSendCheckpointMetadata(out invalidParameters), diff --git a/libs/cluster/Session/ClusterKeyIterationFunctions.cs b/libs/cluster/Session/ClusterKeyIterationFunctions.cs index 54d91d6cd3d..af011f3798c 100644 --- a/libs/cluster/Session/ClusterKeyIterationFunctions.cs +++ b/libs/cluster/Session/ClusterKeyIterationFunctions.cs @@ -34,6 +34,14 @@ internal sealed class MainStoreCountKeys : IScanIteratorFunctions keys, int slot, int maxKeyCount) public bool SingleReader(ref SpanByte key, ref SpanByte value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) { + // TODO: better way to detect namespace + if (key.MetadataSize == 1) + { + // Namespace means not visible + cursorRecordResult = CursorRecordResult.Skip; + return true; + } + cursorRecordResult = CursorRecordResult.Accept; // default; not used here, out CursorRecordResult cursorRecordResult + if (HashSlotUtils.HashSlot(ref key) == slot && !Expired(ref value)) keys.Add(key.ToByteArray()); return keys.Count < maxKeyCount; diff --git a/libs/cluster/Session/ClusterSession.cs b/libs/cluster/Session/ClusterSession.cs index 45780b2d2bf..bfe1f6c475a 100644 --- a/libs/cluster/Session/ClusterSession.cs +++ b/libs/cluster/Session/ClusterSession.cs @@ -12,12 +12,21 @@ namespace Garnet.cluster { + using BasicContext = BasicContext, + SpanByteAllocator>>; + using BasicGarnetApi = GarnetApi, SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; + + using VectorContext = BasicContext, SpanByteAllocator>>; internal sealed unsafe partial class ClusterSession : IClusterSession { @@ -57,7 +66,20 @@ internal sealed unsafe partial class ClusterSession : IClusterSession /// public IGarnetServer Server { get; set; } - public ClusterSession(ClusterProvider clusterProvider, TransactionManager txnManager, IGarnetAuthenticator authenticator, UserHandle userHandle, GarnetSessionMetrics sessionMetrics, BasicGarnetApi basicGarnetApi, INetworkSender networkSender, ILogger logger = null) + private VectorContext vectorContext; + private BasicContext basicContext; + + public ClusterSession( + ClusterProvider clusterProvider, + TransactionManager txnManager, + IGarnetAuthenticator authenticator, + UserHandle userHandle, + GarnetSessionMetrics sessionMetrics, + BasicGarnetApi basicGarnetApi, + BasicContext basicContext, + VectorContext vectorContext, + INetworkSender networkSender, + ILogger logger = null) { this.clusterProvider = clusterProvider; this.authenticator = authenticator; @@ -65,11 +87,13 @@ public ClusterSession(ClusterProvider clusterProvider, TransactionManager txnMan this.txnManager = txnManager; this.sessionMetrics = sessionMetrics; this.basicGarnetApi = basicGarnetApi; + this.basicContext = basicContext; + this.vectorContext = vectorContext; this.networkSender = networkSender; this.logger = logger; } - public void ProcessClusterCommands(RespCommand command, ref SessionParseState parseState, ref byte* dcurr, ref byte* dend) + public void ProcessClusterCommands(RespCommand command, VectorManager vectorManager, ref SessionParseState parseState, ref byte* dcurr, ref byte* dend) { this.dcurr = dcurr; this.dend = dend; @@ -89,7 +113,7 @@ public void ProcessClusterCommands(RespCommand command, ref SessionParseState pa return; } - ProcessClusterCommands(command, out invalidParameters); + ProcessClusterCommands(command, vectorManager, out invalidParameters); } else { diff --git a/libs/cluster/Session/MigrateCommand.cs b/libs/cluster/Session/MigrateCommand.cs index 897ca187e02..f884d4d27b1 100644 --- a/libs/cluster/Session/MigrateCommand.cs +++ b/libs/cluster/Session/MigrateCommand.cs @@ -13,7 +13,7 @@ namespace Garnet.cluster { internal sealed unsafe partial class ClusterSession : IClusterSession { - public static bool Expired(ref SpanByte value) => value.MetadataSize > 0 && value.ExtraMetadata < DateTimeOffset.UtcNow.Ticks; + public static bool Expired(ref SpanByte value) => value.MetadataSize == 8 && value.ExtraMetadata < DateTimeOffset.UtcNow.Ticks; public static bool Expired(ref IGarnetObject value) => value.Expiration != 0 && value.Expiration < DateTimeOffset.UtcNow.Ticks; diff --git a/libs/cluster/Session/ReplicaOfCommand.cs b/libs/cluster/Session/ReplicaOfCommand.cs index bdcdb6fbbe8..8dbb0330224 100644 --- a/libs/cluster/Session/ReplicaOfCommand.cs +++ b/libs/cluster/Session/ReplicaOfCommand.cs @@ -25,7 +25,7 @@ private bool TryREPLICAOF(out bool invalidParameters) var addressSpan = parseState.GetArgSliceByRef(0).ReadOnlySpan; var portSpan = parseState.GetArgSliceByRef(1).ReadOnlySpan; - // Turn of replication and make replica into a primary but do not delete data + // Turn off replication and make replica into a primary but do not delete data if (addressSpan.EqualsUpperCaseSpanIgnoringCase("NO"u8) && portSpan.EqualsUpperCaseSpanIgnoringCase("ONE"u8)) { @@ -45,6 +45,7 @@ private bool TryREPLICAOF(out bool invalidParameters) clusterProvider.replicationManager.TryUpdateForFailover(); clusterProvider.replicationManager.ResetReplayIterator(); UnsafeBumpAndWaitForEpochTransition(); + clusterProvider.storeWrapper.SuspendReplicaOnlyTasks().Wait(); clusterProvider.storeWrapper.StartPrimaryTasks(); } finally @@ -84,6 +85,8 @@ private bool TryREPLICAOF(out bool invalidParameters) clusterProvider.replicationManager.TryReplicateDisklessSync(this, syncOpts, out var errorMessage) : clusterProvider.replicationManager.TryReplicateDiskbasedSync(this, syncOpts, out errorMessage); + clusterProvider.storeWrapper.StartReplicaTasks(); + if (!success) { while (!RespWriteUtils.TryWriteError(errorMessage, ref dcurr, dend)) diff --git a/libs/cluster/Session/RespClusterMigrateCommands.cs b/libs/cluster/Session/RespClusterMigrateCommands.cs index 3dd58cf82a1..5fe9c8d1c4c 100644 --- a/libs/cluster/Session/RespClusterMigrateCommands.cs +++ b/libs/cluster/Session/RespClusterMigrateCommands.cs @@ -17,7 +17,10 @@ namespace Garnet.cluster SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; internal sealed unsafe partial class ClusterSession : IClusterSession { @@ -103,18 +106,30 @@ void Process(BasicGarnetApi basicGarnetApi, byte[] input, string storeTypeSpan, continue; } - var slot = HashSlotUtils.HashSlot(ref key); - if (!currentConfig.IsImportingSlot(slot)) // Slot is not in importing state + // TODO: better way to handle namespaces + if (key.MetadataSize == 1) { - migrateState = 1; - i++; - continue; + // This is a Vector Set namespace key being migrated - it won't necessarily look like it's "in" a hash slot + // because it's dependent on some other key (the index key) being migrated which itself is in a moving hash slot + + clusterProvider.storeWrapper.DefaultDatabase.VectorManager.HandleMigratedElementKey(ref basicContext, ref vectorContext, ref key, ref value); + } + else + { + var slot = HashSlotUtils.HashSlot(ref key); + if (!currentConfig.IsImportingSlot(slot)) // Slot is not in importing state + { + migrateState = 1; + i++; + continue; + } + + // Set if key replace flag is set or key does not exist + var keySlice = new ArgSlice(key.ToPointer(), key.Length); + if (replaceOption || !Exists(ref keySlice)) + _ = basicGarnetApi.SET(ref key, ref value); } - // Set if key replace flag is set or key does not exist - var keySlice = new ArgSlice(key.ToPointer(), key.Length); - if (replaceOption || !Exists(ref keySlice)) - _ = basicGarnetApi.SET(ref key, ref value); i++; } } @@ -150,6 +165,35 @@ void Process(BasicGarnetApi basicGarnetApi, byte[] input, string storeTypeSpan, i++; } } + else if (storeTypeSpan.Equals("VSTORE", StringComparison.OrdinalIgnoreCase)) + { + // This is the subset of the main store that holds Vector Set _index_ keys + // + // Namespace'd element keys are handled by the SSTORE path + + var keyCount = *(int*)payloadPtr; + payloadPtr += 4; + var i = 0; + + TrackImportProgress(keyCount, isMainStore: true, keyCount == 0); + while (i < keyCount) + { + ref var key = ref SpanByte.Reinterpret(payloadPtr); + payloadPtr += key.TotalSize; + ref var value = ref SpanByte.Reinterpret(payloadPtr); + payloadPtr += value.TotalSize; + + // An error has occurred + if (migrateState > 0) + { + i++; + continue; + } + + clusterProvider.storeWrapper.DefaultDatabase.VectorManager.HandleMigratedIndexKey(clusterProvider.storeWrapper.DefaultDatabase, clusterProvider.storeWrapper, ref key, ref value); + i++; + } + } else { throw new Exception("CLUSTER MIGRATE STORE TYPE ERROR!"); diff --git a/libs/cluster/Session/RespClusterReplicationCommands.cs b/libs/cluster/Session/RespClusterReplicationCommands.cs index b30a3ff00f4..8bc596c8e54 100644 --- a/libs/cluster/Session/RespClusterReplicationCommands.cs +++ b/libs/cluster/Session/RespClusterReplicationCommands.cs @@ -115,6 +115,59 @@ private bool NetworkClusterReplicate(out bool invalidParameters) return true; } + /// + /// Implements CLUSTER reserve command (only for internode use). + /// + /// Allows for pre-migration reservation of certain resources. + /// + /// For now, this is only used for Vector Sets. + /// + private bool NetworkClusterReserve(VectorManager vectorManager, out bool invalidParameters) + { + if (parseState.Count < 2) + { + invalidParameters = true; + return true; + } + + var kind = parseState.GetArgSliceByRef(0); + if (!kind.ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("VECTOR_SET_CONTEXTS"u8)) + { + while (!RespWriteUtils.TryWriteError("Unrecognized reservation type"u8, ref dcurr, dend)) + SendAndReset(); + + invalidParameters = false; + return true; + } + + if (!parseState.TryGetInt(1, out var numVectorSetContexts) || numVectorSetContexts <= 0) + { + invalidParameters = true; + return true; + } + + invalidParameters = false; + + if (!vectorManager.TryReserveContextsForMigration(ref vectorContext, numVectorSetContexts, out var newContexts)) + { + while (!RespWriteUtils.TryWriteError("Insufficients contexts available to reserve"u8, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + while (!RespWriteUtils.TryWriteArrayLength(newContexts.Count, ref dcurr, dend)) + SendAndReset(); + + foreach (var ctx in newContexts) + { + while (!RespWriteUtils.TryWriteInt64AsSimpleString((long)ctx, ref dcurr, dend)) + SendAndReset(); + } + + return true; + } + /// /// Implements CLUSTER aofsync command (only for internode use) /// diff --git a/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs b/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs index 0416c064d43..73b274fb168 100644 --- a/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs +++ b/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Threading; using Garnet.server; @@ -23,9 +24,13 @@ private bool CheckIfKeyExists(byte[] key) } } - private ClusterSlotVerificationResult SingleKeySlotVerify(ref ClusterConfig config, ref ArgSlice keySlice, bool readOnly, byte SessionAsking, int slot = -1) + private ClusterSlotVerificationResult SingleKeySlotVerify(ref ClusterConfig config, ref ArgSlice keySlice, bool readOnly, byte SessionAsking, bool waitForStableSlot, int slot = -1) { - return readOnly ? SingleKeyReadSlotVerify(ref config, ref keySlice) : SingleKeyReadWriteSlotVerify(ref config, ref keySlice); + Debug.Assert(!waitForStableSlot || (waitForStableSlot && !readOnly), "Shouldn't see Vector Set writes and readonly at same time"); + + var ret = readOnly ? SingleKeyReadSlotVerify(ref config, ref keySlice) : SingleKeyReadWriteSlotVerify(waitForStableSlot, ref config, ref keySlice); + + return ret; [MethodImpl(MethodImplOptions.AggressiveInlining)] ClusterSlotVerificationResult SingleKeyReadSlotVerify(ref ClusterConfig config, ref ArgSlice keySlice) @@ -69,12 +74,20 @@ ClusterSlotVerificationResult SingleKeyReadSlotVerify(ref ClusterConfig config, } [MethodImpl(MethodImplOptions.AggressiveInlining)] - ClusterSlotVerificationResult SingleKeyReadWriteSlotVerify(ref ClusterConfig config, ref ArgSlice keySlice) + ClusterSlotVerificationResult SingleKeyReadWriteSlotVerify(bool waitForStableSlot, ref ClusterConfig config, ref ArgSlice keySlice) { var _slot = slot == -1 ? ArgSliceUtils.HashSlot(ref keySlice) : (ushort)slot; + + tryAgain: var IsLocal = config.IsLocal(_slot, readWriteSession: readWriteSession); var state = config.GetState(_slot); + if (waitForStableSlot && state is SlotState.IMPORTING or SlotState.MIGRATING) + { + WaitForSlotToStabalize(_slot, ref keySlice, ref config); + goto tryAgain; + } + // Redirect r/w requests towards primary if (config.LocalNodeRole == NodeRole.REPLICA && !readWriteSession) return new(SlotVerifiedState.MOVED, _slot); @@ -123,18 +136,35 @@ bool CanOperateOnKey(ref ArgSlice key, int slot, bool readOnly) } return Exists(ref key); } + + void WaitForSlotToStabalize(ushort slot, ref ArgSlice keySlice, ref ClusterConfig config) + { + // For Vector Set ops specifically, we need a slot to be stable (or faulted, but not migrating) before writes can proceed + // + // This isn't key specific because we can't know the Vector Sets being migrated in advance, only that the slot is moving + + do + { + ReleaseCurrentEpoch(); + _ = Thread.Yield(); + AcquireCurrentEpoch(); + + config = clusterProvider.clusterManager.CurrentConfig; + } + while (config.GetState(slot) is SlotState.IMPORTING or SlotState.MIGRATING); + } } - ClusterSlotVerificationResult MultiKeySlotVerify(ClusterConfig config, ref Span keys, bool readOnly, byte sessionAsking, int count) + ClusterSlotVerificationResult MultiKeySlotVerify(ClusterConfig config, ref Span keys, bool readOnly, byte sessionAsking, bool waitForStableSlot, int count) { var _end = count < 0 ? keys.Length : count; var slot = ArgSliceUtils.HashSlot(ref keys[0]); - var verifyResult = SingleKeySlotVerify(ref config, ref keys[0], readOnly, sessionAsking, slot); + var verifyResult = SingleKeySlotVerify(ref config, ref keys[0], readOnly, sessionAsking, waitForStableSlot, slot); for (var i = 1; i < _end; i++) { var _slot = ArgSliceUtils.HashSlot(ref keys[i]); - var _verifyResult = SingleKeySlotVerify(ref config, ref keys[i], readOnly, sessionAsking, _slot); + var _verifyResult = SingleKeySlotVerify(ref config, ref keys[i], readOnly, sessionAsking, waitForStableSlot, _slot); // Check if slot changes between keys if (_slot != slot) @@ -152,7 +182,7 @@ ClusterSlotVerificationResult MultiKeySlotVerify(ClusterConfig config, ref Sessi { ref var key = ref parseState.GetArgSliceByRef(csvi.firstKey); var slot = ArgSliceUtils.HashSlot(ref key); - var verifyResult = SingleKeySlotVerify(ref config, ref key, csvi.readOnly, csvi.sessionAsking, slot); + var verifyResult = SingleKeySlotVerify(ref config, ref key, csvi.readOnly, csvi.sessionAsking, csvi.waitForStableSlot, slot); var secondKey = csvi.firstKey + csvi.step; for (var i = secondKey; i < csvi.lastKey; i += csvi.step) @@ -161,7 +191,7 @@ ClusterSlotVerificationResult MultiKeySlotVerify(ClusterConfig config, ref Sessi continue; key = ref parseState.GetArgSliceByRef(i); var _slot = ArgSliceUtils.HashSlot(ref key); - var _verifyResult = SingleKeySlotVerify(ref config, ref key, csvi.readOnly, csvi.sessionAsking, _slot); + var _verifyResult = SingleKeySlotVerify(ref config, ref key, csvi.readOnly, csvi.sessionAsking, csvi.waitForStableSlot, _slot); // Check if slot changes between keys if (_slot != slot) diff --git a/libs/cluster/Session/SlotVerification/RespClusterIterativeSlotVerify.cs b/libs/cluster/Session/SlotVerification/RespClusterIterativeSlotVerify.cs index 3fe36867e9c..26ba3937764 100644 --- a/libs/cluster/Session/SlotVerification/RespClusterIterativeSlotVerify.cs +++ b/libs/cluster/Session/SlotVerification/RespClusterIterativeSlotVerify.cs @@ -28,14 +28,14 @@ public void ResetCachedSlotVerificationResult() /// /// /// - public bool NetworkIterativeSlotVerify(ArgSlice keySlice, bool readOnly, byte SessionAsking) + public bool NetworkIterativeSlotVerify(ArgSlice keySlice, bool readOnly, byte SessionAsking, bool waitForStableSlot) { ClusterSlotVerificationResult verifyResult; // If it is the first verification initialize the result cache if (!initialized) { - verifyResult = SingleKeySlotVerify(ref configSnapshot, ref keySlice, readOnly, SessionAsking); + verifyResult = SingleKeySlotVerify(ref configSnapshot, ref keySlice, readOnly, SessionAsking, waitForStableSlot); cachedVerificationResult = verifyResult; initialized = true; return verifyResult.state == SlotVerifiedState.OK; @@ -45,7 +45,7 @@ public bool NetworkIterativeSlotVerify(ArgSlice keySlice, bool readOnly, byte Se if (cachedVerificationResult.state != SlotVerifiedState.OK) return false; - verifyResult = SingleKeySlotVerify(ref configSnapshot, ref keySlice, readOnly, SessionAsking); + verifyResult = SingleKeySlotVerify(ref configSnapshot, ref keySlice, readOnly, SessionAsking, waitForStableSlot); // Check if slot changes between keys if (verifyResult.slot != cachedVerificationResult.slot) diff --git a/libs/cluster/Session/SlotVerification/RespClusterSlotVerify.cs b/libs/cluster/Session/SlotVerification/RespClusterSlotVerify.cs index bd93685f49e..fa3efe3dc11 100644 --- a/libs/cluster/Session/SlotVerification/RespClusterSlotVerify.cs +++ b/libs/cluster/Session/SlotVerification/RespClusterSlotVerify.cs @@ -92,17 +92,18 @@ private void WriteClusterSlotVerificationMessage(ClusterConfig config, ClusterSl /// /// /// + /// /// /// /// /// - public bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, byte sessionAsking, ref byte* dcurr, ref byte* dend, int count = -1) + public bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, byte sessionAsking, bool waitForStableSlot, ref byte* dcurr, ref byte* dend, int count = -1) { // If cluster is not enabled or a transaction is running skip slot check if (!clusterProvider.serverOptions.EnableCluster || txnManager.state == TxnState.Running) return false; var config = clusterProvider.clusterManager.CurrentConfig; - var vres = MultiKeySlotVerify(config, ref keys, readOnly, sessionAsking, count); + var vres = MultiKeySlotVerify(config, ref keys, readOnly, sessionAsking, waitForStableSlot, count); if (vres.state == SlotVerifiedState.OK) return false; diff --git a/libs/common/CountingEventSlim.cs b/libs/common/CountingEventSlim.cs new file mode 100644 index 00000000000..197130fc95a --- /dev/null +++ b/libs/common/CountingEventSlim.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Threading; + +namespace Garnet.common +{ + /// + /// A that is triggered based on a count hitting or advancing above 0 rather than + /// explicit Set and Reset calls. + /// + /// + /// Akin to a , but allows for count to go back up after hitting 0. + /// + public struct CountingEventSlim : IDisposable + { + private const int HoldForSetValue = int.MinValue / 2; + + private readonly ManualResetEventSlim resetEvent; + private int count = 0; + + private CountingEventSlim(ManualResetEventSlim resetEvent) + { + this.resetEvent = resetEvent; + } + + /// + /// Increment the internal count. + /// + /// Any caller to after this call returns but before a paired call to will block. + /// + public void Increment() + { + while (true) + { + var addRes = Interlocked.Increment(ref count); + if (addRes <= 0) + { + // Some thread is about to Set, undo and wait + + _ = Interlocked.Decrement(ref count); + _ = Thread.Yield(); + } + else + { + resetEvent.Reset(); + break; + } + } + } + + /// + /// Decrement the internal count. + /// + /// If this is the last outstanding paired to a completed , threads blocked in will be unblocked. + /// + public void Decrement() + { + var decrRes = Interlocked.Decrement(ref count); + Debug.Assert(decrRes >= 0, "Decrement fell below 0, implies unbalanced calls to Increment and Decrement"); + + if (decrRes == 0 && Interlocked.CompareExchange(ref count, HoldForSetValue, 0) == 0) + { + resetEvent.Set(); + + var unlockRes = Interlocked.Add(ref count, -HoldForSetValue); + Debug.Assert(unlockRes >= 0, "Unlock resulted in incoherent count"); + } + } + + /// + /// Block until the internal count hits 0. + /// + /// Returns true if wait was successful, false otherwise. + /// + /// Timeout for wait operation. -1 (the default) waits indefinitely, 0 returns immediately. + public readonly bool Wait(int millisecondsTimeout = -1) + => resetEvent.Wait(millisecondsTimeout); + + /// + /// Create a new . + /// + public static CountingEventSlim Create() + => new(new(true)); + + /// + public readonly void Dispose() + => resetEvent.Dispose(); + } +} \ No newline at end of file diff --git a/libs/common/ExceptionInjectionType.cs b/libs/common/ExceptionInjectionType.cs index 6388be3e81c..d5ce8123a4a 100644 --- a/libs/common/ExceptionInjectionType.cs +++ b/libs/common/ExceptionInjectionType.cs @@ -64,6 +64,18 @@ public enum ExceptionInjectionType /// /// Replication diskless sync reset cts /// - Replication_Diskless_Sync_Reset_Cts + Replication_Diskless_Sync_Reset_Cts, + /// + /// During deletion of a Vector Set, leaving it partially deleted - at a particular point of execution. + /// + VectorSet_Interrupt_Delete_0, + /// + /// During deletion of a Vector Set, leaving it partially deleted - at a particular point of execution. + /// + VectorSet_Interrupt_Delete_1, + /// + /// During deletion of a Vector Set, leaving it partially deleted - at a particular point of execution. + /// + VectorSet_Interrupt_Delete_2, } } \ No newline at end of file diff --git a/libs/common/HashSlotUtils.cs b/libs/common/HashSlotUtils.cs index f1811ce3a7e..67fbc4d29fd 100644 --- a/libs/common/HashSlotUtils.cs +++ b/libs/common/HashSlotUtils.cs @@ -10,6 +10,8 @@ namespace Garnet.common { public static unsafe class HashSlotUtils { + public const ushort MaxHashSlot = 16_383; + /// /// This table is based on the CRC-16-CCITT polynomial (0x1021) /// @@ -101,14 +103,14 @@ public static unsafe ushort HashSlot(byte* keyPtr, int ksize) var startPtr = keyPtr; var end = keyPtr + ksize; - // Find first occurence of '{' + // Find first occurrence of '{' while (startPtr < end && *startPtr != '{') { startPtr++; } // Return early if did not find '{' - if (startPtr == end) return (ushort)(Hash(keyPtr, ksize) & 16383); + if (startPtr == end) return (ushort)(Hash(keyPtr, ksize) & MaxHashSlot); var endPtr = startPtr + 1; @@ -116,10 +118,10 @@ public static unsafe ushort HashSlot(byte* keyPtr, int ksize) while (endPtr < end && *endPtr != '}') { endPtr++; } // Return early if did not find '}' after '{' - if (endPtr == end || endPtr == startPtr + 1) return (ushort)(Hash(keyPtr, ksize) & 16383); + if (endPtr == end || endPtr == startPtr + 1) return (ushort)(Hash(keyPtr, ksize) & MaxHashSlot); // Return hash for byte sequence between brackets - return (ushort)(Hash(startPtr + 1, (int)(endPtr - startPtr - 1)) & 16383); + return (ushort)(Hash(startPtr + 1, (int)(endPtr - startPtr - 1)) & MaxHashSlot); } } } \ No newline at end of file diff --git a/libs/common/ReadOptimizedLock.cs b/libs/common/ReadOptimizedLock.cs new file mode 100644 index 00000000000..2d3e76b8861 --- /dev/null +++ b/libs/common/ReadOptimizedLock.cs @@ -0,0 +1,408 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Garnet.common +{ + /// + /// Holds a set of RW-esque locks, optimized for reads. + /// + /// This was originally created for Vector Sets, but is general enough for reuse. + /// For Vector Sets, these are acquired and released as needed to prevent concurrent creation/deletion operations or deletion concurrent with read operations. + /// + /// These are outside of Tsavorite for re-entrancy reasons reasons. + /// + /// + /// This is a counter based r/w lock scheme, with a bit of biasing for cache line awareness. + /// + /// Each "key" acquires locks based on its hash. + /// Each hash is mapped to a range of indexes, each range is lockShardCount in length. + /// When acquiring a shared lock, we take one index out of the keys range and acquire a read lock. + /// This will block exclusive locks, but not impact other readers. + /// When acquiring an exclusive lock, we acquire write locks for all indexes in the key's range IN INCREASING _LOGICAL_ ORDER. + /// The order is necessary to avoid deadlocks. + /// By ensuring all exclusive locks walk "up" we guarantee no two exclusive lock acquisitions end up waiting for each other. + /// + /// Locks themselves are just ints, where a negative value indicates an exclusive lock and a positive value is the number of active readers. + /// Read locks are acquired optimistically, so actual lock values will fluctate above int.MinValue when an exclusive lock is held. + /// + /// The last set of optimizations is around cache lines coherency: + /// We assume cache lines of 64-bytes (the x86 default, which is also true for some [but not all] ARM processors) and size counters-per-core in multiples of that + /// We access array elements via reference, to avoid thrashing cache lines due to length checks + /// Each shard is placed, in so much as is possible, into a different cache line rather than grouping a hash's counts physically near each other + /// This will tend to allow a core to retain ownership of the same cache lines even as it moves between different hashes + /// + /// Experimentally (using some rough microbenchmarks) various optimizations are worth (on either shared or exclusive acquisiton paths): + /// - Split shards across cache lines : 7x (read path), 2.5x (write path) + /// - Fast math instead of mod and mult : 50% (read path), 20% (write path) + /// - Unsafe ref instead of array access: 0% (read path), 10% (write path) + /// + public struct ReadOptimizedLock + { + // Beyond 4K bytes per core we're well past "this is worth the tradeoff", so cut off then. + // + // Must be a power of 2. + private const int MaxPerCoreContexts = 1_024; + + /// + /// Estimated size of cache lines on a processor. + /// + /// Generally correct for x86-derived processors, sometimes correct for ARM-derived ones. + /// + public const int CacheLineSizeBytes = 64; + + [ThreadStatic] + private static int ProcessorHint; + + private readonly int[] lockCounts; + private readonly int coreSelectionMask; + private readonly int perCoreCounts; + private readonly ulong perCoreCountsFastMod; + private readonly byte perCoreCountsMultShift; + + /// + /// Create a new . + /// + /// accuracy impacts performance, not correctness. + /// + /// Too low and unrelated locks will end up delaying each other. + /// Too high and more memory than is necessary will be used. + /// + public ReadOptimizedLock(int estimatedSimultaneousActiveLockers) + { + Debug.Assert(estimatedSimultaneousActiveLockers > 0); + + // ~1 per core + var coreCount = (int)BitOperations.RoundUpToPowerOf2((uint)Environment.ProcessorCount); + coreSelectionMask = coreCount - 1; + + // Use estimatedSimultaneousActiveLockers to determine number of shards per lock. + // + // We scale up to a whole multiple of CacheLineSizeBytes to reduce cache line thrashing. + // + // We scale to a power of 2 to avoid divisions (and some multiplies) in index calculation. + perCoreCounts = estimatedSimultaneousActiveLockers; + if (perCoreCounts % (CacheLineSizeBytes / sizeof(int)) != 0) + { + perCoreCounts += (CacheLineSizeBytes / sizeof(int)) - (perCoreCounts % (CacheLineSizeBytes / sizeof(int))); + } + Debug.Assert(perCoreCounts % (CacheLineSizeBytes / sizeof(int)) == 0, "Each core should be whole cache lines of data"); + + perCoreCounts = (int)BitOperations.RoundUpToPowerOf2((uint)perCoreCounts); + + // Put an upper bound of ~1 page worth of locks per core (which is still quite high). + // + // For the largest realistic machines out there (384 cores) this will put us at around ~2M of lock data, max. + if (perCoreCounts is <= 0 or > MaxPerCoreContexts) + { + perCoreCounts = MaxPerCoreContexts; + } + + // Pre-calculate an alternative to %, as that division will be in the hot path + perCoreCountsFastMod = (ulong.MaxValue / (uint)perCoreCounts) + 1; + + // Avoid two multiplies in the hot path + perCoreCountsMultShift = (byte)BitOperations.Log2((uint)perCoreCounts); + + var numInts = coreCount * perCoreCounts; + lockCounts = new int[numInts]; + } + + /// + /// Take a hash and a _hint_ about the current processor and determine which count should be used. + /// + /// Walking from 0 to ( + 1) [exclusive] will return + /// all possible counts for a given hash. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public readonly int CalculateIndex(long hashLong, int currentProcessorHint) + { + // Throw away half the top half of the hash + // + // This set of locks will be small enough that the extra bits shoulnd't matter + var hash = (int)hashLong; + + // Hint might be out of range, so force it into the space we expect + var currentProcessor = currentProcessorHint & coreSelectionMask; + + var startOfCoreCounts = currentProcessor << perCoreCountsMultShift; + + // Avoid doing a division in the hot path + // Based on: https://github.com/dotnet/runtime/blob/3a95842304008b9ca84c14b4bec9ec99ed5802db/src/libraries/System.Private.CoreLib/src/System/Collections/HashHelpers.cs#L99 + var hashOffset = (uint)(((((perCoreCountsFastMod * (uint)hash) >> 32) + 1) << perCoreCountsMultShift) >> 32); + + Debug.Assert(hashOffset == ((uint)hash % perCoreCounts), "Replacing mod with multiplies failed"); + + var ix = (int)(startOfCoreCounts + hashOffset); + + Debug.Assert(ix >= 0 && ix < lockCounts.Length, "About to do something out of bounds"); + + return ix; + } + + /// + /// Attempt to acquire a shared lock for the given hash. + /// + /// Will block exclusive locks until released. + /// + public readonly bool TryAcquireSharedLock(long hash, out int lockToken) + { + var ix = CalculateIndex(hash, GetProcessorHint()); + + ref var acquireRef = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(lockCounts), ix); + + var res = Interlocked.Increment(ref acquireRef); + if (res < 0) + { + // Exclusively locked + _ = Interlocked.Decrement(ref acquireRef); + Unsafe.SkipInit(out lockToken); + return false; + } + + lockToken = ix; + return true; + } + + /// + /// Acquire a shared lock for the given hash, blocking until that succeeds. + /// + /// Will block exclusive locks until released. + /// + public readonly void AcquireSharedLock(long hash, out int lockToken) + { + var ix = CalculateIndex(hash, GetProcessorHint()); + + ref var acquireRef = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(lockCounts), ix); + + while (true) + { + var res = Interlocked.Increment(ref acquireRef); + if (res < 0) + { + // Exclusively locked + _ = Interlocked.Decrement(ref acquireRef); + + // Spin until we can grab this one + _ = Thread.Yield(); + } + else + { + lockToken = ix; + return; + } + } + } + + /// + /// Release a lock previously acquired with or . + /// + public readonly void ReleaseSharedLock(int lockToken) + { + Debug.Assert(lockToken >= 0 && lockToken < lockCounts.Length, "Invalid lock token"); + + ref var releaseRef = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(lockCounts), lockToken); + + _ = Interlocked.Decrement(ref releaseRef); + } + + /// + /// Attempt to acquire an exclusive lock for the given hash. + /// + /// Will block all other locks until released. + /// + public readonly bool TryAcquireExclusiveLock(long hash, out int lockToken) + { + ref var countRef = ref MemoryMarshal.GetArrayDataReference(lockCounts); + + var coreCount = coreSelectionMask + 1; + for (var i = 0; i < coreCount; i++) + { + var acquireIx = CalculateIndex(hash, i); + ref var acquireRef = ref Unsafe.Add(ref countRef, acquireIx); + + if (Interlocked.CompareExchange(ref acquireRef, int.MinValue, 0) != 0) + { + // Failed, release previously acquired + for (var j = 0; j < i; j++) + { + var releaseIx = CalculateIndex(hash, j); + + ref var releaseRef = ref Unsafe.Add(ref countRef, releaseIx); + while (Interlocked.CompareExchange(ref releaseRef, 0, int.MinValue) != int.MinValue) + { + // Optimistic shared lock got us, back off and try again + _ = Thread.Yield(); + } + } + + Unsafe.SkipInit(out lockToken); + return false; + } + } + + // Successfully acquired all shards exclusively + + // Throwing away half the hash shouldn't affect correctness since we do the same thing when processing the full hash + lockToken = (int)hash; + + return true; + } + + /// + /// Acquire an exclusive lock for the given hash, blocking until that succeeds. + /// + /// Will block all other locks until released. + /// + public readonly void AcquireExclusiveLock(long hash, out int lockToken) + { + ref var countRef = ref MemoryMarshal.GetArrayDataReference(lockCounts); + + var coreCount = coreSelectionMask + 1; + for (var i = 0; i < coreCount; i++) + { + var acquireIx = CalculateIndex(hash, i); + + ref var acquireRef = ref Unsafe.Add(ref countRef, acquireIx); + while (Interlocked.CompareExchange(ref acquireRef, int.MinValue, 0) != 0) + { + // Optimistic shared lock got us, or conflict with some other excluive lock acquisition + // + // Backoff and try again + _ = Thread.Yield(); + } + } + + // Throwing away half the hash shouldn't affect correctness since we do the same thing when processing the full hash + lockToken = (int)hash; + } + + /// + /// Release a lock previously acquired with , , or . + /// + public readonly void ReleaseExclusiveLock(int lockToken) + { + // The lockToken is a hash, so no range check here + + ref var countRef = ref MemoryMarshal.GetArrayDataReference(lockCounts); + + var hash = lockToken; + + var coreCount = coreSelectionMask + 1; + for (var i = 0; i < coreCount; i++) + { + var releaseIx = CalculateIndex(hash, i); + + ref var releaseRef = ref Unsafe.Add(ref countRef, releaseIx); + while (Interlocked.CompareExchange(ref releaseRef, 0, int.MinValue) != int.MinValue) + { + // Optimistic shared lock got us, back off and try again + _ = Thread.Yield(); + } + } + } + + /// + /// Attempt to promote a shared lock previously acquired via or to an exclusive lock. + /// + /// If successful, will block all other locks until released. + /// + /// If successful, must be released with . + /// + /// If unsuccessful, shared lock will still be held and must be released with . + /// + public readonly bool TryPromoteSharedLock(long hash, int lockToken, out int newLockToken) + { + Debug.Assert(Interlocked.CompareExchange(ref lockCounts[lockToken], 0, 0) > 0, "Illegal call when not holding shard lock"); + + Debug.Assert(lockToken >= 0 && lockToken < lockCounts.Length, "Invalid lock token"); + + ref var countRef = ref MemoryMarshal.GetArrayDataReference(lockCounts); + + var coreCount = coreSelectionMask + 1; + for (var i = 0; i < coreCount; i++) + { + var acquireIx = CalculateIndex(hash, i); + ref var acquireRef = ref Unsafe.Add(ref countRef, acquireIx); + + if (acquireIx == lockToken) + { + // Do the promote + if (Interlocked.CompareExchange(ref acquireRef, int.MinValue, 1) != 1) + { + // Failed, release previously acquired all of which are exclusive locks + for (var j = 0; j < i; j++) + { + var releaseIx = CalculateIndex(hash, j); + + ref var releaseRef = ref Unsafe.Add(ref countRef, releaseIx); + while (Interlocked.CompareExchange(ref releaseRef, 0, int.MinValue) != int.MinValue) + { + // Optimistic shared lock got us, back off and try again + _ = Thread.Yield(); + } + } + + // Note we're still holding the shared lock here + Unsafe.SkipInit(out newLockToken); + return false; + } + } + else + { + // Otherwise attempt an exclusive acquire + if (Interlocked.CompareExchange(ref acquireRef, int.MinValue, 0) != 0) + { + // Failed, release previously acquired - one of which MIGHT be the shared lock + for (var j = 0; j < i; j++) + { + var releaseIx = CalculateIndex(hash, j); + var releaseTargetValue = releaseIx == lockToken ? 1 : 0; + + ref var releaseRef = ref Unsafe.Add(ref countRef, releaseIx); + while (Interlocked.CompareExchange(ref releaseRef, releaseTargetValue, int.MinValue) != int.MinValue) + { + // Optimistic shared lock got us, back off and try again + _ = Thread.Yield(); + } + } + + // Note we're still holding the shared lock here + Unsafe.SkipInit(out newLockToken); + return false; + } + } + } + + // Throwing away half the hash shouldn't affect correctness since we do the same thing when processing the full hash + newLockToken = (int)hash; + return true; + } + + /// + /// Get a somewhat-correlated-to-processor value. + /// + /// While we could use , that isn't fast on all platforms. + /// + /// For our purposes, we just need something that will tend to keep different active processors + /// from touching each other. ManagedThreadId works well enough. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetProcessorHint() + { + var ret = ProcessorHint; + if (ret == 0) + { + ProcessorHint = ret = Environment.CurrentManagedThreadId; + } + + return ret; + } + } +} \ No newline at end of file diff --git a/libs/common/RespReadUtils.cs b/libs/common/RespReadUtils.cs index 92c41ec4739..1202e8c0e09 100644 --- a/libs/common/RespReadUtils.cs +++ b/libs/common/RespReadUtils.cs @@ -1341,5 +1341,40 @@ public static bool TryReadInfinity(ReadOnlySpan value, out double number) number = default; return false; } + + /// + /// Parses "[+/-]inf" string and returns float.PositiveInfinity/float.NegativeInfinity respectively. + /// If string is not an infinity, parsing fails. + /// + /// input data + /// If parsing was successful,contains positive or negative infinity + /// True is infinity was read, false otherwise + public static bool TryReadInfinity(ReadOnlySpan value, out float number) + { + if (value.Length == 3) + { + if (value.EqualsUpperCaseSpanIgnoringCase(RespStrings.INFINITY)) + { + number = float.PositiveInfinity; + return true; + } + } + else if (value.Length == 4) + { + if (value.EqualsUpperCaseSpanIgnoringCase(RespStrings.POS_INFINITY, true)) + { + number = float.PositiveInfinity; + return true; + } + else if (value.EqualsUpperCaseSpanIgnoringCase(RespStrings.NEG_INFINITY, true)) + { + number = float.NegativeInfinity; + return true; + } + } + + number = default; + return false; + } } } \ No newline at end of file diff --git a/libs/host/Configuration/Options.cs b/libs/host/Configuration/Options.cs index 415c8b2e494..03607922117 100644 --- a/libs/host/Configuration/Options.cs +++ b/libs/host/Configuration/Options.cs @@ -672,6 +672,10 @@ public IEnumerable LuaAllowedFunctions [Option("cluster-replica-resume-with-data", Required = false, HelpText = "If a Cluster Replica resumes with data, allow it to be served prior to a Primary being available")] public bool ClusterReplicaResumeWithData { get; set; } + [RequiresMinimumMemory(nameof(PageSize), minimumValue: "16K")] + [Option("enable-vector-set-preview", Required = false, HelpText = "Enable Vector Sets (preview) - this feature (and associated commands) are incomplete, unstable, and subject to change while still in preview")] + public bool EnableVectorSetPreview { get; set; } + /// /// This property contains all arguments that were not parsed by the command line argument parser /// @@ -960,6 +964,7 @@ public GarnetServerOptions GetServerOptions(ILogger logger = null) ExpiredKeyDeletionScanFrequencySecs = ExpiredKeyDeletionScanFrequencySecs, ClusterReplicationReestablishmentTimeout = ClusterReplicationReestablishmentTimeout, ClusterReplicaResumeWithData = ClusterReplicaResumeWithData, + EnableVectorSetPreview = EnableVectorSetPreview, }; } diff --git a/libs/host/Configuration/OptionsValidators.cs b/libs/host/Configuration/OptionsValidators.cs index 545bd9c9330..f865ff477d5 100644 --- a/libs/host/Configuration/OptionsValidators.cs +++ b/libs/host/Configuration/OptionsValidators.cs @@ -11,6 +11,7 @@ using System.Text; using System.Text.RegularExpressions; using Garnet.common; +using Garnet.server; using Microsoft.Extensions.Logging; namespace Garnet @@ -628,6 +629,53 @@ protected override ValidationResult IsValid(object value, ValidationContext vali } } + /// + /// Validate that, when annotated property is set, another option has a least a minimum memory value. + /// + [AttributeUsage(AttributeTargets.Property)] + internal sealed class RequiresMinimumMemory : OptionValidationAttribute + { + private readonly string otherOptionName; + private readonly string minimumValue; + private readonly long minimumValueBytes; + + internal RequiresMinimumMemory(string otherOptionName, string minimumValue) + { + this.otherOptionName = otherOptionName; + this.minimumValue = minimumValue; + + minimumValueBytes = GarnetServerOptions.ParseSize(this.minimumValue, out var readBytes); + if (readBytes != minimumValue.Length) + { + // If we can't parse config, disable validation + minimumValueBytes = long.MinValue; + } + } + + /// + protected override ValidationResult IsValid(object value, ValidationContext validationContext) + { + var optionIsSet = value is bool valueBool && valueBool; + if (optionIsSet) + { + var propAccessor = validationContext.ObjectInstance?.GetType()?.GetProperty(otherOptionName, BindingFlags.Instance | BindingFlags.Public); + if (propAccessor != null) + { + var otherOptionValue = propAccessor.GetValue(validationContext.ObjectInstance); + var otherOptionValueAsString = (otherOptionValue is string strVal ? strVal : otherOptionValue?.ToString())?.Trim(); + + var otherOptionValueBytes = GarnetServerOptions.ParseSize(otherOptionValueAsString, out var readBytes); + if (readBytes == otherOptionValueAsString.Length && otherOptionValueBytes < minimumValueBytes) + { + return new ValidationResult($"{validationContext.DisplayName} requires {otherOptionName} be at least '{minimumValue}'"); + } + } + } + + return ValidationResult.Success; + } + } + /// /// Forbids a config option from being set if the current OS platform is not supported. /// diff --git a/libs/host/GarnetServer.cs b/libs/host/GarnetServer.cs index 2d12f43a0d4..05103d39106 100644 --- a/libs/host/GarnetServer.cs +++ b/libs/host/GarnetServer.cs @@ -303,9 +303,18 @@ private GarnetDatabase CreateDatabase(int dbId, GarnetServerOptions serverOption var store = CreateMainStore(dbId, clusterFactory, out var epoch, out var stateMachineDriver); var objectStore = CreateObjectStore(dbId, clusterFactory, customCommandManager, epoch, stateMachineDriver, out var objectStoreSizeTracker); var (aofDevice, aof) = CreateAOF(dbId); + + var vectorManager = new VectorManager( + serverOptions.EnableVectorSetPreview, + dbId, + () => Provider.GetSession(WireFormat.ASCII, null), + loggerFactory + ); + return new GarnetDatabase(dbId, store, objectStore, epoch, stateMachineDriver, objectStoreSizeTracker, aofDevice, aof, serverOptions.AdjustedIndexMaxCacheLines == 0, - serverOptions.AdjustedObjectStoreIndexMaxCacheLines == 0); + serverOptions.AdjustedObjectStoreIndexMaxCacheLines == 0, + vectorManager); } private void LoadModules(CustomCommandManager customCommandManager) diff --git a/libs/host/defaults.conf b/libs/host/defaults.conf index d2628beedb1..a4519d58816 100644 --- a/libs/host/defaults.conf +++ b/libs/host/defaults.conf @@ -453,5 +453,8 @@ "ClusterReplicationReestablishmentTimeout": 0, /* If a Cluster Replica has on disk checkpoints or AOF, if that data should be loaded on restart instead of waiting for a Primary to sync with */ - "ClusterReplicaResumeWithData": false + "ClusterReplicaResumeWithData": false, + + /* Enable Vector Sets (preview) - this feature (and associated commands) are incomplete, unstable, and subject to change while still in preview */ + "EnableVectorSetPreview": false } \ No newline at end of file diff --git a/libs/resources/RespCommandsDocs.json b/libs/resources/RespCommandsDocs.json index be77703a3ed..049b46fd80f 100644 --- a/libs/resources/RespCommandsDocs.json +++ b/libs/resources/RespCommandsDocs.json @@ -7719,6 +7719,204 @@ "Group": "Transactions", "Complexity": "O(1)" }, + { + "Command": "VADD", + "Name": "VADD", + "Summary": "Add a new element into the vector set.", + "Group": "Vector", + "Complexity": "O(log(N))", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VCARD", + "Name": "VCARD", + "Summary": "Return the number of elements in a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VDIM", + "Name": "VDIM", + "Summary": "Return the number of dimensions in a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VEMB", + "Name": "VEMB", + "Summary": "Return the approximate vector associated with an element in a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VGETATTR", + "Name": "VGETATTR", + "Summary": "Return the JSON attributes associated with the element in the vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VINFO", + "Name": "VINFO", + "Summary": "Return details about a vector set, including dimensions, quantization, and structure.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VISMEMBER", + "Name": "VISMEMBER", + "Summary": "Determines whether a member belongs to vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "ELEMENT", + "DisplayText": "element", + "Type": "String" + } + ] + }, + { + "Command": "VLINKS", + "Name": "VLINKS", + "Summary": "Return the neighbors of an element in a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VRANDMEMBER", + "Name": "VRANDMEMBER", + "Summary": "Return some number of random elements from a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VREM", + "Name": "VREM", + "Summary": "Remove an element from a vector set.", + "Group": "Vector", + "Complexity": "O(log(N))", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VSETATTR", + "Name": "VSETATTR", + "Summary": "Store attributes alongside a member of a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VSIM", + "Name": "VSIM", + "Summary": "Return elements similar to a given vector or existing element of a vector set.", + "Group": "Vector", + "Complexity": "O(log(N))", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, { "Command": "WATCH", "Name": "WATCH", diff --git a/libs/resources/RespCommandsInfo.json b/libs/resources/RespCommandsInfo.json index daa1acd29d3..40aa9686505 100644 --- a/libs/resources/RespCommandsInfo.json +++ b/libs/resources/RespCommandsInfo.json @@ -811,6 +811,14 @@ "Flags": "Admin, NoMulti, NoScript", "AclCategories": "Admin, Dangerous, Slow, Garnet" }, + { + "Command": "CLUSTER_RESERVE", + "Name": "CLUSTER|RESERVE", + "IsInternal": true, + "Arity": 4, + "Flags": "Admin, NoMulti, NoScript", + "AclCategories": "Admin, Dangerous, Garnet" + }, { "Command": "CLUSTER_MTASKS", "Name": "CLUSTER|MTASKS", @@ -5093,6 +5101,306 @@ "Flags": "Fast, Loading, NoScript, Stale, AllowBusy", "AclCategories": "Fast, Transaction" }, + { + "Command": "VADD", + "Name": "VADD", + "Arity": -1, + "Flags": "DenyOom, Write, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Vector, Write", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RW, Insert" + } + ] + }, + { + "Command": "VCARD", + "Name": "VCARD", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VDIM", + "Name": "VDIM", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VEMB", + "Name": "VEMB", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VGETATTR", + "Name": "VGETATTR", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VINFO", + "Name": "VINFO", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VISMEMBER", + "Name": "VISMEMBER", + "Arity": 3, + "Flags": "Fast, ReadOnly", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VLINKS", + "Name": "VLINKS", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VRANDMEMBER", + "Name": "VRANDMEMBER", + "Arity": -1, + "Flags": "ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Slow, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VREM", + "Name": "VREM", + "Arity": -1, + "Flags": "Write, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Slow, Write, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RW, Delete" + } + ] + }, + { + "Command": "VSETATTR", + "Name": "VSETATTR", + "Arity": -1, + "Flags": "Fast, Write, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Write, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RW, Insert" + } + ] + }, + { + "Command": "VSIM", + "Name": "VSIM", + "Arity": -1, + "Flags": "ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Slow, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, { "Command": "WATCH", "Name": "WATCH", diff --git a/libs/server/ACL/ACLParser.cs b/libs/server/ACL/ACLParser.cs index 621d7a44e8e..2ee3297867c 100644 --- a/libs/server/ACL/ACLParser.cs +++ b/libs/server/ACL/ACLParser.cs @@ -33,6 +33,7 @@ class ACLParser ["stream"] = RespAclCategories.Stream, ["string"] = RespAclCategories.String, ["transaction"] = RespAclCategories.Transaction, + ["vector"] = RespAclCategories.Vector, ["write"] = RespAclCategories.Write, ["garnet"] = RespAclCategories.Garnet, ["custom"] = RespAclCategories.Custom, diff --git a/libs/server/AOF/AofProcessor.cs b/libs/server/AOF/AofProcessor.cs index 0c20a4ba6cc..1c33a4efb5e 100644 --- a/libs/server/AOF/AofProcessor.cs +++ b/libs/server/AOF/AofProcessor.cs @@ -28,6 +28,7 @@ public sealed unsafe partial class AofProcessor readonly AofReplayCoordinator aofReplayCoordinator; int activeDbId; + VectorManager activeVectorManager; /// /// Set ReadWriteSession on the cluster session (NOTE: used for replaying stored procedures only) @@ -47,11 +48,12 @@ public void SetReadWriteSession() /// BasicContext objectStoreBasicContext; - readonly StoreWrapper replayAofStoreWrapper; readonly IClusterProvider clusterProvider; readonly ILogger logger; + readonly Func obtainServerSession; + /// /// Create new AOF processor /// @@ -64,10 +66,12 @@ public AofProcessor( this.storeWrapper = storeWrapper; this.clusterProvider = clusterProvider; - replayAofStoreWrapper = new StoreWrapper(storeWrapper, recordToAof); + var replayAofStoreWrapper = new StoreWrapper(storeWrapper, recordToAof); + + obtainServerSession = () => new(0, networkSender: null, storeWrapper: replayAofStoreWrapper, subscribeBroker: null, authenticator: null, enableScripts: false, clusterProvider: clusterProvider); this.activeDbId = 0; - this.respServerSession = ObtainServerSession(); + this.respServerSession = obtainServerSession(); // Switch current contexts to match the default database SwitchActiveDatabaseContext(storeWrapper.DefaultDatabase, true); @@ -76,9 +80,6 @@ public AofProcessor( this.logger = logger; } - private RespServerSession ObtainServerSession() - => new(0, networkSender: null, storeWrapper: replayAofStoreWrapper, subscribeBroker: null, authenticator: null, enableScripts: false, clusterProvider: clusterProvider); - /// /// Dispose /// @@ -199,6 +200,12 @@ public unsafe void ProcessAofRecordInternal(byte* ptr, int length, bool asReplic var replayContext = aofReplayCoordinator.GetReplayContext(); isCheckpointStart = false; + // Aggressively do not move data if VADD are being replayed + if (header.opType != AofEntryType.StoreRMW) + { + activeVectorManager.WaitForVectorOperationsToComplete(); + } + // Handle transactions if (aofReplayCoordinator.AddOrReplayTransactionOperation(ptr, length, asReplica)) return; @@ -279,6 +286,14 @@ private unsafe bool ReplayOp(TContext storeContext, TO var header = *(AofHeader*)entryPtr; var replayContext = aofReplayCoordinator.GetReplayContext(); + // StoreRMW can queue VADDs onto different threads + // but everything else needs to WAIT for those to complete + // otherwise we might loose consistency + if (header.opType != AofEntryType.StoreRMW) + { + activeVectorManager.WaitForVectorOperationsToComplete(); + } + // Skips (1) entries with versions that were part of prior checkpoint; and (2) future entries in fuzzy region if (SkipRecord(replayContext.inFuzzyRegion, entryPtr, length, asReplica)) return false; @@ -291,10 +306,10 @@ private unsafe bool ReplayOp(TContext storeContext, TO StoreUpsert(storeContext, replayContext.storeInput, entryPtr + sizeof(AofHeader)); break; case AofEntryType.StoreRMW: - StoreRMW(storeContext, replayContext.storeInput, entryPtr + sizeof(AofHeader)); + StoreRMW(storeContext, replayContext.storeInput, activeVectorManager, respServerSession, obtainServerSession, entryPtr + sizeof(AofHeader)); break; case AofEntryType.StoreDelete: - StoreDelete(storeContext, entryPtr + sizeof(AofHeader)); + StoreDelete(storeContext, activeVectorManager, respServerSession.storageSession, entryPtr + sizeof(AofHeader)); break; case AofEntryType.ObjectStoreRMW: ObjectStoreRMW(objectStoreContext, replayContext.objectStoreInput, entryPtr + sizeof(AofHeader), bufferPtr, bufferLength); @@ -337,6 +352,8 @@ private void SwitchActiveDatabaseContext(GarnetDatabase db, bool initialSetup = objectStoreBasicContext = objectStoreSession.BasicContext; this.activeDbId = db.Id; } + + activeVectorManager = db.VectorManager; } static void StoreUpsert( @@ -364,6 +381,9 @@ static void StoreUpsert( static void StoreRMW( TContext context, RawStringInput storeInput, + VectorManager vectorManager, + RespServerSession currentSession, + Func obtainServerSession, byte* ptr) where TContext : ITsavoriteContext { @@ -374,22 +394,52 @@ static void StoreRMW( // Reconstructing RawStringInput _ = storeInput.DeserializeFrom(curr); + // VADD requires special handling, shove it over to the VectorManager + if (storeInput.header.cmd == RespCommand.VADD) + { + vectorManager.HandleVectorSetAddReplication(currentSession.storageSession, obtainServerSession, ref key, ref storeInput); + return; + } + else + { + // Any other op (include other vector ops) need to wait for pending VADDs to complete + vectorManager.WaitForVectorOperationsToComplete(); + + // VREM is also read-like, so requires special handling - shove it over to the VectorManager + if (storeInput.header.cmd == RespCommand.VREM) + { + vectorManager.HandleVectorSetRemoveReplication(currentSession.storageSession, ref key, ref storeInput); + return; + } + } + var pbOutput = stackalloc byte[32]; var output = new SpanByteAndMemory(pbOutput, 32); if (context.RMW(ref key, ref storeInput, ref output).IsPending) _ = context.CompletePending(true); + if (!output.IsSpanByte) output.Memory.Dispose(); } static void StoreDelete( TContext context, + VectorManager vectorManager, + StorageSession storageSession, byte* ptr) where TContext : ITsavoriteContext { ref var key = ref Unsafe.AsRef(ptr); - _ = context.Delete(ref key); + var res = context.Delete(ref key); + + if (res.IsCanceled) + { + // Might be a vector set + res = vectorManager.TryDeleteVectorSet(storageSession, ref key, out _); + if (res.IsPending) + _ = context.CompletePending(true); + } } static void ObjectStoreUpsert( diff --git a/libs/server/API/GarnetApi.cs b/libs/server/API/GarnetApi.cs index 09d23aad563..1b5275eac30 100644 --- a/libs/server/API/GarnetApi.cs +++ b/libs/server/API/GarnetApi.cs @@ -21,9 +21,10 @@ namespace Garnet.server /// /// Garnet API implementation /// - public partial struct GarnetApi : IGarnetApi, IGarnetWatchApi + public partial struct GarnetApi : IGarnetApi, IGarnetWatchApi where TContext : ITsavoriteContext where TObjectContext : ITsavoriteContext + where TVectorContext : ITsavoriteContext { readonly StorageSession storageSession; TContext context; @@ -48,8 +49,12 @@ public void WATCH(byte[] key, StoreType type) #region GET /// - public GarnetStatus GET(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output) - => storageSession.GET(ref key, ref input, ref output, ref context); + public GarnetStatus GET(ArgSlice key, ref RawStringInput input, ref SpanByteAndMemory output) + { + var asSpanByte = key.SpanByte; + + return storageSession.GET(ref asSpanByte, ref input, ref output, ref context); + } /// public GarnetStatus GET_WithPending(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output, long ctx, out bool pending) @@ -68,7 +73,9 @@ public unsafe GarnetStatus GETForMemoryResult(ArgSlice key, out MemoryResult public unsafe GarnetStatus GET(ArgSlice key, out ArgSlice value) - => storageSession.GET(key, out value, ref context); + { + return storageSession.GET(key, out value, ref context); + } /// public GarnetStatus GET(byte[] key, out GarnetObjectStoreOutput value) @@ -118,33 +125,52 @@ public GarnetStatus PEXPIRETIME(ref SpanByte key, StoreType storeType, ref SpanB #endregion #region SET - /// + public GarnetStatus SET(ref SpanByte key, ref SpanByte value) - => storageSession.SET(ref key, ref value, ref context); + => storageSession.SET(ref key, ref value, ref context); /// - public GarnetStatus SET(ref SpanByte key, ref RawStringInput input, ref SpanByte value) - => storageSession.SET(ref key, ref input, ref value, ref context); + public GarnetStatus SET(ArgSlice key, ref RawStringInput input, ref SpanByte value) + { + var asSpanByte = key.SpanByte; - /// - public GarnetStatus SET_Conditional(ref SpanByte key, ref RawStringInput input) - => storageSession.SET_Conditional(ref key, ref input, ref context); + return storageSession.SET(ref asSpanByte, ref input, ref value, ref context); + } /// public GarnetStatus DEL_Conditional(ref SpanByte key, ref RawStringInput input) => storageSession.DEL_Conditional(ref key, ref input, ref context); /// - public GarnetStatus SET_Conditional(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output) - => storageSession.SET_Conditional(ref key, ref input, ref output, ref context); + public GarnetStatus SET_Conditional(ArgSlice key, ref RawStringInput input, ref SpanByteAndMemory output) + { + var asSpanByte = key.SpanByte; + + return storageSession.SET_Conditional(ref asSpanByte, ref input, ref output, ref context); + } + + /// + public GarnetStatus SET_Conditional(ArgSlice key, ref RawStringInput input) + { + var asSpanByte = key.SpanByte; + + return storageSession.SET_Conditional(ref asSpanByte, ref input, ref context); + } /// public GarnetStatus SET(ArgSlice key, Memory value) - => storageSession.SET(key, value, ref context); + { + return storageSession.SET(key, value, ref context); + } /// public GarnetStatus SET(ArgSlice key, ArgSlice value) - => storageSession.SET(key, value, ref context); + { + var asSpanByte = key.SpanByte; + var valSpanByte = value.SpanByte; + + return storageSession.SET(ref asSpanByte, ref valSpanByte, ref context); + } /// public GarnetStatus SET(byte[] key, IGarnetObject value) @@ -302,7 +328,7 @@ public GarnetStatus DELETE(ArgSlice key, StoreType storeType = StoreType.All) /// public GarnetStatus DELETE(ref SpanByte key, StoreType storeType = StoreType.All) - => storageSession.DELETE(ref key, storeType, ref context, ref objectContext); + => storageSession.DELETE(ref key, storeType, ref context, ref objectContext); /// public GarnetStatus DELETE(byte[] key, StoreType storeType = StoreType.All) @@ -482,5 +508,41 @@ public int GetScratchBufferOffset() public bool ResetScratchBuffer(int offset) => storageSession.scratchBufferBuilder.ResetScratchBuffer(offset); #endregion + + #region VectorSet commands + + /// + public unsafe GarnetStatus VectorSetAdd(ArgSlice key, int reduceDims, VectorValueType valueType, ArgSlice values, ArgSlice element, VectorQuantType quantizer, int buildExplorationFactor, ArgSlice attributes, int numLinks, VectorDistanceMetricType distanceMetric, out VectorManagerResult result, out ReadOnlySpan errorMsg) + => storageSession.VectorSetAdd(SpanByte.FromPinnedPointer(key.ptr, key.length), reduceDims, valueType, values, element, quantizer, buildExplorationFactor, attributes, numLinks, distanceMetric, out result, out errorMsg); + + /// + public unsafe GarnetStatus VectorSetRemove(ArgSlice key, ArgSlice element) + => storageSession.VectorSetRemove(SpanByte.FromPinnedPointer(key.ptr, key.length), SpanByte.FromPinnedPointer(element.ptr, element.length)); + + /// + public unsafe GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice values, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + => storageSession.VectorSetValueSimilarity(SpanByte.FromPinnedPointer(key.ptr, key.length), valueType, values, count, delta, searchExplorationFactor, filter.ReadOnlySpan, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, out result); + + /// + public unsafe GarnetStatus VectorSetElementSimilarity(ArgSlice key, ArgSlice element, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + => storageSession.VectorSetElementSimilarity(SpanByte.FromPinnedPointer(key.ptr, key.length), element.ReadOnlySpan, count, delta, searchExplorationFactor, filter.ReadOnlySpan, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, out result); + + /// + public unsafe GarnetStatus VectorSetEmbedding(ArgSlice key, ArgSlice element, out VectorQuantType quantType, ref SpanByteAndMemory outputDistances) + => storageSession.VectorSetEmbedding(SpanByte.FromPinnedPointer(key.ptr, key.length), element.ReadOnlySpan, out quantType, ref outputDistances); + + /// + public unsafe GarnetStatus VectorSetDimensions(ArgSlice key, out int dimensions) + => storageSession.VectorSetDimensions(SpanByte.FromPinnedPointer(key.ptr, key.length), out dimensions); + + /// + public unsafe GarnetStatus VectorSetInfo(ArgSlice key, out VectorQuantType quantType, out VectorDistanceMetricType distanceMetricType, out uint vectorDimensions, out uint reducedDimensions, out uint buildExplorationFactor, out uint numberOfLinks, out long size) + => storageSession.VectorSetInfo(SpanByte.FromPinnedPointer(key.ptr, key.length), out quantType, out distanceMetricType, out vectorDimensions, out reducedDimensions, out buildExplorationFactor, out numberOfLinks, out size); + + /// + public unsafe GarnetStatus VectorSetGetAttribute(ArgSlice key, ArgSlice element, ref SpanByteAndMemory outputAttributes) + => storageSession.VectorSetGetAttribute(SpanByte.FromPinnedPointer(key.ptr, key.length), element, ref outputAttributes); + + #endregion } } \ No newline at end of file diff --git a/libs/server/API/GarnetApiObjectCommands.cs b/libs/server/API/GarnetApiObjectCommands.cs index b0a72473b8e..9ba483e08d7 100644 --- a/libs/server/API/GarnetApiObjectCommands.cs +++ b/libs/server/API/GarnetApiObjectCommands.cs @@ -16,9 +16,10 @@ namespace Garnet.server /// /// Garnet API implementation /// - public partial struct GarnetApi : IGarnetApi, IGarnetWatchApi + public partial struct GarnetApi : IGarnetApi, IGarnetWatchApi where TContext : ITsavoriteContext where TObjectContext : ITsavoriteContext + where TVectorContext : ITsavoriteContext { #region SortedSet Methods diff --git a/libs/server/API/GarnetStatus.cs b/libs/server/API/GarnetStatus.cs index 2277461ad43..20c965d1668 100644 --- a/libs/server/API/GarnetStatus.cs +++ b/libs/server/API/GarnetStatus.cs @@ -23,6 +23,10 @@ public enum GarnetStatus : byte /// /// Wrong type /// - WRONGTYPE + WRONGTYPE, + /// + /// Bad state + /// + BADSTATE, } } \ No newline at end of file diff --git a/libs/server/API/GarnetWatchApi.cs b/libs/server/API/GarnetWatchApi.cs index ac68e97e66f..b5da7158b7c 100644 --- a/libs/server/API/GarnetWatchApi.cs +++ b/libs/server/API/GarnetWatchApi.cs @@ -23,10 +23,10 @@ public GarnetWatchApi(TGarnetApi garnetApi) #region GET /// - public GarnetStatus GET(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output) + public GarnetStatus GET(ArgSlice key, ref RawStringInput input, ref SpanByteAndMemory output) { - garnetApi.WATCH(new ArgSlice(ref key), StoreType.Main); - return garnetApi.GET(ref key, ref input, ref output); + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.GET(key, ref input, ref output); } /// @@ -647,5 +647,50 @@ public bool ResetScratchBuffer(int offset) => garnetApi.ResetScratchBuffer(offset); #endregion + + #region Vector Sets + /// + public GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice value, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + { + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.VectorSetValueSimilarity(key, valueType, value, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, out result); + } + + /// + public GarnetStatus VectorSetElementSimilarity(ArgSlice key, ArgSlice element, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + { + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.VectorSetElementSimilarity(key, element, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, out result); + } + + /// + public GarnetStatus VectorSetEmbedding(ArgSlice key, ArgSlice element, out VectorQuantType quantType, ref SpanByteAndMemory outputDistances) + { + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.VectorSetEmbedding(key, element, out quantType, ref outputDistances); + } + + /// + public GarnetStatus VectorSetDimensions(ArgSlice key, out int dimensions) + { + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.VectorSetDimensions(key, out dimensions); + } + + /// + public GarnetStatus VectorSetInfo(ArgSlice key, out VectorQuantType quantType, out VectorDistanceMetricType distanceMetricType, out uint vectorDimensions, out uint reducedDimensions, out uint buildExplorationFactor, out uint numberOfLinks, out long size) + { + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.VectorSetInfo(key, out quantType, out distanceMetricType, out vectorDimensions, out reducedDimensions, out buildExplorationFactor, out numberOfLinks, out size); + } + + /// + public GarnetStatus VectorSetGetAttribute(ArgSlice key, ArgSlice element, ref SpanByteAndMemory outputAttributes) + { + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.VectorSetGetAttribute(key, element, ref outputAttributes); + } + + #endregion } } \ No newline at end of file diff --git a/libs/server/API/IGarnetApi.cs b/libs/server/API/IGarnetApi.cs index a78ac22f556..e657001e035 100644 --- a/libs/server/API/IGarnetApi.cs +++ b/libs/server/API/IGarnetApi.cs @@ -26,17 +26,12 @@ public interface IGarnetApi : IGarnetReadApi, IGarnetAdvancedApi /// /// SET /// - GarnetStatus SET(ref SpanByte key, ref SpanByte value); - - /// - /// SET - /// - GarnetStatus SET(ref SpanByte key, ref RawStringInput input, ref SpanByte value); + GarnetStatus SET(ArgSlice key, ref RawStringInput input, ref SpanByte value); /// /// SET Conditional /// - GarnetStatus SET_Conditional(ref SpanByte key, ref RawStringInput input); + GarnetStatus SET_Conditional(ArgSlice key, ref RawStringInput input); /// /// DEL Conditional @@ -46,7 +41,7 @@ public interface IGarnetApi : IGarnetReadApi, IGarnetAdvancedApi /// /// SET Conditional /// - GarnetStatus SET_Conditional(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output); + GarnetStatus SET_Conditional(ArgSlice key, ref RawStringInput input, ref SpanByteAndMemory output); /// /// SET @@ -1206,6 +1201,18 @@ GarnetStatus GeoSearchStore(ArgSlice key, ArgSlice destinationKey, ref GeoSearch GarnetStatus HyperLogLogMerge(ref RawStringInput input, out bool error); #endregion + + #region VectorSet Methods + /// + /// Adds to (and may create) a vector set with the given parameters. + /// + GarnetStatus VectorSetAdd(ArgSlice key, int reduceDims, VectorValueType valueType, ArgSlice value, ArgSlice element, VectorQuantType quantizer, int buildExplorationFactor, ArgSlice attributes, int numLinks, VectorDistanceMetricType distanceMetric, out VectorManagerResult result, out ReadOnlySpan errorMsg); + + /// + /// Remove a member from a vector set, if it is present and the key exists. + /// + GarnetStatus VectorSetRemove(ArgSlice key, ArgSlice element); + #endregion } /// @@ -1217,7 +1224,7 @@ public interface IGarnetReadApi /// /// GET /// - GarnetStatus GET(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output); + GarnetStatus GET(ArgSlice key, ref RawStringInput input, ref SpanByteAndMemory output); /// /// GET @@ -2026,6 +2033,47 @@ public bool IterateObjectStore(ref TScanFunctions scanFunctions, #endregion + #region Vector Sets + + /// + /// Perform a similarity search given a vector and these parameters. + /// + /// Ids are encoded in as length prefixed blobs of bytes. + /// Attributes are encoded in as length prefixed blobs of bytes. + /// + GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice value, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result); + + /// + /// Perform a similarity search given an element already in the vector set and these parameters. + /// + /// Ids are encoded in as length prefixed blobs of bytes. + /// Attributes are encoded in as length prefixed blobs of bytes. + /// + GarnetStatus VectorSetElementSimilarity(ArgSlice key, ArgSlice element, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result); + + /// + /// Fetch the embedding of a given element in a Vector set. + /// + GarnetStatus VectorSetEmbedding(ArgSlice key, ArgSlice element, out VectorQuantType quantType, ref SpanByteAndMemory outputDistances); + + /// + /// Fetch the dimensionality of the given Vector Set. + /// + /// If the Vector Set was created with reduced dimensions, reports the reduced dimensions. + /// + GarnetStatus VectorSetDimensions(ArgSlice key, out int dimensions); + + /// + /// Fetch debugging information about the Vector Set. + /// + GarnetStatus VectorSetInfo(ArgSlice key, out VectorQuantType quantType, out VectorDistanceMetricType distanceMetricType, out uint vectorDimensions, out uint reducedDimensions, out uint buildExplorationFactor, out uint numberOfLinks, out long size); + + /// + /// Get the attributes associated with an element in the Vector Set. + /// + GarnetStatus VectorSetGetAttribute(ArgSlice key, ArgSlice element, ref SpanByteAndMemory outputAttributes); + + #endregion } /// diff --git a/libs/server/ArgSlice/ArgSliceVector.cs b/libs/server/ArgSlice/ArgSliceVector.cs index 07091e1b130..26e792d4f56 100644 --- a/libs/server/ArgSlice/ArgSliceVector.cs +++ b/libs/server/ArgSlice/ArgSliceVector.cs @@ -4,6 +4,8 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; +using Tsavorite.core; namespace Garnet.server { @@ -11,13 +13,13 @@ namespace Garnet.server /// Vector of ArgSlices /// /// - public unsafe class ArgSliceVector(int maxItemNum = 1 << 18) : IEnumerable + public unsafe class ArgSliceVector(int maxItemNum = 1 << 18) : IEnumerable { ScratchBufferBuilder bufferManager = new(); readonly int maxCount = maxItemNum; public int Count => items.Count; public bool IsEmpty => items.Count == 0; - readonly List items = []; + readonly List items = []; /// /// Try to add ArgSlice @@ -29,7 +31,32 @@ public bool TryAddItem(Span item) if (Count + 1 >= maxCount) return false; - items.Add(bufferManager.CreateArgSlice(item)); + var argSlice = bufferManager.CreateArgSlice(item); + + items.Add(argSlice.SpanByte); + return true; + } + + /// + /// Try to add ArgSlice + /// + /// + /// True if it succeeds to add ArgSlice, false if maxCount has been reached. + public bool TryAddItem(ulong ns, Span item) + { + Debug.Assert(ns <= byte.MaxValue, "Only byte-size namespaces supported currently"); + + if (Count + 1 >= maxCount) + return false; + + var argSlice = bufferManager.CreateArgSlice(item.Length + 1); + var sb = argSlice.SpanByte; + + sb.MarkNamespace(); + sb.SetNamespaceInPayload((byte)ns); + item.CopyTo(sb.AsSpan()); + + items.Add(sb); return true; } @@ -42,7 +69,7 @@ public void Clear() bufferManager.Reset(); } - public IEnumerator GetEnumerator() + public IEnumerator GetEnumerator() { foreach (var item in items) yield return item; diff --git a/libs/server/Cluster/ClusterSlotVerificationInput.cs b/libs/server/Cluster/ClusterSlotVerificationInput.cs index 8b673189add..9ccf6d4b315 100644 --- a/libs/server/Cluster/ClusterSlotVerificationInput.cs +++ b/libs/server/Cluster/ClusterSlotVerificationInput.cs @@ -34,5 +34,14 @@ public struct ClusterSlotVerificationInput /// Offset of key num if any /// public int keyNumOffset; + + /// + /// If the command being executed requires a slot be STABLE for executing. + /// + /// This requires special handling during migrations. + /// + /// Currently only true for Vector Set commands that are writes. + /// + public bool waitForStableSlot; } } \ No newline at end of file diff --git a/libs/server/Cluster/IClusterProvider.cs b/libs/server/Cluster/IClusterProvider.cs index 344c88c41e2..f8d854ed409 100644 --- a/libs/server/Cluster/IClusterProvider.cs +++ b/libs/server/Cluster/IClusterProvider.cs @@ -12,22 +12,33 @@ namespace Garnet.server { + using BasicContext = BasicContext, + SpanByteAllocator>>; + using BasicGarnetApi = GarnetApi, SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; + + using VectorContext = BasicContext, SpanByteAllocator>>; /// /// Cluster provider /// public interface IClusterProvider : IDisposable { + // TODO: I really hate having to pass Vector and Basic contexts here... cleanup + /// /// Create cluster session /// - IClusterSession CreateClusterSession(TransactionManager txnManager, IGarnetAuthenticator authenticator, UserHandle userHandle, GarnetSessionMetrics garnetSessionMetrics, BasicGarnetApi basicGarnetApi, INetworkSender networkSender, ILogger logger = null); + IClusterSession CreateClusterSession(TransactionManager txnManager, IGarnetAuthenticator authenticator, UserHandle userHandle, GarnetSessionMetrics garnetSessionMetrics, BasicGarnetApi basicGarnetApi, BasicContext basicContext, VectorContext vectorContext, INetworkSender networkSender, ILogger logger = null); /// diff --git a/libs/server/Cluster/IClusterSession.cs b/libs/server/Cluster/IClusterSession.cs index 045d4de959b..f42e5b96490 100644 --- a/libs/server/Cluster/IClusterSession.cs +++ b/libs/server/Cluster/IClusterSession.cs @@ -62,7 +62,7 @@ public interface IClusterSession /// /// Process cluster commands /// - unsafe void ProcessClusterCommands(RespCommand command, ref SessionParseState parseState, ref byte* dcurr, ref byte* dend); + unsafe void ProcessClusterCommands(RespCommand command, VectorManager vectorManager, ref SessionParseState parseState, ref byte* dcurr, ref byte* dend); /// /// Reset cached slot verification result @@ -77,7 +77,7 @@ public interface IClusterSession /// /// /// - bool NetworkIterativeSlotVerify(ArgSlice keySlice, bool readOnly, byte SessionAsking); + bool NetworkIterativeSlotVerify(ArgSlice keySlice, bool readOnly, byte SessionAsking, bool waitForStableSlot); /// /// Write cached slot verification message to output @@ -88,7 +88,7 @@ public interface IClusterSession /// /// Key array slot verify (write result to network) /// - unsafe bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, byte SessionAsking, ref byte* dcurr, ref byte* dend, int count = -1); + unsafe bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, byte SessionAsking, bool waitForStableSlot, ref byte* dcurr, ref byte* dend, int count = -1); /// /// Array slot verify (write result to network) diff --git a/libs/server/Databases/DatabaseManagerBase.cs b/libs/server/Databases/DatabaseManagerBase.cs index 2700eaa088c..29e6f9f8ec7 100644 --- a/libs/server/Databases/DatabaseManagerBase.cs +++ b/libs/server/Databases/DatabaseManagerBase.cs @@ -119,6 +119,9 @@ public abstract Task TaskCheckpointBasedOnAofSizeLimitAsync(long aofSizeLimit, /// public abstract IDatabaseManager Clone(bool enableAof); + /// + public abstract void RecoverVectorSets(); + /// public TsavoriteKV MainStore => DefaultDatabase.MainStore; @@ -414,7 +417,7 @@ protected void ExecuteObjectCollection(GarnetDatabase db, ILogger logger = null) { var scratchBufferManager = new ScratchBufferBuilder(); db.ObjectStoreCollectionDbStorageSession = - new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, Logger); + new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, db.VectorManager, Logger); } ExecuteHashCollect(db.ObjectStoreCollectionDbStorageSession); @@ -722,7 +725,7 @@ private static void ExecuteSortedSetCollect(StorageSession storageSession) if (db.MainStoreExpiredKeyDeletionDbStorageSession == null) { var scratchBufferManager = new ScratchBufferBuilder(); - db.MainStoreExpiredKeyDeletionDbStorageSession = new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, Logger); + db.MainStoreExpiredKeyDeletionDbStorageSession = new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, db.VectorManager, Logger); } var scanFrom = StoreWrapper.store.Log.ReadOnlyAddress; @@ -738,7 +741,7 @@ private static void ExecuteSortedSetCollect(StorageSession storageSession) if (db.ObjectStoreExpiredKeyDeletionDbStorageSession == null) { var scratchBufferManager = new ScratchBufferBuilder(); - db.ObjectStoreExpiredKeyDeletionDbStorageSession = new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, Logger); + db.ObjectStoreExpiredKeyDeletionDbStorageSession = new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, db.VectorManager, Logger); } var scanFrom = StoreWrapper.objectStore.Log.ReadOnlyAddress; @@ -778,7 +781,7 @@ private HybridLogScanMetrics CollectHybridLogStats>(sessionFunctions); diff --git a/libs/server/Databases/IDatabaseManager.cs b/libs/server/Databases/IDatabaseManager.cs index 54b6ebcc716..8959648f334 100644 --- a/libs/server/Databases/IDatabaseManager.cs +++ b/libs/server/Databases/IDatabaseManager.cs @@ -274,5 +274,12 @@ public Task TaskCheckpointBasedOnAofSizeLimitAsync(long aofSizeLimit, Cancellati /// /// public (HybridLogScanMetrics mainStore, HybridLogScanMetrics objectStore)[] CollectHybridLogStats(); + + /// + /// Perform any recovery necessary for Vector Sets. + /// + /// Must be called after checkpoints and AOF are recovered, as Vector Sets may make modifications to the log. + /// + public void RecoverVectorSets(); } } \ No newline at end of file diff --git a/libs/server/Databases/MultiDatabaseManager.cs b/libs/server/Databases/MultiDatabaseManager.cs index 768c9ed66d8..c4885ed42f3 100644 --- a/libs/server/Databases/MultiDatabaseManager.cs +++ b/libs/server/Databases/MultiDatabaseManager.cs @@ -147,6 +147,9 @@ public override void RecoverCheckpoint(bool replicaRecover = false, bool recover if (StoreWrapper.serverOptions.FailOnRecoveryError) throw new GarnetException("Main store and object store checkpoint versions do not match"); } + + // Once everything is setup, initialize the VectorManager + db.VectorManager.Initialize(); } } @@ -711,7 +714,7 @@ public override FunctionsState CreateFunctionsState(int dbId = 0, byte respProto throw new GarnetException($"Database with ID {dbId} was not found."); return new(db.AppendOnlyFile, db.VersionMap, StoreWrapper.customCommandManager, null, db.ObjectStoreSizeTracker, - StoreWrapper.GarnetObjectSerializer, respProtocolVersion); + StoreWrapper.GarnetObjectSerializer, db.VectorManager, respProtocolVersion); } /// @@ -1054,6 +1057,21 @@ private void UpdateLastSaveData(int dbId, long? storeTailAddress, long? objectSt } } + /// + public override void RecoverVectorSets() + { + var databasesMapSnapshot = databases.Map; + + var activeDbIdsMapSize = activeDbIds.ActualSize; + var activeDbIdsMapSnapshot = activeDbIds.Map; + + for (var i = 0; i < activeDbIdsMapSize; i++) + { + var dbId = activeDbIdsMapSnapshot[i]; + databasesMapSnapshot[dbId].VectorManager.ResumePostRecovery(); + } + } + public override void Dispose() { if (Disposed) return; diff --git a/libs/server/Databases/SingleDatabaseManager.cs b/libs/server/Databases/SingleDatabaseManager.cs index dc8d5f8a29b..eb3fa3375a7 100644 --- a/libs/server/Databases/SingleDatabaseManager.cs +++ b/libs/server/Databases/SingleDatabaseManager.cs @@ -107,6 +107,9 @@ public override void RecoverCheckpoint(bool replicaRecover = false, bool recover if (StoreWrapper.serverOptions.FailOnRecoveryError) throw new GarnetException("Main store and object store checkpoint versions do not match"); } + + // Once everything is setup, initialize the VectorManager + defaultDatabase.VectorManager.Initialize(); } /// @@ -384,7 +387,7 @@ public override FunctionsState CreateFunctionsState(int dbId = 0, byte respProto ArgumentOutOfRangeException.ThrowIfNotEqual(dbId, 0); return new(AppendOnlyFile, VersionMap, StoreWrapper.customCommandManager, null, ObjectStoreSizeTracker, - StoreWrapper.GarnetObjectSerializer, respProtocolVersion); + StoreWrapper.GarnetObjectSerializer, DefaultDatabase.VectorManager, respProtocolVersion); } private async Task TryPauseCheckpointsContinuousAsync(int dbId, @@ -430,6 +433,12 @@ private void SafeTruncateAOF(AofEntryType entryType, bool unsafeTruncateLog) } } + /// + public override void RecoverVectorSets() + { + defaultDatabase.VectorManager.ResumePostRecovery(); + } + public override void Dispose() { if (Disposed) return; diff --git a/libs/server/Garnet.server.csproj b/libs/server/Garnet.server.csproj index 2c351e80f45..dc679f37e8f 100644 --- a/libs/server/Garnet.server.csproj +++ b/libs/server/Garnet.server.csproj @@ -22,6 +22,7 @@ + \ No newline at end of file diff --git a/libs/server/GarnetDatabase.cs b/libs/server/GarnetDatabase.cs index 41eb4784f6d..ef3788c7e85 100644 --- a/libs/server/GarnetDatabase.cs +++ b/libs/server/GarnetDatabase.cs @@ -100,6 +100,14 @@ public class GarnetDatabase : IDisposable /// public SingleWriterMultiReaderLock CheckpointingLock; + /// + /// Per-DB VectorManager + /// + /// Contexts, metadata, and associated namespaces are DB-specific, and meaningless + /// outside of the container DB. + /// + public readonly VectorManager VectorManager; + /// /// Storage session intended for store-wide object collection operations /// @@ -124,7 +132,7 @@ public GarnetDatabase(int id, TsavoriteKV objectStore, LightEpoch epoch, StateMachineDriver stateMachineDriver, CacheSizeTracker objectStoreSizeTracker, IDevice aofDevice, TsavoriteLog appendOnlyFile, - bool mainStoreIndexMaxedOut, bool objectStoreIndexMaxedOut) : this() + bool mainStoreIndexMaxedOut, bool objectStoreIndexMaxedOut, VectorManager vectorManager) : this() { Id = id; MainStore = mainStore; @@ -136,6 +144,7 @@ public GarnetDatabase(int id, TsavoriteKV + /// Header for Garnet Main Store inputs but for Vector element r/w/d ops + /// + public struct VectorInput : IStoreInput + { + public int SerializedLength => throw new NotImplementedException(); + + public int ReadDesiredSize { get; set; } + + public int WriteDesiredSize { get; set; } + + public int Index { get; set; } + public nint CallbackContext { get; set; } + public nint Callback { get; set; } + + public VectorInput() + { + } + + public unsafe int CopyTo(byte* dest, int length) => throw new NotImplementedException(); + public unsafe int DeserializeFrom(byte* src) => throw new NotImplementedException(); + } } \ No newline at end of file diff --git a/libs/server/Resp/AdminCommands.cs b/libs/server/Resp/AdminCommands.cs index 73851314355..fa134a1498f 100644 --- a/libs/server/Resp/AdminCommands.cs +++ b/libs/server/Resp/AdminCommands.cs @@ -703,7 +703,7 @@ private bool NetworkProcessClusterCommand(RespCommand command) return AbortWithErrorMessage(CmdStrings.RESP_ERR_GENERIC_CLUSTER_DISABLED); } - clusterSession.ProcessClusterCommands(command, ref parseState, ref dcurr, ref dend); + clusterSession.ProcessClusterCommands(command, storageSession.vectorManager, ref parseState, ref dcurr, ref dend); return true; } diff --git a/libs/server/Resp/BasicCommands.cs b/libs/server/Resp/BasicCommands.cs index 6cc37408b4a..2755fcf5e74 100644 --- a/libs/server/Resp/BasicCommands.cs +++ b/libs/server/Resp/BasicCommands.cs @@ -29,14 +29,17 @@ bool NetworkGET(ref TGarnetApi storageApi) if (useAsync) return NetworkGETAsync(ref storageApi); - RawStringInput input = default; + RawStringInput input = new(RespCommand.GET, arg1: -1); - var key = parseState.GetArgSliceByRef(0).SpanByte; + ref var key = ref parseState.GetArgSliceByRef(0); var o = new SpanByteAndMemory(dcurr, (int)(dend - dcurr)); - var status = storageApi.GET(ref key, ref input, ref o); + var status = storageApi.GET(key, ref input, ref o); switch (status) { + case GarnetStatus.WRONGTYPE: + WriteError(CmdStrings.RESP_ERR_WRONG_TYPE); + break; case GarnetStatus.OK: if (!o.IsSpanByte) SendAndReset(o.Memory, o.Length); @@ -175,7 +178,7 @@ bool NetworkGET_SG(ref TGarnetApi storageApi) where TGarnetApi : IGarnetAdvancedApi { var key = parseState.GetArgSliceByRef(0).SpanByte; - RawStringInput input = default; + RawStringInput input = new(RespCommand.GET, arg1: -1); var firstPending = -1; (GarnetStatus, SpanByteAndMemory)[] outputArr = null; SpanByteAndMemory o = new(dcurr, (int)(dend - dcurr)); @@ -278,10 +281,10 @@ private bool NetworkSET(ref TGarnetApi storageApi) where TGarnetApi : IGarnetApi { Debug.Assert(parseState.Count == 2); - var key = parseState.GetArgSliceByRef(0).SpanByte; - var value = parseState.GetArgSliceByRef(1).SpanByte; + var key = parseState.GetArgSliceByRef(0); + var value = parseState.GetArgSliceByRef(1); - storageApi.SET(ref key, ref value); + storageApi.SET(key, value); while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) SendAndReset(); @@ -296,9 +299,9 @@ private bool NetworkGETSET(ref TGarnetApi storageApi) where TGarnetApi : IGarnetApi { Debug.Assert(parseState.Count == 2); - var key = parseState.GetArgSliceByRef(0).SpanByte; + var key = parseState.GetArgSliceByRef(0); - return NetworkSET_Conditional(RespCommand.SET, 0, ref key, true, + return NetworkSET_Conditional(RespCommand.SET, 0, key, true, false, false, ref storageApi); } @@ -377,7 +380,7 @@ private bool NetworkGetRange(ref TGarnetApi storageApi) private bool NetworkSETEX(bool highPrecision, ref TGarnetApi storageApi) where TGarnetApi : IGarnetApi { - var key = parseState.GetArgSliceByRef(0).SpanByte; + var key = parseState.GetArgSliceByRef(0); // Validate expiry if (!parseState.TryGetInt(1, out var expiry)) @@ -398,7 +401,7 @@ private bool NetworkSETEX(bool highPrecision, ref TGarnetApi storage var sbVal = parseState.GetArgSliceByRef(2).SpanByte; var input = new RawStringInput(RespCommand.SETEX, 0, valMetadata); - _ = storageApi.SET(ref key, ref input, ref sbVal); + _ = storageApi.SET(key, ref input, ref sbVal); while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) SendAndReset(); @@ -418,10 +421,9 @@ private bool NetworkSETNX(bool highPrecision, ref TGarnetApi storage } var key = parseState.GetArgSliceByRef(0); - var sbKey = key.SpanByte; var input = new RawStringInput(RespCommand.SETEXNX, ref parseState, startIdx: 1); - var status = storageApi.SET_Conditional(ref sbKey, ref input); + var status = storageApi.SET_Conditional(key, ref input); // The status returned for SETNX as NOTFOUND is the expected status in the happy path var retVal = status == GarnetStatus.NOTFOUND ? 1 : 0; @@ -573,14 +575,14 @@ private bool NetworkSETEXNX(ref TGarnetApi storageApi) { case ExistOptions.None: return getValue || withEtag - ? NetworkSET_Conditional(RespCommand.SET, expiry, ref sbKey, getValue, + ? NetworkSET_Conditional(RespCommand.SET, expiry, key, getValue, isHighPrecision, withEtag, ref storageApi) - : NetworkSET_EX(RespCommand.SET, expOption, expiry, ref sbKey, ref sbVal, ref storageApi); // Can perform a blind update + : NetworkSET_EX(RespCommand.SET, expOption, expiry, key, ref sbVal, ref storageApi); // Can perform a blind update case ExistOptions.XX: - return NetworkSET_Conditional(RespCommand.SETEXXX, expiry, ref sbKey, + return NetworkSET_Conditional(RespCommand.SETEXXX, expiry, key, getValue, isHighPrecision, withEtag, ref storageApi); case ExistOptions.NX: - return NetworkSET_Conditional(RespCommand.SETEXNX, expiry, ref sbKey, + return NetworkSET_Conditional(RespCommand.SETEXNX, expiry, key, getValue, isHighPrecision, withEtag, ref storageApi); } break; @@ -590,13 +592,13 @@ private bool NetworkSETEXNX(ref TGarnetApi storageApi) { case ExistOptions.None: // We can never perform a blind update due to KEEPTTL - return NetworkSET_Conditional(RespCommand.SETKEEPTTL, expiry, ref sbKey + return NetworkSET_Conditional(RespCommand.SETKEEPTTL, expiry, key , getValue, highPrecision: false, withEtag, ref storageApi); case ExistOptions.XX: - return NetworkSET_Conditional(RespCommand.SETKEEPTTLXX, expiry, ref sbKey, + return NetworkSET_Conditional(RespCommand.SETKEEPTTLXX, expiry, key, getValue, highPrecision: false, withEtag, ref storageApi); case ExistOptions.NX: - return NetworkSET_Conditional(RespCommand.SETEXNX, expiry, ref sbKey, + return NetworkSET_Conditional(RespCommand.SETEXNX, expiry, key, getValue, highPrecision: false, withEtag, ref storageApi); } break; @@ -608,7 +610,7 @@ private bool NetworkSETEXNX(ref TGarnetApi storageApi) } private unsafe bool NetworkSET_EX(RespCommand cmd, ExpirationOption expOption, int expiry, - ref SpanByte key, ref SpanByte val, ref TGarnetApi storageApi) + ArgSlice key, ref SpanByte val, ref TGarnetApi storageApi) where TGarnetApi : IGarnetApi { Debug.Assert(cmd == RespCommand.SET); @@ -621,14 +623,14 @@ private unsafe bool NetworkSET_EX(RespCommand cmd, ExpirationOption var input = new RawStringInput(cmd, 0, valMetadata); - storageApi.SET(ref key, ref input, ref val); + storageApi.SET(key, ref input, ref val); while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) SendAndReset(); return true; } - private bool NetworkSET_Conditional(RespCommand cmd, int expiry, ref SpanByte key, bool getValue, bool highPrecision, bool withEtag, ref TGarnetApi storageApi) + private bool NetworkSET_Conditional(RespCommand cmd, int expiry, ArgSlice key, bool getValue, bool highPrecision, bool withEtag, ref TGarnetApi storageApi) where TGarnetApi : IGarnetApi { var inputArg = expiry == 0 @@ -645,7 +647,7 @@ private bool NetworkSET_Conditional(RespCommand cmd, int expiry, ref // the following debug assertion is the catch any edge case leading to SETIFMATCH, or SETIFGREATER skipping the above block Debug.Assert(cmd is not (RespCommand.SETIFMATCH or RespCommand.SETIFGREATER), "SETIFMATCH should have gone though pointing to right output variable"); - var status = storageApi.SET_Conditional(ref key, ref input); + var status = storageApi.SET_Conditional(key, ref input); // KEEPTTL without flags doesn't care whether it was found or not. if (cmd == RespCommand.SETKEEPTTL) @@ -684,7 +686,7 @@ private bool NetworkSET_Conditional(RespCommand cmd, int expiry, ref // anything with getValue or withEtag always writes to the buffer in the happy path SpanByteAndMemory outputBuffer = new SpanByteAndMemory(dcurr, (int)(dend - dcurr)); - GarnetStatus status = storageApi.SET_Conditional(ref key, ref input, ref outputBuffer); + GarnetStatus status = storageApi.SET_Conditional(key, ref input, ref outputBuffer); // The data will be on the buffer either when we know the response is ok or when the withEtag flag is set. bool ok = status != GarnetStatus.NOTFOUND || withEtag; diff --git a/libs/server/Resp/BasicEtagCommands.cs b/libs/server/Resp/BasicEtagCommands.cs index 59ef098eaa7..2fee440918d 100644 --- a/libs/server/Resp/BasicEtagCommands.cs +++ b/libs/server/Resp/BasicEtagCommands.cs @@ -22,10 +22,10 @@ private bool NetworkGETWITHETAG(ref TGarnetApi storageApi) { Debug.Assert(parseState.Count == 1); - var key = parseState.GetArgSliceByRef(0).SpanByte; + var key = parseState.GetArgSliceByRef(0); var input = new RawStringInput(RespCommand.GETWITHETAG); var output = new SpanByteAndMemory(dcurr, (int)(dend - dcurr)); - var status = storageApi.GET(ref key, ref input, ref output); + var status = storageApi.GET(key, ref input, ref output); switch (status) { @@ -53,10 +53,10 @@ private bool NetworkGETIFNOTMATCH(ref TGarnetApi storageApi) { Debug.Assert(parseState.Count == 2); - var key = parseState.GetArgSliceByRef(0).SpanByte; + var key = parseState.GetArgSliceByRef(0); var input = new RawStringInput(RespCommand.GETIFNOTMATCH, ref parseState, startIdx: 1); var output = new SpanByteAndMemory(dcurr, (int)(dend - dcurr)); - var status = storageApi.GET(ref key, ref input, ref output); + var status = storageApi.GET(key, ref input, ref output); switch (status) { @@ -213,9 +213,9 @@ private bool NetworkSetETagConditional(RespCommand cmd, ref TGarnetA return true; } - SpanByte key = parseState.GetArgSliceByRef(0).SpanByte; + var key = parseState.GetArgSliceByRef(0); - NetworkSET_Conditional(cmd, expiry, ref key, getValue: !noGet, highPrecision: expOption == ExpirationOption.PX, withEtag: true, ref storageApi); + NetworkSET_Conditional(cmd, expiry, key, getValue: !noGet, highPrecision: expOption == ExpirationOption.PX, withEtag: true, ref storageApi); return true; } diff --git a/libs/server/Resp/CmdStrings.cs b/libs/server/Resp/CmdStrings.cs index cd3263aa808..e8c5ba5fb9e 100644 --- a/libs/server/Resp/CmdStrings.cs +++ b/libs/server/Resp/CmdStrings.cs @@ -440,6 +440,7 @@ static partial class CmdStrings public static ReadOnlySpan publish => "PUBLISH"u8; public static ReadOnlySpan spublish => "SPUBLISH"u8; public static ReadOnlySpan mtasks => "MTASKS"u8; + public static ReadOnlySpan reserve => "RESERVE"u8; public static ReadOnlySpan aofsync => "AOFSYNC"u8; public static ReadOnlySpan appendlog => "APPENDLOG"u8; public static ReadOnlySpan attach_sync => "ATTACH_SYNC"u8; diff --git a/libs/server/Resp/GarnetDatabaseSession.cs b/libs/server/Resp/GarnetDatabaseSession.cs index 0e52d40d9c1..1eed9e96553 100644 --- a/libs/server/Resp/GarnetDatabaseSession.cs +++ b/libs/server/Resp/GarnetDatabaseSession.cs @@ -8,13 +8,19 @@ namespace Garnet.server SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; using LockableGarnetApi = GarnetApi, SpanByteAllocator>>, LockableContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + LockableContext, + SpanByteAllocator>>>; /// /// Represents a logical database session in Garnet diff --git a/libs/server/Resp/KeyAdminCommands.cs b/libs/server/Resp/KeyAdminCommands.cs index 812617a3a57..1e9e18efefe 100644 --- a/libs/server/Resp/KeyAdminCommands.cs +++ b/libs/server/Resp/KeyAdminCommands.cs @@ -99,8 +99,6 @@ bool NetworkRESTORE(ref TGarnetApi storageApi) var valArgSlice = scratchBufferBuilder.CreateArgSlice(val); - var sbKey = key.SpanByte; - parseState.InitializeWithArgument(valArgSlice); RawStringInput input; @@ -114,7 +112,7 @@ bool NetworkRESTORE(ref TGarnetApi storageApi) input = new RawStringInput(RespCommand.SETEXNX, ref parseState); } - var status = storageApi.SET_Conditional(ref sbKey, ref input); + var status = storageApi.SET_Conditional(key, ref input); if (status is GarnetStatus.NOTFOUND) { diff --git a/libs/server/Resp/LocalServerSession.cs b/libs/server/Resp/LocalServerSession.cs index b3283504041..3bf4a4ca1c5 100644 --- a/libs/server/Resp/LocalServerSession.cs +++ b/libs/server/Resp/LocalServerSession.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Diagnostics; using Microsoft.Extensions.Logging; using Tsavorite.core; @@ -12,7 +13,10 @@ namespace Garnet.server SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; /// /// Local server session @@ -47,8 +51,11 @@ public LocalServerSession(StoreWrapper storeWrapper) // Initialize session-local scratch buffer of size 64 bytes, used for constructing arguments in GarnetApi this.scratchBufferBuilder = new ScratchBufferBuilder(); + var dbRes = storeWrapper.TryGetOrAddDatabase(0, out var database, out _); + Debug.Assert(dbRes, "Should always be able to get DB 0"); + // Create storage session and API - this.storageSession = new StorageSession(storeWrapper, scratchBufferBuilder, sessionMetrics, LatencyMetrics, dbId: 0, logger); + this.storageSession = new StorageSession(storeWrapper, scratchBufferBuilder, sessionMetrics, LatencyMetrics, dbId: 0, database.VectorManager, logger); this.BasicGarnetApi = new BasicGarnetApi(storageSession, storageSession.basicContext, storageSession.objectStoreBasicContext); } diff --git a/libs/server/Resp/MGetReadArgBatch.cs b/libs/server/Resp/MGetReadArgBatch.cs index 899113d5dfa..77bcfc36006 100644 --- a/libs/server/Resp/MGetReadArgBatch.cs +++ b/libs/server/Resp/MGetReadArgBatch.cs @@ -44,7 +44,7 @@ public readonly int Count /// public readonly void GetInput(int i, out RawStringInput input) - => input = default; + => input = new(RespCommand.GET, arg1: -1); /// public readonly void GetKey(int i, out SpanByte key) @@ -132,10 +132,10 @@ private readonly bool HasGoneAsync /// public readonly void GetInput(int i, out RawStringInput input) { - input = default; - // Save the index so we can order async completions correctly in the response - input.arg1 = i; + // + // Use a - so we get "include RESP protocol"-behavior + input = new(RespCommand.GET, arg1: -(i + 1)); } /// @@ -277,7 +277,7 @@ public readonly unsafe void CompletePending(ref TGarnetApi storageAp while (iter.Next()) { - var rawIndex = (int)iter.Current.Input.arg1; + var rawIndex = -(int)iter.Current.Input.arg1 - 1; var shiftedIndex = rawIndex - asyncOffset; var asyncStatus = iter.Current.Status; diff --git a/libs/server/Resp/Parser/ParseUtils.cs b/libs/server/Resp/Parser/ParseUtils.cs index 14d6e0f5edc..02e9a2c41ca 100644 --- a/libs/server/Resp/Parser/ParseUtils.cs +++ b/libs/server/Resp/Parser/ParseUtils.cs @@ -130,6 +130,44 @@ public static bool TryReadDouble(ref ArgSlice slice, out double number, bool can return canBeInfinite && RespReadUtils.TryReadInfinity(sbNumber, out number); } + /// + /// Read a signed 32-bit float from a given ArgSlice. + /// + /// Source + /// Allow reading an infinity + /// + /// Parsed double + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float ReadFloat(ref ArgSlice slice, bool canBeInfinite) + { + if (!TryReadFloat(ref slice, out var number, canBeInfinite)) + { + RespParsingException.ThrowNotANumber(slice.ptr, slice.length); + } + return number; + } + + /// + /// Try to read a signed 32-bit float from a given ArgSlice. + /// + /// Source + /// Result + /// Allow reading an infinity + /// + /// True if float parsed successfully + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool TryReadFloat(ref ArgSlice slice, out float number, bool canBeInfinite) + { + var sbNumber = slice.ReadOnlySpan; + if (Utf8Parser.TryParse(sbNumber, out number, out var bytesConsumed) && + bytesConsumed == sbNumber.Length) + return true; + + return canBeInfinite && RespReadUtils.TryReadInfinity(sbNumber, out number); + } + /// /// Read an ASCII string from a given ArgSlice. /// diff --git a/libs/server/Resp/Parser/RespCommand.cs b/libs/server/Resp/Parser/RespCommand.cs index cc81121b1df..3ed6ce1554a 100644 --- a/libs/server/Resp/Parser/RespCommand.cs +++ b/libs/server/Resp/Parser/RespCommand.cs @@ -81,6 +81,15 @@ public enum RespCommand : ushort SUNION, TTL, TYPE, + VCARD, + VDIM, + VEMB, + VGETATTR, + VINFO, + VISMEMBER, + VLINKS, + VRANDMEMBER, + VSIM, WATCH, WATCHMS, WATCHOS, @@ -195,6 +204,9 @@ public enum RespCommand : ushort SUNIONSTORE, SWAPDB, UNLINK, + VADD, + VREM, + VSETATTR, ZADD, ZCOLLECT, ZDIFFSTORE, @@ -374,6 +386,7 @@ public enum RespCommand : ushort CLUSTER_SPUBLISH, CLUSTER_REPLICAS, CLUSTER_REPLICATE, + CLUSTER_RESERVE, CLUSTER_RESET, CLUSTER_SEND_CKPT_FILE_SEGMENT, CLUSTER_SEND_CKPT_METADATA, @@ -627,6 +640,12 @@ public static bool IsClusterSubCommand(this RespCommand cmd) bool inRange = test <= (RespCommand.CLUSTER_SYNC - RespCommand.CLUSTER_ADDSLOTS); return inRange; } + + /// + /// Returns true if this command can operate on a Vector Set. + /// + public static bool IsLegalOnVectorSet(this RespCommand cmd) + => cmd is RespCommand.DEL or server.RespCommand.UNLINK or RespCommand.TYPE or RespCommand.DEBUG or RespCommand.RENAME or RespCommand.RENAMENX or RespCommand.VADD or RespCommand.VCARD or RespCommand.VDIM or RespCommand.VEMB or RespCommand.VGETATTR or RespCommand.VINFO or server.RespCommand.VISMEMBER or RespCommand.VLINKS or RespCommand.VRANDMEMBER or RespCommand.VREM or RespCommand.VSETATTR or RespCommand.VSIM; } /// @@ -961,6 +980,29 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan } break; + case 'V': + if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nVADD\r\n"u8)) + { + return RespCommand.VADD; + } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nVDIM\r\n"u8)) + { + return RespCommand.VDIM; + } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nVEMB\r\n"u8)) + { + return RespCommand.VEMB; + } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nVREM\r\n"u8)) + { + return RespCommand.VREM; + } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nVSIM\r\n"u8)) + { + return RespCommand.VSIM; + } + break; + case 'Z': if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nZADD\r\n"u8)) { @@ -1141,6 +1183,17 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan } break; + case 'V': + if (*(ulong*)(ptr + 3) == MemoryMarshal.Read("\nVCARD\r\n"u8)) + { + return RespCommand.VCARD; + } + else if (*(ulong*)(ptr + 3) == MemoryMarshal.Read("\nVINFO\r\n"u8)) + { + return RespCommand.VINFO; + } + break; + case 'W': if (*(ulong*)(ptr + 3) == MemoryMarshal.Read("\nWATCH\r\n"u8)) { @@ -1335,6 +1388,13 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan } break; + case 'V': + if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("VLINKS\r\n"u8)) + { + return RespCommand.VLINKS; + } + break; + case 'Z': if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("ZCOUNT\r\n"u8)) { @@ -1510,6 +1570,14 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan { return RespCommand.SPUBLISH; } + else if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("VGETATTR"u8) && *(ushort*)(ptr + 12) == MemoryMarshal.Read("\r\n"u8)) + { + return RespCommand.VGETATTR; + } + else if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("VSETATTR"u8) && *(ushort*)(ptr + 12) == MemoryMarshal.Read("\r\n"u8)) + { + return RespCommand.VSETATTR; + } break; case 9: if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("SUBSCRIB"u8) && *(uint*)(ptr + 11) == MemoryMarshal.Read("BE\r\n"u8)) @@ -1548,6 +1616,10 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan { return RespCommand.ZEXPIREAT; } + else if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("VISMEMBE"u8) && *(uint*)(ptr + 11) == MemoryMarshal.Read("ER\r\n"u8)) + { + return RespCommand.VISMEMBER; + } break; case 10: if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("SSUBSCRI"u8) && *(uint*)(ptr + 11) == MemoryMarshal.Read("BE\r\n"u8)) @@ -1684,6 +1756,10 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan { return RespCommand.ZEXPIRETIME; } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("1\r\nVRAND"u8) && *(ulong*)(ptr + 10) == MemoryMarshal.Read("MEMBER\r\n"u8)) + { + return RespCommand.VRANDMEMBER; + } break; case 12: @@ -2201,6 +2277,10 @@ private RespCommand SlowParseCommand(ReadOnlySpan command, ref int count, { return RespCommand.CLUSTER_MIGRATE; } + else if (subCommand.SequenceEqual(CmdStrings.reserve)) + { + return RespCommand.CLUSTER_RESERVE; + } else if (subCommand.SequenceEqual(CmdStrings.mtasks)) { return RespCommand.CLUSTER_MTASKS; diff --git a/libs/server/Resp/Parser/SessionParseState.cs b/libs/server/Resp/Parser/SessionParseState.cs index e0e523c7ea2..358b37b14fc 100644 --- a/libs/server/Resp/Parser/SessionParseState.cs +++ b/libs/server/Resp/Parser/SessionParseState.cs @@ -163,18 +163,19 @@ public void InitializeWithArguments(ArgSlice arg1, ArgSlice arg2, ArgSlice arg3, } /// - /// Initialize the parse state with a given set of arguments + /// Expand (if necessary) capacity of , preserving contents. /// - /// Set of arguments to initialize buffer with - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void InitializeWithArguments(ArgSlice[] args) + public void EnsureCapacity(int count) { - Initialize(args.Length); - - for (var i = 0; i < args.Length; i++) + if (count <= Count) { - *(bufferPtr + i) = args[i]; + return; } + + var oldBuffer = rootBuffer; + Initialize(count); + + oldBuffer?.AsSpan().CopyTo(rootBuffer); } /// @@ -432,6 +433,28 @@ public bool TryGetDouble(int i, out double value, bool canBeInfinite = true) return ParseUtils.TryReadDouble(ref Unsafe.AsRef(bufferPtr + i), out value, canBeInfinite); } + /// + /// Get float argument at the given index + /// + /// True if double parsed successfully + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float GetFloat(int i, bool canBeInfinite = true) + { + Debug.Assert(i < Count); + return ParseUtils.ReadFloat(ref Unsafe.AsRef(bufferPtr + i), canBeInfinite); + } + + /// + /// Try to get double argument at the given index + /// + /// True if double parsed successfully + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool TryGetFloat(int i, out float value, bool canBeInfinite = true) + { + Debug.Assert(i < Count); + return ParseUtils.TryReadFloat(ref Unsafe.AsRef(bufferPtr + i), out value, canBeInfinite); + } + /// /// Get ASCII string argument at the given index /// diff --git a/libs/server/Resp/RespCommandDocs.cs b/libs/server/Resp/RespCommandDocs.cs index f6adceaecf0..b58578f7371 100644 --- a/libs/server/Resp/RespCommandDocs.cs +++ b/libs/server/Resp/RespCommandDocs.cs @@ -330,6 +330,8 @@ public enum RespCommandGroup : byte String, [Description("transactions")] Transactions, + [Description("vector")] + Vector } /// diff --git a/libs/server/Resp/RespCommandInfoFlags.cs b/libs/server/Resp/RespCommandInfoFlags.cs index e4f391a8613..bfe03845bf7 100644 --- a/libs/server/Resp/RespCommandInfoFlags.cs +++ b/libs/server/Resp/RespCommandInfoFlags.cs @@ -55,6 +55,8 @@ public enum RespCommandFlags Write = 1 << 19, [Description("allow_busy")] AllowBusy = 1 << 20, + [Description("module")] + Module = 1 << 21, } /// @@ -110,6 +112,8 @@ public enum RespAclCategories Garnet = 1 << 21, [Description("custom")] Custom = 1 << 22, + [Description("vector")] + Vector = 1 << 23, [Description("all")] All = (Custom << 1) - 1, } diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs index d64275383af..abcefb4f34d 100644 --- a/libs/server/Resp/RespServerSession.cs +++ b/libs/server/Resp/RespServerSession.cs @@ -25,13 +25,19 @@ namespace Garnet.server SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; using LockableGarnetApi = GarnetApi, SpanByteAllocator>>, LockableContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + LockableContext, + SpanByteAllocator>>>; /// /// RESP server session @@ -220,6 +226,8 @@ public IGarnetServer Server // Threshold for slow log in ticks (0 means disabled) readonly long slowLogThreshold; + private readonly long maximumVectorSetValueBytes; + /// /// Create a new RESP server session /// @@ -282,7 +290,8 @@ public RespServerSession( this.AuthenticateUser(Encoding.ASCII.GetBytes(this.storeWrapper.accessControlList.GetDefaultUserHandle().User.Name)); var cp = clusterProvider ?? storeWrapper.clusterProvider; - clusterSession = cp?.CreateClusterSession(txnManager, this._authenticator, this._userHandle, sessionMetrics, basicGarnetApi, networkSender, logger); + + clusterSession = cp?.CreateClusterSession(txnManager, this._authenticator, this._userHandle, sessionMetrics, basicGarnetApi, storageSession.basicContext, storageSession.vectorContext, networkSender, logger); clusterSession?.SetUserHandle(this._userHandle); sessionScriptCache?.SetUserHandle(this._userHandle); @@ -299,6 +308,8 @@ public RespServerSession( if (this.networkSender.GetMaxSizeSettings?.MaxOutputSize < sizeof(int)) this.networkSender.GetMaxSizeSettings.MaxOutputSize = sizeof(int); } + + maximumVectorSetValueBytes = GarnetServerOptions.ParseSize(storeWrapper.serverOptions.PageSize, out _) - 16; // Just assume header is 16-ish bytes for now } /// @@ -945,6 +956,20 @@ private bool ProcessArrayCommands(RespCommand cmd, ref TGarnetApi st RespCommand.SUNIONSTORE => SetUnionStore(ref storageApi), RespCommand.SDIFF => SetDiff(ref storageApi), RespCommand.SDIFFSTORE => SetDiffStore(ref storageApi), + // Vector Commands + RespCommand.VADD => NetworkVADD(ref storageApi), + RespCommand.VCARD => NetworkVCARD(ref storageApi), + RespCommand.VDIM => NetworkVDIM(ref storageApi), + RespCommand.VEMB => NetworkVEMB(ref storageApi), + RespCommand.VGETATTR => NetworkVGETATTR(ref storageApi), + RespCommand.VINFO => NetworkVINFO(ref storageApi), + RespCommand.VISMEMBER => NetworkVISMEMBER(ref storageApi), + RespCommand.VLINKS => NetworkVLINKS(ref storageApi), + RespCommand.VRANDMEMBER => NetworkVRANDMEMBER(ref storageApi), + RespCommand.VREM => NetworkVREM(ref storageApi), + RespCommand.VSETATTR => NetworkVSETATTR(ref storageApi), + RespCommand.VSIM => NetworkVSIM(ref storageApi), + // Everything else _ => ProcessOtherCommands(cmd, ref storageApi) }; return success; @@ -1331,7 +1356,7 @@ private void Send(byte* d) if ((int)(dcurr - d) > 0) { - // Debug.WriteLine("SEND: [" + Encoding.UTF8.GetString(new Span(d, (int)(dcurr - d))).Replace("\n", "|").Replace("\r", "!") + "]"); + //Debug.WriteLine("SEND: [" + Encoding.UTF8.GetString(new Span(d, (int)(dcurr - d))).Replace("\n", "|").Replace("\r", "!") + "]"); if (waitForAofBlocking) { var task = storeWrapper.WaitForCommitAsync(); @@ -1495,7 +1520,10 @@ private GarnetDatabaseSession TryGetOrSetDatabaseSession(int dbId, out bool succ /// New database session private GarnetDatabaseSession CreateDatabaseSession(int dbId) { - var dbStorageSession = new StorageSession(storeWrapper, scratchBufferBuilder, sessionMetrics, LatencyMetrics, dbId, logger, respProtocolVersion); + var dbRes = storeWrapper.TryGetOrAddDatabase(dbId, out var database, out _); + Debug.Assert(dbRes, "Should always find database if we're switching to it"); + + var dbStorageSession = new StorageSession(storeWrapper, scratchBufferBuilder, sessionMetrics, LatencyMetrics, dbId, database.VectorManager, logger, respProtocolVersion); var dbGarnetApi = new BasicGarnetApi(dbStorageSession, dbStorageSession.basicContext, dbStorageSession.objectStoreBasicContext); var dbLockableGarnetApi = new LockableGarnetApi(dbStorageSession, dbStorageSession.lockableContext, dbStorageSession.objectStoreLockableContext); diff --git a/libs/server/Resp/RespServerSessionSlotVerify.cs b/libs/server/Resp/RespServerSessionSlotVerify.cs index cc278a4e359..bb1f6326164 100644 --- a/libs/server/Resp/RespServerSessionSlotVerify.cs +++ b/libs/server/Resp/RespServerSessionSlotVerify.cs @@ -17,9 +17,10 @@ internal sealed unsafe partial class RespServerSession : ServerSessionBase /// Array of key ArgSlice /// Whether caller is going to perform a readonly or read/write operation /// Key count if different than keys array length + /// Whether the executing command requires the containing slot be STABLE. /// True when ownership is verified, false otherwise - bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, int count = -1) - => clusterSession != null && clusterSession.NetworkKeyArraySlotVerify(keys, readOnly, SessionAsking, ref dcurr, ref dend, count); + bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, bool waitForStableSlot, int count = -1) + => clusterSession != null && clusterSession.NetworkKeyArraySlotVerify(keys, readOnly, SessionAsking, waitForStableSlot, ref dcurr, ref dend, count); /// /// Validate if this command can be served based on the current slot assignment @@ -48,6 +49,7 @@ bool CanServeSlot(RespCommand cmd) storeWrapper.clusterProvider.ExtractKeySpecs(commandInfo, cmd, ref parseState, ref csvi); csvi.readOnly = cmd.IsReadOnly(); csvi.sessionAsking = SessionAsking; + csvi.waitForStableSlot = cmd is RespCommand.VADD or RespCommand.VREM or RespCommand.VSETATTR; return !clusterSession.NetworkMultiKeySlotVerify(ref parseState, ref csvi, ref dcurr, ref dend); } } diff --git a/libs/server/Resp/Vector/DiskANNService.cs b/libs/server/Resp/Vector/DiskANNService.cs new file mode 100644 index 00000000000..5e02c9fa404 --- /dev/null +++ b/libs/server/Resp/Vector/DiskANNService.cs @@ -0,0 +1,416 @@ +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Tsavorite.core; + +namespace Garnet.server +{ + internal sealed unsafe class DiskANNService + { + // Term types. + internal const byte FullVector = 0; + private const byte NeighborList = 1; + private const byte QuantizedVector = 2; + internal const byte Attributes = 3; + private const byte Metadata = 4; + internal const byte InternalIdMap = 5; + private const byte ExternalIdMap = 6; + + public nint CreateIndex( + ulong context, + uint dimensions, + uint reduceDims, + VectorQuantType quantType, + uint buildExplorationFactor, + uint numLinks, + VectorDistanceMetricType distanceMetric, + delegate* unmanaged[Cdecl] readCallback, + delegate* unmanaged[Cdecl] writeCallback, + delegate* unmanaged[Cdecl] deleteCallback, + delegate* unmanaged[Cdecl] readModifyWriteCallback + ) + { + // TODO: actually pass distance metric + + unsafe + { + return NativeDiskANNMethods.create_index(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, (nint)readCallback, (nint)writeCallback, (nint)deleteCallback, (nint)readModifyWriteCallback); + } + } + + public nint RecreateIndex( + ulong context, + uint dimensions, + uint reduceDims, + VectorQuantType quantType, + uint buildExplorationFactor, + uint numLinks, + VectorDistanceMetricType distanceMetricType, + delegate* unmanaged[Cdecl] readCallback, + delegate* unmanaged[Cdecl] writeCallback, + delegate* unmanaged[Cdecl] deleteCallback, + delegate* unmanaged[Cdecl] readModifyWriteCallback + ) + => CreateIndex(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetricType, readCallback, writeCallback, deleteCallback, readModifyWriteCallback); + + public void DropIndex(ulong context, nint index) + { + NativeDiskANNMethods.drop_index(context, index); + } + + public bool Insert(ulong context, nint index, ReadOnlySpan id, VectorValueType vectorType, ReadOnlySpan vector, ReadOnlySpan attributes) + { + var id_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)); + var id_len = id.Length; + + var vector_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(vector)); + int vector_len; + + if (vectorType == VectorValueType.FP32) + { + vector_len = vector.Length / sizeof(float); + } + else if (vectorType == VectorValueType.XB8) + { + vector_len = vector.Length; + } + else + { + throw new NotImplementedException($"{vectorType}"); + } + + var attributes_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(attributes)); + var attributes_len = attributes.Length; + + return NativeDiskANNMethods.insert(context, index, (nint)id_data, (nuint)id_len, vectorType, (nint)vector_data, (nuint)vector_len, (nint)attributes_data, (nuint)attributes_len) == 1; + } + + public bool Remove(ulong context, nint index, ReadOnlySpan id) + { + var id_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)); + var id_len = id.Length; + + return NativeDiskANNMethods.remove(context, index, (nint)id_data, (nuint)id_len) == 1; + } + + public int SearchVector( + ulong context, + nint index, + VectorValueType vectorType, + ReadOnlySpan vector, + float delta, + int searchExplorationFactor, + ReadOnlySpan filter, + int maxFilteringEffort, + SpanByteAndMemory outputIds, + SpanByteAndMemory outputDistances, + out nint continuation + ) + { + var vector_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(vector)); + int vector_len; + + if (vectorType == VectorValueType.FP32) + { + vector_len = vector.Length / sizeof(float); + } + else if (vectorType == VectorValueType.XB8) + { + vector_len = vector.Length; + } + else + { + throw new NotImplementedException($"{vectorType}"); + } + + var filter_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)); + var filter_len = filter.Length; + + void* output_ids; + void* output_distances; + + GCHandle? outputIdsHandle = null; + GCHandle? outputDistancesHandle = null; + try + { + if (!outputIds.IsSpanByte) + { + var getRes = MemoryMarshal.TryGetArray(outputIds.Memory.Memory, out var arrSeg); + Debug.Assert(getRes, "Should always be able to get array to pin"); + + outputIdsHandle = GCHandle.Alloc(arrSeg.Array, GCHandleType.Pinned); + output_ids = Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(arrSeg.Array)); + } + else + { + outputIdsHandle = null; + output_ids = Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputIds.AsSpan())); + } + + var output_ids_len = outputIds.Length; + + if (!outputDistances.IsSpanByte) + { + var getRes = MemoryMarshal.TryGetArray(outputDistances.Memory.Memory, out var arrSeg); + Debug.Assert(getRes, "Should always be able to get array to pin"); + + outputDistancesHandle = GCHandle.Alloc(arrSeg.Array, GCHandleType.Pinned); + output_distances = Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(arrSeg.Array)); + } + else + { + outputDistancesHandle = null; + output_distances = Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputDistances.AsSpan())); + } + + var output_distances_len = outputDistances.Length / sizeof(float); + + continuation = 0; + ref var continuationRef = ref continuation; + var continuationAddr = (nint)Unsafe.AsPointer(ref continuationRef); + + return NativeDiskANNMethods.search_vector( + context, + index, + vectorType, + (nint)vector_data, + (nuint)vector_len, + delta, + searchExplorationFactor, + (nint)filter_data, + (nuint)filter_len, + (nuint)maxFilteringEffort, + (nint)output_ids, + (nuint)output_ids_len, + (nint)output_distances, + (nuint)output_distances_len, + continuationAddr + ); + } + finally + { + outputIdsHandle?.Free(); + outputDistancesHandle?.Free(); + } + } + + public int SearchElement( + ulong context, + nint index, + ReadOnlySpan id, + float delta, + int searchExplorationFactor, + ReadOnlySpan filter, + int maxFilteringEffort, + SpanByteAndMemory outputIds, + SpanByteAndMemory outputDistances, + out nint continuation + ) + { + var id_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)); + var id_len = id.Length; + + var filter_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)); + var filter_len = filter.Length; + + void* output_ids; + void* output_distances; + + GCHandle? outputIdsHandle = null; + GCHandle? outputDistancesHandle = null; + try + { + if (!outputIds.IsSpanByte) + { + var getRes = MemoryMarshal.TryGetArray(outputIds.Memory.Memory, out var arrSeg); + Debug.Assert(getRes, "Should always be able to get array to pin"); + + outputIdsHandle = GCHandle.Alloc(arrSeg.Array, GCHandleType.Pinned); + output_ids = Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(arrSeg.Array)); + } + else + { + outputIdsHandle = null; + output_ids = Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputIds.AsSpan())); + } + + var output_ids_len = outputIds.Length; + + if (!outputDistances.IsSpanByte) + { + var getRes = MemoryMarshal.TryGetArray(outputDistances.Memory.Memory, out var arrSeg); + Debug.Assert(getRes, "Should always be able to get array to pin"); + + outputDistancesHandle = GCHandle.Alloc(arrSeg.Array, GCHandleType.Pinned); + output_distances = Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(arrSeg.Array)); + } + else + { + outputDistancesHandle = null; + output_distances = Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputDistances.AsSpan())); + } + + var output_distances_len = outputDistances.Length / sizeof(float); + + continuation = 0; + ref var continuationRef = ref continuation; + var continuationAddr = (nint)Unsafe.AsPointer(ref continuationRef); + + return NativeDiskANNMethods.search_element( + context, + index, + (nint)id_data, + (nuint)id_len, + delta, + searchExplorationFactor, + (nint)filter_data, + (nuint)filter_len, + (nuint)maxFilteringEffort, + (nint)output_ids, + (nuint)output_ids_len, + (nint)output_distances, + (nuint)output_distances_len, + continuationAddr + ); + } + finally + { + outputIdsHandle?.Free(); + outputDistancesHandle?.Free(); + } + } + + public int ContinueSearch(ulong context, nint index, nint continuation, Span outputIds, Span outputDistances, out nint newContinuation) + { + throw new NotImplementedException(); + } + + public bool CheckInternalIdValid(ulong context, nint index, ReadOnlySpan internalId) + { + var internal_id_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(internalId)); + var internal_id_len = internalId.Length; + + return NativeDiskANNMethods.check_internal_id_valid(context, index, (nint)internal_id_data, (nuint)internal_id_len) == 1; + } + } + + public static partial class NativeDiskANNMethods + { + const string DISKANN_GARNET = "diskann_garnet"; + + [LibraryImport(DISKANN_GARNET)] + public static partial nint create_index( + ulong context, + uint dimensions, + uint reduceDims, + VectorQuantType quantType, + uint buildExplorationFactor, + uint numLinks, + nint readCallback, + nint writeCallback, + nint deleteCallback, + nint readModifyWriteCallback + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial void drop_index( + ulong context, + nint index + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial byte insert( + ulong context, + nint index, + nint id_data, + nuint id_len, + VectorValueType vector_value_type, + nint vector_data, + nuint vector_len, + nint attribute_data, + nuint attribute_len + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial byte remove( + ulong context, + nint index, + nint id_data, + nuint id_len + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial byte set_attribute( + ulong context, + nint index, + nint id_data, + nuint id_len, + nint attribute_data, + nuint attribute_len + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial int search_vector( + ulong context, + nint index, + VectorValueType vector_value_type, + nint vector_data, + nuint vector_len, + float delta, + int search_exploration_factor, + nint filter_data, + nuint filter_len, + nuint max_filtering_effort, + nint output_ids, + nuint output_ids_len, + nint output_distances, + nuint output_distances_len, + nint continuation + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial int search_element( + ulong context, + nint index, + nint id_data, + nuint id_len, + float delta, + int search_exploration_factor, + nint filter_data, + nuint filter_len, + nuint max_filtering_effort, + nint output_ids, + nuint output_ids_len, + nint output_distances, + nuint output_distances_len, + nint continuation + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial int continue_search( + ulong context, + nint index, + nint continuation, + nint output_ids, + nuint output_ids_len, + nint output_distances, + nuint output_distances_len, + nint new_continuation + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial ulong card( + ulong context, + nint index + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial byte check_internal_id_valid( + ulong context, + nint index, + nint internal_id, + nuint internal_id_len + ); + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/RespServerSessionVectors.cs b/libs/server/Resp/Vector/RespServerSessionVectors.cs new file mode 100644 index 00000000000..eb60f50a2db --- /dev/null +++ b/libs/server/Resp/Vector/RespServerSessionVectors.cs @@ -0,0 +1,1273 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Runtime.InteropServices; +using Garnet.common; +using Tsavorite.core; + +namespace Garnet.server +{ + internal sealed unsafe partial class RespServerSession : ServerSessionBase + { + private bool NetworkVADD(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + // VADD key [REDUCE dim] (FP32 | XB8 | VALUES num) vector element [CAS] [NOQUANT | Q8 | BIN | XPREQ8] [EF build-exploration-factor] [SETATTR attributes] [M numlinks] + // + // XB8 is a non-Redis extension, stands for: eXtension Binary 8-bit values - encodes [0, 255] per dimension + // XPREQ8 is a non-Redis extension, stands for: eXtension PREcalculated Quantization 8-bit - requests no quantization on pre-calculated [0, 255] values + + const int MinM = 4; + const int MaxM = 4_096; + + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // key FP32|VALUES vector element + if (parseState.Count < 4) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + ref var key = ref parseState.GetArgSliceByRef(0); + + var curIx = 1; + + var reduceDim = 0; + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("REDUCE"u8)) + { + curIx++; + if (!parseState.TryGetInt(curIx, out var reduceDimValue) || reduceDimValue <= 0) + { + return AbortWithErrorMessage("REDUCE dimension must be > 0"u8); + } + + reduceDim = reduceDimValue; + curIx++; + } + + var valueType = VectorValueType.Invalid; + byte[] rentedValues = null; + Span values = stackalloc byte[64 * sizeof(float)]; + + try + { + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("FP32"u8)) + { + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + var asBytes = parseState.GetArgSliceByRef(curIx).Span; + if ((asBytes.Length % sizeof(float)) != 0) + { + return AbortWithErrorMessage("ERR invalid vector specification"); + } + + curIx++; + valueType = VectorValueType.FP32; + values = asBytes; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("VALUES"u8)) + { + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + if (!parseState.TryGetInt(curIx, out var valueCount) || valueCount <= 0) + { + return AbortWithErrorMessage("ERR invalid vector specification"); + } + curIx++; + + if (valueCount * sizeof(float) > values.Length) + { + values = rentedValues = ArrayPool.Shared.Rent(valueCount * sizeof(float)); + } + values = values[..(valueCount * sizeof(float))]; + + if (curIx + valueCount > parseState.Count) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + valueType = VectorValueType.FP32; + var floatValues = MemoryMarshal.Cast(values); + + for (var valueIx = 0; valueIx < valueCount; valueIx++) + { + if (!parseState.TryGetFloat(curIx, out floatValues[valueIx])) + { + return AbortWithErrorMessage("ERR invalid vector specification"); + } + + curIx++; + } + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XB8"u8)) + { + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + var asBytes = parseState.GetArgSliceByRef(curIx).Span; + curIx++; + + valueType = VectorValueType.XB8; + values = asBytes; + } + + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + var element = parseState.GetArgSliceByRef(curIx); + curIx++; + + // Order for everything after element is unspecified + var cas = false; + VectorQuantType? quantType = null; + int? buildExplorationFactor = null; + ArgSlice? attributes = null; + int? numLinks = null; + VectorDistanceMetricType? distanceMetric = null; + + while (curIx < parseState.Count) + { + // REDUCE is illegal after values, no matter how specified + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("REDUCE"u8)) + { + return AbortWithErrorMessage("ERR invalid option after element"); + } + + // Look for CAS + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("CAS"u8)) + { + if (cas) + { + return AbortWithErrorMessage("CAS specified multiple times"); + } + + // We ignore CAS, just remember we saw it + cas = true; + curIx++; + + continue; + } + + // Look for quantizer specs + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("NOQUANT"u8)) + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.NoQuant; + curIx++; + + continue; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("Q8"u8)) + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.Q8; + curIx++; + + continue; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("BIN"u8)) + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.Bin; + curIx++; + + continue; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XPREQ8"u8)) + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.XPreQ8; + curIx++; + + continue; + } + + // Look for build-exploration-factor + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("EF"u8)) + { + if (buildExplorationFactor != null) + { + return AbortWithErrorMessage("EF specified multiple times"); + } + + curIx++; + + if (curIx >= parseState.Count) + { + return AbortWithErrorMessage("ERR invalid option after element"); + } + + if (!parseState.TryGetInt(curIx, out var buildExplorationFactorNonNull) || buildExplorationFactorNonNull <= 0) + { + return AbortWithErrorMessage("ERR invalid EF"); + } + + buildExplorationFactor = buildExplorationFactorNonNull; + curIx++; + continue; + } + + // Look for attributes + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("SETATTR"u8)) + { + if (attributes != null) + { + return AbortWithErrorMessage("SETATTR specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithErrorMessage("ERR invalid option after element"); + } + + attributes = parseState.GetArgSliceByRef(curIx); + curIx++; + + // You might think we need to validate attributes, but Redis actually lets anything through + + continue; + } + + // Look for num links + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("M"u8)) + { + if (numLinks != null) + { + return AbortWithErrorMessage("M specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithErrorMessage("ERR invalid option after element"); + } + + if (!parseState.TryGetInt(curIx, out var numLinksNonNull) || numLinksNonNull < MinM || numLinksNonNull > MaxM) + { + return AbortWithErrorMessage("ERR invalid M"); + } + + numLinks = numLinksNonNull; + curIx++; + + continue; + } + + // Didn't recognize this option, error out + return AbortWithErrorMessage("ERR invalid option after element"); + } + + if (key.ReadOnlySpan.IsEmpty) + { + // TODO: this is not a Redis restriction, but once that comes from Replication Keys being in the 0 namespace, we should lift it + return AbortWithErrorMessage("ERR Vector Set key cannot be empty"u8); + } + + // Default unspecified options + quantType ??= VectorQuantType.Q8; + buildExplorationFactor ??= 200; + attributes ??= default; + numLinks ??= 16; + + // TODO: Distance metric specification is an extension - still needs to be implemented + distanceMetric ??= VectorDistanceMetricType.L2; + + // Validate that DiskANN is expected to succeed given data sizes + // + // Note that this goes away in store v2 + if (values.Length > maximumVectorSetValueBytes) + { + WriteError("ERR Vector exceed configured page size"u8); + return true; + } + + if (attributes.Value.Length > maximumVectorSetValueBytes) + { + WriteError("ERR Attribute exceed configured page size"u8); + return true; + } + + if (quantType != VectorQuantType.XPreQ8 && quantType != VectorQuantType.NoQuant) + { + WriteError("ERR Unsupported quantization type"u8); + return true; + } + + // We need to reject these HERE because validation during create_index is very awkward + GarnetStatus res; + VectorManagerResult result; + ReadOnlySpan customErrMsg; + if (quantType == VectorQuantType.XPreQ8 && reduceDim != 0) + { + result = VectorManagerResult.BadParams; + res = GarnetStatus.OK; + customErrMsg = default; + } + else + { + res = storageApi.VectorSetAdd(key, reduceDim, valueType, ArgSlice.FromPinnedSpan(values), element, quantType.Value, buildExplorationFactor.Value, attributes.Value, numLinks.Value, distanceMetric.Value, out result, out customErrMsg); + } + + if (res == GarnetStatus.OK) + { + if (result == VectorManagerResult.OK) + { + if (respProtocolVersion == 3) + { + while (!RespWriteUtils.TryWriteTrue(ref dcurr, dend)) + SendAndReset(); + } + else + { + while (!RespWriteUtils.TryWriteInt32(1, ref dcurr, dend)) + SendAndReset(); + } + } + else if (result == VectorManagerResult.Duplicate) + { + if (respProtocolVersion == 3) + { + while (!RespWriteUtils.TryWriteFalse(ref dcurr, dend)) + SendAndReset(); + } + else + { + while (!RespWriteUtils.TryWriteInt32(0, ref dcurr, dend)) + SendAndReset(); + } + } + else if (result == VectorManagerResult.BadParams) + { + if (customErrMsg.IsEmpty) + { + return AbortWithErrorMessage("ERR asked quantization mismatch with existing vector set"u8); + } + + return AbortWithErrorMessage(customErrMsg); + } + } + else if (res == GarnetStatus.WRONGTYPE) + { + return AbortVectorSetWrongType(); + } + else if (res == GarnetStatus.BADSTATE) + { + return AbortVectorSetPartiallyDeleted(ref key); + } + else + { + return AbortWithErrorMessage($"Unexpected GarnetStatus: {res}"); + } + + return true; + } + finally + { + if (rentedValues != null) + { + ArrayPool.Shared.Return(rentedValues); + } + } + } + + private bool NetworkVSIM(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + const int DefaultResultSetSize = 64; + const int DefaultIdSize = sizeof(ulong); + const int DefaultAttributeSize = 32; + + // VSIM key (ELE | FP32 | XB8 | VALUES num) (vector | element) [WITHSCORES] [WITHATTRIBS] [COUNT num] [EPSILON delta] [EF search-exploration - factor] [FILTER expression][FILTER-EF max - filtering - effort] [TRUTH][NOTHREAD] + // + // XB8 is a non-Redis extension, stands for: eXtension Binary 8-bit values - encodes [0, 255] per dimension + + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + if (parseState.Count < 3) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + ref var key = ref parseState.GetArgSliceByRef(0); + var kind = parseState.GetArgSliceByRef(1); + + var curIx = 2; + + ArgSlice? element; + + VectorValueType valueType = VectorValueType.Invalid; + byte[] rentedValues = null; + try + { + Span values = stackalloc byte[64 * sizeof(float)]; + if (kind.Span.EqualsUpperCaseSpanIgnoringCase("ELE"u8)) + { + element = parseState.GetArgSliceByRef(curIx); + values = default; + curIx++; + } + else + { + element = default; + if (kind.Span.EqualsUpperCaseSpanIgnoringCase("FP32"u8)) + { + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + var asBytes = parseState.GetArgSliceByRef(curIx).Span; + if ((asBytes.Length % sizeof(float)) != 0) + { + return AbortWithErrorMessage("FP32 values must be multiple of 4-bytes in size"); + } + + valueType = VectorValueType.FP32; + values = asBytes; + curIx++; + } + else if (kind.Span.EqualsUpperCaseSpanIgnoringCase("XB8"u8)) + { + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + var asBytes = parseState.GetArgSliceByRef(curIx).Span; + + valueType = VectorValueType.XB8; + values = asBytes; + curIx++; + } + else if (kind.Span.EqualsUpperCaseSpanIgnoringCase("VALUES"u8)) + { + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + if (!parseState.TryGetInt(curIx, out var valueCount) || valueCount <= 0) + { + return AbortWithErrorMessage("VALUES count must > 0"); + } + curIx++; + + if (valueCount * sizeof(float) > values.Length) + { + values = rentedValues = ArrayPool.Shared.Rent(valueCount * sizeof(float)); + } + values = values[..(valueCount * sizeof(float))]; + + if (curIx + valueCount > parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + valueType = VectorValueType.FP32; + var floatValues = MemoryMarshal.Cast(values); + + for (var valueIx = 0; valueIx < valueCount; valueIx++) + { + if (!parseState.TryGetFloat(curIx, out floatValues[valueIx])) + { + return AbortWithErrorMessage("VALUES value must be valid float"); + } + + curIx++; + } + } + else + { + return AbortWithErrorMessage("VSIM expected ELE, FP32, or VALUES"); + } + } + + bool? withScores = null; + bool? withAttributes = null; + int? count = null; + float? delta = null; + int? searchExplorationFactor = null; + ArgSlice? filter = null; + int? maxFilteringEffort = null; + var truth = false; + var noThread = false; + + while (curIx < parseState.Count) + { + // Check for withScores + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("WITHSCORES"u8)) + { + if (withScores != null) + { + return AbortWithErrorMessage("WITHSCORES specified multiple times"); + } + + withScores = true; + curIx++; + continue; + } + + // Check for withAttributes + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("WITHATTRIBS"u8)) + { + if (withAttributes != null) + { + return AbortWithErrorMessage("WITHATTRIBS specified multiple times"); + } + + withAttributes = true; + curIx++; + continue; + } + + // Check for count + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("COUNT"u8)) + { + if (count != null) + { + return AbortWithErrorMessage("COUNT specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + if (!parseState.TryGetInt(curIx, out var countNonNull) || countNonNull < 0) + { + return AbortWithErrorMessage("COUNT must be integer >= 0"); + } + + count = countNonNull; + curIx++; + continue; + } + + // Check for delta + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("EPSILON"u8)) + { + if (delta != null) + { + return AbortWithErrorMessage("EPSILON specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + if (!parseState.TryGetFloat(curIx, out var deltaNonNull) || deltaNonNull <= 0) + { + return AbortWithErrorMessage("EPSILON must be float > 0"); + } + + delta = deltaNonNull; + curIx++; + continue; + } + + // Check for search exploration factor + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("EF"u8)) + { + if (searchExplorationFactor != null) + { + return AbortWithErrorMessage("EF specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + if (!parseState.TryGetInt(curIx, out var searchExplorationFactorNonNull) || searchExplorationFactorNonNull < 0) + { + return AbortWithErrorMessage("EF must be >= 0"); + } + + searchExplorationFactor = searchExplorationFactorNonNull; + curIx++; + continue; + } + + // Check for filter + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("FILTER"u8)) + { + if (filter != null) + { + return AbortWithErrorMessage("FILTER specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + filter = parseState.GetArgSliceByRef(curIx); + curIx++; + + // TODO: validate filter + + continue; + } + + // Check for max filtering effort + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("FILTER-EF"u8)) + { + if (maxFilteringEffort != null) + { + return AbortWithErrorMessage("FILTER-EF specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + if (!parseState.TryGetInt(curIx, out var maxFilteringEffortNonNull) || maxFilteringEffortNonNull < 0) + { + return AbortWithErrorMessage("FILTER-EF must be >= 0"); + } + + maxFilteringEffort = maxFilteringEffortNonNull; + curIx++; + continue; + } + + // Check for truth + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("TRUTH"u8)) + { + if (truth) + { + + } + + // TODO: should we implement TRUTH? + truth = true; + curIx++; + continue; + } + + // Check for no thread + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("NOTHREAD"u8)) + { + if (noThread) + { + return AbortWithErrorMessage("NOTHREAD specified multiple times"); + } + + // We ignore NOTHREAD + noThread = true; + curIx++; + continue; + } + + // Didn't recognize this option, error out + return AbortWithErrorMessage("Unknown option"); + } + + // Default unspecified options + withScores ??= false; + withAttributes ??= false; + count ??= 10; + delta ??= 2f; + searchExplorationFactor ??= 100; + filter ??= default; + maxFilteringEffort ??= count.Value * 100; + + // TODO: these stackallocs are dangerous, need logic to avoid stack overflow + Span idSpace = stackalloc byte[(DefaultResultSetSize * DefaultIdSize) + (DefaultResultSetSize * sizeof(int))]; + Span distanceSpace = stackalloc float[DefaultResultSetSize]; + Span attributeSpace = withAttributes.Value ? stackalloc byte[(DefaultResultSetSize * DefaultAttributeSize) + (DefaultResultSetSize * sizeof(int))] : default; + + var idResult = SpanByteAndMemory.FromPinnedSpan(idSpace); + var distanceResult = SpanByteAndMemory.FromPinnedSpan(MemoryMarshal.Cast(distanceSpace)); + var attributeResult = SpanByteAndMemory.FromPinnedSpan(attributeSpace); + try + { + + GarnetStatus res; + VectorManagerResult vectorRes; + VectorIdFormat idFormat; + if (!element.HasValue) + { + res = storageApi.VectorSetValueSimilarity(key, valueType, ArgSlice.FromPinnedSpan(values), count.Value, delta.Value, searchExplorationFactor.Value, filter.Value, maxFilteringEffort.Value, withAttributes.Value, ref idResult, out idFormat, ref distanceResult, ref attributeResult, out vectorRes); + } + else + { + res = storageApi.VectorSetElementSimilarity(key, element.Value, count.Value, delta.Value, searchExplorationFactor.Value, filter.Value, maxFilteringEffort.Value, withAttributes.Value, ref idResult, out idFormat, ref distanceResult, ref attributeResult, out vectorRes); + } + + if (res == GarnetStatus.NOTFOUND) + { + // Vector Set does not exist + + while (!RespWriteUtils.TryWriteEmptyArray(ref dcurr, dend)) + SendAndReset(); + } + else if (res == GarnetStatus.OK) + { + if (vectorRes == VectorManagerResult.MissingElement) + { + while (!RespWriteUtils.TryWriteError("Element not in Vector Set"u8, ref dcurr, dend)) + SendAndReset(); + } + else if (vectorRes == VectorManagerResult.OK) + { + if (respProtocolVersion == 3) + { + // TODO: this is rather complicated, so punt for now + throw new NotImplementedException(); + } + else + { + var remainingIds = idResult.AsReadOnlySpan(); + var distancesSpan = MemoryMarshal.Cast(distanceResult.AsReadOnlySpan()); + var remaininingAttributes = withAttributes.Value ? attributeResult.AsReadOnlySpan() : default; + + var arrayItemCount = distancesSpan.Length; + if (withScores.Value) + { + arrayItemCount += distancesSpan.Length; + } + if (withAttributes.Value) + { + arrayItemCount += distancesSpan.Length; + } + + while (!RespWriteUtils.TryWriteArrayLength(arrayItemCount, ref dcurr, dend)) + SendAndReset(); + + for (var resultIndex = 0; resultIndex < distancesSpan.Length; resultIndex++) + { + ReadOnlySpan elementData; + + if (idFormat == VectorIdFormat.I32LengthPrefixed) + { + if (remainingIds.Length < sizeof(int)) + { + throw new GarnetException($"Insufficient bytes for result id length at resultIndex={resultIndex}: {Convert.ToHexString(distanceResult.AsReadOnlySpan())}"); + } + + var elementLen = BinaryPrimitives.ReadInt32LittleEndian(remainingIds); + + if (remainingIds.Length < sizeof(int) + elementLen) + { + throw new GarnetException($"Insufficient bytes for result of length={elementLen} at resultIndex={resultIndex}: {Convert.ToHexString(distanceResult.AsReadOnlySpan())}"); + } + + elementData = remainingIds.Slice(sizeof(int), elementLen); + remainingIds = remainingIds[(sizeof(int) + elementLen)..]; + } + else if (idFormat == VectorIdFormat.FixedI32) + { + if (remainingIds.Length < sizeof(int)) + { + throw new GarnetException($"Insufficient bytes for result id length at resultIndex={resultIndex}: {Convert.ToHexString(distanceResult.AsReadOnlySpan())}"); + } + + elementData = remainingIds[..sizeof(int)]; + remainingIds = remainingIds[sizeof(int)..]; + } + else + { + throw new GarnetException($"Unexpected id format: {idFormat}"); + } + + while (!RespWriteUtils.TryWriteBulkString(elementData, ref dcurr, dend)) + SendAndReset(); + + if (withScores.Value) + { + var distance = distancesSpan[resultIndex]; + + while (!RespWriteUtils.TryWriteDoubleBulkString(distance, ref dcurr, dend)) + SendAndReset(); + } + + if (withAttributes.Value) + { + if (remaininingAttributes.Length < sizeof(int)) + { + throw new GarnetException($"Insufficient bytes for attribute length at resultIndex={resultIndex}: {Convert.ToHexString(attributeResult.AsReadOnlySpan())}"); + } + + var attrLen = BinaryPrimitives.ReadInt32LittleEndian(remaininingAttributes); + var attr = remaininingAttributes.Slice(sizeof(int), attrLen); + remaininingAttributes = remaininingAttributes[(sizeof(int) + attrLen)..]; + + while (!RespWriteUtils.TryWriteBulkString(attr, ref dcurr, dend)) + SendAndReset(); + } + } + } + } + else + { + throw new GarnetException($"Unexpected {nameof(VectorManagerResult)}: {vectorRes}"); + } + } + else if (res == GarnetStatus.WRONGTYPE) + { + return AbortVectorSetWrongType(); + } + else if (res == GarnetStatus.BADSTATE) + { + return AbortVectorSetPartiallyDeleted(ref key); + } + else + { + throw new GarnetException($"Unexpected {nameof(GarnetStatus)}: {res}"); + } + + return true; + } + finally + { + idResult.Memory?.Dispose(); + distanceResult.Memory?.Dispose(); + attributeResult.Memory?.Dispose(); + } + } + finally + { + if (rentedValues != null) + { + ArrayPool.Shared.Return(rentedValues); + } + } + } + + private bool NetworkVEMB(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + const int DefaultResultSetSize = 64; + + // VEMB key element [RAW] + + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + if (parseState.Count < 2 || parseState.Count > 3) + { + return AbortWithWrongNumberOfArguments("VEMB"); + } + + ref var key = ref parseState.GetArgSliceByRef(0); + var elem = parseState.GetArgSliceByRef(1); + + var raw = false; + if (parseState.Count == 3) + { + if (!parseState.GetArgSliceByRef(2).Span.EqualsUpperCaseSpanIgnoringCase("RAW"u8)) + { + return AbortWithErrorMessage("Unexpected option to VEMB"); + } + + raw = true; + } + + // TODO: what do we do here? + if (raw) + { + throw new NotImplementedException(); + } + + Span distanceSpace = stackalloc float[DefaultResultSetSize]; + + var distanceResult = SpanByteAndMemory.FromPinnedSpan(MemoryMarshal.Cast(distanceSpace)); + + try + { + var res = storageApi.VectorSetEmbedding(key, elem, out var quantType, ref distanceResult); + + if (res == GarnetStatus.OK) + { + if (quantType == VectorQuantType.NoQuant) + { + var distanceSpan = MemoryMarshal.Cast(distanceResult.AsReadOnlySpan()); + WriteArrayLength(distanceSpan.Length); + for (var i = 0; i < distanceSpan.Length; i++) + { + WriteDoubleNumeric(distanceSpan[i]); + } + } + else if (quantType == VectorQuantType.XPreQ8) + { + var distanceSpan = distanceResult.AsReadOnlySpan(); + WriteArrayLength(distanceSpan.Length); + for (var i = 0; i < distanceSpan.Length; i++) + { + WriteDoubleNumeric(distanceSpan[i]); + } + } + else + { + throw new GarnetException($"Unsupported quantization type for embedding extraction: {quantType}"); + } + } + else if (res == GarnetStatus.WRONGTYPE) + { + return AbortVectorSetWrongType(); + } + else if (res == GarnetStatus.BADSTATE) + { + return AbortVectorSetPartiallyDeleted(ref key); + } + else + { + while (!RespWriteUtils.TryWriteEmptyArray(ref dcurr, dend)) + SendAndReset(); + } + + return true; + } + finally + { + if (!distanceResult.IsSpanByte) + { + distanceResult.Memory.Dispose(); + } + } + } + + private bool NetworkVCARD(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVDIM(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + if (parseState.Count != 1) + return AbortWithWrongNumberOfArguments("VDIM"); + + var key = parseState.GetArgSliceByRef(0); + + var res = storageApi.VectorSetDimensions(key, out var dimensions); + + if (res == GarnetStatus.NOTFOUND) + { + while (!RespWriteUtils.TryWriteError("ERR Key not found"u8, ref dcurr, dend)) + SendAndReset(); + } + else if (res == GarnetStatus.WRONGTYPE) + { + return AbortVectorSetWrongType(); + } + else if (res == GarnetStatus.BADSTATE) + { + return AbortVectorSetPartiallyDeleted(ref key); + } + else + { + while (!RespWriteUtils.TryWriteInt32(dimensions, ref dcurr, dend)) + SendAndReset(); + } + + return true; + } + + private bool NetworkVGETATTR(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + if (parseState.Count != 2) + { + return AbortWithWrongNumberOfArguments("VGETATTR"); + } + + var key = parseState.GetArgSliceByRef(0); + var element = parseState.GetArgSliceByRef(1); + + // Here we reserve some stack buffer to try to avoid allocations if the attributes are small + // However, if it's not enough, VectorSetGetAttribute will allocate and replace attributesOutput + // and attach a Memory to it - So we need to make sure to dispose of that if it happens + Span attributesBuffer = stackalloc byte[256]; + SpanByteAndMemory attributesOutput = SpanByteAndMemory.FromPinnedSpan(attributesBuffer); + + try + { + var res = storageApi.VectorSetGetAttribute(key, element, ref attributesOutput); + if (res != GarnetStatus.OK) + { + if (res == GarnetStatus.NOTFOUND) + { + WriteNull(); + return true; + } + else if (res == GarnetStatus.WRONGTYPE) + { + return AbortVectorSetWrongType(); + } + else if (res == GarnetStatus.BADSTATE) + { + return AbortVectorSetPartiallyDeleted(ref key); + } + + return AbortWithErrorMessage($"Unexpected GarnetStatus: {res}"); + } + + WriteSimpleString(attributesOutput.AsReadOnlySpan()); + return true; + } + finally + { + attributesOutput.Memory?.Dispose(); + } + } + + private bool NetworkVINFO(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + if (parseState.Count != 1) + { + return AbortWithWrongNumberOfArguments("VINFO"); + } + + var key = parseState.GetArgSliceByRef(0); + var res = storageApi.VectorSetInfo(key, out VectorQuantType quantType, out var distanceMetricType, out var vectorDimensions, out var reducedDimensions, out var buildExplorationFactor, out var numLinks, out var size); + if (res != GarnetStatus.OK) + { + if (res == GarnetStatus.NOTFOUND) + { + WriteNullArray(); + return true; + } + else if (res == GarnetStatus.WRONGTYPE) + { + return AbortVectorSetWrongType(); + } + else if (res == GarnetStatus.BADSTATE) + { + return AbortVectorSetPartiallyDeleted(ref key); + } + + return AbortWithErrorMessage($"Unexpected GarnetStatus: {res}"); + } + + var quantTypeSpan = quantType switch + { + VectorQuantType.NoQuant => "f32"u8, + VectorQuantType.Bin => "bin"u8, + VectorQuantType.Q8 => "q8"u8, + VectorQuantType.XPreQ8 => "xpreq8"u8, + _ => throw new GarnetException($"Invalid VectorQuantType: {quantType}"), + }; + + var distanceMetricTypeSpan = distanceMetricType switch + { + VectorDistanceMetricType.Cosine => "cosine"u8, + VectorDistanceMetricType.InnerProduct => "inner-product"u8, + VectorDistanceMetricType.L2 => "l2"u8, + VectorDistanceMetricType.CosineNormalized => "cosine-normalized"u8, + _ => throw new GarnetException($"Invalid VectorDistanceMetricType: {distanceMetricType}"), + }; + + WriteArrayLength(14); + WriteSimpleString("quant-type"u8); + WriteSimpleString(quantTypeSpan); + WriteSimpleString("distance-metric"u8); + WriteSimpleString(distanceMetricTypeSpan); + WriteSimpleString("input-vector-dimensions"u8); + WriteInt32AsBulkString((int)vectorDimensions); + WriteSimpleString("reduced-dimensions"u8); + WriteInt32AsBulkString((int)reducedDimensions); + WriteSimpleString("build-exploration-factor"u8); + WriteInt32AsBulkString((int)buildExplorationFactor); + WriteSimpleString("num-links"u8); + WriteInt32AsBulkString((int)numLinks); + WriteSimpleString("size"u8); + WriteInt64AsBulkString(size); + return true; + } + + private bool NetworkVISMEMBER(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVLINKS(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVRANDMEMBER(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVREM(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + if (parseState.Count != 2) + return AbortWithWrongNumberOfArguments("VREM"); + + var key = parseState.GetArgSliceByRef(0); + var elem = parseState.GetArgSliceByRef(1); + + var res = storageApi.VectorSetRemove(key, elem); + + if (res == GarnetStatus.BADSTATE) + { + return AbortVectorSetPartiallyDeleted(ref key); + } + else if (res == GarnetStatus.WRONGTYPE) + { + return AbortVectorSetWrongType(); + } + else + { + var resp = res == GarnetStatus.OK ? 1 : 0; + + while (!RespWriteUtils.TryWriteInt32(resp, ref dcurr, dend)) + SendAndReset(); + } + + return true; + } + + private bool NetworkVSETATTR(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool AbortVectorSetPartiallyDeleted(ref ArgSlice key) + { + // TODO: We could _finish_ the delete here... though if we do that we should do it for ALL commands, not just Vector Set commands + // That's more intrusive, and is more of a V2 thing... so lets just give a workaround for now + + while (!RespWriteUtils.TryWriteError("ERR Vector Set is in a partially deleted state - re-execute DEL to complete deletion"u8, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool AbortVectorSetWrongType() + { + // Matches Redis behavior - doesn't indicate the type involved + while (!RespWriteUtils.TryWriteError("WRONGTYPE Operation against a key holding the wrong kind of value"u8, ref dcurr, dend)) + SendAndReset(); + + return true; + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Callbacks.cs b/libs/server/Resp/Vector/VectorManager.Callbacks.cs new file mode 100644 index 00000000000..61b128af600 --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Callbacks.cs @@ -0,0 +1,372 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + /// + /// Methods which calls back into to interact with Garnet. + /// + public sealed partial class VectorManager + { + public unsafe struct VectorReadBatch : IReadArgBatch, IDisposable + { + public int Count { get; } + + private readonly ulong context; + private readonly SpanByte lengthPrefixedKeys; + + public readonly unsafe delegate* unmanaged[Cdecl, SuppressGCTransition] callback; + public readonly nint callbackContext; + + private int currentIndex; + + private int currentLen; + private byte* currentPtr; + + private bool hasPending; + + public VectorReadBatch(nint callback, nint callbackContext, ulong context, uint keyCount, SpanByte lengthPrefixedKeys) + { + this.context = context; + this.lengthPrefixedKeys = lengthPrefixedKeys; + + this.callback = (delegate* unmanaged[Cdecl, SuppressGCTransition])callback; + this.callbackContext = callbackContext; + + currentIndex = 0; + Count = (int)keyCount; + + currentPtr = this.lengthPrefixedKeys.ToPointerWithMetadata(); + currentLen = *(int*)currentPtr; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void AdvanceTo(int i) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + + if (i == currentIndex) + { + return; + } + + // Undo namespace mutation + *(int*)currentPtr = currentLen; + + // Most likely case, we're going one forward + if (i == (currentIndex + 1)) + { + currentPtr += currentLen + sizeof(int); // Skip length prefix too + + Debug.Assert(currentPtr < lengthPrefixedKeys.ToPointerWithMetadata() + lengthPrefixedKeys.Length, "About to access out of bounds data"); + + currentLen = *currentPtr; + + currentIndex = i; + + return; + } + + // Next most likely case, we're going back to the start + currentPtr = lengthPrefixedKeys.ToPointerWithMetadata(); + currentLen = *(int*)currentPtr; + currentIndex = 0; + + if (i == 0) + { + return; + } + + SlowPath(ref this, i); + + // For the case where we're not just scanning or rolling back to 0, just iterate + // + // This should basically never happen + [MethodImpl(MethodImplOptions.NoInlining)] + static void SlowPath(ref VectorReadBatch self, int i) + { + for (var subI = 1; subI <= i; subI++) + { + self.AdvanceTo(subI); + } + } + } + + /// + public void GetKey(int i, out SpanByte key) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + + AdvanceTo(i); + + key = SpanByte.FromPinnedPointer(currentPtr + 3, currentLen + 1); + key.MarkNamespace(); + key.SetNamespaceInPayload((byte)context); + } + + /// + public readonly void GetInput(int i, out VectorInput input) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + + input = default; + input.CallbackContext = callbackContext; + input.Callback = (nint)callback; + input.Index = i; + } + + /// + public readonly void GetOutput(int i, out SpanByte output) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + + // Don't care, won't be used + Unsafe.SkipInit(out output); + } + + /// + public readonly void SetOutput(int i, SpanByte output) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + } + + /// + public void SetStatus(int i, Status status) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + + hasPending |= status.IsPending; + } + + internal readonly void CompletePending(ref TContext objectContext) + where TContext : ITsavoriteContext + { + // Undo mutations + *(int*)currentPtr = currentLen; + + if (hasPending) + { + _ = objectContext.CompletePending(wait: true); + } + } + + /// + public void Dispose() + { + if (currentPtr == null) + { + return; + } + + // Undo mangling of prefix, if any + *(int*)currentPtr = currentLen; + currentPtr = null; + } + } + + private unsafe delegate* unmanaged[Cdecl] ReadCallbackPtr { get; } = &ReadCallbackUnmanaged; + private unsafe delegate* unmanaged[Cdecl] WriteCallbackPtr { get; } = &WriteCallbackUnmanaged; + private unsafe delegate* unmanaged[Cdecl] DeleteCallbackPtr { get; } = &DeleteCallbackUnmanaged; + private unsafe delegate* unmanaged[Cdecl] ReadModifyWriteCallbackPtr { get; } = &ReadModifyWriteCallbackUnmanaged; + + /// + /// Used to thread the active across p/invoke and reverse p/invoke boundaries into DiskANN. + /// + /// Not the most elegent option, but work so long as DiskANN remains single threaded. + /// + [ThreadStatic] + internal static StorageSession ActiveThreadSession; + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + private static unsafe void ReadCallbackUnmanaged( + ulong context, + uint numKeys, + nint keysData, + nuint keysLength, + nint dataCallback, + nint dataCallbackContext + ) + { + // dataCallback takes: index, dataCallbackContext, data pointer, data length, and returns nothing + + var enumerable = new VectorReadBatch(dataCallback, dataCallbackContext, context, numKeys, SpanByte.FromPinnedPointer((byte*)keysData, (int)keysLength)); + try + { + ref var ctx = ref ActiveThreadSession.vectorContext; + + ctx.ReadWithPrefetch(ref enumerable); + + enumerable.CompletePending(ref ctx); + } + finally + { + enumerable.Dispose(); + } + } + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + private static unsafe byte WriteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength, nint writeData, nuint writeLength) + { + var keyWithNamespace = MarkDiskANNKeyWithNamespace(context, keyData, keyLength); + try + { + + ref var ctx = ref ActiveThreadSession.vectorContext; + VectorInput input = default; + var valueSpan = SpanByte.FromPinnedPointer((byte*)writeData, (int)writeLength); + SpanByte outputSpan = default; + + var status = ctx.Upsert(ref keyWithNamespace, ref input, ref valueSpan, ref outputSpan); + if (status.IsPending) + { + CompletePending(ref status, ref outputSpan, ref ctx); + } + + return status.IsCompletedSuccessfully ? (byte)1 : default; + } + finally + { + UnmarkDiskANNKey(keyWithNamespace); + } + } + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + private static unsafe byte DeleteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength) + { + var keyWithNamespace = MarkDiskANNKeyWithNamespace(context, keyData, keyLength); + + try + { + ref var ctx = ref ActiveThreadSession.vectorContext; + + var status = ctx.Delete(ref keyWithNamespace); + Debug.Assert(!status.IsPending, "Deletes should never go async"); + + return status.IsCompletedSuccessfully && status.Found ? (byte)1 : default; + } + finally + { + UnmarkDiskANNKey(keyWithNamespace); + } + } + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + private static unsafe byte ReadModifyWriteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength, nuint writeLength, nint dataCallback, nint dataCallbackContext) + { + var keyWithNamespace = MarkDiskANNKeyWithNamespace(context, keyData, keyLength); + + try + { + ref var ctx = ref ActiveThreadSession.vectorContext; + + VectorInput input = default; + input.Callback = dataCallback; + input.CallbackContext = dataCallbackContext; + input.WriteDesiredSize = (int)writeLength; + + var status = ctx.RMW(ref keyWithNamespace, ref input); + if (status.IsPending) + { + SpanByte ignored = default; + + CompletePending(ref status, ref ignored, ref ctx); + } + + return status.IsCompletedSuccessfully ? (byte)1 : default; + } + finally + { + UnmarkDiskANNKey(keyWithNamespace); + } + } + + private static unsafe bool ReadSizeUnknown(ulong context, ReadOnlySpan key, ref SpanByteAndMemory value) + { + Span distinctKey = stackalloc byte[key.Length + 1]; + var keyWithNamespace = SpanByte.FromPinnedSpan(distinctKey); + keyWithNamespace.MarkNamespace(); + keyWithNamespace.SetNamespaceInPayload((byte)context); + key.CopyTo(keyWithNamespace.AsSpan()); + + ref var ctx = ref ActiveThreadSession.vectorContext; + + while (true) + { + VectorInput input = new(); + input.ReadDesiredSize = -1; + fixed (byte* ptr = value.AsSpan()) + { + SpanByte asSpanByte = new(value.Length, (nint)ptr); + + var status = ctx.Read(ref keyWithNamespace, ref input, ref asSpanByte); + if (status.IsPending) + { + CompletePending(ref status, ref asSpanByte, ref ctx); + } + + if (!status.Found) + { + value.Length = 0; + return false; + } + + if (input.ReadDesiredSize > asSpanByte.Length) + { + value.Memory?.Dispose(); + var newAlloc = MemoryPool.Shared.Rent(input.ReadDesiredSize); + value = new(newAlloc, newAlloc.Memory.Length); + continue; + } + + value.Length = asSpanByte.Length; + return true; + } + } + } + + /// + /// Get a which covers (keyData, keyLength), but has a namespace component based on . + /// + /// Attempts to do this in place. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static unsafe SpanByte MarkDiskANNKeyWithNamespace(ulong context, nint keyData, nuint keyLength) + { + // DiskANN guarantees we have 4-bytes worth of unused data right before the key + var keyPtr = (byte*)keyData; + var keyNamespaceByte = keyPtr - 1; + + // TODO: if/when namespace can be > 4-bytes, we'll need to copy here + + var keyWithNamespace = SpanByte.FromPinnedPointer(keyNamespaceByte, (int)(keyLength + 1)); + keyWithNamespace.MarkNamespace(); + keyWithNamespace.SetNamespaceInPayload((byte)context); + + return keyWithNamespace; + } + + /// + /// Inverse of . + /// + /// Used so DiskANN can keep using the same buffer for multiple calls with the same keys. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static unsafe void UnmarkDiskANNKey(SpanByte keyWithNamespace) + { + var expectedLen = keyWithNamespace.Length - 1; + var start = keyWithNamespace.ToPointerWithMetadata() - 3; + BinaryPrimitives.WriteInt32LittleEndian(new Span(start, 4), expectedLen); + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Cleanup.cs b/libs/server/Resp/Vector/VectorManager.Cleanup.cs new file mode 100644 index 00000000000..ae2bdd3c6f2 --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Cleanup.cs @@ -0,0 +1,409 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers.Binary; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading.Channels; +using System.Threading.Tasks; +using Garnet.common; +using Garnet.networking; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + /// + /// Methods related to cleaning up data after a Vector Set is deleted. + /// + public sealed partial class VectorManager + { + /// + /// Used as part of scanning post-index-delete to cleanup abandoned data. + /// + private sealed class PostDropCleanupFunctions : IScanIteratorFunctions + { + private readonly StorageSession storageSession; + private readonly FrozenSet contexts; + + public PostDropCleanupFunctions(StorageSession storageSession, HashSet contexts) + { + this.contexts = contexts.ToFrozenSet(); + this.storageSession = storageSession; + } + + public bool ConcurrentReader(ref SpanByte key, ref SpanByte value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) + => SingleReader(ref key, ref value, recordMetadata, numberOfRecords, out cursorRecordResult); + + public void OnException(Exception exception, long numberOfRecords) { } + public bool OnStart(long beginAddress, long endAddress) => true; + public void OnStop(bool completed, long numberOfRecords) { } + + public bool SingleReader(ref SpanByte key, ref SpanByte value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) + { + if (key.MetadataSize != 1) + { + // Not Vector Set, ignore + cursorRecordResult = CursorRecordResult.Skip; + return true; + } + + var ns = key.GetNamespaceInPayload(); + var pairedContext = (ulong)ns & ~(ContextStep - 1); + if (!contexts.Contains(pairedContext)) + { + // Vector Set, but not one we're scanning for + cursorRecordResult = CursorRecordResult.Skip; + return true; + } + + // Delete it + var status = storageSession.vectorContext.Delete(ref key, 0); + if (status.IsPending) + { + SpanByte ignored = default; + CompletePending(ref status, ref ignored, ref storageSession.vectorContext); + } + + cursorRecordResult = CursorRecordResult.Accept; + return true; + } + } + + private readonly Channel cleanupTaskChannel; + private readonly Task cleanupTask; + private readonly Func getCleanupSession; + + private async Task RunCleanupTaskAsync() + { + // Each drop index will queue a null object here + // We'll handle multiple at once if possible, but using a channel simplifies cancellation and dispose + await foreach (var ignored in cleanupTaskChannel.Reader.ReadAllAsync()) + { + try + { + HashSet needCleanup; + lock (this) + { + needCleanup = contextMetadata.GetNeedCleanup(); + } + + if (needCleanup == null) + { + // Previous run already got here, so bail + continue; + } + + // TODO: this doesn't work with non-RESP impls... which maybe we don't care about? + using var cleanupSession = (RespServerSession)getCleanupSession(); + if (cleanupSession.activeDbId != dbId && !cleanupSession.TrySwitchActiveDatabaseSession(dbId)) + { + throw new GarnetException($"Could not switch VectorManager cleanup session to {dbId}, initialization failed"); + } + + PostDropCleanupFunctions callbacks = new(cleanupSession.storageSession, needCleanup); + + ref var ctx = ref cleanupSession.storageSession.vectorContext; + + // Scan whole keyspace (sigh) and remove any associated data + // + // We don't really have a choice here, just do it + _ = ctx.Session.Iterate(ref callbacks); + + // Key is mostly ignored when deleting from InProgressDeletes + // So we just need a non-empty one to use with the context + Span basicKeySpan = new byte[1]; + unsafe + { + fixed (byte* basicKeyPtr = basicKeySpan) + { + var basicKey = SpanByte.FromPinnedPointer(basicKeyPtr, basicKeySpan.Length); + + // Generally there will already be removed, but if deletes fail in odd spots there can + // be a little bit to cleanup - so go ahead and do it. + // + // Not really worth optimizing given that we just scanned the whole key space to remove elements + // and that will dominate. + foreach (var cleanedUp in needCleanup) + { + ClearDeleteInProgress(ref ctx, ref basicKey, cleanedUp); + } + } + } + + lock (this) + { + foreach (var cleanedUp in needCleanup) + { + contextMetadata.FinishedCleaningUp(cleanedUp); + } + } + + UpdateContextMetadata(ref ctx); + } + catch (Exception e) + { + logger?.LogError(e, "Failure during background cleanup of deleted vector sets, implies storage leak"); + } + } + } + + /// + /// Called in response to or to update metadata in Tsavorite. + /// + /// Returns false if there is insufficient size for the value. + /// + internal static bool TryUpdateInProgressDeletes(Span updateMessage, ref SpanByte inLogValue, ref RecordInfo recordInfo, ref RMWInfo rmwInfo) + { + var context = BinaryPrimitives.ReadUInt64LittleEndian(updateMessage); + var len = BinaryPrimitives.ReadInt32LittleEndian(updateMessage[sizeof(ulong)..]); + var isAdding = len > 0; + var key = updateMessage[(sizeof(ulong) + sizeof(int))..]; + + Debug.Assert(key.Length == (isAdding ? len : -len), "Key length not expected"); + Debug.Assert(context is >= ContextStep, "Special context not allowed"); + + var remaining = inLogValue.AsSpan(); + while (remaining.Length >= sizeof(ulong) + sizeof(int)) + { + var curCtx = BinaryPrimitives.ReadUInt64LittleEndian(remaining); + + if (curCtx == 0) + { + // Reached uninitialized data + break; + } + + var curLen = BinaryPrimitives.ReadInt32LittleEndian(remaining[sizeof(ulong)..]); + if (curCtx == context) + { + if (isAdding) + { + // Already added, ignore and make no other changes + return true; + } + + // Copy later values to cover the one we're removing + var afterCur = remaining[(sizeof(ulong) + sizeof(int) + curLen)..]; + afterCur.CopyTo(remaining); + + // Clear everything after that so we won't think it's valid + remaining[^(sizeof(ulong) + sizeof(int) + curLen)..].Clear(); + + // Shrink record by removed chunk size + var newSize = inLogValue.TotalSize - (sizeof(ulong) + sizeof(int) + curLen); + rmwInfo.ClearExtraValueLength(ref recordInfo, ref inLogValue, inLogValue.TotalSize); + inLogValue.ShrinkSerializedLength(inLogValue.TotalSize - newSize); + rmwInfo.SetUsedValueLength(ref recordInfo, ref inLogValue, inLogValue.TotalSize); + + return true; + } + + remaining = remaining[(sizeof(ulong) + sizeof(int) + curLen)..]; + } + + if (isAdding) + { + if (remaining.Length < sizeof(ulong) + sizeof(int) + key.Length) + { + return false; + } + + // Not already added, so slap it in + BinaryPrimitives.WriteUInt64LittleEndian(remaining, context); + BinaryPrimitives.WriteInt32LittleEndian(remaining[sizeof(ulong)..], len); + + key.CopyTo(remaining[(sizeof(ulong) + sizeof(int))..]); + + remaining = remaining[(sizeof(ulong) + sizeof(int) + key.Length)..]; + + // Record used length + var newSize = inLogValue.TotalSize - remaining.Length; + rmwInfo.ClearExtraValueLength(ref recordInfo, ref inLogValue, inLogValue.TotalSize); + inLogValue.ShrinkSerializedLength(newSize); + rmwInfo.SetUsedValueLength(ref recordInfo, ref inLogValue, inLogValue.TotalSize); + } + + return true; + } + + /// + /// Before we start smashing a for deletion, records that we started to delete it so we can recover from crashes. + /// + internal bool TryMarkDeleteInProgress(ref TContext ctx, ref SpanByte key, ulong context) + where TContext : ITsavoriteContext + { + Span keySpan = stackalloc byte[2]; + + Span dataSpan = stackalloc byte[sizeof(ulong) + sizeof(int) + key.Length]; + BinaryPrimitives.WriteUInt64LittleEndian(dataSpan, context); + + // Positive length indicates we're adding this to the list + BinaryPrimitives.WriteInt32LittleEndian(dataSpan[sizeof(ulong)..], key.LengthWithoutMetadata); + key.AsReadOnlySpan().CopyTo(dataSpan[(sizeof(ulong) + sizeof(int))..]); + + // 0:0 is ContextMetadata + // 0:1 is InProgressDeletes + var inProgressDeletesKey = SpanByte.FromPinnedSpan(keySpan); + + inProgressDeletesKey.MarkNamespace(); + inProgressDeletesKey.SetNamespaceInPayload(0); + inProgressDeletesKey.AsSpan()[0] = 1; + + VectorInput input = default; + input.Callback = 0; + + // Negative to indicate dynamic-ness + input.WriteDesiredSize = -(sizeof(ulong) + sizeof(int) + key.Length); + unsafe + { + input.CallbackContext = (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(dataSpan)); + } + + var status = ctx.RMW(ref inProgressDeletesKey, ref input); + + if (status.IsPending) + { + SpanByte ignored = default; + CompletePending(ref status, ref ignored, ref ctx); + } + + return status.IsCompletedSuccessfully; + } + + /// + /// Enumerate any deletes of Vector Sets that are in progress. + /// + /// Used with and to recover from interrupted deletes. + /// + internal List<(ReadOnlyMemory Key, ulong Context)> GetDeletesInProgress(StorageSession storageSession) + { + Span keySpan = stackalloc byte[1]; + + // 0:1 is InProgressDeletes, but ReadSizeUnknown will attach the context for us + var inProgressDeletesKey = SpanByte.FromPinnedSpan(keySpan); + + inProgressDeletesKey.AsSpan()[0] = 1; + + SpanByteAndMemory readValue = default; + + List<(ReadOnlyMemory Key, ulong Context)> ret = []; + try + { + ActiveThreadSession = storageSession; + try + { + if (!ReadSizeUnknown(context: 0, keySpan, ref readValue)) + { + return ret; + } + } + finally + { + ActiveThreadSession = null; + } + + var remaining = readValue.AsReadOnlySpan(); + while (remaining.Length >= sizeof(ulong) + sizeof(int)) + { + var ctx = BinaryPrimitives.ReadUInt64LittleEndian(remaining); + if (ctx == 0) + { + // Encountered uninitialized data + break; + } + + var len = BinaryPrimitives.ReadInt32LittleEndian(remaining[sizeof(ulong)..]); + + var key = remaining.Slice(sizeof(ulong) + sizeof(int), len); + + ret.Add((key.ToArray(), ctx)); + + remaining = remaining[(sizeof(ulong) + sizeof(int) + len)..]; + } + + return ret; + } + finally + { + readValue.Memory?.Dispose(); + } + } + + /// + /// After a delete has completed, removes the given key from metadata. + /// + internal void ClearDeleteInProgress(ref TContext ctx, ref SpanByte key, ulong context) + where TContext : ITsavoriteContext + { + Span keySpan = stackalloc byte[2]; + + Span dataSpan = stackalloc byte[sizeof(ulong) + sizeof(int) + key.Length]; + BinaryPrimitives.WriteUInt64LittleEndian(dataSpan, context); + + // Negative length indicates we're removing this from the list + BinaryPrimitives.WriteInt32LittleEndian(dataSpan[sizeof(ulong)..], -key.LengthWithoutMetadata); + key.AsReadOnlySpan().CopyTo(dataSpan[(sizeof(ulong) + sizeof(int))..]); + + // 0:0 is ContextMetadata + // 0:1 is InProgressDeletes + var inProgressDeletesKey = SpanByte.FromPinnedSpan(keySpan); + + inProgressDeletesKey.MarkNamespace(); + inProgressDeletesKey.SetNamespaceInPayload(0); + inProgressDeletesKey.AsSpan()[0] = 1; + + VectorInput input = default; + input.Callback = 0; + + // Negative to indicate dynamic-ness + input.WriteDesiredSize = -(sizeof(ulong) + sizeof(int) + key.Length); + unsafe + { + input.CallbackContext = (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(dataSpan)); + } + + var status = ctx.RMW(ref inProgressDeletesKey, ref input); + + if (status.IsPending) + { + SpanByte ignored = default; + CompletePending(ref status, ref ignored, ref ctx); + } + } + + /// + /// After an index is dropped, called to start the process of removing ancillary data (elements, neighbor lists, attributes, etc.). + /// + internal void CleanupDroppedIndex(ref TContext ctx, ulong context) + where TContext : ITsavoriteContext + { + lock (this) + { + contextMetadata.MarkCleaningUp(context); + } + + UpdateContextMetadata(ref ctx); + + // Wake up cleanup task + var writeRes = cleanupTaskChannel.Writer.TryWrite(null); + Debug.Assert(writeRes, "Request for cleanup failed, this should never happen"); + } + + /// + /// Detects if a Vector Set index read out of the main store is in the middle of being deleted. + /// + private static bool PartiallyDeleted(ReadOnlySpan indexConfig) + { + ReadIndex(indexConfig, out var context, out _, out _, out _, out _, out _, out _, out _, out _); + return context == 0; + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.ContextMetadata.cs b/libs/server/Resp/Vector/VectorManager.ContextMetadata.cs new file mode 100644 index 00000000000..1e1f71ce3cc --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.ContextMetadata.cs @@ -0,0 +1,458 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using Garnet.common; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + /// + /// Methods for managing , which tracks process wide + /// information about different contexts. + /// + /// is persisted to the log when modified, but a copy is kept in memory for rapid access. + /// + public sealed partial class VectorManager + { + /// + /// Used for tracking which contexts are currently active. + /// + [StructLayout(LayoutKind.Explicit, Size = Size)] + internal struct ContextMetadata + { + [InlineArray(64)] + private struct HashSlots + { + private ushort element0; + } + + internal const int Size = + (4 * sizeof(ulong)) + // Bitmaps + (64 * sizeof(ushort)); // HashSlots for assigned contexts + + [FieldOffset(0)] + public ulong Version; + + [FieldOffset(8)] + private ulong inUse; + + [FieldOffset(16)] + private ulong cleaningUp; + + [FieldOffset(24)] + private ulong migrating; + + [FieldOffset(32)] + private HashSlots slots; + + public readonly bool IsInUse(ulong context) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + return (inUse & mask) != 0; + } + + public readonly bool IsMigrating(ulong context) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + return (migrating & mask) != 0; + } + + public readonly HashSet GetNamespacesForHashSlots(HashSet hashSlots) + { + HashSet ret = null; + + var remaining = inUse; + while (remaining != 0) + { + var inUseIx = BitOperations.TrailingZeroCount(remaining); + var inUseMask = 1UL << inUseIx; + + remaining &= ~inUseMask; + + if ((cleaningUp & inUseMask) != 0) + { + // If something is being cleaned up, no reason to migrate it + continue; + } + + var hashSlot = slots[inUseIx]; + if (!hashSlots.Contains(hashSlot)) + { + // Active, but not a target + continue; + } + + ret ??= []; + + var nsStart = ContextStep * (ulong)inUseIx; + for (var i = 0U; i < ContextStep; i++) + { + _ = ret.Add(nsStart + i); + } + } + + return ret; + } + + public readonly ulong NextNotInUse() + { + var ignoringZero = inUse | 1; + + var bit = (ulong)BitOperations.TrailingZeroCount(~ignoringZero & (ulong)-(long)(~ignoringZero)); + + if (bit == 64) + { + throw new GarnetException("All possible Vector Sets allocated"); + } + + var ret = bit * ContextStep; + + return ret; + } + + public bool TryReserveForMigration(int count, out List reserved) + { + var ignoringZero = inUse | 1; + + var available = BitOperations.PopCount(~ignoringZero); + + if (available < count) + { + reserved = null; + return false; + } + + reserved = new(); + for (var i = 0; i < count; i++) + { + var ctx = NextNotInUse(); + reserved.Add(ctx); + + MarkInUse(ctx, ushort.MaxValue); // HashSlot isn't known yet, so use an invalid value + MarkMigrating(ctx); + } + + return true; + } + + public void MarkInUse(ulong context, ushort hashSlot) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + Debug.Assert((inUse & mask) == 0, "About to mark context which is already in use"); + inUse |= mask; + + slots[(int)bitIx] = hashSlot; + + Version++; + } + + public void MarkMigrating(ulong context) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + Debug.Assert((inUse & mask) != 0, "About to mark migrating a context which is not in use"); + Debug.Assert((migrating & mask) == 0, "About to mark migrating a context which is already migrating"); + migrating |= mask; + + Version++; + } + + public void MarkMigrationComplete(ulong context, ushort hashSlot) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + Debug.Assert((inUse & mask) != 0, "Should already be in use"); + Debug.Assert((migrating & mask) != 0, "Should be migrating target"); + Debug.Assert(slots[(int)bitIx] == ushort.MaxValue, "Hash slot should not be known yet"); + + migrating &= ~mask; + + slots[(int)bitIx] = hashSlot; + + Version++; + } + + public void MarkCleaningUp(ulong context) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + Debug.Assert((inUse & mask) != 0, "About to mark for cleanup when not actually in use"); + Debug.Assert((cleaningUp & mask) == 0, "About to mark for cleanup when already marked"); + cleaningUp |= mask; + + // If this slot were migrating, it isn't anymore + migrating &= ~mask; + + // Leave the slot around, we need it + + Version++; + } + + public void FinishedCleaningUp(ulong context) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + Debug.Assert((inUse & mask) != 0, "Cleaned up context which isn't in use"); + Debug.Assert((cleaningUp & mask) != 0, "Cleaned up context not marked for it"); + cleaningUp &= ~mask; + inUse &= ~mask; + + slots[(int)bitIx] = 0; + + Version++; + } + + public readonly HashSet GetNeedCleanup() + { + if (cleaningUp == 0) + { + return null; + } + + var ret = new HashSet(); + + var remaining = cleaningUp; + while (remaining != 0UL) + { + var ix = BitOperations.TrailingZeroCount(remaining); + + _ = ret.Add((ulong)ix * ContextStep); + + remaining &= ~(1UL << (byte)ix); + } + + return ret; + } + + public readonly HashSet GetMigrating() + { + if (migrating == 0) + { + return null; + } + + var ret = new HashSet(); + + var remaining = migrating; + while (remaining != 0UL) + { + var ix = BitOperations.TrailingZeroCount(remaining); + + _ = ret.Add((ulong)ix * ContextStep); + + remaining &= ~(1UL << (byte)ix); + } + + return ret; + } + + /// + public override readonly string ToString() + { + // Just for debugging purposes + + var sb = new StringBuilder(); + sb.AppendLine(); + _ = sb.AppendLine($"Version: {Version}"); + var mask = 1UL; + var ix = 0; + while (mask != 0) + { + var isInUse = (inUse & mask) != 0; + var isMigrating = (migrating & mask) != 0; + var cleanup = (cleaningUp & mask) != 0; + + var hashSlot = this.slots[ix]; + + if (isInUse || isMigrating || cleanup) + { + var ctxStart = (ulong)ix * ContextStep; + var ctxEnd = ctxStart + ContextStep - 1; + + sb.AppendLine($"[{ctxStart:00}-{ctxEnd:00}): {(isInUse ? "in-use " : "")}{(isMigrating ? "migrating " : "")}{(cleanup ? "cleanup" : "")}"); + } + + mask <<= 1; + ix++; + } + + return sb.ToString(); + } + } + + private ContextMetadata contextMetadata; + + /// + /// Get a new unique context for a vector set. + /// + /// This value is guaranteed to not be shared by any other vector set in the store. + /// + private ulong NextVectorSetContext(ushort hashSlot) + { + var start = Stopwatch.GetTimestamp(); + + // TODO: This retry is no good, but will go away when namespaces >= 256 are possible + while (true) + { + // Lock isn't amazing, but _new_ vector set creation should be rare + // So just serializing it all is easier. + try + { + ulong nextFree; + lock (this) + { + nextFree = contextMetadata.NextNotInUse(); + + contextMetadata.MarkInUse(nextFree, hashSlot); + } + return nextFree; + } + catch (Exception e) + { + logger?.LogError(e, "NextContext not available, delaying and retrying"); + } + + if (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(30)) + { + lock (this) + { + if (contextMetadata.GetNeedCleanup() == null) + { + throw new GarnetException("No available Vector Sets contexts to allocate, none scheduled for cleanup"); + } + } + + // Wait a little bit for cleanup to make progress + Thread.Sleep(1_000); + } + else + { + throw new GarnetException("No available Vector Sets contexts to allocate, timeout reached"); + } + } + } + + /// + /// Obtain some number of contexts for migrating Vector Sets. + /// + /// The return contexts are unavailable for other use, but are not yet "live" for visibility purposes. + /// + public bool TryReserveContextsForMigration(ref TContext ctx, int count, out List contexts) + where TContext : ITsavoriteContext + { + lock (this) + { + if (!contextMetadata.TryReserveForMigration(count, out contexts)) + { + contexts = null; + return false; + } + } + + UpdateContextMetadata(ref ctx); + + return true; + } + + /// + /// Called when an index creation succeeds to flush into the store. + /// + private void UpdateContextMetadata(ref TContext ctx) + where TContext : ITsavoriteContext + { + Span keySpan = stackalloc byte[1]; + Span dataSpan = stackalloc byte[ContextMetadata.Size]; + + lock (this) + { + MemoryMarshal.Cast(dataSpan)[0] = contextMetadata; + } + + var key = SpanByte.FromPinnedSpan(keySpan); + + key.MarkNamespace(); + key.SetNamespaceInPayload(0); + + VectorInput input = default; + input.Callback = 0; + input.WriteDesiredSize = ContextMetadata.Size; + unsafe + { + input.CallbackContext = (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(dataSpan)); + } + + var data = SpanByte.FromPinnedSpan(dataSpan); + + var status = ctx.RMW(ref key, ref input); + + if (status.IsPending) + { + SpanByte ignored = default; + CompletePending(ref status, ref ignored, ref ctx); + } + } + + /// + /// Find all namespaces in use by vector sets that are logically members of the given hash slots. + /// + /// Meant for use during migration. + /// + public HashSet GetNamespacesForHashSlots(HashSet hashSlots) + { + lock (this) + { + return contextMetadata.GetNamespacesForHashSlots(hashSlots); + } + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Index.cs b/libs/server/Resp/Vector/VectorManager.Index.cs new file mode 100644 index 00000000000..b82500c69f1 --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Index.cs @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Garnet.common; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + /// + /// Methods for managing , which is the information about an index created by DiskANN. + /// + /// is stored under the "visible" key in the log, and thus is the common entry point + /// for all operations. + /// + public sealed partial class VectorManager + { + [StructLayout(LayoutKind.Explicit, Size = Size)] + private struct Index + { + internal const int Size = 56; + + [FieldOffset(0)] + public ulong Context; + [FieldOffset(8)] + public ulong IndexPtr; + [FieldOffset(16)] + public uint Dimensions; + [FieldOffset(20)] + public uint ReduceDims; + [FieldOffset(24)] + public uint NumLinks; + [FieldOffset(28)] + public uint BuildExplorationFactor; + [FieldOffset(32)] + public VectorQuantType QuantType; + [FieldOffset(36)] + public VectorDistanceMetricType DistanceMetric; + [FieldOffset(40)] + public Guid ProcessInstanceId; + } + + /// + /// Construct a new index, and stash enough data to recover it with . + /// + internal void CreateIndex( + uint dimensions, + uint reduceDims, + VectorQuantType quantType, + uint buildExplorationFactor, + uint numLinks, + VectorDistanceMetricType distanceMetric, + ulong newContext, + nint newIndexPtr, + ref SpanByte indexValue) + { + AssertHaveStorageSession(); + + var indexSpan = indexValue.AsSpan(); + + Debug.Assert((newContext % 8) == 0 && newContext != 0, "Illegal context provided"); + Debug.Assert(Unsafe.SizeOf() == Index.Size, "Constant index size is incorrect"); + + if (indexSpan.Length != Index.Size) + { + logger?.LogCritical("Acquired space for vector set index does not match expectations, {Length} != {Size}", indexSpan.Length, Index.Size); + throw new GarnetException($"Acquired space for vector set index does not match expectations, {indexSpan.Length} != {Index.Size}"); + } + + ref var asIndex = ref Unsafe.As(ref MemoryMarshal.GetReference(indexSpan)); + asIndex.Context = newContext; + asIndex.Dimensions = dimensions; + asIndex.ReduceDims = reduceDims; + asIndex.QuantType = quantType; + asIndex.BuildExplorationFactor = buildExplorationFactor; + asIndex.NumLinks = numLinks; + asIndex.DistanceMetric = distanceMetric; + asIndex.IndexPtr = (ulong)newIndexPtr; + asIndex.ProcessInstanceId = processInstanceId; + } + + /// + /// Recreate an index that was created by a prior instance of Garnet. + /// + /// This implies the index still has element data, but the pointer is garbage. + /// + internal void RecreateIndex(nint newIndexPtr, ref SpanByte indexValue) + { + AssertHaveStorageSession(); + + var indexSpan = indexValue.AsSpan(); + + if (indexSpan.Length != Index.Size) + { + logger?.LogCritical("Acquired space for vector set index does not match expectations, {Length} != {Size}", indexSpan.Length, Index.Size); + throw new GarnetException($"Acquired space for vector set index does not match expectations, {indexSpan.Length} != {Index.Size}"); + } + + ReadIndex(indexSpan, out var context, out _, out _, out _, out _, out _, out _, out _, out var indexProcessInstanceId); + Debug.Assert(processInstanceId != indexProcessInstanceId, "Shouldn't be recreating an index that matched our instance id"); + + ref var asIndex = ref Unsafe.As(ref MemoryMarshal.GetReference(indexSpan)); + asIndex.IndexPtr = (ulong)newIndexPtr; + asIndex.ProcessInstanceId = processInstanceId; + } + + /// + /// Drop an index previously constructed with . + /// + internal void DropIndex(ReadOnlySpan indexValue) + { + AssertHaveStorageSession(); + + ReadIndex(indexValue, out var context, out _, out _, out _, out _, out _, out _, out var indexPtr, out var indexProcessInstanceId); + + if (indexProcessInstanceId != processInstanceId) + { + // We never actually spun this index up, so nothing to drop + return; + } + + Service.DropIndex(context, indexPtr); + } + + /// + /// Deconstruct index stored in the value under a Vector Set index key. + /// + public static void ReadIndex( + ReadOnlySpan indexValue, + out ulong context, + out uint dimensions, + out uint reduceDims, + out VectorQuantType quantType, + out uint buildExplorationFactor, + out uint numLinks, + out VectorDistanceMetricType distanceMetric, + out nint indexPtr, + out Guid processInstanceId + ) + { + Debug.Assert(indexValue.Length == Index.Size, $"Index size is incorrect ({indexValue.Length} != {Index.Size}), implies vector set index is probably corrupted"); + + ref var asIndex = ref Unsafe.As(ref MemoryMarshal.GetReference(indexValue)); + + context = asIndex.Context; + dimensions = asIndex.Dimensions; + reduceDims = asIndex.ReduceDims; + quantType = asIndex.QuantType; + buildExplorationFactor = asIndex.BuildExplorationFactor; + numLinks = asIndex.NumLinks; + distanceMetric = asIndex.DistanceMetric; + indexPtr = (nint)asIndex.IndexPtr; + processInstanceId = asIndex.ProcessInstanceId; + + Debug.Assert((context % ContextStep) == 0, $"Context ({context}) not as expected (% 4 == {context % 4}), vector set index is probably corrupted"); + } + + /// + /// Update the context (which defines a range of namespaces) stored in a given index. + /// + /// Doing this also smashes the ProcessInstanceId, so the destination node won't + /// think it's already creating this index. + /// + public static void SetContextForMigration(Span indexValue, ulong newContext) + { + Debug.Assert(newContext != 0, "0 is special, should not be assigning to an index"); + Debug.Assert(indexValue.Length == Index.Size, $"Index size is incorrect ({indexValue.Length} != {Index.Size}), implies vector set index is probably corrupted"); + + ref var asIndex = ref Unsafe.As(ref MemoryMarshal.GetReference(indexValue)); + + asIndex.Context = newContext; + asIndex.ProcessInstanceId = MigratedInstanceId; + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Locking.cs b/libs/server/Resp/Vector/VectorManager.Locking.cs new file mode 100644 index 00000000000..f74c8c69134 --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Locking.cs @@ -0,0 +1,484 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Garnet.common; +using Tsavorite.core; + +namespace Garnet.server +{ + /// + /// Methods managing locking around Vector Sets. + /// + /// Locking is bespoke because of read-like nature of most Vector Set operations, and the re-entrancy implied by DiskANN callbacks. + /// + public sealed partial class VectorManager + { + /// + /// Used to scope a shared lock related to a Vector Set operation. + /// + /// Disposing this releases the lock and exits the storage session context on the current thread. + /// + internal readonly ref struct ReadVectorLock : IDisposable + { + private readonly ref readonly ReadOptimizedLock lockableCtx; + private readonly int lockToken; + + internal ReadVectorLock(ref readonly ReadOptimizedLock lockableCtx, int lockToken) + { + this.lockToken = lockToken; + this.lockableCtx = ref lockableCtx; + } + + /// + public void Dispose() + { + Debug.Assert(ActiveThreadSession != null, "Shouldn't exit context when not in one"); + ActiveThreadSession = null; + + if (Unsafe.IsNullRef(in lockableCtx)) + { + return; + } + + lockableCtx.ReleaseSharedLock(lockToken); + } + } + + /// + /// Used to scope exclusive locks to exclusive Vector Set operation (delete, migrate, etc.). + /// + /// Disposing this releases the lock and exits the storage session context on the current thread. + /// + internal readonly ref struct ExclusiveVectorLock : IDisposable + { + private readonly ref readonly ReadOptimizedLock lockableCtx; + private readonly int lockToken; + + internal ExclusiveVectorLock(ref readonly ReadOptimizedLock lockableCtx, int lockToken) + { + this.lockToken = lockToken; + this.lockableCtx = ref lockableCtx; + } + + /// + public void Dispose() + { + Debug.Assert(ActiveThreadSession != null, "Shouldn't exit context when not in one"); + ActiveThreadSession = null; + + if (Unsafe.IsNullRef(in lockableCtx)) + { + return; + } + + lockableCtx.ReleaseExclusiveLock(lockToken); + } + } + + private readonly ReadOptimizedLock vectorSetLocks; + + /// + /// Returns true for indexes that were created via a previous instance of . + /// + /// Such indexes still have element data, but the index pointer to the DiskANN bits are invalid. + /// + internal bool NeedsRecreate(ReadOnlySpan indexConfig) + { + ReadIndex(indexConfig, out _, out _, out _, out _, out _, out _, out _, out _, out var indexProcessInstanceId); + + return indexProcessInstanceId != processInstanceId; + } + + /// + /// Utility method that will read an vector set index out but not create one. + /// + /// It will however RECREATE one if needed. + /// + /// Returns a disposable that prevents the index from being deleted while undisposed. + /// + internal ReadVectorLock ReadVectorIndex(StorageSession storageSession, ref SpanByte key, ref RawStringInput input, scoped Span indexSpan, out GarnetStatus status) + { + Debug.Assert(indexSpan.Length == IndexSizeBytes, "Insufficient space for index"); + + Debug.Assert(ActiveThreadSession == null, "Shouldn't enter context when already in one"); + ActiveThreadSession = storageSession; + + var keyHash = storageSession.basicContext.GetKeyHash(ref key); + + var indexConfig = SpanByteAndMemory.FromPinnedSpan(indexSpan); + + var readCmd = input.header.cmd; + + while (true) + { + input.header.cmd = readCmd; + input.arg1 = 0; + + vectorSetLocks.AcquireSharedLock(keyHash, out var sharedLockToken); + + GarnetStatus readRes; + try + { + readRes = storageSession.Read_MainStore(ref key, ref input, ref indexConfig, ref storageSession.basicContext); + Debug.Assert(indexConfig.IsSpanByte, "Should never need to move index onto the heap"); + } + catch + { + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + throw; + } + + bool needsRecreate; + if (readRes == GarnetStatus.OK) + { + if (PartiallyDeleted(indexConfig.AsReadOnlySpan())) + { + status = GarnetStatus.BADSTATE; + + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + return default; + } + + needsRecreate = NeedsRecreate(indexConfig.AsReadOnlySpan()); + } + else + { + needsRecreate = false; + } + + if (needsRecreate) + { + if (!vectorSetLocks.TryPromoteSharedLock(keyHash, sharedLockToken, out var exclusiveLockToken)) + { + // Release the SHARED lock if we can't promote and try again + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + continue; + } + + ReadIndex(indexSpan, out var indexContext, out var dims, out var reduceDims, out var quantType, out var buildExplorationFactor, out var numLinks, out var distanceMetric, out _, out _); + + input.arg1 = RecreateIndexArg; + + nint newlyAllocatedIndex; + unsafe + { + newlyAllocatedIndex = Service.RecreateIndex(indexContext, dims, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + } + + input.header.cmd = RespCommand.VADD; + input.arg1 = RecreateIndexArg; + + input.parseState.EnsureCapacity(12); + + // Save off for recreation + input.parseState.SetArgument(10, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref indexContext, 1)))); // Strictly we don't _need_ this, but it keeps everything else aligned nicely + input.parseState.SetArgument(11, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref newlyAllocatedIndex, 1)))); + + GarnetStatus writeRes; + try + { + try + { + writeRes = storageSession.RMW_MainStore(ref key, ref input, ref indexConfig, ref storageSession.basicContext); + + if (writeRes != GarnetStatus.OK) + { + // If we didn't write, drop index so we don't leak it + Service.DropIndex(indexContext, newlyAllocatedIndex); + } + } + catch + { + // Drop to avoid leak on error + Service.DropIndex(indexContext, newlyAllocatedIndex); + throw; + } + } + catch + { + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + + throw; + } + + if (writeRes == GarnetStatus.OK) + { + // Try again so we don't hold an exclusive lock while performing a search + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + continue; + } + else + { + status = writeRes; + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + + return default; + } + } + else if (readRes != GarnetStatus.OK) + { + status = readRes; + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + return default; + } + + status = GarnetStatus.OK; + return new(in vectorSetLocks, sharedLockToken); + } + } + + /// + /// Utility method that will read vector set index out, create one if it doesn't exist, or RECREATE one if needed. + /// + /// Returns a disposable that prevents the index from being deleted while undisposed. + /// + internal ReadVectorLock ReadOrCreateVectorIndex( + StorageSession storageSession, + ref SpanByte key, + ref RawStringInput input, + scoped Span indexSpan, + out GarnetStatus status + ) + { + Debug.Assert(indexSpan.Length == IndexSizeBytes, "Insufficient space for index"); + + Debug.Assert(ActiveThreadSession == null, "Shouldn't enter context when already in one"); + ActiveThreadSession = storageSession; + + var keyHash = storageSession.basicContext.GetKeyHash(ref key); + + var indexConfig = SpanByteAndMemory.FromPinnedSpan(indexSpan); + + while (true) + { + input.arg1 = 0; + + vectorSetLocks.AcquireSharedLock(keyHash, out var sharedLockToken); + + GarnetStatus readRes; + try + { + readRes = storageSession.Read_MainStore(ref key, ref input, ref indexConfig, ref storageSession.basicContext); + Debug.Assert(indexConfig.IsSpanByte, "Should never need to move index onto the heap"); + } + catch + { + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + throw; + } + + bool needsRecreate; + if (readRes == GarnetStatus.OK) + { + if (PartiallyDeleted(indexConfig.AsReadOnlySpan())) + { + status = GarnetStatus.BADSTATE; + + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + return default; + } + + needsRecreate = NeedsRecreate(indexConfig.AsReadOnlySpan()); + } + else + { + needsRecreate = false; + } + + if (readRes == GarnetStatus.NOTFOUND || needsRecreate) + { + if (!vectorSetLocks.TryPromoteSharedLock(keyHash, sharedLockToken, out var exclusiveLockToken)) + { + // Release the SHARED lock if we can't promote and try again + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + continue; + } + + ulong indexContext; + nint newlyAllocatedIndex; + if (needsRecreate) + { + ReadIndex(indexSpan, out indexContext, out var dims, out var reduceDims, out var quantType, out var buildExplorationFactor, out var numLinks, out var distanceMetric, out _, out _); + + input.arg1 = RecreateIndexArg; + + unsafe + { + newlyAllocatedIndex = Service.RecreateIndex(indexContext, dims, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + } + + input.parseState.EnsureCapacity(12); + + // Save off for recreation + input.parseState.SetArgument(10, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref indexContext, 1)))); // Strictly we don't _need_ this, but it keeps everything else aligned nicely + input.parseState.SetArgument(11, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref newlyAllocatedIndex, 1)))); + } + else + { + // Create a new index, grab a new context + + // We must associate the index with a hash slot at creation time to enable future migrations + // TODO: RENAME and friends need to also update this data + var slot = HashSlotUtils.HashSlot(ref key); + + indexContext = NextVectorSetContext(slot); + + var dims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(0).Span); + var reduceDims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(1).Span); + // ValueType is here, skipping during index creation + // Values is here, skipping during index creation + // Element is here, skipping during index creation + var quantizer = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(5).Span); + var buildExplorationFactor = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(6).Span); + // Attributes is here, skipping during index creation + var numLinks = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(8).Span); + var distanceMetric = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(9).Span); + + unsafe + { + newlyAllocatedIndex = Service.CreateIndex(indexContext, dims, reduceDims, quantizer, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + } + + input.parseState.EnsureCapacity(12); + + // Save off for insertion + input.parseState.SetArgument(10, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref indexContext, 1)))); + input.parseState.SetArgument(11, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref newlyAllocatedIndex, 1)))); + } + + GarnetStatus writeRes; + try + { + try + { + writeRes = storageSession.RMW_MainStore(ref key, ref input, ref indexConfig, ref storageSession.basicContext); + + if (writeRes != GarnetStatus.OK) + { + // Insertion failed, drop index + Service.DropIndex(indexContext, newlyAllocatedIndex); + + // If the failure was for a brand new index, free up the context too + if (!needsRecreate) + { + CleanupDroppedIndex(ref ActiveThreadSession.vectorContext, indexContext); + } + } + } + catch + { + if (newlyAllocatedIndex != 0) + { + // Drop to avoid a leak on error + Service.DropIndex(indexContext, newlyAllocatedIndex); + + // If the failure was for a brand new index, free up the context too + if (!needsRecreate) + { + CleanupDroppedIndex(ref ActiveThreadSession.vectorContext, indexContext); + } + } + + throw; + } + + if (!needsRecreate) + { + UpdateContextMetadata(ref storageSession.vectorContext); + } + } + catch + { + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + + throw; + } + + if (writeRes == GarnetStatus.OK) + { + // Try again so we don't hold an exclusive lock while adding a vector (which might be time consuming) + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + continue; + } + else + { + status = writeRes; + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + + return default; + } + } + else if (readRes != GarnetStatus.OK) + { + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + status = readRes; + return default; + } + + status = GarnetStatus.OK; + return new(in vectorSetLocks, sharedLockToken); + } + } + + /// + /// Acquire exclusive lock over a given key. + /// + private ExclusiveVectorLock AcquireExclusiveLocks(StorageSession storageSession, ref SpanByte key) + { + var keyHash = storageSession.lockableContext.GetKeyHash(key); + + vectorSetLocks.AcquireExclusiveLock(keyHash, out var exclusiveLockToken); + + return new(in vectorSetLocks, exclusiveLockToken); + } + + /// + /// Utility method that will read vector set index out, and acquire exclusive locks to allow it to be deleted. + /// + /// If the index is partially deleted, will be set to but the locks will be still acquired. + /// + internal ExclusiveVectorLock ReadForDeleteVectorIndex(StorageSession storageSession, ref SpanByte key, ref RawStringInput input, scoped Span indexSpan, out GarnetStatus status) + { + Debug.Assert(indexSpan.Length == IndexSizeBytes, "Insufficient space for index"); + + Debug.Assert(ActiveThreadSession == null, "Shouldn't enter context when already in one"); + ActiveThreadSession = storageSession; + + var indexConfig = SpanByteAndMemory.FromPinnedSpan(indexSpan); + + // Get the index + var acquiredLock = AcquireExclusiveLocks(storageSession, ref key); + try + { + status = storageSession.Read_MainStore(ref key, ref input, ref indexConfig, ref storageSession.basicContext); + } + catch + { + acquiredLock.Dispose(); + + throw; + } + + if (status == GarnetStatus.OK) + { + // Even if we read the value, it might be in a bad state due to a prior delete + if (PartiallyDeleted(indexConfig.AsReadOnlySpan())) + { + status = GarnetStatus.BADSTATE; + } + } + + return acquiredLock; + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Migration.cs b/libs/server/Resp/Vector/VectorManager.Migration.cs new file mode 100644 index 00000000000..4a9725ec713 --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Migration.cs @@ -0,0 +1,317 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.InteropServices; +using Garnet.common; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + /// + /// Methods related to migrating Vector Sets between different primaries. + /// + /// This is bespoke because normal migration is key based, but Vector Set migration has to move whole namespaces first. + /// + public sealed partial class VectorManager + { + // This is a V8 GUID based on 'GARNET MIGRATION' ASCII string + // It cannot collide with processInstanceIds because it's v8 + // It's unlikely other projects will select the value, so it's unlikely to collide with other v8s + // If it ends up in logs, it's ASCII equivalent looks suspcious enough to lead back here + private static readonly Guid MigratedInstanceId = new("4e524147-5445-8d20-8947-524154494f4e"); + + /// + /// Called to handle a key in a namespace being received during a migration. + /// + /// These keys are what DiskANN stores, that is they are "element" data. + /// + /// The index is handled specially by . + /// + public void HandleMigratedElementKey( + ref BasicContext basicCtx, + ref BasicContext vectorCtx, + ref SpanByte key, + ref SpanByte value + ) + { + Debug.Assert(key.MetadataSize == 1, "Should have namespace if we're migrating a key"); + +#if DEBUG + // Do some extra sanity checking in DEBUG builds + lock (this) + { + var ns = key.GetNamespaceInPayload(); + var context = (ulong)(ns & ~(ContextStep - 1)); + Debug.Assert(contextMetadata.IsInUse(context), "Shouldn't be migrating to an unused context"); + Debug.Assert(contextMetadata.IsMigrating(context), "Shouldn't be migrating to context not marked for it"); + Debug.Assert(!(contextMetadata.GetNeedCleanup()?.Contains(context) ?? false), "Shouldn't be migrating into context being deleted"); + } +#endif + + VectorInput input = default; + SpanByte outputSpan = default; + + var status = vectorCtx.Upsert(ref key, ref input, ref value, ref outputSpan); + if (status.IsPending) + { + CompletePending(ref status, ref outputSpan, ref vectorCtx); + } + + if (!status.IsCompletedSuccessfully) + { + throw new GarnetException("Failed to migrate key, this should fail migration"); + } + + ReplicateMigratedElementKey(ref basicCtx, ref key, ref value, logger); + + // Fake a write for post-migration replication + static void ReplicateMigratedElementKey(ref BasicContext basicCtx, ref SpanByte key, ref SpanByte value, ILogger logger) + { + RawStringInput input = default; + + input.header.cmd = RespCommand.VADD; + input.arg1 = MigrateElementKeyLogArg; + + input.parseState.InitializeWithArguments([ArgSlice.FromPinnedSpan(key.AsReadOnlySpanWithMetadata()), ArgSlice.FromPinnedSpan(value.AsReadOnlySpan())]); + + SpanByte dummyKey = default; + SpanByteAndMemory dummyOutput = default; + + var res = basicCtx.RMW(ref dummyKey, ref input, ref dummyOutput); + + if (res.IsPending) + { + CompletePending(ref res, ref dummyOutput, ref basicCtx); + } + + if (!res.IsCompletedSuccessfully) + { + logger?.LogCritical("Failed to inject replication write for migrated Vector Set key/value into log, result was {res}", res); + throw new GarnetException("Couldn't synthesize Vector Set write operation for key/value migration, data loss may occur"); + } + + // Helper to complete read/writes during vector set synthetic op goes async + static void CompletePending(ref Status status, ref SpanByteAndMemory output, ref BasicContext basicCtx) + { + _ = basicCtx.CompletePendingWithOutputs(out var completedOutputs, wait: true); + var more = completedOutputs.Next(); + Debug.Assert(more); + status = completedOutputs.Current.Status; + output = completedOutputs.Current.Output; + more = completedOutputs.Next(); + Debug.Assert(!more); + completedOutputs.Dispose(); + } + } + } + + /// + /// Called to handle a Vector Set key being received during a migration. These are "index" keys. + /// + /// This is the metadata stuff Garnet creates, DiskANN is not involved. + /// + /// Invoked after all the namespace data is moved via . + /// + public void HandleMigratedIndexKey( + GarnetDatabase db, + StoreWrapper storeWrapper, + ref SpanByte key, + ref SpanByte value) + { + Debug.Assert(key.MetadataSize != 1, "Shouldn't have a namespace if we're migrating a Vector Set index"); + + RawStringInput input = default; + input.header.cmd = RespCommand.VADD; + input.arg1 = RecreateIndexArg; + + ReadIndex(value.AsReadOnlySpan(), out var context, out var dimensions, out var reduceDims, out var quantType, out var buildExplorationFactor, out var numLinks, out var distanceMetric, out _, out var processInstanceId); + + Debug.Assert(processInstanceId == MigratedInstanceId, "Shouldn't receive a real process instance id during a migration"); + + // Extra validation in DEBUG +#if DEBUG + lock (this) + { + Debug.Assert(contextMetadata.IsInUse(context), "Context should be assigned if we're migrating"); + Debug.Assert(contextMetadata.IsMigrating(context), "Context should be marked migrating if we're moving an index key in"); + } +#endif + + // Spin up a new Storage Session is we don't have one + StorageSession newStorageSession; + if (ActiveThreadSession == null) + { + Debug.Assert(db != null, "Must have DB if session is not already set"); + Debug.Assert(storeWrapper != null, "Must have StoreWrapper if session is not already set"); + + ActiveThreadSession = newStorageSession = new StorageSession(storeWrapper, new(), null, null, db.Id, this, this.logger); + } + else + { + newStorageSession = null; + } + + try + { + // Prepare as a psuedo-VADD + var dimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref dimensions, 1))); + var reduceDimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref reduceDims, 1))); + ArgSlice valueTypeArg = default; + ArgSlice valuesArg = default; + ArgSlice elementArg = default; + var quantizerArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref quantType, 1))); + var buildExplorationFactorArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref buildExplorationFactor, 1))); + ArgSlice attributesArg = default; + var numLinksArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref numLinks, 1))); + var distanceMetricArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref distanceMetric, 1))); + + nint newlyAllocatedIndex; + unsafe + { + newlyAllocatedIndex = Service.RecreateIndex(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + } + + var ctxArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref context, 1))); + var indexArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref newlyAllocatedIndex, 1))); + + input.parseState.InitializeWithArguments([dimsArg, reduceDimsArg, valueTypeArg, valuesArg, elementArg, quantizerArg, buildExplorationFactorArg, attributesArg, numLinksArg, distanceMetricArg, ctxArg, indexArg]); + + Span indexSpan = stackalloc byte[Index.Size]; + var indexConfig = SpanByteAndMemory.FromPinnedSpan(indexSpan); + + // Exclusive lock to prevent other modification of this key + + using (AcquireExclusiveLocks(ActiveThreadSession, ref key)) + { + // Perform the write + var writeRes = ActiveThreadSession.RMW_MainStore(ref key, ref input, ref indexConfig, ref ActiveThreadSession.basicContext); + if (writeRes != GarnetStatus.OK) + { + Service.DropIndex(context, newlyAllocatedIndex); + throw new GarnetException("Failed to import migrated Vector Set index, aborting migration"); + } + + var hashSlot = HashSlotUtils.HashSlot(ref key); + + lock (this) + { + contextMetadata.MarkMigrationComplete(context, hashSlot); + } + + UpdateContextMetadata(ref ActiveThreadSession.vectorContext); + + // For REPLICAs which are following, we need to fake up a write + ReplicateMigratedIndexKey(ref ActiveThreadSession.basicContext, ref key, ref value, context, logger); + } + } + finally + { + ActiveThreadSession = null; + + // If we spun up a new storage session, dispose it + newStorageSession?.Dispose(); + } + + // Fake a write for post-migration replication + static void ReplicateMigratedIndexKey( + ref BasicContext basicCtx, + ref SpanByte key, + ref SpanByte value, + ulong context, + ILogger logger) + { + RawStringInput input = default; + + input.header.cmd = RespCommand.VADD; + input.arg1 = MigrateIndexKeyLogArg; + + var contextArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref context, 1))); + + input.parseState.InitializeWithArguments([ArgSlice.FromPinnedSpan(key.AsReadOnlySpanWithMetadata()), ArgSlice.FromPinnedSpan(value.AsReadOnlySpan()), contextArg]); + + SpanByte dummyKey = default; + SpanByteAndMemory dummyOutput = default; + + var res = basicCtx.RMW(ref dummyKey, ref input, ref dummyOutput); + + if (res.IsPending) + { + CompletePending(ref res, ref dummyOutput, ref basicCtx); + } + + if (!res.IsCompletedSuccessfully) + { + logger?.LogCritical("Failed to inject replication write for migrated Vector Set index into log, result was {res}", res); + throw new GarnetException("Couldn't synthesize Vector Set write operation for index migration, data loss may occur"); + } + + // Helper to complete read/writes during vector set synthetic op goes async + static void CompletePending(ref Status status, ref SpanByteAndMemory output, ref BasicContext basicCtx) + { + _ = basicCtx.CompletePendingWithOutputs(out var completedOutputs, wait: true); + var more = completedOutputs.Next(); + Debug.Assert(more); + status = completedOutputs.Current.Status; + output = completedOutputs.Current.Output; + more = completedOutputs.Next(); + Debug.Assert(!more); + completedOutputs.Dispose(); + } + } + } + + /// + /// Find namespaces used by the given keys, IFF they are Vector Sets. They may (and often will) not be. + /// + /// Meant for use during migration. + /// + public unsafe HashSet GetNamespacesForKeys(StoreWrapper storeWrapper, IEnumerable keys, Dictionary vectorSetKeys) + { + // TODO: Ideally we wouldn't make a new session for this, but it's fine for now + using var storageSession = new StorageSession(storeWrapper, new(), null, null, storeWrapper.DefaultDatabase.Id, this, logger); + + HashSet namespaces = null; + + Span indexSpan = stackalloc byte[Index.Size]; + + foreach (var key in keys) + { + fixed (byte* keyPtr = key) + { + var keySpan = SpanByte.FromPinnedPointer(keyPtr, key.Length); + + // Dummy command, we just need something Vector Set-y + RawStringInput input = default; + input.header.cmd = RespCommand.VSIM; + + using (ReadVectorIndex(storageSession, ref keySpan, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + continue; + } + + namespaces ??= []; + + ReadIndex(indexSpan, out var context, out _, out _, out _, out _, out _, out _, out _, out _); + for (var i = 0UL; i < ContextStep; i++) + { + _ = namespaces.Add(context + i); + } + + vectorSetKeys[key] = indexSpan.ToArray(); + } + } + } + + return namespaces; + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Replication.cs b/libs/server/Resp/Vector/VectorManager.Replication.cs new file mode 100644 index 00000000000..eb383b9912a --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Replication.cs @@ -0,0 +1,569 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Garnet.common; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + /// + /// Methods for managing the replication of Vector Sets from primaries to other replicas. + /// + /// This is very bespoke because Vector Set operations are phrased as reads for most things, which + /// bypasses Garnet's usual replication logic. + /// + public sealed partial class VectorManager + { + /// + /// Represents a copy of a VADD being replayed during replication. + /// + private readonly record struct VADDReplicationState(Memory Key, uint Dims, uint ReduceDims, VectorValueType ValueType, Memory Values, Memory Element, VectorQuantType Quantizer, uint BuildExplorationFactor, Memory Attributes, uint NumLinks, VectorDistanceMetricType DistanceMetric) + { + } + + private int replicationReplayStarted; + private CountingEventSlim replicationBlockEvent; + private readonly Channel replicationReplayChannel; + private readonly Task[] replicationReplayTasks; + + private CancellationToken replicationReplayCancellation; + + /// + /// For testing purposes, are the replication replay tasks active. + /// + public bool AreReplicationTasksActive + => replicationReplayCancellation.CanBeCanceled && replicationReplayTasks.Any(static r => !r.IsCompleted); + + /// + /// Hook for to request replication tasks start. + /// + /// The underlying tasks may not be spun up until later, but the provided will be used + /// if the yare. + /// + public async Task StartReplicationTasksAsync(CancellationToken cancellationToken) + { + try + { + await Task.Yield(); + + replicationReplayCancellation = cancellationToken; + + using var cts = new CancellationTokenSource(); + + _ = cancellationToken.Register(() => cts.Cancel()); + + try + { + await Task.Delay(Timeout.InfiniteTimeSpan, cts.Token); + } + catch { } + + var abandoned = ResetReplayTasks(); + logger?.LogInformation("VectorManager replication cancellation abandoned {abandoned} VADDs", abandoned); + } + finally + { + replicationReplayCancellation = default; + } + } + + /// + /// For replication purposes, we need a write against the main log. + /// + /// But we don't actually want to do the (expensive) vector ops as part of a write. + /// + /// So this fakes up a modify operation that we can then intercept as part of replication. + /// + /// This the Primary part, on a Replica runs. + /// + internal void ReplicateVectorSetAdd(ref SpanByte key, ref RawStringInput input, ref TContext context) + where TContext : ITsavoriteContext + { + Debug.Assert(input.header.cmd == RespCommand.VADD, "Shouldn't be called with anything but VADD inputs"); + + var inputCopy = input; + inputCopy.arg1 = VADDAppendLogArg; + + Span keyWithNamespaceBytes = stackalloc byte[key.Length + 1]; + var keyWithNamespace = SpanByte.FromPinnedSpan(keyWithNamespaceBytes); + keyWithNamespace.MarkNamespace(); + keyWithNamespace.SetNamespaceInPayload(0); + key.AsReadOnlySpan().CopyTo(keyWithNamespace.AsSpan()); + + var res = context.RMW(ref keyWithNamespace, ref inputCopy); + + if (res.IsPending) + { + CompletePending(ref res, ref context); + } + + if (!res.IsCompletedSuccessfully) + { + logger?.LogCritical("Failed to inject replication write for VADD into log, result was {res}", res); + throw new GarnetException("Couldn't synthesize Vector Set add operation for replication, data loss will occur"); + } + } + + /// + /// For replication purposes, we need a write against the main log. + /// + /// But we don't actually want to do the (expensive) vector ops as part of a write. + /// + /// So this fakes up a modify operation that we can then intercept as part of replication. + /// + /// This the Primary part, on a Replica runs. + /// + internal void ReplicateVectorSetRemove(ref SpanByte key, ref SpanByte element, ref RawStringInput input, ref TContext context) + where TContext : ITsavoriteContext + { + Debug.Assert(input.header.cmd == RespCommand.VREM, "Shouldn't be called with anything but VREM inputs"); + + var inputCopy = input; + inputCopy.arg1 = VREMAppendLogArg; + + Span keyWithNamespaceBytes = stackalloc byte[key.Length + 1]; + var keyWithNamespace = SpanByte.FromPinnedSpan(keyWithNamespaceBytes); + keyWithNamespace.MarkNamespace(); + keyWithNamespace.SetNamespaceInPayload(0); + key.AsReadOnlySpan().CopyTo(keyWithNamespace.AsSpan()); + + inputCopy.parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(element.AsReadOnlySpan())); + + var res = context.RMW(ref keyWithNamespace, ref inputCopy); + + if (res.IsPending) + { + CompletePending(ref res, ref context); + } + + if (!res.IsCompletedSuccessfully) + { + logger?.LogCritical("Failed to inject replication write for VREM into log, result was {res}", res); + throw new GarnetException("Couldn't synthesize Vector Set remove operation for replication, data loss will occur"); + } + } + + /// + /// After an index is dropped, called to cleanup state injected by + /// + /// Amounts to delete a synthetic key in namespace 0. + /// + internal bool TryDropVectorSetReplicationKey(SpanByte key, ref TContext context) + where TContext : ITsavoriteContext + { + Span keyWithNamespaceBytes = stackalloc byte[key.Length + 1]; + var keyWithNamespace = SpanByte.FromPinnedSpan(keyWithNamespaceBytes); + keyWithNamespace.MarkNamespace(); + keyWithNamespace.SetNamespaceInPayload(0); + key.AsReadOnlySpan().CopyTo(keyWithNamespace.AsSpan()); + + Span dummyBytes = stackalloc byte[4]; + + var res = context.Delete(ref keyWithNamespace); + + if (res.IsPending) + { + CompletePending(ref res, ref context); + } + + return res.IsCompletedSuccessfully; + } + + /// + /// Vector Set adds are phrased as reads (once the index is created), so they require special handling. + /// + /// Operations that are faked up by running on the Primary get diverted here on a Replica. + /// + internal void HandleVectorSetAddReplication(StorageSession currentSession, Func obtainServerSession, ref SpanByte keyWithNamespace, ref RawStringInput input) + { + if (input.arg1 == MigrateElementKeyLogArg) + { + // These are special, injecting by a PRIMARY applying migration operations + // These get replayed on REPLICAs typically, though role changes might still cause these + // to get replayed on now-primary nodes + + var key = input.parseState.GetArgSliceByRef(0).SpanByte; + var value = input.parseState.GetArgSliceByRef(1).SpanByte; + + // TODO: Namespace is present, but not actually transmitted + // This presumably becomes unnecessary in Store v2 + key.MarkNamespace(); + + var ns = key.GetNamespaceInPayload(); + + // REPLICAs wouldn't have seen a reservation message, so allocate this on demand + var ctx = ns & ~(ContextStep - 1); + if (!contextMetadata.IsMigrating(ctx)) + { + var needsUpdate = false; + + lock (this) + { + if (!contextMetadata.IsMigrating(ctx)) + { + contextMetadata.MarkInUse(ctx, ushort.MaxValue); + contextMetadata.MarkMigrating(ctx); + + needsUpdate = true; + } + } + + if (needsUpdate) + { + UpdateContextMetadata(ref currentSession.vectorContext); + } + } + + HandleMigratedElementKey(ref currentSession.basicContext, ref currentSession.vectorContext, ref key, ref value); + return; + } + else if (input.arg1 == MigrateIndexKeyLogArg) + { + // These also injected by a PRIMARY applying migration operations + + var key = input.parseState.GetArgSliceByRef(0).SpanByte; + var value = input.parseState.GetArgSliceByRef(1).SpanByte; + var context = MemoryMarshal.Cast(input.parseState.GetArgSliceByRef(2).Span)[0]; + + // Most of the time a replica will have seen an element moving before now + // but if you a migrate an EMPTY Vector Set that is not necessarily true + // + // So force reservation now + if (!contextMetadata.IsMigrating(context)) + { + var needsUpdate = false; + + lock (this) + { + if (!contextMetadata.IsMigrating(context)) + { + contextMetadata.MarkInUse(context, ushort.MaxValue); + contextMetadata.MarkMigrating(context); + + needsUpdate = true; + } + } + + if (needsUpdate) + { + UpdateContextMetadata(ref currentSession.vectorContext); + } + } + + ActiveThreadSession = currentSession; + try + { + HandleMigratedIndexKey(null, null, ref key, ref value); + } + finally + { + ActiveThreadSession = null; + } + return; + } + + Debug.Assert(input.arg1 == VADDAppendLogArg, "Unexpected operation during replication"); + + // Undo mangling that got replication going + var inputCopy = input; + inputCopy.arg1 = default; + var keyBytesArr = ArrayPool.Shared.Rent(keyWithNamespace.Length - 1); + var keyBytes = keyBytesArr.AsMemory()[..(keyWithNamespace.Length - 1)]; + + keyWithNamespace.AsReadOnlySpan().CopyTo(keyBytes.Span); + + var dims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(0).Span); + var reduceDims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(1).Span); + var valueType = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(2).Span); + var values = input.parseState.GetArgSliceByRef(3).Span; + var element = input.parseState.GetArgSliceByRef(4).Span; + var quantizer = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(5).Span); + var buildExplorationFactor = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(6).Span); + var attributes = input.parseState.GetArgSliceByRef(7).Span; + var numLinks = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(8).Span); + var distanceMetric = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(9).Span); + + // We have to make copies (and they need to be on the heap) to pass to background tasks + var valuesBytes = ArrayPool.Shared.Rent(values.Length).AsMemory()[..values.Length]; + values.CopyTo(valuesBytes.Span); + + var elementBytes = ArrayPool.Shared.Rent(element.Length).AsMemory()[..element.Length]; + element.CopyTo(elementBytes.Span); + + var attributesBytes = ArrayPool.Shared.Rent(attributes.Length).AsMemory()[..attributes.Length]; + attributes.CopyTo(attributesBytes.Span); + + // Spin up replication replay tasks on first use + if (replicationReplayStarted == 0) + { + if (Interlocked.CompareExchange(ref replicationReplayStarted, 1, 0) == 0) + { + StartReplicationReplayTasks(this, obtainServerSession); + } + } + + // We need a running count of pending VADDs so WaitForVectorOperationsToComplete can work + + replicationBlockEvent.Increment(); + var queued = replicationReplayChannel.Writer.TryWrite(new(keyBytes, dims, reduceDims, valueType, valuesBytes, elementBytes, quantizer, buildExplorationFactor, attributesBytes, numLinks, distanceMetric)); + if (!queued) + { + replicationBlockEvent.Decrement(); + } + + static void StartReplicationReplayTasks(VectorManager self, Func obtainServerSession) + { + if (self.dbId != 0) + { + throw new GarnetException($"Unexpected DB ({self.dbId}) in cluster mode, expected 0"); + } + + self.logger?.LogInformation("Starting {numTasks} replication tasks for VADDs", self.replicationReplayTasks.Length); + + for (var i = 0; i < self.replicationReplayTasks.Length; i++) + { + // Allocate session outside of task so we fail "nicely" if something goes wrong with acquiring them + var allocatedSession = obtainServerSession(); + if (allocatedSession.activeDbId != self.dbId && !allocatedSession.TrySwitchActiveDatabaseSession(self.dbId)) + { + allocatedSession.Dispose(); + throw new GarnetException($"Could not switch replication replay session to {self.dbId}, replication will fail"); + } + + self.replicationReplayTasks[i] = Task.Factory.StartNew( + async () => + { + try + { + using (allocatedSession) + { + var reader = self.replicationReplayChannel.Reader; + + SessionParseState reusableParseState = default; + reusableParseState.Initialize(11); + + await foreach (var entry in reader.ReadAllAsync(self.replicationReplayCancellation)) + { + try + { + try + { + ApplyVectorSetAdd(self, allocatedSession.storageSession, entry, ref reusableParseState); + } + finally + { + self.replicationBlockEvent.Decrement(); + } + } + catch + { + self.logger?.LogCritical( + "Faulting ApplyVectorSetAdd ({key}, {dims}, {reducedDims}, {valueType}, 0x{values}, 0x{element}, {quantizer}, {bef}, {attributes}, {numLinks}", + Encoding.UTF8.GetString(entry.Key.Span), + entry.Dims, + entry.ReduceDims, + entry.ValueType, + Convert.ToBase64String(entry.Values.Span), + Convert.ToBase64String(entry.Values.Span), + entry.Quantizer, + entry.BuildExplorationFactor, + Encoding.UTF8.GetString(entry.Attributes.Span), + entry.NumLinks + ); + + throw; + } + } + } + } + catch (OperationCanceledException cancelEx) + { + self.logger?.LogInformation(cancelEx, "ReplicationReplayTask cancelled"); + } + catch (Exception e) + { + self.logger?.LogCritical(e, "Unexpected abort of replication replay task"); + throw; + } + } + ) + .Unwrap(); + } + } + + // Actually apply a replicated VADD + static unsafe void ApplyVectorSetAdd(VectorManager self, StorageSession storageSession, VADDReplicationState state, ref SessionParseState reusableParseState) + { + ref var context = ref storageSession.basicContext; + + var (keyBytes, dims, reduceDims, valueType, valuesBytes, elementBytes, quantizer, buildExplorationFactor, attributesBytes, numLinks, distanceMetric) = state; + try + { + Span indexSpan = stackalloc byte[IndexSizeBytes]; + + fixed (byte* keyPtr = keyBytes.Span) + fixed (byte* valuesPtr = valuesBytes.Span) + fixed (byte* elementPtr = elementBytes.Span) + fixed (byte* attributesPtr = attributesBytes.Span) + { + var key = SpanByte.FromPinnedPointer(keyPtr, keyBytes.Length); + var values = SpanByte.FromPinnedPointer(valuesPtr, valuesBytes.Length); + var element = SpanByte.FromPinnedPointer(elementPtr, elementBytes.Length); + var attributes = SpanByte.FromPinnedPointer(attributesPtr, attributesBytes.Length); + + var indexBytes = stackalloc byte[IndexSizeBytes]; + SpanByteAndMemory indexConfig = new(indexBytes, IndexSizeBytes); + + var dimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref dims, 1))); + var reduceDimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref reduceDims, 1))); + var valueTypeArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref valueType, 1))); + var valuesArg = ArgSlice.FromPinnedSpan(values.AsReadOnlySpan()); + var elementArg = ArgSlice.FromPinnedSpan(element.AsReadOnlySpan()); + var quantizerArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref quantizer, 1))); + var buildExplorationFactorArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref buildExplorationFactor, 1))); + var attributesArg = ArgSlice.FromPinnedSpan(attributes.AsReadOnlySpan()); + var numLinksArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref numLinks, 1))); + var distanceMetricArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref distanceMetric, 1))); + + reusableParseState.InitializeWithArguments([dimsArg, reduceDimsArg, valueTypeArg, valuesArg, elementArg, quantizerArg, buildExplorationFactorArg, attributesArg, numLinksArg, distanceMetricArg]); + + var input = new RawStringInput(RespCommand.VADD, ref reusableParseState); + + // Equivalent to VectorStoreOps.VectorSetAdd + // + // We still need locking here because the replays may proceed in parallel + + using (self.ReadOrCreateVectorIndex(storageSession, ref key, ref input, indexSpan, out var status)) + { + Debug.Assert(status == GarnetStatus.OK, "Replication should only occur when an add is successful, so index must exist"); + + var addRes = self.TryAdd(indexSpan, element.AsReadOnlySpan(), valueType, values.AsReadOnlySpan(), attributes.AsReadOnlySpan(), reduceDims, quantizer, buildExplorationFactor, numLinks, distanceMetric, out _); + + if (addRes != VectorManagerResult.OK) + { + throw new GarnetException("Failed to add to vector set index during AOF sync, this should never happen but will cause data loss if it does"); + } + } + } + } + finally + { + if (MemoryMarshal.TryGetArray(keyBytes, out var toFree)) + { + ArrayPool.Shared.Return(toFree.Array); + } + + if (MemoryMarshal.TryGetArray(valuesBytes, out toFree)) + { + ArrayPool.Shared.Return(toFree.Array); + } + + if (MemoryMarshal.TryGetArray(elementBytes, out toFree)) + { + ArrayPool.Shared.Return(toFree.Array); + } + + if (MemoryMarshal.TryGetArray(attributesBytes, out toFree)) + { + ArrayPool.Shared.Return(toFree.Array); + } + } + } + } + + /// + /// Cancels replication tasks, resetting enough state that they can be resumed by a future call to . + /// + /// Returns the number of abanded VADDs. + /// + private int ResetReplayTasks() + { + Task.WaitAll(replicationReplayTasks); + Array.Fill(replicationReplayTasks, Task.CompletedTask); + + _ = Interlocked.Exchange(ref replicationReplayStarted, 0); + + var abandoned = 0; + while (replicationReplayChannel.Reader.TryRead(out _)) + { + replicationBlockEvent.Decrement(); + abandoned++; + } + + return abandoned; + } + + /// + /// Vector Set removes are phrased as reads (once the index is created), so they require special handling. + /// + /// Operations that are faked up by running on the Primary get diverted here on a Replica. + /// + internal void HandleVectorSetRemoveReplication(StorageSession storageSession, ref SpanByte key, ref RawStringInput input) + { + Span indexSpan = stackalloc byte[IndexSizeBytes]; + var element = input.parseState.GetArgSliceByRef(0); + + // Replication adds a (0) namespace - remove it + Span keyWithoutNamespaceSpan = stackalloc byte[key.Length - 1]; + key.AsReadOnlySpan().CopyTo(keyWithoutNamespaceSpan); + var keyWithoutNamespace = SpanByte.FromPinnedSpan(keyWithoutNamespaceSpan); + + var inputCopy = input; + inputCopy.arg1 = default; + + using (ReadVectorIndex(storageSession, ref keyWithoutNamespace, ref inputCopy, indexSpan, out var status)) + { + Debug.Assert(status == GarnetStatus.OK, "Replication should only occur when a remove is successful, so index must exist"); + + var addRes = TryRemove(indexSpan, element.ReadOnlySpan); + + if (addRes != VectorManagerResult.OK) + { + throw new GarnetException("Failed to remove from vector set index during AOF sync, this should never happen but will cause data loss if it does"); + } + } + } + + /// + /// Wait until all ops passed to have completed. + /// + public void WaitForVectorOperationsToComplete() + { + try + { + replicationBlockEvent.Wait(); + } + catch (ObjectDisposedException) + { + // This is possible during dispose + // + // Dispose already takes pains to drain everything before disposing, so this is safe to ignore + } + } + // Helper to complete read/writes during vector set synthetic op goes async + private static void CompletePending(ref Status status, ref TContext context) + where TContext : ITsavoriteContext + { + _ = context.CompletePendingWithOutputs(out var completedOutputs, wait: true); + var more = completedOutputs.Next(); + Debug.Assert(more); + status = completedOutputs.Current.Status; + more = completedOutputs.Next(); + Debug.Assert(!more); + completedOutputs.Dispose(); + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.cs b/libs/server/Resp/Vector/VectorManager.cs new file mode 100644 index 00000000000..865eeee6216 --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.cs @@ -0,0 +1,937 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Channels; +using System.Threading.Tasks; +using Garnet.common; +using Garnet.networking; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + public enum VectorManagerResult + { + Invalid = 0, + + OK, + BadParams, + Duplicate, + MissingElement, + } + + /// + /// Methods for managing an implementation of various vector operations. + /// + public sealed partial class VectorManager : IDisposable + { + // MUST BE A POWER OF 2 + public const ulong ContextStep = 8; + + internal const int IndexSizeBytes = Index.Size; + internal const long VADDAppendLogArg = long.MinValue; + internal const long DeleteAfterDropArg = VADDAppendLogArg + 1; + internal const long RecreateIndexArg = DeleteAfterDropArg + 1; + internal const long VREMAppendLogArg = RecreateIndexArg + 1; + internal const long MigrateElementKeyLogArg = VREMAppendLogArg + 1; + internal const long MigrateIndexKeyLogArg = MigrateElementKeyLogArg + 1; + + /// + /// Minimum size of an id is assumed to be at least 8 bytes + a length prefix. + /// + private const int MinimumSpacePerId = sizeof(int) + 8; + + /// + /// The process wide instances of DiskANN. + /// + /// We only need the one, even if we have multiple DBs, because all context is provided by DiskANN instances and Garnet storage. + /// + private DiskANNService Service { get; } = new DiskANNService(); + + /// + /// Whether or not Vector Set preview is enabled. + /// + /// TODO: This goes away once we're stable. + /// + public bool IsEnabled { get; } + + /// + /// Unique id for this . + /// + /// Is used to determine if an is backed by a DiskANN index that was created in this process. + /// + private readonly Guid processInstanceId = Guid.NewGuid(); + + private readonly ILogger logger; + + private readonly int dbId; + + public VectorManager(bool enabled, int dbId, Func getCleanupSession, ILoggerFactory loggerFactory) + { + this.dbId = dbId; + + IsEnabled = enabled; + + // Include DB and id so we correlate to what's actually stored in the log + logger = loggerFactory?.CreateLogger($"{nameof(VectorManager)}:{dbId}:{processInstanceId}"); + + replicationBlockEvent = CountingEventSlim.Create(); + replicationReplayChannel = Channel.CreateUnbounded(new() { SingleWriter = true, SingleReader = false, AllowSynchronousContinuations = false }); + + // TODO: Pull this off a config or something + replicationReplayTasks = new Task[Environment.ProcessorCount]; + for (var i = 0; i < replicationReplayTasks.Length; i++) + { + replicationReplayTasks[i] = Task.CompletedTask; + } + + // TODO: Probably configurable? + // For now, just number of processors + vectorSetLocks = new(Environment.ProcessorCount); + + this.getCleanupSession = getCleanupSession; + cleanupTaskChannel = Channel.CreateUnbounded(new() { SingleWriter = false, SingleReader = true, AllowSynchronousContinuations = false }); + cleanupTask = RunCleanupTaskAsync(); + + logger?.LogInformation("Created VectorManager"); + } + + /// + /// Load state necessary for VectorManager from main store. + /// + public void Initialize() + { + using var session = (RespServerSession)getCleanupSession(); + if (session.activeDbId != dbId && !session.TrySwitchActiveDatabaseSession(dbId)) + { + throw new GarnetException($"Could not switch VectorManager cleanup session to {dbId}, initialization failed"); + } + + Span keySpan = stackalloc byte[1]; + Span dataSpan = stackalloc byte[ContextMetadata.Size]; + + var key = SpanByte.FromPinnedSpan(keySpan); + + key.MarkNamespace(); + key.SetNamespaceInPayload(0); + + var data = SpanByte.FromPinnedSpan(dataSpan); + + ref var ctx = ref session.storageSession.vectorContext; + + var status = ctx.Read(ref key, ref data); + + if (status.IsPending) + { + SpanByte ignored = default; + CompletePending(ref status, ref ignored, ref ctx); + } + + // Can be not found if we've never spun up a Vector Set + if (status.Found) + { + lock (this) + { + contextMetadata = MemoryMarshal.Cast(dataSpan)[0]; + } + } + + } + + /// + /// Restart or update any pending work that was discovered as part of recovery. + /// + public void ResumePostRecovery() + { + using var session = (RespServerSession)getCleanupSession(); + + ref var ctx = ref session.storageSession.vectorContext; + + // If we come up and contexts are marked for migration, that means the migration FAILED + // and we'd like those contexts back ASAP + lock (this) + { + var abandonedMigrations = contextMetadata.GetMigrating(); + + if (abandonedMigrations != null) + { + foreach (var abandoned in abandonedMigrations) + { + contextMetadata.MarkMigrationComplete(abandoned, ushort.MaxValue); + contextMetadata.MarkCleaningUp(abandoned); + } + + UpdateContextMetadata(ref ctx); + } + } + + Span indexSpan = stackalloc byte[Index.Size]; + + // Finish any deletes that were in progress before we restarted + var failedDeletes = GetDeletesInProgress(session.storageSession); + var clearInProgressDeletes = true; + foreach (var (toDeleteKey, toDeleteCtx) in failedDeletes) + { + logger?.LogInformation("Cleaning up in progress Vector Set delete of {key} (context: {ctx})", Encoding.UTF8.GetString(toDeleteKey.Span), toDeleteCtx); + + unsafe + { + fixed (byte* toDeleteKeyPtr = toDeleteKey.Span) + { + var toDeleteKeySpanByte = SpanByte.FromPinnedPointer(toDeleteKeyPtr, toDeleteKey.Span.Length); + + RawStringInput input = new(RespCommand.VADD); + + // Check if delete got far enough that we should re-apply it + using (ReadForDeleteVectorIndex(session.storageSession, ref toDeleteKeySpanByte, ref input, indexSpan, out var garnetStatus)) + { + if (garnetStatus is not (GarnetStatus.BADSTATE or GarnetStatus.NOTFOUND)) + { + // It didn't - so don't re-apply (But do remove the "we're deleting"-entry later) + continue; + } + } + + try + { + if (TryDeleteVectorSet(session.storageSession, ref toDeleteKeySpanByte, out var garnetStatus).IsCompletedSuccessfully && garnetStatus != GarnetStatus.BADSTATE) + { + // Normal delete worked, easy enough + // + // This happens if we fail between the "remember we're deleting" and "zero everything out" steps + logger?.LogInformation("Vector Set under {key} (context: {ctx}) deleted normally", Encoding.UTF8.GetString(toDeleteKey.Span), toDeleteCtx); + continue; + } + } + catch (Exception ex) + { + logger?.LogError(ex, "Attempt at normal cleanup of {key} failed", Encoding.UTF8.GetString(toDeleteKey.Span)); + } + + // Partial delete, do these bits directly + // 1. Try to zero out the index key + // 2. Try to delete the index key + // 3. Try to drop the replication key + // 4. Mark the context as needing cleanup + + // Zero out the index (which may already be zero'd, but that's fine to redo) + RawStringInput updateToDroppableVectorSet = new(RespCommand.VADD, arg1: DeleteAfterDropArg); + var update = session.storageSession.basicContext.RMW(ref toDeleteKeySpanByte, ref updateToDroppableVectorSet); + if (!update.IsCompletedSuccessfully) + { + throw new GarnetException("Failed to make Vector Set delete-able, this should never happen but will leave vector sets corrupted"); + } + + // Note that we don't need to DROP the index because we know we haven't re-created it yet + + // Actually delete the value + var del = session.storageSession.basicContext.Delete(ref toDeleteKeySpanByte); + if (!(del.Found || del.NotFound)) + { + logger?.LogCritical("Failed to cleanup delete dropped Vector Set {key} (context: {ctx}), Vector Set will remain corrupted", Encoding.UTF8.GetString(toDeleteKey.Span), toDeleteCtx); + clearInProgressDeletes = false; + continue; + } + + // Cleanup incidental additional state + if (!TryDropVectorSetReplicationKey(toDeleteKeySpanByte, ref session.storageSession.basicContext)) + { + logger?.LogCritical("Failed to cleanup delete dropped Vector Set {key} (context: {ctx}), Vector Set will remain corrupted", Encoding.UTF8.GetString(toDeleteKey.Span), toDeleteCtx); + clearInProgressDeletes = false; + continue; + } + + // Schedule cleanup of element data + CleanupDroppedIndex(ref session.storageSession.vectorContext, toDeleteCtx); + + logger?.LogInformation("Vector Set under {key} (context: {ctx}) deleted normally", Encoding.UTF8.GetString(toDeleteKey.Span), toDeleteCtx); + } + } + } + + if (clearInProgressDeletes) + { + // We successfully dealt with all pending deletes, we can delete the metadata key + Span toDeleteKeySpan = stackalloc byte[2]; + var toDeleteKey = SpanByte.FromPinnedSpan(toDeleteKeySpan); + + // 0:1 is InProgressDeletes + toDeleteKey.MarkNamespace(); + toDeleteKey.SetNamespaceInPayload(0); + toDeleteKey.AsSpan()[0] = 1; + + var deleteStatus = session.storageSession.vectorContext.Delete(ref toDeleteKey); + Debug.Assert(!deleteStatus.IsPending, "Delete shouldn't go async"); + } + + // Resume any cleanups we didn't complete before recovery + _ = cleanupTaskChannel.Writer.TryWrite(null); + } + + /// + public void Dispose() + { + // We must drain all these before disposing, otherwise we'll leave replicationBlockEvent unset + replicationReplayChannel.Writer.Complete(); + replicationReplayChannel.Reader.Completion.Wait(); + + Task.WhenAll(replicationReplayTasks).Wait(); + + replicationBlockEvent.Dispose(); + + // Wait for any in progress cleanup to finish + cleanupTaskChannel.Writer.Complete(); + cleanupTaskChannel.Reader.Completion.Wait(); + cleanupTask.Wait(); + } + + private static void CompletePending(ref Status status, ref SpanByte output, ref TContext ctx) + where TContext : ITsavoriteContext + { + _ = ctx.CompletePendingWithOutputs(out var completedOutputs, wait: true); + var more = completedOutputs.Next(); + Debug.Assert(more); + status = completedOutputs.Current.Status; + output = completedOutputs.Current.Output; + Debug.Assert(!completedOutputs.Next()); + completedOutputs.Dispose(); + } + + /// + /// Add a vector to a vector set encoded by . + /// + /// Assumes that the index is locked in the Tsavorite store. + /// + /// Result of the operation. + internal VectorManagerResult TryAdd( + scoped ReadOnlySpan indexValue, + ReadOnlySpan element, + VectorValueType valueType, + ReadOnlySpan values, + ReadOnlySpan attributes, + uint providedReduceDims, + VectorQuantType providedQuantType, + uint providedBuildExplorationFactor, + uint providedNumLinks, + VectorDistanceMetricType providedDistanceMetric, + out ReadOnlySpan errorMsg + ) + { + AssertHaveStorageSession(); + + errorMsg = default; + + ReadIndex(indexValue, out var context, out var dimensions, out var reduceDims, out var quantType, out _, out var numLinks, out var distanceMetric, out var indexPtr, out _); + + var valueDims = CalculateValueDimensions(valueType, values); + + if (dimensions != valueDims) + { + // Matching Redis behavior + errorMsg = Encoding.ASCII.GetBytes($"ERR Vector dimension mismatch - got {valueDims} but set has {dimensions}"); + return VectorManagerResult.BadParams; + } + + if (providedReduceDims == 0 && reduceDims != 0) + { + // Matching Redis behavior, which is definitely a bit weird here + errorMsg = Encoding.ASCII.GetBytes($"ERR Vector dimension mismatch - got {valueDims} but set has {reduceDims}"); + return VectorManagerResult.BadParams; + } + else if (providedReduceDims != 0 && providedReduceDims != reduceDims) + { + return VectorManagerResult.BadParams; + } + + if (providedQuantType != VectorQuantType.Invalid && providedQuantType != quantType) + { + return VectorManagerResult.BadParams; + } + + if (providedDistanceMetric != VectorDistanceMetricType.Invalid && providedDistanceMetric != distanceMetric) + { + errorMsg = Encoding.ASCII.GetBytes($"ERR Distance metric mismatch - got {providedDistanceMetric} but set has {distanceMetric}"); + return VectorManagerResult.BadParams; + } + + if (providedNumLinks != numLinks) + { + // Matching Redis behavior + errorMsg = "ERR asked M value mismatch with existing vector set"u8; + return VectorManagerResult.BadParams; + } + + var insert = + Service.Insert( + context, + indexPtr, + element, + valueType, + values, + attributes + ); + + if (insert) + { + return VectorManagerResult.OK; + } + + return VectorManagerResult.Duplicate; + } + + /// + /// Try to remove a vector (and associated attributes) from a Vector Set, as identified by element key. + /// + internal VectorManagerResult TryRemove(ReadOnlySpan indexValue, ReadOnlySpan element) + { + AssertHaveStorageSession(); + + ReadIndex(indexValue, out var context, out _, out _, out var quantType, out _, out _, out _, out var indexPtr, out _); + + var del = Service.Remove(context, indexPtr, element); + + return del ? VectorManagerResult.OK : VectorManagerResult.MissingElement; + } + + /// + /// Deletion of a Vector Set needs special handling. + /// + /// This is called by DEL and UNLINK after a naive delete fails for us to _try_ and delete a Vector Set. + /// + internal Status TryDeleteVectorSet(StorageSession storageSession, ref SpanByte key, out GarnetStatus status) + { + storageSession.parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); + + var input = new RawStringInput(RespCommand.VADD, ref storageSession.parseState); + + Span indexSpan = stackalloc byte[Index.Size]; + + using (ReadForDeleteVectorIndex(storageSession, ref key, ref input, indexSpan, out status)) + { + if (status != GarnetStatus.OK) + { + // This can happen is something else successfully deleted before we acquired the lock + return Status.CreateNotFound(); + } + + ReadIndex(indexSpan, out var context, out _, out _, out _, out _, out _, out _, out _, out _); + + if (!TryMarkDeleteInProgress(ref storageSession.vectorContext, ref key, context)) + { + // We can't recover from a crash or error, so fail the delete for safety + return Status.CreateError(); + } + + ExceptionInjectionHelper.TriggerException(ExceptionInjectionType.VectorSet_Interrupt_Delete_0); + + // Update the index to be delete-able + RawStringInput updateToDroppableVectorSet = new(RespCommand.VADD, arg1: DeleteAfterDropArg); + + var update = storageSession.basicContext.RMW(ref key, ref updateToDroppableVectorSet); + if (!update.IsCompletedSuccessfully) + { + throw new GarnetException("Failed to make Vector Set delete-able, this should never happen but will leave vector sets corrupted"); + } + + // Drop the native side of the index now - we can't fault between the two unless the process is torn down + DropIndex(indexSpan); + + ExceptionInjectionHelper.TriggerException(ExceptionInjectionType.VectorSet_Interrupt_Delete_1); + + // Actually delete the value + var del = storageSession.basicContext.Delete(ref key); + if (!del.IsCompletedSuccessfully) + { + throw new GarnetException("Failed to delete dropped Vector Set, this should never happen but will leave vector sets corrupted"); + } + + ExceptionInjectionHelper.TriggerException(ExceptionInjectionType.VectorSet_Interrupt_Delete_2); + + // Cleanup incidental additional state + if (!TryDropVectorSetReplicationKey(key, ref storageSession.basicContext)) + { + logger?.LogCritical("Couldn't synthesize Vector Set delete operation for replication, data loss will occur"); + } + + // Schedule cleanup of element data + CleanupDroppedIndex(ref storageSession.vectorContext, context); + + // Delete has finished, so remove the in progress metadata + // + // A crash or error before this will cause some work to be retried, but no correctness issues + ClearDeleteInProgress(ref storageSession.vectorContext, ref key, context); + + return Status.CreateFound(); + } + } + + /// + /// Perform a similarity search given a vector to compare against. + /// + internal VectorManagerResult ValueSimilarity( + ReadOnlySpan indexValue, + VectorValueType valueType, + ReadOnlySpan values, + int count, + float delta, + int searchExplorationFactor, + ReadOnlySpan filter, + int maxFilteringEffort, + bool includeAttributes, + ref SpanByteAndMemory outputIds, + out VectorIdFormat outputIdFormat, + ref SpanByteAndMemory outputDistances, + ref SpanByteAndMemory outputAttributes + ) + { + AssertHaveStorageSession(); + + ReadIndex(indexValue, out var context, out var dimensions, out _, out var quantType, out _, out _, out _, out var indexPtr, out _); + + var valueDims = CalculateValueDimensions(valueType, values); + if (dimensions != valueDims) + { + outputIdFormat = VectorIdFormat.Invalid; + return VectorManagerResult.BadParams; + } + + // No point in asking for more data than the effort we'll put in + if (count > searchExplorationFactor) + { + count = searchExplorationFactor; + } + + // Make sure enough space in distances for requested count + if (count > outputDistances.Length) + { + if (!outputDistances.IsSpanByte) + { + outputDistances.Memory.Dispose(); + } + + outputDistances = new SpanByteAndMemory(MemoryPool.Shared.Rent(count * sizeof(float)), count * sizeof(float)); + } + + // Indicate requested # of matches + outputDistances.Length = count * sizeof(float); + + // If we're fairly sure the ids won't fit, go ahead and grab more memory now + // + // If we're still wrong, we'll end up using continuation callbacks which have more overhead + if (count * MinimumSpacePerId > outputIds.Length) + { + if (!outputIds.IsSpanByte) + { + outputIds.Memory.Dispose(); + } + + outputIds = new SpanByteAndMemory(MemoryPool.Shared.Rent(count * MinimumSpacePerId), count * MinimumSpacePerId); + } + + var found = + Service.SearchVector( + context, + indexPtr, + valueType, + values, + delta, + searchExplorationFactor, + filter, + maxFilteringEffort, + outputIds, + outputDistances, + out var continuation + ); + + if (found < 0) + { + logger?.LogWarning("Error indicating response from vector service {found}", found); + outputIdFormat = VectorIdFormat.Invalid; + return VectorManagerResult.BadParams; + } + + if (includeAttributes) + { + FetchVectorElementAttributes(context, found, outputIds, ref outputAttributes); + } + + if (continuation != 0) + { + // TODO: paged results! + throw new NotImplementedException(); + } + + outputDistances.Length = sizeof(float) * found; + + // Default assumption is length prefixed + outputIdFormat = VectorIdFormat.I32LengthPrefixed; + + if (quantType == VectorQuantType.XPreQ8) + { + // But in this special case, we force them to be 4-byte ids + //outputIdFormat = VectorIdFormat.FixedI32; + outputIdFormat = VectorIdFormat.I32LengthPrefixed; + } + + return VectorManagerResult.OK; + } + + /// + /// Perform a similarity search given a vector to compare against. + /// + internal VectorManagerResult ElementSimilarity( + ReadOnlySpan indexValue, + ReadOnlySpan element, + int count, + float delta, + int searchExplorationFactor, + ReadOnlySpan filter, + int maxFilteringEffort, + bool includeAttributes, + ref SpanByteAndMemory outputIds, + out VectorIdFormat outputIdFormat, + ref SpanByteAndMemory outputDistances, + ref SpanByteAndMemory outputAttributes + ) + { + AssertHaveStorageSession(); + + ReadIndex(indexValue, out var context, out _, out _, out var quantType, out _, out _, out _, out var indexPtr, out _); + + // No point in asking for more data than the effort we'll put in + if (count > searchExplorationFactor) + { + count = searchExplorationFactor; + } + + // Make sure enough space in distances for requested count + if (count * sizeof(float) > outputDistances.Length) + { + if (!outputDistances.IsSpanByte) + { + outputDistances.Memory.Dispose(); + } + + outputDistances = new SpanByteAndMemory(MemoryPool.Shared.Rent(count * sizeof(float)), count * sizeof(float)); + } + + // Indicate requested # of matches + outputDistances.Length = count * sizeof(float); + + // If we're fairly sure the ids won't fit, go ahead and grab more memory now + // + // If we're still wrong, we'll end up using continuation callbacks which have more overhead + if (count * MinimumSpacePerId > outputIds.Length) + { + if (!outputIds.IsSpanByte) + { + outputIds.Memory.Dispose(); + } + + outputIds = new SpanByteAndMemory(MemoryPool.Shared.Rent(count * MinimumSpacePerId), count * MinimumSpacePerId); + } + + var found = + Service.SearchElement( + context, + indexPtr, + element, + delta, + searchExplorationFactor, + filter, + maxFilteringEffort, + outputIds, + outputDistances, + out var continuation + ); + + if (found < 0) + { + logger?.LogWarning("Error indicating response from vector service {found}", found); + outputIdFormat = VectorIdFormat.Invalid; + return VectorManagerResult.BadParams; + } + + if (includeAttributes) + { + FetchVectorElementAttributes(context, found, outputIds, ref outputAttributes); + } + + if (continuation != 0) + { + // TODO: paged results! + throw new NotImplementedException(); + } + + outputDistances.Length = sizeof(float) * found; + + // Default assumption is length prefixed + outputIdFormat = VectorIdFormat.I32LengthPrefixed; + + if (quantType == VectorQuantType.XPreQ8) + { + // But in this special case, we force them to be 4-byte ids + //outputIdFormat = VectorIdFormat.FixedI32; + outputIdFormat = VectorIdFormat.I32LengthPrefixed; + } + + return VectorManagerResult.OK; + } + + /// + /// Fetch attributes for a single element id. + /// + /// This must only be called while holding locks which prevent the Vector Set from being dropped. + /// + /// IMPORTANT: outputAttributes may be replaced with an allocated memory, so the caller needs to check + /// if the buffer is stack-based or heap-based, and dispose if it's the latter. + /// + internal VectorManagerResult FetchSingleVectorElementAttributes(ReadOnlySpan indexValue, SpanByte element, ref SpanByteAndMemory outputAttributes) + { + AssertHaveStorageSession(); + ReadIndex(indexValue, out var context, out _, out _, out _, out _, out _, out _, out _, out _); + var found = ReadSizeUnknown(context | DiskANNService.Attributes, element.AsReadOnlySpan(), ref outputAttributes); + return found ? VectorManagerResult.OK : VectorManagerResult.MissingElement; + } + + /// + /// Fetch attributes for a given set of element ids. + /// + /// This must only be called while holding locks which prevent the Vector Set from being dropped. + /// + private void FetchVectorElementAttributes(ulong context, int numIds, SpanByteAndMemory ids, ref SpanByteAndMemory attributes) + { + var remainingIds = ids.AsReadOnlySpan(); + + GCHandle idPin = default; + byte[] idWithNamespaceArr = null; + + var attributesNextIx = 0; + + Span attributeFull = stackalloc byte[32]; + var attributeMem = SpanByteAndMemory.FromPinnedSpan(attributeFull); + + try + { + Span idWithNamespace = stackalloc byte[128]; + + // TODO: we could scatter/gather this like MGET - doesn't matter when everything is in memory, + // but if anything is on disk it'd help perf + for (var i = 0; i < numIds; i++) + { + var idLen = BinaryPrimitives.ReadInt32LittleEndian(remainingIds); + if (idLen + sizeof(int) > remainingIds.Length) + { + throw new GarnetException($"Malformed ids, {idLen} + {sizeof(int)} > {remainingIds.Length}"); + } + + var id = remainingIds.Slice(sizeof(int), idLen); + + // Make sure we've got enough space to query the element + if (id.Length + 1 > idWithNamespace.Length) + { + if (idWithNamespaceArr != null) + { + idPin.Free(); + ArrayPool.Shared.Return(idWithNamespaceArr); + } + + idWithNamespaceArr = ArrayPool.Shared.Rent(id.Length + 1); + idPin = GCHandle.Alloc(idWithNamespaceArr, GCHandleType.Pinned); + idWithNamespace = idWithNamespaceArr; + } + + if (attributeMem.Memory != null) + { + attributeMem.Length = attributeMem.Memory.Memory.Length; + } + else + { + attributeMem.Length = attributeMem.SpanByte.Length; + } + + var found = ReadSizeUnknown(context | DiskANNService.Attributes, id, ref attributeMem); + + // Copy attribute into output buffer, length prefixed, resizing as necessary + var neededSpace = 4 + (found ? attributeMem.Length : 0); + + var destSpan = attributes.AsSpan()[attributesNextIx..]; + if (destSpan.Length < neededSpace) + { + var newAttrArr = MemoryPool.Shared.Rent(attributes.Length + neededSpace); + attributes.AsReadOnlySpan().CopyTo(newAttrArr.Memory.Span); + + attributes.Memory?.Dispose(); + + attributes = new SpanByteAndMemory(newAttrArr, newAttrArr.Memory.Length); + destSpan = attributes.AsSpan()[attributesNextIx..]; + } + + BinaryPrimitives.WriteInt32LittleEndian(destSpan, attributeMem.Length); + attributeMem.AsReadOnlySpan().CopyTo(destSpan[sizeof(int)..]); + + attributesNextIx += neededSpace; + + remainingIds = remainingIds[(sizeof(int) + idLen)..]; + } + + attributes.Length = attributesNextIx; + } + finally + { + if (idWithNamespaceArr != null) + { + idPin.Free(); + ArrayPool.Shared.Return(idWithNamespaceArr); + } + + attributeMem.Memory?.Dispose(); + } + } + + /// + /// Try to read the associated dimensions for an element out of a Vector Set. + /// + internal bool TryGetEmbedding(ReadOnlySpan indexValue, ReadOnlySpan element, out VectorQuantType quantType, ref SpanByteAndMemory outputDistances) + { + AssertHaveStorageSession(); + + ReadIndex(indexValue, out var context, out var dimensions, out _, out quantType, out _, out _, out _, out var indexPtr, out _); + + Span internalId = stackalloc byte[sizeof(int)]; + var internalIdBytes = SpanByteAndMemory.FromPinnedSpan(internalId); + try + { + if (!ReadSizeUnknown(context | DiskANNService.InternalIdMap, element, ref internalIdBytes)) + { + return false; + } + + Debug.Assert(internalIdBytes.IsSpanByte, "Internal Id should always be of known size"); + } + finally + { + internalIdBytes.Memory?.Dispose(); + } + + if (quantType == VectorQuantType.NoQuant) + { + return TryGetEmbeddingRaw(context, indexPtr, (int)dimensions, internalId, ref outputDistances); + } + else if (quantType == VectorQuantType.XPreQ8) + { + return TryGetEmbeddingRaw(context, indexPtr, (int)dimensions, internalId, ref outputDistances); + } + else + { + throw new GarnetException($"Unsupported quantization type for embedding retrieval: {quantType}"); + } + } + + private bool TryGetEmbeddingRaw(ulong context, nint indexPtr, int dimensions, ReadOnlySpan internalId, ref SpanByteAndMemory outputDistances) + where T : unmanaged + { + var requiredBytes = dimensions * Unsafe.SizeOf(); + if (requiredBytes > outputDistances.Length) + { + outputDistances.Memory?.Dispose(); + outputDistances = new SpanByteAndMemory(MemoryPool.Shared.Rent(requiredBytes), requiredBytes); + } + else + { + outputDistances.Length = requiredBytes; + } + + Span storedVectorAsBytesSpan = stackalloc byte[requiredBytes]; + var storedVectorAsBytes = SpanByteAndMemory.FromPinnedSpan(storedVectorAsBytesSpan); + try + { + if (!ReadSizeUnknown(context | DiskANNService.FullVector, internalId, ref storedVectorAsBytes)) + { + return false; + } + + Debug.Assert(storedVectorAsBytes.Length == requiredBytes, "Unexpected length for raw FP32 stored vector"); + + var into = MemoryMarshal.Cast(outputDistances.AsSpan()); + MemoryMarshal.Cast(storedVectorAsBytes.AsReadOnlySpan()).CopyTo(into); + + // Vector might have been deleted, so check that after getting data + return Service.CheckInternalIdValid(context, indexPtr, internalId); + } + finally + { + storedVectorAsBytes.Memory?.Dispose(); + } + } + + private bool TryGetEmbeddingU8(ulong context, nint indexPtr, int dimensions, ReadOnlySpan internalId, ref SpanByteAndMemory outputDistances) + { + // Make sure enough space in distances for requested count + if (dimensions > outputDistances.Length) + { + outputDistances.Memory?.Dispose(); + outputDistances = new SpanByteAndMemory(MemoryPool.Shared.Rent(dimensions), dimensions); + } + else + { + outputDistances.Length = dimensions; + } + + Span storedVectorAsBytesSpan = stackalloc byte[dimensions]; + var storedVectorAsBytes = SpanByteAndMemory.FromPinnedSpan(storedVectorAsBytesSpan); + try + { + if (!ReadSizeUnknown(context | DiskANNService.FullVector, internalId, ref storedVectorAsBytes)) + { + return false; + } + + Debug.Assert(storedVectorAsBytes.Length == dimensions, "Unexpected length for raw stored vector"); + storedVectorAsBytes.AsReadOnlySpan().CopyTo(outputDistances.AsSpan()); + + // Vector might have been deleted, so check that after getting data + return Service.CheckInternalIdValid(context, indexPtr, internalId); + } + finally + { + storedVectorAsBytes.Memory?.Dispose(); + } + } + + /// + /// Determine the dimensions of a vector given its and its raw data. + /// + internal static uint CalculateValueDimensions(VectorValueType valueType, ReadOnlySpan values) + { + if (valueType == VectorValueType.FP32) + { + return (uint)(values.Length / sizeof(float)); + } + else if (valueType == VectorValueType.XB8) + { + return (uint)(values.Length); + } + else + { + throw new NotImplementedException($"{valueType}"); + } + } + + [Conditional("DEBUG")] + private static void AssertHaveStorageSession() + { + Debug.Assert(ActiveThreadSession != null, "Should have StorageSession by now"); + } + } +} \ No newline at end of file diff --git a/libs/server/Servers/GarnetServerOptions.cs b/libs/server/Servers/GarnetServerOptions.cs index e0d344e87fb..4b6ee6b2797 100644 --- a/libs/server/Servers/GarnetServerOptions.cs +++ b/libs/server/Servers/GarnetServerOptions.cs @@ -532,6 +532,13 @@ public class GarnetServerOptions : ServerOptions /// public bool ClusterReplicaResumeWithData = false; + /// + /// If true, enable Vector Set commands. + /// + /// This is a preview feature, subject to substantial change, and should not be relied upon. + /// + public bool EnableVectorSetPreview = false; + /// /// Get the directory name for database checkpoints /// diff --git a/libs/server/Storage/Functions/FunctionsState.cs b/libs/server/Storage/Functions/FunctionsState.cs index 4ef24a38260..32eddffbe4e 100644 --- a/libs/server/Storage/Functions/FunctionsState.cs +++ b/libs/server/Storage/Functions/FunctionsState.cs @@ -22,11 +22,12 @@ internal sealed class FunctionsState public EtagState etagState; public byte respProtocolVersion; public bool StoredProcMode; + public readonly VectorManager vectorManager; internal ReadOnlySpan nilResp => respProtocolVersion >= 3 ? CmdStrings.RESP3_NULL_REPLY : CmdStrings.RESP_ERRNOTFOUND; public FunctionsState(TsavoriteLog appendOnlyFile, WatchVersionMap watchVersionMap, CustomCommandManager customCommandManager, - MemoryPool memoryPool, CacheSizeTracker objectStoreSizeTracker, GarnetObjectSerializer garnetObjectSerializer, + MemoryPool memoryPool, CacheSizeTracker objectStoreSizeTracker, GarnetObjectSerializer garnetObjectSerializer, VectorManager vectorManager, byte respProtocolVersion = ServerOptions.DEFAULT_RESP_VERSION) { this.appendOnlyFile = appendOnlyFile; @@ -36,6 +37,7 @@ public FunctionsState(TsavoriteLog appendOnlyFile, WatchVersionMap watchVersionM this.objectStoreSizeTracker = objectStoreSizeTracker; this.garnetObjectSerializer = garnetObjectSerializer; this.etagState = new EtagState(); + this.vectorManager = vectorManager; this.respProtocolVersion = respProtocolVersion; } diff --git a/libs/server/Storage/Functions/MainStore/DeleteMethods.cs b/libs/server/Storage/Functions/MainStore/DeleteMethods.cs index eddf29d54d0..2b3d5cb859a 100644 --- a/libs/server/Storage/Functions/MainStore/DeleteMethods.cs +++ b/libs/server/Storage/Functions/MainStore/DeleteMethods.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +using System; using Tsavorite.core; namespace Garnet.server @@ -13,6 +14,15 @@ namespace Garnet.server /// public bool SingleDeleter(ref SpanByte key, ref SpanByte value, ref DeleteInfo deleteInfo, ref RecordInfo recordInfo) { + if (recordInfo.VectorSet && value.AsReadOnlySpan().ContainsAnyExcept((byte)0)) + { + // Implies this is a vector set, needs special handling + // + // Will call back in after a drop with an all 0 value + deleteInfo.Action = DeleteAction.CancelOperation; + return false; + } + recordInfo.ClearHasETag(); functionsState.watchVersionMap.IncrementVersion(deleteInfo.KeyHash); return true; @@ -28,6 +38,15 @@ public void PostSingleDeleter(ref SpanByte key, ref DeleteInfo deleteInfo) /// public bool ConcurrentDeleter(ref SpanByte key, ref SpanByte value, ref DeleteInfo deleteInfo, ref RecordInfo recordInfo) { + if (recordInfo.VectorSet && value.AsReadOnlySpan().ContainsAnyExcept((byte)0)) + { + // Implies this is a vector set, needs special handling + // + // Will call back in after a drop with an all 0 value + deleteInfo.Action = DeleteAction.CancelOperation; + return false; + } + recordInfo.ClearHasETag(); if (!deleteInfo.RecordInfo.Modified) functionsState.watchVersionMap.IncrementVersion(deleteInfo.KeyHash); diff --git a/libs/server/Storage/Functions/MainStore/PrivateMethods.cs b/libs/server/Storage/Functions/MainStore/PrivateMethods.cs index d6c0a618839..cc4c8ba6495 100644 --- a/libs/server/Storage/Functions/MainStore/PrivateMethods.cs +++ b/libs/server/Storage/Functions/MainStore/PrivateMethods.cs @@ -118,6 +118,13 @@ void CopyRespToWithInput(ref RawStringInput input, ref SpanByte value, ref SpanB value.CopyTo(dst.Memory.Memory.Span); break; + case RespCommand.VADD: + case RespCommand.VSIM: + case RespCommand.VEMB: + case RespCommand.VGETATTR: + case RespCommand.VINFO: + case RespCommand.VREM: + case RespCommand.VDIM: case RespCommand.GET: // Get value without RESP header; exclude expiration if (value.LengthWithoutMetadata <= dst.Length) @@ -242,12 +249,12 @@ void CopyRespToWithInput(ref RawStringInput input, ref SpanByte value, ref SpanB throw new GarnetException($"Not enough space in {input.header.cmd} buffer"); case RespCommand.TTL: - var ttlValue = ConvertUtils.SecondsFromDiffUtcNowTicks(value.MetadataSize > 0 ? value.ExtraMetadata : -1); + var ttlValue = ConvertUtils.SecondsFromDiffUtcNowTicks(value.MetadataSize == 8 ? value.ExtraMetadata : -1); CopyRespNumber(ttlValue, ref dst); return; case RespCommand.PTTL: - var pttlValue = ConvertUtils.MillisecondsFromDiffUtcNowTicks(value.MetadataSize > 0 ? value.ExtraMetadata : -1); + var pttlValue = ConvertUtils.MillisecondsFromDiffUtcNowTicks(value.MetadataSize == 8 ? value.ExtraMetadata : -1); CopyRespNumber(pttlValue, ref dst); return; @@ -260,12 +267,12 @@ void CopyRespToWithInput(ref RawStringInput input, ref SpanByte value, ref SpanB CopyRespTo(ref value, ref dst, start + functionsState.etagState.etagSkippedStart, end + functionsState.etagState.etagSkippedStart); return; case RespCommand.EXPIRETIME: - var expireTime = ConvertUtils.UnixTimeInSecondsFromTicks(value.MetadataSize > 0 ? value.ExtraMetadata : -1); + var expireTime = ConvertUtils.UnixTimeInSecondsFromTicks(value.MetadataSize == 8 ? value.ExtraMetadata : -1); CopyRespNumber(expireTime, ref dst); return; case RespCommand.PEXPIRETIME: - var pexpireTime = ConvertUtils.UnixTimeInMillisecondsFromTicks(value.MetadataSize > 0 ? value.ExtraMetadata : -1); + var pexpireTime = ConvertUtils.UnixTimeInMillisecondsFromTicks(value.MetadataSize == 8 ? value.ExtraMetadata : -1); CopyRespNumber(pexpireTime, ref dst); return; @@ -730,6 +737,11 @@ void WriteLogUpsert(ref SpanByte key, ref RawStringInput input, { if (functionsState.StoredProcMode) return; + if (input.header.cmd == RespCommand.VADD && input.arg1 is not (VectorManager.VADDAppendLogArg or VectorManager.MigrateElementKeyLogArg or VectorManager.MigrateIndexKeyLogArg)) + { + return; + } + // We need this check because when we ingest records from the primary // if the input is zero then input overlaps with value so any update to RespInputHeader->flags // will incorrectly modify the total length of value. @@ -751,6 +763,12 @@ void WriteLogRMW(ref SpanByte key, ref RawStringInput input, lon where TEpochAccessor : IEpochAccessor { if (functionsState.StoredProcMode) return; + + if (input.header.cmd == RespCommand.VADD && input.arg1 is not (VectorManager.VADDAppendLogArg or VectorManager.MigrateElementKeyLogArg or VectorManager.MigrateIndexKeyLogArg)) + { + return; + } + input.header.flags |= RespInputFlags.Deterministic; functionsState.appendOnlyFile.Enqueue( diff --git a/libs/server/Storage/Functions/MainStore/RMWMethods.cs b/libs/server/Storage/Functions/MainStore/RMWMethods.cs index 3f6d55acef7..b909f37f94a 100644 --- a/libs/server/Storage/Functions/MainStore/RMWMethods.cs +++ b/libs/server/Storage/Functions/MainStore/RMWMethods.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics; +using System.Runtime.InteropServices; using Garnet.common; using Tsavorite.core; @@ -242,6 +243,39 @@ public bool InitialUpdater(ref SpanByte key, ref RawStringInput input, ref SpanB var incrByFloat = BitConverter.Int64BitsToDouble(input.arg1); CopyUpdateNumber(incrByFloat, ref value, ref output); break; + + case RespCommand.VADD: + { + if (input.arg1 is VectorManager.VADDAppendLogArg or VectorManager.MigrateElementKeyLogArg or VectorManager.MigrateIndexKeyLogArg) + { + // Synthetic op, do nothing + break; + } + + var dims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(0).Span); + var reduceDims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(1).Span); + // ValueType is here, skipping during index creation + // Values is here, skipping during index creation + // Element is here, skipping during index creation + var quantizer = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(5).Span); + var buildExplorationFactor = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(6).Span); + // Attributes is here, skipping during index creation + var numLinks = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(8).Span); + var distanceMetric = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(9).Span); + + // Pre-allocated by caller because DiskANN needs to be able to call into Garnet as part of create_index + // and thus we can't call into it from session functions + var context = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(10).Span); + var index = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(11).Span); + + recordInfo.VectorSet = true; + + functionsState.vectorManager.CreateIndex(dims, reduceDims, quantizer, buildExplorationFactor, numLinks, distanceMetric, context, index, ref value); + } + break; + case RespCommand.VREM: + Debug.Assert(input.arg1 == VectorManager.VREMAppendLogArg, "Should only see VREM writes as part of replication"); + break; default: if (input.header.cmd > RespCommandExtensions.LastValidCommand) { @@ -327,7 +361,7 @@ private IPUResult InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput inpu { RespCommand cmd = input.header.cmd; // Expired data - if (value.MetadataSize > 0 && input.header.CheckExpiry(value.ExtraMetadata)) + if (value.MetadataSize == 8 && input.header.CheckExpiry(value.ExtraMetadata)) { rmwInfo.Action = cmd is RespCommand.DELIFEXPIM ? RMWAction.ExpireAndStop : RMWAction.ExpireAndResume; recordInfo.ClearHasETag(); @@ -583,7 +617,7 @@ private IPUResult InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput inpu break; case RespCommand.EXPIRE: - var expiryExists = value.MetadataSize > 0; + var expiryExists = value.MetadataSize == 8; var expirationWithOption = new ExpirationWithOption(input.arg1); @@ -593,7 +627,7 @@ private IPUResult InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput inpu return EvaluateExpireInPlace(expirationWithOption.ExpireOption, expiryExists, expirationWithOption.ExpirationTimeInTicks, ref value, ref output); case RespCommand.PERSIST: - if (value.MetadataSize != 0) + if (value.MetadataSize == 8) { rmwInfo.ClearExtraValueLength(ref recordInfo, ref value, value.TotalSize); value.AsSpan().CopyTo(value.AsSpanWithMetadata()); @@ -752,7 +786,7 @@ private IPUResult InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput inpu var _output = new SpanByteAndMemory(SpanByte.FromPinnedPointer(pbOutput, ObjectOutputHeader.Size)); var newExpiry = input.arg1; - return EvaluateExpireInPlace(ExpireOption.None, expiryExists: value.MetadataSize > 0, newExpiry, ref value, ref _output); + return EvaluateExpireInPlace(ExpireOption.None, expiryExists: value.MetadataSize == 8, newExpiry, ref value, ref _output); } if (input.parseState.Count > 0) @@ -794,6 +828,38 @@ private IPUResult InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput inpu // this is the case where it isn't expired shouldUpdateEtag = false; break; + case RespCommand.VADD: + // Adding to an existing VectorSet is modeled as a read operations + // + // However, we do synthesize some (pointless) writes to implement replication + // and a "make me delete=able"-update during drop. + // + // Another "not quite write" is the recreate an index write operation + // that occurs if we're adding to an index that was restored from disk + // or a primary node. + + // Handle "make me delete-able" + if (input.arg1 == VectorManager.DeleteAfterDropArg) + { + value.AsSpan().Clear(); + } + else if (input.arg1 == VectorManager.RecreateIndexArg) + { + var newIndexPtr = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(11).Span); + + functionsState.vectorManager.RecreateIndex(newIndexPtr, ref value); + } + + // Ignore everything else + return IPUResult.Succeeded; + case RespCommand.VREM: + // Removing from a VectorSet is modeled as a read operations + // + // However, we do synthesize some (pointless) writes to implement replication + // in a similar manner to VADD. + + Debug.Assert(input.arg1 == VectorManager.VREMAppendLogArg, "VREM in place update should only happen for replication"); // Ignore everything else + return IPUResult.Succeeded; default: if (cmd > RespCommandExtensions.LastValidCommand) { @@ -877,7 +943,7 @@ public bool NeedCopyUpdate(ref SpanByte key, ref RawStringInput input, ref SpanB switch (input.header.cmd) { case RespCommand.DELIFEXPIM: - if (oldValue.MetadataSize > 0 && input.header.CheckExpiry(oldValue.ExtraMetadata)) + if (oldValue.MetadataSize == 8 && input.header.CheckExpiry(oldValue.ExtraMetadata)) { rmwInfo.Action = RMWAction.ExpireAndStop; } @@ -940,7 +1006,7 @@ public bool NeedCopyUpdate(ref SpanByte key, ref RawStringInput input, ref SpanB case RespCommand.SETEXNX: // Expired data, return false immediately // ExpireAndResume ensures that we set as new value, since it does not exist - if (oldValue.MetadataSize > 0 && input.header.CheckExpiry(oldValue.ExtraMetadata)) + if (oldValue.MetadataSize == 8 && input.header.CheckExpiry(oldValue.ExtraMetadata)) { rmwInfo.Action = RMWAction.ExpireAndResume; rmwInfo.RecordInfo.ClearHasETag(); @@ -968,7 +1034,7 @@ public bool NeedCopyUpdate(ref SpanByte key, ref RawStringInput input, ref SpanB case RespCommand.SETEXXX: // Expired data, return false immediately so we do not set, since it does not exist // ExpireAndStop ensures that caller sees a NOTFOUND status - if (oldValue.MetadataSize > 0 && input.header.CheckExpiry(oldValue.ExtraMetadata)) + if (oldValue.MetadataSize == 8 && input.header.CheckExpiry(oldValue.ExtraMetadata)) { rmwInfo.RecordInfo.ClearHasETag(); rmwInfo.Action = RMWAction.ExpireAndStop; @@ -1009,7 +1075,7 @@ public bool NeedCopyUpdate(ref SpanByte key, ref RawStringInput input, ref SpanB public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte oldValue, ref SpanByte newValue, ref SpanByteAndMemory output, ref RMWInfo rmwInfo, ref RecordInfo recordInfo) { // Expired data - if (oldValue.MetadataSize > 0 && input.header.CheckExpiry(oldValue.ExtraMetadata)) + if (oldValue.MetadataSize == 8 && input.header.CheckExpiry(oldValue.ExtraMetadata)) { recordInfo.ClearHasETag(); rmwInfo.Action = RMWAction.ExpireAndResume; @@ -1171,7 +1237,7 @@ public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte case RespCommand.EXPIRE: shouldUpdateEtag = false; - var expiryExists = oldValue.MetadataSize > 0; + var expiryExists = oldValue.MetadataSize == 8; var expirationWithOption = new ExpirationWithOption(input.arg1); @@ -1181,7 +1247,7 @@ public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte case RespCommand.PERSIST: shouldUpdateEtag = false; oldValue.AsReadOnlySpan().CopyTo(newValue.AsSpan()); - if (oldValue.MetadataSize != 0) + if (oldValue.MetadataSize == 8) { newValue.AsSpan().CopyTo(newValue.AsSpanWithMetadata()); newValue.ShrinkSerializedLength(newValue.Length - newValue.MetadataSize); @@ -1306,7 +1372,7 @@ public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte byte* pbOutput = stackalloc byte[ObjectOutputHeader.Size]; var _output = new SpanByteAndMemory(SpanByte.FromPinnedPointer(pbOutput, ObjectOutputHeader.Size)); var newExpiry = input.arg1; - EvaluateExpireCopyUpdate(ExpireOption.None, expiryExists: oldValue.MetadataSize > 0, newExpiry, ref oldValue, ref newValue, ref _output); + EvaluateExpireCopyUpdate(ExpireOption.None, expiryExists: oldValue.MetadataSize == 8, newExpiry, ref oldValue, ref newValue, ref _output); } oldValue.AsReadOnlySpan().CopyTo(newValue.AsSpan()); @@ -1337,6 +1403,27 @@ public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte CopyValueLengthToOutput(ref newValue, ref output, functionsState.etagState.etagSkippedStart); break; + case RespCommand.VADD: + // Handle "make me delete-able" + if (input.arg1 == VectorManager.DeleteAfterDropArg) + { + newValue.AsSpan().Clear(); + } + else if (input.arg1 == VectorManager.RecreateIndexArg) + { + var newIndexPtr = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(11).Span); + + oldValue.CopyTo(ref newValue); + + functionsState.vectorManager.RecreateIndex(newIndexPtr, ref newValue); + } + + break; + + case RespCommand.VREM: + Debug.Assert(input.arg1 == VectorManager.VREMAppendLogArg, "Unexpected CopyUpdater call on VREM key"); + break; + default: if (input.header.cmd > RespCommandExtensions.LastValidCommand) { diff --git a/libs/server/Storage/Functions/MainStore/ReadMethods.cs b/libs/server/Storage/Functions/MainStore/ReadMethods.cs index d23e5af89dd..53d6a72fe3f 100644 --- a/libs/server/Storage/Functions/MainStore/ReadMethods.cs +++ b/libs/server/Storage/Functions/MainStore/ReadMethods.cs @@ -17,7 +17,7 @@ public bool SingleReader( ref SpanByte key, ref RawStringInput input, ref SpanByte value, ref SpanByteAndMemory dst, ref ReadInfo readInfo) { - if (value.MetadataSize != 0 && CheckExpiry(ref value)) + if (value.MetadataSize == 8 && CheckExpiry(ref value)) { readInfo.RecordInfo.ClearHasETag(); return false; @@ -25,6 +25,36 @@ public bool SingleReader( var cmd = input.header.cmd; + // Ignore special Vector Set logic if we're scanning, detected with cmd == NONE + if (cmd != RespCommand.NONE) + { + // Vector sets are reachable (key not mangled) and hidden. + // So we can use that to detect type mismatches. + if (readInfo.RecordInfo.VectorSet && !cmd.IsLegalOnVectorSet()) + { + // Attempted an illegal op on a VectorSet + readInfo.Action = ReadAction.CancelOperation; + return false; + } + else if (!readInfo.RecordInfo.VectorSet && cmd.IsLegalOnVectorSet()) + { + // Attempted a vector set op on a non-VectorSet + readInfo.Action = ReadAction.CancelOperation; + return false; + } + } + + // GET is used in a number of non-RESP contexts, which messes up existing logic + // + // Easiest to mark the actually-RESP commands with a < 0 arg1 and roll back to old logic + // after the Vector Set checks + // + // TODO: This is quite hacky, but requires a bunch of non-Vector Set changes - do those and remove + if (input.arg1 < 0 && cmd == RespCommand.GET) + { + cmd = RespCommand.NONE; + } + if (cmd == RespCommand.GETIFNOTMATCH) { if (handleGetIfNotMatch(ref input, ref value, ref dst, ref readInfo)) @@ -87,7 +117,7 @@ public bool ConcurrentReader( ref SpanByte key, ref RawStringInput input, ref SpanByte value, ref SpanByteAndMemory dst, ref ReadInfo readInfo, ref RecordInfo recordInfo) { - if (value.MetadataSize != 0 && CheckExpiry(ref value)) + if (value.MetadataSize == 8 && CheckExpiry(ref value)) { recordInfo.ClearHasETag(); return false; @@ -95,6 +125,36 @@ public bool ConcurrentReader( var cmd = input.header.cmd; + // Ignore special Vector Set logic if we're scanning, detected with cmd == NONE + if (cmd != RespCommand.NONE) + { + // Vector sets are reachable (key not mangled) and hidden. + // So we can use that to detect type mismatches. + if (recordInfo.VectorSet && !cmd.IsLegalOnVectorSet()) + { + // Attempted an illegal op on a VectorSet + readInfo.Action = ReadAction.CancelOperation; + return false; + } + else if (!recordInfo.VectorSet && cmd.IsLegalOnVectorSet()) + { + // Attempted a vector set op on a non-VectorSet + readInfo.Action = ReadAction.CancelOperation; + return false; + } + } + + // GET is used in a number of non-RESP contexts, which messes up existing logic + // + // Easiest to mark the actually-RESP commands with a < 0 arg1 and roll back to old logic + // after the Vector Set checks + // + // TODO: This is quite hacky, but requires a bunch of non-Vector Set changes - do those and remove + if (input.arg1 < 0 && cmd == RespCommand.GET) + { + cmd = RespCommand.NONE; + } + if (cmd == RespCommand.GETIFNOTMATCH) { if (handleGetIfNotMatch(ref input, ref value, ref dst, ref readInfo)) @@ -137,7 +197,6 @@ public bool ConcurrentReader( return true; } - if (cmd == RespCommand.NONE) CopyRespTo(ref value, ref dst, functionsState.etagState.etagSkippedStart, functionsState.etagState.etagAccountedLength); else diff --git a/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs b/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs index adc5b124249..0130dcbe389 100644 --- a/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs +++ b/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs @@ -113,6 +113,9 @@ public int GetRMWInitialValueLength(ref RawStringInput input) ndigits = NumUtils.CountCharsInDouble(incrByFloat, out var _, out var _, out var _); return sizeof(int) + ndigits; + case RespCommand.VADD: + return sizeof(int) + VectorManager.IndexSizeBytes; + default: if (cmd > RespCommandExtensions.LastValidCommand) { @@ -236,6 +239,9 @@ public int GetRMWModifiedValueLength(ref SpanByte t, ref RawStringInput input) // Min allocation (only metadata) needed since this is going to be used for tombstoning anyway. return sizeof(int); + case RespCommand.VADD: + return t.Length; + default: if (cmd > RespCommandExtensions.LastValidCommand) { diff --git a/libs/server/Storage/Functions/MainStore/VectorSessionFunctions.cs b/libs/server/Storage/Functions/MainStore/VectorSessionFunctions.cs new file mode 100644 index 00000000000..ca375af1717 --- /dev/null +++ b/libs/server/Storage/Functions/MainStore/VectorSessionFunctions.cs @@ -0,0 +1,413 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Runtime.InteropServices; +using Tsavorite.core; + +namespace Garnet.server +{ + /// + /// Functions for operating against the Main Store, but for data stored as part of a Vector Set operation - not a RESP command. + /// + public readonly struct VectorSessionFunctions : ISessionFunctions + { + private readonly FunctionsState functionsState; + + /// + /// Constructor + /// + internal VectorSessionFunctions(FunctionsState functionsState) + { + this.functionsState = functionsState; + } + + #region Deletes + /// + public bool SingleDeleter(ref SpanByte key, ref SpanByte value, ref DeleteInfo deleteInfo, ref RecordInfo recordInfo) + { + Debug.Assert(key.MetadataSize == 1, "Should never delete a non-namespaced value with VectorSessionFunctions"); + + recordInfo.ClearHasETag(); + functionsState.watchVersionMap.IncrementVersion(deleteInfo.KeyHash); + return true; + } + /// + public bool ConcurrentDeleter(ref SpanByte key, ref SpanByte value, ref DeleteInfo deleteInfo, ref RecordInfo recordInfo) + { + Debug.Assert(key.MetadataSize == 1, "Should never delete a non-namespaced value with VectorSessionFunctions"); + + recordInfo.ClearHasETag(); + if (!deleteInfo.RecordInfo.Modified) + functionsState.watchVersionMap.IncrementVersion(deleteInfo.KeyHash); + return true; + } + /// + public void PostSingleDeleter(ref SpanByte key, ref DeleteInfo deleteInfo) { } + + public void PostDeleteOperation(ref SpanByte key, ref DeleteInfo deleteInfo, TEpochAccessor epoch) where TEpochAccessor : IEpochAccessor { } + #endregion + + #region Reads + /// + public bool SingleReader(ref SpanByte key, ref VectorInput input, ref SpanByte value, ref SpanByte dst, ref ReadInfo readInfo) + { + Debug.Assert(key.MetadataSize == 1, "Should never read a non-namespaced value with VectorSessionFunctions"); + + unsafe + { + if (input.Callback != 0) + { + var callback = (delegate* unmanaged[Cdecl, SuppressGCTransition])input.Callback; + + callback(input.Index, input.CallbackContext, (nint)value.ToPointer(), (nuint)value.Length); + return true; + } + } + + if (input.ReadDesiredSize > 0) + { + Debug.Assert(dst.Length >= value.Length, "Should always have space for vector point reads"); + + dst.Length = value.Length; + value.AsReadOnlySpan(functionsState.etagState.etagSkippedStart).CopyTo(dst.AsSpan()); + } + else + { + input.ReadDesiredSize = value.Length; + if (dst.Length >= value.Length) + { + value.AsReadOnlySpan(functionsState.etagState.etagSkippedStart).CopyTo(dst.AsSpan()); + dst.Length = value.Length; + } + } + + return true; + } + /// + public bool ConcurrentReader(ref SpanByte key, ref VectorInput input, ref SpanByte value, ref SpanByte dst, ref ReadInfo readInfo, ref RecordInfo recordInfo) + => SingleReader(ref key, ref input, ref value, ref dst, ref readInfo); + + /// + public void ReadCompletionCallback(ref SpanByte key, ref VectorInput input, ref SpanByte output, long ctx, Status status, RecordMetadata recordMetadata) + { + } + #endregion + + #region Initial Values + /// + public bool NeedInitialUpdate(ref SpanByte key, ref VectorInput input, ref SpanByte output, ref RMWInfo rmwInfo) + { + Debug.Assert(key.MetadataSize == 1, "Should never write a non-namespaced value with VectorSessionFunctions"); + + // Only needed when updating ContextMetadata or InProgressDeletes via RMW or the DiskANN RMW callback, all of which set WriteDesiredSize + return input.WriteDesiredSize != 0; + } + /// + public bool InitialUpdater(ref SpanByte key, ref VectorInput input, ref SpanByte value, ref SpanByte output, ref RMWInfo rmwInfo, ref RecordInfo recordInfo) + { + Debug.Assert(key.MetadataSize == 1, "Should never write a non-namespaced value with VectorSessionFunctions"); + + if (input.Callback == 0) + { + Debug.Assert(key.GetNamespaceInPayload() == 0, "Should be operating on special namespace"); + + if (key.LengthWithoutMetadata == 0) + { + // Operating on ContextMetadata + + SpanByte newMetadataValue; + unsafe + { + newMetadataValue = SpanByte.FromPinnedPointer((byte*)input.CallbackContext, VectorManager.ContextMetadata.Size); + } + + return SpanByteFunctions.DoSafeCopy(ref newMetadataValue, ref value, ref rmwInfo, ref recordInfo); + } + else + { + // Operating on InProgressDeletes + Debug.Assert(input.CallbackContext != 0, "Should have data on VectorInput"); + Debug.Assert(key.LengthWithoutMetadata == 1 && key.AsReadOnlySpan()[0] == 1, "Should be working on InProgressDeletes"); + + Span inProgressDeleteUpdateData; + bool adding; + + unsafe + { + var len = BinaryPrimitives.ReadInt32LittleEndian(new Span((byte*)input.CallbackContext + sizeof(long), sizeof(int))); + adding = len > 0; + if (!adding) + { + len = -len; + } + + inProgressDeleteUpdateData = new Span((byte*)input.CallbackContext, sizeof(ulong) + sizeof(int) + len); + } + + if (!adding) + { + // We may be recovering and doing some optimistic deletes, but since we're creating... just ignore the op, it does nothing + rmwInfo.Action = RMWAction.CancelOperation; + return false; + } + + var fits = VectorManager.TryUpdateInProgressDeletes(inProgressDeleteUpdateData, ref value, ref recordInfo, ref rmwInfo); + Debug.Assert(fits, "Initial size of record should have been correct for in progress deletes"); + + return true; + } + } + else + { + Debug.Assert(input.WriteDesiredSize <= value.LengthWithoutMetadata, "Insufficient space for initial update, this should never happen"); + + rmwInfo.ClearExtraValueLength(ref recordInfo, ref value, value.TotalSize); + + // Must explicitly 0 before passing if we're doing an initial update + value.AsSpan().Clear(); + + unsafe + { + // Callback takes: dataCallbackContext, dataPtr, dataLength + var callback = (delegate* unmanaged[Cdecl, SuppressGCTransition])input.Callback; + callback(input.CallbackContext, (nint)value.ToPointer(), (nuint)input.WriteDesiredSize); + + value.ShrinkSerializedLength(input.WriteDesiredSize); + value.Length = input.WriteDesiredSize; + } + + return true; + } + } + /// + public void PostInitialUpdater(ref SpanByte key, ref VectorInput input, ref SpanByte value, ref SpanByte output, ref RMWInfo rmwInfo) { } + #endregion + + #region Writes + /// + public bool SingleWriter(ref SpanByte key, ref VectorInput input, ref SpanByte src, ref SpanByte dst, ref SpanByte output, ref UpsertInfo upsertInfo, WriteReason reason, ref RecordInfo recordInfo) + => ConcurrentWriter(ref key, ref input, ref src, ref dst, ref output, ref upsertInfo, ref recordInfo); + + /// + public void PostSingleWriter(ref SpanByte key, ref VectorInput input, ref SpanByte src, ref SpanByte dst, ref SpanByte output, ref UpsertInfo upsertInfo, WriteReason reason) { } + /// + public bool ConcurrentWriter(ref SpanByte key, ref VectorInput input, ref SpanByte src, ref SpanByte dst, ref SpanByte output, ref UpsertInfo upsertInfo, ref RecordInfo recordInfo) + { + Debug.Assert(key.MetadataSize == 1, "Should never write a non-namespaced value with VectorSessionFunctions"); + + return SpanByteFunctions.DoSafeCopy(ref src, ref dst, ref upsertInfo, ref recordInfo, 0); + } + + public void PostUpsertOperation(ref SpanByte key, ref VectorInput input, ref SpanByte src, ref UpsertInfo upsertInfo, TEpochAccessor epoch) where TEpochAccessor : IEpochAccessor { } + + #endregion + + #region RMW + /// + public int GetRMWInitialValueLength(ref VectorInput input) + { + var effectiveWriteDesiredSize = input.WriteDesiredSize; + + if (effectiveWriteDesiredSize < 0) + { + effectiveWriteDesiredSize = -effectiveWriteDesiredSize; + } + + return sizeof(int) + effectiveWriteDesiredSize; + } + /// + public int GetRMWModifiedValueLength(ref SpanByte value, ref VectorInput input) + { + if (input.WriteDesiredSize < 0) + { + // Add to value, this is a dynamically sized type + return value.Length + (-input.WriteDesiredSize); + } + + // Constant size indicated + return sizeof(int) + input.WriteDesiredSize; + } + + /// + public int GetUpsertValueLength(ref SpanByte value, ref VectorInput input) + => sizeof(int) + value.Length; + + /// + public bool InPlaceUpdater(ref SpanByte key, ref VectorInput input, ref SpanByte value, ref SpanByte output, ref RMWInfo rmwInfo, ref RecordInfo recordInfo) + { + Debug.Assert(key.MetadataSize == 1, "Should never write a non-namespaced value with VectorSessionFunctions"); + + if (input.Callback == 0) + { + // We're doing a Metadata or InProgressDelete update + + Debug.Assert(key.GetNamespaceInPayload() == 0, "Should be operating on special namespace"); + + if (key.LengthWithoutMetadata == 0) + { + // Doing a Metadata update + Debug.Assert(value.LengthWithoutMetadata == VectorManager.ContextMetadata.Size, "Should be ContextMetadata"); + Debug.Assert(input.CallbackContext != 0, "Should have data on VectorInput"); + + ref readonly var oldMetadata = ref MemoryMarshal.Cast(value.AsReadOnlySpan())[0]; + + SpanByte newMetadataValue; + unsafe + { + newMetadataValue = SpanByte.FromPinnedPointer((byte*)input.CallbackContext, VectorManager.ContextMetadata.Size); + } + + ref readonly var newMetadata = ref MemoryMarshal.Cast(newMetadataValue.AsReadOnlySpan())[0]; + + if (newMetadata.Version < oldMetadata.Version) + { + rmwInfo.Action = RMWAction.CancelOperation; + return false; + } + + return SpanByteFunctions.DoSafeCopy(ref newMetadataValue, ref value, ref rmwInfo, ref recordInfo); + } + else + { + // Doing an InProgressDelete update + Debug.Assert(input.CallbackContext != 0, "Should have data on VectorInput"); + Debug.Assert(key.LengthWithoutMetadata == 1 && key.AsReadOnlySpan()[0] == 1, "Should be working on InProgressDeletes"); + + Span inProgressDeleteUpdateData; + bool adding; + + unsafe + { + var len = BinaryPrimitives.ReadInt32LittleEndian(new Span(((byte*)input.CallbackContext + sizeof(long)), sizeof(int))); + adding = len > 0; + if (!adding) + { + len = -len; + } + + inProgressDeleteUpdateData = new Span((byte*)input.CallbackContext, sizeof(ulong) + sizeof(int) + len); + } + + return VectorManager.TryUpdateInProgressDeletes(inProgressDeleteUpdateData, ref value, ref recordInfo, ref rmwInfo); + } + } + else + { + Debug.Assert(input.WriteDesiredSize <= value.LengthWithoutMetadata, "Insufficient space for inplace update, this should never happen"); + + unsafe + { + // Callback takes: dataCallbackContext, dataPtr, dataLength + var callback = (delegate* unmanaged[Cdecl, SuppressGCTransition])input.Callback; + callback(input.CallbackContext, (nint)value.ToPointer(), (nuint)input.WriteDesiredSize); + } + + return true; + } + } + + /// + public bool NeedCopyUpdate(ref SpanByte key, ref VectorInput input, ref SpanByte oldValue, ref SpanByte output, ref RMWInfo rmwInfo) + => input.WriteDesiredSize != 0; + + /// + public bool CopyUpdater(ref SpanByte key, ref VectorInput input, ref SpanByte oldValue, ref SpanByte newValue, ref SpanByte output, ref RMWInfo rmwInfo, ref RecordInfo recordInfo) + { + Debug.Assert(key.MetadataSize == 1, "Should never write a non-namespaced value with VectorSessionFunctions"); + + if (input.Callback == 0) + { + // We're doing a Metadata or InProgressDelete update + + Debug.Assert(key.GetNamespaceInPayload() == 0, "Should be operating on special namespace"); + + if (key.LengthWithoutMetadata == 0) + { + // Doing a Metadata update + Debug.Assert(oldValue.LengthWithoutMetadata == VectorManager.ContextMetadata.Size, "Should be ContextMetadata"); + Debug.Assert(newValue.LengthWithoutMetadata == VectorManager.ContextMetadata.Size, "Should be ContextMetadata"); + Debug.Assert(input.CallbackContext != 0, "Should have data on VectorInput"); + + ref readonly var oldMetadata = ref MemoryMarshal.Cast(oldValue.AsReadOnlySpan())[0]; + + SpanByte newMetadataValue; + unsafe + { + newMetadataValue = SpanByte.FromPinnedPointer((byte*)input.CallbackContext, VectorManager.ContextMetadata.Size); + } + + ref readonly var newMetadata = ref MemoryMarshal.Cast(newMetadataValue.AsReadOnlySpan())[0]; + + if (newMetadata.Version < oldMetadata.Version) + { + rmwInfo.Action = RMWAction.CancelOperation; + return false; + } + + return SpanByteFunctions.DoSafeCopy(ref newMetadataValue, ref newValue, ref rmwInfo, ref recordInfo); + } + else + { + // Doing an InProgressDelete update + Debug.Assert(input.CallbackContext != 0, "Should have data on VectorInput"); + Debug.Assert(key.LengthWithoutMetadata == 1 && key.AsReadOnlySpan()[0] == 1, "Should be working on InProgressDeletes"); + + Span inProgressDeleteUpdateData; + bool adding; + + oldValue.CopyTo(ref newValue); + + unsafe + { + var len = BinaryPrimitives.ReadInt32LittleEndian(new Span(((byte*)input.CallbackContext + sizeof(long)), sizeof(int))); + adding = len > 0; + if (!adding) + { + len = -len; + } + + inProgressDeleteUpdateData = new Span((byte*)input.CallbackContext, sizeof(ulong) + sizeof(int) + len); + } + + var fits = VectorManager.TryUpdateInProgressDeletes(inProgressDeleteUpdateData, ref newValue, ref recordInfo, ref rmwInfo); + Debug.Assert(fits, "Copy update should have allocated enough space for in progress deletes"); + + return true; + } + } + else + { + Debug.Assert(input.WriteDesiredSize <= newValue.LengthWithoutMetadata, "Insufficient space for copy update, this should never happen"); + Debug.Assert(input.WriteDesiredSize <= oldValue.LengthWithoutMetadata, "Insufficient space for copy update, this should never happen"); + + oldValue.AsReadOnlySpan().CopyTo(newValue.AsSpan()); + + unsafe + { + // Callback takes: dataCallbackContext, dataPtr, dataLength + var callback = (delegate* unmanaged[Cdecl, SuppressGCTransition])input.Callback; + callback(input.CallbackContext, (nint)newValue.ToPointer(), (nuint)input.WriteDesiredSize); + } + + return true; + } + } + + /// + public bool PostCopyUpdater(ref SpanByte key, ref VectorInput input, ref SpanByte oldValue, ref SpanByte newValue, ref SpanByte output, ref RMWInfo rmwInfo) + => true; + /// + public void RMWCompletionCallback(ref SpanByte key, ref VectorInput input, ref SpanByte output, long ctx, Status status, RecordMetadata recordMetadata) { } + + public void PostRMWOperation(ref SpanByte key, ref VectorInput input, ref RMWInfo rmwInfo, TEpochAccessor epoch) where TEpochAccessor : IEpochAccessor { } + #endregion + + #region Utilities + /// + public void ConvertOutputToHeap(ref VectorInput input, ref SpanByte output) { } + #endregion + } +} \ No newline at end of file diff --git a/libs/server/Storage/Session/Common/ArrayKeyIterationFunctions.cs b/libs/server/Storage/Session/Common/ArrayKeyIterationFunctions.cs index b4cb3c530de..319f440ff9a 100644 --- a/libs/server/Storage/Session/Common/ArrayKeyIterationFunctions.cs +++ b/libs/server/Storage/Session/Common/ArrayKeyIterationFunctions.cs @@ -258,7 +258,7 @@ protected override bool DeleteIfExpiredInMemory(ref byte[] key, ref IGarnetObjec internal sealed class MainStoreExpiredKeyDeletionScan : ExpiredKeysBase { - protected override bool IsExpired(ref SpanByte value) => value.MetadataSize > 0 && MainSessionFunctions.CheckExpiry(ref value); + protected override bool IsExpired(ref SpanByte value) => value.MetadataSize == 8 && MainSessionFunctions.CheckExpiry(ref value); protected override bool DeleteIfExpiredInMemory(ref SpanByte key, ref SpanByte value, RecordMetadata recordMetadata) { var input = new RawStringInput(RespCommand.DELIFEXPIM); @@ -323,8 +323,15 @@ public bool SingleReader(ref SpanByte key, ref SpanByte value, RecordMetadata re public bool ConcurrentReader(ref SpanByte key, ref SpanByte value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) { + // TODO: A better check for "is probably a vector key" + if (key.MetadataSize == 1) + { + cursorRecordResult = CursorRecordResult.Skip; + return true; + } + if ((info.patternB != null && !GlobUtils.Match(info.patternB, info.patternLength, key.ToPointer(), key.Length, true)) - || (value.MetadataSize != 0 && MainSessionFunctions.CheckExpiry(ref value))) + || (value.MetadataSize == 8 && MainSessionFunctions.CheckExpiry(ref value))) { cursorRecordResult = CursorRecordResult.Skip; } @@ -410,7 +417,14 @@ internal sealed class MainStoreGetDBSize : IScanIteratorFunctions(ref SpanByte key, ref RawStringInpu CompletePendingForSession(ref status, ref output, ref context); if (status.Found) + { return GarnetStatus.OK; + } + else if (status.IsCanceled) + { + // Vector Sets signal WRONGTYPE via cancellation - everything else will fall into NOTFOUND + return GarnetStatus.WRONGTYPE; + } else + { return GarnetStatus.NOTFOUND; + } } diff --git a/libs/server/Storage/Session/MainStore/MainStoreOps.cs b/libs/server/Storage/Session/MainStore/MainStoreOps.cs index 0d7d870c936..e3e53d63808 100644 --- a/libs/server/Storage/Session/MainStore/MainStoreOps.cs +++ b/libs/server/Storage/Session/MainStore/MainStoreOps.cs @@ -36,6 +36,10 @@ public GarnetStatus GET(ref SpanByte key, ref RawStringInput input, re incr_session_found(); return GarnetStatus.OK; } + else if (status.IsCanceled) + { + return GarnetStatus.WRONGTYPE; + } else { incr_session_notfound(); @@ -107,7 +111,7 @@ public unsafe GarnetStatus GET(ArgSlice key, out ArgSlice value, ref T public unsafe GarnetStatus GET(ArgSlice key, out MemoryResult value, ref TContext context) where TContext : ITsavoriteContext { - var input = new RawStringInput(RespCommand.GET); + var input = new RawStringInput(RespCommand.GET, arg1: -1); var _key = key.SpanByte; var _output = new SpanByteAndMemory(); @@ -589,6 +593,12 @@ public GarnetStatus DELETE(ref SpanByte key, StoreType if (storeType == StoreType.Main || storeType == StoreType.All) { var status = context.Delete(ref key); + if (status.IsCanceled) + { + // Might be a vector set + status = vectorManager.TryDeleteVectorSet(this, ref key, out _); + } + Debug.Assert(!status.IsPending); if (status.Found) found = true; } @@ -600,10 +610,11 @@ public GarnetStatus DELETE(ref SpanByte key, StoreType Debug.Assert(!status.IsPending); if (status.Found) found = true; } + return found ? GarnetStatus.OK : GarnetStatus.NOTFOUND; } - public GarnetStatus DELETE(byte[] key, StoreType storeType, ref TContext context, ref TObjectContext objectContext) + public unsafe GarnetStatus DELETE(byte[] key, StoreType storeType, ref TContext context, ref TObjectContext objectContext) where TContext : ITsavoriteContext where TObjectContext : ITsavoriteContext { @@ -612,6 +623,18 @@ public GarnetStatus DELETE(byte[] key, StoreType store if ((storeType == StoreType.Object || storeType == StoreType.All) && !objectStoreBasicContext.IsNull) { var status = objectContext.Delete(key); + if (status.IsCanceled) + { + // Might be a vector set + fixed (byte* keyPtr = key) + { + SpanByte keySpan = new(key.Length, (nint)keyPtr); + status = vectorManager.TryDeleteVectorSet(this, ref keySpan, out _); + } + + if (status.Found) found = true; + } + Debug.Assert(!status.IsPending); if (status.Found) found = true; } diff --git a/libs/server/Storage/Session/MainStore/VectorStoreOps.cs b/libs/server/Storage/Session/MainStore/VectorStoreOps.cs new file mode 100644 index 00000000000..1719d921e2d --- /dev/null +++ b/libs/server/Storage/Session/MainStore/VectorStoreOps.cs @@ -0,0 +1,370 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Tsavorite.core; + +namespace Garnet.server +{ + /// + /// Supported quantizations of vector data. + /// + /// This controls the mapping of vector elements to how they're actually stored. + /// + public enum VectorQuantType + { + Invalid = 0, + + // Redis quantiziations + + /// + /// Vectors stored as is with no quantization. + /// + NoQuant, + /// + /// Vectors stored as binary (1 bit). + /// + Bin, + /// + /// Vectors stored as bytes (8 bits). + /// + Q8, + + // Extended quantizations + + /// + /// Vectors stored as bytes (8 bits). XPREQ8 is a non-Redis extension, stands for: + /// eXtension PREcalculated Quantization 8-bit - requests no quantization on pre-calculated [0, 255] values + /// + XPreQ8, + } + + /// + /// Supported formats for Vector value data. + /// + public enum VectorValueType : int + { + Invalid = 0, + + // Redis formats + + /// + /// Floats (FP32). + /// + FP32, + + // Extended formats + + /// + /// Bytes (8 bit). + /// + XB8, + } + + /// + /// How result ids are formatted in responses from DiskANN. + /// + public enum VectorIdFormat : int + { + Invalid = 0, + + /// + /// Has 4 bytes of unsigned length before the data. + /// + I32LengthPrefixed, + + /// + /// Ids are actually 4-byte ints, no prefix. + /// + FixedI32 + } + + /// + /// Supported distance metrics for vector similarity search. + /// Aligned with DiskANN's Metric type + /// + public enum VectorDistanceMetricType : int + { + Invalid = -1, + + /// + /// Cosine similarity + /// + Cosine = 0, + + /// + /// Inner product + /// + InnerProduct, + + /// + /// Squared Euclidean (L2-Squared) + /// + L2, + + /// + /// Normalized Cosine Similarity + /// + CosineNormalized, + } + + /// + /// Implementation of Vector Set operations. + /// + sealed partial class StorageSession : IDisposable + { + /// + /// Implement Vector Set Add - this may also create a Vector Set if one does not already exist. + /// + [SkipLocalsInit] + public unsafe GarnetStatus VectorSetAdd(SpanByte key, int reduceDims, VectorValueType valueType, ArgSlice values, ArgSlice element, VectorQuantType quantizer, int buildExplorationFactor, ArgSlice attributes, int numLinks, VectorDistanceMetricType distanceMetric, out VectorManagerResult result, out ReadOnlySpan errorMsg) + { + var dims = VectorManager.CalculateValueDimensions(valueType, values.ReadOnlySpan); + + var dimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref dims, 1))); + var reduceDimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref reduceDims, 1))); + var valueTypeArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref valueType, 1))); + var valuesArg = values; + var elementArg = element; + var quantizerArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref quantizer, 1))); + var buildExplorationFactorArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref buildExplorationFactor, 1))); + var attributesArg = attributes; + var numLinksArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref numLinks, 1))); + var distanceMetricArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref distanceMetric, 1))); + + parseState.InitializeWithArguments([dimsArg, reduceDimsArg, valueTypeArg, valuesArg, elementArg, quantizerArg, buildExplorationFactorArg, attributesArg, numLinksArg, distanceMetricArg]); + + var input = new RawStringInput(RespCommand.VADD, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadOrCreateVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + result = VectorManagerResult.Invalid; + errorMsg = default; + return status; + } + + // After a successful read we add the vector while holding a shared lock + // That lock prevents deletion, but everything else can proceed in parallel + result = vectorManager.TryAdd(indexSpan, element.ReadOnlySpan, valueType, values.ReadOnlySpan, attributes.ReadOnlySpan, (uint)reduceDims, quantizer, (uint)buildExplorationFactor, (uint)numLinks, distanceMetric, out errorMsg); + + if (result == VectorManagerResult.OK) + { + // On successful addition, we need to manually replicate the write + vectorManager.ReplicateVectorSetAdd(ref key, ref input, ref basicContext); + } + + return GarnetStatus.OK; + } + } + + /// + /// Implement Vector Set Remove - returns not found if the element is not present, or the vector set does not exist. + /// + [SkipLocalsInit] + public unsafe GarnetStatus VectorSetRemove(SpanByte key, SpanByte element) + { + var input = new RawStringInput(RespCommand.VREM, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + return status; + } + + // After a successful read we remove the vector while holding a shared lock + // That lock prevents deletion, but everything else can proceed in parallel + var res = vectorManager.TryRemove(indexSpan, element.AsReadOnlySpan()); + + if (res == VectorManagerResult.OK) + { + // On successful removal, we need to manually replicate the write + vectorManager.ReplicateVectorSetRemove(ref key, ref element, ref input, ref basicContext); + + return GarnetStatus.OK; + } + + return GarnetStatus.NOTFOUND; + } + } + + /// + /// Perform a similarity search on an existing Vector Set given a vector as a bunch of floats. + /// + [SkipLocalsInit] + public unsafe GarnetStatus VectorSetValueSimilarity(SpanByte key, VectorValueType valueType, ArgSlice values, int count, float delta, int searchExplorationFactor, ReadOnlySpan filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + { + parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); + + // Get the index + var input = new RawStringInput(RespCommand.VSIM, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + result = VectorManagerResult.Invalid; + outputIdFormat = VectorIdFormat.Invalid; + return status; + } + + result = vectorManager.ValueSimilarity(indexSpan, valueType, values.ReadOnlySpan, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes); + + return GarnetStatus.OK; + } + } + + /// + /// Perform a similarity search on an existing Vector Set given an element that is already in the Vector Set. + /// + [SkipLocalsInit] + public unsafe GarnetStatus VectorSetElementSimilarity(SpanByte key, ReadOnlySpan element, int count, float delta, int searchExplorationFactor, ReadOnlySpan filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + { + parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); + + var input = new RawStringInput(RespCommand.VSIM, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + result = VectorManagerResult.Invalid; + outputIdFormat = VectorIdFormat.Invalid; + return status; + } + + result = vectorManager.ElementSimilarity(indexSpan, element, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes); + return GarnetStatus.OK; + } + } + + /// + /// Get the approximate vector associated with an element, after (approximately) reversing any transformation. + /// + [SkipLocalsInit] + public unsafe GarnetStatus VectorSetEmbedding(SpanByte key, ReadOnlySpan element, out VectorQuantType quantType, ref SpanByteAndMemory outputDistances) + { + parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); + + var input = new RawStringInput(RespCommand.VEMB, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + quantType = VectorQuantType.Invalid; + return status; + } + + if (!vectorManager.TryGetEmbedding(indexSpan, element, out quantType, ref outputDistances)) + { + return GarnetStatus.NOTFOUND; + } + + return GarnetStatus.OK; + } + } + + [SkipLocalsInit] + internal unsafe GarnetStatus VectorSetDimensions(SpanByte key, out int dimensions) + { + parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); + + var input = new RawStringInput(RespCommand.VDIM, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + dimensions = 0; + return status; + } + + // After a successful read we extract metadata + VectorManager.ReadIndex(indexSpan, out _, out var dimensionsUS, out var reducedDimensionsUS, out _, out _, out _, out _, out _, out _); + + dimensions = (int)(reducedDimensionsUS == 0 ? dimensionsUS : reducedDimensionsUS); + + return GarnetStatus.OK; + } + } + + /// + /// Get debugging information about the VectorSet + /// + [SkipLocalsInit] + internal unsafe GarnetStatus VectorSetInfo(SpanByte key, + out VectorQuantType quantType, + out VectorDistanceMetricType distanceMetricType, + out uint vectorDimensions, + out uint reducedDimensions, + out uint buildExplorationFactor, + out uint numberOfLinks, + out long size) + { + parseState.InitializeWithArgument(new(ref key)); + + var input = new RawStringInput(RespCommand.VINFO, ref parseState); + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + quantType = VectorQuantType.Invalid; + distanceMetricType = VectorDistanceMetricType.Invalid; + vectorDimensions = 0; + reducedDimensions = 0; + buildExplorationFactor = 0; + numberOfLinks = 0; + size = 0; + return status; + } + + // After a successful read we extract metadata + VectorManager.ReadIndex(indexSpan, out var context, out vectorDimensions, out reducedDimensions, out quantType, out buildExplorationFactor, out numberOfLinks, out distanceMetricType, out var indexPtr, out _); + size = (long)NativeDiskANNMethods.card(context, indexPtr); + + return GarnetStatus.OK; + } + } + + /// + /// Get the attributes associated with an element in the VectorSet + /// + [SkipLocalsInit] + internal unsafe GarnetStatus VectorSetGetAttribute(SpanByte key, ArgSlice elementId, ref SpanByteAndMemory outputAttributes) + { + parseState.InitializeWithArgument(new(ref key)); + + // Get the index + var input = new RawStringInput(RespCommand.VGETATTR, ref parseState); + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + return status; + } + + var result = vectorManager.FetchSingleVectorElementAttributes(indexSpan, elementId.SpanByte, ref outputAttributes); + return result == VectorManagerResult.OK ? GarnetStatus.OK : GarnetStatus.NOTFOUND; + } + } + } +} \ No newline at end of file diff --git a/libs/server/Storage/Session/ObjectStore/Common.cs b/libs/server/Storage/Session/ObjectStore/Common.cs index b8ebf286995..5e5a69ad82e 100644 --- a/libs/server/Storage/Session/ObjectStore/Common.cs +++ b/libs/server/Storage/Session/ObjectStore/Common.cs @@ -783,6 +783,41 @@ unsafe GarnetStatus ReadObjectStoreOperation(byte[] key, ref Obj return GarnetStatus.NOTFOUND; } + /// + /// Gets the value of the key store in the Object Store + /// + unsafe GarnetStatus ReadObjectStoreOperationWithObject(byte[] key, ref ObjectInput input, out ObjectOutputHeader output, out IGarnetObject garnetObject, ref TObjectContext objectStoreContext) + where TObjectContext : ITsavoriteContext + { + if (objectStoreContext.Session is null) + ThrowObjectStoreUninitializedException(); + + var _output = new GarnetObjectStoreOutput(); + + // Perform Read on object store + var status = objectStoreContext.Read(ref key, ref input, ref _output); + + if (status.IsPending) + CompletePendingForObjectStoreSession(ref status, ref _output, ref objectStoreContext); + + output = _output.Header; + + if (_output.HasWrongType) + { + garnetObject = null; + return GarnetStatus.WRONGTYPE; + } + + if (status.Found && (!status.Record.Created && !status.Record.CopyUpdated && !status.Record.InPlaceUpdated)) + { + garnetObject = _output.GarnetObject; + return GarnetStatus.OK; + } + + garnetObject = null; + return GarnetStatus.NOTFOUND; + } + /// /// Iterates members of a collection object using a cursor, /// a match pattern and count parameters diff --git a/libs/server/Storage/Session/StorageSession.cs b/libs/server/Storage/Session/StorageSession.cs index 22edec64896..0ff9717d3fb 100644 --- a/libs/server/Storage/Session/StorageSession.cs +++ b/libs/server/Storage/Session/StorageSession.cs @@ -42,6 +42,12 @@ sealed partial class StorageSession : IDisposable public BasicContext objectStoreBasicContext; public LockableContext objectStoreLockableContext; + /// + /// Session Contexts for vector ops against the main store + /// + public BasicContext vectorContext; + public LockableContext vectorLockableContext; + public readonly ScratchBufferBuilder scratchBufferBuilder; public readonly FunctionsState functionsState; @@ -55,11 +61,14 @@ sealed partial class StorageSession : IDisposable public readonly int ObjectScanCountLimit; + public readonly VectorManager vectorManager; + public StorageSession(StoreWrapper storeWrapper, ScratchBufferBuilder scratchBufferBuilder, GarnetSessionMetrics sessionMetrics, GarnetLatencyMetricsSession LatencyMetrics, int dbId, + VectorManager vectorManager, ILogger logger = null, byte respProtocolVersion = ServerOptions.DEFAULT_RESP_VERSION) { @@ -68,6 +77,7 @@ public StorageSession(StoreWrapper storeWrapper, this.scratchBufferBuilder = scratchBufferBuilder; this.logger = logger; this.itemBroker = storeWrapper.itemBroker; + this.vectorManager = vectorManager; parseState.Initialize(); functionsState = storeWrapper.CreateFunctionsState(dbId, respProtocolVersion); @@ -83,6 +93,9 @@ public StorageSession(StoreWrapper storeWrapper, var objectStoreFunctions = new ObjectSessionFunctions(functionsState); var objectStoreSession = db.ObjectStore?.NewSession(objectStoreFunctions); + var vectorFunctions = new VectorSessionFunctions(functionsState); + var vectorSession = db.MainStore.NewSession(vectorFunctions); + basicContext = session.BasicContext; lockableContext = session.LockableContext; if (objectStoreSession != null) @@ -90,6 +103,8 @@ public StorageSession(StoreWrapper storeWrapper, objectStoreBasicContext = objectStoreSession.BasicContext; objectStoreLockableContext = objectStoreSession.LockableContext; } + vectorContext = vectorSession.BasicContext; + vectorLockableContext = vectorSession.LockableContext; HeadAddress = db.MainStore.Log.HeadAddress; ObjectScanCountLimit = storeWrapper.serverOptions.ObjectScanCountLimit; diff --git a/libs/server/StoreWrapper.cs b/libs/server/StoreWrapper.cs index 9b28782b0c5..8201b14adc0 100644 --- a/libs/server/StoreWrapper.cs +++ b/libs/server/StoreWrapper.cs @@ -363,9 +363,11 @@ internal void Recover() { RecoverCheckpoint(); RecoverAOF(); - ReplayAOF(); + _ = ReplayAOF(); } } + + databaseManager.RecoverVectorSets(); } /// @@ -815,6 +817,10 @@ internal void Start() { StartPrimaryTasks(); } + else if (clusterProvider?.IsReplica() ?? false) + { + StartReplicaTasks(); + } // Start generic node tasks StartGenericNodeTasks(); @@ -833,6 +839,13 @@ public bool HasKeysInSlots(List slots) while (!hasKeyInSlots && iter.GetNext(out RecordInfo record)) { ref var key = ref iter.GetKey(); + + // TODO: better way to ignore vector set elements + if (key.MetadataSize == 1) + { + continue; + } + ushort hashSlotForKey = HashSlotUtils.HashSlot(ref key); if (slots.Contains(hashSlotForKey)) { @@ -919,6 +932,15 @@ public async Task SuspendPrimaryOnlyTasks() await taskManager.Cancel(TaskPlacementCategory.Primary); } + /// + /// Suspend background task that may interfere with the primary store. + /// + /// + public async Task SuspendReplicaOnlyTasks() + { + await taskManager.Cancel(TaskPlacementCategory.Replica); + } + /// /// Start background maintenance tasks that should only run when this node is a primary /// @@ -952,6 +974,17 @@ public void StartPrimaryTasks() } } + /// + /// Start background maintenance tasks that hsould only be run when this node is a replica. + /// + public void StartReplicaTasks() + { + if (serverOptions.EnableVectorSetPreview) + { + _ = taskManager.RegisterAndRun(TaskType.VectorReplicationReplayTask, token => DefaultDatabase.VectorManager.StartReplicationTasksAsync(token)); + } + } + /// /// Start background maintenance generic tasks /// diff --git a/libs/server/TaskManager/TaskType.cs b/libs/server/TaskManager/TaskType.cs index 59c295e3d47..b3a2766bd14 100644 --- a/libs/server/TaskManager/TaskType.cs +++ b/libs/server/TaskManager/TaskType.cs @@ -48,6 +48,11 @@ public enum TaskType : byte /// See for implementation. /// IndexAutoGrowTask, + + /// + /// Replays s on replicas in parallel. + /// + VectorReplicationReplayTask, } /// @@ -65,6 +70,7 @@ public static class TaskTypeExtensions /// static TaskTypeExtensions() { + TaskPlacementMapping[(int)TaskType.VectorReplicationReplayTask] = TaskPlacementCategory.Replica; TaskPlacementMapping[(int)TaskType.AofSizeLimitTask] = TaskPlacementCategory.Primary; TaskPlacementMapping[(int)TaskType.CommitTask] = TaskPlacementCategory.Primary; TaskPlacementMapping[(int)TaskType.CompactionTask] = TaskPlacementCategory.Primary; diff --git a/libs/server/Transaction/TransactionManager.cs b/libs/server/Transaction/TransactionManager.cs index e734fd85f4d..d86218f654c 100644 --- a/libs/server/Transaction/TransactionManager.cs +++ b/libs/server/Transaction/TransactionManager.cs @@ -15,13 +15,19 @@ namespace Garnet.server SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; using LockableGarnetApi = GarnetApi, SpanByteAllocator>>, LockableContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + LockableContext, + SpanByteAllocator>>>; using MainStoreAllocator = SpanByteAllocator>; using MainStoreFunctions = StoreFunctions; diff --git a/libs/server/Transaction/TxnKeyManager.cs b/libs/server/Transaction/TxnKeyManager.cs index 2e54ad3c95f..e57b568b158 100644 --- a/libs/server/Transaction/TxnKeyManager.cs +++ b/libs/server/Transaction/TxnKeyManager.cs @@ -50,7 +50,7 @@ public unsafe void VerifyKeyOwnership(ArgSlice key, LockType type) if (!clusterEnabled) return; var readOnly = type == LockType.Shared; - if (!respSession.clusterSession.NetworkIterativeSlotVerify(key, readOnly, respSession.SessionAsking)) + if (!respSession.clusterSession.NetworkIterativeSlotVerify(key, readOnly, respSession.SessionAsking, waitForStableSlot: false)) { this.state = TxnState.Aborted; } diff --git a/libs/server/Transaction/TxnRespCommands.cs b/libs/server/Transaction/TxnRespCommands.cs index e2c333d94d6..6b31be3c015 100644 --- a/libs/server/Transaction/TxnRespCommands.cs +++ b/libs/server/Transaction/TxnRespCommands.cs @@ -60,7 +60,7 @@ private bool NetworkEXEC() endReadHead = txnManager.txnStartHead; txnManager.GetKeysForValidation(recvBufferPtr, out var keys, out int keyCount, out bool readOnly); - if (NetworkKeyArraySlotVerify(keys, readOnly, keyCount)) + if (NetworkKeyArraySlotVerify(keys, readOnly, waitForStableSlot: false, keyCount)) // TODO: We should actually verify if commands contained are Vector Set writes { logger?.LogWarning("Failed CheckClusterTxnKeys"); txnManager.Reset(false); diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Common/RecordInfo.cs b/libs/storage/Tsavorite/cs/src/core/Index/Common/RecordInfo.cs index 5d82c473f53..180dfbb0259 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Common/RecordInfo.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Common/RecordInfo.cs @@ -11,7 +11,7 @@ namespace Tsavorite.core { // RecordInfo layout (64 bits total): - // [Unused1][Modified][InNewVersion][Filler][Dirty][ETag][Sealed][Valid][Tombstone][LLLLLLL] [RAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] + // [VectorSet][Modified][InNewVersion][Filler][Dirty][ETag][Sealed][Valid][Tombstone][LLLLLLL] [RAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] // where L = leftover, R = readcache, A = address [StructLayout(LayoutKind.Explicit, Size = 8)] public struct RecordInfo @@ -35,7 +35,7 @@ public struct RecordInfo const int kFillerBitOffset = kDirtyBitOffset + 1; const int kInNewVersionBitOffset = kFillerBitOffset + 1; const int kModifiedBitOffset = kInNewVersionBitOffset + 1; - const int kUnused1BitOffset = kModifiedBitOffset + 1; + const int kVectorSetBitOffset = kModifiedBitOffset + 1; const long kTombstoneBitMask = 1L << kTombstoneBitOffset; const long kValidBitMask = 1L << kValidBitOffset; @@ -45,7 +45,7 @@ public struct RecordInfo const long kFillerBitMask = 1L << kFillerBitOffset; const long kInNewVersionBitMask = 1L << kInNewVersionBitOffset; const long kModifiedBitMask = 1L << kModifiedBitOffset; - const long kUnused1BitMask = 1L << kUnused1BitOffset; + const long kVectorSetBitMask = 1L << kVectorSetBitOffset; [FieldOffset(0)] private long word; @@ -269,10 +269,10 @@ public long PreviousAddress [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int GetLength() => kTotalSizeInBytes; - internal bool Unused1 + public bool VectorSet { - readonly get => (word & kUnused1BitMask) != 0; - set => word = value ? word | kUnused1BitMask : word & ~kUnused1BitMask; + readonly get => (word & kVectorSetBitMask) != 0; + set => word = value ? word | kVectorSetBitMask : word & ~kVectorSetBitMask; } public bool ETag @@ -289,7 +289,7 @@ public override readonly string ToString() var paRC = IsReadCache(PreviousAddress) ? "(rc)" : string.Empty; static string bstr(bool value) => value ? "T" : "F"; return $"prev {AbsoluteAddress(PreviousAddress)}{paRC}, valid {bstr(Valid)}, tomb {bstr(Tombstone)}, seal {bstr(IsSealed)}," - + $" mod {bstr(Modified)}, dirty {bstr(Dirty)}, fill {bstr(HasFiller)}, etag {bstr(ETag)}, Un1 {bstr(Unused1)}"; + + $" mod {bstr(Modified)}, dirty {bstr(Dirty)}, fill {bstr(HasFiller)}, etag {bstr(ETag)}, vset {bstr(VectorSet)}"; } } } \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/InternalDelete.cs b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/InternalDelete.cs index d949cc4def1..f0b7744e03a 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/InternalDelete.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/InternalDelete.cs @@ -228,6 +228,7 @@ private OperationStatus CreateNewRecordDelete public static Status CreatePending() => new(StatusCode.Pending); + /// + /// Create a Status value. + /// + public static Status CreateNotFound() => new(StatusCode.NotFound); + + /// + /// Create a Error Status value. + /// + public static Status CreateError() => new(StatusCode.Error); + /// /// Whether a Read or RMW found the key /// diff --git a/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByte.cs b/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByte.cs index 5d46bdb8ec4..d3058c49f5a 100644 --- a/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByte.cs +++ b/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByte.cs @@ -25,8 +25,10 @@ public unsafe struct SpanByte private const int UnserializedBitMask = 1 << 31; // Byte #30 is used to denote extra metadata present (1) or absent (0) in payload private const int ExtraMetadataBitMask = 1 << 30; + // Bit #29 used to denote if a namespace is present in payload + private const int NamespaceBitMask = 1 << 29; // Mask for header - private const int HeaderMask = 0x3 << 30; + private const int HeaderMask = UnserializedBitMask | ExtraMetadataBitMask | NamespaceBitMask; /// /// Length of the payload @@ -93,9 +95,9 @@ public int Length public readonly int TotalSize => sizeof(int) + Length; /// - /// Size of metadata header, if any (returns 0 or 8) + /// Size of metadata header, if any (returns 0, 1, 8, or 9) /// - public readonly int MetadataSize => (length & ExtraMetadataBitMask) >> (30 - 3); + public readonly int MetadataSize => ((length & ExtraMetadataBitMask) >> (30 - 3)) + ((length & NamespaceBitMask) >> 29); /// /// Create a around a given pointer and given @@ -144,6 +146,7 @@ public long ExtraMetadata public void MarkExtraMetadata() { Debug.Assert(Length >= 8); + Debug.Assert((length & NamespaceBitMask) == 0, "Don't use both extension for now"); length |= ExtraMetadataBitMask; } @@ -153,6 +156,23 @@ public void MarkExtraMetadata() [MethodImpl(MethodImplOptions.AggressiveInlining)] public void UnmarkExtraMetadata() => length &= ~ExtraMetadataBitMask; + /// + /// Mark as having 1-byte namespace in header of payload + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void MarkNamespace() + { + Debug.Assert(Length >= 1); + Debug.Assert((length & ExtraMetadataBitMask) == 0, "Don't use both extension for now"); + length |= NamespaceBitMask; + } + + /// + /// Unmark as having 1-byte namespace in header of payload + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void UnmarkNamespace() => length &= ~NamespaceBitMask; + /// /// Check or set struct as invalid /// @@ -526,6 +546,18 @@ public void CopyTo(byte* destination) [MethodImpl(MethodImplOptions.AggressiveInlining)] public void SetEtagInPayload(long etag) => *(long*)this.ToPointer() = etag; + /// + /// Gets a namespace from the payload of the SpanByte, caller should make sure the SpanByte has a namespace for the record by checking RecordInfo + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public byte GetNamespaceInPayload() => *(byte*)this.ToPointerWithMetadata(); + + /// + /// Gets a namespace from the payload of the SpanByte, caller should make sure the SpanByte has a namespace for the record by checking RecordInfo + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void SetNamespaceInPayload(byte ns) => *(byte*)this.ToPointerWithMetadata() = ns; + /// public override string ToString() { diff --git a/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByteAndMemory.cs b/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByteAndMemory.cs index 6e8460c2662..cf6a1c5c9d0 100644 --- a/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByteAndMemory.cs +++ b/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByteAndMemory.cs @@ -83,6 +83,12 @@ public SpanByteAndMemory(IMemoryOwner memory, int length) [MethodImpl(MethodImplOptions.AggressiveInlining)] public ReadOnlySpan AsReadOnlySpan() => IsSpanByte ? SpanByte.AsReadOnlySpan() : Memory.Memory.Span.Slice(0, Length); + /// + /// As a span of the contained data. Use this when you haven't tested . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Span AsSpan() => IsSpanByte ? SpanByte.AsSpan() : Memory.Memory.Span.Slice(0, Length); + /// /// As a span of the contained data. Use this when you have already tested . /// diff --git a/playground/CommandInfoUpdater/GarnetCommandsInfo.json b/playground/CommandInfoUpdater/GarnetCommandsInfo.json index 52786d649d8..afb17f2c2e5 100644 --- a/playground/CommandInfoUpdater/GarnetCommandsInfo.json +++ b/playground/CommandInfoUpdater/GarnetCommandsInfo.json @@ -215,6 +215,19 @@ "KeySpecifications": null, "SubCommands": null }, + { + "Command": "CLUSTER_RESERVE", + "Name": "CLUSTER|RESERVE", + "IsInternal": true, + "Arity": 4, + "Flags": "Admin, NoScript, NoMulti", + "FirstKey": 0, + "LastKey": 0, + "Step": 0, + "AclCategories": "Admin, Dangerous, Garnet", + "KeySpecifications": null, + "SubCommands": null + }, { "Command": "CLUSTER_MTASKS", "Name": "CLUSTER|MTASKS", diff --git a/playground/CommandInfoUpdater/SupportedCommand.cs b/playground/CommandInfoUpdater/SupportedCommand.cs index a1a61b79234..20b2ed74a77 100644 --- a/playground/CommandInfoUpdater/SupportedCommand.cs +++ b/playground/CommandInfoUpdater/SupportedCommand.cs @@ -93,6 +93,7 @@ public class SupportedCommand new("CLUSTER|REPLICAS", RespCommand.CLUSTER_REPLICAS), new("CLUSTER|REPLICATE", RespCommand.CLUSTER_REPLICATE), new("CLUSTER|RESET", RespCommand.CLUSTER_RESET), + new("CLUSTER|RESERVE", RespCommand.CLUSTER_RESERVE), new("CLUSTER|SEND_CKPT_FILE_SEGMENT", RespCommand.CLUSTER_SEND_CKPT_FILE_SEGMENT), new("CLUSTER|SEND_CKPT_METADATA", RespCommand.CLUSTER_SEND_CKPT_METADATA), new("CLUSTER|SET-CONFIG-EPOCH", RespCommand.CLUSTER_SETCONFIGEPOCH), diff --git a/test/Garnet.test.cluster/ClusterAuthCommsTests.cs b/test/Garnet.test.cluster/ClusterAuthCommsTests.cs index 78f17c35e5a..9ed5f44dd26 100644 --- a/test/Garnet.test.cluster/ClusterAuthCommsTests.cs +++ b/test/Garnet.test.cluster/ClusterAuthCommsTests.cs @@ -185,7 +185,7 @@ public void ClusterSimpleFailoverAuth() // Setup single primary populate and then attach replicas ClusterReplicationAuth(); - context.ClusterFailoveSpinWait(replicaNodeIndex: 1, logger: context.logger); + context.ClusterFailoverSpinWait(replicaNodeIndex: 1, logger: context.logger); // Reconfigure slotMap to reflect new primary int[] slotMap = new int[16384]; diff --git a/test/Garnet.test.cluster/ClusterTestContext.cs b/test/Garnet.test.cluster/ClusterTestContext.cs index 3205eb8de28..1ccf8ca6bdd 100644 --- a/test/Garnet.test.cluster/ClusterTestContext.cs +++ b/test/Garnet.test.cluster/ClusterTestContext.cs @@ -117,7 +117,6 @@ public void RestartNode(int nodeIndex) nodes[nodeIndex].Start(); } - public void TearDown() { cts.Cancel(); @@ -671,7 +670,7 @@ public void SendAndValidateKeys(int primaryIndex, int replicaIndex, int keyLengt } } - public void ClusterFailoveSpinWait(int replicaNodeIndex, ILogger logger) + public void ClusterFailoverSpinWait(int replicaNodeIndex, ILogger logger) { // Failover primary _ = clusterTestUtils.ClusterFailover(replicaNodeIndex, "ABORT", logger); diff --git a/test/Garnet.test.cluster/ClusterTestUtils.cs b/test/Garnet.test.cluster/ClusterTestUtils.cs index e9ed58f2625..75b1cca2aef 100644 --- a/test/Garnet.test.cluster/ClusterTestUtils.cs +++ b/test/Garnet.test.cluster/ClusterTestUtils.cs @@ -9,12 +9,14 @@ using System.Net; using System.Net.Security; using System.Net.Sockets; +using System.Runtime.CompilerServices; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading; using System.Threading.Tasks; using Garnet.client; using Garnet.common; +using Garnet.server; using Garnet.server.TLS; using GarnetClusterManagement; using Microsoft.Extensions.Logging; @@ -1871,12 +1873,22 @@ public int MigrateTasks(IPEndPoint endPoint, ILogger logger) } } - public void WaitForMigrationCleanup(int nodeIndex, ILogger logger = null) - => WaitForMigrationCleanup(endpoints[nodeIndex].ToIPEndPoint(), logger); + public void WaitForMigrationCleanup(int nodeIndex, ILogger logger = null, CancellationToken cancellationToken = default) + => WaitForMigrationCleanup(endpoints[nodeIndex].ToIPEndPoint(), logger, cancellationToken); - public void WaitForMigrationCleanup(IPEndPoint endPoint, ILogger logger) + public void WaitForMigrationCleanup(IPEndPoint endPoint, ILogger logger, CancellationToken cancellationToken = default) { - while (MigrateTasks(endPoint, logger) > 0) { BackOff(cancellationToken: context.cts.Token); } + CancellationToken backoffToken; + if (cancellationToken.CanBeCanceled) + { + backoffToken = cancellationToken; + } + else + { + backoffToken = context.cts.Token; + } + + while (MigrateTasks(endPoint, logger) > 0) { BackOff(cancellationToken: backoffToken); } } public void WaitForMigrationCleanup(ILogger logger) @@ -2991,11 +3003,29 @@ public void WaitForReplicaAofSync(int primaryIndex, int secondaryIndex, ILogger primaryReplicationOffset = GetReplicationOffset(primaryIndex, logger); secondaryReplicationOffset1 = GetReplicationOffset(secondaryIndex, logger); if (primaryReplicationOffset == secondaryReplicationOffset1) + { + var storeWrapper = GetStoreWrapper(this.context.nodes[secondaryIndex]); + var dbManager = GetDatabaseManager(storeWrapper); + + dbManager.DefaultDatabase.VectorManager.WaitForVectorOperationsToComplete(); + break; + } var primaryMainStoreVersion = context.clusterTestUtils.GetStoreCurrentVersion(primaryIndex, isMainStore: true, logger); var replicaMainStoreVersion = context.clusterTestUtils.GetStoreCurrentVersion(secondaryIndex, isMainStore: true, logger); - BackOff(cancellationToken: context.cts.Token, msg: $"[{endpoints[primaryIndex]}]: {primaryMainStoreVersion},{primaryReplicationOffset} != [{endpoints[secondaryIndex]}]: {replicaMainStoreVersion},{secondaryReplicationOffset1}"); + + CancellationToken backoffToken; + if (cancellation.CanBeCanceled) + { + backoffToken = cancellation; + } + else + { + backoffToken = context.cts.Token; + } + + BackOff(cancellationToken: backoffToken, msg: $"[{endpoints[primaryIndex]}]: {primaryMainStoreVersion},{primaryReplicationOffset} != [{endpoints[secondaryIndex]}]: {replicaMainStoreVersion},{secondaryReplicationOffset1}"); } logger?.LogInformation("[{primaryEndpoint}]{primaryReplicationOffset} ?? [{endpoints[secondaryEndpoint}]{secondaryReplicationOffset1}", endpoints[primaryIndex], primaryReplicationOffset, endpoints[secondaryIndex], secondaryReplicationOffset1); } @@ -3258,5 +3288,11 @@ public int DBSize(IPEndPoint endPoint, ILogger logger = null) return -1; } } + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "storeWrapper")] + private static extern ref StoreWrapper GetStoreWrapper(GarnetServer server); + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "databaseManager")] + private static extern ref IDatabaseManager GetDatabaseManager(StoreWrapper server); } } \ No newline at end of file diff --git a/test/Garnet.test.cluster/RedirectTests/TestClusterProc.cs b/test/Garnet.test.cluster/RedirectTests/TestClusterProc.cs index e7a0607cfd2..9d793d0f952 100644 --- a/test/Garnet.test.cluster/RedirectTests/TestClusterProc.cs +++ b/test/Garnet.test.cluster/RedirectTests/TestClusterProc.cs @@ -115,13 +115,13 @@ public override void Main(TGarnetApi api, ref CustomProcedureInput p { var offset = 0; var getA = GetNextArg(ref procInput, ref offset); - var setB = GetNextArg(ref procInput, ref offset).SpanByte; - var setC = GetNextArg(ref procInput, ref offset).SpanByte; + var setB = GetNextArg(ref procInput, ref offset); + var setC = GetNextArg(ref procInput, ref offset); _ = api.GET(getA, out _); - var status = api.SET(ref setB, ref setB); + var status = api.SET(setB, setB); ClassicAssert.AreEqual(GarnetStatus.OK, status); - status = api.SET(ref setC, ref setC); + status = api.SET(setC, setC); ClassicAssert.AreEqual(GarnetStatus.OK, status); WriteSimpleString(ref output, "SUCCESS"); } diff --git a/test/Garnet.test.cluster/VectorSets/ClusterVectorSetTests.cs b/test/Garnet.test.cluster/VectorSets/ClusterVectorSetTests.cs new file mode 100644 index 00000000000..5e87f597d38 --- /dev/null +++ b/test/Garnet.test.cluster/VectorSets/ClusterVectorSetTests.cs @@ -0,0 +1,2063 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers.Binary; +using System.Collections.Concurrent; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Net; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Allure.NUnit; +using Garnet.server; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using NUnit.Framework; +using NUnit.Framework.Legacy; +using StackExchange.Redis; + +namespace Garnet.test.cluster +{ + [TestFixture, NonParallelizable] + [AllureNUnit] + public class ClusterVectorSetTests : AllureTestBase + { + private sealed class StringAndByteArrayComparer : IEqualityComparer<(string Key, byte[] Elem)> + { + public static readonly StringAndByteArrayComparer Instance = new(); + + private StringAndByteArrayComparer() { } + + public bool Equals((string Key, byte[] Elem) x, (string Key, byte[] Elem) y) + => x.Key.Equals(y.Key) && x.Elem.SequenceEqual(y.Elem); + + public int GetHashCode([DisallowNull] (string Key, byte[] Elem) obj) + { + HashCode code = default; + code.Add(obj.Key); + code.AddBytes(obj.Elem); + + return code.ToHashCode(); + } + } + + private sealed class CaptureLogWriter(TextWriter passThrough) : TextWriter + { + public bool capture; + public readonly StringBuilder buffer = new(); + + public override Encoding Encoding + => passThrough.Encoding; + + public override void Write(string value) + { + passThrough.Write(value); + + if (capture) + { + lock (buffer) + { + _ = buffer.Append(value); + } + } + } + } + + private const int DefaultShards = 2; + private const int HighReplicationShards = 6; + private const int DefaultMultiPrimaryShards = 4; + + private static readonly Dictionary MonitorTests = new() + { + [nameof(MigrateVectorStressAsync)] = LogLevel.Debug, + }; + + + private ClusterTestContext context; + + private CaptureLogWriter captureLogWriter; + + [SetUp] + public virtual void Setup() + { + captureLogWriter = new(TestContext.Progress); + + context = new ClusterTestContext(); + context.logTextWriter = captureLogWriter; + context.Setup(MonitorTests); + } + + [TearDown] + public virtual void TearDown() + { + context?.TearDown(); + } + + [Test] + [TestCase("XB8", "XPREQ8")] + [TestCase("XB8", "Q8")] + [TestCase("XB8", "BIN")] + [TestCase("XB8", "NOQUANT")] + [TestCase("FP32", "XPREQ8")] + [TestCase("FP32", "Q8")] + [TestCase("FP32", "BIN")] + [TestCase("FP32", "NOQUANT")] + public void BasicVADDReplicates(string vectorFormat, string quantizer) + { + // TODO: also test VALUES format? + + const int PrimaryIndex = 0; + const int SecondaryIndex = 1; + + ClassicAssert.IsTrue(Enum.TryParse(vectorFormat, ignoreCase: true, out var vectorFormatParsed)); + ClassicAssert.IsTrue(Enum.TryParse(quantizer, ignoreCase: true, out var quantTypeParsed)); + + context.CreateInstances(DefaultShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: 1, replica_count: 1); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var secondary = (IPEndPoint)context.endpoints[SecondaryIndex]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); + + byte[] vectorAddData; + if (vectorFormatParsed == VectorValueType.XB8) + { + vectorAddData = new byte[75]; + vectorAddData[0] = 1; + for (var i = 1; i < vectorAddData.Length; i++) + { + vectorAddData[i] = (byte)(vectorAddData[i - 1] + 1); + } + } + else if (vectorFormatParsed == VectorValueType.FP32) + { + var floats = new float[75]; + floats[0] = 1; + for (var i = 1; i < floats.Length; i++) + { + floats[i] = floats[i - 1] + 1; + } + + vectorAddData = MemoryMarshal.Cast(floats).ToArray(); + } + else + { + ClassicAssert.Fail("Unexpected vector format"); + return; + } + + var addRes = (int)context.clusterTestUtils.Execute(primary, "VADD", ["foo", vectorFormat, vectorAddData, new byte[] { 0, 0, 0, 0 }, quantizer]); + ClassicAssert.AreEqual(1, addRes); + + byte[] vectorSimData; + if (vectorFormatParsed == VectorValueType.XB8) + { + vectorSimData = new byte[75]; + vectorSimData[0] = 2; + for (var i = 1; i < vectorSimData.Length; i++) + { + vectorSimData[i] = (byte)(vectorSimData[i - 1] + 1); + } + } + else if (vectorFormatParsed == VectorValueType.FP32) + { + var floats = new float[75]; + floats[0] = 2; + for (var i = 1; i < floats.Length; i++) + { + floats[i] = floats[i - 1] + 1; + } + + vectorSimData = MemoryMarshal.Cast(floats).ToArray(); + } + else + { + ClassicAssert.Fail("Unexpected vector format"); + return; + } + + var simRes = (byte[][])context.clusterTestUtils.Execute(primary, "VSIM", ["foo", vectorFormat, vectorSimData]); + ClassicAssert.IsTrue(simRes.Length > 0); + + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, SecondaryIndex); + + var readonlyOnReplica = (string)context.clusterTestUtils.Execute(secondary, "READONLY", []); + ClassicAssert.AreEqual("OK", readonlyOnReplica); + + var simOnReplica = context.clusterTestUtils.Execute(secondary, "VSIM", ["foo", vectorFormat, vectorSimData]); + ClassicAssert.IsTrue(simOnReplica.Length > 0); + } + + [Test] + [TestCase(false)] + [TestCase(true)] + public async Task ConcurrentVADDReplicatedVSimsAsync(bool withAttributes) + { + const int PrimaryIndex = 0; + const int SecondaryIndex = 1; + const int Vectors = 2_000; + const string Key = nameof(ConcurrentVADDReplicatedVSimsAsync); + + context.CreateInstances(DefaultShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: 1, replica_count: 1); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var secondary = (IPEndPoint)context.endpoints[SecondaryIndex]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); + + // Build some repeatably random data for inserts + var vectors = new byte[Vectors][]; + { + var r = new Random(2025_09_15_00); + + for (var i = 0; i < vectors.Length; i++) + { + vectors[i] = new byte[75]; + r.NextBytes(vectors[i]); + } + } + + using var sync = new SemaphoreSlim(2); + + var writeTask = + Task.Run( + async () => + { + await sync.WaitAsync(); + + var key = new byte[4]; + for (var i = 0; i < vectors.Length; i++) + { + BinaryPrimitives.WriteInt32LittleEndian(key, i); + var val = vectors[i]; + int addRes; + if (withAttributes) + { + addRes = (int)context.clusterTestUtils.Execute(primary, "VADD", [Key, "XB8", val, key, "XPREQ8", "SETATTR", $"{{ \"id\": {i} }}"]); + } + else + { + addRes = (int)context.clusterTestUtils.Execute(primary, "VADD", [Key, "XB8", val, key, "XPREQ8"]); + } + ClassicAssert.AreEqual(1, addRes); + } + } + ); + + using var cts = new CancellationTokenSource(); + + var readTask = + Task.Run( + async () => + { + var r = new Random(2025_09_15_01); + + var readonlyOnReplica = (string)context.clusterTestUtils.Execute(secondary, "READONLY", []); + ClassicAssert.AreEqual("OK", readonlyOnReplica); + + await sync.WaitAsync(); + + var nonZeroReturns = 0; + var gotAttrs = 0; + + while (!cts.Token.IsCancellationRequested) + { + var val = vectors[r.Next(vectors.Length)]; + + if (withAttributes) + { + var readRes = (byte[][])context.clusterTestUtils.Execute(secondary, "VSIM", [Key, "XB8", val, "WITHATTRIBS"]); + if (readRes.Length > 0) + { + nonZeroReturns++; + } + + for (var i = 0; i < readRes.Length; i += 2) + { + var id = readRes[i]; + var attr = readRes[i + 1]; + + var asInt = BinaryPrimitives.ReadInt32LittleEndian(id); + + var actualAttr = Encoding.UTF8.GetString(attr); + var expectedAttr = $"{{ \"id\": {asInt} }}"; + + ClassicAssert.AreEqual(expectedAttr, actualAttr); + + gotAttrs++; + } + } + else + { + var readRes = (byte[][])context.clusterTestUtils.Execute(secondary, "VSIM", [Key, "XB8", val]); + if (readRes.Length > 0) + { + nonZeroReturns++; + } + } + } + + return (nonZeroReturns, gotAttrs); + } + ); + + _ = sync.Release(2); + await writeTask; + + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, SecondaryIndex); + + cts.CancelAfter(TimeSpan.FromSeconds(1)); + + var (searchesWithNonZeroResults, searchesWithAttrs) = await readTask; + + ClassicAssert.IsTrue(searchesWithNonZeroResults > 0); + + if (withAttributes) + { + ClassicAssert.IsTrue(searchesWithAttrs > 0); + } + + // Validate all nodes have same vector embeddings + { + var idBytes = new byte[4]; + for (var id = 0; id < vectors.Length; id++) + { + BinaryPrimitives.WriteInt32LittleEndian(idBytes, id); + var expected = vectors[id]; + + var fromPrimary = (string[])context.clusterTestUtils.Execute(primary, "VEMB", [Key, idBytes]); + var fromSecondary = (string[])context.clusterTestUtils.Execute(secondary, "VEMB", [Key, idBytes]); + + ClassicAssert.AreEqual(expected.Length, fromPrimary.Length); + ClassicAssert.AreEqual(expected.Length, fromSecondary.Length); + + for (var i = 0; i < expected.Length; i++) + { + var p = (byte)float.Parse(fromPrimary[i]); + var s = (byte)float.Parse(fromSecondary[i]); + + ClassicAssert.AreEqual(expected[i], p); + ClassicAssert.AreEqual(expected[i], s); + } + } + } + } + + [Test] + public void RepeatedCreateDelete() + { + const int PrimaryIndex = 0; + const int SecondaryIndex = 1; + + context.CreateInstances(DefaultShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: 1, replica_count: 1); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var secondary = (IPEndPoint)context.endpoints[SecondaryIndex]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); + + var bytes1 = new byte[75]; + bytes1[0] = 1; + for (var j = 1; j < bytes1.Length; j++) + { + bytes1[j] = (byte)(bytes1[j - 1] + 1); + } + + var bytes2 = new byte[75]; + bytes2[0] = 5; + for (var j = 1; j < bytes2.Length; j++) + { + bytes2[j] = (byte)(bytes2[j - 1] + 1); + } + + var bytes3 = new byte[75]; + bytes3[0] = 10; + for (var j = 1; j < bytes3.Length; j++) + { + bytes3[j] = (byte)(bytes3[j - 1] + 1); + } + + var key0 = new byte[4]; + key0[0] = 1; + var key1 = new byte[4]; + key1[0] = 2; + + for (var i = 0; i < 100; i++) + { + var delRes = (int)context.clusterTestUtils.Execute(primary, "DEL", ["foo"]); + + if (i != 0) + { + ClassicAssert.AreEqual(1, delRes); + } + else + { + ClassicAssert.AreEqual(0, delRes); + } + + var addRes1 = (int)context.clusterTestUtils.Execute(primary, "VADD", ["foo", "XB8", bytes1, key0, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes1); + + var addRes2 = (int)context.clusterTestUtils.Execute(primary, "VADD", ["foo", "XB8", bytes2, key1, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes2); + + var readPrimaryExc = (string)context.clusterTestUtils.Execute(primary, "GET", ["foo"]); + ClassicAssert.IsTrue(readPrimaryExc.StartsWith("WRONGTYPE ")); + + var queryPrimary = (byte[][])context.clusterTestUtils.Execute(primary, "VSIM", ["foo", "XB8", bytes3]); + ClassicAssert.AreEqual(2, queryPrimary.Length); + + _ = context.clusterTestUtils.Execute(secondary, "READONLY", []); + + // The vector set has either replicated, or not + // If so - we get WRONGTYPE + // If not - we get a null + var readSecondary = (string)context.clusterTestUtils.Execute(secondary, "GET", ["foo"]); + ClassicAssert.IsTrue(readSecondary is null || readSecondary.StartsWith("WRONGTYPE ")); + + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, SecondaryIndex); + + var querySecondary = (byte[][])context.clusterTestUtils.Execute(secondary, "VSIM", ["foo", "XB8", bytes3]); + ClassicAssert.IsTrue(querySecondary.Length >= 1); + + for (var j = 0; j < querySecondary.Length; j++) + { + var expected = + querySecondary[j].AsSpan().SequenceEqual(key0) || + querySecondary[j].AsSpan().SequenceEqual(key1); + + ClassicAssert.IsTrue(expected); + } + + Incr(key0); + Incr(key1); + } + + static void Incr(byte[] k) + { + var ix = k.Length - 1; + while (true) + { + k[ix]++; + if (k[ix] == 0) + { + ix--; + } + else + { + break; + } + } + } + } + + [Test] + public async Task MultipleReplicasWithVectorSetsAsync() + { + const int PrimaryIndex = 0; + const int SecondaryStartIndex = 1; + const int SecondaryEndIndex = 5; + const int Vectors = 2_000; + const string Key = nameof(MultipleReplicasWithVectorSetsAsync); + + context.CreateInstances(HighReplicationShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: 1, replica_count: 5); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var secondaries = new IPEndPoint[SecondaryEndIndex - SecondaryStartIndex + 1]; + for (var i = SecondaryStartIndex; i <= SecondaryEndIndex; i++) + { + secondaries[i - SecondaryStartIndex] = (IPEndPoint)context.endpoints[i]; + } + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + + foreach (var secondary in secondaries) + { + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); + } + + // Build some repeatably random data for inserts + var vectors = new byte[Vectors][]; + { + var r = new Random(2025_09_23_00); + + for (var i = 0; i < vectors.Length; i++) + { + vectors[i] = new byte[75]; + r.NextBytes(vectors[i]); + } + } + + using var sync = new SemaphoreSlim(2); + + var writeTask = + Task.Run( + async () => + { + await sync.WaitAsync(); + + var key = new byte[4]; + for (var i = 0; i < vectors.Length; i++) + { + BinaryPrimitives.WriteInt32LittleEndian(key, i); + var val = vectors[i]; + var addRes = (int)context.clusterTestUtils.Execute(primary, "VADD", [Key, "XB8", val, key, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes); + } + } + ); + + using var cts = new CancellationTokenSource(); + + var readTasks = new Task[secondaries.Length]; + + for (var i = 0; i < secondaries.Length; i++) + { + var secondary = secondaries[i]; + var readTask = + Task.Run( + async () => + { + var r = new Random(2025_09_23_01); + + var readonlyOnReplica = (string)context.clusterTestUtils.Execute(secondary, "READONLY", []); + ClassicAssert.AreEqual("OK", readonlyOnReplica); + + await sync.WaitAsync(); + + var nonZeroReturns = 0; + + while (!cts.Token.IsCancellationRequested) + { + var val = vectors[r.Next(vectors.Length)]; + + var readRes = (byte[][])context.clusterTestUtils.Execute(secondary, "VSIM", [Key, "XB8", val]); + if (readRes.Length > 0) + { + nonZeroReturns++; + } + } + + return nonZeroReturns; + } + ); + + readTasks[i] = readTask; + } + + _ = sync.Release(secondaries.Length + 1); + await writeTask; + + for (var secondaryIndex = SecondaryStartIndex; secondaryIndex <= SecondaryEndIndex; secondaryIndex++) + { + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, secondaryIndex); + } + + cts.CancelAfter(TimeSpan.FromSeconds(1)); + + var searchesWithNonZeroResults = await Task.WhenAll(readTasks); + + ClassicAssert.IsTrue(searchesWithNonZeroResults.All(static x => x > 0)); + + + // Validate all nodes have same vector embeddings + { + var idBytes = new byte[4]; + for (var id = 0; id < vectors.Length; id++) + { + BinaryPrimitives.WriteInt32LittleEndian(idBytes, id); + var expected = vectors[id]; + + var fromPrimary = (string[])context.clusterTestUtils.Execute(primary, "VEMB", [Key, idBytes]); + + ClassicAssert.AreEqual(expected.Length, fromPrimary.Length); + + for (var i = 0; i < expected.Length; i++) + { + var p = (byte)float.Parse(fromPrimary[i]); + ClassicAssert.AreEqual(expected[i], p); + } + + for (var secondaryIx = 0; secondaryIx < secondaries.Length; secondaryIx++) + { + var secondary = secondaries[secondaryIx]; + var fromSecondary = (string[])context.clusterTestUtils.Execute(secondary, "VEMB", [Key, idBytes]); + + ClassicAssert.AreEqual(expected.Length, fromSecondary.Length); + + for (var i = 0; i < expected.Length; i++) + { + var s = (byte)float.Parse(fromSecondary[i]); + ClassicAssert.AreEqual(expected[i], s); + } + } + } + } + } + + [Test] + public async Task MultipleReplicasWithVectorSetsAndDeletesAsync() + { + const int PrimaryIndex = 0; + const int SecondaryStartIndex = 1; + const int SecondaryEndIndex = 5; + const int Vectors = 2_000; + const int Deletes = Vectors / 10; + const string Key = nameof(MultipleReplicasWithVectorSetsAndDeletesAsync); + + context.CreateInstances(HighReplicationShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: 1, replica_count: 5); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var secondaries = new IPEndPoint[SecondaryEndIndex - SecondaryStartIndex + 1]; + for (var i = SecondaryStartIndex; i <= SecondaryEndIndex; i++) + { + secondaries[i - SecondaryStartIndex] = (IPEndPoint)context.endpoints[i]; + } + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + + foreach (var secondary in secondaries) + { + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); + } + + // Build some repeatably random data for inserts + var vectors = new byte[Vectors][]; + var toDeleteVectors = new HashSet(); + var pendingRemove = new List(); + { + var r = new Random(2025_10_20_00); + + for (var i = 0; i < vectors.Length; i++) + { + vectors[i] = new byte[75]; + r.NextBytes(vectors[i]); + } + + while (toDeleteVectors.Count < Deletes) + { + _ = toDeleteVectors.Add(r.Next(vectors.Length)); + } + + pendingRemove.AddRange(toDeleteVectors); + } + + using var sync = new SemaphoreSlim(2); + + var writeTask = + Task.Run( + async () => + { + await sync.WaitAsync(); + + var key = new byte[4]; + for (var i = 0; i < vectors.Length; i++) + { + BinaryPrimitives.WriteInt32LittleEndian(key, i); + var val = vectors[i]; + var addRes = (int)context.clusterTestUtils.Execute(primary, "VADD", [Key, "XB8", val, key, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes); + } + } + ); + + var deleteTask = + Task.Run( + async () => + { + await sync.WaitAsync(); + + var key = new byte[4]; + + while (pendingRemove.Count > 0) + { + var i = Random.Shared.Next(pendingRemove.Count); + var id = pendingRemove[i]; + + BinaryPrimitives.WriteInt32LittleEndian(key, id); + var remRes = (int)context.clusterTestUtils.Execute(primary, "VREM", [Key, key]); + if (remRes == 1) + { + pendingRemove.RemoveAt(i); + } + } + } + ); + + using var cts = new CancellationTokenSource(); + + var readTasks = new Task[secondaries.Length]; + + for (var i = 0; i < secondaries.Length; i++) + { + var secondary = secondaries[i]; + var readTask = + Task.Run( + async () => + { + var r = new Random(2025_09_23_01); + + var readonlyOnReplica = (string)context.clusterTestUtils.Execute(secondary, "READONLY", []); + ClassicAssert.AreEqual("OK", readonlyOnReplica); + + await sync.WaitAsync(); + + var nonZeroReturns = 0; + + while (!cts.Token.IsCancellationRequested) + { + var val = vectors[r.Next(vectors.Length)]; + + var readRes = (byte[][])context.clusterTestUtils.Execute(secondary, "VSIM", [Key, "XB8", val]); + if (readRes.Length > 0) + { + nonZeroReturns++; + } + } + + return nonZeroReturns; + } + ); + + readTasks[i] = readTask; + } + + _ = sync.Release(secondaries.Length + 2); + await writeTask; + await deleteTask; + + for (var secondaryIndex = SecondaryStartIndex; secondaryIndex <= SecondaryEndIndex; secondaryIndex++) + { + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, secondaryIndex); + } + + cts.CancelAfter(TimeSpan.FromSeconds(1)); + + var searchesWithNonZeroResults = await Task.WhenAll(readTasks); + + ClassicAssert.IsTrue(searchesWithNonZeroResults.All(static x => x > 0)); + + // Validate all nodes have same vector embeddings + { + var idBytes = new byte[4]; + for (var id = 0; id < vectors.Length; id++) + { + BinaryPrimitives.WriteInt32LittleEndian(idBytes, id); + var expected = vectors[id]; + + var fromPrimary = (string[])context.clusterTestUtils.Execute(primary, "VEMB", [Key, idBytes]); + + var shouldBePresent = !toDeleteVectors.Contains(id); + if (shouldBePresent) + { + ClassicAssert.AreEqual(expected.Length, fromPrimary.Length); + + for (var i = 0; i < expected.Length; i++) + { + var p = (byte)float.Parse(fromPrimary[i]); + ClassicAssert.AreEqual(expected[i], p); + } + } + else + { + ClassicAssert.IsEmpty(fromPrimary); + } + + for (var secondaryIx = 0; secondaryIx < secondaries.Length; secondaryIx++) + { + var secondary = secondaries[secondaryIx]; + var fromSecondary = (string[])context.clusterTestUtils.Execute(secondary, "VEMB", [Key, idBytes]); + + if (shouldBePresent) + { + ClassicAssert.AreEqual(expected.Length, fromSecondary.Length); + + for (var i = 0; i < expected.Length; i++) + { + var s = (byte)float.Parse(fromSecondary[i]); + ClassicAssert.AreEqual(expected[i], s); + } + } + else + { + ClassicAssert.IsEmpty(fromSecondary); + } + } + } + } + } + + [Test] + public void VectorSetMigrateSingleBySlot() + { + // Test migrating a single slot with a vector set of one element in it + + const int Primary0Index = 0; + const int Primary1Index = 1; + const int Secondary0Index = 2; + const int Secondary1Index = 3; + + context.CreateInstances(DefaultMultiPrimaryShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultMultiPrimaryShards / 2, replica_count: 1); + + var primary0 = (IPEndPoint)context.endpoints[Primary0Index]; + var primary1 = (IPEndPoint)context.endpoints[Primary1Index]; + var secondary0 = (IPEndPoint)context.endpoints[Secondary0Index]; + var secondary1 = (IPEndPoint)context.endpoints[Secondary1Index]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary0).Value); + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary1).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary0).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary1).Value); + + var primary0Id = context.clusterTestUtils.ClusterMyId(primary0); + var primary1Id = context.clusterTestUtils.ClusterMyId(primary1); + + var slots = context.clusterTestUtils.ClusterSlots(primary0); + + string primary0Key; + int primary0HashSlot; + { + var ix = 0; + + while (true) + { + primary0Key = $"{nameof(VectorSetMigrateSingleBySlot)}_{ix}"; + primary0HashSlot = context.clusterTestUtils.HashSlot(primary0Key); + + if (slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && primary0HashSlot >= x.startSlot && primary0HashSlot <= x.endSlot)) + { + break; + } + + ix++; + } + } + + // Setup simple vector set on Primary0 in some hash slot + + var vectorData = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); + var vectorSimData = Enumerable.Range(0, 75).Select(static x => (byte)(x * 2)).ToArray(); + + var add0Res = (int)context.clusterTestUtils.Execute(primary0, "VADD", [primary0Key, "XB8", vectorData, new byte[] { 0, 0, 0, 0 }, "XPREQ8", "SETATTR", "{\"hello\": \"world\"}"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(1, add0Res); + + var sim0Res = (byte[][])context.clusterTestUtils.Execute(primary0, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(3, sim0Res.Length); + ClassicAssert.IsTrue(new byte[] { 0, 0, 0, 0 }.SequenceEqual(sim0Res[0])); + ClassicAssert.IsFalse(float.IsNaN(float.Parse(Encoding.ASCII.GetString(sim0Res[1])))); + ClassicAssert.IsTrue("{\"hello\": \"world\"}"u8.SequenceEqual(sim0Res[2])); + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index); + + var readonlyOnReplica0 = (string)context.clusterTestUtils.Execute(secondary0, "READONLY", [], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual("OK", readonlyOnReplica0); + + var simOnReplica0 = (byte[][])context.clusterTestUtils.Execute(secondary0, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(simOnReplica0.Length > 0); + for (var i = 0; i < sim0Res.Length; i++) + { + ClassicAssert.IsTrue(sim0Res[i].AsSpan().SequenceEqual(simOnReplica0[i])); + } + + // Move to other primary + + context.clusterTestUtils.MigrateSlots(primary0, primary1, [primary0HashSlot]); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index); + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index); + context.clusterTestUtils.WaitForReplicaAofSync(Primary1Index, Secondary1Index); + + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, NullLogger.Instance); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, NullLogger.Instance); + + ClassicAssert.IsFalse(curPrimary0Slots.Contains(primary0HashSlot)); + ClassicAssert.IsTrue(curPrimary1Slots.Contains(primary0HashSlot)); + + // Check available on other primary & secondary + + var sim1Res = (byte[][])context.clusterTestUtils.Execute(primary1, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(sim1Res.Length > 0); + for (var i = 0; i < sim0Res.Length; i++) + { + ClassicAssert.IsTrue(sim0Res[i].AsSpan().SequenceEqual(sim1Res[i])); + } + + var readonlyOnReplica1 = (string)context.clusterTestUtils.Execute(secondary1, "READONLY", [], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual("OK", readonlyOnReplica1); + + var simOnReplica1 = (byte[][])context.clusterTestUtils.Execute(secondary1, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(simOnReplica1.Length > 0); + for (var i = 0; i < sim0Res.Length; i++) + { + ClassicAssert.IsTrue(sim0Res[i].AsSpan().SequenceEqual(simOnReplica0[i])); + } + + // Check no longer available on old primary or secondary + var exc0 = (string)context.clusterTestUtils.Execute(primary0, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(exc0.StartsWith("Key has MOVED to ")); + + var start = Stopwatch.GetTimestamp(); + + var success = false; + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + try + { + var exc1 = (string)context.clusterTestUtils.Execute(secondary0, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(exc1.StartsWith("Key has MOVED to ")); + success = true; + break; + } + catch + { + // Secondary can still have the key for a bit + Thread.Sleep(100); + } + } + + ClassicAssert.IsTrue(success, "Original replica still has Vector Set long after primary has completed"); + } + + [Test] + public void VectorSetMigrateByKeys() + { + // Based on : ClusterSimpleMigrateKeys test + + const int ShardCount = 3; + const int KeyCount = 10; + + context.CreateInstances(ShardCount, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(); + + var otherNodeIndex = 0; + var sourceNodeIndex = 1; + var targetNodeIndex = 2; + var sourceNodeId = context.clusterTestUtils.GetNodeIdFromNode(sourceNodeIndex, NullLogger.Instance); + var targetNodeId = context.clusterTestUtils.GetNodeIdFromNode(targetNodeIndex, NullLogger.Instance); + + var key = Encoding.ASCII.GetBytes("{abc}a"); + List keys = []; + List<(byte[] Key, byte[] Data)> vectors = []; + List attributes = []; + + var _workingSlot = ClusterTestUtils.HashSlot(key); + ClassicAssert.AreEqual(7638, _workingSlot); + + Random rand = new(2025_11_04_00); + + for (var i = 0; i < KeyCount; i++) + { + var newKey = new byte[key.Length]; + Array.Copy(key, 0, newKey, 0, key.Length); + newKey[^1] = (byte)(newKey[^1] + i); + ClassicAssert.AreEqual(_workingSlot, ClusterTestUtils.HashSlot(newKey)); + + var elem = new byte[4]; + rand.NextBytes(elem); + + var data = new byte[75]; + rand.NextBytes(data); + + var attrs = new byte[16]; + rand.NextBytes(attrs); + + var addRes = (int)context.clusterTestUtils.Execute(context.clusterTestUtils.GetEndPoint(sourceNodeIndex), "VADD", [newKey, "XB8", data, elem, "XPREQ8", "SETATTR", attrs]); + ClassicAssert.AreEqual(1, addRes); + + keys.Add(newKey); + vectors.Add((elem, data)); + attributes.Add(attrs); + } + + // Start migration + var respImport = context.clusterTestUtils.SetSlot(targetNodeIndex, _workingSlot, "IMPORTING", sourceNodeId); + ClassicAssert.AreEqual(respImport, "OK"); + + var respMigrate = context.clusterTestUtils.SetSlot(sourceNodeIndex, _workingSlot, "MIGRATING", targetNodeId); + ClassicAssert.AreEqual(respMigrate, "OK"); + + // Check key count + var countKeys = context.clusterTestUtils.CountKeysInSlot(sourceNodeIndex, _workingSlot); + ClassicAssert.AreEqual(countKeys, KeyCount); + + // Enumerate keys in slots + var keysInSlot = context.clusterTestUtils.GetKeysInSlot(sourceNodeIndex, _workingSlot, countKeys); + ClassicAssert.AreEqual(keys, keysInSlot); + + // Migrate keys, but in a random-ish order so context reservation gets stressed + var toMigrate = keysInSlot.ToList(); + while (toMigrate.Count > 0) + { + var migrateSingleIx = rand.Next(toMigrate.Count); + var migrateKey = toMigrate[migrateSingleIx]; + context.clusterTestUtils.MigrateKeys(context.clusterTestUtils.GetEndPoint(sourceNodeIndex), context.clusterTestUtils.GetEndPoint(targetNodeIndex), [migrateKey], NullLogger.Instance); + + toMigrate.RemoveAt(migrateSingleIx); + } + + // Finish migration + var respNodeTarget = context.clusterTestUtils.SetSlot(targetNodeIndex, _workingSlot, "NODE", targetNodeId); + ClassicAssert.AreEqual(respNodeTarget, "OK"); + context.clusterTestUtils.BumpEpoch(targetNodeIndex, waitForSync: true); + + var respNodeSource = context.clusterTestUtils.SetSlot(sourceNodeIndex, _workingSlot, "NODE", targetNodeId); + ClassicAssert.AreEqual(respNodeSource, "OK"); + context.clusterTestUtils.BumpEpoch(sourceNodeIndex, waitForSync: true); + // End Migration + + // Check config + var targetConfigEpochFromTarget = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(targetNodeIndex, targetNodeId, NullLogger.Instance); + var targetConfigEpochFromSource = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(sourceNodeIndex, targetNodeId, NullLogger.Instance); + var targetConfigEpochFromOther = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(otherNodeIndex, targetNodeId, NullLogger.Instance); + + while (targetConfigEpochFromOther != targetConfigEpochFromTarget || targetConfigEpochFromSource != targetConfigEpochFromTarget) + { + _ = Thread.Yield(); + targetConfigEpochFromTarget = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(targetNodeIndex, targetNodeId, NullLogger.Instance); + targetConfigEpochFromSource = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(sourceNodeIndex, targetNodeId, NullLogger.Instance); + targetConfigEpochFromOther = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(otherNodeIndex, targetNodeId, NullLogger.Instance); + } + ClassicAssert.AreEqual(targetConfigEpochFromTarget, targetConfigEpochFromOther); + ClassicAssert.AreEqual(targetConfigEpochFromTarget, targetConfigEpochFromSource); + + // Check migration in progress + foreach (var _key in keys) + { + var resp = context.clusterTestUtils.GetKey(otherNodeIndex, _key, out var slot, out var endpoint, out var responseState); + while (endpoint.Port != context.clusterTestUtils.GetEndPoint(targetNodeIndex).Port && responseState != ResponseState.OK) + { + resp = context.clusterTestUtils.GetKey(otherNodeIndex, _key, out slot, out endpoint, out responseState); + } + ClassicAssert.AreEqual(resp, "MOVED"); + ClassicAssert.AreEqual(_workingSlot, slot); + ClassicAssert.AreEqual(context.clusterTestUtils.GetEndPoint(targetNodeIndex), endpoint); + } + + // Finish migration + context.clusterTestUtils.WaitForMigrationCleanup(NullLogger.Instance); + + // Validate vector sets coherent + for (var i = 0; i < keys.Count; i++) + { + var _key = keys[i]; + var (elem, data) = vectors[i]; + var attrs = attributes[i]; + + var res = (byte[][])context.clusterTestUtils.Execute(context.clusterTestUtils.GetEndPoint(targetNodeIndex), "VSIM", [_key, "XB8", data, "WITHATTRIBS"]); + ClassicAssert.AreEqual(2, res.Length); + ClassicAssert.IsTrue(res[0].SequenceEqual(elem)); + ClassicAssert.IsTrue(res[1].SequenceEqual(attrs)); + } + } + + [Test] + public void VectorSetMigrateManyBySlot() + { + // Test migrating several vector sets from one primary to another primary, which already has vectors sets of its own + + const int Primary0Index = 0; + const int Primary1Index = 1; + const int Secondary0Index = 2; + const int Secondary1Index = 3; + + const int VectorSetsPerPrimary = 8; + + context.CreateInstances(DefaultMultiPrimaryShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultMultiPrimaryShards / 2, replica_count: 1); + + var primary0 = (IPEndPoint)context.endpoints[Primary0Index]; + var primary1 = (IPEndPoint)context.endpoints[Primary1Index]; + var secondary0 = (IPEndPoint)context.endpoints[Secondary0Index]; + var secondary1 = (IPEndPoint)context.endpoints[Secondary1Index]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary0).Value); + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary1).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary0).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary1).Value); + + var primary0Id = context.clusterTestUtils.ClusterMyId(primary0); + var primary1Id = context.clusterTestUtils.ClusterMyId(primary1); + + var slots = context.clusterTestUtils.ClusterSlots(primary0); + + List<(string Key, ushort HashSlot, byte[] Element, byte[] Data, byte[] Attr)> primary0Keys = []; + List<(string Key, ushort HashSlot, byte[] Element, byte[] Data, byte[] Attr)> primary1Keys = []; + + { + var ix = 0; + + while (primary0Keys.Count < VectorSetsPerPrimary || primary1Keys.Count < VectorSetsPerPrimary) + { + var key = $"{nameof(VectorSetMigrateManyBySlot)}_{ix}"; + var hashSlot = context.clusterTestUtils.HashSlot(key); + + var isOnPrimary0 = slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && hashSlot >= x.startSlot && hashSlot <= x.endSlot); + var isOnPrimary1 = slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary1Id) && hashSlot >= x.startSlot && hashSlot <= x.endSlot); + + if (isOnPrimary0 && primary0Keys.Count < VectorSetsPerPrimary) + { + var elem = new byte[4]; + var data = new byte[75]; + var attr = new byte[10]; + Random.Shared.NextBytes(elem); + Random.Shared.NextBytes(data); + Random.Shared.NextBytes(attr); + + primary0Keys.Add((key, (ushort)hashSlot, elem, data, attr)); + } + + if (isOnPrimary1 && primary1Keys.Count < VectorSetsPerPrimary) + { + var elem = new byte[4]; + var data = new byte[75]; + var attr = new byte[10]; + Random.Shared.NextBytes(elem); + Random.Shared.NextBytes(data); + Random.Shared.NextBytes(attr); + + primary1Keys.Add((key, (ushort)hashSlot, elem, data, attr)); + } + + ix++; + } + } + + // Setup vectors on the primaries + foreach (var (key, _, elem, data, attr) in primary0Keys) + { + var add0Res = (int)context.clusterTestUtils.Execute(primary0, "VADD", [key, "XB8", data, elem, "XPREQ8", "SETATTR", attr], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(1, add0Res); + } + + foreach (var (key, _, elem, data, attr) in primary1Keys) + { + var add1Res = (int)context.clusterTestUtils.Execute(primary1, "VADD", [key, "XB8", data, elem, "XPREQ8", "SETATTR", attr], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(1, add1Res); + } + + // Query expected results + Dictionary<(string Key, byte[] Data), (byte[] Elem, byte[] Attr, float Score)> expected = new(StringAndByteArrayComparer.Instance); + + foreach (var (key, _, _, data, _) in primary0Keys) + { + var sim0Res = (byte[][])context.clusterTestUtils.Execute(primary0, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(3, sim0Res.Length); + expected.Add((key, data), (sim0Res[0], sim0Res[2], float.Parse(Encoding.ASCII.GetString(sim0Res[1])))); + } + + foreach (var (key, _, _, data, _) in primary1Keys) + { + var sim1Res = (byte[][])context.clusterTestUtils.Execute(primary1, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(3, sim1Res.Length); + expected.Add((key, data), (sim1Res[0], sim1Res[2], float.Parse(Encoding.ASCII.GetString(sim1Res[1])))); + } + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index); + + // Move from primary0 to primary1 + var migratedHashSlots = primary0Keys.Select(static t => t.HashSlot).Distinct().Select(static s => (int)s).ToList(); + + context.clusterTestUtils.MigrateSlots(primary0, primary1, migratedHashSlots); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index); + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index); + context.clusterTestUtils.WaitForReplicaAofSync(Primary1Index, Secondary1Index); + + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, NullLogger.Instance); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, NullLogger.Instance); + + foreach (var hashSlot in migratedHashSlots) + { + ClassicAssert.IsFalse(curPrimary0Slots.Contains(hashSlot)); + ClassicAssert.IsTrue(curPrimary1Slots.Contains(hashSlot)); + } + + // Check available on other primary + foreach (var (key, _, _, data, _) in primary0Keys.Concat(primary1Keys)) + { + var migrateSimRes = (byte[][])context.clusterTestUtils.Execute(primary1, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(3, migrateSimRes.Length); + + var (elem, attr, score) = expected[(key, data)]; + + ClassicAssert.IsTrue(elem.SequenceEqual(migrateSimRes[0])); + ClassicAssert.AreEqual(score, float.Parse(Encoding.ASCII.GetString(migrateSimRes[1]))); + ClassicAssert.IsTrue(attr.SequenceEqual(migrateSimRes[2])); + } + + // Check no longer available on old primary or secondary + foreach (var (key, _, _, data, _) in primary0Keys.Concat(primary1Keys)) + { + var exc0 = (string)context.clusterTestUtils.Execute(primary0, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(exc0.StartsWith("Key has MOVED to ")); + } + + var start = Stopwatch.GetTimestamp(); + + var success = false; + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + try + { + var migrationNotFinished = false; + foreach (var (key, _, _, data, _) in primary0Keys.Concat(primary1Keys)) + { + var exc1 = (string)context.clusterTestUtils.Execute(secondary0, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + if (!exc1.StartsWith("Key has MOVED to ")) + { + migrationNotFinished = true; + break; + } + } + + if (migrationNotFinished) + { + continue; + } + + success = true; + break; + } + catch + { + // Secondary can still have the key for a bit + Thread.Sleep(100); + } + } + + ClassicAssert.IsTrue(success, "Original replica still has Vector Set long after primary has completed"); + + // Check available on new secondary + var readonlyOnReplica1 = (string)context.clusterTestUtils.Execute(secondary1, "READONLY", [], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual("OK", readonlyOnReplica1); + + start = Stopwatch.GetTimestamp(); + + success = false; + + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + success = true; + + foreach (var (key, _, _, data, _) in primary0Keys.Concat(primary1Keys)) + { + var migrateSimRes = (byte[][])context.clusterTestUtils.Execute(secondary1, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + + if (migrateSimRes.Length == 1 && Encoding.UTF8.GetString(migrateSimRes[1]).StartsWith("Key has MOVED to ")) + { + success = false; + break; + } + + ClassicAssert.AreEqual(3, migrateSimRes.Length); + + var (elem, attr, score) = expected[(key, data)]; + + ClassicAssert.IsTrue(elem.SequenceEqual(migrateSimRes[0])); + ClassicAssert.AreEqual(score, float.Parse(Encoding.ASCII.GetString(migrateSimRes[1]))); + ClassicAssert.IsTrue(attr.SequenceEqual(migrateSimRes[2])); + } + + if (success) + { + break; + } + } + + ClassicAssert.IsTrue(success, "New replica hasn't replicated Vector Set long after primary has received data"); + } + + [Test] + public async Task MigrateVectorSetWhileModifyingAsync() + { + // Test migrating a single slot with a vector set while moving it + + const int Primary0Index = 0; + const int Primary1Index = 1; + const int Secondary0Index = 2; + const int Secondary1Index = 3; + + context.CreateInstances(DefaultMultiPrimaryShards, useTLS: true, enableAOF: true, OnDemandCheckpoint: true, EnableIncrementalSnapshots: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultMultiPrimaryShards / 2, replica_count: 1); + + var primary0 = (IPEndPoint)context.endpoints[Primary0Index]; + var primary1 = (IPEndPoint)context.endpoints[Primary1Index]; + var secondary0 = (IPEndPoint)context.endpoints[Secondary0Index]; + var secondary1 = (IPEndPoint)context.endpoints[Secondary1Index]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary0).Value); + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary1).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary0).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary1).Value); + + var primary0Id = context.clusterTestUtils.ClusterMyId(primary0); + var primary1Id = context.clusterTestUtils.ClusterMyId(primary1); + + var slots = context.clusterTestUtils.ClusterSlots(primary0); + + string primary0Key; + int primary0HashSlot; + { + var ix = 0; + + while (true) + { + primary0Key = $"{nameof(MigrateVectorSetWhileModifyingAsync)}_{ix}"; + primary0HashSlot = context.clusterTestUtils.HashSlot(primary0Key); + + if (slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && primary0HashSlot >= x.startSlot && primary0HashSlot <= x.endSlot)) + { + break; + } + + ix++; + } + } + + // Start writing to this Vector Set + using var cts = new CancellationTokenSource(); + + var added = new ConcurrentBag<(byte[] Elem, byte[] Data, byte[] Attr)>(); + + var writeTask = + Task.Run( + async () => + { + // Force async + await Task.Yield(); + + using var readWriteCon = ConnectionMultiplexer.Connect(context.clusterTestUtils.GetRedisConfig(context.endpoints)); + var readWriteDb = readWriteCon.GetDatabase(); + + var ix = 0; + + var elem = new byte[4]; + var data = new byte[75]; + var attr = new byte[100]; + + BinaryPrimitives.WriteInt32LittleEndian(elem, ix); + Random.Shared.NextBytes(data); + Random.Shared.NextBytes(attr); + + while (!cts.IsCancellationRequested) + { + if (TestUtils.IsRunningAsGitHubAction) + { + // Throw some delay in when running as a GitHub Action to work around the weak drives those VMs have + await Task.Delay(1); + } + + // This should follow redirects, so migration shouldn't cause any failures + try + { + var addRes = (int)readWriteDb.Execute("VADD", [new RedisKey(primary0Key), "XB8", data, elem, "XPREQ8", "SETATTR", attr]); + ClassicAssert.AreEqual(1, addRes); + } + catch (RedisServerException exc) + { + if (exc.Message.StartsWith("MOVED ")) + { + continue; + } + + throw; + } + + added.Add((elem.ToArray(), data.ToArray(), attr.ToArray())); + + ix++; + BinaryPrimitives.WriteInt32LittleEndian(elem, ix); + Random.Shared.NextBytes(data); + Random.Shared.NextBytes(attr); + } + } + ); + + await Task.Delay(1_000); + + var lenPreMigration = added.Count; + ClassicAssert.IsTrue(lenPreMigration > 0, "Should have seen some writes pre-migration"); + + // Move to other primary + using (var migrateToken = new CancellationTokenSource()) + { + migrateToken.CancelAfter(30_000); + + context.clusterTestUtils.MigrateSlots(primary0, primary1, [primary0HashSlot]); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index, cancellationToken: migrateToken.Token); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index, cancellationToken: migrateToken.Token); + } + + using (var replicationToken = new CancellationTokenSource()) + { + replicationToken.CancelAfter(30_000); + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index, cancellation: replicationToken.Token); + context.clusterTestUtils.WaitForReplicaAofSync(Primary1Index, Secondary1Index, cancellation: replicationToken.Token); + } + + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, NullLogger.Instance); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, NullLogger.Instance); + + ClassicAssert.IsFalse(curPrimary0Slots.Contains(primary0HashSlot)); + ClassicAssert.IsTrue(curPrimary1Slots.Contains(primary0HashSlot)); + + var lenPrePause = added.Count; + await Task.Delay(5_000); + var lenPostPause = added.Count; + + ClassicAssert.IsTrue(lenPostPause > lenPrePause, "Writes after migration did not resume"); + + // Stop Writes and wait for replication to catch up + cts.Cancel(); + await writeTask; + + var addedLookup = added.ToFrozenDictionary(static t => t.Elem, t => t, ByteArrayComparer.Instance); + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index); + context.clusterTestUtils.WaitForReplicaAofSync(Primary1Index, Secondary1Index); + + // Check available on other primary & secondary + + foreach (var (_, data, _) in added) + { + var sim1Res = (byte[][])context.clusterTestUtils.Execute(primary1, "VSIM", [primary0Key, "XB8", data, "WITHSCORES", "WITHATTRIBS", "COUNT", "1"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(3, sim1Res.Length); + + // No guarantee we'll get the exact same element, but we should always get _a_ result and the correct associated attribute + var resElem = sim1Res[0]; + var resAttr = sim1Res[2]; + var expectedAttr = addedLookup[resElem].Attr; + ClassicAssert.IsTrue(resAttr.SequenceEqual(expectedAttr)); + } + + var readonlyOnReplica1 = (string)context.clusterTestUtils.Execute(secondary1, "READONLY", [], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual("OK", readonlyOnReplica1); + + foreach (var (elem, data, attr) in added) + { + var simOnReplica1Res = (byte[][])context.clusterTestUtils.Execute(secondary1, "VSIM", [primary0Key, "XB8", data, "WITHSCORES", "WITHATTRIBS", "COUNT", "1"], flags: CommandFlags.NoRedirect); + + // No guarantee we'll get the exact same element, but we should always get _a_ result and the correct associated attribute + var resElem = simOnReplica1Res[0]; + var resAttr = simOnReplica1Res[2]; + var expectedAttr = addedLookup[resElem].Attr; + ClassicAssert.IsTrue(resAttr.SequenceEqual(expectedAttr)); + } + } + + [Test] + public void MigrateVectorSetBack() + { + const int Primary0Index = 0; + const int Primary1Index = 1; + + context.CreateInstances(DefaultShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultShards, replica_count: 0); + + var primary0 = (IPEndPoint)context.endpoints[Primary0Index]; + var primary1 = (IPEndPoint)context.endpoints[Primary1Index]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary0).Value); + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary1).Value); + + var primary0Id = context.clusterTestUtils.ClusterMyId(primary0); + var primary1Id = context.clusterTestUtils.ClusterMyId(primary1); + + var slots = context.clusterTestUtils.ClusterSlots(primary0); + + string vectorSetKey; + int vectorSetKeySlot; + { + var ix = 0; + + while (true) + { + vectorSetKey = $"{nameof(MigrateVectorSetBack)}_{ix}"; + vectorSetKeySlot = context.clusterTestUtils.HashSlot(vectorSetKey); + + var isPrimary0Slot = slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && vectorSetKeySlot >= x.startSlot && vectorSetKeySlot <= x.endSlot); + if (isPrimary0Slot) + { + break; + } + + ix++; + } + } + + using var readWriteCon = ConnectionMultiplexer.Connect(context.clusterTestUtils.GetRedisConfig(context.endpoints)); + var readWriteDB = readWriteCon.GetDatabase(); + + var data0 = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); + byte[] elem0 = [1, 2, 3, 0]; + var attr0 = "hello world"u8.ToArray(); + + var add0Res = (int)readWriteDB.Execute("VADD", [new RedisKey(vectorSetKey), "XB8", data0, elem0, "XPREQ8", "SETATTR", attr0]); + ClassicAssert.AreEqual(1, add0Res); + + // Migrate 0 -> 1 + { + using (var migrateToken = new CancellationTokenSource()) + { + migrateToken.CancelAfter(30_000); + + context.clusterTestUtils.MigrateSlots(primary0, primary1, [vectorSetKeySlot]); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index, cancellationToken: migrateToken.Token); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index, cancellationToken: migrateToken.Token); + } + + var nodePropSuccess = false; + var start = Stopwatch.GetTimestamp(); + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, NullLogger.Instance); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, NullLogger.Instance); + + var movedOffPrimary0 = !curPrimary0Slots.Contains(vectorSetKeySlot); + var movedOntoPrimary1 = curPrimary1Slots.Contains(vectorSetKeySlot); + + if (movedOffPrimary0 && movedOntoPrimary1) + { + nodePropSuccess = true; + break; + } + } + + ClassicAssert.IsTrue(nodePropSuccess, "Node propagation after 0 -> 1 migration took too long"); + } + + // Confirm still valid to add, with client side routing + var data1 = Enumerable.Range(0, 75).Select(static x => (byte)(x * 2)).ToArray(); + byte[] elem1 = [4, 5, 6, 7]; + var attr1 = "fizz buzz"u8.ToArray(); + + var add1Res = (int)readWriteDB.Execute("VADD", [new RedisKey(vectorSetKey), "XB8", data1, elem1, "XPREQ8", "SETATTR", attr1]); + ClassicAssert.AreEqual(1, add1Res); + + // Migrate 1 -> 0 + { + using (var migrateToken = new CancellationTokenSource()) + { + migrateToken.CancelAfter(30_000); + + context.clusterTestUtils.MigrateSlots(primary1, primary0, [vectorSetKeySlot]); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index, cancellationToken: migrateToken.Token); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index, cancellationToken: migrateToken.Token); + } + + var nodePropSuccess = false; + var start = Stopwatch.GetTimestamp(); + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, NullLogger.Instance); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, NullLogger.Instance); + + var movedOntoPrimary0 = curPrimary0Slots.Contains(vectorSetKeySlot); + var movedOffPrimary1 = !curPrimary1Slots.Contains(vectorSetKeySlot); + + if (movedOntoPrimary0 && movedOffPrimary1) + { + nodePropSuccess = true; + break; + } + } + + ClassicAssert.IsTrue(nodePropSuccess, "Node propagation after 1 -> 0 migration took too long"); + } + + // Confirm still valid to add, with client side routing + var data2 = Enumerable.Range(0, 75).Select(static x => (byte)(x * 3)).ToArray(); + byte[] elem2 = [8, 9, 10, 11]; + var attr2 = "foo bar"u8.ToArray(); + + var add2Res = (int)readWriteDB.Execute("VADD", [new RedisKey(vectorSetKey), "XB8", data2, elem2, "XPREQ8", "SETATTR", attr2]); + ClassicAssert.AreEqual(1, add2Res); + + // Confirm no data loss + var emb0 = ((string[])readWriteDB.Execute("VEMB", [new RedisKey(vectorSetKey), elem0])).Select(static x => (byte)float.Parse(x)).ToArray(); + var emb1 = ((string[])readWriteDB.Execute("VEMB", [new RedisKey(vectorSetKey), elem1])).Select(static x => (byte)float.Parse(x)).ToArray(); + var emb2 = ((string[])readWriteDB.Execute("VEMB", [new RedisKey(vectorSetKey), elem2])).Select(static x => (byte)float.Parse(x)).ToArray(); + ClassicAssert.IsTrue(data0.SequenceEqual(emb0)); + ClassicAssert.IsTrue(data1.SequenceEqual(emb1)); + ClassicAssert.IsTrue(data2.SequenceEqual(emb2)); + } + + [Test] + public async Task MigrateVectorStressAsync() + { + // Move vector sets back and forth between replicas, making sure we don't drop data + // Keeps reads and writes going continuously + + const int Primary0Index = 0; + const int Primary1Index = 1; + const int Secondary0Index = 2; + const int Secondary1Index = 3; + + const int VectorSetsPerPrimary = 2; + + var gossipFaultsAtTestStart = 0; + + captureLogWriter.capture = true; + + try + { + context.CreateInstances(DefaultMultiPrimaryShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultMultiPrimaryShards / 2, replica_count: 1); + + var primary0 = (IPEndPoint)context.endpoints[Primary0Index]; + var primary1 = (IPEndPoint)context.endpoints[Primary1Index]; + var secondary0 = (IPEndPoint)context.endpoints[Secondary0Index]; + var secondary1 = (IPEndPoint)context.endpoints[Secondary1Index]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary0).Value); + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary1).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary0).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary1).Value); + + var primary0Id = context.clusterTestUtils.ClusterMyId(primary0); + var primary1Id = context.clusterTestUtils.ClusterMyId(primary1); + + var slots = context.clusterTestUtils.ClusterSlots(primary0); + + var vectorSetKeys = new List<(string Key, ushort HashSlot)>(); + + { + var ix = 0; + + var numP0 = 0; + var numP1 = 0; + + while (numP0 < VectorSetsPerPrimary || numP1 < VectorSetsPerPrimary) + { + var key = $"{nameof(MigrateVectorStressAsync)}_{ix}"; + var slot = context.clusterTestUtils.HashSlot(key); + + var isPrimary0Slot = slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && slot >= x.startSlot && slot <= x.endSlot); + + if (isPrimary0Slot) + { + if (numP0 < VectorSetsPerPrimary) + { + vectorSetKeys.Add((key, (ushort)slot)); + numP0++; + } + } + else + { + if (numP1 < VectorSetsPerPrimary) + { + vectorSetKeys.Add((key, (ushort)slot)); + numP1++; + } + } + + ix++; + } + } + + // Remember how cluster looked right after it was stable + gossipFaultsAtTestStart = CountGossipFaults(captureLogWriter); + + // Start writing to this Vector Set + using var writeCancel = new CancellationTokenSource(); + + using var readWriteCon = ConnectionMultiplexer.Connect(context.clusterTestUtils.GetRedisConfig(context.endpoints)); + var readWriteDB = readWriteCon.GetDatabase(); + + var writeTasks = new Task[vectorSetKeys.Count]; + var writeResults = new ConcurrentBag<(byte[] Elem, byte[] Data, byte[] Attr, DateTime InsertionTime)>[vectorSetKeys.Count]; + + var mostRecentWrite = 0L; + + for (var i = 0; i < vectorSetKeys.Count; i++) + { + var (key, _) = vectorSetKeys[i]; + var written = writeResults[i] = new(); + + writeTasks[i] = + Task.Run( + async () => + { + // Force async + await Task.Yield(); + + var ix = 0; + + while (!writeCancel.IsCancellationRequested) + { + var elem = new byte[4]; + BinaryPrimitives.WriteInt32LittleEndian(elem, ix); + + var data = new byte[75]; + Random.Shared.NextBytes(data); + + var attr = new byte[100]; + Random.Shared.NextBytes(attr); + + while (true) + { + try + { + var addRes = (int)readWriteDB.Execute("VADD", [new RedisKey(key), "XB8", data, elem, "XPREQ8", "SETATTR", attr]); + ClassicAssert.AreEqual(1, addRes); + break; + } + catch (RedisServerException exc) + { + if (exc.Message.StartsWith("MOVED ")) + { + // This is fine, just try again if we're not cancelled + if (writeCancel.IsCancellationRequested) + { + return; + } + + continue; + } + + throw; + } + } + + var now = DateTime.UtcNow; + written.Add((elem, data, attr, now)); + + var mostRecentCopy = mostRecentWrite; + while (mostRecentCopy < now.Ticks) + { + var currentMostRecent = Interlocked.CompareExchange(ref mostRecentWrite, now.Ticks, mostRecentCopy); + if (currentMostRecent == mostRecentCopy) + { + break; + } + mostRecentCopy = currentMostRecent; + } + + ix++; + } + } + ); + } + + using var readCancel = new CancellationTokenSource(); + + var readTasks = new Task[vectorSetKeys.Count]; + for (var i = 0; i < vectorSetKeys.Count; i++) + { + var (key, _) = vectorSetKeys[i]; + var written = writeResults[i]; + readTasks[i] = + Task.Run( + async () => + { + await Task.Yield(); + + var successfulReads = 0; + + while (!readCancel.IsCancellationRequested) + { + var r = written.Count; + if (r == 0) + { + await Task.Delay(10); + continue; + } + + var (elem, data, _, _) = written.ToList()[Random.Shared.Next(r)]; + + var emb = (string[])readWriteDB.Execute("VEMB", [new RedisKey(key), elem]); + + // If we got data, make sure it's coherent + ClassicAssert.AreEqual(data.Length, emb.Length); + + for (var i = 0; i < data.Length; i++) + { + ClassicAssert.AreEqual(data[i], (byte)float.Parse(emb[i])); + } + + successfulReads++; + } + + return successfulReads; + } + ); + } + + await Task.Delay(1_000); + + ClassicAssert.IsTrue(writeResults.All(static r => !r.IsEmpty), "Should have seen some writes pre-migration"); + + // Task to flip back and forth between primaries + using var migrateCancel = new CancellationTokenSource(); + + var migrateTask = + Task.Run( + async () => + { + var hashSlotsOnP0 = new List(); + var hashSlotsOnP1 = new List(); + foreach (var (_, slot) in vectorSetKeys) + { + var isPrimary0Slot = slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && slot >= x.startSlot && slot <= x.endSlot); + if (isPrimary0Slot) + { + if (!hashSlotsOnP0.Contains(slot)) + { + hashSlotsOnP0.Add(slot); + } + } + else + { + if (!hashSlotsOnP1.Contains(slot)) + { + hashSlotsOnP1.Add(slot); + } + } + } + + var migrationTimes = new List(); + + var mostRecentMigration = 0L; + + while (!migrateCancel.IsCancellationRequested) + { + await Task.Delay(100); + + // Don't start another migration until we get at least one successful write + if (Interlocked.CompareExchange(ref mostRecentWrite, 0, 0) < mostRecentMigration) + { + continue; + } + + // Move 0 -> 1 + if (hashSlotsOnP0.Count > 0) + { + using (var migrateToken = new CancellationTokenSource()) + { + migrateToken.CancelAfter(30_000); + + context.clusterTestUtils.MigrateSlots(primary0, primary1, hashSlotsOnP0); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index, cancellationToken: migrateToken.Token); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index, cancellationToken: migrateToken.Token); + } + + var nodePropSuccess = false; + var start = Stopwatch.GetTimestamp(); + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, NullLogger.Instance); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, NullLogger.Instance); + + var movedOffPrimary0 = !curPrimary0Slots.Any(h => hashSlotsOnP0.Contains(h)); + var movedOntoPrimary1 = hashSlotsOnP0.All(h => curPrimary1Slots.Contains(h)); + + if (movedOffPrimary0 && movedOntoPrimary1) + { + nodePropSuccess = true; + break; + } + } + + ClassicAssert.IsTrue(nodePropSuccess, "Node propagation after 0 -> 1 migration took too long"); + } + + // Move 1 -> 0 + if (hashSlotsOnP1.Count > 0) + { + using (var migrateToken = new CancellationTokenSource()) + { + migrateToken.CancelAfter(30_000); + + context.clusterTestUtils.MigrateSlots(primary1, primary0, hashSlotsOnP1); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index, cancellationToken: migrateToken.Token); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index, cancellationToken: migrateToken.Token); + } + + var nodePropSuccess = false; + var start = Stopwatch.GetTimestamp(); + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, NullLogger.Instance); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, NullLogger.Instance); + + var movedOffPrimary1 = !curPrimary1Slots.Any(h => hashSlotsOnP1.Contains(h)); + var movedOntoPrimary0 = hashSlotsOnP1.All(h => curPrimary0Slots.Contains(h)); + + if (movedOffPrimary1 && movedOntoPrimary0) + { + nodePropSuccess = true; + break; + } + } + + ClassicAssert.IsTrue(nodePropSuccess, "Node propagation after 1 -> 0 migration took too long"); + } + + // Remember for next iteration + var now = DateTime.UtcNow; + mostRecentMigration = now.Ticks; + migrationTimes.Add(now); + + // Flip around assignment for next pass + (hashSlotsOnP0, hashSlotsOnP1) = (hashSlotsOnP1, hashSlotsOnP0); + } + + return migrationTimes; + } + ); + + await Task.Delay(10_000); + + migrateCancel.Cancel(); + var migrationTimes = await migrateTask; + + ClassicAssert.IsTrue(migrationTimes.Count > 2, "Should have moved back and forth at least twice"); + + writeCancel.Cancel(); + await Task.WhenAll(writeTasks); + + readCancel.Cancel(); + var readResults = await Task.WhenAll(readTasks); + ClassicAssert.IsTrue(readResults.All(static r => r > 0), "Should have successful reads on all Vector Sets"); + + // Check that everything written survived all the migrations + { + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, NullLogger.Instance); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, NullLogger.Instance); + + for (var i = 0; i < vectorSetKeys.Count; i++) + { + var (key, slot) = vectorSetKeys[i]; + + var isOnPrimary0 = curPrimary0Slots.Contains(slot); + var isOnPrimary1 = curPrimary1Slots.Contains(slot); + + ClassicAssert.IsTrue(isOnPrimary0 || isOnPrimary1, "Hash slot not found on either node"); + ClassicAssert.IsFalse(isOnPrimary0 && isOnPrimary1, "Hash slot found on both nodes"); + + var endpoint = isOnPrimary0 ? primary0 : primary1; + + foreach (var (elem, data, attr, _) in writeResults[i]) + { + var actualData = (string[])context.clusterTestUtils.Execute(endpoint, "VEMB", [key, elem]); + + for (var j = 0; j < data.Length; j++) + { + ClassicAssert.AreEqual(data[j], (byte)float.Parse(actualData[j])); + } + } + } + } + + } + catch (Exception exc) + { + var gossipFaultsAtEnd = CountGossipFaults(captureLogWriter); + + if (gossipFaultsAtTestStart != gossipFaultsAtEnd) + { + // The cluster broke in some way, so data loss is _expected_ + ClassicAssert.Inconclusive($"Gossip fault lead to data loss, Vector Set migration is (probably) not to blame: {exc.Message}"); + } + + // Anything else, keep it going up + throw; + } + + static int CountGossipFaults(CaptureLogWriter captureLogWriter) + { + var capturedLog = captureLogWriter.buffer.ToString(); + + // These kinds of errors happen from stressing migration independent of Vector Sets + // + // TODO: These out to be fixed outside of Vector Set work + var faultRound = capturedLog.Split("^GOSSIP round faulted^").Length - 1; + var faultResponse = capturedLog.Split("^GOSSIP faulted processing response^").Length - 1; + var faultMergeMap = capturedLog.Split("ClusterConfig.MergeSlotMap(").Length - 1; + + return faultRound + faultResponse + faultMergeMap; + } + } + + [Test] + public async Task FailoverStopsVectorManagerReplicationTasksAsync() + { + const int PrimaryIndex = 0; + const int ReplicaIndex = 1; + + context.CreateInstances(DefaultShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultShards / 2, replica_count: DefaultShards / 2); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var replica = (IPEndPoint)context.endpoints[ReplicaIndex]; + + var primaryVectorManager = GetStoreWrapper(context.nodes[PrimaryIndex]).DefaultDatabase.VectorManager; + var replicaVectorManager = GetStoreWrapper(context.nodes[ReplicaIndex]).DefaultDatabase.VectorManager; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(replica).Value); + + var vectorData0 = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); + + var vadd0Res = (int)context.clusterTestUtils.Execute(primary, "VADD", [new RedisKey("foo"), "XB8", vectorData0, new byte[] { 1, 0, 0, 0 }, "XPREQ8"]); + ClassicAssert.AreEqual(1, vadd0Res); + + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, ReplicaIndex); + await Task.Delay(10); + + ClassicAssert.IsFalse(primaryVectorManager.AreReplicationTasksActive); + ClassicAssert.IsTrue(replicaVectorManager.AreReplicationTasksActive); + + context.ClusterFailoverSpinWait(ReplicaIndex, NullLogger.Instance); + + context.clusterTestUtils.WaitForReplicaAofSync(ReplicaIndex, PrimaryIndex); + await Task.Delay(10); + + var vectorData1 = Enumerable.Range(0, 75).Select(static x => (byte)(x * 2)).ToArray(); + + var vadd1Res = (int)context.clusterTestUtils.Execute(replica, "VADD", [new RedisKey("foo"), "XB8", vectorData1, new byte[] { 2, 0, 0, 0 }, "XPREQ8"]); + ClassicAssert.AreEqual(1, vadd1Res); + + ClassicAssert.IsTrue(primaryVectorManager.AreReplicationTasksActive); + ClassicAssert.IsFalse(replicaVectorManager.AreReplicationTasksActive); + + var vsimRes = (byte[][])context.clusterTestUtils.Execute(replica, "VSIM", [new RedisKey("foo"), "XB8", vectorData0]); + ClassicAssert.IsTrue(vsimRes.Length > 0); + } + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "storeWrapper")] + private static extern ref StoreWrapper GetStoreWrapper(GarnetServer server); + } +} \ No newline at end of file diff --git a/test/Garnet.test/CountingEventSlimTests.cs b/test/Garnet.test/CountingEventSlimTests.cs new file mode 100644 index 00000000000..0e86600b181 --- /dev/null +++ b/test/Garnet.test/CountingEventSlimTests.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Concurrent; +using System.Threading.Channels; +using System.Threading.Tasks; +using Allure.NUnit; +using Garnet.common; +using NUnit.Framework; +using NUnit.Framework.Legacy; + +namespace Garnet.test +{ + [AllureNUnit] + [TestFixture] + public class CountingEventSlimTests : AllureTestBase + { + [Test] + public void Basic() + { + using var evt = CountingEventSlim.Create(); + + // Starts signalled + ClassicAssert.IsTrue(evt.Wait()); + + // Signalled can be wait'd multiple times + ClassicAssert.IsTrue(evt.Wait()); + + // Once incremented blocks + evt.Increment(); + ClassicAssert.IsFalse(evt.Wait(0)); + + // Decrementing unblocks + evt.Decrement(); + + ClassicAssert.IsTrue(evt.Wait()); + } + + [Test] + public void Concurrent() + { + // Spawn a number of tasks and fill them with work from a channel + // + // Periodically waits for CountingEventSlim to be 0, and confirms that all work is complete when the signal is received + + const int Iters = 1_000; + const int ItemsPerIter = 1_000; + + for (var iter = 0; iter < Iters; iter++) + { + using var evt = CountingEventSlim.Create(); + + var channel = Channel.CreateUnbounded(new() { SingleReader = false, SingleWriter = true, AllowSynchronousContinuations = false }); + + var working = new ConcurrentDictionary(); + + var tasks = new Task[Math.Max(Environment.ProcessorCount, 4)]; + + for (var i = 0; i < tasks.Length; i++) + { + tasks[i] = + Task.Run( + async () => + { + await foreach (var item in channel.Reader.ReadAllAsync()) + { + ClassicAssert.IsTrue(working.TryRemove(item, out _)); + evt.Decrement(); + } + } + ); + } + + var ix = 0; + while (ix < ItemsPerIter) + { + var toAdd = Random.Shared.Next(ItemsPerIter - ix) + 1; + + for (var i = 0; i < toAdd; i++) + { + ClassicAssert.IsTrue(working.TryAdd(ix, true)); + + evt.Increment(); + ClassicAssert.IsTrue(channel.Writer.TryWrite(ix)); + + ix++; + } + + _ = evt.Wait(); + ClassicAssert.IsTrue(working.IsEmpty); + } + + channel.Writer.Complete(); + + Task.WaitAll(tasks); + } + } + } +} \ No newline at end of file diff --git a/test/Garnet.test/DiskANN/DiskANNGridTests.cs b/test/Garnet.test/DiskANN/DiskANNGridTests.cs new file mode 100644 index 00000000000..6ec1b9c3474 --- /dev/null +++ b/test/Garnet.test/DiskANN/DiskANNGridTests.cs @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using Allure.NUnit; +using Garnet.server; +using NUnit.Framework; +using NUnit.Framework.Legacy; +using StackExchange.Redis; + +namespace Garnet.test.DiskANN +{ + [AllureNUnit] + [TestFixture] + public class DiskANNGridTests : AllureTestBase + { + private GarnetServer server; + + [SetUp] + public void Setup() + { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir); + server.Start(); + } + + [TearDown] + public void TearDown() + { + server.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } + + [Test] + [TestCase(100, 1, VectorQuantType.XPreQ8)] + [TestCase(10, 2, VectorQuantType.XPreQ8)] + [TestCase(3, 7, VectorQuantType.XPreQ8)] + [TestCase(4, 5, VectorQuantType.XPreQ8)] + [TestCase(100, 1, VectorQuantType.NoQuant)] + [TestCase(10, 2, VectorQuantType.NoQuant)] + [TestCase(3, 7, VectorQuantType.NoQuant)] + [TestCase(4, 5, VectorQuantType.NoQuant)] + public void SearchVectorsInGrid(int gridSize, int dimension, VectorQuantType quantType) + { + string quantTypeStr = quantType switch + { + VectorQuantType.NoQuant => "NOQUANT", + VectorQuantType.Bin => "BIN", + VectorQuantType.Q8 => "Q8", + VectorQuantType.XPreQ8 => "XPREQ8", + _ => throw new ArgumentException("Invalid quant type") + }; + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var key = $"gridset_{gridSize}_{dimension}_{quantTypeStr}"; + var addedGridVectors = AddGridVectors(db, key, dimension, gridSize, quantTypeStr); + + // VSIM using existing elements + var vsimArgs = new object[5]; + vsimArgs[0] = key; + vsimArgs[1] = "ELE"; + vsimArgs[3] = "COUNT"; + vsimArgs[4] = "25"; + foreach (var addedGridVector in addedGridVectors.Values) + { + vsimArgs[2] = addedGridVector.IdBytes; + var res = db.Execute("VSIM", vsimArgs); + ClassicAssert.AreEqual(25, res.Length); + var vsimIds = GetVectorIdsForVsimResults((RedisResult[])res); + var expectedNN = BruteForceNearestNeighbors(addedGridVectors, addedGridVector.Vector, 25); + var intersectionCount = CalculateDistanceCountsIntersection(addedGridVectors, addedGridVector.Vector, expectedNN, vsimIds); + var recall = (double)intersectionCount / expectedNN.Count; + ClassicAssert.GreaterOrEqual(recall, 0.99, $"Recall too low: {recall} for vector ID {addedGridVector.Id}"); + } + + // VSIM using values + var vsimValuesArgs = new object[5 + dimension]; + vsimValuesArgs[0] = key; + vsimValuesArgs[1] = "VALUES"; + vsimValuesArgs[2] = dimension.ToString(); + var valuesPos = 3; + vsimValuesArgs[3 + dimension] = "COUNT"; + vsimValuesArgs[4 + dimension] = "25"; + foreach (var addedGridVector in addedGridVectors.Values) + { + for (var i = 0; i < dimension; i++) + { + vsimValuesArgs[valuesPos + i] = addedGridVector.VectorStringValues[i]; + } + + var res = db.Execute("VSIM", vsimValuesArgs); + ClassicAssert.AreEqual(25, res.Length); + var vsimIds = GetVectorIdsForVsimResults((RedisResult[])res); + var expectedNN = BruteForceNearestNeighbors(addedGridVectors, addedGridVector.Vector, 25); + var intersectionCount = CalculateDistanceCountsIntersection(addedGridVectors, addedGridVector.Vector, expectedNN, vsimIds); + var recall = (double)intersectionCount / expectedNN.Count; + ClassicAssert.GreaterOrEqual(recall, 0.99, $"Recall too low: {recall} for vector ID {addedGridVector.Id}"); + } + } + + private static List GenerateGridVectors(int dimensions, int gridSize) + { + List vectors = []; + var totalVectors = (int)Math.Pow(gridSize, dimensions); + for (var i = 0; i < totalVectors; i++) + { + var vector = new int[dimensions]; + var pos = i; + for (var d = 0; d < dimensions; d++) + { + vector[d] = pos % gridSize; + pos /= gridSize; + } + + var vectorId = i + 1; + var idBytes = new byte[4]; + idBytes[0] = (byte)(vectorId & 0xFF); + idBytes[1] = (byte)((vectorId >> 8) & 0xFF); + idBytes[2] = (byte)((vectorId >> 16) & 0xFF); + idBytes[3] = (byte)((vectorId >> 24) & 0xFF); + + vectors.Add(new GridVector + { + Id = vectorId, + Vector = vector, + IdBytes = idBytes, + VectorStringValues = vector.Select(v => v.ToString()).ToArray() + }); + } + + return vectors; + } + + private static Dictionary AddGridVectors(IDatabase db, string key, int dimension, int gridSize, string quantType) + { + var gridVectors = GenerateGridVectors(dimension, gridSize); + List baseArgs = []; + + baseArgs.Add(key); + baseArgs.Add("VALUES"); + baseArgs.Add(dimension.ToString()); + var dimensionsPos = baseArgs.Count; + for (var i = 0; i < dimension; i++) + { + baseArgs.Add(null); + } + + var idBytesPos = baseArgs.Count; + baseArgs.Add(null); + baseArgs.Add(quantType); + baseArgs.Add("EF"); + baseArgs.Add("10"); + baseArgs.Add("M"); + baseArgs.Add(Math.Max(5, dimension * 2).ToString()); + + var args = baseArgs.ToArray(); + foreach (var gridVector in gridVectors) + { + args[idBytesPos] = gridVector.IdBytes; + for (var i = 0; i < dimension; i++) + { + args[dimensionsPos + i] = gridVector.VectorStringValues[i]; + } + + var res = db.Execute("VADD", args); + ClassicAssert.AreEqual(1, (int)res); + } + + return gridVectors.ToDictionary(gv => gv.Id); + } + + private static HashSet GetVectorIdsForVsimResults(RedisResult[] vsimResults) + { + HashSet ids = []; + foreach (var item in vsimResults) + { + var bytes = (byte[])item; + var id = bytes[0] | (bytes[1] << 8) | (bytes[2] << 16) | (bytes[3] << 24); + ids.Add(id); + } + + return ids; + } + + private static int CalculateDistanceCountsIntersection(Dictionary gridVectors, int[] queryVector, HashSet bruteForceSearch, HashSet vsimSearch) + { + var expectedDistances = CalculateCountsPerDistance(gridVectors, queryVector, bruteForceSearch); + var actualDistances = CalculateCountsPerDistance(gridVectors, queryVector, vsimSearch); + + // Intersect distance counts + var intersectionCount = 0; + foreach (var kvp in expectedDistances) + { + var dist = kvp.Key; + var expectedCount = kvp.Value; + if (actualDistances.ContainsKey(dist)) + { + var actualCount = actualDistances[dist]; + intersectionCount += Math.Min(expectedCount, actualCount); + } + } + + return intersectionCount; + } + + private static Dictionary CalculateCountsPerDistance(Dictionary gridVectors, int[] queryVector, HashSet vsimIdResults) + { + Dictionary countsPerDistance = []; + foreach (var id in vsimIdResults) + { + var gridVector = gridVectors[id]; + var dist = CalculateSquaredL2Distance(gridVector.Vector, queryVector); + if (!countsPerDistance.ContainsKey(dist)) + { + countsPerDistance[dist] = 0; + } + + countsPerDistance[dist]++; + } + + return countsPerDistance; + } + + private static HashSet BruteForceNearestNeighbors(Dictionary gridVectors, int[] queryVector, int count) + { + PriorityQueue pq = new(); + foreach (var gridVector in gridVectors.Values) + { + double dist = 0; + for (var i = 0; i < queryVector.Length; i++) + { + double diff = gridVector.Vector[i] - queryVector[i]; + dist += diff * diff; + } + pq.Enqueue(gridVector.Id, -dist); + if (pq.Count > count) + { + pq.Dequeue(); + } + } + + HashSet result = []; + while (pq.Count > 0) + { + result.Add(pq.Dequeue()); + } + + return result; + } + + private static long CalculateSquaredL2Distance(int[] vec1, int[] vec2) + { + long dist = 0; + for (var i = 0; i < vec1.Length; i++) + { + long diff = vec1[i] - vec2[i]; + dist += diff * diff; + } + + return dist; + } + + private class GridVector + { + public int Id; + public byte[] IdBytes; + public int[] Vector; + public string[] VectorStringValues; + } + } +} \ No newline at end of file diff --git a/test/Garnet.test/DiskANN/DiskANNServiceTests.cs b/test/Garnet.test/DiskANN/DiskANNServiceTests.cs new file mode 100644 index 00000000000..b3e3166ba66 --- /dev/null +++ b/test/Garnet.test/DiskANN/DiskANNServiceTests.cs @@ -0,0 +1,498 @@ +using System; +using System.Buffers.Binary; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Allure.NUnit; +using Garnet.server; +using NUnit.Framework; +using NUnit.Framework.Legacy; +using StackExchange.Redis; + +namespace Garnet.test.DiskANN +{ + [AllureNUnit] + [TestFixture] + public class DiskANNServiceTests : AllureTestBase + { + private delegate void ReadCallbackDelegate(ulong context, uint numKeys, nint keysData, nuint keysLength, nint dataCallback, nint dataCallbackContext); + private delegate byte WriteCallbackDelegate(ulong context, nint keyData, nuint keyLength, nint writeData, nuint writeLength); + private delegate byte DeleteCallbackDelegate(ulong context, nint keyData, nuint keyLength); + private delegate byte ReadModifyWriteCallbackDelegate(ulong context, nint keyData, nuint keyLength, nuint writeLength, nint dataCallback, nint dataCallbackContext); + + private sealed class ContextAndKeyComparer : IEqualityComparer<(ulong Context, byte[] Data)> + { + public bool Equals((ulong Context, byte[] Data) x, (ulong Context, byte[] Data) y) + => x.Context == y.Context && x.Data.AsSpan().SequenceEqual(y.Data); + public int GetHashCode([DisallowNull] (ulong Context, byte[] Data) obj) + { + HashCode hash = default; + hash.Add(obj.Context); + hash.AddBytes(obj.Data); + + return hash.ToHashCode(); + } + } + + GarnetServer server; + + [SetUp] + public void Setup() + { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir); + server.Start(); + } + + [TearDown] + public void TearDown() + { + server.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } + + [Test] + public void CheckInternalId() + { + const ulong Context = 8; + + ConcurrentDictionary<(ulong Context, byte[] Key), byte[]> data = new(new ContextAndKeyComparer()); + + unsafe void ReadCallback( + ulong context, + uint numKeys, + nint keysData, + nuint keysLength, + nint dataCallback, + nint dataCallbackContext + ) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keysData), (int)keysLength); + + var remainingKeyDataSpan = keyDataSpan; + var dataCallbackDel = (delegate* unmanaged[Cdecl, SuppressGCTransition])dataCallback; + + for (var index = 0; index < numKeys; index++) + { + var keyLen = BinaryPrimitives.ReadInt32LittleEndian(remainingKeyDataSpan); + var keyData = remainingKeyDataSpan.Slice(sizeof(int), keyLen); + + remainingKeyDataSpan = remainingKeyDataSpan[(sizeof(int) + keyLen)..]; + + var lookup = (context, keyData.ToArray()); + if (data.TryGetValue(lookup, out var res)) + { + fixed (byte* resPtr = res) + { + dataCallbackDel(index, dataCallbackContext, (nint)resPtr, (nuint)res.Length); + } + } + } + } + + unsafe byte WriteCallback(ulong context, nint keyData, nuint keyLength, nint writeData, nuint writeLength) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keyData), (int)keyLength); + var writeDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)writeData), (int)writeLength); + + var lookup = (context, keyDataSpan.ToArray()); + + data[lookup] = writeDataSpan.ToArray(); + + return 1; + } + + unsafe byte DeleteCallback(ulong context, nint keyData, nuint keyLength) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keyData), (int)keyLength); + + var lookup = (context, keyDataSpan.ToArray()); + + if (data.TryRemove(lookup, out _)) + { + return 1; + } + + return 0; + } + + unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength, nuint writeLength, nint callback, nint callbackContext) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keyData), (int)keyLength); + + var lookup = (context, keyDataSpan.ToArray()); + + var callbackDel = (delegate* unmanaged[Cdecl, SuppressGCTransition])callback; + + _ = data.AddOrUpdate( + lookup, + key => + { + var ret = new byte[writeLength]; + fixed (byte* retPtr = ret) + { + callbackDel(callbackContext, (nint)retPtr, (nuint)ret.Length); + } + + return ret; + }, + (key, old) => + { + // Garnet guarantees no concurrent RMW update same value, but ConcurrentDictionary doesn't; so use a lock + lock (old) + { + fixed (byte* oldPtr = old) + { + callbackDel(callbackContext, (nint)oldPtr, (nuint)old.Length); + } + + return old; + } + } + ); + + return 1; + } + + ReadCallbackDelegate readDel = ReadCallback; + WriteCallbackDelegate writeDel = WriteCallback; + DeleteCallbackDelegate deleteDel = DeleteCallback; + ReadModifyWriteCallbackDelegate rmwDel = ReadModifyWriteCallback; + + var readFuncPtr = Marshal.GetFunctionPointerForDelegate(readDel); + var writeFuncPtr = Marshal.GetFunctionPointerForDelegate(writeDel); + var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel); + var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel); + + var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + + Span id = [0, 1, 2, 3]; + Span elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); + Span attr = []; + + // Insert + unsafe + { + var insertRes = NativeDiskANNMethods.insert(Context, rawIndex, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length, VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(attr)), (nuint)attr.Length); + ClassicAssert.AreEqual(1, insertRes); + } + + // Check valid initially + var internalId = data[(Context | DiskANNService.InternalIdMap, id.ToArray())]; + unsafe + { + var validRes = NativeDiskANNMethods.check_internal_id_valid(Context, rawIndex, (nint)Unsafe.AsPointer(ref internalId[0]), (nuint)internalId.Length); + ClassicAssert.AreEqual(1, validRes); + } + + // Remove + unsafe + { + var numRes = + NativeDiskANNMethods.remove( + Context, rawIndex, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length + ); + ClassicAssert.AreEqual(1, numRes); + } + + // Check no longer valid + unsafe + { + var validRes = NativeDiskANNMethods.check_internal_id_valid(Context, rawIndex, (nint)Unsafe.AsPointer(ref internalId[0]), (nuint)internalId.Length); + ClassicAssert.AreEqual(0, validRes); + } + + GC.KeepAlive(deleteDel); + GC.KeepAlive(writeDel); + GC.KeepAlive(readDel); + GC.KeepAlive(rmwDel); + } + + + [Test] + public void VADD() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var res1 = db.Execute("VADD", ["foo", "VALUES", "4", "1.0", "1.0", "1.0", "1.0", new byte[] { 1, 0, 0, 0 }, "NOQUANT", "EF", "128", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "VALUES", "4", "2.0", "2.0", "2.0", "2.0", new byte[] { 2, 0, 0, 0 }, "NOQUANT", "EF", "128", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res2); + } + + [Test] + public void VSIM() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var res1 = db.Execute("VADD", ["foo", "VALUES", "4", "1.0", "1.0", "1.0", "1.0", new byte[] { 1, 0, 0, 0 }, "EF", "128", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "VALUES", "4", "2.0", "2.0", "2.0", "2.0", new byte[] { 2, 0, 0, 0 }, "EF", "128", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res3 = (byte[][])db.Execute("VSIM", ["foo", "VALUES", "4", "0.0", "0.0", "0.0", "0.0", "COUNT", "5", "EF", "128"]); + ClassicAssert.AreEqual(2, res3.Length); + ClassicAssert.IsTrue(res3.Any(static x => x.SequenceEqual(new byte[] { 1, 0, 0, 0 }))); + ClassicAssert.IsTrue(res3.Any(static x => x.SequenceEqual(new byte[] { 2, 0, 0, 0 }))); + + var res4 = (byte[][])db.Execute("VSIM", ["foo", "ELE", new byte[] { 1, 0, 0, 0 }, "COUNT", "5", "EF", "128"]); + ClassicAssert.AreEqual(2, res4.Length); + ClassicAssert.IsTrue(res4.Any(static x => x.SequenceEqual(new byte[] { 1, 0, 0, 0 }))); + ClassicAssert.IsTrue(res4.Any(static x => x.SequenceEqual(new byte[] { 2, 0, 0, 0 }))); + } + + [Test] + public void Recreate() + { + const ulong Context = 8; + + ConcurrentDictionary<(ulong Context, byte[] Key), byte[]> data = new(new ContextAndKeyComparer()); + + unsafe void ReadCallback( + ulong context, + uint numKeys, + nint keysData, + nuint keysLength, + nint dataCallback, + nint dataCallbackContext + ) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keysData), (int)keysLength); + + var remainingKeyDataSpan = keyDataSpan; + var dataCallbackDel = (delegate* unmanaged[Cdecl, SuppressGCTransition])dataCallback; + + for (var index = 0; index < numKeys; index++) + { + var keyLen = BinaryPrimitives.ReadInt32LittleEndian(remainingKeyDataSpan); + var keyData = remainingKeyDataSpan.Slice(sizeof(int), keyLen); + + remainingKeyDataSpan = remainingKeyDataSpan[(sizeof(int) + keyLen)..]; + + var lookup = (context, keyData.ToArray()); + if (data.TryGetValue(lookup, out var res)) + { + fixed (byte* resPtr = res) + { + dataCallbackDel(index, dataCallbackContext, (nint)resPtr, (nuint)res.Length); + } + } + } + } + + unsafe byte WriteCallback(ulong context, nint keyData, nuint keyLength, nint writeData, nuint writeLength) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keyData), (int)keyLength); + var writeDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)writeData), (int)writeLength); + + var lookup = (context, keyDataSpan.ToArray()); + + data[lookup] = writeDataSpan.ToArray(); + + return 1; + } + + unsafe byte DeleteCallback(ulong context, nint keyData, nuint keyLength) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keyData), (int)keyLength); + + var lookup = (context, keyDataSpan.ToArray()); + + if (data.TryRemove(lookup, out _)) + { + return 1; + } + + return 0; + } + + unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength, nuint writeLength, nint callback, nint callbackContext) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keyData), (int)keyLength); + + var lookup = (context, keyDataSpan.ToArray()); + + var callbackDel = (delegate* unmanaged[Cdecl, SuppressGCTransition])callback; + + _ = data.AddOrUpdate( + lookup, + key => + { + var ret = new byte[writeLength]; + fixed (byte* retPtr = ret) + { + callbackDel(callbackContext, (nint)retPtr, (nuint)ret.Length); + } + + return ret; + }, + (key, old) => + { + // Garnet guarantees no concurrent RMW update same value, but ConcurrentDictionary doesn't; so use a lock + lock (old) + { + fixed (byte* oldPtr = old) + { + callbackDel(callbackContext, (nint)oldPtr, (nuint)old.Length); + } + + return old; + } + } + ); + + return 1; + } + + ReadCallbackDelegate readDel = ReadCallback; + WriteCallbackDelegate writeDel = WriteCallback; + DeleteCallbackDelegate deleteDel = DeleteCallback; + ReadModifyWriteCallbackDelegate rmwDel = ReadModifyWriteCallback; + + var readFuncPtr = Marshal.GetFunctionPointerForDelegate(readDel); + var writeFuncPtr = Marshal.GetFunctionPointerForDelegate(writeDel); + var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel); + var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel); + + var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + + Span id = [0, 1, 2, 3]; + Span elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); + Span attr = []; + + // Insert + unsafe + { + var insertRes = NativeDiskANNMethods.insert(Context, rawIndex, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length, VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(attr)), (nuint)attr.Length); + ClassicAssert.AreEqual(1, insertRes); + } + + Span filter = []; + + // Search + unsafe + { + Span outputIds = stackalloc byte[1024]; + Span outputDistances = stackalloc float[64]; + + nint continuation = 0; + + var numRes = + NativeDiskANNMethods.search_vector( + Context, rawIndex, + VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, + 1f, outputDistances.Length, // SearchExplorationFactor must >= Count + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)), (nuint)filter.Length, + 0, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputIds)), (nuint)outputIds.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputDistances)), (nuint)outputDistances.Length, + (nint)Unsafe.AsPointer(ref continuation) + ); + ClassicAssert.AreEqual(1, numRes); + + var firstResLen = BinaryPrimitives.ReadInt32LittleEndian(outputIds); + var firstRes = outputIds.Slice(sizeof(int), firstResLen); + ClassicAssert.IsTrue(firstRes.SequenceEqual(id)); + } + + // Drop does not cleanup data, so use it to simulate a process stop and recreate + { + NativeDiskANNMethods.drop_index(Context, rawIndex); + + rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + } + + // Search value + unsafe + { + Span outputIds = stackalloc byte[1024]; + Span outputDistances = stackalloc float[64]; + + nint continuation = 0; + + var numRes = + NativeDiskANNMethods.search_vector( + Context, rawIndex, + VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, + 1f, outputDistances.Length, // SearchExplorationFactor must >= Count + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)), (nuint)filter.Length, + 0, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputIds)), (nuint)outputIds.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputDistances)), (nuint)outputDistances.Length, + (nint)Unsafe.AsPointer(ref continuation) + ); + ClassicAssert.AreEqual(1, numRes); + + var firstResLen = BinaryPrimitives.ReadInt32LittleEndian(outputIds); + var firstRes = outputIds.Slice(sizeof(int), firstResLen); + ClassicAssert.IsTrue(firstRes.SequenceEqual(id)); + } + + // Search element + unsafe + { + Span outputIds = stackalloc byte[1024]; + Span outputDistances = stackalloc float[64]; + + nint continuation = 0; + + var numRes = + NativeDiskANNMethods.search_element( + Context, rawIndex, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length, + 1f, outputDistances.Length, // SearchExplorationFactor must >= Count + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)), (nuint)filter.Length, + 0, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputIds)), (nuint)outputIds.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputDistances)), (nuint)outputDistances.Length, + (nint)Unsafe.AsPointer(ref continuation) + ); + ClassicAssert.AreEqual(1, numRes); + + var firstResLen = BinaryPrimitives.ReadInt32LittleEndian(outputIds); + var firstRes = outputIds.Slice(sizeof(int), firstResLen); + ClassicAssert.IsTrue(firstRes.SequenceEqual(id)); + } + + // Remove + unsafe + { + var numRes = + NativeDiskANNMethods.remove( + Context, rawIndex, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length + ); + ClassicAssert.AreEqual(1, numRes); + } + + // Insert + unsafe + { + Span id2 = [4, 5, 6, 7]; + Span elem2 = Enumerable.Range(0, 75).Select(static x => (byte)(x * 2)).ToArray(); + ReadOnlySpan attr2 = "{\"foo\": \"bar\"}"u8; + + var insertRes = NativeDiskANNMethods.insert( + Context, rawIndex, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id2)), (nuint)id2.Length, + VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem2)), (nuint)elem2.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(attr2)), (nuint)attr2.Length + ); + ClassicAssert.AreEqual(1, insertRes); + } + + GC.KeepAlive(deleteDel); + GC.KeepAlive(writeDel); + GC.KeepAlive(readDel); + GC.KeepAlive(rmwDel); + } + } +} \ No newline at end of file diff --git a/test/Garnet.test/GarnetServerConfigTests.cs b/test/Garnet.test/GarnetServerConfigTests.cs index f36df8a9ab0..565655051ea 100644 --- a/test/Garnet.test/GarnetServerConfigTests.cs +++ b/test/Garnet.test/GarnetServerConfigTests.cs @@ -940,6 +940,123 @@ public void ClusterReplicaResumeWithData() } } + [Test] + public void EnableVectorSetPreview() + { + // Command line args + { + // Default accepted + { + var args = Array.Empty(); + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsFalse(options.EnableVectorSetPreview); + } + + // Switch is accepted + { + var args = new[] { "--enable-vector-set-preview" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableVectorSetPreview); + } + } + + // JSON args + { + // Default accepted + { + const string JSON = @"{ }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsFalse(options.EnableVectorSetPreview); + } + + // False is accepted + { + const string JSON = @"{ ""EnableVectorSetPreview"": false }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsFalse(options.EnableVectorSetPreview); + } + + // True is accepted + { + const string JSON = @"{ ""EnableVectorSetPreview"": true }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableVectorSetPreview); + } + + // Invalid rejected + { + const string JSON = @"{ ""EnableVectorSetPreview"": ""foo"" }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsFalse(parseSuccessful); + } + } + } + + [Test] + public void MinimumPageSizeWithVectorSetPreview() + { + // Command line args + { + // Allow exactly minimum + { + var args = new[] { "--enable-vector-set-preview", "--page", "16k" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableVectorSetPreview); + ClassicAssert.AreEqual("16k", options.PageSize); + } + + // Allow lower than minimum if preview not enabled + { + var args = new[] { "--page", "1k" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsFalse(options.EnableVectorSetPreview); + ClassicAssert.AreEqual("1k", options.PageSize); + } + + // Reject too small + { + var args = new[] { "--enable-vector-set-preview", "--page", "4k" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out _, out _, out _, out _); + ClassicAssert.IsFalse(parseSuccessful); + } + } + + // JSON args + { + // Allow exactly minimum + { + const string JSON = @"{ ""EnableVectorSetPreview"": true, ""PageSize"": ""16k"" }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableVectorSetPreview); + ClassicAssert.AreEqual("16k", options.PageSize); + } + + // Allow lower than minimum if preview not enabled + { + const string JSON = @"{ ""EnableVectorSetPreview"": false, ""PageSize"": ""1k"" }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsFalse(options.EnableVectorSetPreview); + ClassicAssert.AreEqual("1k", options.PageSize); + } + + // Reject too small + { + const string JSON = @"{ ""EnableVectorSetPreview"": true, ""PageSize"": ""4k"" }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out _, out _, out _); + ClassicAssert.IsFalse(parseSuccessful); + } + } + } + /// /// Import a garnet.conf file with the given contents /// diff --git a/test/Garnet.test/ReadOptimizedLockTests.cs b/test/Garnet.test/ReadOptimizedLockTests.cs new file mode 100644 index 00000000000..43e749b03e7 --- /dev/null +++ b/test/Garnet.test/ReadOptimizedLockTests.cs @@ -0,0 +1,285 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Allure.NUnit; +using Garnet.common; +using NUnit.Framework; +using NUnit.Framework.Legacy; + +namespace Garnet.test +{ + [AllureNUnit] + [TestFixture] + public class ReadOptimizedLockTests : AllureTestBase + { + [TestCase(123)] + [TestCase(0)] + [TestCase(1)] + [TestCase(-1)] + [TestCase(int.MaxValue)] + [TestCase(int.MinValue)] + public void BasicLocks(int hash) + { + var lockContext = new ReadOptimizedLock(16); + + var gotShared0 = lockContext.TryAcquireSharedLock(hash, out var sharedToken0); + ClassicAssert.IsTrue(gotShared0); + + var gotShared1 = lockContext.TryAcquireSharedLock(hash, out var sharedToken1); + ClassicAssert.IsTrue(gotShared1); + + var gotExclusive = lockContext.TryAcquireExclusiveLock(hash, out _); + ClassicAssert.IsFalse(gotExclusive); + + lockContext.ReleaseSharedLock(sharedToken0); + lockContext.ReleaseSharedLock(sharedToken1); + + var gotExclusiveAgain = lockContext.TryAcquireExclusiveLock(hash, out var exclusiveToken); + ClassicAssert.IsTrue(gotExclusiveAgain); + + var gotSharedAgain = lockContext.TryAcquireSharedLock(hash, out _); + ClassicAssert.IsFalse(gotSharedAgain); + + lockContext.ReleaseExclusiveLock(exclusiveToken); + } + + [Test] + public void IndexCalculations() + { + const int Iters = 10_000; + + var lockContext = new ReadOptimizedLock(16); + + var rand = new Random(2025_11_17_00); + + var offsets = new HashSet(); + + for (var i = 0; i < Iters; i++) + { + offsets.Clear(); + + // Bunch of random hashes, including negative ones, to prove reasonable calculations + var hash = (int)rand.NextInt64(); + + var hintBase = (int)rand.NextInt64(); + + for (var j = 0; j < Environment.ProcessorCount; j++) + { + var offset = lockContext.CalculateIndex(hash, hintBase + j); + ClassicAssert.True(offsets.Add(offset)); + } + + foreach (var offset in offsets) + { + var tooClose = offsets.Except([offset]).Where(x => Math.Abs(x - offset) < ReadOptimizedLock.CacheLineSizeBytes / sizeof(int)); + ClassicAssert.IsEmpty(tooClose); + } + } + } + + [TestCase(1)] + [TestCase(4)] + [TestCase(16)] + [TestCase(64)] + [TestCase(128)] + public void Threaded(int hashCount) + { + // Guard some number of distinct value "slots" (defined by hashes) + // + // Runs threads which (randomly) either read values, write values, or read (then promote) and write. + // + // Reads check for correctness. + // Writes are done "plain" with no other locking or coherency enforcement. + + const int Iters = 100_000; + const int LongsPerSlot = 4; + + var lockContext = new ReadOptimizedLock(Math.Min(Math.Max(hashCount / 2, 1), Environment.ProcessorCount)); + + var threads = new Thread[Math.Max(Environment.ProcessorCount, 4)]; + + using var threadStart = new SemaphoreSlim(0, threads.Length); + + var globalRandom = new Random(2025_11_17_01); + + var hashes = new int[hashCount]; + for (var i = 0; i < hashes.Length; i++) + { + var nextHash = (int)globalRandom.NextInt64(); + if (hashes.AsSpan()[..i].Contains(nextHash)) + { + i--; + continue; + } + hashes[i] = nextHash; + } + + var values = new long[hashes.Length][]; + for (var i = 0; i < values.Length; i++) + { + values[i] = new long[LongsPerSlot]; + } + + // Spin up a bunch of mutators + for (var i = 0; i < threads.Length; i++) + { + var threadRandom = new Random(2025_11_17_01 + ((i + 1) * 100_000)); + + threads[i] = + new( + () => + { + threadStart.Wait(); + + for (var j = 0; j < Iters; j++) + { + var hashIx = threadRandom.Next(hashes.Length); + var hash = hashes[hashIx]; + + switch (threadRandom.Next(5)) + { + // Try: Read and verify + case 0: + { + if (lockContext.TryAcquireSharedLock(hash, out var sharedLockToken)) + { + var sub = values[hashIx]; + for (var k = 1; k < sub.Length; k++) + { + ClassicAssert.AreEqual(sub[0], sub[k]); + } + + lockContext.ReleaseSharedLock(sharedLockToken); + } + else + { + j--; + } + } + break; + + // Try: Lock, modify + case 1: + { + if (lockContext.TryAcquireExclusiveLock(hash, out var exclusiveLockToken)) + { + var sub = values[hashIx]; + var newValue = threadRandom.NextInt64(); + for (var k = 0; k < sub.Length; k++) + { + sub[k] = newValue; + } + + lockContext.ReleaseExclusiveLock(exclusiveLockToken); + } + else + { + j--; + } + } + break; + + // Demand: Read and verify + case 2: + { + lockContext.AcquireSharedLock(hash, out var sharedLockToken); + var sub = values[hashIx]; + for (var k = 1; k < sub.Length; k++) + { + ClassicAssert.AreEqual(sub[0], sub[k]); + } + + lockContext.ReleaseSharedLock(sharedLockToken); + } + + break; + + // Demand: Lock, modify + case 3: + { + lockContext.AcquireExclusiveLock(hash, out var exclusiveLockToken); + var sub = values[hashIx]; + var newValue = threadRandom.NextInt64(); + for (var k = 0; k < sub.Length; k++) + { + sub[k] = newValue; + } + + lockContext.ReleaseExclusiveLock(exclusiveLockToken); + } + + break; + + // Try: Read, verify, promote, modify + case 4: + { + if (lockContext.TryAcquireSharedLock(hash, out var sharedLockToken)) + { + var sub = values[hashIx]; + for (var k = 1; k < sub.Length; k++) + { + ClassicAssert.AreEqual(sub[0], sub[k]); + } + + if (lockContext.TryPromoteSharedLock(hash, sharedLockToken, out var exclusiveLockToken)) + { + var newValue = threadRandom.NextInt64(); + for (var k = 0; k < sub.Length; k++) + { + sub[k] = newValue; + } + + lockContext.ReleaseExclusiveLock(exclusiveLockToken); + } + else + { + lockContext.ReleaseSharedLock(sharedLockToken); + + j--; + } + } + else + { + j--; + } + } + + break; + + // There is no Demand version of Promote because that is not safe in general + + default: throw new InvalidOperationException($"Unexpected op"); + } + } + } + ) + { + Name = $"{nameof(Threaded)} #{i}" + }; + threads[i].Start(); + } + + // Let threads run + _ = threadStart.Release(threads.Length); + + // Wait for threads to finish + foreach (var thread in threads) + { + thread.Join(); + } + + // Validate correctness of final state + foreach (var vals in values) + { + for (var k = 1; k < vals.Length; k++) + { + ClassicAssert.AreEqual(vals[0], vals[k]); + } + } + } + } +} \ No newline at end of file diff --git a/test/Garnet.test/Resp/ACL/RespCommandTests.cs b/test/Garnet.test/Resp/ACL/RespCommandTests.cs index 8fb706cfd41..e89942bdb60 100644 --- a/test/Garnet.test/Resp/ACL/RespCommandTests.cs +++ b/test/Garnet.test/Resp/ACL/RespCommandTests.cs @@ -2036,6 +2036,35 @@ static async Task DoClusterReplicateAsync(GarnetClient client) } } + [Test] + public async Task ClusterReserveACLsAsync() + { + // All cluster command "success" is a thrown exception, because clustering is disabled + + await CheckCommandsAsync( + "CLUSTER RESERVE", + [DoClusterReserveAsync] + ); + + static async Task DoClusterReserveAsync(GarnetClient client) + { + try + { + await client.ExecuteForStringResultAsync("CLUSTER", ["RESERVE", "VECTOR_SET_CONTEXTS", "16"]); + Assert.Fail("Shouldn't be reachable, cluster isn't enabled"); + } + catch (Exception e) + { + if (e.Message == "ERR This instance has cluster support disabled") + { + return; + } + + throw; + } + } + } + [Test] public async Task ClusterResetACLsAsync() { @@ -7487,6 +7516,205 @@ static async Task DoUnwatchAsync(GarnetClient client) } } + [Test] + public async Task VAddACLsAsync() + { + await CheckCommandsAsync( + "VADD", + [DoVAddAsync] + ); + + static async Task DoVAddAsync(GarnetClient client) + { + var elem = Encoding.ASCII.GetString("\x0\x1\x2\x3"u8); + + long val = await client.ExecuteForLongResultAsync("VADD", ["foo", "REDUCE", "50", "VALUES", "4", "1.0", "2.0", "3.0", "4.0", elem, "CAS", "Q8", "EF", "16", "SETATTR", "{ 'hello': 'world' }", "M", "32"]); + ClassicAssert.AreEqual(1, val); + } + } + + [Test] + public async Task VCardACLsAsync() + { + await CheckCommandsAsync( + "VCARD", + [DoVCardAsync] + ); + + static async Task DoVCardAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VCARD", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VDimACLsAsync() + { + await CheckCommandsAsync( + "VDIM", + [DoVDimAsync] + ); + + static async Task DoVDimAsync(GarnetClient client) + { + try + { + _ = await client.ExecuteForStringResultAsync("VDIM", ["foo"]); + ClassicAssert.Fail("Shouldn't be reachable"); + } + catch (Exception e) when (e.Message.Equals("ERR Key not found")) + { + // Excepted + } + } + } + + [Test] + public async Task VEmbACLsAsync() + { + await CheckCommandsAsync( + "VEMB", + [DoVEmbAsync] + ); + + static async Task DoVEmbAsync(GarnetClient client) + { + string[] val = await client.ExecuteForStringArrayResultAsync("VEMB", ["foo", "bar"]); + ClassicAssert.AreEqual(0, val.Length); + } + } + + [Test] + public async Task VGetAttrACLsAsync() + { + await CheckCommandsAsync( + "VGETATTR", + [DoVGetAttrAsync] + ); + + static async Task DoVGetAttrAsync(GarnetClient client) + { + string val = await client.ExecuteForStringResultAsync("VGETATTR", ["foo", "wololo"]); + ClassicAssert.AreEqual(null, val); + } + } + + [Test] + public async Task VInfoACLsAsync() + { + await CheckCommandsAsync( + "VINFO", + [DoVInfoAsync] + ); + + static async Task DoVInfoAsync(GarnetClient client) + { + var res = await client.ExecuteForStringArrayResultAsync("VINFO", ["foo"]); + ClassicAssert.AreEqual(res, null); + } + } + + [Test] + public async Task VIsMemberACLsAsync() + { + await CheckCommandsAsync( + "VISMEMBER", + [DoVIsMemberAsync] + ); + + static async Task DoVIsMemberAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VISMEMBER", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VLinksACLsAsync() + { + await CheckCommandsAsync( + "VLINKS", + [DoVLinksAsync] + ); + + static async Task DoVLinksAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VLINKS", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VRandMemberACLsAsync() + { + await CheckCommandsAsync( + "VRANDMEMBER", + [DoVRandMemberAsync] + ); + + static async Task DoVRandMemberAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VRANDMEMBER", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VRemACLsAsync() + { + await CheckCommandsAsync( + "VREM", + [DoVRemAsync] + ); + + static async Task DoVRemAsync(GarnetClient client) + { + long val = await client.ExecuteForLongResultAsync("VREM", ["foo", Encoding.UTF8.GetString("\0\0\0\0"u8)]); + ClassicAssert.AreEqual(0, val); + } + } + + [Test] + public async Task VSetAttrACLsAsync() + { + await CheckCommandsAsync( + "VSETATTR", + [DoVSetAttrAsync] + ); + + static async Task DoVSetAttrAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VSETATTR", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VSimACLsAsync() + { + await CheckCommandsAsync( + "VSIM", + [DoVSimAsync] + ); + + static async Task DoVSimAsync(GarnetClient client) + { + string[] val = await client.ExecuteForStringArrayResultAsync("VSIM", ["foo", "ELE", "bar"]); + ClassicAssert.AreEqual(0, val.Length); + } + } + /// /// Take a command (or subcommand, with a space) and check that adding and removing /// command, subcommand, and categories ACLs behaves as expected. diff --git a/test/Garnet.test/Resp/GarnetAuthenticatorTests.cs b/test/Garnet.test/Resp/GarnetAuthenticatorTests.cs index 8728325e62f..810b1f55bf6 100644 --- a/test/Garnet.test/Resp/GarnetAuthenticatorTests.cs +++ b/test/Garnet.test/Resp/GarnetAuthenticatorTests.cs @@ -60,18 +60,21 @@ public async Task InvalidatingAuthorizationAsync() auth.HasACLSupport = false; auth.IsAuthenticated = false; - int authCalls = 0; + var authCalls = 0; + var authingAsFoo = false; + var authedAsFoo = false; auth.AuthenticateCallback = (p, u) => { - if (authCalls == 0) + if (!authingAsFoo) { ClassicAssert.AreEqual("default", Encoding.UTF8.GetString(u)); } else { ClassicAssert.AreEqual("foo", Encoding.UTF8.GetString(u)); + authedAsFoo = true; } authCalls++; @@ -87,21 +90,22 @@ public async Task InvalidatingAuthorizationAsync() c.Connect(); // Initial command runs under default user - await c.ExecuteAsync("PING"); - ClassicAssert.AreEqual(1, authCalls); + _ = await c.ExecuteAsync("PING"); // Auth as proper user, should get another call - await c.ExecuteAsync("AUTH", "foo", "bar"); - ClassicAssert.AreEqual(2, authCalls); + authingAsFoo = true; + _ = await c.ExecuteAsync("AUTH", "foo", "bar"); + ClassicAssert.IsTrue(authedAsFoo); - await c.ExecuteAsync("PING"); - ClassicAssert.AreEqual(2, authCalls); + _ = await c.ExecuteAsync("PING"); // Command after auth invalidation fails as no auth + + var oldAuthCalls = authCalls; auth.IsAuthenticated = false; try { - await c.ExecuteAsync("PING"); + _ = await c.ExecuteAsync("PING"); Assert.Fail("Should be denied, user is not authed"); } catch (Exception e) @@ -109,8 +113,8 @@ public async Task InvalidatingAuthorizationAsync() ClassicAssert.AreEqual("NOAUTH Authentication required.", e.Message); } - await c.ExecuteAsync("AUTH", "foo", "bar"); - ClassicAssert.AreEqual(3, authCalls); + _ = await c.ExecuteAsync("AUTH", "foo", "bar"); + ClassicAssert.True(authCalls > oldAuthCalls); } } } \ No newline at end of file diff --git a/test/Garnet.test/RespCustomCommandTests.cs b/test/Garnet.test/RespCustomCommandTests.cs index 93b34e44cc1..954a915dc16 100644 --- a/test/Garnet.test/RespCustomCommandTests.cs +++ b/test/Garnet.test/RespCustomCommandTests.cs @@ -211,8 +211,7 @@ public override unsafe void Main(TGarnetApi garnetApi, ref CustomPro ArgSlice valForKey1 = new ArgSlice(valuePtr, valueToMessWith.Count); input.parseState.InitializeWithArgument(valForKey1); // since we are setting with retain to etag, this change should be reflected in an etag update - SpanByte sameKeyToUse = key.SpanByte; - garnetApi.SET_Conditional(ref sameKeyToUse, ref input); + garnetApi.SET_Conditional(key, ref input); } diff --git a/test/Garnet.test/RespSortedSetTests.cs b/test/Garnet.test/RespSortedSetTests.cs index 06554221d98..82e5ffb3a21 100644 --- a/test/Garnet.test/RespSortedSetTests.cs +++ b/test/Garnet.test/RespSortedSetTests.cs @@ -25,7 +25,10 @@ namespace Garnet.test SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; [AllureNUnit] [TestFixture] diff --git a/test/Garnet.test/RespVectorSetTests.cs b/test/Garnet.test/RespVectorSetTests.cs new file mode 100644 index 00000000000..c4eff14c39a --- /dev/null +++ b/test/Garnet.test/RespVectorSetTests.cs @@ -0,0 +1,2247 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using Allure.NUnit; +using Garnet.common; +using Garnet.server; +using NUnit.Framework; +using NUnit.Framework.Legacy; +using StackExchange.Redis; +using Tsavorite.core; + +namespace Garnet.test +{ + [AllureNUnit] + [TestFixture] + public class RespVectorSetTests : AllureTestBase + { + GarnetServer server; + + [SetUp] + public void Setup() + { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, enableAOF: true); + + server.Start(); + } + + [TearDown] + public void TearDown() + { + server.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } + + [Test] + public void DisabledWithFeatureFlag() + { + // Restart with Vector Sets disabled + TearDown(); + + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, enableAOF: true, enableVectorSetPreview: false); + + server.Start(); + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + ReadOnlySpan vectorSetCommands = [RespCommand.VADD, RespCommand.VCARD, RespCommand.VDIM, RespCommand.VEMB, RespCommand.VGETATTR, RespCommand.VINFO, RespCommand.VISMEMBER, RespCommand.VLINKS, RespCommand.VRANDMEMBER, RespCommand.VREM, RespCommand.VSETATTR, RespCommand.VSIM]; + foreach (var cmd in vectorSetCommands) + { + // Should all fault before any validation + var exc = ClassicAssert.Throws(() => db.Execute(cmd.ToString())); + ClassicAssert.AreEqual("ERR Vector Set (preview) commands are not enabled", exc.Message); + } + } + + [Test] + public void OversizedRejected() + { + var options = GetOpts(server); + + var overflowSizeBytes = (int)(GarnetServerOptions.ParseSize(options.PageSize, out _) * 2); + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var oversizedVectorData = Enumerable.Repeat(1, overflowSizeBytes).ToArray(); + var oversideAttribute = Enumerable.Repeat(2, overflowSizeBytes).ToArray(); + + var exc1 = ClassicAssert.Throws(() => db.Execute("VADD", ["foo", "XB8", oversizedVectorData, new byte[] { 0, 0, 0, 0 }, "XPREQ8"])); + ClassicAssert.AreEqual("ERR Vector exceed configured page size", exc1.Message); + + var basicVectorData = Enumerable.Repeat(3, 75).ToArray(); + + var exc2 = ClassicAssert.Throws(() => db.Execute("VADD", ["foo", "XB8", basicVectorData, new byte[] { 0, 0, 0, 1 }, "XPREQ8", "SETATTR", oversideAttribute])); + ClassicAssert.AreEqual("ERR Attribute exceed configured page size", exc2.Message); + } + + [Test] + public void WrongTypeForVectorSetOpsOnNonVectorSetKeys() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var vectorSetCommands = Enum.GetValues().Where(static t => t.IsLegalOnVectorSet() && !(t is RespCommand.DEL or RespCommand.UNLINK or RespCommand.DEBUG or RespCommand.RENAME or RespCommand.RENAMENX or RespCommand.TYPE)); + + // Strings + { + var res = db.StringSet("foo", "bar"); + ClassicAssert.IsTrue(res); + + foreach (var cmd in vectorSetCommands) + { + RedisServerException exc; + switch (cmd) + { + case RespCommand.VADD: + exc = ClassicAssert.Throws(() => db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"])); + break; + case RespCommand.VCARD: + // TODO: Implement when VCARD works + continue; + case RespCommand.VDIM: + exc = ClassicAssert.Throws(() => db.Execute("VDIM", ["foo"])); + break; + case RespCommand.VEMB: + exc = ClassicAssert.Throws(() => db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 0 }])); + break; + case RespCommand.VGETATTR: + exc = ClassicAssert.Throws(() => db.Execute("VGETATTR", ["foo", new byte[] { 0, 0, 0, 0 }])); + break; + case RespCommand.VINFO: + exc = ClassicAssert.Throws(() => db.Execute("VINFO", ["foo"])); + break; + case RespCommand.VISMEMBER: + // TODO: Implement when VISMEMBER works + continue; + case RespCommand.VLINKS: + // TODO: Implement when VLINKS works + continue; + case RespCommand.VRANDMEMBER: + // TODO: Implement when VRANDMEMBER works + continue; + case RespCommand.VREM: + exc = ClassicAssert.Throws(() => db.Execute("VREM", ["foo", new byte[] { 0, 0, 0, 0 }])); + break; + case RespCommand.VSETATTR: + // TODO: Implement when VSETATTR works + continue; + case RespCommand.VSIM: + exc = ClassicAssert.Throws(() => db.Execute("VSIM", ["foo", "VALUES", "75", "110.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "COUNT", "5", "EPSILON", "1.0", "EF", "40"])); + break; + default: + throw new InvalidOperationException($"Unexpected Vector Set command: {cmd}"); + } + + ClassicAssert.AreEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exc.Message, $"RESP Command: {cmd}"); + } + } + + // TODO: Other objects - but we can wait for store v2 for that + } + + [Test] + public void VADD() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // VALUES + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "100.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 1, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res2); + + var float3 = new float[75]; + float3[0] = 5f; + for (var i = 1; i < float3.Length; i++) + { + float3[i] = float3[i - 1] + 1; + } + + // FP32 + var res3 = db.Execute("VADD", ["foo", "REDUCE", "50", "FP32", MemoryMarshal.Cast(float3).ToArray(), new byte[] { 2, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res3); + + var byte4 = new byte[75]; + byte4[0] = 9; + for (var i = 1; i < byte4.Length; i++) + { + byte4[i] = (byte)(byte4[i - 1] + 1); + } + + // XB8 + var res4 = db.Execute("VADD", ["foo", "REDUCE", "50", "XB8", byte4, new byte[] { 3, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res4); + + // TODO: exact duplicates - what does Redis do? + + // Add without specifying reductions after first vector + var res5 = db.Execute("VADD", ["fizz", "REDUCE", "50", "VALUES", "75", "150.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res5); + + var exc1 = ClassicAssert.Throws(() => db.Execute("VADD", ["fizz", "VALUES", "4", "5.0", "6.0", "7.0", "8.0", new byte[] { 0, 0, 0, 1 }, "CAS", "NOQUANT", "EF", "16", "M", "32"])); + ClassicAssert.AreEqual("ERR Vector dimension mismatch - got 4 but set has 75", exc1.Message); + + // Add without specifying EF after first vector + var res6 = db.Execute("VADD", ["fizz", "REDUCE", "50", "VALUES", "75", "170.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 3 }, "CAS", "NOQUANT", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res6); + + // Add without specifying M after first vector + var exc2 = ClassicAssert.Throws(() => db.Execute("VADD", ["fizz", "REDUCE", "50", "VALUES", "75", "180.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 4 }, "CAS", "NOQUANT", "EF", "16"])); + ClassicAssert.AreEqual("ERR asked M value mismatch with existing vector set", exc2.Message); + + // Mismatch vector size for projection + var exc3 = ClassicAssert.Throws(() => db.Execute("VADD", ["fizz", "REDUCE", "50", "VALUES", "5", "1.0", "2.0", "3.0", "4.0", "5.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"])); + ClassicAssert.AreEqual("ERR Vector dimension mismatch - got 5 but set has 75", exc3.Message); + } + + [Test] + public void VADDVariableLengthElementIds() + { + const int MinElementLength = 1; + const int MaxElementLength = 1024; + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Always put a 0 length in as a stress test + List ids = [[]]; + for (var len = MinElementLength; len <= MaxElementLength; len *= 2) + { + ids.Add(Enumerable.Range(0, len).Select(_ => (byte)len).ToArray()); + } + + foreach (var id in ids) + { + var addRes = (int)db.Execute("VADD", ["foo", "VALUES", "1", ((float)(byte)id.Length).ToString(), id, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes); + } + + foreach (var id in ids) + { + var embRes = (string[])db.Execute("VEMB", ["foo", id]); + ClassicAssert.AreEqual(1, embRes.Length); + ClassicAssert.AreEqual((float)(byte)id.Length, float.Parse(embRes[0])); + } + } + + [Test] + public void VADDXPREQB8() + { + // Extra validation is required for this extension quantifier + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Build byte array for vector data (75 bytes) + var vectorData1 = new byte[75]; + vectorData1[0] = 1; + for (var i = 1; i < vectorData1.Length; i++) + { + vectorData1[i] = (byte)(vectorData1[i - 1] + 1); + } + + var vectorData2 = new byte[75]; + vectorData2[0] = 100; + for (var i = 1; i < vectorData2.Length; i++) + { + vectorData2[i] = (byte)(vectorData2[i - 1] + 1); + } + + // Small vector for REDUCE test + var smallVectorData = new byte[4]; + for (var i = 0; i < smallVectorData.Length; i++) + { + smallVectorData[i] = (byte)(i + 1); + } + + // REDUCE not allowed with XPREQ8 + var exc1 = ClassicAssert.Throws(() => db.Execute("VADD", ["fizz", "REDUCE", "2", "XB8", smallVectorData, new byte[] { 0, 0, 0, 0 }, "XPREQ8"])); + ClassicAssert.AreEqual("ERR asked quantization mismatch with existing vector set", exc1.Message); + + // Create a vector set with XB8 + XPREQ8 + var res1 = db.Execute("VADD", ["fizz", "XB8", vectorData1, new byte[] { 0, 0, 0, 0 }, "XPREQ8"]); + ClassicAssert.AreEqual(1, (int)res1); + + // Add another element + var res2 = db.Execute("VADD", ["fizz", "XB8", vectorData2, new byte[] { 0, 0, 0, 1 }, "XPREQ8"]); + ClassicAssert.AreEqual(1, (int)res2); + + // Verify the vector was stored correctly + var embRes = (string[])db.Execute("VEMB", ["fizz", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(75, embRes.Length); + for (var i = 0; i < embRes.Length; i++) + { + ClassicAssert.AreEqual((float)vectorData1[i], float.Parse(embRes[i])); + } + } + + [Test] + public void VADDErrors() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var vectorSetKey = $"{nameof(VADDErrors)}_{Guid.NewGuid()}"; + + // Bad arity + var exc1 = ClassicAssert.Throws(() => db.Execute("VADD")); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc1.Message); + var exc2 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey])); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc2.Message); + var exc3 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "FP32"])); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc3.Message); + var exc4 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES"])); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc4.Message); + var exc5 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1"])); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc5.Message); + var exc6 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "1.0"])); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc6.Message); + + // Reduce after vector + var exc7 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "2", "1.0", "2.0", "bar", "REDUCE", "1"])); + ClassicAssert.AreEqual("ERR invalid option after element", exc7.Message); + + // Duplicate flags + // TODO: Redis doesn't error on these which seems... wrong, confirm with them + //var exc8 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "CAS", "CAS"])); + //var exc9 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "NOQUANT", "Q8"])); + //var exc10 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "EF", "1", "EF", "1"])); + //var exc11 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "SETATTR", "abc", "SETATTR", "abc"])); + //var exc12 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "M", "5", "M", "5"])); + + // M out of range (Redis imposes M >= 4 and m <= 4096 + var exc13 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "M", "1"])); + ClassicAssert.AreEqual("ERR invalid M", exc13.Message); + var exc14 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "M", "10000"])); + ClassicAssert.AreEqual("ERR invalid M", exc14.Message); + + // Missing/bad option value + var exc20 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "EF"])); + ClassicAssert.AreEqual("ERR invalid option after element", exc20.Message); + var exc21 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "EF", "0"])); + ClassicAssert.AreEqual("ERR invalid EF", exc21.Message); + var exc22 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "SETATTR"])); + ClassicAssert.AreEqual("ERR invalid option after element", exc22.Message); + var exc23 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "M"])); + ClassicAssert.AreEqual("ERR invalid option after element", exc23.Message); + var exc24 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "2", "2.0", "bar"])); + ClassicAssert.AreEqual("ERR invalid vector specification", exc24.Message); + var exc25 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "0", "bar"])); + ClassicAssert.AreEqual("ERR invalid vector specification", exc25.Message); + var exc26 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "fizz", "bar"])); + ClassicAssert.AreEqual("ERR invalid vector specification", exc26.Message); + + // Unknown option + var exc27 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "FOO"])); + ClassicAssert.AreEqual("ERR invalid option after element", exc27.Message); + + // Malformed FP32 + var binary = new float[] { 1, 2, 3 }; + var blob = MemoryMarshal.Cast(binary)[..^1].ToArray(); + var exc15 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "FP32", blob, "bar"])); + ClassicAssert.AreEqual("ERR invalid vector specification", exc15.Message); + + // Mismatch after creating a vector set + _ = db.KeyDelete(vectorSetKey); + + _ = db.Execute("VADD", [vectorSetKey, "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 1, 0 }, "NOQUANT", "EF", "6", "M", "10"]); + + var exc16 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "2", "1.0", "2.0", "fizz", "NOQUANT", "EF", "6", "M", "10"])); + ClassicAssert.AreEqual("ERR Vector dimension mismatch - got 2 but set has 75", exc16.Message); + var exc17 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "fizz", "XPREQ8", "EF", "6", "M", "10"])); + ClassicAssert.AreEqual("ERR asked quantization mismatch with existing vector set", exc17.Message); + var exc18 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "fizz", "NOQUANT", "EF", "12", "M", "20"])); + ClassicAssert.AreEqual("ERR asked M value mismatch with existing vector set", exc18.Message); + + // TODO: Redis doesn't appear to validate attributes... so that's weird + + // Empty Vector Set keys are forbidden (TODO: Remove this constraint) + var exc19 = ClassicAssert.Throws(() => db.Execute("VADD", ["", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "XPREQ8"])); + ClassicAssert.AreEqual("ERR Vector Set key cannot be empty", exc19.Message); + + // Unsupported quantization types (Q8 and BIN are not yet supported) + var exc28 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar"])); + ClassicAssert.AreEqual("ERR Unsupported quantization type", exc28.Message); + var exc29 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "Q8"])); + ClassicAssert.AreEqual("ERR Unsupported quantization type", exc29.Message); + var exc30 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "BIN"])); + ClassicAssert.AreEqual("ERR Unsupported quantization type", exc30.Message); + } + + [Test] + public void VEMB_FP32Storage() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // Add a vector using VALUES format with NOQUANT (FP32 storage) + var res1 = db.Execute("VADD", ["foo", "VALUES", "8", "1.0", "2.0", "3.0", "4.0", "5.0", "6.0", "7.0", "8.0", new byte[] { 0, 0, 0, 0 }, "NOQUANT"]); + ClassicAssert.AreEqual(1, (int)res1); + + // Add a vector using XB8 format with NOQUANT (FP32 storage) + byte[] vectorBytes = new byte[8]; + for (int i = 0; i < 8; i++) + { + vectorBytes[i] = (byte)(i + 10); + } + + var res2 = db.Execute("VADD", ["foo", "XB8", vectorBytes, new byte[] { 0, 0, 0, 2 }, "NOQUANT"]); + ClassicAssert.AreEqual(1, (int)res2); + + // Verify VEMB for XB8 input vector + var res3 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 2 }]); + ClassicAssert.AreEqual(8, res3.Length); + for (var i = 0; i < 8; i++) + { + ClassicAssert.AreEqual((float)vectorBytes[i], float.Parse(res3[i])); + } + + // Verify VEMB for VALUES input vector + var res4 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(8, res4.Length); + for (var i = 0; i < 8; i++) + { + ClassicAssert.AreEqual((float)(i + 1), float.Parse(res4[i])); + } + + // Verify non-existent element returns empty + var res5 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 1 }]); + ClassicAssert.AreEqual(0, res5.Length); + } + + [Test] + public void VEMB_BinaryStorage() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // Add a vector using VALUES format with XPREQ8 (Binary storage) + var res1 = db.Execute("VADD", ["foo", "VALUES", "8", "1.0", "2.0", "3.0", "4.0", "5.0", "6.0", "7.0", "8.0", new byte[] { 0, 0, 0, 0 }, "XPREQ8"]); + ClassicAssert.AreEqual(1, (int)res1); + + // Add a vector using XB8 format with XPREQ8 (Binary storage) + //byte[] vectorBytes = new byte[8]; + //for (int i = 0; i < 8; i++) + //{ + // vectorBytes[i] = (byte)(i + 10); + //} + + //var res2 = db.Execute("VADD", ["foo", "XB8", vectorBytes, new byte[] { 0, 0, 0, 2 }, "XPREQ8"]); + //ClassicAssert.AreEqual(1, (int)res2); + + // Verify VEMB for XB8 input vector + //var res3 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 2 }]); + //ClassicAssert.AreEqual(8, res3.Length); + //for (var i = 0; i < 8; i++) + //{ + //ClassicAssert.AreEqual((float)vectorBytes[i], float.Parse(res3[i])); + //} + + // Verify VEMB for VALUES input vector - should return the original float values + var res4 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(8, res4.Length); + for (var i = 0; i < 8; i++) + { + ClassicAssert.AreEqual((float)(i + 1), float.Parse(res4[i])); + } + + // Verify non-existent element returns empty + var res5 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 1 }]); + ClassicAssert.AreEqual(0, res5.Length); + } + + [Test] + public void VEMB_RespVersions() + { + // VEMB response format depends on the RESP version used: + // - Resp3: Array of doubles + // - Resp2: Array of bulk strings + + var vectorSetKey = "foo"; + var elementId = new byte[] { 0, 0, 0, 0 }; + + // Create a vector set with known values + using var redisResp3 = ConnectionMultiplexer.Connect(TestUtils.GetConfig(protocol: RedisProtocol.Resp3)); + var dbResp3 = redisResp3.GetDatabase(); + + var addRes = dbResp3.Execute("VADD", [vectorSetKey, "VALUES", "4", "1.0", "2.0", "3.0", "4.0", elementId, "NOQUANT"]); + ClassicAssert.AreEqual(1, (int)addRes); + + // Test RESP3 response - should be array of doubles + var resp3Result = dbResp3.Execute("VEMB", [vectorSetKey, elementId]); + ClassicAssert.IsFalse(resp3Result.IsNull); + ClassicAssert.AreEqual(ResultType.Array, resp3Result.Resp3Type); + + var resp3Array = (RedisValue[])resp3Result; + ClassicAssert.AreEqual(4, resp3Array.Length); + for (var i = 0; i < resp3Array.Length; i++) + { + // In RESP3, the values should be doubles that can be directly cast + ClassicAssert.AreEqual((double)(i + 1), (double)resp3Array[i]); + } + + // Test RESP2 response - should be array of bulk strings + using var redisResp2 = ConnectionMultiplexer.Connect(TestUtils.GetConfig(protocol: RedisProtocol.Resp2)); + var dbResp2 = redisResp2.GetDatabase(); + + var resp2Result = dbResp2.Execute("VEMB", [vectorSetKey, elementId]); + ClassicAssert.IsFalse(resp2Result.IsNull); + ClassicAssert.AreEqual(ResultType.Array, resp2Result.Resp2Type); + + var resp2Array = (RedisValue[])resp2Result; + ClassicAssert.AreEqual(4, resp2Array.Length); + for (var i = 0; i < resp2Array.Length; i++) + { + // In RESP2, the values are bulk strings that need parsing + ClassicAssert.AreEqual((float)(i + 1), float.Parse((string)resp2Array[i])); + } + + // Test not found case - both should return empty array + var nonExistentElementId = new byte[] { 9, 9, 9, 9 }; + + var resp3NotFound = dbResp3.Execute("VEMB", [vectorSetKey, nonExistentElementId]); + ClassicAssert.AreEqual(ResultType.Array, resp3NotFound.Resp3Type); + ClassicAssert.AreEqual(0, ((RedisValue[])resp3NotFound).Length); + + var resp2NotFound = dbResp2.Execute("VEMB", [vectorSetKey, nonExistentElementId]); + ClassicAssert.AreEqual(ResultType.Array, resp2NotFound.Resp2Type); + ClassicAssert.AreEqual(0, ((RedisValue[])resp2NotFound).Length); + } + + [Test] + public void VectorSetOpacity() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = ClassicAssert.Throws(() => db.StringGet("foo")); + ClassicAssert.True(res2.Message.Contains("WRONGTYPE")); + } + + [Test] + public void VectorElementOpacity() + { + // Check that we can't touch an element with GET despite it also being in the main store + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = (string)db.StringGet(new byte[] { 0, 0, 0, 0 }); + ClassicAssert.IsNull(res2); + + var res3 = db.KeyDelete(new byte[] { 0, 0, 0, 0 }); + ClassicAssert.IsFalse(res3); + + var res4 = db.StringSet(new byte[] { 0, 0, 0, 0 }, "def", when: When.NotExists); + ClassicAssert.IsTrue(res4); + + Span buffer = stackalloc byte[128]; + + // Check we haven't messed up the element + var res7 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(75, res7.Length); + for (var i = 0; i < res7.Length; i++) + { + var expected = + (i % 4) switch + { + 0 => float.Parse("1.0"), + 1 => float.Parse("2.0"), + 2 => float.Parse("3.0"), + 3 => float.Parse("4.0"), + _ => throw new InvalidOperationException(), + }; + + ClassicAssert.AreEqual(expected, float.Parse(res7[i])); + } + } + + [Test] + public void VSIM() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "100.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 1 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res2); + + var res3 = (byte[][])db.Execute("VSIM", ["foo", "VALUES", "75", "110.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "COUNT", "5", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(2, res3.Length); + ClassicAssert.IsTrue(res3.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res3.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + + var res4 = (byte[][])db.Execute("VSIM", ["foo", "ELE", new byte[] { 0, 0, 0, 0 }, "COUNT", "5", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(2, res4.Length); + ClassicAssert.IsTrue(res4.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res4.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + + // FP32 + var float5 = new float[75]; + float5[0] = 3; + for (var i = 1; i < float5.Length; i++) + { + float5[i] = float5[i - 1] + 0.1f; + } + var res5 = (byte[][])db.Execute("VSIM", ["foo", "FP32", MemoryMarshal.Cast(float5).ToArray(), "COUNT", "5", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(2, res5.Length); + ClassicAssert.IsTrue(res5.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res5.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + + // XB8 + var byte6 = new byte[75]; + byte6[0] = 10; + for (var i = 1; i < byte6.Length; i++) + { + byte6[i] = (byte)(byte6[i - 1] + 1); + } + var res6 = (byte[][])db.Execute("VSIM", ["foo", "XB8", byte6, "COUNT", "5", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(2, res6.Length); + ClassicAssert.IsTrue(res6.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res6.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + + // COUNT > EF + var byte7 = new byte[75]; + byte7[0] = 20; + for (var i = 1; i < byte7.Length; i++) + { + byte7[i] = (byte)(byte7[i - 1] + 1); + } + var res7 = (byte[][])db.Execute("VSIM", ["foo", "XB8", byte7, "COUNT", "100", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(2, res7.Length); + ClassicAssert.IsTrue(res7.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res7.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + + // WITHSCORES + var res8 = (byte[][])db.Execute("VSIM", ["foo", "XB8", byte7, "COUNT", "100", "EPSILON", "1.0", "EF", "40", "WITHSCORES"]); + ClassicAssert.AreEqual(4, res8.Length); + ClassicAssert.IsTrue(res8.Where(static (x, ix) => (ix % 2) == 0).Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res8.Where(static (x, ix) => (ix % 2) == 0).Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + ClassicAssert.IsFalse(double.IsNaN(double.Parse(Encoding.UTF8.GetString(res8[1])))); + ClassicAssert.IsFalse(double.IsNaN(double.Parse(Encoding.UTF8.GetString(res8[3])))); + + // Large Count + var res9 = (byte[][])db.Execute("VSIM", ["foo", "XB8", byte7, "COUNT", "1000"]); + ClassicAssert.AreEqual(2, res9.Length); + ClassicAssert.IsTrue(res9.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res9.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + } + + [Test] + public void VSIMWithAttribs() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "100.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 1 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "fizz buzz"]); + ClassicAssert.AreEqual(1, (int)res2); + + // Equivalent to no attribute + var res3 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "110.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 2 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", ""]); + ClassicAssert.AreEqual(1, (int)res3); + + // Actually no attribute + var res4 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "120.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 3 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res4); + + // Very long attribute + var bigAttr = Enumerable.Repeat((byte)'a', 1_024).ToArray(); + var res5 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "130.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 4 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", bigAttr]); + ClassicAssert.AreEqual(1, (int)res5); + + var res6 = (byte[][])db.Execute("VSIM", ["foo", "VALUES", "75", "140.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "COUNT", "5", "EPSILON", "1.0", "EF", "40", "WITHATTRIBS"]); + ClassicAssert.AreEqual(10, res6.Length); + for (var i = 0; i < res6.Length; i += 2) + { + var id = res6[i]; + var attr = res6[i + 1]; + + if (id.SequenceEqual(new byte[] { 0, 0, 0, 0 })) + { + ClassicAssert.True(attr.SequenceEqual("hello world"u8.ToArray())); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 1 })) + { + ClassicAssert.True(attr.SequenceEqual("fizz buzz"u8.ToArray())); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 2 })) + { + ClassicAssert.AreEqual(0, attr.Length); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 3 })) + { + ClassicAssert.AreEqual(0, attr.Length); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 4 })) + { + ClassicAssert.True(bigAttr.SequenceEqual(attr)); + } + else + { + ClassicAssert.Fail("Unexpected id"); + } + } + + // WITHSCORES + var res7 = (byte[][])db.Execute("VSIM", ["foo", "VALUES", "75", "140.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "COUNT", "5", "EPSILON", "1.0", "EF", "40", "WITHATTRIBS", "WITHSCORES"]); + ClassicAssert.AreEqual(15, res7.Length); + for (var i = 0; i < res7.Length; i += 3) + { + var id = res7[i]; + var score = double.Parse(Encoding.UTF8.GetString(res7[i + 1])); + var attr = res7[i + 2]; + + ClassicAssert.IsFalse(double.IsNaN(score)); + + if (id.SequenceEqual(new byte[] { 0, 0, 0, 0 })) + { + ClassicAssert.True(attr.SequenceEqual("hello world"u8.ToArray())); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 1 })) + { + ClassicAssert.True(attr.SequenceEqual("fizz buzz"u8.ToArray())); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 2 })) + { + ClassicAssert.AreEqual(0, attr.Length); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 3 })) + { + ClassicAssert.AreEqual(0, attr.Length); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 4 })) + { + ClassicAssert.True(bigAttr.SequenceEqual(attr)); + } + else + { + ClassicAssert.Fail("Unexpected id"); + } + } + } + + [Test] + public void VDIM() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "3", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VDIM", "foo"); + ClassicAssert.AreEqual(3, (int)res2); + + var res3 = db.Execute("VADD", ["bar", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res3); + + var res4 = db.Execute("VDIM", "bar"); + ClassicAssert.AreEqual(75, (int)res4); + + var exc1 = ClassicAssert.Throws(() => db.Execute("VDIM", "fizz")); + ClassicAssert.IsTrue(exc1.Message.Contains("Key not found")); + + // TODO: Add WRONGTYPE behavior check once implemented + } + + [Test] + public void DeleteVectorSet() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "3", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.KeyDelete("foo"); + ClassicAssert.IsTrue(res2); + + var res3 = db.Execute("VADD", ["fizz", "REDUCE", "3", "VALUES", "75", "100.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res3); + + var res4 = db.StringSet("buzz", "abc"); + ClassicAssert.IsTrue(res4); + + var res5 = db.KeyDelete(["fizz", "buzz"]); + ClassicAssert.AreEqual(2, res5); + } + + [Test] + public void InteterruptedVectorSetDelete_AfterMark() + => InterruptedVectorSetDelete(ExceptionInjectionType.VectorSet_Interrupt_Delete_0); + + [Test] + public void InterruptedVectorSetDelete_AfterZeroingOut() + => InterruptedVectorSetDelete(ExceptionInjectionType.VectorSet_Interrupt_Delete_1); + + [Test] + public void InterruptedVectorSetDelete_AfterDelete() + => InterruptedVectorSetDelete(ExceptionInjectionType.VectorSet_Interrupt_Delete_2); + + private void InterruptedVectorSetDelete(ExceptionInjectionType faultLocation) + { +#if !DEBUG + ClassicAssert.Ignore("Relies on ExceptionInjectionHelper, disable in non-DEBUG"); +#endif + + var key = $"{nameof(InterruptedVectorSetDelete)}_{faultLocation}"; + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig())) + { + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", [key, "REDUCE", "3", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + // TODO: we could use EXISTS here... except not all non-Vector Set commands understand Vector Sets, so that's a bit flaky + ExceptionInjectionHelper.EnableException(faultLocation); + try + { + _ = ClassicAssert.Throws(() => db.KeyDelete(key)); + } + finally + { + ExceptionInjectionHelper.DisableException(faultLocation); + } + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig())) + { + var db = redis.GetDatabase(); + + var deleteWasEffective = false; + + try + { + _ = (string)db.StringGet(key); + deleteWasEffective = true; + } + catch + { + } + + var vectorSetCommands = Enum.GetValues().Where(static x => x.IsLegalOnVectorSet() && x is not (RespCommand.DEL or RespCommand.UNLINK or RespCommand.TYPE or RespCommand.DEBUG or RespCommand.RENAME or RespCommand.RENAMENX)).OrderBy(static x => x); + + if (!deleteWasEffective) + { + // Check that all Vector Set commands on a partially deleted vector set give a reasonable error message OR succeed + // + // Success is possible if the delete failed early enough that we didn't actually being a "real" delete + // + // Such cases leave some trash around, but it'll be cleaned up either at restart or the next time a Vector Set is really deleted + foreach (var cmd in vectorSetCommands) + { + RedisServerException exc = null; + switch (cmd) + { + case RespCommand.VADD: + try + { + var res = db.Execute("VADD", [key, "REDUCE", "3", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 1 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res); + } + catch (RedisServerException e) + { + exc = e; + } + break; + case RespCommand.VCARD: + // TODO: Implement once VCARD is implemented + continue; + case RespCommand.VDIM: + try + { + var res = db.Execute("VDIM", [key]); + ClassicAssert.AreEqual(3, (int)res); + } + catch (RedisServerException e) + { + exc = e; + } + break; + case RespCommand.VEMB: + try + { + var res = (string[])db.Execute("VEMB", [key, new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(75, res.Length); + } + catch (RedisServerException e) + { + exc = e; + } + break; + case RespCommand.VGETATTR: + try + { + var res = db.Execute("VGETATTR", [key, "wololo"]); + ClassicAssert.IsTrue(res.IsNull); + } + catch (RedisServerException e) + { + exc = e; + } + break; + case RespCommand.VINFO: + try + { + var res = (RedisValue[])db.Execute("VINFO", [key]); + ClassicAssert.AreEqual(14, res.Length); + } + catch (RedisServerException e) + { + exc = e; + } + break; + case RespCommand.VISMEMBER: + // TODO: Implement once VISMEMBER is implemented + continue; + case RespCommand.VLINKS: + // TODO: Implement once VLINKS is implemented + continue; + case RespCommand.VRANDMEMBER: + // TODO: Implement once VRANDMEMBER is implemented + continue; + case RespCommand.VREM: + try + { + var res = db.Execute("VREM", [key, new byte[] { 0, 0, 0, 5 }]); + ClassicAssert.AreEqual(0, (int)res); + } + catch (RedisServerException e) + { + exc = e; + } + break; + case RespCommand.VSETATTR: + // TODO: Implement once VSETATTR is implemented + continue; + case RespCommand.VSIM: + try + { + var res = (byte[][])db.Execute("VSIM", [key, "ELE", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.IsTrue(res.Length > 0); + } + catch (RedisServerException e) + { + exc = e; + } + break; + default: + Assert.Fail($"No test for command: {cmd}"); + return; + } + + if (exc != null) + { + ClassicAssert.AreEqual("ERR Vector Set is in a partially deleted state - re-execute DEL to complete deletion", exc.Message, $"For command: {cmd}"); + } + } + + // Delete again, this time we'll succeed + var delRes = db.KeyDelete(key); + ClassicAssert.IsTrue(delRes); + } + + // Now accessing the key should give a null, no matter what happened + var res2 = (string)db.StringGet(key); + ClassicAssert.IsNull(res2); + } + } + + [Test] + public void InteterruptedVectorSetDelete_AfterMark_Recovery() + => InterruptedVectorSetDeleteRecovery(ExceptionInjectionType.VectorSet_Interrupt_Delete_0); + + [Test] + public void InteterruptedVectorSetDelete_AfterZeroingOut_Recovery() + => InterruptedVectorSetDeleteRecovery(ExceptionInjectionType.VectorSet_Interrupt_Delete_1); + + [Test] + public void InteterruptedVectorSetDelete_AfterDelete_Recovery() + => InterruptedVectorSetDeleteRecovery(ExceptionInjectionType.VectorSet_Interrupt_Delete_2); + + private void InterruptedVectorSetDeleteRecovery(ExceptionInjectionType faultLocation) + { +#if !DEBUG + ClassicAssert.Ignore("Relies on ExceptionInjectionHelper, disable in non-DEBUG"); +#endif + + var key = $"{nameof(InterruptedVectorSetDeleteRecovery)}_{faultLocation}"; + + // Create a partially deleted Vector Set, then take a checkpoint and shutdown + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res = db.Execute("VADD", [key, "REDUCE", "3", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res); + + ExceptionInjectionHelper.EnableException(faultLocation); + try + { + _ = ClassicAssert.Throws(() => db.KeyDelete(key)); + } + finally + { + ExceptionInjectionHelper.DisableException(faultLocation); + } + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + } + + // Restart Garnet, which should block applying any pending Vector Set deletes + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + + // Validate that Vector Set index key is gone, even if no Vector Set command ran + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + // Now accessing the key should give a null OR a WRONGTYPE (that still has data) if delete didn't get particularly far + try + { + var res = (string)db.StringGet(key); + ClassicAssert.IsNull(res); + } + catch (RedisServerException exc) + { + ClassicAssert.IsTrue(exc.Message.StartsWith("WRONGTYPE ")); + + // If the value still exists, the Vector Set needs to still work + var res = (byte[][])db.Execute("VSIM", [key, "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0"]); + ClassicAssert.AreEqual(1, res.Length); + } + } + } + + [Test] + public void RepeatedVectorSetDeletes() + { + var bytes1 = new byte[75]; + var bytes2 = new byte[75]; + var bytes3 = new byte[75]; + bytes1[0] = 1; + bytes2[0] = 75; + bytes3[0] = 128; + for (var i = 1; i < bytes1.Length; i++) + { + bytes1[i] = (byte)(bytes1[i - 1] + 1); + bytes2[i] = (byte)(bytes2[i - 1] + 1); + bytes3[i] = (byte)(bytes3[i - 1] + 1); + } + + for (var i = 0; i < 1_000; i++) + { + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig())) + { + var db = redis.GetDatabase(); + + var delRes = (int)db.Execute("DEL", ["foo"]); + + if (i != 0) + { + ClassicAssert.AreEqual(1, delRes); + } + else + { + ClassicAssert.AreEqual(0, delRes); + } + + var addRes1 = (int)db.Execute("VADD", ["foo", "XB8", bytes1, new byte[] { 0, 0, 0, 0 }, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes1); + + var addRes2 = (int)db.Execute("VADD", ["foo", "XB8", bytes2, new byte[] { 0, 0, 0, 1 }, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes2); + + var readExc = ClassicAssert.Throws(() => db.Execute("GET", ["foo"])); + ClassicAssert.IsTrue(readExc.Message.Equals("WRONGTYPE Operation against a key holding the wrong kind of value."), $"In iteration: {i}"); + } + + // After an exception, get a clean connection + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig())) + { + var db = redis.GetDatabase(); + + var query = (byte[][])db.Execute("VSIM", ["foo", "XB8", bytes3]); + + if (query is null) + { + try + { + var res = db.Execute("FOO"); + Console.WriteLine($"After unexpected null, got: {res}"); + } + catch { } + } + else if (query.Length != 2) + { + Console.WriteLine($"Wrong length {query.Length} != 2 response was"); + for (var j = 0; j < query.Length; j++) + { + var txt = Encoding.UTF8.GetString(query[j]); + Console.WriteLine("---"); + Console.WriteLine(txt); + } + } + + ClassicAssert.AreEqual(2, query.Length, $"In iteration: {i}"); + } + } + } + + [Test] + public unsafe void VectorReadBatchVariants() + { + // Single key, 4 byte keys + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var data = new int[] { 4, 1234 }; + var dataCopy = data.ToArray(); + fixed (int* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length * sizeof(int)); + using var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 64, 1, keyData); + + var iters = 0; + for (var i = 0; i < batch.Count; i++) + { + iters++; + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(64, keyCopy.GetNamespaceInPayload()); + ClassicAssert.IsTrue(keyCopy.AsReadOnlySpan().SequenceEqual(MemoryMarshal.Cast(data.AsSpan().Slice(1, 1)))); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + + ClassicAssert.AreEqual(1, iters); + } + ClassicAssert.IsTrue(dataCopy.SequenceEqual(data)); + } + + // Multiple keys, 4 byte keys + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var data = new int[] { 4, 1234, 4, 5678, 4, 0123, 4, 9999, 4, 0000, 4, int.MaxValue, 4, int.MinValue }; + var dataCopy = data.ToArray(); + fixed (int* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length * sizeof(int)); + using var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 32, 7, keyData); + + var iters = 0; + for (var i = 0; i < batch.Count; i++) + { + iters++; + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(32, keyCopy.GetNamespaceInPayload()); + + var offset = i * 2 + 1; + var keyCopyData = keyCopy.AsReadOnlySpan(); + var expectedData = MemoryMarshal.Cast(data.AsSpan().Slice(offset, 1)); + ClassicAssert.IsTrue(keyCopyData.SequenceEqual(expectedData)); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + + ClassicAssert.AreEqual(7, iters); + } + ClassicAssert.IsTrue(dataCopy.SequenceEqual(data)); + } + + // Multiple keys, 4 byte keys, random order + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var data = new int[] { 4, 1234, 4, 5678, 4, 0123, 4, 9999, 4, 0000, 4, int.MaxValue, 4, int.MinValue }; + var dataCopy = data.ToArray(); + fixed (int* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length * sizeof(int)); + using var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 16, 7, keyData); + + var rand = new Random(2025_10_06_00); + + for (var j = 0; j < 1_000; j++) + { + var i = rand.Next(batch.Count); + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(16, keyCopy.GetNamespaceInPayload()); + + var offset = i * 2 + 1; + var keyCopyData = keyCopy.AsReadOnlySpan(); + var expectedData = MemoryMarshal.Cast(data.AsSpan().Slice(offset, 1)); + ClassicAssert.IsTrue(keyCopyData.SequenceEqual(expectedData)); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + } + ClassicAssert.IsTrue(dataCopy.SequenceEqual(data)); + } + + // Single key, variable length + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var key0 = "hello"u8.ToArray(); + var data = + MemoryMarshal.Cast([key0.Length]) + .ToArray() + .Concat(key0) + .ToArray(); + var dataCopy = data.ToArray(); + fixed (byte* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length); + using var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 8, 1, keyData); + + var iters = 0; + for (var i = 0; i < batch.Count; i++) + { + iters++; + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + var expectedLength = + i switch + { + 0 => key0.Length, + _ => throw new InvalidOperationException("Unexpected index"), + }; + var expectedStart = + i switch + { + 0 => 0 + 1 * sizeof(int), + _ => throw new InvalidOperationException("Unexpected index"), + }; + + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(8, keyCopy.GetNamespaceInPayload()); + var keyCopyData = keyCopy.AsReadOnlySpan(); + var expectedData = data.AsSpan().Slice(expectedStart, expectedLength); + ClassicAssert.IsTrue(expectedData.SequenceEqual(keyCopyData)); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + + ClassicAssert.AreEqual(1, iters); + } + ClassicAssert.IsTrue(dataCopy.SequenceEqual(data)); + } + + // Multiple keys, variable length + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var key0 = "hello"u8.ToArray(); + var key1 = "fizz"u8.ToArray(); + var key2 = "the quick brown fox jumps over the lazy dog"u8.ToArray(); + var key3 = "CF29E323-E376-4BC4-AB63-FCFD371EB445"u8.ToArray(); + var key4 = Array.Empty(); + var key5 = new byte[] { 1 }; + var key6 = new byte[] { 2, 3 }; + var key7 = new byte[] { 4, 5, 6 }; + var data = + MemoryMarshal.Cast([key0.Length]) + .ToArray() + .Concat(key0) + .Concat( + MemoryMarshal.Cast([key1.Length]).ToArray() + ) + .Concat( + key1 + ) + .Concat( + MemoryMarshal.Cast([key2.Length]).ToArray() + ) + .Concat( + key2 + ) + .Concat( + MemoryMarshal.Cast([key3.Length]).ToArray() + ) + .Concat( + key3 + ) + .Concat( + MemoryMarshal.Cast([key4.Length]).ToArray() + ) + .Concat( + key4 + ) + .Concat( + MemoryMarshal.Cast([key5.Length]).ToArray() + ) + .Concat( + key5 + ) + .Concat( + MemoryMarshal.Cast([key6.Length]).ToArray() + ) + .Concat( + key6 + ) + .Concat( + MemoryMarshal.Cast([key7.Length]).ToArray() + ) + .Concat( + key7 + ) + .ToArray(); + var dataCopy = data.ToArray(); + fixed (byte* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length); + using var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 4, 8, keyData); + + var iters = 0; + for (var i = 0; i < batch.Count; i++) + { + iters++; + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + var expectedLength = + i switch + { + 0 => key0.Length, + 1 => key1.Length, + 2 => key2.Length, + 3 => key3.Length, + 4 => key4.Length, + 5 => key5.Length, + 6 => key6.Length, + 7 => key7.Length, + _ => throw new InvalidOperationException("Unexpected index"), + }; + var expectedStart = + i switch + { + 0 => 0 + 1 * sizeof(int), + 1 => key0.Length + 2 * sizeof(int), + 2 => key0.Length + key1.Length + 3 * sizeof(int), + 3 => key0.Length + key1.Length + key2.Length + 4 * sizeof(int), + 4 => key0.Length + key1.Length + key2.Length + key3.Length + 5 * sizeof(int), + 5 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + 6 * sizeof(int), + 6 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + key5.Length + 7 * sizeof(int), + 7 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + key5.Length + key6.Length + 8 * sizeof(int), + _ => throw new InvalidOperationException("Unexpected index"), + }; + + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(4, keyCopy.GetNamespaceInPayload()); + var keyCopyData = keyCopy.AsReadOnlySpan(); + var expectedData = data.AsSpan().Slice(expectedStart, expectedLength); + ClassicAssert.IsTrue(expectedData.SequenceEqual(keyCopyData)); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + + ClassicAssert.AreEqual(8, iters); + } + ClassicAssert.IsTrue(dataCopy.SequenceEqual(data)); + } + + // Multiple keys, variable length, random access + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var key0 = "hello"u8.ToArray(); + var key1 = "fizz"u8.ToArray(); + var key2 = "the quick brown fox jumps over the lazy dog"u8.ToArray(); + var key3 = "CF29E323-E376-4BC4-AB63-FCFD371EB445"u8.ToArray(); + var key4 = Array.Empty(); + var key5 = new byte[] { 1 }; + var key6 = new byte[] { 2, 3 }; + var key7 = new byte[] { 4, 5, 6 }; + var data = + MemoryMarshal.Cast([key0.Length]) + .ToArray() + .Concat(key0) + .Concat( + MemoryMarshal.Cast([key1.Length]).ToArray() + ) + .Concat( + key1 + ) + .Concat( + MemoryMarshal.Cast([key2.Length]).ToArray() + ) + .Concat( + key2 + ) + .Concat( + MemoryMarshal.Cast([key3.Length]).ToArray() + ) + .Concat( + key3 + ) + .Concat( + MemoryMarshal.Cast([key4.Length]).ToArray() + ) + .Concat( + key4 + ) + .Concat( + MemoryMarshal.Cast([key5.Length]).ToArray() + ) + .Concat( + key5 + ) + .Concat( + MemoryMarshal.Cast([key6.Length]).ToArray() + ) + .Concat( + key6 + ) + .Concat( + MemoryMarshal.Cast([key7.Length]).ToArray() + ) + .Concat( + key7 + ) + .ToArray(); + var dataCopy = data.ToArray(); + fixed (byte* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length); + using var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 4, 8, keyData); + + var rand = new Random(2025_10_06_01); + + for (var j = 0; j < 1_000; j++) + { + var i = rand.Next(batch.Count); + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + var expectedLength = + i switch + { + 0 => key0.Length, + 1 => key1.Length, + 2 => key2.Length, + 3 => key3.Length, + 4 => key4.Length, + 5 => key5.Length, + 6 => key6.Length, + 7 => key7.Length, + _ => throw new InvalidOperationException("Unexpected index"), + }; + var expectedStart = + i switch + { + 0 => 0 + 1 * sizeof(int), + 1 => key0.Length + 2 * sizeof(int), + 2 => key0.Length + key1.Length + 3 * sizeof(int), + 3 => key0.Length + key1.Length + key2.Length + 4 * sizeof(int), + 4 => key0.Length + key1.Length + key2.Length + key3.Length + 5 * sizeof(int), + 5 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + 6 * sizeof(int), + 6 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + key5.Length + 7 * sizeof(int), + 7 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + key5.Length + key6.Length + 8 * sizeof(int), + _ => throw new InvalidOperationException("Unexpected index"), + }; + + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(4, keyCopy.GetNamespaceInPayload()); + var keyCopyData = keyCopy.AsReadOnlySpan(); + var expectedData = data.AsSpan().Slice(expectedStart, expectedLength); + ClassicAssert.IsTrue(expectedData.SequenceEqual(keyCopyData)); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + } + ClassicAssert.IsTrue(dataCopy.SequenceEqual(data)); + } + } + + [Test] + public unsafe void MarkWithNamespace() + { + var data = new int[] { 4, 1234 }; + var dataCopy = data.ToArray(); + fixed (int* intPtr = data) + { + var bytePtr = (byte*)intPtr; + var span = VectorManager.MarkDiskANNKeyWithNamespace(8, (nint)(bytePtr + 4), 4); + ClassicAssert.AreEqual(8, span.GetNamespaceInPayload()); + ClassicAssert.AreEqual(1234, *(int*)span.ToPointer()); + + VectorManager.UnmarkDiskANNKey(span); + } + ClassicAssert.IsTrue(dataCopy.SequenceEqual(data)); + } + + [Test] + public void RecreateIndexesOnRestore() + { + var addData1 = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); + var addData2 = Enumerable.Range(0, 75).Select(static x => (byte)(x * 2)).ToArray(); + var queryData = addData1.ToArray(); + queryData[0]++; + + // VADD + { + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res2 = db.Execute("VADD", ["foo", "XB8", addData2, new byte[] { 0, 0, 0, 1 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "fizz buzz"]); + ClassicAssert.AreEqual(1, (int)res2); + } + } + + // VSIM with vector + { + byte[][] expectedVSimResult; + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + + expectedVSimResult = (byte[][])db.Execute("VSIM", ["foo", "XB8", queryData]); + ClassicAssert.AreEqual(1, expectedVSimResult.Length); +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res2 = (byte[][])db.Execute("VSIM", ["foo", "XB8", queryData]); + ClassicAssert.AreEqual(expectedVSimResult.Length, res2.Length); + for (var i = 0; i < res2.Length; i++) + { + ClassicAssert.IsTrue(expectedVSimResult[i].AsSpan().SequenceEqual(res2[i])); + } + } + } + + // VSIM with element + { + byte[][] expectedVSimResult; + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "XB8", addData2, new byte[] { 0, 0, 0, 1 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + + expectedVSimResult = (byte[][])db.Execute("VSIM", ["foo", "ELE", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(2, expectedVSimResult.Length); +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res2 = (byte[][])db.Execute("VSIM", ["foo", "ELE", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(expectedVSimResult.Length, res2.Length); + for (var i = 0; i < res2.Length; i++) + { + ClassicAssert.IsTrue(expectedVSimResult[i].AsSpan().SequenceEqual(res2[i])); + } + } + } + + // VDIM + { + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res2 = (int)db.Execute("VDIM", ["foo"]); + ClassicAssert.AreEqual(addData1.Length, res2); + } + } + + // VEMB + { + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res2 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(res2.Length, addData1.Length); + + for (var i = 0; i < res2.Length; i++) + { + ClassicAssert.AreEqual((float)addData1[i], float.Parse(res2[i])); + } + } + } + + // VREM + { + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "XB8", addData2, new byte[] { 0, 0, 0, 1 }, "CAS", "NOQUANT", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res1 = (int)db.Execute("VREM", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(1, res1); + + var res2 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 1 }]); + ClassicAssert.AreEqual(res2.Length, addData1.Length); + + for (var i = 0; i < res2.Length; i++) + { + ClassicAssert.AreEqual((float)addData2[i], float.Parse(res2[i])); + } + } + } + } + + // TODO: FLUSHDB needs to cleanup too... + + [Test] + public void VINFO_NotFound() + { + // VINFO NotFound response depends on the RESP version used: + // - Resp3: Null + // - Resp2: Null array reply + using var redisResp3 = ConnectionMultiplexer.Connect(TestUtils.GetConfig(protocol: RedisProtocol.Resp3)); + var resp3Result = redisResp3.GetDatabase().Execute("VINFO", ["nonexistent"]); + ClassicAssert.IsTrue(resp3Result.IsNull); + ClassicAssert.IsTrue(resp3Result.Resp3Type == ResultType.Null); + + using var redisResp2 = ConnectionMultiplexer.Connect(TestUtils.GetConfig(protocol: RedisProtocol.Resp2)); + var resp2Result = redisResp2.GetDatabase().Execute("VINFO", ["nonexistent"]); + ClassicAssert.IsTrue(resp2Result.IsNull); + ClassicAssert.IsTrue(resp2Result.Resp2Type == ResultType.Array); + } + + [Test] + public void VINFO() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // TODO: Add tests for Q8 and BIN quantizers once supported in diskann-garnet + string[] quantizers = ["XPREQ8", "NOQUANT"]; + int[] reduceValues = [0, 5]; + int[] efValues = [0, 8]; + int[] mValues = [0, 16]; + int[] vectorDimensions = [9, 10]; + var testCnt = 0; + + foreach (var quantizer in quantizers) + { + var expectedQuantType = quantizer == "NOQUANT" ? + "f32" : quantizer.ToLower(); + + foreach (var reduceValue in reduceValues) + { + var reduceValueToUse = quantizer == "XPREQ8" ? 0 : reduceValue; + foreach (var ef in efValues) + { + foreach (var numLinks in mValues) + { + foreach (var vectorDim in vectorDimensions) + { + testCnt++; + string fooKey = $"foo:{testCnt}"; + + // Generate vector data based on quantizer type + // XPREQ8 requires XB8 format, NOQUANT uses VALUES format + object vectorData1; + object vectorData2; + if (quantizer == "XPREQ8") + { + // XB8 format: byte array + var bytes1 = new byte[vectorDim]; + var bytes2 = new byte[vectorDim]; + for (int i = 0; i < vectorDim; i++) + { + bytes1[i] = (byte)(i + 1); + bytes2[i] = (byte)(i + 2); + } + vectorData1 = bytes1; + vectorData2 = bytes2; + } + else + { + // VALUES format: list of float strings + var values1 = new List { "VALUES", vectorDim.ToString() }; + var values2 = new List { "VALUES", vectorDim.ToString() }; + for (int i = 1; i <= vectorDim; i++) + { + values1.Add($"{i}.0"); + values2.Add($"{i + 1}.0"); + } + vectorData1 = values1.ToArray(); + vectorData2 = values2.ToArray(); + } + + // Create a vector set with known parameters + var res = db.Execute("VADD", GenerateVADDOptions(fooKey, quantizer, reduceValueToUse, ef, numLinks, vectorData1, [0, 0, 0, 0])); + ClassicAssert.AreEqual(1, (int)res); + + string expectedEf = ef == 0 ? "200" : ef.ToString(); + string expectedNumLinks = numLinks == 0 ? "16" : numLinks.ToString(); + + // Get VINFO - should return an array of 14 elements (6 key-value pairs) + var vinfoRes = (RedisValue[])db.Execute("VINFO", [fooKey]); + ClassicAssert.AreEqual(14, vinfoRes.Length); + var values = BuildDictionaryFromResponse(vinfoRes); + ClassicAssert.AreEqual(values["quant-type"], expectedQuantType); + ClassicAssert.AreEqual(values["distance-metric"], "l2"); + ClassicAssert.AreEqual(values["input-vector-dimensions"], vectorDim.ToString()); + ClassicAssert.AreEqual(values["reduced-dimensions"], reduceValueToUse.ToString()); + ClassicAssert.AreEqual(values["build-exploration-factor"], expectedEf); + ClassicAssert.AreEqual(values["num-links"], expectedNumLinks); + ClassicAssert.AreEqual(values["size"], "1"); + + // Add another element and try again + res = db.Execute("VADD", GenerateVADDOptions(fooKey, quantizer, reduceValueToUse, ef, numLinks, vectorData2, [0, 0, 0, 1])); + ClassicAssert.AreEqual(1, (int)res); + + vinfoRes = (RedisValue[])db.Execute(command: "VINFO", [fooKey]); + ClassicAssert.AreEqual(14, vinfoRes.Length); + values = BuildDictionaryFromResponse(vinfoRes); + ClassicAssert.AreEqual(values["quant-type"], expectedQuantType); + ClassicAssert.AreEqual(values["distance-metric"], "l2"); + ClassicAssert.AreEqual(values["input-vector-dimensions"], vectorDim.ToString()); + ClassicAssert.AreEqual(values["reduced-dimensions"], reduceValueToUse.ToString()); + ClassicAssert.AreEqual(values["build-exploration-factor"], expectedEf); + ClassicAssert.AreEqual(values["num-links"], expectedNumLinks); + ClassicAssert.AreEqual(values["size"], "2"); + + // Delete vector set + db.KeyDelete(fooKey); + } + } + } + } + } + + static object[] GenerateVADDOptions(string key, string quantizer, int reduce, int buildExplorationFactor, int numLinks, object vectorData, byte[] elementId) + { + if (quantizer == "XPREQ8") + { + reduce = 0; + } + + List opts = [key]; + if (reduce > 0) + { + opts.Add("REDUCE"); + opts.Add(reduce.ToString()); + } + + // Add vector data based on quantizer type + if (quantizer == "XPREQ8") + { + // XB8 format for XPREQ8 + opts.Add("XB8"); + opts.Add(vectorData); + } + else + { + // VALUES format for NOQUANT + opts.AddRange((object[])vectorData); + } + + opts.Add(elementId); + opts.Add(quantizer); + if (buildExplorationFactor > 0) + { + opts.Add("EF"); + opts.Add(buildExplorationFactor.ToString()); + } + + if (numLinks > 0) + { + opts.Add("M"); + opts.Add(numLinks.ToString()); + } + + return opts.ToArray(); + } + + static Dictionary BuildDictionaryFromResponse(RedisValue[] response) + { + Dictionary values = new(); + for (var i = 0; i < response.Length; i += 2) + { + values[response[i]] = response[i + 1]; + } + + return values; + } + } + + [Test] + public void VGETATTR_NotFound() + { + var vectorSetKey = "foo"; + var elementId1 = new byte[] { 0, 0, 0, 0 }; + var nonExistentElementId = new byte[] { 9, 9, 9, 9 }; + + // Test not found case - non-existent vector set (RESP3) + using var redisResp3 = ConnectionMultiplexer.Connect(TestUtils.GetConfig(protocol: RedisProtocol.Resp3)); + var dbResp3 = redisResp3.GetDatabase(); + + var resp3Result1 = dbResp3.Execute("VGETATTR", [vectorSetKey, elementId1]); + ClassicAssert.IsTrue(resp3Result1.IsNull); + ClassicAssert.IsTrue(resp3Result1.Resp3Type == ResultType.Null); + + // Test not found case - non-existent vector set (RESP2) + using var redisResp2 = ConnectionMultiplexer.Connect(TestUtils.GetConfig(protocol: RedisProtocol.Resp2)); + var dbResp2 = redisResp2.GetDatabase(); + + var resp2Result1 = dbResp2.Execute("VGETATTR", [vectorSetKey, elementId1]); + ClassicAssert.IsTrue(resp2Result1.IsNull); + ClassicAssert.IsTrue(resp2Result1.Resp2Type == ResultType.BulkString); + + // Create a vector set with first element + var res1 = dbResp3.Execute("VADD", ["foo", "VALUES", "3", "1.0", "2.0", "3.0", elementId1, "NOQUANT"]); + ClassicAssert.AreEqual(1, (int)res1); + + // Test not found case - non-existent element (RESP3) + var resp3Result2 = dbResp3.Execute("VGETATTR", [vectorSetKey, nonExistentElementId]); + ClassicAssert.IsTrue(resp3Result2.IsNull); + ClassicAssert.IsTrue(resp3Result2.Resp3Type == ResultType.Null); + + // Test not found case - non-existent element (RESP2) + var resp2Result2 = dbResp2.Execute("VGETATTR", [vectorSetKey, nonExistentElementId]); + ClassicAssert.IsTrue(resp2Result2.IsNull); + ClassicAssert.IsTrue(resp2Result2.Resp2Type == ResultType.BulkString); + } + + [Test] + public void VGETATTR() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var vectorSetKey = "foo"; + var elementId1 = new byte[] { 0, 0, 0, 0 }; + + // Create a vector set with first element (no attribute) + var res1 = db.Execute("VADD", ["foo", "VALUES", "3", "1.0", "2.0", "3.0", elementId1, "NOQUANT"]); + ClassicAssert.AreEqual(1, (int)res1); + + // Test success case - element with no attribute + var res2 = (byte[])db.Execute("VGETATTR", [vectorSetKey, elementId1]); + ClassicAssert.AreEqual(0, res2.Length); + + // Test various attribute sizes + int[] attributeSizes = [64, 128, 256, 257, 512, 1024]; + + for (var i = 0; i < attributeSizes.Length; i++) + { + var attrSize = attributeSizes[i]; + var attrData = Enumerable.Repeat((byte)(i + '0'), attrSize).ToArray(); + var elementId = new byte[] { 0, 0, 0, (byte)(i + 1) }; + + // Add element with attribute of specific size + var addRes = db.Execute("VADD", ["foo", "VALUES", "3", "4.0", "5.0", "6.0", elementId, "NOQUANT", "SETATTR", attrData]); + ClassicAssert.AreEqual(1, (int)addRes); + + // Get and validate attribute + var getAttrRes = (byte[])db.Execute(command: "VGETATTR", [vectorSetKey, elementId]); + ClassicAssert.AreEqual(attrSize, getAttrRes.Length, $"Attribute size mismatch for size {attrSize}"); + ClassicAssert.IsTrue(attrData.SequenceEqual(getAttrRes), $"Attribute content mismatch for size {attrSize}"); + } + + // Test empty string attribute (equivalent to no attribute) + var emptyAttrElement = new byte[] { 0, 0, 0, 99 }; + var res3 = db.Execute("VADD", ["foo", "VALUES", "3", "7.0", "8.0", "9.0", emptyAttrElement, "NOQUANT", "SETATTR", ""]); + ClassicAssert.AreEqual(1, (int)res3); + + var res4 = (byte[])db.Execute("VGETATTR", [vectorSetKey, emptyAttrElement]); + ClassicAssert.AreEqual(0, res4.Length); + } + + [Test] + public void VREM() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Populate + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "100.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 1, 0, 0, 0 }, "CAS", "NOQUANT", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res2); + + // Remove on non-vector set fails + // TODO: test against Redis, how do they respond (I expect WRONGTYPE, but needs verification) + //_ = db.StringSet("fizz", "buzz"); + //var exc1 = ClassicAssert.Throws(() => db.Execute("VREM", "fizz", new byte[] { 0, 0, 0, 0 })); + //ClassicAssert.AreEqual("", exc1.Message); + + // Remove exists + var res3 = db.Execute("VREM", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(1, (int)res3); + + // Remove again fails + var res4 = db.Execute("VREM", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(0, (int)res4); + + // Remove not present + var res5 = db.Execute("VREM", ["foo", new byte[] { 1, 2, 3, 4 }]); + ClassicAssert.AreEqual(0, (int)res5); + + // VSIM doesn't return removed element + var res6 = (byte[][])db.Execute("VSIM", ["foo", "VALUES", "75", "110.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "COUNT", "5", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(1, res6.Length); + ClassicAssert.IsTrue(res6.Any(static x => x.SequenceEqual(new byte[] { 1, 0, 0, 0 }))); + + // VEMB doesn't return removed element + var res7 = (string[])db.Execute("VEMB", "foo", new byte[] { 0, 0, 0, 0 }); + ClassicAssert.IsEmpty(res7); + } + + [Test] + public void SimpleInternalIdReuse() + { + const string Key = "SimpleInternalIdReuse"; + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Both these adds get internal id 1 due the interleaves remove + ExpectSuccess(db.Execute("VADD", [Key, "XB8", new byte[] { 36, 127, 75, 189, 65, 104, 32, 98, 182, 97, 52, 85, 16, 176, 0, 233, 236, 90, 153, 239, 88, 107, 60, 191, 208, 50, 60, 241, 27, 21, 30, 233, 23, 9, 23, 6, 152, 179, 206, 168, 117, 201, 179, 226, 72, 114, 149, 45, 95, 5, 57, 230, 72, 50, 83, 184, 67, 140, 236, 15, 43, 46, 71, 161, 67, 75, 62, 7, 152, 249, 80, 57, 139, 241, 121 }, new byte[] { 143 }, "XPREQ8"])); + ExpectSuccess(db.Execute("VREM", [Key, new byte[] { 143 }])); + ExpectSuccess(db.Execute("VADD", [Key, "XB8", new byte[] { 176, 79, 173, 190, 74, 104, 121, 238, 209, 182, 91, 37, 70, 231, 58, 20, 151, 19, 62, 38, 143, 52, 79, 148, 24, 98, 242, 192, 96, 39, 76, 254, 82, 13, 217, 35, 79, 91, 9, 141, 41, 169, 86, 220, 64, 191, 98, 105, 38, 131, 145, 14, 198, 28, 190, 124, 0, 24, 165, 231, 117, 184, 142, 170, 106, 93, 210, 56, 14, 22, 197, 60, 10, 177, 253 }, new byte[] { 230, 221, 114, 84, 89, 0, 137, 154, 220, 149, 61 }, "XPREQ8"])); + var shouldBeEmpty = (string[])db.Execute("VEMB", [Key, new byte[] { 143 }]); + ClassicAssert.IsEmpty(shouldBeEmpty); + + static void ExpectSuccess(dynamic res) + { + ClassicAssert.AreEqual(1, (int)res); + } + } + + [Test] + public void StressInternalIdReuse() + { + const int Vectors = 1_000; + const int Deletes = 200; + const string Key = "StressInternalIdReuse"; + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Build some repeatably random data for inserts + var vectors = new List<(byte[] Id, byte[] Data)>(); + var toDeleteVectors = new HashSet(ByteArrayComparer.Instance); + var pendingAdd = new List<(byte[] Id, byte[] Data)>(); + var pendingRemove = new List(); + var alreadyRemoved = new List(); + var r = new Random(2026_01_21); + { + for (var i = 0; i < Vectors; i++) + { + var id = new byte[r.Next(16) + 1]; + var data = new byte[75]; + r.NextBytes(data); + r.NextBytes(id); + + if (vectors.Any(t => t.Id.SequenceEqual(id))) + { + i--; + continue; + } + + vectors.Add((id, data)); + } + + while (toDeleteVectors.Count < Deletes) + { + _ = toDeleteVectors.Add(vectors[r.Next(vectors.Count)].Id); + } + + pendingAdd.AddRange(vectors); + pendingRemove.AddRange(toDeleteVectors); + } + + // Randomly interleave adds and removes + while (pendingAdd.Count > 0 || pendingRemove.Count > 0) + { + if (r.Next(2) == 0 && pendingAdd.Count > 0) + { + var addIx = r.Next(pendingAdd.Count); + var (id, data) = pendingAdd[addIx]; + + var addRes = (int)db.Execute("VADD", [Key, "XB8", data, id, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes); + + pendingAdd.RemoveAt(addIx); + } + else if (pendingRemove.Count > 0) + { + var removeIx = r.Next(pendingRemove.Count); + var id = pendingRemove[removeIx]; + + var shouldSucceed = !pendingAdd.Any(t => t.Id.SequenceEqual(id)); + + var remRes = (int)db.Execute("VREM", [Key, id]); + + if (shouldSucceed) + { + ClassicAssert.AreEqual(1, remRes); + + var embRes = (string[])db.Execute("VEMB", [Key, id]); + ClassicAssert.IsEmpty(embRes); + + pendingRemove.RemoveAt(removeIx); + alreadyRemoved.Add(id); + } + else + { + ClassicAssert.AreEqual(0, remRes); + } + } + + // Check that prior deletes remain deleted + foreach (var id in alreadyRemoved) + { + var embRes = (string[])db.Execute("VEMB", [Key, id]); + ClassicAssert.IsEmpty(embRes); + } + } + + // Validate final state + foreach (var (id, data) in vectors) + { + var shouldExists = !toDeleteVectors.Contains(id); + + var embRes = (string[])db.Execute("VEMB", [Key, id]); + + if (shouldExists) + { + ClassicAssert.AreEqual(data.Length, embRes.Length); + for (var i = 0; i < data.Length; i++) + { + ClassicAssert.AreEqual(data[i], byte.Parse(embRes[i])); + } + } + else + { + ClassicAssert.IsEmpty(embRes); + } + } + } + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "opts")] + private static extern ref GarnetServerOptions GetOpts(GarnetServer server); + } +} \ No newline at end of file diff --git a/test/Garnet.test/TestUtils.cs b/test/Garnet.test/TestUtils.cs index 11570261f53..ba9dcab783f 100644 --- a/test/Garnet.test/TestUtils.cs +++ b/test/Garnet.test/TestUtils.cs @@ -125,6 +125,9 @@ internal static bool IsRunningAzureTests } } + internal static bool IsRunningAsGitHubAction + => "true".Equals(Environment.GetEnvironmentVariable("GITHUB_ACTIONS"), StringComparison.OrdinalIgnoreCase); + [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static void AssertEqualUpToExpectedLength(string expectedResponse, byte[] response) { @@ -273,8 +276,9 @@ public static GarnetServer CreateGarnetServer( int expiredKeyDeletionScanFrequencySecs = -1, bool useReviv = false, bool useInChainRevivOnly = false, - bool useLogNullDevice = false - ) + bool useLogNullDevice = false, + bool enableVectorSetPreview = true + ) { if (useAzureStorage) IgnoreIfNotRunningAzureTests(); @@ -361,6 +365,7 @@ public static GarnetServer CreateGarnetServer( UnixSocketPermission = unixSocketPermission, SlowLogThreshold = slowLogThreshold, ExpiredKeyDeletionScanFrequencySecs = expiredKeyDeletionScanFrequencySecs, + EnableVectorSetPreview = enableVectorSetPreview, }; if (!string.IsNullOrEmpty(memorySize)) @@ -662,7 +667,8 @@ public static GarnetServerOptions GetGarnetServerOptions( int replicaSyncTimeout = 60, int expiredObjectCollectionFrequencySecs = 0, ClusterPreferredEndpointType clusterPreferredEndpointType = ClusterPreferredEndpointType.Ip, - string clusterAnnounceHostname = null) + string clusterAnnounceHostname = null, + bool enableVectorSetPreview = true) { if (useAzureStorage) IgnoreIfNotRunningAzureTests(); @@ -786,6 +792,7 @@ public static GarnetServerOptions GetGarnetServerOptions( CheckpointThrottleFlushDelayMs = checkpointThrottleFlushDelayMs, ClusterReplicaResumeWithData = clusterReplicaResumeWithData, ReplicaSyncTimeout = replicaSyncTimeout <= 0 ? Timeout.InfiniteTimeSpan : TimeSpan.FromSeconds(replicaSyncTimeout), + EnableVectorSetPreview = enableVectorSetPreview, ExpiredObjectCollectionFrequencySecs = expiredObjectCollectionFrequencySecs, }; diff --git a/website/docs/dev/vector-sets.md b/website/docs/dev/vector-sets.md new file mode 100644 index 00000000000..79fa84f09e1 --- /dev/null +++ b/website/docs/dev/vector-sets.md @@ -0,0 +1,421 @@ +--- +id: vector-sets +sidebar_label: Vector Sets +title: Vector Sets +--- + +# Overview + +Garnet has partial support for Vector Sets, implemented on top of the [DiskANN project](https://www.nuget.org/packages/diskann-garnet/). + +This data type is very strange when compared to others Garnet supports. + +> [!IMPORTANT] +> The DiskANN link needs to be updated once OSS'd. + +# Design + +Vector Sets are a combination of one "index" key, which stores metadata and a pointer to the DiskANN data structure, and many "element" keys, which store vectors/quantized vectors/attributes/etc. All Vector Set keys are kept in the main store, but only the index key is visible - this is accomplished by putting all element keys in different namespaces. + +## Global Metadata + +In order to track allocated Vector Sets (and their respective hash slots), in progress cleanups, in progress migrations - we keep a single `ContextMetadata` struct under the empty key in namespace 0. + +This is loaded and cached on startup, and updated (both in memory and in Tsavorite) whenever a Vector Set is created or deleted. Simple locking (on the `VectorManager` instance) is used to serialize these updates as they should be rare. + +> [!IMPORTANT] +> Today `ContextMetadata` can track only 64 Vector Sets in some state of creation or cleanup. +> +> The practical limit is actually 31, because context must be < 256, divisible by 8, and not 0 (which is reserved). +> +> This limitation will be lifted eventually, perhaps after Store V2 lands. + +## Indexes + +The index key (represented by the `Index` struct) contains the following data: + - `ulong Context` - used to derive namespaces, detailed below + - `ulong IndexPtr` - a pointer to the DiskANN data structure, note this may be _dangling_ after [recovery](#recovery) or [replication](#replication) + - `uint Dimensions` - the expected dimension of vectors in commands targeting the Vector Set, this is inferred based on the `VADD` that creates the Vector Set + - `uint ReduceDims` - if a Vector Set was created with the `REDUCE` option that value, otherwise zero + * > [!NOTE] + > Today this ignored except for validation purposes, eventually DiskANN will use it. + - `uint NumLinks` - the `M` used to create the Vector Set, or the default value of 16 if not specified + - `uint BuildExplorationFactor` - the `EF` used to create the Vector Set, or the default value of 200 if not specified + - `VectorQuantType QuantType` - the quantizier specified at creation time, or the default value of `Q8` if not specified + * > [!NOTE] + > We have an extension here, `XPREQ8` which is not from Redis. + > This is a quantizier for data sets which have already been 8-bit quantized or are otherwise naturally small byte vectors, and is extremely optimized for reducing reads during queries. + > It forbids the `REDUCE` option and requires 4-byte element ids. + * > [!IMPORTANT] + > Today only `XPREQ` is actually implemented, eventually DiskANN will provide reasonable versions of all the Redis builtin quantizers. + - `Guid ProcessInstanceId` - an identifier which is used distinguish the current process from previous instances, this is used after [recovery](#recovery) or [replication](#replication) to detect if `IndexPtr` is dangling + +The index key is in the main store alongside other binary values like strings, hyperloglogs, and so on. It is distinguished for `WRONGTYPE` purposes with the `VectorSet` bit on `RecordInfo`. + +> [!IMPORTANT] +> `RecordInfo.VectorSet` is checked in a few places to correctly produce `WRONGTYPE` responses, but we need more coverage for all commands. Probably something akin to how ACLs required per-command tests. + +> [!IMPORTANT] +> A generalization of the `VectorSet`-bit should be used for all data types, this can happen once we have Store V2. + +## Elements + +While the Vector Set API only concerns itself with top-level index keys, ids, vectors, and attributes; DiskANN has different storage needs. To abstract around these needs a bit, we reserve a number of different "namespaces" for each Vector Set. + +These namespaces are simple numbers, starting at the `Context` value stored in the `Index` struct - we currently reserve 8 namespaces per Vector Set. What goes in which namespace is mostly hidden from Garnet, DiskANN indicates namespace (and index) to use with a modified `Context` passed to relevant callbacks. +> There are two cases where we "know" the namespace involved: attributes (+3) and full vectors (+0) which are used to implement the `WITHATTR` option and the `VEMB` command respectively. These exceptions _may_ go away in the future, but don't have to. + +Using namespaces prevents other commands from accessing keys which store element data. + +To illustrate, this means that: +``` +VADD vector-set-key VALUES 1 123 element-key +SET element-key string-value +``` +Can work as expected. Without namespacing, the `SET` would overwrite (or otherwise mangle) the element data of the Vector Set. + +# Operations + +We implement the [Redis Vector Set API](https://redis.io/docs/latest/commands/?group=vector_set): + +Implemented commands: + - [x] VADD + - [ ] VCARD + - [x] VDIM + - [x] VEMB + - [ ] VGETATTR + - [ ] VINFO + - [ ] VISMEMBER + - [ ] VLINKS + - [ ] VRANDMEMBER + - [x] VREM + - [ ] VSETATTR + - [x] VSIM + +## Creation (via `VADD`) + +[`VADD`](https://redis.io/docs/latest/commands/vadd/) implicitly creates a Vector Set when run on an empty key. + +DiskANN index creation must be serialized, so this requires holding an exclusive lock ([more details on locking](#locking)) that covers just that key. During the `create_index` call to DiskANN the read/write/delete callbacks provided may be invoked - accordingly creation is re-entrant and we cannot call `create_index` directly from any Tsavorite session functions. + +## Insertion (via `VADD`) + +Once a Vector Set exists, insertions (which also use `VADD`) can proceed in parallel. + +Every insertion begins with a Tsavorite read, to get the [`Index`](#indexes) metadata (for validation) and the pointer to DiskANN's index. As a consequence, most `VADD` operations despite _semantically_ being writes are, from Tsavorite's perspective, reads. This has implications for replication, [which is discussed below](#replication). + +To prevent the index from being deleted mid-insertion, we hold a shared lock while calling DiskANN's `insert` function. These locks are sharded for performance purposes, [which is discussed below](#locking). + +## Removal (via `VREM`) + +Removal works much the same as insertion, using shared locks so it can proceed in parallel. The only meaningful difference is calling DiskANN's `remove` instead of `insert`. + +> [!NOTE] +> Removing all elements from a Vector Set is not the same as deleting it. While it is not possible to create an empty Vector Set with a single command, it is legal for one to exist after a `VREM`. + +## Search (via `VSIM`) + +Searching is a pure read operation, and so holds shared locks and proceeds in parallel like insertions and removals. + +Great care is taken to avoid copying during `VSIM`. In particular, values and element ids are passed directly from the receive buffer for all encodings except `VALUES`. Callbacks from DiskANN to Garnet likewise take great care to avoid copying, and are [detailed below](#diskann-integration). + +## Element Data (via `VEMB` and `VGETATTR`) + +These operations are handled purely on the Garnet side by first reading out the [`Index`](#indexes) structure, and then using the context value to look for data in the appropriate namespaces. + +> [!NOTE] +> Strictly speaking we don't need the DiskANN index to access this data, but the current implementation does make sure the index is valid. + +## Metadata (via `VDIM` and `VINFO`) + +Metadata is handled purely on the Garnet side by reading out the [`Index`](#indexes) structure. + +> [!NOTE] +> `VINFO` directly exposes Redis implementation details in addition to "normal" data. +> Because our implementation is different, we intentionally will not expose all the same information. +> To be concrete `max-level`, `vset-uid`, and `hnsw-max-node-uid` are not returned. + +> [!IMPORTANT] +> We _may_ return more details of our own implementation. What those are need to be documented, and why, +> when we implement `VINFO`. + +## Deletion (via `DEL` and `UNLINK`) + +`DEL` (and its equivalent `UNLINK`) is only non-Vector Set command to be routinely expected on a Vector Set key. It is complicated by not knowing we're operating on a Vector Set until we get rather far into deletion. + +We cope with this by _cancelling_ the Tsavorite delete operation once we have a `RecordInfo` with the `VectorSet`-bit set and a value which is not all zeros, detecting that cancellation in `MainStoreOps`, and shunting the delete attempt to `VectorManager`. + +`VectorManager` performs the delete in five steps: + - Acquire exclusive locks covering the Vector Set ([more locking details](#locking)) + - Add the key to an `InProgressDeletes` key (namespace 0, key=0x01) + - If the index was initialized in the current process ([see recovery for more details](#recovery)), call DiskANN's `drop_index` function + - Perform a write to zero out the index key in Tsavorite + - Reattempt the Tsavorite delete + - Cleanup ancillary metadata and schedule element data for cleanup ([more details below](#cleanup)) + - Remove the key from the `InProgressDeletes` key + +The `InProgressDeletes` key is necessary to recover from interrupted deletes. At process start, `VectorManager` consults the `InProgressDeletes` key and completes any deletes that got as far as zero-ing out the index key. + +> [!IMPORTANT] Interrupted deletes are expected only during process exits, but if they occur without the process exiting they will leave the Vector Set in a partially deleted state. We detect that and return a new `GarnetStatus.BADSTATE` which returns an explanatory error. +> +> We _could_ resume the delete on `GarnetStatus.BADSTATE`, but like `GarnetStatus.WRONGTYPE` that needs to be done for _all_ commands not just Vector Set commands. This work is likewise left for the future. + +## FlushDB + +`FLUSHDB` (and it's relative `FLUSHALL`) require special handling. + +> [!IMPORTANT] +> This is not currently implemented. + +# Locking + +Vector Sets workloads require extreme parallelism, and so intricate locking protocols are required for both performance and correctness. + +Concretely, there are 3 sorts of locks involved: + - Tsavorite hashbucket locks + - A `ReadOptimizedLock` instance + - `VectorManager` lock around `ContextMetadata` + +## Tsavorite Locks + +Whenever we read or write a key/value pair in the main store, we acquire locks in Tsavorite. Importantly, we cannot start a new Tsavorite operation while still holding these locks - we must copy the index out before each operation so Garnet can use the read/write/delete callbacks. + +> [!NOTE] +> Based on profiling, Tsavorite shared locks are a significant source of contention. Even though reads will not block each other we still pay a cache coherency tax. Accordingly, reducing the number of Tsavorite operations (even reads) can lead to significant performance gains. + +> [!IMPORTANT] +> Some effort was spent early attempting to elide the initial index read in common cases. This did not pay dividends on smaller clusters, but is worth exploring again on large SKUs. + +## `ReadOptimizedLock` + +As noted above, to prevent `DEL` from clobbering in use Vector Sets and concurrent `VADD`s from calling `create_index` multiple times we have to hold locks based on the Vector Set key. As every Vector Set operations starts by taking these locks, we have sharded them into separate locks. To derive many related keys from a single key, we mangle the low bits of a key's hash value - this is implemented in the new (but not bound to Vector Sets) type `ReadOptimizedLock`. + +For operations which remain reads, we only acquire a single shared lock (based on the current thread) to prevent destructive operations. + +For operations which are always writes (like `DEL`) we acquire all sharded locks in exclusive mode. + +For operations which might be either (like `VADD`) we first acquire the usual single sharded lock (in shared mode), then promote to an exclusive lock if needed. + +## `VectorManager` Lock Around `ContextMetadata` + +Whenever we need to allocate a new context or mark an old one for cleanup, we need to modify the cached `ContextMetadata` and write the new value to Tsavorite. To simplify this, we take a plain `lock` around `VectorManager` while preparing a new `ContextMetadata`. + +The `RMW` into Tsavorite still proceeds in parallel, outside of the lock, but a version counter in `ContextMetadata` allows us to keep only the latest version in the store. + +> [!NOTE] +> Rapid creation or deletion of Vector Sets is expected to perform poorly due to this lock. +> This isn't a case we're very interested in right now, but if that changes this will need to be reworked. + +# Replication + +Replicating Vector Sets is tricky because of the unusual "writes are actually reads"-semantics of most operations. + +## On Primaries + +As noted above, inserts (via `VADD`) and deletes (via `VREM`) are reads from Tsavorite's perspective. As a consequence, normal replication (which is triggered via `MainSessionFunctions.WriteLog(Delete|RMW|Upsert)`) does not happen on those operations. + +To fix that, synthetic writes against related keys are made after an insert or remove. These writes are against the same Vector Set key, but in namespace 0. See `VectorManager.ReplicateVectorSetAdd` and `VectorManager.ReplicateVectorSetRemove` for details. + +> [!IMPORTANT] +> There is a failure case here where we crash between the insert operation completing and the replication operation completing. +> +> This appears to simply extend a window that already existed between when a Tsavorite operation completed and an entry was written to the AOF. +> This needs to confirmed - if it is not the case, handling this failure needs to be figured out. + +> [!IMPORTANT] +> This code assumes a Vector Set under the empty string is illegal. That does not seem to be true with Redis - so we will need to move these keys elsewhere. For now, we just forbid the empty key for VADDs. + +> [!NOTE] +> These synthetic writes might appear to double write volume, but that is not the case. Actual inserts and deletes have extreme write amplification (that is, each cause DiskANN to perform many writes against the Main Store), whereas the synthetic writes cause a single (no-op) modification to the Main Store plus an AOF entry. + +> [!NOTE] +> The replication key is the same for all operations against the same Vector Set, this could be sharded which may improve performance. + +## On Replicas + +The synthetic writes on primary are intercepted on replicas and redirected to `VectorManager.HandleVectorSetAddReplication` and `VectorManager.HandleVectorSetRemoveReplication`, rather than being handled directly by `AOFProcessor`. + +For performance reasons, replicated `VADD`s are applied across many threads instead of serially. This introduces a new source of non-determinism, since `VADD`s will occur in a different order than on the primary, but this is acceptable as Vector Sets are inherently non-deterministic. While not _exactly_ the same Redis also permits a degree of non-determinism with its `CAS` option for `VADD`, so we're not diverging an incredible amount here. + +While a `VADD` can proceed in parallel with respect to other `VADD`s, that is not the case for any other commands. Accordingly, `AofProcessor` now calls `VectorManager.WaitForVectorOperationsToComplete()` before applying any other updates to maintain coherency. + +## Migration + +Migrating a Vector Set between two primaries (either as part of a `MIGRATE ... KEYS` or migration of a whole hash slot) is complicated by storing element data in namespaces. + +Namespaces (intentionally) do not participate in hash slots or clustering, and are a node specific concept. This means that migration must also update the namespaces of elements as they are migrated. + +At a high level, migration between the originating primary a destination primary behaves as follows: + 1. Once target slots transition to `MIGRATING`... + * An addition to `ClusterSession.SingleKeySlotVerify` causes all WRITE Vector Set commands to pause once a slot is `MIGRATING` or `IMPORTING` - this is necessary because we cannot block based on the key as Vector Sets are composed of many key-value pairs across several namespaces + 2. `VectorManager` on the originating primary enumerates all _namespaces_ and Vector Sets that are covered by those slots + 3. The originating primary contacts the destination primary and reserves enough new Vector Set contexts to handled those found in step 2 + * These Vector Sets are "in use" but also in a migrating state in `ContextMetadata` + 4. During the scan of main store in `MigrateOperation` any keys found with namespaces found in step 2 are migrated, but their namespace is updated prior to transmission to the appropriate new namespaces reserved in step 3 + * Unlike with normal keys, we do not _delete_ the keys in namespaces as we enumerate them + * Also unlike with normal keys, we synthesize a write on the _destination_ (using a special arg and `VADD`) so replicas of the destination also get these writes + 5. Once all namespace keys are migrated, we migrate the Vector Set index keys, but mutate their values to have the appropriate context reserved in step 3 + * As in 4, we synthesize a write on the _destination_ to tell any replicas to also create the index key + 6. When the target slots transition back to `STABLE`, we do a delete of the Vector Set index keys, drop the DiskANN indexes, and schedule the original contexts for cleanup on the originating primary + * Unlike in 4 & 5, we do no synthetic writes here. The normal replication of `DEL` will cleanup replicas of the originating primary. + + `KEYS` migrations differ only in the slot discovery being omitted. We still have to determine the migrating namespaces, reserve new ones on the destination primary, and schedule cleanup only once migration is completed. This does mean that, if any of the keys being migrated is a Vector Set, `MIGRATE ... KEYS` now causes a scan of the main store. + +> [!NOTE] +> This approach prevents the Vector Set from being visible when it is partially migrated, which has the desirable property of not returning weird results during a migration. + +> [!NOTE] +> While we explicitly reserve contexts on primaries, they are implicit on replicas. This is because a replica should always come up with the same determination of reserved contexts. +> +> To keep that determinism, the synthetic `VADD`s introduced by migration are not executed in parallel. + +# Cleanup + +Deleting a Vector Set only drops the DiskANN index and removes the top-level keys (ie. the visible key and related hidden keys for replication). This leaves all element, attribute, neighbor lists, etc. still in the Main Store. + +To clean up the remaining data we record the deleted index context value in `ContextMetadata` and then schedule a full sweep of the Main Store looking for any keys under namespaces related to that context. When we find those keys we delete them, see `VectorManager.RunCleanupTaskAsync()` and `VectorManager.PostDropCleanupFunctions` for details. + +> [!NOTE] +> There isn't really an elegant way to avoid scanning the whole keyspace which can take awhile to free everything up. +> +> If we wanted to explore better options, we'd need to build something that can drop whole namespaces at once in Tsavorite. + +> [!IMPORTANT] +> Today because we only have ~30 available Vector Set contexts, it is quite likely that deleting a Vector Set and then immediately creating a new one will fail if you're near the limit. +> +> This will be fixed once we have arbitrarily long namespaces in Store V2, and have updated `ContextMetadata` to track those. + +# Recovery + +Vector Sets represent a unique kind of recovery because most operations are mediated through DiskANN, for which we only ever have a pointer to a data structure. This means that recovery needs to both deal with Vector Sets metadata AND the recreation of the DiskANN side of things. + +## Vector Set Metadata + +During startup we read any old `ContextMetadata` out of the Main Store, cache it, and resume any in progress cleanups. + +## Vector Sets + +While reading out [`Index`](#indexes) before performing a DiskANN function call, we check the stored `ProcessInstanceId` against the (randomly generated) one in our `VectorManager` instance. If they do not match, we know that the DiskANN `IndexPtr` is dangling and we need to recreate the index. + +To recreate, we acquire exclusive locks (in the same way we would for `VADD` or `DEL`) and invoke `create_index` again. From DiskANN's perspective, there's no difference between creating a new empty index and recreating an old one which has existing data. + +This means we recreate indexes lazily after recovery. Consequently the _first_ command (regardless of if it's a `VADD`, a `VSIM`, or whatever) against an index after recovery will be slower since it needs to do extra work, and will block other commands since it needs exclusive locking. + +> [!NOTE] +> Today `ProcessInstanceId` is a `GUID`, which means we're paying for a 16-byte comparison on every command. +> +> This comparison is highly predictable, but we could try and remove the comparison (with caching, as mentioned for `Index` above). +> We could also make it cheaper by using a random `ulong` instead, but would need to do some math to convince ourselves collisions aren't possible in realistic scenarios. + +# DiskANN Integration + +Almost all of how Vector Sets actually function is handled by DiskANN. Garnet just embeds it, translates between RESP commands and DiskANN functions, and manages storage. + +In order for DiskANN to access and store data in Garnet, we provide a set of callbacks. All callbacks are `[UnmanagedCallersOnly]` and converted to function pointers before they are passed to Garnet. + +All callbacks take a `ulong context` parameter which identifies the Vector Set involved (the high 61-bits of the context) and the associated namespace (the low 3-bits of the context). On the Garnet side, the whole `context` is effectively a namespace, but from DiskANN's perspective the top 61-bits are an opaque identifier. + +> [!IMPORTANT] +> As noted elsewhere, we only have a byte's worth of namespaces today - so although `context` could handle quintillions of Vector Sets, today we're limited to just 31. +> +> This restriction will go away with Store V2, but we expect "lower" Vector Sets to out perform "higher" ones due to the need for intermediate data copies with longer namespaces. + +## Read Callback + +The most complicated of our callbacks, the signature is: +```csharp +void ReadCallbackUnmanaged(ulong context, uint numKeys, nint keysData, nuint keysLength, nint dataCallback, nint dataCallbackContext) +``` + +`context` identifies which Vector Set is being operated on AND the associated namespace, `numKeys` tells us how many keys have been encoded into `keysData`, `keysData` and `keysLength` define a `Span` of length prefixied keys, `dataCallback` is a `delegate* unmanaged[Cdecl, SuppressGCTransition]` used to push found keys back into DiskANN, and `dataCallbackContext` is passed back unaltered to `dataCallback`. + +In the `Span` defined by `keysData` and `keysLength` the keys are length prefixed with a 4-byte little endian `int`. This is necessary to support variable length element ids, but also gives us some scratch space to store a namespace when we convert these to `SpanByte`s. This mangling is done as part of the `IReadArgBatch` implementation we use to read keys from Tsavorite. + +> [!NOTE] +> Once variable sized namespaces are supported we'll have to handle the case where the namespace can't fit in 4 bytes. However, we expect that to be rare (4-bytes would give us ~53,000,000 Vector Sets) and the performance benefits of _not_ copying during querying are very large. + +As we find keys, we invoke `dataCallback(index, dataCallbackContext, keyPointer, keyLength)`. If a key is not found, its index is simply skipped. The benefits of this is that we don't copy data out of the Tsavorite log as part of reads, DiskANN is able to do distance calculations and traversal over in-place data. + +> [!NOTE] +> Each invocation of `dataCallback` is a managed -> native transition, which can add up very quickly. We've reduced that as much as possible with function points and `SuppressGCTransition`, but that comes with risks. +> +> In particular if DiskANN raises an error or blocks in the `dataCallback` expect very bad things to happen, up to the runtime corrupting itself. Great care must be taken to keep the DiskANN side of this call cheap and reliable. + +> [!IMPORTANT] +> Tsavorite has been extended with a `ContextReadWithPrefetch` method to accommodate this pattern, which also employs prefetching when we have batches of keys to lookup. +> +> Additionally, some experimentation to figure out good prefetch sizes (and if [AMAC](https://dl.acm.org/doi/10.14778/2856318.2856321) is useful) based on hardware is merited. Right now we've chosen 12 based on testing with some 96-core Intel machines, but that is unlikely to be correct in all interesting circumstances. + +## Write Callback + +A simpler callback, the signature is: +```csharp +byte WriteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength, nint writeData, nuint writeLength) +``` + +`context` identifies which Vector Set is being operated on AND the associated namespace, `keyData` and `keyLength` represent a `Span` of the key to write, and `writeData` and `writeLength` represent a `Span` of the value to write. + +DiskANN guarantees an extra 4-bytes BEFORE `keyData` that we can safely modify. This is used to avoid copying the key value when we add a namespace to the `SpanByte` before invoking Tsavorite's `Upsert`. + +This callback returns 1 if successful, and 0 otherwise. + +## Delete Callback + +Another simple callback, the signature is: +```csharp +byte DeleteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength) +``` + +`context` identifies which Vector Set is being operated on AND the associated namespace, and `keyData` and `keyLength` represent a `Span` of the key to delete. + +As with the write callback, DiskANN guarantees an extra 4-bytes BEFORE `keyData` that we use to store a namespace, and thus avoid copying the key value before invoking Tsavorite's `Delete`. + +This callback returns 1 if the key was found and removed, and 0 otherwise. + +## Read Modify Write Callback + +A more complicated callback, the signature is: +```csharp +byte ReadModifyWriteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength, nuint writeLength, nint dataCallback, nint dataCallbackContext) +``` + +`context` identifies which Vector Set is being operated on AND the associated namespace, and `keyData` and `keyLength` represent a `Span` of the key to create, read, or update. + +`writeLength` is the desired number of bytes, this is only used used if we are creating a new key-value pair. + +As with the write and delete callbacks, DiskANN guarantees an extra 4-bytes BEFORE `keyData` that we use to store a namespace, and thus avoid copying the key value before invoking Tsavorite's `RMW`. + +After we allocate a new key-value pair or find an existing one, `dataCallback(nint dataCallbackContext, nint dataPointer, nuint dataLength)` is called. Changes made to data in this callback are persisted. This needs to be _fast_ to prevent gumming up Tsavorite, as we are under epoch protection. + +Newly allocated values are guaranteed to be all zeros. + +The callback returns 1 if the key-value pair was found or created, and 0 if some error occurred. + +## DiskANN Functions + +Garnet calls into the following DiskANN functions: + + - [x] `nint create_index(ulong context, uint dimensions, uint reduceDims, VectorQuantType quantType, uint buildExplorationFactor, uint numLinks, nint readCallback, nint writeCallback, nint deleteCallback, nint readModifyWriteCallback)` + - [x] `void drop_index(ulong context, nint index)` + - [x] `byte insert(ulong context, nint index, nint id_data, nuint id_len, VectorValueType vector_value_type, nint vector_data, nuint vector_len, nint attribute_data, nuint attribute_len)` + - [x] `byte remove(ulong context, nint index, nint id_data, nuint id_len)` + - [ ] `byte set_attribute(ulong context, nint index, nint id_data, nuint id_len, nint attribute_data, nuint attribute_len)` + - [x] `int search_vector(ulong context, nint index, VectorValueType vector_value_type, nint vector_data, nuint vector_len, float delta, int search_exploration_factor, nint filter_data, nuint filter_len, nuint max_filtering_effort, nint output_ids, nuint output_ids_len, nint output_distances, nuint output_distances_len, nint continuation)` + - [x] `int search_element(ulong context, nint index, nint id_data, nuint id_len, float delta, int search_exploration_factor, nint filter_data, nuint filter_len, nuint max_filtering_effort, nint output_ids, nuint output_ids_len, nint output_distances, nuint output_distances_len, nint continuation)` + - [ ] `int continue_search(ulong context, nint index, nint continuation, nint output_ids, nuint output_ids_len, nint output_distances, nuint output_distances_len, nint new_continuation)` + - [ ] `ulong card(ulong context, nint index)` + - [x] `byte check_internal_id_valid(ulong context, nint index, nint internal_id, nuint internal_id_len)` + + Some non-obvious subtleties: + - The number of results _requested_ from `search_vector` and `search_element` is indicated by `output_distances_len` + - `output_distances_len` is the number of _floats_ in `output_distances`, not bytes + - When inserting, if `vector_value_type == FP32` then `vector_len` is the number of _floats_ in `vector_data`, otherwise it is the number of bytes + - `byte` returning functions are effectively returning booleans, `0 == false` and `1 == true` + - `index` is always a pointer created by DiskANN and returned from `create_index` + - `context` is always the `Context` value created by Garnet and stored in [`Index`](#indexes) for a Vector Set, this implies it is always a non-0 multiple of 8 + - `search_vector`, `search_element`, and `continue_search` all return the number of ids written into `output_ids`, and if there are more values to return they set the `nint` _pointed to by_ `continuation` or `new_continuation` + +> [!IMPORTANT] +> These p/invoke definitions are all a little rough and should be cleaned up. +> +> They were defined very loosely to ease getting the .NET <-> Rust interface working quickly. \ No newline at end of file