diff --git a/internal/infrastructure/db/postgres/asset_repo.go b/internal/infrastructure/db/postgres/asset_repo.go index 9d0eda536..ca67c061f 100644 --- a/internal/infrastructure/db/postgres/asset_repo.go +++ b/internal/infrastructure/db/postgres/asset_repo.go @@ -125,43 +125,46 @@ func (r *assetRepository) GetAssets( if len(assetIds) == 0 { return nil, nil } - var assets []domain.Asset - txBody := func(querierWithTx *queries.Queries) error { - rows, err := querierWithTx.SelectAssetsByIds(ctx, assetIds) - if err != nil { - return err - } - assets = make([]domain.Asset, 0, len(rows)) - for _, row := range rows { - supplyStr, err := querierWithTx.SelectAssetSupply(ctx, row.ID) - if err != nil { - return fmt.Errorf("failed to compute supply for asset %s: %w", row.ID, err) - } - supply := new(big.Int) - if _, ok := supply.SetString(supplyStr, 10); !ok { - return fmt.Errorf("invalid supply value: %s", supplyStr) - } + + rows, err := r.querier.SelectAssetsWithUnspentAmountsByIds(ctx, assetIds) + if err != nil { + return nil, err + } + + assets := make([]domain.Asset, 0, len(rows)) + indexByID := make(map[string]int, len(rows)) + for _, row := range rows { + idx, ok := indexByID[row.ID] + if !ok { ast := domain.Asset{ Id: row.ID, ControlAssetId: row.ControlAssetID.String, - Supply: *supply, + Supply: *big.NewInt(0), } + if row.Metadata.Valid { // Parsing metadata should never fail but if it does we just return an empty list // of metadata and log the error - ast.Metadata, err = asset.NewMetadataListFromString(row.Metadata.String) - if err != nil { - log.WithError(err).Warnf("failed to parse metadata for asset %s", row.ID) + metadata, parseErr := asset.NewMetadataListFromString(row.Metadata.String) + if parseErr != nil { + log.WithError(parseErr).Warnf("failed to parse metadata for asset %s", row.ID) + } else { + ast.Metadata = metadata } } assets = append(assets, ast) + idx = len(assets) - 1 + indexByID[row.ID] = idx } - return nil - } - if err := execTx(ctx, r.db, txBody); err != nil { - return nil, err + + amount, ok := new(big.Int).SetString(row.AssetAmount, 10) + if !ok { + return nil, fmt.Errorf("invalid supply value: %s", row.AssetAmount) + } + assets[idx].Supply.Add(&assets[idx].Supply, amount) } + return assets, nil } diff --git a/internal/infrastructure/db/postgres/sqlc/queries/query.sql.go b/internal/infrastructure/db/postgres/sqlc/queries/query.sql.go index a59cab93d..175621d63 100644 --- a/internal/infrastructure/db/postgres/sqlc/queries/query.sql.go +++ b/internal/infrastructure/db/postgres/sqlc/queries/query.sql.go @@ -330,6 +330,62 @@ func (q *Queries) SelectAssetsByIds(ctx context.Context, dollar_1 []string) ([]A return items, nil } +const selectAssetsWithUnspentAmountsByIds = `-- name: SelectAssetsWithUnspentAmountsByIds :many +SELECT + a.id, + a.is_immutable, + a.metadata_hash, + a.metadata, + a.control_asset_id, + COALESCE(v.asset_amount, 0)::TEXT AS asset_amount +FROM asset a +LEFT JOIN vtxo_vw v + ON v.asset_id = a.id + AND v.spent = false + AND v.asset_amount > 0 +WHERE a.id = ANY($1::varchar[]) +ORDER BY a.id +` + +type SelectAssetsWithUnspentAmountsByIdsRow struct { + ID string + IsImmutable bool + MetadataHash sql.NullString + Metadata sql.NullString + ControlAssetID sql.NullString + AssetAmount string +} + +func (q *Queries) SelectAssetsWithUnspentAmountsByIds(ctx context.Context, dollar_1 []string) ([]SelectAssetsWithUnspentAmountsByIdsRow, error) { + rows, err := q.db.QueryContext(ctx, selectAssetsWithUnspentAmountsByIds, pq.Array(dollar_1)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectAssetsWithUnspentAmountsByIdsRow + for rows.Next() { + var i SelectAssetsWithUnspentAmountsByIdsRow + if err := rows.Scan( + &i.ID, + &i.IsImmutable, + &i.MetadataHash, + &i.Metadata, + &i.ControlAssetID, + &i.AssetAmount, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const selectControlAssetByID = `-- name: SelectControlAssetByID :one SELECT control_asset_id FROM asset WHERE id = $1 ` diff --git a/internal/infrastructure/db/postgres/sqlc/query.sql b/internal/infrastructure/db/postgres/sqlc/query.sql index f24c4f701..cc7a2a361 100644 --- a/internal/infrastructure/db/postgres/sqlc/query.sql +++ b/internal/infrastructure/db/postgres/sqlc/query.sql @@ -441,6 +441,22 @@ VALUES (@asset_id, @txid, @vout, @amount); -- name: SelectAssetsByIds :many SELECT * FROM asset WHERE asset.id = ANY($1::varchar[]); +-- name: SelectAssetsWithUnspentAmountsByIds :many +SELECT + a.id, + a.is_immutable, + a.metadata_hash, + a.metadata, + a.control_asset_id, + COALESCE(v.asset_amount, 0)::TEXT AS asset_amount +FROM asset a +LEFT JOIN vtxo_vw v + ON v.asset_id = a.id + AND v.spent = false + AND v.asset_amount > 0 +WHERE a.id = ANY($1::varchar[]) +ORDER BY a.id; + -- name: SelectAssetSupply :one SELECT (COALESCE(SUM(ap.amount), 0))::TEXT AS supply FROM asset_projection ap @@ -451,4 +467,4 @@ WHERE ap.asset_id = $1 AND v.spent = false; SELECT control_asset_id FROM asset WHERE id = $1; -- name: SelectAssetExists :one -SELECT 1 FROM asset WHERE id = $1 LIMIT 1; \ No newline at end of file +SELECT 1 FROM asset WHERE id = $1 LIMIT 1; diff --git a/internal/infrastructure/db/service.go b/internal/infrastructure/db/service.go index 299b873de..97ba29d6c 100644 --- a/internal/infrastructure/db/service.go +++ b/internal/infrastructure/db/service.go @@ -310,12 +310,16 @@ func NewService(config ServiceConfig, txDecoder ports.TxDecoder) (ports.RepoMana } dbFile := filepath.Join(baseDir, sqliteDbFile) - db, err := sqlitedb.OpenDb(dbFile) + db, err := sqlitedb.OpenDb( + dbFile, + sqlitedb.WithJournalModeWAL(), + sqlitedb.WithBusyTimeout(5*time.Second), + ) if err != nil { return nil, fmt.Errorf("failed to open db: %s", err) } - driver, err := sqlitemigrate.WithInstance(db, &sqlitemigrate.Config{}) + driver, err := sqlitemigrate.WithInstance(db.Write(), &sqlitemigrate.Config{}) if err != nil { return nil, fmt.Errorf("failed to init driver: %s", err) } @@ -330,7 +334,7 @@ func NewService(config ServiceConfig, txDecoder ports.TxDecoder) (ports.RepoMana return nil, fmt.Errorf("failed to create migration instance: %s", err) } - err = handleIntentTxidMigration(m, db, config.DataStoreType) + err = handleIntentTxidMigration(m, db.Write(), config.DataStoreType) if err != nil { return nil, fmt.Errorf("failed to handle intent txid migration: %w", err) } diff --git a/internal/infrastructure/db/service_test.go b/internal/infrastructure/db/service_test.go index ee006e436..a09a9c053 100644 --- a/internal/infrastructure/db/service_test.go +++ b/internal/infrastructure/db/service_test.go @@ -186,6 +186,7 @@ func TestService(t *testing.T) { testOffchainTxRepository(t, svc) testAssetRepository(t, svc) testVtxoRepository(t, svc) + testAssetRepositorySpentOnlySupply(t, svc) testScheduledSessionRepository(t, svc) testConvictionRepository(t, svc) testFeeRepository(t, svc) @@ -1769,6 +1770,53 @@ func testAssetRepository(t *testing.T, svc ports.RepoManager) { }) } +func testAssetRepositorySpentOnlySupply(t *testing.T, svc ports.RepoManager) { + t.Run("test_asset_repository_spent_only_supply", func(t *testing.T) { + ctx := t.Context() + repo := svc.Assets() + vtxoRepo := svc.Vtxos() + + assetID := randomString(16) + vtxoTxid := randomString(32) + spentBy := randomString(32) + arkTxid := randomString(32) + + count, err := repo.AddAssets(ctx, map[string][]domain.Asset{"spentOnlyAssetTx": { + { + Id: assetID, + Metadata: []asset.Metadata{}, + }, + }}) + require.NoError(t, err) + require.Equal(t, 1, count) + + spentOnlyVtxo := domain.Vtxo{ + Outpoint: domain.Outpoint{ + Txid: vtxoTxid, + VOut: 0, + }, + Amount: 330, + Assets: []domain.AssetDenomination{{ + AssetId: assetID, + Amount: 42, + }}, + } + err = vtxoRepo.AddVtxos(ctx, []domain.Vtxo{spentOnlyVtxo}) + require.NoError(t, err) + + err = vtxoRepo.SpendVtxos(ctx, map[domain.Outpoint]string{ + spentOnlyVtxo.Outpoint: spentBy, + }, arkTxid) + require.NoError(t, err) + + assets, err := repo.GetAssets(ctx, []string{assetID}) + require.NoError(t, err) + require.Len(t, assets, 1) + require.Equal(t, assetID, assets[0].Id) + require.Zero(t, assets[0].Supply.Sign()) + }) +} + func testFeeRepository(t *testing.T, svc ports.RepoManager) { t.Run("test_fee_repository", func(t *testing.T) { ctx := context.Background() diff --git a/internal/infrastructure/db/sqlite/asset_repo.go b/internal/infrastructure/db/sqlite/asset_repo.go index 7aa553e99..b422a6b75 100644 --- a/internal/infrastructure/db/sqlite/asset_repo.go +++ b/internal/infrastructure/db/sqlite/asset_repo.go @@ -4,33 +4,32 @@ import ( "context" "database/sql" "encoding/hex" + "errors" "fmt" + "math/big" "sort" "github.com/arkade-os/arkd/internal/core/domain" "github.com/arkade-os/arkd/internal/infrastructure/db/sqlite/sqlc/queries" "github.com/arkade-os/arkd/pkg/ark-lib/asset" - "github.com/shopspring/decimal" log "github.com/sirupsen/logrus" ) type assetRepository struct { - db *sql.DB - querier *queries.Queries + db SQLiteDB } func NewAssetRepository(config ...interface{}) (domain.AssetRepository, error) { if len(config) != 1 { return nil, fmt.Errorf("invalid config") } - db, ok := config[0].(*sql.DB) + db, ok := config[0].(SQLiteDB) if !ok { return nil, fmt.Errorf("cannot open asset repository: invalid config") } return &assetRepository{ - db: db, - querier: queries.New(db), + db: db, }, nil } @@ -111,7 +110,7 @@ func (r *assetRepository) AddAssets( return nil } - if err := execTx(ctx, r.db, txBody); err != nil { + if err := execTx(ctx, r.db.Write(), txBody); err != nil { return -1, err } return count, nil @@ -123,54 +122,61 @@ func (r *assetRepository) GetAssets( if len(assetIds) == 0 { return nil, nil } - var assets []domain.Asset - txBody := func(querierWithTx *queries.Queries) error { - rows, err := querierWithTx.SelectAssetsByIds(ctx, assetIds) - if err != nil { - return err - } - assets = make([]domain.Asset, 0, len(rows)) - for _, row := range rows { - // TODO: this is not efficient, but avoids overflows - amounts, err := querierWithTx.SelectAssetAmounts(ctx, row.ID) - if err != nil { - return fmt.Errorf("failed to compute supply for asset %s: %w", row.ID, err) - } - supply := decimal.NewFromFloat(0) - for _, amount := range amounts { - // nolint - dec, _ := decimal.NewFromString(amount) - supply = supply.Add(dec) - } + var rows []queries.SelectAssetsWithUnspentAmountsByIdsRow + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + rows, err = q.SelectAssetsWithUnspentAmountsByIds(ctx, assetIds) + return err + }); err != nil { + return nil, err + } + + assets := make([]domain.Asset, 0, len(rows)) + indexByID := make(map[string]int, len(rows)) + for _, row := range rows { + idx, ok := indexByID[row.ID] + if !ok { ast := domain.Asset{ Id: row.ID, ControlAssetId: row.ControlAssetID.String, - Supply: *supply.BigInt(), + Supply: *big.NewInt(0), } if row.Metadata.Valid { // Parsing metadata should never fail but if it does we just return an empty list // of metadata and log the error - ast.Metadata, err = asset.NewMetadataListFromString(row.Metadata.String) - if err != nil { - log.WithError(err).Warnf("failed to parse metadata for asset %s", row.ID) + metadata, parseErr := asset.NewMetadataListFromString(row.Metadata.String) + if parseErr != nil { + log.WithError(parseErr).Warnf("failed to parse metadata for asset %s", row.ID) + } else { + ast.Metadata = metadata } } + assets = append(assets, ast) + idx = len(assets) - 1 + indexByID[row.ID] = idx } - return nil - } - if err := execTx(ctx, r.db, txBody); err != nil { - return nil, err + + amount, ok := new(big.Int).SetString(row.AssetAmount, 10) + if !ok { + return nil, fmt.Errorf("invalid supply value: %s", row.AssetAmount) + } + assets[idx].Supply.Add(&assets[idx].Supply, amount) } + return assets, nil } func (r *assetRepository) GetControlAsset(ctx context.Context, assetID string) (string, error) { - controlID, err := r.querier.SelectControlAssetByID(ctx, assetID) - if err != nil { - if err == sql.ErrNoRows { + var controlID sql.NullString + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + controlID, err = q.SelectControlAssetByID(ctx, assetID) + return err + }); err != nil { + if errors.Is(err, sql.ErrNoRows) { return "", fmt.Errorf("no control asset found") } return "", err @@ -182,9 +188,11 @@ func (r *assetRepository) GetControlAsset(ctx context.Context, assetID string) ( } func (r *assetRepository) AssetExists(ctx context.Context, assetID string) (bool, error) { - _, err := r.querier.SelectAssetExists(ctx, assetID) - if err != nil { - if err == sql.ErrNoRows { + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + _, err := q.SelectAssetExists(ctx, assetID) + return err + }); err != nil { + if errors.Is(err, sql.ErrNoRows) { return false, nil } return false, err diff --git a/internal/infrastructure/db/sqlite/cancellation_test.go b/internal/infrastructure/db/sqlite/cancellation_test.go new file mode 100644 index 000000000..6c01d6736 --- /dev/null +++ b/internal/infrastructure/db/sqlite/cancellation_test.go @@ -0,0 +1,80 @@ +package sqlitedb + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const slowRecursiveQuery = ` + WITH RECURSIVE cnt(x) AS ( + SELECT 1 + UNION ALL + SELECT x + 1 FROM cnt WHERE x < 100000000 + ) + SELECT x FROM cnt +` + +func TestCanceledReadQueryDiscardsConnection(t *testing.T) { + // Use a shared in-memory DB because the test exercises a pinned read + // connection while the DB wrapper exposes separate read/write pools. + db, err := OpenDb("file::memory:", WithSharedCache()) + require.NoError(t, err) + t.Cleanup(func() { + _ = db.Close() + }) + + ctx := t.Context() + // Pin a single read connection so the test can verify that an interrupted + // SQLite connection is discarded instead of being reused. + conn, err := db.Read().Conn(ctx) + require.NoError(t, err) + + queryCtx, cancel := context.WithTimeout(ctx, 5*time.Millisecond) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- runSlowReadQuery(queryCtx, conn) + }() + + err = <-errCh + require.Error(t, err) + require.True(t, isInterruptError(queryCtx, err), "expected interrupt-like error, got %v", err) + // Discard the interrupted connection explicitly; a normal close would return + // it to the pool for reuse. + require.NoError(t, closeConn(conn, true)) + + assertReadPoolStillHealthy(t, db, ctx) +} + +func runSlowReadQuery(ctx context.Context, conn *sql.Conn) error { + rows, err := conn.QueryContext(ctx, slowRecursiveQuery) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var value int + if err := rows.Scan(&value); err != nil { + return err + } + } + + return rows.Err() +} + +func assertReadPoolStillHealthy(t *testing.T, db SQLiteDB, ctx context.Context) { + t.Helper() + + for range 20 { + var got int + err := db.Read().QueryRowContext(ctx, `SELECT 1`).Scan(&got) + require.NoError(t, err) + require.Equal(t, 1, got) + } +} diff --git a/internal/infrastructure/db/sqlite/conviction_repo.go b/internal/infrastructure/db/sqlite/conviction_repo.go index 2a951aee0..58f7e5d1e 100644 --- a/internal/infrastructure/db/sqlite/conviction_repo.go +++ b/internal/infrastructure/db/sqlite/conviction_repo.go @@ -11,15 +11,14 @@ import ( ) type convictionRepository struct { - db *sql.DB - querier *queries.Queries + db SQLiteDB } func NewConvictionRepository(config ...interface{}) (domain.ConvictionRepository, error) { if len(config) != 1 { return nil, fmt.Errorf("invalid config") } - db, ok := config[0].(*sql.DB) + db, ok := config[0].(SQLiteDB) if !ok { return nil, fmt.Errorf( "cannot open conviction repository: invalid config, expected db at 0", @@ -27,8 +26,7 @@ func NewConvictionRepository(config ...interface{}) (domain.ConvictionRepository } return &convictionRepository{ - db: db, - querier: queries.New(db), + db: db, }, nil } @@ -38,8 +36,12 @@ func (r *convictionRepository) Close() { } func (r *convictionRepository) Get(ctx context.Context, id string) (domain.Conviction, error) { - conviction, err := r.querier.SelectConviction(ctx, id) - if err != nil { + var conviction queries.Conviction + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + conviction, err = q.SelectConviction(ctx, id) + return err + }); err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("conviction with id %s not found", id) } @@ -55,19 +57,24 @@ func (r *convictionRepository) GetActiveScriptConvictions( ) ([]domain.ScriptConviction, error) { currentTime := time.Now().Unix() - convictions, err := r.querier.SelectActiveScriptConvictions( - ctx, - queries.SelectActiveScriptConvictionsParams{ - Script: sql.NullString{ - String: script, - Valid: true, - }, - ExpiresAt: sql.NullInt64{ - Int64: currentTime, - Valid: true, + var convictions []queries.Conviction + err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + convictions, err = q.SelectActiveScriptConvictions( + ctx, + queries.SelectActiveScriptConvictionsParams{ + Script: sql.NullString{ + String: script, + Valid: true, + }, + ExpiresAt: sql.NullInt64{ + Int64: currentTime, + Valid: true, + }, }, - }, - ) + ) + return err + }) if err != nil { if err == sql.ErrNoRows { return nil, nil @@ -95,7 +102,9 @@ func (r *convictionRepository) Add(ctx context.Context, convictions ...domain.Co return fmt.Errorf("failed to convert conviction to db params: %w", err) } - if err := r.querier.UpsertConviction(ctx, params); err != nil { + if err := withWriteQuerier(ctx, r.db, func(q *queries.Queries) error { + return q.UpsertConviction(ctx, params) + }); err != nil { return fmt.Errorf("failed to upsert conviction: %w", err) } } @@ -108,14 +117,18 @@ func (r *convictionRepository) GetAll( from, to time.Time, ) ([]domain.Conviction, error) { - convictions, err := r.querier.SelectConvictionsInTimeRange( - ctx, - queries.SelectConvictionsInTimeRangeParams{ - FromTime: from.Unix(), - ToTime: to.Unix(), - }, - ) - if err != nil { + var convictions []queries.Conviction + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + convictions, err = q.SelectConvictionsInTimeRange( + ctx, + queries.SelectConvictionsInTimeRangeParams{ + FromTime: from.Unix(), + ToTime: to.Unix(), + }, + ) + return err + }); err != nil { return nil, fmt.Errorf("failed to get convictions in time range: %w", err) } @@ -136,8 +149,12 @@ func (r *convictionRepository) GetByRoundID( roundID string, ) ([]domain.Conviction, error) { - convictions, err := r.querier.SelectConvictionsByRoundID(ctx, roundID) - if err != nil { + var convictions []queries.Conviction + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + convictions, err = q.SelectConvictionsByRoundID(ctx, roundID) + return err + }); err != nil { return nil, fmt.Errorf("failed to get convictions by round ID: %w", err) } @@ -155,7 +172,9 @@ func (r *convictionRepository) GetByRoundID( func (r *convictionRepository) Pardon(ctx context.Context, id string) error { - if err := r.querier.UpdateConvictionPardoned(ctx, id); err != nil { + if err := withWriteQuerier(ctx, r.db, func(q *queries.Queries) error { + return q.UpdateConvictionPardoned(ctx, id) + }); err != nil { return fmt.Errorf("failed to pardon conviction: %w", err) } diff --git a/internal/infrastructure/db/sqlite/intent_fees_repo.go b/internal/infrastructure/db/sqlite/intent_fees_repo.go index b1aa8e6f6..4e4b970a4 100644 --- a/internal/infrastructure/db/sqlite/intent_fees_repo.go +++ b/internal/infrastructure/db/sqlite/intent_fees_repo.go @@ -3,6 +3,7 @@ package sqlitedb import ( "context" "database/sql" + "errors" "fmt" "github.com/arkade-os/arkd/internal/core/domain" @@ -10,15 +11,14 @@ import ( ) type intentFeesRepo struct { - db *sql.DB - querier *queries.Queries + db SQLiteDB } func NewIntentFeesRepository(config ...interface{}) (domain.FeeRepository, error) { if len(config) != 1 { return nil, fmt.Errorf("invalid config") } - db, ok := config[0].(*sql.DB) + db, ok := config[0].(SQLiteDB) if !ok { return nil, fmt.Errorf( "cannot open intent fees repository: invalid config, expected db at 0", @@ -26,8 +26,7 @@ func NewIntentFeesRepository(config ...interface{}) (domain.FeeRepository, error } return &intentFeesRepo{ - db: db, - querier: queries.New(db), + db: db, }, nil } @@ -37,9 +36,13 @@ func (r *intentFeesRepo) Close() { } func (r *intentFeesRepo) GetIntentFees(ctx context.Context) (*domain.IntentFees, error) { - intentFees, err := r.querier.SelectLatestIntentFees(ctx) - if err != nil { - if err == sql.ErrNoRows { + var intentFees queries.IntentFee + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + intentFees, err = q.SelectLatestIntentFees(ctx) + return err + }); err != nil { + if errors.Is(err, sql.ErrNoRows) { return &domain.IntentFees{}, nil } return nil, fmt.Errorf("failed to get intent fees: %w", err) @@ -60,11 +63,13 @@ func (r *intentFeesRepo) UpdateIntentFees(ctx context.Context, fees domain.Inten return fmt.Errorf("missing fees to update") } - if err := r.querier.AddIntentFees(ctx, queries.AddIntentFeesParams{ - OnchainInputFeeProgram: fees.OnchainInputFee, - OffchainInputFeeProgram: fees.OffchainInputFee, - OnchainOutputFeeProgram: fees.OnchainOutputFee, - OffchainOutputFeeProgram: fees.OffchainOutputFee, + if err := withWriteQuerier(ctx, r.db, func(q *queries.Queries) error { + return q.AddIntentFees(ctx, queries.AddIntentFeesParams{ + OnchainInputFeeProgram: fees.OnchainInputFee, + OffchainInputFeeProgram: fees.OffchainInputFee, + OnchainOutputFeeProgram: fees.OnchainOutputFee, + OffchainOutputFeeProgram: fees.OffchainOutputFee, + }) }); err != nil { return fmt.Errorf("failed to add intent fees: %w", err) } @@ -73,7 +78,9 @@ func (r *intentFeesRepo) UpdateIntentFees(ctx context.Context, fees domain.Inten } func (r *intentFeesRepo) ClearIntentFees(ctx context.Context) error { - if err := r.querier.ClearIntentFees(ctx); err != nil { + if err := withWriteQuerier(ctx, r.db, func(q *queries.Queries) error { + return q.ClearIntentFees(ctx) + }); err != nil { return fmt.Errorf("failed to clear intent fees: %w", err) } diff --git a/internal/infrastructure/db/sqlite/intent_txid_migration_test.go b/internal/infrastructure/db/sqlite/intent_txid_migration_test.go index a39a89347..12b5aaba6 100644 --- a/internal/infrastructure/db/sqlite/intent_txid_migration_test.go +++ b/internal/infrastructure/db/sqlite/intent_txid_migration_test.go @@ -13,7 +13,8 @@ import ( func TestIntentTxidMigration(t *testing.T) { ctx := context.Background() - db, err := sqlitedb.OpenDb(":memory:") + // shared in-memory SQLite DB so multiple connections (read/write pools) see the same data + db, err := sqlitedb.OpenDb("file::memory:", sqlitedb.WithSharedCache()) require.NoError(t, err) t.Cleanup(func() { @@ -21,24 +22,25 @@ func TestIntentTxidMigration(t *testing.T) { db.Close() }) // create table intent references - setupRoundTable(t, db) + setupRoundTable(t, db.Write()) // create intent table using old schema - setupOldIntentTable(t, db) + setupOldIntentTable(t, db.Write()) // insert dummy data into tables - insertTestRoundRows(t, db) - insertTestIntentRows(t, db) + insertTestRoundRows(t, db.Write()) + insertTestIntentRows(t, db.Write()) // add new txid field to intent table - modifyIntentTable(t, db) + modifyIntentTable(t, db.Write()) // run the backfill to populate intent rows with derived txids - err = sqlitedb.BackfillIntentTxid(ctx, db) + err = sqlitedb.BackfillIntentTxid(ctx, db.Write()) require.NoError(t, err) // check the intent table has the new txid column var hasID int - err = db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('intent') WHERE name = 'txid'`). + err = db.Read(). + QueryRow(`SELECT COUNT(*) FROM pragma_table_info('intent') WHERE name = 'txid'`). Scan(&hasID) require.NoError(t, err) require.Equal(t, 1, hasID) @@ -48,7 +50,7 @@ func TestIntentTxidMigration(t *testing.T) { Txid string Proof string } - rows, err := db.Query(` + rows, err := db.Read().Query(` SELECT id, txid, proof FROM intent; `) require.NoError(t, err) diff --git a/internal/infrastructure/db/sqlite/offchain_tx_repo.go b/internal/infrastructure/db/sqlite/offchain_tx_repo.go index c9df73897..2ef88db25 100644 --- a/internal/infrastructure/db/sqlite/offchain_tx_repo.go +++ b/internal/infrastructure/db/sqlite/offchain_tx_repo.go @@ -10,22 +10,20 @@ import ( ) type offchainTxRepository struct { - db *sql.DB - querier *queries.Queries + db SQLiteDB } func NewOffchainTxRepository(config ...interface{}) (domain.OffchainTxRepository, error) { if len(config) != 1 { return nil, fmt.Errorf("invalid config") } - db, ok := config[0].(*sql.DB) + db, ok := config[0].(SQLiteDB) if !ok { return nil, fmt.Errorf("cannot open offchain tx repository: invalid config") } return &offchainTxRepository{ - db: db, - querier: queries.New(db), + db: db, }, nil } @@ -66,14 +64,18 @@ func (v *offchainTxRepository) AddOrUpdateOffchainTx( } return nil } - return execTx(ctx, v.db, txBody) + return execTx(ctx, v.db.Write(), txBody) } func (v *offchainTxRepository) GetOffchainTx( ctx context.Context, txid string, ) (*domain.OffchainTx, error) { - rows, err := v.querier.SelectOffchainTx(ctx, txid) - if err != nil { + var rows []queries.SelectOffchainTxRow + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + rows, err = q.SelectOffchainTx(ctx, txid) + return err + }); err != nil { return nil, err } if len(rows) == 0 { diff --git a/internal/infrastructure/db/sqlite/round_repo.go b/internal/infrastructure/db/sqlite/round_repo.go index ba4d4efac..ec800c357 100644 --- a/internal/infrastructure/db/sqlite/round_repo.go +++ b/internal/infrastructure/db/sqlite/round_repo.go @@ -14,22 +14,20 @@ import ( ) type roundRepository struct { - db *sql.DB - querier *queries.Queries + db SQLiteDB } func NewRoundRepository(config ...interface{}) (domain.RoundRepository, error) { if len(config) != 1 { return nil, fmt.Errorf("invalid config") } - db, ok := config[0].(*sql.DB) + db, ok := config[0].(SQLiteDB) if !ok { return nil, fmt.Errorf("cannot open round repository: invalid config, expected db at 0") } return &roundRepository{ - db: db, - querier: queries.New(db), + db: db, }, nil } @@ -41,23 +39,23 @@ func (r *roundRepository) GetRoundIds( ctx context.Context, startedAfter, startedBefore int64, withFailed, withCompleted bool, ) ([]string, error) { var roundIDs []string - if startedAfter == 0 && startedBefore == 0 { - // Use filtering query when no time range is specified - ids, err := r.querier.SelectRoundIdsWithFilters( - ctx, - queries.SelectRoundIdsWithFiltersParams{ - WithFailed: withFailed, - WithCompleted: withCompleted, - }, - ) - if err != nil { - return nil, err + err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + if startedAfter == 0 && startedBefore == 0 { + ids, err := q.SelectRoundIdsWithFilters( + ctx, + queries.SelectRoundIdsWithFiltersParams{ + WithFailed: withFailed, + WithCompleted: withCompleted, + }, + ) + if err != nil { + return err + } + roundIDs = ids + return nil } - roundIDs = ids - } else { - // Use time range filtering query - ids, err := r.querier.SelectRoundIdsInTimeRangeWithFilters( + ids, err := q.SelectRoundIdsInTimeRangeWithFilters( ctx, queries.SelectRoundIdsInTimeRangeWithFiltersParams{ StartTs: startedAfter, @@ -67,10 +65,14 @@ func (r *roundRepository) GetRoundIds( }, ) if err != nil { - return nil, err + return err } roundIDs = ids + return nil + }) + if err != nil { + return nil, err } return roundIDs, nil @@ -211,108 +213,126 @@ func (r *roundRepository) AddOrUpdateRound(ctx context.Context, round domain.Rou return nil } - return execTx(ctx, r.db, txBody) + return execTx(ctx, r.db.Write(), txBody) } func (r *roundRepository) GetRoundWithId(ctx context.Context, id string) (*domain.Round, error) { - rows, err := r.querier.SelectRoundWithId(ctx, id) - if err != nil { - return nil, err - } + var round *domain.Round + err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + // Keep these related reads on the same pinned connection so cancellation + // and connection discard semantics apply consistently across the full load. + rows, err := q.SelectRoundWithId(ctx, id) + if err != nil { + return err + } - rvs := make([]combinedRow, 0, len(rows)) - for _, row := range rows { - rvs = append(rvs, combinedRow{ - round: row.Round, - intent: row.RoundIntentsVw, - tx: row.RoundTxsVw, - }) - } + rvs := make([]combinedRow, 0, len(rows)) + for _, row := range rows { + rvs = append( + rvs, + combinedRow{round: row.Round, intent: row.RoundIntentsVw, tx: row.RoundTxsVw}, + ) + } - rounds, err := rowsToRounds(rvs) - if err != nil { - return nil, err - } + rounds, err := rowsToRounds(rvs) + if err != nil { + return err + } + if len(rounds) == 0 { + return errors.New("batch not found") + } - if len(rounds) == 0 { - return nil, errors.New("batch not found") - } + round = rounds[0] + roundID := sql.NullString{String: round.Id, Valid: true} - round := rounds[0] - roundID := sql.NullString{String: round.Id, Valid: true} + receivers, err := q.SelectIntentReceiversByRoundId(ctx, roundID) + if err != nil { + return err + } + for _, row := range receivers { + applyReceiverToRound(round, row.IntentWithReceiversVw) + } - receivers, err := r.querier.SelectIntentReceiversByRoundId(ctx, roundID) - if err != nil { - return nil, err - } - for _, row := range receivers { - applyReceiverToRound(round, row.IntentWithReceiversVw) - } + vtxoInputs, err := q.SelectVtxoInputsByRoundId(ctx, roundID) + if err != nil { + return err + } + for _, row := range vtxoInputs { + applyVtxoInputToRound(round, row.IntentWithInputsVw) + } - vtxoInputs, err := r.querier.SelectVtxoInputsByRoundId(ctx, roundID) + return nil + }) if err != nil { return nil, err } - for _, row := range vtxoInputs { - applyVtxoInputToRound(round, row.IntentWithInputsVw) - } - return round, nil } func (r *roundRepository) GetRoundWithCommitmentTxid( ctx context.Context, txid string, ) (*domain.Round, error) { - rows, err := r.querier.SelectRoundWithTxid(ctx, txid) - if err != nil { - return nil, err - } + var round *domain.Round + err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + // Keep these related reads on the same pinned connection so cancellation + // and connection discard semantics apply consistently across the full load. + rows, err := q.SelectRoundWithTxid(ctx, txid) + if err != nil { + return err + } - rvs := make([]combinedRow, 0, len(rows)) - for _, row := range rows { - rvs = append(rvs, combinedRow{ - round: row.Round, - intent: row.RoundIntentsVw, - tx: row.RoundTxsVw, - }) - } + rvs := make([]combinedRow, 0, len(rows)) + for _, row := range rows { + rvs = append( + rvs, + combinedRow{round: row.Round, intent: row.RoundIntentsVw, tx: row.RoundTxsVw}, + ) + } - rounds, err := rowsToRounds(rvs) - if err != nil { - return nil, err - } + rounds, err := rowsToRounds(rvs) + if err != nil { + return err + } + if len(rounds) == 0 { + return errors.New("batch not found") + } - if len(rounds) == 0 { - return nil, errors.New("batch not found") - } + round = rounds[0] + roundID := sql.NullString{String: round.Id, Valid: true} - round := rounds[0] - roundID := sql.NullString{String: round.Id, Valid: true} + receivers, err := q.SelectIntentReceiversByRoundId(ctx, roundID) + if err != nil { + return err + } + for _, row := range receivers { + applyReceiverToRound(round, row.IntentWithReceiversVw) + } - receivers, err := r.querier.SelectIntentReceiversByRoundId(ctx, roundID) - if err != nil { - return nil, err - } - for _, row := range receivers { - applyReceiverToRound(round, row.IntentWithReceiversVw) - } + vtxoInputs, err := q.SelectVtxoInputsByRoundId(ctx, roundID) + if err != nil { + return err + } + for _, row := range vtxoInputs { + applyVtxoInputToRound(round, row.IntentWithInputsVw) + } - vtxoInputs, err := r.querier.SelectVtxoInputsByRoundId(ctx, roundID) + return nil + }) if err != nil { return nil, err } - for _, row := range vtxoInputs { - applyVtxoInputToRound(round, row.IntentWithInputsVw) - } - return round, nil } func (r *roundRepository) GetRoundStats( ctx context.Context, id string, ) (*domain.RoundStats, error) { - rs, err := r.querier.SelectRoundStats(ctx, id) - if err != nil { + var rs queries.SelectRoundStatsRow + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + rs, err = q.SelectRoundStats(ctx, id) + return err + }); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -372,14 +392,24 @@ func (r *roundRepository) GetRoundStats( } func (r *roundRepository) GetSweepableRounds(ctx context.Context) ([]string, error) { - return r.querier.SelectSweepableRounds(ctx) + var rounds []string + err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + rounds, err = q.SelectSweepableRounds(ctx) + return err + }) + return rounds, err } func (r *roundRepository) GetRoundForfeitTxs( ctx context.Context, commitmentTxid string, ) ([]domain.ForfeitTx, error) { - rows, err := r.querier.SelectRoundForfeitTxs(ctx, commitmentTxid) - if err != nil { + var rows []queries.Tx + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + rows, err = q.SelectRoundForfeitTxs(ctx, commitmentTxid) + return err + }); err != nil { return nil, err } @@ -397,8 +427,12 @@ func (r *roundRepository) GetRoundForfeitTxs( func (r *roundRepository) GetSweepTxs( ctx context.Context, commitmentTxid string, ) (map[string]string, error) { - rows, err := r.querier.SelectRoundSweepTxs(ctx, commitmentTxid) - if err != nil { + var rows []queries.SelectRoundSweepTxsRow + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + rows, err = q.SelectRoundSweepTxs(ctx, commitmentTxid) + return err + }); err != nil { return nil, err } @@ -413,8 +447,12 @@ func (r *roundRepository) GetSweepTxs( func (r *roundRepository) GetRoundConnectorTree( ctx context.Context, commitmentTxid string, ) (tree.FlatTxTree, error) { - rows, err := r.querier.SelectRoundConnectors(ctx, commitmentTxid) - if err != nil { + var rows []queries.Tx + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + rows, err = q.SelectRoundConnectors(ctx, commitmentTxid) + return err + }); err != nil { return nil, err } @@ -440,14 +478,24 @@ func (r *roundRepository) GetRoundConnectorTree( } func (r *roundRepository) GetSweptRoundsConnectorAddress(ctx context.Context) ([]string, error) { - return r.querier.SelectSweptRoundsConnectorAddress(ctx) + var addresses []string + err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + addresses, err = q.SelectSweptRoundsConnectorAddress(ctx) + return err + }) + return addresses, err } func (r *roundRepository) GetRoundVtxoTree( ctx context.Context, txid string, ) (tree.FlatTxTree, error) { - rows, err := r.querier.SelectRoundVtxoTree(ctx, txid) - if err != nil { + var rows []queries.Tx + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + rows, err = q.SelectRoundVtxoTree(ctx, txid) + return err + }); err != nil { return nil, err } @@ -472,12 +520,12 @@ func (r *roundRepository) GetRoundVtxoTree( } func (r *roundRepository) GetTxsWithTxids(ctx context.Context, txids []string) ([]string, error) { - rows, err := r.querier.SelectTxs(ctx, queries.SelectTxsParams{ - Ids1: txids, - Ids2: txids, - Ids3: txids, - }) - if err != nil { + var rows []queries.SelectTxsRow + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + rows, err = q.SelectTxs(ctx, queries.SelectTxsParams{Ids1: txids, Ids2: txids, Ids3: txids}) + return err + }); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -495,13 +543,17 @@ func (r *roundRepository) GetTxsWithTxids(ctx context.Context, txids []string) ( func (r *roundRepository) GetRoundsWithCommitmentTxids( ctx context.Context, txids []string, ) (map[string]any, error) { - txids, err := r.querier.SelectRoundsWithTxids(ctx, txids) - if err != nil { + var roundTxids []string + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + roundTxids, err = q.SelectRoundsWithTxids(ctx, txids) + return err + }); err != nil { return nil, err } resp := make(map[string]any) - for _, txid := range txids { + for _, txid := range roundTxids { resp[txid] = nil } return resp, nil @@ -511,8 +563,12 @@ func (r *roundRepository) GetIntentByTxid( ctx context.Context, txid string, ) (*domain.Intent, error) { - intent, err := r.querier.SelectIntentByTxid(ctx, sql.NullString{String: txid, Valid: true}) - if err != nil { + var intent queries.SelectIntentByTxidRow + if err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + intent, err = q.SelectIntentByTxid(ctx, sql.NullString{String: txid, Valid: true}) + return err + }); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } diff --git a/internal/infrastructure/db/sqlite/scheduled_session_repo.go b/internal/infrastructure/db/sqlite/scheduled_session_repo.go index d371de1d3..d4d0bdbae 100644 --- a/internal/infrastructure/db/sqlite/scheduled_session_repo.go +++ b/internal/infrastructure/db/sqlite/scheduled_session_repo.go @@ -12,29 +12,32 @@ import ( ) type scheduledSessionRepository struct { - db *sql.DB - querier *queries.Queries + db SQLiteDB } func NewScheduledSessionRepository(config ...interface{}) (domain.ScheduledSessionRepo, error) { if len(config) != 1 { return nil, fmt.Errorf("invalid config: expected 1 argument, got %d", len(config)) } - db, ok := config[0].(*sql.DB) + db, ok := config[0].(SQLiteDB) if !ok { return nil, fmt.Errorf( - "cannot open scheduled session repository: expected *sql.DB but got %T", config[0], + "cannot open scheduled session repository: expected SQLiteDB but got %T", config[0], ) } return &scheduledSessionRepository{ - db: db, - querier: queries.New(db), + db: db, }, nil } func (r *scheduledSessionRepository) Get(ctx context.Context) (*domain.ScheduledSession, error) { - scheduledSession, err := r.querier.SelectLatestScheduledSession(ctx) + var scheduledSession queries.ScheduledSession + err := withReadQuerier(ctx, r.db, func(q *queries.Queries) error { + var err error + scheduledSession, err = q.SelectLatestScheduledSession(ctx) + return err + }) if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -56,19 +59,23 @@ func (r *scheduledSessionRepository) Get(ctx context.Context) (*domain.Scheduled func (r *scheduledSessionRepository) Upsert( ctx context.Context, scheduledSession domain.ScheduledSession, ) error { - return r.querier.UpsertScheduledSession(ctx, queries.UpsertScheduledSessionParams{ - StartTime: scheduledSession.StartTime.Unix(), - EndTime: scheduledSession.EndTime.Unix(), - Period: int64(scheduledSession.Period), - Duration: int64(scheduledSession.Duration), - RoundMinParticipants: scheduledSession.RoundMinParticipantsCount, - RoundMaxParticipants: scheduledSession.RoundMaxParticipantsCount, - UpdatedAt: scheduledSession.UpdatedAt.Unix(), + return withWriteQuerier(ctx, r.db, func(q *queries.Queries) error { + return q.UpsertScheduledSession(ctx, queries.UpsertScheduledSessionParams{ + StartTime: scheduledSession.StartTime.Unix(), + EndTime: scheduledSession.EndTime.Unix(), + Period: int64(scheduledSession.Period), + Duration: int64(scheduledSession.Duration), + RoundMinParticipants: scheduledSession.RoundMinParticipantsCount, + RoundMaxParticipants: scheduledSession.RoundMaxParticipantsCount, + UpdatedAt: scheduledSession.UpdatedAt.Unix(), + }) }) } func (r *scheduledSessionRepository) Clear(ctx context.Context) error { - return r.querier.ClearScheduledSession(ctx) + return withWriteQuerier(ctx, r.db, func(q *queries.Queries) error { + return q.ClearScheduledSession(ctx) + }) } func (r *scheduledSessionRepository) Close() { diff --git a/internal/infrastructure/db/sqlite/sqlc/queries/query.sql.go b/internal/infrastructure/db/sqlite/sqlc/queries/query.sql.go index 0bb137d25..38ac6f87e 100644 --- a/internal/infrastructure/db/sqlite/sqlc/queries/query.sql.go +++ b/internal/infrastructure/db/sqlite/sqlc/queries/query.sql.go @@ -352,6 +352,72 @@ func (q *Queries) SelectAssetsByIds(ctx context.Context, ids []string) ([]Asset, return items, nil } +const selectAssetsWithUnspentAmountsByIds = `-- name: SelectAssetsWithUnspentAmountsByIds :many +SELECT + a.id, + a.is_immutable, + a.metadata_hash, + a.metadata, + a.control_asset_id, + COALESCE(v.asset_amount, '0') AS asset_amount +FROM asset a +LEFT JOIN vtxo_vw v + ON v.asset_id = a.id + AND v.spent = false + AND v.asset_amount > 0 +WHERE a.id IN (/*SLICE:ids*/?) +ORDER BY a.id +` + +type SelectAssetsWithUnspentAmountsByIdsRow struct { + ID string + IsImmutable bool + MetadataHash sql.NullString + Metadata sql.NullString + ControlAssetID sql.NullString + AssetAmount string +} + +func (q *Queries) SelectAssetsWithUnspentAmountsByIds(ctx context.Context, ids []string) ([]SelectAssetsWithUnspentAmountsByIdsRow, error) { + query := selectAssetsWithUnspentAmountsByIds + var queryParams []interface{} + if len(ids) > 0 { + for _, v := range ids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:ids*/?", strings.Repeat(",?", len(ids))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SelectAssetsWithUnspentAmountsByIdsRow + for rows.Next() { + var i SelectAssetsWithUnspentAmountsByIdsRow + if err := rows.Scan( + &i.ID, + &i.IsImmutable, + &i.MetadataHash, + &i.Metadata, + &i.ControlAssetID, + &i.AssetAmount, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const selectControlAssetByID = `-- name: SelectControlAssetByID :one SELECT control_asset_id FROM asset WHERE id = ? ` diff --git a/internal/infrastructure/db/sqlite/sqlc/query.sql b/internal/infrastructure/db/sqlite/sqlc/query.sql index f4e2e9fe1..267efcf03 100644 --- a/internal/infrastructure/db/sqlite/sqlc/query.sql +++ b/internal/infrastructure/db/sqlite/sqlc/query.sql @@ -446,6 +446,22 @@ VALUES (@asset_id, @txid, @vout, @amount); -- name: SelectAssetsByIds :many SELECT * FROM asset WHERE asset.id IN (sqlc.slice('ids')); +-- name: SelectAssetsWithUnspentAmountsByIds :many +SELECT + a.id, + a.is_immutable, + a.metadata_hash, + a.metadata, + a.control_asset_id, + COALESCE(v.asset_amount, '0') AS asset_amount +FROM asset a +LEFT JOIN vtxo_vw v + ON v.asset_id = a.id + AND v.spent = false + AND v.asset_amount > 0 +WHERE a.id IN (sqlc.slice('ids')) +ORDER BY a.id; + -- name: SelectAssetAmounts :many SELECT v.asset_amount FROM vtxo_vw v WHERE v.asset_id = ? AND v.spent = false AND v.asset_amount > 0; @@ -454,4 +470,4 @@ WHERE v.asset_id = ? AND v.spent = false AND v.asset_amount > 0; SELECT control_asset_id FROM asset WHERE id = ?; -- name: SelectAssetExists :one -SELECT 1 FROM asset WHERE id = ? LIMIT 1; \ No newline at end of file +SELECT 1 FROM asset WHERE id = ? LIMIT 1; diff --git a/internal/infrastructure/db/sqlite/utils.go b/internal/infrastructure/db/sqlite/utils.go index 075ce0594..a08278dd3 100644 --- a/internal/infrastructure/db/sqlite/utils.go +++ b/internal/infrastructure/db/sqlite/utils.go @@ -3,14 +3,19 @@ package sqlitedb import ( "context" "database/sql" + "database/sql/driver" + "errors" "fmt" "os" "path/filepath" + "runtime" "strings" "time" "github.com/arkade-os/arkd/internal/infrastructure/db/sqlite/sqlc/queries" - _ "modernc.org/sqlite" + log "github.com/sirupsen/logrus" + sqlite "modernc.org/sqlite" + sqlite3 "modernc.org/sqlite/lib" ) const ( @@ -18,7 +23,68 @@ const ( maxRetries = 5 ) -func OpenDb(dbPath string) (*sql.DB, error) { +type SQLiteDB interface { + Read() *sql.DB + Write() *sql.DB + Close() error +} + +type sqliteDB struct { + readDB *sql.DB + writeDB *sql.DB +} + +type openOptions struct { + sharedCache bool + journalModeWAL bool + busyTimeout *time.Duration +} + +// OpenOption configures how OpenDb builds the SQLite DSN. +type OpenOption func(*openOptions) + +// WithSharedCache enables SQLite shared-cache mode. +func WithSharedCache() OpenOption { + return func(opts *openOptions) { + opts.sharedCache = true + } +} + +// WithJournalModeWAL enables WAL journaling mode. +func WithJournalModeWAL() OpenOption { + return func(opts *openOptions) { + opts.journalModeWAL = true + } +} + +// WithBusyTimeout sets SQLite's busy timeout pragma. +func WithBusyTimeout(d time.Duration) OpenOption { + return func(opts *openOptions) { + opts.busyTimeout = &d + } +} + +func (s *sqliteDB) Read() *sql.DB { + return s.readDB +} + +func (s *sqliteDB) Write() *sql.DB { + return s.writeDB +} + +func (s *sqliteDB) Close() error { + readErr := s.readDB.Close() + writeErr := s.writeDB.Close() + return errors.Join(readErr, writeErr) +} + +// OpenDb returns a split SQLite handle with separate read and write pools. +func OpenDb(dbPath string, opts ...OpenOption) (SQLiteDB, error) { + openOpts := openOptions{} + for _, opt := range opts { + opt(&openOpts) + } + dir := filepath.Dir(dbPath) if _, err := os.Stat(dir); os.IsNotExist(err) { err = os.MkdirAll(dir, 0755) @@ -27,14 +93,59 @@ func OpenDb(dbPath string) (*sql.DB, error) { } } - db, err := sql.Open(driverName, dbPath) + dsn := buildDSN(dbPath, openOpts) + + readDB, err := sql.Open(driverName, dsn) + if err != nil { + return nil, fmt.Errorf("failed to open read db: %w", err) + } + readDB.SetMaxOpenConns(runtime.NumCPU()) + + // single connection writer pool + writeDB, err := sql.Open(driverName, dsn) if err != nil { - return nil, fmt.Errorf("failed to open db: %w", err) + _ = readDB.Close() + return nil, fmt.Errorf("failed to open write db: %w", err) + } + writeDB.SetMaxOpenConns(1) + + // Check there are no errors when opening a connection + if err := writeDB.Ping(); err != nil { + _ = readDB.Close() + _ = writeDB.Close() + return nil, fmt.Errorf("failed to ping write db: %w", err) } - db.SetMaxOpenConns(1) // prevent concurrent writes + if err := readDB.Ping(); err != nil { + _ = readDB.Close() + _ = writeDB.Close() + return nil, fmt.Errorf("failed to ping read db: %w", err) + } - return db, nil + return &sqliteDB{ + readDB: readDB, + writeDB: writeDB, + }, nil +} + +func buildDSN(dbPath string, opts openOptions) string { + params := make([]string, 0, 3) + if opts.sharedCache { + params = append(params, "cache=shared") + } + if opts.journalModeWAL { + params = append(params, "_pragma=journal_mode(WAL)") + } + if opts.busyTimeout != nil { + params = append( + params, + fmt.Sprintf("_pragma=busy_timeout(%d)", opts.busyTimeout.Milliseconds()), + ) + } + if len(params) == 0 { + return dbPath + } + return dbPath + "?" + strings.Join(params, "&") } func extendArray[T any](arr []T, position int) []T { @@ -49,21 +160,36 @@ func extendArray[T any](arr []T, position int) []T { return arr } +// execTx runs txBody on a pinned write connection. +// +// The connection is kept for the full transaction so SQLite interrupt/cancel +// state stays scoped to that connection. Conflict-like errors are retried, and +// interrupted connections are discarded instead of being returned to the pool. func execTx( ctx context.Context, db *sql.DB, txBody func(*queries.Queries) error, ) error { var lastErr error for range maxRetries { - tx, err := db.BeginTx(ctx, nil) + conn, err := db.Conn(ctx) + if err != nil { + return fmt.Errorf("failed to acquire connection: %w", err) + } + + tx, err := conn.BeginTx(ctx, nil) if err != nil { + _ = closeConn(conn, isInterruptError(ctx, err)) return fmt.Errorf("failed to begin transaction: %w", err) } - qtx := queries.New(db).WithTx(tx) + qtx := queries.New(conn).WithTx(tx) if err := txBody(qtx); err != nil { //nolint:all tx.Rollback() + if closeErr := closeConn(conn, isInterruptError(ctx, err)); closeErr != nil { + return fmt.Errorf("%w: %w", err, closeErr) + } + if isConflictError(err) { lastErr = err time.Sleep(100 * time.Millisecond) @@ -74,6 +200,9 @@ func execTx( // Commit the transaction if err := tx.Commit(); err != nil { + if closeErr := closeConn(conn, isInterruptError(ctx, err)); closeErr != nil { + return fmt.Errorf("failed to commit transaction: %w: %w", err, closeErr) + } if isConflictError(err) { lastErr = err time.Sleep(100 * time.Millisecond) @@ -81,12 +210,118 @@ func execTx( } return fmt.Errorf("failed to commit transaction: %w", err) } + + if err := closeConn(conn, false); err != nil { + log.WithError(err).Warn("failed to close connection after successful commit") + } return nil } return lastErr } +// withReadQuerier runs fn on a pinned read connection. +// +// If the read is canceled or interrupted, the connection is discarded so later +// callers do not reuse a tainted SQLite connection from the pool. +func withReadQuerier( + ctx context.Context, db SQLiteDB, fn func(*queries.Queries) error, +) error { + conn, err := db.Read().Conn(ctx) + if err != nil { + return fmt.Errorf("failed to acquire connection: %w", err) + } + + err = fn(queries.New(conn)) + if closeErr := closeConn(conn, isInterruptError(ctx, err)); closeErr != nil { + if err != nil { + return fmt.Errorf("%w: %w", err, closeErr) + } + return closeErr + } + + return err +} + +// withWriteQuerier runs fn on a pinned write connection. +// +// Even non-transactional writes go through an explicit connection so interrupt +// handling can discard the connection when SQLite reports it as tainted. +func withWriteQuerier( + ctx context.Context, db SQLiteDB, fn func(*queries.Queries) error, +) error { + conn, err := db.Write().Conn(ctx) + if err != nil { + return fmt.Errorf("failed to acquire connection: %w", err) + } + + err = fn(queries.New(conn)) + if closeErr := closeConn(conn, isInterruptError(ctx, err)); closeErr != nil { + if err != nil { + return fmt.Errorf("%w: %w", err, closeErr) + } + return closeErr + } + + return err +} + +// closeConn closes conn and optionally forces the sql pool to discard it. +// +// When discard is true we surface driver.ErrBadConn through Raw so database/sql +// treats the underlying SQLite connection as unusable and does not recycle it. +func closeConn(conn *sql.Conn, discard bool) error { + if conn == nil { + return nil + } + + if discard { + err := conn.Raw(func(any) error { + return driver.ErrBadConn + }) + if errors.Is(err, driver.ErrBadConn) { + err = nil + } + if err != nil { + _ = conn.Close() + return fmt.Errorf("failed to discard tainted connection: %w", err) + } + } + + if err := conn.Close(); err != nil { + if discard && errors.Is(err, sql.ErrConnDone) { + return nil + } + return fmt.Errorf("failed to close connection: %w", err) + } + + return nil +} + +// isInterruptError reports whether err is a context cancellation or a SQLite +// interrupt. +func isInterruptError(ctx context.Context, err error) bool { + if err == nil { + return false + } + + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } + + if ctx != nil && ctx.Err() != nil && errors.Is(err, ctx.Err()) { + return true + } + + var sqliteErr *sqlite.Error + if errors.As(err, &sqliteErr) { + code := sqliteErr.Code() + return code == sqlite3.SQLITE_INTERRUPT || code&0xff == sqlite3.SQLITE_INTERRUPT + } + + return false +} + func isConflictError(err error) bool { if err == nil { return false diff --git a/internal/infrastructure/db/sqlite/vtxo_repo.go b/internal/infrastructure/db/sqlite/vtxo_repo.go index fd86e1fc5..ef17c7a6b 100644 --- a/internal/infrastructure/db/sqlite/vtxo_repo.go +++ b/internal/infrastructure/db/sqlite/vtxo_repo.go @@ -14,22 +14,20 @@ import ( ) type vtxoRepository struct { - db *sql.DB - querier *queries.Queries + db SQLiteDB } func NewVtxoRepository(config ...interface{}) (domain.VtxoRepository, error) { if len(config) != 1 { return nil, fmt.Errorf("invalid config") } - db, ok := config[0].(*sql.DB) + db, ok := config[0].(SQLiteDB) if !ok { return nil, fmt.Errorf("cannot open vtxo repository: invalid config") } return &vtxoRepository{ - db: db, - querier: queries.New(db), + db: db, }, nil } @@ -100,14 +98,18 @@ func (v *vtxoRepository) AddVtxos(ctx context.Context, vtxos []domain.Vtxo) erro return nil } - return execTx(ctx, v.db, txBody) + return execTx(ctx, v.db.Write(), txBody) } func (v *vtxoRepository) GetAllSweepableUnrolledVtxos( ctx context.Context, ) ([]domain.Vtxo, error) { - res, err := v.querier.SelectSweepableUnrolledVtxos(ctx) - if err != nil { + var res []queries.SelectSweepableUnrolledVtxosRow + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + res, err = q.SelectSweepableUnrolledVtxos(ctx) + return err + }); err != nil { return nil, err } @@ -124,24 +126,30 @@ func (v *vtxoRepository) GetAllNonUnrolledVtxos( withPubkey := len(pubkey) > 0 var rows []queries.VtxoVw - if withPubkey { - res, err := v.querier.SelectNotUnrolledVtxosWithPubkey(ctx, pubkey) - if err != nil { - return nil, nil, err - } - rows = make([]queries.VtxoVw, 0, len(res)) - for _, row := range res { - rows = append(rows, row.VtxoVw) + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + if withPubkey { + res, err := q.SelectNotUnrolledVtxosWithPubkey(ctx, pubkey) + if err != nil { + return err + } + rows = make([]queries.VtxoVw, 0, len(res)) + for _, row := range res { + rows = append(rows, row.VtxoVw) + } + return nil } - } else { - res, err := v.querier.SelectNotUnrolledVtxos(ctx) + + res, err := q.SelectNotUnrolledVtxos(ctx) if err != nil { - return nil, nil, err + return err } rows = make([]queries.VtxoVw, 0, len(res)) for _, row := range res { rows = append(rows, row.VtxoVw) } + return nil + }); err != nil { + return nil, nil, err } vtxos, err := readRows(rows) @@ -167,45 +175,52 @@ func (v *vtxoRepository) GetVtxos( ctx context.Context, outpoints []domain.Outpoint, ) ([]domain.Vtxo, error) { vtxos := make([]domain.Vtxo, 0, len(outpoints)) - for _, o := range outpoints { - res, err := v.querier.SelectVtxo( - ctx, - queries.SelectVtxoParams{ - Txid: o.Txid, - Vout: int64(o.VOut), - }, - ) - if err != nil { - return nil, err - } + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + for _, o := range outpoints { + res, err := q.SelectVtxo( + ctx, + queries.SelectVtxoParams{Txid: o.Txid, Vout: int64(o.VOut)}, + ) + if err != nil { + return err + } - if len(res) == 0 { - continue - } + if len(res) == 0 { + continue + } - rows := make([]queries.VtxoVw, 0, len(res)) - for _, row := range res { - rows = append(rows, row.VtxoVw) - } + rows := make([]queries.VtxoVw, 0, len(res)) + for _, row := range res { + rows = append(rows, row.VtxoVw) + } - result, err := readRows(rows) - if err != nil { - return nil, err - } + result, err := readRows(rows) + if err != nil { + return err + } - if len(result) == 0 { - continue + if len(result) == 0 { + continue + } + + vtxos = append(vtxos, result[0]) } - vtxos = append(vtxos, result[0]) + return nil + }); err != nil { + return nil, err } return vtxos, nil } func (v *vtxoRepository) GetAllVtxos(ctx context.Context) ([]domain.Vtxo, error) { - res, err := v.querier.SelectAllVtxos(ctx) - if err != nil { + var res []queries.SelectAllVtxosRow + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + res, err = q.SelectAllVtxos(ctx) + return err + }); err != nil { return nil, err } rows := make([]queries.VtxoVw, 0, len(res)) @@ -219,14 +234,15 @@ func (v *vtxoRepository) GetAllVtxos(ctx context.Context) ([]domain.Vtxo, error) func (v *vtxoRepository) GetExpiringLiquidity( ctx context.Context, after, before int64, ) (uint64, error) { - amount, err := v.querier.SelectExpiringLiquidityAmount( - ctx, - queries.SelectExpiringLiquidityAmountParams{ - After: after, - Before: before, - }, - ) - if err != nil { + var amount interface{} + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + amount, err = q.SelectExpiringLiquidityAmount( + ctx, + queries.SelectExpiringLiquidityAmountParams{After: after, Before: before}, + ) + return err + }); err != nil { return 0, err } @@ -241,13 +257,17 @@ func (v *vtxoRepository) GetExpiringLiquidity( } func (v *vtxoRepository) GetRecoverableLiquidity(ctx context.Context) (uint64, error) { - amount, err := v.querier.SelectRecoverableLiquidityAmount(ctx) - if err != nil { + var amount interface{} + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + amount, err = q.SelectRecoverableLiquidityAmount(ctx) + return err + }); err != nil { return 0, err } n, ok := amount.(int64) if !ok { - return 0, nil + return 0, fmt.Errorf("unexpected type for recoverable liquidity: %T", amount) } if n < 0 { return 0, fmt.Errorf("data integrity issue: got negative value %d", n) @@ -258,8 +278,12 @@ func (v *vtxoRepository) GetRecoverableLiquidity(ctx context.Context) (uint64, e func (v *vtxoRepository) GetLeafVtxosForBatch( ctx context.Context, txid string, ) ([]domain.Vtxo, error) { - res, err := v.querier.SelectRoundVtxoTreeLeaves(ctx, txid) - if err != nil { + var res []queries.SelectRoundVtxoTreeLeavesRow + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + res, err = q.SelectRoundVtxoTreeLeaves(ctx, txid) + return err + }); err != nil { return nil, err } rows := make([]queries.VtxoVw, 0, len(res)) @@ -283,7 +307,7 @@ func (v *vtxoRepository) UnrollVtxos(ctx context.Context, vtxos []domain.Outpoin return nil } - return execTx(ctx, v.db, txBody) + return execTx(ctx, v.db.Write(), txBody) } func (v *vtxoRepository) SettleVtxos( @@ -307,7 +331,7 @@ func (v *vtxoRepository) SettleVtxos( return nil } - return execTx(ctx, v.db, txBody) + return execTx(ctx, v.db.Write(), txBody) } func (v *vtxoRepository) SpendVtxos( @@ -331,7 +355,7 @@ func (v *vtxoRepository) SpendVtxos( return nil } - return execTx(ctx, v.db, txBody) + return execTx(ctx, v.db.Write(), txBody) } func (v *vtxoRepository) SweepVtxos(ctx context.Context, vtxos []domain.Outpoint) (int, error) { @@ -356,7 +380,7 @@ func (v *vtxoRepository) SweepVtxos(ctx context.Context, vtxos []domain.Outpoint return nil } - if err := execTx(ctx, v.db, txBody); err != nil { + if err := execTx(ctx, v.db.Write(), txBody); err != nil { return -1, err } @@ -383,7 +407,7 @@ func (v *vtxoRepository) UpdateVtxosExpiration( return nil } - return execTx(ctx, v.db, txBody) + return execTx(ctx, v.db.Write(), txBody) } func (v *vtxoRepository) GetAllVtxosWithPubKeys( @@ -392,12 +416,16 @@ func (v *vtxoRepository) GetAllVtxosWithPubKeys( if err := validateTimeRange(after, before); err != nil { return nil, err } - res, err := v.querier.SelectVtxosWithPubkeys(ctx, queries.SelectVtxosWithPubkeysParams{ - Pubkeys: pubkeys, - After: sql.NullInt64{Int64: after, Valid: true}, - Before: before, - }) - if err != nil { + var res []queries.SelectVtxosWithPubkeysRow + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + res, err = q.SelectVtxosWithPubkeys(ctx, queries.SelectVtxosWithPubkeysParams{ + Pubkeys: pubkeys, + After: sql.NullInt64{Int64: after, Valid: true}, + Before: before, + }) + return err + }); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -425,8 +453,12 @@ func (v *vtxoRepository) GetSweepableVtxosByCommitmentTxid( ) ( []domain.Outpoint, error, ) { - res, err := v.querier.SelectSweepableVtxoOutpointsByCommitmentTxid(ctx, commitmentTxid) - if err != nil { + var res []queries.SelectSweepableVtxoOutpointsByCommitmentTxidRow + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + res, err = q.SelectSweepableVtxoOutpointsByCommitmentTxid(ctx, commitmentTxid) + return err + }); err != nil { return nil, err } @@ -444,8 +476,12 @@ func (v *vtxoRepository) GetSweepableVtxosByCommitmentTxid( func (v *vtxoRepository) GetAllChildrenVtxos( ctx context.Context, txid string, ) ([]domain.Outpoint, error) { - res, err := v.querier.SelectVtxosOutpointsByArkTxidRecursive(ctx, txid) - if err != nil { + var res []queries.SelectVtxosOutpointsByArkTxidRecursiveRow + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + res, err = q.SelectVtxosOutpointsByArkTxidRecursive(ctx, txid) + return err + }); err != nil { return nil, err } @@ -467,12 +503,16 @@ func (v *vtxoRepository) GetVtxoPubKeysByCommitmentTxid( return nil, nil } - taprootKeys, err := v.querier.SelectVtxoPubKeysByCommitmentTxid(ctx, - queries.SelectVtxoPubKeysByCommitmentTxidParams{ - MinAmount: int64(withMinimumAmount), - CommitmentTxid: commitmentTxid, - }) - if err != nil { + var taprootKeys []string + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + taprootKeys, err = q.SelectVtxoPubKeysByCommitmentTxid(ctx, + queries.SelectVtxoPubKeysByCommitmentTxidParams{ + MinAmount: int64(withMinimumAmount), + CommitmentTxid: commitmentTxid, + }) + return err + }); err != nil { return nil, err } @@ -485,15 +525,19 @@ func (v *vtxoRepository) GetPendingSpentVtxosWithPubKeys( if err := validateTimeRange(after, before); err != nil { return nil, err } - rows, err := v.querier.SelectPendingSpentVtxosWithPubkeys( - ctx, - queries.SelectPendingSpentVtxosWithPubkeysParams{ - Pubkeys: pubkeys, - After: sql.NullInt64{Int64: after, Valid: true}, - Before: before, - }, - ) - if err != nil { + var rows []queries.VtxoVw + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + var err error + rows, err = q.SelectPendingSpentVtxosWithPubkeys( + ctx, + queries.SelectPendingSpentVtxosWithPubkeysParams{ + Pubkeys: pubkeys, + After: sql.NullInt64{Int64: after, Valid: true}, + Before: before, + }, + ) + return err + }); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -515,27 +559,34 @@ func (v *vtxoRepository) GetPendingSpentVtxosWithOutpoints( ctx context.Context, outpoints []domain.Outpoint, ) ([]domain.Vtxo, error) { var vtxos []domain.Vtxo - for _, outpoint := range outpoints { - res, err := v.querier.SelectPendingSpentVtxo( - ctx, queries.SelectPendingSpentVtxoParams{ - Txid: outpoint.Txid, - Vout: int64(outpoint.VOut), - }, - ) - if err != nil { - return nil, err - } + if err := withReadQuerier(ctx, v.db, func(q *queries.Queries) error { + for _, outpoint := range outpoints { + res, err := q.SelectPendingSpentVtxo( + ctx, + queries.SelectPendingSpentVtxoParams{ + Txid: outpoint.Txid, + Vout: int64(outpoint.VOut), + }, + ) + if err != nil { + return err + } - if len(res) == 0 { - continue - } + if len(res) == 0 { + continue + } - result, err := readRows(res) - if err != nil { - return nil, err + result, err := readRows(res) + if err != nil { + return err + } + + vtxos = append(vtxos, result...) } - vtxos = append(vtxos, result...) + return nil + }); err != nil { + return nil, err } sort.SliceStable(vtxos, func(i, j int) bool { diff --git a/internal/test/e2e/e2e_test.go b/internal/test/e2e/e2e_test.go index 9fa5d6ae2..d68c9443c 100644 --- a/internal/test/e2e/e2e_test.go +++ b/internal/test/e2e/e2e_test.go @@ -43,6 +43,8 @@ import ( "github.com/btcsuite/btcwallet/waddrmgr" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) const ( @@ -387,100 +389,100 @@ func TestUnilateralExit(t *testing.T) { // offchain. func TestUnrolledVtxoRejoinBatch(t *testing.T) { t.Run("without asset", func(t *testing.T) { - ctx := t.Context() - alice := setupArkSDK(t) + ctx := t.Context() + alice := setupArkSDK(t) - // Fund Alice offchain + small onchain amount for unroll fees - faucet(t, alice, 0.00021) - time.Sleep(5 * time.Second) + // Fund Alice offchain + small onchain amount for unroll fees + faucet(t, alice, 0.00021) + time.Sleep(5 * time.Second) - _, offchainAddr, _, err := alice.Receive(ctx) - require.NoError(t, err) + _, offchainAddr, _, err := alice.Receive(ctx) + require.NoError(t, err) - balance, err := alice.Balance(ctx) - require.NoError(t, err) - require.NotZero(t, balance.OffchainBalance.Total) - require.Empty(t, balance.OnchainBalance.LockedAmount) + balance, err := alice.Balance(ctx) + require.NoError(t, err) + require.NotZero(t, balance.OffchainBalance.Total) + require.Empty(t, balance.OnchainBalance.LockedAmount) - // Unroll: moves VTXOs onchain + // Unroll: moves VTXOs onchain txids, err := alice.Unroll(ctx) - require.NoError(t, err) + require.NoError(t, err) require.NotEmpty(t, txids) - err = generateBlocks(1) - require.NoError(t, err) + err = generateBlocks(1) + require.NoError(t, err) - // Poll for the wallet to index the new block instead of sleeping a fixed - // interval — every second spent here eats into the unrolled VTXO's CSV - // runway before the subsequent Settle call. - require.Eventually(t, func() bool { - b, err := alice.Balance(ctx) - if err != nil { - return false - } - return b.OffchainBalance.Total == 0 && - len(b.OnchainBalance.LockedAmount) > 0 && - b.OnchainBalance.LockedAmount[0].Amount > 0 - }, 15*time.Second, 200*time.Millisecond, "unroll did not settle onchain in time") + // Poll for the wallet to index the new block instead of sleeping a fixed + // interval — every second spent here eats into the unrolled VTXO's CSV + // runway before the subsequent Settle call. + require.Eventually(t, func() bool { + b, err := alice.Balance(ctx) + if err != nil { + return false + } + return b.OffchainBalance.Total == 0 && + len(b.OnchainBalance.LockedAmount) > 0 && + b.OnchainBalance.LockedAmount[0].Amount > 0 + }, 15*time.Second, 200*time.Millisecond, "unroll did not settle onchain in time") - balance, err = alice.Balance(ctx) - require.NoError(t, err) + balance, err = alice.Balance(ctx) + require.NoError(t, err) - // Find the unrolled VTXO in the spent list - _, spentVtxos, err := alice.ListVtxos(ctx) - require.NoError(t, err) + // Find the unrolled VTXO in the spent list + _, spentVtxos, err := alice.ListVtxos(ctx) + require.NoError(t, err) - var unrolledVtxo types.Vtxo - for _, v := range spentVtxos { - if v.Unrolled && !v.Spent { - unrolledVtxo = v - break + var unrolledVtxo types.Vtxo + for _, v := range spentVtxos { + if v.Unrolled && !v.Spent { + unrolledVtxo = v + break + } + } + require.NotZero(t, unrolledVtxo.Amount, "expected an unrolled VTXO") + + // Receive returns *types.Address which carries Tapscripts — use them + // to present the unrolled VTXO as a boarding input. + boardingUtxo := types.Utxo{ + Outpoint: unrolledVtxo.Outpoint, + Amount: unrolledVtxo.Amount, + Tapscripts: offchainAddr.Tapscripts, } - } - require.NotZero(t, unrolledVtxo.Amount, "expected an unrolled VTXO") - - // Receive returns *types.Address which carries Tapscripts — use them - // to present the unrolled VTXO as a boarding input. - boardingUtxo := types.Utxo{ - Outpoint: unrolledVtxo.Outpoint, - Amount: unrolledVtxo.Amount, - Tapscripts: offchainAddr.Tapscripts, - } - // Rejoin the batch — unrolled VTXO should be accepted as a boarding input - wg := &sync.WaitGroup{} - wg.Add(1) - var incomingErr error - go func() { - _, incomingErr = alice.NotifyIncomingFunds(ctx, offchainAddr.Address) - wg.Done() - }() + // Rejoin the batch — unrolled VTXO should be accepted as a boarding input + wg := &sync.WaitGroup{} + wg.Add(1) + var incomingErr error + go func() { + _, incomingErr = alice.NotifyIncomingFunds(ctx, offchainAddr.Address) + wg.Done() + }() - res, err := alice.Settle(ctx, - arksdk.WithFunds([]types.Utxo{boardingUtxo}, nil), - ) - require.NoError(t, err) - require.NotEmpty(t, res.CommitmentTxid) + res, err := alice.Settle(ctx, + arksdk.WithFunds([]types.Utxo{boardingUtxo}, nil), + ) + require.NoError(t, err) + require.NotEmpty(t, res.CommitmentTxid) - wg.Wait() - require.NoError(t, incomingErr) - time.Sleep(time.Second) + wg.Wait() + require.NoError(t, incomingErr) + time.Sleep(time.Second) - // Alice has offchain funds again - balance, err = alice.Balance(ctx) - require.NoError(t, err) - require.NotZero(t, balance.OffchainBalance.Total) + // Alice has offchain funds again + balance, err = alice.Balance(ctx) + require.NoError(t, err) + require.NotZero(t, balance.OffchainBalance.Total) - // Once the unrolled VTXO has been accepted into a batch, the onchain - // UTXO is spent. Mining past the unilateral exit delay and calling - // CompleteUnroll should find no mature funds to claim. - err = generateBlocks(20) - require.NoError(t, err) + // Once the unrolled VTXO has been accepted into a batch, the onchain + // UTXO is spent. Mining past the unilateral exit delay and calling + // CompleteUnroll should find no mature funds to claim. + err = generateBlocks(20) + require.NoError(t, err) - time.Sleep(5 * time.Second) + time.Sleep(5 * time.Second) - _, err = alice.CompleteUnroll(ctx, "") - require.ErrorContains(t, err, "no mature funds available") + _, err = alice.CompleteUnroll(ctx, "") + require.ErrorContains(t, err, "no mature funds available") }) } @@ -5421,6 +5423,341 @@ func TestAsset(t *testing.T) { }) } +func TestGetAssetQueryChurn(t *testing.T) { + ctx := t.Context() + + const supply = 200 + // join a batch after n offchain sends + const batchInterval = 10 + const assetQueryWorkers = 4 + + alice := setupArkSDK(t) + bob := setupArkSDK(t) + + faucetOffchain(t, alice, 0.002) + faucetOffchain(t, bob, 0.002) + + _, aliceOffchainAddr, _, err := alice.Receive(ctx) + require.NoError(t, err) + aliceOffchainAddrDecoded, err := arklib.DecodeAddressV0(aliceOffchainAddr.Address) + require.NoError(t, err) + aliceP2TR, err := script.P2TRScript(aliceOffchainAddrDecoded.VtxoTapKey) + require.NoError(t, err) + aliceP2TRStr := hex.EncodeToString(aliceP2TR) + + _, bobOffchainAddr, _, err := bob.Receive(ctx) + require.NoError(t, err) + bobOffchainAddrDecoded, err := arklib.DecodeAddressV0(bobOffchainAddr.Address) + require.NoError(t, err) + bobP2TR, err := script.P2TRScript(bobOffchainAddrDecoded.VtxoTapKey) + require.NoError(t, err) + bobP2TRStr := hex.EncodeToString(bobP2TR) + + _, aliceEvtCh, closeFn, err := alice.Indexer().NewSubscription(ctx, []string{aliceP2TRStr}) + require.NoError(t, err) + defer closeFn() + + _, bobEvtCh, closeFn, err := bob.Indexer().NewSubscription(ctx, []string{bobP2TRStr}) + require.NoError(t, err) + defer closeFn() + + recvVtxosTimeout := time.Second * 20 + + var aliceRecvErr, bobRecvErr error + + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + // expect 1 asset vtxo + change vtxo + _, aliceRecvErr = waitForVTXOs(aliceEvtCh, 2, recvVtxosTimeout) + wg.Done() + }() + go func() { + // expect 1 asset vtxo + change vtxo + _, bobRecvErr = waitForVTXOs(bobEvtCh, 2, recvVtxosTimeout) + wg.Done() + }() + + res, err := alice.IssueAsset(ctx, supply, nil, nil) + require.NoError(t, err) + require.NotNil(t, res) + require.NotEmpty(t, res.Txid) + require.Len(t, res.IssuedAssets, 1) + aliceAssetID := res.IssuedAssets[0].String() + + res, err = bob.IssueAsset(ctx, supply, nil, nil) + require.NoError(t, err) + require.NotNil(t, res) + require.NotEmpty(t, res.Txid) + require.Len(t, res.IssuedAssets, 1) + bobAssetID := res.IssuedAssets[0].String() + + wg.Wait() + + require.NoError(t, aliceRecvErr) + require.NoError(t, bobRecvErr) + + time.Sleep(2 * time.Second) + + stressCtx, cancelStress := context.WithCancel(ctx) + errCh := make(chan error, assetQueryWorkers) + var canceledAssetCalls atomic.Int64 + + assetTargets := []struct { + client arksdk.ArkClient + assetID string + }{ + {client: alice, assetID: aliceAssetID}, + {client: bob, assetID: bobAssetID}, + } + + assetQueryWG := &sync.WaitGroup{} + assetQueryWG.Add(assetQueryWorkers) + for i := range assetQueryWorkers { + // repeatedly issue and cancel GetAsset query requests + go func(workerID int) { + defer assetQueryWG.Done() + + // staggered start + time.Sleep(time.Duration(workerID) * time.Millisecond) + + target := assetTargets[workerID%len(assetTargets)] + for stressCtx.Err() == nil { + callCtx, cancel := context.WithTimeout(stressCtx, 50*time.Millisecond) + done := make(chan error, 1) + go func() { + _, getAssetErr := target.client.Indexer().GetAsset(callCtx, target.assetID) + done <- getAssetErr + }() + + time.Sleep(5 * time.Millisecond) + cancel() + + getAssetErr := <-done + if getAssetErr != nil { + if st, ok := status.FromError(getAssetErr); ok { + switch st.Code() { + case codes.Canceled, codes.DeadlineExceeded: + canceledAssetCalls.Add(1) + continue + case codes.Internal: + errMsg := strings.ToLower(st.Message()) + if strings.Contains(errMsg, "context") { + canceledAssetCalls.Add(1) + continue + } + } + } + + select { + case errCh <- fmt.Errorf("asset query worker %d: %w", workerID, getAssetErr): + default: + } + return + } + } + }(i) + } + defer func() { + cancelStress() + assetQueryWG.Wait() + }() + + var aliceSendErr, bobSendErr error + var aliceSendRes, bobSendRes *arksdk.SendOffChainRes + var aliceRecvd, bobRecvd []types.Vtxo + + for i := range supply { + completed := i + 1 + + sendWg := &sync.WaitGroup{} + sendWg.Add(2) + recvWg := &sync.WaitGroup{} + recvWg.Add(2) + + go func() { + // expect 1 asset from bob + change vtxo + aliceRecvd, aliceRecvErr = waitForVTXOs(aliceEvtCh, 2, recvVtxosTimeout) + recvWg.Done() + }() + go func() { + // expect 1 asset from alice + change vtxo + bobRecvd, bobRecvErr = waitForVTXOs(bobEvtCh, 2, recvVtxosTimeout) + recvWg.Done() + }() + go func() { + aliceSendRes, aliceSendErr = alice.SendOffChain(ctx, []types.Receiver{{ + To: bobOffchainAddr.Address, + Amount: 330, + Assets: []types.Asset{{ + AssetId: aliceAssetID, + Amount: 1, + }}, + }}) + sendWg.Done() + }() + go func() { + bobSendRes, bobSendErr = bob.SendOffChain(ctx, []types.Receiver{{ + To: aliceOffchainAddr.Address, + Amount: 330, + Assets: []types.Asset{{ + AssetId: bobAssetID, + Amount: 1, + }}, + }}) + sendWg.Done() + }() + + sendWg.Wait() + require.NoErrorf(t, aliceSendErr, "send %d/%d failed", completed, supply) + require.NoErrorf(t, bobSendErr, "send %d/%d failed", completed, supply) + + recvWg.Wait() + require.NoError(t, aliceRecvErr, "receiving vtxos for send %s %d/%d failed", + aliceSendRes.Txid, completed, supply) + require.NoError(t, bobRecvErr, "receiving vtxos for send %s %d/%d failed", + bobSendRes.Txid, completed, supply) + + outpoints := make([]types.Outpoint, 0) + spentVtxos := make([]types.Outpoint, 0) + unspentVtxos := make([]types.Outpoint, 0) + for _, input := range aliceSendRes.Inputs { + outpoints = append(outpoints, input.Outpoint) + spentVtxos = append(spentVtxos, input.Outpoint) + } + for _, input := range bobSendRes.Inputs { + outpoints = append(outpoints, input.Outpoint) + spentVtxos = append(spentVtxos, input.Outpoint) + } + for _, output := range aliceRecvd { + outpoints = append(outpoints, output.Outpoint) + unspentVtxos = append(unspentVtxos, output.Outpoint) + } + for _, output := range bobRecvd { + outpoints = append(outpoints, output.Outpoint) + unspentVtxos = append(unspentVtxos, output.Outpoint) + } + + dbVtxos := make(map[types.Outpoint]types.Vtxo) + vtxosInDBDeadline := time.Now().Add(10 * time.Second) + for time.Now().Before(vtxosInDBDeadline) { + res, err := alice.Indexer(). + GetVtxos(ctx, indexer.WithOutpoints(outpoints)) + require.NoError(t, err) + + if len(res.Vtxos) == len(outpoints) { + vtxos := res.Vtxos + for _, v := range vtxos { + dbVtxos[v.Outpoint] = v + } + break + } + + time.Sleep(100 * time.Millisecond) + } + + require.Len(t, dbVtxos, len(outpoints), "failed to find all sent/received vtxos in db") + + for _, spent := range spentVtxos { + require.Truef(t, dbVtxos[spent].Spent, "failed to update spent vtxo in db: %s", spent) + } + for _, unspent := range unspentVtxos { + require.Falsef(t, dbVtxos[unspent].Spent, "failed to add new vtxo in db: %s", unspent) + require.Truef(t, dbVtxos[unspent].Preconfirmed, + "failed to add new vtxo in db: %s", unspent) + } + + // start a batch after every batchInterval sends + if completed%batchInterval == 0 { + settleWg := &sync.WaitGroup{} + settleWg.Add(4) + + var aliceSettleErr, bobSettleErr error + var aliceSettleRes, bobSettleRes *arksdk.SettleRes + go func() { + // expect 1 new batch vtxo + _, aliceRecvErr = waitForVTXOs(aliceEvtCh, 1, recvVtxosTimeout) + settleWg.Done() + }() + go func() { + // expect 1 new batch vtxo + _, bobRecvErr = waitForVTXOs(bobEvtCh, 1, recvVtxosTimeout) + settleWg.Done() + }() + go func() { + aliceSettleRes, aliceSettleErr = alice.Settle(ctx) + settleWg.Done() + }() + go func() { + bobSettleRes, bobSettleErr = bob.Settle(ctx) + settleWg.Done() + }() + settleWg.Wait() + + require.NoError(t, aliceRecvErr) + require.NoError(t, bobRecvErr) + require.NoError(t, aliceSettleErr) + require.NoError(t, bobSettleErr) + + // ensure rounds were written to the DB + batchInDbDeadline := time.Now().Add(10 * time.Second) + outpoints := make([]types.Outpoint, 0) + for _, v := range aliceSettleRes.VtxoInputs { + outpoints = append(outpoints, v.Outpoint) + } + + var aliceCtx, bobCtx *indexer.CommitmentTx + var aliceGetCtxErr, bobGetCtxErr error + for time.Now().Before(batchInDbDeadline) { + aliceCtx, aliceGetCtxErr = alice.Indexer(). + GetCommitmentTx(ctx, aliceSettleRes.CommitmentTxid) + bobCtx, bobGetCtxErr = bob.Indexer(). + GetCommitmentTx(ctx, bobSettleRes.CommitmentTxid) + + dbVtxos, err := alice.Indexer().GetVtxos(ctx, indexer.WithOutpoints(outpoints)) + + require.NoError(t, err) + require.Len(t, dbVtxos.Vtxos, len(outpoints)) + + allSpent := true + for _, v := range dbVtxos.Vtxos { + allSpent = v.Spent + if !allSpent { + break + } + } + + if aliceGetCtxErr == nil && bobGetCtxErr == nil && allSpent { + break + } + time.Sleep(100 * time.Millisecond) + } + require.NoError(t, aliceGetCtxErr) + require.Len(t, aliceCtx.Batches, 1, "failed to update completed round in database") + require.NoError(t, bobGetCtxErr) + require.Len(t, bobCtx.Batches, 1, "failed to update completed round in database") + t.Logf("completed %d/%d offchain sends and batch %d/%d", + completed, supply, completed/batchInterval, supply/batchInterval) + } + } + + cancelStress() + assetQueryWG.Wait() + cancelledQueries := canceledAssetCalls.Load() + require.Greater(t, cancelledQueries, int64(0)) + + t.Logf("cancelled query count: %d", cancelledQueries) + + for { + select { + case runErr := <-errCh: + require.NoError(t, runErr) + default: + return + } + } +} + // TestTxListenerChurn verifies that the gRPC transaction stream fanout is // resilient to subscription churn. It runs three concurrent activities: // diff --git a/internal/test/e2e/utils_test.go b/internal/test/e2e/utils_test.go index 394ce07ef..cf5e58818 100644 --- a/internal/test/e2e/utils_test.go +++ b/internal/test/e2e/utils_test.go @@ -2,6 +2,7 @@ package e2e_test import ( "bytes" + "context" "encoding/hex" "encoding/json" "fmt" @@ -809,3 +810,35 @@ func isRetryableChurnError(err error) bool { return false } + +func waitForVTXOs( + ch <-chan indexer.ScriptEvent, + atLeastN int, + timeout time.Duration, +) ([]types.Vtxo, error) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout)) + defer cancel() + vtxos := make([]types.Vtxo, 0) + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("timed out - %d/%d received", len(vtxos), atLeastN) + case evt, ok := <-ch: + if !ok { + return nil, fmt.Errorf("vtxo event channel closed") + } + if evt.Connection != nil { + continue + } + + if evt.Err != nil { + return nil, evt.Err + } + vtxos = append(vtxos, evt.Data.NewVtxos...) + } + + if len(vtxos) >= atLeastN { + return vtxos, nil + } + } +}