From b982afe4e01eecf654467446f2cd596fe1cba18e Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 23 Jun 2026 00:25:27 -0400 Subject: [PATCH] refactor(packetrelay): remove SetWriteIdleTimeout from PacketListenerRelay The timeout is set once at construction via NewPacketRelayFromPacketListener and does not need to be mutable after that. Removes the method, the RWMutex it required, and adds race-condition tests for concurrent SendPacket/Close and ReceivePackets/Close. Co-Authored-By: Claude Sonnet 4.6 --- network/packetrelay/packet_listener_relay.go | 26 +---- .../packetrelay/packet_listener_relay_test.go | 94 +++++++++++++------ 2 files changed, 69 insertions(+), 51 deletions(-) diff --git a/network/packetrelay/packet_listener_relay.go b/network/packetrelay/packet_listener_relay.go index 4c060a39..b760073f 100644 --- a/network/packetrelay/packet_listener_relay.go +++ b/network/packetrelay/packet_listener_relay.go @@ -43,7 +43,6 @@ var _ PacketReceiver = (*packetListenerReceiver)(nil) // PacketListenerRelay creates a new [PacketRelay] that uses the existing [transport.PacketListener] to // create connections to a relay. type PacketListenerRelay struct { - mu sync.RWMutex listener transport.PacketListener writeIdleTimeout time.Duration } @@ -52,9 +51,7 @@ type PacketListenerRelay struct { // create connections to a relay. // This function is useful if you already have an implementation of [transport.PacketListener] and you want to use it // with one of the network stacks (for example, network/lwip2transport) as a UDP traffic handler. -// -// Associations use a write-idle timeout that is reset only by [PacketSender.SendPacket], not by -// incoming packets. +// The writeIdleTimeout is reset only by [PacketSender.SendPacket], not by incoming packets. func NewPacketRelayFromPacketListener(pl transport.PacketListener, writeIdleTimeout time.Duration) (*PacketListenerRelay, error) { if pl == nil { return nil, errors.New("pl must not be nil") @@ -69,34 +66,17 @@ func NewPacketRelayFromPacketListener(pl transport.PacketListener, writeIdleTime return r, nil } -// SetWriteIdleTimeout sets the write-idle timeout for new associations. -// Existing associations keep the timeout they were created with. -func (relay *PacketListenerRelay) SetWriteIdleTimeout(timeout time.Duration) error { - if timeout <= 0 { - return errors.New("timeout must be greater than 0") - } - relay.mu.Lock() - defer relay.mu.Unlock() - relay.writeIdleTimeout = timeout - return nil -} - // NewAssociation implements [PacketRelay].NewAssociation. It uses [transport.PacketListener].ListenPacket to create // a [net.PacketConn], and returns a [PacketSender] and [PacketReceiver] based on this [net.PacketConn]. func (relay *PacketListenerRelay) NewAssociation() (PacketSender, PacketReceiver, error) { - relay.mu.RLock() - listener := relay.listener - writeIdleTimeout := relay.writeIdleTimeout - relay.mu.RUnlock() - - packetConn, err := listener.ListenPacket(context.Background()) + packetConn, err := relay.listener.ListenPacket(context.Background()) if err != nil { return nil, nil, err } association := &packetListenerAssociation{ packetConn: packetConn, - writeIdleTimeout: writeIdleTimeout, + writeIdleTimeout: relay.writeIdleTimeout, } if err := association.refreshDeadline(); err != nil { _ = association.close() diff --git a/network/packetrelay/packet_listener_relay_test.go b/network/packetrelay/packet_listener_relay_test.go index 6e199a2b..14fa46c7 100644 --- a/network/packetrelay/packet_listener_relay_test.go +++ b/network/packetrelay/packet_listener_relay_test.go @@ -51,48 +51,56 @@ func TestNewPacketRelayFromPacketListenerRejectsInvalidTimeout(t *testing.T) { require.Nil(t, relay) } -func TestPacketListenerRelaySetWriteIdleTimeout(t *testing.T) { - conn := &fakePacketConn{} +func TestPacketListenerRelayReceiveTimeoutClosesAssociation(t *testing.T) { + conn := &fakePacketConn{readErr: timeoutErr{}} pl := &fakePacketListener{conn: conn} - timeout := 5 * time.Minute relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second) require.NoError(t, err) - require.NoError(t, relay.SetWriteIdleTimeout(timeout)) - - sender, _, err := relay.NewAssociation() + _, receiver, err := relay.NewAssociation() require.NoError(t, err) - require.Len(t, conn.deadlines, 1) - requireDeadlineNear(t, conn.deadlines[0], timeout) - err = sender.SendPacket([]byte("hello"), netip.MustParseAddrPort("1.2.3.4:53")) - require.NoError(t, err) - require.Len(t, conn.deadlines, 2) - requireDeadlineNear(t, conn.deadlines[1], timeout) - require.Len(t, conn.writes, 1) + err = receiver.ReceivePackets(&mockPacketHandler{}) + require.ErrorAs(t, err, &timeoutErr{}) + require.True(t, conn.isClosed()) } -func TestPacketListenerRelaySetWriteIdleTimeoutRejectsInvalidTimeout(t *testing.T) { - conn := &fakePacketConn{} - pl := &fakePacketListener{conn: conn} - - relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second) +func TestPacketListenerRelaySendCloseRace(t *testing.T) { + relay, err := NewPacketRelayFromPacketListener(&fakePacketListener{conn: &fakePacketConn{}}, time.Second) + require.NoError(t, err) + sender, _, err := relay.NewAssociation() require.NoError(t, err) - require.Error(t, relay.SetWriteIdleTimeout(0)) -} -func TestPacketListenerRelayReceiveTimeoutClosesAssociation(t *testing.T) { - conn := &fakePacketConn{readErr: timeoutErr{}} - pl := &fakePacketListener{conn: conn} + start := make(chan struct{}) + var wg sync.WaitGroup + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + <-start + _ = sender.SendPacket([]byte("x"), netip.MustParseAddrPort("1.2.3.4:53")) + }() + } + close(start) + _ = sender.Close() + wg.Wait() +} - relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second) +func TestPacketListenerRelayReceiveCloseRace(t *testing.T) { + conn := newBlockingPacketConn() + relay, err := NewPacketRelayFromPacketListener(&connListener{conn: conn}, time.Second) require.NoError(t, err) - _, receiver, err := relay.NewAssociation() + sender, receiver, err := relay.NewAssociation() require.NoError(t, err) - err = receiver.ReceivePackets(&mockPacketHandler{}) - require.ErrorAs(t, err, &timeoutErr{}) - require.True(t, conn.isClosed()) + done := make(chan struct{}) + go func() { + defer close(done) + _ = receiver.ReceivePackets(&mockPacketHandler{}) + }() + + _ = sender.Close() + <-done } func requireDeadlineNear(t *testing.T, deadline time.Time, timeout time.Duration) { @@ -169,6 +177,36 @@ func (c *fakePacketConn) isClosed() bool { return c.closed } +// blockingPacketConn blocks ReadFrom until Close is called, simulating a real UDP conn. +type blockingPacketConn struct { + fakePacketConn + once sync.Once + closed chan struct{} +} + +func newBlockingPacketConn() *blockingPacketConn { + return &blockingPacketConn{closed: make(chan struct{})} +} + +func (c *blockingPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + <-c.closed + return 0, nil, net.ErrClosed +} + +func (c *blockingPacketConn) Close() error { + c.once.Do(func() { close(c.closed) }) + return c.fakePacketConn.Close() +} + +// connListener wraps a net.PacketConn as a transport.PacketListener. +type connListener struct { + conn net.PacketConn +} + +func (l *connListener) ListenPacket(_ context.Context) (net.PacketConn, error) { + return l.conn, nil +} + type timeoutErr struct{} func (timeoutErr) Error() string {