Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions network/dnsintercept/packet_relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ func NewInterceptDNSPacketRelay(dnsRelay, defaultRelay packetrelay.PacketRelay,
// State machine for lazy default association initialization:
//
// stateIdle: Initial state. No default association created yet.
// stateInitializing: A SendPacket call is currently invoking NewAssociation on the defaultRelay.
// Other concurrent SendPacket calls will block waiting for this to finish.
// stateInitialized: The default association has been resolved (either success or error cached).
// Future SendPacket calls will immediately use the cached result.
// stateInitializing: A SendPacket call is currently invoking NewAssociation on the defaultRelay;
// other concurrent SendPacket calls will block waiting for this to finish.
// stateInitialized: The default association has been resolved (either success or error cached);
// future SendPacket calls will immediately use the cached result.
const (
stateIdle = iota
stateInitializing
Expand Down Expand Up @@ -295,7 +295,7 @@ var _ packetrelay.PacketSender = (*interceptSender)(nil)
// on the DNS relay, rewrites the destination, and forwards the packet.
// Otherwise, it lazily initializes and uses a single association on the default relay.
func (s *interceptSender) SendPacket(p []byte, destination netip.AddrPort) error {
if destination == s.a.relay.dnsLocalResolver {
if isSameAddrPort(destination, s.a.relay.dnsLocalResolver) {
return s.a.handleDNSQuery(p)
}

Expand All @@ -306,6 +306,12 @@ func (s *interceptSender) SendPacket(p []byte, destination netip.AddrPort) error
return defSender.SendPacket(p, destination)
}

// isSameAddrPort treats IPv4 and IPv4-mapped IPv6 addresses as equivalent because
// some network stacks surface IPv4 UDP destinations in IPv4-mapped form.
func isSameAddrPort(a, b netip.AddrPort) bool {
return a.Addr().Unmap() == b.Addr().Unmap() && a.Port() == b.Port()
}

// Close terminates the parent association and immediately closes all active sub-associations,
// including the default association and any pending DNS query associations.
func (s *interceptSender) Close() error {
Expand Down
56 changes: 56 additions & 0 deletions network/dnsintercept/packet_relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,62 @@ func TestDNSQueryRouting(t *testing.T) {
})
}

func TestDNSQueryRoutingWithIPv4MappedLocalResolver(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
dnsRelay := &mockRelay{}
defaultRelay := &mockRelay{}
localRes := netip.MustParseAddrPort("10.0.0.1:53")
remoteRes := netip.MustParseAddrPort("8.8.8.8:53")
relay := NewInterceptDNSPacketRelay(dnsRelay, defaultRelay, localRes, remoteRes)

sender, receiver, err := relay.NewAssociation()
if err != nil {
t.Fatalf("NewAssociation failed: %v", err)
}

handler := &mockHandler{packets: make(chan packetData, 10)}
go receiver.ReceivePackets(handler)

reqData := []byte("dns_query")
mappedLocalRes := netip.MustParseAddrPort("[::ffff:10.0.0.1]:53")
if err := sender.SendPacket(reqData, mappedLocalRes); err != nil {
t.Fatalf("SendPacket failed: %v", err)
}
synctest.Wait()

if dnsRelay.newAssocCount != 1 {
t.Fatalf("Expected 1 DNS association, got %d", dnsRelay.newAssocCount)
}
if defaultRelay.newAssocCount != 0 {
t.Fatalf("Expected 0 default associations, got %d", defaultRelay.newAssocCount)
}

dnsSender := dnsRelay.senders[0]
pd := <-dnsSender.packets
if !bytes.Equal(pd.p, reqData) {
t.Errorf("Expected packet %q, got %q", reqData, pd.p)
}
if pd.dest != remoteRes {
t.Errorf("Expected dest %v, got %v", remoteRes, pd.dest)
}

respData := []byte("dns_response")
dnsSender.receiver.PushResponse(respData, remoteRes)
synctest.Wait()

resp := <-handler.packets
if !bytes.Equal(resp.p, respData) {
t.Errorf("Expected response %q, got %q", respData, resp.p)
}
if resp.dest != localRes {
t.Errorf("Expected rewritten source %v, got %v", localRes, resp.dest)
}
if !dnsSender.IsClosed() {
t.Errorf("Expected DNS sender to be closed after one response")
}
})
}

func TestDefaultRouting(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
dnsRelay := &mockRelay{}
Expand Down
17 changes: 5 additions & 12 deletions network/packet_listener_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,8 @@ type PacketListenerProxy struct {
//
// Deprecated: Use [packetrelay.NewPacketRelayFromPacketListener] instead.
func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...func(*PacketListenerProxy) error) (*PacketListenerProxy, error) {
// Create the underlying base relay
baseRelay, err := packetrelay.NewPacketRelayFromPacketListener(pl)
if err != nil {
return nil, err
}

p := &PacketListenerProxy{
baseRelay: baseRelay,
writeIdleTimeout: 30 * time.Second, // Default timeout
writeIdleTimeout: 30 * time.Second,
}

// Apply options
Expand All @@ -58,12 +51,12 @@ func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...fu
}
}

// Build the final relay chain: TimeoutPacketRelay(PacketListenerRelay)
timeoutRelay, err := packetrelay.NewTimeoutPacketRelay(p.baseRelay, p.writeIdleTimeout)
baseRelay, err := packetrelay.NewPacketRelayFromPacketListener(pl, p.writeIdleTimeout)
if err != nil {
return nil, err
}
p.relay = timeoutRelay
p.baseRelay = baseRelay
p.relay = baseRelay

return p, nil
}
Expand All @@ -72,7 +65,7 @@ func NewPacketProxyFromPacketListener(pl transport.PacketListener, options ...fu
// This means that if there are no WriteTo operations on the UDP session created by NewSession for the specified amount
// of time, the proxy will end this session.
//
// Deprecated: Use [packetrelay.NewTimeoutPacketRelay] to decorate the underlying [packetrelay.PacketRelay] instead.
// Deprecated: Use [packetrelay.PacketListenerRelay.SetWriteIdleTimeout] instead.
func WithPacketListenerWriteIdleTimeout(timeout time.Duration) func(*PacketListenerProxy) error {
return func(p *PacketListenerProxy) error {
if timeout <= 0 {
Expand Down
20 changes: 9 additions & 11 deletions network/packetrelay/delegate_packet_relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ type DelegatePacketRelay interface {

// SetRelay updates the underlying PacketRelay to `relay`; `relay` must not be nil. After this function
// returns, all new PacketRelay calls will be forwarded to the `relay`. Existing associations will not be affected.
SetRelay(relay PacketRelay) error
SetRelay(relay PacketRelay)
}

var errInvalidRelay = errors.New("the underlying relay must not be nil")
var errNoRelay = errors.New("no relay configured")

// Compilation guard against interface implementation
var _ DelegatePacketRelay = (*delegatePacketRelay)(nil)
Expand All @@ -51,24 +51,22 @@ type delegatePacketRelay struct {
// NewDelegatePacketRelay creates a new [DelegatePacketRelay] that forwards calls to the `relay` [PacketRelay].
// The `relay` must not be nil.
func NewDelegatePacketRelay(relay PacketRelay) (DelegatePacketRelay, error) {
if relay == nil {
return nil, errInvalidRelay
}
dr := delegatePacketRelay{}
dr.relay.Store(&relay)
return &dr, nil
}

// NewAssociation implements PacketRelay.NewAssociation, and it will forward the call to the underlying PacketRelay.
// Returns an error if the underlying relay is nil.
func (p *delegatePacketRelay) NewAssociation() (PacketSender, PacketReceiver, error) {
return (*p.relay.Load()).NewAssociation()
relayPtr := p.relay.Load()
if relayPtr == nil || *relayPtr == nil {
return nil, nil, errNoRelay
}
return (*relayPtr).NewAssociation()
}

// SetRelay implements DelegatePacketRelay.SetRelay.
func (p *delegatePacketRelay) SetRelay(relay PacketRelay) error {
if relay == nil {
return errInvalidRelay
}
func (p *delegatePacketRelay) SetRelay(relay PacketRelay) {
p.relay.Store(&relay)
return nil
}
26 changes: 13 additions & 13 deletions network/packetrelay/delegate_packet_relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ func TestRelayCanBeUpdated(t *testing.T) {
require.Exactly(t, 0, newRelay.Count())

// SetRelay should not call NewAssociation
err = p.SetRelay(newRelay)
require.NoError(t, err)
p.SetRelay(newRelay)
require.Exactly(t, 1, defRelay.Count())
require.Exactly(t, 0, newRelay.Count())

Expand Down Expand Up @@ -84,8 +83,7 @@ func TestSetRelayRaceCondition(t *testing.T) {
setRelayTask.Add(1)
go func() {
for i := 0; !cancelSetRelay.Load(); i = (i + 1) % relaysCnt {
err := dr.SetRelay(relays[i])
require.NoError(t, err)
dr.SetRelay(relays[i])
}
setRelayTask.Done()
}()
Expand All @@ -112,20 +110,23 @@ func TestSetRelayRaceCondition(t *testing.T) {
require.Equal(t, expectedTotal, actualTotal)
}

// Make sure we cannot SetRelay to nil
// Make sure we can SetRelay to nil, which makes NewAssociation fail with errNoRelay
func TestSetRelayWithNilValue(t *testing.T) {
// must not initialize with nil
// initialization with nil does not return error, but calling NewAssociation will return errNoRelay
dr, err := NewDelegatePacketRelay(nil)
require.Error(t, err)
require.Nil(t, dr)
require.NoError(t, err)
require.NotNil(t, dr)
_, _, err = dr.NewAssociation()
require.ErrorIs(t, err, errNoRelay)

dr, err = NewDelegatePacketRelay(&sessionCountPacketRelay{})
require.NoError(t, err)
require.NotNil(t, dr)

// must not SetRelay to nil
err = dr.SetRelay(nil)
require.Error(t, err)
// SetRelay(nil) should not panic or error, but calling NewAssociation afterwards will return errNoRelay
dr.SetRelay(nil)
_, _, err = dr.NewAssociation()
require.ErrorIs(t, err, errNoRelay)
}

// Make sure we can SetRelay to different types
Expand All @@ -138,8 +139,7 @@ func TestSetRelayOfDifferentTypes(t *testing.T) {
require.NoError(t, err)

// SetRelay should not return error
err = p.SetRelay(newRelay)
require.NoError(t, err)
p.SetRelay(newRelay)

// NewAssociation should not go to defRelay
snd, rcv, err := p.NewAssociation()
Expand Down
Loading
Loading