diff --git a/README.md b/README.md index 3ac0be7..7b31c12 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,32 @@ When Newt receives WireGuard control messages, it will use the information encod When Newt receives WireGuard control messages, it will use the information encoded to create a local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets. +### DNS Authority + +Newt includes an authoritative DNS server that can serve customized DNS records for specific domains (zones) managed by Pangolin. This allows for intelligent routing and high-availability setups where Newt can respond with the healthiest target IPs for a given service. + +The DNS server runs on port 53 (UDP/TCP). By default, it binds to `0.0.0.0`, but this can be customized using the `--dns-bind` flag or `DNS_BIND_ADDR` environment variable. + +#### systemd-resolved Conflict + +On many modern Linux distributions, `systemd-resolved` binds to `127.0.0.53:53`, which prevents Newt from binding to `0.0.0.0:53`. To resolve this, you can: +1. Disable `systemd-resolved`: `sudo systemctl disable --now systemd-resolved` +2. Or bind Newt to a specific public IP that doesn't conflict with the loopback address used by resolved: `--dns-bind 1.2.3.4` +3. Or disable the DNS Authority feature entirely if you don't need it: `--disable-dns-authority` + +## Configuration + +Newt can be configured via environment variables or command-line flags. + +| Environment Variable | Flag | Description | Default | +|----------------------|------|-------------|---------| +| `PANGOLIN_ENDPOINT` | `--endpoint` | Pangolin server endpoint | | +| `NEWT_ID` | `--id` | Newt Site ID | | +| `NEWT_SECRET` | `--secret` | Newt Site Secret | | +| `DNS_BIND_ADDR` | `--dns-bind` | Bind address for DNS Authority | `0.0.0.0` | +| `DISABLE_DNS_AUTHORITY` | `--disable-dns-authority` | Disable the DNS Authority server | `false` | +| `LOG_LEVEL` | `--log-level` | Logging level (DEBUG, INFO, WARN, ERROR, FATAL) | `INFO` | + ## Build ### Binary diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000..1db5c1c --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,1160 @@ +package auth + +import ( + "context" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "regexp" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/fosrl/newt/logger" + "github.com/golang-jwt/jwt/v5" +) + +// AuthConfig holds the authentication configuration synced from Pangolin +type AuthConfig struct { + Enabled bool `json:"enabled"` + PangolinURL string `json:"pangolinUrl"` // e.g., "https://pangolin.example.com" + JWTPublicKey string `json:"jwtPublicKey"` // PEM-encoded RSA public key + CookieName string `json:"cookieName"` // Session cookie name + CookieDomain string `json:"cookieDomain"` // Shared cookie domain + SessionValidationURL string `json:"sessionValidationUrl"` // API endpoint to validate sessions + AllowedEmails []string `json:"allowedEmails"` // Email whitelist (if enabled) + EmailWhitelistEnabled bool `json:"emailWhitelistEnabled"` +} + +// TargetConfig holds a single backend target +type TargetConfig struct { + TargetURL string `json:"targetUrl"` + Path string `json:"path,omitempty"` + PathMatchType string `json:"pathMatchType,omitempty"` // exact, prefix, regex + RewritePath string `json:"rewritePath,omitempty"` + RewritePathType string `json:"rewritePathType,omitempty"` // exact, prefix, regex, stripPrefix + Priority int `json:"priority,omitempty"` +} + +// ResourceAuthConfig holds auth configuration for a specific resource +type ResourceAuthConfig struct { + ResourceID int `json:"resourceId"` + Domain string `json:"domain"` // Full domain for the resource + SSO bool `json:"sso"` // SSO enabled + BlockAccess bool `json:"blockAccess"` // Block all access + EmailWhitelistEnabled bool `json:"emailWhitelistEnabled"` + AllowedEmails []string `json:"allowedEmails"` + SSL bool `json:"ssl"` // Frontend TLS + TargetURL string `json:"targetUrl,omitempty"` // backward compat: single target URL from older Pangolin + Targets []TargetConfig `json:"targets"` + StickySession bool `json:"stickySession,omitempty"` + TLSServerName string `json:"tlsServerName,omitempty"` + SetHostHeader string `json:"setHostHeader,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + PostAuthPath string `json:"postAuthPath,omitempty"` + rrIndex uint64 // internal: atomic round-robin counter +} + +// TLSCertificateConfig holds a TLS certificate pushed from Pangolin +type TLSCertificateConfig struct { + Domain string `json:"domain"` // Domain this cert covers (may be wildcard like *.example.com) + CertPEM string `json:"certPem"` // PEM-encoded certificate chain + KeyPEM string `json:"keyPem"` // PEM-encoded private key + ExpiresAt int64 `json:"expiresAt"` // Unix timestamp when cert expires + Wildcard bool `json:"wildcard"` // Whether this is a wildcard cert +} + +// AuthProxyConfig is the full config message from Pangolin +type AuthProxyConfig struct { + Action string `json:"action"` // "update", "remove", "start", "stop" + Auth AuthConfig `json:"auth"` + Resources []ResourceAuthConfig `json:"resources"` + TLSCertificates []TLSCertificateConfig `json:"tlsCertificates,omitempty"` +} + +// AuthProxy handles authentication for direct-routed resources +type AuthProxy struct { + mu sync.RWMutex + config AuthConfig + resources map[string]*ResourceAuthConfig // domain -> config + servers map[string]*http.Server // domain -> server + jwtPublicKey *rsa.PublicKey + httpClient *http.Client + proxyTransport *http.Transport + proxyCache map[string]*httputil.ReverseProxy // target URL -> reverse proxy + sessionCacheTTL time.Duration + sessionMu sync.RWMutex + sessionCache map[string]cachedSession + running bool + ctx context.Context + cancel context.CancelFunc + listenAddr string + httpsListenAddr string + httpsServer *http.Server + certStore map[string]*tls.Certificate // domain -> parsed TLS cert (lowercase) + certWildcards map[string]*tls.Certificate // base domain -> wildcard cert (e.g. "example.com" -> *.example.com cert) + hasCerts bool // whether any TLS certs have been loaded + httpBindFailed bool // true if HTTP port was already in use (e.g. Traefik colocated) + httpsBindFailed bool // true if HTTPS port was already in use + onSessionEstablished func(domain string, clientIP string) +} + +// NewAuthProxy creates a new auth proxy +func NewAuthProxy() *AuthProxy { + ctx, cancel := context.WithCancel(context.Background()) + listenAddr := os.Getenv("NEWT_AUTH_PROXY_BIND") + if listenAddr == "" { + listenAddr = ":80" + } + httpsListenAddr := os.Getenv("NEWT_AUTH_PROXY_HTTPS_BIND") + if httpsListenAddr == "" { + httpsListenAddr = ":443" + } + + proxyTransport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{Timeout: 5 * time.Second, KeepAlive: 30 * time.Second}).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 200, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + sessionTransport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{Timeout: 5 * time.Second, KeepAlive: 30 * time.Second}).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + sessionCacheTTL := sessionCacheTTLFromEnv() + + return &AuthProxy{ + resources: make(map[string]*ResourceAuthConfig), + servers: make(map[string]*http.Server), + certStore: make(map[string]*tls.Certificate), + certWildcards: make(map[string]*tls.Certificate), + httpClient: &http.Client{ + Timeout: 10 * time.Second, + Transport: sessionTransport, + }, + proxyTransport: proxyTransport, + proxyCache: make(map[string]*httputil.ReverseProxy), + sessionCacheTTL: sessionCacheTTL, + sessionCache: make(map[string]cachedSession), + ctx: ctx, + cancel: cancel, + listenAddr: listenAddr, + httpsListenAddr: httpsListenAddr, + } +} + +// UpdateConfig updates the global auth configuration +func (p *AuthProxy) UpdateConfig(config AuthConfig) error { + p.mu.Lock() + defer p.mu.Unlock() + + p.config = config + + // Parse JWT public key if provided + if config.JWTPublicKey != "" { + key, err := parseRSAPublicKey(config.JWTPublicKey) + if err != nil { + return fmt.Errorf("failed to parse JWT public key: %w", err) + } + p.jwtPublicKey = key + logger.Info("Auth Proxy: Updated JWT public key") + } + + return nil +} + +// SetSessionEstablishedHandler sets a callback invoked when a request is +// served for a resource. This is used by DNS authority sticky affinity to +// remember which site last served a client session. +func (p *AuthProxy) SetSessionEstablishedHandler(handler func(domain string, clientIP string)) { + p.mu.Lock() + defer p.mu.Unlock() + p.onSessionEstablished = handler +} + +// UpdateResource updates or adds a resource auth configuration +func (p *AuthProxy) UpdateResource(resource ResourceAuthConfig) error { + p.mu.Lock() + defer p.mu.Unlock() + + domain := strings.ToLower(resource.Domain) + if _, ok := p.resources[domain]; ok { + p.proxyCache = make(map[string]*httputil.ReverseProxy) + } + + // Store the resource config + p.resources[domain] = &resource + + logger.Info("Auth Proxy: Updated resource %s (SSO: %v, BlockAccess: %v, Targets: %d)", + domain, resource.SSO, resource.BlockAccess, len(resource.Targets)) + + return nil +} + +// RemoveResource removes a resource configuration +func (p *AuthProxy) RemoveResource(domain string) { + p.mu.Lock() + defer p.mu.Unlock() + + domain = strings.ToLower(domain) + if existing, ok := p.resources[domain]; ok { + if len(existing.Targets) > 0 { + p.proxyCache = make(map[string]*httputil.ReverseProxy) + } + } + delete(p.resources, domain) + + // Stop server if running for this domain + if server, exists := p.servers[domain]; exists { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + server.Shutdown(ctx) + delete(p.servers, domain) + } + + logger.Info("Auth Proxy: Removed resource %s", domain) +} + +// Start starts the auth proxy +func (p *AuthProxy) Start() error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.running { + return nil + } + + p.ctx, p.cancel = context.WithCancel(context.Background()) + + // Try to bind the HTTP port. If another process owns it (e.g. Traefik + // colocated on the same machine), log a clear message and skip the HTTP + // listener but still mark as running so certs/resources are stored. + httpUp := false + listener, err := net.Listen("tcp", p.listenAddr) + if err != nil { + p.httpBindFailed = true + logger.Warn("Auth Proxy: HTTP port %s is already in use by another process "+ + "(likely Traefik/Gerbil on this machine). HTTP listener skipped. "+ + "Set NEWT_AUTH_PROXY_BIND to use a different port.", p.listenAddr) + } else { + listener.Close() + p.httpBindFailed = false + + // HTTP server: serves requests directly when no TLS certs are available, + // otherwise redirects to HTTPS + httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.mu.RLock() + hasCerts := p.hasCerts + p.mu.RUnlock() + + if hasCerts { + // Redirect HTTP → HTTPS + host := r.Host + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + target := "https://" + host + r.RequestURI + http.Redirect(w, r, target, http.StatusMovedPermanently) + return + } + // No TLS certs loaded: serve directly on HTTP + p.ServeHTTP(w, r) + }) + + server := &http.Server{ + Addr: p.listenAddr, + Handler: httpHandler, + } + p.servers["__default__"] = server + + go func() { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.Error("Auth Proxy: HTTP server error on %s: %v", p.listenAddr, err) + } + }() + httpUp = true + } + + // Start HTTPS server if we have certificates + if p.hasCerts { + p.startHTTPSServerLocked() + } + + p.running = true + + if httpUp { + logger.Info("Auth Proxy: Started on %s", p.listenAddr) + } else { + logger.Info("Auth Proxy: Started (HTTP skipped — port in use; HTTPS will be attempted when certs arrive)") + } + return nil +} + +// startHTTPSServerLocked starts the HTTPS server. Must be called with p.mu held. +func (p *AuthProxy) startHTTPSServerLocked() { + if p.httpsServer != nil { + return // already running + } + if p.httpsBindFailed { + return // previously failed — don't retry until restart + } + + // Preflight check — if the HTTPS port is in use, record and skip + ln, err := net.Listen("tcp", p.httpsListenAddr) + if err != nil { + p.httpsBindFailed = true + logger.Warn("Auth Proxy: HTTPS port %s is already in use by another process "+ + "(likely Traefik/Gerbil on this machine). HTTPS listener skipped. "+ + "Set NEWT_AUTH_PROXY_HTTPS_BIND to use a different port.", p.httpsListenAddr) + return + } + ln.Close() + + tlsConfig := &tls.Config{ + GetCertificate: p.getCertificate, + MinVersion: tls.VersionTLS12, + } + + p.httpsServer = &http.Server{ + Addr: p.httpsListenAddr, + Handler: p, // use the same ServeHTTP handler + TLSConfig: tlsConfig, + } + + go func() { + // ListenAndServeTLS with empty cert/key files because GetCertificate handles it + if err := p.httpsServer.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.Error("Auth Proxy: HTTPS server error on %s: %v", p.httpsListenAddr, err) + } + }() + + logger.Info("Auth Proxy: HTTPS server started on %s", p.httpsListenAddr) +} + +// stopHTTPSServerLocked stops the HTTPS server. Must be called with p.mu held. +func (p *AuthProxy) stopHTTPSServerLocked() { + if p.httpsServer == nil { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + p.httpsServer.Shutdown(ctx) + p.httpsServer = nil + logger.Info("Auth Proxy: HTTPS server stopped") +} + +// getCertificate is the tls.Config.GetCertificate callback for SNI-based cert selection +func (p *AuthProxy) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + p.mu.RLock() + defer p.mu.RUnlock() + + serverName := strings.ToLower(hello.ServerName) + + // Try exact domain match first + if cert, ok := p.certStore[serverName]; ok { + return cert, nil + } + + // Try wildcard match: for "sub.example.com", check if we have a wildcard cert for "example.com" + parts := strings.SplitN(serverName, ".", 2) + if len(parts) == 2 { + baseDomain := parts[1] + if cert, ok := p.certWildcards[baseDomain]; ok { + return cert, nil + } + } + + return nil, fmt.Errorf("no certificate found for %s", serverName) +} + +// Stop stops the auth proxy +func (p *AuthProxy) Stop() error { + p.mu.Lock() + defer p.mu.Unlock() + + if !p.running { + return nil + } + + p.cancel() + + // Stop HTTPS server + p.stopHTTPSServerLocked() + + // Shutdown all HTTP servers + for domain, server := range p.servers { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + server.Shutdown(ctx) + cancel() + logger.Debug("Auth Proxy: Stopped server for %s", domain) + } + p.servers = make(map[string]*http.Server) + p.proxyCache = make(map[string]*httputil.ReverseProxy) + if p.proxyTransport != nil { + p.proxyTransport.CloseIdleConnections() + } + + p.sessionMu.Lock() + p.sessionCache = make(map[string]cachedSession) + p.sessionMu.Unlock() + + p.running = false + logger.Info("Auth Proxy: Stopped") + return nil +} + +// ServeHTTP handles incoming requests with authentication +func (p *AuthProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + host := strings.ToLower(r.Host) + // Remove port if present + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + + p.mu.RLock() + resource, exists := p.resources[host] + config := p.config + p.mu.RUnlock() + + if !exists { + http.Error(w, "Resource not found", http.StatusNotFound) + return + } + + // Check if access is blocked + if resource.BlockAccess { + http.Error(w, "Access blocked", http.StatusForbidden) + return + } + + // If SSO is enabled, validate authentication + if resource.SSO && config.Enabled { + user, err := p.validateAuth(r) + if err != nil { + logger.Debug("Auth Proxy: Auth validation failed for %s: %v", host, err) + p.redirectToLogin(w, r, resource) + return + } + + // Check email whitelist if enabled + if resource.EmailWhitelistEnabled && len(resource.AllowedEmails) > 0 { + if !p.isEmailAllowed(user.Email, resource.AllowedEmails) { + http.Error(w, "Access denied: email not in whitelist", http.StatusForbidden) + return + } + } + + // Add user info to headers for the backend + r.Header.Set("X-Auth-User", user.Email) + r.Header.Set("X-Auth-User-ID", user.UserID) + } + + // Proxy to backend + p.proxyToBackend(w, r, resource) +} + +// UserClaims represents the claims in a Pangolin JWT +type UserClaims struct { + jwt.RegisteredClaims + UserID string `json:"userId"` + Email string `json:"email"` + OrgID string `json:"orgId"` + Resources []int `json:"resources"` // Resource IDs the user can access +} + +type sessionValidationData struct { + Valid bool `json:"valid"` + UserID string `json:"userId"` + Email string `json:"email"` + OrgID string `json:"orgId"` + ExpiresAt string `json:"expiresAt"` +} + +type sessionValidationAPIResponse struct { + Data sessionValidationData `json:"data"` + Success bool `json:"success"` + Error bool `json:"error"` + Message string `json:"message"` +} + +type cachedSession struct { + claims UserClaims + expiresAt time.Time +} + +// validateAuth validates the request authentication +func (p *AuthProxy) validateAuth(r *http.Request) (*UserClaims, error) { + p.mu.RLock() + config := p.config + publicKey := p.jwtPublicKey + p.mu.RUnlock() + + // Try to get token from cookie + cookie, err := r.Cookie(config.CookieName) + if err != nil { + // Try Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return nil, fmt.Errorf("no auth token found") + } + + // Extract Bearer token + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + return nil, fmt.Errorf("invalid authorization header") + } + + return p.validateJWT(parts[1], publicKey) + } + + // Validate cookie token + return p.validateJWT(cookie.Value, publicKey) +} + +// validateJWT validates a JWT token +func (p *AuthProxy) validateJWT(tokenString string, publicKey *rsa.PublicKey) (*UserClaims, error) { + if publicKey != nil { + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return publicKey, nil + }) + + if err == nil { + claims, ok := token.Claims.(*UserClaims) + if ok && token.Valid { + return claims, nil + } + } + + // If JWT validation fails (e.g. it's an opaque session token, or expired), + // fall back to session validation against the Pangolin API. + logger.Debug("Auth Proxy: JWT validation failed/skipped, falling back to session API") + } + + return p.validateSession(tokenString) +} + +// validateSession validates a session token against Pangolin's API +func (p *AuthProxy) validateSession(sessionToken string) (*UserClaims, error) { + if claims, ok := p.getCachedSession(sessionToken); ok { + return claims, nil + } + + p.mu.RLock() + config := p.config + p.mu.RUnlock() + + if config.SessionValidationURL == "" { + return nil, fmt.Errorf("session validation not configured") + } + + req, err := http.NewRequest("GET", config.SessionValidationURL, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", config.CookieName, sessionToken)) + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("session validation request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("session invalid: status %d", resp.StatusCode) + } + + var validationResp sessionValidationAPIResponse + if err := json.NewDecoder(resp.Body).Decode(&validationResp); err != nil { + return nil, fmt.Errorf("failed to parse session response: %w", err) + } + + if !validationResp.Data.Valid { + return nil, fmt.Errorf("session invalid") + } + + if validationResp.Data.UserID == "" { + return nil, fmt.Errorf("session validation response missing userId") + } + + claims := UserClaims{ + UserID: validationResp.Data.UserID, + Email: validationResp.Data.Email, + OrgID: validationResp.Data.OrgID, + } + + p.cacheSession(sessionToken, &claims, validationResp.Data.ExpiresAt) + + return &claims, nil +} + +// redirectToLogin redirects the user to Pangolin's login page +func (p *AuthProxy) redirectToLogin(w http.ResponseWriter, r *http.Request, resource *ResourceAuthConfig) { + p.mu.RLock() + config := p.config + p.mu.RUnlock() + + // Build the redirect-after-login URL + scheme := "https" + if r.TLS == nil { + scheme = "http" + } + // If postAuthPath is set, redirect to that path after login instead of the original URL + redirectTarget := fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI) + if resource.PostAuthPath != "" { + redirectTarget = fmt.Sprintf("%s://%s%s", scheme, r.Host, resource.PostAuthPath) + } + + // Build login URL with redirect + loginURL := fmt.Sprintf("%s/auth/login?redirect=%s&resource=%d", + config.PangolinURL, + url.QueryEscape(redirectTarget), + resource.ResourceID, + ) + + http.Redirect(w, r, loginURL, http.StatusFound) +} + +// isEmailAllowed checks if an email is in the allowed list +func (p *AuthProxy) isEmailAllowed(email string, allowedEmails []string) bool { + email = strings.ToLower(email) + for _, allowed := range allowedEmails { + allowed = strings.ToLower(allowed) + if allowed == email { + return true + } + // Support wildcard domain matching like *@example.com + if strings.HasPrefix(allowed, "*@") { + domain := allowed[2:] + if strings.HasSuffix(email, "@"+domain) { + return true + } + } + } + return false +} + +// proxyToBackend selects a backend target, applies path rewriting, and proxies the request +func (p *AuthProxy) proxyToBackend(w http.ResponseWriter, r *http.Request, resource *ResourceAuthConfig) { + target := p.selectTarget(r, resource) + if target == nil { + http.Error(w, "No available backend", http.StatusBadGateway) + return + } + + if clientIP := clientIPFromRemoteAddr(r.RemoteAddr); clientIP != "" { + p.mu.RLock() + handler := p.onSessionEstablished + p.mu.RUnlock() + if handler != nil { + handler(resource.Domain, clientIP) + } + } + + // Apply path rewriting before proxying + applyPathRewrite(r, target) + + proxy, err := p.getOrCreateResourceProxy(resource, target) + if err != nil { + logger.Error("Auth Proxy: Failed to create proxy for resource %d target %s: %v", resource.ResourceID, target.TargetURL, err) + http.Error(w, "Invalid backend configuration", http.StatusInternalServerError) + return + } + + // Set sticky session cookie if enabled and there are multiple targets + if resource.StickySession && len(resource.Targets) > 1 { + http.SetCookie(w, &http.Cookie{ + Name: "p_sticky", + Value: target.TargetURL, + Path: "/", + Secure: resource.SSL, + HttpOnly: true, + }) + } + + proxy.ServeHTTP(w, r) +} + +func clientIPFromRemoteAddr(remoteAddr string) string { + if remoteAddr == "" { + return "" + } + + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + + return host +} + +// selectTarget picks a backend target based on path matching, sticky sessions, and round-robin +func (p *AuthProxy) selectTarget(r *http.Request, resource *ResourceAuthConfig) *TargetConfig { + targets := resource.Targets + if len(targets) == 0 { + return nil + } + if len(targets) == 1 { + if matchesPath(r.URL.Path, &targets[0]) { + return &targets[0] + } + // Single target with no path constraint always matches + if targets[0].Path == "" { + return &targets[0] + } + return nil + } + + // Filter targets by path match + var matched []*TargetConfig + for i := range targets { + if matchesPath(r.URL.Path, &targets[i]) { + matched = append(matched, &targets[i]) + } + } + + // If no path-specific targets matched, fall back to targets without path constraints + if len(matched) == 0 { + for i := range targets { + if targets[i].Path == "" { + matched = append(matched, &targets[i]) + } + } + } + + if len(matched) == 0 { + return nil + } + if len(matched) == 1 { + return matched[0] + } + + // Sticky session: check cookie for target affinity + if resource.StickySession { + if cookie, err := r.Cookie("p_sticky"); err == nil { + for _, t := range matched { + if t.TargetURL == cookie.Value { + return t + } + } + } + } + + // Round-robin across matched targets + idx := atomic.AddUint64(&resource.rrIndex, 1) - 1 + return matched[idx%uint64(len(matched))] +} + +// matchesPath checks if a request path matches a target's path constraints +func matchesPath(reqPath string, target *TargetConfig) bool { + if target.Path == "" { + return true + } + + path := target.Path + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + switch target.PathMatchType { + case "exact": + return reqPath == path + case "prefix": + return strings.HasPrefix(reqPath, path) + case "regex": + matched, err := regexp.MatchString(target.Path, reqPath) + return err == nil && matched + default: + return true + } +} + +// applyPathRewrite modifies the request URL path based on the target's rewrite configuration +func applyPathRewrite(r *http.Request, target *TargetConfig) { + if target.RewritePathType == "" { + return + } + + switch target.RewritePathType { + case "stripPrefix": + if target.PathMatchType == "prefix" && target.Path != "" { + prefix := target.Path + if !strings.HasPrefix(prefix, "/") { + prefix = "/" + prefix + } + r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix) + if r.URL.Path == "" { + r.URL.Path = "/" + } + // If rewritePath is set, prepend it (acts as addPrefix after strip) + if target.RewritePath != "" { + r.URL.Path = target.RewritePath + r.URL.Path + } + } + case "prefix": + if target.Path != "" { + escaped := regexp.QuoteMeta(target.Path) + re, err := regexp.Compile("^" + escaped + "(.*)") + if err == nil { + r.URL.Path = re.ReplaceAllString(r.URL.Path, target.RewritePath+"$1") + } + } + case "exact": + if target.Path != "" { + escaped := regexp.QuoteMeta(target.Path) + re, err := regexp.Compile("^" + escaped + "$") + if err == nil { + r.URL.Path = re.ReplaceAllString(r.URL.Path, target.RewritePath) + } + } + case "regex": + if target.Path != "" { + re, err := regexp.Compile(target.Path) + if err == nil { + r.URL.Path = re.ReplaceAllString(r.URL.Path, target.RewritePath) + } + } + } + + // Ensure path always starts with / + if !strings.HasPrefix(r.URL.Path, "/") { + r.URL.Path = "/" + r.URL.Path + } + + // Update RawPath as well + r.URL.RawPath = r.URL.Path +} + +// UpdateCertificates updates the TLS certificate store with certificates pushed from Pangolin. +// If certs are loaded for the first time and the proxy is already running, it starts the HTTPS server. +func (p *AuthProxy) UpdateCertificates(certs []TLSCertificateConfig) error { + p.mu.Lock() + defer p.mu.Unlock() + + newStore := make(map[string]*tls.Certificate) + newWildcards := make(map[string]*tls.Certificate) + loaded := 0 + + for _, certCfg := range certs { + tlsCert, err := tls.X509KeyPair([]byte(certCfg.CertPEM), []byte(certCfg.KeyPEM)) + if err != nil { + logger.Error("Auth Proxy: Failed to parse TLS cert for %s: %v", certCfg.Domain, err) + continue + } + + domain := strings.ToLower(certCfg.Domain) + + if certCfg.Wildcard { + // Wildcard cert: domain is stored as the base domain (e.g. "example.com") + // and covers *.example.com + // Strip leading "*." if present + baseDomain := domain + if strings.HasPrefix(baseDomain, "*.") { + baseDomain = baseDomain[2:] + } + newWildcards[baseDomain] = &tlsCert + // Also store as exact match for the base domain itself + newStore[baseDomain] = &tlsCert + logger.Info("Auth Proxy: Loaded wildcard TLS cert for *.%s", baseDomain) + } else { + newStore[domain] = &tlsCert + logger.Info("Auth Proxy: Loaded TLS cert for %s", domain) + } + loaded++ + } + + p.certStore = newStore + p.certWildcards = newWildcards + hadCerts := p.hasCerts + p.hasCerts = loaded > 0 + + // If we just got certs for the first time and the proxy is already running, start HTTPS + if p.hasCerts && !hadCerts && p.running { + p.startHTTPSServerLocked() + } + + // If we lost all certs, stop HTTPS + if !p.hasCerts && hadCerts { + p.stopHTTPSServerLocked() + } + + logger.Info("Auth Proxy: Certificate store updated with %d cert(s)", loaded) + return nil +} + +// GetResource returns the auth config for a domain +func (p *AuthProxy) GetResource(domain string) *ResourceAuthConfig { + p.mu.RLock() + defer p.mu.RUnlock() + return p.resources[strings.ToLower(domain)] +} + +// ReplaceResources replaces the full in-memory resource configuration set. +func (p *AuthProxy) ReplaceResources(resources []ResourceAuthConfig) { + p.mu.Lock() + defer p.mu.Unlock() + + newResources := make(map[string]*ResourceAuthConfig, len(resources)) + for _, resource := range resources { + resourceCopy := resource + + // Backward compat: if targets is empty but targetUrl is set (older Pangolin), + // synthesize a single target from the flat targetUrl field + if len(resourceCopy.Targets) == 0 && resourceCopy.TargetURL != "" { + resourceCopy.Targets = []TargetConfig{{TargetURL: resourceCopy.TargetURL}} + } + + domain := strings.ToLower(resourceCopy.Domain) + newResources[domain] = &resourceCopy + } + + p.resources = newResources + p.proxyCache = make(map[string]*httputil.ReverseProxy) + logger.Info("Auth Proxy: Replaced resource set with %d resources", len(resources)) +} + +// IsRunning returns whether the proxy is running +func (p *AuthProxy) IsRunning() bool { + p.mu.RLock() + defer p.mu.RUnlock() + return p.running +} + +// BindStatus returns whether each listener is active, skipped (port in use), or not started. +func (p *AuthProxy) BindStatus() (httpOk, httpsOk, httpSkipped, httpsSkipped bool) { + p.mu.RLock() + defer p.mu.RUnlock() + httpOk = p.running && !p.httpBindFailed && p.servers["__default__"] != nil + httpsOk = p.running && !p.httpsBindFailed && p.httpsServer != nil + httpSkipped = p.httpBindFailed + httpsSkipped = p.httpsBindFailed + return +} + +// parseRSAPublicKey parses a PEM-encoded RSA public key +func parseRSAPublicKey(pemStr string) (*rsa.PublicKey, error) { + // Try PEM decode first + block, _ := pem.Decode([]byte(pemStr)) + if block != nil { + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, err + } + rsaPub, ok := pub.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("not an RSA public key") + } + return rsaPub, nil + } + + // Try base64 decode + decoded, err := base64.StdEncoding.DecodeString(pemStr) + if err != nil { + return nil, fmt.Errorf("failed to decode public key: %w", err) + } + + pub, err := x509.ParsePKIXPublicKey(decoded) + if err != nil { + return nil, err + } + + rsaPub, ok := pub.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("not an RSA public key") + } + + return rsaPub, nil +} + +func (p *AuthProxy) getCachedSession(sessionToken string) (*UserClaims, bool) { + if sessionToken == "" || p.sessionCacheTTL <= 0 { + return nil, false + } + + now := time.Now() + p.sessionMu.RLock() + entry, exists := p.sessionCache[sessionToken] + p.sessionMu.RUnlock() + + if !exists { + return nil, false + } + + if now.After(entry.expiresAt) { + p.sessionMu.Lock() + if current, ok := p.sessionCache[sessionToken]; ok && now.After(current.expiresAt) { + delete(p.sessionCache, sessionToken) + } + p.sessionMu.Unlock() + return nil, false + } + + claimsCopy := entry.claims + return &claimsCopy, true +} + +func (p *AuthProxy) cacheSession(sessionToken string, claims *UserClaims, apiExpiresAt string) { + if sessionToken == "" || claims == nil || p.sessionCacheTTL <= 0 { + return + } + + now := time.Now() + expiresAt := now.Add(p.sessionCacheTTL) + if parsed, ok := parseSessionExpiry(apiExpiresAt); ok && parsed.Before(expiresAt) { + expiresAt = parsed + } + + if !expiresAt.After(now) { + return + } + + claimsCopy := *claims + p.sessionMu.Lock() + p.sessionCache[sessionToken] = cachedSession{claims: claimsCopy, expiresAt: expiresAt} + p.sessionMu.Unlock() +} + +// getOrCreateResourceProxy creates or retrieves a cached reverse proxy for a resource+target combination. +// Each proxy is configured with the resource's TLS settings, host header, and custom headers. +func (p *AuthProxy) getOrCreateResourceProxy(resource *ResourceAuthConfig, target *TargetConfig) (*httputil.ReverseProxy, error) { + cacheKey := fmt.Sprintf("%d:%s", resource.ResourceID, target.TargetURL) + + p.mu.RLock() + if proxy, ok := p.proxyCache[cacheKey]; ok { + p.mu.RUnlock() + return proxy, nil + } + p.mu.RUnlock() + + targetURL, err := url.Parse(target.TargetURL) + if err != nil { + return nil, err + } + + p.mu.Lock() + defer p.mu.Unlock() + if proxy, ok := p.proxyCache[cacheKey]; ok { + return proxy, nil + } + + proxy := httputil.NewSingleHostReverseProxy(targetURL) + + // Determine host header: prefer setHostHeader, else use target host + hostHeader := targetURL.Host + if resource.SetHostHeader != "" { + hostHeader = resource.SetHostHeader + } + + // Capture custom headers for the Director closure + customHeaders := resource.Headers + + originalDirector := proxy.Director + proxy.Director = func(req *http.Request) { + originalDirector(req) + originalHost := req.Host + req.Host = hostHeader + req.Header.Set("X-Forwarded-Host", originalHost) + + // X-Forwarded-Proto based on incoming connection TLS state + if req.TLS != nil { + req.Header.Set("X-Forwarded-Proto", "https") + } else { + req.Header.Set("X-Forwarded-Proto", "http") + } + + // X-Real-IP from remote address + clientIP := req.RemoteAddr + if host, _, splitErr := net.SplitHostPort(req.RemoteAddr); splitErr == nil { + clientIP = host + } + req.Header.Set("X-Real-IP", clientIP) + + // Apply custom headers from resource config + for name, value := range customHeaders { + req.Header.Set(name, value) + } + } + + // Transport: use per-resource TLS config for HTTPS backends or when tlsServerName is set + transport := p.proxyTransport + if targetURL.Scheme == "https" || resource.TLSServerName != "" { + transport = p.proxyTransport.Clone() + transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + if resource.TLSServerName != "" { + transport.TLSClientConfig.ServerName = resource.TLSServerName + } + } + proxy.Transport = transport + + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, proxyErr error) { + domain := r.Header.Get("X-Forwarded-Host") + if domain == "" { + domain = r.Host + } + logger.Error("Auth Proxy: Backend error for %s → %s: %v", domain, target.TargetURL, proxyErr) + http.Error(w, "Backend unavailable", http.StatusBadGateway) + } + + p.proxyCache[cacheKey] = proxy + return proxy, nil +} + +func sessionCacheTTLFromEnv() time.Duration { + const defaultTTL = 15 * time.Second + raw := strings.TrimSpace(os.Getenv("NEWT_AUTH_SESSION_CACHE_TTL")) + if raw == "" { + return defaultTTL + } + + ttl, err := time.ParseDuration(raw) + if err != nil || ttl < 0 { + logger.Warn("Auth Proxy: Invalid NEWT_AUTH_SESSION_CACHE_TTL=%q, using default %s", raw, defaultTTL) + return defaultTTL + } + + return ttl +} + +func parseSessionExpiry(value string) (time.Time, bool) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return time.Time{}, false + } + + if t, err := time.Parse(time.RFC3339Nano, trimmed); err == nil { + return t, true + } + if t, err := time.Parse(time.RFC3339, trimmed); err == nil { + return t, true + } + + return time.Time{}, false +} diff --git a/dns/authority.go b/dns/authority.go new file mode 100644 index 0000000..a37a460 --- /dev/null +++ b/dns/authority.go @@ -0,0 +1,868 @@ +package dns + +import ( + "context" + "fmt" + "net" + "strings" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "github.com/miekg/dns" +) + +// DNSAuthorityConfig holds configuration for a DNS authority zone +type DNSAuthorityConfig struct { + Enabled bool `json:"enabled"` + Domain string `json:"domain"` // e.g., "hub.docker.visnovsky.us" + TTL uint32 `json:"ttl"` // TTL for DNS responses + RoutingPolicy string `json:"routingPolicy"` // "failover", "roundrobin", "priority", "intelligent" + StickySession bool `json:"stickySession,omitempty"` + ServingSiteID int `json:"servingSiteId,omitempty"` + Targets []DNSAuthorityTarget `json:"targets"` +} + +func authoritativeBaseDomain(domain string) string { + trimmed := strings.TrimSuffix(domain, ".") + trimmed = strings.TrimPrefix(trimmed, "*.") + return dns.Fqdn(trimmed) +} + +// DNSAuthorityTarget represents a target IP with health status +type DNSAuthorityTarget struct { + IP string `json:"ip"` // Public IP to respond with + Priority int `json:"priority"` // Lower = higher priority for failover + Healthy bool `json:"healthy"` // Health status from Pangolin + SiteID int `json:"siteId"` // Site ID for reference + SiteName string `json:"siteName"` // Human-readable name + BackendLatencyMs int64 `json:"backendLatencyMs,omitempty"` // Existing target healthcheck latency from site to backend +} + +// DNSAuthorityServer serves authoritative DNS responses on port 53 +type DNSAuthorityServer struct { + mu sync.RWMutex + zones map[string]*DNSAuthorityConfig // domain -> config + server *dns.Server + tcpServer *dns.Server + ctx context.Context + cancel context.CancelFunc + running bool + bindAddr string + rrIndex map[string]int // For round-robin: domain -> current index + latencyCache map[string]map[string]latencySample // domain -> target IP -> sample + latencyRefreshing map[string]bool // domain -> refresh in progress + stickyAffinities map[string]map[string]stickyAffinity // queried domain -> client IP -> sticky target + stickyAffinityTTL time.Duration + intelligentProbeInterval time.Duration + intelligentProbeTimeout time.Duration +} + +type latencySample struct { + latency time.Duration + measuredAt time.Time +} + +type stickyAffinity struct { + targetIP string + establishedAt time.Time +} + +// NewDNSAuthorityServer creates a new DNS authority server +func NewDNSAuthorityServer(bindAddr string) *DNSAuthorityServer { + ctx, cancel := context.WithCancel(context.Background()) + + if bindAddr == "" { + bindAddr = "0.0.0.0" + } + + return &DNSAuthorityServer{ + zones: make(map[string]*DNSAuthorityConfig), + ctx: ctx, + cancel: cancel, + bindAddr: bindAddr, + rrIndex: make(map[string]int), + latencyCache: make(map[string]map[string]latencySample), + latencyRefreshing: make(map[string]bool), + stickyAffinities: make(map[string]map[string]stickyAffinity), + stickyAffinityTTL: 24 * time.Hour, + intelligentProbeInterval: 15 * time.Second, + intelligentProbeTimeout: 500 * time.Millisecond, + } +} + +// UpdateZone updates or adds a zone configuration +func (s *DNSAuthorityServer) UpdateZone(config *DNSAuthorityConfig) { + s.mu.Lock() + defer s.mu.Unlock() + + // Normalize domain to FQDN format (trailing dot) + domain := config.Domain + if len(domain) > 0 && domain[len(domain)-1] != '.' { + domain = domain + "." + } + + // Set defaults + if config.TTL == 0 { + config.TTL = 60 + } + if config.RoutingPolicy == "" { + config.RoutingPolicy = "failover" + } + + s.zones[domain] = config + if _, ok := s.latencyCache[domain]; !ok { + s.latencyCache[domain] = make(map[string]latencySample) + } + logger.Info("DNS Authority: Updated zone %s with %d targets (policy: %s)", domain, len(config.Targets), config.RoutingPolicy) +} + +// RemoveZone removes a zone configuration +func (s *DNSAuthorityServer) RemoveZone(domain string) { + s.mu.Lock() + defer s.mu.Unlock() + + if len(domain) > 0 && domain[len(domain)-1] != '.' { + domain = domain + "." + } + + delete(s.zones, domain) + delete(s.rrIndex, domain) + delete(s.latencyCache, domain) + delete(s.latencyRefreshing, domain) + delete(s.stickyAffinities, normalizeDomainKey(domain)) + logger.Info("DNS Authority: Removed zone %s", domain) +} + +// UpdateTargetHealth updates the health status of a target +func (s *DNSAuthorityServer) UpdateTargetHealth(domain string, siteID int, healthy bool) { + s.mu.Lock() + defer s.mu.Unlock() + + if len(domain) > 0 && domain[len(domain)-1] != '.' { + domain = domain + "." + } + + zone, exists := s.zones[domain] + if !exists { + return + } + + for i := range zone.Targets { + if zone.Targets[i].SiteID == siteID { + zone.Targets[i].Healthy = healthy + logger.Debug("DNS Authority: Updated health for %s site %d to %v", domain, siteID, healthy) + break + } + } +} + +// Start starts the DNS authority server on port 53. +// +// CAVEAT: Port 53 is a privileged port on most systems. This method performs +// a pre-flight bind check before attempting to start. Common reasons for +// failure include: +// - systemd-resolved (Linux) already listening on 127.0.0.53:53 +// - Another DNS server (dnsmasq, unbound) occupying port 53 +// - Insufficient privileges (non-root on Linux, no admin on Windows) +// - macOS mDNSResponder listening on port 53 +func (s *DNSAuthorityServer) Start() error { + s.mu.Lock() + if s.running { + s.mu.Unlock() + return nil + } + s.mu.Unlock() + + addr := fmt.Sprintf("%s:53", s.bindAddr) + + // Pre-flight: check if port 53 is bindable before committing to start + if err := s.checkPort53Available(addr); err != nil { + logger.Warn("DNS Authority: Port 53 pre-flight check failed on %s: %v", addr, err) + logger.Warn("DNS Authority: Common causes:") + logger.Warn(" - systemd-resolved is listening on 127.0.0.53:53 (try: sudo systemctl disable --now systemd-resolved)") + logger.Warn(" - Another DNS server (dnsmasq, unbound, pihole) is using port 53") + logger.Warn(" - Insufficient privileges (port 53 requires root/admin)") + logger.Warn(" - macOS mDNSResponder is occupying port 53") + logger.Warn("DNS Authority: The server will NOT start. DNS authority zones are configured but inactive.") + return fmt.Errorf("port 53 is not available on %s: %w", s.bindAddr, err) + } + + logger.Info("DNS Authority: Port 53 pre-flight check passed on %s", addr) + + // Create DNS handler + mux := dns.NewServeMux() + mux.HandleFunc(".", s.handleDNSQuery) + + // Create UDP server + s.server = &dns.Server{ + Addr: addr, + Net: "udp", + Handler: mux, + } + + // Create TCP server (some clients prefer TCP) + s.tcpServer = &dns.Server{ + Addr: addr, + Net: "tcp", + Handler: mux, + } + + // Start UDP server + go func() { + logger.Info("DNS Authority: Starting UDP server on %s", addr) + if err := s.server.ListenAndServe(); err != nil { + if s.ctx.Err() == nil { + logger.Error("DNS Authority: UDP server error: %v", err) + } + } + }() + + // Start TCP server + go func() { + logger.Info("DNS Authority: Starting TCP server on %s", addr) + if err := s.tcpServer.ListenAndServe(); err != nil { + if s.ctx.Err() == nil { + logger.Error("DNS Authority: TCP server error: %v", err) + } + } + }() + + // Give servers time to start and check for bind errors + time.Sleep(100 * time.Millisecond) + + s.mu.Lock() + s.running = true + s.mu.Unlock() + + s.startIntelligentRefreshLoop() + + logger.Info("DNS Authority: Server started successfully on %s", addr) + return nil +} + +// Stop stops the DNS authority server +func (s *DNSAuthorityServer) Stop() error { + s.mu.Lock() + if !s.running { + s.mu.Unlock() + return nil + } + s.running = false + s.mu.Unlock() + + s.cancel() + + if s.server != nil { + if err := s.server.Shutdown(); err != nil { + logger.Error("DNS Authority: Error shutting down UDP server: %v", err) + } + } + + if s.tcpServer != nil { + if err := s.tcpServer.Shutdown(); err != nil { + logger.Error("DNS Authority: Error shutting down TCP server: %v", err) + } + } + + logger.Info("DNS Authority: Server stopped") + return nil +} + +// handleDNSQuery handles incoming DNS queries +func (s *DNSAuthorityServer) handleDNSQuery(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + + if len(r.Question) == 0 { + w.WriteMsg(m) + return + } + + q := r.Question[0] + logger.Debug("DNS Authority: Query for %s (type %s) from %s", q.Name, dns.TypeToString[q.Qtype], w.RemoteAddr()) + + s.mu.RLock() + zone, exactMatch := s.zones[q.Name] + + // If no exact match, try to find a wildcard match + if !exactMatch { + zone = s.findWildcardMatch(q.Name) + } + s.mu.RUnlock() + + if zone == nil || !zone.Enabled { + // Not authoritative for this domain - return NXDOMAIN or REFUSED + m.Rcode = dns.RcodeRefused + w.WriteMsg(m) + return + } + + switch q.Qtype { + case dns.TypeA: + s.handleARecord(m, q, zone, clientIPFromRemoteAddr(w.RemoteAddr())) + case dns.TypeAAAA: + // Return empty response for AAAA (no IPv6 support yet) + // This prevents browsers from waiting for AAAA timeout + m.Rcode = dns.RcodeSuccess + case dns.TypeNS: + s.handleNSRecord(m, q, zone) + case dns.TypeSOA: + s.handleSOARecord(m, q, zone) + default: + // Return empty response for unsupported types + m.Rcode = dns.RcodeSuccess + } + + w.WriteMsg(m) +} + +// findWildcardMatch finds a zone that matches via wildcard +func (s *DNSAuthorityServer) findWildcardMatch(name string) *DNSAuthorityConfig { + // Try progressively shorter domain prefixes + // e.g., for "foo.bar.example.com." try "*.bar.example.com." etc. + labels := dns.SplitDomainName(name) + for i := 1; i < len(labels); i++ { + wildcard := "*." + dns.Fqdn(labels[i]) + for j := i + 1; j < len(labels); j++ { + wildcard = wildcard[:len(wildcard)-1] + "." + labels[j] + "." + } + if zone, ok := s.zones[wildcard]; ok { + return zone + } + } + return nil +} + +// handleARecord responds with A records based on health and routing policy +func (s *DNSAuthorityServer) handleARecord(m *dns.Msg, q dns.Question, zone *DNSAuthorityConfig, clientIP string) { + ips := s.selectTargetIPs(zone, q.Name, clientIP) + + for _, ip := range ips { + parsedIP := net.ParseIP(ip) + if parsedIP == nil || parsedIP.To4() == nil { + continue + } + + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: zone.TTL, + }, + A: parsedIP.To4(), + } + m.Answer = append(m.Answer, rr) + } + + if len(m.Answer) == 0 { + logger.Warn("DNS Authority: No healthy targets for %s", q.Name) + } +} + +// handleNSRecord responds with NS records for all healthy targets and includes +// glue A-records in the Additional section so resolvers can reach each nameserver. +func (s *DNSAuthorityServer) handleNSRecord(m *dns.Msg, q dns.Question, zone *DNSAuthorityConfig) { + baseDomain := authoritativeBaseDomain(zone.Domain) + nsIndex := 1 + for _, target := range zone.Targets { + if target.Healthy || len(zone.Targets) == 1 { + nsName := dns.Fqdn(fmt.Sprintf("ns%d.%s", nsIndex, baseDomain)) + rr := &dns.NS{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + Ttl: zone.TTL, + }, + Ns: nsName, + } + m.Answer = append(m.Answer, rr) + + // Add glue A-record in Additional section + parsedIP := net.ParseIP(target.IP) + if parsedIP != nil && parsedIP.To4() != nil { + glue := &dns.A{ + Hdr: dns.RR_Header{ + Name: nsName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: zone.TTL, + }, + A: parsedIP.To4(), + } + m.Extra = append(m.Extra, glue) + } + + nsIndex++ + } + } +} + +// handleSOARecord responds with SOA record +func (s *DNSAuthorityServer) handleSOARecord(m *dns.Msg, q dns.Question, zone *DNSAuthorityConfig) { + baseDomain := authoritativeBaseDomain(zone.Domain) + soa := &dns.SOA{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + Ttl: zone.TTL, + }, + Ns: dns.Fqdn(fmt.Sprintf("ns1.%s", baseDomain)), + Mbox: dns.Fqdn(fmt.Sprintf("hostmaster.%s", baseDomain)), + Serial: uint32(time.Now().Unix()), + Refresh: 86400, + Retry: 7200, + Expire: 3600000, + Minttl: zone.TTL, + } + m.Answer = append(m.Answer, soa) +} + +// selectTargetIPs selects IPs based on routing policy and health status +func (s *DNSAuthorityServer) selectTargetIPs(zone *DNSAuthorityConfig, queriedDomain string, clientIP string) []string { + var ips []string + + // Get healthy targets + var healthyTargets []DNSAuthorityTarget + var allTargets []DNSAuthorityTarget + + for _, t := range zone.Targets { + allTargets = append(allTargets, t) + if t.Healthy { + healthyTargets = append(healthyTargets, t) + } + } + + // If no healthy targets, fall back to all targets (best effort) + targets := healthyTargets + if len(targets) == 0 { + targets = allTargets + logger.Warn("DNS Authority: No healthy targets for %s, using all targets", zone.Domain) + } + + if len(targets) == 0 { + return ips + } + + var stickyTarget *DNSAuthorityTarget + if zone.StickySession && clientIP != "" { + if target, ok := s.getStickyTarget(queriedDomain, clientIP, targets); ok { + stickyTarget = &target + } + } + + switch zone.RoutingPolicy { + case "failover": + if stickyTarget != nil { + ips = append(ips, stickyTarget.IP) + } else { + ips = append(ips, selectLowestPriorityTarget(targets).IP) + } + + case "roundrobin": + if stickyTarget != nil { + ips = append(ips, stickyTarget.IP) + } else { + // Rotate through all healthy targets + s.mu.Lock() + idx := s.rrIndex[zone.Domain] + s.rrIndex[zone.Domain] = (idx + 1) % len(targets) + s.mu.Unlock() + ips = append(ips, targets[idx%len(targets)].IP) + } + + case "priority": + // Return all healthy targets (client can choose) + if stickyTarget != nil { + ips = append(ips, stickyTarget.IP) + } + for _, t := range targets { + if stickyTarget != nil && t.IP == stickyTarget.IP { + continue + } + ips = append(ips, t.IP) + } + + case "intelligent": + if stickyTarget != nil { + ips = append(ips, stickyTarget.IP) + } else { + best := s.selectIntelligentTarget(zone, targets) + ips = append(ips, best.IP) + } + + default: + // Default to failover behavior + if stickyTarget != nil { + ips = append(ips, stickyTarget.IP) + } else { + ips = append(ips, selectLowestPriorityTarget(targets).IP) + } + } + + return ips +} + +// RecordSessionEstablished records that a client has established a session on +// this Newt for the given domain. Sticky DNS responses will prioritize this +// site's public IP for subsequent queries from that client. +func (s *DNSAuthorityServer) RecordSessionEstablished(domain string, clientIP string) { + if clientIP == "" || domain == "" { + return + } + + domainKey := normalizeDomainKey(domain) + + s.mu.RLock() + var zone *DNSAuthorityConfig + if z, ok := s.zones[domainKey]; ok { + zone = z + } else { + zone = s.findWildcardMatch(domainKey) + } + + if zone == nil || !zone.Enabled || !zone.StickySession || zone.ServingSiteID == 0 { + s.mu.RUnlock() + return + } + + targetIP := "" + for _, target := range zone.Targets { + if target.SiteID == zone.ServingSiteID { + targetIP = target.IP + break + } + } + s.mu.RUnlock() + + if targetIP == "" { + return + } + + s.setStickyTarget(domainKey, clientIP, targetIP) +} + +func (s *DNSAuthorityServer) getStickyTarget(queriedDomain string, clientIP string, targets []DNSAuthorityTarget) (DNSAuthorityTarget, bool) { + domainKey := normalizeDomainKey(queriedDomain) + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + byClient := s.stickyAffinities[domainKey] + if byClient == nil { + return DNSAuthorityTarget{}, false + } + + affinity, ok := byClient[clientIP] + if !ok { + return DNSAuthorityTarget{}, false + } + + if now.Sub(affinity.establishedAt) > s.stickyAffinityTTL { + delete(byClient, clientIP) + if len(byClient) == 0 { + delete(s.stickyAffinities, domainKey) + } + return DNSAuthorityTarget{}, false + } + + for _, t := range targets { + if t.IP == affinity.targetIP { + return t, true + } + } + + delete(byClient, clientIP) + if len(byClient) == 0 { + delete(s.stickyAffinities, domainKey) + } + + return DNSAuthorityTarget{}, false +} + +func (s *DNSAuthorityServer) setStickyTarget(queriedDomain string, clientIP string, targetIP string) { + domainKey := normalizeDomainKey(queriedDomain) + s.mu.Lock() + defer s.mu.Unlock() + + byClient := s.stickyAffinities[domainKey] + if byClient == nil { + byClient = make(map[string]stickyAffinity) + s.stickyAffinities[domainKey] = byClient + } + + byClient[clientIP] = stickyAffinity{targetIP: targetIP, establishedAt: time.Now()} +} + +func clientIPFromRemoteAddr(addr net.Addr) string { + if addr == nil { + return "" + } + + if udpAddr, ok := addr.(*net.UDPAddr); ok { + return udpAddr.IP.String() + } + if tcpAddr, ok := addr.(*net.TCPAddr); ok { + return tcpAddr.IP.String() + } + + host, _, err := net.SplitHostPort(addr.String()) + if err != nil { + return addr.String() + } + return host +} + +func normalizeDomainKey(domain string) string { + trimmed := strings.TrimSpace(domain) + if trimmed == "" { + return "" + } + return strings.ToLower(dns.Fqdn(trimmed)) +} + +func selectLowestPriorityTarget(targets []DNSAuthorityTarget) DNSAuthorityTarget { + best := targets[0] + for _, t := range targets[1:] { + if t.Priority < best.Priority { + best = t + } + } + return best +} + +func (s *DNSAuthorityServer) selectIntelligentTarget(zone *DNSAuthorityConfig, targets []DNSAuthorityTarget) DNSAuthorityTarget { + now := time.Now() + refreshNeeded := false + + s.mu.RLock() + zoneCache := s.latencyCache[zone.Domain] + bestScore := int64(0) + var bestTarget *DNSAuthorityTarget + for i := range targets { + t := &targets[i] + sample, ok := zoneCache[t.IP] + if !ok || now.Sub(sample.measuredAt) > s.intelligentProbeInterval { + refreshNeeded = true + continue + } + frontendLatencyMs := sample.latency.Milliseconds() + if frontendLatencyMs <= 0 { + frontendLatencyMs = 1 + } + + backendLatencyMs := t.BackendLatencyMs + if backendLatencyMs <= 0 { + backendLatencyMs = frontendLatencyMs + } + + // Weight edge reachability higher than backend health latency so DNS answers + // prefer the site clients can connect to fastest, while still accounting for + // backend responsiveness when edge latencies are close. + score := (frontendLatencyMs * 70) + (backendLatencyMs * 30) + if bestTarget == nil || score < bestScore || (score == bestScore && t.Priority < bestTarget.Priority) { + bestTarget = t + bestScore = score + } + } + s.mu.RUnlock() + + if refreshNeeded { + s.scheduleLatencyRefresh(zone.Domain, targets) + } + + if bestTarget != nil { + return *bestTarget + } + + // If no fresh latency is available yet, preserve HA semantics via failover. + return selectLowestPriorityTarget(targets) +} + +func (s *DNSAuthorityServer) scheduleLatencyRefresh(domain string, targets []DNSAuthorityTarget) { + s.mu.Lock() + if s.latencyRefreshing[domain] { + s.mu.Unlock() + return + } + s.latencyRefreshing[domain] = true + timeout := s.intelligentProbeTimeout + s.mu.Unlock() + + go func() { + results := make(map[string]latencySample) + for _, t := range targets { + if latency, ok := probeTargetLatency(t.IP, timeout); ok { + results[t.IP] = latencySample{latency: latency, measuredAt: time.Now()} + } + } + + s.mu.Lock() + cache := s.latencyCache[domain] + if cache == nil { + cache = make(map[string]latencySample) + s.latencyCache[domain] = cache + } + for ip, sample := range results { + cache[ip] = sample + } + s.latencyRefreshing[domain] = false + s.mu.Unlock() + }() +} + +func (s *DNSAuthorityServer) startIntelligentRefreshLoop() { + go func() { + ticker := time.NewTicker(s.intelligentProbeInterval) + defer ticker.Stop() + + // Prime the latency cache shortly after start so intelligent routing + // can use measured data without waiting for a query-triggered refresh. + s.refreshIntelligentZones() + + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + s.refreshIntelligentZones() + } + } + }() +} + +type intelligentRefreshJob struct { + domain string + targets []DNSAuthorityTarget +} + +func (s *DNSAuthorityServer) refreshIntelligentZones() { + jobs := make([]intelligentRefreshJob, 0) + + s.mu.RLock() + for domain, zone := range s.zones { + if zone == nil || !zone.Enabled || zone.RoutingPolicy != "intelligent" { + continue + } + + healthyTargets := make([]DNSAuthorityTarget, 0, len(zone.Targets)) + allTargets := make([]DNSAuthorityTarget, 0, len(zone.Targets)) + for _, target := range zone.Targets { + allTargets = append(allTargets, target) + if target.Healthy { + healthyTargets = append(healthyTargets, target) + } + } + + targets := healthyTargets + if len(targets) == 0 { + targets = allTargets + } + + if len(targets) == 0 { + continue + } + + jobs = append(jobs, intelligentRefreshJob{domain: domain, targets: targets}) + } + s.mu.RUnlock() + + for _, job := range jobs { + s.scheduleLatencyRefresh(job.domain, job.targets) + } +} + +func probeTargetLatency(ip string, timeout time.Duration) (time.Duration, bool) { + ports := []string{"443", "80"} + for _, port := range ports { + addr := net.JoinHostPort(ip, port) + start := time.Now() + conn, err := net.DialTimeout("tcp", addr, timeout) + if err != nil { + continue + } + _ = conn.Close() + return time.Since(start), true + } + return 0, false +} + +// IsRunning returns whether the server is running +func (s *DNSAuthorityServer) IsRunning() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.running +} + +// GetZones returns a copy of all configured zones +func (s *DNSAuthorityServer) GetZones() map[string]*DNSAuthorityConfig { + s.mu.RLock() + defer s.mu.RUnlock() + + zones := make(map[string]*DNSAuthorityConfig) + for k, v := range s.zones { + zones[k] = v + } + return zones +} + +// checkPort53Available performs a pre-flight check to determine whether port 53 +// can be bound (both UDP and TCP). The test listeners are closed immediately +// after a successful bind. This catches common conflicts early with a clear +// error message instead of a silent goroutine failure. +func (s *DNSAuthorityServer) checkPort53Available(addr string) error { + // Check UDP + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return fmt.Errorf("invalid address %s: %w", addr, err) + } + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return fmt.Errorf("cannot bind UDP %s: %w", addr, err) + } + udpConn.Close() + + // Check TCP + tcpListener, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("cannot bind TCP %s: %w", addr, err) + } + tcpListener.Close() + + return nil +} + +// SelfTest performs a DNS query to itself to verify the server is responding +func (s *DNSAuthorityServer) SelfTest() error { + addr := fmt.Sprintf("%s:53", s.bindAddr) + + c := new(dns.Client) + c.Timeout = 1 * time.Second + + m := new(dns.Msg) + // We use a dummy query; even a NXDOMAIN response confirms the server is alive + m.SetQuestion("healthcheck.newt.", dns.TypeA) + + // In some environments, binding to 0.0.0.0 then querying 127.0.0.1:53 works. + // We'll try the bind address first, then localhost as fallback. + testAddrs := []string{addr} + if s.bindAddr == "0.0.0.0" { + testAddrs = append(testAddrs, "127.0.0.1:53") + } + + var lastErr error + for _, testAddr := range testAddrs { + _, _, err := c.Exchange(m, testAddr) + if err == nil { + return nil + } + lastErr = err + } + + return fmt.Errorf("self-test failed: %w", lastErr) +} diff --git a/get-newt.sh b/get-newt.sh index d4ddd3f..ad2b18a 100644 --- a/get-newt.sh +++ b/get-newt.sh @@ -31,7 +31,7 @@ print_error() { # Function to get latest version from GitHub API get_latest_version() { local latest_info - + if command -v curl >/dev/null 2>&1; then latest_info=$(curl -fsSL "$GITHUB_API_URL" 2>/dev/null) elif command -v wget >/dev/null 2>&1; then @@ -40,30 +40,30 @@ get_latest_version() { print_error "Neither curl nor wget is available. Please install one of them." >&2 exit 1 fi - + if [ -z "$latest_info" ]; then print_error "Failed to fetch latest version information" >&2 exit 1 fi - + # Extract version from JSON response (works without jq) local version=$(echo "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/') - + if [ -z "$version" ]; then print_error "Could not parse version from GitHub API response" >&2 exit 1 fi - + # Remove 'v' prefix if present version=$(echo "$version" | sed 's/^v//') - + echo "$version" } # Detect OS and architecture detect_platform() { local os arch - + # Detect OS case "$(uname -s)" in Linux*) os="linux" ;; @@ -75,12 +75,12 @@ detect_platform() { exit 1 ;; esac - + # Detect architecture case "$(uname -m)" in x86_64|amd64) arch="amd64" ;; arm64|aarch64) arch="arm64" ;; - armv7l|armv6l) + armv7l|armv6l) if [ "$os" = "linux" ]; then if [ "$(uname -m)" = "armv6l" ]; then arch="arm32v6" @@ -91,7 +91,7 @@ detect_platform() { arch="arm64" # Default for non-Linux ARM fi ;; - riscv64) + riscv64) if [ "$os" = "linux" ]; then arch="riscv64" else @@ -104,10 +104,52 @@ detect_platform() { exit 1 ;; esac - + echo "${os}_${arch}" } +# Check for potential system conflicts (Port 53, systemd-resolved, etc.) +check_conflicts() { + local platform="$1" + + # Only check on Linux as that's where systemd-resolved is common + if [[ "$platform" == *"linux"* ]]; then + print_status "Checking for potential system conflicts..." + + # Check if port 53 is in use + if command -v ss >/dev/null 2>&1; then + if ss -tuln | grep -q ":53 "; then + print_warning "Port 53 is already in use on this system." + + # Check if it's systemd-resolved + if ss -tulnp | grep -q "systemd-resolve\|resolved"; then + print_warning "systemd-resolved appears to be occupying port 53." + print_warning "This will prevent Newt's DNS Authority from starting on 0.0.0.0:53." + print_warning "To fix this, you can either:" + print_warning " 1. Disable systemd-resolved: sudo systemctl disable --now systemd-resolved" + print_warning " 2. Bind Newt to a specific IP: newt --dns-bind " + print_warning " 3. Disable DNS Authority if not needed: newt --disable-dns-authority" + else + print_warning "Another process is using port 53. DNS Authority may fail to start." + print_warning "If you don't need this feature, you can disable it with: newt --disable-dns-authority" + fi + fi + fi + + # Check for WireGuard kernel module (optional for Newt but recommended for high performance) + if [ -f /proc/modules ] && ! grep -q "^wireguard" /proc/modules; then + print_status "WireGuard kernel module not loaded. Newt will use userspace implementation (netstack)." + print_status "For better performance, you can load it with: sudo modprobe wireguard" + fi + + # Check privileges for port 53 + if [ "$EUID" -ne 0 ]; then + print_warning "Newt is being installed as a non-root user." + print_warning "Note: Binding to port 53 (DNS) typically requires root privileges (sudo)." + fi + fi +} + # Get installation directory get_install_dir() { if [ "$OS" = "windows" ]; then @@ -184,9 +226,11 @@ install_newt() { $sudo_cmd mkdir -p "$install_dir" print_status "Using sudo to install to ${install_dir}" $sudo_cmd mv "$temp_file" "$final_path" + $sudo_cmd chmod +x "$final_path" else mkdir -p "$install_dir" mv "$temp_file" "$final_path" + chmod +x "$final_path" fi print_status "newt installed to ${final_path}" @@ -203,13 +247,21 @@ install_newt() { verify_installation() { local install_dir="$1" local exe_suffix="" +<<<<<<< HEAD case "$PLATFORM" in *windows*) exe_suffix=".exe" ;; esac +======= + + if [[ "$PLATFORM" == *"windows"* ]]; then + exe_suffix=".exe" + fi + +>>>>>>> 6307483 (DNS authority management and features) local newt_path="${install_dir}/newt${exe_suffix}" - + if [ -f "$newt_path" ] && [ -x "$newt_path" ]; then print_status "Installation successful!" print_status "newt version: $("$newt_path" --version 2>/dev/null || echo "unknown")" @@ -236,10 +288,17 @@ main() { PLATFORM=$(detect_platform) print_status "Detected platform: ${PLATFORM}" +<<<<<<< HEAD +======= + # Check for conflicts + check_conflicts "$PLATFORM" + +>>>>>>> 6307483 (DNS authority management and features) # Get install directory INSTALL_DIR=$(get_install_dir) print_status "Install directory: ${INSTALL_DIR}" +<<<<<<< HEAD # Check if we need sudo SUDO_CMD=$(get_sudo_cmd "$INSTALL_DIR") if [ -n "$SUDO_CMD" ]; then @@ -248,6 +307,10 @@ main() { # Install newt install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD" +======= + # Install newt + install_newt "$PLATFORM" "$INSTALL_DIR" +>>>>>>> 6307483 (DNS authority management and features) # Verify installation if verify_installation "$INSTALL_DIR"; then diff --git a/go.mod b/go.mod index 2aa8f5e..833b8ad 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.25.0 require ( github.com/docker/docker v28.5.2+incompatible github.com/gaissmai/bart v0.26.0 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/gorilla/websocket v1.5.3 + github.com/miekg/dns v1.1.62 github.com/prometheus/client_golang v1.23.2 github.com/vishvananda/netlink v1.3.1 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.66.0 diff --git a/go.sum b/go.sum index d345b1d..882160a 100644 --- a/go.sum +++ b/go.sum @@ -33,6 +33,8 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= @@ -53,6 +55,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= +github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw= diff --git a/healthcheck/healthcheck.go b/healthcheck/healthcheck.go index f618803..8ba6636 100644 --- a/healthcheck/healthcheck.go +++ b/healthcheck/healthcheck.go @@ -59,6 +59,7 @@ type Target struct { Status Health `json:"status"` LastCheck time.Time `json:"lastCheck"` LastError string `json:"lastError,omitempty"` + LastLatencyMs int64 `json:"latencyMs,omitempty"` CheckCount int `json:"checkCount"` timer *time.Timer ctx context.Context @@ -351,10 +352,13 @@ func (m *Monitor) monitorTarget(target *Target) { // Reset timer for next check with current interval target.timer.Reset(interval) - // Notify callback if status changed - if oldStatus != target.Status && m.callback != nil { - logger.Info("Target %d status changed: %s -> %s", - target.Config.ID, oldStatus.String(), target.Status.String()) + // Notify callback on every check so downstream systems receive fresh + // latency telemetry even when health status is unchanged. + if m.callback != nil { + if oldStatus != target.Status { + logger.Info("Target %d status changed: %s -> %s", + target.Config.ID, oldStatus.String(), target.Status.String()) + } go m.callback(m.GetTargets()) } } @@ -366,6 +370,7 @@ func (m *Monitor) performHealthCheck(target *Target) { target.CheckCount++ target.LastCheck = time.Now() target.LastError = "" + target.LastLatencyMs = 0 // Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports) host := target.Config.Hostname @@ -411,6 +416,7 @@ func (m *Monitor) performHealthCheck(target *Target) { } // Perform request + requestStart := time.Now() resp, err := target.client.Do(req) if err != nil { target.Status = StatusUnhealthy @@ -419,6 +425,7 @@ func (m *Monitor) performHealthCheck(target *Target) { return } defer resp.Body.Close() + target.LastLatencyMs = time.Since(requestStart).Milliseconds() // Check response status var expectedStatus int diff --git a/main.go b/main.go index d5f2a96..2d93ed0 100644 --- a/main.go +++ b/main.go @@ -21,7 +21,9 @@ import ( "syscall" "time" + "github.com/fosrl/newt/auth" "github.com/fosrl/newt/authdaemon" + dnsauthority "github.com/fosrl/newt/dns" "github.com/fosrl/newt/docker" "github.com/fosrl/newt/healthcheck" "github.com/fosrl/newt/logger" @@ -139,6 +141,10 @@ var ( authorizedKeysFile string preferEndpoint string healthMonitor *healthcheck.Monitor + dnsAuthorityServer *dnsauthority.DNSAuthorityServer // DNS Authority server for intelligent routing + dnsBindAddr string // Bind address for DNS Authority (default 0.0.0.0) + disableDNSAuthority bool // Disable the DNS Authority server entirely + authProxy *auth.AuthProxy // Auth proxy for SSO protection on direct routes enforceHealthcheckCert bool authDaemonKey string authDaemonPrincipalsFile string @@ -252,6 +258,10 @@ func runNewtMain(ctx context.Context) { asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES") pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED") + dnsBindAddr = os.Getenv("DNS_BIND_ADDR") + disableDNSAuthorityEnv := os.Getenv("DISABLE_DNS_AUTHORITY") + disableDNSAuthority = disableDNSAuthorityEnv == "true" + disableClientsEnv := os.Getenv("DISABLE_CLIENTS") disableClients = disableClientsEnv == "true" useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE") @@ -472,6 +482,13 @@ func runNewtMain(ctx context.Context) { } } + if dnsBindAddr == "" { + flag.StringVar(&dnsBindAddr, "dns-bind", "", "Bind address for DNS Authority server (default 0.0.0.0, also DNS_BIND_ADDR)") + } + if disableDNSAuthorityEnv == "" { + flag.BoolVar(&disableDNSAuthority, "disable-dns-authority", false, "Disable the DNS Authority server (default false, also DISABLE_DNS_AUTHORITY)") + } + // do a --version check version := flag.Bool("version", false, "Print the version") @@ -701,6 +718,7 @@ func runNewtMain(ctx context.Context) { "status": target.Status.String(), "lastCheck": target.LastCheck.Format(time.RFC3339), "checkCount": target.CheckCount, + "latencyMs": target.LastLatencyMs, "lastError": target.LastError, "config": target.Config, } @@ -1596,6 +1614,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( "status": target.Status.String(), "lastCheck": target.LastCheck.Format(time.RFC3339), "checkCount": target.CheckCount, + "latencyMs": target.LastLatencyMs, "lastError": target.LastError, "config": target.Config, } @@ -1609,6 +1628,291 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } }) + // Register handler for DNS Authority configuration + bindDNSAuthoritySessionAffinity := func() { + if authProxy == nil { + return + } + authProxy.SetSessionEstablishedHandler(func(domain string, clientIP string) { + if dnsAuthorityServer == nil { + return + } + dnsAuthorityServer.RecordSessionEstablished(domain, clientIP) + }) + } + + dnsStatusAddress := func() string { + if dnsBindAddr == "" { + return "0.0.0.0:53" + } + return dnsBindAddr + ":53" + } + + client.RegisterHandler("newt/dns/authority/config", func(msg websocket.WSMessage) { + logger.Debug("Received DNS authority config message: %+v", msg.Data) + + type DNSAuthorityConfigMessage struct { + Action string `json:"action"` // "update", "remove", "start", "stop" + Zones []dnsauthority.DNSAuthorityConfig `json:"zones,omitempty"` + } + + var configMsg DNSAuthorityConfigMessage + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling DNS authority config data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &configMsg); err != nil { + logger.Error("Error unmarshaling DNS authority config data: %v", err) + return + } + + switch configMsg.Action { + case "start": + if disableDNSAuthority { + logger.Warn("Received request to start DNS authority, but it is disabled via configuration") + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "disabled", + "message": "DNS Authority is disabled via --disable-dns-authority", + }) + return + } + if dnsAuthorityServer == nil { + dnsAuthorityServer = dnsauthority.NewDNSAuthorityServer(dnsBindAddr) + bindDNSAuthoritySessionAffinity() + } + if !dnsAuthorityServer.IsRunning() { + if err := dnsAuthorityServer.Start(); err != nil { + logger.Error("Failed to start DNS authority server: %v", err) + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "error", + "error": err.Error(), + }) + } else { + logger.Info("DNS Authority server started") + // Wait a moment for the server to bind, then self-test + go func() { + time.Sleep(200 * time.Millisecond) + if err := dnsAuthorityServer.SelfTest(); err != nil { + logger.Warn("DNS Authority self-test failed: %v", err) + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "warning", + "message": "Server bound but self-test failed (firewall or loopback issue?)", + "error": err.Error(), + "address": dnsStatusAddress(), + }) + } else { + logger.Info("DNS Authority self-test passed") + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "running", + "address": dnsStatusAddress(), + }) + } + }() + } + } + + case "stop": + if dnsAuthorityServer != nil && dnsAuthorityServer.IsRunning() { + if err := dnsAuthorityServer.Stop(); err != nil { + logger.Error("Failed to stop DNS authority server: %v", err) + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "error", + "error": err.Error(), + }) + } else { + logger.Info("DNS Authority server stopped") + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "disabled", + }) + } + } + + case "update": + // Ensure DNS authority server is running + if disableDNSAuthority { + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "disabled", + "message": "DNS Authority is disabled via --disable-dns-authority", + }) + return + } + if dnsAuthorityServer == nil { + dnsAuthorityServer = dnsauthority.NewDNSAuthorityServer(dnsBindAddr) + bindDNSAuthoritySessionAffinity() + } + justStarted := false + if !dnsAuthorityServer.IsRunning() { + if err := dnsAuthorityServer.Start(); err != nil { + logger.Error("Failed to start DNS authority server: %v", err) + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "error", + "error": err.Error(), + }) + return + } + justStarted = true + } + for _, zone := range configMsg.Zones { + zoneCopy := zone + dnsAuthorityServer.UpdateZone(&zoneCopy) + } + logger.Info("Updated %d DNS authority zones", len(configMsg.Zones)) + if justStarted { + // Self-test after start; report status once the test completes + zoneCount := len(configMsg.Zones) + go func() { + time.Sleep(200 * time.Millisecond) + if err := dnsAuthorityServer.SelfTest(); err != nil { + logger.Warn("DNS Authority self-test failed: %v", err) + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "warning", + "message": "Server bound but self-test failed", + "error": err.Error(), + "address": dnsStatusAddress(), + "zones": zoneCount, + }) + } else { + logger.Info("DNS Authority self-test passed") + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "running", + "address": dnsStatusAddress(), + "zones": zoneCount, + }) + } + }() + } else { + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "running", + "address": dnsStatusAddress(), + "zones": len(configMsg.Zones), + }) + } + + case "remove": + if dnsAuthorityServer == nil { + return + } + for _, zone := range configMsg.Zones { + dnsAuthorityServer.RemoveZone(zone.Domain) + } + // If no zones left, stop the server + if len(dnsAuthorityServer.GetZones()) == 0 { + if err := dnsAuthorityServer.Stop(); err != nil { + logger.Error("Failed to stop DNS authority server: %v", err) + } else { + _ = client.SendMessage("newt/dns/status", map[string]interface{}{ + "status": "disabled", + }) + } + } + logger.Info("Removed %d DNS authority zones", len(configMsg.Zones)) + + default: + logger.Warn("Unknown DNS authority config action: %s", configMsg.Action) + } + }) + + // Register handler for Auth Proxy configuration (SSO protection for direct routes) + client.RegisterHandler("newt/auth/proxy/config", func(msg websocket.WSMessage) { + logger.Debug("Received auth proxy config message: %+v", msg.Data) + + var configMsg auth.AuthProxyConfig + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling auth proxy config data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &configMsg); err != nil { + logger.Error("Error unmarshaling auth proxy config data: %v", err) + return + } + + switch configMsg.Action { + case "start": + if authProxy == nil { + authProxy = auth.NewAuthProxy() + bindDNSAuthoritySessionAffinity() + } + if !authProxy.IsRunning() { + if err := authProxy.Start(); err != nil { + logger.Error("Failed to start auth proxy: %v", err) + } else { + logger.Info("Auth Proxy started") + } + } + + case "stop": + if authProxy != nil && authProxy.IsRunning() { + if err := authProxy.Stop(); err != nil { + logger.Error("Failed to stop auth proxy: %v", err) + } else { + logger.Info("Auth Proxy stopped") + } + } + + case "update": + // Ensure auth proxy is running + if authProxy == nil { + authProxy = auth.NewAuthProxy() + bindDNSAuthoritySessionAffinity() + } + if !authProxy.IsRunning() { + if err := authProxy.Start(); err != nil { + logger.Error("Failed to start auth proxy: %v", err) + return + } + } + + // Update TLS certificates if provided + if len(configMsg.TLSCertificates) > 0 { + if err := authProxy.UpdateCertificates(configMsg.TLSCertificates); err != nil { + logger.Error("Failed to update TLS certificates: %v", err) + } else { + logger.Info("Updated auth proxy with %d TLS certificate(s)", len(configMsg.TLSCertificates)) + } + } + + // Update global auth config + if err := authProxy.UpdateConfig(configMsg.Auth); err != nil { + logger.Error("Failed to update auth config: %v", err) + } + + // Update resource configs + authProxy.ReplaceResources(configMsg.Resources) + logger.Info("Updated auth proxy with %d resources", len(configMsg.Resources)) + + // Report auth proxy bind status back to Pangolin + httpOk, httpsOk, httpSkipped, httpsSkipped := authProxy.BindStatus() + statusData := map[string]interface{}{ + "httpListening": httpOk, + "httpsListening": httpsOk, + "httpSkipped": httpSkipped, + "httpsSkipped": httpsSkipped, + "certCount": len(configMsg.TLSCertificates), + "resourceCount": len(configMsg.Resources), + } + if httpSkipped || httpsSkipped { + statusData["warning"] = "One or more auth proxy ports are already in use by another process (e.g. Traefik). Set NEWT_AUTH_PROXY_BIND / NEWT_AUTH_PROXY_HTTPS_BIND to use alternate ports." + } + _ = client.SendMessage("newt/auth/proxy/status", statusData) + + case "remove": + if authProxy == nil { + return + } + for _, resource := range configMsg.Resources { + authProxy.RemoveResource(resource.Domain) + } + logger.Info("Removed %d auth proxy resources", len(configMsg.Resources)) + + default: + logger.Warn("Unknown auth proxy config action: %s", configMsg.Action) + } + }) + // Register handler for getting health check status client.RegisterHandler("newt/blueprint/results", func(msg websocket.WSMessage) { logger.Debug("Received blueprint results message")