diff --git a/census/censusdb/censusdb.go b/census/censusdb/censusdb.go index a5416ae75..ac1710e9b 100644 --- a/census/censusdb/censusdb.go +++ b/census/censusdb/censusdb.go @@ -281,6 +281,75 @@ func (c *CensusDB) LoadByScopedAddress(chainID uint64, address common.Address) ( return c.loadCensusRef(scopedAddressToCensusID(chainID, address), scopedAddressDBPrefix(chainID, address)) } +// LoadFromPersistentTree loads a census tree directly from the persistent KV +// storage (censusTreeDBPrefix) for the given root hash, bypassing the +// ephemeral Pebble cache. This recovers census state after a node restart +// when the Pebble tree on disk is empty but the KV-backed tree still holds +// all leaves from the original Import call. +// +// If the census is already in the in-memory cache but its tree has no root +// (i.e. the cached ref holds an empty Pebble tree), the cached tree is +// replaced with the persistent KV-backed tree. +func (c *CensusDB) LoadFromPersistentTree(root types.HexBytes) (*CensusRef, error) { + censusID := rootToCensusID(root) + + // Check in-memory cache, but only return the cached ref if its tree has a + // valid root. If the cached tree is empty (happens after restart when the + // Pebble store has been wiped), fall through and reload from the KV backend. + c.mu.RLock() + if ref, exists := c.loadedCensus[censusID]; exists { + if _, ok := ref.tree.Root(); ok { + c.mu.RUnlock() + return ref, nil + } + } + c.mu.RUnlock() + + // Open the persistent KV-backed tree. + treeDB := prefixeddb.NewPrefixedDatabase(c.db, censusTreeDBPrefix(censusID)) + tree, err := census.NewCensusIMT(treeDB, censusHasher) + if err != nil { + return nil, fmt.Errorf("failed to open persistent census tree: %w", err) + } + treeRoot, ok := tree.Root() + if !ok { + return nil, fmt.Errorf("%w: persistent tree has no root for census root %s", + ErrCensusNotFound, root.String()) + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Re-check under write lock. + if ref, exists := c.loadedCensus[censusID]; exists { + if _, ok := ref.tree.Root(); ok { + return ref, nil + } + // Cached ref has an empty tree — replace it with the persistent tree. + ref.SetTree(tree) + ref.currentRoot = treeRoot.Bytes() + ref.LastUsed = time.Now() + return ref, nil + } + + ref := &CensusRef{ + ID: censusID, + HashType: censusHasherName, + LastUsed: time.Now(), + updateRootRequest: c.updateRootChan, + } + ref.currentRoot = treeRoot.Bytes() + ref.SetTree(tree) + + c.loadedCensus[censusID] = ref + rk := rootKey(ref.currentRoot) + if _, exists := c.rootIndex[rk]; !exists { + c.rootIndex[rk] = censusID + } + + return ref, nil +} + // loadCensusRef loads a census reference from memory or persistent DB using a // double‑check. It takes the censusID and the key to use. func (c *CensusDB) loadCensusRef(censusID uuid.UUID, key types.HexBytes) (*CensusRef, error) { diff --git a/census/graphql.go b/census/graphql.go index b82655f18..30b662bda 100644 --- a/census/graphql.go +++ b/census/graphql.go @@ -248,7 +248,7 @@ func queryEvents( if err != nil { return nil, fmt.Errorf("error creating request: %v", err) } - req.Header.Set("Content-Type", "application/json") + req.Header.Set(contentTypeHeader, "application/json") res, err := client.Do(req) if err != nil { return nil, fmt.Errorf("error executing request: %v", err) diff --git a/census/importer_test.go b/census/importer_test.go index 2e8a9ecdb..f820f7df0 100644 --- a/census/importer_test.go +++ b/census/importer_test.go @@ -14,6 +14,8 @@ import ( "github.com/vocdoni/davinci-node/types" ) +const testCensusURI = "https://example.invalid/dump" + func testNewCensusDB(c *qt.C) *censusdb.CensusDB { c.Helper() internalDB, err := metadb.New(db.TypeInMem, "") @@ -73,7 +75,7 @@ func TestCensusImporter(t *testing.T) { importer := NewCensusImporter(nil) _, err := importer.ImportCensus(c.Context(), 0, &types.Census{ CensusOrigin: types.CensusOriginUnknown, - CensusURI: "https://example.invalid/dump", + CensusURI: testCensusURI, CensusRoot: types.HexBytes{0x01}, }, 0) c.Assert(err, qt.Not(qt.IsNil)) @@ -104,13 +106,13 @@ func TestCensusImporter(t *testing.T) { validFn: func(string) bool { return false }, } plugin2 := &testImporterPlugin{ - validFn: func(uri string) bool { return uri == "https://example.invalid/dump" }, + validFn: func(uri string) bool { return uri == testCensusURI }, } importer := NewCensusImporter(stg, plugin1, plugin2) census := &types.Census{ CensusOrigin: types.CensusOriginMerkleTreeOffchainStaticV1, - CensusURI: "https://example.invalid/dump", + CensusURI: testCensusURI, CensusRoot: types.HexBytes{0xaa, 0xbb}, } @@ -136,7 +138,7 @@ func TestCensusImporter(t *testing.T) { importer := NewCensusImporter(stg, plugin) _, err := importer.ImportCensus(c.Context(), 0, &types.Census{ CensusOrigin: types.CensusOriginMerkleTreeOffchainDynamicV1, - CensusURI: "https://example.invalid/dump", + CensusURI: testCensusURI, CensusRoot: types.HexBytes{0x01}, }, 0) c.Assert(err, qt.ErrorIs, sentinelErr) @@ -150,7 +152,7 @@ func TestCensusImporter(t *testing.T) { importer := NewCensusImporter(stg, plugin) _, err := importer.ImportCensus(c.Context(), 0, &types.Census{ CensusOrigin: types.CensusOriginMerkleTreeOffchainStaticV1, - CensusURI: "https://example.invalid/dump", + CensusURI: testCensusURI, CensusRoot: types.HexBytes{0x01}, }, 0) c.Assert(err, qt.Not(qt.IsNil)) diff --git a/census/json.go b/census/json.go index bfaee710b..be5689d15 100644 --- a/census/json.go +++ b/census/json.go @@ -29,6 +29,8 @@ const ( UnknownJSON ) +const contentTypeHeader = "Content-Type" + // String returns the string representation of the JSONFormat. func (format JSONFormat) String() string { switch format { @@ -127,7 +129,7 @@ func requestRawDump(ctx context.Context, targetURL string) (*http.Response, erro // cannot be determined, it returns UnknownJSON. func jsonReader(res *http.Response) (io.Reader, JSONFormat, error) { // Check Content-Type header first - contentType := strings.ToLower(res.Header.Get("Content-Type")) + contentType := strings.ToLower(res.Header.Get(contentTypeHeader)) if strings.Contains(contentType, "ndjson") || strings.Contains(contentType, "jsonl") { return res.Body, JSONL, nil diff --git a/census/json_test.go b/census/json_test.go index d3872e6b9..edc7a9cf7 100644 --- a/census/json_test.go +++ b/census/json_test.go @@ -18,6 +18,11 @@ import ( leancensus "github.com/vocdoni/lean-imt-go/census" ) +const ( + testCensusURI2 = "https://example.invalid/dump" + testContentTypeNDJSON = "application/x-ndjson" +) + type testErrReader struct { err error } @@ -125,7 +130,7 @@ func TestRequestRawDump(t *testing.T) { c.Assert(r.Header.Get("Accept"), qt.Equals, "application/x-ndjson, application/json;q=0.9, */*;q=0.1") return &http.Response{ StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/x-ndjson"}}, + Header: http.Header{contentTypeHeader: []string{testContentTypeNDJSON}}, Body: io.NopCloser(strings.NewReader(`{"ok":true}` + "\n")), Request: r, }, nil @@ -192,7 +197,7 @@ func TestJSONReader(t *testing.T) { c.Run("ContentTypeJSONL", func(c *qt.C) { const body = `{"a":1}` + "\n" res := &http.Response{ - Header: http.Header{"Content-Type": []string{"application/x-ndjson"}}, + Header: http.Header{contentTypeHeader: []string{testContentTypeNDJSON}}, Body: io.NopCloser(strings.NewReader(body)), } @@ -208,7 +213,7 @@ func TestJSONReader(t *testing.T) { c.Run("ContentTypeJSONArray", func(c *qt.C) { const body = `[{"a":1}]` res := &http.Response{ - Header: http.Header{"Content-Type": []string{"application/json; charset=utf-8"}}, + Header: http.Header{contentTypeHeader: []string{"application/json; charset=utf-8"}}, Body: io.NopCloser(strings.NewReader(body)), } @@ -352,7 +357,7 @@ func TestJSONDownloadAndImportCensus(t *testing.T) { http.DefaultTransport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}}, + Header: http.Header{contentTypeHeader: []string{"application/json"}}, Body: &testReadCloser{ Reader: bytes.NewReader(dumpJSON), closeErr: fmt.Errorf("close error"), @@ -382,7 +387,7 @@ func TestJSONDownloadAndImportCensus(t *testing.T) { c.Cleanup(func() { http.DefaultTransport = oldTransport }) _, err := ji.ImportCensus(c.Context(), censusDB, 0, &types.Census{ - CensusURI: "https://example.invalid/dump", + CensusURI: testCensusURI2, CensusRoot: expectedRoot, }, 0) c.Assert(err, qt.Not(qt.IsNil)) @@ -402,7 +407,7 @@ func TestJSONDownloadAndImportCensus(t *testing.T) { c.Cleanup(func() { http.DefaultTransport = oldTransport }) _, err := ji.ImportCensus(c.Context(), censusDB, 0, &types.Census{ - CensusURI: "https://example.invalid/dump", + CensusURI: testCensusURI2, CensusRoot: expectedRoot, }, 0) c.Assert(err, qt.Not(qt.IsNil)) @@ -414,7 +419,7 @@ func TestJSONDownloadAndImportCensus(t *testing.T) { http.DefaultTransport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/x-ndjson"}}, + Header: http.Header{contentTypeHeader: []string{testContentTypeNDJSON}}, Body: io.NopCloser(strings.NewReader("not json\n")), Request: r, }, nil @@ -435,7 +440,7 @@ func TestJSONDownloadAndImportCensus(t *testing.T) { http.DefaultTransport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}}, + Header: http.Header{contentTypeHeader: []string{"application/json"}}, Body: io.NopCloser(bytes.NewReader(dumpJSON)), Request: r, }, nil diff --git a/census/test/graphql.go b/census/test/graphql.go index 047e08327..08178e0c3 100644 --- a/census/test/graphql.go +++ b/census/test/graphql.go @@ -13,11 +13,16 @@ import ( "github.com/vocdoni/davinci-node/types" ) +const ( + testAccountID1 = "0xdeb8699659be5d41a0e57e179d6cb42e00b9200c" + testAccountID2 = "0xb1f05b11ba3d892edd00f2e7689779e2b8841827" +) + var ( DefaultExpectedRoot = types.HexStringToHexBytesMustUnmarshal("0x0b3600e19a4f5017dea4f91f03d8aa0dd6f4c797795e7a5aa542e81b2c5a9485") DefaultGraphQLEvents = []TestWeightChangeEvent{ { - AccountID: "0xdeb8699659be5d41a0e57e179d6cb42e00b9200c", + AccountID: testAccountID1, PreviousWeight: "0", NewWeight: "1", }, @@ -27,7 +32,7 @@ var ( NewWeight: "1", }, { - AccountID: "0xb1f05b11ba3d892edd00f2e7689779e2b8841827", + AccountID: testAccountID2, PreviousWeight: "0", NewWeight: "1", }, @@ -47,12 +52,12 @@ var ( NewWeight: "0", }, { - AccountID: "0xdeb8699659be5d41a0e57e179d6cb42e00b9200c", + AccountID: testAccountID1, PreviousWeight: "1", NewWeight: "0", }, { - AccountID: "0xb1f05b11ba3d892edd00f2e7689779e2b8841827", + AccountID: testAccountID2, PreviousWeight: "1", NewWeight: "2", }, @@ -62,12 +67,12 @@ var ( NewWeight: "1", }, { - AccountID: "0xb1f05b11ba3d892edd00f2e7689779e2b8841827", + AccountID: testAccountID2, PreviousWeight: "2", NewWeight: "1", }, { - AccountID: "0xdeb8699659be5d41a0e57e179d6cb42e00b9200c", + AccountID: testAccountID1, PreviousWeight: "0", NewWeight: "1", }, diff --git a/circuits/test/statetransition/statetransition_test.go b/circuits/test/statetransition/statetransition_test.go index 13cf9736a..7b0cccbca 100644 --- a/circuits/test/statetransition/statetransition_test.go +++ b/circuits/test/statetransition/statetransition_test.go @@ -37,11 +37,14 @@ import ( "github.com/vocdoni/davinci-node/util" ) -const falseString = "false" +const ( + falseString = "false" + testTimeFormat = "15:04:05" +) func TestMain(m *testing.M) { // enable log to see nbConstraints - logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: testTimeFormat}).With().Timestamp().Logger()) m.Run() } @@ -50,7 +53,7 @@ func testCircuitCompile(t *testing.T, c frontend.Circuit) { t.Skip("skipping circuit tests...") } // enable log to see nbConstraints - logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: testTimeFormat}).With().Timestamp().Logger()) if _, err := frontend.Compile(params.StateTransitionCurve.ScalarField(), r1cs.NewBuilder, c); err != nil { panic(err) } @@ -69,7 +72,7 @@ func testCircuitProve(t *testing.T, circuit, assignment frontend.Circuit) { } func TestStateTransitionCircuit(t *testing.T) { - logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: testTimeFormat}).With().Timestamp().Logger()) if os.Getenv("RUN_CIRCUIT_TESTS") == "" || os.Getenv("RUN_CIRCUIT_TESTS") == falseString { t.Skip("skipping circuit tests...") } @@ -87,7 +90,7 @@ func TestStateTransitionCircuit(t *testing.T) { } func TestStateTransitionFullProvingCircuit(t *testing.T) { - logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: testTimeFormat}).With().Timestamp().Logger()) if os.Getenv("RUN_CIRCUIT_TESTS") == "" || os.Getenv("RUN_CIRCUIT_TESTS") == falseString { t.Skip("skipping circuit tests...") } diff --git a/circuits/test/voteverifier/vote_verifier_test.go b/circuits/test/voteverifier/vote_verifier_test.go index b0f5bd22d..0d2792e5f 100644 --- a/circuits/test/voteverifier/vote_verifier_test.go +++ b/circuits/test/voteverifier/vote_verifier_test.go @@ -21,8 +21,10 @@ import ( "github.com/vocdoni/davinci-node/types" ) +const testTimeFormat = "15:04:05" + func TestVerifyMerkletreeVoteCircuit(t *testing.T) { - logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: testTimeFormat}).With().Timestamp().Logger()) c := qt.New(t) // Generate a deterministic voter account for reproducible test data. s, err := ballottest.GenDeterministicECDSAaccountForTest(0) @@ -45,7 +47,7 @@ func TestVerifyMerkletreeVoteCircuit(t *testing.T) { } func TestVerifyCSPVoteCircuit(t *testing.T) { - logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: testTimeFormat}).With().Timestamp().Logger()) c := qt.New(t) // Generate a deterministic voter account for reproducible test data. s, err := ballottest.GenDeterministicECDSAaccountForTest(0) @@ -81,7 +83,7 @@ func TestVerifyNoValidVoteCircuit(t *testing.T) { } func TestVerifyMultipleVotesCircuit(t *testing.T) { - logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: testTimeFormat}).With().Timestamp().Logger()) if os.Getenv("RUN_CIRCUIT_TESTS") == "" || os.Getenv("RUN_CIRCUIT_TESTS") == "false" { t.Skip("skipping circuit tests...") } @@ -109,7 +111,7 @@ func TestCompileAndPrintConstraints(t *testing.T) { if os.Getenv("RUN_CIRCUIT_TESTS") == "" || os.Getenv("RUN_CIRCUIT_TESTS") == "false" { t.Skip("skipping circuit tests...") } - logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: testTimeFormat}).With().Timestamp().Logger()) c := qt.New(t) // generate vote verifier circuit and inputs with deterministic ProcessID vvPlaceholder, err := voteverifier.DummyPlaceholder() diff --git a/db/mongodb/mongodb.go b/db/mongodb/mongodb.go index 9238d2302..2cbaf85d0 100644 --- a/db/mongodb/mongodb.go +++ b/db/mongodb/mongodb.go @@ -24,6 +24,8 @@ const ( MongodbTimeoutCommit = 12 * time.Second // MongodbTimeoutQuery is the timeout for querying the database MongodbTimeoutQuery = 4 * time.Second + // mongoIDField is the MongoDB field name for the document ID + mongoIDField = "_id" ) // MongoDB is a MongoDB implementation of the db.DB interface @@ -110,7 +112,7 @@ func (tx *WriteTx) Get(k []byte) ([]byte, error) { } collection := tx.db.Collection(tx.collection) - filter := bson.M{"_id": string(k)} // Convert to string + filter := bson.M{mongoIDField: string(k)} // Convert to string var result KeyVal ctx, cancel := context.WithTimeout(context.Background(), MongodbTimeoutQuery) defer cancel() @@ -137,7 +139,7 @@ func (tx *WriteTx) Iterate(prefix []byte, callback func(key, value []byte) bool) filter := bson.M{} if len(prefix) > 0 { filter = bson.M{ - "_id": bson.M{ + mongoIDField: bson.M{ "$regex": primitive.Regex{ Pattern: "^" + string(prefix), }, @@ -182,7 +184,7 @@ func (tx *WriteTx) Set(k, v []byte) error { model := mongo.NewUpdateOneModel().SetFilter( bson.D{ primitive.E{ - Key: "_id", + Key: mongoIDField, Value: string(k), }, }, @@ -194,7 +196,7 @@ func (tx *WriteTx) Set(k, v []byte) error { } func (tx *WriteTx) Delete(k []byte) error { - model := mongo.NewDeleteOneModel().SetFilter(bson.M{"_id": string(k)}) // Convert to string + model := mongo.NewDeleteOneModel().SetFilter(bson.M{mongoIDField: string(k)}) // Convert to string tx.batch = append(tx.batch, model) delete(tx.inMem, string(k)) return nil diff --git a/metadata/pinata_test.go b/metadata/pinata_test.go index 69b133e96..65c8c0c8e 100644 --- a/metadata/pinata_test.go +++ b/metadata/pinata_test.go @@ -14,6 +14,13 @@ import ( "github.com/vocdoni/davinci-node/types" ) +const ( + testGatewayURL = "gateway.example" + testJWT = "jwt" + testToken = "token" + testStatusOK = "200 OK" +) + type roundTripFunc func(*http.Request) (*http.Response, error) func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { @@ -40,7 +47,7 @@ func newTestPinataProvider() *PinataMetadataProvider { return NewPinataMetadataProvider(PinataMetadataProviderConfig{ HostnameURL: "https://pinata.example/upload", HostnameJWT: "jwt-token", - GatewayURL: "gateway.example", + GatewayURL: testGatewayURL, GatewayToken: "gateway-token", }) } @@ -57,42 +64,42 @@ func TestPinataMetadataProviderConfigValid(t *testing.T) { name: "valid with all fields", config: PinataMetadataProviderConfig{ HostnameURL: "https://pinata.example/upload", - HostnameJWT: "jwt", - GatewayURL: "gateway.example", - GatewayToken: "token", + HostnameJWT: testJWT, + GatewayURL: testGatewayURL, + GatewayToken: testToken, }, valid: true, }, { name: "missing jwt", config: PinataMetadataProviderConfig{ - GatewayURL: "gateway.example", - GatewayToken: "token", + GatewayURL: testGatewayURL, + GatewayToken: testToken, }, valid: false, }, { name: "missing gateway url", config: PinataMetadataProviderConfig{ - HostnameJWT: "jwt", - GatewayToken: "token", + HostnameJWT: testJWT, + GatewayToken: testToken, }, valid: false, }, { name: "missing gateway token", config: PinataMetadataProviderConfig{ - HostnameJWT: "jwt", - GatewayURL: "gateway.example", + HostnameJWT: testJWT, + GatewayURL: testGatewayURL, }, valid: false, }, { name: "missing hostname url", config: PinataMetadataProviderConfig{ - HostnameJWT: "jwt", - GatewayURL: "gateway.example", - GatewayToken: "token", + HostnameJWT: testJWT, + GatewayURL: testGatewayURL, + GatewayToken: testToken, }, valid: false, }, @@ -150,7 +157,7 @@ func TestPinataMetadataProviderSetMetadata(t *testing.T) { respBody := fmt.Sprintf(`{"data":{"cid":"%s"}}`, mustCIDString(c, key)) return &http.Response{ StatusCode: http.StatusOK, - Status: "200 OK", + Status: testStatusOK, Body: io.NopCloser(strings.NewReader(respBody)), }, nil }) @@ -171,7 +178,7 @@ func TestPinataMetadataProviderSetMetadata(t *testing.T) { provider := NewPinataMetadataProvider(PinataMetadataProviderConfig{ HostnameURL: server.URL, HostnameJWT: "default-client-jwt", - GatewayURL: "gateway.example", + GatewayURL: testGatewayURL, GatewayToken: "gateway-token", }) @@ -204,7 +211,7 @@ func TestPinataMetadataProviderSetMetadata(t *testing.T) { provider.httpClient = newTestHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, - Status: "200 OK", + Status: testStatusOK, Body: errReadCloser{err: expectedErr}, }, nil }) @@ -236,7 +243,7 @@ func TestPinataMetadataProviderSetMetadata(t *testing.T) { provider.httpClient = newTestHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, - Status: "200 OK", + Status: testStatusOK, Body: io.NopCloser(strings.NewReader("not-json")), }, nil }) @@ -254,7 +261,7 @@ func TestPinataMetadataProviderSetMetadata(t *testing.T) { provider.httpClient = newTestHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, - Status: "200 OK", + Status: testStatusOK, Body: io.NopCloser(strings.NewReader(respBody)), }, nil }) @@ -286,7 +293,7 @@ func TestPinataMetadataProviderMetadata(t *testing.T) { c.Assert(err, qt.IsNil) return &http.Response{ StatusCode: http.StatusOK, - Status: "200 OK", + Status: testStatusOK, Body: io.NopCloser(strings.NewReader(string(body))), }, nil }) @@ -323,7 +330,7 @@ func TestPinataMetadataProviderMetadata(t *testing.T) { provider.httpClient = newTestHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, - Status: "200 OK", + Status: testStatusOK, Body: errReadCloser{err: expectedErr}, }, nil }) @@ -357,7 +364,7 @@ func TestPinataMetadataProviderMetadata(t *testing.T) { provider.httpClient = newTestHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, - Status: "200 OK", + Status: testStatusOK, Body: io.NopCloser(strings.NewReader("not-json")), }, nil }) diff --git a/sequencer/aggregate.go b/sequencer/aggregate.go index 16f0e5bc6..77a2c1540 100644 --- a/sequencer/aggregate.go +++ b/sequencer/aggregate.go @@ -11,6 +11,7 @@ import ( "github.com/consensys/gnark/std/algebra/native/sw_bls12377" "github.com/consensys/gnark/std/math/emulated" stdgroth16 "github.com/consensys/gnark/std/recursion/groth16" + "github.com/ethereum/go-ethereum/common" "github.com/vocdoni/davinci-node/circuits/aggregator" "github.com/vocdoni/davinci-node/circuits/voteverifier" "github.com/vocdoni/davinci-node/log" @@ -43,6 +44,7 @@ func collectAggregationBatchInputs( maxVotersReached bool, proofToRecursion proofToRecursionFn, verifyVoteVerifierProof voteVerifierProofValidatorFn, + checkCensusMembership func(b *storage.VerifiedBallot) bool, ) (*aggregator.AggregatorInputs, error) { // Prepare data structures for the aggregator circuit proofs := [params.VotesPerBatch]stdgroth16.Proof[sw_bls12377.G1Affine, sw_bls12377.G2Affine]{} @@ -54,19 +56,22 @@ func collectAggregationBatchInputs( for i, b := range ballots { if b == nil { - log.Warnw("skipping nil verified ballot", + log.Warnw( + "skipping nil verified ballot", "processID", processID.String(), "index", i, ) if i < len(keys) { if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { - log.Warnw("failed to mark nil ballot as failed", + log.Warnw( + "failed to mark nil ballot as failed", "error", err.Error(), "processID", processID.String(), "index", i, ) if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { - log.Warnw("failed to release ballot reservation after nil ballot failure marking", + log.Warnw( + "failed to release ballot reservation after nil ballot failure marking", "error", err.Error(), "processID", processID.String(), "index", i, @@ -82,20 +87,23 @@ func collectAggregationBatchInputs( if b.Address != nil { addressStr = b.Address.String() } - log.Warnw("skipping verified ballot with missing voteID", + log.Warnw( + "skipping verified ballot with missing voteID", "processID", processID.String(), "index", i, "address", addressStr, ) if i < len(keys) { if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { - log.Warnw("failed to mark ballot as failed", + log.Warnw( + "failed to mark ballot as failed", "error", err.Error(), "processID", processID.String(), "index", i, ) if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { - log.Warnw("failed to release ballot reservation after failure marking", + log.Warnw( + "failed to release ballot reservation after failure marking", "error", err.Error(), "processID", processID.String(), "index", i, @@ -106,20 +114,23 @@ func collectAggregationBatchInputs( continue } if b.Address == nil { - log.Warnw("skipping verified ballot with missing address", + log.Warnw( + "skipping verified ballot with missing address", "processID", processID.String(), "index", i, "voteID", b.VoteID.String(), ) if i < len(keys) { if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { - log.Warnw("failed to mark ballot as failed", + log.Warnw( + "failed to mark ballot as failed", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), ) if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { - log.Warnw("failed to release ballot reservation after failure marking", + log.Warnw( + "failed to release ballot reservation after failure marking", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), @@ -132,18 +143,21 @@ func collectAggregationBatchInputs( // if the vote ID already exists in the state, skip it if processState.ContainsVoteID(b.VoteID) { - log.Debugw("skipping ballot already in state", + log.Debugw( + "skipping ballot already in state", "processID", processID.String(), "voteID", b.VoteID.String(), ) if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { - log.Warnw("failed to mark ballot as failed", + log.Warnw( + "failed to mark ballot as failed", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), ) if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { - log.Warnw("failed to release ballot reservation after failure marking", + log.Warnw( + "failed to release ballot reservation after failure marking", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), @@ -160,14 +174,16 @@ func collectAggregationBatchInputs( "address", types.HexBytes(b.Address.Bytes()), "processID", processID.String()) if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { - log.Warnw("failed to mark ballot as failed", + log.Warnw( + "failed to mark ballot as failed", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), ) if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { - log.Warnw("failed to release ballot reservation after failure marking", + log.Warnw( + "failed to release ballot reservation after failure marking", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), @@ -179,20 +195,23 @@ func collectAggregationBatchInputs( } if b.Proof == nil { - log.Warnw("skipping verified ballot with missing vote verifier proof", + log.Warnw( + "skipping verified ballot with missing vote verifier proof", "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), ) if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { - log.Warnw("failed to mark ballot as failed", + log.Warnw( + "failed to mark ballot as failed", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), ) if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { - log.Warnw("failed to release ballot reservation after failure marking", + log.Warnw( + "failed to release ballot reservation after failure marking", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), @@ -203,20 +222,23 @@ func collectAggregationBatchInputs( continue } if b.InputsHash == nil { - log.Warnw("skipping verified ballot with missing vote verifier inputs hash", + log.Warnw( + "skipping verified ballot with missing vote verifier inputs hash", "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), ) if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { - log.Warnw("failed to mark ballot as failed", + log.Warnw( + "failed to mark ballot as failed", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), ) if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { - log.Warnw("failed to release ballot reservation after failure marking", + log.Warnw( + "failed to release ballot reservation after failure marking", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), @@ -228,20 +250,23 @@ func collectAggregationBatchInputs( } if !b.Proof.Ar.IsInSubGroup() || !b.Proof.Krs.IsInSubGroup() || !b.Proof.Bs.IsInSubGroup() { - log.Warnw("skipping verified ballot with malformed vote verifier proof (subgroup check failed)", + log.Warnw( + "skipping verified ballot with malformed vote verifier proof (subgroup check failed)", "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), ) if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { - log.Warnw("failed to mark ballot as failed", + log.Warnw( + "failed to mark ballot as failed", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), ) if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { - log.Warnw("failed to release ballot reservation after failure marking", + log.Warnw( + "failed to release ballot reservation after failure marking", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), @@ -254,21 +279,24 @@ func collectAggregationBatchInputs( if verifyVoteVerifierProof != nil { if err := verifyVoteVerifierProof(b); err != nil { - log.Warnw("skipping verified ballot with invalid vote verifier proof", + log.Warnw( + "skipping verified ballot with invalid vote verifier proof", "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), "error", err.Error(), ) if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { - log.Warnw("failed to mark ballot as failed", + log.Warnw( + "failed to mark ballot as failed", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), ) if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { - log.Warnw("failed to release ballot reservation after failure marking", + log.Warnw( + "failed to release ballot reservation after failure marking", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), @@ -280,6 +308,34 @@ func collectAggregationBatchInputs( } } + if checkCensusMembership != nil && !checkCensusMembership(b) { + log.Warnw( + "skipping ballot: address not found in census tree at aggregation time", + "processID", processID.String(), + "voteID", b.VoteID.String(), + "address", types.HexBytes(b.Address.Bytes()), + ) + if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { + log.Warnw( + "failed to mark census-absent ballot as failed", + "error", err.Error(), + "processID", processID.String(), + "voteID", b.VoteID.String(), + "address", types.HexBytes(b.Address.Bytes()), + ) + if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { + log.Warnw( + "failed to release ballot reservation after census-absent failure marking", + "error", err.Error(), + "processID", processID.String(), + "voteID", b.VoteID.String(), + "address", types.HexBytes(b.Address.Bytes()), + ) + } + } + continue + } + batchIdx := len(aggBallots) if batchIdx >= params.VotesPerBatch { remainingKeys := keys[i:] @@ -293,21 +349,24 @@ func collectAggregationBatchInputs( var err error proofs[batchIdx], err = proofToRecursion(groth16.Proof(b.Proof)) if err != nil { - log.Warnw("failed to transform proof for recursion; marking ballot as failed", + log.Warnw( + "failed to transform proof for recursion; marking ballot as failed", "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), "error", err.Error(), ) if err := stg.MarkVerifiedBallotsFailed(keys[i]); err != nil { - log.Warnw("failed to mark ballot as failed", + log.Warnw( + "failed to mark ballot as failed", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), ) if err := stg.ReleaseVerifiedBallotReservations([][]byte{keys[i]}); err != nil { - log.Warnw("failed to release ballot reservation after failure marking", + log.Warnw( + "failed to release ballot reservation after failure marking", "error", err.Error(), "processID", processID.String(), "voteID", b.VoteID.String(), @@ -515,6 +574,34 @@ func (s *Sequencer) aggregateBatch(processID types.ProcessID) error { return s.voteVerifier.Verify(vb.Proof, pubAssignment) } + // Build a census membership checker for merkle-tree censuses so that + // any address absent from the census is filtered out before the + // aggregator proof is generated. Filtering after proof generation would + // invalidate the BatchHash public input of the aggregator circuit. + // For CSP censuses no local tree is available, so the check is skipped. + var checkCensusMembership func(*storage.VerifiedBallot) bool + if proc, pErr := s.stg.Process(processID); pErr == nil && proc.Census.CensusOrigin.IsMerkleTree() { + var chainID uint64 + if proc.Census.CensusOrigin == types.CensusOriginMerkleTreeOnchainDynamicV1 { + if contracts, cErr := s.contractsForProcess(processID); cErr == nil { + chainID = contracts.ChainID + } + } + if censusRef, cErr := s.stg.LoadCensus(chainID, proc.Census); cErr == nil { + censusTree := censusRef.Tree() + checkCensusMembership = func(b *storage.VerifiedBallot) bool { + _, ok := censusTree.GetWeight(common.BigToAddress(b.Address)) + return ok + } + } else { + log.Warnw( + "could not load census for pre-aggregation check; census filter skipped", + "processID", processID.String(), + "error", cErr.Error(), + ) + } + } + batchInputs, err := collectAggregationBatchInputs( s.stg, processID, @@ -524,6 +611,7 @@ func (s *Sequencer) aggregateBatch(processID types.ProcessID) error { process.MaxVotersReached(), proofToRecursion, verifyVoteVerifierProof, + checkCensusMembership, ) if err != nil { return err @@ -559,7 +647,8 @@ func (s *Sequencer) aggregateBatch(processID types.ProcessID) error { log.Debugw("filling with dummy proofs", "count", params.VotesPerBatch-len(batchInputs.AggBallots)) if err := assignment.FillWithDummy(len(batchInputs.AggBallots), s.voteVerifierDummyProof); err != nil { if err := s.stg.ReleaseVerifiedBallotReservations(batchInputs.ProcessedKeys); err != nil { - log.Warnw("failed to release ballot reservations after dummy fill failure", + log.Warnw( + "failed to release ballot reservations after dummy fill failure", "error", err.Error(), "processID", processID.String(), ) @@ -592,7 +681,8 @@ func (s *Sequencer) aggregateBatch(processID types.ProcessID) error { addressStr = vb.Address.String() } } - log.Warnw("vote verifier proof does not verify for aggregation batch; excluding ballot", + log.Warnw( + "vote verifier proof does not verify for aggregation batch; excluding ballot", "processID", processID.String(), "voteID", voteIDStr, "address", addressStr, @@ -604,7 +694,8 @@ func (s *Sequencer) aggregateBatch(processID types.ProcessID) error { // Mark any invalid ballots as failed if len(invalidKeys) > 0 { if err := s.stg.MarkVerifiedBallotsFailed(invalidKeys...); err != nil { - log.Warnw("failed to mark invalid ballots as failed after aggregation proving failure", + log.Warnw( + "failed to mark invalid ballots as failed after aggregation proving failure", "error", err.Error(), "processID", processID.String(), "invalidCount", len(invalidKeys), @@ -612,7 +703,8 @@ func (s *Sequencer) aggregateBatch(processID types.ProcessID) error { } } if err := s.stg.ReleaseVerifiedBallotReservations(batchInputs.ProcessedKeys); err != nil { - log.Warnw("failed to release ballot reservations after aggregation proving failure", + log.Warnw( + "failed to release ballot reservations after aggregation proving failure", "error", err.Error(), "processID", processID.String(), ) @@ -627,7 +719,8 @@ func (s *Sequencer) aggregateBatch(processID types.ProcessID) error { proofBW6, ok := proof.(*groth16_bw6761.Proof) if !ok { if err := s.stg.ReleaseVerifiedBallotReservations(batchInputs.ProcessedKeys); err != nil { - log.Warnw("failed to release ballot reservations after unexpected aggregate proof type", + log.Warnw( + "failed to release ballot reservations after unexpected aggregate proof type", "error", err.Error(), "processID", processID.String(), ) @@ -644,7 +737,8 @@ func (s *Sequencer) aggregateBatch(processID types.ProcessID) error { if err := s.stg.PushAggregatorBatch(&abb); err != nil { if err := s.stg.ReleaseVerifiedBallotReservations(batchInputs.ProcessedKeys); err != nil { - log.Warnw("failed to release ballot reservations after batch push failure", + log.Warnw( + "failed to release ballot reservations after batch push failure", "error", err.Error(), "processID", processID.String(), ) diff --git a/sequencer/aggregate_debug.go b/sequencer/aggregate_debug.go index d6187dede..a05dac80e 100644 --- a/sequencer/aggregate_debug.go +++ b/sequencer/aggregate_debug.go @@ -27,7 +27,8 @@ func (s *Sequencer) debugAggregationFailure( return } - log.Warnw("aggregator proving failed; investigating batch inputs", + log.Warnw( + "aggregator proving failed; investigating batch inputs", "processID", processID.String(), "error", proveErr.Error(), "votersCount", len(batchInputs.VerifiedBallots), @@ -40,7 +41,8 @@ func (s *Sequencer) debugAggregationFailure( if pubW, err := frontend.NewWitness(assignment, params.AggregatorCurve.ScalarField(), frontend.PublicOnly()); err != nil { log.Warnw("failed to build aggregator public witness", "processID", processID.String(), "error", err.Error()) } else { - log.Debugw("aggregator public witness", + log.Debugw( + "aggregator public witness", "processID", processID.String(), "vector", witnessVectorStrings(pubW), ) @@ -48,7 +50,8 @@ func (s *Sequencer) debugAggregationFailure( proofInputsHashStrings := bigIntStrings(batchInputs.ProofsInputsHashInputs) hashPrefix, hashSuffix := prefixSuffixStrings(proofInputsHashStrings, 5) - log.Debugw("aggregator inputs hash preimage (vote verifier inputs hashes)", + log.Debugw( + "aggregator inputs hash preimage (vote verifier inputs hashes)", "processID", processID.String(), "count", len(proofInputsHashStrings), "prefix", hashPrefix, @@ -57,14 +60,16 @@ func (s *Sequencer) debugAggregationFailure( for i, vb := range batchInputs.VerifiedBallots { if vb == nil { - log.Warnw("nil verified ballot in aggregation batch", + log.Warnw( + "nil verified ballot in aggregation batch", "processID", processID.String(), "index", i, ) continue } if vb.Proof == nil { - log.Warnw("missing vote verifier proof in aggregation batch", + log.Warnw( + "missing vote verifier proof in aggregation batch", "processID", processID.String(), "index", i, "voteID", vb.VoteID.String(), @@ -73,7 +78,8 @@ func (s *Sequencer) debugAggregationFailure( continue } if vb.InputsHash == nil { - log.Warnw("missing vote verifier inputs hash in aggregation batch", + log.Warnw( + "missing vote verifier inputs hash in aggregation batch", "processID", processID.String(), "index", i, "voteID", vb.VoteID.String(), @@ -87,7 +93,8 @@ func (s *Sequencer) debugAggregationFailure( BallotHash: emulated.ValueOf[sw_bn254.ScalarField](vb.InputsHash), } if err := s.voteVerifier.Verify(vb.Proof, pubAssignment); err != nil { - log.Warnw("vote verifier proof does not verify (native)", + log.Warnw( + "vote verifier proof does not verify (native)", "processID", processID.String(), "index", i, "voteID", vb.VoteID.String(), @@ -97,7 +104,8 @@ func (s *Sequencer) debugAggregationFailure( ) pubAssignment.IsValid = 0 if err := s.voteVerifier.Verify(vb.Proof, pubAssignment); err == nil { - log.Warnw("vote verifier proof verifies only with IsValid=0; aggregator treating it as real will fail", + log.Warnw( + "vote verifier proof verifies only with IsValid=0; aggregator treating it as real will fail", "processID", processID.String(), "index", i, "voteID", vb.VoteID.String(), @@ -108,7 +116,8 @@ func (s *Sequencer) debugAggregationFailure( continue } - log.Debugw("vote verifier proof verifies (native)", + log.Debugw( + "vote verifier proof verifies (native)", "processID", processID.String(), "index", i, "voteID", vb.VoteID.String(), diff --git a/sequencer/aggregate_inputs_test.go b/sequencer/aggregate_inputs_test.go index e88078d99..813350d7f 100644 --- a/sequencer/aggregate_inputs_test.go +++ b/sequencer/aggregate_inputs_test.go @@ -75,7 +75,7 @@ func TestCollectAggregationBatchInputs_SkipsDontCreateHoles(t *testing.T) { return stdgroth16.Proof[sw_bls12377.G1Affine, sw_bls12377.G2Affine]{}, nil } - inputs, err := collectAggregationBatchInputs(stg, processID, ballots, keys, processState, false, proofToRecursion, nil) + inputs, err := collectAggregationBatchInputs(stg, processID, ballots, keys, processState, false, proofToRecursion, nil, nil) c.Assert(err, qt.IsNil) c.Assert(len(inputs.AggBallots), qt.Equals, params.VotesPerBatch) diff --git a/sequencer/aggregate_test.go b/sequencer/aggregate_test.go index 6d1149928..8e73e0670 100644 --- a/sequencer/aggregate_test.go +++ b/sequencer/aggregate_test.go @@ -1,11 +1,18 @@ package sequencer import ( + "math/big" "testing" "time" + "github.com/consensys/gnark/backend/groth16" + groth16_bls12377 "github.com/consensys/gnark/backend/groth16/bls12-377" + "github.com/consensys/gnark/std/algebra/native/sw_bls12377" + stdgroth16 "github.com/consensys/gnark/std/recursion/groth16" qt "github.com/frankban/quicktest" "github.com/vocdoni/davinci-node/internal/testutil" + "github.com/vocdoni/davinci-node/storage" + "github.com/vocdoni/davinci-node/types" ) // TestBatchTimingBehavior tests the core behavior of our timing update: @@ -49,3 +56,92 @@ func TestBatchTimingBehavior(t *testing.T) { c.Assert(newTime.After(newStartTime.Add(-time.Second)), qt.IsTrue, qt.Commentf("New first ballot time should be recent")) } + +// TestAggregatePreflightQuarantine verifies that when a batch contains N valid +// ballots plus one ballot whose address is absent from the census, the +// pre-flight census check in collectAggregationBatchInputs: +// 1. Quarantines the absent-address ballot by marking it failed in storage. +// 2. Passes the N valid ballots through to the aggregation inputs unchanged. +// +// This covers AC7 from the plan. +func TestAggregatePreflightQuarantine(t *testing.T) { + c := qt.New(t) + + stg := &mockAggregationStore{} + processState := mockAggregationState{ + voteIDs: make(map[string]struct{}), + addresses: make(map[string]struct{}), + } + + processID := testutil.FixedProcessID() + + // N = 3 valid ballots whose addresses are "in census". + // 1 additional ballot whose address is deliberately absent. + const n = 3 + absentAddress := big.NewInt(0xDEAD) + + ballots := make([]*storage.VerifiedBallot, 0, n+1) + keys := make([][]byte, 0, n+1) + + for i := range n { + ballots = append(ballots, &storage.VerifiedBallot{ + VoteID: types.VoteID(i + 1), + Address: big.NewInt(int64(i + 1)), + Proof: new(groth16_bls12377.Proof), + InputsHash: big.NewInt(int64(1000 + i)), + }) + keys = append(keys, []byte{0xCC, byte(i)}) + } + absentKey := []byte{0xCC, byte(n)} + ballots = append(ballots, &storage.VerifiedBallot{ + VoteID: types.VoteID(n + 1), + Address: absentAddress, + Proof: new(groth16_bls12377.Proof), + InputsHash: big.NewInt(9999), + }) + keys = append(keys, absentKey) + + // checkCensusMembership returns false only for absentAddress. + checkCensus := func(b *storage.VerifiedBallot) bool { + return b.Address.Cmp(absentAddress) != 0 + } + + proofToRecursion := func(_ groth16.Proof) (stdgroth16.Proof[sw_bls12377.G1Affine, sw_bls12377.G2Affine], error) { + return stdgroth16.Proof[sw_bls12377.G1Affine, sw_bls12377.G2Affine]{}, nil + } + + inputs, err := collectAggregationBatchInputs( + stg, + processID, + ballots, + keys, + processState, + false, // maxVotersReached + proofToRecursion, + nil, // verifyVoteVerifierProof — skip ZK proof check in this unit test + checkCensus, + ) + c.Assert(err, qt.IsNil) + + // The N valid ballots must proceed to aggregation. + c.Assert(inputs.AggBallots, qt.HasLen, n, + qt.Commentf("N valid ballots must proceed to aggregation")) + c.Assert(inputs.ProcessedKeys, qt.HasLen, n) + c.Assert(inputs.ProofsInputsHashInputs, qt.HasLen, n) + + // Each of the N valid ballots must appear in order. + for i := range n { + c.Assert(inputs.AggBallots[i].VoteID, qt.Equals, ballots[i].VoteID, + qt.Commentf("valid ballot %d must be in aggregation inputs", i)) + c.Assert(inputs.ProcessedKeys[i], qt.DeepEquals, keys[i]) + } + + // The absent-address ballot must be quarantined: exactly one key in failed. + c.Assert(stg.failed, qt.HasLen, 1, + qt.Commentf("absent-address ballot must be quarantined (marked failed)")) + c.Assert(stg.failed[0], qt.DeepEquals, absentKey, + qt.Commentf("the quarantined key must be the absent-address ballot's storage key")) + + // Nothing should be in released — valid ballots are consumed, not released. + c.Assert(stg.released, qt.HasLen, 0) +} diff --git a/sequencer/ballot.go b/sequencer/ballot.go index a1bc99b5c..0a0df7e1c 100644 --- a/sequencer/ballot.go +++ b/sequencer/ballot.go @@ -91,7 +91,8 @@ func (s *Sequencer) processAvailableBallots() bool { continue } - log.Infow("processing ballot", + log.Infow( + "processing ballot", "address", types.HexBytes(ballot.Address.Bytes()), "voteID", ballot.VoteID.String(), "processID", ballot.ProcessID.String(), @@ -99,7 +100,8 @@ func (s *Sequencer) processAvailableBallots() bool { verifiedBallot, err := s.processBallot(ballot) if err != nil { - log.Warnw("invalid ballot", + log.Warnw( + "invalid ballot", "error", err.Error(), "ballot", ballot.String(), ) @@ -111,7 +113,8 @@ func (s *Sequencer) processAvailableBallots() bool { // Mark the ballot as processed if err := s.stg.MarkBallotVerified(key, verifiedBallot); err != nil { - log.Warnw("failed to mark ballot as processed", + log.Warnw( + "failed to mark ballot as processed", "error", err.Error(), "address", types.HexBytes(ballot.Address.Bytes()), "processID", ballot.ProcessID.String(), @@ -177,7 +180,8 @@ func (s *Sequencer) processBallot(b *storage.Ballot) (*storage.VerifiedBallot, e CircomProof: b.BallotProof, } - log.Debugw("vote verifier inputs ready", + log.Debugw( + "vote verifier inputs ready", "processID", b.ProcessID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), @@ -190,7 +194,8 @@ func (s *Sequencer) processBallot(b *storage.Ballot) (*storage.VerifiedBallot, e return nil, fmt.Errorf("failed to generate proof: %w", err) } - log.InfoTime("vote verification proof generated", startTime, + log.InfoTime( + "vote verification proof generated", startTime, "processID", b.ProcessID.String(), "voteID", b.VoteID.String(), "address", types.HexBytes(b.Address.Bytes()), diff --git a/sequencer/helpers.go b/sequencer/helpers.go index 2464cf6d6..9bac8899d 100644 --- a/sequencer/helpers.go +++ b/sequencer/helpers.go @@ -4,11 +4,17 @@ import ( "errors" "fmt" + "github.com/ethereum/go-ethereum/common" "github.com/vocdoni/davinci-node/log" "github.com/vocdoni/davinci-node/state" + "github.com/vocdoni/davinci-node/storage" "github.com/vocdoni/davinci-node/types" ) +const ( + errGetProcessMetadata = "failed to get process metadata: %w" +) + // currentProcessState retrieves the current in-construction state for a given // process ID. This state includes all locally processed batches, even if they // haven't been confirmed on-chain yet. Use this for processing new votes. @@ -16,7 +22,7 @@ func (s *Sequencer) currentProcessState(processID types.ProcessID) (*state.State // get the process from the storage process, err := s.stg.Process(processID) if err != nil { - return nil, fmt.Errorf("failed to get process metadata: %w", err) + return nil, fmt.Errorf(errGetProcessMetadata, err) } // Open the state tree - this gives us the in-construction root @@ -50,3 +56,55 @@ func (s *Sequencer) currentProcessState(processID types.ProcessID) (*state.State return st, nil } + +// filterBallotsByCensus returns only the ballots whose voter address is +// present in the process census tree. Ballots with absent addresses are +// discarded and logged at WARN level. If the census tree is unavailable +// (no root), an error is returned so the caller can retry rather than +// silently discarding all ballots. +// +// For CSP-based censuses no local merkle lookup is possible; all ballots +// are returned unchanged. +func (s *Sequencer) filterBallotsByCensus(processID types.ProcessID, ballots []*storage.AggregatorBallot) ([]*storage.AggregatorBallot, error) { + process, err := s.stg.Process(processID) + if err != nil { + return nil, fmt.Errorf(errGetProcessMetadata, err) + } + if !process.Census.CensusOrigin.IsMerkleTree() { + // CSP censuses are verified via the embedded proof; no local tree to query. + return ballots, nil + } + + var chainID uint64 + if process.Census.CensusOrigin == types.CensusOriginMerkleTreeOnchainDynamicV1 { + contracts, err := s.contractsForProcess(processID) + if err != nil { + return nil, fmt.Errorf("failed to resolve contracts for process %s: %w", processID.String(), err) + } + chainID = contracts.ChainID + } + censusRef, err := s.stg.LoadCensus(chainID, process.Census) + if err != nil { + return nil, fmt.Errorf("failed to load census for process %s: %w", processID.String(), err) + } + censusTree := censusRef.Tree() + if _, ok := censusTree.Root(); !ok { + return nil, fmt.Errorf("census tree has no root for process %s (censusRoot=%s)", + processID.String(), process.Census.CensusRoot.String()) + } + + filtered := make([]*storage.AggregatorBallot, 0, len(ballots)) + for _, b := range ballots { + addr := common.BigToAddress(b.Address) + if _, ok := censusTree.GetWeight(addr); !ok { + log.Warnw( + "address not found in census, skipping ballot", + "processID", processID.String(), + "address", addr.Hex(), + ) + continue + } + filtered = append(filtered, b) + } + return filtered, nil +} diff --git a/sequencer/helpers_test.go b/sequencer/helpers_test.go new file mode 100644 index 000000000..6a4530e10 --- /dev/null +++ b/sequencer/helpers_test.go @@ -0,0 +1,113 @@ +package sequencer + +import ( + "math/big" + "testing" + + qt "github.com/frankban/quicktest" + "github.com/vocdoni/davinci-node/internal/testutil" + "github.com/vocdoni/davinci-node/storage" + "github.com/vocdoni/davinci-node/types" + leanimt "github.com/vocdoni/lean-imt-go" + leancensus "github.com/vocdoni/lean-imt-go/census" +) + +const testMetadataURI = "http://example.com/metadata" + +// buildCensusForTest creates an in-memory census tree with the given addresses +// (all with weight 1), imports it into stg, and returns the census root bytes. +func buildCensusForTest(t *testing.T, stg *storage.Storage, addrs ...uint64) types.HexBytes { + t.Helper() + tree, err := leancensus.NewCensusIMT(nil, leanimt.PoseidonHasher) + if err != nil { + t.Fatalf("NewCensusIMT: %v", err) + } + for _, n := range addrs { + if err := tree.Add(testutil.DeterministicAddress(n), big.NewInt(1)); err != nil { + t.Fatalf("tree.Add(%d): %v", n, err) + } + } + root, ok := tree.Root() + if !ok { + t.Fatal("census tree has no root after adding entries") + } + censusRoot := types.HexBytes(root.Bytes()) + if _, err := stg.CensusDB().Import(censusRoot, tree.Dump()); err != nil { + t.Fatalf("CensusDB.Import: %v", err) + } + return censusRoot +} + +// ballot returns a minimal AggregatorBallot for the given deterministic address index. +func ballot(n uint64) *storage.AggregatorBallot { + return &storage.AggregatorBallot{ + VoteID: testutil.RandomVoteID(), + Address: testutil.DeterministicAddress(n).Big(), + Weight: big.NewInt(1), + } +} + +// TestFilterBallotsByCensusAllPresent verifies that when every ballot address +// is in the census tree the full batch is returned unchanged. +func TestFilterBallotsByCensusAllPresent(t *testing.T) { + c := qt.New(t) + stg := newTestSequencerStorage(t) + defer stg.Close() + + processID := testutil.RandomProcessID() + censusRoot := buildCensusForTest(t, stg, 0, 1) + makeTestProcessWithCensus(t, stg, processID, censusRoot) + + seq := &Sequencer{stg: stg, processIDs: NewProcessIDMap()} + input := []*storage.AggregatorBallot{ballot(0), ballot(1)} + + got, err := seq.filterBallotsByCensus(processID, input) + c.Assert(err, qt.IsNil) + c.Assert(got, qt.HasLen, 2) +} + +// TestFilterBallotsByCensusOneAbsent verifies that a ballot whose address is +// absent from the census tree is removed from the returned slice while the +// remaining ballot is kept and no error is returned. +func TestFilterBallotsByCensusOneAbsent(t *testing.T) { + c := qt.New(t) + stg := newTestSequencerStorage(t) + defer stg.Close() + + processID := testutil.RandomProcessID() + // Census contains addresses 0 and 1; address 999 is not in the census. + censusRoot := buildCensusForTest(t, stg, 0, 1) + makeTestProcessWithCensus(t, stg, processID, censusRoot) + + seq := &Sequencer{stg: stg, processIDs: NewProcessIDMap()} + input := []*storage.AggregatorBallot{ballot(0), ballot(999), ballot(1)} + + got, err := seq.filterBallotsByCensus(processID, input) + c.Assert(err, qt.IsNil) + c.Assert(got, qt.HasLen, 2) + // The two surviving ballots must be the ones with indices 0 and 1. + c.Assert(got[0].Address.Cmp(testutil.DeterministicAddress(0).Big()), qt.Equals, 0) + c.Assert(got[1].Address.Cmp(testutil.DeterministicAddress(1).Big()), qt.Equals, 0) +} + +// TestFilterBallotsByCensusCensusUnavailable verifies that an error is +// returned when the census root stored on the process has no corresponding +// entry in the census database, so the caller can retry rather than silently +// dropping all ballots. +func TestFilterBallotsByCensusCensusUnavailable(t *testing.T) { + c := qt.New(t) + stg := newTestSequencerStorage(t) + defer stg.Close() + + processID := testutil.RandomProcessID() + // Create the process with a census root that was never imported into the DB. + missingRoot := types.HexBytes(testutil.RandomCensusRoot().Bytes()) + makeTestProcessWithCensus(t, stg, processID, missingRoot) + + seq := &Sequencer{stg: stg, processIDs: NewProcessIDMap()} + input := []*storage.AggregatorBallot{ballot(0)} + + _, err := seq.filterBallotsByCensus(processID, input) + c.Assert(err, qt.Not(qt.IsNil)) + c.Assert(err.Error(), qt.Contains, "failed to load census") +} diff --git a/sequencer/onchain_test.go b/sequencer/onchain_test.go index 807dbf4d1..28f634cdb 100644 --- a/sequencer/onchain_test.go +++ b/sequencer/onchain_test.go @@ -51,7 +51,7 @@ func ensureSequencerTestProcess(t *testing.T, stg *storage.Storage, pid types.Pr Status: types.ProcessStatusReady, StartTime: time.Now(), Duration: time.Hour, - MetadataURI: "http://example.com/metadata", + MetadataURI: testMetadataURI, BallotMode: testutil.BallotMode(), EncryptionKey: &encryptionKey, StateRoot: types.BigIntConverter(stateRoot), @@ -97,7 +97,7 @@ func TestPushStateTransitionCallbackMarksBatchDoneAfterSuccess(t *testing.T) { Status: types.ProcessStatusReady, StartTime: time.Now(), Duration: time.Hour, - MetadataURI: "http://example.com/metadata", + MetadataURI: testMetadataURI, BallotMode: ballotMode, EncryptionKey: func() *types.EncryptionKey { key := types.EncryptionKeyFromPoint(encryptionKey) diff --git a/sequencer/sequencer_test.go b/sequencer/sequencer_test.go index 05b8f215c..5b968fb05 100644 --- a/sequencer/sequencer_test.go +++ b/sequencer/sequencer_test.go @@ -116,7 +116,7 @@ func createReadyProcess(t *testing.T, pid types.ProcessID) *types.Process { Status: types.ProcessStatusReady, StartTime: time.Now(), Duration: time.Hour, - MetadataURI: "http://example.com/metadata", + MetadataURI: testMetadataURI, BallotMode: testutil.BallotMode(), EncryptionKey: &encryptionKey, StateRoot: types.BigIntConverter(stateRoot), diff --git a/sequencer/statetransition.go b/sequencer/statetransition.go index 13ec45555..70f5937cb 100644 --- a/sequencer/statetransition.go +++ b/sequencer/statetransition.go @@ -1,6 +1,7 @@ package sequencer import ( + "errors" "fmt" "math/big" "time" @@ -23,6 +24,11 @@ import ( imtcircuit "github.com/vocdoni/lean-imt-go/circuit" ) +// errCensusTreeUnavailable is returned by processCensusProofs when the census +// tree cannot be loaded even after a reload attempt. The caller must not +// permanently fail the batch; it should retry on the next tick instead. +var errCensusTreeUnavailable = errors.New("census tree unavailable, retry next tick") + func (s *Sequencer) startStateTransitionProcessor() error { const tickInterval = time.Second ticker := time.NewTicker(tickInterval) @@ -89,7 +95,8 @@ func (s *Sequencer) processPendingTransitions() { return true // Continue to next process ID } - log.Debugw("state transition ready for processing", + log.Debugw( + "state transition ready for processing", "processID", batch.ProcessID.String(), "ballotCount", len(batch.Ballots), ) @@ -109,6 +116,14 @@ func (s *Sequencer) processPendingTransitions() { } censusRoot, circuitCensusProofs, err := s.processCensusProofs(batch.ProcessID, reencryptedVotes, censusProofs) if err != nil { + if errors.Is(err, errCensusTreeUnavailable) { + // Census tree is temporarily unavailable (e.g. after restart); + // leave the batch intact so it is retried on the next tick. + log.Warnw("census tree unavailable, batch will be retried", + "processID", batch.ProcessID.String(), + "error", err) + return true // Continue to next process ID + } log.Errorw(err, "failed to get census proofs") s.markAggregatorBatchFailed(batchID) return true // Continue to next process ID @@ -122,7 +137,8 @@ func (s *Sequencer) processPendingTransitions() { *circuitCensusProofs, reencryptedVotes, kSeed, - batch.Proof) + batch.Proof, + ) if err != nil { log.Errorw(err, "failed to process state transition batch") s.markAggregatorBatchFailed(batchID) @@ -145,7 +161,8 @@ func (s *Sequencer) processPendingTransitions() { // Get blob sidecar and hash blobSidecar := stateBatch.BlobEvalData().TxSidecar() - log.InfoTime("state transition proof generated", startTime, + log.InfoTime( + "state transition proof generated", startTime, "processID", processID.String(), "rootHashBefore", stateBatch.RootHashBefore().String(), "rootHashAfter", stateBatch.RootHashAfter().String(), @@ -235,7 +252,8 @@ func (s *Sequencer) processStateTransitionBatch( return nil, nil, fmt.Errorf("failed to generate assignment: %w", err) } defer batch.Discard() - log.DebugTime("state transition assignment ready for proof generation", startTime, + log.DebugTime( + "state transition assignment ready for proof generation", startTime, "processID", processState.ProcessID(), "votersCount", assignment.VotersCount, "overwrittenVotesCount", assignment.OverwrittenVotesCount, @@ -267,7 +285,8 @@ func (s *Sequencer) logStateTransitionDebugInfo( ) { log.Errorw(err, "STATE TRANSITION CONSTRAINT ERROR - DEBUG INFO") if assignment != nil { - log.Infow("constraint error details", + log.Infow( + "constraint error details", "processID", processState.ProcessID().String(), "rootHashBefore", assignment.RootHashBefore, "rootHashAfter", assignment.RootHashAfter, @@ -285,7 +304,8 @@ func (s *Sequencer) logStateTransitionDebugInfo( // Log vote details for i, v := range votes { - log.Infow("vote details", + log.Infow( + "vote details", "index", i, "voteID", v.VoteID.String(), "address", types.HexBytes(v.Address.Bytes()), @@ -376,7 +396,7 @@ func (s *Sequencer) processCensusProofs( // get the process from the storage process, err := s.stg.Process(processID) if err != nil { - return nil, nil, fmt.Errorf("failed to get process metadata: %w", err) + return nil, nil, fmt.Errorf(errGetProcessMetadata, err) } var root *big.Int @@ -402,7 +422,29 @@ func (s *Sequencer) processCensusProofs( censusTree := censusRef.Tree() var ok bool if root, ok = censusTree.Root(); !ok { - log.Warnw("census tree has no root?", "censusRoot", process.Census.CensusRoot.String(), "fetchedRoot", root.String()) + // The ephemeral Pebble tree is empty (e.g. after a node restart). + // Attempt to reload from the persistent KV-backed storage. + log.Warnw("census tree has no root, attempting reload from persistent storage", + "processID", processID.String(), + "censusRoot", process.Census.CensusRoot.String()) + reloadedRef, reloadErr := s.stg.CensusDB().LoadFromPersistentTree(process.Census.CensusRoot) + if reloadErr != nil { + log.Errorw(reloadErr, fmt.Sprintf("census tree reload failed for process %s (censusRoot=%s)", + processID.String(), process.Census.CensusRoot.String())) + return nil, nil, fmt.Errorf("%w: process %s (censusRoot=%s): %w", + errCensusTreeUnavailable, processID.String(), process.Census.CensusRoot.String(), reloadErr) + } + censusTree = reloadedRef.Tree() + if root, ok = censusTree.Root(); !ok { + log.Warnw("census tree still has no root after reload", + "processID", processID.String(), + "censusRoot", process.Census.CensusRoot.String()) + return nil, nil, fmt.Errorf("%w: process %s (censusRoot=%s): tree empty after reload", + errCensusTreeUnavailable, processID.String(), process.Census.CensusRoot.String()) + } + log.Infow("census tree successfully reloaded from persistent storage", + "processID", processID.String(), + "censusRoot", process.Census.CensusRoot.String()) } // iterate over the votes to generate the merkle proofs of each voter for i := range params.VotesPerBatch { diff --git a/sequencer/statetransition_test.go b/sequencer/statetransition_test.go index 541054e7a..3ba2834f1 100644 --- a/sequencer/statetransition_test.go +++ b/sequencer/statetransition_test.go @@ -5,18 +5,25 @@ import ( "math/big" "os" "testing" + "time" "github.com/consensys/gnark/backend/groth16" "github.com/ethereum/go-ethereum/accounts/abi" qt "github.com/frankban/quicktest" stc "github.com/vocdoni/davinci-node/circuits/statetransition" statetransitiontest "github.com/vocdoni/davinci-node/circuits/test/statetransition" + "github.com/vocdoni/davinci-node/db" + "github.com/vocdoni/davinci-node/db/metadb" "github.com/vocdoni/davinci-node/internal/testutil" + spechash "github.com/vocdoni/davinci-node/spec/hash" specutil "github.com/vocdoni/davinci-node/spec/util" + "github.com/vocdoni/davinci-node/state" statetest "github.com/vocdoni/davinci-node/state/testutil" "github.com/vocdoni/davinci-node/storage" "github.com/vocdoni/davinci-node/types" "github.com/vocdoni/davinci-node/web3" + leanimt "github.com/vocdoni/lean-imt-go" + leancensus "github.com/vocdoni/lean-imt-go/census" ) func testVariableAsBigInt(t *testing.T, v any) *big.Int { @@ -327,3 +334,189 @@ func publicStateTransitionCircuitFromInputs(inputs storage.StateTransitionBatchP circuit.BlobCommitmentLimbs[2] = inputs.BlobCommitmentLimbs[2] return circuit } + +// makeTestProcessWithCensus creates a process in storage with the given MerkleTree census root. +func makeTestProcessWithCensus(t *testing.T, stg *storage.Storage, processID types.ProcessID, censusRoot types.HexBytes) { + t.Helper() + encryptionKey := testutil.RandomEncryptionPubKey() + censusOrigin := types.CensusOriginMerkleTreeOffchainStaticV1 + stateRoot, err := spechash.StateRoot( + processID.MathBigInt(), + censusOrigin.BigInt().MathBigInt(), + encryptionKey.X.MathBigInt(), + encryptionKey.Y.MathBigInt(), + testutil.BallotModePacked(), + ) + if err != nil { + t.Fatalf("spechash.StateRoot: %v", err) + } + proc := &types.Process{ + ID: &processID, + Status: types.ProcessStatusReady, + StartTime: time.Now(), + Duration: time.Hour, + MetadataURI: testMetadataURI, + BallotMode: testutil.BallotMode(), + EncryptionKey: &encryptionKey, + StateRoot: types.BigIntConverter(stateRoot), + Census: &types.Census{ + CensusOrigin: censusOrigin, + CensusRoot: censusRoot, + }, + } + if err := stg.NewProcess(proc); err != nil { + t.Fatalf("NewProcess: %v", err) + } +} + +// TestProcessCensusProofsNilRootReturnsError verifies that processCensusProofs +// returns an error when the census tree exists in storage but has no root +// (empty tree, no entries). +func TestProcessCensusProofsNilRootReturnsError(t *testing.T) { + c := qt.New(t) + stg := newTestSequencerStorage(t) + defer stg.Close() + + processID := testutil.RandomProcessID() + censusRoot := types.HexBytes(testutil.RandomCensusRoot().Bytes()) + + // Register an empty census (no entries) so LoadCensus succeeds but + // Tree().Root() returns (nil, false). + _, err := stg.CensusDB().NewByRoot(censusRoot) + c.Assert(err, qt.IsNil) + makeTestProcessWithCensus(t, stg, processID, censusRoot) + + seq := &Sequencer{stg: stg, processIDs: NewProcessIDMap()} + + _, _, err = seq.processCensusProofs(processID, nil, nil) + c.Assert(err, qt.Not(qt.IsNil)) + c.Assert(err.Error(), qt.Contains, "census tree unavailable") +} + +// TestProcessCensusProofsMissingAddressReturnsError verifies that processCensusProofs +// returns an error when a vote's address is absent from the census tree. Census +// filtering must happen before aggregation (to preserve the aggregator proof's +// BatchHash public input); a missing address at state-transition time is fatal. +func TestProcessCensusProofsMissingAddressReturnsError(t *testing.T) { + c := qt.New(t) + stg := newTestSequencerStorage(t) + defer stg.Close() + + processID := testutil.RandomProcessID() + + addr0 := testutil.DeterministicAddress(0) + addr1 := testutil.DeterministicAddress(1) + sourceTree, err := leancensus.NewCensusIMT(nil, leanimt.PoseidonHasher) + c.Assert(err, qt.IsNil) + c.Assert(sourceTree.Add(addr0, big.NewInt(1)), qt.IsNil) + c.Assert(sourceTree.Add(addr1, big.NewInt(1)), qt.IsNil) + root, ok := sourceTree.Root() + c.Assert(ok, qt.IsTrue) + + censusRoot := types.HexBytes(root.Bytes()) + _, err = stg.CensusDB().Import(censusRoot, sourceTree.Dump()) + c.Assert(err, qt.IsNil) + makeTestProcessWithCensus(t, stg, processID, censusRoot) + + // addr0 and addr1 are in the census; addr2 (index 999) is not. + addr2 := testutil.DeterministicAddress(999) + votes := []*state.Vote{ + {Address: addr0.Big(), Weight: big.NewInt(1)}, + {Address: addr1.Big(), Weight: big.NewInt(1)}, + {Address: addr2.Big(), Weight: big.NewInt(1)}, + } + + seq := &Sequencer{stg: stg, processIDs: NewProcessIDMap()} + _, _, err = seq.processCensusProofs(processID, votes, nil) + c.Assert(err, qt.Not(qt.IsNil)) +} + +// TestProcessCensusProofsAllValidReturnsAllVotes verifies that when every vote +// address is present in the census tree, all votes are returned unchanged. +func TestProcessCensusProofsAllValidReturnsAllVotes(t *testing.T) { + c := qt.New(t) + stg := newTestSequencerStorage(t) + defer stg.Close() + + processID := testutil.RandomProcessID() + + addr0 := testutil.DeterministicAddress(0) + addr1 := testutil.DeterministicAddress(1) + sourceTree, err := leancensus.NewCensusIMT(nil, leanimt.PoseidonHasher) + c.Assert(err, qt.IsNil) + c.Assert(sourceTree.Add(addr0, big.NewInt(1)), qt.IsNil) + c.Assert(sourceTree.Add(addr1, big.NewInt(1)), qt.IsNil) + root, ok := sourceTree.Root() + c.Assert(ok, qt.IsTrue) + + censusRoot := types.HexBytes(root.Bytes()) + _, err = stg.CensusDB().Import(censusRoot, sourceTree.Dump()) + c.Assert(err, qt.IsNil) + makeTestProcessWithCensus(t, stg, processID, censusRoot) + + votes := []*state.Vote{ + {Address: addr0.Big(), Weight: big.NewInt(1)}, + {Address: addr1.Big(), Weight: big.NewInt(1)}, + } + + seq := &Sequencer{stg: stg, processIDs: NewProcessIDMap()} + _, _, err = seq.processCensusProofs(processID, votes, nil) + c.Assert(err, qt.IsNil) +} + +// TestStateTransitionCensusTreeReload verifies that processCensusProofs +// recovers when the census tree Pebble cache is empty after a node restart. +// +// Import stores tree data in the persistent KV DB (censusTreeDBPrefix). After +// a restart the in-memory census cache is gone; loadCensusRef re-opens the +// census by creating a fresh empty Pebble tree (in /tmp, never populated by +// Import), so Tree().Root() returns (nil, false). +// +// The fix in T005 must reload from censusTreeDBPrefix so the state transition +// is not permanently lost. +// +// Failure mode today: returns "census tree has no root" error. +// Expected after T005: returns (root, proofs, nil). +func TestStateTransitionCensusTreeReload(t *testing.T) { + c := qt.New(t) + + // Build a small census with one address so the imported tree has a valid root. + addr0 := testutil.DeterministicAddress(0) + sourceTree, err := leancensus.NewCensusIMT(nil, leanimt.PoseidonHasher) + c.Assert(err, qt.IsNil) + c.Assert(sourceTree.Add(addr0, big.NewInt(1)), qt.IsNil) + treeRoot, ok := sourceTree.Root() + c.Assert(ok, qt.IsTrue) + censusRoot := types.HexBytes(treeRoot.Bytes()) + + // Phase 1: persist the census and process, then close to simulate shutdown. + // Import writes tree nodes into the persistent KV DB (censusTreeDBPrefix), + // NOT into a Pebble /tmp directory. The persistent data survives the close. + dbDir := t.TempDir() + testdb1, err := metadb.New(db.TypePebble, dbDir) + c.Assert(err, qt.IsNil) + stg1 := storage.New(testdb1) + _, err = stg1.CensusDB().Import(censusRoot, sourceTree.Dump()) + c.Assert(err, qt.IsNil) + processID := testutil.RandomProcessID() + makeTestProcessWithCensus(t, stg1, processID, censusRoot) + stg1.Close() + + // Phase 2: reopen the same DB to simulate a node restart. + // The in-memory census cache is empty; loadCensusRef creates a NEW empty + // Pebble tree at censusPrefix (in /tmp) because Import never touched Pebble. + // Tree().Root() therefore returns (nil, false). + testdb2, err := metadb.New(db.TypePebble, dbDir) + c.Assert(err, qt.IsNil) + stg2 := storage.New(testdb2) + t.Cleanup(func() { stg2.Close() }) + + seq := &Sequencer{stg: stg2} + + // Today this fails with "census tree has no root" because loadCensusRef + // creates an empty Pebble tree and T005's reload logic does not yet exist. + // After T005 the function must reload from censusTreeDBPrefix and return nil. + _, _, err = seq.processCensusProofs(processID, nil, nil) + c.Assert(err, qt.IsNil, + qt.Commentf("processCensusProofs must reload census from persistent KV DB when Pebble tree is empty after restart")) +} diff --git a/sequencer/worker.go b/sequencer/worker.go index a7b43bd19..3410f4dca 100644 --- a/sequencer/worker.go +++ b/sequencer/worker.go @@ -98,7 +98,8 @@ func NewWorker(stg *storage.Storage, rawSequencerURL, workerAddr, workerToken, w return nil, fmt.Errorf("failed to load vote verifier artifacts: %w", err) } - log.DebugTime("worker sequencer initialized", startTime, + log.DebugTime( + "worker sequencer initialized", startTime, "sequencerURL", sequencerURL, "workerAddress", workerAddr, "workerName", workerName, @@ -339,7 +340,8 @@ func (s *Sequencer) submitJobToMaster(vb *storage.VerifiedBallot) error { return fmt.Errorf("failed to decode worker response: %w", err) } - log.Infow("submitted job to master", + log.Infow( + "submitted job to master", "voteID", fmt.Sprintf("%x", vb.VoteID), "processID", vb.ProcessID.String(), "success", workerResponse.SuccessCount, diff --git a/types/hexbytes_test.go b/types/hexbytes_test.go index 775b59d34..f2f16e768 100644 --- a/types/hexbytes_test.go +++ b/types/hexbytes_test.go @@ -7,6 +7,8 @@ import ( qt "github.com/frankban/quicktest" ) +const testNameEmpty = "empty" + func TestHexBytes(t *testing.T) { c := qt.New(t) @@ -26,7 +28,7 @@ func TestHexBytes(t *testing.T) { want string }{ {name: "nil slice", in: nil, want: "0x"}, - {name: "empty", in: HexBytes{}, want: "0x"}, + {name: testNameEmpty, in: HexBytes{}, want: "0x"}, {name: "non-empty", in: HexBytes{0x00, 0xAB, 0xCD}, want: "0x00abcd"}, } @@ -43,7 +45,7 @@ func TestHexBytes(t *testing.T) { in HexBytes want string }{ - {name: "empty", in: HexBytes{}, want: "0"}, + {name: testNameEmpty, in: HexBytes{}, want: "0"}, {name: "big-endian", in: HexBytes{0x01, 0x00}, want: "256"}, {name: "leading zeros", in: HexBytes{0x00, 0x00, 0x02}, want: "2"}, } @@ -170,7 +172,7 @@ func TestHexBytes(t *testing.T) { in HexBytes want string }{ - {name: "empty", in: HexBytes{}, want: `"0x"`}, + {name: testNameEmpty, in: HexBytes{}, want: `"0x"`}, {name: "non-empty", in: HexBytes{0xDE, 0xAD, 0xBE, 0xEF}, want: `"0xdeadbeef"`}, } @@ -196,7 +198,7 @@ func TestHexBytes(t *testing.T) { {name: "with 0x prefix", in: `"0xdeadbeef"`, want: HexBytes{0xDE, 0xAD, 0xBE, 0xEF}}, {name: "with 0X prefix", in: `"0Xdeadbeef"`, want: HexBytes{0xDE, 0xAD, 0xBE, 0xEF}}, {name: "without prefix", in: `"deadbeef"`, want: HexBytes{0xDE, 0xAD, 0xBE, 0xEF}}, - {name: "empty", in: `"0x"`, want: HexBytes{}}, + {name: testNameEmpty, in: `"0x"`, want: HexBytes{}}, } for _, tc := range testCases { @@ -253,7 +255,7 @@ func TestHexBytes(t *testing.T) { {name: "with prefix", in: "0xdeadbeef", want: HexBytes{0xDE, 0xAD, 0xBE, 0xEF}}, {name: "with uppercase prefix", in: "0Xdeadbeef", want: HexBytes{0xDE, 0xAD, 0xBE, 0xEF}}, {name: "without prefix", in: "deadbeef", want: HexBytes{0xDE, 0xAD, 0xBE, 0xEF}}, - {name: "empty", in: "", want: HexBytes{}}, + {name: testNameEmpty, in: "", want: HexBytes{}}, } for _, tc := range testCases { diff --git a/types/process_test.go b/types/process_test.go index f07c4253c..0c4b52911 100644 --- a/types/process_test.go +++ b/types/process_test.go @@ -8,6 +8,8 @@ import ( "github.com/fxamacker/cbor/v2" ) +const testNilProcess = "nil process" + func TestNestedMetadata(t *testing.T) { metadata := Metadata{ Meta: GenericMetadata{ @@ -132,7 +134,7 @@ func TestProcessIsActive(t *testing.T) { want: false, }, { - name: "nil process", + name: testNilProcess, process: nil, want: false, }, @@ -219,7 +221,7 @@ func TestProcessIsAcceptingVotes(t *testing.T) { want: false, }, { - name: "nil process", + name: testNilProcess, process: nil, want: false, }, @@ -315,7 +317,7 @@ func TestProcessMaxVotersReached(t *testing.T) { want: false, }, { - name: "nil process", + name: testNilProcess, process: nil, want: false, }, diff --git a/util/circomgnark/cache_test.go b/util/circomgnark/cache_test.go index a8222b8b7..78b33ae87 100644 --- a/util/circomgnark/cache_test.go +++ b/util/circomgnark/cache_test.go @@ -7,6 +7,8 @@ import ( "github.com/vocdoni/davinci-node/circuits/ballotproof" ) +const testProtocol = "groth16" + func TestUnmarshalCircomVerificationKeyJSONCachesByInput(t *testing.T) { c := qt.New(t) @@ -59,7 +61,7 @@ func TestToGnarkRecursionFixedVkSkipsVerificationKeyConversion(t *testing.T) { "4110411832118690910191887320272248494012149664813960539989768130756673868858", "1", }, - Protocol: "groth16", + Protocol: testProtocol, } recursionProof, err := proof.ToGnarkRecursion(&CircomVerificationKey{}, []string{ diff --git a/util/circomgnark/marshal_test.go b/util/circomgnark/marshal_test.go index 1b3c279c6..1a1f0b78d 100644 --- a/util/circomgnark/marshal_test.go +++ b/util/circomgnark/marshal_test.go @@ -13,7 +13,7 @@ func TestMarshalCircomProofJSON(t *testing.T) { PiA: []string{"1", "2", "1"}, PiB: [][]string{{"1", "2"}, {"3", "4"}, {"1", "0"}}, PiC: []string{"5", "6", "1"}, - Protocol: "groth16", + Protocol: testProtocol, } data, err := MarshalCircomProofJSON(proof) @@ -28,7 +28,7 @@ func TestMarshalCircomVerificationKeyJSON(t *testing.T) { c := qt.New(t) vk := &CircomVerificationKey{ - Protocol: "groth16", + Protocol: testProtocol, Curve: "bn128", NPublic: 3, VkAlpha1: []string{"1", "2", "1"}, diff --git a/web3/contracts_test.go b/web3/contracts_test.go index a5ea72db1..ab60e8df8 100644 --- a/web3/contracts_test.go +++ b/web3/contracts_test.go @@ -21,6 +21,11 @@ import ( "github.com/vocdoni/davinci-node/web3/rpc" ) +const ( + testMetadataURI = "https://example.com/metadata" + testCensusURI = "https://example.com/census" +) + type testRPCRequest struct { ID json.RawMessage `json:"id"` Method string `json:"method"` @@ -106,13 +111,13 @@ func TestProcessAtBlockUsesHistoricalSnapshot(t *testing.T) { OverwrittenVotesCount: big.NewInt(0), CreationBlock: big.NewInt(10), BatchNumber: big.NewInt(0), - MetadataURI: "https://example.com/metadata", + MetadataURI: testMetadataURI, BallotMode: npbindings.DAVINCITypesBallotMode{UniqueValues: false, NumFields: 5, GroupSize: 0, CostExponent: 2, MaxValue: big.NewInt(10), MinValue: big.NewInt(0), MaxValueSum: big.NewInt(100), MinValueSum: big.NewInt(0)}, Census: npbindings.DAVINCITypesCensus{ CensusOrigin: uint8(types.CensusOriginMerkleTreeOffchainStaticV1), CensusRoot: [32]byte{}, ContractAddress: common.Address{}, - CensusURI: "https://example.com/census", + CensusURI: testCensusURI, }, }, latest: npbindings.DAVINCITypesProcess{ @@ -128,13 +133,13 @@ func TestProcessAtBlockUsesHistoricalSnapshot(t *testing.T) { OverwrittenVotesCount: big.NewInt(0), CreationBlock: big.NewInt(10), BatchNumber: big.NewInt(0), - MetadataURI: "https://example.com/metadata", + MetadataURI: testMetadataURI, BallotMode: npbindings.DAVINCITypesBallotMode{UniqueValues: false, NumFields: 5, GroupSize: 0, CostExponent: 2, MaxValue: big.NewInt(10), MinValue: big.NewInt(0), MaxValueSum: big.NewInt(100), MinValueSum: big.NewInt(0)}, Census: npbindings.DAVINCITypesCensus{ CensusOrigin: uint8(types.CensusOriginMerkleTreeOffchainStaticV1), CensusRoot: [32]byte{}, ContractAddress: common.Address{}, - CensusURI: "https://example.com/census", + CensusURI: testCensusURI, }, }, } @@ -175,13 +180,13 @@ func TestProcessAtBlockUsesCreationSnapshot(t *testing.T) { OverwrittenVotesCount: big.NewInt(0), CreationBlock: big.NewInt(10), BatchNumber: big.NewInt(0), - MetadataURI: "https://example.com/metadata", + MetadataURI: testMetadataURI, BallotMode: npbindings.DAVINCITypesBallotMode{UniqueValues: false, NumFields: 5, GroupSize: 0, CostExponent: 2, MaxValue: big.NewInt(10), MinValue: big.NewInt(0), MaxValueSum: big.NewInt(100), MinValueSum: big.NewInt(0)}, Census: npbindings.DAVINCITypesCensus{ CensusOrigin: uint8(types.CensusOriginMerkleTreeOffchainStaticV1), CensusRoot: [32]byte{}, ContractAddress: common.Address{}, - CensusURI: "https://example.com/census", + CensusURI: testCensusURI, }, }, } diff --git a/web3/rpc/chainlist/chainlist_test.go b/web3/rpc/chainlist/chainlist_test.go index a69c189f7..7504def81 100644 --- a/web3/rpc/chainlist/chainlist_test.go +++ b/web3/rpc/chainlist/chainlist_test.go @@ -11,6 +11,12 @@ import ( "time" ) +const ( + testArbRPC1 = "https://arb-rpc-1.example.com" + testArbRPC2 = "https://arb-rpc-2.example.com" + testValidBlockAndLogsEndpoint = "https://valid-block-and-logs.example.com" +) + // testRand is a package-level random number generator for consistent test results var testRand = rand.New(rand.NewSource(1234)) @@ -81,8 +87,8 @@ func TestEndpointList(t *testing.T) { ShortName: "arb1", ChainID: 42161, RPC: []RPCEntry{ - {URL: "https://arb-rpc-1.example.com"}, - {URL: "https://arb-rpc-2.example.com"}, + {URL: testArbRPC1}, + {URL: testArbRPC2}, }, }, } @@ -104,8 +110,8 @@ func TestEndpointList(t *testing.T) { // Verify all endpoints are present (in any order due to randomization) expectURLs := map[string]bool{ - "https://arb-rpc-1.example.com": true, - "https://arb-rpc-2.example.com": true, + testArbRPC1: true, + testArbRPC2: true, } for _, url := range endpoints { @@ -128,8 +134,8 @@ func TestEndpointList(t *testing.T) { } expectURLs := map[string]bool{ - "https://arb-rpc-1.example.com": true, - "https://arb-rpc-2.example.com": true, + testArbRPC1: true, + testArbRPC2: true, } for _, url := range endpoints { if !expectURLs[url] { @@ -257,7 +263,7 @@ func TestEnhancedHealthCheck(t *testing.T) { ShortName: "test", ChainID: 1, RPC: []RPCEntry{ - {URL: "https://valid-block-and-logs.example.com"}, // Both valid + {URL: testValidBlockAndLogsEndpoint}, // Both valid {URL: "https://valid-block-no-logs.example.com"}, // Only valid block but no getLogs {URL: "https://zero-block-valid-logs.example.com"}, // Zero block but valid getLogs {URL: "https://zero-block-no-logs.example.com"}, // Neither valid @@ -294,7 +300,7 @@ func TestEnhancedHealthCheck(t *testing.T) { supportsGetLogs bool correctChainID bool }{ - "https://valid-block-and-logs.example.com": {true, true, true}, + testValidBlockAndLogsEndpoint: {true, true, true}, "https://valid-block-no-logs.example.com": {true, false, true}, "https://zero-block-valid-logs.example.com": {false, true, true}, "https://zero-block-no-logs.example.com": {false, false, true}, @@ -324,8 +330,8 @@ func TestEnhancedHealthCheck(t *testing.T) { } // Verify it's the correct endpoint - if len(endpoints) > 0 && endpoints[0] != "https://valid-block-and-logs.example.com" { - t.Errorf("Expected healthy endpoint 'https://valid-block-and-logs.example.com', got %s", endpoints[0]) + if len(endpoints) > 0 && endpoints[0] != testValidBlockAndLogsEndpoint { + t.Errorf("Expected healthy endpoint '%s', got %s", testValidBlockAndLogsEndpoint, endpoints[0]) } }) } diff --git a/web3/rpc/web3_pool_test.go b/web3/rpc/web3_pool_test.go index 6ed9bc012..f09d76c68 100644 --- a/web3/rpc/web3_pool_test.go +++ b/web3/rpc/web3_pool_test.go @@ -9,6 +9,12 @@ import ( qt "github.com/frankban/quicktest" ) +const ( + testEndpoint1 = "http://endpoint1.example.com" + testEndpoint2 = "http://endpoint2.example.com" + testEndpoint3 = "http://endpoint3.example.com" +) + // TestEndpointSwitchingOnFailure tests that when an endpoint fails, // the retry mechanism switches to the next available endpoint func TestEndpointSwitchingOnFailure(t *testing.T) { @@ -17,9 +23,9 @@ func TestEndpointSwitchingOnFailure(t *testing.T) { // Create a mock iterator with multiple endpoints endpoints := []*Web3Endpoint{ - {ChainID: 1, URI: "http://endpoint1.example.com"}, - {ChainID: 1, URI: "http://endpoint2.example.com"}, - {ChainID: 1, URI: "http://endpoint3.example.com"}, + {ChainID: 1, URI: testEndpoint1}, + {ChainID: 1, URI: testEndpoint2}, + {ChainID: 1, URI: testEndpoint3}, } pool.endpoints[1] = NewWeb3Iterator(endpoints...) @@ -28,16 +34,16 @@ func TestEndpointSwitchingOnFailure(t *testing.T) { c.Assert(pool.NumberOfEndpoints(1, true), qt.Equals, 3) // Disable first endpoint - pool.DisableEndpoint(1, "http://endpoint1.example.com") + pool.DisableEndpoint(1, testEndpoint1) c.Assert(pool.NumberOfEndpoints(1, true), qt.Equals, 2) c.Assert(pool.NumberOfEndpoints(1, false), qt.Equals, 3) // Disable second endpoint - pool.DisableEndpoint(1, "http://endpoint2.example.com") + pool.DisableEndpoint(1, testEndpoint2) c.Assert(pool.NumberOfEndpoints(1, true), qt.Equals, 1) // Disable third endpoint - should trigger reset - pool.DisableEndpoint(1, "http://endpoint3.example.com") + pool.DisableEndpoint(1, testEndpoint3) // After disabling all endpoints, they should be reset to available c.Assert(pool.NumberOfEndpoints(1, true), qt.Equals, 3) @@ -50,8 +56,8 @@ func TestDisableNonExistentEndpoint(t *testing.T) { pool := NewWeb3Pool() endpoints := []*Web3Endpoint{ - {ChainID: 1, URI: "http://endpoint1.example.com"}, - {ChainID: 1, URI: "http://endpoint2.example.com"}, + {ChainID: 1, URI: testEndpoint1}, + {ChainID: 1, URI: testEndpoint2}, } pool.endpoints[1] = NewWeb3Iterator(endpoints...) @@ -63,7 +69,7 @@ func TestDisableNonExistentEndpoint(t *testing.T) { c.Assert(pool.NumberOfEndpoints(1, true), qt.Equals, 2) // Try to disable from a chainID that doesn't exist - pool.DisableEndpoint(999, "http://endpoint1.example.com") + pool.DisableEndpoint(999, testEndpoint1) // Original chain should still have 2 endpoints c.Assert(pool.NumberOfEndpoints(1, true), qt.Equals, 2) @@ -73,9 +79,9 @@ func TestDisableNonExistentEndpoint(t *testing.T) { func TestIteratorRoundRobin(t *testing.T) { c := qt.New(t) endpoints := []*Web3Endpoint{ - {ChainID: 1, URI: "http://endpoint1.example.com"}, - {ChainID: 1, URI: "http://endpoint2.example.com"}, - {ChainID: 1, URI: "http://endpoint3.example.com"}, + {ChainID: 1, URI: testEndpoint1}, + {ChainID: 1, URI: testEndpoint2}, + {ChainID: 1, URI: testEndpoint3}, } iter := NewWeb3Iterator(endpoints...) @@ -83,29 +89,29 @@ func TestIteratorRoundRobin(t *testing.T) { // Get endpoints in round-robin fashion ep1, err := iter.Next() c.Assert(err, qt.IsNil) - c.Assert(ep1.URI, qt.Equals, "http://endpoint1.example.com") + c.Assert(ep1.URI, qt.Equals, testEndpoint1) ep2, err := iter.Next() c.Assert(err, qt.IsNil) - c.Assert(ep2.URI, qt.Equals, "http://endpoint2.example.com") + c.Assert(ep2.URI, qt.Equals, testEndpoint2) ep3, err := iter.Next() c.Assert(err, qt.IsNil) - c.Assert(ep3.URI, qt.Equals, "http://endpoint3.example.com") + c.Assert(ep3.URI, qt.Equals, testEndpoint3) // Should wrap around to first endpoint ep4, err := iter.Next() c.Assert(err, qt.IsNil) - c.Assert(ep4.URI, qt.Equals, "http://endpoint1.example.com") + c.Assert(ep4.URI, qt.Equals, testEndpoint1) } // TestIteratorDisableAndNext tests that disabling an endpoint properly updates the next index func TestIteratorDisableAndNext(t *testing.T) { c := qt.New(t) endpoints := []*Web3Endpoint{ - {ChainID: 1, URI: "http://endpoint1.example.com"}, - {ChainID: 1, URI: "http://endpoint2.example.com"}, - {ChainID: 1, URI: "http://endpoint3.example.com"}, + {ChainID: 1, URI: testEndpoint1}, + {ChainID: 1, URI: testEndpoint2}, + {ChainID: 1, URI: testEndpoint3}, } iter := NewWeb3Iterator(endpoints...) @@ -113,32 +119,32 @@ func TestIteratorDisableAndNext(t *testing.T) { // Get first endpoint (nextIndex moves to 1, pointing to endpoint2) ep1, err := iter.Next() c.Assert(err, qt.IsNil) - c.Assert(ep1.URI, qt.Equals, "http://endpoint1.example.com") + c.Assert(ep1.URI, qt.Equals, testEndpoint1) // Disable the second endpoint (at index 1, which is where nextIndex points) // After removal: [endpoint1, endpoint3], nextIndex stays at 1 but gets decremented to 0 // because we removed an element before it - iter.Disable("http://endpoint2.example.com") + iter.Disable(testEndpoint2) // Next should return endpoint1 (at index 0, since nextIndex was adjusted) // Then nextIndex moves to 1 ep2, err := iter.Next() c.Assert(err, qt.IsNil) - c.Assert(ep2.URI, qt.Equals, "http://endpoint1.example.com") + c.Assert(ep2.URI, qt.Equals, testEndpoint1) // Next should return endpoint3 (at index 1) ep3, err := iter.Next() c.Assert(err, qt.IsNil) - c.Assert(ep3.URI, qt.Equals, "http://endpoint3.example.com") + c.Assert(ep3.URI, qt.Equals, testEndpoint3) } // TestIteratorDisableCurrentEndpoint tests disabling the current endpoint func TestIteratorDisableCurrentEndpoint(t *testing.T) { c := qt.New(t) endpoints := []*Web3Endpoint{ - {ChainID: 1, URI: "http://endpoint1.example.com"}, - {ChainID: 1, URI: "http://endpoint2.example.com"}, - {ChainID: 1, URI: "http://endpoint3.example.com"}, + {ChainID: 1, URI: testEndpoint1}, + {ChainID: 1, URI: testEndpoint2}, + {ChainID: 1, URI: testEndpoint3}, } iter := NewWeb3Iterator(endpoints...) @@ -153,7 +159,7 @@ func TestIteratorDisableCurrentEndpoint(t *testing.T) { // Next should return endpoint2 (index 0 after removal, but nextIndex was adjusted) ep2, err := iter.Next() c.Assert(err, qt.IsNil) - c.Assert(ep2.URI, qt.Equals, "http://endpoint2.example.com") + c.Assert(ep2.URI, qt.Equals, testEndpoint2) } // TestIteratorEmptyPool tests behavior with no endpoints @@ -171,17 +177,17 @@ func TestIteratorEmptyPool(t *testing.T) { func TestIteratorAllDisabled(t *testing.T) { c := qt.New(t) endpoints := []*Web3Endpoint{ - {ChainID: 1, URI: "http://endpoint1.example.com"}, - {ChainID: 1, URI: "http://endpoint2.example.com"}, + {ChainID: 1, URI: testEndpoint1}, + {ChainID: 1, URI: testEndpoint2}, } iter := NewWeb3Iterator(endpoints...) // Disable all endpoints - iter.Disable("http://endpoint1.example.com") + iter.Disable(testEndpoint1) c.Assert(iter.Available(), qt.Equals, 1) - iter.Disable("http://endpoint2.example.com") + iter.Disable(testEndpoint2) // Should have reset all to available c.Assert(iter.Available(), qt.Equals, 2) @@ -192,9 +198,9 @@ func TestIteratorAllDisabled(t *testing.T) { func TestConcurrentAccess(t *testing.T) { c := qt.New(t) endpoints := []*Web3Endpoint{ - {ChainID: 1, URI: "http://endpoint1.example.com"}, - {ChainID: 1, URI: "http://endpoint2.example.com"}, - {ChainID: 1, URI: "http://endpoint3.example.com"}, + {ChainID: 1, URI: testEndpoint1}, + {ChainID: 1, URI: testEndpoint2}, + {ChainID: 1, URI: testEndpoint3}, } iter := NewWeb3Iterator(endpoints...) @@ -214,7 +220,7 @@ func TestConcurrentAccess(t *testing.T) { // Also disable endpoints concurrently go func() { for range 10 { - iter.Disable("http://endpoint1.example.com") + iter.Disable(testEndpoint1) time.Sleep(time.Millisecond) } done <- true @@ -235,8 +241,8 @@ func TestRetryLogic(t *testing.T) { pool := NewWeb3Pool() endpoints := []*Web3Endpoint{ - {ChainID: 1, URI: "http://endpoint1.example.com"}, - {ChainID: 1, URI: "http://endpoint2.example.com"}, + {ChainID: 1, URI: testEndpoint1}, + {ChainID: 1, URI: testEndpoint2}, } pool.endpoints[1] = NewWeb3Iterator(endpoints...) @@ -272,8 +278,8 @@ func TestRetryAllEndpointsFail(t *testing.T) { pool := NewWeb3Pool() endpoints := []*Web3Endpoint{ - {ChainID: 1, URI: "http://endpoint1.example.com"}, - {ChainID: 1, URI: "http://endpoint2.example.com"}, + {ChainID: 1, URI: testEndpoint1}, + {ChainID: 1, URI: testEndpoint2}, } pool.endpoints[1] = NewWeb3Iterator(endpoints...) diff --git a/web3/web3_config_test.go b/web3/web3_config_test.go index cd9f195ab..3f3b9b028 100644 --- a/web3/web3_config_test.go +++ b/web3/web3_config_test.go @@ -7,6 +7,12 @@ import ( qt "github.com/frankban/quicktest" ) +const ( + testSepoliaContract = "11155111:0x015eac820688da203a0bd730a8a7a4cdb97e1a02" + testSepoliaPrefix = "11155111:" + testMainnetContract = "1:0x9b7c0b5e1240373c2d8a1f3b7e0b2d8d4a6f3c2e" +) + func TestAddressesByChainID(t *testing.T) { c := qt.New(t) @@ -29,7 +35,7 @@ func TestAddressesByChainID(t *testing.T) { }, { desc: "single valid match", - contracts: []string{"11155111:0x015eac820688da203a0bd730a8a7a4cdb97e1a02"}, + contracts: []string{testSepoliaContract}, chainID: 11155111, wantAddr: sepoliaAddr, wantResult: true, @@ -43,7 +49,7 @@ func TestAddressesByChainID(t *testing.T) { }, { desc: "chain ID mismatch returns nil", - contracts: []string{"11155111:0x015eac820688da203a0bd730a8a7a4cdb97e1a02"}, + contracts: []string{testSepoliaContract}, chainID: 1, wantResult: false, }, @@ -67,14 +73,14 @@ func TestAddressesByChainID(t *testing.T) { }, { desc: "empty address part skips entry", - contracts: []string{"11155111:"}, + contracts: []string{testSepoliaPrefix}, chainID: 11155111, wantResult: false, }, { desc: "multiple entries, first matching returns", contracts: []string{ - "11155111:0x015eac820688da203a0bd730a8a7a4cdb97e1a02", + testSepoliaContract, "42220:0x68dac70af68aa0bed8cef36c523243941d7d7876", }, chainID: 11155111, @@ -84,8 +90,8 @@ func TestAddressesByChainID(t *testing.T) { { desc: "multiple entries, later entry matches", contracts: []string{ - "1:0x9b7c0b5e1240373c2d8a1f3b7e0b2d8d4a6f3c2e", - "11155111:0x015eac820688da203a0bd730a8a7a4cdb97e1a02", + testMainnetContract, + testSepoliaContract, }, chainID: 11155111, wantAddr: sepoliaAddr, @@ -95,7 +101,7 @@ func TestAddressesByChainID(t *testing.T) { desc: "multiple entries, none match returns nil", contracts: []string{ "42220:0x68dac70af68aa0bed8cef36c523243941d7d7876", - "1:0x9b7c0b5e1240373c2d8a1f3b7e0b2d8d4a6f3c2e", + testMainnetContract, }, chainID: 11155111, wantResult: false, @@ -105,9 +111,9 @@ func TestAddressesByChainID(t *testing.T) { contracts: []string{ "bad", ":", - "11155111:", + testSepoliaPrefix, "11155111:0x0000000000000000000000000000000000000000", - "11155111:0x015eac820688da203a0bd730a8a7a4cdb97e1a02", + testSepoliaContract, }, chainID: 11155111, wantAddr: sepoliaAddr, @@ -118,7 +124,7 @@ func TestAddressesByChainID(t *testing.T) { contracts: []string{ "bad", "nope:0x015eac820688da203a0bd730a8a7a4cdb97e1a02", - "11155111:", + testSepoliaPrefix, }, chainID: 11155111, wantResult: false, @@ -126,8 +132,8 @@ func TestAddressesByChainID(t *testing.T) { { desc: "matching entry with mainnet address", contracts: []string{ - "11155111:0x015eac820688da203a0bd730a8a7a4cdb97e1a02", - "1:0x9b7c0b5e1240373c2d8a1f3b7e0b2d8d4a6f3c2e", + testSepoliaContract, + testMainnetContract, }, chainID: 1, wantAddr: mainnetAddr, diff --git a/workers/worker_manager_test.go b/workers/worker_manager_test.go index 06fa4e593..23db774b7 100644 --- a/workers/worker_manager_test.go +++ b/workers/worker_manager_test.go @@ -69,7 +69,7 @@ func TestWorkerIsBanned(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { worker := &Worker{ - Address: "test-worker", + Address: testWorkerAddr, consecutiveFails: int64(tt.consecutiveFails), } @@ -113,8 +113,8 @@ func TestWorkerManagerGetWorker(t *testing.T) { c.Assert(exists, qt.IsFalse) // Add a worker and test getting it - addedWorker := wm.AddWorker("test-worker", testWorkerName) - retrievedWorker, exists := wm.GetWorker("test-worker") + addedWorker := wm.AddWorker(testWorkerAddr, testWorkerName) + retrievedWorker, exists := wm.GetWorker(testWorkerAddr) c.Assert(retrievedWorker, qt.IsNotNil) c.Assert(exists, qt.IsTrue) c.Assert(retrievedWorker, qt.Equals, addedWorker) @@ -259,8 +259,8 @@ func TestWorkerManagerStartStop(t *testing.T) { c.Assert(wm.cancelFunc, qt.IsNotNil) // Add a worker to verify it exists - wm.AddWorker("test-worker", testWorkerName) - worker, exists := wm.GetWorker("test-worker") + wm.AddWorker(testWorkerAddr, testWorkerName) + worker, exists := wm.GetWorker(testWorkerAddr) c.Assert(exists, qt.IsTrue) c.Assert(worker, qt.IsNotNil) @@ -268,7 +268,7 @@ func TestWorkerManagerStartStop(t *testing.T) { wm.Stop() // Verify workers are cleared - worker, exists = wm.GetWorker("test-worker") + worker, exists = wm.GetWorker(testWorkerAddr) c.Assert(exists, qt.IsFalse) c.Assert(worker, qt.IsNil) } @@ -568,8 +568,8 @@ func TestWorkerManagerContextCancellation(t *testing.T) { c.Assert(wm.innerCtx, qt.IsNotNil) // Add a worker - wm.AddWorker("test-worker", testWorkerName) - _, exists := wm.GetWorker("test-worker") + wm.AddWorker(testWorkerAddr, testWorkerName) + _, exists := wm.GetWorker(testWorkerAddr) c.Assert(exists, qt.IsTrue) // Cancel context @@ -579,7 +579,7 @@ func TestWorkerManagerContextCancellation(t *testing.T) { time.Sleep(50 * time.Millisecond) // Worker should be cleared because context cancellation calls stop() which clears workers - _, exists = wm.GetWorker("test-worker") + _, exists = wm.GetWorker(testWorkerAddr) c.Assert(exists, qt.IsFalse) } diff --git a/workers/worker_manager_time_ban_test.go b/workers/worker_manager_time_ban_test.go index 6bc00f634..f30323731 100644 --- a/workers/worker_manager_time_ban_test.go +++ b/workers/worker_manager_time_ban_test.go @@ -18,7 +18,7 @@ func TestWorkerTimeBanCoverage(t *testing.T) { t.Run("Time-based banning scenarios", func(t *testing.T) { worker := &Worker{ - Address: "test-worker", + Address: testWorkerAddr, consecutiveFails: 0, // No consecutive fails } @@ -48,7 +48,7 @@ func TestWorkerTimeBanCoverage(t *testing.T) { t.Run("Combined consecutive fails and time-based banning", func(t *testing.T) { // Test worker that is banned by consecutive fails AND has a time-based ban worker := &Worker{ - Address: "test-worker", + Address: testWorkerAddr, consecutiveFails: 5, // Above threshold } @@ -77,7 +77,7 @@ func TestWorkerTimeBanCoverage(t *testing.T) { t.Run("Edge cases for time comparison", func(t *testing.T) { worker := &Worker{ - Address: "test-worker", + Address: testWorkerAddr, consecutiveFails: 0, }