diff --git a/pkg/api/message/export_test.go b/pkg/api/message/export_test.go index 912914775..3ba38b9da 100644 --- a/pkg/api/message/export_test.go +++ b/pkg/api/message/export_test.go @@ -1,39 +1,32 @@ package message import ( - "context" - "time" - "github.com/xmtp/xmtpd/pkg/db" ) -// AwaitCursor blocks until the subscribe worker has polled past all sequence IDs in vc. -// Only compiled during testing (export_test.go pattern). -func (s *Service) AwaitCursor(ctx context.Context, vc db.VectorClock) error { - const checkInterval = 5 * time.Millisecond - ticker := time.NewTicker(checkInterval) - defer ticker.Stop() +// DispatchedMet reports whether the subscribe worker has *dispatched* every +// sequence ID in vc to its listeners. This is stronger than "polled past" — +// a true return guarantees the start() loop has already handed those rows +// off, so any listener registered after this call will not retroactively +// receive them. Tests use this predicate with require.Eventually to pre-seed +// envelopes before opening a stream without racing the dispatch loop. +func (s *Service) DispatchedMet(vc db.VectorClock) bool { + return s.subscribeWorker.dispatchedMet(vc) +} - for { - if s.subscribeWorker.subscriptions.cursorMet(vc) { - return nil - } - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - } - } +// GlobalListenerCount returns the number of global listeners currently +// registered with the subscribe worker. Tests poll this (via require.Eventually) +// to wait for proof that a server-side listener is registered before +// triggering the code under test — no time.Sleep needed. +func (s *Service) GlobalListenerCount() int { + return s.subscribeWorker.countGlobalListeners() } -func (s *subscriptionHandler) cursorMet(vc db.VectorClock) bool { - s.Lock() - defer s.Unlock() - for nodeID, minSeq := range vc { - poller, ok := s.subs[nodeID] - if !ok || uint64(poller.sub.LastSeen()) < minSeq { - return false - } - } - return true +func (s *subscribeWorker) countGlobalListeners() int { + count := 0 + s.globalListeners.Range(func(_, _ any) bool { + count++ + return true + }) + return count } diff --git a/pkg/api/message/publish_test.go b/pkg/api/message/publish_test.go index 67c414827..020744efc 100644 --- a/pkg/api/message/publish_test.go +++ b/pkg/api/message/publish_test.go @@ -596,13 +596,16 @@ func TestPublishEnvelopeBatchPublishNoPartialError(t *testing.T) { require.Nil(t, resp) require.Contains(t, err.Error(), "published via the blockchain") - // give this some time to process just in case - time.Sleep(100 * time.Millisecond) - - envs, err := queries.New(suite.DB). - SelectGatewayEnvelopesUnfiltered(context.Background(), queries.SelectGatewayEnvelopesUnfilteredParams{}) - require.NoError(t, err) - require.Empty(t, envs) + // Assert that no envelopes ever land in the DB. require.Never fails fast + // if one does, instead of always waiting 100ms. + require.Never(t, func() bool { + envs, err := queries.New(suite.DB). + SelectGatewayEnvelopesUnfiltered( + context.Background(), + queries.SelectGatewayEnvelopesUnfilteredParams{}, + ) + return err != nil || len(envs) > 0 + }, 100*time.Millisecond, 10*time.Millisecond) } func TestPublishEnvelopeBalanceEnforcement(t *testing.T) { diff --git a/pkg/api/message/subscribe_test.go b/pkg/api/message/subscribe_test.go index 7a9c9d9ef..2da60fa24 100644 --- a/pkg/api/message/subscribe_test.go +++ b/pkg/api/message/subscribe_test.go @@ -14,7 +14,6 @@ import ( "connectrpc.com/connect" "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/require" - "github.com/xmtp/xmtpd/pkg/api/message" "github.com/xmtp/xmtpd/pkg/constants" "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" @@ -112,9 +111,9 @@ func insertInitialRows(t *testing.T, suite *testUtilsApi.APIServerTestSuite) { }) // Wait until the subscribe worker has polled past the inserted rows so that // a subsequent subscription with LastSeen=nil won't see them. - ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond) - defer cancel() - require.NoError(t, suite.MessageService.AwaitCursor(ctx, db.VectorClock{100: 1, 200: 1})) + require.Eventually(t, func() bool { + return suite.MessageService.DispatchedMet(db.VectorClock{100: 1, 200: 1}) + }, 500*time.Millisecond, 5*time.Millisecond) } func insertAdditionalRows(t *testing.T, store *sql.DB, notifyChan ...chan bool) { @@ -616,8 +615,12 @@ func TestSubscribeCatchUpSkewedOriginators(t *testing.T) { // Populate the database. saveEnvelopes(t, server.DB, sourceEnvelopes) - // Let the subscribeWorker's catch up. - time.Sleep(4 * message.SubscribeWorkerPollTime) + // Block until the subscribeWorker has polled past the last inserted sequence ID. + require.Eventually(t, func() bool { + return server.MessageService.DispatchedMet( + db.VectorClock{heavyOriginatorID: uint64(heavyMsgCount)}, + ) + }, 5*time.Second, 5*time.Millisecond) ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() @@ -694,7 +697,6 @@ func TestSubscribeAll(t *testing.T) { minEnvelopes = 10 maxEnvelopes = 20 - insertDelay = 100 * time.Millisecond envelopeList = flattenEnvelopeMap( generateEnvelopes( t, @@ -716,12 +718,12 @@ func TestSubscribeAll(t *testing.T) { require.NoError(t, err) var ( - received = 0 + received atomic.Int64 streamWG sync.WaitGroup ) streamWG.Go(func() { - for received < total { + for received.Load() < int64(total) { ok := stream.Receive() if !ok { break @@ -730,23 +732,25 @@ func TestSubscribeAll(t *testing.T) { n := len(stream.Msg().GetEnvelopes()) t.Logf("stream produced %v envelopes", n) - received += n + received.Add(int64(n)) } cancel() }) - // Wait a bit - then start inserting envelopes. Make sure these are streamed. - time.Sleep(insertDelay) + // Wait until the server has registered the listener before inserting, so + // no inserts race the listener registration. + require.Eventually(t, func() bool { + return server.MessageService.GlobalListenerCount() >= 1 + }, 5*time.Second, 5*time.Millisecond) for _, env := range envelopeList { testutils.InsertGatewayEnvelopes(t, server.DB, []queries.InsertGatewayEnvelopeV3Params{env}) - time.Sleep(insertDelay) } streamWG.Wait() - require.Equal(t, total, received) + require.Equal(t, int64(total), received.Load()) } func readOriginatorsStream( @@ -1080,11 +1084,11 @@ func TestOriginatorParity_SkewedPagination(t *testing.T) { ) saveEnvelopes(t, server.DB, sourceEnvelopes) - awaitCtx, awaitCancel := context.WithTimeout(t.Context(), 5*time.Second) - defer awaitCancel() - require.NoError(t, server.MessageService.AwaitCursor( - awaitCtx, db.VectorClock{heavyOriginatorID: uint64(heavyMsgCount)}, - )) + require.Eventually(t, func() bool { + return server.MessageService.DispatchedMet( + db.VectorClock{heavyOriginatorID: uint64(heavyMsgCount)}, + ) + }, 5*time.Second, 5*time.Millisecond) ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) defer cancel() @@ -1273,9 +1277,9 @@ func TestOriginatorParity_VariableVolume(t *testing.T) { for nodeID, envs := range sourceEnvelopes { expectedVC[uint32(nodeID)] = uint64(len(envs)) } - awaitCtx, awaitCancel := context.WithTimeout(t.Context(), 5*time.Second) - defer awaitCancel() - require.NoError(t, server.MessageService.AwaitCursor(awaitCtx, expectedVC)) + require.Eventually(t, func() bool { + return server.MessageService.DispatchedMet(expectedVC) + }, 5*time.Second, 5*time.Millisecond) ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) defer cancel() @@ -1335,39 +1339,50 @@ func TestSubscribeAll_StreamsOnlyNewMessages(t *testing.T) { topic.TopicKindGroupMessagesV1, fmt.Appendf(nil, "generic-topic-%v", rand.Int()), ) - - insertDelay = 100 * time.Millisecond ) - // Envelope data. + // generateEnvelopes returns perNode envelopes per originator, so flattening + // gives us 2*perNode total. Slice by exact bounds so the post-subscribe + // insert count is exactly streamSize — the final equality check depends on it. var ( initialBatchSize = 5 streamSize = 5 - totalMessages = initialBatchSize + streamSize + perNode = initialBatchSize + streamSize sourceEnvelopes = flattenEnvelopeMap( generateEnvelopes( t, nodeIDs, - totalMessages, - totalMessages, // Let's get exactly N messages. + perNode, + perNode, // Exactly perNode per node. payerID, subTopic, )) initialBatch = sourceEnvelopes[:initialBatchSize] - streamBatch = sourceEnvelopes[initialBatchSize:] + streamBatch = sourceEnvelopes[initialBatchSize : initialBatchSize+streamSize] ) defer cancel() - // Pre-seed envelopes in the DB. - // These should NOT get picked up by the stream. + // Pre-seed envelopes in the DB. These should NOT get picked up by the + // stream because the subscribe worker marks them as known before the + // stream subscription is registered. for _, env := range initialBatch { testutils.InsertGatewayEnvelopes(t, server.DB, []queries.InsertGatewayEnvelopeV3Params{env}) } - // Add a delay so the subscribe worker picks pre-seeded envelopes as known before the streaming started. - time.Sleep(insertDelay) + // Block until the subscribe worker has polled past every pre-seeded row. + preSeedVC := make(db.VectorClock) + for _, env := range initialBatch { + nodeID := uint32(env.OriginatorNodeID) + seq := uint64(env.OriginatorSequenceID) + if cur, ok := preSeedVC[nodeID]; !ok || seq > cur { + preSeedVC[nodeID] = seq + } + } + require.Eventually(t, func() bool { + return server.MessageService.DispatchedMet(preSeedVC) + }, 5*time.Second, 5*time.Millisecond) // Start a subscriber stream. req := &message_api.SubscribeAllEnvelopesRequest{} @@ -1375,12 +1390,12 @@ func TestSubscribeAll_StreamsOnlyNewMessages(t *testing.T) { require.NoError(t, err) var ( - received = 0 + received atomic.Int64 streamWG sync.WaitGroup ) streamWG.Go(func() { - for received < streamSize { + for received.Load() < int64(streamSize) { ok := stream.Receive() if !ok { break @@ -1389,21 +1404,23 @@ func TestSubscribeAll_StreamsOnlyNewMessages(t *testing.T) { n := len(stream.Msg().GetEnvelopes()) t.Logf("stream produced %v envelopes", n) - received += n + received.Add(int64(n)) } cancel() }) - // Wait a bit - then start inserting envelopes. These should in fact be streamed. - time.Sleep(insertDelay) + // Wait until the server has registered the listener before inserting, so + // no inserts race the listener registration. + require.Eventually(t, func() bool { + return server.MessageService.GlobalListenerCount() >= 1 + }, 5*time.Second, 5*time.Millisecond) for _, env := range streamBatch { testutils.InsertGatewayEnvelopes(t, server.DB, []queries.InsertGatewayEnvelopeV3Params{env}) - time.Sleep(insertDelay) } streamWG.Wait() - require.Equal(t, streamSize, received) + require.Equal(t, int64(streamSize), received.Load()) } diff --git a/pkg/api/message/subscribe_topics_test.go b/pkg/api/message/subscribe_topics_test.go index 50e366933..12de426ae 100644 --- a/pkg/api/message/subscribe_topics_test.go +++ b/pkg/api/message/subscribe_topics_test.go @@ -10,7 +10,6 @@ import ( "connectrpc.com/connect" "github.com/stretchr/testify/require" - "github.com/xmtp/xmtpd/pkg/api/message" "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" @@ -23,16 +22,15 @@ import ( "github.com/xmtp/xmtpd/pkg/topic" ) -// setupTopicTest creates a test API server and returns the client, DB and mocks. -func setupTopicTest( - t *testing.T, -) (message_apiconnect.ReplicationApiClient, *sql.DB, testUtilsApi.APIServerMocks) { +// setupTopicTest creates a test API server and returns the full test suite so +// tests can access the client, DB, mocks, and MessageService for cursor +// synchronization. +func setupTopicTest(t *testing.T) *testUtilsApi.APIServerTestSuite { nodes := []registry.Node{ {NodeID: 100, IsCanonical: true}, {NodeID: 200, IsCanonical: true}, } - suite := testUtilsApi.NewTestAPIServer(t, testUtilsApi.WithRegistryNodes(nodes)) - return suite.ClientReplication, suite.DB, suite.APIServerMocks + return testUtilsApi.NewTestAPIServer(t, testUtilsApi.WithRegistryNodes(nodes)) } // makeFilter creates a TopicFilter with the given topic and optional LastSeen cursor. @@ -87,11 +85,31 @@ func subscribeTopics( return stream } -// insertAndWait inserts gateway envelopes and waits for the subscribe worker to poll them. -func insertAndWait(t *testing.T, store *sql.DB, rows []queries.InsertGatewayEnvelopeV3Params) { +// insertAndWait inserts gateway envelopes and blocks until the subscribe +// worker has polled past the inserted sequence IDs, so a subsequent +// subscription observes them as either catch-up or known history (depending on +// the cursor the caller provides). +func insertAndWait( + t *testing.T, + suite *testUtilsApi.APIServerTestSuite, + rows []queries.InsertGatewayEnvelopeV3Params, +) { t.Helper() - testutils.InsertGatewayEnvelopes(t, store, rows) - time.Sleep(message.SubscribeWorkerPollTime + 100*time.Millisecond) + testutils.InsertGatewayEnvelopes(t, suite.DB, rows) + + // Build the target VectorClock from the max sequence ID per originator. + vc := make(db.VectorClock, 2) + for _, r := range rows { + nodeID := uint32(r.OriginatorNodeID) + seq := uint64(r.OriginatorSequenceID) + if cur, ok := vc[nodeID]; !ok || seq > cur { + vc[nodeID] = seq + } + } + + require.Eventually(t, func() bool { + return suite.MessageService.DispatchedMet(vc) + }, 5*time.Second, 5*time.Millisecond) } // requireOriginatorOrdering verifies that envelopes are ordered per originator. @@ -154,7 +172,8 @@ func requireTopicStreamError( // ---- Validation Tests ---- func TestSubscribeTopics_Validation(t *testing.T) { - client, _, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client := suite.ClientReplication tooManyFilters := make([]*message_api.SubscribeTopicsRequest_TopicFilter, 10001) for i := range tooManyFilters { @@ -190,10 +209,11 @@ func TestSubscribeTopics_Validation(t *testing.T) { } func TestSubscribeTopics_UnknownOriginatorInCursor(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) - insertAndWait(t, store, []queries.InsertGatewayEnvelopeV3Params{ + insertAndWait(t, suite, []queries.InsertGatewayEnvelopeV3Params{ makeEnvRow(t, 100, 1, topicA, payerID), }) @@ -213,7 +233,8 @@ func TestSubscribeTopics_UnknownOriginatorInCursor(t *testing.T) { // ---- Live-Only Tests (nil LastSeen) ---- func TestSubscribeTopics_LiveOnly(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) stream := subscribeTopics( @@ -241,7 +262,8 @@ func TestSubscribeTopics_LiveOnly(t *testing.T) { } func TestSubscribeTopics_LiveOnlyFiltersByTopic(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) stream := subscribeTopics( @@ -269,10 +291,11 @@ func TestSubscribeTopics_LiveOnlyFiltersByTopic(t *testing.T) { // ---- Catch-Up Tests ---- func TestSubscribeTopics_CatchUpFromEmpty(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) - insertAndWait(t, store, []queries.InsertGatewayEnvelopeV3Params{ + insertAndWait(t, suite, []queries.InsertGatewayEnvelopeV3Params{ makeEnvRow(t, 100, 1, topicA, payerID), makeEnvRow(t, 200, 1, topicA, payerID), }) @@ -291,10 +314,11 @@ func TestSubscribeTopics_CatchUpFromEmpty(t *testing.T) { } func TestSubscribeTopics_CatchUpFromCursor(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) - insertAndWait(t, store, []queries.InsertGatewayEnvelopeV3Params{ + insertAndWait(t, suite, []queries.InsertGatewayEnvelopeV3Params{ makeEnvRow(t, 100, 1, topicA, payerID), makeEnvRow(t, 100, 2, topicA, payerID), makeEnvRow(t, 100, 3, topicA, payerID), @@ -322,12 +346,13 @@ func TestSubscribeTopics_CatchUpFromCursor(t *testing.T) { } func TestSubscribeTopics_DifferentCursorsPerTopic(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) // topicA: seq 1, 2, 3 from node 100 // topicB: seq 4 from node 100 - insertAndWait(t, store, []queries.InsertGatewayEnvelopeV3Params{ + insertAndWait(t, suite, []queries.InsertGatewayEnvelopeV3Params{ makeEnvRow(t, 100, 1, topicA, payerID), makeEnvRow(t, 100, 2, topicA, payerID), makeEnvRow(t, 100, 3, topicA, payerID), @@ -360,10 +385,11 @@ func TestSubscribeTopics_DifferentCursorsPerTopic(t *testing.T) { } func TestSubscribeTopics_CatchUpThenLive(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) - insertAndWait(t, store, []queries.InsertGatewayEnvelopeV3Params{ + insertAndWait(t, suite, []queries.InsertGatewayEnvelopeV3Params{ makeEnvRow(t, 100, 1, topicA, payerID), }) @@ -400,10 +426,11 @@ func TestSubscribeTopics_CatchUpThenLive(t *testing.T) { } func TestSubscribeTopics_NoDuplicatesBetweenCatchUpAndLive(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) - insertAndWait(t, store, []queries.InsertGatewayEnvelopeV3Params{ + insertAndWait(t, suite, []queries.InsertGatewayEnvelopeV3Params{ makeEnvRow(t, 100, 1, topicA, payerID), makeEnvRow(t, 100, 2, topicA, payerID), }) @@ -460,7 +487,8 @@ func TestSubscribeTopics_NoDuplicatesBetweenCatchUpAndLive(t *testing.T) { } func TestSubscribeTopics_StatusStartedOnOpen(t *testing.T) { - client, _, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client := suite.ClientReplication stream, err := client.SubscribeTopics( t.Context(), @@ -482,11 +510,12 @@ func TestSubscribeTopics_StatusStartedOnOpen(t *testing.T) { } func TestSubscribeTopics_StatusLifecycle(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) // Insert envelopes before subscribing so catch-up is triggered. - insertAndWait(t, store, []queries.InsertGatewayEnvelopeV3Params{ + insertAndWait(t, suite, []queries.InsertGatewayEnvelopeV3Params{ makeEnvRow(t, 100, 1, topicA, payerID), makeEnvRow(t, 100, 2, topicA, payerID), }) @@ -554,7 +583,8 @@ func TestSubscribeTopics_StatusLifecycle(t *testing.T) { // ---- Ordering Tests ---- func TestSubscribeTopics_PerOriginatorOrdering(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) // Insert multiple envelopes from two originators. @@ -565,7 +595,7 @@ func TestSubscribeTopics_PerOriginatorOrdering(t *testing.T) { makeEnvRow(t, 200, seq, topicA, payerID), ) } - insertAndWait(t, store, rows) + insertAndWait(t, suite, rows) stream := subscribeTopics( t, @@ -605,7 +635,7 @@ func TestSubscribeTopics_MultiOriginatorMultiTopic(t *testing.T) { } } } - insertAndWait(t, suite.DB, rows) + insertAndWait(t, suite, rows) // Subscribe to all topics with empty cursors. filters := make([]*message_api.SubscribeTopicsRequest_TopicFilter, len(topics)) @@ -623,7 +653,8 @@ func TestSubscribeTopics_MultiOriginatorMultiTopic(t *testing.T) { // ---- Scale Tests ---- func TestSubscribeTopics_LargeCatchUpMultiplePages(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) // Insert >500 envelopes to trigger pagination. @@ -632,7 +663,7 @@ func TestSubscribeTopics_LargeCatchUpMultiplePages(t *testing.T) { for i := range total { rows[i] = makeEnvRow(t, 100, uint64(i+1), topicA, payerID) } - insertAndWait(t, store, rows) + insertAndWait(t, suite, rows) stream := subscribeTopics( t, @@ -648,7 +679,8 @@ func TestSubscribeTopics_LargeCatchUpMultiplePages(t *testing.T) { } func TestSubscribeTopics_ManyTopics(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) numTopics := 1000 @@ -663,7 +695,7 @@ func TestSubscribeTopics_ManyTopics(t *testing.T) { filters[i] = makeFilter(tp, map[uint32]uint64{}) rows[i] = makeEnvRow(t, 100, uint64(i+1), tp, payerID) } - insertAndWait(t, store, rows) + insertAndWait(t, suite, rows) stream := subscribeTopics(t, client, t.Context(), filters) @@ -674,7 +706,8 @@ func TestSubscribeTopics_ManyTopics(t *testing.T) { // ---- Error Path Tests ---- func TestSubscribeTopics_ContextCancelledDuringLive(t *testing.T) { - client, _, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client := suite.ClientReplication ctx, cancel := context.WithCancel(t.Context()) stream := subscribeTopics(t, client, ctx, []*message_api.SubscribeTopicsRequest_TopicFilter{ @@ -699,7 +732,8 @@ func TestSubscribeTopics_ContextCancelledDuringLive(t *testing.T) { } func TestSubscribeTopics_ContextCancelledDuringCatchUp(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) // Insert a large amount of data to make catch-up take time. @@ -707,7 +741,7 @@ func TestSubscribeTopics_ContextCancelledDuringCatchUp(t *testing.T) { for i := range 200 { rows[i] = makeEnvRow(t, 100, uint64(i+1), topicA, payerID) } - insertAndWait(t, store, rows) + insertAndWait(t, suite, rows) ctx, cancel := context.WithCancel(t.Context()) @@ -741,7 +775,8 @@ func TestSubscribeTopics_ContextCancelledDuringCatchUp(t *testing.T) { // ---- Concurrent Tests ---- func TestSubscribeTopics_SimultaneousWithSubscribeEnvelopes(t *testing.T) { - client, store, _ := setupTopicTest(t) + suite := setupTopicTest(t) + client, store := suite.ClientReplication, suite.DB payerID := db.NullInt32(testutils.CreatePayer(t, store)) ctx := t.Context() diff --git a/pkg/api/message/subscribe_worker.go b/pkg/api/message/subscribe_worker.go index 2382e06e1..126799d48 100644 --- a/pkg/api/message/subscribe_worker.go +++ b/pkg/api/message/subscribe_worker.go @@ -3,13 +3,16 @@ package message import ( "context" "fmt" + "maps" "slices" + "sync" "time" "go.uber.org/zap" "github.com/xmtp/xmtpd/pkg/constants" "github.com/xmtp/xmtpd/pkg/db" + "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/envelopes" "github.com/xmtp/xmtpd/pkg/migrator" "github.com/xmtp/xmtpd/pkg/registry" @@ -52,6 +55,13 @@ type subscribeWorker struct { globalListeners listenerSet originatorListeners listenersMap[uint32] topicListeners listenersMap[string] + + // dispatched tracks the highest per-originator sequence ID already handed + // off to listeners by start() — advanced *after* dispatch, unlike the + // per-poller LastSeen which advances on DB read. Tests use it to wait for + // pre-seeded rows to drain before opening a stream. + dispatchedMu sync.Mutex + dispatched db.VectorClock } func (s *subscribeWorker) getOriginatorNodeIds() ([]uint32, error) { @@ -96,6 +106,12 @@ func startSubscribeWorker( } vc := db.ToVectorClock(latestEnvelopes) + // Seed dispatched with the current vector clock so any envelope already + // written to the DB before the worker started counts as "dispatched" (from + // the perspective of a brand-new listener that will never see it anyway). + initialDispatched := make(db.VectorClock, len(vc)) + maps.Copy(initialDispatched, vc) + worker := &subscribeWorker{ ctx: ctx, logger: logger, @@ -105,6 +121,7 @@ func startSubscribeWorker( subscriptions: newSubscriptionHandler(logger, store, vc), originatorListeners: listenersMap[uint32]{}, topicListeners: listenersMap[string]{}, + dispatched: initialDispatched, } nodeIDs, err := worker.getOriginatorNodeIds() @@ -165,6 +182,12 @@ func (s *subscribeWorker) start() { s.dispatchToTopics(envs) s.dispatchToGlobals(envs) + // Advance the dispatched cursor based on the raw batch (not envs), + // so an envelope that failed to unmarshal still counts as + // "dispatched" — otherwise tests waiting on its sequence ID would + // spin forever. + s.advanceDispatched(batch) + span.Finish() } } @@ -236,6 +259,43 @@ func (s *subscribeWorker) dispatchToGlobals(envs []*envelopes.OriginatorEnvelope s.dispatchToListeners(&s.globalListeners, envs) } +// advanceDispatched records the highest sequence ID per originator seen in +// this batch, so dispatchedMet can confirm that every row up to a given +// vector clock has already been handed to the listener set. +func (s *subscribeWorker) advanceDispatched( + batch []queries.SelectGatewayEnvelopesBySingleOriginatorRow, +) { + if len(batch) == 0 { + return + } + // Batch rows come from SelectGatewayEnvelopesBySingleOriginator: all from + // one originator, ORDER BY originator_sequence_id — so the last row has the + // max seq. + last := batch[len(batch)-1] + nodeID := uint32(last.OriginatorNodeID) + seq := uint64(last.OriginatorSequenceID) + s.dispatchedMu.Lock() + defer s.dispatchedMu.Unlock() + if cur, ok := s.dispatched[nodeID]; !ok || seq > cur { + s.dispatched[nodeID] = seq + } +} + +// dispatchedMet returns true iff every (originator, seq) in target has been +// dispatched to listeners. A returned true is a guarantee that any listener +// registered after this call will not retroactively receive those envelopes +// — the batch has already left start(). +func (s *subscribeWorker) dispatchedMet(target db.VectorClock) bool { + s.dispatchedMu.Lock() + defer s.dispatchedMu.Unlock() + for nodeID, minSeq := range target { + if cur, ok := s.dispatched[nodeID]; !ok || cur < minSeq { + return false + } + } + return true +} + func (s *subscribeWorker) dispatchToListeners( listeners *listenerSet, envs []*envelopes.OriginatorEnvelope, diff --git a/pkg/api/metadata/cursor_test.go b/pkg/api/metadata/cursor_test.go index 355be0089..e6975f43f 100644 --- a/pkg/api/metadata/cursor_test.go +++ b/pkg/api/metadata/cursor_test.go @@ -1,7 +1,6 @@ package metadata_test import ( - "database/sql" "testing" "time" @@ -11,7 +10,6 @@ import ( "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/metadata_api" metadata_apiconnect "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/metadata_api/metadata_apiconnect" - "github.com/xmtp/xmtpd/pkg/api/message" dbUtils "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/testutils" @@ -28,7 +26,7 @@ var ( func setupTest( t *testing.T, -) (metadata_apiconnect.MetadataApiClient, *sql.DB, testUtilsApi.APIServerMocks) { +) (metadata_apiconnect.MetadataApiClient, *testUtilsApi.APIServerTestSuite) { var ( suite = testUtilsApi.NewTestAPIServer(t) payerID = dbUtils.NullInt32(testutils.CreatePayer(t, suite.DB)) @@ -89,25 +87,49 @@ func setupTest( }, } - return suite.ClientMetadata, suite.DB, suite.APIServerMocks + return suite.ClientMetadata, suite } -func insertInitialRows(t *testing.T, store *sql.DB) { - testutils.InsertGatewayEnvelopes(t, store, []queries.InsertGatewayEnvelopeV3Params{ +// insertInitialRows inserts the first two rows and blocks until GetSyncCursor +// reports the subscribe worker has polled past them, so tests observe a known +// cursor state. +func insertInitialRows( + t *testing.T, + client metadata_apiconnect.MetadataApiClient, + suite *testUtilsApi.APIServerTestSuite, +) { + testutils.InsertGatewayEnvelopes(t, suite.DB, []queries.InsertGatewayEnvelopeV3Params{ allRows[0], allRows[1], }) - time.Sleep(message.SubscribeWorkerPollTime + 100*time.Millisecond) + expected := map[uint32]uint64{100: 1, 200: 1} + require.Eventually(t, func() bool { + resp, err := client.GetSyncCursor( + t.Context(), + connect.NewRequest(&metadata_api.GetSyncCursorRequest{}), + ) + if err != nil { + return false + } + return assert.ObjectsAreEqual( + expected, + resp.Msg.GetLatestSync().GetNodeIdToSequenceId(), + ) + }, 5*time.Second, 5*time.Millisecond) } -func insertAdditionalRows(t *testing.T, store *sql.DB, notifyChan ...chan bool) { - testutils.InsertGatewayEnvelopes(t, store, []queries.InsertGatewayEnvelopeV3Params{ +func insertAdditionalRows( + t *testing.T, + suite *testUtilsApi.APIServerTestSuite, + notifyChan ...chan bool, +) { + testutils.InsertGatewayEnvelopes(t, suite.DB, []queries.InsertGatewayEnvelopeV3Params{ allRows[2], allRows[3], allRows[4], }, notifyChan...) } func TestGetCursorBasic(t *testing.T) { - client, db, _ := setupTest(t) - insertInitialRows(t, db) + client, suite := setupTest(t) + insertInitialRows(t, client, suite) ctx := t.Context() @@ -126,7 +148,7 @@ func TestGetCursorBasic(t *testing.T) { require.Equal(t, expectedCursor, cursor.Msg.GetLatestSync().GetNodeIdToSequenceId()) - insertAdditionalRows(t, db) + insertAdditionalRows(t, suite) require.Eventually(t, func() bool { expectedCursor := map[uint32]uint64{ 100: 3, @@ -155,8 +177,8 @@ func TestGetCursorBasic(t *testing.T) { } func TestSubscribeSyncCursorBasic(t *testing.T) { - client, db, _ := setupTest(t) - insertInitialRows(t, db) + client, suite := setupTest(t) + insertInitialRows(t, client, suite) ctx := t.Context() @@ -181,7 +203,7 @@ func TestSubscribeSyncCursorBasic(t *testing.T) { require.Equal(t, expectedCursor, firstUpdate.GetLatestSync().GetNodeIdToSequenceId()) - insertAdditionalRows(t, db) + insertAdditionalRows(t, suite) expectedCursor = map[uint32]uint64{ 100: 3, diff --git a/pkg/api/payer/publish_test.go b/pkg/api/payer/publish_test.go index 7eb88a5f9..74b18300b 100644 --- a/pkg/api/payer/publish_test.go +++ b/pkg/api/payer/publish_test.go @@ -241,19 +241,25 @@ func TestPublishToNodes(t *testing.T) { ) } +// slowServer is a test fixture that simulates a replication node that never +// responds until the client cancels the RPC. When `slow` is set, the handler +// blocks on ctx.Done() rather than sleeping through a wall-clock duration, so +// the test completes as soon as the client's publish deadline fires instead of +// waiting out a hard-coded sleep. type slowServer struct { - delay atomic.Duration + slow atomic.Bool message_api.UnimplementedReplicationApiServer } func (s *slowServer) PublishPayerEnvelopes( - context.Context, - *message_api.PublishPayerEnvelopesRequest, + ctx context.Context, + _ *message_api.PublishPayerEnvelopesRequest, ) (*message_api.PublishPayerEnvelopesResponse, error) { - time.Sleep(s.delay.Load()) - - res := &message_api.PublishPayerEnvelopesResponse{} - return res, nil + if s.slow.Load() { + <-ctx.Done() + return nil, ctx.Err() + } + return &message_api.PublishPayerEnvelopesResponse{}, nil } func TestPublishToNodesExpires(t *testing.T) { @@ -299,8 +305,8 @@ func TestPublishToNodesExpires(t *testing.T) { envelopesTestUtils.GetRealisticGroupMessagePayload(false), ) - // Make the server take longer than the service is willing to wait. - srv.delay.Store(publishTimeout + time.Second) + // Make the server block until the client's publish deadline fires. + srv.slow.Store(true) _, err = svc.PublishClientEnvelopes( ctx, connect.NewRequest(&payer_api.PublishClientEnvelopesRequest{ @@ -311,7 +317,7 @@ func TestPublishToNodesExpires(t *testing.T) { require.Error(t, err) // Publish rpc should succeed if completed within the deadline. - srv.delay.Store(0) + srv.slow.Store(false) _, err = svc.PublishClientEnvelopes( ctx, connect.NewRequest(&payer_api.PublishClientEnvelopesRequest{ diff --git a/pkg/api/payer/selectors/node_selector_test.go b/pkg/api/payer/selectors/node_selector_test.go index 59a9eaecd..a78bba434 100644 --- a/pkg/api/payer/selectors/node_selector_test.go +++ b/pkg/api/payer/selectors/node_selector_test.go @@ -663,9 +663,8 @@ func TestClosestNodeSelector_WithPreferredNodes(t *testing.T) { require.NoError(t, err) require.NotNil(t, selector) - // Allow time for latency measurement - time.Sleep(200 * time.Millisecond) - + // The first GetNode call performs TCP latency probing synchronously inside + // updateLatencyCache, so no wall-clock wait is needed before calling it. tpc := *topic.NewTopic(topic.TopicKindIdentityUpdatesV1, []byte("test")) node, err := selector.GetNode(tpc) @@ -699,9 +698,8 @@ func TestClosestNodeSelector_WithoutPreferredNodes(t *testing.T) { require.NoError(t, err) require.NotNil(t, selector) - // Allow time for latency measurement - time.Sleep(200 * time.Millisecond) - + // The first GetNode call performs TCP latency probing synchronously inside + // updateLatencyCache, so no wall-clock wait is needed before calling it. tpc := *topic.NewTopic(topic.TopicKindIdentityUpdatesV1, []byte("test")) _, err = selector.GetNode(tpc) // In test environment, latency measurement may fail - both outcomes are acceptable @@ -726,9 +724,8 @@ func TestClosestNodeSelector_PreferredNodesFallback(t *testing.T) { require.NoError(t, err) require.NotNil(t, selector) - // Allow time for latency measurement - time.Sleep(200 * time.Millisecond) - + // The first GetNode call performs TCP latency probing synchronously inside + // updateLatencyCache, so no wall-clock wait is needed before calling it. tpc := *topic.NewTopic(topic.TopicKindIdentityUpdatesV1, []byte("test")) node, err := selector.GetNode(tpc) diff --git a/pkg/blockchain/noncemanager/manager_test.go b/pkg/blockchain/noncemanager/manager_test.go index d78d869b3..30a173adc 100644 --- a/pkg/blockchain/noncemanager/manager_test.go +++ b/pkg/blockchain/noncemanager/manager_test.go @@ -253,18 +253,27 @@ func TestSimultaneousAllocation(t *testing.T) { err := tm.manager.Replenish(ctx, *big.NewInt(0)) require.NoError(t, err) - const numGoroutines = 50 + // numGoroutines must not exceed BestGuessConcurrency: both the SQL + // and Redis backends cap concurrent nonce holders via a semaphore + // of that size, so going higher would leave some workers blocked + // inside GetNonce and deadlock the acquired-barrier below. + const numGoroutines = noncemanager.BestGuessConcurrency var wg sync.WaitGroup var activeNonces sync.Map // nonce -> count of simultaneous holders var errors []error var mu sync.Mutex - // Phase 1: All goroutines get nonces simultaneously. - // The barrier MUST be closed below (see `close(barrier)`) so the - // workers are released; otherwise they would block forever on - // `<-barrier` and wg.Wait() would hang. - barrier := make(chan struct{}) + // Phase 1: all goroutines race to acquire their nonce. + // Phase 2: once every goroutine has acquired, they all consume + // together. This keeps the "held but not consumed" window + // open deterministically across all workers, so a bug that + // double-allocates a nonce is caught without needing a + // wall-clock sleep. + startBarrier := make(chan struct{}) + consumeBarrier := make(chan struct{}) + var acquired sync.WaitGroup + acquired.Add(numGoroutines) for i := range numGoroutines { wg.Add(1) @@ -272,13 +281,14 @@ func TestSimultaneousAllocation(t *testing.T) { defer wg.Done() // Wait for all goroutines to be ready - <-barrier + <-startBarrier nonce, err := tm.manager.GetNonce(ctx) if err != nil { mu.Lock() errors = append(errors, err) mu.Unlock() + acquired.Done() return } @@ -302,8 +312,11 @@ func TestSimultaneousAllocation(t *testing.T) { mu.Unlock() } - // Hold the nonce briefly to ensure simultaneous allocation detection - time.Sleep(1 * time.Millisecond) + // Signal acquisition and hold the nonce until every other + // goroutine has also acquired. This is what makes the + // simultaneous-allocation window deterministic. + acquired.Done() + <-consumeBarrier // Phase 2: Always consume to avoid cancelled nonce reuse confusion err = nonce.Consume() @@ -320,7 +333,11 @@ func TestSimultaneousAllocation(t *testing.T) { } // Release all goroutines at once - close(barrier) + close(startBarrier) + // Wait for every goroutine to acquire (or fail) its nonce so they + // all hold their nonces concurrently before anyone consumes. + acquired.Wait() + close(consumeBarrier) wg.Wait() // For this test, we only care about true simultaneous allocation @@ -348,15 +365,14 @@ func TestCancelReuse(t *testing.T) { err := tm.manager.Replenish(ctx, *big.NewInt(0)) require.NoError(t, err) - // Allocate one nonce and cancel it + // Allocate one nonce and cancel it. Cancel is synchronous on every + // backend (sql, redis, in-memory), so when it returns the nonce is + // back in the available pool. nonce1, err := tm.manager.GetNonce(ctx) require.NoError(t, err) firstNonce := nonce1.Nonce.Int64() nonce1.Cancel() - // Allow time for cancellation to take effect (especially for Redis) - time.Sleep(10 * time.Millisecond) - // Get the next nonce - it should be the cancelled one OR the next available nonce2, err := tm.manager.GetNonce(ctx) require.NoError(t, err) diff --git a/pkg/blockchain/noncemanager/redis/manager_test.go b/pkg/blockchain/noncemanager/redis/manager_test.go index 777674a80..e5e2960d9 100644 --- a/pkg/blockchain/noncemanager/redis/manager_test.go +++ b/pkg/blockchain/noncemanager/redis/manager_test.go @@ -46,10 +46,9 @@ func TestRedisGetNonce_RevertMany(t *testing.T) { nonce, err := nonceManager.GetNonce(t.Context()) require.NoError(t, err) require.EqualValues(t, 0, nonce.Nonce.Int64()) + // Cancel is synchronous; when it returns, the nonce is already back + // in the available pool, so the next GetNonce can reuse it immediately. nonce.Cancel() - - // Add a small delay to ensure the Cancel operation completes - time.Sleep(1 * time.Millisecond) } } @@ -177,9 +176,8 @@ func TestRedisGetNonce_ContextCancellation(t *testing.T) { func TestRedisGetNonce_KeyPrefix(t *testing.T) { client1, keyPrefix1 := redistestutils.NewRedisForTest(t) - // The keyPrefix uses the timestamp as a tiebreak when run from the same test. - // Ensure we get a distinct prefix - time.Sleep(1 * time.Millisecond) + // NewRedisForTest adds a process-level monotonic counter to the prefix, + // so back-to-back calls always produce distinct prefixes. client2, keyPrefix2 := redistestutils.NewRedisForTest(t) logger, err := zap.NewDevelopment() diff --git a/pkg/blockchain/oracle/oracle.go b/pkg/blockchain/oracle/oracle.go index 80171af81..cee15316d 100644 --- a/pkg/blockchain/oracle/oracle.go +++ b/pkg/blockchain/oracle/oracle.go @@ -77,14 +77,25 @@ type Oracle struct { gasPriceSource GasPriceSource gasPriceLastUpdated atomic.Int64 gasPrice atomic.Int64 + maxStaleDuration time.Duration sfGroup singleflight.Group } +// Option configures an Oracle at construction time. +type Option func(*Oracle) + +// WithMaxStaleDuration overrides the cache staleness window. Mainly useful +// in tests that want to force every call to fetch a fresh price. +func WithMaxStaleDuration(d time.Duration) Option { + return func(o *Oracle) { o.maxStaleDuration = d } +} + func New( ctx context.Context, logger *zap.Logger, wsURL string, + opts ...Option, ) (*Oracle, error) { ethClient, err := ethclient.Dial(wsURL) if err != nil { @@ -135,6 +146,11 @@ func New( chainID: chainID.Int64(), gasPriceSource: gasPriceSource, gasPriceDefaultWei: gasPriceDefaultWei, + maxStaleDuration: gasPriceMaxStaleDuration, + } + + for _, opt := range opts { + opt(oracle) } return oracle, nil @@ -188,7 +204,7 @@ func (o *Oracle) isStale() bool { return true } - return time.Since(time.UnixMilli(lastUpdated)) > gasPriceMaxStaleDuration + return time.Since(time.UnixMilli(lastUpdated)) > o.maxStaleDuration } func (o *Oracle) updateGasPrice(ctx context.Context) error { diff --git a/pkg/blockchain/oracle/oracle_test.go b/pkg/blockchain/oracle/oracle_test.go index ae02ae387..0b6ee76ad 100644 --- a/pkg/blockchain/oracle/oracle_test.go +++ b/pkg/blockchain/oracle/oracle_test.go @@ -2,7 +2,6 @@ package oracle_test import ( "context" - "math/rand" "sync" "testing" "time" @@ -13,7 +12,7 @@ import ( "github.com/xmtp/xmtpd/pkg/testutils/anvil" ) -func buildOracle(t *testing.T) *oracle.Oracle { +func buildOracle(t *testing.T, opts ...oracle.Option) *oracle.Oracle { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) @@ -25,6 +24,7 @@ func buildOracle(t *testing.T) *oracle.Oracle { ctx, logger, wsURL, + opts..., ) require.NoError(t, err) @@ -38,8 +38,8 @@ func buildOracle(t *testing.T) *oracle.Oracle { func TestOracleGetGasPrice(t *testing.T) { o := buildOracle(t) - // By forcing waiting, we ensure that the gas price is fetched from the blockchain. - time.Sleep(500 * time.Millisecond) + // First GetGasPrice call always fetches because lastUpdated is zero, + // so no wall-clock wait is required. gasPrice := o.GetGasPrice() require.Positive(t, gasPrice) } @@ -72,7 +72,10 @@ func TestOracleConcurrentGetGasPrice(t *testing.T) { } func TestOracleGetGasPriceRandom(t *testing.T) { - o := buildOracle(t) + // Force every GetGasPrice call to refetch by configuring a zero stale + // window. singleflight still coalesces concurrent fetches, so this + // exercises the racy "fetch under load" path without any wall-clock sleep. + o := buildOracle(t, oracle.WithMaxStaleDuration(0)) initialPrice := o.GetGasPrice() require.Positive(t, initialPrice) @@ -80,8 +83,6 @@ func TestOracleGetGasPriceRandom(t *testing.T) { var wg sync.WaitGroup for range 100 { wg.Go(func() { - // Force the oracle to fetch a new gas price for some goroutines. - time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) price := o.GetGasPrice() require.Positive(t, price) }) @@ -90,16 +91,15 @@ func TestOracleGetGasPriceRandom(t *testing.T) { } func TestOracleGasPriceRefreshesAfterStaleness(t *testing.T) { - o := buildOracle(t) + // A zero stale window forces the second GetGasPrice call to fetch a + // fresh value, deterministically exercising the refresh path. + o := buildOracle(t, oracle.WithMaxStaleDuration(0)) // Get initial gas price initialPrice := o.GetGasPrice() t.Logf("initial gas price: %d", initialPrice) require.Positive(t, initialPrice) - // Wait for staleness (250ms + buffer) - time.Sleep(300 * time.Millisecond) - // Price should still be valid after refresh currentPrice := o.GetGasPrice() t.Logf("current gas price: %d", currentPrice) diff --git a/pkg/db/payer_test.go b/pkg/db/payer_test.go index 706ae7185..2ca329203 100644 --- a/pkg/db/payer_test.go +++ b/pkg/db/payer_test.go @@ -55,7 +55,7 @@ func TestFindOrCreatePayerWithRetry(t *testing.T) { rawDB, _ := testutils.NewRawDB(t, ctx) address := testutils.RandomString(42) - // Start transaction T1 and insert the payer (holds row lock, uncommitted) + // Start transaction T1 and insert the payer (holds row lock, uncommitted). tx1, err := rawDB.BeginTx(ctx, nil) require.NoError(t, err) defer func() { _ = tx1.Rollback() }() @@ -63,21 +63,45 @@ func TestFindOrCreatePayerWithRetry(t *testing.T) { _, err = tx1.ExecContext(ctx, "INSERT INTO payers(address) VALUES ($1)", address) require.NoError(t, err) - // Commit T1 after a short delay so the retry can succeed + poolQuerier := queries.New(rawDB) + + // Run FindOrCreatePayerWithRetry concurrently with T1 commit. The + // retry call will block inside INSERT ... ON CONFLICT on T1's unique- + // index lock until T1 resolves. Instead of sleeping a fixed duration + // to guarantee the retry is blocked before we commit, poll + // pg_stat_activity for a session waiting on a Lock event — a + // deterministic signal that the contending query is actually stalled. + type result struct { + id int32 + err error + } + resCh := make(chan result, 1) go func() { - time.Sleep(5 * time.Millisecond) - _ = tx1.Commit() + id, err := db.FindOrCreatePayerWithRetry(ctx, poolQuerier, address, 10) + resCh <- result{id: id, err: err} }() - // On a separate connection, the raw FindOrCreatePayer gets sql.ErrNoRows - // because the CTE INSERT conflicts (T1 holds the lock) and the SELECT - // uses the pre-commit snapshot. - poolQuerier := queries.New(rawDB) - - // FindOrCreatePayerWithRetry should succeed after T1 commits - id, err := db.FindOrCreatePayerWithRetry(ctx, poolQuerier, address, 3) - require.NoError(t, err) - require.NotZero(t, id) + require.Eventually(t, func() bool { + var count int + err := rawDB.QueryRowContext(ctx, ` + SELECT count(*) + FROM pg_stat_activity + WHERE state = 'active' + AND wait_event_type = 'Lock' + AND query ILIKE '%payers%' + `).Scan(&count) + return err == nil && count >= 1 + }, 5*time.Second, 10*time.Millisecond, "retry call should be blocked on T1's lock") + + require.NoError(t, tx1.Commit()) + + select { + case r := <-resCh: + require.NoError(t, r.err) + require.NotZero(t, r.id) + case <-time.After(5 * time.Second): + t.Fatal("FindOrCreatePayerWithRetry did not return after T1 commit") + } }) t.Run("context cancellation stops retries", func(t *testing.T) { diff --git a/pkg/db/pgx_test.go b/pkg/db/pgx_test.go index 81c383570..c1c5e3516 100644 --- a/pkg/db/pgx_test.go +++ b/pkg/db/pgx_test.go @@ -111,11 +111,17 @@ func TestNamespacedDBInvalidDSN(t *testing.T) { require.Error(t, err) } -func BlackHoleServer(ctx context.Context, port string) error { +// BlackHoleServer starts a TCP listener that accepts connections but never +// responds. It closes `ready` as soon as the listener is bound so callers can +// synchronize on readiness instead of sleeping. +func BlackHoleServer(ctx context.Context, port string, ready chan<- struct{}) error { ln, err := net.Listen("tcp", ":"+port) if err != nil { return fmt.Errorf("error starting blackhole server: %w", err) } + if ready != nil { + close(ready) + } defer func() { _ = ln.Close() }() @@ -169,11 +175,16 @@ func TestBlackholeDNS(t *testing.T) { defer cancelServer() serverErrCh := make(chan error, 1) + serverReady := make(chan struct{}) go func() { - serverErrCh <- BlackHoleServer(serverCtx, strconv.Itoa(port)) + serverErrCh <- BlackHoleServer(serverCtx, strconv.Itoa(port), serverReady) }() - // Wait for server to start - time.Sleep(50 * time.Millisecond) + // Wait for the listener to be bound before dialing. + select { + case <-serverReady: + case <-testCtx.Done(): + t.Fatal("blackhole server did not start in time") + } _, err = db.NewNamespacedDB( testCtx, diff --git a/pkg/db/sequences_test.go b/pkg/db/sequences_test.go index f21f1a906..2a2c82912 100644 --- a/pkg/db/sequences_test.go +++ b/pkg/db/sequences_test.go @@ -7,7 +7,6 @@ import ( "sync" "sync/atomic" "testing" - "time" "github.com/stretchr/testify/require" db2 "github.com/xmtp/xmtpd/pkg/db" @@ -55,7 +54,43 @@ func getNextPayerSequence(t *testing.T, ctx context.Context, db *sql.DB) (int64, return 0, err } t.Log("Acquired sequence ID: ", seq) - time.Sleep(10 * time.Millisecond) + + _, err = querier.DeleteAvailableNonce(ctx, seq) + if err != nil { + return 0, err + } + + return int64(seq), nil + }, + ) +} + +// getNextPayerSequenceHeld is like getNextPayerSequence but holds the locked +// row until the release channel is closed. It signals acquired.Done() after +// GetNextAvailableNonce returns, so tests can wait for all concurrent callers +// to hold their locks simultaneously. This replaces a `time.Sleep` hold with +// deterministic synchronization and still exercises `FOR UPDATE SKIP LOCKED`. +func getNextPayerSequenceHeld( + t *testing.T, + ctx context.Context, + db *sql.DB, + acquired *sync.WaitGroup, + release <-chan struct{}, +) (int64, error) { + return db2.RunInTxWithResult(ctx, db, &sql.TxOptions{}, + func(ctx context.Context, querier *queries.Queries) (int64, error) { + seq, err := querier.GetNextAvailableNonce(ctx) + acquired.Done() + if err != nil { + return 0, err + } + t.Log("Acquired sequence ID: ", seq) + + // Hold the row-level lock until every concurrent worker has + // acquired its own. This guarantees that FOR UPDATE SKIP LOCKED + // sees simultaneously-held rows, which is the property the test + // is verifying. + <-release _, err = querier.DeleteAvailableNonce(ctx, seq) if err != nil { @@ -75,7 +110,6 @@ func failNextPayerSequence(t *testing.T, ctx context.Context, db *sql.DB) (int64 return 0, err } t.Log("Acquired sequence ID: ", seq) - time.Sleep(10 * time.Millisecond) return 0, errors.New("failed to acquire sequence") }, @@ -98,9 +132,17 @@ func TestConcurrentReads(t *testing.T) { numClients := 20 results := make(chan int64, numClients) + // Force every worker to hold its row-level lock concurrently by using a + // two-phase barrier: first wait until all workers have acquired, then + // release them all at once. This makes `FOR UPDATE SKIP LOCKED` exercise + // real contention without relying on wall-clock sleeps. + var acquired sync.WaitGroup + acquired.Add(numClients) + release := make(chan struct{}) + for range numClients { wg.Go(func() { - seqID, err := getNextPayerSequence(t, ctx, db) + seqID, err := getNextPayerSequenceHeld(t, ctx, db, &acquired, release) if err != nil { t.Errorf("Error acquiring sequence: %v", err) } else { @@ -109,6 +151,10 @@ func TestConcurrentReads(t *testing.T) { }) } + // Wait for every worker to be holding a row lock, then release them. + acquired.Wait() + close(release) + // Wait for all goroutines to complete wg.Wait() close(results) diff --git a/pkg/payerreport/store_test.go b/pkg/payerreport/store_test.go index ae5bb287f..a6ffe2c85 100644 --- a/pkg/payerreport/store_test.go +++ b/pkg/payerreport/store_test.go @@ -3,6 +3,7 @@ package payerreport_test import ( "context" "math" + "runtime" "sync" "testing" "time" @@ -179,10 +180,21 @@ func TestIdempotentStore(t *testing.T) { require.Equal(t, report.ID, storedReports[0].ID) } +// waitForWallClockAfter spins until time.Now() is strictly after ref. +// Used to guarantee the next postgres now() call produces a CreatedAt strictly +// greater than ref. pg's now() has microsecond precision, and postgres shares +// the host wall clock in the dev env, so this gives a deterministic ordering +// without a 1ms sleep. +func waitForWallClockAfter(ref time.Time) { + for !time.Now().After(ref) { + runtime.Gosched() + } +} + func TestFetchReport(t *testing.T) { store := createTestStore(t) report1 := insertRandomReport(t, store) - time.Sleep(1 * time.Millisecond) + waitForWallClockAfter(report1.CreatedAt) report2 := insertRandomReport(t, store) // Set the second report's status to Approved attestation := &payerreport.PayerReportAttestation{ diff --git a/pkg/payerreport/workers/integration_test.go b/pkg/payerreport/workers/integration_test.go index 8410a748d..eba1ba3a1 100644 --- a/pkg/payerreport/workers/integration_test.go +++ b/pkg/payerreport/workers/integration_test.go @@ -466,10 +466,28 @@ func TestCanGenerateReport(t *testing.T) { err = scaffold.reportGenerators[0].GenerateReports() require.NoError(t, err) - // Make sure there is still only one report after generating again - time.Sleep(100 * time.Millisecond) - messagesOnNode1 := scaffold.getMessagesFromTopic(t, 0, node1ReportTopic) - require.Len(t, messagesOnNode1, 1) + // Make sure there is still only one report after generating again. + // Poll via require.Never instead of a fixed sleep, but tolerate transient + // RPC errors inside the predicate so unrelated intermediate failures do + // not get promoted into test failures — we only want to catch a genuine + // "a second report appeared" condition. + client := scaffold.clients[0] + require.Never(t, func() bool { + resp, err := client.QueryEnvelopes( + t.Context(), + &connect.Request[message_api.QueryEnvelopesRequest]{ + Msg: &message_api.QueryEnvelopesRequest{ + Query: &message_api.EnvelopesQuery{ + Topics: [][]byte{node1ReportTopic}, + }, + }, + }, + ) + if err != nil { + return false + } + return len(resp.Msg.GetEnvelopes()) != 1 + }, 100*time.Millisecond, 10*time.Millisecond) } func TestFullReportLifecycle(t *testing.T) { diff --git a/pkg/payerreport/workers/settlement_test.go b/pkg/payerreport/workers/settlement_test.go index 25f7a2545..3b145e74d 100644 --- a/pkg/payerreport/workers/settlement_test.go +++ b/pkg/payerreport/workers/settlement_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "math/big" + "runtime" "testing" "time" @@ -512,8 +513,12 @@ func TestSettleReportsWithPartialFailure(t *testing.T) { stored1 := storeReport(t, store, &report1.PayerReport) require.NoError(t, store.SetReportAttestationApproved(t.Context(), stored1.ID)) require.NoError(t, store.SetReportSubmitted(t.Context(), stored1.ID, 0)) - // Make sure the created_at is different between the two reports - time.Sleep(1 * time.Millisecond) + // Make sure the created_at is different between the two reports. pg now() has + // microsecond precision and shares the host wall clock in the dev env, so this + // spin gives a deterministic ordering without a 1ms sleep. + for !time.Now().After(stored1.CreatedAt) { + runtime.Gosched() + } payers2 := map[common.Address]currency.PicoDollar{ common.HexToAddress("0x2"): 200, diff --git a/pkg/ratelimiter/circuit_breaker.go b/pkg/ratelimiter/circuit_breaker.go index ee8e0593a..bd1c12615 100644 --- a/pkg/ratelimiter/circuit_breaker.go +++ b/pkg/ratelimiter/circuit_breaker.go @@ -39,6 +39,9 @@ type CircuitBreaker struct { mu sync.Mutex failureThreshold int cooldown time.Duration + // now returns the current time. Injectable so tests can advance the clock + // deterministically instead of sleeping through the cooldown. + now func() time.Time state BreakerState failureCount int @@ -52,6 +55,7 @@ func NewCircuitBreaker(failureThreshold int, cooldown time.Duration) *CircuitBre failureThreshold: failureThreshold, cooldown: cooldown, state: BreakerClosed, + now: time.Now, } } @@ -74,7 +78,7 @@ func (cb *CircuitBreaker) Allow() bool { case BreakerHalfOpen: return true case BreakerOpen: - if time.Since(cb.openedAt) >= cb.cooldown { + if cb.now().Sub(cb.openedAt) >= cb.cooldown { cb.state = BreakerHalfOpen BreakerStateGauge.Set(1) return true @@ -108,7 +112,7 @@ func (cb *CircuitBreaker) RecordFailure() { } if cb.state == BreakerHalfOpen { cb.state = BreakerOpen - cb.openedAt = time.Now() + cb.openedAt = cb.now() BreakerStateGauge.Set(2) BreakerTripsTotal.Inc() return @@ -116,7 +120,7 @@ func (cb *CircuitBreaker) RecordFailure() { cb.failureCount++ if cb.failureCount >= cb.failureThreshold { cb.state = BreakerOpen - cb.openedAt = time.Now() + cb.openedAt = cb.now() BreakerStateGauge.Set(2) BreakerTripsTotal.Inc() } diff --git a/pkg/ratelimiter/circuit_breaker_test.go b/pkg/ratelimiter/circuit_breaker_test.go index 2d75b4a54..9b0315919 100644 --- a/pkg/ratelimiter/circuit_breaker_test.go +++ b/pkg/ratelimiter/circuit_breaker_test.go @@ -26,11 +26,14 @@ func TestCircuitBreaker_OpensAfterThresholdFailures(t *testing.T) { func TestCircuitBreaker_HalfOpenAfterCooldown(t *testing.T) { cb := NewCircuitBreaker(1, 50*time.Millisecond) + fakeNow := time.Now() + cb.now = func() time.Time { return fakeNow } cb.RecordFailure() require.Equal(t, BreakerOpen, cb.State()) require.False(t, cb.Allow()) - time.Sleep(80 * time.Millisecond) + // Advance the fake clock past the cooldown without sleeping. + fakeNow = fakeNow.Add(80 * time.Millisecond) require.True(t, cb.Allow()) require.Equal(t, BreakerHalfOpen, cb.State()) @@ -63,8 +66,10 @@ func TestCircuitBreaker_RecordFailureWhileOpenIsIdempotent(t *testing.T) { func TestCircuitBreaker_HalfOpenSuccessClosesCircuit(t *testing.T) { cb := NewCircuitBreaker(1, 50*time.Millisecond) + fakeNow := time.Now() + cb.now = func() time.Time { return fakeNow } cb.RecordFailure() - time.Sleep(80 * time.Millisecond) + fakeNow = fakeNow.Add(80 * time.Millisecond) require.True(t, cb.Allow()) cb.RecordSuccess() @@ -73,8 +78,10 @@ func TestCircuitBreaker_HalfOpenSuccessClosesCircuit(t *testing.T) { func TestCircuitBreaker_HalfOpenFailureReopens(t *testing.T) { cb := NewCircuitBreaker(1, 50*time.Millisecond) + fakeNow := time.Now() + cb.now = func() time.Time { return fakeNow } cb.RecordFailure() - time.Sleep(80 * time.Millisecond) + fakeNow = fakeNow.Add(80 * time.Millisecond) require.True(t, cb.Allow()) cb.RecordFailure() diff --git a/pkg/ratelimiter/export_test.go b/pkg/ratelimiter/export_test.go new file mode 100644 index 000000000..c3c5fd19b --- /dev/null +++ b/pkg/ratelimiter/export_test.go @@ -0,0 +1,8 @@ +package ratelimiter + +import "time" + +// SetNowForTest overrides the RedisLimiter clock. Test-only. +func (l *RedisLimiter) SetNowForTest(now func() time.Time) { + l.now = now +} diff --git a/pkg/ratelimiter/redis_limiter.go b/pkg/ratelimiter/redis_limiter.go index 21da6dd0c..9fdb2562a 100644 --- a/pkg/ratelimiter/redis_limiter.go +++ b/pkg/ratelimiter/redis_limiter.go @@ -18,6 +18,11 @@ type RedisLimiter struct { script *redis.Script keyPrefix string limits []Limit + // now returns the current time. Injectable so tests can advance the clock + // deterministically instead of sleeping through the refill window. The Lua + // script uses the client-provided timestamp for refill math, so overriding + // `now` is enough to fully control refill behavior in tests. + now func() time.Time } func NewRedisLimiter( @@ -34,6 +39,7 @@ func NewRedisLimiter( script: redis.NewScript(luaScript), keyPrefix: keyPrefix, limits: limits, + now: time.Now, }, nil } @@ -66,7 +72,7 @@ func (l *RedisLimiter) Allow(ctx context.Context, subject string, cost uint64) ( if cost == 0 { return nil, ErrCostMustBeGreaterThanZero } - now := time.Now() + now := l.now() keys := l.buildKeys(subject) args := l.buildArgs(now, cost) diff --git a/pkg/ratelimiter/redis_limiter_test.go b/pkg/ratelimiter/redis_limiter_test.go index 4511d2bf6..1d5694996 100644 --- a/pkg/ratelimiter/redis_limiter_test.go +++ b/pkg/ratelimiter/redis_limiter_test.go @@ -439,6 +439,12 @@ func TestRedisLimiter_Refill(t *testing.T) { []ratelimiter.Limit{{Capacity: 10, RefillEvery: 100 * time.Millisecond}}) require.NoError(t, err) + // The Lua script uses the client-provided timestamp for refill math, so + // overriding the limiter's clock fully controls refill behavior without + // any wall-clock sleeping. + fakeNow := time.Now() + limiter.SetNowForTest(func() time.Time { return fakeNow }) + // Consume all tokens res, err := limiter.Allow(context.Background(), "test-subject", 10) require.NoError(t, err) @@ -450,8 +456,8 @@ func TestRedisLimiter_Refill(t *testing.T) { require.NoError(t, err) require.False(t, res.Allowed) - // Wait for partial refill - time.Sleep(50 * time.Millisecond) + // Advance the fake clock past the partial refill window. + fakeNow = fakeNow.Add(50 * time.Millisecond) res, err = limiter.Allow(context.Background(), "test-subject", 1) require.NoError(t, err) require.True(t, res.Allowed, "should allow after partial refill") diff --git a/pkg/registrant/registrant.go b/pkg/registrant/registrant.go index 967a4fb49..43a80ca50 100644 --- a/pkg/registrant/registrant.go +++ b/pkg/registrant/registrant.go @@ -86,6 +86,13 @@ func (r *Registrant) SignStagedEnvelope( congestionFee currency.PicoDollar, retentionDays uint32, ) (*envelopes.OriginatorEnvelope, error) { + // Expiry is derived from the staged envelope's originator_time (set once by + // the DB at staging) rather than time.Now(), so that SignStagedEnvelope is a + // pure function of its inputs. PublishPayerEnvelopes signs the same staged + // envelope twice (once to return to the client, once in the publish worker + // before storing it); basing expiry on OriginatorTime guarantees both calls + // produce identical bytes and signatures, so the envelope the client sees + // in the publish response matches what is stored and later queried back. unsignedEnv := envelopes.UnsignedOriginatorEnvelope{ OriginatorNodeId: r.record.NodeID, OriginatorSequenceId: uint64(stagedEnv.ID), @@ -94,7 +101,7 @@ func (r *Registrant) SignStagedEnvelope( BaseFeePicodollars: uint64(baseFee), CongestionFeePicodollars: uint64(congestionFee), ExpiryUnixtime: uint64( - time.Now().UTC(). + stagedEnv.OriginatorTime.UTC(). Add(time.Hour * 24 * time.Duration(retentionDays)). Unix(), ), diff --git a/pkg/testutils/redis/redis.go b/pkg/testutils/redis/redis.go index 6db01f7d4..852e7680a 100644 --- a/pkg/testutils/redis/redis.go +++ b/pkg/testutils/redis/redis.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "strings" + "sync/atomic" "testing" "time" @@ -13,6 +14,12 @@ import ( const redisAddress = "localhost:6379" +// testKeyPrefixCounter monotonically increases for every call to +// generateTestKeyPrefix to guarantee uniqueness even when two calls happen +// inside the same millisecond (as was the case when one test creates two +// Redis clients back-to-back). +var testKeyPrefixCounter atomic.Uint64 + // NewRedisForTest creates a Redis client configured for testing with proper cleanup. // It automatically generates a unique key prefix based on the test name to avoid conflicts. // All keys created with this prefix are cleaned up after the test. @@ -45,7 +52,10 @@ func NewRedisForTest(t *testing.T) (redis.UniversalClient, string) { return client, keyPrefix } -// generateTestKeyPrefix creates a unique key prefix based on test name and timestamp +// generateTestKeyPrefix creates a unique key prefix based on test name, +// timestamp, and a process-level monotonic counter. The counter ensures that +// two calls in the same millisecond (e.g. two clients created back-to-back +// within a single test) receive distinct prefixes. func generateTestKeyPrefix(t *testing.T) string { // Clean test name to be Redis-key safe testName := strings.ReplaceAll(t.Name(), "/", "_") @@ -54,8 +64,9 @@ func generateTestKeyPrefix(t *testing.T) string { // Add timestamp to ensure uniqueness even for parallel runs timestamp := time.Now().UnixNano() / int64(time.Millisecond) + seq := testKeyPrefixCounter.Add(1) - return fmt.Sprintf("test:%s:%d:", testName, timestamp) + return fmt.Sprintf("test:%s:%d:%d:", testName, timestamp, seq) } // cleanupKeysByPrefix removes all keys matching the prefix pattern diff --git a/pkg/tracing/context_store.go b/pkg/tracing/context_store.go index 998d7df58..b1c4cf73f 100644 --- a/pkg/tracing/context_store.go +++ b/pkg/tracing/context_store.go @@ -26,6 +26,9 @@ type TraceContextStore struct { ttl time.Duration lastCleanup time.Time cleanupCount int // Track cleanups for testing/monitoring + // now returns the current time. Injectable so tests can advance the + // clock deterministically instead of sleeping. + now func() time.Time } // Span limits for production safety - prevent runaway memory/payload sizes. @@ -51,6 +54,7 @@ func NewTraceContextStore() *TraceContextStore { contexts: make(map[int64]traceContextEntry), ttl: DefaultTraceContextTTL, lastCleanup: time.Now(), + now: time.Now, } } @@ -66,8 +70,10 @@ func (s *TraceContextStore) Store(stagedID int64, span Span) { s.mu.Lock() defer s.mu.Unlock() + now := s.now() + // Lazy cleanup: run every minute to prevent unbounded growth - if time.Since(s.lastCleanup) > time.Minute { + if now.Sub(s.lastCleanup) > time.Minute { s.cleanupExpiredLocked() } @@ -79,7 +85,7 @@ func (s *TraceContextStore) Store(stagedID int64, span Span) { s.contexts[stagedID] = traceContextEntry{ ctx: span.Context(), - createdAt: time.Now(), + createdAt: now, } } @@ -98,7 +104,7 @@ func (s *TraceContextStore) Retrieve(stagedID int64) ddtrace.SpanContext { delete(s.contexts, stagedID) // Check if expired - if time.Since(entry.createdAt) > s.ttl { + if s.now().Sub(entry.createdAt) > s.ttl { return nil } @@ -108,7 +114,7 @@ func (s *TraceContextStore) Retrieve(stagedID int64) ddtrace.SpanContext { // cleanupExpiredLocked removes entries older than TTL. // Must be called with lock held. func (s *TraceContextStore) cleanupExpiredLocked() { - now := time.Now() + now := s.now() for id, entry := range s.contexts { if now.Sub(entry.createdAt) > s.ttl { delete(s.contexts, id) diff --git a/pkg/tracing/tracing_test.go b/pkg/tracing/tracing_test.go index 01b10919b..272c6e56f 100644 --- a/pkg/tracing/tracing_test.go +++ b/pkg/tracing/tracing_test.go @@ -66,7 +66,10 @@ func TestTraceContextStore_StoreNilSpan(t *testing.T) { func TestTraceContextStore_TTLExpiration(t *testing.T) { enableTracingForTest(t) store := NewTraceContextStore() - // Set short TTL for testing + // Drive the store with a controllable fake clock so TTL expiration is + // deterministic and does not require a wall-clock sleep. + fakeNow := time.Now() + store.now = func() time.Time { return fakeNow } store.ttl = 50 * time.Millisecond mt := mocktracer.Start() @@ -78,8 +81,8 @@ func TestTraceContextStore_TTLExpiration(t *testing.T) { store.Store(stagedID, span) assert.Equal(t, 1, store.Size()) - // Wait for TTL to expire - time.Sleep(100 * time.Millisecond) + // Advance the fake clock past the TTL. + fakeNow = fakeNow.Add(100 * time.Millisecond) // Retrieve should return nil for expired entry ctx := store.Retrieve(stagedID) @@ -94,6 +97,9 @@ func TestTraceContextStore_TTLExpiration(t *testing.T) { func TestTraceContextStore_CleanupRemovesExpired(t *testing.T) { enableTracingForTest(t) store := NewTraceContextStore() + // Use a fake clock so we can advance time past the TTL without sleeping. + fakeNow := time.Now() + store.now = func() time.Time { return fakeNow } store.ttl = 50 * time.Millisecond mt := mocktracer.Start() @@ -103,11 +109,11 @@ func TestTraceContextStore_CleanupRemovesExpired(t *testing.T) { span1 := StartSpan("test.operation1") store.Store(1, span1) - // Wait for it to expire - time.Sleep(100 * time.Millisecond) + // Advance the clock past the TTL so span1 is now expired. + fakeNow = fakeNow.Add(100 * time.Millisecond) // Force cleanup to run on next store by setting lastCleanup in the past - store.lastCleanup = time.Now().Add(-2 * time.Minute) + store.lastCleanup = fakeNow.Add(-2 * time.Minute) // Store second span - this should trigger cleanup of expired span1 span2 := StartSpan("test.operation2")