diff --git a/libs/storage/Tsavorite/cs/src/core/Allocator/AllocatorBase.cs b/libs/storage/Tsavorite/cs/src/core/Allocator/AllocatorBase.cs index f125768992c..bd15b587a15 100644 --- a/libs/storage/Tsavorite/cs/src/core/Allocator/AllocatorBase.cs +++ b/libs/storage/Tsavorite/cs/src/core/Allocator/AllocatorBase.cs @@ -241,18 +241,27 @@ protected abstract void ReadAsync(ulong alignedSourceAddress, IntPtr d /// Flush checkpoint Delta to the Device [MethodImpl(MethodImplOptions.NoInlining)] internal virtual void AsyncFlushDeltaToDevice(CircularDiskWriteBuffer flushBuffers, long startAddress, long endAddress, long prevEndAddress, long version, DeltaLog deltaLog, - out SemaphoreSlim completedSemaphore, int throttleCheckpointFlushDelayMs) + out Task completedTask, int throttleCheckpointFlushDelayMs) { logger?.LogTrace("Starting async delta log flush with throttling {throttlingEnabled}", throttleCheckpointFlushDelayMs >= 0 ? $"enabled ({throttleCheckpointFlushDelayMs}ms)" : "disabled"); - var _completedSemaphore = new SemaphoreSlim(0); - completedSemaphore = _completedSemaphore; - // If throttled, convert rest of the method into a truly async task run because issuing IO can take up synchronous time if (throttleCheckpointFlushDelayMs >= 0) - _ = Task.Run(FlushRunner); + { + completedTask = Task.Run(FlushRunner); + } else - FlushRunner(); + { + try + { + FlushRunner(); + completedTask = Task.CompletedTask; + } + catch (Exception ex) + { + completedTask = Task.FromException(ex); + } + } void FlushRunner() { @@ -340,7 +349,6 @@ void FlushRunner() if (destOffset > 0) deltaLog.Seal(destOffset); - _completedSemaphore.Release(); } } @@ -403,7 +411,8 @@ public virtual void Dispose() bufferPool.Free(); flushEvent.Dispose(); - notifyFlushedUntilAddressSemaphore?.Dispose(); + notifyFlushedUntilAddressTcs?.TrySetCanceled(); + notifyFlushedUntilAddressTcs = null; onReadOnlyObserver?.OnCompleted(); onEvictionObserver?.OnCompleted(); @@ -1304,15 +1313,15 @@ internal long CalculateReadOnlyAddress(long tailAddress, long headAddress) } /// Used by applications to make the current state of the database immutable quickly - public bool ShiftReadOnlyToTail(out long tailAddress, out SemaphoreSlim notifyDone) + public bool ShiftReadOnlyToTail(out long tailAddress, out Task notifyDone) { notifyDone = null; tailAddress = GetTailAddress(); var localTailAddress = tailAddress; if (MonotonicUpdate(ref ReadOnlyAddress, tailAddress, out _)) { - notifyFlushedUntilAddressSemaphore = new SemaphoreSlim(0); - notifyDone = notifyFlushedUntilAddressSemaphore; + notifyFlushedUntilAddressTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + notifyDone = notifyFlushedUntilAddressTcs.Task; notifyFlushedUntilAddress = localTailAddress; epoch.BumpCurrentEpoch(() => OnPagesMarkedReadOnly(localTailAddress)); return true; @@ -1573,7 +1582,7 @@ protected void ShiftFlushedUntilAddress() flushEvent.Set(); if ((oldFlushedUntilAddress < notifyFlushedUntilAddress) && (currentFlushedUntilAddress >= notifyFlushedUntilAddress)) - _ = notifyFlushedUntilAddressSemaphore.Release(); + _ = notifyFlushedUntilAddressTcs?.TrySetResult(true); } } @@ -1592,8 +1601,8 @@ protected void ShiftFlushedUntilAddress() /// Address for notification of flushed-until public long notifyFlushedUntilAddress; - /// Semaphore for notification of flushed-until - public SemaphoreSlim notifyFlushedUntilAddressSemaphore; + /// TaskCompletionSource for notification of flushed-until + public TaskCompletionSource notifyFlushedUntilAddressTcs; /// Reset for recovery [MethodImpl(MethodImplOptions.NoInlining)] @@ -1870,16 +1879,16 @@ public void AsyncFlushPagesForRecovery(long scanFromAddress, long flus /// /// /// - /// + /// Task that completes when all pages are flushed, or faults if an exception occurs /// [MethodImpl(MethodImplOptions.NoInlining)] public void AsyncFlushPagesForSnapshot(CircularDiskWriteBuffer flushBuffers, long startPage, long endPage, long startLogicalAddress, long endLogicalAddress, - long fuzzyStartLogicalAddress, IDevice logDevice, IDevice objectLogDevice, out SemaphoreSlim completedSemaphore, int throttleCheckpointFlushDelayMs) + long fuzzyStartLogicalAddress, IDevice logDevice, IDevice objectLogDevice, out Task completedTask, int throttleCheckpointFlushDelayMs) { logger?.LogTrace("Starting async full log flush with throttling {throttlingEnabled}", throttleCheckpointFlushDelayMs >= 0 ? $"enabled ({throttleCheckpointFlushDelayMs}ms)" : "disabled"); - var _completedSemaphore = new SemaphoreSlim(0); - completedSemaphore = _completedSemaphore; + var completionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + completedTask = completionTcs.Task; // If throttled, convert rest of the method into a truly async task run because issuing IO can take up synchronous time if (throttleCheckpointFlushDelayMs >= 0) @@ -1891,49 +1900,69 @@ void FlushRunner() { var totalNumPages = (int)(endPage - startPage); - var flushCompletionTracker = new FlushCompletionTracker(_completedSemaphore, throttleCheckpointFlushDelayMs >= 0 ? new SemaphoreSlim(0) : null, totalNumPages); + var flushCompletionTracker = new FlushCompletionTracker(completionTcs, enableThrottling: throttleCheckpointFlushDelayMs >= 0, totalNumPages); - // Flush each page in sequence - for (long flushPage = startPage; flushPage < endPage; flushPage++) + try { - // For the first page, startLogicalAddress may be in the middle of the page; for the last page, endLogicalAddress may be in the middle of the page; - // for middle pages, we flush the entire page. - var flushStartAddress = GetLogicalAddressOfStartOfPage(flushPage); - if (startLogicalAddress > flushStartAddress) - flushStartAddress = startLogicalAddress; - var flushEndAddress = GetLogicalAddressOfStartOfPage(flushPage + 1); - if (endLogicalAddress < flushEndAddress) - flushEndAddress = endLogicalAddress; - var flushSize = flushEndAddress - flushStartAddress; - if (flushSize <= 0) - continue; - - var asyncResult = new PageAsyncFlushResult + // Flush each page in sequence + for (long flushPage = startPage; flushPage < endPage; flushPage++) { - flushCompletionTracker = flushCompletionTracker, - page = flushPage, - fromAddress = flushStartAddress, - untilAddress = flushEndAddress, - count = 1, - flushRequestState = FlushRequestState.Snapshot, - flushBuffers = flushBuffers - }; - - // Intended destination is flushPage - WriteAsyncToDeviceForSnapshot(startPage, flushPage, (int)flushSize, AsyncFlushPageForSnapshotCallback, asyncResult, logDevice, objectLogDevice, fuzzyStartLogicalAddress); - - // If we did not issue a flush write (due to HeadAddress moving past flushPage), then WriteAsync set isForSnapshot false and we release the asyncResult here; - // otherwise, we wait for the completion of the flush (and the callback will release the asyncResult). - if (asyncResult.flushRequestState != FlushRequestState.WriteNotIssued) - { - if (throttleCheckpointFlushDelayMs >= 0) + // For the first page, startLogicalAddress may be in the middle of the page; for the last page, endLogicalAddress may be in the middle of the page; + // for middle pages, we flush the entire page. + var flushStartAddress = GetLogicalAddressOfStartOfPage(flushPage); + if (startLogicalAddress > flushStartAddress) + flushStartAddress = startLogicalAddress; + var flushEndAddress = GetLogicalAddressOfStartOfPage(flushPage + 1); + if (endLogicalAddress < flushEndAddress) + flushEndAddress = endLogicalAddress; + var flushSize = flushEndAddress - flushStartAddress; + if (flushSize <= 0) { + // No data to flush for this page. Signal completion and drain the + // throttle semaphore so the next real page's WaitOneFlush is not + // satisfied by this page's release. + flushCompletionTracker.CompleteFlush(); + flushCompletionTracker.WaitOneFlush(); + continue; + } + + var asyncResult = new PageAsyncFlushResult + { + flushCompletionTracker = flushCompletionTracker, + page = flushPage, + fromAddress = flushStartAddress, + untilAddress = flushEndAddress, + count = 1, + flushRequestState = FlushRequestState.Snapshot, + flushBuffers = flushBuffers + }; + + // Intended destination is flushPage + WriteAsyncToDeviceForSnapshot(startPage, flushPage, (int)flushSize, AsyncFlushPageForSnapshotCallback, asyncResult, logDevice, objectLogDevice, fuzzyStartLogicalAddress); + + // If we did not issue a flush write (due to HeadAddress moving past flushPage), then WriteAsync set isForSnapshot false and we release the asyncResult here; + // otherwise, we wait for the completion of the flush (and the callback will release the asyncResult). + if (asyncResult.flushRequestState != FlushRequestState.WriteNotIssued) + { + if (throttleCheckpointFlushDelayMs >= 0) + { + flushCompletionTracker.WaitOneFlush(); + Thread.Sleep(throttleCheckpointFlushDelayMs); + } + } + else + { + _ = asyncResult.Release(); + // Release() called CompleteFlush() which released the throttle semaphore. + // Drain it so the next real page's WaitOneFlush is not satisfied by this no-op. flushCompletionTracker.WaitOneFlush(); - Thread.Sleep(throttleCheckpointFlushDelayMs); } } - else - _ = asyncResult.Release(); + } + catch (Exception ex) + { + logger?.LogError(ex, "{method} failed while flushing snapshot pages from {startPage} to {endPage}", nameof(AsyncFlushPagesForSnapshot), startPage, endPage); + flushCompletionTracker.SetException(ex); } } } @@ -2250,6 +2279,7 @@ protected void AsyncFlushPageForSnapshotCallback(uint errorCode, uint numBytes, if (info.Dirty) info.ClearDirtyAtomic(); // there may be read locks being taken, hence atomic physicalAddress += alignedRecordSize; + startAddress += alignedRecordSize; } } } @@ -2257,9 +2287,8 @@ protected void AsyncFlushPageForSnapshotCallback(uint errorCode, uint numBytes, { if (epochTaken) epoch.Suspend(); + _ = result.Release(); } - - _ = result.Release(); } catch when (disposed) { } } diff --git a/libs/storage/Tsavorite/cs/src/core/Allocator/MallocFixedPageSize.cs b/libs/storage/Tsavorite/cs/src/core/Allocator/MallocFixedPageSize.cs index 35ff1d6baba..664e0954cda 100644 --- a/libs/storage/Tsavorite/cs/src/core/Allocator/MallocFixedPageSize.cs +++ b/libs/storage/Tsavorite/cs/src/core/Allocator/MallocFixedPageSize.cs @@ -39,7 +39,7 @@ public sealed class MallocFixedPageSize : IDisposable internal static bool IsBlittable => Utility.IsBlittable(); private int checkpointCallbackCount; - private SemaphoreSlim checkpointSemaphore; + private TaskCompletionSource checkpointTcs; private readonly ConcurrentQueue freeList; @@ -267,12 +267,10 @@ private unsafe long InternalAllocate(int blockSize) /// public async ValueTask IsCheckpointCompletedAsync(CancellationToken token = default) { - var s = checkpointSemaphore; - await s.WaitAsync(token).ConfigureAwait(false); - s.Release(); + await checkpointTcs.Task.WaitAsync(token).ConfigureAwait(false); } - public SemaphoreSlim GetCheckpointSemaphore() => checkpointSemaphore; + public Task GetCheckpointTask() => checkpointTcs.Task; /// /// Public facing persistence API @@ -299,7 +297,7 @@ internal unsafe void BeginCheckpoint(IDevice device, ulong offset, out ulong num int numCompleteLevels = localCount >> PageSizeBits; int numLevels = numCompleteLevels + (recordsCountInLastLevel > 0 ? 1 : 0); checkpointCallbackCount = numLevels; - checkpointSemaphore = new SemaphoreSlim(0); + checkpointTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); uint alignedPageSize = PageSize * (uint)RecordSize; uint lastLevelSize = (uint)recordsCountInLastLevel * (uint)RecordSize; @@ -353,7 +351,7 @@ private unsafe void AsyncFlushCallback(uint errorCode, uint numBytes, object con if (Interlocked.Decrement(ref checkpointCallbackCount) == 0) { - checkpointSemaphore.Release(); + checkpointTcs.TrySetResult(true); } } diff --git a/libs/storage/Tsavorite/cs/src/core/Allocator/ObjectAllocatorImpl.cs b/libs/storage/Tsavorite/cs/src/core/Allocator/ObjectAllocatorImpl.cs index e9dd7ba87c0..8d53bf3a2dc 100644 --- a/libs/storage/Tsavorite/cs/src/core/Allocator/ObjectAllocatorImpl.cs +++ b/libs/storage/Tsavorite/cs/src/core/Allocator/ObjectAllocatorImpl.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.Logging; namespace Tsavorite.core @@ -519,17 +520,22 @@ protected override void WriteAsyncToDeviceForSnapshot(long startPage, var epochTaken = epoch.ResumeIfNotProtected(); try { - if (HeadAddress >= asyncResult.untilAddress) + var headAddress = HeadAddress; + + if (headAddress >= asyncResult.untilAddress) { // Requested span on page is entirely unavailable in memory; ignore it and call the callback directly. callback(0, 0, asyncResult); return; } - // If requested page span is only partly available in memory, adjust the start position. WriteAsync will handle it if HeadAddress is lower, - // but this is faster. - if (HeadAddress > asyncResult.fromAddress) - asyncResult.fromAddress = HeadAddress; + // If requested page span is only partly available in memory, adjust the start position + // and mark as partial so WriteAsync recalculates the flush size from the adjusted range. + if (headAddress > asyncResult.fromAddress) + { + asyncResult.fromAddress = headAddress; + asyncResult.partial = true; + } // We are writing to a separate device which starts at startPage. Eventually, startPage becomes the basis of // HybridLogRecoveryInfo.snapshotStartFlushedLogicalAddress, which is the page starting at offset 0 of the snapshot file. @@ -732,6 +738,8 @@ private void WriteAsync(long flushPage, ulong alignedMainLogFlushPageA // which will never be less than HeadAddress. So we do not need to worry about whatever values are in the inline // record space between the current logicalAddress and HeadAddress. extraRecordOffset = (int)(headAddress - (logicalAddress + logRecordSize)); + // Skip object serialization + goto NextRecord; } else { @@ -777,6 +785,7 @@ private void WriteAsync(long flushPage, ulong alignedMainLogFlushPageA } } // endif record id Valid + NextRecord: logicalAddress += logRecordSize + extraRecordOffset; // advance in main log physicalAddress += logRecordSize + extraRecordOffset; // advance in source buffer } @@ -1067,7 +1076,7 @@ internal override void MemoryPageScan(long beginAddress, long endAddress, IObser observer?.OnNext(iter); } - internal override void AsyncFlushDeltaToDevice(CircularDiskWriteBuffer flushBuffers, long startAddress, long endAddress, long prevEndAddress, long version, DeltaLog deltaLog, out SemaphoreSlim completedSemaphore, int throttleCheckpointFlushDelayMs) + internal override void AsyncFlushDeltaToDevice(CircularDiskWriteBuffer flushBuffers, long startAddress, long endAddress, long prevEndAddress, long version, DeltaLog deltaLog, out Task completedTask, int throttleCheckpointFlushDelayMs) { throw new TsavoriteException("Incremental snapshots not supported with generic allocator"); } diff --git a/libs/storage/Tsavorite/cs/src/core/Index/CheckpointManagement/RecoveryInfo.cs b/libs/storage/Tsavorite/cs/src/core/Index/CheckpointManagement/RecoveryInfo.cs index e5663754cda..37314286d5d 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/CheckpointManagement/RecoveryInfo.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/CheckpointManagement/RecoveryInfo.cs @@ -3,7 +3,7 @@ using System; using System.IO; -using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.Logging; namespace Tsavorite.core @@ -338,7 +338,7 @@ internal struct HybridLogCheckpointInfo : IDisposable public IDevice snapshotFileObjectLogDevice; public IDevice deltaFileDevice; public DeltaLog deltaLog; - public SemaphoreSlim flushedSemaphore; + public Task flushedTask; public long prevVersion; internal CircularDiskWriteBuffer objectLogFlushBuffers; diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/FoldOverSMTask.cs b/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/FoldOverSMTask.cs index aae6d23f474..f246211cdb6 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/FoldOverSMTask.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/FoldOverSMTask.cs @@ -37,9 +37,12 @@ public override void GlobalBeforeEnteringState(SystemState next, StateMachineDri try { store.epoch.Resume(); - _ = store.hlogBase.ShiftReadOnlyToTail(out var tailAddress, out store._hybridLogCheckpoint.flushedSemaphore); - if (store._hybridLogCheckpoint.flushedSemaphore != null) - stateMachineDriver.AddToWaitingList(store._hybridLogCheckpoint.flushedSemaphore); + _ = store.hlogBase.ShiftReadOnlyToTail(out var tailAddress, out var flushedTask); + if (flushedTask != null) + { + store._hybridLogCheckpoint.flushedTask = flushedTask; + stateMachineDriver.AddToWaitingList(store._hybridLogCheckpoint.flushedTask, StateMachineTaskType.FoldOverSMTaskHybridLogFlushed); + } // Update final logical address to the flushed tail - this may not be necessary store._hybridLogCheckpoint.info.finalLogicalAddress = tailAddress; diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/IncrementalSnapshotCheckpointSMTask.cs b/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/IncrementalSnapshotCheckpointSMTask.cs index 7b3798aa2e7..89926ddfa0c 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/IncrementalSnapshotCheckpointSMTask.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/IncrementalSnapshotCheckpointSMTask.cs @@ -56,10 +56,10 @@ public override void GlobalBeforeEnteringState(SystemState next, StateMachineDri store._lastSnapshotCheckpoint.info.finalLogicalAddress, store._hybridLogCheckpoint.prevVersion, store._hybridLogCheckpoint.deltaLog, - out store._hybridLogCheckpoint.flushedSemaphore, + out store._hybridLogCheckpoint.flushedTask, store.ThrottleCheckpointFlushDelayMs); - if (store._hybridLogCheckpoint.flushedSemaphore != null) - stateMachineDriver.AddToWaitingList(store._hybridLogCheckpoint.flushedSemaphore); + if (store._hybridLogCheckpoint.flushedTask != null) + stateMachineDriver.AddToWaitingList(store._hybridLogCheckpoint.flushedTask, StateMachineTaskType.IncrementalSnapshotCheckpointSMTaskHybridLogFlushed); break; case Phase.PERSISTENCE_CALLBACK: diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/SnapshotCheckpointSMTask.cs b/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/SnapshotCheckpointSMTask.cs index c2d3e4015ca..e4ccecb2a25 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/SnapshotCheckpointSMTask.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/SnapshotCheckpointSMTask.cs @@ -67,10 +67,10 @@ public override void GlobalBeforeEnteringState(SystemState next, StateMachineDri fuzzyStartLogicalAddress: store._hybridLogCheckpoint.info.startLogicalAddress, logDevice: store._hybridLogCheckpoint.snapshotFileDevice, objectLogDevice: store._hybridLogCheckpoint.snapshotFileObjectLogDevice, - out store._hybridLogCheckpoint.flushedSemaphore, + out store._hybridLogCheckpoint.flushedTask, store.ThrottleCheckpointFlushDelayMs); - if (store._hybridLogCheckpoint.flushedSemaphore != null) - stateMachineDriver.AddToWaitingList(store._hybridLogCheckpoint.flushedSemaphore); + if (store._hybridLogCheckpoint.flushedTask != null) + stateMachineDriver.AddToWaitingList(store._hybridLogCheckpoint.flushedTask, StateMachineTaskType.SnapshotCheckpointSMTaskHybridLogFlushed); break; case Phase.PERSISTENCE_CALLBACK: diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/StateMachineDriver.cs b/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/StateMachineDriver.cs index a03f631bf47..571ef07911d 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/StateMachineDriver.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/StateMachineDriver.cs @@ -17,7 +17,7 @@ public class StateMachineDriver { SystemState systemState; IStateMachine stateMachine; - readonly List waitingList; + readonly List<(Task task, StateMachineTaskType type)> waitingList; TaskCompletionSource stateMachineCompleted; // All threads have entered the given state SemaphoreSlim waitForTransitionIn; @@ -26,7 +26,7 @@ public class StateMachineDriver SemaphoreSlim waitForTransitionOut; // Transactions drained in last version long lastVersion; - SemaphoreSlim lastVersionTransactionsDone; + TaskCompletionSource lastVersionTransactionsDone; List callbacks; readonly LightEpoch epoch; readonly ILogger logger; @@ -59,7 +59,7 @@ void DecrementActiveTransactions(long txnVersion) var _lastVersionTransactionsDone = lastVersionTransactionsDone; if (_lastVersionTransactionsDone != null && txnVersion == lastVersion) { - _lastVersionTransactionsDone.Release(); + _lastVersionTransactionsDone.TrySetResult(true); } } } @@ -68,19 +68,19 @@ internal void TrackLastVersion(long version) { if (GetNumActiveTransactions(version) > 0) { - // Set version number first, then create semaphore + // Set version number first, then create TCS lastVersion = version; - lastVersionTransactionsDone = new(0); + lastVersionTransactionsDone = new(TaskCreationOptions.RunContinuationsAsynchronously); } // We have to re-check the number of active transactions after assigning lastVersion and lastVersionTransactionsDone if (GetNumActiveTransactions(version) > 0) - AddToWaitingList(lastVersionTransactionsDone); + AddToWaitingList(lastVersionTransactionsDone.Task, StateMachineTaskType.LastVersionTransactionsDone); } internal void ResetLastVersion() { - // First null semaphore, then reset version number + // First null TCS, then reset version number lastVersionTransactionsDone = null; lastVersion = 0; } @@ -155,10 +155,10 @@ public long VerifyTransactionVersion(long txnVersion) public void EndTransaction(long txnVersion) => DecrementActiveTransactions(txnVersion); - internal void AddToWaitingList(SemaphoreSlim waiter) + internal void AddToWaitingList(Task waiter, StateMachineTaskType type) { if (waiter != null) - waitingList.Add(waiter); + waitingList.Add((waiter, type)); } public bool Register(IStateMachine stateMachine, CancellationToken token = default) @@ -309,9 +309,17 @@ async Task ProcessWaitingListAsync(CancellationToken token = default) { throw waitForTransitionInException; } - foreach (var waiter in waitingList) + foreach (var (task, type) in waitingList) { - await waiter.WaitAsync(token); + try + { + await task.WaitAsync(token).ConfigureAwait(false); + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + logger?.LogError(ex, "State machine task '{type}' faulted", type); + throw; + } } waitingList.Clear(); } diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/StateMachineTaskType.cs b/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/StateMachineTaskType.cs new file mode 100644 index 00000000000..c400aabe6b3 --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Index/Checkpointing/StateMachineTaskType.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +namespace Tsavorite.core +{ + /// + /// Identifies the type of waiter added to the state machine driver waiting list, + /// including the originating state machine task. + /// + internal enum StateMachineTaskType + { + /// + /// Waiting for all transactions in the last version to complete. + /// + LastVersionTransactionsDone, + + /// + /// Waiting for the main index checkpoint to complete (IndexCheckpointSMTask). + /// + IndexCheckpointSMTaskMainIndexCheckpoint, + + /// + /// Waiting for the overflow buckets checkpoint to complete (IndexCheckpointSMTask). + /// + IndexCheckpointSMTaskOverflowBucketsCheckpoint, + + /// + /// Waiting for the hybrid log flush to complete (FoldOverSMTask). + /// + FoldOverSMTaskHybridLogFlushed, + + /// + /// Waiting for the hybrid log flush to complete (IncrementalSnapshotCheckpointSMTask). + /// + IncrementalSnapshotCheckpointSMTaskHybridLogFlushed, + + /// + /// Waiting for the hybrid log flush to complete (SnapshotCheckpointSMTask). + /// + SnapshotCheckpointSMTaskHybridLogFlushed, + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Recovery/IndexCheckpoint.cs b/libs/storage/Tsavorite/cs/src/core/Index/Recovery/IndexCheckpoint.cs index a2d78b88694..6060ab5f21e 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Recovery/IndexCheckpoint.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Recovery/IndexCheckpoint.cs @@ -54,8 +54,8 @@ internal bool IsIndexFuzzyCheckpointCompleted() internal void AddIndexCheckpointWaitingList(StateMachineDriver stateMachineDriver) { - stateMachineDriver.AddToWaitingList(mainIndexCheckpointSemaphore); - stateMachineDriver.AddToWaitingList(overflowBucketsAllocator.GetCheckpointSemaphore()); + stateMachineDriver.AddToWaitingList(mainIndexCheckpointTcs.Task, StateMachineTaskType.IndexCheckpointSMTaskMainIndexCheckpoint); + stateMachineDriver.AddToWaitingList(overflowBucketsAllocator.GetCheckpointTask(), StateMachineTaskType.IndexCheckpointSMTaskOverflowBucketsCheckpoint); } internal async ValueTask IsIndexFuzzyCheckpointCompletedAsync(CancellationToken token = default) @@ -71,14 +71,14 @@ internal async ValueTask IsIndexFuzzyCheckpointCompletedAsync(CancellationToken // Implementation of an asynchronous checkpointing scheme // for main hash index of Tsavorite private int mainIndexCheckpointCallbackCount; - private SemaphoreSlim mainIndexCheckpointSemaphore; + private TaskCompletionSource mainIndexCheckpointTcs; private SemaphoreSlim throttleIndexCheckpointFlushSemaphore; internal unsafe void BeginMainIndexCheckpoint(int version, IDevice device, out ulong numBytesWritten, bool useReadCache = false, SkipReadCache skipReadCache = default, int throttleCheckpointFlushDelayMs = -1) { long totalSize = state[version].size * sizeof(HashBucket); numBytesWritten = (ulong)totalSize; - mainIndexCheckpointSemaphore = new SemaphoreSlim(0); + mainIndexCheckpointTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); if (throttleCheckpointFlushDelayMs >= 0) Task.Run(FlushRunner); @@ -87,64 +87,72 @@ internal unsafe void BeginMainIndexCheckpoint(int version, IDevice device, out u void FlushRunner() { - int numChunks = 1; - if (useReadCache && (totalSize > (1L << 25))) + try { - numChunks = (int)Math.Ceiling((double)totalSize / (1L << 25)); - numChunks = (int)Math.Pow(2, Math.Ceiling(Math.Log(numChunks, 2))); - } - else if (totalSize > uint.MaxValue) - { - numChunks = (int)Math.Ceiling((double)totalSize / (long)uint.MaxValue); - numChunks = (int)Math.Pow(2, Math.Ceiling(Math.Log(numChunks, 2))); - } + int numChunks = 1; + if (useReadCache && (totalSize > (1L << 25))) + { + numChunks = (int)Math.Ceiling((double)totalSize / (1L << 25)); + numChunks = (int)Math.Pow(2, Math.Ceiling(Math.Log(numChunks, 2))); + } + else if (totalSize > uint.MaxValue) + { + numChunks = (int)Math.Ceiling((double)totalSize / (long)uint.MaxValue); + numChunks = (int)Math.Pow(2, Math.Ceiling(Math.Log(numChunks, 2))); + } - uint chunkSize = (uint)(totalSize / numChunks); - mainIndexCheckpointCallbackCount = numChunks; + uint chunkSize = (uint)(totalSize / numChunks); + mainIndexCheckpointCallbackCount = numChunks; - if (throttleCheckpointFlushDelayMs >= 0) - throttleIndexCheckpointFlushSemaphore = new SemaphoreSlim(0); - HashBucket* start = state[version].tableAligned; + if (throttleCheckpointFlushDelayMs >= 0) + throttleIndexCheckpointFlushSemaphore = new SemaphoreSlim(0); + HashBucket* start = state[version].tableAligned; - ulong numBytesWritten = 0; - for (int index = 0; index < numChunks; index++) - { - IntPtr chunkStartBucket = (IntPtr)((byte*)start + (index * chunkSize)); - HashIndexPageAsyncFlushResult result = default; - result.chunkIndex = index; - if (!useReadCache) - { - device.WriteAsync(chunkStartBucket, numBytesWritten, chunkSize, AsyncPageFlushCallback, result); - } - else + ulong numBytesWritten = 0; + for (int index = 0; index < numChunks; index++) { - result.mem = new SectorAlignedMemory((int)chunkSize, (int)device.SectorSize); - bool prot = false; - if (!epoch.ThisInstanceProtected()) + IntPtr chunkStartBucket = (IntPtr)((byte*)start + (index * chunkSize)); + HashIndexPageAsyncFlushResult result = default; + result.chunkIndex = index; + if (!useReadCache) { - prot = true; - epoch.Resume(); + device.WriteAsync(chunkStartBucket, numBytesWritten, chunkSize, AsyncPageFlushCallback, result); } - Buffer.MemoryCopy((void*)chunkStartBucket, result.mem.aligned_pointer, chunkSize, chunkSize); - for (int j = 0; j < chunkSize; j += sizeof(HashBucket)) + else { - skipReadCache((HashBucket*)(result.mem.aligned_pointer + j)); + result.mem = new SectorAlignedMemory((int)chunkSize, (int)device.SectorSize); + bool prot = false; + if (!epoch.ThisInstanceProtected()) + { + prot = true; + epoch.Resume(); + } + Buffer.MemoryCopy((void*)chunkStartBucket, result.mem.aligned_pointer, chunkSize, chunkSize); + for (int j = 0; j < chunkSize; j += sizeof(HashBucket)) + { + skipReadCache((HashBucket*)(result.mem.aligned_pointer + j)); + } + if (prot) + epoch.Suspend(); + + device.WriteAsync((IntPtr)result.mem.aligned_pointer, numBytesWritten, chunkSize, AsyncPageFlushCallback, result); } - if (prot) - epoch.Suspend(); - - device.WriteAsync((IntPtr)result.mem.aligned_pointer, numBytesWritten, chunkSize, AsyncPageFlushCallback, result); - } - if (throttleCheckpointFlushDelayMs >= 0) - { - throttleIndexCheckpointFlushSemaphore.Wait(); - Thread.Sleep(throttleCheckpointFlushDelayMs); + if (throttleCheckpointFlushDelayMs >= 0) + { + throttleIndexCheckpointFlushSemaphore.Wait(); + Thread.Sleep(throttleCheckpointFlushDelayMs); + } + numBytesWritten += chunkSize; } - numBytesWritten += chunkSize; - } - Debug.Assert(numBytesWritten == (ulong)totalSize); - throttleIndexCheckpointFlushSemaphore = null; + Debug.Assert(numBytesWritten == (ulong)totalSize); + throttleIndexCheckpointFlushSemaphore = null; + } + catch (Exception ex) + { + logger?.LogError(ex, "{method} failed while flushing index checkpoint", nameof(BeginMainIndexCheckpoint)); + mainIndexCheckpointTcs.TrySetException(ex); + } } } @@ -155,9 +163,7 @@ private bool IsMainIndexCheckpointCompleted() private async ValueTask IsMainIndexCheckpointCompletedAsync(CancellationToken token = default) { - var s = mainIndexCheckpointSemaphore; - await s.WaitAsync(token).ConfigureAwait(false); - s.Release(); + await mainIndexCheckpointTcs.Task.WaitAsync(token).ConfigureAwait(false); } private unsafe void AsyncPageFlushCallback(uint errorCode, uint numBytes, object context) @@ -172,7 +178,7 @@ private unsafe void AsyncPageFlushCallback(uint errorCode, uint numBytes, object } if (Interlocked.Decrement(ref mainIndexCheckpointCallbackCount) == 0) { - mainIndexCheckpointSemaphore.Release(); + mainIndexCheckpointTcs.TrySetResult(true); } throttleIndexCheckpointFlushSemaphore?.Release(); } diff --git a/libs/storage/Tsavorite/cs/src/core/Utilities/FlushCompletionTracker.cs b/libs/storage/Tsavorite/cs/src/core/Utilities/FlushCompletionTracker.cs new file mode 100644 index 00000000000..2e844c24fb3 --- /dev/null +++ b/libs/storage/Tsavorite/cs/src/core/Utilities/FlushCompletionTracker.cs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Tsavorite.core +{ + /// + /// Tracks the completion of page flush operations during snapshot checkpoints. + /// Signals a when all pages have been flushed, + /// or faults it if an exception occurs. Optionally supports per-page throttle waiting. + /// + internal sealed class FlushCompletionTracker + { + /// + /// Task completion source to signal when all page flushes are done, or to fault on error. + /// + readonly TaskCompletionSource completionTcs; + + /// + /// Semaphore for per-page flush completion, used only when throttling is enabled. + /// + readonly SemaphoreSlim flushSemaphore; + + /// + /// Number of pages being flushed + /// + int count; + + public override string ToString() + { + var flushSemCount = flushSemaphore?.CurrentCount.ToString() ?? "null"; + return $"count {count}, flushSemCount {flushSemCount}"; + } + + /// + /// Create a flush completion tracker + /// + /// TaskCompletionSource to signal when all flushes complete or to fault on error + /// If true, creates a semaphore for per-page throttle waiting + /// Number of pages to flush + public FlushCompletionTracker(TaskCompletionSource completionTcs, bool enableThrottling, int count) + { + this.completionTcs = completionTcs; + this.flushSemaphore = enableThrottling ? new SemaphoreSlim(0) : null; + this.count = count; + + if (count == 0) + _ = completionTcs.TrySetResult(true); + } + + /// + /// Complete flush of one page + /// + public void CompleteFlush() + { + _ = (flushSemaphore?.Release()); + if (Interlocked.Decrement(ref count) == 0) + _ = completionTcs.TrySetResult(true); + } + + /// + /// Signal that the flush failed with an exception. + /// + public void SetException(Exception ex) + => _ = completionTcs.TrySetException(ex); + + /// + /// Wait for one page flush to complete. Only valid when throttling is enabled. + /// + public void WaitOneFlush() => flushSemaphore?.Wait(); + } +} \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Utilities/PageAsyncResultTypes.cs b/libs/storage/Tsavorite/cs/src/core/Utilities/PageAsyncResultTypes.cs index 3da7b85e19c..9bff4c26769 100644 --- a/libs/storage/Tsavorite/cs/src/core/Utilities/PageAsyncResultTypes.cs +++ b/libs/storage/Tsavorite/cs/src/core/Utilities/PageAsyncResultTypes.cs @@ -72,56 +72,6 @@ public void DisposeHandle() /// /// Shared flush completion tracker, when bulk-flushing many pages /// - internal sealed class FlushCompletionTracker - { - /// - /// Semaphore to set on flush completion - /// - readonly SemaphoreSlim completedSemaphore; - - /// - /// Semaphore to wait on for flush completion - /// - readonly SemaphoreSlim flushSemaphore; - - /// - /// Number of pages being flushed - /// - int count; - - public override string ToString() - { - var compSemCount = completedSemaphore?.CurrentCount.ToString() ?? "null"; - var flushSemCount = completedSemaphore?.CurrentCount.ToString() ?? "null"; - return $"count {count}, compSemCount {compSemCount}, flushSemCount {flushSemCount}"; - } - - /// - /// Create a flush completion tracker - /// - /// Semaphpore to release when all flushes completed - /// Semaphpore to release when each flush completes - /// Number of pages to flush - public FlushCompletionTracker(SemaphoreSlim completedSemaphore, SemaphoreSlim flushSemaphore, int count) - { - this.completedSemaphore = completedSemaphore; - this.flushSemaphore = flushSemaphore; - this.count = count; - } - - /// - /// Complete flush of one page - /// - public void CompleteFlush() - { - _ = (flushSemaphore?.Release()); - if (Interlocked.Decrement(ref count) == 0) - _ = completedSemaphore.Release(); - } - - public void WaitOneFlush() => flushSemaphore?.Wait(); - } - internal enum FlushRequestState : byte { /// The default; we are here for flush. This diff --git a/libs/storage/Tsavorite/cs/test/ObjectTests.cs b/libs/storage/Tsavorite/cs/test/ObjectTests.cs index 92ea2f3ae18..3931b830335 100644 --- a/libs/storage/Tsavorite/cs/test/ObjectTests.cs +++ b/libs/storage/Tsavorite/cs/test/ObjectTests.cs @@ -327,8 +327,8 @@ public void LargeObjectMultiFlushedPages([Values(SerializeKeyValueSize.Thirty, S store.epoch.Resume(); try { - Assert.That(store.hlogBase.ShiftReadOnlyToTail(out _, out var sroSemaphore), Is.True); - sroSemaphore.Wait(); + Assert.That(store.hlogBase.ShiftReadOnlyToTail(out _, out var sroTask), Is.True); + sroTask.Wait(); } finally { @@ -417,11 +417,11 @@ public async Task LargeObjectLinearizeFlushedPages([Values(SerializeKeyValueSize // We have to wait for this outside the epoch to avoid deadlock. gate.Wait(); - SemaphoreSlim sroSemaphore; + Task sroTask; store.epoch.Resume(); try { - Assert.That(store.hlogBase.ShiftReadOnlyToTail(out _, out sroSemaphore), Is.True); + Assert.That(store.hlogBase.ShiftReadOnlyToTail(out _, out sroTask), Is.True); } finally { @@ -429,7 +429,7 @@ public async Task LargeObjectLinearizeFlushedPages([Values(SerializeKeyValueSize } gate.Dispose(); - await Task.WhenAll(task, sroSemaphore.WaitAsync(millisecondsTimeout: 2000)); + await Task.WhenAll(task, sroTask); // Test that the FlushedUntilAddress is correct and that we get the right results back; nothing has been evicted yet, so all records are in memory. Assert.That(store.hlogBase.FlushedUntilAddress, Is.EqualTo(store.hlogBase.GetTailAddress())); diff --git a/test/Garnet.test.cluster/ClusterTestContext.cs b/test/Garnet.test.cluster/ClusterTestContext.cs index 449799d5444..b310e54ef46 100644 --- a/test/Garnet.test.cluster/ClusterTestContext.cs +++ b/test/Garnet.test.cluster/ClusterTestContext.cs @@ -769,31 +769,31 @@ public void ClusterFailoverSpinWait(int replicaNodeIndex, ILogger logger) } } - public void AttachAndWaitForSync(int primary_count, int replica_count, bool disableObjects) + public void AttachAndWaitForSync(int primaryIndex, int replicaStartIndex, int replicaCount, bool disableObjects) { - var primaryId = clusterTestUtils.GetNodeIdFromNode(0, logger); + var primaryId = clusterTestUtils.GetNodeIdFromNode(primaryIndex, logger); // Wait until primary node is known so as not to fail replicate - for (var i = primary_count; i < primary_count + replica_count; i++) + for (var i = replicaStartIndex; i < replicaStartIndex + replicaCount; i++) clusterTestUtils.WaitUntilNodeIdIsKnown(i, primaryId, logger: logger); // Issue cluster replicate and bump epoch manually to capture config. - for (var i = primary_count; i < primary_count + replica_count; i++) + for (var i = replicaStartIndex; i < replicaStartIndex + replicaCount; i++) _ = clusterTestUtils.ClusterReplicate(i, primaryId, async: true, logger: logger); if (!checkpointTask.Wait(TimeSpan.FromSeconds(100))) Assert.Fail("Checkpoint task timeout"); // Wait for recovery and AofSync - for (var i = primary_count; i < replica_count; i++) + for (var i = replicaStartIndex; i < replicaStartIndex + replicaCount; i++) { clusterTestUtils.WaitForReplicaRecovery(i, logger); - clusterTestUtils.WaitForReplicaAofSync(0, i, logger); + clusterTestUtils.WaitForReplicaAofSync(primaryIndex, i, logger); } - clusterTestUtils.WaitForConnectedReplicaCount(0, replica_count, logger: logger); + clusterTestUtils.WaitForConnectedReplicaCount(primaryIndex, replicaCount, logger: logger); // Validate data on replicas - for (var i = primary_count; i < replica_count; i++) + for (var i = replicaStartIndex; i < replicaStartIndex + replicaCount; i++) { if (disableObjects) ValidateKVCollectionAgainstReplica(ref kvPairs, i); diff --git a/test/Garnet.test.cluster/ClusterTestUtils.cs b/test/Garnet.test.cluster/ClusterTestUtils.cs index 534f19c4c91..08e67bee676 100644 --- a/test/Garnet.test.cluster/ClusterTestUtils.cs +++ b/test/Garnet.test.cluster/ClusterTestUtils.cs @@ -3068,26 +3068,32 @@ public void Checkpoint(int nodeIndex, ILogger logger = null) public void Checkpoint(IPEndPoint endPoint, ILogger logger = null) { + const int maxRetries = 10; var server = redis.GetServer(endPoint); - try + for (var attempt = 0; ; attempt++) { - var previousSaveTicks = (long)server.Execute("LASTSAVE"); + try + { #pragma warning disable CS0618 // Type or member is obsolete - server.Save(SaveType.ForegroundSave); + server.Save(SaveType.ForegroundSave); #pragma warning restore CS0618 // Type or member is obsolete + break; + } + catch (RedisServerException ex) when (ex.Message.Contains("checkpoint already in progress", StringComparison.OrdinalIgnoreCase)) + { + if (attempt >= maxRetries) + Assert.Fail($"Checkpoint still in progress after {maxRetries} retries"); - //// Spin wait for checkpoint to complete - //while (true) - //{ - // var lastSaveTicks = (long)server.Execute("LASTSAVE"); - // if (previousSaveTicks < lastSaveTicks) break; - // BackOff(TimeSpan.FromSeconds(1)); - //} - } - catch (Exception ex) - { - logger?.LogError(ex, "An error has occurred; StoreWrapper.Checkpoint"); - Assert.Fail(); + // Another checkpoint is in progress (e.g., on-demand checkpoint from replication). + // Retry after a short delay. + logger?.LogWarning(ex, "Checkpoint already in progress, retrying (attempt {attempt})", attempt); + BackOff(cancellationToken: context?.cts?.Token ?? CancellationToken.None); + } + catch (Exception ex) + { + logger?.LogError(ex, "An error has occurred; StoreWrapper.Checkpoint"); + Assert.Fail(ex.Message); + } } } diff --git a/test/Garnet.test.cluster/ReplicationTests/ClusterReplicationBaseTests.cs b/test/Garnet.test.cluster/ReplicationTests/ClusterReplicationBaseTests.cs index d7007f1b6a5..9fabda2bf8b 100644 --- a/test/Garnet.test.cluster/ReplicationTests/ClusterReplicationBaseTests.cs +++ b/test/Garnet.test.cluster/ReplicationTests/ClusterReplicationBaseTests.cs @@ -743,7 +743,7 @@ public void ClusterReplicationCheckpointCleanupTest([Values] bool performRMW, [V var slotRangesStr = string.Join(",", myself.Slots.Select(x => $"({x.From}-{x.To})").ToList()); ClassicAssert.AreEqual(1, myself.Slots.Count, $"Setup failed slot ranges count greater than 1 {slotRangesStr}"); - var shards = context.clusterTestUtils.ClusterShards(0, context.logger); + var shards = context.clusterTestUtils.ClusterShards(primaryIndex, context.logger); ClassicAssert.AreEqual(2, shards.Count); ClassicAssert.AreEqual(1, shards[0].slotRanges.Count); ClassicAssert.AreEqual(0, shards[0].slotRanges[0].Item1); @@ -752,13 +752,11 @@ public void ClusterReplicationCheckpointCleanupTest([Values] bool performRMW, [V context.kvPairs = []; context.kvPairsObj = []; context.checkpointTask = Task.Run(() => context.PopulatePrimaryAndTakeCheckpointTask(performRMW, disableObjects, takeCheckpoint: true)); - var attachReplicaTask = Task.Run(() => context.AttachAndWaitForSync(primary_count, replica_count, disableObjects)); - - if (!context.checkpointTask.Wait(TimeSpan.FromSeconds(60))) - Assert.Fail("checkpointTask timeout"); + var attachReplicaTask = Task.Run(() => context.AttachAndWaitForSync(primaryIndex, primary_count, replica_count, disableObjects)); - if (!attachReplicaTask.Wait(TimeSpan.FromSeconds(60))) - Assert.Fail("attachReplicaTask timeout"); + var tasks = new Task[] { context.checkpointTask, attachReplicaTask }; + if (!Task.WhenAll(tasks).Wait(TimeSpan.FromSeconds(60))) + Assert.Fail($"Task timeout - checkpointTask: {context.checkpointTask.Status}, attachReplicaTask: {attachReplicaTask.Status}"); context.clusterTestUtils.WaitForReplicaAofSync(primaryIndex: primaryIndex, secondaryIndex: replicaIndex, logger: context.logger); }