Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions network/dnsintercept/packet_relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ func NewInterceptDNSPacketRelay(dnsRelay, defaultRelay packetrelay.PacketRelay,
// State machine for lazy default association initialization:
//
// stateIdle: Initial state. No default association created yet.
// stateInitializing: A SendPacket call is currently invoking NewAssociation on the defaultRelay.
// Other concurrent SendPacket calls will block waiting for this to finish.
// stateInitialized: The default association has been resolved (either success or error cached).
// Future SendPacket calls will immediately use the cached result.
// stateInitializing: A SendPacket call is currently invoking NewAssociation on the defaultRelay;
// other concurrent SendPacket calls will block waiting for this to finish.
// stateInitialized: The default association has been resolved (either success or error cached);
// future SendPacket calls will immediately use the cached result.
const (
stateIdle = iota
stateInitializing
Expand Down Expand Up @@ -295,7 +295,7 @@ var _ packetrelay.PacketSender = (*interceptSender)(nil)
// on the DNS relay, rewrites the destination, and forwards the packet.
// Otherwise, it lazily initializes and uses a single association on the default relay.
func (s *interceptSender) SendPacket(p []byte, destination netip.AddrPort) error {
if destination == s.a.relay.dnsLocalResolver {
if isSameAddrPort(destination, s.a.relay.dnsLocalResolver) {
return s.a.handleDNSQuery(p)
}

Expand All @@ -306,6 +306,12 @@ func (s *interceptSender) SendPacket(p []byte, destination netip.AddrPort) error
return defSender.SendPacket(p, destination)
}

// isSameAddrPort treats IPv4 and IPv4-mapped IPv6 addresses as equivalent because
// some network stacks surface IPv4 UDP destinations in IPv4-mapped form.
func isSameAddrPort(a, b netip.AddrPort) bool {
return a.Addr().Unmap() == b.Addr().Unmap() && a.Port() == b.Port()
}

// Close terminates the parent association and immediately closes all active sub-associations,
// including the default association and any pending DNS query associations.
func (s *interceptSender) Close() error {
Expand Down
56 changes: 56 additions & 0 deletions network/dnsintercept/packet_relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,62 @@ func TestDNSQueryRouting(t *testing.T) {
})
}

func TestDNSQueryRoutingWithIPv4MappedLocalResolver(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
dnsRelay := &mockRelay{}
defaultRelay := &mockRelay{}
localRes := netip.MustParseAddrPort("10.0.0.1:53")
remoteRes := netip.MustParseAddrPort("8.8.8.8:53")
relay := NewInterceptDNSPacketRelay(dnsRelay, defaultRelay, localRes, remoteRes)

sender, receiver, err := relay.NewAssociation()
if err != nil {
t.Fatalf("NewAssociation failed: %v", err)
}

handler := &mockHandler{packets: make(chan packetData, 10)}
go receiver.ReceivePackets(handler)

reqData := []byte("dns_query")
mappedLocalRes := netip.MustParseAddrPort("[::ffff:10.0.0.1]:53")
if err := sender.SendPacket(reqData, mappedLocalRes); err != nil {
t.Fatalf("SendPacket failed: %v", err)
}
synctest.Wait()

if dnsRelay.newAssocCount != 1 {
t.Fatalf("Expected 1 DNS association, got %d", dnsRelay.newAssocCount)
}
if defaultRelay.newAssocCount != 0 {
t.Fatalf("Expected 0 default associations, got %d", defaultRelay.newAssocCount)
}

dnsSender := dnsRelay.senders[0]
pd := <-dnsSender.packets
if !bytes.Equal(pd.p, reqData) {
t.Errorf("Expected packet %q, got %q", reqData, pd.p)
}
if pd.dest != remoteRes {
t.Errorf("Expected dest %v, got %v", remoteRes, pd.dest)
}

respData := []byte("dns_response")
dnsSender.receiver.PushResponse(respData, remoteRes)
synctest.Wait()

resp := <-handler.packets
if !bytes.Equal(resp.p, respData) {
t.Errorf("Expected response %q, got %q", respData, resp.p)
}
if resp.dest != localRes {
t.Errorf("Expected rewritten source %v, got %v", localRes, resp.dest)
}
if !dnsSender.IsClosed() {
t.Errorf("Expected DNS sender to be closed after one response")
}
})
}

func TestDefaultRouting(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
dnsRelay := &mockRelay{}
Expand Down
20 changes: 8 additions & 12 deletions network/packet_listener_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,8 @@ type PacketListenerProxy struct {
//
// Deprecated: Use [packetrelay.NewPacketRelayFromPacketListener] instead.
func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...func(*PacketListenerProxy) error) (*PacketListenerProxy, error) {
// Create the underlying base relay
baseRelay, err := packetrelay.NewPacketRelayFromPacketListener(pl)
if err != nil {
return nil, err
}

p := &PacketListenerProxy{
baseRelay: baseRelay,
writeIdleTimeout: 30 * time.Second, // Default timeout
writeIdleTimeout: packetrelay.DefaultPacketListenerRelayWriteIdleTimeout,
}

// Apply options
Expand All @@ -58,12 +51,15 @@ func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...fu
}
}

// Build the final relay chain: TimeoutPacketRelay(PacketListenerRelay)
timeoutRelay, err := packetrelay.NewTimeoutPacketRelay(p.baseRelay, p.writeIdleTimeout)
baseRelay, err := packetrelay.NewPacketRelayFromPacketListener(pl)
if err != nil {
return nil, err
}
p.relay = timeoutRelay
if err := baseRelay.SetWriteIdleTimeout(p.writeIdleTimeout); err != nil {
return nil, err
}
p.baseRelay = baseRelay
p.relay = baseRelay

return p, nil
}
Expand All @@ -72,7 +68,7 @@ func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...fu
// This means that if there are no WriteTo operations on the UDP session created by NewSession for the specified amount
// of time, the proxy will end this session.
//
// Deprecated: Use [packetrelay.NewTimeoutPacketRelay] to decorate the underlying [packetrelay.PacketRelay] instead.
// Deprecated: Use [packetrelay.PacketListenerRelay.SetWriteIdleTimeout] instead.
func WithPacketListenerWriteIdleTimeout(timeout time.Duration) func(*PacketListenerProxy) error {
return func(p *PacketListenerProxy) error {
if timeout <= 0 {
Expand Down
133 changes: 98 additions & 35 deletions network/packetrelay/packet_listener_relay.go
Comment thread
fortuna marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net"
"net/netip"
"sync"
"time"

"golang.getoutline.org/sdk/internal/slicepool"
"golang.getoutline.org/sdk/transport"
Expand All @@ -34,6 +35,10 @@ const packetMaxSize = 2048
// packetBufferPool is used to create buffers to read UDP response packets
var packetBufferPool = slicepool.MakePool(packetMaxSize)

// DefaultPacketListenerRelayWriteIdleTimeout is the default write-idle timeout for
// associations created by [PacketListenerRelay].
const DefaultPacketListenerRelayWriteIdleTimeout = 30 * time.Second

// Compilation guard against interface implementation
var _ PacketRelay = (*PacketListenerRelay)(nil)
var _ PacketSender = (*packetListenerSender)(nil)
Expand All @@ -42,87 +47,144 @@ var _ PacketReceiver = (*packetListenerReceiver)(nil)
// PacketListenerRelay creates a new [PacketRelay] that uses the existing [transport.PacketListener] to
// create connections to a relay.
type PacketListenerRelay struct {
listener transport.PacketListener
mu sync.RWMutex
listener transport.PacketListener
writeIdleTimeout time.Duration
}

// NewPacketRelayFromPacketListener creates a new [PacketRelay] that uses the existing [transport.PacketListener] to
// create connections to a relay. You can also specify additional options.
// create connections to a relay.
// This function is useful if you already have an implementation of [transport.PacketListener] and you want to use it
// with one of the network stacks (for example, network/lwip2transport) as a UDP traffic handler.
func NewPacketRelayFromPacketListener(pl transport.PacketListener, options ...func(*PacketListenerRelay) error) (*PacketListenerRelay, error) {
//
// Associations use a write-idle timeout that is reset only by [PacketSender.SendPacket], not by
// incoming packets. The default timeout is [DefaultPacketListenerRelayWriteIdleTimeout].
func NewPacketRelayFromPacketListener(pl transport.PacketListener) (*PacketListenerRelay, error) {
if pl == nil {
return nil, errors.New("pl must not be nil")
}
r := &PacketListenerRelay{
listener: pl,
}
for _, opt := range options {
if err := opt(r); err != nil {
return nil, err
}
listener: pl,
writeIdleTimeout: DefaultPacketListenerRelayWriteIdleTimeout,
}
return r, nil
}

// SetWriteIdleTimeout sets the write-idle timeout for new associations.
// Existing associations keep the timeout they were created with.
func (relay *PacketListenerRelay) SetWriteIdleTimeout(timeout time.Duration) error {
if timeout <= 0 {
return errors.New("timeout must be greater than 0")
}
relay.mu.Lock()
defer relay.mu.Unlock()
relay.writeIdleTimeout = timeout
return nil
}

// NewAssociation implements [PacketRelay].NewAssociation. It uses [transport.PacketListener].ListenPacket to create
// a [net.PacketConn], and returns a [PacketSender] and [PacketReceiver] based on this [net.PacketConn].
func (relay *PacketListenerRelay) NewAssociation() (PacketSender, PacketReceiver, error) {
packetConn, err := relay.listener.ListenPacket(context.Background())
relay.mu.RLock()
listener := relay.listener
writeIdleTimeout := relay.writeIdleTimeout
relay.mu.RUnlock()

packetConn, err := listener.ListenPacket(context.Background())
if err != nil {
return nil, nil, err
}

association := &packetListenerAssociation{
packetConn: packetConn,
writeIdleTimeout: writeIdleTimeout,
}
if err := association.refreshDeadline(); err != nil {
_ = association.close()
return nil, nil, err
}

sender := &packetListenerSender{
packetConn: packetConn,
association: association,
}

receiver := &packetListenerReceiver{
packetConn: packetConn,
association: association,
packetConn: packetConn,
}

return sender, receiver, nil
}

type packetListenerSender struct {
mu sync.Mutex // Protects closed flag
closed bool
type packetListenerAssociation struct {
mu sync.Mutex
closed bool
packetConn net.PacketConn
writeIdleTimeout time.Duration
}

func (a *packetListenerAssociation) refreshDeadline() error {
a.mu.Lock()
defer a.mu.Unlock()

if a.closed {
return ErrClosed
}
return a.packetConn.SetDeadline(time.Now().Add(a.writeIdleTimeout))
}

func (a *packetListenerAssociation) getPacketConn() (net.PacketConn, error) {
a.mu.Lock()
defer a.mu.Unlock()

if a.closed {
return nil, ErrClosed
}
return a.packetConn, nil
}

packetConn net.PacketConn
func (a *packetListenerAssociation) close() error {
a.mu.Lock()
if a.closed {
a.mu.Unlock()
return ErrClosed
}
a.closed = true
packetConn := a.packetConn
a.packetConn = nil
a.mu.Unlock()

return packetConn.Close()
}

type packetListenerSender struct {
association *packetListenerAssociation
}

// SendPacket implements [PacketSender].SendPacket. It simply forwards the packet to the underlying
// [net.PacketConn].WriteTo function.
func (s *packetListenerSender) SendPacket(p []byte, destination netip.AddrPort) error {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return ErrClosed
if err := s.association.refreshDeadline(); err != nil {
return err
}
packetConn, err := s.association.getPacketConn()
if err != nil {
return err
}
packetConn := s.packetConn
s.mu.Unlock()

_, err := packetConn.WriteTo(p, net.UDPAddrFromAddrPort(destination))
_, err = packetConn.WriteTo(p, net.UDPAddrFromAddrPort(destination))
return err
}

// Close implements [PacketSender].Close. It closes the underlying [net.PacketConn]. This will also
// terminate the blocking loop in ReceivePackets because s.packetConn.ReadFrom will return an error.
func (s *packetListenerSender) Close() error {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return ErrClosed
}
s.closed = true
packetConn := s.packetConn
s.packetConn = nil
s.mu.Unlock()

return packetConn.Close()
return s.association.close()
}

type packetListenerReceiver struct {
packetConn net.PacketConn
association *packetListenerAssociation
packetConn net.PacketConn
}

// ReceivePackets implements [PacketReceiver].ReceivePackets. It blocks and passes incoming packets
Expand All @@ -139,6 +201,7 @@ func (r *packetListenerReceiver) ReceivePackets(handler PacketHandler) error {
if errors.Is(err, io.ErrShortBuffer) {
continue
}
_ = r.association.close()
return err
}

Expand Down
Loading
Loading