Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions bind/shared_bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// TestSharedBindCreation tests basic creation and initialization
func TestSharedBindCreation(t *testing.T) {
// Create a UDP connection
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
Expand Down Expand Up @@ -44,7 +44,7 @@ func TestSharedBindCreation(t *testing.T) {

// TestSharedBindReferenceCount tests reference counting
func TestSharedBindReferenceCount(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
Expand Down Expand Up @@ -82,7 +82,7 @@ func TestSharedBindReferenceCount(t *testing.T) {
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
func TestSharedBindWriteToUDP(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
senderConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
Expand All @@ -94,7 +94,7 @@ func TestSharedBindWriteToUDP(t *testing.T) {
defer senderBind.Close()

// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
receiverConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
Expand Down Expand Up @@ -129,7 +129,7 @@ func TestSharedBindWriteToUDP(t *testing.T) {
// TestSharedBindConcurrentWrites tests thread-safety
func TestSharedBindConcurrentWrites(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
senderConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
Expand All @@ -141,7 +141,7 @@ func TestSharedBindConcurrentWrites(t *testing.T) {
defer senderBind.Close()

// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
receiverConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
Expand Down Expand Up @@ -170,7 +170,7 @@ func TestSharedBindConcurrentWrites(t *testing.T) {

// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
func TestSharedBindWireGuardInterface(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
Expand Down Expand Up @@ -210,7 +210,7 @@ func TestSharedBindWireGuardInterface(t *testing.T) {
// TestSharedBindSend tests the Send method with WireGuard endpoints
func TestSharedBindSend(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
senderConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
Expand All @@ -222,7 +222,7 @@ func TestSharedBindSend(t *testing.T) {
defer senderBind.Close()

// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
receiverConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
Expand Down Expand Up @@ -258,7 +258,7 @@ func TestSharedBindSend(t *testing.T) {
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
func TestSharedBindMultipleUsers(t *testing.T) {
// Create shared bind
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
Expand All @@ -272,7 +272,7 @@ func TestSharedBindMultipleUsers(t *testing.T) {
sharedBind.AddRef()

// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
receiverConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
Expand Down Expand Up @@ -360,7 +360,7 @@ func TestEndpoint(t *testing.T) {

// TestParseEndpoint tests the ParseEndpoint method
func TestParseEndpoint(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
Expand Down Expand Up @@ -426,7 +426,7 @@ func TestParseEndpoint(t *testing.T) {
// TestNetstackRouting tests that packets from netstack endpoints are routed back through netstack
func TestNetstackRouting(t *testing.T) {
// Create the SharedBind with a physical UDP socket
physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
physicalConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create physical UDP connection: %v", err)
}
Expand All @@ -438,7 +438,7 @@ func TestNetstackRouting(t *testing.T) {
defer sharedBind.Close()

// Create a mock "netstack" connection (just another UDP socket for testing)
netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
netstackConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create netstack UDP connection: %v", err)
}
Expand All @@ -448,7 +448,7 @@ func TestNetstackRouting(t *testing.T) {
sharedBind.SetNetstackConn(netstackConn)

// Create a "client" that would receive packets
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
clientConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create client UDP connection: %v", err)
}
Expand Down Expand Up @@ -494,7 +494,7 @@ func TestNetstackRouting(t *testing.T) {
// TestSocketRouting tests that packets from socket endpoints are routed through socket
func TestSocketRouting(t *testing.T) {
// Create the SharedBind with a physical UDP socket
physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
physicalConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create physical UDP connection: %v", err)
}
Expand All @@ -506,7 +506,7 @@ func TestSocketRouting(t *testing.T) {
defer sharedBind.Close()

// Create a mock "netstack" connection
netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
netstackConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create netstack UDP connection: %v", err)
}
Expand All @@ -516,7 +516,7 @@ func TestSocketRouting(t *testing.T) {
sharedBind.SetNetstackConn(netstackConn)

// Create a "client" that would receive packets (this simulates a hole-punched client)
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
clientConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)))
if err != nil {
t.Fatalf("Failed to create client UDP connection: %v", err)
}
Expand Down
18 changes: 7 additions & 11 deletions clients/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
}

// Create shared UDP socket for both holepunch and WireGuard
localAddr := &net.UDPAddr{
Port: int(port),
IP: net.IPv4zero,
}
addrPort := netip.AddrPortFrom(netip.IPv4Unspecified(), port)
localAddr := net.UDPAddrFromAddrPort(addrPort)

udpConn, err := net.ListenUDP("udp", localAddr)
if err != nil {
Expand Down Expand Up @@ -317,17 +315,15 @@ func (s *WireGuardService) StartDirectUDPRelay(tunnelIP string) error {

s.directRelayStop = make(chan struct{})

// Parse the tunnel IP
ip := net.ParseIP(tunnelIP)
if ip == nil {
// Parse the tunnel IP using netip
addr, err := netip.ParseAddr(tunnelIP)
if err != nil {
return fmt.Errorf("invalid tunnel IP: %s", tunnelIP)
}

// Listen on the main tunnel netstack for UDP packets destined for the clients' WireGuard port
listenAddr := &net.UDPAddr{
IP: ip,
Port: int(s.Port),
}
addrPort := netip.AddrPortFrom(addr, s.Port)
listenAddr := net.UDPAddrFromAddrPort(addrPort)

// Use othertnet (main tunnel's netstack) to listen
listener, err := s.othertnet.ListenUDP(listenAddr)
Expand Down
10 changes: 6 additions & 4 deletions docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -124,14 +125,15 @@ func IsWithinHostNetwork(socketPath string, targetAddress string, targetPort int
return false, err
}

// Determine if given an IP address
var parsedTargetAddressIp = net.ParseIP(targetAddress)
// Determine if given an IP address using netip
parsedTargetAddressIp, parseErr := netip.ParseAddr(targetAddress)
isIPAddress := parseErr == nil

// If we can find the passed hostname/IP address in the networks or as the container name, it is valid and can add it
for _, c := range containers {
for _, network := range c.Networks {
// If the target address is not an IP address, use the container name
if parsedTargetAddressIp == nil {
if !isIPAddress {
if c.Name == targetAddress {
for _, port := range c.Ports {
if port.PublicPort == targetPort || port.PrivatePort == targetPort {
Expand All @@ -141,7 +143,7 @@ func IsWithinHostNetwork(socketPath string, targetAddress string, targetPort int
}
} else {
//If the IP address matches, check the ports being mapped too
if network.IPAddress == targetAddress {
if network.IPAddress == parsedTargetAddressIp.String() {
for _, port := range c.Ports {
if port.PublicPort == targetPort || port.PrivatePort == targetPort {
return true, nil
Expand Down
46 changes: 26 additions & 20 deletions holepunch/holepunch.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"encoding/json"
"fmt"
"net"
"strconv"
"net/netip"
"sync"
"time"

Expand Down Expand Up @@ -287,14 +287,16 @@ func (m *Manager) TriggerHolePunch() error {
continue
}

serverAddr := net.JoinHostPort(host, strconv.Itoa(int(exitNode.RelayPort)))
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
// Parse the resolved host as netip.Addr
hostAddr, err := netip.ParseAddr(host)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
logger.Error("Failed to parse resolved address %s: %v", host, err)
continue
}

if err := m.sendHolePunch(remoteAddr, exitNode.PublicKey); err != nil {
addrPort := netip.AddrPortFrom(hostAddr, exitNode.RelayPort)

if err := m.sendHolePunch(addrPort, exitNode.PublicKey); err != nil {
logger.Warn("Failed to send on-demand hole punch to %s: %v", exitNode.Endpoint, err)
continue
}
Expand Down Expand Up @@ -377,9 +379,9 @@ func (m *Manager) runMultipleExitNodes() {

// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
remoteAddrPort netip.AddrPort
publicKey string
endpointName string
}

resolveNodes := func() []resolvedExitNode {
Expand All @@ -398,19 +400,21 @@ func (m *Manager) runMultipleExitNodes() {
continue
}

serverAddr := net.JoinHostPort(host, strconv.Itoa(int(exitNode.RelayPort)))
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
// Parse the resolved host as netip.Addr
hostAddr, err := netip.ParseAddr(host)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
logger.Error("Failed to parse resolved address %s: %v", host, err)
continue
}

addrPort := netip.AddrPortFrom(hostAddr, exitNode.RelayPort)

resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
remoteAddrPort: addrPort,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Debug("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
logger.Debug("Resolved exit node: %s -> %s", exitNode.Endpoint, addrPort.String())
}
return resolvedNodes
}
Expand All @@ -422,7 +426,7 @@ func (m *Manager) runMultipleExitNodes() {
} else {
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
if err := m.sendHolePunch(node.remoteAddrPort, node.publicKey); err != nil {
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
}
}
Expand Down Expand Up @@ -457,15 +461,15 @@ func (m *Manager) runMultipleExitNodes() {
ticker.Reset(m.sendHolepunchInterval)
// Send immediate hole punch to newly resolved nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
if err := m.sendHolePunch(node.remoteAddrPort, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
case <-ticker.C:
// Send hole punch to all exit nodes (if any are available)
if len(resolvedNodes) > 0 {
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
if err := m.sendHolePunch(node.remoteAddrPort, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
Expand All @@ -487,7 +491,7 @@ func (m *Manager) runMultipleExitNodes() {
}

// sendHolePunch sends an encrypted hole punch packet using the shared bind
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
func (m *Manager) sendHolePunch(remoteAddrPort netip.AddrPort, serverPubKey string) error {
m.mu.Lock()
token := m.token
ID := m.ID
Expand Down Expand Up @@ -537,12 +541,14 @@ func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) er
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
}

// Convert netip.AddrPort to *net.UDPAddr for SharedBind interface
remoteAddr := net.UDPAddrFromAddrPort(remoteAddrPort)
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to write to UDP: %w", err)
}

logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddrPort.String(), string(jsonData))

return nil
}
Expand Down
Loading