Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 3 additions & 23 deletions network/packetrelay/packet_listener_relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ var _ PacketReceiver = (*packetListenerReceiver)(nil)
// PacketListenerRelay creates a new [PacketRelay] that uses the existing [transport.PacketListener] to
// create connections to a relay.
type PacketListenerRelay struct {
mu sync.RWMutex
listener transport.PacketListener
writeIdleTimeout time.Duration
}
Expand All @@ -52,9 +51,7 @@ type PacketListenerRelay struct {
// create connections to a relay.
// This function is useful if you already have an implementation of [transport.PacketListener] and you want to use it
// with one of the network stacks (for example, network/lwip2transport) as a UDP traffic handler.
//
// Associations use a write-idle timeout that is reset only by [PacketSender.SendPacket], not by
// incoming packets.
// The writeIdleTimeout is reset only by [PacketSender.SendPacket], not by incoming packets.
func NewPacketRelayFromPacketListener(pl transport.PacketListener, writeIdleTimeout time.Duration) (*PacketListenerRelay, error) {
if pl == nil {
return nil, errors.New("pl must not be nil")
Expand All @@ -69,34 +66,17 @@ func NewPacketRelayFromPacketListener(pl transport.PacketListener, writeIdleTime
return r, nil
}

// SetWriteIdleTimeout sets the write-idle timeout for new associations.
// Existing associations keep the timeout they were created with.
func (relay *PacketListenerRelay) SetWriteIdleTimeout(timeout time.Duration) error {
if timeout <= 0 {
return errors.New("timeout must be greater than 0")
}
relay.mu.Lock()
defer relay.mu.Unlock()
relay.writeIdleTimeout = timeout
return nil
}

// NewAssociation implements [PacketRelay].NewAssociation. It uses [transport.PacketListener].ListenPacket to create
// a [net.PacketConn], and returns a [PacketSender] and [PacketReceiver] based on this [net.PacketConn].
func (relay *PacketListenerRelay) NewAssociation() (PacketSender, PacketReceiver, error) {
relay.mu.RLock()
listener := relay.listener
writeIdleTimeout := relay.writeIdleTimeout
relay.mu.RUnlock()

packetConn, err := listener.ListenPacket(context.Background())
packetConn, err := relay.listener.ListenPacket(context.Background())
if err != nil {
return nil, nil, err
}

association := &packetListenerAssociation{
packetConn: packetConn,
writeIdleTimeout: writeIdleTimeout,
writeIdleTimeout: relay.writeIdleTimeout,
}
if err := association.refreshDeadline(); err != nil {
_ = association.close()
Expand Down
94 changes: 66 additions & 28 deletions network/packetrelay/packet_listener_relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,48 +51,56 @@ func TestNewPacketRelayFromPacketListenerRejectsInvalidTimeout(t *testing.T) {
require.Nil(t, relay)
}

func TestPacketListenerRelaySetWriteIdleTimeout(t *testing.T) {
conn := &fakePacketConn{}
func TestPacketListenerRelayReceiveTimeoutClosesAssociation(t *testing.T) {
conn := &fakePacketConn{readErr: timeoutErr{}}
pl := &fakePacketListener{conn: conn}
timeout := 5 * time.Minute

relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second)
require.NoError(t, err)
require.NoError(t, relay.SetWriteIdleTimeout(timeout))

sender, _, err := relay.NewAssociation()
_, receiver, err := relay.NewAssociation()
require.NoError(t, err)
require.Len(t, conn.deadlines, 1)
requireDeadlineNear(t, conn.deadlines[0], timeout)

err = sender.SendPacket([]byte("hello"), netip.MustParseAddrPort("1.2.3.4:53"))
require.NoError(t, err)
require.Len(t, conn.deadlines, 2)
requireDeadlineNear(t, conn.deadlines[1], timeout)
require.Len(t, conn.writes, 1)
err = receiver.ReceivePackets(&mockPacketHandler{})
require.ErrorAs(t, err, &timeoutErr{})
require.True(t, conn.isClosed())
}

func TestPacketListenerRelaySetWriteIdleTimeoutRejectsInvalidTimeout(t *testing.T) {
conn := &fakePacketConn{}
pl := &fakePacketListener{conn: conn}

relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second)
func TestPacketListenerRelaySendCloseRace(t *testing.T) {
relay, err := NewPacketRelayFromPacketListener(&fakePacketListener{conn: &fakePacketConn{}}, time.Second)
require.NoError(t, err)
sender, _, err := relay.NewAssociation()
require.NoError(t, err)
require.Error(t, relay.SetWriteIdleTimeout(0))
}

func TestPacketListenerRelayReceiveTimeoutClosesAssociation(t *testing.T) {
conn := &fakePacketConn{readErr: timeoutErr{}}
pl := &fakePacketListener{conn: conn}
start := make(chan struct{})
var wg sync.WaitGroup
for range 10 {
wg.Add(1)
go func() {
defer wg.Done()
<-start
_ = sender.SendPacket([]byte("x"), netip.MustParseAddrPort("1.2.3.4:53"))
}()
}
close(start)
_ = sender.Close()
wg.Wait()
}

relay, err := NewPacketRelayFromPacketListener(pl, 30*time.Second)
func TestPacketListenerRelayReceiveCloseRace(t *testing.T) {
conn := newBlockingPacketConn()
relay, err := NewPacketRelayFromPacketListener(&connListener{conn: conn}, time.Second)
require.NoError(t, err)
_, receiver, err := relay.NewAssociation()
sender, receiver, err := relay.NewAssociation()
require.NoError(t, err)

err = receiver.ReceivePackets(&mockPacketHandler{})
require.ErrorAs(t, err, &timeoutErr{})
require.True(t, conn.isClosed())
done := make(chan struct{})
go func() {
defer close(done)
_ = receiver.ReceivePackets(&mockPacketHandler{})
}()

_ = sender.Close()
<-done
}

func requireDeadlineNear(t *testing.T, deadline time.Time, timeout time.Duration) {
Expand Down Expand Up @@ -169,6 +177,36 @@ func (c *fakePacketConn) isClosed() bool {
return c.closed
}

// blockingPacketConn blocks ReadFrom until Close is called, simulating a real UDP conn.
type blockingPacketConn struct {
fakePacketConn
once sync.Once
closed chan struct{}
}

func newBlockingPacketConn() *blockingPacketConn {
return &blockingPacketConn{closed: make(chan struct{})}
}

func (c *blockingPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
<-c.closed
return 0, nil, net.ErrClosed
}

func (c *blockingPacketConn) Close() error {
c.once.Do(func() { close(c.closed) })
return c.fakePacketConn.Close()
}

// connListener wraps a net.PacketConn as a transport.PacketListener.
type connListener struct {
conn net.PacketConn
}

func (l *connListener) ListenPacket(_ context.Context) (net.PacketConn, error) {
return l.conn, nil
}

type timeoutErr struct{}

func (timeoutErr) Error() string {
Expand Down
Loading