Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,15 @@ func (c *Client) AddBond(params []*bond.BondParams) {
c.bondInfo.AddBonds(params, time.Now())
}

// WaitForConnection blocks until the client has an active primary mesh connection
// or the context is done.
func (c *Client) WaitForConnection(ctx context.Context) error {
if c.connManager == nil {
return errNoMeshConnection
}
Comment thread
dnldd marked this conversation as resolved.
return c.connManager.waitForConnection(ctx)
}

Comment thread
dnldd marked this conversation as resolved.
Outdated
// parseBootstrapAddrs parses a list of multiaddr strings into peer.AddrInfo.
// Multiple addresses for the same peer ID are combined into a single AddrInfo.
func parseBootstrapAddrs(addrs []string) ([]peer.AddrInfo, error) {
Expand Down
6 changes: 5 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ func (m *tMeshConnection) fail(err error) {
// setTestMeshConnection sets up a mock mesh connection for testing.
func (c *Client) setTestMeshConnection(mc meshConn) {
if c.connManager == nil {
c.connManager = &meshConnectionManager{}
logBackend := slog.NewBackend(os.Stdout)
logger := logBackend.Logger("test")
c.connManager = newMeshConnectionManager(&meshConnectionManagerConfig{
log: logger,
})
}
c.connManager.setPrimaryConnection(mc)
}
Expand Down
33 changes: 31 additions & 2 deletions client/mesh_connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ type meshConnectionManager struct {

nodesMtx sync.RWMutex
knownNodes []peer.ID

connMtx sync.Mutex
connChanged chan struct{}
}

// meshConnectionManagerConfig holds the configuration for creating a meshConnectionManager.
Expand All @@ -64,6 +67,7 @@ func newMeshConnectionManager(cfg *meshConnectionManagerConfig) *meshConnectionM
host: cfg.host,
log: cfg.log,
connFactory: cfg.connFactory,
connChanged: make(chan struct{}),
}

for _, peer := range cfg.bootstrapPeers {
Expand All @@ -84,11 +88,36 @@ func (m *meshConnectionManager) primaryConnection() (meshConn, error) {
}

func (m *meshConnectionManager) setPrimaryConnection(mc meshConn) {
m.connMtx.Lock()
if mc == nil {
m.primaryConn.Store(nil)
return
} else {
m.primaryConn.Store(&meshConnHolder{mc: mc})
}
oldCh := m.connChanged
m.connChanged = make(chan struct{})
Comment thread
dnldd marked this conversation as resolved.
Outdated
m.connMtx.Unlock()
close(oldCh)
}
Comment thread
dnldd marked this conversation as resolved.

// waitForConnection blocks until the connection manager has an active primary mesh connection
// or the context is done.
func (m *meshConnectionManager) waitForConnection(ctx context.Context) error {
for {
if m.primaryConn.Load() != nil {
return nil
}

m.connMtx.Lock()
ch := m.connChanged
m.connMtx.Unlock()

select {
case <-ch:
case <-ctx.Done():
return ctx.Err()
}
}
m.primaryConn.Store(&meshConnHolder{mc: mc})
}

// connectResult holds the result of a connection attempt.
Expand Down
223 changes: 223 additions & 0 deletions client/mesh_connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,226 @@ func TestMeshConnectionManagerFailover(t *testing.T) {
t.Fatal("mesh connection manager did not stop")
}
}

func TestWaitForConnection(t *testing.T) {
logBackend := slog.NewBackend(os.Stdout)
logger := logBackend.Logger("wait-for-connection-test")

t.Run("already connected returns immediately", func(t *testing.T) {
m := newMeshConnectionManager(&meshConnectionManagerConfig{log: logger})
conn := newTMeshConnection(randomPeerID(t))
m.setPrimaryConnection(conn)

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

start := time.Now()
err := m.waitForConnection(ctx)
elapsed := time.Since(start)

if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if elapsed > 50*time.Millisecond {
t.Fatalf("expected immediate return, took %v", elapsed)
}
})

t.Run("waits for connection when disconnected", func(t *testing.T) {
m := newMeshConnectionManager(&meshConnectionManagerConfig{log: logger})

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

// Start waiter in goroutine
done := make(chan error, 1)
go func() {
done <- m.waitForConnection(ctx)
}()

// Wait a bit to ensure goroutine is blocking
time.Sleep(100 * time.Millisecond)

// Connect
conn := newTMeshConnection(randomPeerID(t))
m.setPrimaryConnection(conn)

// Waiter should return
select {
case err := <-done:
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("waiter did not return after connection")
}
})

t.Run("multiple concurrent waiters all wake up", func(t *testing.T) {
m := newMeshConnectionManager(&meshConnectionManagerConfig{log: logger})

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

// Start multiple waiters
numWaiters := 5
done := make(chan error, numWaiters)
for i := 0; i < numWaiters; i++ {
go func() {
done <- m.waitForConnection(ctx)
}()
}

// Wait a bit to ensure all goroutines are blocking
time.Sleep(100 * time.Millisecond)

// Connect
conn := newTMeshConnection(randomPeerID(t))
m.setPrimaryConnection(conn)

// All waiters should return
for i := 0; i < numWaiters; i++ {
select {
case err := <-done:
if err != nil {
t.Fatalf("waiter %d: expected no error, got %v", i, err)
}
case <-time.After(time.Second):
t.Fatalf("waiter %d did not return", i)
}
}
})

t.Run("context cancellation returns error", func(t *testing.T) {
m := newMeshConnectionManager(&meshConnectionManagerConfig{log: logger})

ctx, cancel := context.WithCancel(context.Background())

// Start waiter in goroutine
done := make(chan error, 1)
go func() {
done <- m.waitForConnection(ctx)
}()

// Wait a bit to ensure goroutine is blocking
time.Sleep(100 * time.Millisecond)

// Cancel context
cancel()

// Waiter should return with context error
select {
case err := <-done:
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("waiter did not return after context cancel")
}
})

t.Run("handles connect/disconnect cycles", func(t *testing.T) {
m := newMeshConnectionManager(&meshConnectionManagerConfig{log: logger})

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

nodeID := randomPeerID(t)

// Cycle 1: wait, then connect
done1 := make(chan error, 1)
go func() {
done1 <- m.waitForConnection(ctx)
}()
time.Sleep(50 * time.Millisecond)
conn1 := newTMeshConnection(nodeID)
m.setPrimaryConnection(conn1)

if err := <-done1; err != nil {
t.Fatalf("cycle 1: expected no error, got %v", err)
}

// Cycle 2: disconnect, then wait and connect
m.setPrimaryConnection(nil)

done2 := make(chan error, 1)
go func() {
done2 <- m.waitForConnection(ctx)
}()
time.Sleep(50 * time.Millisecond)
conn2 := newTMeshConnection(nodeID)
m.setPrimaryConnection(conn2)

if err := <-done2; err != nil {
t.Fatalf("cycle 2: expected no error, got %v", err)
}

// Cycle 3: disconnect, then wait and connect
m.setPrimaryConnection(nil)

done3 := make(chan error, 1)
go func() {
done3 <- m.waitForConnection(ctx)
}()
time.Sleep(50 * time.Millisecond)
conn3 := newTMeshConnection(nodeID)
m.setPrimaryConnection(conn3)

if err := <-done3; err != nil {
t.Fatalf("cycle 3: expected no error, got %v", err)
}
})

t.Run("new waiter after disconnection waits again", func(t *testing.T) {
m := newMeshConnectionManager(&meshConnectionManagerConfig{log: logger})

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

// Connect
conn1 := newTMeshConnection(randomPeerID(t))
m.setPrimaryConnection(conn1)

// First waiter returns immediately (already connected)
done1 := make(chan error, 1)
go func() {
done1 <- m.waitForConnection(ctx)
}()

if err := <-done1; err != nil {
t.Fatalf("first waiter: expected no error, got %v", err)
}

// Disconnect
m.setPrimaryConnection(nil)

// Second waiter should block
done2 := make(chan error, 1)
go func() {
done2 <- m.waitForConnection(ctx)
}()

time.Sleep(100 * time.Millisecond)

// Should still be blocked
select {
case <-done2:
t.Fatal("waiter returned but should still be waiting")
default:
}

// Reconnect
conn2 := newTMeshConnection(randomPeerID(t))
m.setPrimaryConnection(conn2)

// Now waiter should return
select {
case err := <-done2:
if err != nil {
t.Fatalf("second waiter: expected no error, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("waiter did not return after reconnection")
}
})
}
50 changes: 50 additions & 0 deletions client/oracle.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package client

import (
"context"
"fmt"
"math/big"

"github.com/bisoncraft/mesh/protocols"
protocolsPb "github.com/bisoncraft/mesh/protocols/pb"
"google.golang.org/protobuf/proto"
)

// SubscribeToPriceOracle subscribes to live price updates for ticker (e.g. "BTC").
// handler is called with the USD price as a float64 for each update received.
func (c *Client) SubscribeToPriceOracle(ctx context.Context, ticker string, handler func(float64)) error {
return c.Subscribe(ctx, protocols.PriceTopic(ticker), func(event TopicEvent) {
if event.Type != TopicEventData {
return
}
var priceUpdate protocolsPb.ClientPriceUpdate
if err := unmarshalOracleData(event.Data, &priceUpdate); err != nil {
c.log.Errorf("Failed to unmarshal price update for %s: %v", ticker, err)
return
}
handler(priceUpdate.Price)
})
}

// SubscribeToFeeRateOracle subscribes to live fee rate updates for network (e.g. "BTC").
// handler is called with the fee rate as a *big.Int for each update received.
func (c *Client) SubscribeToFeeRateOracle(ctx context.Context, network string, handler func(*big.Int)) error {
return c.Subscribe(ctx, protocols.FeeRateTopic(network), func(event TopicEvent) {
if event.Type != TopicEventData {
return
}
var feeRateUpdate protocolsPb.ClientFeeRateUpdate
if err := unmarshalOracleData(event.Data, &feeRateUpdate); err != nil {
c.log.Errorf("Failed to unmarshal fee rate update for %s: %v", network, err)
return
}
handler(new(big.Int).SetBytes(feeRateUpdate.FeeRate))
})
}

func unmarshalOracleData(data []byte, msg proto.Message) error {
if err := proto.Unmarshal(data, msg); err != nil {
return fmt.Errorf("failed to unmarshal oracle data: %w", err)
}
return nil
}
Comment thread
dnldd marked this conversation as resolved.
Outdated
Loading