From 6dbaeb4453ad3c2c57274fd84b8da45a107b99dc Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 16 Jun 2026 01:52:44 -0400 Subject: [PATCH 1/4] fix(dnsintercept): match IPv4-mapped DNS resolver --- network/dnsintercept/packet_relay.go | 16 +++++-- network/dnsintercept/packet_relay_test.go | 56 +++++++++++++++++++++++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/network/dnsintercept/packet_relay.go b/network/dnsintercept/packet_relay.go index 34915080..eabc3757 100644 --- a/network/dnsintercept/packet_relay.go +++ b/network/dnsintercept/packet_relay.go @@ -58,10 +58,10 @@ func NewInterceptDNSPacketRelay(dnsRelay, defaultRelay packetrelay.PacketRelay, // State machine for lazy default association initialization: // // stateIdle: Initial state. No default association created yet. -// stateInitializing: A SendPacket call is currently invoking NewAssociation on the defaultRelay. -// Other concurrent SendPacket calls will block waiting for this to finish. -// stateInitialized: The default association has been resolved (either success or error cached). -// Future SendPacket calls will immediately use the cached result. +// stateInitializing: A SendPacket call is currently invoking NewAssociation on the defaultRelay; +// other concurrent SendPacket calls will block waiting for this to finish. +// stateInitialized: The default association has been resolved (either success or error cached); +// future SendPacket calls will immediately use the cached result. const ( stateIdle = iota stateInitializing @@ -295,7 +295,7 @@ var _ packetrelay.PacketSender = (*interceptSender)(nil) // on the DNS relay, rewrites the destination, and forwards the packet. // Otherwise, it lazily initializes and uses a single association on the default relay. func (s *interceptSender) SendPacket(p []byte, destination netip.AddrPort) error { - if destination == s.a.relay.dnsLocalResolver { + if isSameAddrPort(destination, s.a.relay.dnsLocalResolver) { return s.a.handleDNSQuery(p) } @@ -306,6 +306,12 @@ func (s *interceptSender) SendPacket(p []byte, destination netip.AddrPort) error return defSender.SendPacket(p, destination) } +// isSameAddrPort treats IPv4 and IPv4-mapped IPv6 addresses as equivalent because +// some network stacks surface IPv4 UDP destinations in IPv4-mapped form. +func isSameAddrPort(a, b netip.AddrPort) bool { + return a.Addr().Unmap() == b.Addr().Unmap() && a.Port() == b.Port() +} + // Close terminates the parent association and immediately closes all active sub-associations, // including the default association and any pending DNS query associations. func (s *interceptSender) Close() error { diff --git a/network/dnsintercept/packet_relay_test.go b/network/dnsintercept/packet_relay_test.go index 03d470de..a7d6d7a8 100644 --- a/network/dnsintercept/packet_relay_test.go +++ b/network/dnsintercept/packet_relay_test.go @@ -174,6 +174,62 @@ func TestDNSQueryRouting(t *testing.T) { }) } +func TestDNSQueryRoutingWithIPv4MappedLocalResolver(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + dnsRelay := &mockRelay{} + defaultRelay := &mockRelay{} + localRes := netip.MustParseAddrPort("10.0.0.1:53") + remoteRes := netip.MustParseAddrPort("8.8.8.8:53") + relay := NewInterceptDNSPacketRelay(dnsRelay, defaultRelay, localRes, remoteRes) + + sender, receiver, err := relay.NewAssociation() + if err != nil { + t.Fatalf("NewAssociation failed: %v", err) + } + + handler := &mockHandler{packets: make(chan packetData, 10)} + go receiver.ReceivePackets(handler) + + reqData := []byte("dns_query") + mappedLocalRes := netip.MustParseAddrPort("[::ffff:10.0.0.1]:53") + if err := sender.SendPacket(reqData, mappedLocalRes); err != nil { + t.Fatalf("SendPacket failed: %v", err) + } + synctest.Wait() + + if dnsRelay.newAssocCount != 1 { + t.Fatalf("Expected 1 DNS association, got %d", dnsRelay.newAssocCount) + } + if defaultRelay.newAssocCount != 0 { + t.Fatalf("Expected 0 default associations, got %d", defaultRelay.newAssocCount) + } + + dnsSender := dnsRelay.senders[0] + pd := <-dnsSender.packets + if !bytes.Equal(pd.p, reqData) { + t.Errorf("Expected packet %q, got %q", reqData, pd.p) + } + if pd.dest != remoteRes { + t.Errorf("Expected dest %v, got %v", remoteRes, pd.dest) + } + + respData := []byte("dns_response") + dnsSender.receiver.PushResponse(respData, remoteRes) + synctest.Wait() + + resp := <-handler.packets + if !bytes.Equal(resp.p, respData) { + t.Errorf("Expected response %q, got %q", respData, resp.p) + } + if resp.dest != localRes { + t.Errorf("Expected rewritten source %v, got %v", localRes, resp.dest) + } + if !dnsSender.IsClosed() { + t.Errorf("Expected DNS sender to be closed after one response") + } + }) +} + func TestDefaultRouting(t *testing.T) { synctest.Test(t, func(t *testing.T) { dnsRelay := &mockRelay{} From 4f32f083df861a4886c68be6d69d4b509e3773af Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 16 Jun 2026 02:19:59 -0400 Subject: [PATCH 2/4] fix(packetrelay): add PacketListenerRelay idle deadlines --- network/packet_listener_proxy.go | 20 +-- network/packetrelay/packet_listener_relay.go | 133 +++++++++++----- .../packetrelay/packet_listener_relay_test.go | 146 +++++++++++++++++- 3 files changed, 250 insertions(+), 49 deletions(-) diff --git a/network/packet_listener_proxy.go b/network/packet_listener_proxy.go index 1e1e69b0..5aea2b59 100644 --- a/network/packet_listener_proxy.go +++ b/network/packet_listener_proxy.go @@ -40,15 +40,8 @@ type PacketListenerProxy struct { // // Deprecated: Use [packetrelay.NewPacketRelayFromPacketListener] instead. func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...func(*PacketListenerProxy) error) (*PacketListenerProxy, error) { - // Create the underlying base relay - baseRelay, err := packetrelay.NewPacketRelayFromPacketListener(pl) - if err != nil { - return nil, err - } - p := &PacketListenerProxy{ - baseRelay: baseRelay, - writeIdleTimeout: 30 * time.Second, // Default timeout + writeIdleTimeout: packetrelay.DefaultPacketListenerRelayWriteIdleTimeout, } // Apply options @@ -58,12 +51,15 @@ func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...fu } } - // Build the final relay chain: TimeoutPacketRelay(PacketListenerRelay) - timeoutRelay, err := packetrelay.NewTimeoutPacketRelay(p.baseRelay, p.writeIdleTimeout) + baseRelay, err := packetrelay.NewPacketRelayFromPacketListener(pl) if err != nil { return nil, err } - p.relay = timeoutRelay + if err := baseRelay.SetWriteIdleTimeout(p.writeIdleTimeout); err != nil { + return nil, err + } + p.baseRelay = baseRelay + p.relay = baseRelay return p, nil } @@ -72,7 +68,7 @@ func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...fu // This means that if there are no WriteTo operations on the UDP session created by NewSession for the specified amount // of time, the proxy will end this session. // -// Deprecated: Use [packetrelay.NewTimeoutPacketRelay] to decorate the underlying [packetrelay.PacketRelay] instead. +// Deprecated: Use [packetrelay.PacketListenerRelay.SetWriteIdleTimeout] instead. func WithPacketListenerWriteIdleTimeout(timeout time.Duration) func(*PacketListenerProxy) error { return func(p *PacketListenerProxy) error { if timeout <= 0 { diff --git a/network/packetrelay/packet_listener_relay.go b/network/packetrelay/packet_listener_relay.go index 2fa175c9..cd7d9dd3 100644 --- a/network/packetrelay/packet_listener_relay.go +++ b/network/packetrelay/packet_listener_relay.go @@ -21,6 +21,7 @@ import ( "net" "net/netip" "sync" + "time" "golang.getoutline.org/sdk/internal/slicepool" "golang.getoutline.org/sdk/transport" @@ -34,6 +35,10 @@ const packetMaxSize = 2048 // packetBufferPool is used to create buffers to read UDP response packets var packetBufferPool = slicepool.MakePool(packetMaxSize) +// DefaultPacketListenerRelayWriteIdleTimeout is the default write-idle timeout for +// associations created by [PacketListenerRelay]. +const DefaultPacketListenerRelayWriteIdleTimeout = 30 * time.Second + // Compilation guard against interface implementation var _ PacketRelay = (*PacketListenerRelay)(nil) var _ PacketSender = (*packetListenerSender)(nil) @@ -42,87 +47,144 @@ var _ PacketReceiver = (*packetListenerReceiver)(nil) // PacketListenerRelay creates a new [PacketRelay] that uses the existing [transport.PacketListener] to // create connections to a relay. type PacketListenerRelay struct { - listener transport.PacketListener + mu sync.RWMutex + listener transport.PacketListener + writeIdleTimeout time.Duration } // NewPacketRelayFromPacketListener creates a new [PacketRelay] that uses the existing [transport.PacketListener] to -// create connections to a relay. You can also specify additional options. +// create connections to a relay. // This function is useful if you already have an implementation of [transport.PacketListener] and you want to use it // with one of the network stacks (for example, network/lwip2transport) as a UDP traffic handler. -func NewPacketRelayFromPacketListener(pl transport.PacketListener, options ...func(*PacketListenerRelay) error) (*PacketListenerRelay, error) { +// +// Associations use a write-idle timeout that is reset only by [PacketSender.SendPacket], not by +// incoming packets. The default timeout is [DefaultPacketListenerRelayWriteIdleTimeout]. +func NewPacketRelayFromPacketListener(pl transport.PacketListener) (*PacketListenerRelay, error) { if pl == nil { return nil, errors.New("pl must not be nil") } r := &PacketListenerRelay{ - listener: pl, - } - for _, opt := range options { - if err := opt(r); err != nil { - return nil, err - } + listener: pl, + writeIdleTimeout: DefaultPacketListenerRelayWriteIdleTimeout, } return r, nil } +// SetWriteIdleTimeout sets the write-idle timeout for new associations. +// Existing associations keep the timeout they were created with. +func (relay *PacketListenerRelay) SetWriteIdleTimeout(timeout time.Duration) error { + if timeout <= 0 { + return errors.New("timeout must be greater than 0") + } + relay.mu.Lock() + defer relay.mu.Unlock() + relay.writeIdleTimeout = timeout + return nil +} + // NewAssociation implements [PacketRelay].NewAssociation. It uses [transport.PacketListener].ListenPacket to create // a [net.PacketConn], and returns a [PacketSender] and [PacketReceiver] based on this [net.PacketConn]. func (relay *PacketListenerRelay) NewAssociation() (PacketSender, PacketReceiver, error) { - packetConn, err := relay.listener.ListenPacket(context.Background()) + relay.mu.RLock() + listener := relay.listener + writeIdleTimeout := relay.writeIdleTimeout + relay.mu.RUnlock() + + packetConn, err := listener.ListenPacket(context.Background()) if err != nil { return nil, nil, err } + association := &packetListenerAssociation{ + packetConn: packetConn, + writeIdleTimeout: writeIdleTimeout, + } + if err := association.refreshDeadline(); err != nil { + _ = association.close() + return nil, nil, err + } + sender := &packetListenerSender{ - packetConn: packetConn, + association: association, } receiver := &packetListenerReceiver{ - packetConn: packetConn, + association: association, + packetConn: packetConn, } return sender, receiver, nil } -type packetListenerSender struct { - mu sync.Mutex // Protects closed flag - closed bool +type packetListenerAssociation struct { + mu sync.Mutex + closed bool + packetConn net.PacketConn + writeIdleTimeout time.Duration +} + +func (a *packetListenerAssociation) refreshDeadline() error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.closed { + return ErrClosed + } + return a.packetConn.SetDeadline(time.Now().Add(a.writeIdleTimeout)) +} + +func (a *packetListenerAssociation) getPacketConn() (net.PacketConn, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.closed { + return nil, ErrClosed + } + return a.packetConn, nil +} - packetConn net.PacketConn +func (a *packetListenerAssociation) close() error { + a.mu.Lock() + if a.closed { + a.mu.Unlock() + return ErrClosed + } + a.closed = true + packetConn := a.packetConn + a.packetConn = nil + a.mu.Unlock() + + return packetConn.Close() +} + +type packetListenerSender struct { + association *packetListenerAssociation } // SendPacket implements [PacketSender].SendPacket. It simply forwards the packet to the underlying // [net.PacketConn].WriteTo function. func (s *packetListenerSender) SendPacket(p []byte, destination netip.AddrPort) error { - s.mu.Lock() - if s.closed { - s.mu.Unlock() - return ErrClosed + if err := s.association.refreshDeadline(); err != nil { + return err + } + packetConn, err := s.association.getPacketConn() + if err != nil { + return err } - packetConn := s.packetConn - s.mu.Unlock() - _, err := packetConn.WriteTo(p, net.UDPAddrFromAddrPort(destination)) + _, err = packetConn.WriteTo(p, net.UDPAddrFromAddrPort(destination)) return err } // Close implements [PacketSender].Close. It closes the underlying [net.PacketConn]. This will also // terminate the blocking loop in ReceivePackets because s.packetConn.ReadFrom will return an error. func (s *packetListenerSender) Close() error { - s.mu.Lock() - if s.closed { - s.mu.Unlock() - return ErrClosed - } - s.closed = true - packetConn := s.packetConn - s.packetConn = nil - s.mu.Unlock() - - return packetConn.Close() + return s.association.close() } type packetListenerReceiver struct { - packetConn net.PacketConn + association *packetListenerAssociation + packetConn net.PacketConn } // ReceivePackets implements [PacketReceiver].ReceivePackets. It blocks and passes incoming packets @@ -139,6 +201,7 @@ func (r *packetListenerReceiver) ReceivePackets(handler PacketHandler) error { if errors.Is(err, io.ErrShortBuffer) { continue } + _ = r.association.close() return err } diff --git a/network/packetrelay/packet_listener_relay_test.go b/network/packetrelay/packet_listener_relay_test.go index cac6b2b0..59ec1925 100644 --- a/network/packetrelay/packet_listener_relay_test.go +++ b/network/packetrelay/packet_listener_relay_test.go @@ -15,16 +15,158 @@ package packetrelay import ( + "context" + "net" + "net/netip" + "sync" "testing" + "time" - "golang.getoutline.org/sdk/transport" "github.com/stretchr/testify/require" ) func TestNewPacketRelayFromPacketListener(t *testing.T) { - pl := &transport.UDPListener{} + conn := &fakePacketConn{} + pl := &fakePacketListener{conn: conn} relay, err := NewPacketRelayFromPacketListener(pl) require.NoError(t, err) require.NotNil(t, relay) + + _, _, err = relay.NewAssociation() + require.NoError(t, err) + require.Len(t, conn.deadlines, 1) + requireDeadlineNear(t, conn.deadlines[0], DefaultPacketListenerRelayWriteIdleTimeout) +} + +func TestPacketListenerRelaySetWriteIdleTimeout(t *testing.T) { + conn := &fakePacketConn{} + pl := &fakePacketListener{conn: conn} + timeout := 5 * time.Minute + + relay, err := NewPacketRelayFromPacketListener(pl) + require.NoError(t, err) + require.NoError(t, relay.SetWriteIdleTimeout(timeout)) + + sender, _, err := relay.NewAssociation() + require.NoError(t, err) + require.Len(t, conn.deadlines, 1) + requireDeadlineNear(t, conn.deadlines[0], timeout) + + err = sender.SendPacket([]byte("hello"), netip.MustParseAddrPort("1.2.3.4:53")) + require.NoError(t, err) + require.Len(t, conn.deadlines, 2) + requireDeadlineNear(t, conn.deadlines[1], timeout) + require.Len(t, conn.writes, 1) +} + +func TestPacketListenerRelaySetWriteIdleTimeoutRejectsInvalidTimeout(t *testing.T) { + conn := &fakePacketConn{} + pl := &fakePacketListener{conn: conn} + + relay, err := NewPacketRelayFromPacketListener(pl) + require.NoError(t, err) + require.Error(t, relay.SetWriteIdleTimeout(0)) +} + +func TestPacketListenerRelayReceiveTimeoutClosesAssociation(t *testing.T) { + conn := &fakePacketConn{readErr: timeoutErr{}} + pl := &fakePacketListener{conn: conn} + + relay, err := NewPacketRelayFromPacketListener(pl) + require.NoError(t, err) + _, receiver, err := relay.NewAssociation() + require.NoError(t, err) + + err = receiver.ReceivePackets(&mockPacketHandler{}) + require.ErrorAs(t, err, &timeoutErr{}) + require.True(t, conn.isClosed()) +} + +func requireDeadlineNear(t *testing.T, deadline time.Time, timeout time.Duration) { + t.Helper() + require.WithinDuration(t, time.Now().Add(timeout), deadline, time.Second) +} + +type fakePacketListener struct { + conn *fakePacketConn +} + +func (l *fakePacketListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { + return l.conn, nil +} + +type packetWrite struct { + payload []byte + addr net.Addr +} + +type fakePacketConn struct { + mu sync.Mutex + deadlines []time.Time + writes []packetWrite + readErr error + closed bool +} + +func (c *fakePacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + if c.readErr != nil { + return 0, nil, c.readErr + } + return 0, nil, net.ErrClosed +} + +func (c *fakePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + payload := make([]byte, len(p)) + copy(payload, p) + c.writes = append(c.writes, packetWrite{payload: payload, addr: addr}) + return len(p), nil +} + +func (c *fakePacketConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func (c *fakePacketConn) LocalAddr() net.Addr { + return &net.UDPAddr{} +} + +func (c *fakePacketConn) SetDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() + c.deadlines = append(c.deadlines, t) + return nil +} + +func (c *fakePacketConn) SetReadDeadline(t time.Time) error { + return c.SetDeadline(t) +} + +func (c *fakePacketConn) SetWriteDeadline(t time.Time) error { + return c.SetDeadline(t) +} + +func (c *fakePacketConn) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +type timeoutErr struct{} + +func (timeoutErr) Error() string { + return "timeout" +} + +func (timeoutErr) Timeout() bool { + return true +} + +func (timeoutErr) Temporary() bool { + return true } From 79c8a04b30e9ebefe7a50c4ee24a4f192cb43162 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 16 Jun 2026 18:31:01 -0400 Subject: [PATCH 3/4] Clean up SetRelay --- network/packetrelay/delegate_packet_relay.go | 20 +++++++------- .../packetrelay/delegate_packet_relay_test.go | 26 +++++++++---------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/network/packetrelay/delegate_packet_relay.go b/network/packetrelay/delegate_packet_relay.go index 891a7ae9..7701281d 100644 --- a/network/packetrelay/delegate_packet_relay.go +++ b/network/packetrelay/delegate_packet_relay.go @@ -33,10 +33,10 @@ type DelegatePacketRelay interface { // SetRelay updates the underlying PacketRelay to `relay`; `relay` must not be nil. After this function // returns, all new PacketRelay calls will be forwarded to the `relay`. Existing associations will not be affected. - SetRelay(relay PacketRelay) error + SetRelay(relay PacketRelay) } -var errInvalidRelay = errors.New("the underlying relay must not be nil") +var errNoRelay = errors.New("no relay configured") // Compilation guard against interface implementation var _ DelegatePacketRelay = (*delegatePacketRelay)(nil) @@ -51,24 +51,22 @@ type delegatePacketRelay struct { // NewDelegatePacketRelay creates a new [DelegatePacketRelay] that forwards calls to the `relay` [PacketRelay]. // The `relay` must not be nil. func NewDelegatePacketRelay(relay PacketRelay) (DelegatePacketRelay, error) { - if relay == nil { - return nil, errInvalidRelay - } dr := delegatePacketRelay{} dr.relay.Store(&relay) return &dr, nil } // NewAssociation implements PacketRelay.NewAssociation, and it will forward the call to the underlying PacketRelay. +// Returns an error if the underlying relay is nil. func (p *delegatePacketRelay) NewAssociation() (PacketSender, PacketReceiver, error) { - return (*p.relay.Load()).NewAssociation() + relayPtr := p.relay.Load() + if relayPtr == nil || *relayPtr == nil { + return nil, nil, errNoRelay + } + return (*relayPtr).NewAssociation() } // SetRelay implements DelegatePacketRelay.SetRelay. -func (p *delegatePacketRelay) SetRelay(relay PacketRelay) error { - if relay == nil { - return errInvalidRelay - } +func (p *delegatePacketRelay) SetRelay(relay PacketRelay) { p.relay.Store(&relay) - return nil } diff --git a/network/packetrelay/delegate_packet_relay_test.go b/network/packetrelay/delegate_packet_relay_test.go index aa4d101a..404f03c6 100644 --- a/network/packetrelay/delegate_packet_relay_test.go +++ b/network/packetrelay/delegate_packet_relay_test.go @@ -44,8 +44,7 @@ func TestRelayCanBeUpdated(t *testing.T) { require.Exactly(t, 0, newRelay.Count()) // SetRelay should not call NewAssociation - err = p.SetRelay(newRelay) - require.NoError(t, err) + p.SetRelay(newRelay) require.Exactly(t, 1, defRelay.Count()) require.Exactly(t, 0, newRelay.Count()) @@ -84,8 +83,7 @@ func TestSetRelayRaceCondition(t *testing.T) { setRelayTask.Add(1) go func() { for i := 0; !cancelSetRelay.Load(); i = (i + 1) % relaysCnt { - err := dr.SetRelay(relays[i]) - require.NoError(t, err) + dr.SetRelay(relays[i]) } setRelayTask.Done() }() @@ -112,20 +110,23 @@ func TestSetRelayRaceCondition(t *testing.T) { require.Equal(t, expectedTotal, actualTotal) } -// Make sure we cannot SetRelay to nil +// Make sure we can SetRelay to nil, which makes NewAssociation fail with errNoRelay func TestSetRelayWithNilValue(t *testing.T) { - // must not initialize with nil + // initialization with nil does not return error, but calling NewAssociation will return errNoRelay dr, err := NewDelegatePacketRelay(nil) - require.Error(t, err) - require.Nil(t, dr) + require.NoError(t, err) + require.NotNil(t, dr) + _, _, err = dr.NewAssociation() + require.ErrorIs(t, err, errNoRelay) dr, err = NewDelegatePacketRelay(&sessionCountPacketRelay{}) require.NoError(t, err) require.NotNil(t, dr) - // must not SetRelay to nil - err = dr.SetRelay(nil) - require.Error(t, err) + // SetRelay(nil) should not panic or error, but calling NewAssociation afterwards will return errNoRelay + dr.SetRelay(nil) + _, _, err = dr.NewAssociation() + require.ErrorIs(t, err, errNoRelay) } // Make sure we can SetRelay to different types @@ -138,8 +139,7 @@ func TestSetRelayOfDifferentTypes(t *testing.T) { require.NoError(t, err) // SetRelay should not return error - err = p.SetRelay(newRelay) - require.NoError(t, err) + p.SetRelay(newRelay) // NewAssociation should not go to defRelay snd, rcv, err := p.NewAssociation() From 9ac5658faac332952a7e69669330b4b8f6e3daaa Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 17 Jun 2026 12:42:56 -0400 Subject: [PATCH 4/4] refactor(packetrelay): make writeIdleTimeout explicit in NewPacketRelayFromPacketListener --- network/packet_listener_proxy.go | 7 ++---- network/packetrelay/packet_listener_relay.go | 13 +++++------ .../packetrelay/packet_listener_relay_test.go | 22 ++++++++++++++----- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/network/packet_listener_proxy.go b/network/packet_listener_proxy.go index 5aea2b59..dd69cf8f 100644 --- a/network/packet_listener_proxy.go +++ b/network/packet_listener_proxy.go @@ -41,7 +41,7 @@ type PacketListenerProxy struct { // Deprecated: Use [packetrelay.NewPacketRelayFromPacketListener] instead. func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...func(*PacketListenerProxy) error) (*PacketListenerProxy, error) { p := &PacketListenerProxy{ - writeIdleTimeout: packetrelay.DefaultPacketListenerRelayWriteIdleTimeout, + writeIdleTimeout: 30 * time.Second, } // Apply options @@ -51,13 +51,10 @@ func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...fu } } - baseRelay, err := packetrelay.NewPacketRelayFromPacketListener(pl) + baseRelay, err := packetrelay.NewPacketRelayFromPacketListener(pl, p.writeIdleTimeout) if err != nil { return nil, err } - if err := baseRelay.SetWriteIdleTimeout(p.writeIdleTimeout); err != nil { - return nil, err - } p.baseRelay = baseRelay p.relay = baseRelay diff --git a/network/packetrelay/packet_listener_relay.go b/network/packetrelay/packet_listener_relay.go index cd7d9dd3..4c060a39 100644 --- a/network/packetrelay/packet_listener_relay.go +++ b/network/packetrelay/packet_listener_relay.go @@ -35,10 +35,6 @@ const packetMaxSize = 2048 // packetBufferPool is used to create buffers to read UDP response packets var packetBufferPool = slicepool.MakePool(packetMaxSize) -// DefaultPacketListenerRelayWriteIdleTimeout is the default write-idle timeout for -// associations created by [PacketListenerRelay]. -const DefaultPacketListenerRelayWriteIdleTimeout = 30 * time.Second - // Compilation guard against interface implementation var _ PacketRelay = (*PacketListenerRelay)(nil) var _ PacketSender = (*packetListenerSender)(nil) @@ -58,14 +54,17 @@ type PacketListenerRelay struct { // with one of the network stacks (for example, network/lwip2transport) as a UDP traffic handler. // // Associations use a write-idle timeout that is reset only by [PacketSender.SendPacket], not by -// incoming packets. The default timeout is [DefaultPacketListenerRelayWriteIdleTimeout]. -func NewPacketRelayFromPacketListener(pl transport.PacketListener) (*PacketListenerRelay, error) { +// incoming packets. +func NewPacketRelayFromPacketListener(pl transport.PacketListener, writeIdleTimeout time.Duration) (*PacketListenerRelay, error) { if pl == nil { return nil, errors.New("pl must not be nil") } + if writeIdleTimeout <= 0 { + return nil, errors.New("writeIdleTimeout must be greater than 0") + } r := &PacketListenerRelay{ listener: pl, - writeIdleTimeout: DefaultPacketListenerRelayWriteIdleTimeout, + writeIdleTimeout: writeIdleTimeout, } return r, nil } diff --git a/network/packetrelay/packet_listener_relay_test.go b/network/packetrelay/packet_listener_relay_test.go index 59ec1925..6e199a2b 100644 --- a/network/packetrelay/packet_listener_relay_test.go +++ b/network/packetrelay/packet_listener_relay_test.go @@ -29,14 +29,26 @@ func TestNewPacketRelayFromPacketListener(t *testing.T) { conn := &fakePacketConn{} pl := &fakePacketListener{conn: conn} - relay, err := NewPacketRelayFromPacketListener(pl) + relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second) require.NoError(t, err) require.NotNil(t, relay) _, _, err = relay.NewAssociation() require.NoError(t, err) require.Len(t, conn.deadlines, 1) - requireDeadlineNear(t, conn.deadlines[0], DefaultPacketListenerRelayWriteIdleTimeout) + requireDeadlineNear(t, conn.deadlines[0], 30*time.Second) +} + +func TestNewPacketRelayFromPacketListenerRejectsInvalidTimeout(t *testing.T) { + pl := &fakePacketListener{conn: &fakePacketConn{}} + + relay, err := NewPacketRelayFromPacketListener(pl, 0) + require.Error(t, err) + require.Nil(t, relay) + + relay, err = NewPacketRelayFromPacketListener(pl, -1*time.Second) + require.Error(t, err) + require.Nil(t, relay) } func TestPacketListenerRelaySetWriteIdleTimeout(t *testing.T) { @@ -44,7 +56,7 @@ func TestPacketListenerRelaySetWriteIdleTimeout(t *testing.T) { pl := &fakePacketListener{conn: conn} timeout := 5 * time.Minute - relay, err := NewPacketRelayFromPacketListener(pl) + relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second) require.NoError(t, err) require.NoError(t, relay.SetWriteIdleTimeout(timeout)) @@ -64,7 +76,7 @@ func TestPacketListenerRelaySetWriteIdleTimeoutRejectsInvalidTimeout(t *testing. conn := &fakePacketConn{} pl := &fakePacketListener{conn: conn} - relay, err := NewPacketRelayFromPacketListener(pl) + relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second) require.NoError(t, err) require.Error(t, relay.SetWriteIdleTimeout(0)) } @@ -73,7 +85,7 @@ func TestPacketListenerRelayReceiveTimeoutClosesAssociation(t *testing.T) { conn := &fakePacketConn{readErr: timeoutErr{}} pl := &fakePacketListener{conn: conn} - relay, err := NewPacketRelayFromPacketListener(pl) + relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second) require.NoError(t, err) _, receiver, err := relay.NewAssociation() require.NoError(t, err)