From f43e11f7e103d1282be002132b4fc1b376ff051b Mon Sep 17 00:00:00 2001 From: Laurence Date: Tue, 7 Apr 2026 15:59:25 +0100 Subject: [PATCH 1/2] Refactor Newt runtime lifecycle to context-driven worker supervision. Replace stop-channel orchestration with scoped context cancellation for tunnel and interval workers, add errgroup-based supervision for long-running services, and simplify shutdown to fail fast so Docker/systemd handles restarts. --- authdaemon.go | 20 +++--- common.go | 56 +++++++-------- main.go | 162 ++++++++++++++++++++++++++------------------ websocket/client.go | 66 +++++++++++++----- 4 files changed, 181 insertions(+), 123 deletions(-) diff --git a/authdaemon.go b/authdaemon.go index dc6d313..defb2a5 100644 --- a/authdaemon.go +++ b/authdaemon.go @@ -22,9 +22,9 @@ var ( authDaemonServer *authdaemon.Server // Global auth daemon server instance ) -// startAuthDaemon initializes and starts the auth daemon in the background. -// It validates requirements (Linux, root, preshared key) and starts the server -// in a goroutine so it runs alongside normal newt operation. +// startAuthDaemon initializes and starts the auth daemon. +// It validates requirements (Linux, root, preshared key) and runs the server +// until the provided context is cancelled. func startAuthDaemon(ctx context.Context) error { // Validation if runtime.GOOS != "linux" { @@ -61,15 +61,11 @@ func startAuthDaemon(ctx context.Context) error { authDaemonServer = srv - // Start the auth daemon in a goroutine so it runs alongside newt - go func() { - logger.Info("Auth daemon starting (native mode, no HTTP server)") - if err := srv.Run(ctx); err != nil { - logger.Error("Auth daemon error: %v", err) - } - logger.Info("Auth daemon stopped") - }() - + logger.Info("Auth daemon starting (native mode, no HTTP server)") + if err := srv.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + return fmt.Errorf("auth daemon error: %w", err) + } + logger.Info("Auth daemon stopped") return nil } diff --git a/common.go b/common.go index 4e1ed00..09883b8 100644 --- a/common.go +++ b/common.go @@ -136,7 +136,7 @@ func reliablePing(tnet *netstack.Net, dst string, baseTimeout time.Duration, max return totalLatency / time.Duration(successCount), nil } -func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopChan chan struct{}, err error) { +func pingWithRetry(ctx context.Context, tnet *netstack.Net, dst string, timeout time.Duration) error { if healthFile != "" { err = os.Remove(healthFile) @@ -151,7 +151,6 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC maxRetryDelay = 60 * time.Second // Cap the maximum delay ) - stopChan = make(chan struct{}) attempt := 1 retryDelay := initialRetryDelay @@ -167,7 +166,7 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC logger.Warn(msgHealthFileWriteFailed, err) } } - return stopChan, nil + return nil } else { logger.Warn("Ping attempt %d failed: %v", attempt, err) } @@ -175,10 +174,16 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC // Start a goroutine that will attempt pings indefinitely with increasing delays go func() { attempt = 2 // Continue from attempt 2 + var retryTimer *time.Timer + defer func() { + if retryTimer != nil { + retryTimer.Stop() + } + }() for { select { - case <-stopChan: + case <-ctx.Done(): return default: logger.Debug("Ping attempt %d", attempt) @@ -195,7 +200,16 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC logger.Info("Increasing ping retry delay to %v", retryDelay) } - time.Sleep(retryDelay) + if retryTimer == nil { + retryTimer = time.NewTimer(retryDelay) + } else { + retryTimer.Reset(retryDelay) + } + select { + case <-ctx.Done(): + return + case <-retryTimer.C: + } attempt++ } else { // Successful ping @@ -209,18 +223,15 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC } } } - case <-pingStopChan: - // Stop the goroutine when signaled - return } } }() // Return an error for the first batch of attempts (to maintain compatibility with existing code) - return stopChan, fmt.Errorf("initial ping attempts failed, continuing in background") + return fmt.Errorf("initial ping attempts failed, continuing in background") } -func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, tunnelID string) chan struct{} { +func startPingCheck(ctx context.Context, tnet *netstack.Net, serverIP string, tunnelID string, onConnectionLost func()) { maxInterval := 6 * time.Second currentInterval := pingInterval consecutiveFailures := 0 @@ -229,13 +240,14 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien // Track recent latencies for adaptive timeout calculation recentLatencies := make([]time.Duration, 0, 10) - pingStopChan := make(chan struct{}) - go func() { ticker := time.NewTicker(currentInterval) defer ticker.Stop() for { select { + case <-ctx.Done(): + logger.Info("Stopping ping check (context cancelled)") + return case <-ticker.C: // Calculate adaptive timeout based on recent latencies adaptiveTimeout := pingTimeout @@ -288,19 +300,8 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien } pingChainId := generateChainId() pendingPingChainId = pingChainId - stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ - "chainId": pingChainId, - }, 3*time.Second) - // Send registration message to the server for backward compatibility - bcChainId := generateChainId() - pendingRegisterChainId = bcChainId - err := client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - "backwardsCompatible": true, - "chainId": bcChainId, - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) + if onConnectionLost != nil { + onConnectionLost() } if healthFile != "" { err = os.Remove(healthFile) @@ -347,14 +348,9 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien } consecutiveFailures = 0 } - case <-pingStopChan: - logger.Info("Stopping ping check") - return } } }() - - return pingStopChan } func parseTargetData(data interface{}) (TargetData, error) { diff --git a/main.go b/main.go index d5f2a96..5e2280d 100644 --- a/main.go +++ b/main.go @@ -33,6 +33,7 @@ import ( "github.com/fosrl/newt/internal/state" "github.com/fosrl/newt/internal/telemetry" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "golang.org/x/sync/errgroup" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -130,8 +131,6 @@ var ( pingInterval time.Duration pingTimeout time.Duration publicKey wgtypes.Key - pingStopChan chan struct{} - stopFunc func() pendingRegisterChainId string pendingPingChainId string healthFile string @@ -223,11 +222,16 @@ func main() { defer stop() // Run the main newt logic - runNewtMain(ctx) + if err := runNewtMain(ctx); err != nil { + logger.Error("Newt terminated with lifecycle error: %v", err) + os.Exit(1) + } } // runNewtMain contains the main newt logic, extracted for service support -func runNewtMain(ctx context.Context) { +func runNewtMain(ctx context.Context) error { + g, runCtx := errgroup.WithContext(ctx) + // if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values endpoint = os.Getenv("PANGOLIN_ENDPOINT") id = os.Getenv("NEWT_ID") @@ -503,9 +507,12 @@ func runNewtMain(ctx context.Context) { // Start auth daemon if enabled if authDaemonEnabled { - if err := startAuthDaemon(ctx); err != nil { - logger.Fatal("Failed to start auth daemon: %v", err) - } + g.Go(func() error { + if err := startAuthDaemon(runCtx); err != nil { + return fmt.Errorf("auth daemon failed: %w", err) + } + return nil + }) } logger.GetLogger().SetLevel(loggerLevel) @@ -523,7 +530,7 @@ func runNewtMain(ctx context.Context) { tcfg.BuildVersion = newtVersion tcfg.BuildCommit = os.Getenv("NEWT_COMMIT") - tel, telErr := telemetry.Init(ctx, tcfg) + tel, telErr := telemetry.Init(runCtx, tcfg) if telErr != nil { logger.Warn("Telemetry init failed: %v", telErr) } @@ -551,11 +558,12 @@ func runNewtMain(ctx context.Context) { ReadHeaderTimeout: 5 * time.Second, IdleTimeout: 30 * time.Second, } - go func() { + g.Go(func() error { if err := admin.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - logger.Warn("admin http error: %v", err) + return fmt.Errorf("admin http error: %w", err) } - }() + return nil + }) defer func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -571,7 +579,7 @@ func runNewtMain(ctx context.Context) { // parse the mtu string into an int mtuInt, err = strconv.Atoi(mtu) if err != nil { - logger.Fatal("Failed to parse MTU: %v", err) + return fmt.Errorf("failed to parse MTU: %w", err) } // parse if we want to enforce container network validation @@ -583,7 +591,7 @@ func runNewtMain(ctx context.Context) { // Add TLS configuration validation if err := validateTLSConfig(); err != nil { - logger.Fatal("TLS configuration error: %v", err) + return fmt.Errorf("TLS configuration error: %w", err) } // Show deprecation warning if using PKCS12 @@ -593,7 +601,7 @@ func runNewtMain(ctx context.Context) { privateKey, err = wgtypes.GeneratePrivateKey() if err != nil { - logger.Fatal("Failed to generate private key: %v", err) + return fmt.Errorf("failed to generate private key: %w", err) } // Create client option based on TLS configuration @@ -628,7 +636,7 @@ func runNewtMain(ctx context.Context) { websocket.WithConfigFile(configFile), ) if err != nil { - logger.Fatal("Failed to create client: %v", err) + return fmt.Errorf("failed to create client: %w", err) } // If a provisioning key was supplied via CLI / env and the config file did // not already carry one, inject it now so provisionIfNeeded() can use it. @@ -717,13 +725,53 @@ func runNewtMain(ctx context.Context) { } }, enforceHealthcheckCert) - var pingWithRetryStopChan chan struct{} + var intervalSendCancel context.CancelFunc + var tunnelCtx context.Context + var tunnelCancel context.CancelFunc + + stopIntervalSend := func() { + if intervalSendCancel != nil { + intervalSendCancel() + intervalSendCancel = nil + } + } + + startIntervalSend := func(messageType string, data map[string]interface{}, interval time.Duration) { + stopIntervalSend() + var intervalCtx context.Context + intervalCtx, intervalSendCancel = context.WithCancel(runCtx) + client.SendMessageIntervalCtx(intervalCtx, messageType, data, interval) + } + + restartTunnelWorkers := func(serverIP string, tunnelID string) { + if tunnelCancel != nil { + tunnelCancel() + tunnelCancel = nil + } + tunnelCtx, tunnelCancel = context.WithCancel(runCtx) + _ = pingWithRetry(tunnelCtx, tnet, serverIP, pingTimeout) + startPingCheck(tunnelCtx, tnet, serverIP, tunnelID, func() { + pingChainID := generateChainId() + pendingPingChainId = pingChainID + startIntervalSend("newt/ping/request", map[string]interface{}{ + "chainId": pingChainID, + }, 3*time.Second) + bcChainId := generateChainId() + pendingRegisterChainId = bcChainId + if sendErr := client.SendMessage("newt/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "backwardsCompatible": true, + "chainId": bcChainId, + }); sendErr != nil { + logger.Error("Failed to send registration message: %v", sendErr) + } + }) + } closeWgTunnel := func() { - if pingStopChan != nil { - // Stop the ping check - close(pingStopChan) - pingStopChan = nil + if tunnelCancel != nil { + tunnelCancel() + tunnelCancel = nil } // Stop proxy manager if running @@ -773,10 +821,7 @@ func runNewtMain(ctx context.Context) { pendingRegisterChainId = "" // consume – further duplicates with this id are rejected } - if stopFunc != nil { - stopFunc() // stop the ws from sending more requests - stopFunc = nil // reset stopFunc to nil to avoid double stopping - } + stopIntervalSend() if connected { // Mark as disconnected @@ -865,17 +910,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( logger.Debug("WireGuard device created. Lets ping the server now...") - // Even if pingWithRetry returns an error, it will continue trying in the background - if pingWithRetryStopChan != nil { - // Stop the previous pingWithRetry if it exists - close(pingWithRetryStopChan) - pingWithRetryStopChan = nil - } // Use reliable ping for initial connection test logger.Debug("Testing initial connection with reliable ping...") lat, err := reliablePing(tnet, wgData.ServerIP, pingTimeout, 5) if err == nil && wgData.PublicKey != "" { - telemetry.ObserveTunnelLatency(ctx, wgData.PublicKey, "wireguard", lat.Seconds()) + telemetry.ObserveTunnelLatency(runCtx, wgData.PublicKey, "wireguard", lat.Seconds()) } if err != nil { logger.Warn("Initial reliable ping failed, but continuing: %v", err) @@ -884,14 +923,8 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( logger.Debug("Initial connection test successful") } - pingWithRetryStopChan, _ = pingWithRetry(tnet, wgData.ServerIP, pingTimeout) - - // Always mark as connected and start the proxy manager regardless of initial ping result - // as the pings will continue in the background - if !connected { - logger.Debug("Starting ping check") - pingStopChan = startPingCheck(tnet, wgData.ServerIP, client, wgData.PublicKey) - } + // Always mark as connected and start tunnel workers regardless of initial ping result. + restartTunnelWorkers(wgData.ServerIP, wgData.PublicKey) // Create proxy manager pm = proxy.NewProxyManager(tnet) @@ -951,15 +984,12 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( // Mark as disconnected connected = false - if stopFunc != nil { - stopFunc() // stop the ws from sending more requests - stopFunc = nil // reset stopFunc to nil to avoid double stopping - } + stopIntervalSend() // Request exit nodes from the server pingChainId := generateChainId() pendingPingChainId = pingChainId - stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ + startIntervalSend("newt/ping/request", map[string]interface{}{ "noCloud": noCloud, "chainId": pingChainId, }, 3*time.Second) @@ -977,10 +1007,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( closeWgTunnel() closeClients() - if stopFunc != nil { - stopFunc() // stop the ws from sending more requests - stopFunc = nil // reset stopFunc to nil to avoid double stopping - } + stopIntervalSend() // Mark as disconnected connected = false @@ -991,10 +1018,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) { logger.Debug("Received ping message") - if stopFunc != nil { - stopFunc() // stop the ws from sending more requests - stopFunc = nil // reset stopFunc to nil to avoid double stopping - } + stopIntervalSend() // Parse the incoming list of exit nodes var exitNodeData ExitNodeData @@ -1052,7 +1076,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( chainId := generateChainId() pendingRegisterChainId = chainId - stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ + startIntervalSend(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "pingResults": pingResults, "newtVersion": newtVersion, @@ -1158,7 +1182,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( // Send the ping results to the cloud for selection chainId := generateChainId() pendingRegisterChainId = chainId - stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ + startIntervalSend(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "pingResults": pingResults, "newtVersion": newtVersion, @@ -1807,14 +1831,12 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( logger.Info("Websocket connected") if !connected { - // make sure the stop function is called - if stopFunc != nil { - stopFunc() - } + // cancel previous periodic requests first + stopIntervalSend() // request from the server the list of nodes to ping pingChainId := generateChainId() pendingPingChainId = pingChainId - stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ + startIntervalSend("newt/ping/request", map[string]interface{}{ "noCloud": noCloud, "chainId": pingChainId, }, 3*time.Second) @@ -1874,10 +1896,10 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( }) // Connect to the WebSocket server - if err := client.Connect(); err != nil { - logger.Fatal("Failed to connect to server: %v", err) + if err := client.ConnectWithContext(runCtx); err != nil { + return fmt.Errorf("failed to connect to server: %w", err) } - defer client.Close() + defer func() { _ = client.Close() }() // Initialize Docker event monitoring if Docker socket is available and monitoring is enabled if dockerSocket != "" { @@ -1903,12 +1925,16 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( logger.Error("Failed to start Docker event monitoring: %v", err) } else { logger.Debug("Docker event monitoring started successfully") + g.Go(func() error { + <-runCtx.Done() + dockerEventMonitor.Stop() + return nil + }) } } } - // Wait for context cancellation (from signal or service stop) - <-ctx.Done() + groupErr := g.Wait() // Close clients first (including WGTester) closeClients() @@ -1935,6 +1961,10 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( client.Close() } logger.Info("Exiting...") + if groupErr != nil && !errors.Is(groupErr, context.Canceled) { + return groupErr + } + return nil } // runNewtMainWithArgs is used by the Windows service to run newt with specific arguments @@ -1948,7 +1978,9 @@ func runNewtMainWithArgs(ctx context.Context, args []string) { setupWindowsEventLog() // Run the main newt logic - runNewtMain(ctx) + if err := runNewtMain(ctx); err != nil { + logger.Error("Service run failed: %v", err) + } } // validateTLSConfig validates the TLS configuration diff --git a/websocket/client.go b/websocket/client.go index 6990bd2..df421cc 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -54,6 +54,7 @@ type Client struct { processingMux sync.RWMutex // Protects processingMessage processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete justProvisioned bool // Set to true when provisionIfNeeded exchanges a key for permanent credentials + lifecycleCtx context.Context } type ClientOption func(*Client) @@ -192,7 +193,15 @@ func (c *Client) setConfigVersion(version int64) { // Connect establishes the WebSocket connection func (c *Client) Connect() error { - go c.connectWithRetry() + go c.connectWithRetry(context.Background()) + return nil +} + +// ConnectWithContext establishes the WebSocket connection and binds reconnect +// behavior to the provided context lifecycle. +func (c *Client) ConnectWithContext(ctx context.Context) error { + c.lifecycleCtx = ctx + go c.connectWithRetry(ctx) return nil } @@ -268,10 +277,17 @@ func (c *Client) SendMessageNoLog(messageType string, data interface{}) error { } func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) { - stopChan := make(chan struct{}) + stopCtx, cancel := context.WithCancel(context.Background()) + c.SendMessageIntervalCtx(stopCtx, messageType, data, interval) + return cancel +} + +// SendMessageIntervalCtx sends a message repeatedly until ctx is cancelled. +func (c *Client) SendMessageIntervalCtx(ctx context.Context, messageType string, data interface{}, interval time.Duration) { go func() { count := 0 - maxAttempts := 10 + currentInterval := interval + maxInterval := 60 * time.Second // Cap the maximum interval err := c.SendMessage(messageType, data) // Send immediately if err != nil { @@ -279,28 +295,31 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter } count++ - ticker := time.NewTicker(interval) + ticker := time.NewTicker(currentInterval) defer ticker.Stop() for { select { + case <-ctx.Done(): + return case <-ticker.C: - if count >= maxAttempts { - logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) - return - } err = c.SendMessage(messageType, data) if err != nil { logger.Error("Failed to send message: %v", err) } count++ - case <-stopChan: - return + + // Increase interval every 10 attempts up to maxInterval + if count%10 == 0 && currentInterval < maxInterval { + currentInterval = time.Duration(float64(currentInterval) * 1.5) + if currentInterval > maxInterval { + currentInterval = maxInterval + } + ticker.Reset(currentInterval) + logger.Debug("Increased message interval to %v after %d attempts for message type: %s", currentInterval, count, messageType) + } } } }() - return func() { - close(stopChan) - } } // RegisterHandler registers a handler for a specific message type @@ -479,16 +498,27 @@ func classifyWSDisconnect(err error) (result, reason string) { } } -func (c *Client) connectWithRetry() { +func (c *Client) connectWithRetry(ctx context.Context) { for { select { + case <-ctx.Done(): + return case <-c.done: return default: err := c.establishConnection() if err != nil { logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) - time.Sleep(c.reconnectInterval) + retryTimer := time.NewTimer(c.reconnectInterval) + select { + case <-ctx.Done(): + retryTimer.Stop() + return + case <-c.done: + retryTimer.Stop() + return + case <-retryTimer.C: + } continue } return @@ -869,7 +899,11 @@ func (c *Client) reconnect() { case <-c.done: return default: - go c.connectWithRetry() + ctx := c.lifecycleCtx + if ctx == nil { + ctx = context.Background() + } + go c.connectWithRetry(ctx) } } From b7b7a90e51db46022cbd21501beadd4d816ce33c Mon Sep 17 00:00:00 2001 From: Laurence Date: Tue, 7 Apr 2026 16:06:27 +0100 Subject: [PATCH 2/2] fix: dont throw away timers reuse them for gc --- websocket/client.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/websocket/client.go b/websocket/client.go index df421cc..81f0ecc 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -499,6 +499,13 @@ func classifyWSDisconnect(err error) (result, reason string) { } func (c *Client) connectWithRetry(ctx context.Context) { + var retryTimer *time.Timer + defer func() { + if retryTimer != nil { + retryTimer.Stop() + } + }() + for { select { case <-ctx.Done(): @@ -509,13 +516,15 @@ func (c *Client) connectWithRetry(ctx context.Context) { err := c.establishConnection() if err != nil { logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) - retryTimer := time.NewTimer(c.reconnectInterval) + if retryTimer == nil { + retryTimer = time.NewTimer(c.reconnectInterval) + } else { + retryTimer.Reset(c.reconnectInterval) + } select { case <-ctx.Done(): - retryTimer.Stop() return case <-c.done: - retryTimer.Stop() return case <-retryTimer.C: }