diff --git a/bind/shared_bind_test.go b/bind/shared_bind_test.go index e463f88..695d941 100644 --- a/bind/shared_bind_test.go +++ b/bind/shared_bind_test.go @@ -15,7 +15,7 @@ import ( // TestSharedBindCreation tests basic creation and initialization func TestSharedBindCreation(t *testing.T) { // Create a UDP connection - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create UDP connection: %v", err) } @@ -44,7 +44,7 @@ func TestSharedBindCreation(t *testing.T) { // TestSharedBindReferenceCount tests reference counting func TestSharedBindReferenceCount(t *testing.T) { - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create UDP connection: %v", err) } @@ -82,7 +82,7 @@ func TestSharedBindReferenceCount(t *testing.T) { // TestSharedBindWriteToUDP tests the WriteToUDP functionality func TestSharedBindWriteToUDP(t *testing.T) { // Create sender - senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + senderConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create sender UDP connection: %v", err) } @@ -94,7 +94,7 @@ func TestSharedBindWriteToUDP(t *testing.T) { defer senderBind.Close() // Create receiver - receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + receiverConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create receiver UDP connection: %v", err) } @@ -129,7 +129,7 @@ func TestSharedBindWriteToUDP(t *testing.T) { // TestSharedBindConcurrentWrites tests thread-safety func TestSharedBindConcurrentWrites(t *testing.T) { // Create sender - senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + senderConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create sender UDP connection: %v", err) } @@ -141,7 +141,7 @@ func TestSharedBindConcurrentWrites(t *testing.T) { defer senderBind.Close() // Create receiver - receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + receiverConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create receiver UDP connection: %v", err) } @@ -170,7 +170,7 @@ func TestSharedBindConcurrentWrites(t *testing.T) { // TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation func TestSharedBindWireGuardInterface(t *testing.T) { - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create UDP connection: %v", err) } @@ -210,7 +210,7 @@ func TestSharedBindWireGuardInterface(t *testing.T) { // TestSharedBindSend tests the Send method with WireGuard endpoints func TestSharedBindSend(t *testing.T) { // Create sender - senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + senderConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create sender UDP connection: %v", err) } @@ -222,7 +222,7 @@ func TestSharedBindSend(t *testing.T) { defer senderBind.Close() // Create receiver - receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + receiverConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create receiver UDP connection: %v", err) } @@ -258,7 +258,7 @@ func TestSharedBindSend(t *testing.T) { // TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind func TestSharedBindMultipleUsers(t *testing.T) { // Create shared bind - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create UDP connection: %v", err) } @@ -272,7 +272,7 @@ func TestSharedBindMultipleUsers(t *testing.T) { sharedBind.AddRef() // Create receiver - receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + receiverConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create receiver UDP connection: %v", err) } @@ -360,7 +360,7 @@ func TestEndpoint(t *testing.T) { // TestParseEndpoint tests the ParseEndpoint method func TestParseEndpoint(t *testing.T) { - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + udpConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create UDP connection: %v", err) } @@ -426,7 +426,7 @@ func TestParseEndpoint(t *testing.T) { // TestNetstackRouting tests that packets from netstack endpoints are routed back through netstack func TestNetstackRouting(t *testing.T) { // Create the SharedBind with a physical UDP socket - physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + physicalConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create physical UDP connection: %v", err) } @@ -438,7 +438,7 @@ func TestNetstackRouting(t *testing.T) { defer sharedBind.Close() // Create a mock "netstack" connection (just another UDP socket for testing) - netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + netstackConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create netstack UDP connection: %v", err) } @@ -448,7 +448,7 @@ func TestNetstackRouting(t *testing.T) { sharedBind.SetNetstackConn(netstackConn) // Create a "client" that would receive packets - clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + clientConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create client UDP connection: %v", err) } @@ -494,7 +494,7 @@ func TestNetstackRouting(t *testing.T) { // TestSocketRouting tests that packets from socket endpoints are routed through socket func TestSocketRouting(t *testing.T) { // Create the SharedBind with a physical UDP socket - physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + physicalConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create physical UDP connection: %v", err) } @@ -506,7 +506,7 @@ func TestSocketRouting(t *testing.T) { defer sharedBind.Close() // Create a mock "netstack" connection - netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + netstackConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create netstack UDP connection: %v", err) } @@ -516,7 +516,7 @@ func TestSocketRouting(t *testing.T) { sharedBind.SetNetstackConn(netstackConn) // Create a "client" that would receive packets (this simulates a hole-punched client) - clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + clientConn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))) if err != nil { t.Fatalf("Failed to create client UDP connection: %v", err) } diff --git a/clients/clients.go b/clients/clients.go index 05ed3cf..3134b02 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -124,10 +124,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string } // Create shared UDP socket for both holepunch and WireGuard - localAddr := &net.UDPAddr{ - Port: int(port), - IP: net.IPv4zero, - } + addrPort := netip.AddrPortFrom(netip.IPv4Unspecified(), port) + localAddr := net.UDPAddrFromAddrPort(addrPort) udpConn, err := net.ListenUDP("udp", localAddr) if err != nil { @@ -317,17 +315,15 @@ func (s *WireGuardService) StartDirectUDPRelay(tunnelIP string) error { s.directRelayStop = make(chan struct{}) - // Parse the tunnel IP - ip := net.ParseIP(tunnelIP) - if ip == nil { + // Parse the tunnel IP using netip + addr, err := netip.ParseAddr(tunnelIP) + if err != nil { return fmt.Errorf("invalid tunnel IP: %s", tunnelIP) } // Listen on the main tunnel netstack for UDP packets destined for the clients' WireGuard port - listenAddr := &net.UDPAddr{ - IP: ip, - Port: int(s.Port), - } + addrPort := netip.AddrPortFrom(addr, s.Port) + listenAddr := net.UDPAddrFromAddrPort(addrPort) // Use othertnet (main tunnel's netstack) to listen listener, err := s.othertnet.ListenUDP(listenAddr) diff --git a/docker/docker.go b/docker/docker.go index 281c594..18a7e03 100644 --- a/docker/docker.go +++ b/docker/docker.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "os" "strconv" "strings" @@ -124,14 +125,15 @@ func IsWithinHostNetwork(socketPath string, targetAddress string, targetPort int return false, err } - // Determine if given an IP address - var parsedTargetAddressIp = net.ParseIP(targetAddress) + // Determine if given an IP address using netip + parsedTargetAddressIp, parseErr := netip.ParseAddr(targetAddress) + isIPAddress := parseErr == nil // If we can find the passed hostname/IP address in the networks or as the container name, it is valid and can add it for _, c := range containers { for _, network := range c.Networks { // If the target address is not an IP address, use the container name - if parsedTargetAddressIp == nil { + if !isIPAddress { if c.Name == targetAddress { for _, port := range c.Ports { if port.PublicPort == targetPort || port.PrivatePort == targetPort { @@ -141,7 +143,7 @@ func IsWithinHostNetwork(socketPath string, targetAddress string, targetPort int } } else { //If the IP address matches, check the ports being mapped too - if network.IPAddress == targetAddress { + if network.IPAddress == parsedTargetAddressIp.String() { for _, port := range c.Ports { if port.PublicPort == targetPort || port.PrivatePort == targetPort { return true, nil diff --git a/holepunch/holepunch.go b/holepunch/holepunch.go index 85679a9..e0c0fa5 100644 --- a/holepunch/holepunch.go +++ b/holepunch/holepunch.go @@ -4,7 +4,7 @@ import ( "encoding/json" "fmt" "net" - "strconv" + "net/netip" "sync" "time" @@ -287,14 +287,16 @@ func (m *Manager) TriggerHolePunch() error { continue } - serverAddr := net.JoinHostPort(host, strconv.Itoa(int(exitNode.RelayPort))) - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + // Parse the resolved host as netip.Addr + hostAddr, err := netip.ParseAddr(host) if err != nil { - logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + logger.Error("Failed to parse resolved address %s: %v", host, err) continue } - if err := m.sendHolePunch(remoteAddr, exitNode.PublicKey); err != nil { + addrPort := netip.AddrPortFrom(hostAddr, exitNode.RelayPort) + + if err := m.sendHolePunch(addrPort, exitNode.PublicKey); err != nil { logger.Warn("Failed to send on-demand hole punch to %s: %v", exitNode.Endpoint, err) continue } @@ -377,9 +379,9 @@ func (m *Manager) runMultipleExitNodes() { // Resolve all endpoints upfront type resolvedExitNode struct { - remoteAddr *net.UDPAddr - publicKey string - endpointName string + remoteAddrPort netip.AddrPort + publicKey string + endpointName string } resolveNodes := func() []resolvedExitNode { @@ -398,19 +400,21 @@ func (m *Manager) runMultipleExitNodes() { continue } - serverAddr := net.JoinHostPort(host, strconv.Itoa(int(exitNode.RelayPort))) - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + // Parse the resolved host as netip.Addr + hostAddr, err := netip.ParseAddr(host) if err != nil { - logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + logger.Error("Failed to parse resolved address %s: %v", host, err) continue } + addrPort := netip.AddrPortFrom(hostAddr, exitNode.RelayPort) + resolvedNodes = append(resolvedNodes, resolvedExitNode{ - remoteAddr: remoteAddr, - publicKey: exitNode.PublicKey, - endpointName: exitNode.Endpoint, + remoteAddrPort: addrPort, + publicKey: exitNode.PublicKey, + endpointName: exitNode.Endpoint, }) - logger.Debug("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) + logger.Debug("Resolved exit node: %s -> %s", exitNode.Endpoint, addrPort.String()) } return resolvedNodes } @@ -422,7 +426,7 @@ func (m *Manager) runMultipleExitNodes() { } else { // Send initial hole punch to all exit nodes for _, node := range resolvedNodes { - if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + if err := m.sendHolePunch(node.remoteAddrPort, node.publicKey); err != nil { logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err) } } @@ -457,7 +461,7 @@ func (m *Manager) runMultipleExitNodes() { ticker.Reset(m.sendHolepunchInterval) // Send immediate hole punch to newly resolved nodes for _, node := range resolvedNodes { - if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + if err := m.sendHolePunch(node.remoteAddrPort, node.publicKey); err != nil { logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err) } } @@ -465,7 +469,7 @@ func (m *Manager) runMultipleExitNodes() { // Send hole punch to all exit nodes (if any are available) if len(resolvedNodes) > 0 { for _, node := range resolvedNodes { - if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + if err := m.sendHolePunch(node.remoteAddrPort, node.publicKey); err != nil { logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err) } } @@ -487,7 +491,7 @@ func (m *Manager) runMultipleExitNodes() { } // sendHolePunch sends an encrypted hole punch packet using the shared bind -func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error { +func (m *Manager) sendHolePunch(remoteAddrPort netip.AddrPort, serverPubKey string) error { m.mu.Lock() token := m.token ID := m.ID @@ -537,12 +541,14 @@ func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) er return fmt.Errorf("failed to marshal encrypted payload: %w", err) } + // Convert netip.AddrPort to *net.UDPAddr for SharedBind interface + remoteAddr := net.UDPAddrFromAddrPort(remoteAddrPort) _, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr) if err != nil { return fmt.Errorf("failed to write to UDP: %w", err) } - logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData)) + logger.Debug("Sent UDP hole punch to %s: %s", remoteAddrPort.String(), string(jsonData)) return nil } diff --git a/holepunch/tester.go b/holepunch/tester.go index 9fb83df..5306c86 100644 --- a/holepunch/tester.go +++ b/holepunch/tester.go @@ -43,7 +43,7 @@ func DefaultTestOptions() TestConnectionOptions { // cachedAddr holds a cached resolved UDP address type cachedAddr struct { - addr *net.UDPAddr + addrPort netip.AddrPort resolvedAt time.Time } @@ -157,7 +157,7 @@ func (t *HolepunchTester) Stop() { } // resolveEndpoint resolves an endpoint to a UDP address, using cache when possible -func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error) { +func (t *HolepunchTester) resolveEndpoint(endpoint string) (netip.AddrPort, error) { // Check cache first t.addrCacheMu.RLock() cached, ok := t.addrCache[endpoint] @@ -165,7 +165,7 @@ func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error) t.addrCacheMu.RUnlock() if ok && time.Since(cached.resolvedAt) < ttl { - return cached.addr, nil + return cached.addrPort, nil } // Resolve the endpoint @@ -174,25 +174,26 @@ func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error) host = endpoint } - _, _, err = net.SplitHostPort(host) + // Parse as AddrPort - if it fails, try adding default port + addrPort, err := netip.ParseAddrPort(host) if err != nil { - host = net.JoinHostPort(host, "21820") - } - - remoteAddr, err := net.ResolveUDPAddr("udp", host) - if err != nil { - return nil, fmt.Errorf("failed to resolve UDP address %s: %w", host, err) + // Try parsing just the address and add default port + addr, addrErr := netip.ParseAddr(host) + if addrErr != nil { + return netip.AddrPort{}, fmt.Errorf("failed to parse address %s: %w", host, addrErr) + } + addrPort = netip.AddrPortFrom(addr, 21820) } // Cache the result t.addrCacheMu.Lock() t.addrCache[endpoint] = &cachedAddr{ - addr: remoteAddr, + addrPort: addrPort, resolvedAt: time.Now(), } t.addrCacheMu.Unlock() - return remoteAddr, nil + return addrPort, nil } // InvalidateCache removes a specific endpoint from the address cache @@ -255,7 +256,7 @@ func (t *HolepunchTester) TestEndpoint(endpoint string, timeout time.Duration) T } // Resolve the endpoint (using cache) - remoteAddr, err := t.resolveEndpoint(endpoint) + remoteAddrPort, err := t.resolveEndpoint(endpoint) if err != nil { result.Error = err return result @@ -283,7 +284,8 @@ func (t *HolepunchTester) TestEndpoint(endpoint string, timeout time.Duration) T copy(request, bind.MagicTestRequest) copy(request[len(bind.MagicTestRequest):], randomData) - // Send the test packet + // Send the test packet - convert netip.AddrPort to *net.UDPAddr for SharedBind + remoteAddr := net.UDPAddrFromAddrPort(remoteAddrPort) _, err = sharedBind.WriteToUDP(request, remoteAddr) if err != nil { t.pendingRequests.Delete(key) @@ -337,15 +339,16 @@ func TestConnectionWithBind(sharedBind *bind.SharedBind, endpoint string, opts * host = endpoint } - _, _, err = net.SplitHostPort(host) + // Parse as AddrPort - if it fails, try adding default port + remoteAddrPort, err := netip.ParseAddrPort(host) if err != nil { - host = net.JoinHostPort(host, "21820") - } - - remoteAddr, err := net.ResolveUDPAddr("udp", host) - if err != nil { - result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err) - return result + // Try parsing just the address and add default port + addr, addrErr := netip.ParseAddr(host) + if addrErr != nil { + result.Error = fmt.Errorf("failed to parse address %s: %w", host, addrErr) + return result + } + remoteAddrPort = netip.AddrPortFrom(addr, 21820) } // Generate random data for the test packet @@ -367,6 +370,9 @@ func TestConnectionWithBind(sharedBind *bind.SharedBind, endpoint string, opts * return result } + // Convert netip.AddrPort to *net.UDPAddr for SharedBind + remoteAddr := net.UDPAddrFromAddrPort(remoteAddrPort) + attempts := opts.Retries + 1 for attempt := 0; attempt < attempts; attempt++ { if attempt > 0 { diff --git a/netstack2/handlers.go b/netstack2/handlers.go index 75c58b2..0492f40 100644 --- a/netstack2/handlers.go +++ b/netstack2/handlers.go @@ -540,7 +540,12 @@ func (h *ICMPHandler) sendAndReceiveICMP(conn *icmp.PacketConn, actualDstIP stri var writeErr error if isUnprivileged { // For unprivileged ICMP, use UDP-style addressing - udpAddr := &net.UDPAddr{IP: net.ParseIP(actualDstIP)} + dstAddr, err := netip.ParseAddr(actualDstIP) + if err != nil { + logger.Debug("ICMP Handler: Failed to parse destination %s: %v", actualDstIP, err) + return false + } + udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(dstAddr, 0)) logger.Debug("ICMP Handler: Sending ping to %s (unprivileged)", udpAddr.String()) conn.SetDeadline(time.Now().Add(icmpTimeout)) _, writeErr = conn.WriteTo(msgBytes, udpAddr) diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 388a3d1..4876e42 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -253,8 +253,7 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro // Use the first resolved IP address ip := ips[0] - if ip4 := ip.To4(); ip4 != nil { - addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]}) + if addr, ok := netip.AddrFromSlice(ip); ok && addr.Is4() { logger.Debug("Resolved %s to %s", rewriteTo, addr) return addr, nil } diff --git a/netstack2/tun.go b/netstack2/tun.go index e743f1e..2c1b5f7 100644 --- a/netstack2/tun.go +++ b/netstack2/tun.go @@ -284,7 +284,10 @@ func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.T if addr == nil { return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) } - ip, _ := netip.AddrFromSlice(addr.IP) + ip, ok := netip.AddrFromSlice(addr.IP) + if !ok || !ip.IsValid() { + return nil, fmt.Errorf("invalid TCP address IP: %v", addr.IP) + } return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) } @@ -297,7 +300,10 @@ func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { return net.DialTCPAddrPort(netip.AddrPort{}) } - ip, _ := netip.AddrFromSlice(addr.IP) + ip, ok := netip.AddrFromSlice(addr.IP) + if !ok || !ip.IsValid() { + return nil, fmt.Errorf("invalid TCP address IP: %v", addr.IP) + } return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) } @@ -310,7 +316,10 @@ func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { if addr == nil { return net.ListenTCPAddrPort(netip.AddrPort{}) } - ip, _ := netip.AddrFromSlice(addr.IP) + ip, ok := netip.AddrFromSlice(addr.IP) + if !ok || !ip.IsValid() { + return nil, fmt.Errorf("invalid TCP address IP: %v", addr.IP) + } return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) } @@ -337,11 +346,17 @@ func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { var la, ra netip.AddrPort if laddr != nil { - ip, _ := netip.AddrFromSlice(laddr.IP) + ip, ok := netip.AddrFromSlice(laddr.IP) + if !ok || !ip.IsValid() { + return nil, fmt.Errorf("invalid UDP local address IP: %v", laddr.IP) + } la = netip.AddrPortFrom(ip, uint16(laddr.Port)) } if raddr != nil { - ip, _ := netip.AddrFromSlice(raddr.IP) + ip, ok := netip.AddrFromSlice(raddr.IP) + if !ok || !ip.IsValid() { + return nil, fmt.Errorf("invalid UDP remote address IP: %v", raddr.IP) + } ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) } return net.DialUDPAddrPort(la, ra) @@ -505,7 +520,11 @@ func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { case *PingAddr: na = v.addr case *net.IPAddr: - na, _ = netip.AddrFromSlice(v.IP) + var ok bool + na, ok = netip.AddrFromSlice(v.IP) + if !ok || !na.IsValid() { + return 0, fmt.Errorf("ping write: invalid IP address: %v", v.IP) + } default: return 0, fmt.Errorf("ping write: wrong net.Addr type") } @@ -550,8 +569,11 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) } - remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) - return res.Count, &PingAddr{remoteAddr}, nil + remoteAddr, ok := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) + if !ok || !remoteAddr.IsValid() { + return int(res.Count), nil, fmt.Errorf("ping read: invalid remote address from stack") + } + return int(res.Count), &PingAddr{remoteAddr}, nil } func (pc *PingConn) Read(p []byte) (n int, err error) { diff --git a/network/interface.go b/network/interface.go index 70556be..a97410d 100644 --- a/network/interface.go +++ b/network/interface.go @@ -3,6 +3,7 @@ package network import ( "fmt" "net" + "net/netip" "os/exec" "regexp" "runtime" @@ -17,15 +18,15 @@ import ( func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error { logger.Info("The tunnel IP is: %s", tunnelIp) - // Parse the IP address and network - ip, ipNet, err := net.ParseCIDR(tunnelIp) + // Parse the IP address and network using netip + prefix, err := netip.ParsePrefix(tunnelIp) if err != nil { return fmt.Errorf("invalid IP address: %v", err) } // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) - mask := net.IP(ipNet.Mask).String() - destinationAddress := ip.String() + mask := prefixToMaskString(prefix) + destinationAddress := prefix.Addr().String() logger.Debug("The destination address is: %s", destinationAddress) @@ -39,11 +40,11 @@ func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error { switch runtime.GOOS { case "linux": - return configureLinux(interfaceName, ip, ipNet) + return configureLinux(interfaceName, prefix) case "darwin": - return configureDarwin(interfaceName, ip, ipNet) + return configureDarwin(interfaceName, prefix) case "windows": - return configureWindows(interfaceName, ip, ipNet) + return configureWindows(interfaceName, prefix) case "android": return nil case "ios": @@ -53,8 +54,19 @@ func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error { return nil } +// prefixToMaskString converts a netip.Prefix to a dotted decimal subnet mask string +func prefixToMaskString(prefix netip.Prefix) string { + bits := prefix.Bits() + if prefix.Addr().Is4() { + mask := net.CIDRMask(bits, 32) + return net.IP(mask).String() + } + // For IPv6, return the prefix length as we don't typically use dotted decimal + return fmt.Sprintf("/%d", bits) +} + // waitForInterfaceUp polls the network interface until it's up or times out -func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error { +func waitForInterfaceUp(interfaceName string, expectedIP netip.Addr, timeout time.Duration) error { logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP) deadline := time.Now().Add(timeout) pollInterval := 500 * time.Millisecond @@ -70,9 +82,14 @@ func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Du if err == nil { for _, addr := range addrs { ipNet, ok := addr.(*net.IPNet) - if ok && ipNet.IP.Equal(expectedIP) { - logger.Info("Interface %s is up with correct IP", interfaceName) - return nil // Interface is up with correct IP + if ok { + if ifaceAddr, ok := netip.AddrFromSlice(ipNet.IP); ok { + // Unmap IPv4-mapped IPv6 addresses for comparison + if ifaceAddr.Unmap() == expectedIP.Unmap() { + logger.Info("Interface %s is up with correct IP", interfaceName) + return nil // Interface is up with correct IP + } + } } } logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName) @@ -114,13 +131,13 @@ func FindUnusedUTUN() (string, error) { return "", fmt.Errorf("no unused utun interface found") } -func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { +func configureDarwin(interfaceName string, prefix netip.Prefix) error { logger.Info("Configuring darwin interface: %s", interfaceName) - prefix, _ := ipNet.Mask.Size() - ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix) + ipStr := prefix.String() + ip := prefix.Addr().String() - cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias") + cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip, "alias") logger.Info("Running command: %v", cmd) out, err := cmd.CombinedOutput() @@ -140,19 +157,19 @@ func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { return nil } -func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { +func configureLinux(interfaceName string, prefix netip.Prefix) error { // Get the interface link, err := netlink.LinkByName(interfaceName) if err != nil { return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) } + // Convert netip.Prefix to net.IPNet for netlink library + ipNet := prefixToIPNet(prefix) + // Create the IP address attributes addr := &netlink.Addr{ - IPNet: &net.IPNet{ - IP: ip, - Mask: ipNet.Mask, - }, + IPNet: ipNet, } // Add the IP address to the interface @@ -167,3 +184,21 @@ func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { return nil } + +// prefixToIPNet converts a netip.Prefix to a *net.IPNet for compatibility with netlink +func prefixToIPNet(prefix netip.Prefix) *net.IPNet { + addr := prefix.Addr() + bits := prefix.Bits() + if addr.Is4() { + ip := addr.As4() + return &net.IPNet{ + IP: net.IP(ip[:]), + Mask: net.CIDRMask(bits, 32), + } + } + ip := addr.As16() + return &net.IPNet{ + IP: net.IP(ip[:]), + Mask: net.CIDRMask(bits, 128), + } +} diff --git a/network/interface_notwindows.go b/network/interface_notwindows.go index 5d15ace..c568752 100644 --- a/network/interface_notwindows.go +++ b/network/interface_notwindows.go @@ -4,9 +4,9 @@ package network import ( "fmt" - "net" + "net/netip" ) -func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { +func configureWindows(interfaceName string, prefix netip.Prefix) error { return fmt.Errorf("configureWindows called on non-Windows platform") } diff --git a/network/interface_windows.go b/network/interface_windows.go index 966486b..7dc4082 100644 --- a/network/interface_windows.go +++ b/network/interface_windows.go @@ -11,7 +11,7 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { +func configureWindows(interfaceName string, prefix netip.Prefix) error { logger.Info("Configuring Windows interface: %s", interfaceName) // Get the LUID for the interface @@ -25,23 +25,6 @@ func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err) } - // Create the IP address prefix - maskBits, _ := ipNet.Mask.Size() - - // Ensure we convert to the correct IP version (IPv4 vs IPv6) - var addr netip.Addr - if ip4 := ip.To4(); ip4 != nil { - // IPv4 address - addr, _ = netip.AddrFromSlice(ip4) - } else { - // IPv6 address - addr, _ = netip.AddrFromSlice(ip) - } - if !addr.IsValid() { - return fmt.Errorf("failed to convert IP address") - } - prefix := netip.PrefixFrom(addr, maskBits) - // Add the IP address to the interface logger.Info("Adding IP address %s to interface %s", prefix.String(), interfaceName) err = luid.AddIPAddress(prefix) @@ -54,7 +37,7 @@ func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { // need this step anymore as far as I can tell. // // Wait for the interface to be up and have the correct IP - // err = waitForInterfaceUp(interfaceName, ip, 30*time.Second) + // err = waitForInterfaceUp(interfaceName, prefix.Addr(), 30*time.Second) // if err != nil { // return fmt.Errorf("interface did not come up within timeout: %v", err) // } diff --git a/network/route.go b/network/route.go index 8aae063..72d9b91 100644 --- a/network/route.go +++ b/network/route.go @@ -3,6 +3,7 @@ package network import ( "fmt" "net" + "net/netip" "os/exec" "runtime" "strings" @@ -59,24 +60,34 @@ func LinuxAddRoute(destination string, gateway string, interfaceName string) err return nil } - // Parse destination CIDR - _, ipNet, err := net.ParseCIDR(destination) + // Parse destination CIDR using netip + prefix, err := netip.ParsePrefix(destination) if err != nil { return fmt.Errorf("invalid destination address: %v", err) } + // Convert to net.IPNet for netlink library + ipNet := prefixToIPNet(prefix) + // Create route route := &netlink.Route{ Dst: ipNet, } if gateway != "" { - // Route with specific gateway - gw := net.ParseIP(gateway) - if gw == nil { + // Route with specific gateway using netip + gwAddr, err := netip.ParseAddr(gateway) + if err != nil { return fmt.Errorf("invalid gateway address: %s", gateway) } - route.Gw = gw + // Convert netip.Addr to net.IP for netlink + if gwAddr.Is4() { + ip4 := gwAddr.As4() + route.Gw = net.IP(ip4[:]) + } else { + ip6 := gwAddr.As16() + route.Gw = net.IP(ip6[:]) + } logger.Info("Adding route to %s via gateway %s", destination, gateway) } else if interfaceName != "" { // Route via interface @@ -103,12 +114,15 @@ func LinuxRemoveRoute(destination string) error { return nil } - // Parse destination CIDR - _, ipNet, err := net.ParseCIDR(destination) + // Parse destination CIDR using netip + prefix, err := netip.ParsePrefix(destination) if err != nil { return fmt.Errorf("invalid destination address: %v", err) } + // Convert to net.IPNet for netlink library + ipNet := prefixToIPNet(prefix) + // Create route to delete route := &netlink.Route{ Dst: ipNet, @@ -165,15 +179,15 @@ func RemoveRouteForServerIP(serverIP string, interfaceName string) error { } func AddRouteForNetworkConfig(destination string) error { - // Parse the subnet to extract IP and mask - _, ipNet, err := net.ParseCIDR(destination) + // Parse the subnet to extract IP and mask using netip + prefix, err := netip.ParsePrefix(destination) if err != nil { return fmt.Errorf("failed to parse subnet %s: %v", destination, err) } // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) - mask := net.IP(ipNet.Mask).String() - destinationAddress := ipNet.IP.String() + mask := prefixToMaskString(prefix) + destinationAddress := prefix.Masked().Addr().String() AddIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) @@ -181,15 +195,15 @@ func AddRouteForNetworkConfig(destination string) error { } func RemoveRouteForNetworkConfig(destination string) error { - // Parse the subnet to extract IP and mask - _, ipNet, err := net.ParseCIDR(destination) + // Parse the subnet to extract IP and mask using netip + prefix, err := netip.ParsePrefix(destination) if err != nil { return fmt.Errorf("failed to parse subnet %s: %v", destination, err) } // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) - mask := net.IP(ipNet.Mask).String() - destinationAddress := ipNet.IP.String() + mask := prefixToMaskString(prefix) + destinationAddress := prefix.Masked().Addr().String() RemoveIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) diff --git a/network/route_windows.go b/network/route_windows.go index ba613b6..232c919 100644 --- a/network/route_windows.go +++ b/network/route_windows.go @@ -17,29 +17,12 @@ func WindowsAddRoute(destination string, gateway string, interfaceName string) e return nil } - // Parse destination CIDR - _, ipNet, err := net.ParseCIDR(destination) + // Parse destination CIDR using netip + prefix, err := netip.ParsePrefix(destination) if err != nil { return fmt.Errorf("invalid destination address: %v", err) } - // Convert to netip.Prefix - maskBits, _ := ipNet.Mask.Size() - - // Ensure we convert to the correct IP version (IPv4 vs IPv6) - var addr netip.Addr - if ip4 := ipNet.IP.To4(); ip4 != nil { - // IPv4 address - addr, _ = netip.AddrFromSlice(ip4) - } else { - // IPv6 address - addr, _ = netip.AddrFromSlice(ipNet.IP) - } - if !addr.IsValid() { - return fmt.Errorf("failed to convert destination IP") - } - prefix := netip.PrefixFrom(addr, maskBits) - var luid winipcfg.LUID var nextHop netip.Addr @@ -57,24 +40,16 @@ func WindowsAddRoute(destination string, gateway string, interfaceName string) e } if gateway != "" { - // Route with specific gateway - gwIP := net.ParseIP(gateway) - if gwIP == nil { + // Route with specific gateway using netip + gwAddr, err := netip.ParseAddr(gateway) + if err != nil { return fmt.Errorf("invalid gateway address: %s", gateway) } - // Convert to correct IP version - if ip4 := gwIP.To4(); ip4 != nil { - nextHop, _ = netip.AddrFromSlice(ip4) - } else { - nextHop, _ = netip.AddrFromSlice(gwIP) - } - if !nextHop.IsValid() { - return fmt.Errorf("failed to convert gateway IP") - } + nextHop = gwAddr logger.Info("Adding route to %s via gateway %s on interface %s", destination, gateway, interfaceName) } else if interfaceName != "" { // Route via interface only - if addr.Is4() { + if prefix.Addr().Is4() { nextHop = netip.IPv4Unspecified() } else { nextHop = netip.IPv6Unspecified() @@ -94,33 +69,16 @@ func WindowsAddRoute(destination string, gateway string, interfaceName string) e } func WindowsRemoveRoute(destination string) error { - // Parse destination CIDR - _, ipNet, err := net.ParseCIDR(destination) + // Parse destination CIDR using netip + prefix, err := netip.ParsePrefix(destination) if err != nil { return fmt.Errorf("invalid destination address: %v", err) } - // Convert to netip.Prefix - maskBits, _ := ipNet.Mask.Size() - - // Ensure we convert to the correct IP version (IPv4 vs IPv6) - var addr netip.Addr - if ip4 := ipNet.IP.To4(); ip4 != nil { - // IPv4 address - addr, _ = netip.AddrFromSlice(ip4) - } else { - // IPv6 address - addr, _ = netip.AddrFromSlice(ipNet.IP) - } - if !addr.IsValid() { - return fmt.Errorf("failed to convert destination IP") - } - prefix := netip.PrefixFrom(addr, maskBits) - // Get all routes and find the one to delete // We need to get the LUID from the existing route var family winipcfg.AddressFamily - if addr.Is4() { + if prefix.Addr().Is4() { family = 2 // AF_INET } else { family = 23 // AF_INET6 diff --git a/util/util.go b/util/util.go index 58221c4..ab561aa 100644 --- a/util/util.go +++ b/util/util.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "net" + "net/netip" "strings" mathrand "math/rand/v2" @@ -37,9 +38,9 @@ func ResolveDomain(domain string) (string, error) { // For IPv6, the host from SplitHostPort will already have brackets stripped // but if there was no port, we need to handle bracketed IPv6 addresses cleanHost := strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[") - if ip := net.ParseIP(cleanHost); ip != nil { + if addr, err := netip.ParseAddr(cleanHost); err == nil { // It's already an IP address, no need to resolve - ipAddr := ip.String() + ipAddr := addr.String() if port != "" { return net.JoinHostPort(ipAddr, port), nil } @@ -59,15 +60,19 @@ func ResolveDomain(domain string) (string, error) { // Get the first IPv4 address if available var ipAddr string for _, ip := range ips { - if ipv4 := ip.To4(); ipv4 != nil { - ipAddr = ipv4.String() + if addr, ok := netip.AddrFromSlice(ip); ok && addr.Is4() { + ipAddr = addr.String() break } } // If no IPv4 found, use the first IP (might be IPv6) if ipAddr == "" { - ipAddr = ips[0].String() + if addr, ok := netip.AddrFromSlice(ips[0]); ok { + ipAddr = addr.String() + } else { + ipAddr = ips[0].String() + } } // Add port back if it existed @@ -122,11 +127,8 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { // Try each port in the randomized order for _, port := range portRange { // Check if port is available - addr1 := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: int(port), - } - conn1, err1 := net.ListenUDP("udp", addr1) + addrPort := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port) + conn1, err1 := net.ListenUDP("udp", net.UDPAddrFromAddrPort(addrPort)) if err1 != nil { continue // Port is in use or there was an error, try next port } diff --git a/wgtester/wgtester.go b/wgtester/wgtester.go index ee88439..23795a1 100644 --- a/wgtester/wgtester.go +++ b/wgtester/wgtester.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "time" @@ -75,13 +76,11 @@ func (s *Server) Start() error { return nil } - //create the address to listen on - addr := net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort)) - if s.useNetstack && s.tnet != nil { // Use WireGuard netstack tnet := s.tnet.(*netstack2.Net) - udpAddr := &net.UDPAddr{Port: int(s.serverPort)} + addrPort := netip.AddrPortFrom(netip.IPv4Unspecified(), s.serverPort) + udpAddr := net.UDPAddrFromAddrPort(addrPort) netstackConn, err := tnet.ListenUDP(udpAddr) if err != nil { return err @@ -89,11 +88,21 @@ func (s *Server) Start() error { s.netstackConn = netstackConn s.conn = netstackConn } else { - // Use regular UDP socket - udpAddr, err := net.ResolveUDPAddr("udp", addr) + // Use regular UDP socket - parse address with netip, fallback to resolve for hostnames (match upstream) + serverAddr, err := netip.ParseAddr(s.serverAddr) if err != nil { - return err + // Not a literal IP; try resolving as hostname (same as upstream ResolveUDPAddr) + addr, resolveErr := net.ResolveUDPAddr("udp", net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort))) + if resolveErr != nil { + serverAddr = netip.IPv4Unspecified() + } else if a, ok := netip.AddrFromSlice(addr.IP); ok && a.IsValid() { + serverAddr = a + } else { + serverAddr = netip.IPv4Unspecified() + } } + addrPort := netip.AddrPortFrom(serverAddr, s.serverPort) + udpAddr := net.UDPAddrFromAddrPort(addrPort) udpConn, err := net.ListenUDP("udp", udpAddr) if err != nil {