diff --git a/inmem_store.go b/inmem_store.go index 026a9702..7504d5af 100644 --- a/inmem_store.go +++ b/inmem_store.go @@ -135,10 +135,6 @@ func (i *InmemStore) GetUint64(key []byte) (uint64, error) { return i.kvInt[string(key)], nil } -type commitIndexTrackingLog struct { - log *Log - CommitIndex uint64 -} type InmemCommitTrackingStore struct { InmemStore commitIndex atomic.Uint64 diff --git a/raft_test.go b/raft_test.go index f4f28acc..40d3074f 100644 --- a/raft_test.go +++ b/raft_test.go @@ -1,4 +1,4 @@ -// Copyright IBM Corp. 2013, 2025 +// Copyright IBM Corp. 2013, 2026 // SPDX-License-Identifier: MPL-2.0 package raft @@ -1021,10 +1021,10 @@ func TestRaft_RestoreSnapshotOnStartup_Monotonic(t *testing.T) { conf := inmemConfig(t) conf.TrailingLogs = 10 opts := &MakeClusterOpts{ - Peers: 1, - Bootstrap: true, - Conf: conf, - MonotonicLogs: true, + Peers: 1, + Bootstrap: true, + Conf: conf, + LogstoreWrapperFunc: NewMockMonotonicLogStore, } c := MakeClusterCustom(t, opts) defer c.Close() @@ -1262,6 +1262,9 @@ func TestRaft_RestoreCommittedLogs(t *testing.T) { t.Fatal("err: raft log store does not implement CommitTrackingLogStore interface") } commitIdx, err := store.GetCommitIndex() + if err != nil { + t.Fatalf("err: %v", err) + } // We should have applied all committed logs if last := r.getLastApplied(); last != commitIdx { t.Fatalf("bad last index: %d, expecting %d", last, commitIdx) @@ -1642,10 +1645,10 @@ func snapshotAndRestore(t *testing.T, offset uint64, monotonicLogStore bool, res var c *cluster numPeers := 3 optsMonotonic := &MakeClusterOpts{ - Peers: numPeers, - Bootstrap: true, - Conf: conf, - MonotonicLogs: true, + Peers: numPeers, + Bootstrap: true, + Conf: conf, + LogstoreWrapperFunc: NewMockMonotonicLogStore, } if monotonicLogStore { c = MakeClusterCustom(t, optsMonotonic) @@ -2926,7 +2929,7 @@ func TestRaft_LogStoreIsMonotonic(t *testing.T) { // Now create a new MockMonotonicLogStore using the leader logs and expect // it to work. - log = &MockMonotonicLogStore{s: leader.logs} + log = NewMockMonotonicLogStore(leader.logs) mcast, ok = log.(MonotonicLogStore) require.True(t, ok) assert.True(t, mcast.IsMonotonic()) diff --git a/replication.go b/replication.go index d48a8c83..44c0b9ef 100644 --- a/replication.go +++ b/replication.go @@ -1,4 +1,4 @@ -// Copyright IBM Corp. 2013, 2025 +// Copyright IBM Corp. 2013, 2026 // SPDX-License-Identifier: MPL-2.0 package raft @@ -10,7 +10,7 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/go-metrics/compat" + metrics "github.com/hashicorp/go-metrics/compat" ) const ( @@ -73,6 +73,14 @@ type followerReplication struct { // lastContactLock protects 'lastContact'. lastContactLock sync.RWMutex + // lastReplicationStart is updated to the current time whenever a + // replicateTo method call is started, and is cleared when the replicateTo + // method returns. This is used by heartbeats to check if replication has + // stalled for too long. + lastReplicationStart time.Time + // lastReplicationStartLock protects 'lastReplicationStart'. + lastReplicationStartLock sync.RWMutex + // failures counts the number of failed RPCs since the last success, which is // used to apply backoff. failures uint64 @@ -132,6 +140,30 @@ func (s *followerReplication) setLastContact() { s.lastContactLock.Unlock() } +// resetLastReplicationStart clears the marker for the start of the last replication +// attempt +func (s *followerReplication) resetLastReplicationStart() { + s.lastReplicationStartLock.Lock() + s.lastReplicationStart = time.Time{} + s.lastReplicationStartLock.Unlock() +} + +// setLastReplicationStart sets the marker for the start of the last replication +// attempt to the current time +func (s *followerReplication) setLastReplicationStart() { + s.lastReplicationStartLock.Lock() + s.lastReplicationStart = time.Now() + s.lastReplicationStartLock.Unlock() +} + +// getLastReplicationStart gets the start of the last replication attempt +func (s *followerReplication) getLastReplicationStart() time.Time { + s.lastReplicationStartLock.RLock() + t := s.lastReplicationStart + s.lastReplicationStartLock.RUnlock() + return t +} + // replicate is a long running routine that replicates log entries to a single // follower. func (r *Raft) replicate(s *followerReplication) { @@ -140,17 +172,22 @@ func (r *Raft) replicate(s *followerReplication) { defer close(stopHeartbeat) r.goFunc(func() { r.heartbeat(s, stopHeartbeat) }) + defer s.resetLastReplicationStart() + RPC: shouldStop := false for !shouldStop { + s.resetLastReplicationStart() select { case maxIndex := <-s.stopCh: // Make a best effort to replicate up to this index if maxIndex > 0 { + s.setLastReplicationStart() r.replicateTo(s, maxIndex) } return case deferErr := <-s.triggerDeferErrorCh: + s.setLastReplicationStart() lastLogIdx, _ := r.getLastLog() shouldStop = r.replicateTo(s, lastLogIdx) if !shouldStop { @@ -159,6 +196,7 @@ RPC: deferErr.respond(fmt.Errorf("replication failed")) } case <-s.triggerCh: + s.setLastReplicationStart() lastLogIdx, _ := r.getLastLog() shouldStop = r.replicateTo(s, lastLogIdx) // This is _not_ our heartbeat mechanism but is to ensure @@ -167,12 +205,14 @@ RPC: // can't do this to keep them unblocked by disk IO on the // follower. See https://github.com/hashicorp/raft/issues/282. case <-randomTimeout(r.config().CommitTimeout): + s.setLastReplicationStart() lastLogIdx, _ := r.getLastLog() shouldStop = r.replicateTo(s, lastLogIdx) } // If things looks healthy, switch to pipeline mode if !shouldStop && s.allowPipeline { + s.resetLastReplicationStart() goto PIPELINE } } @@ -409,6 +449,30 @@ func (r *Raft) heartbeat(s *followerReplication, stopCh chan struct{}) { s.peerLock.RUnlock() start := time.Now() + + lastReplicationStart := s.getLastReplicationStart() + if !lastReplicationStart.IsZero() { + maxLastReplication := r.config().HeartbeatTimeout * 10 + if lastReplicationStart.Add(maxLastReplication).Before(start) { + r.logger.Warn("delaying heartbeat for peer because replication is stalled", + "peer", peer.Address, + "timeout", maxLastReplication, + "replication_started", lastReplicationStart, + ) + // Replication has been stalled for too long. Delay the next + // heartbeat to allow a follower to take over, but don't exit the + // loop yet in case replication unblocks during the delay + select { + case <-s.notifyCh: + case <-randomTimeout(r.config().HeartbeatTimeout): + case <-stopCh: + return + } + + continue + } + } + if err := r.trans.AppendEntries(peer.ID, peer.Address, &req, &resp); err != nil { nextBackoffTime := cappedExponentialBackoff(failureWait, failures, maxFailureScale, r.config().HeartbeatTimeout/2) r.logger.Error("failed to heartbeat to", "peer", peer.Address, "backoff time", @@ -472,16 +536,19 @@ func (r *Raft) pipelineReplicate(s *followerReplication) error { shouldStop := false SEND: for !shouldStop { + s.resetLastReplicationStart() select { case <-finishCh: break SEND case maxIndex := <-s.stopCh: // Make a best effort to replicate up to this index if maxIndex > 0 { + s.setLastReplicationStart() r.pipelineSend(s, pipeline, &nextIndex, maxIndex) } break SEND case deferErr := <-s.triggerDeferErrorCh: + s.setLastReplicationStart() lastLogIdx, _ := r.getLastLog() shouldStop = r.pipelineSend(s, pipeline, &nextIndex, lastLogIdx) if !shouldStop { @@ -490,13 +557,16 @@ SEND: deferErr.respond(fmt.Errorf("replication failed")) } case <-s.triggerCh: + s.setLastReplicationStart() lastLogIdx, _ := r.getLastLog() shouldStop = r.pipelineSend(s, pipeline, &nextIndex, lastLogIdx) case <-randomTimeout(r.config().CommitTimeout): lastLogIdx, _ := r.getLastLog() + s.setLastReplicationStart() shouldStop = r.pipelineSend(s, pipeline, &nextIndex, lastLogIdx) } } + s.resetLastReplicationStart() // Stop our decoder, and wait for it to finish close(stopCh) diff --git a/replication_test.go b/replication_test.go new file mode 100644 index 00000000..f4ec63bc --- /dev/null +++ b/replication_test.go @@ -0,0 +1,147 @@ +// Copyright IBM Corp. 2013, 2026 +// SPDX-License-Identifier: MPL-2.0 + +package raft + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestLeaderHeartbeatsWithStalledDisk(t *testing.T) { + c := MakeClusterCustom(t, &MakeClusterOpts{ + Peers: 3, + Bootstrap: true, + LogstoreWrapperFunc: newBlockingLogStore, + }) + t.Cleanup(c.Close) + + go func() { + i := 0 + ticker := time.NewTicker(500 * time.Millisecond) + for { + select { + case <-t.Context().Done(): + return + case <-ticker.C: + i++ + leader := c.Leader() + fut := leader.Apply(fmt.Appendf([]byte{}, "test%d", i), 0) + if err := fut.Error(); err != nil { + t.Logf("got error trying to write: %v", err) + } else { + t.Logf("write %d ok", i) + } + } + } + }() + + t.Log("waiting for 5 seconds before partitioning leader") + time.Sleep(time.Second * 5) + + oldLeader := c.Leader() + oldLeaderID := oldLeader.leaderID + oldLeaderTerm := oldLeader.getCurrentTerm() + c.Partition([]ServerAddress{c.Leader().localAddr}) + + var newLeaderTerm uint64 + ctx, cancel := context.WithTimeout(t.Context(), 10*c.propagateTimeout) + t.Cleanup(cancel) +DONE: + for { + select { + case <-ctx.Done(): + t.Fatal("election didn't happen!") + default: + newLeader := c.Leader() + if newLeader.leaderID != oldLeaderID { + t.Log("leader has stepped down!") + newLeaderTerm = newLeader.getCurrentTerm() + require.NotEqual(t, newLeaderTerm, oldLeaderTerm) + cancel() + break DONE + } + } + } + + require.Len(t, c.WaitForFollowers(1), 1) + + t.Log("leader was elected. healing parition") + + // reconnect the partitioned node + c.FullyConnect() + time.Sleep(3 * c.propagateTimeout) + + leaderTerm := c.Leader().getCurrentTerm() + require.Equal(t, newLeaderTerm, leaderTerm) + + t.Log("blocking disk on leader") + + leader := c.Leader() + leaderID := leader.leaderID + leaderStore := leader.logs.(*blockingLogStore) + leaderStore.block() + t.Cleanup(leaderStore.unblock) + + ctx, cancel = context.WithTimeout(t.Context(), 5*time.Second) + t.Cleanup(cancel) + for { + select { + case <-ctx.Done(): + t.Fatal("leader did not step down!") + default: + if c.Leader().leaderID != leaderID { + t.Log("leader has stepped down!") + return + } + } + } +} + +// blockingLogStore wraps a LogStore and blocks GetLog calls on demand, +// simulating disk IO stalls. +type blockingLogStore struct { + LogStore + mu sync.Mutex + blocked atomic.Bool + unblockC chan struct{} +} + +func newBlockingLogStore(inner LogStore) LogStore { + return &blockingLogStore{ + LogStore: inner, + unblockC: make(chan struct{}), + } +} + +func (b *blockingLogStore) block() { + b.mu.Lock() + defer b.mu.Unlock() + b.blocked.Store(true) + b.unblockC = make(chan struct{}) +} + +func (b *blockingLogStore) unblock() { + b.mu.Lock() + defer b.mu.Unlock() + if b.blocked.Load() { + b.blocked.Store(false) + close(b.unblockC) + } +} + +func (b *blockingLogStore) GetLog(index uint64, log *Log) error { + if b.blocked.Load() { + b.mu.Lock() + ch := b.unblockC + b.mu.Unlock() + <-ch + } + return b.LogStore.GetLog(index, log) +} diff --git a/testing.go b/testing.go index ce9ee9f4..4e5dd493 100644 --- a/testing.go +++ b/testing.go @@ -1,4 +1,4 @@ -// Copyright IBM Corp. 2013, 2025 +// Copyright IBM Corp. 2013, 2026 // SPDX-License-Identifier: MPL-2.0 package raft @@ -136,6 +136,10 @@ type MockMonotonicLogStore struct { s LogStore } +func NewMockMonotonicLogStore(logs LogStore) LogStore { + return &MockMonotonicLogStore{s: logs} +} + // IsMonotonic implements the MonotonicLogStore interface. func (m *MockMonotonicLogStore) IsMonotonic() bool { return true @@ -717,15 +721,15 @@ WAIT: // NOTE: This is exposed for middleware testing purposes and is not a stable API type MakeClusterOpts struct { - Peers int - Bootstrap bool - Conf *Config - ConfigStoreFSM bool - MakeFSMFunc func() FSM - LongstopTimeout time.Duration - MonotonicLogs bool - CommitTrackingLogs bool - PropagateError bool // If true, return errors instead of calling t.Fatal + Peers int + Bootstrap bool + Conf *Config + ConfigStoreFSM bool + MakeFSMFunc func() FSM + LongstopTimeout time.Duration + LogstoreWrapperFunc func(LogStore) LogStore + CommitTrackingLogs bool + PropagateError bool // If true, return errors instead of calling t.Fatal } // makeCluster will return a cluster with the given config and number of peers. @@ -810,8 +814,8 @@ func makeCluster(t *testing.T, opts *MakeClusterOpts) (*cluster, error) { snap := c.snaps[i] trans := c.trans[i] - if opts.MonotonicLogs { - logs = &MockMonotonicLogStore{s: logs} + if opts.LogstoreWrapperFunc != nil { + logs = opts.LogstoreWrapperFunc(logs) } else if opts.CommitTrackingLogs { logs = NewInmemCommitTrackingStore() }