diff --git a/clients/clients.go b/clients/clients.go index 4c64dbd..d46e8ca 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -3,7 +3,9 @@ package clients import ( "context" "encoding/json" + "errors" "fmt" + "io" "net" "net/netip" "os" @@ -105,6 +107,12 @@ type WireGuardService struct { netstackListener net.PacketConn netstackListenerMu sync.Mutex wgTesterServer *wgtester.Server + // Bandwidth check goroutine lifecycle + bandwidthCheckStop chan struct{} + bandwidthCheckWg sync.WaitGroup + bandwidthCheckMu sync.Mutex + // UAPI listener for native interface mode + uapiListener net.Listener } func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { @@ -196,6 +204,9 @@ func (s *WireGuardService) Close() { s.stopGetConfig = nil } + // Stop the periodic bandwidth check goroutine + s.stopPeriodicBandwidthCheck() + // Stop the direct UDP relay first s.StopDirectUDPRelay() @@ -204,6 +215,12 @@ func (s *WireGuardService) Close() { s.holePunchManager.Stop() } + // Close UAPI listener (native interface mode) - this will cause the Accept goroutine to exit + if s.uapiListener != nil { + s.uapiListener.Close() + s.uapiListener = nil + } + s.mu.Lock() defer s.mu.Unlock() @@ -236,6 +253,20 @@ func (s *WireGuardService) Close() { } } +func (s *WireGuardService) startPeriodicBandwidthCheck() { + s.bandwidthCheckMu.Lock() + defer s.bandwidthCheckMu.Unlock() + + if s.bandwidthCheckStop != nil { + close(s.bandwidthCheckStop) + s.bandwidthCheckWg.Wait() + } + + s.bandwidthCheckStop = make(chan struct{}) + s.bandwidthCheckWg.Add(1) + go s.periodicBandwidthCheck(s.bandwidthCheckStop) +} + func (s *WireGuardService) SetToken(token string) { s.token = token if s.holePunchManager != nil { @@ -378,9 +409,17 @@ func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) { n, remoteAddr, err := listener.ReadFrom(buf) if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Check for timeout first - this is normal operation + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { continue // Just a timeout, check for stop and try again } + // Check for connection closed conditions - exit gracefully + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + logger.Debug("Direct UDP relay connection closed, stopping") + return + } + // Check if we've been asked to stop if s.directRelayStop != nil { select { case <-s.directRelayStop: @@ -448,7 +487,9 @@ func (s *WireGuardService) LoadRemoteConfig() error { }, 2*time.Second) logger.Debug("Requesting WireGuard configuration from remote server") - go s.periodicBandwidthCheck() + + // Restart the periodic bandwidth check for the current device lifecycle. + s.startPeriodicBandwidthCheck() return nil } @@ -683,6 +724,16 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { // Parse the IP address and CIDR mask tunnelIP := netip.MustParseAddr(parts[0]) + // Config refreshes can legitimately resend the same config. Reuse the + // existing interface instead of creating a second device/listener stack. + if s.device != nil { + if s.TunnelIP != "" && s.TunnelIP != tunnelIP.String() { + logger.Warn("WireGuard interface already initialized with tunnel IP %s; ignoring re-init request for %s", s.TunnelIP, tunnelIP.String()) + } + s.mu.Unlock() + return nil + } + var err error if s.useNativeInterface { @@ -724,22 +775,23 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { logger.Error("UAPI listen error: %v", err) } - uapiListener, err := newtDevice.UapiListen(interfaceName, fileUAPI) + listener, err := newtDevice.UapiListen(interfaceName, fileUAPI) if err != nil { logger.Error("Failed to listen on uapi socket: %v", err) os.Exit(1) } + s.uapiListener = listener - go func() { + go func(listener net.Listener, dev *device.Device) { for { - conn, err := uapiListener.Accept() + conn, err := listener.Accept() if err != nil { - + // Listener closed, exit goroutine return } - go s.device.IpcHandle(conn) + go dev.IpcHandle(conn) } - }() + }(listener, s.device) logger.Info("UAPI listener started") // Configure WireGuard with private key @@ -1110,17 +1162,36 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { logger.Info("Peer %s updated successfully", request.PublicKey) } -func (s *WireGuardService) periodicBandwidthCheck() { +func (s *WireGuardService) periodicBandwidthCheck(stopCh <-chan struct{}) { + defer s.bandwidthCheckWg.Done() ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() - for range ticker.C { - if err := s.reportPeerBandwidth(); err != nil { - logger.Info("Failed to report peer bandwidth: %v", err) + for { + select { + case <-stopCh: + logger.Debug("Stopping periodic bandwidth check") + return + case <-ticker.C: + if err := s.reportPeerBandwidth(); err != nil { + logger.Info("Failed to report peer bandwidth: %v", err) + } } } } +// stopPeriodicBandwidthCheck stops the bandwidth check goroutine and waits for it to exit +func (s *WireGuardService) stopPeriodicBandwidthCheck() { + s.bandwidthCheckMu.Lock() + defer s.bandwidthCheckMu.Unlock() + + if s.bandwidthCheckStop != nil { + close(s.bandwidthCheckStop) + s.bandwidthCheckWg.Wait() + s.bandwidthCheckStop = nil + } +} + func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { if s.device == nil { return []PeerBandwidth{}, nil