Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions authdaemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ var (
authDaemonServer *authdaemon.Server // Global auth daemon server instance
)

// startAuthDaemon initializes and starts the auth daemon in the background.
// It validates requirements (Linux, root, preshared key) and starts the server
// in a goroutine so it runs alongside normal newt operation.
// startAuthDaemon initializes and starts the auth daemon.
// It validates requirements (Linux, root, preshared key) and runs the server
// until the provided context is cancelled.
func startAuthDaemon(ctx context.Context) error {
// Validation
if runtime.GOOS != "linux" {
Expand Down Expand Up @@ -61,15 +61,11 @@ func startAuthDaemon(ctx context.Context) error {

authDaemonServer = srv

// Start the auth daemon in a goroutine so it runs alongside newt
go func() {
logger.Info("Auth daemon starting (native mode, no HTTP server)")
if err := srv.Run(ctx); err != nil {
logger.Error("Auth daemon error: %v", err)
}
logger.Info("Auth daemon stopped")
}()

logger.Info("Auth daemon starting (native mode, no HTTP server)")
if err := srv.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
return fmt.Errorf("auth daemon error: %w", err)
}
logger.Info("Auth daemon stopped")
return nil
}

Expand Down
56 changes: 26 additions & 30 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func reliablePing(tnet *netstack.Net, dst string, baseTimeout time.Duration, max
return totalLatency / time.Duration(successCount), nil
}

func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopChan chan struct{}, err error) {
func pingWithRetry(ctx context.Context, tnet *netstack.Net, dst string, timeout time.Duration) error {

if healthFile != "" {
err = os.Remove(healthFile)
Expand All @@ -151,7 +151,6 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
maxRetryDelay = 60 * time.Second // Cap the maximum delay
)

stopChan = make(chan struct{})
attempt := 1
retryDelay := initialRetryDelay

Expand All @@ -167,18 +166,24 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
logger.Warn(msgHealthFileWriteFailed, err)
}
}
return stopChan, nil
return nil
} else {
logger.Warn("Ping attempt %d failed: %v", attempt, err)
}

// Start a goroutine that will attempt pings indefinitely with increasing delays
go func() {
attempt = 2 // Continue from attempt 2
var retryTimer *time.Timer
defer func() {
if retryTimer != nil {
retryTimer.Stop()
}
}()

for {
select {
case <-stopChan:
case <-ctx.Done():
return
default:
logger.Debug("Ping attempt %d", attempt)
Expand All @@ -195,7 +200,16 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
logger.Info("Increasing ping retry delay to %v", retryDelay)
}

time.Sleep(retryDelay)
if retryTimer == nil {
retryTimer = time.NewTimer(retryDelay)
} else {
retryTimer.Reset(retryDelay)
}
select {
case <-ctx.Done():
return
case <-retryTimer.C:
}
attempt++
} else {
// Successful ping
Expand All @@ -209,18 +223,15 @@ func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) (stopC
}
}
}
case <-pingStopChan:
// Stop the goroutine when signaled
return
}
}
}()

// Return an error for the first batch of attempts (to maintain compatibility with existing code)
return stopChan, fmt.Errorf("initial ping attempts failed, continuing in background")
return fmt.Errorf("initial ping attempts failed, continuing in background")
}

func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, tunnelID string) chan struct{} {
func startPingCheck(ctx context.Context, tnet *netstack.Net, serverIP string, tunnelID string, onConnectionLost func()) {
maxInterval := 6 * time.Second
currentInterval := pingInterval
consecutiveFailures := 0
Expand All @@ -229,13 +240,14 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
// Track recent latencies for adaptive timeout calculation
recentLatencies := make([]time.Duration, 0, 10)

pingStopChan := make(chan struct{})

go func() {
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
logger.Info("Stopping ping check (context cancelled)")
return
case <-ticker.C:
// Calculate adaptive timeout based on recent latencies
adaptiveTimeout := pingTimeout
Expand Down Expand Up @@ -288,19 +300,8 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
}
pingChainId := generateChainId()
pendingPingChainId = pingChainId
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
"chainId": pingChainId,
}, 3*time.Second)
// Send registration message to the server for backward compatibility
bcChainId := generateChainId()
pendingRegisterChainId = bcChainId
err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"backwardsCompatible": true,
"chainId": bcChainId,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
if onConnectionLost != nil {
onConnectionLost()
}
if healthFile != "" {
err = os.Remove(healthFile)
Expand Down Expand Up @@ -347,14 +348,9 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
}
consecutiveFailures = 0
}
case <-pingStopChan:
logger.Info("Stopping ping check")
return
}
}
}()

return pingStopChan
}

func parseTargetData(data interface{}) (TargetData, error) {
Expand Down
Loading