diff --git a/network/packetrelay/packet_listener_relay.go b/network/packetrelay/packet_listener_relay.go index 4c060a39..b760073f 100644 --- a/network/packetrelay/packet_listener_relay.go +++ b/network/packetrelay/packet_listener_relay.go @@ -43,7 +43,6 @@ var _ PacketReceiver = (*packetListenerReceiver)(nil) // PacketListenerRelay creates a new [PacketRelay] that uses the existing [transport.PacketListener] to // create connections to a relay. type PacketListenerRelay struct { - mu sync.RWMutex listener transport.PacketListener writeIdleTimeout time.Duration } @@ -52,9 +51,7 @@ type PacketListenerRelay struct { // create connections to a relay. // This function is useful if you already have an implementation of [transport.PacketListener] and you want to use it // with one of the network stacks (for example, network/lwip2transport) as a UDP traffic handler. -// -// Associations use a write-idle timeout that is reset only by [PacketSender.SendPacket], not by -// incoming packets. +// The writeIdleTimeout is reset only by [PacketSender.SendPacket], not by incoming packets. func NewPacketRelayFromPacketListener(pl transport.PacketListener, writeIdleTimeout time.Duration) (*PacketListenerRelay, error) { if pl == nil { return nil, errors.New("pl must not be nil") @@ -69,34 +66,17 @@ func NewPacketRelayFromPacketListener(pl transport.PacketListener, writeIdleTime return r, nil } -// SetWriteIdleTimeout sets the write-idle timeout for new associations. -// Existing associations keep the timeout they were created with. -func (relay *PacketListenerRelay) SetWriteIdleTimeout(timeout time.Duration) error { - if timeout <= 0 { - return errors.New("timeout must be greater than 0") - } - relay.mu.Lock() - defer relay.mu.Unlock() - relay.writeIdleTimeout = timeout - return nil -} - // NewAssociation implements [PacketRelay].NewAssociation. It uses [transport.PacketListener].ListenPacket to create // a [net.PacketConn], and returns a [PacketSender] and [PacketReceiver] based on this [net.PacketConn]. func (relay *PacketListenerRelay) NewAssociation() (PacketSender, PacketReceiver, error) { - relay.mu.RLock() - listener := relay.listener - writeIdleTimeout := relay.writeIdleTimeout - relay.mu.RUnlock() - - packetConn, err := listener.ListenPacket(context.Background()) + packetConn, err := relay.listener.ListenPacket(context.Background()) if err != nil { return nil, nil, err } association := &packetListenerAssociation{ packetConn: packetConn, - writeIdleTimeout: writeIdleTimeout, + writeIdleTimeout: relay.writeIdleTimeout, } if err := association.refreshDeadline(); err != nil { _ = association.close() diff --git a/network/packetrelay/packet_listener_relay_test.go b/network/packetrelay/packet_listener_relay_test.go index 6e199a2b..14fa46c7 100644 --- a/network/packetrelay/packet_listener_relay_test.go +++ b/network/packetrelay/packet_listener_relay_test.go @@ -51,48 +51,56 @@ func TestNewPacketRelayFromPacketListenerRejectsInvalidTimeout(t *testing.T) { require.Nil(t, relay) } -func TestPacketListenerRelaySetWriteIdleTimeout(t *testing.T) { - conn := &fakePacketConn{} +func TestPacketListenerRelayReceiveTimeoutClosesAssociation(t *testing.T) { + conn := &fakePacketConn{readErr: timeoutErr{}} pl := &fakePacketListener{conn: conn} - timeout := 5 * time.Minute relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second) require.NoError(t, err) - require.NoError(t, relay.SetWriteIdleTimeout(timeout)) - - sender, _, err := relay.NewAssociation() + _, receiver, err := relay.NewAssociation() require.NoError(t, err) - require.Len(t, conn.deadlines, 1) - requireDeadlineNear(t, conn.deadlines[0], timeout) - err = sender.SendPacket([]byte("hello"), netip.MustParseAddrPort("1.2.3.4:53")) - require.NoError(t, err) - require.Len(t, conn.deadlines, 2) - requireDeadlineNear(t, conn.deadlines[1], timeout) - require.Len(t, conn.writes, 1) + err = receiver.ReceivePackets(&mockPacketHandler{}) + require.ErrorAs(t, err, &timeoutErr{}) + require.True(t, conn.isClosed()) } -func TestPacketListenerRelaySetWriteIdleTimeoutRejectsInvalidTimeout(t *testing.T) { - conn := &fakePacketConn{} - pl := &fakePacketListener{conn: conn} - - relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second) +func TestPacketListenerRelaySendCloseRace(t *testing.T) { + relay, err := NewPacketRelayFromPacketListener(&fakePacketListener{conn: &fakePacketConn{}}, time.Second) + require.NoError(t, err) + sender, _, err := relay.NewAssociation() require.NoError(t, err) - require.Error(t, relay.SetWriteIdleTimeout(0)) -} -func TestPacketListenerRelayReceiveTimeoutClosesAssociation(t *testing.T) { - conn := &fakePacketConn{readErr: timeoutErr{}} - pl := &fakePacketListener{conn: conn} + start := make(chan struct{}) + var wg sync.WaitGroup + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + <-start + _ = sender.SendPacket([]byte("x"), netip.MustParseAddrPort("1.2.3.4:53")) + }() + } + close(start) + _ = sender.Close() + wg.Wait() +} - relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second) +func TestPacketListenerRelayReceiveCloseRace(t *testing.T) { + conn := newBlockingPacketConn() + relay, err := NewPacketRelayFromPacketListener(&connListener{conn: conn}, time.Second) require.NoError(t, err) - _, receiver, err := relay.NewAssociation() + sender, receiver, err := relay.NewAssociation() require.NoError(t, err) - err = receiver.ReceivePackets(&mockPacketHandler{}) - require.ErrorAs(t, err, &timeoutErr{}) - require.True(t, conn.isClosed()) + done := make(chan struct{}) + go func() { + defer close(done) + _ = receiver.ReceivePackets(&mockPacketHandler{}) + }() + + _ = sender.Close() + <-done } func requireDeadlineNear(t *testing.T, deadline time.Time, timeout time.Duration) { @@ -169,6 +177,36 @@ func (c *fakePacketConn) isClosed() bool { return c.closed } +// blockingPacketConn blocks ReadFrom until Close is called, simulating a real UDP conn. +type blockingPacketConn struct { + fakePacketConn + once sync.Once + closed chan struct{} +} + +func newBlockingPacketConn() *blockingPacketConn { + return &blockingPacketConn{closed: make(chan struct{})} +} + +func (c *blockingPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + <-c.closed + return 0, nil, net.ErrClosed +} + +func (c *blockingPacketConn) Close() error { + c.once.Do(func() { close(c.closed) }) + return c.fakePacketConn.Close() +} + +// connListener wraps a net.PacketConn as a transport.PacketListener. +type connListener struct { + conn net.PacketConn +} + +func (l *connListener) ListenPacket(_ context.Context) (net.PacketConn, error) { + return l.conn, nil +} + type timeoutErr struct{} func (timeoutErr) Error() string {