Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions inmem_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,6 @@ func (i *InmemStore) GetUint64(key []byte) (uint64, error) {
return i.kvInt[string(key)], nil
}

type commitIndexTrackingLog struct {
log *Log
CommitIndex uint64
}
type InmemCommitTrackingStore struct {
InmemStore
commitIndex atomic.Uint64
Expand Down
23 changes: 13 additions & 10 deletions raft_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright IBM Corp. 2013, 2025
// Copyright IBM Corp. 2013, 2026
// SPDX-License-Identifier: MPL-2.0

package raft
Expand Down Expand Up @@ -1021,10 +1021,10 @@ func TestRaft_RestoreSnapshotOnStartup_Monotonic(t *testing.T) {
conf := inmemConfig(t)
conf.TrailingLogs = 10
opts := &MakeClusterOpts{
Peers: 1,
Bootstrap: true,
Conf: conf,
MonotonicLogs: true,
Peers: 1,
Bootstrap: true,
Conf: conf,
LogstoreWrapperFunc: NewMockMonotonicLogStore,
}
c := MakeClusterCustom(t, opts)
defer c.Close()
Expand Down Expand Up @@ -1262,6 +1262,9 @@ func TestRaft_RestoreCommittedLogs(t *testing.T) {
t.Fatal("err: raft log store does not implement CommitTrackingLogStore interface")
}
commitIdx, err := store.GetCommitIndex()
if err != nil {
t.Fatalf("err: %v", err)
}
// We should have applied all committed logs
if last := r.getLastApplied(); last != commitIdx {
t.Fatalf("bad last index: %d, expecting %d", last, commitIdx)
Expand Down Expand Up @@ -1642,10 +1645,10 @@ func snapshotAndRestore(t *testing.T, offset uint64, monotonicLogStore bool, res
var c *cluster
numPeers := 3
optsMonotonic := &MakeClusterOpts{
Peers: numPeers,
Bootstrap: true,
Conf: conf,
MonotonicLogs: true,
Peers: numPeers,
Bootstrap: true,
Conf: conf,
LogstoreWrapperFunc: NewMockMonotonicLogStore,
}
if monotonicLogStore {
c = MakeClusterCustom(t, optsMonotonic)
Expand Down Expand Up @@ -2926,7 +2929,7 @@ func TestRaft_LogStoreIsMonotonic(t *testing.T) {

// Now create a new MockMonotonicLogStore using the leader logs and expect
// it to work.
log = &MockMonotonicLogStore{s: leader.logs}
log = NewMockMonotonicLogStore(leader.logs)
mcast, ok = log.(MonotonicLogStore)
require.True(t, ok)
assert.True(t, mcast.IsMonotonic())
Expand Down
74 changes: 72 additions & 2 deletions replication.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright IBM Corp. 2013, 2025
// Copyright IBM Corp. 2013, 2026
// SPDX-License-Identifier: MPL-2.0

package raft
Expand All @@ -10,7 +10,7 @@ import (
"sync/atomic"
"time"

"github.com/hashicorp/go-metrics/compat"
metrics "github.com/hashicorp/go-metrics/compat"
)

const (
Expand Down Expand Up @@ -73,6 +73,14 @@ type followerReplication struct {
// lastContactLock protects 'lastContact'.
lastContactLock sync.RWMutex

// lastReplicationStart is updated to the current time whenever a
// replicateTo method call is started, and is cleared when the replicateTo
// method returns. This is used by heartbeats to check if replication has
// stalled for too long.
lastReplicationStart time.Time
// lastReplicationStartLock protects 'lastReplicationStart'.
lastReplicationStartLock sync.RWMutex
Comment on lines +80 to +82
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the critical section is only a load or store, we could use a sync.Pointer instead. Your choice matches existing code (lastContact) though, so either way is fine.


// failures counts the number of failed RPCs since the last success, which is
// used to apply backoff.
failures uint64
Expand Down Expand Up @@ -132,6 +140,30 @@ func (s *followerReplication) setLastContact() {
s.lastContactLock.Unlock()
}

// resetLastReplicationStart clears the marker for the start of the last replication
// attempt
func (s *followerReplication) resetLastReplicationStart() {
s.lastReplicationStartLock.Lock()
s.lastReplicationStart = time.Time{}
s.lastReplicationStartLock.Unlock()
}

// setLastReplicationStart sets the marker for the start of the last replication
// attempt to the current time
func (s *followerReplication) setLastReplicationStart() {
s.lastReplicationStartLock.Lock()
s.lastReplicationStart = time.Now()
s.lastReplicationStartLock.Unlock()
}

// getLastReplicationStart gets the start of the last replication attempt
func (s *followerReplication) getLastReplicationStart() time.Time {
s.lastReplicationStartLock.RLock()
t := s.lastReplicationStart
s.lastReplicationStartLock.RUnlock()
return t
}

// replicate is a long running routine that replicates log entries to a single
// follower.
func (r *Raft) replicate(s *followerReplication) {
Expand All @@ -140,17 +172,22 @@ func (r *Raft) replicate(s *followerReplication) {
defer close(stopHeartbeat)
r.goFunc(func() { r.heartbeat(s, stopHeartbeat) })

defer s.resetLastReplicationStart()

RPC:
shouldStop := false
for !shouldStop {
s.resetLastReplicationStart()
select {
case maxIndex := <-s.stopCh:
// Make a best effort to replicate up to this index
if maxIndex > 0 {
s.setLastReplicationStart()
r.replicateTo(s, maxIndex)
}
return
case deferErr := <-s.triggerDeferErrorCh:
s.setLastReplicationStart()
lastLogIdx, _ := r.getLastLog()
shouldStop = r.replicateTo(s, lastLogIdx)
if !shouldStop {
Expand All @@ -159,6 +196,7 @@ RPC:
deferErr.respond(fmt.Errorf("replication failed"))
}
case <-s.triggerCh:
s.setLastReplicationStart()
lastLogIdx, _ := r.getLastLog()
shouldStop = r.replicateTo(s, lastLogIdx)
// This is _not_ our heartbeat mechanism but is to ensure
Expand All @@ -167,12 +205,14 @@ RPC:
// can't do this to keep them unblocked by disk IO on the
// follower. See https://github.com/hashicorp/raft/issues/282.
case <-randomTimeout(r.config().CommitTimeout):
s.setLastReplicationStart()
lastLogIdx, _ := r.getLastLog()
shouldStop = r.replicateTo(s, lastLogIdx)
}

// If things looks healthy, switch to pipeline mode
if !shouldStop && s.allowPipeline {
s.resetLastReplicationStart()
goto PIPELINE
}
}
Expand Down Expand Up @@ -409,6 +449,30 @@ func (r *Raft) heartbeat(s *followerReplication, stopCh chan struct{}) {
s.peerLock.RUnlock()

start := time.Now()

lastReplicationStart := s.getLastReplicationStart()
if !lastReplicationStart.IsZero() {
maxLastReplication := r.config().HeartbeatTimeout * 10
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's always comment magic numbers (especially since I can't remember exactly why we picked this one)

Suggested change
maxLastReplication := r.config().HeartbeatTimeout * 10
// Replication timeout should be relatively long to avoid
// costly leadership elections due to temporary replication
// stalls.
maxLastReplication := r.config().HeartbeatTimeout * 10

maybe?

if lastReplicationStart.Add(maxLastReplication).Before(start) {
r.logger.Warn("delaying heartbeat for peer because replication is stalled",
"peer", peer.Address,
"timeout", maxLastReplication,
"replication_started", lastReplicationStart,
)
// Replication has been stalled for too long. Delay the next
// heartbeat to allow a follower to take over, but don't exit the
// loop yet in case replication unblocks during the delay
select {
case <-s.notifyCh:
case <-randomTimeout(r.config().HeartbeatTimeout):
case <-stopCh:
return
}

continue
}
}

if err := r.trans.AppendEntries(peer.ID, peer.Address, &req, &resp); err != nil {
nextBackoffTime := cappedExponentialBackoff(failureWait, failures, maxFailureScale, r.config().HeartbeatTimeout/2)
r.logger.Error("failed to heartbeat to", "peer", peer.Address, "backoff time",
Expand Down Expand Up @@ -472,16 +536,19 @@ func (r *Raft) pipelineReplicate(s *followerReplication) error {
shouldStop := false
SEND:
for !shouldStop {
s.resetLastReplicationStart()
select {
case <-finishCh:
break SEND
case maxIndex := <-s.stopCh:
// Make a best effort to replicate up to this index
if maxIndex > 0 {
s.setLastReplicationStart()
r.pipelineSend(s, pipeline, &nextIndex, maxIndex)
}
break SEND
case deferErr := <-s.triggerDeferErrorCh:
s.setLastReplicationStart()
lastLogIdx, _ := r.getLastLog()
shouldStop = r.pipelineSend(s, pipeline, &nextIndex, lastLogIdx)
if !shouldStop {
Expand All @@ -490,13 +557,16 @@ SEND:
deferErr.respond(fmt.Errorf("replication failed"))
}
case <-s.triggerCh:
s.setLastReplicationStart()
lastLogIdx, _ := r.getLastLog()
shouldStop = r.pipelineSend(s, pipeline, &nextIndex, lastLogIdx)
case <-randomTimeout(r.config().CommitTimeout):
lastLogIdx, _ := r.getLastLog()
s.setLastReplicationStart()
shouldStop = r.pipelineSend(s, pipeline, &nextIndex, lastLogIdx)
}
}
s.resetLastReplicationStart()

// Stop our decoder, and wait for it to finish
close(stopCh)
Expand Down
147 changes: 147 additions & 0 deletions replication_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright IBM Corp. 2013, 2026
// SPDX-License-Identifier: MPL-2.0

package raft

import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestLeaderHeartbeatsWithStalledDisk(t *testing.T) {
c := MakeClusterCustom(t, &MakeClusterOpts{
Peers: 3,
Bootstrap: true,
LogstoreWrapperFunc: newBlockingLogStore,
})
t.Cleanup(c.Close)

go func() {
i := 0
ticker := time.NewTicker(500 * time.Millisecond)
for {
select {
case <-t.Context().Done():
return
case <-ticker.C:
i++
leader := c.Leader()
fut := leader.Apply(fmt.Appendf([]byte{}, "test%d", i), 0)
if err := fut.Error(); err != nil {
t.Logf("got error trying to write: %v", err)
} else {
t.Logf("write %d ok", i)
}
}
}
}()

t.Log("waiting for 5 seconds before partitioning leader")
time.Sleep(time.Second * 5)

oldLeader := c.Leader()
oldLeaderID := oldLeader.leaderID
oldLeaderTerm := oldLeader.getCurrentTerm()
c.Partition([]ServerAddress{c.Leader().localAddr})

var newLeaderTerm uint64
ctx, cancel := context.WithTimeout(t.Context(), 10*c.propagateTimeout)
t.Cleanup(cancel)
DONE:
for {
select {
case <-ctx.Done():
t.Fatal("election didn't happen!")
default:
newLeader := c.Leader()
if newLeader.leaderID != oldLeaderID {
t.Log("leader has stepped down!")
newLeaderTerm = newLeader.getCurrentTerm()
require.NotEqual(t, newLeaderTerm, oldLeaderTerm)
cancel()
break DONE
}
}
}

require.Len(t, c.WaitForFollowers(1), 1)

t.Log("leader was elected. healing parition")

// reconnect the partitioned node
c.FullyConnect()
time.Sleep(3 * c.propagateTimeout)

leaderTerm := c.Leader().getCurrentTerm()
require.Equal(t, newLeaderTerm, leaderTerm)

t.Log("blocking disk on leader")

leader := c.Leader()
leaderID := leader.leaderID
leaderStore := leader.logs.(*blockingLogStore)
leaderStore.block()
t.Cleanup(leaderStore.unblock)

ctx, cancel = context.WithTimeout(t.Context(), 5*time.Second)
t.Cleanup(cancel)
for {
select {
case <-ctx.Done():
t.Fatal("leader did not step down!")
default:
if c.Leader().leaderID != leaderID {
t.Log("leader has stepped down!")
return
}
}
}
}

// blockingLogStore wraps a LogStore and blocks GetLog calls on demand,
// simulating disk IO stalls.
type blockingLogStore struct {
LogStore
mu sync.Mutex
blocked atomic.Bool
unblockC chan struct{}
}

func newBlockingLogStore(inner LogStore) LogStore {
return &blockingLogStore{
LogStore: inner,
unblockC: make(chan struct{}),
}
}

func (b *blockingLogStore) block() {
b.mu.Lock()
defer b.mu.Unlock()
b.blocked.Store(true)
b.unblockC = make(chan struct{})
}

func (b *blockingLogStore) unblock() {
b.mu.Lock()
defer b.mu.Unlock()
if b.blocked.Load() {
b.blocked.Store(false)
close(b.unblockC)
}
}

func (b *blockingLogStore) GetLog(index uint64, log *Log) error {
if b.blocked.Load() {
b.mu.Lock()
ch := b.unblockC
b.mu.Unlock()
<-ch
}
return b.LogStore.GetLog(index, log)
}
Loading
Loading