diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 81f903dc..babf401b 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -5,9 +5,12 @@ name: Go on: push: - branches: [ "main" ] + branches: [ "main", "refactor" ] + # No branch filter here — run CI on every pull request regardless of target + # branch. This catches PRs aimed at refactor, feature branches, or anything + # else, not just main. The previous filter ("branches: [ main ]") silently + # skipped CI for PRs into refactor, which is our active dev branch. pull_request: - branches: [ "main" ] jobs: @@ -21,6 +24,13 @@ jobs: with: go-version-file: "go.mod" + - name: Check for slog overwrite calls in tests + run: | + if grep -rn 'slog\.SetDefault\|slog\.SetLogLoggerLevel' --include='*_test.go' .; then + echo "::error::Test files should not upate the slog.Default logger or level. This pollutes the output." + exit 1 + fi + - name: Build run: go build -v -tags with_clash_api ./... - name: Test diff --git a/AGENTS.md b/AGENTS.md index b1bbed87..c7041d4f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1 +1,61 @@ - Telemetry attributes: follow rules in https://github.com/getlantern/semconv/blob/main/AGENTS.md + +## Code Comments + +**Default: no comment.** Only add one if a specific *why* is load-bearing — invariant, concurrency guarantee, error condition, zero-value behavior, non-obvious caller contract, or a constraint that would surprise the reader. Aesthetic "this section is well-documented" comments are noise. + +Before writing any comment, run this checklist on the proposed text. If any answer is yes, delete or rewrite: + +1. Does it restate the identifier name or signature? (`// Foo does foo`, `// updateX manages X across Y`) +2. Does it narrate what the visible next line does? (`// Cancel any existing listener` immediately above `cancel()`) +3. Does it open with a generic lifecycle/management preamble before getting to the point? (`// manages the lifecycle of...`, `// handles the X for Y`) +4. Does it reference tickets, coworkers, sibling files, commit SHAs, or other code locations? Those belong in the commit message / PR description — they rot in source. +5. Does it describe the mechanism instead of the contract? (`authenticates via peer credentials over a Unix socket` vs. `authenticates each connection`) + +Lead with the *why*, not a summary of the function. If the only thing you can write is a summary, the comment isn't needed. + +Examples: + +```go +// BAD — restates name, generic preamble, narrates the code +// updateURLTestListener manages the lifecycle of the URL test result listener +// across VPN status changes. Connected always re-attaches (canceling any +// existing listener) so a stale event still leaves the listener bound to +// the live storage. + +// GOOD — leads with the trap, no narration +// Status events are dispatched in unordered goroutines, so reacting to +// intermediate statuses risks a stale handler tearing down a listener +// a concurrent Connected handler just attached. Only Connected (which +// re-attaches unconditionally) and terminal-down statuses are acted on. +``` + +```go +// BAD — narrates the next line +// Cancel any in-flight offline tests and wait for them to finish. +c.offlineTestCancel() +<-done + +// GOOD — no comment; the names already say it +c.offlineTestCancel() +<-done +``` + +```go +// BAD — references ticket and coworker +// Per Freshdesk #172640 (reported by Alice), saveServers held the lock +// for 1+ minute. We now release access before disk I/O. + +// GOOD — states the invariant; the ticket lives in git history +// access is released before disk I/O so a slow write can't starve readers. +``` + +Before writing an inline comment, consider whether a doc comment on the enclosing function or type would make it unnecessary. Prefer documenting contracts at the declaration over explaining implementation details inline. + +TODO comments must state *what* needs to happen and *why* it isn't done now. `TODO: ???` is not actionable — either resolve it or remove it. + +## Go Doc Comments + +- When a doc comment is warranted on an exported identifier, start it with the identifier's name and use complete sentences: `// Foo does X.` The first sentence is the summary shown by `go doc` and pkg.go.dev. +- Package comments: one per package, above the `package` clause (conventionally in `doc.go` for larger packages), starting with `// Package foo ...`. +- Formatting (gofmt-aware since Go 1.19): blank lines separate paragraphs; indented lines render as code blocks; lines starting with `-`, `*`, or `1.` render as lists; `[Name]` links to other symbols; `# Heading` renders as a heading. Avoid HTML and manual wrapping. diff --git a/Makefile b/Makefile index c44ff48b..f5dad53f 100644 --- a/Makefile +++ b/Makefile @@ -4,9 +4,5 @@ proto: protoc --go_out=. --plugin=build/protoc-gen-go --go_opt=paths=source_relative api/protos/subscription.proto protoc --go_out=. --plugin=build/protoc-gen-go --go_opt=paths=source_relative issue/issue.proto -mock: - go install go.uber.org/mock/mockgen@latest - go generate ./... - test: go test -v ./... diff --git a/README.md b/README.md index 395dfd38..63d61e34 100644 --- a/README.md +++ b/README.md @@ -34,50 +34,103 @@ Available variables: * `RADIANCE_FEATURE_OVERRIDE`: Comma-separated list of feature flags to force-enable on the server side. If set, the value is sent as the `X-Lantern-Feature-Override` header on config requests in any environment, and it is recommended for testing/non-production use. For example, `RADIANCE_FEATURE_OVERRIDE=bandit_assignment` enables bandit-based proxy assignment during testing. -## Packages +## Architecture -Use `common.Init` to setup directories and configure loggers. -> [!note] -> This isn't necessary if `NewRadiance` was called as it will call `Init` for you. +Radiance is structured around a `LocalBackend` pattern that ties together all core functionality: configuration, servers, VPN connection, account management, issue reporting, and telemetry. The `LocalBackend` is the central coordinator and should be the primary interface for interacting with Radiance programmatically. -### `vpn` +In addition to being the core of the [Lantern client](https://github.com/getlantern/lantern), radiance also provides a standalone daemon and CLI: -The `vpn` package provides high-level functions for controlling the VPN tunnel. +- **`lanternd`** — the VPN daemon that runs the `LocalBackend` and exposes an IPC server. It can run in the foreground or be installed as a system service. +- **`lantern`** — a CLI client that communicates with the daemon over IPC. -To connect to the best available server, you can use the `QuickConnect` function. This function takes a server group (`servers.SGLantern`, `servers.SGUser`, or `"all"`) and a `PlatformInterface` as input. For example: +### Building CLI & Daemon -```go -err := vpn.QuickConnect(servers.SGLantern, platIfce) +From the `cmd/` directory: + +```sh +make build-daemon +make build-cli +``` +Or using [just](https://github.com/casey/just) +```sh +just build-daemon +just build-cli ``` -will connect to the best Lantern server, while: +Both binaries are output to `bin/`. You can also run the daemon directly with `make run-daemon`. -```go -err := vpn.QuickConnect("all", platIfce) +### Running + +```sh +# Start the daemon +lanternd run --data-path ~/data --log-path ~/logs + +# Install/uninstall as a system service +lanternd install --data-path ~/data --log-path ~/logs +lanternd uninstall + +# CLI commands (requires a running daemon) +lantern connect [--tag ] +lantern disconnect +lantern status +lantern servers +lantern account login +lantern subscription +lantern split-tunnel +lantern logs +lantern ip ``` -will connect to the best overall. +## Packages + +Use `common.Init` to setup directories and configure loggers. +> [!note] +> This isn't necessary if `NewLocalBackend` was called as it will call `Init` for you. + +### `backend` + +The `backend` package provides `LocalBackend`, the main entry point for all Radiance functionality. Create one with `NewLocalBackend(ctx, opts)` and call `Start()` to begin fetching configuration and serving requests. `LocalBackend` owns and coordinates the `VPNClient`, `ServerManager`, `ConfigHandler`, `AccountClient`, `IssueReporter`, and telemetry. + +### `vpn` -You can also connect to a specific server using `ConnectToServer`. This function requires a server group, a server tag, and a `PlatformInterface`. For example: +The `vpn` package provides `VPNClient`, which manages the lifecycle of the VPN tunnel. ```go -err := vpn.ConnectToServer(servers.SGUser, "my-server", platIfce) +client := vpn.NewVPNClient(dataPath, logger, platformIfce) +err := client.Connect(boxOptions) ``` -Both `QuickConnect` and `ConnectToServer` can be called without disconnecting first, allowing you to seamlessly switch between servers or connection modes. +`Connect` can be called without disconnecting first, allowing you to seamlessly switch between servers. Once connected, you can query status or view `Connections`. To stop the VPN, call `Disconnect`. -Once connected, you can check the `GetStatus` or view `ActiveConnections`. To stop the VPN, simply call `Disconnect`. The package also supports reconnecting to the last used server with `Reconnect`. +> [!note] +> In most cases, you should use the `LocalBackend` methods (`ConnectVPN`, `DisconnectVPN`, `RestartVPN`, `VPNStatus`) rather than using `VPNClient` directly. -This package also includes split tunneling capabilities, allowing you to include or exclude specific applications, domains, or IP addresses from the VPN tunnel. You can manage split tunneling by creating a `SplitTunnel` handler with `NewSplitTunnelHandler`. This handler allows you to `Enable` or `Disable` split tunneling, `AddItem` or `RemoveItem` from the filter, and view the current `Filters`. +This package also includes split tunneling capabilities via the `SplitTunnel` type, allowing you to include or exclude specific applications, domains, or IP addresses from the VPN tunnel. ### `servers` -The `servers` package is responsible for managing all VPN server configurations, separating them into two groups: `lantern` (official Lantern servers) and `user` (user-provided servers). +The `servers` package manages all VPN server configurations, separating them into two groups: `lantern` (official Lantern servers fetched from the config) and `user` (user-provided servers). -The `Manager` allows you to `AddServers` and `RemoveServer` configurations. You can retrieve the config for a specific server with `GetServerByTag` or use `Servers` to retrieve all configs. +The `Manager` allows you to `AddServers` and `RemoveServers` configurations. You can retrieve the config for a specific server with `GetServerByTag` or use `Servers` to retrieve all configs. > [!caution] -> While you can get a new `Manager` instance with `NewManager`, it is recommended to use `Radiance.ServerManager`. This will return the shared manager instance. `NewManager` can be useful for retrieving server information if you don't have access to the shared instance, but the new instance should not be kept as it won't stay in sync and adding server configs to it will overwrite existing configs if both manager instances are pointed to the same server file. +> While you can get a new `Manager` instance with `NewManager`, it is recommended to use the `LocalBackend`'s server methods (`Servers`, `AddServers`, `RemoveServers`, `GetServerByTag`). These use the shared manager instance. `NewManager` can be useful for retrieving server information if you don't have access to the shared instance, but the new instance should not be kept as it won't stay in sync. + +A key feature of this package is the ability to add private servers from a server manager via an access token using `AddPrivateServer`. This process uses Trust-on-first-use (TOFU) to securely add the server. Once a private server is added, you can invite other users with `InviteToPrivateServer` and revoke access with `RevokePrivateServerInvite`. + +### `ipc` + +The `ipc` package provides the communication layer between the `lantern` CLI and the `lanternd` daemon. The `ipc.Server` exposes an HTTP API backed by the `LocalBackend`, and the `ipc.Client` provides a typed Go client for calling it. All communication happens over a local socket. + +### `account` + +The `account` package handles user authentication (email/password and OAuth), signup, email verification, account recovery, device management, and subscription operations. It communicates with the Lantern account server and caches authentication state locally. + +### `config` + +The `config` package fetches proxy configuration from the Lantern API on a polling interval and emits `NewConfigEvent` events when the configuration changes. The `LocalBackend` subscribes to these events to update server configurations automatically. + +### `events` -A key feature of this package is the ability to add private servers from a server manager via an access token using `AddPrivateServer`. This process uses Trust-on-first-use (TOFU) to securely add the server. Once a private server is added, you can use the manager to invite other users to it with `InviteToPrivateServer` and revoke access with `RevokePrivateServerInvite`. +A generic pub-sub event system used throughout Radiance for decoupled communication between components (config changes, VPN status updates, log entries, etc.). diff --git a/account/auth.go b/account/auth.go new file mode 100644 index 00000000..da7568da --- /dev/null +++ b/account/auth.go @@ -0,0 +1,117 @@ +package account + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "math/big" + + "github.com/1Password/srp" + "golang.org/x/crypto/pbkdf2" + "google.golang.org/protobuf/proto" + + "github.com/getlantern/radiance/account/protos" +) + +func (a *Client) fetchSalt(ctx context.Context, email string) (*protos.GetSaltResponse, error) { + query := map[string]string{"email": email} + resp, err := a.sendRequest(ctx, "GET", "/users/salt", query, nil, nil) + if err != nil { + return nil, err + } + var salt protos.GetSaltResponse + if err := proto.Unmarshal(resp, &salt); err != nil { + return nil, fmt.Errorf("unmarshaling salt response: %w", err) + } + return &salt, nil +} + +// clientProof performs the SRP authentication flow to generate the client proof for the given email and password. +func (a *Client) clientProof(ctx context.Context, email, password string, salt []byte) ([]byte, error) { + srpClient, err := newSRPClient(email, password, salt) + if err != nil { + return nil, err + } + + A := srpClient.EphemeralPublic() + data := &protos.PrepareRequest{ + Email: email, + A: A.Bytes(), + } + resp, err := a.sendRequest(ctx, "POST", "/users/prepare", nil, nil, data) + if err != nil { + return nil, err + } + + var srpB protos.PrepareResponse + if err := proto.Unmarshal(resp, &srpB); err != nil { + return nil, fmt.Errorf("unmarshaling prepare response: %w", err) + } + B := big.NewInt(0).SetBytes(srpB.B) + if err = srpClient.SetOthersPublic(B); err != nil { + return nil, err + } + + key, err := srpClient.Key() + if err != nil || key == nil { + return nil, fmt.Errorf("user_not_found error while generating Client key %w", err) + } + if !srpClient.GoodServerProof(salt, email, srpB.Proof) { + return nil, fmt.Errorf("user_not_found checking server proof %w", err) + } + + proof, err := srpClient.ClientProof() + if err != nil { + return nil, fmt.Errorf("user_not_found generating client proof %w", err) + } + return proof, nil +} + +// getSalt retrieves the salt for the given email address or it's cached value. +func (a *Client) getSalt(ctx context.Context, email string) ([]byte, error) { + if cached := a.getSaltCached(); cached != nil { + return cached, nil + } + resp, err := a.fetchSalt(ctx, email) + if err != nil { + return nil, err + } + return resp.Salt, nil +} + +const group = srp.RFC5054Group3072 + +func newSRPClient(email, password string, salt []byte) (*srp.SRP, error) { + if len(salt) == 0 || len(password) == 0 || len(email) == 0 { + return nil, errors.New("salt, password and email should not be empty") + } + + encryptedKey, err := generateEncryptedKey(password, email, salt) + if err != nil { + return nil, fmt.Errorf("failed to generate encrypted key: %w", err) + } + + return srp.NewSRPClient(srp.KnownGroups[group], encryptedKey, nil), nil +} + +func generateEncryptedKey(password, email string, salt []byte) (*big.Int, error) { + if len(salt) == 0 || len(password) == 0 || len(email) == 0 { + return nil, errors.New("salt or password or email is empty") + } + combinedInput := password + email + encryptedKey := pbkdf2.Key([]byte(combinedInput), salt, 4096, 32, sha256.New) + encryptedKeyBigInt := big.NewInt(0).SetBytes(encryptedKey) + return encryptedKeyBigInt, nil +} + +func generateSalt() ([]byte, error) { + salt := make([]byte, 16) + if n, err := rand.Read(salt); err != nil { + return nil, err + } else if n != 16 { + return nil, errors.New("failed to generate 16 byte salt") + } + return salt, nil +} diff --git a/account/client.go b/account/client.go new file mode 100644 index 00000000..e5b1a95a --- /dev/null +++ b/account/client.go @@ -0,0 +1,241 @@ +// Package account provides a client for communicating with the account server to perform operations +// such as user authentication, subscription management, and account information retrieval. +package account + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "maps" + "net/http" + "path/filepath" + "sort" + "strings" + "sync" + "unicode" + "unicode/utf8" + + "google.golang.org/protobuf/proto" + + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/env" + "github.com/getlantern/radiance/common/settings" +) + +const tracerName = "github.com/getlantern/radiance/account" + +// Client is an account client that communicates with the account server to perform operations such as +// user authentication, subscription management, and account information retrieval. +type Client struct { + httpClient *http.Client + // proURL and authURL override the default server URLs. Used for testing. + proURL string + authURL string + + salt []byte + saltPath string + mu sync.RWMutex +} + +// NewClient creates a new account client with the given HTTP client and data directory for caching +// the salt value. +func NewClient(httpClient *http.Client, dataDir string) *Client { + path := filepath.Join(dataDir, saltFileName) + salt, err := readSalt(path) + if err != nil { + slog.Warn("failed to read salt", "error", err) + } + return &Client{ + httpClient: httpClient, + salt: salt, + saltPath: path, + } +} + +func (a *Client) getSaltCached() []byte { + a.mu.RLock() + defer a.mu.RUnlock() + return a.salt +} + +func (a *Client) setSalt(salt []byte) { + a.mu.Lock() + defer a.mu.Unlock() + a.salt = salt +} + +func (a *Client) proBaseURL() string { + if a.proURL != "" { + return a.proURL + } + return common.GetProServerURL() +} + +func (a *Client) baseURL() string { + if a.authURL != "" { + return a.authURL + } + return common.GetBaseURL() +} + +// sendRequest sends an HTTP request to the specified URL with the given method, query parameters, +// headers, and body. If the URL is relative, the base URL will be prepended. +func (a *Client) sendRequest( + ctx context.Context, + method, url string, + queryParams, headers map[string]string, + body any, +) ([]byte, error) { + // check if url is absolute, if not prepend base URL + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + url = a.baseURL() + url + } + + var bodyReader io.Reader + contentType := "" + if body != nil { + if pb, ok := body.(proto.Message); ok { + data, err := proto.Marshal(pb) + if err != nil { + return nil, fmt.Errorf("marshaling protobuf request: %w", err) + } + bodyReader = bytes.NewReader(data) + contentType = "application/x-protobuf" + } else { + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshaling JSON request: %w", err) + } + bodyReader = bytes.NewReader(data) + contentType = "application/json" + } + } + req, err := http.NewRequestWithContext(ctx, method, url, bodyReader) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + for k, v := range headers { + req.Header.Set(k, v) + } + req.Header.Set(common.AppNameHeader, common.Name) + req.Header.Set(common.VersionHeader, common.GetVersion()) + req.Header.Set(common.PlatformHeader, common.Platform) + if contentType != "" { + req.Header.Set("Content-Type", contentType) + req.Header.Set("Accept", contentType) + } + if len(queryParams) > 0 { + q := req.URL.Query() + for k, v := range queryParams { + q.Set(k, v) + } + req.URL.RawQuery = q.Encode() + } + + if env.GetBool(env.PrintCurl) { + slog.Debug("CURL command", "curl", curlFromRequest(req)) + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("sending request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + sanitized := sanitizeResponseBody(respBody) + slog.Debug("error response", "path", req.URL.Path, "status", resp.StatusCode, "body", string(sanitized)) + return nil, fmt.Errorf("unexpected status %v body %s", resp.StatusCode, sanitized) + } + + if len(respBody) == 0 { + return nil, nil + } + if contentType := resp.Header.Get("Content-Type"); strings.Contains(contentType, "application/json") { + return sanitizeResponseBody(respBody), nil + } + return respBody, nil +} + +// sendProRequest sends a request to the Pro server, automatically adding the required headers, +// including the device ID, user ID, and Pro token from settings, if available. If the URL is relative, +// the Pro server base URL will be prepended. +func (a *Client) sendProRequest( + ctx context.Context, + method, url string, + queryParams, additionalheaders map[string]string, + body any, +) ([]byte, error) { + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + url = a.proBaseURL() + url + } + headers := map[string]string{ + common.DeviceIDHeader: settings.GetString(settings.DeviceIDKey), + } + if tok := settings.GetString(settings.TokenKey); tok != "" { + headers[common.ProTokenHeader] = tok + } + if uid := settings.GetString(settings.UserIDKey); uid != "" { + headers[common.UserIDHeader] = uid + } + maps.Copy(headers, additionalheaders) + return a.sendRequest(ctx, method, url, queryParams, headers, body) +} + +// curlFromRequest generates a curl command string from an [http.Request]. +func curlFromRequest(req *http.Request) string { + var b strings.Builder + fmt.Fprintf(&b, "curl -X %s", req.Method) + + keys := make([]string, 0, len(req.Header)) + for k := range req.Header { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + for _, v := range req.Header[k] { + fmt.Fprintf(&b, " -H '%s: %s'", k, v) + } + } + + if req.Body != nil { + buf, _ := io.ReadAll(req.Body) + // Important! we need to reset the body since it can only be read once. + req.Body = io.NopCloser(bytes.NewBuffer(buf)) + fmt.Fprintf(&b, " -d '%s'", shellEscape(string(buf))) + } + + fmt.Fprintf(&b, " '%s'", req.URL.String()) + return b.String() +} + +func shellEscape(s string) string { + return strings.ReplaceAll(s, "'", "'\\''") +} + +func sanitizeResponseBody(data []byte) []byte { + var out bytes.Buffer + r := bytes.NewReader(data) + for { + ch, size, err := r.ReadRune() + if err != nil { + break + } + if ch == utf8.RuneError && size == 1 { + continue + } + if unicode.IsControl(ch) && ch != '\n' && ch != '\r' && ch != '\t' { + continue + } + out.WriteRune(ch) + } + return out.Bytes() +} diff --git a/account/datacap.go b/account/datacap.go new file mode 100644 index 00000000..151ac545 --- /dev/null +++ b/account/datacap.go @@ -0,0 +1,150 @@ +package account + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "time" + + "go.opentelemetry.io/otel" + + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/traces" +) + +type sseEvent struct { + Type string + Data string +} + +// readSSE reads Server-Sent Events from body and sends parsed events on the +// returned channel. The channel is closed when the body returns EOF, an error +// occurs, or ctx is cancelled. The caller is responsible for closing body. +// After the channel is closed, call the returned function to retrieve any +// scanner error (nil on clean EOF). +func readSSE(ctx context.Context, body io.Reader) (<-chan sseEvent, func() error) { + ch := make(chan sseEvent, 1) + var scanErr error + go func() { + defer close(ch) + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) // 1 MB max token + var evt sseEvent + for scanner.Scan() { + if ctx.Err() != nil { + return + } + line := scanner.Text() + switch { + case strings.HasPrefix(line, "event:"): + evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + case strings.HasPrefix(line, "data:"): + dataLine := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if evt.Data == "" { + evt.Data = dataLine + } else { + evt.Data = evt.Data + "\n" + dataLine + } + case strings.HasPrefix(line, ":"): + // comment / heartbeat — ignore + case line == "": + // blank line = event delimiter + if evt.Type != "" || evt.Data != "" { + select { + case ch <- evt: + case <-ctx.Done(): + return + } + evt = sseEvent{} + } + } + } + scanErr = scanner.Err() + }() + return ch, func() error { return scanErr } +} + +// DataCapStream connects to the datacap SSE endpoint and calls handler whenever +// the server pushes an update. The method blocks until ctx is cancelled, +// reconnecting with backoff on stream errors. +func (a *Client) DataCapStream(ctx context.Context, handler func(*DataCapInfo)) error { + bo := common.NewBackoff(2 * time.Minute) + for { + if ctx.Err() != nil { + return ctx.Err() + } + start := time.Now() + err := a.connectDataCapSSE(ctx, handler) + if err != nil { + slog.Debug("datacap SSE stream ended", "error", err) + } + if ctx.Err() != nil { + return ctx.Err() + } + // Reset backoff if the connection was up for a while before dropping, + // so we reconnect quickly after a transient disconnect. + if time.Since(start) > 30*time.Second { + bo.Reset() + } + bo.Wait(ctx) + } +} + +// connectDataCapSSE opens an SSE connection to the datacap stream endpoint and +// processes events until the stream ends or ctx is cancelled. +func (a *Client) connectDataCapSSE(ctx context.Context, handler func(*DataCapInfo)) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "datacap_sse") + defer span.End() + + sseURL := fmt.Sprintf("%s/stream/datacap/%s", a.baseURL(), settings.GetString(settings.DeviceIDKey)) + req, err := common.NewRequestWithHeaders(ctx, http.MethodGet, sseURL, nil) + if err != nil { + return traces.RecordError(ctx, fmt.Errorf("datacap SSE request: %w", err)) + } + req.Header.Set(common.AcceptHeader, "text/event-stream") + + resp, err := a.httpClient.Do(req) + if err != nil { + return traces.RecordError(ctx, fmt.Errorf("datacap SSE connect: %w", err)) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return traces.RecordError(ctx, fmt.Errorf("datacap SSE status %d", resp.StatusCode)) + } + + slog.Debug("connected to datacap SSE stream") + eventCh, scanErr := readSSE(ctx, resp.Body) + for evt := range eventCh { + switch evt.Type { + case "datacap": + var datacap DataCapInfo + if err := json.Unmarshal([]byte(evt.Data), &datacap); err != nil { + slog.Debug("datacap SSE unmarshal error", "error", err) + continue + } + handler(&datacap) + if datacap.Usage != nil { + slog.Debug("datacap updated", "bytesUsed", datacap.Usage.BytesUsed) + } + case "cap_exhausted": + slog.Warn("datacap exhausted") + default: + // heartbeat or unknown event — ignore + } + } + if err := ctx.Err(); err != nil { + return traces.RecordError(ctx, err) + } + if err := scanErr(); err != nil { + return traces.RecordError(ctx, fmt.Errorf("datacap SSE scanner: %w", err)) + } + return traces.RecordError(ctx, errors.New("datacap SSE stream ended unexpectedly")) +} diff --git a/api/sse_test.go b/account/datacap_test.go similarity index 99% rename from api/sse_test.go rename to account/datacap_test.go index d92e09ec..834be833 100644 --- a/api/sse_test.go +++ b/account/datacap_test.go @@ -1,4 +1,4 @@ -package api +package account import ( "context" diff --git a/api/jwt.go b/account/jwt.go similarity index 79% rename from api/jwt.go rename to account/jwt.go index cb14f482..c381243a 100644 --- a/api/jwt.go +++ b/account/jwt.go @@ -1,4 +1,4 @@ -package api +package account import ( "encoding/json" @@ -10,13 +10,15 @@ import ( type JWTUserInfo struct { UserID string `json:"user_id"` Email string `json:"email"` - DeviceId string `json:"device_id"` + DeviceID string `json:"device_id"` LegacyUserID int64 `json:"legacy_user_id"` LegacyToken string `json:"legacy_token"` } func decodeJWT(tokenStr string) (*JWTUserInfo, error) { claims := jwt.MapClaims{} + // ParseUnverified is used intentionally: the JWT has already been validated + // server-side and the client only needs to extract claims for local use. token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, &claims) if err != nil { return nil, err diff --git a/api/protos/auth.pb.go b/account/protos/auth.pb.go similarity index 100% rename from api/protos/auth.pb.go rename to account/protos/auth.pb.go diff --git a/api/protos/auth.proto b/account/protos/auth.proto similarity index 100% rename from api/protos/auth.proto rename to account/protos/auth.proto diff --git a/api/protos/subscription.pb.go b/account/protos/subscription.pb.go similarity index 100% rename from api/protos/subscription.pb.go rename to account/protos/subscription.pb.go diff --git a/api/protos/subscription.proto b/account/protos/subscription.proto similarity index 100% rename from api/protos/subscription.proto rename to account/protos/subscription.proto diff --git a/api/subscription.go b/account/subscription.go similarity index 58% rename from api/subscription.go rename to account/subscription.go index 2fcadf44..2a6f54f1 100644 --- a/api/subscription.go +++ b/account/subscription.go @@ -1,19 +1,20 @@ -package api +package account import ( "context" + "encoding/json" "fmt" "log/slog" "net/url" "strconv" "time" - "github.com/getlantern/radiance/api/protos" - "github.com/getlantern/radiance/backend" + "go.opentelemetry.io/otel" + + "github.com/getlantern/radiance/account/protos" "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/traces" - "go.opentelemetry.io/otel" ) type ( @@ -39,49 +40,48 @@ type PaymentRedirectData struct { BillingType SubscriptionType `json:"billingType"` } -// SubscriptionPlans contains information about available subscription plans and payment providers. type SubscriptionPlans struct { *protos.BaseResponse `json:",inline"` Providers map[string][]*protos.PaymentMethod `json:"providers"` Plans []*protos.Plan `json:"plans"` } -// SubscriptionResponse contains information about a created subscription. type SubscriptionResponse struct { - CustomerId string `json:"customerId"` - SubscriptionId string `json:"subscriptionId"` + CustomerID string `json:"customerId"` + SubscriptionID string `json:"subscriptionId"` ClientSecret string `json:"clientSecret"` PendingSecret string `json:"pending_secret"` PublishableKey string `json:"publishableKey"` } // SubscriptionPlans retrieves available subscription plans for a given channel. -func (ac *APIClient) SubscriptionPlans(ctx context.Context, channel string) (string, error) { +func (a *Client) SubscriptionPlans(ctx context.Context, channel string) (string, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "subscription_plans") defer span.End() - var resp SubscriptionPlans params := map[string]string{ "locale": settings.GetString(settings.LocaleKey), "distributionChannel": channel, } - proWC := ac.proWebClient() - req := proWC.NewRequest(params, nil, nil) - err := proWC.Get(ctx, "/plans-v5", req, &resp) + resp, err := a.sendProRequest(ctx, "GET", "/plans-v5", params, nil, nil) if err != nil { slog.Error("retrieving plans", "error", err) return "", traces.RecordError(ctx, err) } - if resp.BaseResponse != nil && resp.Error != "" { - err = fmt.Errorf("received bad response: %s", resp.Error) + var plans SubscriptionPlans + if err := json.Unmarshal(resp, &plans); err != nil { + return "", traces.RecordError(ctx, fmt.Errorf("unmarshaling plans response: %w", err)) + } + if plans.BaseResponse != nil && plans.Error != "" { + err = fmt.Errorf("received bad response: %s", plans.Error) slog.Error("retrieving plans", "error", err) return "", traces.RecordError(ctx, err) } - return withMarshalJsonString(resp, nil) + return string(resp), nil } // NewStripeSubscription creates a new Stripe subscription for the given email and plan ID. -func (ac *APIClient) NewStripeSubscription(ctx context.Context, email, planID string) (string, error) { +func (a *Client) NewStripeSubscription(ctx context.Context, email, planID string) (string, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "new_stripe_subscription") defer span.End() @@ -89,25 +89,25 @@ func (ac *APIClient) NewStripeSubscription(ctx context.Context, email, planID st "email": email, "planId": planID, } - proWC := ac.proWebClient() - req := proWC.NewRequest(nil, nil, data) - var resp SubscriptionResponse - err := proWC.Post(ctx, "/stripe-subscription", req, &resp) - return withMarshalJsonString(resp, err) + resp, err := a.sendProRequest(ctx, "POST", "/stripe-subscription", nil, nil, data) + if err != nil { + return "", traces.RecordError(ctx, fmt.Errorf("creating stripe subscription: %w", err)) + } + return string(resp), nil } type VerifySubscriptionResponse struct { Status string `json:"status"` - SubscriptionId string `json:"subscriptionId"` - ActualUserId int64 `json:"actualUserId" json:",omitempty"` - ActualUserToken string `json:"actualUserToken" json:",omitempty"` + SubscriptionID string `json:"subscriptionId"` + ActualUserID int64 `json:"actualUserId,omitempty"` + ActualUserToken string `json:"actualUserToken,omitempty"` } // VerifySubscription verifies a subscription for a given service (Google or Apple). data // should contain the information required by service to verify the subscription, such as the // purchase token for Google Play or the receipt for Apple. The status and subscription ID are returned // along with any error that occurred during the verification process. -func (ac *APIClient) VerifySubscription(ctx context.Context, service SubscriptionService, data map[string]string) (string, error) { +func (a *Client) VerifySubscription(ctx context.Context, service SubscriptionService, data map[string]string) (string, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "verify_subscription") defer span.End() @@ -122,46 +122,57 @@ func (ac *APIClient) VerifySubscription(ctx context.Context, service Subscriptio return "", traces.RecordError(ctx, fmt.Errorf("unsupported service: %s", service)) } - proWC := ac.proWebClient() - req := proWC.NewRequest(nil, nil, data) - var resp VerifySubscriptionResponse - err := proWC.Post(ctx, path, req, &resp) + resp, err := a.sendProRequest(ctx, "POST", path, nil, nil, data) if err != nil { slog.Error("verifying subscription", "error", err) return "", traces.RecordError(ctx, fmt.Errorf("verifying subscription: %w", err)) } - return withMarshalJsonString(resp, nil) + return string(resp), nil + } -// StripeBillingPortalUrl generates the Stripe billing portal URL for the given user ID. -func (ac *APIClient) StripeBillingPortalUrl(ctx context.Context) (string, error) { +// StripeBillingPortalURL generates the Stripe billing portal URL for the given user ID. +// baseURL = common.GetProServerURL +func (a *Client) StripeBillingPortalURL(ctx context.Context, baseURL, userID, proToken string) (string, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "stripe_billing_portal_url") defer span.End() - portalURL, err := url.Parse(fmt.Sprintf("%s/%s", common.GetProServerURL(), "stripe-billing-portal")) + portalURL, err := url.Parse(baseURL + "/stripe-billing-portal") if err != nil { slog.Error("parsing portal URL", "error", err) return "", traces.RecordError(ctx, fmt.Errorf("parsing portal URL: %w", err)) } query := portalURL.Query() query.Set("referer", "https://lantern.io/") - query.Set("userId", strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10)) - query.Set("proToken", settings.GetString(settings.TokenKey)) + query.Set("userId", userID) + query.Set("proToken", proToken) portalURL.RawQuery = query.Encode() return portalURL.String(), nil } -// SubscriptionPaymentRedirectURL generates a redirect URL for subscription payment. -func (ac *APIClient) SubscriptionPaymentRedirectURL(ctx context.Context, data PaymentRedirectData) (string, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "subscription_payment_redirect_url") - defer span.End() +type redirect struct { + Redirect string +} - type response struct { - Redirect string - } - var resp response +func (a *Client) paymentRedirect(ctx context.Context, path string, params map[string]string) (string, error) { headers := map[string]string{ - backend.RefererHeader: "https://lantern.io/", + common.RefererHeader: "https://lantern.io/", } + resp, err := a.sendProRequest(ctx, "GET", path, params, headers, nil) + if err != nil { + slog.Error("payment redirect", "error", err) + return "", traces.RecordError(ctx, fmt.Errorf("payment redirect: %w", err)) + } + var r redirect + if err := json.Unmarshal(resp, &r); err != nil { + return "", traces.RecordError(ctx, fmt.Errorf("unmarshaling payment redirect response: %w", err)) + } + return r.Redirect, nil +} + +// SubscriptionPaymentRedirectURL generates a redirect URL for subscription payment. +func (a *Client) SubscriptionPaymentRedirectURL(ctx context.Context, data PaymentRedirectData) (string, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "subscription_payment_redirect_url") + defer span.End() params := map[string]string{ "provider": data.Provider, "plan": data.Plan, @@ -169,43 +180,21 @@ func (ac *APIClient) SubscriptionPaymentRedirectURL(ctx context.Context, data Pa "email": data.Email, "billingType": string(data.BillingType), } - proWC := ac.proWebClient() - req := proWC.NewRequest(params, headers, nil) - err := proWC.Get(ctx, "/subscription-payment-redirect", req, &resp) - if err != nil { - slog.Error("subscription payment redirect", "error", err) - return "", traces.RecordError(ctx, fmt.Errorf("subscription payment redirect: %w", err)) - } - return resp.Redirect, traces.RecordError(ctx, err) + return a.paymentRedirect(ctx, "/subscription-payment-redirect", params) } -// PaymentRedirect is used to get the payment redirect URL with PaymentRedirectData -// this is used in desktop app and android app -func (ac *APIClient) PaymentRedirect(ctx context.Context, data PaymentRedirectData) (string, error) { +// PaymentRedirect is used to get the payment redirect URL with PaymentRedirectData. +// This is used in the desktop and android apps. +func (a *Client) PaymentRedirect(ctx context.Context, data PaymentRedirectData) (string, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "payment_redirect") defer span.End() - - type response struct { - Redirect string - } - var resp response - headers := map[string]string{ - backend.RefererHeader: "https://lantern.io/", - } - mapping := map[string]string{ + params := map[string]string{ "provider": data.Provider, "plan": data.Plan, "deviceName": data.DeviceName, "email": data.Email, } - proWC := ac.proWebClient() - req := proWC.NewRequest(mapping, headers, nil) - err := proWC.Get(ctx, "/payment-redirect", req, &resp) - if err != nil { - slog.Error("subscription payment redirect", "error", err) - return "", traces.RecordError(ctx, fmt.Errorf("subscription payment redirect: %w", err)) - } - return resp.Redirect, traces.RecordError(ctx, err) + return a.paymentRedirect(ctx, "/payment-redirect", params) } type PurchaseResponse struct { @@ -216,28 +205,29 @@ type PurchaseResponse struct { } // ActivationCode is used to purchase a subscription using a reseller code. -func (ac *APIClient) ActivationCode(ctx context.Context, email, resellerCode string) (*PurchaseResponse, error) { +func (a *Client) ActivationCode(ctx context.Context, email, resellerCode string) (*PurchaseResponse, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "activation_code") defer span.End() - data := map[string]interface{}{ + data := map[string]any{ "idempotencyKey": strconv.FormatInt(time.Now().UnixNano(), 10), "provider": "reseller-code", "email": email, "deviceName": settings.GetString(settings.DeviceIDKey), "resellerCode": resellerCode, } - var resp PurchaseResponse - proWC := ac.proWebClient() - req := proWC.NewRequest(nil, nil, data) - err := proWC.Post(ctx, "/purchase", req, &resp) + resp, err := a.sendProRequest(ctx, "POST", "/purchase", nil, nil, data) if err != nil { slog.Error("retrieving subscription status", "error", err) return nil, traces.RecordError(ctx, fmt.Errorf("retrieving subscription status: %w", err)) } - if resp.BaseResponse != nil && resp.Error != "" { - slog.Error("retrieving subscription status", "error", err) - return nil, traces.RecordError(ctx, fmt.Errorf("received bad response: %s", resp.Error)) + var purchase PurchaseResponse + if err := json.Unmarshal(resp, &purchase); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("unmarshaling purchase response: %w", err)) + } + if purchase.BaseResponse != nil && purchase.Error != "" { + slog.Error("retrieving subscription status", "error", purchase.Error) + return nil, traces.RecordError(ctx, fmt.Errorf("received bad response: %s", purchase.Error)) } - return &resp, nil + return &purchase, nil } diff --git a/account/subscription_test.go b/account/subscription_test.go new file mode 100644 index 00000000..cedd3ee3 --- /dev/null +++ b/account/subscription_test.go @@ -0,0 +1,61 @@ +package account + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSubscriptionPaymentRedirect(t *testing.T) { + ac, _ := newTestClient(t) + data := PaymentRedirectData{ + Provider: "stripe", + Plan: "pro", + DeviceName: "test-device", + Email: "", + BillingType: SubscriptionTypeOneTime, + } + url, err := ac.SubscriptionPaymentRedirectURL(context.Background(), data) + require.NoError(t, err) + assert.NotEmpty(t, url) +} + +func TestPaymentRedirect(t *testing.T) { + ac, _ := newTestClient(t) + data := PaymentRedirectData{ + Provider: "stripe", + Plan: "pro", + DeviceName: "test-device", + Email: "", + } + url, err := ac.PaymentRedirect(context.Background(), data) + require.NoError(t, err) + assert.NotEmpty(t, url) +} + +func TestNewUser(t *testing.T) { + ac, _ := newTestClient(t) + resp, err := ac.NewUser(context.Background()) + require.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestVerifySubscription(t *testing.T) { + ac, _ := newTestClient(t) + data := map[string]string{ + "email": "test@getlantern.org", + "planID": "1y-usd-10", + } + resp, err := ac.VerifySubscription(context.Background(), AppleService, data) + require.NoError(t, err) + assert.NotEmpty(t, resp) +} + +func TestPlans(t *testing.T) { + ac, _ := newTestClient(t) + resp, err := ac.SubscriptionPlans(context.Background(), "store") + require.NoError(t, err) + assert.NotEmpty(t, resp) +} diff --git a/account/user.go b/account/user.go new file mode 100644 index 00000000..4f34cd1b --- /dev/null +++ b/account/user.go @@ -0,0 +1,700 @@ +package account + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/url" + "os" + "strings" + + "go.opentelemetry.io/otel" + "google.golang.org/protobuf/proto" + + "github.com/getlantern/radiance/account/protos" + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/fileperm" + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/traces" +) + +const saltFileName = ".salt" + +type UserDataResponse struct { + *protos.BaseResponse + *protos.LoginResponse_UserData +} + +type SignupResponse = protos.SignupResponse +type UserData = protos.LoginResponse + +// NewUser creates a new user account +func (a *Client) NewUser(ctx context.Context) (*UserData, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "new_user") + defer span.End() + + resp, err := a.sendProRequest(ctx, "POST", "/user-create", nil, nil, nil) + if err != nil { + slog.Error("creating new user", "error", err) + return nil, traces.RecordError(ctx, err) + } + var userResp UserDataResponse + if err := json.Unmarshal(resp, &userResp); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling new user response: %w", err)) + } + userData, err := a.storeData(ctx, userResp) + if err != nil { + return nil, err + } + return userData, nil +} + +// FetchUserData fetches user data from the server. +func (a *Client) FetchUserData(ctx context.Context) (*UserData, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "fetch_user_data") + defer span.End() + return a.fetchUserData(ctx) +} + +// fetchUserData calls the /user-data endpoint and stores the result via storeData. +func (a *Client) fetchUserData(ctx context.Context) (*UserData, error) { + resp, err := a.sendProRequest(ctx, "GET", "/user-data", nil, nil, nil) + if err != nil { + slog.Error("user data", "error", err) + return nil, traces.RecordError(ctx, fmt.Errorf("getting user data: %w", err)) + } + var userResp UserDataResponse + if err := json.Unmarshal(resp, &userResp); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling new user response: %w", err)) + } + return a.storeData(ctx, userResp) +} + +func (a *Client) storeData(ctx context.Context, resp UserDataResponse) (*UserData, error) { + if resp.BaseResponse != nil && resp.Error != "" { + err := fmt.Errorf("received bad response: %s", resp.Error) + slog.Error("user data", "error", err) + return nil, traces.RecordError(ctx, err) + } + if resp.LoginResponse_UserData == nil { + slog.Error("user data", "error", "no user data in response") + return nil, traces.RecordError(ctx, fmt.Errorf("no user data in response")) + } + resp.DeviceID = settings.GetString(settings.DeviceIDKey) + login := &UserData{ + LegacyID: resp.UserId, + LegacyToken: resp.Token, + LegacyUserData: resp.LoginResponse_UserData, + } + a.setData(login) + return login, nil +} + +type DataCapInfo struct { + // Whether data cap is enabled for this device/user + Enabled bool `json:"enabled"` + // Data cap usage details (only populated if enabled is true) + Usage *DataCapUsageDetails `json:"usage,omitempty"` +} + +type DataCapUsageDetails struct { + BytesAllotted string `json:"bytesAllotted"` + BytesUsed string `json:"bytesUsed"` + AllotmentStartTime string `json:"allotmentStartTime"` + AllotmentEndTime string `json:"allotmentEndTime"` +} + +func (a *Client) DataCapInfo(ctx context.Context) (*DataCapInfo, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "data_cap_info") + defer span.End() + + getURL := "/datacap/" + settings.GetString(settings.DeviceIDKey) + headers := map[string]string{ + "Content-Type": "application/json", + } + resp, err := a.sendRequest(ctx, "GET", getURL, nil, headers, nil) + if err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("getting datacap info: %w", err)) + } + var usage *DataCapInfo + if err := json.Unmarshal(resp, &usage); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling datacap info response: %w", err)) + } + return usage, nil +} + +// SignUp signs the user up for an account. +func (a *Client) SignUp(ctx context.Context, email, password string) ([]byte, *protos.SignupResponse, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up") + defer span.End() + + lowerCaseEmail := strings.ToLower(email) + salt, err := generateSalt() + if err != nil { + return nil, nil, traces.RecordError(ctx, err) + } + srpClient, err := newSRPClient(lowerCaseEmail, password, salt) + if err != nil { + return nil, nil, traces.RecordError(ctx, err) + } + verifierKey, err := srpClient.Verifier() + if err != nil { + return nil, nil, traces.RecordError(ctx, err) + } + data := &protos.SignupRequest{ + Email: lowerCaseEmail, + Salt: salt, + Verifier: verifierKey.Bytes(), + SkipEmailConfirmation: true, + // Set temp always to true for now + // If new user faces any issue while sign up user can sign up again + Temp: true, + } + // Signup endpoint need to include device ID, user ID and pro token + // if not api wil create new user instead of linking to existing user which cause issue + resp, err := a.sendProRequest(ctx, "POST", "/users/signup", nil, nil, data) + if err != nil { + return nil, nil, traces.RecordError(ctx, err) + } + a.setSalt(salt) + + var signupData protos.SignupResponse + if err := proto.Unmarshal(resp, &signupData); err != nil { + return nil, nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling sign up response: %w", err)) + } + idErr := settings.Set(settings.UserIDKey, signupData.LegacyID) + if idErr != nil { + return nil, nil, traces.RecordError(ctx, fmt.Errorf("could not save user id: %w", idErr)) + } + proTokenErr := settings.Set(settings.TokenKey, signupData.ProToken) + if proTokenErr != nil { + return nil, nil, traces.RecordError(ctx, fmt.Errorf("could not save token: %w", proTokenErr)) + } + jwtTokenErr := settings.Set(settings.JwtTokenKey, signupData.Token) + if jwtTokenErr != nil { + return nil, nil, traces.RecordError(ctx, fmt.Errorf("could not save JWT token: %w", jwtTokenErr)) + } + + return salt, &signupData, nil +} + +var ErrNoSalt = errors.New("no salt available") +var ErrNotLoggedIn = errors.New("not logged in") +var ErrInvalidCode = errors.New("invalid code") + +// SignupEmailResendCode requests that the sign-up code be resent via email. +func (a *Client) SignupEmailResendCode(ctx context.Context, email string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up_email_resend_code") + defer span.End() + + salt := a.getSaltCached() + if salt == nil { + return traces.RecordError(ctx, ErrNoSalt) + } + data := &protos.SignupEmailResendRequest{ + Email: email, + Salt: salt, + } + _, err := a.sendRequest(ctx, "POST", "/users/signup/resend/email", nil, nil, data) + return traces.RecordError(ctx, err) +} + +// SignupEmailConfirmation confirms the new account using the sign-up code received via email. +func (a *Client) SignupEmailConfirmation(ctx context.Context, email, code string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up_email_confirmation") + defer span.End() + + data := &protos.ConfirmSignupRequest{ + Email: email, + Code: code, + } + _, err := a.sendRequest(ctx, "POST", "/users/signup/complete/email", nil, nil, data) + return traces.RecordError(ctx, err) +} + +func writeSalt(salt []byte, path string) error { + if err := os.WriteFile(path, salt, fileperm.File); err != nil { + return fmt.Errorf("writing salt to %s: %w", path, err) + } + return nil +} + +func readSalt(path string) ([]byte, error) { + buf, err := os.ReadFile(path) + if err != nil && !os.IsNotExist(err) { + return nil, fmt.Errorf("reading salt from %s: %w", path, err) + } + if len(buf) == 0 { + return nil, nil + } + return buf, nil +} + +// Login logs the user in. +func (a *Client) Login(ctx context.Context, email, password string) (*UserData, error) { + // clear any previous salt value + a.setSalt(nil) + ctx, span := otel.Tracer(tracerName).Start(ctx, "login") + defer span.End() + + lowerCaseEmail := strings.ToLower(email) + salt, err := a.getSalt(ctx, lowerCaseEmail) + if err != nil { + return nil, traces.RecordError(ctx, err) + } + + deviceID := settings.GetString(settings.DeviceIDKey) + proof, err := a.clientProof(ctx, lowerCaseEmail, password, salt) + if err != nil { + return nil, err + } + + loginData := &protos.LoginRequest{ + Email: lowerCaseEmail, + DeviceId: deviceID, + Proof: proof, + } + resp, err := a.sendRequest(ctx, "POST", "/users/login", nil, nil, loginData) + if err != nil { + return nil, traces.RecordError(ctx, err) + } + + var loginResp UserData + if err := proto.Unmarshal(resp, &loginResp); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling login response: %w", err)) + } + //this can be nil if the user has reached the device limit + if loginResp.LegacyUserData != nil { + loginResp.LegacyUserData.DeviceID = deviceID + } + + // regardless of state we need to save login information + // We have device flow limit on login + a.setData(&loginResp) + a.setSalt(salt) + if saltErr := writeSalt(salt, a.saltPath); saltErr != nil { + return nil, traces.RecordError(ctx, saltErr) + } + settings.Set(settings.OAuthLoginKey, false) + settings.Set(settings.OAuthProviderKey, "") + return &loginResp, nil +} + +// Logout logs the user out. No-op if there is no user account logged in. +func (a *Client) Logout(ctx context.Context, email string) (*UserData, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "logout") + defer span.End() + logout := &protos.LogoutRequest{ + Email: email, + DeviceId: settings.GetString(settings.DeviceIDKey), + LegacyUserID: settings.GetInt64(settings.UserIDKey), + LegacyToken: settings.GetString(settings.TokenKey), + } + // JWT token is only set for OAuth users; omit the field entirely when empty + jwtToken := settings.GetString(settings.JwtTokenKey) + if jwtToken != "" { + logout.Token = jwtToken + } + slog.Info("Logout request", "request", logout, "JWTTokenSet", jwtToken != "") + _, err := a.sendRequest(ctx, "POST", "/users/logout", nil, nil, logout) + if err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("logging out: %w", err)) + } + a.ClearUser() + a.setSalt(nil) + settings.Set(settings.OAuthLoginKey, false) + settings.Set(settings.OAuthProviderKey, "") + if err := writeSalt(nil, a.saltPath); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("writing salt after logout: %w", err)) + } + return a.NewUser(ctx) +} + +// StartRecoveryByEmail initializes the account recovery process for the provided email. +func (a *Client) StartRecoveryByEmail(ctx context.Context, email string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "start_recovery_by_email") + defer span.End() + + data := &protos.StartRecoveryByEmailRequest{Email: email} + _, err := a.sendRequest(ctx, "POST", "/users/recovery/start/email", nil, nil, data) + return traces.RecordError(ctx, err) +} + +// CompleteRecoveryByEmail completes account recovery using the code received via email. +func (a *Client) CompleteRecoveryByEmail(ctx context.Context, email, newPassword, code string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "complete_recovery_by_email") + defer span.End() + lowerCaseEmail := strings.ToLower(email) + newSalt, err := generateSalt() + if err != nil { + return traces.RecordError(ctx, err) + } + srpClient, err := newSRPClient(lowerCaseEmail, newPassword, newSalt) + if err != nil { + return traces.RecordError(ctx, err) + } + verifierKey, err := srpClient.Verifier() + if err != nil { + return traces.RecordError(ctx, err) + } + + data := &protos.CompleteRecoveryByEmailRequest{ + Email: lowerCaseEmail, + Code: code, + NewSalt: newSalt, + NewVerifier: verifierKey.Bytes(), + } + _, err = a.sendRequest(ctx, "POST", "/users/recovery/complete/email", nil, nil, data) + if err != nil { + return traces.RecordError(ctx, fmt.Errorf("failed to complete recovery by email: %w", err)) + } + if err = writeSalt(newSalt, a.saltPath); err != nil { + return traces.RecordError(ctx, fmt.Errorf("failed to write new salt: %w", err)) + } + return nil +} + +// ValidateEmailRecoveryCode validates the recovery code received via email. +func (a *Client) ValidateEmailRecoveryCode(ctx context.Context, email, code string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "validate_email_recovery_code") + defer span.End() + + data := &protos.ValidateRecoveryCodeRequest{ + Email: email, + Code: code, + } + resp, err := a.sendRequest(ctx, "POST", "/users/recovery/validate/email", nil, nil, data) + if err != nil { + return traces.RecordError(ctx, err) + } + var codeResp protos.ValidateRecoveryCodeResponse + if err := proto.Unmarshal(resp, &codeResp); err != nil { + return traces.RecordError(ctx, fmt.Errorf("error unmarshalling validate recovery code response: %w", err)) + } + if !codeResp.Valid { + return traces.RecordError(ctx, ErrInvalidCode) + } + return nil +} + +// StartChangeEmail initializes a change of the email address associated with this user account. +func (a *Client) StartChangeEmail(ctx context.Context, newEmail, password string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "start_change_email") + defer span.End() + + lowerCaseEmail := strings.ToLower(settings.GetString(settings.EmailKey)) + lowerCaseNewEmail := strings.ToLower(newEmail) + + salt, err := a.getSalt(ctx, lowerCaseEmail) + if err != nil { + return traces.RecordError(ctx, err) + } + proof, err := a.clientProof(ctx, lowerCaseEmail, password, salt) + if err != nil { + return traces.RecordError(ctx, err) + } + + data := &protos.ChangeEmailRequest{ + OldEmail: lowerCaseEmail, + NewEmail: lowerCaseNewEmail, + Proof: proof, + } + _, err = a.sendRequest(ctx, "POST", "/users/change_email", nil, nil, data) + return traces.RecordError(ctx, err) +} + +// CompleteChangeEmail completes a change of the email address associated with this user account, +// using the code received via email. +func (a *Client) CompleteChangeEmail(ctx context.Context, newEmail, password, code string) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "complete_change_email") + defer span.End() + + newSalt, err := generateSalt() + if err != nil { + return traces.RecordError(ctx, err) + } + + newEmail = strings.ToLower(newEmail) + srpClient, err := newSRPClient(newEmail, password, newSalt) + if err != nil { + return traces.RecordError(ctx, err) + } + verifierKey, err := srpClient.Verifier() + if err != nil { + return traces.RecordError(ctx, err) + } + + data := &protos.CompleteChangeEmailRequest{ + OldEmail: settings.GetString(settings.EmailKey), + NewEmail: newEmail, + Code: code, + NewSalt: newSalt, + NewVerifier: verifierKey.Bytes(), + } + _, err = a.sendRequest(ctx, "POST", "/users/change_email/complete/email", nil, nil, data) + if err != nil { + return traces.RecordError(ctx, err) + } + if err := writeSalt(newSalt, a.saltPath); err != nil { + return traces.RecordError(ctx, err) + } + if err := settings.Set(settings.EmailKey, newEmail); err != nil { + return traces.RecordError(ctx, err) + } + + a.setSalt(newSalt) + return nil +} + +// DeleteAccount deletes this user account. +func (a *Client) DeleteAccount(ctx context.Context, email, password string) (*UserData, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "delete_account") + defer span.End() + + lowerCaseEmail := strings.ToLower(email) + data := &protos.DeleteUserRequest{ + Email: lowerCaseEmail, + Permanent: true, + DeviceId: settings.GetString(settings.DeviceIDKey), + Token: settings.GetString(settings.JwtTokenKey), + } + if !settings.GetBool(settings.OAuthLoginKey) { + salt, err := a.getSalt(ctx, lowerCaseEmail) + if err != nil { + return nil, traces.RecordError(ctx, err) + } + proof, err := a.clientProof(ctx, lowerCaseEmail, password, salt) + if err != nil { + return nil, err + } + data.Proof = proof + } else { + if data.Token == "" { + return nil, traces.RecordError(ctx, errors.New("jwt token is required for OAuth account deletion")) + } + } + + _, err := a.sendRequest(ctx, "POST", "/users/delete", nil, nil, data) + if err != nil { + return nil, traces.RecordError(ctx, err) + } + + a.ClearUser() + a.setSalt(nil) + if err := writeSalt(nil, a.saltPath); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("failed to write salt during account deletion cleanup: %w", err)) + } + + return a.NewUser(ctx) +} + +// OAuthLoginURL initiates the OAuth login process for the specified provider. +func (a *Client) OAuthLoginURL(ctx context.Context, provider string) (string, error) { + authURL := a.authURL + if authURL == "" { + authURL = common.GetBaseURL() + } + loginURL, err := url.Parse(authURL + "/users/oauth2/" + provider) + if err != nil { + return "", fmt.Errorf("failed to parse URL: %w", err) + } + query := loginURL.Query() + query.Set("deviceId", settings.GetString(settings.DeviceIDKey)) + query.Set("userId", settings.GetString(settings.UserIDKey)) + query.Set("proToken", settings.GetString(settings.TokenKey)) + query.Set("returnTo", "lantern://auth") + loginURL.RawQuery = query.Encode() + // Persist the provider so it's available after the callback completes. + if err := settings.Set(settings.OAuthProviderKey, provider); err != nil { + return "", fmt.Errorf("failed to persist OAuth provider: %w", err) + } + return loginURL.String(), nil +} + +func (a *Client) OAuthLoginCallback(ctx context.Context, oAuthToken string) (*UserData, error) { + slog.Debug("Getting OAuth login callback") + jwtUserInfo, err := decodeJWT(oAuthToken) + if err != nil { + return nil, fmt.Errorf("error decoding JWT: %w", err) + } + + // Temporary set user data to so api can read it + login := &UserData{ + LegacyID: jwtUserInfo.LegacyUserID, + LegacyToken: jwtUserInfo.LegacyToken, + LegacyUserData: &protos.LoginResponse_UserData{ + UserId: jwtUserInfo.LegacyUserID, + Token: jwtUserInfo.LegacyToken, + DeviceID: jwtUserInfo.DeviceID, + Email: jwtUserInfo.Email, + }, + } + a.setData(login) + // Get user data from api this will also save data in user config + user, err := a.fetchUserData(ctx) + if err != nil { + return nil, fmt.Errorf("error getting user data: %w", err) + } + + if err := settings.Set(settings.JwtTokenKey, oAuthToken); err != nil { + slog.Error("Failed to persist JWT token", "error", err) + return nil, fmt.Errorf("failed to persist JWT token: %w", err) + } + settings.Set(settings.OAuthLoginKey, true) + user.Id = jwtUserInfo.Email + user.EmailConfirmed = true + a.setData(user) + return user, nil +} + +type LinkResponse struct { + *protos.BaseResponse `json:",inline"` + UserID int `json:"userID"` + ProToken string `json:"token"` +} + +// RemoveDevice removes a device from the user's account. +func (a *Client) RemoveDevice(ctx context.Context, deviceID string) (*LinkResponse, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "remove_device") + defer span.End() + + data := map[string]string{ + "deviceId": deviceID, + } + resp, err := a.sendProRequest(ctx, "POST", "/user-link-remove", nil, nil, data) + if err != nil { + return nil, traces.RecordError(ctx, err) + } + var link LinkResponse + if err := json.Unmarshal(resp, &link); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("error unmarshalling remove device response: %w", err)) + } + if link.BaseResponse != nil && link.BaseResponse.Error != "" { + return nil, traces.RecordError(ctx, fmt.Errorf("failed to remove device: %s", link.BaseResponse.Error)) + } + return &link, nil +} + +func (a *Client) ReferralAttach(ctx context.Context, code string) (bool, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "referral_attach") + defer span.End() + + data := map[string]string{ + "code": code, + } + resp, err := a.sendProRequest(ctx, "POST", "/referral-attach", nil, nil, data) + if err != nil { + return false, traces.RecordError(ctx, err) + } + var baseResp protos.BaseResponse + if err := proto.Unmarshal(resp, &baseResp); err != nil { + return false, traces.RecordError(ctx, fmt.Errorf("error unmarshalling referral attach response: %w", err)) + } + if baseResp.Error != "" { + return false, traces.RecordError(ctx, errors.New(baseResp.Error)) + } + return true, nil +} + +type UserChangeEvent struct { + events.Event +} + +func (a *Client) setData(data *UserData) { + a.mu.Lock() + defer a.mu.Unlock() + if data == nil { + a.ClearUser() + return + } + + // This case when user hits device limit while login + if data.LegacyUserData == nil { + slog.Info("no user data to set, storing id and token only") + if data.LegacyID != 0 { + if err := settings.Set(settings.UserIDKey, data.LegacyID); err != nil { + slog.Error("failed to set user ID in settings", "error", err) + } + } + if data.LegacyToken != "" { + if err := settings.Set(settings.TokenKey, data.LegacyToken); err != nil { + slog.Error("failed to set token in settings", "error", err) + } + } + return + } + + existingUser := settings.GetInt64(settings.UserIDKey) != 0 + + var changed bool + if data.LegacyUserData.UserLevel != "" { + oldUserLevel := settings.GetString(settings.UserLevelKey) + changed = changed || oldUserLevel != data.LegacyUserData.UserLevel + if err := settings.Set(settings.UserLevelKey, data.LegacyUserData.UserLevel); err != nil { + slog.Error("failed to set user level in settings", "error", err) + } + } + if data.LegacyUserData.Email != "" { + oldEmail := settings.GetString(settings.EmailKey) + changed = changed || oldEmail != data.LegacyUserData.Email + if err := settings.Set(settings.EmailKey, data.LegacyUserData.Email); err != nil { + slog.Error("failed to set email in settings", "error", err) + } + } + if data.LegacyID != 0 { + oldUserID := settings.GetInt64(settings.UserIDKey) + changed = changed || oldUserID != data.LegacyID + if err := settings.Set(settings.UserIDKey, data.LegacyID); err != nil { + slog.Error("failed to set user ID in settings", "error", err) + } + } + if data.LegacyToken != "" { + oldToken := settings.GetString(settings.TokenKey) + changed = changed || oldToken != data.LegacyToken + if err := settings.Set(settings.TokenKey, data.LegacyToken); err != nil { + slog.Error("failed to set token in settings", "error", err) + } + } + if data.Token != "" { + oldJwtToken := settings.GetString(settings.JwtTokenKey) + changed = changed || oldJwtToken != data.Token + if err := settings.Set(settings.JwtTokenKey, data.Token); err != nil { + slog.Error("failed to set JWT token in settings", "error", err) + } + } + + if len(data.Devices) > 0 { + devices := []settings.Device{} + for _, d := range data.Devices { + devices = append(devices, settings.Device{ + Name: d.Name, + ID: d.Id, + }) + } + if err := settings.Set(settings.DevicesKey, devices); err != nil { + slog.Error("failed to set devices in settings", "error", err) + } + } + + if err := settings.Set(settings.UserDataKey, data); err != nil { + slog.Error("failed to set login response in settings", "error", err) + } + + // We only consider the user to have changed if there was a previous user. + if existingUser && changed { + events.Emit(UserChangeEvent{}) + } +} + +func (a *Client) ClearUser() { + settings.Clear(settings.UserIDKey) + settings.Clear(settings.TokenKey) + settings.Clear(settings.UserLevelKey) + settings.Clear(settings.EmailKey) + settings.Clear(settings.DevicesKey) + settings.Clear(settings.JwtTokenKey) + settings.Clear(settings.UserDataKey) +} diff --git a/account/user_test.go b/account/user_test.go new file mode 100644 index 00000000..87ee1f1c --- /dev/null +++ b/account/user_test.go @@ -0,0 +1,353 @@ +package account + +import ( + "context" + "encoding/hex" + "encoding/json" + "io" + "math/big" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + + "github.com/1Password/srp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/getlantern/radiance/account/protos" + "github.com/getlantern/radiance/common/settings" +) + +// testServer holds server-side SRP state for the mock auth server. +type testServer struct { + salt map[string][]byte + verifier []byte + cache map[string]string +} + +func writeProtoResponse(w http.ResponseWriter, msg proto.Message) { + data, err := proto.Marshal(msg) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + w.Header().Set("Content-Type", "application/x-protobuf") + w.Write(data) +} + +func readProtoRequest(r *http.Request, msg proto.Message) error { + data, err := io.ReadAll(r.Body) + if err != nil { + return err + } + return proto.Unmarshal(data, msg) +} + +func writeJSONResponse(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(v) +} + +func newTestServer(t *testing.T) (*httptest.Server, *testServer) { + state := &testServer{ + salt: make(map[string][]byte), + cache: make(map[string]string), + } + mux := http.NewServeMux() + + // Auth endpoints + mux.HandleFunc("/users/salt", func(w http.ResponseWriter, r *http.Request) { + email := r.URL.Query().Get("email") + salt := state.salt[email] + if salt == nil { + salt = []byte("salt") + } + writeProtoResponse(w, &protos.GetSaltResponse{Salt: salt}) + }) + + mux.HandleFunc("/users/signup", func(w http.ResponseWriter, r *http.Request) { + var req protos.SignupRequest + if err := readProtoRequest(r, &req); err != nil { + http.Error(w, err.Error(), 500) + return + } + state.salt[req.Email] = req.Salt + state.verifier = req.Verifier + writeProtoResponse(w, &protos.SignupResponse{}) + }) + + mux.HandleFunc("/users/prepare", func(w http.ResponseWriter, r *http.Request) { + var req protos.PrepareRequest + if err := readProtoRequest(r, &req); err != nil { + http.Error(w, err.Error(), 500) + return + } + A := big.NewInt(0).SetBytes(req.A) + verifier := big.NewInt(0).SetBytes(state.verifier) + server := srp.NewSRPServer(srp.KnownGroups[srp.RFC5054Group3072], verifier, nil) + if err := server.SetOthersPublic(A); err != nil { + http.Error(w, err.Error(), 500) + return + } + B := server.EphemeralPublic() + if B == nil { + http.Error(w, "cannot generate B", 500) + return + } + if _, err := server.Key(); err != nil { + http.Error(w, "cannot generate key", 500) + return + } + proof, err := server.M(state.salt[req.Email], req.Email) + if err != nil { + http.Error(w, "cannot generate proof", 500) + return + } + serverState, _ := server.MarshalBinary() + state.cache[req.Email] = hex.EncodeToString(serverState) + writeProtoResponse(w, &protos.PrepareResponse{B: B.Bytes(), Proof: proof}) + }) + + mux.HandleFunc("/users/login", func(w http.ResponseWriter, r *http.Request) { + writeProtoResponse(w, &protos.LoginResponse{ + LegacyUserData: &protos.LoginResponse_UserData{ + DeviceID: "deviceId", + }, + }) + }) + + // Simple auth endpoints that return empty responses + for _, path := range []string{ + "/users/signup/resend/email", + "/users/signup/complete/email", + "/users/recovery/start/email", + "/users/recovery/complete/email", + "/users/change_email", + "/users/change_email/complete/email", + "/users/delete", + "/users/logout", + } { + mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + writeProtoResponse(w, &protos.EmptyResponse{}) + }) + } + + mux.HandleFunc("/users/recovery/validate/email", func(w http.ResponseWriter, r *http.Request) { + writeProtoResponse(w, &protos.ValidateRecoveryCodeResponse{Valid: true}) + }) + + // Pro server endpoints + mux.HandleFunc("/user-create", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, UserDataResponse{ + BaseResponse: &protos.BaseResponse{}, + LoginResponse_UserData: &protos.LoginResponse_UserData{ + UserId: 123, + Token: "test-token", + }, + }) + }) + + mux.HandleFunc("/user-data", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, UserDataResponse{ + BaseResponse: &protos.BaseResponse{}, + LoginResponse_UserData: &protos.LoginResponse_UserData{ + UserId: 123, + Token: "test-token", + }, + }) + }) + + mux.HandleFunc("/user-link-remove", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, LinkResponse{ + BaseResponse: &protos.BaseResponse{}, + UserID: 123, + ProToken: "token", + }) + }) + + mux.HandleFunc("/referral-attach", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, protos.BaseResponse{}) + }) + + // Subscription endpoints + mux.HandleFunc("/plans-v5", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, SubscriptionPlans{ + BaseResponse: &protos.BaseResponse{}, + Plans: []*protos.Plan{{Id: "1y-usd-10", Description: "Pro Plan"}}, + }) + }) + + mux.HandleFunc("/subscription-payment-redirect", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, map[string]string{"Redirect": "https://example.com/redirect"}) + }) + + mux.HandleFunc("/payment-redirect", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, map[string]string{"Redirect": "https://example.com/redirect"}) + }) + + mux.HandleFunc("/stripe-subscription", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, SubscriptionResponse{ + CustomerID: "cus_123", + SubscriptionID: "sub_123", + ClientSecret: "secret", + }) + }) + + mux.HandleFunc("/purchase-apple-subscription-v2", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, VerifySubscriptionResponse{ + Status: "active", + SubscriptionID: "sub_1234567890", + }) + }) + + mux.HandleFunc("/purchase", func(w http.ResponseWriter, r *http.Request) { + writeJSONResponse(w, PurchaseResponse{ + BaseResponse: &protos.BaseResponse{}, + PaymentStatus: "completed", + }) + }) + + ts := httptest.NewServer(mux) + t.Cleanup(ts.Close) + return ts, state +} + +func newTestClient(t *testing.T) (*Client, *testServer) { + ts, state := newTestServer(t) + settings.InitSettings(t.TempDir()) + t.Cleanup(settings.Reset) + return &Client{ + httpClient: ts.Client(), + proURL: ts.URL, + authURL: ts.URL, + saltPath: filepath.Join(t.TempDir(), saltFileName), + }, state +} + +// newTestClientWithSRP creates a test client and pre-registers an email/password on the mock server. +func newTestClientWithSRP(t *testing.T, email, password string) (*Client, *testServer) { + ac, state := newTestClient(t) + + salt, err := generateSalt() + require.NoError(t, err) + + encKey, err := generateEncryptedKey(password, email, salt) + require.NoError(t, err) + + srpClient := srp.NewSRPClient(srp.KnownGroups[group], encKey, nil) + verifierKey, err := srpClient.Verifier() + require.NoError(t, err) + + state.salt[email] = salt + state.verifier = verifierKey.Bytes() + ac.salt = salt + + return ac, state +} + +func TestSignUp(t *testing.T) { + ac, _ := newTestClient(t) + salt, signupResponse, err := ac.SignUp(context.Background(), "test@example.com", "password") + assert.NoError(t, err) + assert.NotNil(t, salt) + assert.NotNil(t, signupResponse) +} + +func TestSignupEmailResendCode(t *testing.T) { + ac, _ := newTestClient(t) + ac.salt = []byte("salt") + err := ac.SignupEmailResendCode(context.Background(), "test@example.com") + assert.NoError(t, err) +} + +func TestSignupEmailConfirmation(t *testing.T) { + ac, _ := newTestClient(t) + err := ac.SignupEmailConfirmation(context.Background(), "test@example.com", "code") + assert.NoError(t, err) +} + +func TestLogin(t *testing.T) { + email := "test@example.com" + ac, _ := newTestClientWithSRP(t, email, "password") + // Clear cached salt to test the full flow (getSalt → srpLogin) + ac.salt = nil + _, err := ac.Login(context.Background(), email, "password") + assert.NoError(t, err) +} + +func TestLogout(t *testing.T) { + ac, _ := newTestClient(t) + settings.Set(settings.DeviceIDKey, "deviceId") + _, err := ac.Logout(context.Background(), "test@example.com") + assert.NoError(t, err) +} + +func TestStartRecoveryByEmail(t *testing.T) { + ac, _ := newTestClient(t) + err := ac.StartRecoveryByEmail(context.Background(), "test@example.com") + assert.NoError(t, err) +} + +func TestCompleteRecoveryByEmail(t *testing.T) { + ac, _ := newTestClient(t) + err := ac.CompleteRecoveryByEmail(context.Background(), "test@example.com", "newPassword", "code") + assert.NoError(t, err) +} + +func TestValidateEmailRecoveryCode(t *testing.T) { + ac, _ := newTestClient(t) + err := ac.ValidateEmailRecoveryCode(context.Background(), "test@example.com", "code") + assert.NoError(t, err) +} + +func TestStartChangeEmail(t *testing.T) { + email := "test@example.com" + ac, _ := newTestClientWithSRP(t, email, "password") + settings.Set(settings.EmailKey, email) + err := ac.StartChangeEmail(context.Background(), "new@example.com", "password") + assert.NoError(t, err) +} + +func TestCompleteChangeEmail(t *testing.T) { + ac, _ := newTestClient(t) + settings.Set(settings.EmailKey, "old@example.com") + err := ac.CompleteChangeEmail(context.Background(), "new@example.com", "password", "code") + assert.NoError(t, err) +} + +func TestDeleteAccount(t *testing.T) { + email := "test@example.com" + ac, _ := newTestClientWithSRP(t, email, "password") + settings.Set(settings.DeviceIDKey, "deviceId") + _, err := ac.DeleteAccount(context.Background(), email, "password") + assert.NoError(t, err) +} + +func TestOAuthLoginUrl(t *testing.T) { + ac, _ := newTestClient(t) + url, err := ac.OAuthLoginURL(context.Background(), "google") + assert.NoError(t, err) + assert.NotEmpty(t, url) +} + +func TestOAuthLoginCallback(t *testing.T) { + ac, _ := newTestClient(t) + settings.Set(settings.DeviceIDKey, "deviceId") + + // Mock JWT with unverified signature — decodeJWT uses ParseUnverified so this succeeds. + mockToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20iLCJsZWdhY3lfdXNlcl9pZCI6MTIzNDUsImxlZ2FjeV90b2tlbiI6InRlc3QtdG9rZW4ifQ.test" + + data, err := ac.OAuthLoginCallback(context.Background(), mockToken) + assert.NoError(t, err) + assert.NotEmpty(t, data) +} + +func TestOAuthLoginCallback_InvalidToken(t *testing.T) { + ac, _ := newTestClient(t) + + _, err := ac.OAuthLoginCallback(context.Background(), "invalid-token") + assert.Error(t, err) + assert.Contains(t, err.Error(), "error decoding JWT") +} diff --git a/api/api.go b/api/api.go deleted file mode 100644 index cd8b6b27..00000000 --- a/api/api.go +++ /dev/null @@ -1,77 +0,0 @@ -package api - -import ( - "log/slog" - "net/http" - "path/filepath" - "strconv" - "sync" - "time" - - "github.com/go-resty/resty/v2" - - "github.com/getlantern/radiance/backend" - "github.com/getlantern/radiance/bypass" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/kindling" -) - -const tracerName = "github.com/getlantern/radiance/api" - -type APIClient struct { - salt []byte - saltPath string - authClient AuthClient - mu sync.RWMutex -} - -func NewAPIClient(dataDir string) *APIClient { - path := filepath.Join(dataDir, saltFileName) - salt, err := readSalt(path) - if err != nil { - slog.Warn("failed to read salt", "error", err) - } - - cli := &APIClient{ - salt: salt, - saltPath: path, - authClient: &authClient{}, - } - return cli -} - -func (a *APIClient) proWebClient() *webClient { - httpClient := kindling.HTTPClient() - proWC := newWebClient(httpClient, common.GetProServerURL()) - proWC.client.OnBeforeRequest(func(client *resty.Client, req *resty.Request) error { - req.Header.Set(backend.DeviceIDHeader, settings.GetString(settings.DeviceIDKey)) - if settings.GetString(settings.TokenKey) != "" { - req.Header.Set(backend.ProTokenHeader, settings.GetString(settings.TokenKey)) - } - if settings.GetInt64(settings.UserIDKey) != 0 { - req.Header.Set(backend.UserIDHeader, strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10)) - } - return nil - }) - return proWC -} - -func authWebClient() *webClient { - return newWebClient(kindling.HTTPClient(), common.GetBaseURL()) -} - -// tunnelClient routes through the local tunnel proxy when the VPN is running. -// Unlike the bypass proxy (which routes to direct), this routes through the -// active VPN proxy outbound. No client-level timeout — SSE streams are -// long-lived. When the VPN is not running, connections fail (no fallback). -var tunnelClient = &http.Client{ - Transport: &http.Transport{ - DialContext: bypass.TunnelDialContext, - ForceAttemptHTTP2: false, - MaxIdleConns: 10, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - }, -} diff --git a/api/auth.go b/api/auth.go deleted file mode 100644 index 0949c1a2..00000000 --- a/api/auth.go +++ /dev/null @@ -1,178 +0,0 @@ -package api - -import ( - "context" - "fmt" - "strconv" - - "github.com/getlantern/radiance/api/protos" - "github.com/getlantern/radiance/backend" - "github.com/getlantern/radiance/common/settings" -) - -type AuthClient interface { - // Sign up methods - SignUp(ctx context.Context, email string, password string) ([]byte, *protos.SignupResponse, error) - SignupEmailResendCode(ctx context.Context, data *protos.SignupEmailResendRequest) error - SignupEmailConfirmation(ctx context.Context, data *protos.ConfirmSignupRequest) error - // Login methods - GetSalt(ctx context.Context, email string) (*protos.GetSaltResponse, error) - LoginPrepare(ctx context.Context, loginData *protos.PrepareRequest) (*protos.PrepareResponse, error) - Login(ctx context.Context, email, password, deviceID string, salt []byte) (*protos.LoginResponse, error) - // Recovery methods - StartRecoveryByEmail(ctx context.Context, loginData *protos.StartRecoveryByEmailRequest) error - CompleteRecoveryByEmail(ctx context.Context, loginData *protos.CompleteRecoveryByEmailRequest) error - ValidateEmailRecoveryCode(ctx context.Context, loginData *protos.ValidateRecoveryCodeRequest) (*protos.ValidateRecoveryCodeResponse, error) - // Change email methods - ChangeEmail(ctx context.Context, loginData *protos.ChangeEmailRequest) error - // Complete change email methods - CompleteChangeEmail(ctx context.Context, loginData *protos.CompleteChangeEmailRequest) error - DeleteAccount(ctc context.Context, loginData *protos.DeleteUserRequest) error - // Logout - SignOut(ctx context.Context, logoutData *protos.LogoutRequest) error -} - -type authClient struct{} - -// Auth APIS -// GetSalt is used to get the salt for a given email address -func (c *authClient) GetSalt(ctx context.Context, email string) (*protos.GetSaltResponse, error) { - var resp protos.GetSaltResponse - query := map[string]string{ - "email": email, - } - header := map[string]string{ - "Content-Type": "application/x-protobuf", - "Accept": "application/x-protobuf", - } - wc := authWebClient() - req := wc.NewRequest(query, header, nil) - if err := wc.Get(ctx, "/users/salt", req, &resp); err != nil { - return nil, err - } - return &resp, nil -} - -// Sign up API -// SignUp is used to sign up a new user with the SignupRequest -func (c *authClient) signUp(ctx context.Context, signupData *protos.SignupRequest) (*protos.SignupResponse, error) { - var resp protos.SignupResponse - header := map[string]string{ - backend.DeviceIDHeader: settings.GetString(settings.DeviceIDKey), - backend.UserIDHeader: strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10), - backend.ProTokenHeader: settings.GetString(settings.TokenKey), - } - wc := authWebClient() - req := wc.NewRequest(nil, header, signupData) - if err := wc.Post(ctx, "/users/signup", req, &resp); err != nil { - return nil, err - } - return &resp, nil -} - -// SignupEmailResendCode is used to resend the email confirmation code -// Params: ctx context.Context, data *SignupEmailResendRequest -func (c *authClient) SignupEmailResendCode(ctx context.Context, data *protos.SignupEmailResendRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, data) - return wc.Post(ctx, "/users/signup/resend/email", req, &resp) -} - -// SignupEmailConfirmation is used to confirm the email address once user enter code -// Params: ctx context.Context, data *ConfirmSignupRequest -func (c *authClient) SignupEmailConfirmation(ctx context.Context, data *protos.ConfirmSignupRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, data) - return wc.Post(ctx, "/users/signup/complete/email", req, &resp) -} - -// LoginPrepare does the initial login preparation with come make sure the user exists and match user salt -func (c *authClient) LoginPrepare(ctx context.Context, loginData *protos.PrepareRequest) (*protos.PrepareResponse, error) { - var model protos.PrepareResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - if err := wc.Post(ctx, "/users/prepare", req, &model); err != nil { - // Send custom error to show error on client side - return nil, fmt.Errorf("user_not_found %w", err) - } - return &model, nil -} - -// Login is used to login a user with the LoginRequest -func (c *authClient) login(ctx context.Context, loginData *protos.LoginRequest) (*protos.LoginResponse, error) { - var resp protos.LoginResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - if err := wc.Post(ctx, "/users/login", req, &resp); err != nil { - return nil, err - } - - return &resp, nil -} - -// StartRecoveryByEmail is used to start the recovery process by sending a recovery code to the user's email -func (c *authClient) StartRecoveryByEmail(ctx context.Context, loginData *protos.StartRecoveryByEmailRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - return wc.Post(ctx, "/users/recovery/start/email", req, &resp) -} - -// CompleteRecoveryByEmail is used to complete the recovery process by validating the recovery code -func (c *authClient) CompleteRecoveryByEmail(ctx context.Context, loginData *protos.CompleteRecoveryByEmailRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - return wc.Post(ctx, "/users/recovery/complete/email", req, &resp) -} - -// // ValidateEmailRecoveryCode is used to validate the recovery code -func (c *authClient) ValidateEmailRecoveryCode(ctx context.Context, recoveryData *protos.ValidateRecoveryCodeRequest) (*protos.ValidateRecoveryCodeResponse, error) { - var resp protos.ValidateRecoveryCodeResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, recoveryData) - err := wc.Post(ctx, "/users/recovery/validate/email", req, &resp) - if err != nil { - return nil, err - } - if !resp.Valid { - return nil, fmt.Errorf("invalid_code Error decoding response body: %w", err) - } - return &resp, nil -} - -// ChangeEmail is used to change the email address of a user -func (c *authClient) ChangeEmail(ctx context.Context, loginData *protos.ChangeEmailRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - return wc.Post(ctx, "/users/change_email", req, &resp) -} - -// CompleteChangeEmail is used to complete the email change process -func (c *authClient) CompleteChangeEmail(ctx context.Context, loginData *protos.CompleteChangeEmailRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, loginData) - return wc.Post(ctx, "/users/change_email/complete/email", req, &resp) -} - -// DeleteAccount is used to delete the account of a user -// Once account is delete make sure to create new account -func (c *authClient) DeleteAccount(ctx context.Context, accountData *protos.DeleteUserRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, accountData) - return wc.Post(ctx, "/users/delete", req, &resp) -} - -// DeleteAccount is used to delete the account of a user -// Once account is delete make sure to create new account -func (c *authClient) SignOut(ctx context.Context, logoutData *protos.LogoutRequest) error { - var resp protos.EmptyResponse - wc := authWebClient() - req := wc.NewRequest(nil, nil, logoutData) - return wc.Post(ctx, "/users/logout", req, &resp) -} diff --git a/api/srp.go b/api/srp.go deleted file mode 100644 index 8d0d60c7..00000000 --- a/api/srp.go +++ /dev/null @@ -1,138 +0,0 @@ -package api - -import ( - "context" - "crypto/rand" - "crypto/sha256" - "errors" - "fmt" - "math/big" - "strings" - - "github.com/1Password/srp" - "golang.org/x/crypto/pbkdf2" - - "github.com/getlantern/radiance/api/protos" -) - -func newSRPClient(email string, password string, salt []byte) (*srp.SRP, error) { - if len(salt) == 0 || len(password) == 0 || len(email) == 0 { - return nil, errors.New("salt, password and email should not be empty") - } - - lowerCaseEmail := strings.ToLower(email) - encryptedKey, err := generateEncryptedKey(password, lowerCaseEmail, salt) - if err != nil { - return nil, fmt.Errorf("failed to generate encrypted key: %w", err) - } - - return srp.NewSRPClient(srp.KnownGroups[group], encryptedKey, nil), nil -} - -// Takes password and email, salt and returns encrypted key -func generateEncryptedKey(password string, email string, salt []byte) (*big.Int, error) { - if len(salt) == 0 || len(password) == 0 || len(email) == 0 { - return nil, errors.New("salt or password or email is empty") - } - lowerCaseEmail := strings.ToLower(email) - combinedInput := password + lowerCaseEmail - encryptedKey := pbkdf2.Key([]byte(combinedInput), salt, 4096, 32, sha256.New) - encryptedKeyBigInt := big.NewInt(0).SetBytes(encryptedKey) - return encryptedKeyBigInt, nil -} - -func generateSalt() ([]byte, error) { - salt := make([]byte, 16) - if n, err := rand.Read(salt); err != nil { - return nil, err - } else if n != 16 { - return nil, errors.New("failed to generate 16 byte salt") - } - return salt, nil -} - -func (c *authClient) SignUp(ctx context.Context, email string, password string) ([]byte, *protos.SignupResponse, error) { - lowerCaseEmail := strings.ToLower(email) - salt, err := generateSalt() - if err != nil { - return nil, nil, err - } - srpClient, err := newSRPClient(lowerCaseEmail, password, salt) - if err != nil { - return nil, nil, err - } - verifierKey, err := srpClient.Verifier() - if err != nil { - return nil, nil, err - } - signUpRequestBody := &protos.SignupRequest{ - Email: lowerCaseEmail, - Salt: salt, - Verifier: verifierKey.Bytes(), - SkipEmailConfirmation: true, - // Set temp always to true for now - // If new user faces any issue while sign up user can sign up again - Temp: true, - } - - body, err := c.signUp(ctx, signUpRequestBody) - if err != nil { - return salt, nil, err - } - return salt, body, nil -} - -// Todo find way to optimize this method -func (c *authClient) Login(ctx context.Context, email string, password string, deviceId string, salt []byte) (*protos.LoginResponse, error) { - lowerCaseEmail := strings.ToLower(email) - - // Prepare login request body - client, err := newSRPClient(lowerCaseEmail, password, salt) - if err != nil { - return nil, err - } - //Send this key to client - A := client.EphemeralPublic() - //Create body - prepareRequestBody := &protos.PrepareRequest{ - Email: lowerCaseEmail, - A: A.Bytes(), - } - - srpB, err := c.LoginPrepare(ctx, prepareRequestBody) - if err != nil { - return nil, err - } - - // // Once the client receives B from the server Client should check error status here as defense against - // // a malicious B sent from server - B := big.NewInt(0).SetBytes(srpB.B) - - if err = client.SetOthersPublic(B); err != nil { - return nil, err - } - - // client can now make the session key - clientKey, err := client.Key() - if err != nil || clientKey == nil { - return nil, fmt.Errorf("user_not_found error while generating Client key %w", err) - } - - // Step 3 - - // check if the server proof is valid - if !client.GoodServerProof(salt, lowerCaseEmail, srpB.Proof) { - return nil, fmt.Errorf("user_not_found error while checking server proof %w", err) - } - - clientProof, err := client.ClientProof() - if err != nil { - return nil, fmt.Errorf("user_not_found error while generating client proof %w", err) - } - loginRequestBody := &protos.LoginRequest{ - Email: lowerCaseEmail, - Proof: clientProof, - DeviceId: deviceId, - } - return c.login(ctx, loginRequestBody) -} diff --git a/api/sse.go b/api/sse.go deleted file mode 100644 index abc10ac4..00000000 --- a/api/sse.go +++ /dev/null @@ -1,60 +0,0 @@ -package api - -import ( - "bufio" - "context" - "io" - "strings" -) - -type sseEvent struct { - Type string - Data string -} - -// readSSE reads Server-Sent Events from body and sends parsed events on the -// returned channel. The channel is closed when the body returns EOF, an error -// occurs, or ctx is cancelled. The caller is responsible for closing body. -// After the channel is closed, call the returned function to retrieve any -// scanner error (nil on clean EOF). -func readSSE(ctx context.Context, body io.Reader) (<-chan sseEvent, func() error) { - ch := make(chan sseEvent, 1) - var scanErr error - go func() { - defer close(ch) - scanner := bufio.NewScanner(body) - scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) // 1 MB max token - var evt sseEvent - for scanner.Scan() { - if ctx.Err() != nil { - return - } - line := scanner.Text() - switch { - case strings.HasPrefix(line, "event:"): - evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:")) - case strings.HasPrefix(line, "data:"): - dataLine := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if evt.Data == "" { - evt.Data = dataLine - } else { - evt.Data = evt.Data + "\n" + dataLine - } - case strings.HasPrefix(line, ":"): - // comment / heartbeat — ignore - case line == "": - // blank line = event delimiter - if evt.Type != "" || evt.Data != "" { - select { - case ch <- evt: - case <-ctx.Done(): - return - } - evt = sseEvent{} - } - } - } - scanErr = scanner.Err() - }() - return ch, func() error { return scanErr } -} diff --git a/api/subscription_test.go b/api/subscription_test.go deleted file mode 100644 index 0fa2c052..00000000 --- a/api/subscription_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package api - -import ( - "context" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/getlantern/radiance/api/protos" -) - -func TestSubscriptionPaymentRedirect(t *testing.T) { - ac := mockAPIClient(t) - data := PaymentRedirectData{ - Provider: "stripe", - Plan: "pro", - DeviceName: "test-device", - Email: "", - BillingType: SubscriptionTypeOneTime, - } - url, err := ac.SubscriptionPaymentRedirectURL(context.Background(), data) - require.NoError(t, err) - assert.NotEmpty(t, url) -} -func TestPaymentRedirect(t *testing.T) { - ac := mockAPIClient(t) - data := PaymentRedirectData{ - Provider: "stripe", - Plan: "pro", - DeviceName: "test-device", - Email: "", - } - url, err := ac.PaymentRedirect(context.Background(), data) - require.NoError(t, err) - assert.NotEmpty(t, url) -} - -func TestNewUser(t *testing.T) { - ac := mockAPIClient(t) - resp, err := ac.NewUser(context.Background()) - require.NoError(t, err) - assert.NotNil(t, resp) -} - -func TestVerifySubscription(t *testing.T) { - ac := mockAPIClient(t) - email := "test@getlantern.org" - planID := "1y-usd-10" - data := map[string]string{ - "email": email, - "planID": planID, - } - status, subID, err := ac.VerifySubscription(context.Background(), AppleService, data) - require.NoError(t, err) - assert.NotEmpty(t, status) - assert.NotEmpty(t, subID) -} - -func TestPlans(t *testing.T) { - ac := mockAPIClient(t) - resp, err := ac.SubscriptionPlans(context.Background(), "store") - require.NoError(t, err) - assert.NotNil(t, resp) - assert.NotNil(t, resp.Plans) -} - -type MockAPIClient struct { - *APIClient -} - -func mockAPIClient(t *testing.T) *MockAPIClient { - return &MockAPIClient{ - APIClient: &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - salt: []byte{1, 2, 3, 4, 5}, - }, - } -} - -func (m *MockAPIClient) VerifySubscription(ctx context.Context, service SubscriptionService, data map[string]string) (status, subID string, err error) { - return "active", "sub_1234567890", nil -} - -func (m *MockAPIClient) SubscriptionPlans(ctx context.Context, channel string) (*SubscriptionPlans, error) { - resp := &SubscriptionPlans{ - BaseResponse: &protos.BaseResponse{}, - Plans: []*protos.Plan{ - {Id: "1y-usd-10", Description: "Pro Plan", Price: map[string]int64{}}, - }, - } - return resp, nil -} -func (m *MockAPIClient) SubscriptionPaymentRedirectURL(ctx context.Context, data PaymentRedirectData) (string, error) { - return "https://example.com/redirect", nil -} - -func (m *MockAPIClient) PaymentRedirect(ctx context.Context, data PaymentRedirectData) (string, error) { - return "https://example.com/redirect", nil -} -func (m *MockAPIClient) NewUser(ctx context.Context) (*protos.LoginResponse, error) { - return &protos.LoginResponse{}, nil -} diff --git a/api/user.go b/api/user.go deleted file mode 100644 index 7262157e..00000000 --- a/api/user.go +++ /dev/null @@ -1,865 +0,0 @@ -package api - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "math/big" - "net/http" - "net/url" - "os" - "strconv" - "strings" - "time" - - "github.com/1Password/srp" - - "go.opentelemetry.io/otel" - "google.golang.org/protobuf/proto" - - "github.com/getlantern/radiance/api/protos" - "github.com/getlantern/radiance/backend" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/traces" -) - -// The main output of this file is Radiance.GetUser, which provides a hook into all user account -// functionality. - -const saltFileName = ".salt" - -// datacapBackoffResetAfter is the minimum connection duration before we reset -// the reconnect backoff. It must exceed the server's SSE idle timeout (~60s); -// otherwise every normal timeout-triggered disconnect looks "long-lived" and -// resets the backoff, causing a tight-loop reconnect. -const datacapBackoffResetAfter = 90 * time.Second - -// pro-server requests -type UserDataResponse struct { - *protos.BaseResponse `json:",inline"` - *protos.LoginResponse_UserData `json:",inline"` -} - -// NewUser creates a new user account -func (ac *APIClient) NewUser(ctx context.Context) (*protos.LoginResponse, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "new_user") - defer span.End() - - var resp UserDataResponse - header := map[string]string{ - backend.ContentTypeHeader: "application/json", - } - req := ac.proWebClient().NewRequest(nil, header, nil) - err := ac.proWebClient().Post(ctx, "/user-create", req, &resp) - if err != nil { - slog.Error("creating new user", "error", err) - return nil, traces.RecordError(ctx, err) - } - loginResponse, err := ac.storeData(ctx, resp) - if err != nil { - return nil, err - } - return loginResponse, nil -} - -func (ac *APIClient) UserData() ([]byte, error) { - slog.Debug("Getting user data") - user := &protos.LoginResponse{} - err := settings.GetStruct(settings.LoginResponseKey, user) - return withMarshalProto(user, err) -} - -// FetchUserData fetches user data from the server. -func (ac *APIClient) FetchUserData(ctx context.Context) ([]byte, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "fetch_user_data") - defer span.End() - return withMarshalProto(ac.fetchUserData(ctx)) -} - -// fetchUserData calls the /user-data endpoint and stores the result via storeData. -func (ac *APIClient) fetchUserData(ctx context.Context) (*protos.LoginResponse, error) { - var resp UserDataResponse - err := ac.proWebClient().Get(ctx, "/user-data", nil, &resp) - if err != nil { - slog.Error("user data", "error", err) - return nil, traces.RecordError(ctx, fmt.Errorf("getting user data: %w", err)) - } - return ac.storeData(ctx, resp) -} - -func (a *APIClient) storeData(ctx context.Context, resp UserDataResponse) (*protos.LoginResponse, error) { - if resp.BaseResponse != nil && resp.Error != "" { - err := fmt.Errorf("received bad response: %s", resp.Error) - slog.Error("user data", "error", err) - return nil, traces.RecordError(ctx, err) - } - if resp.LoginResponse_UserData == nil { - slog.Error("user data", "error", "no user data in response") - return nil, traces.RecordError(ctx, fmt.Errorf("no user data in response")) - } - // Append device ID to user data - resp.LoginResponse_UserData.DeviceID = settings.GetString(settings.DeviceIDKey) - login := &protos.LoginResponse{ - LegacyID: resp.UserId, - LegacyToken: resp.Token, - LegacyUserData: resp.LoginResponse_UserData, - } - a.setData(login) - return login, nil -} - -// user-server requests - -// Devices returns a list of devices associated with this user account. -func (a *APIClient) Devices() ([]settings.Device, error) { - return settings.Devices() -} - -// DataCapUsageResponse represents the data cap usage response -type DataCapUsageResponse struct { - // Whether data cap is enabled for this device/user - Enabled bool `json:"enabled"` - // Data cap usage details (only populated if enabled is true) - Usage *DataCapUsageDetails `json:"usage,omitempty"` -} - -// DataCapUsageDetails contains details of the data cap usage -type DataCapUsageDetails struct { - BytesAllotted string `json:"bytesAllotted"` - BytesUsed string `json:"bytesUsed"` - AllotmentStartTime string `json:"allotmentStartTime"` - AllotmentEndTime string `json:"allotmentEndTime"` -} - -// DataCapInfo returns information about this user's data cap from the local cache. -func (a *APIClient) DataCapInfo(ctx context.Context) (string, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "data_cap_info") - defer span.End() - var cached DataCapUsageResponse - if err := settings.GetStruct(settings.DataCapUsageKey, &cached); err != nil { - return withMarshalJsonString(&DataCapUsageResponse{}, nil) - } - return withMarshalJsonString(&cached, nil) -} - -type DataCapChangeEvent struct { - events.Event - *DataCapUsageResponse -} - -// DataCapStream connects to the datacap SSE endpoint via the tunnel proxy -// and emits DataCapChangeEvent whenever the server pushes an update. Each -// update is persisted to settings so the UI has data even when the VPN is off. -// The method blocks until ctx is cancelled, reconnecting with backoff on -// stream errors. -func (a *APIClient) DataCapStream(ctx context.Context) error { - // Emit cached datacap immediately so the UI has data before connecting. - var cached DataCapUsageResponse - if err := settings.GetStruct(settings.DataCapUsageKey, &cached); err == nil { - events.Emit(DataCapChangeEvent{DataCapUsageResponse: &cached}) - slog.Debug("emitted cached datacap from settings") - } - - bo := common.NewBackoff(2 * time.Minute) - for { - if ctx.Err() != nil { - return ctx.Err() - } - start := time.Now() - err := a.connectSSE(ctx) - if err != nil { - slog.Debug("datacap SSE stream ended", "error", err) - } - if ctx.Err() != nil { - return ctx.Err() - } - // Reset backoff if the connection was up for a while before dropping, - // so we reconnect quickly after a transient disconnect. - if time.Since(start) > datacapBackoffResetAfter { - bo.Reset() - } - bo.Wait(ctx) - } -} - -// connectSSE opens an SSE connection to the datacap stream endpoint and -// processes events until the stream ends or ctx is cancelled. -func (a *APIClient) connectSSE(ctx context.Context) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "datacap_sse") - defer span.End() - - sseURL := fmt.Sprintf("%s/stream/datacap/%s", common.GetBaseURL(), settings.GetString(settings.DeviceIDKey)) - req, err := backend.NewRequestWithHeaders(ctx, http.MethodGet, sseURL, nil) - if err != nil { - return traces.RecordError(ctx, fmt.Errorf("datacap SSE request: %w", err)) - } - req.Header.Set(backend.AcceptHeader, "text/event-stream") - - resp, err := tunnelClient.Do(req) - if err != nil { - return traces.RecordError(ctx, fmt.Errorf("datacap SSE connect: %w", err)) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return traces.RecordError(ctx, fmt.Errorf("datacap SSE status %d", resp.StatusCode)) - } - - slog.Debug("connected to datacap SSE stream") - eventCh, scanErr := readSSE(ctx, resp.Body) - for evt := range eventCh { - switch evt.Type { - case "datacap": - var datacap DataCapUsageResponse - if err := json.Unmarshal([]byte(evt.Data), &datacap); err != nil { - slog.Debug("datacap SSE unmarshal error", "error", err) - continue - } - if err := settings.Set(settings.DataCapUsageKey, &datacap); err != nil { - slog.Debug("datacap persist error", "error", err) - } - events.Emit(DataCapChangeEvent{DataCapUsageResponse: &datacap}) - if datacap.Usage != nil { - slog.Debug("datacap updated", "bytesUsed", datacap.Usage.BytesUsed) - } - case "cap_exhausted": - slog.Warn("datacap exhausted") - default: - // heartbeat or unknown event — ignore - } - } - if err := ctx.Err(); err != nil { - return traces.RecordError(ctx, err) - } - if err := scanErr(); err != nil { - return traces.RecordError(ctx, fmt.Errorf("datacap SSE scanner: %w", err)) - } - return traces.RecordError(ctx, errors.New("datacap SSE stream ended unexpectedly")) -} - -// SignUp signs the user up for an account. -func (a *APIClient) SignUp(ctx context.Context, email, password string) ([]byte, *protos.SignupResponse, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up") - defer span.End() - - salt, signupResponse, err := a.authClient.SignUp(ctx, email, password) - if err != nil { - return nil, nil, traces.RecordError(ctx, err) - } - a.salt = salt - - idErr := settings.Set(settings.UserIDKey, signupResponse.LegacyID) - if idErr != nil { - return nil, nil, fmt.Errorf("could not save user id: %w", idErr) - } - proTokenErr := settings.Set(settings.TokenKey, signupResponse.ProToken) - if proTokenErr != nil { - return nil, nil, fmt.Errorf("could not save token: %w", proTokenErr) - } - jwtTokenErr := settings.Set(settings.JwtTokenKey, signupResponse.Token) - if jwtTokenErr != nil { - return nil, nil, fmt.Errorf("could not save JWT token: %w", jwtTokenErr) - } - - return salt, signupResponse, nil -} - -var ErrNoSalt = errors.New("not salt available, call GetSalt/Signup first") -var ErrNotLoggedIn = errors.New("not logged in") -var ErrInvalidCode = errors.New("invalid code") - -// SignupEmailResendCode requests that the sign-up code be resent via email. -func (a *APIClient) SignupEmailResendCode(ctx context.Context, email string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up_email_resend_code") - defer span.End() - - if a.salt == nil { - return traces.RecordError(ctx, ErrNoSalt) - } - return traces.RecordError(ctx, a.authClient.SignupEmailResendCode(ctx, &protos.SignupEmailResendRequest{ - Email: email, - Salt: a.salt, - })) -} - -// SignupEmailConfirmation confirms the new account using the sign-up code received via email. -func (a *APIClient) SignupEmailConfirmation(ctx context.Context, email, code string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "sign_up_email_confirmation") - defer span.End() - - return traces.RecordError(ctx, a.authClient.SignupEmailConfirmation(ctx, &protos.ConfirmSignupRequest{ - Email: email, - Code: code, - })) -} - -func writeSalt(salt []byte, path string) error { - if err := os.WriteFile(path, salt, 0600); err != nil { - return fmt.Errorf("writing salt to %s: %w", path, err) - } - return nil -} - -func readSalt(path string) ([]byte, error) { - buf, err := os.ReadFile(path) - if err != nil && !os.IsNotExist(err) { - return nil, fmt.Errorf("reading salt from %s: %w", path, err) - } - if len(buf) == 0 { - return nil, nil - } - return buf, nil -} - -// getSalt retrieves the salt for the given email address or it's cached value. -func (a *APIClient) getSalt(ctx context.Context, email string) ([]byte, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "get_salt") - defer span.End() - - if a.salt != nil { - return a.salt, nil // use cached value - } - resp, err := a.authClient.GetSalt(ctx, email) - if err != nil { - return nil, traces.RecordError(ctx, err) - } - return resp.Salt, nil -} - -// Login logs the user in. -func (a *APIClient) Login(ctx context.Context, email string, password string) ([]byte, error) { - // clear any previous salt value - a.salt = nil - ctx, span := otel.Tracer(tracerName).Start(ctx, "login") - defer span.End() - - salt, err := a.getSalt(ctx, email) - if err != nil { - return nil, err - } - - deviceId := settings.GetString(settings.DeviceIDKey) - resp, err := a.authClient.Login(ctx, email, password, deviceId, salt) - if err != nil { - return nil, traces.RecordError(ctx, err) - } - - //this can be nil if the user has reached the device limit - if resp.LegacyUserData != nil { - // Append device ID to user data - resp.LegacyUserData.DeviceID = deviceId - } - - // regardless of state we need to save login information - // We have device flow limit on login - a.setData(resp) - a.salt = salt - if saltErr := writeSalt(salt, a.saltPath); saltErr != nil { - return nil, traces.RecordError(ctx, saltErr) - } - return withMarshalProto(resp, nil) -} - -// Logout logs the user out. No-op if there is no user account logged in. -func (a *APIClient) Logout(ctx context.Context, email string) ([]byte, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "logout") - defer span.End() - logout := &protos.LogoutRequest{ - Email: email, - DeviceId: settings.GetString(settings.DeviceIDKey), - LegacyUserID: settings.GetInt64(settings.UserIDKey), - LegacyToken: settings.GetString(settings.TokenKey), - Token: settings.GetString(settings.JwtTokenKey), - } - if err := a.authClient.SignOut(ctx, logout); err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("logging out: %w", err)) - } - a.Reset() - a.salt = nil - if err := writeSalt(nil, a.saltPath); err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("writing salt after logout: %w", err)) - } - return withMarshalProto(a.NewUser(context.Background())) -} - -func withMarshalProto(resp *protos.LoginResponse, err error) ([]byte, error) { - if err != nil { - return nil, err - } - protoUserData, err := proto.Marshal(resp) - if err != nil { - return nil, fmt.Errorf("error marshalling login response: %w", err) - } - return protoUserData, nil -} - -func withMarshalJson(data any, err error) ([]byte, error) { - if err != nil { - return nil, err - } - jsonData, err := json.Marshal(data) - if err != nil { - return nil, fmt.Errorf("error marshalling user data: %w", err) - } - return jsonData, nil -} - -func withMarshalJsonString(data any, err error) (string, error) { - raw, err := withMarshalJson(data, err) - return string(raw), err -} - -// StartRecoveryByEmail initializes the account recovery process for the provided email. -func (a *APIClient) StartRecoveryByEmail(ctx context.Context, email string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "start_recovery_by_email") - defer span.End() - - return traces.RecordError(ctx, a.authClient.StartRecoveryByEmail(ctx, &protos.StartRecoveryByEmailRequest{ - Email: email, - })) -} - -// CompleteRecoveryByEmail completes account recovery using the code received via email. -func (a *APIClient) CompleteRecoveryByEmail(ctx context.Context, email, newPassword, code string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "complete_recovery_by_email") - defer span.End() - lowerCaseEmail := strings.ToLower(email) - newSalt, err := generateSalt() - if err != nil { - return traces.RecordError(ctx, err) - } - srpClient, err := newSRPClient(lowerCaseEmail, newPassword, newSalt) - if err != nil { - return traces.RecordError(ctx, err) - } - verifierKey, err := srpClient.Verifier() - if err != nil { - return traces.RecordError(ctx, err) - } - - err = a.authClient.CompleteRecoveryByEmail(ctx, &protos.CompleteRecoveryByEmailRequest{ - Email: email, - Code: code, - NewSalt: newSalt, - NewVerifier: verifierKey.Bytes(), - }) - if err != nil { - return traces.RecordError(ctx, fmt.Errorf("failed to complete recovery by email: %w", err)) - } - if err = writeSalt(newSalt, a.saltPath); err != nil { - return traces.RecordError(ctx, fmt.Errorf("failed to write new salt: %w", err)) - } - return nil -} - -// ValidateEmailRecoveryCode validates the recovery code received via email. -func (a *APIClient) ValidateEmailRecoveryCode(ctx context.Context, email, code string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "validate_email_recovery_code") - defer span.End() - - resp, err := a.authClient.ValidateEmailRecoveryCode(ctx, &protos.ValidateRecoveryCodeRequest{ - Email: email, - Code: code, - }) - if err != nil { - return traces.RecordError(ctx, err) - } - if !resp.Valid { - return traces.RecordError(ctx, ErrInvalidCode) - } - return nil -} - -const group = srp.RFC5054Group3072 - -// StartChangeEmail initializes a change of the email address associated with this user account. -func (a *APIClient) StartChangeEmail(ctx context.Context, newEmail string, password string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "start_change_email") - defer span.End() - lowerCaseEmail := strings.ToLower(settings.GetString(settings.EmailKey)) - lowerCaseNewEmail := strings.ToLower(newEmail) - salt, err := a.getSalt(ctx, lowerCaseEmail) - if err != nil { - return traces.RecordError(ctx, err) - } - // Prepare login request body - encKey, err := generateEncryptedKey(password, lowerCaseEmail, salt) - if err != nil { - return traces.RecordError(ctx, err) - } - client := srp.NewSRPClient(srp.KnownGroups[group], encKey, nil) - - //Send this key to client - A := client.EphemeralPublic() - - //Create body - prepareRequestBody := &protos.PrepareRequest{ - Email: lowerCaseEmail, - A: A.Bytes(), - } - - srpB, err := a.authClient.LoginPrepare(ctx, prepareRequestBody) - if err != nil { - return traces.RecordError(ctx, err) - } - // Once the client receives B from the server Client should check error status here as defense against - // a malicious B sent from server - B := big.NewInt(0).SetBytes(srpB.B) - - if err = client.SetOthersPublic(B); err != nil { - return traces.RecordError(ctx, err) - } - - // client can now make the session key - clientKey, err := client.Key() - if err != nil || clientKey == nil { - return traces.RecordError(ctx, fmt.Errorf("user_not_found error while generating Client key %w", err)) - } - - // // check if the server proof is valid - if !client.GoodServerProof(salt, lowerCaseEmail, srpB.Proof) { - return traces.RecordError(ctx, fmt.Errorf("user_not_found error while checking server proof %w", err)) - } - - clientProof, err := client.ClientProof() - if err != nil { - return traces.RecordError(ctx, fmt.Errorf("user_not_found error while generating client proof %w", err)) - } - - changeEmailRequestBody := &protos.ChangeEmailRequest{ - OldEmail: lowerCaseEmail, - NewEmail: lowerCaseNewEmail, - Proof: clientProof, - } - - return traces.RecordError(ctx, a.authClient.ChangeEmail(ctx, changeEmailRequestBody)) -} - -// CompleteChangeEmail completes a change of the email address associated with this user account, -// using the code recieved via email. -func (a *APIClient) CompleteChangeEmail(ctx context.Context, newEmail, password, code string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "complete_change_email") - defer span.End() - newSalt, err := generateSalt() - if err != nil { - return traces.RecordError(ctx, err) - } - - encKey, err := generateEncryptedKey(password, newEmail, newSalt) - if err != nil { - return traces.RecordError(ctx, err) - } - - srpClient := srp.NewSRPClient(srp.KnownGroups[group], encKey, nil) - verifierKey, err := srpClient.Verifier() - if err != nil { - return traces.RecordError(ctx, err) - } - if err := a.authClient.CompleteChangeEmail(ctx, &protos.CompleteChangeEmailRequest{ - OldEmail: settings.GetString(settings.EmailKey), - NewEmail: newEmail, - Code: code, - NewSalt: newSalt, - NewVerifier: verifierKey.Bytes(), - }); err != nil { - return traces.RecordError(ctx, err) - } - if err := writeSalt(newSalt, a.saltPath); err != nil { - return traces.RecordError(ctx, err) - } - if err := settings.Set(settings.EmailKey, newEmail); err != nil { - return traces.RecordError(ctx, err) - } - - a.salt = newSalt - return nil -} - -// DeleteAccount deletes this user account. -func (a *APIClient) DeleteAccount(ctx context.Context, email, password string, isOAuthUser bool) ([]byte, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "delete_account") - defer span.End() - var deleteRequestBody *protos.DeleteUserRequest - lowerCaseEmail := strings.ToLower(email) - if !isOAuthUser { - salt, err := a.getSalt(ctx, lowerCaseEmail) - if err != nil { - return nil, traces.RecordError(ctx, err) - } - - // Prepare login request body - encKey, err := generateEncryptedKey(password, lowerCaseEmail, salt) - if err != nil { - return nil, traces.RecordError(ctx, err) - } - client := srp.NewSRPClient(srp.KnownGroups[group], encKey, nil) - - //Send this key to client - A := client.EphemeralPublic() - - //Create body - prepareRequestBody := &protos.PrepareRequest{ - Email: lowerCaseEmail, - A: A.Bytes(), - } - - srpB, err := a.authClient.LoginPrepare(ctx, prepareRequestBody) - if err != nil { - return nil, traces.RecordError(ctx, err) - } - - B := big.NewInt(0).SetBytes(srpB.B) - - if err = client.SetOthersPublic(B); err != nil { - return nil, traces.RecordError(ctx, err) - } - - clientKey, err := client.Key() - if err != nil || clientKey == nil { - return nil, traces.RecordError(ctx, fmt.Errorf("user_not_found error while generating Client key %w", err)) - } - - // check if the server proof is valid - if !client.GoodServerProof(salt, lowerCaseEmail, srpB.Proof) { - return nil, traces.RecordError(ctx, errors.New("user_not_found error while checking server proof")) - } - - clientProof, err := client.ClientProof() - if err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("user_not_found error while generating client proof %w", err)) - } - deleteRequestBody = &protos.DeleteUserRequest{ - Email: lowerCaseEmail, - Proof: clientProof, - Permanent: true, - DeviceId: settings.GetString(settings.DeviceIDKey), - Token: settings.GetString(settings.JwtTokenKey), - } - } else { - jwtToken := settings.GetString(settings.JwtTokenKey) - if jwtToken == "" { - return nil, traces.RecordError(ctx, errors.New("jwt token is required for OAuth account deletion")) - } - deleteRequestBody = &protos.DeleteUserRequest{ - Email: lowerCaseEmail, - Permanent: true, - Token: jwtToken, - DeviceId: settings.GetString(settings.DeviceIDKey), - } - } - if err := a.authClient.DeleteAccount(ctx, deleteRequestBody); err != nil { - return nil, traces.RecordError(ctx, err) - } - // clean up local data - a.Reset() - a.salt = nil - if err := writeSalt(nil, a.saltPath); err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("failed to write salt during account deletion cleanup: %w", err)) - } - - return withMarshalProto(a.NewUser(context.Background())) -} - -// OAuthLoginUrl initiates the OAuth login process for the specified provider. -func (a *APIClient) OAuthLoginUrl(ctx context.Context, provider string) (string, error) { - loginURL, err := url.Parse(fmt.Sprintf("%s/%s/%s", common.GetBaseURL(), "users/oauth2", provider)) - if err != nil { - return "", fmt.Errorf("failed to parse URL: %w", err) - } - query := loginURL.Query() - query.Set("deviceId", settings.GetString(settings.DeviceIDKey)) - query.Set("userId", strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10)) - query.Set("proToken", settings.GetString(settings.TokenKey)) - query.Set("returnTo", "lantern://auth") - loginURL.RawQuery = query.Encode() - return loginURL.String(), nil -} - -func (a *APIClient) OAuthLoginCallback(ctx context.Context, oAuthToken string) ([]byte, error) { - slog.Debug("Getting OAuth login callback") - jwtUserInfo, err := decodeJWT(oAuthToken) - if err != nil { - return nil, fmt.Errorf("error decoding JWT: %w", err) - } - - // Temporary set user data to so api can read it - login := &protos.LoginResponse{ - LegacyID: jwtUserInfo.LegacyUserID, - LegacyToken: jwtUserInfo.LegacyToken, - LegacyUserData: &protos.LoginResponse_UserData{ - UserId: jwtUserInfo.LegacyUserID, - Token: jwtUserInfo.LegacyToken, - DeviceID: jwtUserInfo.DeviceId, - Email: jwtUserInfo.Email, - }, - } - a.setData(login) - // Get user data from api this will also save data in user config - user, err := a.fetchUserData(context.Background()) - if err != nil { - return nil, fmt.Errorf("error getting user data: %w", err) - } - - if err := settings.Set(settings.JwtTokenKey, oAuthToken); err != nil { - slog.Error("Failed to persist JWT token", "error", err) - return nil, fmt.Errorf("failed to persist JWT token: %w", err) - } - user.Id = jwtUserInfo.Email - user.EmailConfirmed = true - a.setData(user) - return withMarshalProto(user, nil) -} - -type LinkResponse struct { - *protos.BaseResponse `json:",inline"` - UserID int `json:"userID"` - ProToken string `json:"token"` -} - -// RemoveDevice removes a device from the user's account. -func (a *APIClient) RemoveDevice(ctx context.Context, deviceID string) (*LinkResponse, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "remove_device") - defer span.End() - - data := map[string]string{ - "deviceId": deviceID, - } - proWC := a.proWebClient() - req := proWC.NewRequest(nil, nil, data) - resp := &LinkResponse{} - if err := proWC.Post(ctx, "/user-link-remove", req, resp); err != nil { - return nil, traces.RecordError(ctx, err) - } - if resp.BaseResponse != nil && resp.BaseResponse.Error != "" { - return nil, traces.RecordError(ctx, fmt.Errorf("failed to remove device: %s", resp.BaseResponse.Error)) - } - return resp, nil -} - -func (a *APIClient) ReferralAttach(ctx context.Context, code string) (bool, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "referral_attach") - defer span.End() - - data := map[string]string{ - "code": code, - } - proWC := a.proWebClient() - req := proWC.NewRequest(nil, nil, data) - resp := &protos.BaseResponse{} - if err := proWC.Post(ctx, "/referral-attach", req, resp); err != nil { - return false, traces.RecordError(ctx, err) - } - if resp.Error != "" { - return false, traces.RecordError(ctx, fmt.Errorf("%s", resp.Error)) - } - return true, nil -} - -func (a *APIClient) setData(data *protos.LoginResponse) { - a.mu.Lock() - defer a.mu.Unlock() - if data == nil { - a.Reset() - return - } - var changed bool - if data.LegacyUserData == nil { - slog.Info("No legacy user data in response, storing legacy ID and pro token only") - if data.LegacyID != 0 { - if err := settings.Set(settings.UserIDKey, data.LegacyID); err != nil { - slog.Error("failed to set user ID in settings", "error", err) - } - } - if data.LegacyToken != "" { - if err := settings.Set(settings.TokenKey, data.LegacyToken); err != nil { - slog.Error("failed to set token in settings", "error", err) - } - } - - if data.Devices != nil && len(data.Devices) > 0 { - devices := []settings.Device{} - for _, d := range data.Devices { - devices = append(devices, settings.Device{ - Name: d.Name, - ID: d.Id, - }) - } - if err := settings.Set(settings.DevicesKey, devices); err != nil { - slog.Error("failed to set devices in settings", "error", err) - } - } - return - } - - existingUser := settings.GetInt64(settings.UserIDKey) != 0 - - if data.LegacyUserData.UserLevel != "" { - oldUserLevel := settings.GetString(settings.UserLevelKey) - changed = changed || oldUserLevel != data.LegacyUserData.UserLevel - if err := settings.Set(settings.UserLevelKey, data.LegacyUserData.UserLevel); err != nil { - slog.Error("failed to set user level in settings", "error", err) - } - } - if data.LegacyUserData.Email != "" { - oldEmail := settings.GetString(settings.EmailKey) - changed = changed && oldEmail != data.LegacyUserData.Email - if err := settings.Set(settings.EmailKey, data.LegacyUserData.Email); err != nil { - slog.Error("failed to set email in settings", "error", err) - } - } - if data.LegacyID != 0 { - oldUserID := settings.GetInt64(settings.UserIDKey) - changed = changed && oldUserID != data.LegacyID - if err := settings.Set(settings.UserIDKey, data.LegacyID); err != nil { - slog.Error("failed to set user ID in settings", "error", err) - } - } - if data.LegacyToken != "" { - oldToken := settings.GetString(settings.TokenKey) - changed = changed && oldToken != data.LegacyToken - if err := settings.Set(settings.TokenKey, data.LegacyToken); err != nil { - slog.Error("failed to set token in settings", "error", err) - } - } - if data.Token != "" { - oldJwtToken := settings.GetString(settings.JwtTokenKey) - changed = changed && oldJwtToken != data.Token - if err := settings.Set(settings.JwtTokenKey, data.Token); err != nil { - slog.Error("failed to set JWT token in settings", "error", err) - } - } - if data.Devices != nil && len(data.Devices) > 0 { - devices := []settings.Device{} - for _, d := range data.Devices { - devices = append(devices, settings.Device{ - Name: d.Name, - ID: d.Id, - }) - } - if err := settings.Set(settings.DevicesKey, devices); err != nil { - slog.Error("failed to set devices in settings", "error", err) - } - } - - if err := settings.Set(settings.LoginResponseKey, data); err != nil { - slog.Error("failed to set login response in settings", "error", err) - } - - // We only consider the user to have changed if there was a previous user. - if existingUser && changed { - events.Emit(settings.UserChangeEvent{}) - } -} - -func (a *APIClient) Reset() { - // Clear user data - settings.Set(settings.UserIDKey, int64(0)) - settings.Set(settings.TokenKey, "") - settings.Set(settings.UserLevelKey, "") - settings.Set(settings.EmailKey, "") - settings.Set(settings.DevicesKey, []settings.Device{}) -} diff --git a/api/user_test.go b/api/user_test.go deleted file mode 100644 index 5a3ceb22..00000000 --- a/api/user_test.go +++ /dev/null @@ -1,346 +0,0 @@ -package api - -import ( - "context" - "encoding/hex" - "errors" - "math/big" - "path/filepath" - "testing" - - "github.com/1Password/srp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/getlantern/radiance/api/protos" - "github.com/getlantern/radiance/common/settings" -) - -func TestSignUp(t *testing.T) { - settings.InitSettings(t.TempDir()) - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - salt, signupResponse, err := ac.SignUp(context.Background(), "test@example.com", "password") - assert.NoError(t, err) - assert.NotNil(t, salt) - assert.NotNil(t, signupResponse) -} - -func TestSignupEmailResendCode(t *testing.T) { - ac := &APIClient{ - salt: []byte("salt"), - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err := ac.SignupEmailResendCode(context.Background(), "test@example.com") - assert.NoError(t, err) -} - -func TestSignupEmailConfirmation(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err := ac.SignupEmailConfirmation(context.Background(), "test@example.com", "code") - assert.NoError(t, err) -} - -func TestLogin(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - _, err := ac.Login(context.Background(), "test@example.com", "password") - assert.NoError(t, err) -} - -func TestLogout(t *testing.T) { - settings.InitSettings(t.TempDir()) - settings.Set(settings.DeviceIDKey, "deviceId") - t.Cleanup(settings.Reset) - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - _, err := ac.Logout(context.Background(), "test@example.com") - assert.NoError(t, err) -} - -func TestStartRecoveryByEmail(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err := ac.StartRecoveryByEmail(context.Background(), "test@example.com") - assert.NoError(t, err) -} - -func TestCompleteRecoveryByEmail(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err := ac.CompleteRecoveryByEmail(context.Background(), "test@example.com", "newPassword", "code") - assert.NoError(t, err) -} - -func TestValidateEmailRecoveryCode(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err := ac.ValidateEmailRecoveryCode(context.Background(), "test@example.com", "code") - assert.NoError(t, err) -} - -func TestStartChangeEmail(t *testing.T) { - email := "test@example.com" - settings.Set(settings.EmailKey, email) - authClient := mockAuthClientNew(t, email, "password") - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: authClient, - salt: authClient.salt[email], - } - err := ac.StartChangeEmail(context.Background(), "new@example.com", "password") - assert.NoError(t, err) -} - -func TestCompleteChangeEmail(t *testing.T) { - old := "old@example.com" - tmp := t.TempDir() - err := settings.InitSettings(tmp) - require.NoError(t, err) - settings.Set(settings.EmailKey, old) - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - err = ac.CompleteChangeEmail(context.Background(), "new@example.com", "password", "code") - assert.NoError(t, err) -} - -func TestDeleteAccount(t *testing.T) { - settings.InitSettings(t.TempDir()) - settings.Set(settings.DeviceIDKey, "deviceId") - t.Cleanup(settings.Reset) - email := "test@example.com" - authClient := mockAuthClientNew(t, email, "password") - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: authClient, - salt: authClient.salt[email], - } - _, err := ac.DeleteAccount(context.Background(), "test@example.com", "password", false) - assert.NoError(t, err) -} - -func TestDeleteAccount_OAuthUser(t *testing.T) { - settings.InitSettings(t.TempDir()) - settings.Set(settings.DeviceIDKey, "deviceId") - settings.Set(settings.JwtTokenKey, "jwt-token") - t.Cleanup(settings.Reset) - email := "test@example.com" - authClient := mockAuthClientNew(t, email, "password") - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: authClient, - salt: authClient.salt[email], - } - _, err := ac.DeleteAccount(context.Background(), "test@example.com", "password", true) - assert.NoError(t, err) -} -func TestDeleteAccount_OAuthUser_MissingJwtToken(t *testing.T) { - settings.InitSettings(t.TempDir()) - settings.Set(settings.DeviceIDKey, "deviceId") - t.Cleanup(settings.Reset) - email := "test@example.com" - authClient := mockAuthClientNew(t, email, "password") - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: authClient, - salt: authClient.salt[email], - } - _, err := ac.DeleteAccount(context.Background(), "test@example.com", "password", true) - assert.Error(t, err) -} - -func TestOAuthLoginUrl(t *testing.T) { - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - } - url, err := ac.OAuthLoginUrl(context.Background(), "google") - assert.NoError(t, err) - assert.NotEmpty(t, url) -} - -func TestOAuthLoginCallback(t *testing.T) { - settings.InitSettings(t.TempDir()) - settings.Set(settings.DeviceIDKey, "deviceId") - t.Cleanup(settings.Reset) - - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - - // Create a mock JWT token - mockToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20iLCJsZWdhY3lVc2VySUQiOjEyMzQ1LCJsZWdhY3lUb2tlbiI6InRlc3QtdG9rZW4ifQ.test" - - _, err := ac.OAuthLoginCallback(context.Background(), mockToken) - // This will fail because decodeJWT is not mocked, but demonstrates the test structure - assert.Error(t, err) -} - -func TestOAuthLoginCallback_InvalidToken(t *testing.T) { - settings.InitSettings(t.TempDir()) - t.Cleanup(settings.Reset) - - ac := &APIClient{ - saltPath: filepath.Join(t.TempDir(), saltFileName), - authClient: &mockAuthClient{}, - } - - _, err := ac.OAuthLoginCallback(context.Background(), "invalid-token") - assert.Error(t, err) - assert.Contains(t, err.Error(), "error decoding JWT") -} - -// Mock implementation of AuthClient for testing purposes -type mockAuthClient struct { - cache map[string]string - salt map[string][]byte - verifier []byte -} - -func mockAuthClientNew(t *testing.T, email, password string) *mockAuthClient { - salt, err := generateSalt() - require.NoError(t, err) - - encKey, err := generateEncryptedKey(password, email, salt) - require.NoError(t, err) - - srpClient := srp.NewSRPClient(srp.KnownGroups[group], encKey, nil) - verifierKey, err := srpClient.Verifier() - require.NoError(t, err) - - m := &mockAuthClient{ - salt: map[string][]byte{email: salt}, - verifier: verifierKey.Bytes(), - cache: make(map[string]string), - } - return m -} - -func (m *mockAuthClient) SignUp(ctx context.Context, email, password string) ([]byte, *protos.SignupResponse, error) { - return []byte("salt"), &protos.SignupResponse{}, nil -} - -func (m *mockAuthClient) SignupEmailResendCode(ctx context.Context, req *protos.SignupEmailResendRequest) error { - return nil -} - -func (m *mockAuthClient) SignupEmailConfirmation(ctx context.Context, req *protos.ConfirmSignupRequest) error { - return nil -} - -func (m *mockAuthClient) GetSalt(ctx context.Context, email string) (*protos.GetSaltResponse, error) { - return &protos.GetSaltResponse{Salt: []byte("salt")}, nil -} - -func (m *mockAuthClient) Login(ctx context.Context, email, password, deviceId string, salt []byte) (*protos.LoginResponse, error) { - return &protos.LoginResponse{ - LegacyUserData: &protos.LoginResponse_UserData{ - DeviceID: "deviceId", - }, - }, nil -} - -func (m *mockAuthClient) SignOut(ctx context.Context, req *protos.LogoutRequest) error { - return nil -} - -func (m *mockAuthClient) StartRecoveryByEmail(ctx context.Context, req *protos.StartRecoveryByEmailRequest) error { - return nil -} - -func (m *mockAuthClient) CompleteRecoveryByEmail(ctx context.Context, req *protos.CompleteRecoveryByEmailRequest) error { - return nil -} - -func (m *mockAuthClient) ValidateEmailRecoveryCode(ctx context.Context, req *protos.ValidateRecoveryCodeRequest) (*protos.ValidateRecoveryCodeResponse, error) { - return &protos.ValidateRecoveryCodeResponse{Valid: true}, nil -} - -func (m *mockAuthClient) ChangeEmail(ctx context.Context, req *protos.ChangeEmailRequest) error { - return nil -} - -func (m *mockAuthClient) CompleteChangeEmail(ctx context.Context, req *protos.CompleteChangeEmailRequest) error { - return nil -} - -func (m *mockAuthClient) DeleteAccount(ctx context.Context, req *protos.DeleteUserRequest) error { - return nil -} - -func (m *mockAuthClient) LoginPrepare(ctx context.Context, req *protos.PrepareRequest) (*protos.PrepareResponse, error) { - A := big.NewInt(0).SetBytes(req.A) - verifier := big.NewInt(0).SetBytes(m.verifier) - - server := srp.NewSRPServer(srp.KnownGroups[srp.RFC5054Group3072], verifier, nil) - if err := server.SetOthersPublic(A); err != nil { - return nil, err - } - B := server.EphemeralPublic() - if B == nil { - return nil, errors.New("cannot generate B") - } - if _, err := server.Key(); err != nil { - return nil, errors.New("cannot generate key") - } - proof, err := server.M(m.salt[req.Email], req.Email) - if err != nil { - return nil, errors.New("cannot generate Proof") - } - state, err := server.MarshalBinary() - if err != nil { - return nil, err - } - m.cache[req.Email] = hex.EncodeToString(state) - return &protos.PrepareResponse{B: B.Bytes(), Proof: proof}, nil -} - -func TestSetData_NilLegacyUserData(t *testing.T) { - settings.InitSettings(t.TempDir()) - t.Cleanup(settings.Reset) - - ac := &APIClient{} - - data := &protos.LoginResponse{ - LegacyID: 12345, - LegacyToken: "test-pro-token", - Devices: []*protos.LoginResponse_Device{ - {Id: "device-1", Name: "Phone"}, - {Id: "device-2", Name: "Laptop"}, - }, - // LegacyUserData intentionally nil to simulate device-limit flow - } - - ac.setData(data) - - // Verify legacy ID and pro token are persisted - assert.Equal(t, int64(12345), settings.GetInt64(settings.UserIDKey)) - assert.Equal(t, "test-pro-token", settings.GetString(settings.TokenKey)) - - // Verify devices are persisted - devices, err := settings.Devices() - require.NoError(t, err) - assert.Len(t, devices, 2) - assert.Equal(t, "device-1", devices[0].ID) - assert.Equal(t, "Phone", devices[0].Name) - assert.Equal(t, "device-2", devices[1].ID) - assert.Equal(t, "Laptop", devices[1].Name) -} diff --git a/api/webclient.go b/api/webclient.go deleted file mode 100644 index bbe5f2d2..00000000 --- a/api/webclient.go +++ /dev/null @@ -1,156 +0,0 @@ -package api - -import ( - "bytes" - "context" - "encoding/json" - "log/slog" - "unicode" - "unicode/utf8" - - "fmt" - "net/http" - - "github.com/go-resty/resty/v2" - "google.golang.org/protobuf/proto" - - "github.com/getlantern/radiance/backend" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/env" -) - -type webClient struct { - client *resty.Client -} - -func newWebClient(httpClient *http.Client, baseURL string) *webClient { - if httpClient == nil { - httpClient = &http.Client{ - Timeout: common.DefaultHTTPTimeout, - } - } - client := resty.NewWithClient(httpClient) - if baseURL != "" { - client.SetBaseURL(baseURL) - } - - client.SetHeaders(map[string]string{ - backend.AppNameHeader: common.Name, - backend.VersionHeader: common.Version, - backend.PlatformHeader: common.Platform, - }) - - // Include detected public IP on every request - client.OnBeforeRequest(func(c *resty.Client, req *resty.Request) error { - if ip := backend.GetClientIP(); ip != "" { - req.SetHeader(backend.ClientIPHeader, ip) - } - return nil - }) - - // Add a request middleware to marshal the request body to protobuf or JSON - client.OnBeforeRequest(func(c *resty.Client, req *resty.Request) error { - if req.Body == nil { - return nil - } - if pb, ok := req.Body.(proto.Message); ok { - data, err := proto.Marshal(pb) - if err != nil { - return err - } - req.Body = data - req.Header.Set("Content-Type", "application/x-protobuf") - req.Header.Set("Accept", "application/x-protobuf") - } else { - data, err := json.Marshal(req.Body) - if err != nil { - return err - } - req.Body = data - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - } - - return nil - }) - - // Add a response middleware to unmarshal the response body from protobuf or JSON - client.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error { - if len(resp.Body()) == 0 || resp.Request.Result == nil { - return nil - } - switch ct := resp.RawResponse.Header.Get("Content-Type"); ct { - case "application/x-protobuf": - pb, ok := resp.Request.Result.(proto.Message) - if !ok { - return fmt.Errorf("response body is not a protobuf message") - } - return proto.Unmarshal(resp.Body(), pb) - case "application/json": - body := sanitizeResponseBody(resp.Body()) - return json.Unmarshal(body, resp.Request.Result) - } - return nil - }) - return &webClient{client: client} -} - -func (wc *webClient) NewRequest(queryParams, headers map[string]string, body any) *resty.Request { - req := wc.client.NewRequest().SetQueryParams(queryParams).SetHeaders(headers).SetBody(body) - if curl, _ := env.Get[bool](env.PrintCurl); curl { - req = req.SetDebug(true).EnableGenerateCurlOnDebug() - } - return req -} - -func (wc *webClient) Get(ctx context.Context, path string, req *resty.Request, res any) error { - return wc.send(ctx, resty.MethodGet, path, req, res) -} - -func (wc *webClient) Post(ctx context.Context, path string, req *resty.Request, res any) error { - return wc.send(ctx, resty.MethodPost, path, req, res) -} - -func (wc *webClient) send(ctx context.Context, method, path string, req *resty.Request, res any) error { - if req == nil { - req = wc.client.NewRequest() - } - req.SetContext(ctx) - if res != nil { - req.SetResult(res) - } - - resp, err := req.Execute(method, path) - if err != nil { - return fmt.Errorf("error sending request: %w", err) - } - // print curl command for debugging - slog.Debug("CURL command", "curl", req.GenerateCurlCommand()) - if resp.StatusCode() < 200 || resp.StatusCode() >= 300 { - sanitizedBody := sanitizeResponseBody(resp.Body()) - slog.Debug("error sending request", "path", path, "status", resp.StatusCode(), "body", string(sanitizedBody)) - return fmt.Errorf("unexpected status %v body %s ", resp.StatusCode(), sanitizedBody) - } - return nil -} - -func sanitizeResponseBody(data []byte) []byte { - var out bytes.Buffer - r := bytes.NewReader(data) - for { - ch, size, err := r.ReadRune() - if err != nil { - break - } - // Skip invalid UTF-8 sequences - if ch == utf8.RuneError && size == 1 { - continue - } - // Skip control characters (optional) - if unicode.IsControl(ch) && ch != '\n' && ch != '\r' && ch != '\t' { - continue - } - out.WriteRune(ch) - } - return out.Bytes() -} diff --git a/backend/radiance.go b/backend/radiance.go new file mode 100644 index 00000000..2c76f98f --- /dev/null +++ b/backend/radiance.go @@ -0,0 +1,1082 @@ +// Package backend provides the main interface for all the major components of Radiance. +package backend + +import ( + "context" + "errors" + "fmt" + "log/slog" + "maps" + "os" + "path/filepath" + "reflect" + "slices" + "strings" + "sync" + + "time" + + "github.com/Xuanwo/go-locale" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + C "github.com/getlantern/common" + "github.com/getlantern/publicip" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/deviceid" + "github.com/getlantern/radiance/common/env" + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/config" + "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/issue" + "github.com/getlantern/radiance/kindling" + "github.com/getlantern/radiance/log" + "github.com/getlantern/radiance/servers" + "github.com/getlantern/radiance/telemetry" + "github.com/getlantern/radiance/traces" + "github.com/getlantern/radiance/vpn" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" +) + +const tracerName = "github.com/getlantern/radiance/backend" + +// LocalBackend ties all the core functionality of Radiance together. It manages the configuration, +// servers, VPN connection, account management, issue reporting, and telemetry for the application. +type LocalBackend struct { + ctx context.Context + cancel context.CancelFunc + + confHandler *config.ConfigHandler + issueReporter *issue.IssueReporter + accountClient *account.Client + + srvManager *servers.Manager + vpnClient *vpn.VPNClient + splitTunnelMgr *vpn.SplitTunnel + + shutdownFuncs []func() error + closeOnce sync.Once + stopChan chan struct{} + + deviceID string + + telemetryCfgSub *events.Subscription[config.NewConfigEvent] + stopConnMetrics context.CancelFunc + connMetricsMu sync.Mutex + + dataCapCh chan *account.DataCapInfo // latest datacap update; nil when stream not running + stopDataCap context.CancelFunc + dataCapMu sync.Mutex + + stopURLTestListener context.CancelFunc + urlTestMu sync.Mutex +} + +type Options struct { + DataDir string + LogDir string + Locale string + LogLevel string + // this should be the platform device ID on mobile devices, desktop platforms will generate their + // own device ID and ignore this value + DeviceID string + // User choice for telemetry consent + TelemetryConsent bool + PlatformInterface vpn.PlatformInterface + // EnvOverrides are applied via os.Setenv before common.Init so sandboxed + // system extensions (macOS/iOS), which don't inherit shell env, still see + // RADIANCE_* vars from the host process. Entries are set verbatim — no + // filtering. + EnvOverrides map[string]string +} + +// NewLocalBackend performs global initialization and returns a new LocalBackend instance. +// It should be called once at the start of the application. +func NewLocalBackend(ctx context.Context, opts Options) (*LocalBackend, error) { + // Must run before common.Init: it reads RADIANCE_VERSION once and + // freezes it, so a later Setenv is ignored by the header-fill path. + var envOverrideErrs error + for k, v := range opts.EnvOverrides { + if err := os.Setenv(k, v); err != nil { + envOverrideErrs = errors.Join(envOverrideErrs, fmt.Errorf("apply env override %q: %w", k, err)) + } + } + if envOverrideErrs != nil { + return nil, fmt.Errorf("failed to apply environment overrides: %w", envOverrideErrs) + } + if err := common.Init(opts.DataDir, opts.LogDir, opts.LogLevel); err != nil { + return nil, fmt.Errorf("failed to initialize common components: %w", err) + } + if opts.Locale == "" { + if tag, err := locale.Detect(); err != nil { + opts.Locale = "en-US" + } else { + opts.Locale = tag.String() + } + } + + var platformDeviceID string + switch common.Platform { + case "ios", "android": + platformDeviceID = opts.DeviceID + default: + platformDeviceID = deviceid.Get(settings.GetString(settings.DataPathKey)) + } + + dataDir := settings.GetString(settings.DataPathKey) + disableFetch := env.GetBool(env.DisableFetch) + settings.Patch(settings.Settings{ + settings.LocaleKey: opts.Locale, + settings.DeviceIDKey: platformDeviceID, + settings.ConfigFetchDisabledKey: disableFetch, + settings.TelemetryKey: opts.TelemetryConsent, + }) + + accountClient := account.NewClient(kindling.HTTPClient(), dataDir) + + svrMgr, err := servers.NewManager( + dataDir, slog.Default().With("service", "server_manager"), + ) + if err != nil { + return nil, fmt.Errorf("failed to create server manager: %w", err) + } + + splitTunnelMgr, err := vpn.NewSplitTunnelHandler( + dataDir, slog.Default().With("service", "split_tunnel"), + ) + if err != nil { + return nil, fmt.Errorf("failed to create split tunnel manager: %w", err) + } + + vpnClient := vpn.NewVPNClient(dataDir, slog.Default().With("service", "vpn"), opts.PlatformInterface) + ctx, cancel := context.WithCancel(ctx) + cOpts := config.Options{ + DataPath: dataDir, + Locale: opts.Locale, + AccountClient: accountClient, + HTTPClient: kindling.HTTPClient(), + Logger: slog.Default().With("service", "config_handler"), + } + r := &LocalBackend{ + ctx: ctx, + cancel: cancel, + issueReporter: issue.NewIssueReporter(kindling.HTTPClient()), + accountClient: accountClient, + confHandler: config.NewConfigHandler(ctx, cOpts), + srvManager: svrMgr, + vpnClient: vpnClient, + splitTunnelMgr: splitTunnelMgr, + shutdownFuncs: []func() error{ + telemetry.Close, kindling.Close, + }, + stopChan: make(chan struct{}), + closeOnce: sync.Once{}, + deviceID: platformDeviceID, + dataCapCh: make(chan *account.DataCapInfo, 1), + } + return r, nil +} + +func (r *LocalBackend) Start() { + // eagerly start kindling so it's ready by the time we need to make network requests + kindling.Init() + // QA: when an upstream outbound SOCKS5 is set, publicip.Detect would + // leak the host's real IP via direct calls to AWS/ifconfig.me, and the + // resulting X-Lantern-Config-Client-IP header would override our Russia + // egress for the API's bandit lookup. Skip detection in that mode. + if addr, _ := env.Get(env.OutboundSocksAddress); addr == "" { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + result, err := publicip.Detect(ctx, &publicip.Config{ + Timeout: 2 * time.Second, + MinConsensus: 1, + }) + cancel() + if err != nil { + slog.Warn("Failed to get public IP", "error", err) + } else { + common.SetPublicIP(result.IP.String()) + slog.Debug("Detected public IP", "confidence", result.Confidence, "sources", result.Sources) + } + }() + } else { + slog.Info("Skipping publicip.Detect because RADIANCE_OUTBOUND_SOCKS_ADDRESS is set", "addr", addr) + } + + if settings.GetBool(settings.TelemetryKey) { + if err := r.startTelemetry(); err != nil { + slog.Error("Failed to start telemetry", "error", err) + } + } + r.startVPNStatusListeners() + r.startAutoSelectedListener() + + // set country code in settings when new config is received so it can be included in issue reports + events.SubscribeOnce(func(evt config.NewConfigEvent) { + if env.GetString(env.Country) != "" { + return // respect env override if set + } + if evt.New != nil && evt.New.Country != "" { + if err := settings.Set(settings.CountryCodeKey, evt.New.Country); err != nil { + slog.Error("failed to set country code in settings", "error", err) + } + slog.Info("Set country code from config response", "country_code", evt.New.Country) + } + }) + // update VPN outbounds when new config is received + events.SubscribeContext(r.ctx, func(evt config.NewConfigEvent) { + if evt.New == nil { + return + } + cfg := evt.New + locs := make(map[string]C.ServerLocation, len(cfg.OutboundLocations)) + // Track which cities are already covered by active outbounds. + coveredCities := make(map[string]bool, len(cfg.OutboundLocations)) + for k, v := range cfg.OutboundLocations { + if v == nil { + slog.Warn("Server location is nil, skipping", "tag", k) + continue + } + locs[k] = *v + coveredCities[v.City+"|"+v.CountryCode] = true + } + // Include available server locations not already covered by active + // outbounds so the client's location picker shows every location. + for _, sl := range cfg.Servers { + if coveredCities[sl.City+"|"+sl.CountryCode] { + continue + } + key := strings.ToLower(strings.ReplaceAll(sl.City, " ", "-") + "-" + sl.CountryCode) + locs[key] = sl + } + var srvs []*servers.Server + for _, out := range cfg.Options.Outbounds { + srvs = append(srvs, &servers.Server{ + Tag: out.Tag, Type: out.Type, IsLantern: true, + Options: out, Location: locs[out.Tag], + }) + } + for _, ep := range cfg.Options.Endpoints { + srvs = append(srvs, &servers.Server{ + Tag: ep.Tag, Type: ep.Type, IsLantern: true, + Options: ep, Location: locs[ep.Tag], + }) + } + list := servers.ServerList{Servers: srvs, URLOverrides: cfg.BanditURLOverrides} + if len(cfg.BanditURLOverrides) > 0 { + // Create a marker span linked to the API's bandit trace so the + // config fetch appears in the same distributed trace as the callback. + if ctx, ok := traces.ExtractBanditTraceContext(cfg.BanditURLOverrides); ok { + _, span := otel.Tracer(tracerName).Start(ctx, "radiance.config_received", + trace.WithAttributes( + attribute.Int("bandit.override_count", len(cfg.BanditURLOverrides)), + attribute.Int("bandit.outbound_count", len(cfg.Options.Outbounds)), + ), + ) + span.End() // point-in-time marker — config was received at this timestamp + } + } + if err := r.setServers(list, true); err != nil { + slog.Error("setting servers in manager", "error", err) + } + if err := r.RunOfflineURLTests(); err != nil && !errors.Is(err, vpn.ErrTunnelAlreadyConnected) { + // ErrTunnelAlreadyConnected is the expected, non-error case while + // the VPN is up: setServers above already pushed the new outbounds + // (and any bandit URL overrides) into the running tunnel, and + // addOutbounds triggers an immediate URL test cycle for them via + // MutableURLTest.CheckOutbounds. The "offline" pre-warm path here + // is for the not-yet-connected case only — running both would + // duplicate work and conflict with the live URLTest selector. + slog.Error("Failed to run offline URL tests after config update", "error", err) + } + }) + go r.confHandler.Start() +} + +func (r *LocalBackend) Close() { + r.closeOnce.Do(func() { + slog.Debug("Closing Radiance") + r.cancel() // cancels context, unsubscribes all event listeners and stops child goroutines + close(r.stopChan) + for _, shutdown := range r.shutdownFuncs { + if err := shutdown(); err != nil { + slog.Error("Failed to shutdown", "error", err) + } + } + }) + <-r.stopChan +} + +func (r *LocalBackend) startVPNStatusListeners() { + events.SubscribeContext(r.ctx, func(evt vpn.StatusUpdateEvent) { + r.updateConnMetrics(evt.Status) + }) + events.SubscribeContext(r.ctx, func(evt vpn.StatusUpdateEvent) { + r.updateDataCapStream(evt.Status) + }) + events.SubscribeContext(r.ctx, func(evt vpn.StatusUpdateEvent) { + r.updateURLTestListener(evt.Status) + }) +} + +////////////////// +// Issue Report // +////////////////// + +// ReportIssue allows the user to report an issue with the application. It collects relevant +// information about the user's environment such as country, device ID, user ID, subscription level, +// and locale, and log files to include in the report. The additionalAttachments parameter allows +// the caller to include any extra files they want to attach to the issue report. +func (r *LocalBackend) ReportIssue(issueType issue.IssueType, description, email string, additionalAttachments []string) error { + ctx, span := otel.Tracer(tracerName).Start(context.Background(), "report_issue") + defer span.End() + // get country from the config returned by the backend + var country string + cfg, err := r.confHandler.GetConfig() + if err != nil { + slog.Warn("Failed to get config", "error", err) + } else { + country = cfg.Country + } + + attachments := baseIssueAttachments() + if r.splitTunnelMgr.IsEnabled() { + attachments = append(attachments, filepath.Join(settings.GetString(settings.DataPathKey), internal.SplitTunnelFileName)) + } + attachments = append(attachments, additionalAttachments...) + + report := issue.IssueReport{ + Type: issueType, + Description: description, + Email: email, + CountryCode: country, + DeviceID: r.deviceID, + UserID: settings.GetString(settings.UserIDKey), + SubscriptionLevel: settings.GetString(settings.UserLevelKey), + Locale: settings.GetString(settings.LocaleKey), + AdditionalAttachments: attachments, + } + err = r.issueReporter.Report(ctx, report) + if err != nil { + slog.Error("Failed to report issue", "error", err) + return traces.RecordError(ctx, fmt.Errorf("failed to report issue: %w", err)) + } + slog.Info("Issue reported successfully") + return nil +} + +// baseIssueAttachments returns a list of file paths to include as attachments in every issue report +// in order of importance. +func baseIssueAttachments() []string { + logPath := settings.GetString(settings.LogPathKey) + dataPath := settings.GetString(settings.DataPathKey) + // TODO: any other files we want to include?? + return []string{ + filepath.Join(logPath, internal.CrashLogFileName), + filepath.Join(dataPath, internal.ConfigFileName), + filepath.Join(dataPath, internal.ServersFileName), + filepath.Join(dataPath, internal.DebugBoxOptionsFileName), + } +} + +///////////////// +// Settings // +///////////////// + +// UpdateConfig forces an immediate fetch of the latest configuration. It returns +// [config.ErrConfigFetchDisabled] if config fetching is disabled in settings. +func (r *LocalBackend) UpdateConfig() error { + return r.confHandler.Update() +} + +// Features returns the features available in the current configuration, returned from the server in the +// config response. +func (r *LocalBackend) Features() map[string]bool { + _, span := otel.Tracer(tracerName).Start(context.Background(), "features") + defer span.End() + cfg, err := r.confHandler.GetConfig() + if err != nil { + slog.Info("Failed to get config for features", "error", err) + return map[string]bool{} + } + if cfg == nil { + slog.Info("No config available for features, returning empty map") + return map[string]bool{} + } + slog.Debug("Returning features from config", "features", cfg.Features) + if cfg.Features == nil { + slog.Info("No features available in config, returning empty map") + return map[string]bool{} + } + return cfg.Features +} + +func (r *LocalBackend) PatchSettings(updates settings.Settings) error { + curr := settings.GetAllFor(slices.Collect(maps.Keys(updates))...) + diff := updates.Diff(curr) + slog.Log(nil, log.LevelTrace, "Patching settings", "updates", updates, "current", curr, "diff", diff) + if len(diff) == 0 { + return nil + } + if err := settings.Patch(diff); err != nil { + return fmt.Errorf("failed to update settings: %w", err) + } + // telemetry settings + if _, ok := diff[settings.TelemetryKey]; ok { + if settings.GetBool(settings.TelemetryKey) { + if err := r.startTelemetry(); err != nil { + slog.Error("Failed to start telemetry", "error", err) + } + } else { + r.stopTelemetry() + } + } + + // vpn settings + k := settings.SplitTunnelKey + if _, ok := diff[k]; ok { + r.splitTunnelMgr.SetEnabled(settings.GetBool(k)) + } + r.maybeRestartVPN(diff) + + return nil +} + +// maybeRestartVPN restarts the VPN connection if either the ad block or smart routing settings +// were changed and the VPN is currently connected. +func (r *LocalBackend) maybeRestartVPN(updates settings.Settings) { + _, adBlockChanged := updates[settings.AdBlockKey] + _, smartRoutingChanged := updates[settings.SmartRoutingKey] + if (adBlockChanged || smartRoutingChanged) && r.vpnClient.Status() == vpn.Connected { + slog.Info("Restarting VPN to apply new settings", "ad_block_changed", adBlockChanged, "smart_routing_changed", smartRoutingChanged) + bOptions := r.getBoxOptions() + go r.vpnClient.Restart(bOptions) + } +} + +///////////////// +// telemetry // +///////////////// + +func (r *LocalBackend) startTelemetry() error { + cfg, err := r.confHandler.GetConfig() + if err == nil { + if err := telemetry.Initialize(r.deviceID, *cfg, settings.IsPro()); err != nil { + return fmt.Errorf("failed to initialize telemetry: %w", err) + } + } + if r.telemetryCfgSub != nil { + return nil + } + // subscribe to config changes to update telemetry config + r.telemetryCfgSub = events.SubscribeContext(r.ctx, func(evt config.NewConfigEvent) { + if !settings.GetBool(settings.TelemetryKey) { + return + } + if evt.Old != nil && reflect.DeepEqual(evt.Old.OTEL, evt.New.OTEL) { + // no changes to telemetry config, no need to update + return + } + if err := telemetry.Initialize(r.deviceID, *evt.New, settings.IsPro()); err != nil { + slog.Error("Failed to update telemetry config", "error", err) + } + }) + return nil +} + +func (r *LocalBackend) stopTelemetry() { + if r.telemetryCfgSub != nil { + r.telemetryCfgSub.Unsubscribe() + r.telemetryCfgSub = nil + } + r.updateConnMetrics(vpn.Disconnected) + telemetry.Close() +} + +func (r *LocalBackend) updateConnMetrics(status vpn.VPNStatus) { + if !settings.GetBool(settings.TelemetryKey) { + return + } + r.connMetricsMu.Lock() + defer r.connMetricsMu.Unlock() + if status == vpn.Connected { + if r.stopConnMetrics != nil { + return // already running + } + ctx, cancel := context.WithCancel(r.ctx) + telemetry.StartConnectionMetrics(ctx, r.vpnClient, 1*time.Minute) + r.stopConnMetrics = cancel + slog.Debug("Started connection metrics collection") + } else if r.stopConnMetrics != nil { + r.stopConnMetrics() + r.stopConnMetrics = nil + slog.Debug("Stopped connection metrics collection") + } +} + +/////////////////////// +// Server management // +/////////////////////// + +func (r *LocalBackend) AllServers() []*servers.Server { + return r.srvManager.AllServers() +} + +func (r *LocalBackend) GetServerByTag(tag string) (*servers.Server, bool) { + return r.srvManager.GetServerByTag(tag) +} + +func (r *LocalBackend) AddServers(list servers.ServerList) error { + if err := r.srvManager.AddServers(list, false); err != nil { + return fmt.Errorf("failed to add servers to ServerManager: %w", err) + } + if err := r.vpnClient.AddOutbounds(list); err != nil && !errors.Is(err, vpn.ErrTunnelNotConnected) { + return fmt.Errorf("failed to add outbounds to VPN client: %w", err) + } + return nil +} + +func (r *LocalBackend) RemoveServers(tags []string) error { + removed, err := r.srvManager.RemoveServers(tags) + if err != nil { + return fmt.Errorf("failed to remove servers from ServerManager: %w", err) + } + removedTags := make([]string, 0, len(removed)) + for _, srv := range removed { + removedTags = append(removedTags, srv.Tag) + } + if len(removedTags) > 0 { + var selected servers.Server + if err := settings.GetStruct(settings.SelectedServerKey, &selected); err == nil { + if slices.Contains(removedTags, selected.Tag) { + // clear selected server from settings if it's being removed + if err := settings.Set(settings.SelectedServerKey, nil); err != nil { + slog.Warn("Failed to clear selected server from settings after it was removed", "error", err) + } + } + } + if err := r.vpnClient.RemoveOutbounds(removedTags); err != nil && !errors.Is(err, vpn.ErrTunnelNotConnected) { + return fmt.Errorf("failed to remove outbounds: %w", err) + } + } + return nil +} + +func (r *LocalBackend) setServers(list servers.ServerList, isLantern bool) error { + if err := r.srvManager.SetServers(list, isLantern); err != nil { + return fmt.Errorf("failed to set servers in ServerManager: %w", err) + } + err := r.vpnClient.UpdateOutbounds(list) + if err != nil && !errors.Is(err, vpn.ErrTunnelNotConnected) { + slog.Error("Failed to update VPN outbounds after config change", "error", err) + } + return nil +} + +func (r *LocalBackend) AddServersByJSON(config string) ([]string, error) { + return r.srvManager.AddServersByJSON(context.Background(), []byte(config)) +} + +func (r *LocalBackend) AddServersByURL(urls []string, skipCertVerification bool) ([]string, error) { + return r.srvManager.AddServersByURL(context.Background(), urls, skipCertVerification) +} + +func (r *LocalBackend) AddPrivateServer(tag, ip string, port int, accessToken string, loc C.ServerLocation, joined bool) error { + return r.srvManager.AddPrivateServer(tag, ip, port, accessToken, loc, joined) +} + +func (r *LocalBackend) InviteToPrivateServer(ip string, port int, accessToken string, inviteName string) (string, error) { + return r.srvManager.InviteToPrivateServer(ip, port, accessToken, inviteName) +} + +func (r *LocalBackend) RevokePrivateServerInvite(ip string, port int, accessToken string, inviteName string) error { + return r.srvManager.RevokePrivateServerInvite(ip, port, accessToken, inviteName) +} + +// urlTestFlushInterval bounds how often URL test results are written back to the servers manager +// (and to disk). URL test cycles run on the order of minutes and notify per-result, so coalescing +// into a periodic flush avoids re-marshalling and re-writing the servers file for each parallel result. +const urlTestFlushInterval = 5 * time.Second + +func (r *LocalBackend) updateURLTestListener(status vpn.VPNStatus) { + r.urlTestMu.Lock() + defer r.urlTestMu.Unlock() + // Status events are dispatched in unordered goroutines, so reacting to + // intermediate statuses (Connecting, Disconnecting, Restarting) risks a + // stale handler tearing down a listener a concurrent Connected handler + // just attached to the new tunnel. + switch status { + case vpn.Connected: + if r.stopURLTestListener != nil { + r.stopURLTestListener() + r.stopURLTestListener = nil + } + storage := r.vpnClient.HistoryStorage() + if storage == nil { + return + } + ctx, cancel := context.WithCancel(r.ctx) + r.stopURLTestListener = cancel + hook := make(chan struct{}, 1) + storage.SetHook(hook) + go r.runURLTestListener(ctx, storage, hook) + slog.Debug("Started URL test result listener") + case vpn.Disconnected, vpn.ErrorStatus: + if r.stopURLTestListener != nil { + r.stopURLTestListener() + r.stopURLTestListener = nil + slog.Debug("Stopped URL test result listener") + } + } +} + +// runURLTestListener coalesces per-result hook notifications into a periodic flush so the servers +// file isn't rewritten for each parallel URL test completion. A final flush runs on shutdown so any +// results that arrived since the last tick are persisted. +func (r *LocalBackend) runURLTestListener(ctx context.Context, storage vpn.URLTestHistoryStorage, hook <-chan struct{}) { + ticker := time.NewTicker(urlTestFlushInterval) + defer ticker.Stop() + dirty := true // start dirty so we persist any results that arrived before the listener started + for { + select { + case <-ctx.Done(): + if dirty { + r.flushURLTestResults(storage) + } + return + case <-hook: + dirty = true + case <-ticker.C: + if dirty { + r.flushURLTestResults(storage) + dirty = false + } + } + } +} + +func (r *LocalBackend) flushURLTestResults(storage vpn.URLTestHistoryStorage) { + results := make(map[string]servers.URLTestResult) + for _, srv := range r.srvManager.AllServers() { + if h := storage.LoadURLTestHistory(srv.Tag); h != nil { + results[srv.Tag] = servers.URLTestResult{Delay: h.Delay, Time: h.Time} + } + } + if len(results) > 0 { + if err := r.srvManager.UpdateURLTestResults(results); err != nil { + slog.Warn("Failed to persist URL test results", "error", err) + } + } +} + +///////////////// +// VPN // +///////////////// + +func (r *LocalBackend) VPNStatus() vpn.VPNStatus { + return r.vpnClient.Status() +} + +func (r *LocalBackend) ConnectVPN(tag string) error { + if tag == "" { + tag = vpn.AutoSelectTag + } + if tag != vpn.AutoSelectTag { + if _, found := r.srvManager.GetServerByTag(tag); !found { + return fmt.Errorf("no server found with tag %s", tag) + } + } + bOptions := r.getBoxOptions() + bOptions.InitialServer = tag + if err := r.vpnClient.Connect(bOptions); err != nil { + return fmt.Errorf("failed to connect VPN: %w", err) + } + r.persistSelection(tag) + return nil +} + +func (r *LocalBackend) getBoxOptions() vpn.BoxOptions { + // ignore error, we can still connect with default options if config is not available for some reason + cfg, _ := r.confHandler.GetConfig() + bOptions := vpn.BoxOptions{ + BasePath: settings.GetString(settings.DataPathKey), + } + if cfg != nil { + bOptions.Options = cfg.Options + bOptions.BanditURLOverrides = cfg.BanditURLOverrides + bOptions.BanditThroughputURL = cfg.BanditThroughputURL + if settings.GetBool(settings.SmartRoutingKey) { + bOptions.SmartRouting = cfg.SmartRouting + } + if settings.GetBool(settings.AdBlockKey) { + bOptions.AdBlock = cfg.AdBlock + } + } + seed := make(map[string]adapter.URLTestHistory) + for _, srv := range r.srvManager.AllServers() { + if !srv.IsLantern { + switch opts := srv.Options.(type) { + case option.Outbound: + bOptions.Options.Outbounds = append(bOptions.Options.Outbounds, opts) + case option.Endpoint: + bOptions.Options.Endpoints = append(bOptions.Options.Endpoints, opts) + } + } + if srv.URLTestResult != nil { + seed[srv.Tag] = adapter.URLTestHistory{ + Time: srv.URLTestResult.Time, + Delay: srv.URLTestResult.Delay, + } + } + } + if len(seed) > 0 { + bOptions.URLTestSeed = seed + } + return bOptions +} + +func (r *LocalBackend) DisconnectVPN() error { + return r.vpnClient.Disconnect() +} + +func (r *LocalBackend) RestartVPN() error { + bOptions := r.getBoxOptions() + return r.vpnClient.Restart(bOptions) +} + +// SelectServer selects the server identified by tag. The empty string is +// treated as [vpn.AutoSelectTag]. +func (r *LocalBackend) SelectServer(tag string) error { + if tag == "" { + tag = vpn.AutoSelectTag + } + if err := r.vpnClient.SelectServer(tag); err != nil { + return fmt.Errorf("failed to select server: %w", err) + } + r.persistSelection(tag) + return nil +} + +// persistSelection records the user's server choice in settings. tag must be +// AutoSelectTag or the tag of a server known to the manager. +func (r *LocalBackend) persistSelection(tag string) { + if tag == vpn.AutoSelectTag { + if err := settings.Patch(settings.Settings{ + settings.AutoConnectKey: true, + settings.SelectedServerKey: nil, + }); err != nil { + slog.Warn("failed to update settings", "error", err) + } + return + } + server, found := r.srvManager.GetServerByTag(tag) + if !found { + slog.Warn("no server found for tag, skipping settings persistence", "tag", tag) + return + } + server.Options = nil + if err := settings.Patch(settings.Settings{ + settings.AutoConnectKey: false, + settings.SelectedServerKey: server, + }); err != nil { + slog.Warn("Failed to save selected server in settings", "error", err) + return + } + slog.Info("Selected server", "tag", tag, "type", server.Type) +} + +// VPNConnections returns a list of all connections, both active and recently closed. If there are no +// connections and the tunnel is open, an empty slice is returned without an error. +func (r *LocalBackend) VPNConnections() ([]vpn.Connection, error) { + return r.vpnClient.Connections() +} + +// ActiveVPNConnections returns a list of currently active connections, ordered from newest to oldest. +func (r *LocalBackend) ActiveVPNConnections() ([]vpn.Connection, error) { + connections, err := r.vpnClient.Connections() + if err != nil { + return nil, fmt.Errorf("failed to get VPN connections: %w", err) + } + connections = slices.DeleteFunc(connections, func(conn vpn.Connection) bool { + return conn.ClosedAt != 0 + }) + slices.SortFunc(connections, func(a, b vpn.Connection) int { + return int(b.CreatedAt - a.CreatedAt) + }) + return connections, nil +} + +// TODO: handle case where selected server is no longer available (e.g. removed from manager) more +// gracefully, currently we just return that the server is no longer available but maybe we should +// also clear the selected server from settings and select a new server in the VPN client. +// should we not remove a lantern server if it's currently selected in the VPN client and instead +// mark it as unavailable in the manager until it's no longer selected in the VPN client? + +// SelectedServer returns the currently selected server and whether the server is still available. +// The server may no longer be available if it was removed from the manager since it was selected. +func (r *LocalBackend) SelectedServer() (*servers.Server, bool, error) { + if settings.GetBool(settings.AutoConnectKey) { + tag, err := r.vpnClient.CurrentAutoSelectedServer() + if err != nil { + return nil, false, fmt.Errorf("failed to get current auto-selected server: %w", err) + } + server, found := r.srvManager.GetServerByTag(tag) + return server, found, nil + } + if !settings.Exists(settings.SelectedServerKey) { + return nil, false, fmt.Errorf("no selected server") + } + var selected servers.Server + if err := settings.GetStruct(settings.SelectedServerKey, &selected); err != nil { + return nil, false, fmt.Errorf("failed to get selected server from settings: %w", err) + } + server, found := r.srvManager.GetServerByTag(selected.Tag) + stillExists := found && + server.IsLantern == selected.IsLantern && + server.Type == selected.Type && + server.Location == selected.Location + return &selected, stillExists, nil +} + +// CurrentAutoSelectedServer returns the tag of the server that is currently auto-selected by the +// VPN client. +func (r *LocalBackend) CurrentAutoSelectedServer() (string, error) { + return r.vpnClient.CurrentAutoSelectedServer() +} + +func (r *LocalBackend) startAutoSelectedListener() { + var ( + mu sync.Mutex + cancel context.CancelFunc + ) + events.SubscribeContext(r.ctx, func(evt vpn.StatusUpdateEvent) { + mu.Lock() + defer mu.Unlock() + if cancel != nil { + cancel() + cancel = nil + } + if evt.Status == vpn.Connected { + var ctx context.Context + ctx, cancel = context.WithCancel(r.ctx) + r.vpnClient.AutoSelectedChangeListener(ctx) + } + }) +} + +func (r *LocalBackend) RunOfflineURLTests() error { + cfg, err := r.confHandler.GetConfig() + if err != nil { + return fmt.Errorf("no config available: %w", err) + } + svrs := r.srvManager.AllServers() + slog.Debug("Running offline URL tests", "server_count", len(svrs), "url_override_count", len(cfg.BanditURLOverrides)) + results, err := r.vpnClient.RunOfflineURLTests( + settings.GetString(settings.DataPathKey), + servers.ServerList{Servers: svrs}.Outbounds(), + cfg.BanditURLOverrides, + ) + if err != nil { + return err + } + now := time.Now() + urlResults := make(map[string]servers.URLTestResult, len(results)) + for tag, delay := range results { + urlResults[tag] = servers.URLTestResult{Delay: delay, Time: now} + } + if len(urlResults) > 0 { + if err := r.srvManager.UpdateURLTestResults(urlResults); err != nil { + slog.Warn("Failed to persist offline URL test results", "error", err) + } + selected, err := r.vpnClient.CurrentAutoSelectedServer() + if err != nil { + slog.Warn("Failed to get current auto-selected server after URL tests", "error", err) + } else { + events.Emit(vpn.AutoSelectedEvent{Selected: selected}) + } + } + return nil +} + +////////////////// +// Split Tunnel // +///////////////// + +func (r *LocalBackend) SplitTunnelFilters() vpn.SplitTunnelFilter { + return r.splitTunnelMgr.Filters() +} + +func (r *LocalBackend) AddSplitTunnelItems(items vpn.SplitTunnelFilter) error { + return r.splitTunnelMgr.AddItems(items) +} + +func (r *LocalBackend) RemoveSplitTunnelItems(items vpn.SplitTunnelFilter) error { + return r.splitTunnelMgr.RemoveItems(items) +} + +///////////// +// Account // +///////////// + +func (r *LocalBackend) NewUser(ctx context.Context) (*account.UserData, error) { + return r.accountClient.NewUser(ctx) +} + +func (r *LocalBackend) Login(ctx context.Context, email, password string) (*account.UserData, error) { + return r.accountClient.Login(ctx, email, password) +} + +func (r *LocalBackend) Logout(ctx context.Context, email string) (*account.UserData, error) { + return r.accountClient.Logout(ctx, email) +} + +func (r *LocalBackend) FetchUserData(ctx context.Context) (*account.UserData, error) { + return r.accountClient.FetchUserData(ctx) +} + +func (r *LocalBackend) StartChangeEmail(ctx context.Context, newEmail, password string) error { + return r.accountClient.StartChangeEmail(ctx, newEmail, password) +} + +func (r *LocalBackend) CompleteChangeEmail(ctx context.Context, newEmail, password, code string) error { + return r.accountClient.CompleteChangeEmail(ctx, newEmail, password, code) +} + +func (r *LocalBackend) StartRecoveryByEmail(ctx context.Context, email string) error { + return r.accountClient.StartRecoveryByEmail(ctx, email) +} + +func (r *LocalBackend) CompleteRecoveryByEmail(ctx context.Context, email, newPassword, code string) error { + return r.accountClient.CompleteRecoveryByEmail(ctx, email, newPassword, code) +} + +func (r *LocalBackend) DeleteAccount(ctx context.Context, email, password string) (*account.UserData, error) { + return r.accountClient.DeleteAccount(ctx, email, password) +} + +func (r *LocalBackend) SignUp(ctx context.Context, email, password string) ([]byte, *account.SignupResponse, error) { + return r.accountClient.SignUp(ctx, email, password) +} + +func (r *LocalBackend) SignupEmailConfirmation(ctx context.Context, email, code string) error { + return r.accountClient.SignupEmailConfirmation(ctx, email, code) +} + +func (r *LocalBackend) SignupEmailResendCode(ctx context.Context, email string) error { + return r.accountClient.SignupEmailResendCode(ctx, email) +} + +func (r *LocalBackend) ValidateEmailRecoveryCode(ctx context.Context, email, code string) error { + return r.accountClient.ValidateEmailRecoveryCode(ctx, email, code) +} + +func (r *LocalBackend) DataCapInfo(ctx context.Context) (*account.DataCapInfo, error) { + return r.accountClient.DataCapInfo(ctx) +} + +// DataCapUpdates returns the channel that receives datacap updates from the +// upstream SSE stream. The stream runs while the VPN is connected; the channel +// is never closed so callers should select on it alongside a context or other +// signal. +func (r *LocalBackend) DataCapUpdates() <-chan *account.DataCapInfo { + return r.dataCapCh +} + +func (r *LocalBackend) updateDataCapStream(status vpn.VPNStatus) { + r.dataCapMu.Lock() + defer r.dataCapMu.Unlock() + if status == vpn.Connected { + if r.stopDataCap != nil { + return // already running + } + ctx, cancel := context.WithCancel(r.ctx) + r.stopDataCap = cancel + go func() { + _ = r.accountClient.DataCapStream(ctx, func(info *account.DataCapInfo) { + // Non-blocking send; drops stale updates if the reader is slow. + select { + case r.dataCapCh <- info: + default: + select { + case <-r.dataCapCh: + default: + } + r.dataCapCh <- info + } + }) + }() + slog.Debug("Started datacap SSE stream") + } else if r.stopDataCap != nil { + r.stopDataCap() + r.stopDataCap = nil + slog.Debug("Stopped datacap SSE stream") + } +} + +func (r *LocalBackend) RemoveDevice(ctx context.Context, deviceID string) (*account.LinkResponse, error) { + return r.accountClient.RemoveDevice(ctx, deviceID) +} + +func (r *LocalBackend) OAuthLoginCallback(ctx context.Context, oAuthToken string) (*account.UserData, error) { + return r.accountClient.OAuthLoginCallback(ctx, oAuthToken) +} + +func (r *LocalBackend) OAuthLoginURL(ctx context.Context, provider string) (string, error) { + return r.accountClient.OAuthLoginURL(ctx, provider) +} + +func (r *LocalBackend) UserDevices() ([]settings.Device, error) { + return settings.Devices() +} + +func (r *LocalBackend) UserData() (*account.UserData, error) { + var userData account.UserData + if err := settings.GetStruct(settings.UserDataKey, &userData); err != nil { + return nil, fmt.Errorf("failed to get user data from settings: %w", err) + } + return &userData, nil +} + +/////////////////// +// Subscriptions // +/////////////////// + +func (r *LocalBackend) ActivationCode(ctx context.Context, email, resellerCode string) (*account.PurchaseResponse, error) { + return r.accountClient.ActivationCode(ctx, email, resellerCode) +} + +func (r *LocalBackend) NewStripeSubscription(ctx context.Context, email, planID string) (string, error) { + return r.accountClient.NewStripeSubscription(ctx, email, planID) +} + +func (r *LocalBackend) PaymentRedirect(ctx context.Context, data account.PaymentRedirectData) (string, error) { + return r.accountClient.PaymentRedirect(ctx, data) +} + +func (r *LocalBackend) ReferralAttach(ctx context.Context, code string) (bool, error) { + return r.accountClient.ReferralAttach(ctx, code) +} + +func (r *LocalBackend) StripeBillingPortalURL(ctx context.Context) (string, error) { + return r.accountClient.StripeBillingPortalURL(ctx, + common.GetProServerURL(), settings.GetString(settings.UserIDKey), settings.GetString(settings.TokenKey), + ) +} + +func (r *LocalBackend) SubscriptionPaymentRedirectURL(ctx context.Context, data account.PaymentRedirectData) (string, error) { + return r.accountClient.SubscriptionPaymentRedirectURL(ctx, data) +} + +func (r *LocalBackend) SubscriptionPlans(ctx context.Context, channel string) (string, error) { + return r.accountClient.SubscriptionPlans(ctx, channel) +} + +func (r *LocalBackend) VerifySubscription(ctx context.Context, service account.SubscriptionService, data map[string]string) (string, error) { + return r.accountClient.VerifySubscription(ctx, service, data) +} diff --git a/backend/radiance_test.go b/backend/radiance_test.go new file mode 100644 index 00000000..dd6eaa62 --- /dev/null +++ b/backend/radiance_test.go @@ -0,0 +1,8 @@ +package backend + +import ( + "testing" +) + +func TestBackend(t *testing.T) { +} diff --git a/bypass/bypass.go b/bypass/bypass.go index 2332856f..15140cd9 100644 --- a/bypass/bypass.go +++ b/bypass/bypass.go @@ -8,7 +8,13 @@ import ( "net" "net/http" "net/url" + "sync" "time" + + "golang.org/x/net/proxy" + + "github.com/getlantern/radiance/common/env" + "github.com/getlantern/radiance/log" ) const ( @@ -21,21 +27,6 @@ const ( // BypassInboundTag is the sing-box inbound tag used for routing bypass traffic to direct. BypassInboundTag = "bypass-in" - // TunnelProxyPort is the port for the local tunnel proxy listener. - // Traffic entering this inbound is routed through the active VPN proxy, - // unlike bypass traffic which is routed directly. - TunnelProxyPort = 14986 - - // TunnelInboundTag is the sing-box inbound tag for the tunnel proxy. - // Unlike BypassInboundTag, this has no routing rule sending it to direct, - // so traffic falls through to the active proxy group. - TunnelInboundTag = "tunnel-in" -) - -// TunnelProxyAddr is the address of the local tunnel proxy listener. -var TunnelProxyAddr = net.JoinHostPort("127.0.0.1", fmt.Sprintf("%d", TunnelProxyPort)) - -const ( // connectTimeout is the default timeout for the HTTP CONNECT handshake // when the caller's context has no deadline. connectTimeout = 10 * time.Second @@ -49,14 +40,21 @@ const ( // DialContext tries to connect through the local bypass proxy. If the proxy is // not reachable (VPN not running), it falls back to a direct dial. +// +// QA: when env.OutboundSocksAddress is set, both the bypass-proxy path and the +// direct-fallback path are replaced by a dial through that upstream SOCKS5, +// so every dial out of radiance goes via the same residential egress. func DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + if d, ok := outboundSocksDialer(); ok { + return d.DialContext(ctx, network, addr) + } dialer := &net.Dialer{ Timeout: dialTimeout, KeepAlive: dialKeepAlive, } proxyConn, err := dialer.DialContext(ctx, "tcp", ProxyAddr) if err != nil { - slog.Debug("bypass proxy not reachable, falling back to direct dial", "addr", addr, "error", err) + slog.Log(nil, log.LevelTrace, "bypass proxy not reachable, falling back to direct dial", "addr", addr, "error", err) return dialer.DialContext(ctx, network, addr) } tunnelConn, err := httpConnect(ctx, proxyConn, addr) @@ -67,33 +65,37 @@ func DialContext(ctx context.Context, network, addr string) (net.Conn, error) { return tunnelConn, nil } +var ( + outboundSocksOnce sync.Once + outboundSocksDialFn proxy.ContextDialer +) + +// outboundSocksDialer returns a SOCKS5 ContextDialer for env.OutboundSocksAddress +// if set, cached after the first successful build. +func outboundSocksDialer() (proxy.ContextDialer, bool) { + outboundSocksOnce.Do(func() { + addr, ok := env.Get(env.OutboundSocksAddress) + if !ok || addr == "" { + return + } + d, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct) + if err != nil { + slog.Error("invalid RADIANCE_OUTBOUND_SOCKS_ADDRESS for bypass dialer", slog.Any("error", err), slog.String("addr", addr)) + return + } + if cd, ok := d.(proxy.ContextDialer); ok { + outboundSocksDialFn = cd + } + }) + return outboundSocksDialFn, outboundSocksDialFn != nil +} + // Dial is a convenience wrapper without context, suitable for use with // amp.WithDialer which expects func(network, addr string) (net.Conn, error). func Dial(network, addr string) (net.Conn, error) { return DialContext(context.Background(), network, addr) } -// TunnelDialContext connects through the local tunnel proxy, which routes -// traffic through the active VPN proxy outbound. Unlike DialContext, it does -// NOT fall back to a direct dial when the proxy is unreachable — if the VPN is -// not running, the connection fails. -func TunnelDialContext(ctx context.Context, network, addr string) (net.Conn, error) { - dialer := &net.Dialer{ - Timeout: dialTimeout, - KeepAlive: dialKeepAlive, - } - proxyConn, err := dialer.DialContext(ctx, "tcp", TunnelProxyAddr) - if err != nil { - return nil, fmt.Errorf("tunnel proxy not reachable: %w", err) - } - tunnelConn, err := httpConnect(ctx, proxyConn, addr) - if err != nil { - proxyConn.Close() - return nil, err - } - return tunnelConn, nil -} - // bufferedConn wraps a net.Conn with a bufio.Reader so that any bytes // buffered during the HTTP CONNECT response read are not lost. type bufferedConn struct { diff --git a/cmd/Makefile b/cmd/Makefile index e39104dd..c57f1edf 100644 --- a/cmd/Makefile +++ b/cmd/Makefile @@ -1,7 +1,39 @@ -TAGS=with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale +TAGS=with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_conntrack +UNAME_S := $(shell uname -s) +ifeq ($(UNAME_S),Darwin) + TAGS := standalone,$(TAGS) +endif + +ifeq ($(OS),Windows_NT) + LANTERND := lanternd.exe + LANTERN := lantern.exe +else + LANTERND := lanternd + LANTERN := lantern +endif + +VERSION ?= +LDFLAGS := $(if $(VERSION),-ldflags "-X 'github.com/getlantern/radiance/common.Version=$(VERSION)'") + +.PHONY: build-daemon build-daemon: - go build -tags "$(TAGS)" -o ../bin/lanternd ./lanternd/lanternd.go + go build -tags "$(TAGS)" $(LDFLAGS) -o ../bin/$(LANTERND) ./lanternd +.PHONY: run-daemon run-daemon: - go run -tags=$(TAGS) ./lanternd/lanternd.go $(args) + go run -tags=$(TAGS) ./lanternd run \ + $(if $(data-path),--data-path=$(data-path)) \ + $(if $(log-path),--log-path=$(log-path)) \ + $(if $(log-level),--log-level=$(log-level)) + +.PHONY: build-cli +build-cli: +ifeq ($(UNAME_S),Darwin) + go build -tags "standalone" $(LDFLAGS) -o ../bin/$(LANTERN) ./lantern +else + go build $(LDFLAGS) -o ../bin/$(LANTERN) ./lantern +endif + +.PHONY: build +build: build-daemon build-cli diff --git a/cmd/justfile b/cmd/justfile new file mode 100644 index 00000000..a2fb8b55 --- /dev/null +++ b/cmd/justfile @@ -0,0 +1,19 @@ +base_tags := "with_gvisor,with_quic,with_dhcp,with_wireguard,with_utls,with_acme,with_clash_api,with_tailscale,with_conntrack" +tags := if os() == "macos" { "standalone," + base_tags } else { base_tags } +lanternd := if os() == "windows" { "lanternd.exe" } else { "lanternd" } +lantern := if os() == "windows" { "lantern.exe" } else { "lantern" } +version := env("VERSION", "") +ldflags := if version != "" { "-ldflags \"-X 'github.com/getlantern/radiance/common.Version=" + version + "'\"" } else { "" } + +build-daemon: + go build -tags "{{tags}}" {{ldflags}} -o ../bin/{{lanternd}} ./lanternd + +run-daemon *args: + go run -tags={{tags}} ./lanternd run {{args}} + +cli_tags := if os() == "macos" { "standalone" } else { "" } + +build-cli: + go build {{ if cli_tags != "" { "-tags " + cli_tags } else { "" } }} {{ldflags}} -o ../bin/{{lantern}} ./lantern + +build: build-daemon build-cli diff --git a/cmd/kindling-tester/main.go b/cmd/kindling-tester/main.go index 77e1e036..d1b8569b 100644 --- a/cmd/kindling-tester/main.go +++ b/cmd/kindling-tester/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "fmt" "io" "log/slog" @@ -13,7 +12,7 @@ import ( "github.com/getlantern/radiance/kindling" ) -func performKindlingPing(ctx context.Context, urlToHit string, runID string, deviceID string, userID int64, token string, dataDir string) error { +func performKindlingPing(urlToHit string, runID string, deviceID string, userID int64, token string, dataDir string) error { os.MkdirAll(dataDir, 0o755) settings.Set(settings.DataPathKey, dataDir) settings.Set(settings.UserIDKey, userID) @@ -28,14 +27,14 @@ func performKindlingPing(ctx context.Context, urlToHit string, runID string, dev }) t1 := time.Now() - newK, err := kindling.NewKindling() + newK, err := kindling.NewKindling(dataDir) if err != nil { slog.Error("failed to initialize kindling", slog.Any("error", err)) } if newK != nil { kindling.SetKindling(newK) } - defer kindling.Close(ctx) + defer kindling.Close() cli := kindling.HTTPClient() t2 := time.Now() @@ -61,7 +60,7 @@ func performKindlingPing(ctx context.Context, urlToHit string, runID string, dev if err := os.WriteFile(dataDir+"/output.txt", responseBody, 0o644); err != nil { slog.Error("failed to write output file", slog.Any("error", err), slog.String("path", dataDir+"/output.txt")) } - return os.WriteFile(dataDir+"/timing.txt", []byte(fmt.Sprintf(` + return os.WriteFile(dataDir+"/timing.txt", fmt.Appendf([]byte{}, ` result: %v run-id: %s err: %v @@ -69,7 +68,7 @@ func performKindlingPing(ctx context.Context, urlToHit string, runID string, dev connected: %d fetched: %d url: %s`, - true, runID, nil, t1, int32(t2.Sub(t1).Milliseconds()), int32(t3.Sub(t1).Milliseconds()), urlToHit)), 0o644) + true, runID, nil, t1, int32(t2.Sub(t1).Milliseconds()), int32(t3.Sub(t1).Milliseconds()), urlToHit), 0o644) } func main() { @@ -101,8 +100,6 @@ func main() { } } - ctx := context.Background() - // disabling all other transports before enabling the selected for name := range kindling.EnabledTransports { kindling.EnabledTransports[name] = false @@ -110,7 +107,7 @@ func main() { kindling.EnabledTransports[transport] = true slog.Debug("enabled transports", slog.Any("enabled_transports", kindling.EnabledTransports)) - if err := performKindlingPing(ctx, targetURL, runID, deviceID, uid, token, data); err != nil { + if err := performKindlingPing(targetURL, runID, deviceID, uid, token, data); err != nil { slog.Error("failed to perform kindling ping", slog.Any("error", err)) os.Exit(1) } diff --git a/cmd/lantern/account.go b/cmd/lantern/account.go new file mode 100644 index 00000000..f1148cdf --- /dev/null +++ b/cmd/lantern/account.go @@ -0,0 +1,345 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strconv" + "strings" + "syscall" + + "golang.org/x/term" + + "github.com/getlantern/radiance/ipc" +) + +type AccountCmd struct { + Login *LoginCmd `arg:"subcommand:login" help:"log in to your account"` + Logout *LogoutCmd `arg:"subcommand:logout" help:"log out of your account"` + Signup *SignupCmd `arg:"subcommand:signup" help:"create a new account"` + Recover *RecoverAccountCmd `arg:"subcommand:recover" help:"recover existing account"` + + Usage *UsageCmd `arg:"subcommand:usage" help:"view data usage"` + Devices *DevicesCmd `arg:"subcommand:devices" help:"manage user devices"` + SetEmail *SetEmailCmd `arg:"subcommand:set-email" help:"change account email"` +} + +type LoginCmd struct { + OAuth bool `arg:"--oauth" help:"log in with OAuth provider"` + Provider string `arg:"--provider" help:"OAuth provider"` +} + +type LogoutCmd struct{} + +type SignupCmd struct{} + +type RecoverAccountCmd struct{} + +type SetEmailCmd struct{} + +type UsageCmd struct{} + +type DevicesCmd struct { + List bool `arg:"-l,--list" help:"list user devices"` + Remove string `arg:"-r,--remove" help:"remove a device by ID"` +} + +func runAccount(ctx context.Context, c *ipc.Client, cmd *AccountCmd) error { + switch { + case cmd.Login != nil: + return accountLogin(ctx, c, cmd.Login) + case cmd.Logout != nil: + return accountLogout(ctx, c) + case cmd.Signup != nil: + return accountSignup(ctx, c) + case cmd.Recover != nil: + return accountRecover(ctx, c) + case cmd.Usage != nil: + return accountDataUsage(ctx, c) + case cmd.Devices != nil: + return accountDevices(ctx, c, cmd.Devices) + case cmd.SetEmail != nil: + return accountSetEmail(ctx, c) + default: + return fmt.Errorf("no subcommand specified") + } +} + +// isLoggedIn returns the current user's email if logged in, or empty string if not. +func isLoggedIn(ctx context.Context, c *ipc.Client) (string, error) { + userData, err := c.UserData(ctx) + if err != nil { + return "", err + } + return userData.GetLegacyUserData().GetEmail(), nil +} + +func requireLoggedOut(ctx context.Context, c *ipc.Client) error { + email, err := isLoggedIn(ctx, c) + if err != nil { + return fmt.Errorf("failed to check login status: %w", err) + } + if email != "" { + return fmt.Errorf("already logged in as %s — log out first", email) + } + return nil +} + +func requireLoggedIn(ctx context.Context, c *ipc.Client) (string, error) { + email, err := isLoggedIn(ctx, c) + if err != nil { + return "", fmt.Errorf("failed to check login status: %w", err) + } + if email == "" { + return "", fmt.Errorf("no user is currently logged in") + } + return email, nil +} + +func accountLogin(ctx context.Context, c *ipc.Client, cmd *LoginCmd) error { + if err := requireLoggedOut(ctx, c); err != nil { + return err + } + + if cmd.OAuth { + provider := cmd.Provider + if provider == "" { + provider = "google" + } + url, err := c.OAuthLoginURL(ctx, provider) + if err != nil { + return err + } + fmt.Println("Open this URL in your browser to log in:") + fmt.Println(url) + fmt.Print("Enter OAuth token: ") + token, err := readLine() + if err != nil { + return err + } + userData, err := c.OAuthLoginCallback(ctx, token) + if err != nil { + return err + } + return printJSON(userData) + } + + email, err := prompt("Email: ") + if err != nil { + return err + } + password, err := promptPassword("Password: ") + if err != nil { + return err + } + + userData, err := c.Login(ctx, email, password) + if err != nil { + return err + } + fmt.Println("Logged in successfully.") + return printJSON(userData) +} + +func accountLogout(ctx context.Context, c *ipc.Client) error { + email, err := requireLoggedIn(ctx, c) + if err != nil { + return err + } + _, err = c.Logout(ctx, email) + if err != nil { + return err + } + fmt.Println("Logged out successfully.") + return nil +} + +func accountSignup(ctx context.Context, c *ipc.Client) error { + if err := requireLoggedOut(ctx, c); err != nil { + return err + } + + email, err := prompt("Email: ") + if err != nil { + return err + } + password, err := promptPassword("Password: ") + if err != nil { + return err + } + confirm, err := promptPassword("Confirm password: ") + if err != nil { + return err + } + if password != confirm { + return fmt.Errorf("passwords do not match") + } + + _, resp, err := c.SignUp(ctx, email, password) + if err != nil { + return err + } + fmt.Println("Account created successfully.") + + fmt.Println("A confirmation code has been sent to your email.") + code, err := prompt("Confirmation code: ") + if err != nil { + return err + } + if err := c.SignupEmailConfirmation(ctx, email, code); err != nil { + return fmt.Errorf("email confirmation failed: %w", err) + } + fmt.Println("Email confirmed.") + _ = resp + return nil +} + +func accountRecover(ctx context.Context, c *ipc.Client) error { + if _, err := requireLoggedIn(ctx, c); err != nil { + return err + } + + email, err := prompt("Email: ") + if err != nil { + return err + } + + if err := c.StartRecoveryByEmail(ctx, email); err != nil { + return err + } + fmt.Println("A recovery code has been sent to your email.") + + code, err := prompt("Recovery code: ") + if err != nil { + return err + } + if err := c.ValidateEmailRecoveryCode(ctx, email, code); err != nil { + return fmt.Errorf("invalid recovery code: %w", err) + } + + newPassword, err := promptPassword("New password: ") + if err != nil { + return err + } + confirm, err := promptPassword("Confirm new password: ") + if err != nil { + return err + } + if newPassword != confirm { + return fmt.Errorf("passwords do not match") + } + + if err := c.CompleteRecoveryByEmail(ctx, email, newPassword, code); err != nil { + return err + } + fmt.Println("Account recovered successfully. You can now log in with your new password.") + return nil +} + +func accountSetEmail(ctx context.Context, c *ipc.Client) error { + if _, err := requireLoggedIn(ctx, c); err != nil { + return err + } + + newEmail, err := prompt("New email: ") + if err != nil { + return err + } + password, err := promptPassword("Password: ") + if err != nil { + return err + } + + if err := c.StartChangeEmail(ctx, newEmail, password); err != nil { + return err + } + fmt.Println("A confirmation code has been sent to your new email.") + + code, err := prompt("Confirmation code: ") + if err != nil { + return err + } + if err := c.CompleteChangeEmail(ctx, newEmail, password, code); err != nil { + return err + } + fmt.Println("Email changed successfully.") + return nil +} + +func accountDataUsage(ctx context.Context, c *ipc.Client) error { + info, err := c.DataCapInfo(ctx) + if err != nil { + return err + } + fmt.Printf("Enabled: %t\n", info.Enabled) + if !info.Enabled { + return nil + } + if info.Usage == nil { + return fmt.Errorf("data usage info is unavailable") + } + bytesAllowed, _ := strconv.Atoi(info.Usage.BytesAllotted) + BytesUsed, _ := strconv.Atoi(info.Usage.BytesUsed) + resetTime := info.Usage.AllotmentEndTime + + fmt.Printf( + "Data used: %.2f MB / %.2f MB (%.2f%%)\n", + float64(BytesUsed)/1e6, float64(bytesAllowed)/1e6, + 100*float64(BytesUsed)/float64(bytesAllowed), + ) + fmt.Printf("Resets at: %s\n", resetTime) + return nil +} + +func accountDevices(ctx context.Context, c *ipc.Client, cmd *DevicesCmd) error { + if _, err := requireLoggedIn(ctx, c); err != nil { + return err + } + + switch { + case cmd.Remove != "": + resp, err := c.RemoveDevice(ctx, cmd.Remove) + if err != nil { + return err + } + fmt.Println("Device removed.") + return printJSON(resp) + default: + // Default to listing devices + devices, err := c.UserDevices(ctx) + if err != nil { + return err + } + return printJSON(devices) + } +} + +// prompt prints a prompt and reads a line of input from stdin. +func prompt(label string) (string, error) { + fmt.Print(label) + return readLine() +} + +// readLine reads a single line from stdin, trimming the trailing newline. +func readLine() (string, error) { + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return "", err + } + return "", fmt.Errorf("unexpected end of input") + } + return strings.TrimSpace(scanner.Text()), nil +} + +// promptPassword prints a prompt and reads a password without echoing it. +func promptPassword(label string) (string, error) { + fmt.Print(label) + password, err := term.ReadPassword(int(syscall.Stdin)) + fmt.Println() // newline after hidden input + if err != nil { + return "", fmt.Errorf("failed to read password: %w", err) + } + return string(password), nil +} diff --git a/cmd/lantern/config.go b/cmd/lantern/config.go new file mode 100644 index 00000000..9e4228a0 --- /dev/null +++ b/cmd/lantern/config.go @@ -0,0 +1,13 @@ +package main + +import ( + "context" + + "github.com/getlantern/radiance/ipc" +) + +type UpdateConfigCmd struct{} + +func runUpdateConfig(ctx context.Context, c *ipc.Client) error { + return c.UpdateConfig(ctx) +} diff --git a/cmd/lantern/ip.go b/cmd/lantern/ip.go new file mode 100644 index 00000000..5e40b38d --- /dev/null +++ b/cmd/lantern/ip.go @@ -0,0 +1,69 @@ +package main + +import ( + "context" + "fmt" + "net/netip" + "time" + + "github.com/getlantern/publicip" +) + +var ( + // list of extra public IP services to query in addition to the default ones provided by the publicip package + ipURLs = []string{ + "https://ip.me", + "https://ifconfig.me/ip", + "https://checkip.amazonaws.com", + "https://ifconfig.io/ip", + "https://ident.me", + "https://ipinfo.io/ip", + } + + publicIPCfg = &publicip.Config{ + Timeout: 5 * time.Second, + MinConsensus: 2, + Methods: publicip.DefaultMethods(), + } +) + +func init() { + for _, url := range ipURLs { + publicIPCfg.Methods = append(publicIPCfg.Methods, publicip.NewHTTP(url, publicip.FormatPlainText)) + } +} + +type IPCmd struct{} + +func runIP(ctx context.Context) error { + tctx, tcancel := context.WithTimeout(ctx, 10*time.Second) + defer tcancel() + ip, err := getPublicIP(tctx) + if err != nil { + return err + } + fmt.Println(ip) + return nil +} + +// fakeIPRange is the CIDR used by sing-box's fake-ip DNS. Addresses in this +// range can briefly appear as the "public IP" right after the VPN connects, +// before the tunnel is fully established. +var fakeIPRange = netip.MustParsePrefix("198.18.0.0/15") + +// getPublicIP fetches the public IP address +func getPublicIP(ctx context.Context) (string, error) { + result, err := publicip.Detect(ctx, publicIPCfg) + if err != nil { + return "", err + } + ip := result.IP + addr, ok := netip.AddrFromSlice(ip) + if ok { + addr = addr.Unmap() // normalize IPv4-mapped IPv6 to IPv4 + } + if ip.IsPrivate() || ip.IsLoopback() || ip.IsUnspecified() || (ok && fakeIPRange.Contains(addr)) { + return "", fmt.Errorf("detected IP is not a valid public IP: %s", ip.String()) + } + return ip.String(), nil +} diff --git a/cmd/lantern/lantern.go b/cmd/lantern/lantern.go new file mode 100644 index 00000000..25ac2c3a --- /dev/null +++ b/cmd/lantern/lantern.go @@ -0,0 +1,130 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "os/signal" + "syscall" + + "context" + + "github.com/alexflint/go-arg" + + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/ipc" + "github.com/getlantern/radiance/issue" + rlog "github.com/getlantern/radiance/log" +) + +type args struct { + Connect *ConnectCmd `arg:"subcommand:connect" help:"connect to VPN"` + Disconnect *DisconnectCmd `arg:"subcommand:disconnect" help:"disconnect VPN"` + Status *StatusCmd `arg:"subcommand:status" help:"show VPN status"` + Servers *ServersCmd `arg:"subcommand:servers" help:"manage servers"` + Features *FeaturesCmd `arg:"subcommand:features" help:"list available features and their status"` + Set *SetCmd `arg:"subcommand:set" help:"update one or more settings"` + Get *GetCmd `arg:"subcommand:get" help:"show one or all settings"` + UpdateConfig *UpdateConfigCmd `arg:"subcommand:update-config" help:"force an immediate config fetch"` + SplitTunnel *SplitTunnelCmd `arg:"subcommand:split-tunnel" help:"split-tunnel filter management"` + Account *AccountCmd `arg:"subcommand:account" help:"login, signup, user data, devices, recovery"` + Subscription *SubscriptionCmd `arg:"subcommand:subscription" help:"plans, payments, and billing"` + ReportIssue *ReportIssueCmd `arg:"subcommand:report-issue" help:"report an issue"` + Logs *LogsCmd `arg:"subcommand:logs" help:"tail daemon logs"` + IP *IPCmd `arg:"subcommand:ip" help:"show public IP address"` + Version *VersionCmd `arg:"subcommand:version" help:"print version"` +} + +func (args) Description() string { + return "Lantern CLI — command-line interface for the Lantern VPN daemon" +} + +type ReportIssueCmd struct { + Type int `arg:"-t,--type,required" help:"0=purchase 1=signin 2=spinner 3=blocked-sites 4=slow 5=link-device 6=crash 9=other 10=update"` + Description string `arg:"-d,--desc,required" help:"issue description"` + Email string `arg:"-e,--email" help:"email address"` +} + +func runReportIssue(ctx context.Context, c *ipc.Client, cmd *ReportIssueCmd) error { + return c.ReportIssue(ctx, issue.IssueType(cmd.Type), cmd.Description, cmd.Email, nil) +} + +type LogsCmd struct{} + +func tailLogs(ctx context.Context, c *ipc.Client) error { + err := c.TailLogs(ctx, func(entry rlog.LogEntry) { + fmt.Println(entry) + }) + if ctx.Err() != nil { + fmt.Fprintln(os.Stderr, "\nStopped tailing logs.") + return nil + } + return err +} + +type VersionCmd struct{} + +func main() { + var a args + p := arg.MustParse(&a) + if p.Subcommand() == nil { + p.WriteHelp(os.Stdout) + os.Exit(1) + } + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + client := ipc.NewClient() + defer client.Close() + + if err := run(ctx, client, &a); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n\n", err) + p.WriteHelpForSubcommand(os.Stdout, p.SubcommandNames()...) + os.Exit(1) + } +} + +func run(ctx context.Context, c *ipc.Client, a *args) error { + switch { + case a.Connect != nil: + return vpnConnect(ctx, c, a.Connect.Name, a.Connect.Wait) + case a.Disconnect != nil: + return c.DisconnectVPN(ctx) + case a.Status != nil: + return vpnStatus(ctx, c) + case a.Servers != nil: + return runServers(ctx, c, a.Servers) + case a.Features != nil: + return runFeatures(ctx, c) + case a.Set != nil: + return runSet(ctx, c, a.Set) + case a.Get != nil: + return runGet(ctx, c, a.Get) + case a.UpdateConfig != nil: + return runUpdateConfig(ctx, c) + case a.SplitTunnel != nil: + return runSplitTunnel(ctx, c, a.SplitTunnel) + case a.Account != nil: + return runAccount(ctx, c, a.Account) + case a.Subscription != nil: + return runSubscription(ctx, c, a.Subscription) + case a.ReportIssue != nil: + return runReportIssue(ctx, c, a.ReportIssue) + case a.Logs != nil: + return tailLogs(ctx, c) + case a.IP != nil: + return runIP(ctx) + case a.Version != nil: + fmt.Println(common.Version) + return nil + default: + return fmt.Errorf("no subcommand specified") + } +} + +func printJSON(v any) error { + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + return enc.Encode(v) +} diff --git a/cmd/lantern/servers.go b/cmd/lantern/servers.go new file mode 100644 index 00000000..920110db --- /dev/null +++ b/cmd/lantern/servers.go @@ -0,0 +1,155 @@ +package main + +import ( + "context" + "fmt" + "strings" + + C "github.com/getlantern/common" + + "github.com/getlantern/radiance/ipc" + "github.com/getlantern/radiance/servers" + "github.com/getlantern/radiance/vpn" +) + +type ServersCmd struct { + Show string `arg:"-s,--show" help:"display server by tag"` + AddJSON string `arg:"--add-json" help:"add servers from JSON config"` + AddURL string `arg:"--add-url" help:"add servers from comma-separated URLs"` + SkipCertVerify bool `arg:"--skip-cert-verify" help:"skip cert verification (with --add-url)"` + Remove string `arg:"--remove" help:"comma-separated list of servers to remove"` + List bool `arg:"-l,--list" help:"list servers"` + Latency bool `arg:"--latency" help:"include URL test latency results (with --list)"` + + PrivateServer *PrivateServerCmd `arg:"subcommand:private" help:"private server operations"` +} + +type PrivateServerCmd struct { + Add string `arg:"-a,--add" help:"add private server with given tag"` + Invite string `arg:"-i,--invite" help:"invite to private server"` + RevokeInvite string `arg:"-r,--revoke-invite" help:"revoke invite"` + IP string `arg:"--ip" help:"server IP"` + Port int `arg:"--port" help:"server port"` + Token string `arg:"--token" help:"access token"` +} + +func runServers(ctx context.Context, c *ipc.Client, cmd *ServersCmd) error { + switch { + case cmd.Show != "": + return serversGet(ctx, c, cmd.Show) + case cmd.AddJSON != "": + return printAddedServers(c.AddServersByJSON(ctx, cmd.AddJSON)) + case cmd.AddURL != "": + urls := strings.Split(cmd.AddURL, ",") + return printAddedServers(c.AddServersByURL(ctx, urls, cmd.SkipCertVerify)) + case cmd.Remove != "": + return serversRemove(ctx, c, cmd.Remove) + case cmd.List: + return serversList(ctx, c, cmd.Latency) + case cmd.PrivateServer != nil: + return runPrivateServer(ctx, c, cmd.PrivateServer) + default: + return fmt.Errorf("must specify one of --get, --add-json, --add-url, --remove, or --list") + } +} + +func runPrivateServer(ctx context.Context, c *ipc.Client, cmd *PrivateServerCmd) error { + switch { + case cmd.Add != "": + return c.AddPrivateServer(ctx, cmd.Add, cmd.IP, cmd.Port, cmd.Token) + case cmd.Invite != "": + code, err := c.InviteToPrivateServer(ctx, cmd.IP, cmd.Port, cmd.Token, cmd.Invite) + if err != nil { + return err + } + fmt.Println(code) + return nil + case cmd.RevokeInvite != "": + return c.RevokePrivateServerInvite(ctx, cmd.IP, cmd.Port, cmd.Token, cmd.RevokeInvite) + default: + return fmt.Errorf("must specify one of --add, --invite, or --revoke-invite") + } +} + +func serversList(ctx context.Context, c *ipc.Client, showLatency bool) error { + srvs, err := c.Servers(ctx) + if err != nil { + return err + } + if len(srvs) == 0 { + fmt.Println("No servers available") + return nil + } + for _, s := range srvs { + printServerEntry(s, showLatency) + } + return nil +} + +func printServerEntry(s *servers.Server, showLatency bool) { + fmt.Printf(" %s [%s]", s.Tag, s.Type) + if s.Location != (C.ServerLocation{}) { + fmt.Printf(" — %s, %s", s.Location.City, s.Location.Country) + } + if !showLatency { + fmt.Println() + return + } + if s.URLTestResult != nil { + fmt.Printf(" (%dms)\n", s.URLTestResult.Delay) + } else { + fmt.Println(" (n/a)") + } +} + +func serversGet(ctx context.Context, c *ipc.Client, tag string) error { + svr, exists, err := c.GetServerByTag(ctx, tag) + if err != nil { + return err + } + if !exists { + fmt.Println("Server not found") + return nil + } + return printJSON(svr) +} + +func serversSelected(ctx context.Context, c *ipc.Client) error { + svr, exists, err := c.SelectedServer(ctx) + if err != nil { + return err + } + if !exists { + fmt.Println("No server selected") + return nil + } + return printJSON(svr) +} + +func serversAutoSelections(ctx context.Context, c *ipc.Client, watch bool) error { + if watch { + return c.AutoSelectedEvents(ctx, func(ev vpn.AutoSelectedEvent) { + s := ev.Selected + fmt.Printf("Selected: %s\n", s) + }) + } + sel, err := c.AutoSelected(ctx) + if err != nil { + return err + } + fmt.Printf("Selected: %s\n", sel.Tag) + return nil +} + +func printAddedServers(tags []string, err error) error { + if err != nil { + return err + } + fmt.Printf("Added %d server(s): %s\n", len(tags), strings.Join(tags, ", ")) + return nil +} + +func serversRemove(ctx context.Context, c *ipc.Client, tags string) error { + tagList := strings.Split(tags, ",") + return c.RemoveServers(ctx, tagList) +} diff --git a/cmd/lantern/settings.go b/cmd/lantern/settings.go new file mode 100644 index 00000000..6ef512b3 --- /dev/null +++ b/cmd/lantern/settings.go @@ -0,0 +1,122 @@ +package main + +import ( + "context" + "fmt" + + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/ipc" + rlog "github.com/getlantern/radiance/log" +) + +type FeaturesCmd struct{} + +func runFeatures(ctx context.Context, c *ipc.Client) error { + f, err := c.Features(ctx) + if err != nil { + return err + } + for k, v := range f { + fmt.Printf("%s: %v\n", k, v) + } + return nil +} + +// settingViews is the single source of truth for which settings the CLI exposes +// under `set`/`get` and how their user-facing values map to the underlying +// settings keys. +var settingViews = []struct { + name string + get func(settings.Settings) any +}{ + {"smart-routing", func(s settings.Settings) any { return orBool(s[settings.SmartRoutingKey]) }}, + {"ad-block", func(s settings.Settings) any { return orBool(s[settings.AdBlockKey]) }}, + {"telemetry", func(s settings.Settings) any { return orBool(s[settings.TelemetryKey]) }}, + {"split-tunnel", func(s settings.Settings) any { return orBool(s[settings.SplitTunnelKey]) }}, + {"config-fetch", func(s settings.Settings) any { return !toBool(s[settings.ConfigFetchDisabledKey]) }}, + {"log-level", func(s settings.Settings) any { return orString(s[settings.LogLevelKey]) }}, +} + +type SetCmd struct { + SmartRouting *bool `arg:"--smart-routing" help:"enable or disable smart routing (true|false)"` + AdBlock *bool `arg:"--ad-block" help:"enable or disable ad blocking (true|false)"` + Telemetry *bool `arg:"--telemetry" help:"enable or disable telemetry (true|false)"` + SplitTunnel *bool `arg:"--split-tunnel" help:"enable or disable split tunneling (true|false)"` + ConfigFetch *bool `arg:"--config-fetch" help:"enable or disable periodic config fetching (true|false)"` + LogLevel *string `arg:"--log-level" help:"log level (trace|debug|info|warn|error|fatal|panic|disable)"` +} + +func runSet(ctx context.Context, c *ipc.Client, cmd *SetCmd) error { + updates := settings.Settings{} + if cmd.SmartRouting != nil { + updates[settings.SmartRoutingKey] = *cmd.SmartRouting + } + if cmd.AdBlock != nil { + updates[settings.AdBlockKey] = *cmd.AdBlock + } + if cmd.Telemetry != nil { + updates[settings.TelemetryKey] = *cmd.Telemetry + } + if cmd.SplitTunnel != nil { + updates[settings.SplitTunnelKey] = *cmd.SplitTunnel + } + if cmd.ConfigFetch != nil { + updates[settings.ConfigFetchDisabledKey] = !*cmd.ConfigFetch + } + if cmd.LogLevel != nil { + if _, err := rlog.ParseLogLevel(*cmd.LogLevel); err != nil { + return err + } + updates[settings.LogLevelKey] = *cmd.LogLevel + } + if len(updates) == 0 { + return fmt.Errorf("no settings provided; pass one or more flags (see `lantern set --help`)") + } + _, err := c.PatchSettings(ctx, updates) + return err +} + +type GetCmd struct { + Name string `arg:"positional" help:"setting name (smart-routing, ad-block, telemetry, split-tunnel, config-fetch, log-level); omit to list all"` +} + +func runGet(ctx context.Context, c *ipc.Client, cmd *GetCmd) error { + s, err := c.Settings(ctx) + if err != nil { + return err + } + if cmd.Name == "" { + for _, v := range settingViews { + fmt.Printf("%s: %v\n", v.name, v.get(s)) + } + return nil + } + for _, v := range settingViews { + if v.name == cmd.Name { + fmt.Printf("%s: %v\n", v.name, v.get(s)) + return nil + } + } + return fmt.Errorf("unknown setting %q", cmd.Name) +} + +func orBool(v any) any { + if v == nil { + return false + } + return v +} + +func orString(v any) any { + if v == nil { + return "" + } + return v +} + +func toBool(v any) bool { + if v == nil { + return false + } + return fmt.Sprintf("%v", v) == "true" +} diff --git a/cmd/lantern/split_tunnel.go b/cmd/lantern/split_tunnel.go new file mode 100644 index 00000000..24fe9660 --- /dev/null +++ b/cmd/lantern/split_tunnel.go @@ -0,0 +1,129 @@ +package main + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/ipc" + "github.com/getlantern/radiance/vpn" +) + +type SplitTunnelCmd struct { + List *SplitTunnelListCmd `arg:"subcommand:list" help:"list current filters"` + Add *SplitTunnelAddCmd `arg:"subcommand:add" help:"add a filter"` + Remove *SplitTunnelRemoveCmd `arg:"subcommand:remove" help:"remove a filter"` +} + +type SplitTunnelListCmd struct{} + +type SplitTunnelAddCmd struct { + Type string `arg:"-t,--type,required" help:"filter type: domain, domain-suffix, domain-keyword, domain-regex, process-name, process-path, process-path-regex, package-name"` + Value string `arg:"-v,--value,required" help:"filter value (e.g. example.com)"` +} + +type SplitTunnelRemoveCmd struct { + Type string `arg:"-t,--type,required" help:"filter type: domain, domain-suffix, domain-keyword, domain-regex, process-name, process-path, process-path-regex, package-name"` + Value string `arg:"-v,--value,required" help:"filter value (e.g. example.com)"` +} + +func runSplitTunnel(ctx context.Context, c *ipc.Client, cmd *SplitTunnelCmd) error { + switch { + case cmd.Add != nil: + typ := filterTypeFromArg(cmd.Add.Type) + return c.AddSplitTunnelItems(ctx, buildFilter(typ, cmd.Add.Value)) + case cmd.Remove != nil: + typ := filterTypeFromArg(cmd.Remove.Type) + return c.RemoveSplitTunnelItems(ctx, buildFilter(typ, cmd.Remove.Value)) + default: + return splitTunnelList(ctx, c) + } +} + +func splitTunnelList(ctx context.Context, c *ipc.Client) error { + s, err := c.Settings(ctx) + if err != nil { + return err + } + enabled, _ := strconv.ParseBool(fmt.Sprintf("%v", s[settings.SplitTunnelKey])) + fmt.Printf("Split tunneling: %v\n", enabled) + filters, err := c.SplitTunnelFilters(ctx) + if err != nil { + return err + } + printFilters(filters) + return nil +} + +func printFilters(f vpn.SplitTunnelFilter) { + type entry struct { + label string + values []string + } + entries := []entry{ + {"domain", f.Domain}, + {"domain-suffix", f.DomainSuffix}, + {"domain-keyword", f.DomainKeyword}, + {"domain-regex", f.DomainRegex}, + {"process-name", f.ProcessName}, + {"process-path", f.ProcessPath}, + {"process-path-regex", f.ProcessPathRegex}, + {"package-name", f.PackageName}, + } + hasAny := false + for _, e := range entries { + for _, v := range e.values { + if !hasAny { + fmt.Println("Filters:") + hasAny = true + } + fmt.Printf(" %s: %s\n", e.label, v) + } + } + if !hasAny { + fmt.Println("Filters: none") + } +} + +// parseFilter splits "TYPE:VALUE" into the internal filter type and value. +func parseFilter(spec string) (string, string, error) { + typ, val, ok := strings.Cut(spec, ":") + if !ok || val == "" { + return "", "", fmt.Errorf("filter format: TYPE:VALUE (e.g. domain-suffix:example.com)") + } + return filterTypeFromArg(typ), val, nil +} + +// filterTypeFromArg converts a CLI arg like "domain-suffix" to the internal type "domainSuffix". +func filterTypeFromArg(a string) string { + s, rest, _ := strings.Cut(a, "-") + if rest != "" { + s += strings.ToUpper(rest[:1]) + rest[1:] + } + return s +} + +func buildFilter(filterType, value string) vpn.SplitTunnelFilter { + var f vpn.SplitTunnelFilter + switch filterType { + case vpn.TypeDomain: + f.Domain = []string{value} + case vpn.TypeDomainSuffix: + f.DomainSuffix = []string{value} + case vpn.TypeDomainKeyword: + f.DomainKeyword = []string{value} + case vpn.TypeDomainRegex: + f.DomainRegex = []string{value} + case vpn.TypeProcessName: + f.ProcessName = []string{value} + case vpn.TypeProcessPath: + f.ProcessPath = []string{value} + case vpn.TypeProcessPathRegex: + f.ProcessPathRegex = []string{value} + case vpn.TypePackageName: + f.PackageName = []string{value} + } + return f +} diff --git a/cmd/lantern/subscription.go b/cmd/lantern/subscription.go new file mode 100644 index 00000000..efb047ed --- /dev/null +++ b/cmd/lantern/subscription.go @@ -0,0 +1,247 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/ipc" +) + +type SubscriptionCmd struct { + Plans *SubscriptionPlansCmd `arg:"subcommand:plans" help:"list subscription plans for a channel"` + Activate *ActivateCmd `arg:"subcommand:activate" help:"activate with reseller code"` + StripeSub *StripeSubCmd `arg:"subcommand:stripe-sub" help:"create Stripe subscription"` + PaymentRedirect *PaymentRedirectCmd `arg:"subcommand:redirect" help:"get payment redirect URL"` + Referral *ReferralCmd `arg:"subcommand:referral" help:"attach referral code"` + StripeBilling *StripeBillingCmd `arg:"subcommand:stripe-billing" help:"get Stripe billing portal URL"` + Verify *VerifySubscriptionCmd `arg:"subcommand:verify" help:"verify subscription"` +} + +type SubscriptionPlansCmd struct { + Channel string `arg:"-c,--channel" help:"subscription channel"` +} + +type ActivateCmd struct { + Email string `arg:"-e,--email" help:"email address"` + Code string `arg:"-c,--code" help:"reseller code"` +} + +type StripeSubCmd struct { + Email string `arg:"-e,--email" help:"email address"` + PlanID string `arg:"-p,--plan" help:"plan ID"` +} + +type PaymentRedirectCmd struct { + PlanID string `arg:"-p,--plan" help:"plan ID"` + Provider string `arg:"-P,--provider" help:"payment provider"` + Email string `arg:"-e,--email" help:"email address"` + DeviceName string `arg:"-d,--device" help:"device name"` + BillingType string `arg:"-b,--billing-type" default:"subscription" help:"one_time or subscription"` +} + +type ReferralCmd struct { + Code string `arg:"-c,--code" help:"referral code"` +} + +type StripeBillingCmd struct{} + +type VerifySubscriptionCmd struct { + Service string `arg:"-s,--service" help:"stripe, apple, or google"` + VerifyData string `arg:"-d,--data" help:"verification data as JSON"` +} + +func runSubscription(ctx context.Context, c *ipc.Client, cmd *SubscriptionCmd) error { + switch { + case cmd.Plans != nil: + return subPlans(ctx, c, cmd.Plans) + case cmd.Activate != nil: + return subActivate(ctx, c, cmd.Activate) + case cmd.StripeSub != nil: + return subStripeSub(ctx, c, cmd.StripeSub) + case cmd.PaymentRedirect != nil: + return subRedirect(ctx, c, cmd.PaymentRedirect) + case cmd.Referral != nil: + return subReferral(ctx, c, cmd.Referral) + case cmd.StripeBilling != nil: + return subStripeBilling(ctx, c, cmd.StripeBilling) + case cmd.Verify != nil: + return subVerify(ctx, c, cmd.Verify) + default: + return fmt.Errorf("no subcommand specified") + } +} + +func subPlans(ctx context.Context, c *ipc.Client, cmd *SubscriptionPlansCmd) error { + channel := cmd.Channel + if channel == "" { + var err error + channel, err = prompt("Channel: ") + if err != nil { + return err + } + } + plans, err := c.SubscriptionPlans(ctx, channel) + if err != nil { + return err + } + fmt.Println(plans) + return nil +} + +func subActivate(ctx context.Context, c *ipc.Client, cmd *ActivateCmd) error { + email := cmd.Email + code := cmd.Code + var err error + if email == "" { + email, err = prompt("Email: ") + if err != nil { + return err + } + } + if code == "" { + code, err = prompt("Reseller code: ") + if err != nil { + return err + } + } + resp, err := c.ActivationCode(ctx, email, code) + if err != nil { + return err + } + return printJSON(resp) +} + +func subStripeSub(ctx context.Context, c *ipc.Client, cmd *StripeSubCmd) error { + email := cmd.Email + planID := cmd.PlanID + var err error + if email == "" { + email, err = prompt("Email: ") + if err != nil { + return err + } + } + if planID == "" { + planID, err = prompt("Plan ID: ") + if err != nil { + return err + } + } + secret, err := c.NewStripeSubscription(ctx, email, planID) + if err != nil { + return err + } + fmt.Println(secret) + return nil +} + +func promptRedirectData(planID, provider, email, deviceName, billingType string) (account.PaymentRedirectData, error) { + var err error + if planID == "" { + planID, err = prompt("Plan ID: ") + if err != nil { + return account.PaymentRedirectData{}, err + } + } + if provider == "" { + provider, err = prompt("Provider: ") + if err != nil { + return account.PaymentRedirectData{}, err + } + } + if email == "" { + email, err = prompt("Email: ") + if err != nil { + return account.PaymentRedirectData{}, err + } + } + if deviceName == "" { + deviceName, err = prompt("Device name: ") + if err != nil { + return account.PaymentRedirectData{}, err + } + } + if billingType == "" { + billingType = "subscription" + } + return account.PaymentRedirectData{ + Plan: planID, + Provider: provider, + Email: email, + DeviceName: deviceName, + BillingType: account.SubscriptionType(billingType), + }, nil +} + +func subRedirect(ctx context.Context, c *ipc.Client, cmd *PaymentRedirectCmd) error { + data, err := promptRedirectData(cmd.PlanID, cmd.Provider, cmd.Email, cmd.DeviceName, cmd.BillingType) + if err != nil { + return err + } + url, err := c.PaymentRedirect(ctx, data) + if err != nil { + return err + } + fmt.Println(url) + return nil +} + +func subReferral(ctx context.Context, c *ipc.Client, cmd *ReferralCmd) error { + code := cmd.Code + if code == "" { + var err error + code, err = prompt("Referral code: ") + if err != nil { + return err + } + } + ok, err := c.ReferralAttach(ctx, code) + if err != nil { + return err + } + if ok { + fmt.Println("Referral attached successfully") + } else { + fmt.Println("Referral was not attached") + } + return nil +} + +func subStripeBilling(ctx context.Context, c *ipc.Client, cmd *StripeBillingCmd) error { + url, err := c.StripeBillingPortalURL(ctx) + if err != nil { + return err + } + fmt.Println(url) + return nil +} + +func subVerify(ctx context.Context, c *ipc.Client, cmd *VerifySubscriptionCmd) error { + service := cmd.Service + verifyData := cmd.VerifyData + var err error + if service == "" { + service, err = prompt("Service (stripe, apple, or google): ") + if err != nil { + return err + } + } + if verifyData == "" { + verifyData, err = prompt("Verification data (JSON): ") + if err != nil { + return err + } + } + var data map[string]string + if err := json.Unmarshal([]byte(verifyData), &data); err != nil { + return fmt.Errorf("invalid JSON for verification data: %w", err) + } + result, err := c.VerifySubscription(ctx, account.SubscriptionService(service), data) + if err != nil { + return err + } + fmt.Println(result) + return nil +} diff --git a/cmd/lantern/vpn.go b/cmd/lantern/vpn.go new file mode 100644 index 00000000..cd83f14e --- /dev/null +++ b/cmd/lantern/vpn.go @@ -0,0 +1,103 @@ +package main + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/getlantern/radiance/ipc" + "github.com/getlantern/radiance/vpn" +) + +type ConnectCmd struct { + Name string `arg:"-n,--name" default:"auto" help:"server name to connect to"` + Wait bool `arg:"-w,--wait" default:"false" help:"wait for IP change after connecting"` +} + +type DisconnectCmd struct{} + +type StatusCmd struct{} + +func vpnConnect(ctx context.Context, c *ipc.Client, tag string, wait bool) error { + tctx, tcancel := context.WithTimeout(ctx, 5*time.Second) + var prevIP string + if wait { + prevIP, _ = getPublicIP(tctx) + } + tcancel() + + status, err := c.VPNStatus(ctx) + if err != nil { + return err + } + switch status { + case vpn.Connected: + if err := c.SelectServer(ctx, tag); err != nil { + return err + } + case vpn.Disconnected: + if err := c.ConnectVPN(ctx, tag); err != nil { + return err + } + default: + return fmt.Errorf("busy with VPN status: %s", status) + } + + fmt.Printf("Connected (tag: %s)\n", tag) + if !wait { + return nil + } + + fmt.Print("Waiting for IP change...") + waitCtx, waitCancel := context.WithTimeout(ctx, 30*time.Second) + defer waitCancel() + start := time.Now() + ip, err := waitForIPChange(waitCtx, prevIP, 100*time.Millisecond) + if err == nil && ip != "" { + fmt.Printf("\rPublic IP: %s (took %v)\n", ip, time.Since(start).Truncate(time.Millisecond)) + } else { + fmt.Printf("\rIP change not detected after %v\n", time.Since(start).Truncate(time.Second)) + } + return nil +} + +func waitForIPChange(ctx context.Context, current string, interval time.Duration) (string, error) { + for { + select { + case <-ctx.Done(): + return "", nil + case <-time.After(interval): + ip, err := getPublicIP(ctx) + if err != nil { + return "", nil + } + if ip != current { + return ip, nil + } + } + } +} + +func vpnStatus(ctx context.Context, c *ipc.Client) error { + status, err := c.VPNStatus(ctx) + if err != nil { + return err + } + line := string(status) + line = strings.ToUpper(line[:1]) + line[1:] // capitalize first letter + if status == vpn.Connected { + if sel, exists, err := c.SelectedServer(ctx); err == nil && exists { + line += "\nServer: " + sel.Tag + } else { + fmt.Printf("error getting selected server: err=%v, sel=%v, exists=%v\n", err, sel, exists) + } + } + tctx, tcancel := context.WithTimeout(ctx, 5*time.Second) + if ip, err := getPublicIP(tctx); err == nil { + line += "\nIP: " + ip + } + tcancel() + fmt.Println(line) + return nil +} diff --git a/cmd/lanternd/lanternd.go b/cmd/lanternd/lanternd.go index ba159528..63d7f2ab 100644 --- a/cmd/lanternd/lanternd.go +++ b/cmd/lanternd/lanternd.go @@ -1,78 +1,367 @@ package main import ( + "bufio" "context" - "flag" + "errors" "fmt" + "io" "log" "log/slog" "os" + "os/exec" "os/signal" + "path/filepath" "syscall" "time" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" + "github.com/alexflint/go-arg" + "github.com/getlantern/radiance/backend" "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/traces" + "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/ipc" + rlog "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/vpn" - "github.com/getlantern/radiance/vpn/ipc" ) -const tracerName = "github.com/getlantern/radiance/cmd/lanternd" +type runCmd struct { + DataPath string `arg:"--data-path" help:"path to store data"` + LogPath string `arg:"--log-path" help:"path to store logs"` + LogLevel string `arg:"--log-level" default:"info" help:"logging level (trace, debug, info, warn, error)"` +} -var ( - dataPath = flag.String("data-path", "$HOME/.lantern", "Path to store data") - logPath = flag.String("log-path", "$HOME/.lantern", "Path to store logs") - logLevel = flag.String("log-level", "info", "Logging level (trace, debug, info, warn, error)") -) +type installCmd struct { + DataPath string `arg:"--data-path" help:"path to store data"` + LogPath string `arg:"--log-path" help:"path to store logs"` + LogLevel string `arg:"--log-level" default:"info" help:"logging level (trace, debug, info, warn, error)"` +} + +type uninstallCmd struct{} + +type versionCmd struct{} + +type daemonArgs struct { + Run *runCmd `arg:"subcommand:run" help:"run the daemon"` + Install *installCmd `arg:"subcommand:install" help:"install as system service"` + Uninstall *uninstallCmd `arg:"subcommand:uninstall" help:"uninstall system service"` + Version *versionCmd `arg:"subcommand:version" help:"print version"` +} + +func (daemonArgs) Description() string { + return "lanternd — Lantern VPN daemon" +} + +func init() { + log.SetFlags(log.Lshortfile | log.LstdFlags) +} func main() { - flag.Parse() + if maybePlatformService() { + return + } - dataPath := os.ExpandEnv(*dataPath) - logPath := os.ExpandEnv(*logPath) - logLevel := *logLevel + var a daemonArgs + p := arg.MustParse(&a) + if p.Subcommand() == nil { + p.WriteHelp(os.Stdout) + os.Exit(1) + } - slog.Info("Starting lanternd", "version", common.Version, "dataPath", dataPath) - if err := common.Init(dataPath, logPath, logLevel); err != nil { - log.Fatalf("Failed to initialize common: %v\n", err) + defaultDataPath := internal.DefaultDataPath() + defaultLogPath := internal.DefaultLogPath() + var err error + switch { + case a.Run != nil: + dataPath := os.ExpandEnv(withDefault(a.Run.DataPath, defaultDataPath)) + logPath := os.ExpandEnv(withDefault(a.Run.LogPath, defaultLogPath)) + if os.Getenv("_LANTERND_CHILD") != "1" { + err = babysit(os.Args[1:], dataPath, logPath, a.Run.LogLevel) + break + } + ctx, cancel := context.WithCancel(context.Background()) + // Shut down on stdin closure (babysit parent signals us) or OS signal. + go func() { + io.Copy(io.Discard, os.Stdin) + cancel() + }() + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigCh + cancel() + // Restore default signal behavior so a second signal terminates immediately. + signal.Reset(syscall.SIGINT, syscall.SIGTERM) + }() + err = runDaemon(ctx, dataPath, logPath, a.Run.LogLevel) + case a.Install != nil: + err = install( + os.ExpandEnv(withDefault(a.Install.DataPath, defaultDataPath)), + os.ExpandEnv(withDefault(a.Install.LogPath, defaultLogPath)), + a.Install.LogLevel, + ) + case a.Uninstall != nil: + err = uninstall() + case a.Version != nil: + fmt.Println(common.Version) + } + if err != nil { + log.Fatalf("Error: %v\n", err) + } +} + +func withDefault(val, def string) string { + if val == "" { + return def + } + return val +} + +// copyBin copies the current executable to binPath, creating parent directories +// as needed. It returns the destination path. +func copyBin() (string, error) { + src, err := os.Executable() + if err != nil { + return "", fmt.Errorf("failed to get executable path: %w", err) + } + src, err = filepath.EvalSymlinks(src) + if err != nil { + return "", fmt.Errorf("failed to resolve executable path: %w", err) + } + + dst := binPath + if src == dst { + return dst, nil + } + + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return "", fmt.Errorf("failed to create directory for %s: %w", dst, err) + } + + sf, err := os.Open(src) + if err != nil { + return "", fmt.Errorf("failed to open source binary: %w", err) + } + defer sf.Close() + + df, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) + if err != nil { + return "", fmt.Errorf("failed to create %s: %w", dst, err) + } + defer df.Close() + + if _, err := io.Copy(df, sf); err != nil { + return "", fmt.Errorf("failed to copy binary to %s: %w", dst, err) + } + + slog.Info("Copied binary", "src", src, "dst", dst) + return dst, nil +} + +// childProcess manages a daemon child process. The parent spawns the child, drains its output, +// and can signal graceful shutdown by closing its stdin pipe. If the child crashes, the parent +// cleans up stale VPN network state immediately. +type childProcess struct { + cmd *exec.Cmd + stdin io.Closer + done chan error + dataPath string + logger *slog.Logger +} + +// spawnChild creates and starts a daemon child process with piped I/O. The child's stdout and +// stderr are merged and drained through the provided logger (or os.Stdout as fallback). +func spawnChild(args []string, dataPath, logPath, logLevel string) (*childProcess, error) { + exe, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("failed to get executable path: %w", err) } - ipcServer, err := initIPC(dataPath, logPath, logLevel) + cmd := exec.Command(exe, args...) + cmd.Env = append(os.Environ(), "_LANTERND_CHILD=1") + stdinPipe, err := cmd.StdinPipe() if err != nil { - log.Fatalf("Failed to initialize IPC: %v\n", err) + return nil, fmt.Errorf("failed to create stdin pipe: %w", err) } + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + cmd.Stderr = cmd.Stdout // merge stderr into the same pipe + + logger := rlog.NewLogger(rlog.Config{ + LogPath: filepath.Join(logPath, internal.LogFileName), + Level: logLevel, + Prod: true, + DisablePublisher: true, + }) + + go func() { + defer stdoutPipe.Close() + var w io.Writer = os.Stdout + if h, ok := logger.Handler().(rlog.Handler); ok { + w = h.Writer() + } + scanner := bufio.NewScanner(stdoutPipe) + for scanner.Scan() { + if s := scanner.Text(); s != "" { + w.Write([]byte(s + "\n")) + } + } + if err := scanner.Err(); err != nil { + logger.Error("Error reading child process output", "error", err) + } + }() + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start daemon process: %w", err) + } + logger.Info("Started daemon process", "pid", cmd.Process.Pid) + + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + + return &childProcess{ + cmd: cmd, + stdin: stdinPipe, + done: done, + dataPath: dataPath, + logger: logger, + }, nil +} + +// RequestShutdown signals the child to shut down gracefully by closing its stdin pipe. +func (c *childProcess) RequestShutdown() { + c.logger.Info("Requesting child process shutdown") + c.stdin.Close() +} + +// Done returns a channel that receives the child's exit error (nil on clean exit). +func (c *childProcess) Done() <-chan error { + return c.done +} + +// WaitOrKill waits for the child to exit, killing it if it doesn't exit within the timeout. +func (c *childProcess) WaitOrKill(timeout time.Duration) error { + select { + case err := <-c.done: + return err + case <-time.After(timeout): + c.logger.Warn("Child did not exit in time, killing") + c.cmd.Process.Kill() + return <-c.done + } +} - // Wait for a signal to gracefully shut down. +// HandleCrash cleans up stale VPN network state left by a crashed child. +func (c *childProcess) HandleCrash(err error) { + c.logger.Warn("Daemon process exited unexpectedly, cleaning up network state", "error", err) + vpn.AttemptFixNetState() +} + +// babysit runs the daemon as a child process and monitors it. If the child exits unexpectedly +// (crash, panic, etc.), the parent immediately cleans up any stale VPN network state and +// automatically restarts the child process with exponential backoff. +// +// Graceful shutdown is signaled by closing the child's stdin pipe — this works cross-platform, +// including inside a Windows service where there is no console for signal delivery. +func babysit(args []string, dataPath, logPath, logLevel string) error { + // On termination signal, request graceful shutdown of the current child. sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - <-sigCh + stopping := false - slog.Info("Shutting down...") - time.AfterFunc(15*time.Second, func() { - log.Fatal("Failed to shut down in time, forcing exit.") - }) - ipcServer.Close() + const resetAfter = 2 * time.Minute // reset backoff if child ran longer than this + bo := common.NewBackoff(60 * time.Second) + + for { + child, err := spawnChild(args, dataPath, logPath, logLevel) + if err != nil { + if stopping { + return nil + } + return err + } + child.logger.Info("Monitoring daemon process") + startedAt := time.Now() + + // Wait for either a termination signal or child exit. + select { + case sig := <-sigCh: + stopping = true + child.logger.Info("Received signal, shutting down child", "signal", sig) + child.RequestShutdown() + err = child.WaitOrKill(15 * time.Second) + case err = <-child.Done(): + } + + if stopping { + signal.Stop(sigCh) + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + os.Exit(exitErr.ExitCode()) + } + return err + } + + // Unexpected exit — clean up and restart. + if err != nil { + child.HandleCrash(err) + } + + // Reset backoff if the child ran for a while (i.e. it wasn't a fast crash loop). + if time.Since(startedAt) > resetAfter { + bo.Reset() + } + + child.logger.Info("Restarting child process") + bo.Wait(context.Background()) + } } -func initIPC(dataPath, logPath, logLevel string) (*ipc.Server, error) { - ctx, span := otel.Tracer(tracerName).Start( - context.Background(), - "initIPC", - trace.WithAttributes(attribute.String("dataPath", dataPath)), - ) - defer span.End() +func runDaemon(ctx context.Context, dataPath, logPath, logLevel string) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() - span.AddEvent("initializing IPC server") + slog.Info("Starting lanternd", "version", common.Version, "dataPath", dataPath) + be, err := backend.NewLocalBackend(ctx, backend.Options{ + DataDir: dataPath, + LogDir: logPath, + LogLevel: logLevel, + }) + if err != nil { + return fmt.Errorf("failed to create backend: %w", err) + } + user, err := be.UserData() + if err != nil { + return fmt.Errorf("failed to get current data: %w", err) + } + if user == nil { + if _, err := be.NewUser(ctx); err != nil { + return fmt.Errorf("failed to create new user: %w", err) + } + } - server := ipc.NewServer(vpn.NewTunnelService(dataPath, slog.Default().With("service", "ipc"), nil)) - slog.Debug("starting IPC server") + be.Start() + server := ipc.NewServer(be, !common.IsMobile()) if err := server.Start(); err != nil { - slog.Error("failed to start IPC server", "error", err) - return nil, traces.RecordError(ctx, fmt.Errorf("start IPC server: %w", err)) + return fmt.Errorf("failed to start IPC server: %w", err) + } + + // Wait for context cancellation to gracefully shut down. + <-ctx.Done() + + slog.Info("Shutting down...") + + time.AfterFunc(15*time.Second, func() { + slog.Error("Failed to shut down in time, forcing exit") + os.Exit(1) + }) + + be.Close() + if err := server.Close(); err != nil { + slog.Error("Error closing IPC server", "error", err) } - return server, nil + slog.Info("Shutdown complete") + return nil } diff --git a/cmd/lanternd/lanternd.service b/cmd/lanternd/lanternd.service deleted file mode 100644 index de147401..00000000 --- a/cmd/lanternd/lanternd.service +++ /dev/null @@ -1,19 +0,0 @@ -[Unit] -Description=Lantern VPN Daemon -Wants=network-online.target -After=network-online.target - -[Service] -Type=simple -ExecStart=/usr/sbin/lanternd -data-path /var/lib/lantern -log-path /var/log/lantern -log-level trace -Restart=on-failure -RestartSec=5s - -RuntimeDirectory=lantern -RuntimeDirectoryMode=0755 -StateDirectory=lantern -CacheDirectory=lantern -LogsDirectory=lantern - -[Install] -WantedBy=multi-user.target diff --git a/cmd/lanternd/lanternd_darwin.go b/cmd/lanternd/lanternd_darwin.go new file mode 100644 index 00000000..c4c10628 --- /dev/null +++ b/cmd/lanternd/lanternd_darwin.go @@ -0,0 +1,111 @@ +//go:build darwin && !ios + +package main + +import ( + "fmt" + "log/slog" + "os" + "os/exec" + "text/template" + + "github.com/getlantern/radiance/common" +) + +const ( + serviceName = "com.lantern.lanternd" + binPath = "/usr/local/bin/" + serviceName +) + +func maybePlatformService() bool { + return false +} + +var launchdPlistTmpl = template.Must(template.New("plist").Parse(` + + + + Label + {{.ServiceName}} + ProgramArguments + + {{.ExePath}} + run + --data-path + {{.DataPath}} + --log-path + {{.LogPath}} + --log-level + {{.LogLevel}} + + RunAtLoad + + KeepAlive + + StandardOutPath + {{.LogPath}}/lanternd.stdout.log + StandardErrorPath + {{.LogPath}}/lanternd.stderr.log + + +`)) + +func plistPath() string { + return fmt.Sprintf("/Library/LaunchDaemons/%s.plist", serviceName) +} + +func install(dataPath, logPath, logLevel string) error { + slog.Info("Installing launchd service..", "version", common.Version) + + // Remove any existing service so we can recreate it cleanly. + // Errors are expected on first install when no service exists yet. + if err := uninstall(); err != nil { + slog.Debug("No existing service to remove (expected on first install)", "error", err) + } + + exe, err := copyBin() + if err != nil { + return err + } + + plist := plistPath() + f, err := os.Create(plist) + if err != nil { + return fmt.Errorf("failed to create plist %s: %w", plist, err) + } + defer f.Close() + + err = launchdPlistTmpl.Execute(f, struct { + ServiceName, ExePath, DataPath, LogPath, LogLevel string + }{serviceName, exe, dataPath, logPath, logLevel}) + if err != nil { + return fmt.Errorf("failed to write plist: %w", err) + } + + slog.Info("Installing launchd service", "plist", plist) + if out, err := exec.Command("launchctl", "load", "-w", plist).CombinedOutput(); err != nil { + return fmt.Errorf("launchctl load: %w\n%s", err, out) + } + + slog.Info("Launchd service installed and started") + return nil +} + +func uninstall() error { + slog.Info("Uninstalling launchd service") + plist := plistPath() + + if out, err := exec.Command("launchctl", "unload", "-w", plist).CombinedOutput(); err != nil { + slog.Warn("Failed to unload service", "error", err, "output", string(out)) + } + + if err := os.Remove(plist); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove plist: %w", err) + } + + slog.Info("Launchd service uninstalled") + if err := os.Remove(binPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove binary: %w", err) + } + return nil +} diff --git a/cmd/lanternd/lanternd_linux.go b/cmd/lanternd/lanternd_linux.go new file mode 100644 index 00000000..4af6a3a7 --- /dev/null +++ b/cmd/lanternd/lanternd_linux.go @@ -0,0 +1,110 @@ +package main + +import ( + "fmt" + "log/slog" + "os" + "os/exec" + "text/template" + + "github.com/getlantern/radiance/common" +) + +const ( + serviceName = "lanternd" + binPath = "/usr/bin/" + serviceName + systemdUnitPath = "/usr/lib/systemd/system/" + serviceName + ".service" +) + +func maybePlatformService() bool { + return false +} + +var systemdUnitTmpl = template.Must(template.New("unit").Parse(`[Unit] +Description=Lantern VPN Daemon +Wants=network-online.target +After=network-online.target + +[Service] +Type=simple +ExecStart={{.ExePath}} run --data-path {{.DataPath}} --log-path {{.LogPath}} --log-level {{.LogLevel}} +Restart=on-failure +RestartSec=5s + +RuntimeDirectory=lantern +RuntimeDirectoryMode=0755 +StateDirectory=lantern +CacheDirectory=lantern +LogsDirectory=lantern + +[Install] +WantedBy=multi-user.target +`)) + +func install(dataPath, logPath, logLevel string) error { + slog.Info("Installing systemd service..", "version", common.Version) + + // Remove any existing service so we can recreate it cleanly. + // Errors are expected on first install when no service exists yet. + if err := uninstall(); err != nil { + slog.Debug("No existing service to remove (expected on first install)", "error", err) + } + + exe, err := copyBin() + if err != nil { + return err + } + + f, err := os.Create(systemdUnitPath) + if err != nil { + return fmt.Errorf("failed to create unit file %s: %w", systemdUnitPath, err) + } + defer f.Close() + + err = systemdUnitTmpl.Execute(f, struct { + ExePath, DataPath, LogPath, LogLevel string + }{exe, dataPath, logPath, logLevel}) + if err != nil { + return fmt.Errorf("failed to write unit file: %w", err) + } + + slog.Info("Installing systemd service", "unit", systemdUnitPath) + for _, args := range [][]string{ + {"systemctl", "daemon-reload"}, + {"systemctl", "enable", serviceName}, + {"systemctl", "start", serviceName}, + } { + if out, err := exec.Command(args[0], args[1:]...).CombinedOutput(); err != nil { + return fmt.Errorf("%v: %w\n%s", args, err, out) + } + } + + slog.Info("Systemd service installed and started") + return nil +} + +func uninstall() error { + slog.Info("Uninstalling systemd service") + for _, args := range [][]string{ + {"systemctl", "stop", serviceName}, + {"systemctl", "disable", serviceName}, + } { + if out, err := exec.Command(args[0], args[1:]...).CombinedOutput(); err != nil { + slog.Warn("Command failed", "cmd", args, "error", err, "output", string(out)) + } + } + + if err := os.Remove(systemdUnitPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove unit file: %w", err) + } + + if out, err := exec.Command("systemctl", "daemon-reload").CombinedOutput(); err != nil { + return fmt.Errorf("systemctl daemon-reload: %w\n%s", err, out) + } + + slog.Info("Systemd service uninstalled") + if err := os.Remove(binPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove binary: %w", err) + } + return nil +} diff --git a/cmd/lanternd/lanternd_windows.go b/cmd/lanternd/lanternd_windows.go new file mode 100644 index 00000000..5d941989 --- /dev/null +++ b/cmd/lanternd/lanternd_windows.go @@ -0,0 +1,231 @@ +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "os" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/internal" +) + +const ( + serviceName = "LanternSvc" + binPath = "C:\\Program Files\\Lantern\\" + serviceName + ".exe" +) + +var isWindowsService bool + +func init() { + isSvc, err := svc.IsWindowsService() + if err != nil { + log.Fatalf("Failed to determine if running as Windows service: %v\n", err) + } + isWindowsService = isSvc +} + +func install(dataPath, logPath, logLevel string) error { + dataPath = os.ExpandEnv(dataPath) + logPath = os.ExpandEnv(logPath) + + slog.Info("Installing Windows service..", "version", common.Version) + + // Remove any existing service so we can recreate it cleanly. + // Errors are expected on first install when no service exists yet. + if err := uninstall(); err != nil { + slog.Debug("No existing service to remove (expected on first install)", "error", err) + } + + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to Windows service manager: %w", err) + } + defer m.Disconnect() + + exe, err := copyBin() + if err != nil { + return err + } + + config := mgr.Config{ + ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, + StartType: mgr.StartAutomatic, + ErrorControl: mgr.ErrorNormal, + DisplayName: serviceName, + Description: "Lantern Daemon Service", + } + + args := []string{ + "run", + "--data-path", dataPath, + "--log-path", logPath, + "--log-level", logLevel, + } + + slog.Info("Creating Windows service", "exe", exe, "args", args) + service, err := m.CreateService(serviceName, exe, config, args...) + if err != nil { + return fmt.Errorf("failed to create %q service: %w", serviceName, err) + } + defer service.Close() + + err = service.SetRecoveryActions([]mgr.RecoveryAction{ + {Type: mgr.ServiceRestart, Delay: 1 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 2 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 4 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 8 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 16 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 32 * time.Second}, + {Type: mgr.ServiceRestart, Delay: 64 * time.Second}, + }, 60) + if err != nil { + return fmt.Errorf("failed to set service recovery actions: %w", err) + } + if err := service.Start(); err != nil { + return fmt.Errorf("failed to start service: %w", err) + } + + slog.Info("Windows service installed successfully") + return nil +} + +func uninstall() error { + slog.Info("Uninstalling Windows service..") + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to Windows service manager: %w", err) + } + defer m.Disconnect() + + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("failed to open %q service: %w", serviceName, err) + } + + status, err := service.Query() + if err != nil { + service.Close() + return fmt.Errorf("failed to query service state: %w", err) + } + if status.State != svc.Stopped { + service.Control(svc.Stop) + } + err = service.Delete() + service.Close() + if err != nil { + return fmt.Errorf("failed to delete service: %w", err) + } + + slog.Info("Waiting for service to be removed...") + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for service to be removed") + case <-time.After(100 * time.Millisecond): + if service, err = m.OpenService(serviceName); err != nil { + slog.Info("Windows service uninstalled successfully") + if err := os.Remove(binPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove binary: %w", err) + } + return nil + } + service.Close() + } + } +} + +func maybePlatformService() bool { + if !isWindowsService { + return false + } + if err := startWindowsService(); err != nil { + log.Fatalf("Failed to start Windows service: %v\n", err) + } + return true +} + +type service struct{} + +func startWindowsService() error { + return svc.Run(serviceName, &service{}) +} + +func (s *service) Execute(args []string, r <-chan svc.ChangeRequest, status chan<- svc.Status) (bool, uint32) { + status <- svc.Status{State: svc.StartPending} + + // The Execute args from the SCM dispatcher only contain runtime start parameters + // (typically just [serviceName]). The actual configured arguments are baked into + // os.Args via the service ImagePath. Parse from os.Args to get the real values, + // falling back to defaults if not present. + dataPath, logPath, logLevel := parseServiceArgs(os.Args[1:]) + + // Run the daemon as a child process so we can clean up network state if it crashes, + // regardless of whether the SCM is configured to restart the service. + childArgs := []string{"run", "--data-path", dataPath, "--log-path", logPath, "--log-level", logLevel} + child, err := spawnChild(childArgs, dataPath, logPath, logLevel) + if err != nil { + slog.Error("Failed to start daemon", "error", err) + return true, 1 + } + + status <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} + child.logger.Info("Running as Windows service") + + for { + select { + case err := <-child.Done(): + if err != nil { + child.HandleCrash(err) + } + return true, 1 + case change := <-r: + switch change.Cmd { + case svc.Stop, svc.Shutdown: + status <- svc.Status{State: svc.StopPending} + child.logger.Info("Service stop requested") + child.RequestShutdown() + child.WaitOrKill(15 * time.Second) + return false, windows.NO_ERROR + case svc.Interrogate: + status <- change.CurrentStatus + case svc.SessionChange: + status <- change.CurrentStatus + } + } + } +} + +func parseServiceArgs(args []string) (dataPath, logPath, logLevel string) { + dataPath = internal.DefaultDataPath() + logPath = internal.DefaultLogPath() + logLevel = "info" + for i := 0; i < len(args); i++ { + switch args[i] { + case "--data-path": + if i+1 < len(args) { + dataPath = os.ExpandEnv(args[i+1]) + i++ + } + case "--log-path": + if i+1 < len(args) { + logPath = os.ExpandEnv(args[i+1]) + i++ + } + case "--log-level": + if i+1 < len(args) { + logLevel = args[i+1] + i++ + } + } + } + return +} diff --git a/cmd/qa-bandit/main.go b/cmd/qa-bandit/main.go new file mode 100644 index 00000000..892e63c6 --- /dev/null +++ b/cmd/qa-bandit/main.go @@ -0,0 +1,304 @@ +// Command qa-bandit is a focused QA driver for the bandit assignment path. +// It boots a radiance backend that impersonates an Android client, captures +// the first /v1/config-new response from the bandit, prints the assignment, +// then optionally connects the VPN and probes a target URL through the +// resulting tunnel to confirm both the API view of the client and the +// outbound dials originate from the country we're simulating. +// +// Pair with `pinger bridge --country ru`: +// +// # in lantern-cloud-bridge: +// ./cmd/pinger/bridge.sh +// # in radiance: +// RADIANCE_OUTBOUND_SOCKS_ADDRESS=127.0.0.1:1080 \ +// go run -tags 'with_quic,with_gvisor,with_wireguard,with_utls' ./cmd/qa-bandit +// +// The build tags are needed by sing-box outbounds (hysteria2 needs QUIC, +// etc.) — without them ConnectVPN fails with "X is not included in this +// build, rebuild with -tags with_X". +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "os" + "strconv" + "time" + + "golang.org/x/net/proxy" + + "github.com/getlantern/radiance/backend" + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/config" + "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/vpn" +) + +func main() { + var ( + outboundSocks = flag.String("outbound-socks", os.Getenv("RADIANCE_OUTBOUND_SOCKS_ADDRESS"), + "upstream SOCKS5 to route ALL radiance egress through (e.g. 127.0.0.1:1080 — pinger bridge)") + platform = flag.String("platform", "android", "platform to advertise to the API (sent in the body and X-Lantern-Platform)") + version = flag.String("version", "9.0.30-qa-bandit", + "app version to advertise (X-Lantern-App-Version / X-Lantern-Version)") + deviceID = flag.String("device-id", "qa-bandit-android-0001", "device ID to advertise") + userID = flag.String("user-id", "0", "user ID to advertise (string; 0 = no specific user)") + token = flag.String("token", "", "pro token (optional — empty = free tier)") + probeURL = flag.String("probe-url", "https://api.ipify.org", + "URL to fetch through the bandit-assigned tunnel to verify egress IP") + doConnect = flag.Bool("connect", true, + "actually ConnectVPN(AutoSelect) and probe — false = just dump the bandit response and exit") + socksIn = flag.String("socks-inbound", "127.0.0.1:46666", + "local SOCKS5 inbound that radiance exposes for the probe (avoids needing a TUN / root)") + // The API's GeoIP→country logic overrides the IP-derived country + // with the timezone-derived one (treats mismatches as VPN users). + // Without spoofing these to Russia equivalents, the bandit will + // keep serving US-tier outbounds even though the TCP egress is + // Russia. See cmd/api/maxmind.go LookupCountryASNState. + tz = flag.String("tz", "Europe/Moscow", "TZ env var sent as the request's X-Lantern-Time-Zone") + locale = flag.String("locale", "ru_RU", "locale to pass to the radiance backend (X-Lantern-Locale)") + timeout = flag.Duration("timeout", 180*time.Second, "overall timeout (covers config fetch + URLTest convergence + probe retries)") + ) + flag.Parse() + + // Plumb the QA env vars BEFORE common.Init runs (i.e. before NewLocalBackend). + // All three of these are honored by code on qa/outbound-socks-egress branch. + if *outboundSocks != "" { + os.Setenv("RADIANCE_OUTBOUND_SOCKS_ADDRESS", *outboundSocks) + } + os.Setenv("RADIANCE_PLATFORM", *platform) + os.Setenv("RADIANCE_VERSION", *version) + // Use a SOCKS5 inbound listener instead of a TUN device — no root/sudo + // needed, and gives us a clean address to probe through. + os.Setenv("RADIANCE_USE_SOCKS_PROXY", "true") + os.Setenv("RADIANCE_SOCKS_ADDRESS", *socksIn) + // Spoof TZ so the X-Lantern-Time-Zone radiance sends matches the country + // we're impersonating. The API's MaxMind logic overrides the GeoIP-derived + // country with the timezone-derived one when they disagree, so without + // this the bandit thinks "user behind a VPN, return their real country". + if *tz != "" { + os.Setenv("TZ", *tz) + } + + ctx, cancel := context.WithTimeout(context.Background(), *timeout) + defer cancel() + + dataDir, err := os.MkdirTemp("", "qa-bandit-") + if err != nil { + fatal("mktempdir", err) + } + defer os.RemoveAll(dataDir) + + banner(*outboundSocks, *platform, *version, dataDir, *socksIn) + + be, err := backend.NewLocalBackend(ctx, backend.Options{ + DataDir: dataDir, + LogDir: dataDir, + Locale: *locale, + }) + if err != nil { + fatal("NewLocalBackend", err) + } + defer be.Close() + + uid, err := strconv.ParseInt(*userID, 10, 64) + if err != nil { + fatal("parse user-id", err) + } + settings.Set(settings.UserIDKey, uid) + settings.Set(settings.TokenKey, *token) + settings.Set(settings.UserLevelKey, "") + settings.Set(settings.EmailKey, "qa-bandit@local") + // Need both: DeviceIDKey is what common.NewRequestWithHeaders pulls for + // the X-Lantern-DeviceID header (and the user-create body field), while + // DevicesKey is the canonical list used elsewhere. + settings.Set(settings.DeviceIDKey, *deviceID) + settings.Set(settings.DevicesKey, []settings.Device{{ID: *deviceID, Name: *deviceID}}) + + // Subscribe BEFORE Start() so we don't race the first config event. + cfgCh := make(chan *config.Config, 1) + go events.SubscribeOnce(func(evt config.NewConfigEvent) { + cfgCh <- evt.New + }) + + be.Start() + + // Note: we deliberately do NOT bring up the IPC server here. It's there + // for client UIs (Lantern Flutter, etc.) to talk to the backend — we're + // calling backend methods directly, and on macOS its default Unix-socket + // path (/var/run/lantern/lanternd.sock) requires root. + + fmt.Println("[qa-bandit] waiting for first /v1/config-new response (bandit assignment)...") + var cfg *config.Config + select { + case cfg = <-cfgCh: + case <-ctx.Done(): + fatal("waiting for config", ctx.Err()) + } + + dumpAssignment(cfg) + + if !*doConnect { + return + } + + fmt.Println("\n[qa-bandit] connecting VPN with bandit auto-pick...") + if err := be.ConnectVPN(vpn.AutoSelectTag); err != nil { + fmt.Printf("[qa-bandit] ConnectVPN FAILED: %v\n", err) + os.Exit(1) + } + defer be.DisconnectVPN() + + // URLTest needs a few seconds to converge on a working outbound. UDP + // outbounds (hysteria2/wireguard/tuic) fail immediately through our + // bridge — it only does TCP CONNECT. After URLTest marks them dead, + // AutoSelect prefers TCP-based ones (samizdat/reflex/vmess/etc.). + fmt.Printf("[qa-bandit] VPN connected; waiting up to 30s for URLTest to converge, then probing %s through %s...\n", *probeURL, *socksIn) + var ( + body string + dur time.Duration + err2 error + ) + deadline := time.Now().Add(30 * time.Second) + for attempt := 1; ; attempt++ { + body, dur, err2 = probeViaSocks(ctx, *socksIn, *probeURL) + if err2 == nil { + fmt.Printf("[qa-bandit] probe OK in %.2fs (attempt %d) — egress IP: %s\n", dur.Seconds(), attempt, body) + return + } + if time.Now().After(deadline) || ctx.Err() != nil { + fmt.Printf("[qa-bandit] probe FAILED after %d attempts: %v\n", attempt, err2) + os.Exit(1) + } + fmt.Printf("[qa-bandit] attempt %d failed (%v) — retrying in 3s...\n", attempt, err2) + time.Sleep(3 * time.Second) + } +} + +func banner(outboundSocks, platform, version, dataDir, socksIn string) { + fmt.Println() + fmt.Println("======================================================================") + fmt.Println(" qa-bandit — radiance bandit-assignment probe") + fmt.Println("======================================================================") + fmt.Printf(" Platform : %s\n", platform) + fmt.Printf(" App version : %s\n", version) + fmt.Printf(" Time zone : %s\n", os.Getenv("TZ")) + if outboundSocks == "" { + fmt.Println(" Outbound SOCKS5 : (unset — radiance will dial DIRECTLY, NOT through any country)") + } else { + fmt.Printf(" Outbound SOCKS5 : %s (every radiance dial goes here)\n", outboundSocks) + } + fmt.Printf(" Probe inbound SOCKS: %s\n", socksIn) + fmt.Printf(" Data dir : %s\n", dataDir) + fmt.Println() +} + +// dumpAssignment prints the parts of the config response the bandit decided. +func dumpAssignment(cfg *config.Config) { + fmt.Println("=========================== bandit assignment ===========================") + fmt.Printf(" API saw client as : country=%s ip=%s\n", cfg.Country, cfg.IP) + fmt.Printf(" Servers (%d) :\n", len(cfg.Servers)) + for _, s := range cfg.Servers { + fmt.Printf(" %-2s %s / %s\n", s.CountryCode, s.Country, s.City) + } + fmt.Printf(" Outbounds (%d):\n", len(cfg.Options.Outbounds)) + for _, o := range cfg.Options.Outbounds { + loc := cfg.OutboundLocations[o.Tag] + fmt.Printf(" %-12s %s (%s / %s)\n", o.Type, o.Tag, loc.CountryCode, loc.City) + } + if len(cfg.BanditURLOverrides) > 0 { + fmt.Printf(" Bandit callback URLs : %d outbounds tagged with per-proxy callbacks\n", len(cfg.BanditURLOverrides)) + } + if cfg.PollIntervalSeconds > 0 { + fmt.Printf(" Server-suggested poll: %ds\n", cfg.PollIntervalSeconds) + } + if raw, err := json.MarshalIndent(struct { + Country string `json:"country"` + IP string `json:"ip"` + Outbounds int `json:"outbounds"` + Servers int `json:"servers"` + BanditURLOverrides int `json:"bandit_url_overrides"` + PollIntervalSeconds int `json:"poll_interval_seconds"` + OutboundLocations map[string]string `json:"outbound_locations,omitempty"` + }{ + Country: cfg.Country, + IP: cfg.IP, + Outbounds: len(cfg.Options.Outbounds), + Servers: len(cfg.Servers), + BanditURLOverrides: len(cfg.BanditURLOverrides), + PollIntervalSeconds: cfg.PollIntervalSeconds, + OutboundLocations: shortOutboundLocations(cfg), + }, "", " "); err == nil { + fmt.Printf(" Summary JSON :\n%s\n", raw) + } + fmt.Println("==========================================================================") +} + +func shortOutboundLocations(cfg *config.Config) map[string]string { + out := make(map[string]string, len(cfg.OutboundLocations)) + for tag, loc := range cfg.OutboundLocations { + out[tag] = fmt.Sprintf("%s / %s", loc.CountryCode, loc.City) + } + return out +} + +func probeViaSocks(ctx context.Context, socksAddr, target string) (string, time.Duration, error) { + d, err := proxy.SOCKS5("tcp", socksAddr, nil, proxy.Direct) + if err != nil { + return "", 0, fmt.Errorf("building SOCKS5 dialer to %s: %w", socksAddr, err) + } + cd := d.(proxy.ContextDialer) + tr := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return cd.DialContext(ctx, network, addr) + }, + } + defer tr.CloseIdleConnections() + client := &http.Client{Transport: tr, Timeout: 30 * time.Second} + + parsed, err := url.Parse(target) + if err != nil { + return "", 0, fmt.Errorf("parsing %q: %w", target, err) + } + if parsed.Scheme == "" { + return "", 0, fmt.Errorf("probe URL must include scheme (https://...): %q", target) + } + + t0 := time.Now() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, target, nil) + if err != nil { + return "", 0, err + } + resp, err := client.Do(req) + if err != nil { + return "", time.Since(t0), err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", time.Since(t0), fmt.Errorf("probe returned status %d", resp.StatusCode) + } + body, err := io.ReadAll(io.LimitReader(resp.Body, 1024)) + if err != nil { + return "", time.Since(t0), err + } + return string(body), time.Since(t0), nil +} + +func fatal(stage string, err error) { + slog.Error(stage, "error", err) + fmt.Fprintf(os.Stderr, "[qa-bandit] FAILED at %s: %v\n", stage, err) + os.Exit(1) +} + +// Compile-time check that common.Platform is a var (not const) — see +// common/platform.go. If this stops compiling, the override env var +// (RADIANCE_PLATFORM, set in main()) won't take effect. +var _ = func() string { return common.Platform } diff --git a/common/constants.go b/common/constants.go index d38cbd0c..6c2b1ca6 100644 --- a/common/constants.go +++ b/common/constants.go @@ -2,22 +2,18 @@ package common import ( "time" + + "github.com/getlantern/radiance/common/env" ) // Version is the application version, injected at build time via ldflags: // // -X 'github.com/getlantern/radiance/common.Version=x.y.z' -// -// Can also be overridden at runtime via the RADIANCE_VERSION environment variable. var Version = "dev" const ( Name = "lantern" - // filenames - LogFileName = "lantern.log" - ConfigFileName = "config.json" - ServersFileName = "servers.json" DefaultHTTPTimeout = (60 * time.Second) // API URLs @@ -27,8 +23,13 @@ const ( StageBaseURL = "https://api.staging.iantem.io/v1" ) +func GetVersion() string { + if v := env.GetString(env.AppVersion); v != "" { + return v + } + return Version +} -// GetProServerURL returns the pro server URL based on the current environment. func GetProServerURL() string { if Stage() { return StageProServerURL @@ -36,7 +37,6 @@ func GetProServerURL() string { return ProServerURL } -// GetBaseURL returns the auth/user base URL based on the current environment. func GetBaseURL() string { if Stage() { return StageBaseURL diff --git a/common/deviceid/deviceid_nonwindows.go b/common/deviceid/deviceid_nonwindows.go index 3e2fb8c0..03e4df5e 100644 --- a/common/deviceid/deviceid_nonwindows.go +++ b/common/deviceid/deviceid_nonwindows.go @@ -9,41 +9,63 @@ import ( "path/filepath" "github.com/google/uuid" + + "github.com/getlantern/radiance/common/atomicfile" + "github.com/getlantern/radiance/common/fileperm" ) // Get returns a unique identifier for this device. The identifier is a random UUID that's stored on -// disk at $HOME/.lanternsecrets/.deviceid. If unable to read/write to that location, this defaults to the +// disk at {path}/.lanternsecrets/.deviceid. If unable to read/write to that location, this defaults to the // old-style device ID derived from MAC address. -func Get() string { - home, err := os.UserHomeDir() - if err != nil { - slog.Error("Could not get home dir", "error", err) - return OldStyleDeviceID() - } - path := filepath.Join(home, ".lanternsecrets") - err = os.Mkdir(path, 0o755) +func Get(path string) string { + path = filepath.Join(path, ".lanternsecrets") + err := os.Mkdir(path, 0o755) if err != nil && !os.IsExist(err) { slog.Error("Unable to create folder to store deviceID, defaulting to old-style device ID", "error", err) return OldStyleDeviceID() } filename := filepath.Join(path, ".deviceid") - existing, err := os.ReadFile(filename) + existing, err := atomicfile.ReadFile(filename) + if err == nil { + return string(existing) + } + + if migrated, ok := migrateLegacyDeviceID(filename); ok { + return migrated + } + + slog.Debug("Storing new deviceID") + _deviceID, err := uuid.NewRandom() if err != nil { - slog.Debug("Storing new deviceID") - _deviceID, err := uuid.NewRandom() - if err != nil { - slog.Error("Error generating new deviceID, defaulting to old-style device ID", "error", err) - return OldStyleDeviceID() - } - deviceID := _deviceID.String() - err = os.WriteFile(filename, []byte(deviceID), 0o644) - if err != nil { - slog.Error("Error storing new deviceID, defaulting to old-style device ID", "error", err) - return OldStyleDeviceID() - } - return deviceID - } - - return string(existing) + slog.Error("Error generating new deviceID, defaulting to old-style device ID", "error", err) + return OldStyleDeviceID() + } + deviceID := _deviceID.String() + if err := atomicfile.WriteFile(filename, []byte(deviceID), fileperm.File); err != nil { + slog.Error("Error storing new deviceID, defaulting to old-style device ID", "error", err) + return OldStyleDeviceID() + } + return deviceID +} + +// migrateLegacyDeviceID copies a device ID from the pre-refactor location ($HOME/.lanternsecrets/.deviceid) +// to dst, returning the migrated ID on success. The legacy file is left in place. +// TODO(2026-04-20): remove this migration code after a few releases. +func migrateLegacyDeviceID(dst string) (string, bool) { + home, err := os.UserHomeDir() + if err != nil { + return "", false + } + legacy := filepath.Join(home, ".lanternsecrets", ".deviceid") + contents, err := atomicfile.ReadFile(legacy) + if err != nil { + return "", false + } + if err := atomicfile.WriteFile(dst, contents, fileperm.File); err != nil { + slog.Warn("Failed to migrate legacy deviceID", "error", err) + return "", false + } + slog.Info("Migrated legacy deviceID", "from", legacy, "to", dst) + return string(contents), true } diff --git a/common/deviceid/deviceid_test.go b/common/deviceid/deviceid_test.go index d2d98c5c..2a916b9d 100644 --- a/common/deviceid/deviceid_test.go +++ b/common/deviceid/deviceid_test.go @@ -1,14 +1,47 @@ package deviceid import ( + "os" + "path/filepath" + "runtime" "testing" "github.com/stretchr/testify/require" ) func TestGet(t *testing.T) { - id1 := Get() + tmp := t.TempDir() + t.Setenv("HOME", tmp) // isolate from any real legacy deviceID on the dev machine + id1 := Get(tmp) require.True(t, len(id1) > 8) - id2 := Get() + id2 := Get(tmp) require.Equal(t, id1, id2) } + +func TestMigrateLegacyDeviceID(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("migration is non-windows only") + } + + home := t.TempDir() + t.Setenv("HOME", home) + legacyDir := filepath.Join(home, ".lanternsecrets") + require.NoError(t, os.Mkdir(legacyDir, 0o755)) + legacyID := "legacy-device-id-12345" + require.NoError(t, os.WriteFile(filepath.Join(legacyDir, ".deviceid"), []byte(legacyID), 0o644)) + + data := t.TempDir() + require.Equal(t, legacyID, Get(data), "should return the migrated legacy ID") + + newFile := filepath.Join(data, ".lanternsecrets", ".deviceid") + contents, err := os.ReadFile(newFile) + require.NoError(t, err) + require.Equal(t, legacyID, string(contents), "legacy ID should be copied to new location") + + // Legacy file should remain in place. + _, err = os.Stat(filepath.Join(legacyDir, ".deviceid")) + require.NoError(t, err, "legacy file should not be deleted") + + // Second call reads from the new location and returns the same ID. + require.Equal(t, legacyID, Get(data)) +} diff --git a/common/deviceid/deviceid_windows.go b/common/deviceid/deviceid_windows.go index 84f578b9..65fa6b16 100644 --- a/common/deviceid/deviceid_windows.go +++ b/common/deviceid/deviceid_windows.go @@ -11,18 +11,19 @@ import ( ) const ( - keyPath = `Sofware\\Lantern` + keyPath = `Software\Lantern` ) // Get returns a unique identifier for this device. The identifier is a random UUID that's stored in the registry // at HKEY_CURRENT_USERS\Software\Lantern\deviceid. If unable to read/write to the registry, this defaults to the // old-style device ID derived from MAC address. -func Get() string { +func Get(_ string) string { key, _, err := registry.CreateKey(registry.CURRENT_USER, keyPath, registry.QUERY_VALUE|registry.SET_VALUE|registry.WRITE) if err != nil { - slog.Error("Unable to create registry entry to store deviceID, defaulting to old-style device ID: %v", "error", err) + slog.Error("Unable to create registry entry to store deviceID, defaulting to old-style device ID", "error", err) return OldStyleDeviceID() } + defer key.Close() existing, _, err := key.GetStringValue("deviceid") if err != nil { @@ -39,7 +40,7 @@ func Get() string { deviceID := _deviceID.String() err = key.SetStringValue("deviceid", deviceID) if err != nil { - slog.Error("Error storing new deviceID, defaulting to old-style device IDL", "error", err) + slog.Error("Error storing new deviceID, defaulting to old-style device ID", "error", err) return OldStyleDeviceID() } return deviceID diff --git a/common/env/env.go b/common/env/env.go index 0364962e..a65f01e5 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -6,47 +6,50 @@ import ( "errors" "io/fs" "log/slog" + "maps" "os" "strconv" "strings" + "sync" "testing" - - "github.com/getlantern/radiance/internal" ) -type Key = string +type _key string -const ( - LogLevel Key = "RADIANCE_LOG_LEVEL" - LogPath Key = "RADIANCE_LOG_PATH" - DataPath Key = "RADIANCE_DATA_PATH" - DisableFetch Key = "RADIANCE_DISABLE_FETCH_CONFIG" - PrintCurl Key = "RADIANCE_PRINT_CURL" - DisableStdout Key = "RADIANCE_DISABLE_STDOUT_LOG" - ENV Key = "RADIANCE_ENV" - UseSocks Key = "RADIANCE_USE_SOCKS_PROXY" - SocksAddress Key = "RADIANCE_SOCKS_ADDRESS" - AppVersion Key = "RADIANCE_VERSION" +var ( + LogLevel _key = "RADIANCE_LOG_LEVEL" + LogPath _key = "RADIANCE_LOG_PATH" + DataPath _key = "RADIANCE_DATA_PATH" + DisableFetch _key = "RADIANCE_DISABLE_FETCH_CONFIG" + PrintCurl _key = "RADIANCE_PRINT_CURL" + DisableStdout _key = "RADIANCE_DISABLE_STDOUT_LOG" + ENV _key = "RADIANCE_ENV" + UseSocks _key = "RADIANCE_USE_SOCKS_PROXY" + SocksAddress _key = "RADIANCE_SOCKS_ADDRESS" + // OutboundSocksAddress, when set to host:port of a SOCKS5 server, routes + // every outbound connection that radiance opens (kindling HTTP client, + // sing-box outbound tunnel dials, the bypass dialer) through that server. + // Distinct from SocksAddress, which sets up an inbound listener for other + // apps to use radiance as a SOCKS proxy. Intended for censorship- + // circumvention QA — point it at a SOCKS server that egresses through a + // residential proxy in the country we want to simulate. + OutboundSocksAddress _key = "RADIANCE_OUTBOUND_SOCKS_ADDRESS" + // Platform overrides common.Platform for QA scenarios that want to + // impersonate a different OS (e.g. test the Android bandit path from a + // Linux/macOS process). Honored in common.Init(). + Platform _key = "RADIANCE_PLATFORM" + Country _key = "RADIANCE_COUNTRY" + FeatureOverrides _key = "RADIANCE_FEATURE_OVERRIDES" + AppVersion _key = "RADIANCE_VERSION" - Testing Key = "RADIANCE_TESTING" -) + Testing _key = "RADIANCE_TESTING" -var ( - keys = []Key{ - LogLevel, - LogPath, - DataPath, - DisableFetch, - PrintCurl, - DisableStdout, - SocksAddress, - UseSocks, - ENV, - AppVersion, - } - envVars = map[string]any{} + mu sync.RWMutex + dotenv = map[string]string{} ) +func (k _key) String() string { return string(k) } + func init() { buf, err := os.ReadFile(".env") if err != nil && !errors.Is(err, fs.ErrNotExist) { @@ -63,65 +66,73 @@ func init() { if len(parts) == 2 { key := strings.TrimSpace(parts[0]) value := strings.TrimSpace(parts[1]) - parseAndSet(key, value) + dotenv[key] = value } } } - - // Check for environment variables and populate envVars, overriding any values from the .env file - for _, key := range keys { - if value, exists := os.LookupEnv(key); exists { - parseAndSet(key, value) - } - } if testing.Testing() { - envVars[Testing] = true - envVars[LogLevel] = "DISABLE" - slog.SetLogLoggerLevel(internal.Disable) + dotenv[Testing.String()] = "true" + dotenv[LogLevel.String()] = "disable" } } -// Get retrieves the value associated with the given key and attempts to cast it to type T. If the -// key does not exist or the type does not match, it returns the zero value of T and false. -func Get[T any](key Key) (T, bool) { - if value, exists := envVars[key]; exists { - if v, ok := value.(T); ok { - return v, true - } +// Get returns the value for key. OS env takes precedence over .env / runtime +// Set values (matching the package docstring); dotenv is the fallback. +func Get(key _key) (string, bool) { + if value, exists := os.LookupEnv(key.String()); exists { + return value, true + } + mu.RLock() + value, exists := dotenv[key.String()] + mu.RUnlock() + if exists { + return value, true } - var zero T - return zero, false + return "", false } -// SetStagingEnv sets the environment to staging if it has not already been set. -// This is used for testing that need to interact with staging services, -func SetStagingEnv() { - slog.Info("setting environment to staging for testing") - envVars[ENV] = "staging" - envVars[PrintCurl] = true +// Set writes a key to the in-memory dotenv map. If the same key is set in +// the OS env, Get still returns the OS value — shell env wins. +func Set(key string, value string) { + mu.Lock() + dotenv[key] = value + mu.Unlock() } -// alwaysString is the set of keys whose values must be stored as strings even if they -// look numeric or boolean (e.g. RADIANCE_VERSION="9" should remain a string). -var alwaysString = map[Key]bool{ - AppVersion: true, +// GetAll returns a copy of the in-memory dotenv map. +func GetAll() map[string]string { + mu.RLock() + defer mu.RUnlock() + m := make(map[string]string, len(dotenv)) + maps.Copy(m, dotenv) + return m } -func parseAndSet(key, value string) { - if alwaysString[key] { - envVars[key] = value - return - } - // Attempt to parse as a boolean - if b, err := strconv.ParseBool(value); err == nil { - envVars[key] = b - return +func GetString(key _key) string { + value, _ := Get(key) + return value +} + +func GetBool(key _key) bool { + value, exists := Get(key) + if !exists { + return false } - // Attempt to parse as an integer - if i, err := strconv.Atoi(value); err == nil { - envVars[key] = i - return + v, _ := strconv.ParseBool(value) + return v +} + +func GetInt(key _key) int { + value, exists := Get(key) + if !exists { + return 0 } - // Otherwise, store as a string - envVars[key] = value + v, _ := strconv.Atoi(value) + return v +} + +func SetStagingEnv() { + slog.Info("setting environment to staging for testing") + Set(ENV.String(), "staging") + Set(PrintCurl.String(), "true") } diff --git a/common/env/env_test.go b/common/env/env_test.go new file mode 100644 index 00000000..7d72181f --- /dev/null +++ b/common/env/env_test.go @@ -0,0 +1,61 @@ +package env + +import ( + "os" + "testing" +) + +// Guards the precedence promised by the package docstring: OS env > dotenv. +func TestGet_OSEnvWinsOverDotenv(t *testing.T) { + saved := cloneDotenv() + defer restoreDotenv(saved) + + // Test-only key — don't mutate real RADIANCE_* vars that sibling + // packages may read during parallel test execution. + const testKey = "RADIANCE_UNIT_TEST_OS_WINS_KEY_DOES_NOT_EXIST" + t.Setenv(testKey, "prod") + Set(testKey, "staging") + + got, ok := Get(_key(testKey)) + if !ok { + t.Fatal("Get returned ok=false") + } + if got != "prod" { + t.Fatalf("OS env should win; got %q, want %q", got, "prod") + } +} + +// Other half of the contract: dotenv is still consulted when OS env is unset, +// so runtime instrumentation like SetStagingEnv keeps working. +func TestGet_DotenvFallsBackWhenOSUnset(t *testing.T) { + saved := cloneDotenv() + defer restoreDotenv(saved) + + const testKey = "RADIANCE_UNIT_TEST_KEY_DOES_NOT_EXIST" + _ = os.Unsetenv(testKey) + + Set(testKey, "from-dotenv") + got, ok := Get(_key(testKey)) + if !ok { + t.Fatal("Get returned ok=false when only dotenv had the value") + } + if got != "from-dotenv" { + t.Fatalf("dotenv should be used when OS env unset; got %q", got) + } +} + +func cloneDotenv() map[string]string { + mu.RLock() + defer mu.RUnlock() + out := make(map[string]string, len(dotenv)) + for k, v := range dotenv { + out[k] = v + } + return out +} + +func restoreDotenv(m map[string]string) { + mu.Lock() + defer mu.Unlock() + dotenv = m +} diff --git a/common/errors.go b/common/errors.go index 3793d4a7..6e21d921 100644 --- a/common/errors.go +++ b/common/errors.go @@ -2,6 +2,6 @@ package common import "errors" -// ErrNotImplemented is returned by functions which have not yet been implemented. The existence of -// this error is temporary; this will go away when the API stabilized. +// ErrNotImplemented is returned by functions that have not yet been implemented. +// It is temporary and will be removed once the API stabilizes. var ErrNotImplemented = errors.New("not yet implemented") diff --git a/common/fileperm/fileperm_mobile.go b/common/fileperm/fileperm_mobile.go new file mode 100644 index 00000000..7636d14e --- /dev/null +++ b/common/fileperm/fileperm_mobile.go @@ -0,0 +1,8 @@ +//go:build android || ios || (darwin && !standalone) + +// Package fileperm provides the permission bits used when creating files owned by radiance. +package fileperm + +import "os" + +const File os.FileMode = 0o644 diff --git a/common/fileperm/fileperm_nonmobile.go b/common/fileperm/fileperm_nonmobile.go new file mode 100644 index 00000000..b279d730 --- /dev/null +++ b/common/fileperm/fileperm_nonmobile.go @@ -0,0 +1,8 @@ +//go:build (!android && !ios && !darwin) || (darwin && standalone) + +// Package fileperm provides the permission bits used when creating files owned by radiance. +package fileperm + +import "os" + +const File os.FileMode = 0o644 // temporarily set to 644 to during developement, will be set to 600 for production builds. diff --git a/common/gostack.go b/common/gostack.go deleted file mode 100644 index 52273b30..00000000 --- a/common/gostack.go +++ /dev/null @@ -1,38 +0,0 @@ -package common - -import ( - "fmt" - "log/slog" - "runtime/debug" -) - -// RunOffCgoStack executes fn on a new goroutine and returns its result. -// A new goroutine is spawned per call; there is no persistent worker. -// -// Gomobile-exported functions run on a CGo callback stack whose memory isn't -// covered by the GC heap bitmap. When the gomobile-generated wrapper copies Go -// pointer-containing return values to the C thread stack, bulkBarrierPreWrite -// can panic. Running the body on a real Go goroutine avoids this entirely. -// -// If fn panics, the panic is recovered and a zero value + error are returned -// instead of blocking the caller forever. -func RunOffCgoStack[T any](fn func() (T, error)) (T, error) { - type result struct { - val T - err error - } - ch := make(chan result, 1) - go func() { - defer func() { - if r := recover(); r != nil { - slog.Error("panic in RunOffCgoStack", "panic", r, "stack", string(debug.Stack())) - var zero T - ch <- result{val: zero, err: fmt.Errorf("panic: %v", r)} - } - }() - v, err := fn() - ch <- result{val: v, err: err} - }() - r := <-ch - return r.val, r.err -} diff --git a/backend/headers.go b/common/headers.go similarity index 63% rename from backend/headers.go rename to common/headers.go index d1e5a99c..80dcee8a 100644 --- a/backend/headers.go +++ b/common/headers.go @@ -1,4 +1,4 @@ -package backend +package common import ( "context" @@ -6,31 +6,25 @@ import ( "io" "math/big" "net/http" - "strconv" "sync/atomic" "time" "github.com/getlantern/timezone" - "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" ) -// clientIP holds the detected public IP address, set once at startup. -var clientIP atomic.Value // string +// publicIP holds the detected public IP address, set once at startup. +var publicIP atomic.Value // string -// SetClientIP stores the detected public IP for inclusion in API requests. -func SetClientIP(ip string) { - clientIP.Store(ip) +func init() { + publicIP.Store("") // ensure publicIP is type string } -// GetClientIP returns the detected public IP, or empty string if not yet detected. -func GetClientIP() string { - v := clientIP.Load() - if v == nil { - return "" - } - return v.(string) +// SetPublicIP stores the detected public IP for inclusion in API requests. It should only be called +// once at startup after successfully detecting the public IP. +func SetPublicIP(ip string) { + publicIP.Store(ip) } const ( @@ -62,36 +56,21 @@ func NewRequestWithHeaders(ctx context.Context, method, url string, body io.Read // based on consistent packet lengths. req.Header.Add(RandomNoiseHeader, randomizedString()) - req.Header.Set(AppVersionHeader, common.Version) - req.Header.Set(VersionHeader, common.Version) - req.Header.Set(UserIDHeader, strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10)) - req.Header.Set(PlatformHeader, common.Platform) - req.Header.Set(AppNameHeader, common.Name) + req.Header.Set(AppVersionHeader, GetVersion()) + req.Header.Set(VersionHeader, GetVersion()) + req.Header.Set(UserIDHeader, settings.GetString(settings.UserIDKey)) + req.Header.Set(PlatformHeader, Platform) + req.Header.Set(AppNameHeader, Name) req.Header.Set(DeviceIDHeader, settings.GetString(settings.DeviceIDKey)) if tz, err := timezone.IANANameForTime(time.Now()); err == nil { req.Header.Set(TimeZoneHeader, tz) } - if ip := GetClientIP(); ip != "" { + if ip := publicIP.Load().(string); ip != "" { req.Header.Set(ClientIPHeader, ip) } return req, nil } -// NewIssueRequest creates a new HTTP request with the required headers for issue reporting. -func NewIssueRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { - req, err := NewRequestWithHeaders(ctx, method, url, body) - if err != nil { - return nil, err - } - - req.Header.Set("content-type", "application/x-protobuf") - - // data caps - req.Header.Set(SupportedDataCapsHeader, "monthly,weekly,daily") - - return req, nil -} - // randomizedString returns a random string to avoid consistent packet lengths censors // may use to detect Lantern. func randomizedString() string { diff --git a/common/init.go b/common/init.go index 7ef8fbcb..7d377d75 100644 --- a/common/init.go +++ b/common/init.go @@ -3,25 +3,19 @@ package common import ( "fmt" - "io" "log/slog" "os" "path/filepath" - "runtime" "runtime/debug" "strings" "sync/atomic" - "time" - "unicode" - "unicode/utf8" - - "github.com/getlantern/appdir" - "gopkg.in/natefinch/lumberjack.v2" "github.com/getlantern/radiance/common/env" + "github.com/getlantern/radiance/common/fileperm" "github.com/getlantern/radiance/common/reporting" "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" ) var ( @@ -29,89 +23,83 @@ var ( ) func Env() string { - e, _ := env.Get[string](env.ENV) - e = strings.ToLower(e) - return e + return strings.ToLower(env.GetString(env.ENV)) } // Prod returns true if the application is running in production environment. // Treating ENV == "" as production is intentional: if RADIANCE_ENV is unset, // we default to production mode to ensure the application runs with safe, non-debug settings. func Prod() bool { - e, _ := env.Get[string](env.ENV) - e = strings.ToLower(e) + e := Env() return e == "production" || e == "prod" || e == "" } // Dev returns true if the application is running in development environment. func Dev() bool { - e, _ := env.Get[string](env.ENV) - e = strings.ToLower(e) + e := Env() return e == "development" || e == "dev" } // Stage returns true if the application is running in staging environment. func Stage() bool { - e, _ := env.Get[string](env.ENV) - e = strings.ToLower(e) + e := Env() return e == "stage" || e == "staging" } +func init() { + if env.GetBool(env.Testing) { + slog.SetDefault(log.NoOpLogger()) + slog.SetLogLoggerLevel(log.Disable) + } +} + // Init initializes the common components of the application. This includes setting up the directories // for data and logs, initializing the logger, and setting up reporting. -func Init(dataDir, logDir, logLevel string) error { +func Init(dataDir, logDir, logLevel string) (err error) { slog.Info("Initializing common package") - return initialize(dataDir, logDir, logLevel, false) -} - -// InitReadOnly locates the settings file in provided directory and initializes the common components -// in read-only mode using the necessary settings from the settings file. This is used in contexts -// where settings should not be modified, such as in the IPC server or other auxiliary processes. -func InitReadOnly(dataDir, logDir, logLevel string) error { - slog.Info("Initializing in read-only") - return initialize(dataDir, logDir, logLevel, true) -} - -func initialize(dataDir, logDir, logLevel string, readonly bool) error { if initialized.Swap(true) { return nil } + defer func() { + if err != nil { + initialized.Store(false) + } + }() - if v, ok := env.Get[string](env.AppVersion); ok && v != "" { + if v, ok := env.Get(env.AppVersion); ok && v != "" { Version = v slog.Info("Version overridden via RADIANCE_VERSION", "version", Version) } - reporting.Init(Version) + if v, ok := env.Get(env.Platform); ok && v != "" { + Platform = v + slog.Info("Platform overridden via RADIANCE_PLATFORM", "platform", Platform) + } + reporting.Init(GetVersion()) data, logs, err := setupDirectories(dataDir, logDir) if err != nil { return fmt.Errorf("failed to setup directories: %w", err) } - if readonly { - // in read-only mode, favor settings from the settings file if given parameters are empty - if logDir == "" && settings.GetString(settings.LogPathKey) != "" { - logs = settings.GetString(settings.LogPathKey) - } - if settings.GetString(settings.LogLevelKey) != "" { - logLevel = settings.GetString(settings.LogLevelKey) - } - } - err = initLogger(filepath.Join(logs, LogFileName), logLevel) - if err != nil { - slog.Error("Error initializing logger", "error", err) - return fmt.Errorf("initialize log: %w", err) + + if err = settings.InitSettings(data); err != nil { + return fmt.Errorf("failed to initialize settings: %w", err) } - if readonly { - settings.SetReadOnly(true) - if err := settings.StartWatching(); err != nil { - return fmt.Errorf("start watching settings file: %w", err) - } - } else { - settings.Set(settings.DataPathKey, data) - settings.Set(settings.LogPathKey, logs) + settings.Set(settings.DataPathKey, data) + settings.Set(settings.LogPathKey, logs) + // env override wins; otherwise preserve any persisted value; otherwise seed from the arg. + if v := env.GetString(env.LogLevel); v != "" { + settings.Set(settings.LogLevelKey, v) + } else if !settings.Exists(settings.LogLevelKey) { settings.Set(settings.LogLevelKey, logLevel) } + logger := log.NewLogger(log.Config{ + LogPath: filepath.Join(logs, internal.LogFileName), + Level: logLevel, + Prod: Prod(), + }) + slog.SetDefault(logger) + slog.Info("Using data and log directories", "dataDir", data, "logDir", logs) createCrashReporter() if Dev() { @@ -137,8 +125,8 @@ func logModuleInfo() { } func createCrashReporter() { - crashFilePath := filepath.Join(settings.GetString(settings.LogPathKey), "lantern_crash.log") - f, err := os.OpenFile(crashFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + crashFilePath := filepath.Join(settings.GetString(settings.LogPathKey), internal.CrashLogFileName) + f, err := os.OpenFile(crashFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, fileperm.File) if err != nil { slog.Error("Failed to open crash log file", "error", err) } else { @@ -148,151 +136,23 @@ func createCrashReporter() { } } -// initLogger reconfigures the default slog.Logger to write to a file and stdout and sets the log level. -// The log level is determined, first by the environment variable if set and valid, then by the provided level. -// If both are invalid and/or not set, it defaults to "info". -func initLogger(logPath, level string) error { - if elevel, hasLevel := env.Get[string](env.LogLevel); hasLevel { - level = elevel - } - var lvl slog.Level - if level != "" { - var err error - lvl, err = internal.ParseLogLevel(level) - if err != nil { - slog.Warn("Failed to parse log level", "error", err) - } else { - slog.SetLogLoggerLevel(lvl) - } - } - if lvl == internal.Disable { - return nil - } - - // lumberjack will create the log file if it does not exist with permissions 0600 otherwise it - // carries over the existing permissions. So we create it here with 0644 so we don't need root/admin - // privileges or chown/chmod to read it. - f, err := os.OpenFile(logPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) - if err != nil { - slog.Warn("Failed to pre-create log file", "error", err, "path", logPath) - } else { - f.Close() - } - - logRotator := &lumberjack.Logger{ - Filename: logPath, // Log file path - MaxSize: 25, // Rotate log when it reaches 25 MB - MaxBackups: 2, // Keep up to 2 rotated log files - MaxAge: 30, // Retain old log files for up to 30 days - Compress: Prod(), // Compress rotated log files - } - - loggingToStdOut := true - var logWriter io.Writer - if noStdout, _ := env.Get[bool](env.DisableStdout); noStdout { - logWriter = logRotator - loggingToStdOut = false - } else if isWindowsProd() { - // For some reason, logging to both stdout and a file on Windows - // causes issues with some Windows services where the logs - // do not get written to the file. So in prod mode on Windows, - // we log to file only. See: - // https://www.reddit.com/r/golang/comments/1fpo3cg/golang_windows_service_cannot_write_log_files/ - logWriter = logRotator - loggingToStdOut = false - } else { - logWriter = io.MultiWriter(os.Stdout, logRotator) - } - runtime.AddCleanup(&logWriter, func(f *os.File) { - f.Close() - }, f) - logger := slog.New(slog.NewTextHandler(logWriter, &slog.HandlerOptions{ - AddSource: true, - Level: lvl, - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - switch a.Key { - case slog.TimeKey: - if t, ok := a.Value.Any().(time.Time); ok { - a.Value = slog.StringValue(t.UTC().Format("2006-01-02 15:04:05.000 UTC")) - } - return a - case slog.SourceKey: - source, ok := a.Value.Any().(*slog.Source) - if !ok { - return a - } - // remove github.com/ to get pkg name - var service, fn string - fields := strings.SplitN(source.Function, "/", 4) - switch len(fields) { - case 0, 1, 2: - file := filepath.Base(source.File) - a.Value = slog.StringValue(fmt.Sprintf("%s:%d", file, source.Line)) - return a - case 3: - pf := strings.SplitN(fields[2], ".", 2) - service, fn = pf[0], pf[1] - default: - service = fields[2] - fn = strings.SplitN(fields[3], ".", 2)[1] - } - - _, file, fnd := strings.Cut(source.File, service+"/") - if !fnd { - file = filepath.Base(source.File) - } - src := slog.GroupValue( - slog.String("func", fn), - slog.String("file", fmt.Sprintf("%s:%d", file, source.Line)), - ) - a.Value = slog.GroupValue( - slog.String("service", service), - slog.Any("source", src), - ) - a.Key = "" - case slog.LevelKey: - // format the log level to account for the custom levels defined in internal/util.go, i.e. trace - // otherwise, slog will print as "DEBUG-4" (trace) or similar - level := a.Value.Any().(slog.Level) - a.Value = slog.StringValue(internal.FormatLogLevel(level)) - } - return a - }, - })) - if !loggingToStdOut { - if IsWindows() { - fmt.Printf("Logging to file only on Windows prod -- run with RADIANCE_ENV=dev to enable stdout path: %s, level: %s\n", logPath, internal.FormatLogLevel(lvl)) - } else { - fmt.Printf("Logging to file only -- RADIANCE_DISABLE_STDOUT_LOG is set path: %s, level: %s\n", logPath, internal.FormatLogLevel(lvl)) - } - } else { - fmt.Printf("Logging to file and stdout path: %s, level: %s\n", logPath, internal.FormatLogLevel(lvl)) - } - slog.SetDefault(logger) - return nil -} - -func isWindowsProd() bool { - if !IsWindows() { - return false - } - return !Dev() -} - // setupDirectories creates the data and logs directories, and needed subdirectories if they do // not exist. If data or logs are the empty string, it will use the user's config directory retrieved // from the OS. func setupDirectories(data, logs string) (dataDir, logDir string, err error) { - if d, ok := env.Get[string](env.DataPath); ok { - data = d + if path := env.GetString(env.DataPath); path != "" { + data = path } else if data == "" { - data = outDir("data") + data = internal.DefaultDataPath() } - if l, ok := env.Get[string](env.LogPath); ok { - logs = l + if path := env.GetString(env.LogPath); path != "" { + logs = path } else if logs == "" { - logs = outDir("logs") + logs = internal.DefaultLogPath() } + // ensure the data and logs directories end with the correct suffix + data = maybeAddSuffix(data, "data") + logs = maybeAddSuffix(logs, "logs") data, _ = filepath.Abs(data) logs, _ = filepath.Abs(logs) for _, path := range []string{data, logs} { @@ -300,44 +160,11 @@ func setupDirectories(data, logs string) (dataDir, logDir string, err error) { return data, logs, fmt.Errorf("failed to create directory %s: %w", path, err) } } - if err := settings.InitSettings(data); err != nil { - return data, logs, fmt.Errorf("failed to initialize settings: %w", err) - } return data, logs, nil } -func outDir(subdir string) string { - var data string - var name string - if IsWindows() || IsMacOS() { - name = capitalizeFirstLetter(Name) - } else { - name = Name - } - if IsWindows() { - publicDir := os.Getenv("Public") - data = filepath.Join(publicDir, name) - } else { - data = appdir.General(name) - } - return maybeAddSuffix(data, subdir) -} - -func capitalizeFirstLetter(s string) string { - if s == "" { - return "" - } - - r, size := utf8.DecodeRuneInString(s) - if r == utf8.RuneError { // Handle invalid UTF-8 sequences - return s // Or handle error as needed - } - - return string(unicode.ToUpper(r)) + s[size:] -} - func maybeAddSuffix(path, suffix string) string { - if filepath.Base(path) != suffix { + if !strings.EqualFold(filepath.Base(path), suffix) { path = filepath.Join(path, suffix) } return path diff --git a/common/platform.go b/common/platform.go index e6ab3198..020a1624 100644 --- a/common/platform.go +++ b/common/platform.go @@ -2,7 +2,11 @@ package common import "runtime" -const Platform = runtime.GOOS +// Platform is the runtime platform string, defaulting to runtime.GOOS but +// overridable via RADIANCE_PLATFORM (handled in common.Init) for QA scenarios +// that need to impersonate a different platform — e.g. running radiance as a +// Go process on macOS while making the API see us as an Android client. +var Platform = runtime.GOOS func IsAndroid() bool { return Platform == "android" diff --git a/common/settings/settings.go b/common/settings/settings.go index 96d4d837..7e9d29f0 100644 --- a/common/settings/settings.go +++ b/common/settings/settings.go @@ -1,3 +1,4 @@ +// Package settings provides a simple interface for storing and retrieving user settings. package settings import ( @@ -7,9 +8,9 @@ import ( "log/slog" "os" "path/filepath" + "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/knadh/koanf/parsers/json" @@ -17,38 +18,57 @@ import ( "github.com/knadh/koanf/v2" "github.com/getlantern/radiance/common/atomicfile" - "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/common/fileperm" ) -// Keys for various settings. +type _key string + const ( - CountryCodeKey = "country_code" - LocaleKey = "locale" - DeviceIDKey = "device_id" - DataPathKey = "data_path" - LogPathKey = "log_path" - EmailKey = "email" - UserLevelKey = "user_level" - TokenKey = "token" - JwtTokenKey = "jwt_token" - UserIDKey = "user_id" - DevicesKey = "devices" - LogLevelKey = "log_level" - LoginResponseKey = "login_response" - SmartRoutingKey = "smart_routing" - AdBlockKey = "ad_block" - DataCapUsageKey = "datacap_usage" - filePathKey = "file_path" - - settingsFileName = "local.json" + // Keys for various settings. + // General settings keys. + DataPathKey _key = "data_path" // string + LogPathKey _key = "log_path" // string + LogLevelKey _key = "log_level" // string + CountryCodeKey _key = "country_code" // string + LocaleKey _key = "locale" // string + DeviceIDKey _key = "device_id" // string/int + + // Application behavior related keys. + TelemetryKey _key = "telemetry_enabled" // bool + ConfigFetchDisabledKey _key = "config_fetch_disabled" // bool + FeatureOverridesKey _key = "feature_overrides" // string + + // User account related keys. + EmailKey _key = "email" // string + UserIDKey _key = "user_id" // string + UserLevelKey _key = "user_level" // string + TokenKey _key = "token" // string + JwtTokenKey _key = "jwt_token" // string + DevicesKey _key = "devices" // []Device + UserDataKey _key = "user_data" // [account.UserData] + OAuthLoginKey _key = "oauth_login" // bool + OAuthProviderKey _key = "oauth_provider" // string (e.g. "google", "apple", "email") + + // VPN related keys. + SmartRoutingKey _key = "smart_routing" // bool + SplitTunnelKey _key = "split_tunnel" // bool + AdBlockKey _key = "ad_block" // bool + AutoConnectKey _key = "auto_connect" // bool + SelectedServerKey _key = "selected_server" // [servers.Server] Server.Options is not stored + + PreferredLocationKey _key = "preferred_location" // [common.PreferredLocation] + + settingsFileName = "settings.json" ) +var ErrNotExist = errors.New("key does not exist") + +func (k _key) String() string { return string(k) } + type settings struct { k *koanf.Koanf - readOnly atomic.Bool initialized bool - watcher *internal.FileWatcher + filePath string mu sync.Mutex } @@ -56,60 +76,38 @@ var k = &settings{ k: koanf.New("."), } -var ErrReadOnly = errors.New("read-only") +func init() { + // set default values. + k.k.Set(LocaleKey.String(), "fa-IR") + k.k.Set(UserLevelKey.String(), "free") +} -// InitSettings initializes the config for user settings, which can be used by both the tunnel process and -// the main application process to read user preferences like locale. +// InitSettings initializes the config for user settings. func InitSettings(fileDir string) error { k.mu.Lock() defer k.mu.Unlock() if k.initialized { return nil } - if err := initialize(fileDir); err != nil { - return fmt.Errorf("initializing settings: %w", err) - } - k.initialized = true - return nil -} - -func initialize(fileDir string) error { - k.k = koanf.New(".") if err := os.MkdirAll(fileDir, 0755); err != nil { return fmt.Errorf("failed to create data directory: %v", err) } - filePath := filepath.Join(fileDir, settingsFileName) - switch err := loadSettings(filePath); { + k.filePath = filepath.Join(fileDir, settingsFileName) + switch err := loadSettings(k.filePath); { case errors.Is(err, fs.ErrNotExist): - slog.Warn("settings file not found", "path", filePath) // file may not have been created yet - if err := setDefaults(filePath); err != nil { - return fmt.Errorf("setting default settings: %w", err) - } + slog.Warn("settings file not found", "path", k.filePath) // file may not have been created yet return save() case err != nil: return fmt.Errorf("loading settings: %w", err) } - return nil -} - -func setDefaults(filePath string) error { - // We need to set the file path first because the save function reads it as soon as we set any key. - if err := k.k.Set(filePathKey, filePath); err != nil { - return fmt.Errorf("failed to set file path: %w", err) - } - if err := k.k.Set(LocaleKey, "fa-IR"); err != nil { - return fmt.Errorf("failed to set default locale: %w", err) - } - if err := k.k.Set(UserLevelKey, "free"); err != nil { - return fmt.Errorf("failed to set default user level: %w", err) - } + k.initialized = true return nil } func loadSettings(path string) error { contents, err := atomicfile.ReadFile(path) if err != nil { - return fmt.Errorf("loading settings (read-only): %w", err) + return fmt.Errorf("loading settings: %w", err) } kk := koanf.New(".") if err := kk.Load(rawbytes.Provider(contents), json.Parser()); err != nil { @@ -119,107 +117,120 @@ func loadSettings(path string) error { return nil } -func SetReadOnly(readOnly bool) { - k.readOnly.Store(readOnly) +func Get(key _key) any { + return k.k.Get(key.String()) } -func StartWatching() error { - k.mu.Lock() - defer k.mu.Unlock() - if !k.initialized { - return errors.New("settings not initialized") +func GetString(key _key) string { + // JSON round-trip turns all numbers into float64 and since koanf uses Sprintf("%v") for string + // conversion, large integers (i.e. userID) get converted to scientific notation (e.g. 3.87286618e+08) + // so we handle float64 separately + value := Get(key) + if value == nil { + return "" } - if k.watcher != nil { - return errors.New("settings file watcher already started") + switch v := value.(type) { + case float64: + return strconv.FormatInt(int64(v), 10) + case string: + return v + default: + return fmt.Sprintf("%v", v) } +} - path := k.k.String(filePathKey) - watcher := internal.NewFileWatcher(path, func() { - if err := loadSettings(path); err != nil { - slog.Error("reloading settings file", "error", err) - } - }) - if err := watcher.Start(); err != nil { - return fmt.Errorf("starting settings file watcher: %w", err) - } - k.watcher = watcher - // reload settings once at start in case there were changes before we started watching - if err := loadSettings(path); err != nil && !errors.Is(err, fs.ErrNotExist) { - return err - } - return nil +func GetBool(key _key) bool { + return k.k.Bool(key.String()) } -// StopWatching stops watching the settings file for changes. This is only relevant in read-only mode. -func StopWatching() { - k.mu.Lock() - defer k.mu.Unlock() - if k.watcher != nil { - k.watcher.Close() - k.watcher = nil - } +func GetInt(key _key) int { + return k.k.Int(key.String()) } -func Get(key string) any { - return k.k.Get(key) +func GetInt64(key _key) int64 { + return k.k.Int64(key.String()) } -func GetString(key string) string { - return k.k.String(key) +func GetFloat64(key _key) float64 { + return k.k.Float64(key.String()) } -func GetBool(key string) bool { - return k.k.Bool(key) +func GetStringSlice(key _key) []string { + return k.k.Strings(key.String()) } -func GetInt(key string) int { - return k.k.Int(key) +func GetDuration(key _key) time.Duration { + return k.k.Duration(key.String()) } -func GetInt64(key string) int64 { - return k.k.Int64(key) +func GetStruct(key _key, out any) error { + return k.k.Unmarshal(key.String(), out) } -func GetFloat64(key string) float64 { - return k.k.Float64(key) +func Exists(key _key) bool { + return k.k.Exists(key.String()) } -func GetStringSlice(key string) []string { - return k.k.Strings(key) +func Set(key _key, value any) error { + err := k.k.Set(key.String(), value) + if err != nil { + return fmt.Errorf("could not set key %s: %w", key, err) + } + return save() } -func GetDuration(key string) time.Duration { - return k.k.Duration(key) +func Clear(key _key) { + k.k.Delete(key.String()) } -func GetStruct(key string, out any) error { - return k.k.Unmarshal(key, out) +type Settings map[_key]any + +func (s Settings) Diff(s2 Settings) Settings { + diff := make(Settings) + for k, v1 := range s { + if v2, ok := s2[k]; !ok || v1 != v2 { + diff[k] = v1 + } + } + return diff } -func Set(key string, value any) error { - if k.readOnly.Load() { - return ErrReadOnly +func GetAll() Settings { + s := make(Settings) + for key, value := range k.k.All() { + s[_key(key)] = value } - err := k.k.Set(key, value) - if err != nil { - return fmt.Errorf("could not set key %s: %w", key, err) + return s +} + +func GetAllFor(keys ..._key) Settings { + if len(keys) == 0 { + return GetAll() + } + s := make(Settings) + for _, key := range keys { + s[key] = k.k.Get(key.String()) + } + return s +} + +// Patch takes a map of settings to update and applies them all at once. +func Patch(updates Settings) error { + for key, value := range updates { + if err := k.k.Set(_key(key).String(), value); err != nil { + return fmt.Errorf("could not set key %s: %w", key, err) + } } return save() } func save() error { - if k.readOnly.Load() { - return ErrReadOnly - } - if GetString(filePathKey) == "" { - return errors.New("settings file path is not set") - } out, err := k.k.Marshal(json.Parser()) if err != nil { return fmt.Errorf("could not marshal koanf file: %w", err) } - err = atomicfile.WriteFile(GetString(filePathKey), out, 0644) + err = atomicfile.WriteFile(k.filePath, out, fileperm.File) if err != nil { return fmt.Errorf("could not write koanf file: %w", err) } @@ -230,14 +241,8 @@ func save() error { func Reset() { k.mu.Lock() defer k.mu.Unlock() - if !k.readOnly.Load() { - if k.watcher != nil { - k.watcher.Close() - k.watcher = nil - } - k.k = koanf.New(".") - k.initialized = false - } + k.k = koanf.New(".") + k.initialized = false } func IsPro() bool { @@ -255,7 +260,3 @@ func Devices() ([]Device, error) { err := GetStruct(DevicesKey, &devices) return devices, err } - -type UserChangeEvent struct { - events.Event -} diff --git a/common/settings/settings_test.go b/common/settings/settings_test.go index 21f16bd2..585205c2 100644 --- a/common/settings/settings_test.go +++ b/common/settings/settings_test.go @@ -5,190 +5,28 @@ import ( "path/filepath" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + _ "github.com/getlantern/radiance/common/env" ) func TestInitSettings(t *testing.T) { - t.Run("first run - no config file exists", func(t *testing.T) { - tempDir := t.TempDir() - err := initialize(tempDir) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - // Verify default locale was set - locale := Get(LocaleKey) - if locale != "fa-IR" { - t.Errorf("expected default locale 'fa-IR', got %s", locale) - } - }) - t.Run("existing valid config file", func(t *testing.T) { - // Create a temporary directory tempDir := t.TempDir() + path := filepath.Join(tempDir, settingsFileName) + content := []byte(`{"locale": "en-US", "country_code": "US"}`) + require.NoError(t, os.WriteFile(path, content, 0644), "failed to create test config file") - // Create a valid config file - configPath := filepath.Join(tempDir, "local.json") - configContent := []byte(`{"locale": "en-US", "country_code": "US"}`) - if err := os.WriteFile(configPath, configContent, 0644); err != nil { - t.Fatalf("failed to create test config file: %v", err) - } - - err := initialize(tempDir) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - // Verify config was loaded - locale := Get(LocaleKey) - if locale != "en-US" { - t.Errorf("expected locale 'en-US', got %s", locale) - } - - countryCode := Get(CountryCodeKey) - if countryCode != "US" { - t.Errorf("expected country_code 'US', got %s", countryCode) - } + require.NoError(t, InitSettings(tempDir), "failed to initialize settings") + assert.Equal(t, "en-US", Get(LocaleKey)) + assert.Equal(t, "US", Get(CountryCodeKey)) }) t.Run("invalid config file", func(t *testing.T) { - // Create a temporary directory - tempDir := t.TempDir() - - // Create an invalid config file - configPath := filepath.Join(tempDir, "local.json") - configContent := []byte(`{invalid json}`) - if err := os.WriteFile(configPath, configContent, 0644); err != nil { - t.Fatalf("failed to create test config file: %v", err) - } - - err := initialize(tempDir) - if err == nil { - t.Fatal("expected error for invalid config file, got nil") - } - }) - - t.Run("non-existent directory", func(t *testing.T) { - // Use a non-existent directory - nonExistentDir := filepath.Join(os.TempDir(), "non-existent-dir-123456789") - - err := initialize(nonExistentDir) - if err != nil { - t.Fatalf("expected no error for non-existent directory (first run), got %v", err) - } - }) -} - -func TestSetStruct(t *testing.T) { - tempDir := t.TempDir() - err := initialize(tempDir) - if err != nil { - t.Fatalf("expected no error initializing settings, got %v", err) - } - - err = Set("testStruct", struct { - Field1 string - Field2 int - }{ - Field1: "value1", - Field2: 42, + path := filepath.Join(t.TempDir(), settingsFileName) + content := []byte(`{invalid json}`) + require.NoError(t, os.WriteFile(path, content, 0644), "failed to create test config file") + require.Error(t, loadSettings(path), "expected error for invalid config file") }) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - var result struct { - Field1 string - Field2 int - } - err = GetStruct("testStruct", &result) - if err != nil { - t.Fatalf("expected no error retrieving struct, got %v", err) - } - - if result.Field1 != "value1" || result.Field2 != 42 { - t.Errorf("expected struct {Field1: 'value1', Field2: 42}, got %+v", result) - } - - // Reset koanf state - Reset() - result.Field1 = "" - result.Field2 = 0 - - // At first, the struct should not be present. - err = GetStruct("testStruct", &result) - if err != nil { - t.Fatalf("expected no error retrieving struct, got %v", err) - } - - if result.Field1 != "" || result.Field2 != 0 { - t.Errorf("expected struct {Field1: '', Field2: 0}, got %+v", result) - } - - err = initialize(tempDir) - if err != nil { - t.Fatalf("expected no error re-initializing settings, got %v", err) - } - - var result2 struct { - Field1 string - Field2 int - } - err = GetStruct("testStruct", &result2) - if err != nil { - t.Fatalf("expected no error retrieving struct after re-init, got %v", err) - } - - if result2.Field1 != "value1" || result2.Field2 != 42 { - t.Errorf("expected struct {Field1: 'value1', Field2: 42} after re-init, got %+v", result2) - } -} - -func TestStructSlicePersistence(t *testing.T) { - tempDir := t.TempDir() - err := initialize(tempDir) - if err != nil { - t.Fatalf("expected no error initializing settings, got %v", err) - } - - type Item struct { - Name string - Value int - } - - items := []Item{ - {Name: "item1", Value: 1}, - {Name: "item2", Value: 2}, - } - - err = Set("itemList", items) - if err != nil { - t.Fatalf("expected no error setting struct slice, got %v", err) - } - - var retrievedItems []Item - err = GetStruct("itemList", &retrievedItems) - if err != nil { - t.Fatalf("expected no error retrieving struct slice, got %v", err) - } - - if len(retrievedItems) != 2 || retrievedItems[0].Name != "item1" || retrievedItems[1].Value != 2 { - t.Errorf("retrieved struct slice does not match expected values: %+v", retrievedItems) - } - - retrievedItems = nil - err = initialize(tempDir) - if err != nil { - t.Fatalf("expected no error re-initializing settings, got %v", err) - } - - var retrievedItems2 []Item - err = GetStruct("itemList", &retrievedItems2) - if err != nil { - t.Fatalf("expected no error retrieving struct slice after re-init, got %v", err) - } - - if len(retrievedItems2) != 2 || retrievedItems2[0].Name != "item1" || retrievedItems2[1].Value != 2 { - t.Errorf("retrieved struct slice after re-init does not match expected values: %+v", retrievedItems2) - } } diff --git a/common/types.go b/common/types.go new file mode 100644 index 00000000..fc1db8fb --- /dev/null +++ b/common/types.go @@ -0,0 +1,7 @@ +package common + +import ( + C "github.com/getlantern/common" +) + +type PreferredLocation = C.ServerLocation diff --git a/config/config.go b/config/config.go index f90fa8ba..c5fffd7c 100644 --- a/config/config.go +++ b/config/config.go @@ -10,6 +10,7 @@ import ( "fmt" "io/fs" "log/slog" + "net/http" "os" "path/filepath" "reflect" @@ -27,107 +28,109 @@ import ( box "github.com/getlantern/lantern-box" lbO "github.com/getlantern/lantern-box/option" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - - "github.com/getlantern/radiance/api" + "github.com/getlantern/radiance/account" "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/atomicfile" + "github.com/getlantern/radiance/common/fileperm" "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/traces" + "github.com/getlantern/radiance/internal" ) const ( - maxRetryDelay = 2 * time.Minute + maxRetryDelay = 2 * time.Minute + defaultPollInterval = 10 * time.Minute ) var ( // ErrFetchingConfig is returned by [ConfigHandler.GetConfig] when if there was an error // fetching the configuration. ErrFetchingConfig = errors.New("failed to fetch config") -) - -// Config includes all configuration data from the Lantern API as well as any stored local preferences. -type Config struct { - ConfigResponse C.ConfigResponse - PreferredLocation C.ServerLocation -} -type ServerManager interface { - SetServers(serverGroup string, opts servers.Options) error -} + // ErrConfigFetchDisabled is returned by [ConfigHandler.Update] when config fetching + // is disabled via settings. + ErrConfigFetchDisabled = errors.New("config fetching is disabled") +) -// ListenerFunc is a function that is called when the configuration changes. -type ListenerFunc func(oldConfig, newConfig *Config) error +// Config includes all configuration data from the Lantern API +type Config = C.ConfigResponse type Options struct { - PollInterval time.Duration - SvrManager ServerManager - DataDir string - Locale string - APIHandler *api.APIClient + PollInterval time.Duration + DataPath string + Locale string + AccountClient *account.Client + Logger *slog.Logger + HTTPClient *http.Client } // ConfigHandler handles fetching the proxy configuration from the proxy server. It provides access // to the most recent configuration. type ConfigHandler struct { // config holds a configResult. - config atomic.Pointer[Config] - ftr Fetcher - svrManager ServerManager - - ctx context.Context - cancel context.CancelFunc - fetchDisabled bool - configPath string - wgKeyPath string - preferredLocation atomic.Pointer[C.ServerLocation] - configMu sync.RWMutex + config atomic.Pointer[Config] + ftr Fetcher + started atomic.Bool + fetching atomic.Bool + logger *slog.Logger + options Options + + ctx context.Context + cancel context.CancelFunc + pollInterval time.Duration + configPath string + wgKeyPath string + startOnce sync.Once } // NewConfigHandler creates a new ConfigHandler that fetches the proxy configuration every pollInterval. -func NewConfigHandler(options Options) *ConfigHandler { - configPath := filepath.Join(options.DataDir, common.ConfigFileName) - ctx, cancel := context.WithCancel(context.Background()) +func NewConfigHandler(ctx context.Context, options Options) *ConfigHandler { + ctx, cancel := context.WithCancel(ctx) + pollInterval := options.PollInterval + if pollInterval <= 0 { + pollInterval = defaultPollInterval + } + logger := options.Logger + if logger == nil { + logger = slog.Default() + } + dir := options.DataPath ch := &ConfigHandler{ - fetchDisabled: options.PollInterval <= 0, - ctx: ctx, - cancel: cancel, - configPath: configPath, - wgKeyPath: filepath.Join(options.DataDir, "wg.key"), - svrManager: options.SvrManager, + ctx: ctx, + cancel: cancel, + pollInterval: pollInterval, + configPath: filepath.Join(dir, internal.ConfigFileName), + wgKeyPath: filepath.Join(dir, "wg.key"), + logger: logger, + options: options, } - // Set the preferred location to an empty struct to define the underlying type. - ch.preferredLocation.Store(&C.ServerLocation{}) - - if err := os.MkdirAll(filepath.Dir(options.DataDir), 0o755); err != nil { - slog.Error("creating config directory", "error", err) + if err := os.MkdirAll(dir, 0o755); err != nil { + ch.logger.Error("creating config directory", "error", err) } - if err := ch.loadConfig(); err != nil { - slog.Error("failed to load config", "error", err) + ch.logger.Error("failed to load config", "error", err) } + return ch +} - if !ch.fetchDisabled { - ch.ftr = newFetcher(options.Locale, options.APIHandler) - go ch.fetchLoop(options.PollInterval) - events.Subscribe(func(evt settings.UserChangeEvent) { - slog.Debug("User change detected that requires config refetch") +func (ch *ConfigHandler) Start() { + ch.startOnce.Do(func() { + ch.ftr = newFetcher(ch.options.Locale, ch.options.AccountClient, ch.options.HTTPClient) + ch.started.Store(true) + go ch.fetchLoop(ch.pollInterval) + events.Subscribe(func(evt account.UserChangeEvent) { + ch.logger.Debug("User change detected that requires config refetch") if err := ch.fetchConfig(); err != nil { - slog.Error("Failed to fetch config", "error", err) + ch.logger.Error("Failed to fetch config", "error", err) } }) - } - return ch + }) } var ErrNoWGKey = errors.New("no wg key") func (ch *ConfigHandler) loadWGKey() (wgtypes.Key, error) { - buf, err := os.ReadFile(ch.wgKeyPath) + buf, err := atomicfile.ReadFile(ch.wgKeyPath) if os.IsNotExist(err) { return wgtypes.Key{}, ErrNoWGKey } @@ -141,43 +144,19 @@ func (ch *ConfigHandler) loadWGKey() (wgtypes.Key, error) { return key, nil } -// SetPreferredServerLocation sets the preferred server location to connect to -func (ch *ConfigHandler) SetPreferredServerLocation(country, city string) { - preferred := &C.ServerLocation{ - Country: country, - City: city, - } - // We store the preferred location in memory in case we haven't fetched a config yet. - ch.preferredLocation.Store(preferred) - ch.modifyConfig(func(cfg *Config) { - cfg.PreferredLocation = *preferred - }) - // fetch the config with the new preferred location on a separate goroutine - go func() { - if err := ch.fetchConfig(); err != nil { - slog.Error("Failed to fetch config: %v", "error", err) - } - }() -} - func (ch *ConfigHandler) fetchConfig() error { - if ch.fetchDisabled { - return fmt.Errorf("fetching config is disabled") + if settings.GetBool(settings.ConfigFetchDisabledKey) { + ch.logger.Info("config fetch disabled, skipping") + return nil } if ch.isClosed() { return fmt.Errorf("config handler is closed") } - var preferred C.ServerLocation - oldConfig, err := ch.GetConfig() - if err != nil { - slog.Info("No stored config yet -- using in-memory server location", "error", err) - storedLocation := ch.preferredLocation.Load() - if storedLocation != nil { - preferred = *storedLocation - } - } else { - preferred = oldConfig.PreferredLocation + if !ch.fetching.CompareAndSwap(false, true) { + ch.logger.Info("config fetch already in flight, skipping") + return nil } + defer ch.fetching.Store(false) privateKey, err := ch.loadWGKey() if err != nil && !errors.Is(err, ErrNoWGKey) { @@ -190,25 +169,30 @@ func (ch *ConfigHandler) fetchConfig() error { return fmt.Errorf("failed to generate wg keys: %w", keyErr) } - if writeErr := os.WriteFile(ch.wgKeyPath, []byte(privateKey.String()), 0o600); writeErr != nil { + if writeErr := atomicfile.WriteFile(ch.wgKeyPath, []byte(privateKey.String()), fileperm.File); writeErr != nil { return fmt.Errorf("writing wg key file: %w", writeErr) } } - slog.Info("Fetching config") + ch.logger.Info("Fetching config") + preferred := common.PreferredLocation{} + if err := settings.GetStruct(settings.PreferredLocationKey, &preferred); err != nil { + ch.logger.Error("failed to get preferred location from settings", "error", err) + } + resp, err := ch.ftr.fetchConfig(ch.ctx, preferred, privateKey.PublicKey().String()) if err != nil { return fmt.Errorf("%w: %w", ErrFetchingConfig, err) } if resp == nil { - slog.Info("no new config available") + ch.logger.Info("no new config available") return nil } - slog.Info("Config fetched from server") + ch.logger.Info("Config fetched from server") // Save the raw config for debugging - if writeErr := os.WriteFile(strings.TrimSuffix(ch.configPath, ".json")+"_raw.json", resp, 0o600); writeErr != nil { - slog.Error("writing raw config file", "error", writeErr) + if writeErr := atomicfile.WriteFile(strings.TrimSuffix(ch.configPath, ".json")+"_raw.json", resp, fileperm.File); writeErr != nil { + ch.logger.Error("writing raw config file", "error", writeErr) } // Otherwise, we keep the previous config and store any error that might have occurred. @@ -218,68 +202,18 @@ func (ch *ConfigHandler) fetchConfig() error { // On the other hand, if we have a new config, we want to overwrite any previous error. confResp, err := singjson.UnmarshalExtendedContext[C.ConfigResponse](box.BaseContext(), resp) if err != nil { - slog.Error("failed to parse config", "error", err) + ch.logger.Error("failed to parse config", "error", err) return fmt.Errorf("parsing config: %w", err) } cleanTags(&confResp) - if err = setWireGuardKeyInOptions(confResp.Options.Endpoints, privateKey); err != nil { - slog.Error("failed to replace private key", "error", err) - return fmt.Errorf("setting wireguard private key: %w", err) - } + setWireGuardKeyInOptions(confResp.Options.Endpoints, privateKey) setCustomProtocolOptions(confResp.Options.Outbounds) - if err := ch.setConfig(&Config{ConfigResponse: confResp}); err == nil { - cfg := ch.config.Load().ConfigResponse - locs := make(map[string]C.ServerLocation, len(cfg.OutboundLocations)+len(cfg.Servers)) - // Track which cities are already covered by active outbounds. - coveredCities := make(map[string]bool, len(cfg.OutboundLocations)) - for k, v := range cfg.OutboundLocations { - if v == nil { - slog.Warn("Server location is nil, skipping", "tag", k) - continue - } - locs[k] = *v - coveredCities[v.City+"|"+v.CountryCode] = true - } - // Include available server locations not already covered by active - // outbounds so the client's location picker shows every location. - for _, sl := range cfg.Servers { - if coveredCities[sl.City+"|"+sl.CountryCode] { - continue - } - key := strings.ToLower(strings.ReplaceAll(sl.City, " ", "-") + "-" + sl.CountryCode) - locs[key] = sl - } - opts := servers.Options{ - Outbounds: cfg.Options.Outbounds, - Endpoints: cfg.Options.Endpoints, - Locations: locs, - URLOverrides: cfg.BanditURLOverrides, - } - if len(cfg.BanditURLOverrides) > 0 { - slog.Info("Config includes bandit URL overrides", - "override_count", len(cfg.BanditURLOverrides), - "outbound_count", len(cfg.Options.Outbounds), - "endpoint_count", len(cfg.Options.Endpoints), - ) - // Create a marker span linked to the API's bandit trace so the - // config fetch appears in the same distributed trace as the callback. - if ctx, ok := traces.ExtractBanditTraceContext(cfg.BanditURLOverrides); ok { - _, span := otel.Tracer(tracerName).Start(ctx, "radiance.config_received", - trace.WithAttributes( - attribute.Int("bandit.override_count", len(cfg.BanditURLOverrides)), - attribute.Int("bandit.outbound_count", len(cfg.Options.Outbounds)), - ), - ) - span.End() // point-in-time marker — config was received at this timestamp - } - } - if err := ch.svrManager.SetServers(servers.SGLantern, opts); err != nil { - slog.Error("setting servers in manager", "error", err) - } + if err := ch.setConfig(&confResp); err != nil { + ch.logger.Error("failed to set config", "error", err) + return fmt.Errorf("setting config: %w", err) } - - slog.Info("Config fetched") + ch.logger.Info("Config fetched") return nil } @@ -296,7 +230,6 @@ func setCustomProtocolOptions(outbounds []option.Outbound) { } } -// TODO: move this to lantern-cloud func cleanTags(cfg *C.ConfigResponse) { opts := cfg.Options locs := cfg.OutboundLocations @@ -316,7 +249,7 @@ func cleanTags(cfg *C.ConfigResponse) { cfg.OutboundLocations = nlocs } -func setWireGuardKeyInOptions(endpoints []option.Endpoint, privateKey wgtypes.Key) error { +func setWireGuardKeyInOptions(endpoints []option.Endpoint, privateKey wgtypes.Key) { // Requires privilege and cannot conflict with existing system interfaces // System tries to use system env; for mobile we need to tun device system := !(common.IsAndroid() || common.IsIOS() || common.IsMacOS()) @@ -331,7 +264,6 @@ func setWireGuardKeyInOptions(endpoints []option.Endpoint, privateKey wgtypes.Ke default: } } - return nil } // fetchLoop fetches the configuration periodically. It uses the server's @@ -342,7 +274,7 @@ func (ch *ConfigHandler) fetchLoop(defaultPollInterval time.Duration) { backoff := common.NewBackoff(maxRetryDelay) for { if err := ch.fetchConfig(); err != nil { - slog.Error("Failed to fetch config. Retrying", "error", err) + ch.logger.Error("Failed to fetch config. Retrying", "error", err) backoff.Wait(ch.ctx) if ch.ctx.Err() != nil { return @@ -354,13 +286,13 @@ func (ch *ConfigHandler) fetchLoop(defaultPollInterval time.Duration) { // Use server-recommended poll interval if available, clamped to a // minimum of 10s to prevent excessive polling. interval := defaultPollInterval - if cfg := ch.config.Load(); cfg != nil && cfg.ConfigResponse.PollIntervalSeconds > 0 { - serverInterval := time.Duration(cfg.ConfigResponse.PollIntervalSeconds) * time.Second + if cfg := ch.config.Load(); cfg != nil && cfg.PollIntervalSeconds > 0 { + serverInterval := time.Duration(cfg.PollIntervalSeconds) * time.Second if serverInterval < 10*time.Second { serverInterval = 10 * time.Second } interval = serverInterval - slog.Debug("Using server-recommended poll interval", + ch.logger.Debug("Using server-recommended poll interval", "interval", interval, "default", defaultPollInterval, ) @@ -374,6 +306,18 @@ func (ch *ConfigHandler) fetchLoop(defaultPollInterval time.Duration) { } } +// Update immediately fetches the latest config. It returns [ErrConfigFetchDisabled] +// if config fetching is disabled in settings. +func (ch *ConfigHandler) Update() error { + if settings.GetBool(settings.ConfigFetchDisabledKey) { + return ErrConfigFetchDisabled + } + if !ch.started.Load() { + return fmt.Errorf("config handler not started") + } + return ch.fetchConfig() +} + // Stop stops the ConfigHandler from fetching new configurations. func (ch *ConfigHandler) Stop() { ch.cancel() @@ -391,8 +335,8 @@ func (ch *ConfigHandler) isClosed() bool { // loadConfig loads the config file from the disk. If the config file is not found, it returns // nil. func (ch *ConfigHandler) loadConfig() error { - slog.Debug("reading config file") - cfg, err := Load(ch.configPath) + ch.logger.Debug("reading config file") + cfg, err := load(ch.configPath) if err != nil { return fmt.Errorf("reading config file: %w", err) } @@ -403,7 +347,7 @@ func (ch *ConfigHandler) loadConfig() error { return nil } -func Load(path string) (*Config, error) { +func load(path string) (*Config, error) { buf, err := atomicfile.ReadFile(path) if errors.Is(err, fs.ErrNotExist) { return nil, nil // No config file yet @@ -411,14 +355,22 @@ func Load(path string) (*Config, error) { if err != nil { return nil, fmt.Errorf("reading config file: %w", err) } - cfg, err := unmarshalConfig(buf) + ctx := box.BaseContext() + cfg, err := singjson.UnmarshalExtendedContext[*Config](ctx, buf) + if err != nil { + // try to migrate from old format if parsing fails + // TODO(3/06, garmr-ulfr): remove this migration code after a few releases + if cfg, err = migrateToNewFmt(buf); err == nil { + saveConfig(cfg, path) + } + } if err != nil { return nil, fmt.Errorf("parsing config: %w", err) } return cfg, nil } -func unmarshalConfig(data []byte) (*Config, error) { +func migrateToNewFmt(data []byte) (*Config, error) { type T struct { ConfigResponse json.RawMessage PreferredLocation C.ServerLocation @@ -431,10 +383,8 @@ func unmarshalConfig(data []byte) (*Config, error) { if err != nil { return nil, err } - return &Config{ - ConfigResponse: opts, - PreferredLocation: tmp.PreferredLocation, - }, nil + settings.Set(settings.PreferredLocationKey, &tmp.PreferredLocation) + return &opts, nil } // saveConfig saves the config to the disk. It creates the config file if it doesn't exist. @@ -446,7 +396,7 @@ func saveConfig(cfg *Config, path string) error { if err != nil { return fmt.Errorf("marshalling config: %w", err) } - return atomicfile.WriteFile(path, buf, 0644) + return atomicfile.WriteFile(path, buf, fileperm.File) } // GetConfig returns the current configuration. It returns an error if the config is not yet available. @@ -459,27 +409,20 @@ func (ch *ConfigHandler) GetConfig() (*Config, error) { } func (ch *ConfigHandler) setConfig(cfg *Config) error { - slog.Info("Setting config") + ch.logger.Info("Setting config") if cfg == nil { - slog.Warn("Config is nil, not setting") + ch.logger.Warn("Config is nil, not setting") return nil } oldConfig, _ := ch.GetConfig() - if cfg.PreferredLocation == (C.ServerLocation{}) { - storedLocation := ch.preferredLocation.Load() - if storedLocation != nil { - cfg.PreferredLocation = *storedLocation - } - } - ch.config.Store(cfg) - slog.Debug("Saving config", "path", ch.configPath) + ch.logger.Debug("Saving config", "path", ch.configPath) if err := saveConfig(cfg, ch.configPath); err != nil { - slog.Error("saving config", "error", err) + ch.logger.Error("saving config", "error", err) return fmt.Errorf("saving config: %w", err) } - slog.Info("saved new config") - slog.Info("Config set") + ch.logger.Info("saved new config") + ch.logger.Info("Config set") if !ch.isClosed() { emit(oldConfig, cfg) } @@ -498,21 +441,3 @@ func emit(old, new *Config) { events.Emit(NewConfigEvent{Old: old, New: new}) } } - -// modifyConfig saves the config to the disk with the given config. It creates the config file -// if it doesn't exist. -func (ch *ConfigHandler) modifyConfig(fn func(cfg *Config)) { - ch.configMu.Lock() - cfg, err := ch.GetConfig() - if err != nil { - // This could happen if we haven't successfully fetched the config yet. - slog.Error("getting config", "error", err) - ch.configMu.Unlock() - return - } - // Call the function with the config - // and save the config to the disk. - fn(cfg) - ch.configMu.Unlock() - ch.setConfig(cfg) -} diff --git a/config/config_test.go b/config/config_test.go index 2282d8cd..36907ae1 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -6,30 +6,27 @@ import ( "errors" "os" "path/filepath" - "sync/atomic" "testing" C "github.com/getlantern/common" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/servers" + "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" ) func TestSaveConfig(t *testing.T) { // Setup temporary directory for testing tempDir := t.TempDir() - configPath := filepath.Join(tempDir, common.ConfigFileName) + configPath := filepath.Join(tempDir, internal.ConfigFileName) // Create a sample config to save expectedConfig := Config{ - ConfigResponse: C.ConfigResponse{ - // Populate with sample data - Servers: []C.ServerLocation{ - {Country: "US", City: "New York"}, - {Country: "UK", City: "London"}, - }, + // Populate with sample data + Servers: []C.ServerLocation{ + {Country: "US", City: "New York"}, + {Country: "UK", City: "London"}, }, } // Save the config @@ -50,7 +47,7 @@ func TestSaveConfig(t *testing.T) { func TestGetConfig(t *testing.T) { // Setup temporary directory for testing tempDir := t.TempDir() - configPath := filepath.Join(tempDir, common.ConfigFileName) + configPath := filepath.Join(tempDir, internal.ConfigFileName) // Create a ConfigHandler with the mock parser ch := &ConfigHandler{ @@ -67,11 +64,9 @@ func TestGetConfig(t *testing.T) { // Test case: Valid config set t.Run("ValidConfigSet", func(t *testing.T) { expectedConfig := &Config{ - ConfigResponse: C.ConfigResponse{ - Servers: []C.ServerLocation{ - {Country: "US", City: "New York"}, - {Country: "UK", City: "London"}, - }, + Servers: []C.ServerLocation{ + {Country: "US", City: "New York"}, + {Country: "UK", City: "London"}, }, } @@ -84,53 +79,10 @@ func TestGetConfig(t *testing.T) { }) } -func TestSetPreferredServerLocation(t *testing.T) { - // Setup temporary directory for testing - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, common.ConfigFileName) - - // Create a ConfigHandler with the mock parser - ctx, cancel := context.WithCancel(context.Background()) - ch := &ConfigHandler{ - configPath: configPath, - ftr: newFetcher("en-US", nil), - ctx: ctx, - cancel: cancel, - } - - ch.config.Store(&Config{ - ConfigResponse: C.ConfigResponse{ - Servers: []C.ServerLocation{ - {Country: "US", City: "New York"}, - {Country: "UK", City: "London"}, - }, - }, - PreferredLocation: C.ServerLocation{ - Country: "US", - City: "New York", - }, - }) - - // Test case: Set preferred server location - t.Run("SetPreferredServerLocation", func(t *testing.T) { - country := "US" - city := "Los Angeles" - - // Call SetPreferredServerLocation - ch.SetPreferredServerLocation(country, city) - - // Verify the preferred location is updated - actualConfig, err := ch.GetConfig() - require.NoError(t, err, "Should not return an error when getting config") - assert.Equal(t, country, actualConfig.PreferredLocation.Country, "Preferred country should match") - assert.Equal(t, city, actualConfig.PreferredLocation.City, "Preferred city should match") - }) -} - func TestHandlerFetchConfig(t *testing.T) { // Setup temporary directory for testing tempDir := t.TempDir() - configPath := filepath.Join(tempDir, common.ConfigFileName) + configPath := filepath.Join(tempDir, internal.ConfigFileName) // Mock fetcher mockFetcher := &MockFetcher{} @@ -138,13 +90,12 @@ func TestHandlerFetchConfig(t *testing.T) { // Create a ConfigHandler with the mock parser and fetcher ctx, cancel := context.WithCancel(context.Background()) ch := &ConfigHandler{ - configPath: configPath, - preferredLocation: atomic.Pointer[C.ServerLocation]{}, - ftr: mockFetcher, - wgKeyPath: filepath.Join(tempDir, "wg.key"), - svrManager: &mockSrvManager{}, - ctx: ctx, - cancel: cancel, + configPath: configPath, + ftr: mockFetcher, + wgKeyPath: filepath.Join(tempDir, "wg.key"), + ctx: ctx, + cancel: cancel, + logger: log.NoOpLogger(), } // Test case: No server location set @@ -160,8 +111,8 @@ func TestHandlerFetchConfig(t *testing.T) { require.NoError(t, err, "Should not return an error when no server location is set") actualConfig, err := ch.GetConfig() require.NoError(t, err, "Should not return an error when getting config") - assert.Equal(t, "US", actualConfig.ConfigResponse.Servers[0].Country, "First server country should match") - assert.Equal(t, "New York", actualConfig.ConfigResponse.Servers[0].City, "First server city should match") + assert.Equal(t, "US", actualConfig.Servers[0].Country, "First server country should match") + assert.Equal(t, "New York", actualConfig.Servers[0].City, "First server city should match") }) // Test case: No stored config, fetch succeeds @@ -174,15 +125,13 @@ func TestHandlerFetchConfig(t *testing.T) { }`) mockFetcher.err = nil - ch.preferredLocation.Store(&C.ServerLocation{Country: "US", City: "New York"}) - err := ch.fetchConfig() require.NoError(t, err, "Should not return an error when fetch succeeds") actualConfig, err := ch.GetConfig() require.NoError(t, err, "Should not return an error when getting config") - assert.Equal(t, "US", actualConfig.ConfigResponse.Servers[0].Country, "First server country should match") - assert.Equal(t, "New York", actualConfig.ConfigResponse.Servers[0].City, "First server city should match") + assert.Equal(t, "US", actualConfig.Servers[0].Country, "First server country should match") + assert.Equal(t, "New York", actualConfig.Servers[0].City, "First server city should match") }) // Test case: Fetch fails @@ -215,10 +164,6 @@ func TestHandlerFetchConfig(t *testing.T) { }) } -type mockSrvManager struct{} - -func (m *mockSrvManager) SetServers(_ string, _ servers.Options) error { return nil } - // Make sure MockFetcher implements the Fetcher interface var _ Fetcher = (*MockFetcher)(nil) diff --git a/config/fetcher.go b/config/fetcher.go index b9419b75..dc09a9cd 100644 --- a/config/fetcher.go +++ b/config/fetcher.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "errors" - "os" "fmt" "io" @@ -23,12 +22,11 @@ import ( "github.com/getlantern/lantern-box/protocol" - "github.com/getlantern/radiance/api" - "github.com/getlantern/radiance/backend" + "github.com/getlantern/radiance/account" "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/env" "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal" - "github.com/getlantern/radiance/kindling" + "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/traces" ) @@ -40,7 +38,7 @@ type Fetcher interface { // preferred is used to select the server location. // If preferred is empty, the server will select the best location. // The lastModified time is used to check if the configuration has changed since the last request. - fetchConfig(ctx context.Context, preferred C.ServerLocation, wgPublicKey string) ([]byte, error) + fetchConfig(ctx context.Context, preferred common.PreferredLocation, wgPublicKey string) ([]byte, error) } // fetcher is responsible for fetching the configuration from the server. @@ -48,20 +46,27 @@ type fetcher struct { lastModified time.Time locale string etag string - apiClient *api.APIClient + baseURL string + apiClient *account.Client + httpClient *http.Client } // newFetcher creates a new fetcher with the given http client. -func newFetcher(locale string, apiClient *api.APIClient) Fetcher { +func newFetcher(locale string, apiClient *account.Client, httpClient *http.Client) Fetcher { + if httpClient == nil { + httpClient = &http.Client{Timeout: common.DefaultHTTPTimeout} + } return &fetcher{ lastModified: time.Time{}, locale: locale, + baseURL: common.GetBaseURL(), apiClient: apiClient, + httpClient: httpClient, } } // fetchConfig fetches the configuration from the server. Nil is returned if no new config is available. -func (f *fetcher) fetchConfig(ctx context.Context, preferred C.ServerLocation, wgPublicKey string) ([]byte, error) { +func (f *fetcher) fetchConfig(ctx context.Context, preferred common.PreferredLocation, wgPublicKey string) ([]byte, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "config_fetcher.fetchConfig") defer span.End() // If we don't have a user ID or token, create a new user. @@ -78,7 +83,7 @@ func (f *fetcher) fetchConfig(ctx context.Context, preferred C.ServerLocation, w WGPublicKey: wgPublicKey, Backend: C.SINGBOX, Locale: f.locale, - Protocols: protocol.SupportedProtocols(), + Protocols: filterProtocolsForBridge(protocol.SupportedProtocols()), } if preferred.Country != "" { confReq.PreferredLocation = &preferred @@ -98,7 +103,7 @@ func (f *fetcher) fetchConfig(ctx context.Context, preferred C.ServerLocation, w if buf == nil { // no new config available return nil, nil } - slog.Log(nil, internal.LevelTrace, "received config", "config", string(buf)) + slog.Log(nil, log.LevelTrace, "received config", "config", string(buf)) f.lastModified = time.Now() return buf, nil @@ -142,18 +147,18 @@ func (f *fetcher) ensureUser(ctx context.Context) error { func (f *fetcher) send(ctx context.Context, body io.Reader) ([]byte, error) { ctx, span := otel.Tracer(tracerName).Start(ctx, "config_fetcher.send") defer span.End() - req, err := backend.NewRequestWithHeaders(ctx, http.MethodPost, common.GetBaseURL()+"/config-new", body) + req, err := common.NewRequestWithHeaders(ctx, http.MethodPost, f.baseURL+"/config-new", body) if err != nil { return nil, fmt.Errorf("could not create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Cache-Control", "no-cache") - if val, exists := os.LookupEnv("RADIANCE_COUNTRY"); exists { + if val := settings.GetString(settings.CountryCodeKey); val != "" { slog.Info("Setting x-lantern-client-country header", "country", val) req.Header.Set("x-lantern-client-country", val) } - if val, exists := os.LookupEnv("RADIANCE_FEATURE_OVERRIDE"); exists && val != "" { + if val := settings.GetString(settings.FeatureOverridesKey); val != "" { slog.Info("Setting X-Lantern-Feature-Override header", "features", val) req.Header.Set("X-Lantern-Feature-Override", val) } @@ -165,7 +170,7 @@ func (f *fetcher) send(ctx context.Context, body io.Reader) ([]byte, error) { req.Header.Set("If-None-Match", f.etag) } - resp, err := kindling.HTTPClient().Do(req) + resp, err := f.httpClient.Do(req) if err != nil { return nil, traces.RecordError(ctx, fmt.Errorf("could not send request: %w", err)) } @@ -225,3 +230,36 @@ func moduleVersion(modulePath ...string) (string, error) { return "", fmt.Errorf("module %s not found", modulePath) } + +// udpOnlyProtocols are sing-box outbound protocols whose entry-server +// connection is UDP-only. When radiance's outbound dials are detoured +// through an upstream SOCKS5 (the QA path), our bridge SOCKS5 listener +// only implements TCP CONNECT — UDP ASSOCIATE isn't wired — so these +// outbounds can't be reached and would just clutter URLTest with +// failures. Drop them from the request so the bandit doesn't pick them. +var udpOnlyProtocols = map[string]struct{}{ + "hysteria": {}, + "hysteria2": {}, + "tuic": {}, + "wireguard": {}, + "amnezia": {}, // wireguard-based; same UDP constraint +} + +// filterProtocolsForBridge returns the input slice unchanged unless +// RADIANCE_OUTBOUND_SOCKS_ADDRESS is set, in which case UDP-only +// protocols are filtered out. +func filterProtocolsForBridge(in []string) []string { + if addr, _ := env.Get(env.OutboundSocksAddress); addr == "" { + return in + } + out := in[:0:0] + for _, p := range in { + if _, drop := udpOnlyProtocols[p]; drop { + continue + } + out = append(out, p) + } + slog.Info("RADIANCE_OUTBOUND_SOCKS_ADDRESS set — dropping UDP-only protocols from config request", + "kept", len(out), "dropped", len(in)-len(out)) + return out +} diff --git a/config/fetcher_test.go b/config/fetcher_test.go index 4a078ffa..86fafdaa 100644 --- a/config/fetcher_test.go +++ b/config/fetcher_test.go @@ -1,83 +1,22 @@ package config import ( - "bytes" - "context" "encoding/json" + "fmt" "io" "net/http" - "path/filepath" + "net/http/httptest" "testing" C "github.com/getlantern/common" - "github.com/getlantern/kindling" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/getlantern/radiance/api" "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/reporting" "github.com/getlantern/radiance/common/settings" - rkindling "github.com/getlantern/radiance/kindling" - "github.com/getlantern/radiance/kindling/fronted" ) -func TestDomainFrontingFetchConfig(t *testing.T) { - // Disable this test for now since it depends on external service. - t.Skip("Skipping TestDomainFrontingFetchConfig since it depends on external service.") - dataDir := t.TempDir() - f, err := fronted.NewFronted(context.Background(), reporting.PanicListener, filepath.Join(dataDir, "fronted_cache.json"), io.Discard) - require.NoError(t, err) - k, err := kindling.NewKindling( - "radiance-df-test", - kindling.WithDomainFronting(f), - ) - require.NoError(t, err) - rkindling.SetKindling(k) - fetcher := newFetcher("en-US", &api.APIClient{}) - - privateKey, err := wgtypes.GenerateKey() - require.NoError(t, err) - - _, err = fetcher.fetchConfig(context.Background(), C.ServerLocation{Country: "US"}, privateKey.PublicKey().String()) - // We expect a 500 error since the user does not have any matching tracks. - require.Error(t, err) - assert.Contains(t, err.Error(), "no lantern-cloud tracks") -} - -func TestProxylessFetchConfig(t *testing.T) { - // Disable this test for now since it depends on external service. - t.Skip("Skipping TestProxylessFetchConfig since it depends on external service.") - k, err := kindling.NewKindling( - "radiance-df-test", - kindling.WithProxyless("df.iantem.io"), - ) - require.NoError(t, err) - rkindling.SetKindling(k) - fetcher := newFetcher("en-US", &api.APIClient{}) - - privateKey, err := wgtypes.GenerateKey() - require.NoError(t, err) - - _, err = fetcher.fetchConfig(context.Background(), C.ServerLocation{Country: "US"}, privateKey.PublicKey().String()) - // We expect a 500 error since the user does not have any matching tracks. - require.Error(t, err) - assert.Contains(t, err.Error(), "no lantern-cloud tracks") - -} - -type mockRoundTripper struct { - req *http.Request - resp *http.Response - err error -} - -func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - m.req = req - return m.resp, m.err -} - func TestFetchConfig(t *testing.T) { settings.InitSettings(t.TempDir()) settings.Set(settings.DeviceIDKey, "mock-device-id") @@ -88,25 +27,20 @@ func TestFetchConfig(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - preferredServerLoc *C.ServerLocation - mockResponse *http.Response - mockError error - expectedConfig []byte - expectedErrorMessage string + name string + preferredServerLoc *C.ServerLocation + serverStatus int + serverBody string + expectedConfig []byte + expectError bool }{ { - name: "successful fetch with new config", + name: "successful fetch", preferredServerLoc: &C.ServerLocation{ Country: "US", }, - mockResponse: &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(func() []byte { - data := []byte(`{"key":"value"}`) - return data - }())), - }, + serverStatus: http.StatusOK, + serverBody: `{"key":"value"}`, expectedConfig: []byte(`{"key":"value"}`), }, { @@ -114,81 +48,86 @@ func TestFetchConfig(t *testing.T) { preferredServerLoc: &C.ServerLocation{ Country: "US", }, - mockResponse: &http.Response{ - StatusCode: http.StatusNotModified, - Body: io.NopCloser(bytes.NewReader(nil)), - }, + serverStatus: http.StatusNotModified, expectedConfig: nil, }, - { - name: "error during request", - preferredServerLoc: &C.ServerLocation{ - Country: "US", - }, - mockError: context.DeadlineExceeded, - expectedErrorMessage: "context deadline exceeded", - }, } - apiClient := &api.APIClient{} - defer apiClient.Reset() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockRT := &mockRoundTripper{ - resp: tt.mockResponse, - err: tt.mockError, - } - rkindling.SetKindling(&mockKindling{ - &http.Client{ - Transport: mockRT, - }, - }) - fetcher := newFetcher("en-US", &api.APIClient{}) + var capturedReq *http.Request + var capturedBody []byte + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + capturedBody = body + capturedReq = r + w.WriteHeader(tt.serverStatus) + if tt.serverBody != "" { + w.Write([]byte(tt.serverBody)) + } + })) + defer srv.Close() - gotConfig, err := fetcher.fetchConfig(t.Context(), *tt.preferredServerLoc, privateKey.PublicKey().String()) + f := newFetcher("en-US", nil, srv.Client()).(*fetcher) + f.baseURL = srv.URL - if tt.expectedErrorMessage != "" { + gotConfig, err := f.fetchConfig(t.Context(), *tt.preferredServerLoc, privateKey.PublicKey().String()) + + if tt.expectError { require.Error(t, err) - assert.Contains(t, err.Error(), tt.expectedErrorMessage) } else { require.NoError(t, err) assert.Equal(t, tt.expectedConfig, gotConfig) } - if tt.mockResponse != nil { - require.NotNil(t, mockRT.req) - assert.Equal(t, "application/json", mockRT.req.Header.Get("Content-Type")) - assert.Equal(t, "no-cache", mockRT.req.Header.Get("Cache-Control")) - - body, err := io.ReadAll(mockRT.req.Body) - require.NoError(t, err) - - var confReq C.ConfigRequest - err = json.Unmarshal(body, &confReq) - require.NoError(t, err) - - assert.Equal(t, common.Platform, confReq.Platform) - assert.Equal(t, common.Name, confReq.AppName) - assert.Equal(t, settings.GetString(settings.DeviceIDKey), confReq.DeviceID) - assert.Equal(t, privateKey.PublicKey().String(), confReq.WGPublicKey) - if tt.preferredServerLoc != nil { - assert.Equal(t, tt.preferredServerLoc, confReq.PreferredLocation) - } + require.NotNil(t, capturedReq) + assert.Equal(t, "application/json", capturedReq.Header.Get("Content-Type")) + assert.Equal(t, "no-cache", capturedReq.Header.Get("Cache-Control")) + + var confReq C.ConfigRequest + err = json.Unmarshal(capturedBody, &confReq) + require.NoError(t, err) + + assert.Equal(t, common.Platform, confReq.Platform) + assert.Equal(t, common.Name, confReq.AppName) + assert.Equal(t, settings.GetString(settings.DeviceIDKey), confReq.DeviceID) + assert.Equal(t, "1234567890", confReq.UserID, + "UserID must serialize as a base-10 decimal string matching main's format") + assert.Equal(t, privateKey.PublicKey().String(), confReq.WGPublicKey) + if tt.preferredServerLoc != nil { + assert.Equal(t, tt.preferredServerLoc, confReq.PreferredLocation) } }) } } -type mockKindling struct { - c *http.Client -} - -// NewHTTPClient returns a new HTTP client that is configured to use kindling. -func (m *mockKindling) NewHTTPClient() *http.Client { - return m.c -} - -// ReplaceTransport replaces an existing transport RoundTripper generator with the provided one. -func (m *mockKindling) ReplaceTransport(name string, rt func(ctx context.Context, addr string) (http.RoundTripper, error)) error { - panic("not implemented") // TODO: Implement +// TestUserIDFormatMatchesMain exercises the same expression used in +// fetchConfig to build ConfigRequest.UserID. It guards the regression +// fixed in this PR: on main the value is serialized as a base-10 +// decimal string ("0" when unset, "" when set), and we need +// refactor to match so server-side strconv.ParseInt doesn't treat an +// empty string as malformed. +func TestUserIDFormatMatchesMain(t *testing.T) { + cases := []struct { + name string + set bool + value int64 + expect string + }{ + {name: "unset -> zero", set: false, expect: "0"}, + {name: "small id", set: true, value: 42, expect: "42"}, + {name: "large id (exercises float64 JSON round-trip)", set: true, value: 1234567890, expect: "1234567890"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + require.NoError(t, settings.InitSettings(t.TempDir())) + settings.Clear(settings.UserIDKey) + if tc.set { + require.NoError(t, settings.Set(settings.UserIDKey, tc.value)) + } + got := fmt.Sprintf("%d", settings.GetInt64(settings.UserIDKey)) + assert.Equal(t, tc.expect, got) + }) + } } diff --git a/events/events.go b/events/events.go index 7889f1f1..fba0d7a6 100644 --- a/events/events.go +++ b/events/events.go @@ -27,7 +27,10 @@ package events import ( + "context" + "reflect" "sync" + "sync/atomic" ) type Event interface { @@ -36,7 +39,7 @@ type Event interface { } var ( - subscriptions = make(map[any]map[*Subscription[Event]]func(any)) + subscriptions = make(map[reflect.Type]map[*Subscription[Event]]func(any)) subscriptionsMu sync.RWMutex ) @@ -50,26 +53,48 @@ type Subscription[T Event] struct { func Subscribe[T Event](callback func(evt T)) *Subscription[T] { subscriptionsMu.Lock() defer subscriptionsMu.Unlock() - var evt T - if subscriptions[evt] == nil { - subscriptions[evt] = make(map[*Subscription[Event]]func(any)) + key := reflect.TypeFor[T]() + if subscriptions[key] == nil { + subscriptions[key] = make(map[*Subscription[Event]]func(any)) } sub := &Subscription[T]{} - subscriptions[evt][(*Subscription[Event])(sub)] = func(e any) { callback(e.(T)) } + subscriptions[key][(*Subscription[Event])(sub)] = func(e any) { callback(e.(T)) } return sub } // SubscribeOnce registers a callback function for the given event type T that will be invoked only // once. Returns a Subscription handle that can be used to unsubscribe if needed. func SubscribeOnce[T Event](callback func(evt T)) *Subscription[T] { - ready := make(chan struct{}) + return SubscribeUntil(callback, func(evt T) bool { return true }) +} + +// SubscribeUntil registers a callback function for the given event type T that will be invoked until +// the provided condition function returns true for an event. Returns a Subscription handle that can +// be used to unsubscribe if needed. +func SubscribeUntil[T Event](callback func(evt T), cond func(evt T) bool) *Subscription[T] { + var done atomic.Bool var sub *Subscription[T] sub = Subscribe(func(evt T) { - <-ready + if done.Load() { + return + } callback(evt) - sub.Unsubscribe() + if cond(evt) { + done.Store(true) + sub.Unsubscribe() + } }) - close(ready) + return sub +} + +// SubscribeContext registers a callback for event type T that is automatically unsubscribed when +// the provided context is cancelled. +func SubscribeContext[T Event](ctx context.Context, callback func(evt T)) *Subscription[T] { + sub := Subscribe(callback) + go func() { + <-ctx.Done() + sub.Unsubscribe() + }() return sub } @@ -77,11 +102,11 @@ func SubscribeOnce[T Event](callback func(evt T)) *Subscription[T] { func Unsubscribe[T Event](sub *Subscription[T]) { subscriptionsMu.Lock() defer subscriptionsMu.Unlock() - var evt T - if subs, ok := subscriptions[evt]; ok { + key := reflect.TypeFor[T]() + if subs, ok := subscriptions[key]; ok { delete(subs, (*Subscription[Event])(sub)) if len(subs) == 0 { - delete(subscriptions, evt) + delete(subscriptions, key) } } } @@ -95,8 +120,7 @@ func (e *Subscription[T]) Unsubscribe() { func Emit[T Event](evt T) { subscriptionsMu.RLock() defer subscriptionsMu.RUnlock() - var e T - if subs, ok := subscriptions[e]; ok { + if subs, ok := subscriptions[reflect.TypeFor[T]()]; ok { for _, cb := range subs { go cb(evt) } diff --git a/go.mod b/go.mod index ad756f65..e33f23ee 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,10 @@ module github.com/getlantern/radiance -go 1.25.1 +go 1.26.2 replace github.com/sagernet/sing => github.com/getlantern/sing v0.7.18-lantern -replace github.com/sagernet/sing-box => github.com/getlantern/sing-box-minimal v1.12.21-lantern +replace github.com/sagernet/sing-box => github.com/getlantern/sing-box-minimal v1.12.22-lantern replace github.com/sagernet/wireguard-go => github.com/getlantern/wireguard-go v0.0.1-beta.7.0.20251208214020-d78e69f1eff4 @@ -23,20 +23,19 @@ replace github.com/refraction-networking/water => github.com/getlantern/water v0 require ( github.com/1Password/srp v0.2.0 github.com/Microsoft/go-winio v0.6.2 + github.com/alexflint/go-arg v1.6.1 github.com/alitto/pond v1.9.2 github.com/getlantern/amp v0.0.0-20260305201851-782bc8045e58 - github.com/getlantern/appdir v0.0.0-20250324200952-507a0625eb01 github.com/getlantern/common v1.2.1-0.20260326210434-cb69537aaf46 github.com/getlantern/dnstt v0.0.0-20260112160750-05100563bd0d github.com/getlantern/fronted v0.0.0-20260325003030-cb5041ba1538 - github.com/getlantern/keepcurrent v0.0.0-20260304213122-017d542145ae + github.com/getlantern/keepcurrent v0.0.0-20260422161259-54a4d9a93694 github.com/getlantern/kindling v0.0.0-20260329144042-b1825b9cb1bb - github.com/getlantern/lantern-box v0.0.67 + github.com/getlantern/lantern-box v0.0.74 github.com/getlantern/pluriconfig v0.0.0-20251126214241-8cc8bc561535 github.com/getlantern/publicip v0.0.0-20260328175246-2c460fe80c6b github.com/getlantern/semconv v0.0.0-20260327040646-21845dda05cb github.com/getlantern/timezone v0.0.0-20210901200113-3f9de9d360c9 - github.com/go-resty/resty/v2 v2.16.5 github.com/goccy/go-yaml v1.19.0 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 @@ -52,7 +51,7 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.41.0 go.opentelemetry.io/otel/sdk v1.41.0 go.opentelemetry.io/otel/sdk/metric v1.41.0 - go.uber.org/mock v0.5.0 + golang.org/x/term v0.40.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 google.golang.org/protobuf v1.36.11 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -67,6 +66,7 @@ require ( github.com/akutz/memconn v0.1.0 // indirect github.com/alecthomas/atomic v0.1.0-alpha2 // indirect github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect + github.com/alexflint/go-scalar v1.2.0 // indirect github.com/alitto/pond/v2 v2.1.5 // indirect github.com/anacrolix/btree v0.0.0-20251201064447-d86c3fa41bd8 // indirect github.com/anacrolix/chansync v0.7.0 // indirect @@ -106,14 +106,17 @@ require ( github.com/dsnet/compress v0.0.2-0.20230904184137-39efe44ab707 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/edsrzf/mmap-go v1.1.0 // indirect + github.com/enobufs/go-nats v0.0.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/flynn/noise v1.0.1-0.20220214164934-d803f5c4b0f4 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/gaissmai/bart v0.11.1 // indirect github.com/gaukas/wazerofs v0.1.0 // indirect github.com/getlantern/algeneva v0.0.0-20250307163401-1824e7b54f52 // indirect + github.com/getlantern/broflake v0.0.0-20260421172440-caea0799b63a // indirect github.com/getlantern/lantern-water v0.0.0-20260317143726-e0ee64a11d90 // indirect github.com/getlantern/samizdat v0.0.3-0.20260327203406-ef7323341974 // indirect + github.com/go-chi/chi/v5 v5.2.2 // indirect github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288 // indirect github.com/go-llsqlite/adapter v0.0.0-20230927005056-7f5ce7f0c916 // indirect github.com/go-llsqlite/crawshaw v0.5.6-0.20250312230104-194977a03421 // indirect @@ -154,25 +157,31 @@ require ( github.com/multiformats/go-multihash v0.2.3 // indirect github.com/multiformats/go-varint v0.0.6 // indirect github.com/nwaples/rardecode/v2 v2.2.0 // indirect - github.com/pion/datachannel v1.5.10 // indirect - github.com/pion/dtls/v3 v3.0.4 // indirect - github.com/pion/ice/v4 v4.0.7 // indirect - github.com/pion/interceptor v0.1.40 // indirect - github.com/pion/logging v0.2.3 // indirect - github.com/pion/mdns/v2 v2.0.7 // indirect + github.com/pion/datachannel v1.6.0 // indirect + github.com/pion/dtls/v2 v2.2.12 // indirect + github.com/pion/dtls/v3 v3.1.2 // indirect + github.com/pion/ice/v4 v4.2.2 // indirect + github.com/pion/interceptor v0.1.44 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/mdns/v2 v2.1.0 // indirect github.com/pion/randutil v0.1.0 // indirect - github.com/pion/rtcp v1.2.15 // indirect - github.com/pion/rtp v1.8.18 // indirect - github.com/pion/sctp v1.8.37 // indirect - github.com/pion/sdp/v3 v3.0.11 // indirect - github.com/pion/srtp/v3 v3.0.4 // indirect - github.com/pion/stun/v3 v3.0.0 // indirect - github.com/pion/transport/v3 v3.0.7 // indirect - github.com/pion/turn/v4 v4.0.0 // indirect - github.com/pion/webrtc/v4 v4.0.13 // indirect + github.com/pion/rtcp v1.2.16 // indirect + github.com/pion/rtp v1.10.1 // indirect + github.com/pion/sctp v1.9.4 // indirect + github.com/pion/sdp/v3 v3.0.18 // indirect + github.com/pion/srtp/v3 v3.0.10 // indirect + github.com/pion/stun v0.6.1 // indirect + github.com/pion/stun/v3 v3.1.1 // indirect + github.com/pion/transport v0.14.1 // indirect + github.com/pion/transport/v2 v2.2.10 // indirect + github.com/pion/transport/v4 v4.0.1 // indirect + github.com/pion/turn v1.3.7 // indirect + github.com/pion/turn/v4 v4.1.4 // indirect + github.com/pion/webrtc/v4 v4.2.11 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/prometheus-community/pro-bing v0.4.0 // indirect github.com/protolambda/ctxlock v0.1.0 // indirect + github.com/quic-go/quic-go v0.59.0 // indirect github.com/refraction-networking/utls v1.8.2 // indirect github.com/refraction-networking/water v0.7.1-alpha // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect @@ -194,6 +203,7 @@ require ( github.com/templexxx/xorsimd v0.4.3 // indirect github.com/tetratelabs/wazero v1.11.0 // indirect github.com/tevino/abool/v2 v2.1.0 // indirect + github.com/theodorsm/covert-dtls v1.5.0 // indirect github.com/tidwall/btree v1.8.1 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect github.com/tkuchiki/go-timezone v0.2.0 // indirect @@ -206,12 +216,12 @@ require ( go.etcd.io/bbolt v1.3.6 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect + go.uber.org/mock v0.5.2 // indirect go.uber.org/zap/exp v0.3.0 // indirect go4.org v0.0.0-20230225012048-214862532bf5 // indirect go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect golang.getoutline.org/sdk v0.0.21 // indirect golang.getoutline.org/sdk/x v0.1.0 // indirect - golang.org/x/term v0.40.0 // indirect golang.org/x/text v0.34.0 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect @@ -235,14 +245,13 @@ require ( github.com/getlantern/ops v0.0.0-20231025133620-f368ab734534 // indirect github.com/getlantern/osversion v0.0.0-20240418205916-2e84a4a4e175 github.com/getsentry/sentry-go v0.31.1 - github.com/go-chi/chi/v5 v5.2.2 github.com/go-chi/render v1.0.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect - github.com/gofrs/uuid/v5 v5.3.2 + github.com/gofrs/uuid/v5 v5.3.2 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/hashicorp/yamux v0.1.2 // indirect @@ -259,7 +268,7 @@ require ( github.com/miekg/dns v1.1.67 github.com/pierrec/lz4/v4 v4.1.22 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/quic-go/qpack v0.5.1 // indirect + github.com/quic-go/qpack v0.6.0 // indirect github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a // indirect github.com/sagernet/cors v1.2.1 // indirect github.com/sagernet/fswatch v0.1.1 // indirect @@ -302,3 +311,19 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/blake3 v1.3.0 // indirect ) + +// Keep qpack on v0.5.1 for sing-box-minimal/sagernet quic-go HTTP/3 compatibility: +// sing-box-minimal@v1.12.21-lantern pins sagernet/quic-go@v0.52.0-sing-box-mod.3, +// whose http3 package (used by hysteria2, DoQ, v2rayquic) was compiled against +// qpack's v0.5.1 API (NewDecoder(cb) + DecodeFull, both removed in v0.6.0). +// Without this override, MVS picks v0.6.0 via quic-go/quic-go v0.59.0's require +// and the build breaks. +// +// NB: the require block above may still show qpack v0.6.0 — that's expected +// (it's what quic-go v0.59.0 declares). This replace is what actually forces +// the resolved version down to v0.5.1 at build time. +// +// Mirrors the same replace in lantern-box/go.mod; remove once sing-box-minimal +// bumps to a sagernet/quic-go release that uses the qpack v0.6.0 API +// (v0.59.0-sing-box-mod.4 or later). +replace github.com/quic-go/qpack => github.com/quic-go/qpack v0.5.1 diff --git a/go.sum b/go.sum index f2e3eac0..a2b2fa48 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,10 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= +github.com/alexflint/go-arg v1.6.1 h1:uZogJ6VDBjcuosydKgvYYRhh9sRCusjOvoOLZopBlnA= +github.com/alexflint/go-arg v1.6.1/go.mod h1:nQ0LFYftLJ6njcaee0sU+G0iS2+2XJQfA8I062D0LGc= +github.com/alexflint/go-scalar v1.2.0 h1:WR7JPKkeNpnYIOfHRa7ivM21aWAdHD0gEWHCx+WQBRw= +github.com/alexflint/go-scalar v1.2.0/go.mod h1:LoFvNMqS1CPrMVltza4LvnGKhaSpc3oyLEBUZVhhS2o= github.com/alitto/pond v1.9.2 h1:9Qb75z/scEZVCoSU+osVmQ0I0JOeLfdTDafrbcJ8CLs= github.com/alitto/pond v1.9.2/go.mod h1:xQn3P/sHTYcU/1BR3i86IGIrilcrGC2LiS+E2+CJWsI= github.com/alitto/pond/v2 v2.1.5 h1:2pp/KAPcb02NSpHsjjnxnrTDzogMLsq+vFf/L0DB84A= @@ -195,6 +199,8 @@ github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1 github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/edsrzf/mmap-go v1.1.0 h1:6EUwBLQ/Mcr1EYLE4Tn1VdW1A4ckqCQWZBw8Hr0kjpQ= github.com/edsrzf/mmap-go v1.1.0/go.mod h1:19H/e8pUPLicwkyNgOykDXkJ9F0MHE+Z52B8EIth78Q= +github.com/enobufs/go-nats v0.0.1 h1:uzC0mxan4hyGzUFG7cShFmk6c+XYgfoT8yTBgF5CJYw= +github.com/enobufs/go-nats v0.0.1/go.mod h1:ZF0vpSk02ALIMFsHkIO4MHXUN1v3nLZssTaG+fgX/io= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -222,8 +228,8 @@ github.com/getlantern/algeneva v0.0.0-20250307163401-1824e7b54f52 h1:w2/RqYPw7Pb github.com/getlantern/algeneva v0.0.0-20250307163401-1824e7b54f52/go.mod h1:PrNR8tMXO26YNs8K9653XCUH7u2Kv4OdfFC3Ke1GsX0= github.com/getlantern/amp v0.0.0-20260305201851-782bc8045e58 h1:3wxMKw90adxiEzsJmAmMHqBJQr/P/9Goqy/U2a1l/sg= github.com/getlantern/amp v0.0.0-20260305201851-782bc8045e58/go.mod h1:p6WdG48YAz5SCUpiMSGLy616A6YghKToc63y3NP7avI= -github.com/getlantern/appdir v0.0.0-20250324200952-507a0625eb01 h1:Mmeh4/DA1OKN9tVWRAvTL5efFx4c7v9/55hoK17NclA= -github.com/getlantern/appdir v0.0.0-20250324200952-507a0625eb01/go.mod h1:3vR6+jQdWfWojZ77w+htCqEF5MO/Y2twJOpAvFuM9po= +github.com/getlantern/broflake v0.0.0-20260421172440-caea0799b63a h1:WQ11Ms5jGvBaH6v/u1QBvmnnzRY0ckMiifCnDM/x6TI= +github.com/getlantern/broflake v0.0.0-20260421172440-caea0799b63a/go.mod h1:bZGGfTwne9NIsy3Kc1avcXNWn/yA8ghUwlXdS2z+AlA= github.com/getlantern/common v1.2.1-0.20260326210434-cb69537aaf46 h1:Ab2esudqgFz2K1WYQKtX+58kaiVMX0UohjW2XmdEgf4= github.com/getlantern/common v1.2.1-0.20260326210434-cb69537aaf46/go.mod h1:eSSuV4bMPgQJnczBw+KWWqWNo1itzmVxC++qUBPRTt0= github.com/getlantern/context v0.0.0-20220418194847-3d5e7a086201 h1:oEZYEpZo28Wdx+5FZo4aU7JFXu0WG/4wJWese5reQSA= @@ -240,12 +246,12 @@ github.com/getlantern/hex v0.0.0-20220104173244-ad7e4b9194dc h1:sue+aeVx7JF5v36H github.com/getlantern/hex v0.0.0-20220104173244-ad7e4b9194dc/go.mod h1:D9RWpXy/EFPYxiKUURo2TB8UBosbqkiLhttRrZYtvqM= github.com/getlantern/hidden v0.0.0-20220104173330-f221c5a24770 h1:cSrD9ryDfTV2yaur9Qk3rHYD414j3Q1rl7+L0AylxrE= github.com/getlantern/hidden v0.0.0-20220104173330-f221c5a24770/go.mod h1:GOQsoDnEHl6ZmNIL+5uVo+JWRFWozMEp18Izcb++H+A= -github.com/getlantern/keepcurrent v0.0.0-20260304213122-017d542145ae h1:NMq3K7h3N/usgEtUMQs8WBzvhKKOfBvHo+18pXgtpds= -github.com/getlantern/keepcurrent v0.0.0-20260304213122-017d542145ae/go.mod h1:ag5g9aWUw2FJcX5RVRpJ9EBQBy5yJuy2WXDouIn/m4w= +github.com/getlantern/keepcurrent v0.0.0-20260422161259-54a4d9a93694 h1:iLWm6S/47Hfk7FjW6yaD+1h6kO7C/iauV0DkVia/bXU= +github.com/getlantern/keepcurrent v0.0.0-20260422161259-54a4d9a93694/go.mod h1:ag5g9aWUw2FJcX5RVRpJ9EBQBy5yJuy2WXDouIn/m4w= github.com/getlantern/kindling v0.0.0-20260329144042-b1825b9cb1bb h1:A92dC/E/HvkEb1r4tAwCFNlcMsGdqKe5GMmxeUFid9M= github.com/getlantern/kindling v0.0.0-20260329144042-b1825b9cb1bb/go.mod h1:c5cFjpNrqX8wQ0PUE2blHrO7knAlRCVx3j1/G6zaVlY= -github.com/getlantern/lantern-box v0.0.67 h1:0uDILTY2fVzy47IoEecsMoeplqdxFU/KE/izaZXwM/Q= -github.com/getlantern/lantern-box v0.0.67/go.mod h1:n5NzI/rqr1USYIQPnEy3oZBYNPDyi8EODXNg8jPsQqY= +github.com/getlantern/lantern-box v0.0.74 h1:3LgqcjHX/lLJO4BCEg21vzFaDwiAcUyhdn5o6M6VAaQ= +github.com/getlantern/lantern-box v0.0.74/go.mod h1:lRpNV/lDbsQ2NfA747Oa3mdZXzc0rDsgtlN0lDHh9pM= github.com/getlantern/lantern-water v0.0.0-20260317143726-e0ee64a11d90 h1:P9JX1yAu2uq3b5YiT0sLtHkTrkZuttV8gPZh81nUuag= github.com/getlantern/lantern-water v0.0.0-20260317143726-e0ee64a11d90/go.mod h1:3JpJgwi4KEI6rS9loOAvcBp+F2jP65d0tTg2GQcTPBU= github.com/getlantern/ops v0.0.0-20231025133620-f368ab734534 h1:3BwvWj0JZzFEvNNiMhCu4bf60nqcIuQpTYb00Ezm1ag= @@ -262,8 +268,8 @@ github.com/getlantern/semconv v0.0.0-20260327040646-21845dda05cb h1:c5YM7b3a4r2J github.com/getlantern/semconv v0.0.0-20260327040646-21845dda05cb/go.mod h1:GkPT5P9JoOTIRXRmFWxYgu1hhXgTFFTNc2hoG7WQc3g= github.com/getlantern/sing v0.7.18-lantern h1:QKGgIUA3LwmKYP/7JlQTRkxj9jnP4cX2Q/B+nd8XEjo= github.com/getlantern/sing v0.7.18-lantern/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= -github.com/getlantern/sing-box-minimal v1.12.21-lantern h1:DUlwWDHrU60hd/83mvU/fR9aASiq4KaN5Z1wa8gaRtM= -github.com/getlantern/sing-box-minimal v1.12.21-lantern/go.mod h1:LzlFRel9E92gX0HXWCdsxgeg+kuAEPzLR+Znixk9EI4= +github.com/getlantern/sing-box-minimal v1.12.22-lantern h1:dZXg3jJu8dZGvBAptoJ7L2Gmwe9bSPFRRZlUVT/O8CM= +github.com/getlantern/sing-box-minimal v1.12.22-lantern/go.mod h1:LzlFRel9E92gX0HXWCdsxgeg+kuAEPzLR+Znixk9EI4= github.com/getlantern/timezone v0.0.0-20210901200113-3f9de9d360c9 h1:VTNjZxSuAHUzu13lYpEVB8gc3xz5hZePGNHG5enHYLY= github.com/getlantern/timezone v0.0.0-20210901200113-3f9de9d360c9/go.mod h1:7uvbzuoOr3uYGHZx5QWlI8/C52XEf/aTb/tJFEe41Ak= github.com/getlantern/waitforserver v1.0.1 h1:xBjqJ3GgEk9JMWnDgRSiNHXINi6Lv2tGNjJR0hCkHFY= @@ -311,8 +317,6 @@ github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= -github.com/go-resty/resty/v2 v2.16.5 h1:hBKqmWrr7uRc3euHVqmh1HTHcKn99Smr7o5spptdhTM= -github.com/go-resty/resty/v2 v2.16.5/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ07xAwp/fiA= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= @@ -548,38 +552,60 @@ github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= -github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o= -github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M= -github.com/pion/dtls/v3 v3.0.4 h1:44CZekewMzfrn9pmGrj5BNnTMDCFwr+6sLH+cCuLM7U= -github.com/pion/dtls/v3 v3.0.4/go.mod h1:R373CsjxWqNPf6MEkfdy3aSe9niZvL/JaKlGeFphtMg= -github.com/pion/ice/v4 v4.0.7 h1:mnwuT3n3RE/9va41/9QJqN5+Bhc0H/x/ZyiVlWMw35M= -github.com/pion/ice/v4 v4.0.7/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= -github.com/pion/interceptor v0.1.40 h1:e0BjnPcGpr2CFQgKhrQisBU7V3GXK6wrfYrGYaU6Jq4= -github.com/pion/interceptor v0.1.40/go.mod h1:Z6kqH7M/FYirg3frjGJ21VLSRJGBXB/KqaTIrdqnOic= -github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI= -github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90= -github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= -github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= +github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0= +github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk= +github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= +github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk= +github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= +github.com/pion/dtls/v3 v3.1.2 h1:gqEdOUXLtCGW+afsBLO0LtDD8GnuBBjEy6HRtyofZTc= +github.com/pion/dtls/v3 v3.1.2/go.mod h1:Hw/igcX4pdY69z1Hgv5x7wJFrUkdgHwAn/Q/uo7YHRo= +github.com/pion/ice/v4 v4.2.2 h1:dQJzzcgTFHDYyV3BoCfjPeX+JEtr58BWPi4PGyo6Vjg= +github.com/pion/ice/v4 v4.2.2/go.mod h1:2quLV1S5v1tAx3VvAJaH//KGitRXvo4RKlX6D3tnN+c= +github.com/pion/interceptor v0.1.44 h1:sNlZwM8dWXU9JQAkJh8xrarC0Etn8Oolcniukmuy0/I= +github.com/pion/interceptor v0.1.44/go.mod h1:4atVlBkcgXuUP+ykQF0qOCGU2j7pQzX2ofvPRFsY5RY= +github.com/pion/logging v0.2.1/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= +github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY= +github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= -github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo= -github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0= -github.com/pion/rtp v1.8.18 h1:yEAb4+4a8nkPCecWzQB6V/uEU18X1lQCGAQCjP+pyvU= -github.com/pion/rtp v1.8.18/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk= -github.com/pion/sctp v1.8.37 h1:ZDmGPtRPX9mKCiVXtMbTWybFw3z/hVKAZgU81wcOrqs= -github.com/pion/sctp v1.8.37/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE= -github.com/pion/sdp/v3 v3.0.11 h1:VhgVSopdsBKwhCFoyyPmT1fKMeV9nLMrEKxNOdy3IVI= -github.com/pion/sdp/v3 v3.0.11/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E= -github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M= -github.com/pion/srtp/v3 v3.0.4/go.mod h1:1Jx3FwDoxpRaTh1oRV8A/6G1BnFL+QI82eK4ms8EEJQ= -github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= -github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= -github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= -github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= -github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM= -github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA= -github.com/pion/webrtc/v4 v4.0.13 h1:XuUaWTjRufsiGJRC+G71OgiSMe7tl7mQ0kkd4bAqIaQ= -github.com/pion/webrtc/v4 v4.0.13/go.mod h1:Fadzxm0CbY99YdCEfxrgiVr0L4jN1l8bf8DBkPPpJbs= +github.com/pion/rtcp v1.2.16 h1:fk1B1dNW4hsI78XUCljZJlC4kZOPk67mNRuQ0fcEkSo= +github.com/pion/rtcp v1.2.16/go.mod h1:/as7VKfYbs5NIb4h6muQ35kQF/J0ZVNz2Z3xKoCBYOo= +github.com/pion/rtp v1.10.1 h1:xP1prZcCTUuhO2c83XtxyOHJteISg6o8iPsE2acaMtA= +github.com/pion/rtp v1.10.1/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM= +github.com/pion/sctp v1.9.4 h1:cMxEu0F5tbP4qH07bKf1Zjf4rUih9LIo0qQt424e258= +github.com/pion/sctp v1.9.4/go.mod h1:N20Dq6LY+JvJDAh9VVh1JELngb2rQ8dPgds5yBWiPgw= +github.com/pion/sdp/v3 v3.0.18 h1:l0bAXazKHpepazVdp+tPYnrsy9dfh7ZbT8DxesH5ZnI= +github.com/pion/sdp/v3 v3.0.18/go.mod h1:ZREGo6A9ZygQ9XkqAj5xYCQtQpif0i6Pa81HOiAdqQ8= +github.com/pion/srtp/v3 v3.0.10 h1:tFirkpBb3XccP5VEXLi50GqXhv5SKPxqrdlhDCJlZrQ= +github.com/pion/srtp/v3 v3.0.10/go.mod h1:3mOTIB0cq9qlbn59V4ozvv9ClW/BSEbRp4cY0VtaR7M= +github.com/pion/stun v0.3.1/go.mod h1:xrCld6XM+6GWDZdvjPlLMsTU21rNxnO6UO8XsAvHr/M= +github.com/pion/stun v0.3.2/go.mod h1:xrCld6XM+6GWDZdvjPlLMsTU21rNxnO6UO8XsAvHr/M= +github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4= +github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8= +github.com/pion/stun/v3 v3.1.1 h1:CkQxveJ4xGQjulGSROXbXq94TAWu8gIX2dT+ePhUkqw= +github.com/pion/stun/v3 v3.1.1/go.mod h1:qC1DfmcCTQjl9PBaMa5wSn3x9IPmKxSdcCsxBcDBndM= +github.com/pion/transport v0.8.6/go.mod h1:nAmRRnn+ArVtsoNuwktvAD+jrjSD7pA+H3iRmZwdUno= +github.com/pion/transport v0.8.8/go.mod h1:lpeSM6KJFejVtZf8k0fgeN7zE73APQpTF83WvA1FVP8= +github.com/pion/transport v0.14.1 h1:XSM6olwW+o8J4SCmOBb/BpwZypkHeyM0PGFCxNQBr40= +github.com/pion/transport v0.14.1/go.mod h1:4tGmbk00NeYA3rUa9+n+dzCCoKkcy3YlYb99Jn2fNnI= +github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= +github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= +github.com/pion/transport/v2 v2.2.10 h1:ucLBLE8nuxiHfvkFKnkDQRYWYfp8ejf4YBOPfaQpw6Q= +github.com/pion/transport/v2 v2.2.10/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E= +github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= +github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= +github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= +github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= +github.com/pion/turn v1.3.5/go.mod h1:zGPB7YYB/HTE9MWn0Sbznz8NtyfeVeanZ834cG/MXu0= +github.com/pion/turn v1.3.7 h1:/nyM2XrlZILD7KKfnh0oYEBTRG5JlbH21ibjluRoCeo= +github.com/pion/turn v1.3.7/go.mod h1:js0LBFqMcKAlaWAXoYqNjefGI7kfJCrkCBfHGuTToXE= +github.com/pion/turn/v4 v4.1.4 h1:EU11yMXKIsK43FhcUnjLlrhE4nboHZq+TXBIi3QpcxQ= +github.com/pion/turn/v4 v4.1.4/go.mod h1:ES1DXVFKnOhuDkqn9hn5VJlSWmZPaRJLyBXoOeO/BmQ= +github.com/pion/webrtc/v4 v4.2.11 h1:QUX1QZKlNIn4O7U5JxLPGP0sV5RTncZkzu9SPR3jVNU= +github.com/pion/webrtc/v4 v4.2.11/go.mod h1:s/rAiyy77GyRFrZMx+Ls6aua26dIBPudH8/ZHYbIRWY= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -611,6 +637,8 @@ github.com/protolambda/ctxlock v0.1.0 h1:rCUY3+vRdcdZXqT07iXgyr744J2DU2LCBIXowYA github.com/protolambda/ctxlock v0.1.0/go.mod h1:vefhX6rIZH8rsg5ZpOJfEDYQOppZi19SfPiGOFrNnwM= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= +github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo= github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= @@ -696,6 +724,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e h1:PtWT87weP5LWHEY//SWsYkSO3RWRZo4OSWagh3YD2vQ= @@ -722,6 +752,8 @@ github.com/templexxx/xorsimd v0.4.3 h1:9AQTFHd7Bhk3dIT7Al2XeBX5DWOvsUPZCuhyAtNbH github.com/templexxx/xorsimd v0.4.3/go.mod h1:oZQcD6RFDisW2Am58dSAGwwL6rHjbzrlu25VDqfWkQg= github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= +github.com/theodorsm/covert-dtls v1.5.0 h1:kGUnCuGB65kLrga0e1mYv8t2RA4vfRMN0iYlakY0z/c= +github.com/theodorsm/covert-dtls v1.5.0/go.mod h1:MTb9IO4aqSxrcrh569UGO4PlC1Yel37M440z+gcm13E= github.com/things-go/go-socks5 v0.0.5 h1:qvKaGcBkfDrUL33SchHN93srAmYGzb4CxSM2DPYufe8= github.com/things-go/go-socks5 v0.0.5/go.mod h1:mtzInf8v5xmsBpHZVbIw2YQYhc4K0jRwzfsH64Uh0IQ= github.com/tidwall/btree v1.8.1 h1:27ehoXvm5AG/g+1VxLS1SD3vRhp/H7LuEfwNvddEdmA= @@ -743,6 +775,7 @@ github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zd github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= github.com/willf/bitset v1.1.10/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= +github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= @@ -797,8 +830,8 @@ go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjce go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= -go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= +go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= @@ -826,7 +859,10 @@ golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWP golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -887,10 +923,14 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -942,18 +982,26 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= +golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -964,9 +1012,12 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/log.go b/internal/log.go deleted file mode 100644 index 3f1c6195..00000000 --- a/internal/log.go +++ /dev/null @@ -1,73 +0,0 @@ -package internal - -import ( - "fmt" - "io" - "log/slog" - "strings" -) - -const ( - // slog does not define trace and fatal levels, so we define them here. - LevelTrace = slog.LevelDebug - 4 - LevelDebug = slog.LevelDebug - LevelInfo = slog.LevelInfo - LevelWarn = slog.LevelWarn - LevelError = slog.LevelError - LevelFatal = slog.LevelError + 4 - LevelPanic = slog.LevelError + 8 - - Disable = slog.LevelInfo + 1000 // A level that disables logging, used for testing or no-op logger. -) - -// ParseLogLevel parses a string representation of a log level and returns the corresponding slog.Level. -// If the level is not recognized, it returns LevelInfo. -func ParseLogLevel(level string) (slog.Level, error) { - switch strings.ToLower(level) { - case "trace": - return LevelTrace, nil - case "debug": - return LevelDebug, nil - case "info": - return LevelInfo, nil - case "warn", "warning": - return LevelWarn, nil - case "error": - return LevelError, nil - case "fatal": - return LevelFatal, nil - case "panic": - return LevelPanic, nil - case "disable", "none", "off": - return Disable, nil - default: - return LevelInfo, fmt.Errorf("unknown log level: %s", level) - } -} - -func FormatLogLevel(level slog.Level) string { - switch { - case level < LevelDebug: - return "TRACE" - case level < LevelInfo: - return "DEBUG" - case level < LevelWarn: - return "INFO" - case level < LevelError: - return "WARN" - case level < LevelFatal: - return "ERROR" - case level < LevelPanic: - return "FATAL" - default: - return "PANIC" - } -} - -// NoOpLogger returns a no-op logger that does not log anything. -func NoOpLogger() *slog.Logger { - // Create a no-op logger that does nothing. - return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ - Level: Disable, - })) -} diff --git a/internal/paths.go b/internal/paths.go new file mode 100644 index 00000000..25391076 --- /dev/null +++ b/internal/paths.go @@ -0,0 +1,42 @@ +package internal + +import ( + "os" + "path/filepath" + "runtime" +) + +const ( + DebugBoxOptionsFileName = "debug-box-options.json" + ConfigFileName = "config.json" + ServersFileName = "servers.json" + SplitTunnelFileName = "split-tunnel.json" + LogFileName = "lantern.log" + CrashLogFileName = "lantern-crash.log" +) + +func DefaultDataPath() string { + switch runtime.GOOS { + case "windows": + return filepath.Join(os.Getenv("ProgramData"), "Lantern") + case "darwin": + return "/Library/Application Support/Lantern" + case "linux": + return "/var/lib/lantern" + default: + return "" + } +} + +func DefaultLogPath() string { + switch runtime.GOOS { + case "windows": + return filepath.Join(os.Getenv("ProgramData"), "Lantern") + case "darwin": + return "/Library/Logs/Lantern" + case "linux": + return "/var/log/lantern" + default: + return "" + } +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index 44dbc0a8..8bdb3f05 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -2,7 +2,6 @@ package testutil import ( "testing" - _ "unsafe" // for go:linkname "github.com/getlantern/radiance/common/settings" ) @@ -15,8 +14,4 @@ func SetPathsForTesting(t *testing.T) { tmp := t.TempDir() settings.Set(settings.DataPathKey, tmp) settings.Set(settings.LogPathKey, tmp) - ipc_serverTestSetup(tmp + "/lantern.sock") } - -//go:linkname ipc_serverTestSetup -func ipc_serverTestSetup(path string) diff --git a/ipc/client.go b/ipc/client.go new file mode 100644 index 00000000..8ab04c17 --- /dev/null +++ b/ipc/client.go @@ -0,0 +1,736 @@ +package ipc + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "os" + "sync" + "syscall" + "time" + + box "github.com/getlantern/lantern-box" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/issue" + rlog "github.com/getlantern/radiance/log" + "github.com/getlantern/radiance/servers" + "github.com/getlantern/radiance/vpn" + + sjson "github.com/sagernet/sing/common/json" +) + +func newClient() *Client { + return &Client{ + http: &http.Client{ + Transport: &http.Transport{ + DialContext: dialContext, + ForceAttemptHTTP2: true, + Protocols: &protocols, + }, + }, + } +} + +// marshalBody encodes body as a JSON reader suitable for an HTTP request body. +// Returns nil if body is nil. +func marshalBody(body any) (io.Reader, error) { + if body == nil { + return nil, nil + } + switch body := body.(type) { + case []byte: + return bytes.NewReader(body), nil + default: + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + return bytes.NewReader(data), nil + } +} + +// doJSON executes an HTTP request and decodes the JSON response into dst. +func (c *Client) doJSON(ctx context.Context, method, endpoint string, body, dst any) error { + data, err := c.do(ctx, method, endpoint, body) + if err != nil { + return err + } + if dst == nil { + return nil + } + return json.Unmarshal(data, dst) +} + +// Error is returned by Client methods when the server responds with an error status. +type Error struct { + Status int + Message string +} + +func (e *Error) Error() string { + return fmt.Sprintf("ipc: status %d: %s", e.Status, e.Message) +} + +// IsNotFound reports whether the error is a 404 response. +func IsNotFound(err error) bool { + var e *Error + return errors.As(err, &e) && e.Status == http.StatusNotFound +} + +///////////// +// VPN // +///////////// + +func (c *Client) VPNStatus(ctx context.Context) (vpn.VPNStatus, error) { + var status vpn.VPNStatus + err := c.doJSON(ctx, http.MethodGet, vpnStatusEndpoint, nil, &status) + return status, err +} + +// ConnectVPN connects the VPN using the given server tag. +func (c *Client) ConnectVPN(ctx context.Context, tag string) error { + _, err := c.do(ctx, http.MethodPost, vpnConnectEndpoint, TagRequest{Tag: tag}) + return err +} + +// DisconnectVPN disconnects the VPN. +func (c *Client) DisconnectVPN(ctx context.Context) error { + _, err := c.do(ctx, http.MethodPost, vpnDisconnectEndpoint, nil) + return err +} + +// RestartVPN restarts the VPN connection. +func (c *Client) RestartVPN(ctx context.Context) error { + _, err := c.do(ctx, http.MethodPost, vpnRestartEndpoint, nil) + return err +} + +// VPNConnections returns all VPN connections (active and recently closed). +func (c *Client) VPNConnections(ctx context.Context) ([]vpn.Connection, error) { + var conns []vpn.Connection + err := c.doJSON(ctx, http.MethodGet, vpnConnectionsEndpoint, nil, &conns) + return conns, err +} + +// ActiveVPNConnections returns currently active VPN connections. +func (c *Client) ActiveVPNConnections(ctx context.Context) ([]vpn.Connection, error) { + var conns []vpn.Connection + err := c.doJSON(ctx, http.MethodGet, vpnConnectionsEndpoint+"?active=true", nil, &conns) + return conns, err +} + +// RunOfflineURLTests runs URL performance tests when offline (VPN disconnected) and caches the +// results. This enables autoconnect to select the best server for the initial connection. +func (c *Client) RunOfflineURLTests(ctx context.Context) error { + _, err := c.do(ctx, http.MethodPost, vpnOfflineTestsEndpoint, nil) + return err +} + +/////////////////////// +// Server selection // +/////////////////////// + +var boxCtx = box.BaseContext() + +// SelectServer selects the server with the given tag. +func (c *Client) SelectServer(ctx context.Context, tag string) error { + _, err := c.do(ctx, http.MethodPost, serverSelectedEndpoint, TagRequest{Tag: tag}) + return err +} + +// SelectedServer returns the currently selected server and whether it still exists. +func (c *Client) SelectedServer(ctx context.Context) (*servers.Server, bool, error) { + data, err := c.do(ctx, http.MethodGet, serverSelectedEndpoint, nil) + if err != nil { + return nil, false, err + } + resp, err := sjson.UnmarshalExtendedContext[SelectedServerResponse](boxCtx, data) + return resp.Server, resp.Exists, err +} + +// SelectedServerJSON returns the currently selected server as raw JSON bytes. +func (c *Client) SelectedServerJSON(ctx context.Context) ([]byte, error) { + return c.do(ctx, http.MethodGet, serverSelectedEndpoint, nil) +} + +// AutoSelected returns the server that's currently auto-selected. +func (c *Client) AutoSelected(ctx context.Context) (*servers.Server, error) { + data, err := c.do(ctx, http.MethodGet, serverAutoSelectedEndpoint, nil) + if err != nil { + return nil, err + } + return sjson.UnmarshalExtendedContext[*servers.Server](boxCtx, data) +} + +//////////// +// Config // +//////////// + +// UpdateConfig forces an immediate config fetch on the daemon. Returns an error +// if config fetching is disabled. +func (c *Client) UpdateConfig(ctx context.Context) error { + _, err := c.do(ctx, http.MethodPost, configUpdateEndpoint, nil) + return err +} + +/////////////////////// +// Server management // +/////////////////////// + +// Servers returns all servers. +func (c *Client) Servers(ctx context.Context) ([]*servers.Server, error) { + data, err := c.do(ctx, http.MethodGet, serversEndpoint, nil) + if err != nil { + return nil, err + } + return sjson.UnmarshalExtendedContext[[]*servers.Server](boxCtx, data) +} + +// ServersJSON returns all servers as raw JSON bytes. +// This is useful when the caller needs to forward the JSON without re-marshaling, +// since the server options require sing-box's context-aware JSON encoder. +func (c *Client) ServersJSON(ctx context.Context) ([]byte, error) { + return c.do(ctx, http.MethodGet, serversEndpoint, nil) +} + +// GetServerByTag returns the server with the given tag. +func (c *Client) GetServerByTag(ctx context.Context, tag string) (*servers.Server, bool, error) { + q := url.Values{"tag": {tag}} + data, err := c.do(ctx, http.MethodGet, serversEndpoint+"?"+q.Encode(), nil) + if err != nil { + if IsNotFound(err) { + return nil, false, nil + } + return nil, false, err + } + server, err := sjson.UnmarshalExtendedContext[*servers.Server](boxCtx, data) + if err != nil { + return nil, false, err + } + return server, true, nil +} + +// GetServerByTagJSON returns the server with the given tag as raw JSON bytes. +func (c *Client) GetServerByTagJSON(ctx context.Context, tag string) ([]byte, bool, error) { + q := url.Values{"tag": {tag}} + data, err := c.do(ctx, http.MethodGet, serversEndpoint+"?"+q.Encode(), nil) + if err != nil { + if IsNotFound(err) { + return nil, false, nil + } + return nil, false, err + } + return data, true, nil +} + +// AddServers adds servers. +func (c *Client) AddServers(ctx context.Context, list servers.ServerList) error { + req := AddServersRequest{Servers: list} + body, err := sjson.MarshalContext(boxCtx, req) + if err != nil { + return fmt.Errorf("marshal add servers request: %w", err) + } + _, err = c.do(ctx, http.MethodPost, serversAddEndpoint, body) + return err +} + +// RemoveServers removes servers by tag from the given group. +func (c *Client) RemoveServers(ctx context.Context, tags []string) error { + _, err := c.do(ctx, http.MethodPost, serversRemoveEndpoint, RemoveServersRequest{Tags: tags}) + return err +} + +// AddServersByJSON adds servers from a JSON configuration string and returns the tags of the added servers. +func (c *Client) AddServersByJSON(ctx context.Context, config string) ([]string, error) { + data, err := c.do(ctx, http.MethodPost, serversFromJSONEndpoint, JSONConfigRequest{Config: config}) + if err != nil { + return nil, err + } + var tags []string + if err := json.Unmarshal(data, &tags); err != nil { + return nil, err + } + return tags, nil +} + +// AddServersByURL adds servers from the given URLs and returns the tags of the added servers. +func (c *Client) AddServersByURL(ctx context.Context, urls []string, skipCertVerification bool) ([]string, error) { + data, err := c.do(ctx, http.MethodPost, serversFromURLsEndpoint, URLsRequest{URLs: urls, SkipCertVerification: skipCertVerification}) + if err != nil { + return nil, err + } + var tags []string + if err := json.Unmarshal(data, &tags); err != nil { + return nil, err + } + return tags, nil +} + +// AddPrivateServer adds a private server. +func (c *Client) AddPrivateServer(ctx context.Context, tag, ip string, port int, accessToken string) error { + _, err := c.do(ctx, http.MethodPost, serversPrivateEndpoint, PrivateServerRequest{Tag: tag, IP: ip, Port: port, AccessToken: accessToken}) + return err +} + +// InviteToPrivateServer creates an invite for a private server and returns the invite code. +func (c *Client) InviteToPrivateServer(ctx context.Context, ip string, port int, accessToken, inviteName string) (string, error) { + var resp CodeResponse + err := c.doJSON(ctx, http.MethodPost, serversPrivateInviteEndpoint, + PrivateServerInviteRequest{IP: ip, Port: port, AccessToken: accessToken, InviteName: inviteName}, &resp) + return resp.Code, err +} + +// RevokePrivateServerInvite revokes an invite for a private server. +func (c *Client) RevokePrivateServerInvite(ctx context.Context, ip string, port int, accessToken, inviteName string) error { + _, err := c.do(ctx, http.MethodDelete, serversPrivateInviteEndpoint, + PrivateServerInviteRequest{IP: ip, Port: port, AccessToken: accessToken, InviteName: inviteName}) + return err +} + +////////////// +// Settings // +////////////// + +func (c *Client) Features(ctx context.Context) (map[string]bool, error) { + var features map[string]bool + err := c.doJSON(ctx, http.MethodGet, featuresEndpoint, nil, &features) + return features, err +} + +// Settings returns the current settings as a map of key-value pairs. +func (c *Client) Settings(ctx context.Context) (settings.Settings, error) { + var s settings.Settings + err := c.doJSON(ctx, http.MethodGet, settingsEndpoint, nil, &s) + return s, err +} + +// PatchSettings updates settings with the given key-value pairs and returns the full updates settings. +func (c *Client) PatchSettings(ctx context.Context, updates settings.Settings) (settings.Settings, error) { + var s settings.Settings + err := c.doJSON(ctx, http.MethodPatch, settingsEndpoint, updates, &s) + return s, err +} + +func (c *Client) EnableTelemetry(ctx context.Context, enable bool) error { + _, err := c.PatchSettings(ctx, settings.Settings{settings.TelemetryKey: enable}) + return err +} + +func (c *Client) EnableSplitTunneling(ctx context.Context, enable bool) error { + _, err := c.PatchSettings(ctx, settings.Settings{settings.SplitTunnelKey: enable}) + return err +} + +func (c *Client) EnableSmartRouting(ctx context.Context, enable bool) error { + _, err := c.PatchSettings(ctx, settings.Settings{settings.SmartRoutingKey: enable}) + return err +} + +func (c *Client) EnableAdBlocking(ctx context.Context, enable bool) error { + _, err := c.PatchSettings(ctx, settings.Settings{settings.AdBlockKey: enable}) + return err +} + +// EnableConfigFetch toggles periodic config fetching. Passing false sets +// settings.ConfigFetchDisabledKey to true on the daemon. +func (c *Client) EnableConfigFetch(ctx context.Context, enable bool) error { + _, err := c.PatchSettings(ctx, settings.Settings{settings.ConfigFetchDisabledKey: !enable}) + return err +} + +// SetLogLevel sets the daemon's log level. Valid values: trace, debug, info, +// warn, error, fatal, panic, disable. +func (c *Client) SetLogLevel(ctx context.Context, level string) error { + if _, err := rlog.ParseLogLevel(level); err != nil { + return err + } + _, err := c.PatchSettings(ctx, settings.Settings{settings.LogLevelKey: level}) + return err +} + +///////// +// Env // +///////// + +// PatchEnvVars updates the daemon's in-memory environment variables. +// This is intended for dev/testing use only. +func (c *Client) PatchEnvVars(ctx context.Context, updates map[string]string) (map[string]string, error) { + var result map[string]string + err := c.doJSON(ctx, http.MethodPatch, envEndpoint, updates, &result) + return result, err +} + +////////////////// +// Split Tunnel // +///////////////// + +// SplitTunnelFilters returns the current split tunnel configuration. +func (c *Client) SplitTunnelFilters(ctx context.Context) (vpn.SplitTunnelFilter, error) { + var filter vpn.SplitTunnelFilter + err := c.doJSON(ctx, http.MethodGet, splitTunnelEndpoint, nil, &filter) + return filter, err +} + +// AddSplitTunnelItems adds items to the split tunnel filter. +func (c *Client) AddSplitTunnelItems(ctx context.Context, items vpn.SplitTunnelFilter) error { + _, err := c.do(ctx, http.MethodPost, splitTunnelEndpoint, items) + return err +} + +// RemoveSplitTunnelItems removes items from the split tunnel filter. +func (c *Client) RemoveSplitTunnelItems(ctx context.Context, items vpn.SplitTunnelFilter) error { + _, err := c.do(ctx, http.MethodDelete, splitTunnelEndpoint, items) + return err +} + +///////////// +// Account // +///////////// + +// NewUser creates a new anonymous user. +func (c *Client) NewUser(ctx context.Context) (*account.UserData, error) { + var userData account.UserData + if err := c.doJSON(ctx, http.MethodPost, accountNewUserEndpoint, nil, &userData); err != nil { + return nil, err + } + return &userData, nil +} + +// Login authenticates the user with email and password. +func (c *Client) Login(ctx context.Context, email, password string) (*account.UserData, error) { + var userData account.UserData + err := c.doJSON(ctx, http.MethodPost, accountLoginEndpoint, + EmailPasswordRequest{Email: email, Password: password}, &userData) + if err != nil { + return nil, err + } + return &userData, nil +} + +// Logout logs the user out. +func (c *Client) Logout(ctx context.Context, email string) (*account.UserData, error) { + var userData account.UserData + if err := c.doJSON(ctx, http.MethodPost, accountLogoutEndpoint, EmailRequest{Email: email}, &userData); err != nil { + return nil, err + } + return &userData, nil +} + +// FetchUserData fetches fresh user data from the remote server. +func (c *Client) FetchUserData(ctx context.Context) (*account.UserData, error) { + return c.userData(ctx, true) +} + +// UserData returns locally cached user data. +func (c *Client) UserData(ctx context.Context) (*account.UserData, error) { + return c.userData(ctx, false) +} + +func (c *Client) userData(ctx context.Context, fetch bool) (*account.UserData, error) { + var userData account.UserData + url := fmt.Sprintf("%s?fetch=%v", accountUserDataEndpoint, fetch) + if err := c.doJSON(ctx, http.MethodGet, url, nil, &userData); err != nil { + return nil, err + } + return &userData, nil +} + +// UserDevices returns the list of devices linked to the user's account. +func (c *Client) UserDevices(ctx context.Context) ([]settings.Device, error) { + var devices []settings.Device + err := c.doJSON(ctx, http.MethodGet, accountDevicesEndpoint, nil, &devices) + return devices, err +} + +// RemoveDevice removes a device from the user's account. +func (c *Client) RemoveDevice(ctx context.Context, deviceID string) (*account.LinkResponse, error) { + var resp account.LinkResponse + if err := c.doJSON(ctx, http.MethodDelete, accountDevicesEndpoint+url.PathEscape(deviceID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// SignUp creates a new account with the given email and password. +func (c *Client) SignUp(ctx context.Context, email, password string) ([]byte, *account.SignupResponse, error) { + var resp SignupResponse + err := c.doJSON( + ctx, http.MethodPost, accountSignupEndpoint, + EmailPasswordRequest{Email: email, Password: password}, &resp, + ) + if err != nil { + return nil, nil, err + } + return resp.Salt, resp.Response, nil +} + +// SignupEmailConfirmation confirms the signup email with the given code. +func (c *Client) SignupEmailConfirmation(ctx context.Context, email, code string) error { + _, err := c.do(ctx, http.MethodPost, accountSignupEndpoint+"confirm", EmailCodeRequest{Email: email, Code: code}) + return err +} + +// SignupEmailResendCode requests a resend of the signup confirmation email. +func (c *Client) SignupEmailResendCode(ctx context.Context, email string) error { + _, err := c.do(ctx, http.MethodPost, accountSignupEndpoint+"resend", EmailRequest{Email: email}) + return err +} + +// StartChangeEmail initiates an email address change. +func (c *Client) StartChangeEmail(ctx context.Context, newEmail, password string) error { + _, err := c.do(ctx, http.MethodPost, accountEmailEndpoint+"/start", ChangeEmailStartRequest{NewEmail: newEmail, Password: password}) + return err +} + +// CompleteChangeEmail completes an email address change. +func (c *Client) CompleteChangeEmail(ctx context.Context, newEmail, password, code string) error { + _, err := c.do(ctx, http.MethodPost, accountEmailEndpoint+"/complete", + ChangeEmailCompleteRequest{NewEmail: newEmail, Password: password, Code: code}) + return err +} + +// StartRecoveryByEmail initiates account recovery by email. +func (c *Client) StartRecoveryByEmail(ctx context.Context, email string) error { + _, err := c.do(ctx, http.MethodPost, accountRecoveryEndpoint+"/start", EmailRequest{Email: email}) + return err +} + +// CompleteRecoveryByEmail completes account recovery with a new password and code. +func (c *Client) CompleteRecoveryByEmail(ctx context.Context, email, newPassword, code string) error { + _, err := c.do(ctx, http.MethodPost, accountRecoveryEndpoint+"/complete", + RecoveryCompleteRequest{Email: email, NewPassword: newPassword, Code: code}) + return err +} + +// ValidateEmailRecoveryCode validates the recovery code without completing the recovery. +func (c *Client) ValidateEmailRecoveryCode(ctx context.Context, email, code string) error { + _, err := c.do(ctx, http.MethodPost, accountRecoveryEndpoint+"/validate", EmailCodeRequest{Email: email, Code: code}) + return err +} + +// DeleteAccount deletes the user's account. +func (c *Client) DeleteAccount(ctx context.Context, email, password string) (*account.UserData, error) { + var userData account.UserData + err := c.doJSON(ctx, http.MethodDelete, accountDeleteEndpoint, + EmailPasswordRequest{Email: email, Password: password}, &userData) + if err != nil { + return nil, err + } + return &userData, nil +} + +// OAuthLoginURL returns the OAuth login URL for the given provider. +func (c *Client) OAuthLoginURL(ctx context.Context, provider string) (string, error) { + var resp URLResponse + q := url.Values{"provider": {provider}} + err := c.doJSON(ctx, http.MethodGet, accountOAuthEndpoint+"?"+q.Encode(), nil, &resp) + return resp.URL, err +} + +// OAuthLoginCallback exchanges an OAuth token for user data. +func (c *Client) OAuthLoginCallback(ctx context.Context, oAuthToken string) (*account.UserData, error) { + var userData account.UserData + err := c.doJSON(ctx, http.MethodPost, accountOAuthEndpoint, + OAuthTokenRequest{OAuthToken: oAuthToken}, &userData) + if err != nil { + return nil, err + } + return &userData, nil +} + +// DataCapInfo returns the current data cap information as a JSON string. +func (c *Client) DataCapInfo(ctx context.Context) (*account.DataCapInfo, error) { + var resp account.DataCapInfo + err := c.doJSON(ctx, http.MethodGet, accountDataCapEndpoint, nil, &resp) + return &resp, err +} + +/////////////////// +// Subscriptions // +/////////////////// + +// ActivationCode purchases a subscription using a reseller code. +func (c *Client) ActivationCode(ctx context.Context, email, resellerCode string) (*account.PurchaseResponse, error) { + var resp account.PurchaseResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionActivationEndpoint, + ActivationRequest{Email: email, ResellerCode: resellerCode}, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +// NewStripeSubscription creates a new Stripe subscription and returns the client secret. +func (c *Client) NewStripeSubscription(ctx context.Context, email, planID string) (string, error) { + var resp ClientSecretResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionStripeEndpoint, + StripeSubscriptionRequest{Email: email, PlanID: planID}, &resp) + return resp.ClientSecret, err +} + +// ReferralAttach attaches a referral code to the current user. +func (c *Client) ReferralAttach(ctx context.Context, code string) (bool, error) { + var resp SuccessResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionReferralEndpoint, CodeRequest{Code: code}, &resp) + return resp.Success, err +} + +// StripeBillingPortalURL returns the Stripe billing portal URL. +func (c *Client) StripeBillingPortalURL(ctx context.Context) (string, error) { + var resp URLResponse + err := c.doJSON(ctx, http.MethodGet, subscriptionBillingPortalEndpoint, nil, &resp) + return resp.URL, err +} + +// PaymentRedirect returns a payment redirect URL. +func (c *Client) PaymentRedirect(ctx context.Context, data account.PaymentRedirectData) (string, error) { + var resp URLResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionPaymentRedirectEndpoint, data, &resp) + return resp.URL, err +} + +// SubscriptionPaymentRedirectURL returns a subscription payment redirect URL. +func (c *Client) SubscriptionPaymentRedirectURL(ctx context.Context, data account.PaymentRedirectData) (string, error) { + var resp URLResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionPaymentRedirectURLEndpoint, data, &resp) + return resp.URL, err +} + +// SubscriptionPlans returns available subscription plans for the given channel. +func (c *Client) SubscriptionPlans(ctx context.Context, channel string) (string, error) { + var resp PlansResponse + q := url.Values{"channel": {channel}} + err := c.doJSON(ctx, http.MethodGet, subscriptionPlansEndpoint+"?"+q.Encode(), nil, &resp) + return resp.Plans, err +} + +// VerifySubscription verifies a subscription purchase. +func (c *Client) VerifySubscription(ctx context.Context, service account.SubscriptionService, data map[string]string) (string, error) { + var resp ResultResponse + err := c.doJSON(ctx, http.MethodPost, subscriptionVerifyEndpoint, + VerifySubscriptionRequest{Service: service, Data: data}, &resp) + return resp.Result, err +} + +/////////// +// Issue // +/////////// + +// ReportIssue submits an issue report. additionalAttachments is a list of file paths for additional +// files to include. Logs, diagnostics, and the config response are included automatically and do +// not need to be specified. +func (c *Client) ReportIssue(ctx context.Context, issueType issue.IssueType, description, email string, additionalAttachments []string) error { + _, err := c.do(ctx, http.MethodPost, issueEndpoint, + IssueReportRequest{IssueType: issueType, Description: description, Email: email, AdditionalAttachments: additionalAttachments}) + return err +} + +///////////// +// streams // +///////////// + +// sseRetryLoop runs sseStream in a retry loop until ctx is cancelled. +func (c *Client) sseRetryLoop(ctx context.Context, endpoint string, handler func([]byte)) error { + bo := common.NewBackoff(30 * time.Second) + for ctx.Err() == nil { + err := c.sseStream(ctx, endpoint, handler) + if ctx.Err() != nil { + return ctx.Err() + } + // silently ignore IPC not running errors, since they are expected when the daemon is not running. + // prevent spamming the logs with errors until the daemon starts. + if err != nil && !errors.Is(err, ErrIPCNotRunning) { + slog.Warn("SSE stream ended, retrying", "endpoint", endpoint, "error", err) + } + bo.Wait(ctx) + } + return ctx.Err() +} + +// dataCapStream runs the data-cap SSE stream only while the VPN is connected. Blocks until ctx +// is cancelled. +func (c *Client) dataCapStream(ctx context.Context, handler func(account.DataCapInfo)) error { + var ( + mu sync.Mutex + cancelFn context.CancelFunc + ) + + decode := func(data []byte) { + var info account.DataCapInfo + if err := json.Unmarshal(data, &info); err == nil { + handler(info) + } + } + + start := func() { + mu.Lock() + defer mu.Unlock() + if cancelFn != nil { + return + } + subCtx, cancel := context.WithCancel(ctx) + cancelFn = cancel + go c.sseRetryLoop(subCtx, accountDataCapStreamEndpoint, decode) + } + + stop := func() { + mu.Lock() + defer mu.Unlock() + if cancelFn != nil { + cancelFn() + cancelFn = nil + } + } + defer stop() + + // check if VPN is already connected before starting the stream, otherwise we might miss the + // "connected" event that triggers the stream start + if status, err := c.VPNStatus(ctx); err == nil && status == vpn.Connected { + start() + } + + return c.VPNStatusEvents(ctx, func(evt vpn.StatusUpdateEvent) { + if evt.Status == vpn.Connected { + start() + } else { + stop() + } + }) +} + +///////////// +// helpers // +///////////// + +// isConnectionError reports whether err indicates that the IPC socket is unreachable +// (e.g. connection refused or socket file not found). +func isConnectionError(err error) bool { + var opErr *net.OpError + if errors.As(err, &opErr) { + // connection refused (server not listening) + if errors.Is(opErr.Err, syscall.ECONNREFUSED) { + return true + } + // socket file does not exist (server never started / was cleaned up) + if errors.Is(opErr.Err, syscall.ENOENT) { + return true + } + // check wrapped syscall errors + var sysErr *os.SyscallError + if errors.As(opErr.Err, &sysErr) { + return errors.Is(sysErr.Err, syscall.ECONNREFUSED) || errors.Is(sysErr.Err, syscall.ENOENT) + } + } + // Also check the unwrapped error directly for cases where the wrapping differs by platform + return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT) +} diff --git a/ipc/client_events_mobile.go b/ipc/client_events_mobile.go new file mode 100644 index 00000000..d03fca66 --- /dev/null +++ b/ipc/client_events_mobile.go @@ -0,0 +1,62 @@ +//go:build android || ios || (darwin && !standalone) + +package ipc + +import ( + "context" + "encoding/json" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/config" + "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/vpn" +) + +// AutoSelectedEvents streams auto-selection changes. Blocks until ctx is cancelled. +func (c *Client) AutoSelectedEvents(ctx context.Context, handler func(vpn.AutoSelectedEvent)) error { + events.SubscribeContext(ctx, handler) + if c.localOnly { + <-ctx.Done() + return ctx.Err() + } + return c.sseRetryLoop(ctx, serverAutoSelectedEventsEndpoint, func(data []byte) { + var evt vpn.AutoSelectedEvent + if err := json.Unmarshal(data, &evt); err == nil { + handler(evt) + } + }) +} + +// ConfigEvents streams config-updated notifications. Payloads are empty — callers should treat each +// call as a "refresh" signal. Blocks until ctx is cancelled. +func (c *Client) ConfigEvents(ctx context.Context, handler func()) error { + events.SubscribeContext(ctx, func(config.NewConfigEvent) { handler() }) + if c.localOnly { + <-ctx.Done() + return ctx.Err() + } + return c.sseRetryLoop(ctx, configEventsEndpoint, func([]byte) { handler() }) +} + +// VPNStatusEvents streams VPN status changes. Blocks until ctx is cancelled. +func (c *Client) VPNStatusEvents(ctx context.Context, handler func(vpn.StatusUpdateEvent)) error { + if c.localOnly { + <-ctx.Done() + return ctx.Err() + } + return c.sseRetryLoop(ctx, vpnStatusEventsEndpoint, func(data []byte) { + var evt vpn.StatusUpdateEvent + if err := json.Unmarshal(data, &evt); err == nil { + handler(evt) + } + }) +} + +// DataCapStream streams data-cap updates while the VPN is connected. Blocks until ctx is cancelled. +func (c *Client) DataCapStream(ctx context.Context, handler func(account.DataCapInfo)) error { + if c.localOnly { + <-ctx.Done() + return ctx.Err() + } + return c.dataCapStream(ctx, handler) +} diff --git a/ipc/client_events_nonmobile.go b/ipc/client_events_nonmobile.go new file mode 100644 index 00000000..16d3184e --- /dev/null +++ b/ipc/client_events_nonmobile.go @@ -0,0 +1,42 @@ +//go:build (!android && !ios && !darwin) || (darwin && standalone) + +package ipc + +import ( + "context" + "encoding/json" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/vpn" +) + +// AutoSelectedEvents streams auto-selection changes. Blocks until ctx is cancelled. +func (c *Client) AutoSelectedEvents(ctx context.Context, handler func(vpn.AutoSelectedEvent)) error { + return c.sseRetryLoop(ctx, serverAutoSelectedEventsEndpoint, func(data []byte) { + var evt vpn.AutoSelectedEvent + if err := json.Unmarshal(data, &evt); err == nil { + handler(evt) + } + }) +} + +// ConfigEvents streams config-updated notifications. Payloads are empty — callers should treat each +// call as a "refresh" signal. Blocks until ctx is cancelled. +func (c *Client) ConfigEvents(ctx context.Context, handler func()) error { + return c.sseRetryLoop(ctx, configEventsEndpoint, func([]byte) { handler() }) +} + +// VPNStatusEvents streams VPN status changes. Blocks until ctx is cancelled. +func (c *Client) VPNStatusEvents(ctx context.Context, handler func(vpn.StatusUpdateEvent)) error { + return c.sseRetryLoop(ctx, vpnStatusEventsEndpoint, func(data []byte) { + var evt vpn.StatusUpdateEvent + if err := json.Unmarshal(data, &evt); err == nil { + handler(evt) + } + }) +} + +// DataCapStream streams data-cap updates while the VPN is connected. Blocks until ctx is cancelled. +func (c *Client) DataCapStream(ctx context.Context, handler func(account.DataCapInfo)) error { + return c.dataCapStream(ctx, handler) +} diff --git a/ipc/client_mobile.go b/ipc/client_mobile.go new file mode 100644 index 00000000..d477cc1f --- /dev/null +++ b/ipc/client_mobile.go @@ -0,0 +1,241 @@ +//go:build android || ios || (darwin && !standalone) + +package ipc + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "time" + + "github.com/getlantern/radiance/backend" + "github.com/getlantern/radiance/common/settings" + rlog "github.com/getlantern/radiance/log" +) + +type Client struct { + http *http.Client + localapi *localapi + localOnly bool // when true, serve all requests in-process; never attempt the IPC socket + mu sync.RWMutex +} + +func NewClient(ctx context.Context, opts backend.Options) (*Client, error) { + b, err := backend.NewLocalBackend(ctx, opts) + if err != nil { + return nil, fmt.Errorf("create local backend: %w", err) + } + b.Start() + c := newClient() + c.localapi = newLocalAPI(b, false) + return c, nil +} + +// NewLoopbackClient creates a Client that serves all requests in-process +// through the given LocalBackend without attempting IPC socket connections. +// The backend is NOT owned by this client — Close will not shut it down. +func NewLoopbackClient(b *backend.LocalBackend) *Client { + c := newClient() + c.localapi = newLocalAPI(b, false) + c.localOnly = true + return c +} + +// Close releases resources held by the client, including any local backend. +func (c *Client) Close() { + if c.localOnly { + return + } + c.stopLocal() + c.http.CloseIdleConnections() +} + +func (c *Client) stopLocal() { + c.mu.Lock() + defer c.mu.Unlock() + if be := c.localapi.setBackend(nil); be != nil { + be.Close() + } +} + +// do executes an HTTP request with an optional JSON body and returns the raw response body. If +// body needs to be marshaled using sing/json, it should be pre-marshaled to []byte before passing +// to do. do returns an error if the response status is >= 400. +func (c *Client) do(ctx context.Context, method, endpoint string, body any) ([]byte, error) { + bodyReader, err := marshalBody(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, method, apiURL+endpoint, bodyReader) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + if c.localOnly { + return c.doLocal(req) + } + + resp, err := c.http.Do(req) + if err != nil { + if isConnectionError(err) { + c.mu.Lock() + defer c.mu.Unlock() + if be := c.localapi.be.Load(); be == nil { + opts := backend.Options{ + DataDir: settings.GetString(settings.DataPathKey), + LogDir: settings.GetString(settings.LogPathKey), + Locale: settings.GetString(settings.LocaleKey), + DeviceID: settings.GetString(settings.DeviceIDKey), + LogLevel: settings.GetString(settings.LogLevelKey), + TelemetryConsent: settings.GetBool(settings.TelemetryKey), + } + be, err = backend.NewLocalBackend(ctx, opts) + if err != nil { + return nil, fmt.Errorf("create local backend: %w", err) + } + c.localapi.setBackend(be) + } + if br, ok := bodyReader.(*bytes.Reader); ok { + br.Seek(0, io.SeekStart) + } + req, _ = http.NewRequestWithContext(ctx, method, apiURL+endpoint, bodyReader) + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + return c.doLocal(req) + } + return nil, fmt.Errorf("ipc request %s %s: %w", method, endpoint, err) + } + c.stopLocal() // IPC is reachable; shut down local backend if still running + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + if resp.StatusCode >= 400 { + return nil, &Error{ + Status: resp.StatusCode, + Message: strings.TrimSpace(string(respBody)), + } + } + return respBody, nil +} + +// doLocal serves the request through the given in-process handler. +func (c *Client) doLocal(req *http.Request) ([]byte, error) { + rec := httptest.NewRecorder() + c.localapi.ServeHTTP(rec, req) + + body := rec.Body.Bytes() + if rec.Code >= 400 { + return nil, &Error{ + Status: rec.Code, + Message: strings.TrimSpace(string(body)), + } + } + return body, nil +} + +// TailLogs connects to the log stream endpoint and calls handler for each log +// entry received until ctx is cancelled or the connection is closed. +func (c *Client) TailLogs(ctx context.Context, handler func(rlog.LogEntry)) error { + merged := make(chan rlog.LogEntry, 64) + + // Always tail local logs. + localCh, unsub := rlog.Subscribe() + defer unsub() + go func() { + for { + select { + case entry := <-localCh: + select { + case merged <- entry: + default: + } + case <-ctx.Done(): + return + } + } + }() + + // Tail server logs whenever the IPC server is reachable. + go func() { + for ctx.Err() == nil { + c.sseStream(ctx, logsStreamEndpoint, func(data []byte) { + select { + case merged <- string(data): + default: + } + }) + // Server unavailable or disconnected; wait before retrying. + select { + case <-time.After(500 * time.Millisecond): + case <-ctx.Done(): + return + } + } + }() + + for { + select { + case entry := <-merged: + handler(entry) + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// sseStream connects to an SSE endpoint and calls handler for each event data line. +// Blocks until ctx is cancelled or the connection is closed. +func (c *Client) sseStream(ctx context.Context, endpoint string, handler func([]byte)) error { + if c.localOnly { + return ErrIPCNotRunning + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL+endpoint, nil) + if err != nil { + return fmt.Errorf("create SSE request: %w", err) + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := c.http.Do(req) + if err != nil { + c.mu.RLock() + hasFallback := c.localapi != nil + c.mu.RUnlock() + if hasFallback && isConnectionError(err) { + return ErrIPCNotRunning + } + return fmt.Errorf("SSE connect %s: %w", endpoint, err) + } + c.stopLocal() // IPC is reachable; shut down local backend if still running + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return &Error{Status: resp.StatusCode, Message: strings.TrimSpace(string(body))} + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if data, ok := strings.CutPrefix(line, "data: "); ok { + handler([]byte(data)) + } + } + if err := scanner.Err(); err != nil && ctx.Err() == nil { + return fmt.Errorf("SSE %s: read: %w", endpoint, err) + } + return nil +} diff --git a/ipc/client_nonmobile.go b/ipc/client_nonmobile.go new file mode 100644 index 00000000..93e1add7 --- /dev/null +++ b/ipc/client_nonmobile.go @@ -0,0 +1,112 @@ +//go:build (!android && !ios && !darwin) || (darwin && standalone) + +package ipc + +import ( + "bufio" + "context" + "fmt" + "io" + "net/http" + "strings" + + rlog "github.com/getlantern/radiance/log" +) + +// Client communicates with the IPC server over a local socket. +type Client struct { + http *http.Client +} + +// NewClient creates a new IPC client that communicates exclusively through the IPC server. +func NewClient() *Client { + return newClient() +} + +// Close releases resources held by the client, including any local backend. +func (c *Client) Close() { + c.http.CloseIdleConnections() +} + +// do executes an HTTP request with an optional JSON body and returns the raw response body. If +// body needs to be marshaled using sing/json, it should be pre-marshaled to []byte before passing +// to do. do returns an error if the response status is >= 400. +func (c *Client) do(ctx context.Context, method, endpoint string, body any) ([]byte, error) { + bodyReader, err := marshalBody(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, method, apiURL+endpoint, bodyReader) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("ipc request %s %s: %w", method, endpoint, err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + if resp.StatusCode >= 400 { + return nil, &Error{ + Status: resp.StatusCode, + Message: strings.TrimSpace(string(respBody)), + } + } + return respBody, nil +} + +// TailLogs connects to the log stream endpoint and calls handler for each log +// entry received until ctx is cancelled or the connection is closed. +func (c *Client) TailLogs(ctx context.Context, handler func(rlog.LogEntry)) error { + return c.sseStream(ctx, logsStreamEndpoint, func(data []byte) { + handler(string(data)) + }) +} + +// sseStream connects to an SSE endpoint and calls handler for each event data line. +// Blocks until ctx is cancelled or the connection is closed. +func (c *Client) sseStream(ctx context.Context, endpoint string, handler func([]byte)) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL+endpoint, nil) + if err != nil { + return fmt.Errorf("create SSE request: %w", err) + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := c.http.Do(req) + if err != nil { + if isConnectionError(err) { + return ErrIPCNotRunning + } + return fmt.Errorf("SSE connect %s: %w", endpoint, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return &Error{Status: resp.StatusCode, Message: strings.TrimSpace(string(body))} + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if data, ok := strings.CutPrefix(line, "data: "); ok { + data = strings.TrimSpace(data) + if data != "" { + handler([]byte(data)) + } + } + } + if err := scanner.Err(); err != nil && ctx.Err() == nil { + return fmt.Errorf("SSE %s: read: %w", endpoint, err) + } + return nil +} diff --git a/vpn/ipc/conn_nonwindows.go b/ipc/conn_nonwindows.go similarity index 91% rename from vpn/ipc/conn_nonwindows.go rename to ipc/conn_nonwindows.go index 76266fd8..6aee1c41 100644 --- a/vpn/ipc/conn_nonwindows.go +++ b/ipc/conn_nonwindows.go @@ -15,11 +15,9 @@ import ( const apiURL = "http://lantern" -func dialContext(_ context.Context, _, _ string) (net.Conn, error) { - return net.DialUnix("unix", nil, &net.UnixAddr{ - Name: socketPath(), - Net: "unix", - }) +func dialContext(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath()) } type sockListener struct { diff --git a/vpn/ipc/conn_windows.go b/ipc/conn_windows.go similarity index 100% rename from vpn/ipc/conn_windows.go rename to ipc/conn_windows.go diff --git a/vpn/ipc/middlewares.go b/ipc/middlewares.go similarity index 58% rename from vpn/ipc/middlewares.go rename to ipc/middlewares.go index e7e20d7f..74595845 100644 --- a/vpn/ipc/middlewares.go +++ b/ipc/middlewares.go @@ -6,17 +6,16 @@ import ( "log/slog" "net/http" - "github.com/go-chi/chi/v5/middleware" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" semconv "github.com/getlantern/semconv" "go.opentelemetry.io/otel/trace" - "github.com/getlantern/radiance/internal" + rlog "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/traces" ) -func log(h http.Handler) http.Handler { +func logger(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Pull the trace ID from the request, if it exists. ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header)) @@ -24,7 +23,7 @@ func log(h http.Handler) http.Handler { span := trace.SpanFromContext(r.Context()) span.SetAttributes(semconv.HTTPRouteKey.String(r.URL.Path)) - slog.Log(r.Context(), internal.LevelTrace, "IPC request", "method", r.Method, "path", r.URL.Path) + slog.Log(r.Context(), rlog.LevelTrace, "IPC request", "method", r.Method, "path", r.URL.Path) h.ServeHTTP(w, r) }) } @@ -36,15 +35,41 @@ func tracer(next http.Handler) http.Handler { r = r.WithContext(ctx) var buf bytes.Buffer - ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) - ww.Tee(&buf) + ww := &statusRecorder{ResponseWriter: w, body: &buf} next.ServeHTTP(ww, r) - if ww.Status() >= 400 { - traces.RecordError(ctx, fmt.Errorf("status %d: %s", ww.Status(), buf.String())) + if ww.status >= 400 { + traces.RecordError(ctx, fmt.Errorf("status %d: %s", ww.status, buf.String())) } }) } +// statusRecorder wraps http.ResponseWriter to capture the status code and response body. +type statusRecorder struct { + http.ResponseWriter + status int + body *bytes.Buffer +} + +func (r *statusRecorder) WriteHeader(code int) { + r.status = code + r.ResponseWriter.WriteHeader(code) +} + +func (r *statusRecorder) Write(b []byte) (int, error) { + if r.status == 0 { + r.status = http.StatusOK + } + r.body.Write(b) + return r.ResponseWriter.Write(b) +} + +// Flush implements http.Flusher if the underlying ResponseWriter supports it. +func (r *statusRecorder) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + func authPeer(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { peer := usrFromContext(r.Context()) diff --git a/ipc/outbound_test.go b/ipc/outbound_test.go new file mode 100644 index 00000000..d366ac2b --- /dev/null +++ b/ipc/outbound_test.go @@ -0,0 +1,74 @@ +package ipc + +import ( + "testing" + + box "github.com/getlantern/lantern-box" + LO "github.com/getlantern/lantern-box/option" + O "github.com/sagernet/sing-box/option" + singjson "github.com/sagernet/sing/common/json" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getlantern/radiance/servers" +) + +// TestSamizdatOptionsRoundTrip verifies that samizdat outbound options +// (specifically public_key) survive JSON serialization/deserialization +// through the IPC path. This was the root cause of the "public_key must +// be 64 hex characters (32 bytes), got len=0" bug -- standard encoding/json +// doesn't preserve typed Options on option.Outbound's any interface. +func TestSamizdatOptionsRoundTrip(t *testing.T) { + const testPubKey = "20ebb18d5fdf9bff27fe32ef9501035d8f0bb8dfb481a0a2363181560e0e8115" + const testShortID = "3b1a8fc7f1edf914" + + outbound := O.Outbound{ + Type: "samizdat", + Tag: "samizdat-out-test-route", + Options: &LO.SamizdatOutboundOptions{ + ServerOptions: O.ServerOptions{ + Server: "1.2.3.4", + ServerPort: 443, + }, + PublicKey: testPubKey, + ShortID: testShortID, + ServerName: "example.com", + }, + } + + // Verify that outbound options survive round-trip through sing-box context JSON + // when used within a ServerList (the transfer type for IPC). + t.Run("singbox_json_preserves_public_key_in_serverlist", func(t *testing.T) { + ctx := box.BaseContext() + + list := servers.ServerList{ + Servers: []*servers.Server{ + { + Tag: outbound.Tag, + Type: outbound.Type, + IsLantern: true, + Options: outbound, + }, + }, + } + + buf, err := singjson.MarshalContext(ctx, list) + require.NoError(t, err) + + // Verify public_key is in the serialized JSON + assert.Contains(t, string(buf), testPubKey, "serialized JSON should contain public_key") + + decoded, err := singjson.UnmarshalExtendedContext[servers.ServerList](ctx, buf) + require.NoError(t, err) + + require.Len(t, decoded.Servers, 1) + outOpts, ok := decoded.Servers[0].Options.(O.Outbound) + require.True(t, ok, "sing-box json should preserve typed Outbound Options") + + samOpts, ok := outOpts.Options.(*LO.SamizdatOutboundOptions) + require.True(t, ok, "sing-box json should preserve typed SamizdatOutboundOptions") + assert.Equal(t, testPubKey, samOpts.PublicKey, "public_key should survive round-trip") + assert.Equal(t, testShortID, samOpts.ShortID, "short_id should survive round-trip") + assert.Equal(t, "example.com", samOpts.ServerName, "server_name should survive round-trip") + }) +} diff --git a/ipc/server.go b/ipc/server.go new file mode 100644 index 00000000..eac08cf5 --- /dev/null +++ b/ipc/server.go @@ -0,0 +1,1110 @@ +// Package ipc implements the IPC server for communicating between the client and the VPN service. +// It provides HTTP endpoints for retrieving statistics, managing groups, selecting outbounds, +// changing modes, and closing connections. +package ipc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net" + "net/http" + "sync/atomic" + "time" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/backend" + "github.com/getlantern/radiance/common/env" + "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/config" + "github.com/getlantern/radiance/events" + rlog "github.com/getlantern/radiance/log" + "github.com/getlantern/radiance/vpn" + + sjson "github.com/sagernet/sing/common/json" +) + +const ( + tracerName = "github.com/getlantern/radiance/ipc" + + // VPN endpoints + vpnStatusEndpoint = "/vpn/status" + vpnConnectEndpoint = "/vpn/connect" + vpnDisconnectEndpoint = "/vpn/disconnect" + vpnRestartEndpoint = "/vpn/restart" + vpnConnectionsEndpoint = "/vpn/connections" + vpnOfflineTestsEndpoint = "/vpn/offline-tests" + vpnStatusEventsEndpoint = "/vpn/status/events" + + // Server selection endpoints + serverSelectedEndpoint = "/server/selected" + serverAutoSelectedEndpoint = "/server/auto-selected" + serverAutoSelectedEventsEndpoint = "/server/auto-selected/events" + + // Config endpoints + configEventsEndpoint = "/config/events" + configUpdateEndpoint = "/config/update" + + // Server management endpoints + serversEndpoint = "/servers" + serversAddEndpoint = "/servers/add" + serversRemoveEndpoint = "/servers/remove" + serversFromJSONEndpoint = "/servers/json" + serversFromURLsEndpoint = "/servers/urls" + serversPrivateEndpoint = "/servers/private" + serversPrivateInviteEndpoint = "/servers/private/invite" + + // Settings endpoints + featuresEndpoint = "/settings/features" + settingsEndpoint = "/settings" + + // Split tunnel endpoint + splitTunnelEndpoint = "/split-tunnel" + + // Account endpoints + accountNewUserEndpoint = "/account/new-user" + accountLoginEndpoint = "/account/login" + accountLogoutEndpoint = "/account/logout" + accountUserDataEndpoint = "/account/user" + accountDevicesEndpoint = "/account/devices/" + accountSignupEndpoint = "/account/signup/" + accountEmailEndpoint = "/account/email" + accountRecoveryEndpoint = "/account/recovery" + accountDeleteEndpoint = "/account/delete" + accountOAuthEndpoint = "/account/oauth" + accountDataCapEndpoint = "/account/datacap" + accountDataCapStreamEndpoint = "/account/datacap/stream" + + // Subscription endpoints + subscriptionActivationEndpoint = "/subscription/activation" + subscriptionStripeEndpoint = "/subscription/stripe" + subscriptionPaymentRedirectEndpoint = "/subscription/payment-redirect" + subscriptionReferralEndpoint = "/subscription/referral" + subscriptionBillingPortalEndpoint = "/subscription/billing-portal" + subscriptionPaymentRedirectURLEndpoint = "/subscription/payment-redirect-url" + subscriptionPlansEndpoint = "/subscription/plans" + subscriptionVerifyEndpoint = "/subscription/verify" + + // Issue endpoint + issueEndpoint = "/issue" + + // Logs endpoint + logsStreamEndpoint = "/logs/stream" + + // Env endpoint (dev/testing) + envEndpoint = "/env" +) + +var ( + protocols = func() http.Protocols { + var p http.Protocols + p.SetUnencryptedHTTP2(true) + return p + }() + + ErrServiceIsNotReady = errors.New("service is not ready") + ErrIPCNotRunning = errors.New("IPC not running") +) + +// Server represents the IPC server that communicates over a Unix domain socket for Unix-like +// systems, and a named pipe for Windows. +type Server struct { + svr *http.Server + closed atomic.Bool +} + +// NewServer returns an IPC server backed by b. When withAuth is true, the +// server authenticates each connection; when false, it accepts all connections. +func NewServer(b *backend.LocalBackend, withAuth bool) *Server { + svr := &http.Server{ + Handler: newLocalAPI(b, withAuth), + ReadTimeout: 5 * time.Second, + Protocols: &protocols, + } + if withAuth { + svr.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + peer, err := getConnPeer(c) + if err != nil { + slog.Error("Failed to get peer credentials", "error", err) + } + return contextWithUsr(ctx, peer) + } + } + return &Server{svr: svr} +} + +// Start begins listening for incoming IPC requests. +func (s *Server) Start() error { + if s.closed.Load() { + return errors.New("IPC server is closed") + } + l, err := listen() + if err != nil { + return fmt.Errorf("IPC server: listen: %w", err) + } + go func() { + slog.Info("IPC server started", "address", l.Addr().String()) + if err := s.svr.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) { + slog.Error("IPC server error", "error", err) + } + s.closed.Store(true) + }() + return nil +} + +// Close shuts down the IPC server. +func (s *Server) Close() error { + if s.closed.Swap(true) { + return nil + } + slog.Info("Closing IPC server") + return s.svr.Close() +} + +type backendKey struct{} + +type localapi struct { + be atomic.Pointer[backend.LocalBackend] + handler http.Handler +} + +// backend returns the LocalBackend snapshotted at the start of the request. +func (s *localapi) backend(ctx context.Context) *backend.LocalBackend { + return ctx.Value(backendKey{}).(*backend.LocalBackend) +} + +func newLocalAPI(b *backend.LocalBackend, withAuth bool) *localapi { + s := &localapi{} + s.be.Store(b) + + mux := http.NewServeMux() + + // traced wraps a handler with the tracer middleware. + traced := func(h http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + tracer(http.HandlerFunc(h)).ServeHTTP(w, r) + } + } + + // VPN + mux.HandleFunc("GET "+vpnStatusEndpoint, traced(s.vpnStatusHandler)) + mux.HandleFunc("POST "+vpnConnectEndpoint, traced(s.vpnConnectHandler)) + mux.HandleFunc("POST "+vpnDisconnectEndpoint, traced(s.vpnDisconnectHandler)) + mux.HandleFunc("POST "+vpnRestartEndpoint, traced(s.vpnRestartHandler)) + mux.HandleFunc("GET "+vpnConnectionsEndpoint, traced(s.vpnConnectionsHandler)) + mux.HandleFunc("POST "+vpnOfflineTestsEndpoint, traced(s.vpnOfflineTestsHandler)) + + // SSE routes skip the tracer middleware since it buffers the entire response body. + mux.HandleFunc("GET "+vpnStatusEventsEndpoint, s.vpnStatusEventsHandler) + + // Server selection + mux.HandleFunc(serverSelectedEndpoint, traced(s.serverSelectedHandler)) + mux.HandleFunc("GET "+serverAutoSelectedEndpoint, traced(s.serverAutoSelectedHandler)) + mux.HandleFunc("GET "+serverAutoSelectedEventsEndpoint, s.serverAutoSelectedEventsHandler) + mux.HandleFunc("GET "+configEventsEndpoint, s.configEventsHandler) + mux.HandleFunc("POST "+configUpdateEndpoint, traced(s.configUpdateHandler)) + + // Server management + mux.HandleFunc("GET "+serversEndpoint, traced(s.serversHandler)) + mux.HandleFunc("POST "+serversAddEndpoint, traced(s.serversAddHandler)) + mux.HandleFunc("POST "+serversRemoveEndpoint, traced(s.serversRemoveHandler)) + mux.HandleFunc("POST "+serversFromJSONEndpoint, traced(s.serversFromJSONHandler)) + mux.HandleFunc("POST "+serversFromURLsEndpoint, traced(s.serversFromURLsHandler)) + mux.HandleFunc("POST "+serversPrivateEndpoint, traced(s.serversPrivateAddHandler)) + mux.HandleFunc(serversPrivateInviteEndpoint, traced(s.serversPrivateInviteHandler)) + + // Settings + mux.HandleFunc("GET "+featuresEndpoint, traced(s.featuresHandler)) + mux.HandleFunc(settingsEndpoint, traced(s.settingsHandler)) + + // Split tunnel + mux.HandleFunc(splitTunnelEndpoint, traced(s.splitTunnelHandler)) + + // Account + mux.HandleFunc("POST "+accountNewUserEndpoint, traced(s.accountNewUserHandler)) + mux.HandleFunc("POST "+accountLoginEndpoint, traced(s.accountLoginHandler)) + mux.HandleFunc("POST "+accountLogoutEndpoint, traced(s.accountLogoutHandler)) + mux.HandleFunc("GET "+accountUserDataEndpoint, traced(s.accountUserDataHandler)) + mux.HandleFunc(accountDevicesEndpoint+"{deviceID...}", traced(s.accountDevicesHandler)) + mux.HandleFunc("POST "+accountSignupEndpoint+"{action...}", traced(s.accountSignupHandler)) + mux.HandleFunc("POST "+accountEmailEndpoint+"/{action}", traced(s.accountEmailHandler)) + mux.HandleFunc("POST "+accountRecoveryEndpoint+"/{action}", traced(s.accountRecoveryHandler)) + mux.HandleFunc("DELETE "+accountDeleteEndpoint, traced(s.accountDeleteHandler)) + mux.HandleFunc(accountOAuthEndpoint, traced(s.accountOAuthHandler)) + mux.HandleFunc("GET "+accountDataCapEndpoint, traced(s.accountDataCapHandler)) + + // SSE routes skip the tracer middleware since it buffers the entire response body. + mux.HandleFunc("GET "+accountDataCapStreamEndpoint, s.accountDataCapStreamHandler) + + // Subscriptions + mux.HandleFunc("POST "+subscriptionActivationEndpoint, traced(s.subscriptionActivationHandler)) + mux.HandleFunc("POST "+subscriptionStripeEndpoint, traced(s.subscriptionStripeHandler)) + mux.HandleFunc("POST "+subscriptionPaymentRedirectEndpoint, traced(s.subscriptionPaymentRedirectHandler)) + mux.HandleFunc("POST "+subscriptionReferralEndpoint, traced(s.subscriptionReferralHandler)) + mux.HandleFunc("GET "+subscriptionBillingPortalEndpoint, traced(s.subscriptionBillingPortalHandler)) + mux.HandleFunc("POST "+subscriptionPaymentRedirectURLEndpoint, traced(s.subscriptionPaymentRedirectURLHandler)) + mux.HandleFunc("GET "+subscriptionPlansEndpoint, traced(s.subscriptionPlansHandler)) + mux.HandleFunc("POST "+subscriptionVerifyEndpoint, traced(s.subscriptionVerifyHandler)) + + // Issue + mux.HandleFunc("POST "+issueEndpoint, traced(s.issueReportHandler)) + + // Logs (SSE, skip tracer) + mux.HandleFunc("GET "+logsStreamEndpoint, s.logsStreamHandler) + + // Env (dev/testing) + mux.HandleFunc(envEndpoint, traced(s.envHandler)) + + // Build the middleware chain: log -> (optional auth) -> mux + var handler http.Handler = mux + if withAuth { + handler = authPeer(handler) + } + handler = logger(handler) + s.handler = handler + + return s +} + +func (s *localapi) setBackend(b *backend.LocalBackend) *backend.LocalBackend { + return s.be.Swap(b) +} + +func (s *localapi) ServeHTTP(w http.ResponseWriter, r *http.Request) { + b := s.be.Load() + if b == nil { + http.Error(w, "service is not ready", http.StatusServiceUnavailable) + return + } + ctx := context.WithValue(r.Context(), backendKey{}, b) + s.handler.ServeHTTP(w, r.WithContext(ctx)) +} + +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(v); err != nil { + slog.Error("IPC: failed to write JSON response", "error", err) + } +} + +func decodeJSON(r *http.Request, v any) error { + return json.NewDecoder(r.Body).Decode(v) +} + +func writeSingJSON[T any](w http.ResponseWriter, status int, v T) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := sjson.NewEncoderContext(boxCtx, w).Encode(v); err != nil { + slog.Error("IPC: failed to write JSON response", "error", err) + } +} + +func decodeSingJSON(r *http.Request, v any) error { + return sjson.NewDecoderContext(boxCtx, r.Body).Decode(v) +} + +// sseWriter sets headers for a Server-Sent Events response and returns the flusher. +// Returns nil if the ResponseWriter does not support flushing. +func sseWriter(w http.ResponseWriter) http.Flusher { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming not supported", http.StatusInternalServerError) + return nil + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + return flusher +} + +///////////// +// VPN // +///////////// + +func (s *localapi) vpnStatusHandler(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, s.backend(r.Context()).VPNStatus()) +} + +func (s *localapi) vpnConnectHandler(w http.ResponseWriter, r *http.Request) { + var req TagRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).ConnectVPN(req.Tag); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) vpnDisconnectHandler(w http.ResponseWriter, r *http.Request) { + if err := s.backend(r.Context()).DisconnectVPN(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) vpnRestartHandler(w http.ResponseWriter, r *http.Request) { + if err := s.backend(r.Context()).RestartVPN(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +// vpnConnectionsHandler handles GET /vpn/connections/ (all) and GET /vpn/connections/active. +func (s *localapi) vpnConnectionsHandler(w http.ResponseWriter, r *http.Request) { + var ( + conns []vpn.Connection + err error + ) + if r.URL.Query().Get("active") == "true" { + conns, err = s.backend(r.Context()).ActiveVPNConnections() + } else { + conns, err = s.backend(r.Context()).VPNConnections() + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, conns) +} + +func (s *localapi) vpnOfflineTestsHandler(w http.ResponseWriter, r *http.Request) { + if err := s.backend(r.Context()).RunOfflineURLTests(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) vpnStatusEventsHandler(w http.ResponseWriter, r *http.Request) { + flusher := sseWriter(w) + if flusher == nil { + return + } + ch := make(chan []byte, 16) + sub := events.Subscribe(func(evt vpn.StatusUpdateEvent) { + data, err := json.Marshal(evt) + if err != nil { + return + } + select { + case ch <- data: + default: + } + }) + defer sub.Unsubscribe() + for { + select { + case data := <-ch: + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +/////////////////////// +// Server selection // +/////////////////////// + +// serverSelectedHandler handles GET /server/selected (read) and POST /server/selected (set). +func (s *localapi) serverSelectedHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var req TagRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).SelectServer(req.Tag); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + return + } + server, exists, err := s.backend(r.Context()).SelectedServer() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeSingJSON(w, http.StatusOK, SelectedServerResponse{Server: server, Exists: exists}) +} + +func (s *localapi) serverAutoSelectedHandler(w http.ResponseWriter, r *http.Request) { + tag, err := s.backend(r.Context()).CurrentAutoSelectedServer() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + server, found := s.backend(r.Context()).GetServerByTag(tag) + if !found { + http.Error(w, "auto-selected server not found", http.StatusNotFound) + return + } + writeSingJSON(w, http.StatusOK, server) +} + +func (s *localapi) serverAutoSelectedEventsHandler(w http.ResponseWriter, r *http.Request) { + flusher := sseWriter(w) + if flusher == nil { + return + } + ch := make(chan []byte, 16) + sub := events.Subscribe(func(evt vpn.AutoSelectedEvent) { + data, err := json.Marshal(evt) + if err != nil { + return + } + select { + case ch <- data: + default: + } + }) + defer sub.Unsubscribe() + for { + select { + case data := <-ch: + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +func (s *localapi) configUpdateHandler(w http.ResponseWriter, r *http.Request) { + if err := s.backend(r.Context()).UpdateConfig(); err != nil { + status := http.StatusInternalServerError + if errors.Is(err, config.ErrConfigFetchDisabled) { + status = http.StatusConflict + } + http.Error(w, err.Error(), status) + return + } + w.WriteHeader(http.StatusOK) +} + +// configEventsHandler streams a notification on every config.NewConfigEvent. +// The payload is always "{}" — subscribers only need to know a change +// occurred and fetch fresh state through the other GET endpoints, so we don't +// serialize the (potentially large) full Config. +func (s *localapi) configEventsHandler(w http.ResponseWriter, r *http.Request) { + flusher := sseWriter(w) + if flusher == nil { + return + } + ch := make(chan struct{}, 16) + sub := events.Subscribe(func(evt config.NewConfigEvent) { + select { + case ch <- struct{}{}: + default: + } + }) + defer sub.Unsubscribe() + for { + select { + case <-ch: + fmt.Fprint(w, "data: {}\n\n") + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +/////////////////////// +// Server management // +/////////////////////// + +// serversHandler handles GET /servers +func (s *localapi) serversHandler(w http.ResponseWriter, r *http.Request) { + if tag := r.URL.Query().Get("tag"); tag != "" { + server, found := s.backend(r.Context()).GetServerByTag(tag) + if !found { + http.Error(w, "server not found", http.StatusNotFound) + return + } + writeSingJSON(w, http.StatusOK, server) + return + } + writeSingJSON(w, http.StatusOK, s.backend(r.Context()).AllServers()) +} + +func (s *localapi) serversAddHandler(w http.ResponseWriter, r *http.Request) { + var req AddServersRequest + if err := decodeSingJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).AddServers(req.Servers); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) serversRemoveHandler(w http.ResponseWriter, r *http.Request) { + var req RemoveServersRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).RemoveServers(req.Tags); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) serversFromJSONHandler(w http.ResponseWriter, r *http.Request) { + var req JSONConfigRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + tags, err := s.backend(r.Context()).AddServersByJSON(req.Config) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, tags) +} + +func (s *localapi) serversFromURLsHandler(w http.ResponseWriter, r *http.Request) { + var req URLsRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + tags, err := s.backend(r.Context()).AddServersByURL(req.URLs, req.SkipCertVerification) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, tags) +} + +func (s *localapi) serversPrivateAddHandler(w http.ResponseWriter, r *http.Request) { + var req PrivateServerRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err := s.backend(r.Context()).AddPrivateServer(req.Tag, req.IP, req.Port, req.AccessToken, req.Location, req.Joined) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +// serversPrivateInviteHandler handles POST (create) and DELETE (revoke) on /servers/private/invite. +func (s *localapi) serversPrivateInviteHandler(w http.ResponseWriter, r *http.Request) { + var req PrivateServerInviteRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if r.Method == http.MethodDelete { + if err := s.backend(r.Context()).RevokePrivateServerInvite(req.IP, req.Port, req.AccessToken, req.InviteName); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + return + } + code, err := s.backend(r.Context()).InviteToPrivateServer(req.IP, req.Port, req.AccessToken, req.InviteName) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, CodeResponse{Code: code}) +} + +////////////// +// Settings // +////////////// + +func (s *localapi) featuresHandler(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, s.backend(r.Context()).Features()) +} + +func (s *localapi) settingsHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPatch: + var updates settings.Settings + if err := decodeJSON(r, &updates); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).PatchSettings(updates); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + fallthrough + case http.MethodGet: + writeJSON(w, http.StatusOK, settings.GetAll()) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *localapi) envHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPatch: + var updates map[string]string + if err := decodeJSON(r, &updates); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + for k, v := range updates { + env.Set(k, v) + switch k { + case env.Country.String(): + if err := settings.Set(settings.CountryCodeKey, v); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + case env.FeatureOverrides.String(): + if err := settings.Set(settings.FeatureOverridesKey, v); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + } + fallthrough + case http.MethodGet: + writeJSON(w, http.StatusOK, env.GetAll()) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +///////////////// +// Split Tunnel // +///////////////// + +// splitTunnelHandler handles GET (read), POST (add), and DELETE (remove) on /split-tunnel. +func (s *localapi) splitTunnelHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + writeJSON(w, http.StatusOK, s.backend(r.Context()).SplitTunnelFilters()) + return + } + var items vpn.SplitTunnelFilter + if err := decodeJSON(r, &items); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + var err error + switch r.Method { + case http.MethodPost: + err = s.backend(r.Context()).AddSplitTunnelItems(items) + case http.MethodDelete: + err = s.backend(r.Context()).RemoveSplitTunnelItems(items) + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +///////////// +// Account // +///////////// + +func (s *localapi) accountNewUserHandler(w http.ResponseWriter, r *http.Request) { + userData, err := s.backend(r.Context()).NewUser(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, userData) +} + +func (s *localapi) accountLoginHandler(w http.ResponseWriter, r *http.Request) { + var req EmailPasswordRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userData, err := s.backend(r.Context()).Login(r.Context(), req.Email, req.Password) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + writeJSON(w, http.StatusOK, userData) +} + +func (s *localapi) accountLogoutHandler(w http.ResponseWriter, r *http.Request) { + var req EmailRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userData, err := s.backend(r.Context()).Logout(r.Context(), req.Email) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, userData) +} + +func (s *localapi) accountUserDataHandler(w http.ResponseWriter, r *http.Request) { + var userData *account.UserData + var err error + if r.URL.Query().Get("fetch") == "true" { + userData, err = s.backend(r.Context()).FetchUserData(r.Context()) + } else { + userData, err = s.backend(r.Context()).UserData() + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, userData) +} + +// accountDevicesHandler handles GET /account/devices (list) and DELETE /account/devices/{deviceID} (remove). +func (s *localapi) accountDevicesHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodDelete { + resp, err := s.backend(r.Context()).RemoveDevice(r.Context(), r.PathValue("deviceID")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, resp) + return + } + devices, err := s.backend(r.Context()).UserDevices() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, devices) +} + +// accountSignupHandler handles POST /account/signup, /account/signup/confirm, and /account/signup/resend. +func (s *localapi) accountSignupHandler(w http.ResponseWriter, r *http.Request) { + switch r.PathValue("action") { + case "confirm": + var req EmailCodeRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).SignupEmailConfirmation(r.Context(), req.Email, req.Code); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + case "resend": + var req EmailRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).SignupEmailResendCode(r.Context(), req.Email); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + default: + var req EmailPasswordRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + salt, resp, err := s.backend(r.Context()).SignUp(r.Context(), req.Email, req.Password) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, SignupResponse{Salt: salt, Response: resp}) + } +} + +// accountEmailHandler handles POST /account/email/{action} for start and complete. +func (s *localapi) accountEmailHandler(w http.ResponseWriter, r *http.Request) { + var err error + switch r.PathValue("action") { + case "start": + var req ChangeEmailStartRequest + if err = decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = s.backend(r.Context()).StartChangeEmail(r.Context(), req.NewEmail, req.Password) + case "complete": + var req ChangeEmailCompleteRequest + if err = decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = s.backend(r.Context()).CompleteChangeEmail(r.Context(), req.NewEmail, req.Password, req.Code) + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +// accountRecoveryHandler handles POST /account/recovery/{action} for start, complete, and validate. +func (s *localapi) accountRecoveryHandler(w http.ResponseWriter, r *http.Request) { + var err error + switch r.PathValue("action") { + case "start": + var req EmailRequest + if err = decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = s.backend(r.Context()).StartRecoveryByEmail(r.Context(), req.Email) + case "complete": + var req RecoveryCompleteRequest + if err = decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = s.backend(r.Context()).CompleteRecoveryByEmail(r.Context(), req.Email, req.NewPassword, req.Code) + case "validate": + var req EmailCodeRequest + if err = decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = s.backend(r.Context()).ValidateEmailRecoveryCode(r.Context(), req.Email, req.Code) + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *localapi) accountDeleteHandler(w http.ResponseWriter, r *http.Request) { + var req EmailPasswordRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userData, err := s.backend(r.Context()).DeleteAccount(r.Context(), req.Email, req.Password) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, userData) +} + +// accountOAuthHandler handles GET /account/oauth (login URL) and POST /account/oauth (callback). +func (s *localapi) accountOAuthHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var req OAuthTokenRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userData, err := s.backend(r.Context()).OAuthLoginCallback(r.Context(), req.OAuthToken) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, userData) + return + } + provider := r.URL.Query().Get("provider") + if provider == "" { + http.Error(w, "provider is required", http.StatusBadRequest) + return + } + u, err := s.backend(r.Context()).OAuthLoginURL(r.Context(), provider) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, URLResponse{URL: u}) +} + +func (s *localapi) accountDataCapHandler(w http.ResponseWriter, r *http.Request) { + info, err := s.backend(r.Context()).DataCapInfo(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, info) +} + +func (s *localapi) accountDataCapStreamHandler(w http.ResponseWriter, r *http.Request) { + flusher := sseWriter(w) + if flusher == nil { + return + } + ch := s.backend(r.Context()).DataCapUpdates() + for { + select { + case info := <-ch: + data, err := json.Marshal(info) + if err != nil { + continue + } + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +/////////////////// +// Subscriptions // +/////////////////// + +func (s *localapi) subscriptionActivationHandler(w http.ResponseWriter, r *http.Request) { + var req ActivationRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp, err := s.backend(r.Context()).ActivationCode(r.Context(), req.Email, req.ResellerCode) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, resp) +} + +func (s *localapi) subscriptionStripeHandler(w http.ResponseWriter, r *http.Request) { + var req StripeSubscriptionRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + clientSecret, err := s.backend(r.Context()).NewStripeSubscription(r.Context(), req.Email, req.PlanID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, ClientSecretResponse{ClientSecret: clientSecret}) +} + +func (s *localapi) subscriptionPaymentRedirectHandler(w http.ResponseWriter, r *http.Request) { + var req account.PaymentRedirectData + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + u, err := s.backend(r.Context()).PaymentRedirect(r.Context(), req) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, URLResponse{URL: u}) +} + +func (s *localapi) subscriptionReferralHandler(w http.ResponseWriter, r *http.Request) { + var req CodeRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + ok, err := s.backend(r.Context()).ReferralAttach(r.Context(), req.Code) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, SuccessResponse{Success: ok}) +} + +func (s *localapi) subscriptionBillingPortalHandler(w http.ResponseWriter, r *http.Request) { + u, err := s.backend(r.Context()).StripeBillingPortalURL(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, URLResponse{URL: u}) +} + +func (s *localapi) subscriptionPaymentRedirectURLHandler(w http.ResponseWriter, r *http.Request) { + var req account.PaymentRedirectData + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + u, err := s.backend(r.Context()).SubscriptionPaymentRedirectURL(r.Context(), req) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, URLResponse{URL: u}) +} + +func (s *localapi) subscriptionPlansHandler(w http.ResponseWriter, r *http.Request) { + plans, err := s.backend(r.Context()).SubscriptionPlans(r.Context(), r.URL.Query().Get("channel")) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, PlansResponse{Plans: plans}) +} + +func (s *localapi) subscriptionVerifyHandler(w http.ResponseWriter, r *http.Request) { + var req VerifySubscriptionRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + result, err := s.backend(r.Context()).VerifySubscription(r.Context(), req.Service, req.Data) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, ResultResponse{Result: result}) +} + +/////////// +// Issue // +/////////// + +func (s *localapi) issueReportHandler(w http.ResponseWriter, r *http.Request) { + var req IssueReportRequest + if err := decodeJSON(r, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := s.backend(r.Context()).ReportIssue(req.IssueType, req.Description, req.Email, req.AdditionalAttachments); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +/////////// +// Logs // +/////////// + +func (s *localapi) logsStreamHandler(w http.ResponseWriter, r *http.Request) { + flusher := sseWriter(w) + if flusher == nil { + return + } + ch, unsub := rlog.Subscribe() + defer unsub() + for { + select { + case entry := <-ch: + fmt.Fprintf(w, "data: %s\n\n", entry) + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} diff --git a/ipc/socket.go b/ipc/socket.go new file mode 100644 index 00000000..98e52710 --- /dev/null +++ b/ipc/socket.go @@ -0,0 +1,24 @@ +//go:build !android && !ios && !windows + +package ipc + +import ( + "os" +) + +// use a var so it can be overridden in tests +var _socketPath = "/var/run/lantern/lanternd.sock" + +// setSocketPathForTesting is only used for testing. +func setSocketPathForTesting(path string) { + _socketPath = path +} + +func socketPath() string { + return _socketPath +} + +func setPermissions() error { + // we'll check if user is sudoer to restrict access + return os.Chmod(socketPath(), 0666) +} diff --git a/vpn/ipc/socket_mobile.go b/ipc/socket_mobile.go similarity index 74% rename from vpn/ipc/socket_mobile.go rename to ipc/socket_mobile.go index 6383a570..c7289f1e 100644 --- a/vpn/ipc/socket_mobile.go +++ b/ipc/socket_mobile.go @@ -11,7 +11,7 @@ import ( "syscall" "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" ) // this is a no-op on mobile @@ -41,7 +41,7 @@ func getNonRootOwner(path string) (uid, gid int) { return uid, gid } - slog.Log(context.Background(), internal.LevelTrace, "searching for non-root owner of", "path", path) + slog.Log(context.Background(), log.LevelTrace, "searching for non-root owner of", "path", path) for { parentDir := filepath.Dir(path) if parentDir == path || parentDir == "/" { @@ -51,7 +51,7 @@ func getNonRootOwner(path string) (uid, gid int) { fInfo, err := os.Stat(path) if err != nil { - slog.Log(context.Background(), internal.LevelTrace, "stat error", "path", path, "error", err) + slog.Log(context.Background(), log.LevelTrace, "stat error", "path", path, "error", err) continue } stat, ok := fInfo.Sys().(*syscall.Stat_t) @@ -59,11 +59,11 @@ func getNonRootOwner(path string) (uid, gid int) { continue } if int(stat.Uid) != 0 { - slog.Log(context.Background(), internal.LevelTrace, "found non-root owner", "path", path, "uid", stat.Uid, "gid", stat.Gid) + slog.Log(context.Background(), log.LevelTrace, "found non-root owner", "path", path, "uid", stat.Uid, "gid", stat.Gid) return int(stat.Uid), int(stat.Gid) } } - if slog.Default().Enabled(context.Background(), internal.LevelTrace) { + if slog.Default().Enabled(context.Background(), log.LevelTrace) { slog.Warn("falling back to root owner for", "path", path) } return uid, gid diff --git a/vpn/ipc/testsetup.go b/ipc/testsetup.go similarity index 100% rename from vpn/ipc/testsetup.go rename to ipc/testsetup.go diff --git a/ipc/types.go b/ipc/types.go new file mode 100644 index 00000000..b79f9c30 --- /dev/null +++ b/ipc/types.go @@ -0,0 +1,145 @@ +package ipc + +import ( + "github.com/getlantern/common" + + "github.com/getlantern/radiance/account" + "github.com/getlantern/radiance/issue" + "github.com/getlantern/radiance/servers" +) + +// Shared request types used by both client and server. + +type TagRequest struct { + Tag string `json:"tag"` +} + +type EmailRequest struct { + Email string `json:"email"` +} + +type EmailPasswordRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type EmailCodeRequest struct { + Email string `json:"email"` + Code string `json:"code"` +} + +type OAuthTokenRequest struct { + OAuthToken string `json:"oAuthToken"` +} + +type CodeRequest struct { + Code string `json:"code"` +} + +type JSONConfigRequest struct { + Config string `json:"config"` +} + +type AddServersRequest struct { + Servers servers.ServerList `json:"servers"` +} + +type RemoveServersRequest struct { + Tags []string `json:"tags"` +} + +type URLsRequest struct { + URLs []string `json:"urls"` + SkipCertVerification bool `json:"skipCertVerification"` +} + +type PrivateServerRequest struct { + Tag string `json:"tag"` + IP string `json:"ip"` + Port int `json:"port"` + AccessToken string `json:"accessToken"` + Location common.ServerLocation `json:"location"` + Joined bool `json:"joined"` +} + +type PrivateServerInviteRequest struct { + IP string `json:"ip"` + Port int `json:"port"` + AccessToken string `json:"accessToken"` + InviteName string `json:"inviteName"` +} + +type ChangeEmailStartRequest struct { + NewEmail string `json:"newEmail"` + Password string `json:"password"` +} + +type ChangeEmailCompleteRequest struct { + NewEmail string `json:"newEmail"` + Password string `json:"password"` + Code string `json:"code"` +} + +type RecoveryCompleteRequest struct { + Email string `json:"email"` + NewPassword string `json:"newPassword"` + Code string `json:"code"` +} + +type ActivationRequest struct { + Email string `json:"email"` + ResellerCode string `json:"resellerCode"` +} + +type StripeSubscriptionRequest struct { + Email string `json:"email"` + PlanID string `json:"planID"` +} + +type VerifySubscriptionRequest struct { + Service account.SubscriptionService `json:"service"` + Data map[string]string `json:"data"` +} + +type IssueReportRequest struct { + IssueType issue.IssueType `json:"issueType"` + Description string `json:"description"` + Email string `json:"email"` + AdditionalAttachments []string `json:"additionalAttachments"` +} + +// Shared response types used by both client and server. + +type SelectedServerResponse struct { + Server *servers.Server `json:"server"` + Exists bool `json:"exists"` +} + +type SignupResponse struct { + Salt []byte `json:"salt"` + Response *account.SignupResponse `json:"response"` +} + +type URLResponse struct { + URL string `json:"url"` +} + +type CodeResponse struct { + Code string `json:"code"` +} + +type ClientSecretResponse struct { + ClientSecret string `json:"clientSecret"` +} + +type SuccessResponse struct { + Success bool `json:"success"` +} + +type PlansResponse struct { + Plans string `json:"plans"` +} + +type ResultResponse struct { + Result string `json:"result"` +} diff --git a/vpn/ipc/usr.go b/ipc/usr.go similarity index 100% rename from vpn/ipc/usr.go rename to ipc/usr.go diff --git a/vpn/ipc/usr_darwin.go b/ipc/usr_darwin.go similarity index 100% rename from vpn/ipc/usr_darwin.go rename to ipc/usr_darwin.go diff --git a/vpn/ipc/usr_linux.go b/ipc/usr_linux.go similarity index 100% rename from vpn/ipc/usr_linux.go rename to ipc/usr_linux.go diff --git a/vpn/ipc/usr_windows.go b/ipc/usr_windows.go similarity index 100% rename from vpn/ipc/usr_windows.go rename to ipc/usr_windows.go diff --git a/vpn/ipc/zsyscall_windows.go b/ipc/zsyscall_windows.go similarity index 100% rename from vpn/ipc/zsyscall_windows.go rename to ipc/zsyscall_windows.go diff --git a/issue/archive.go b/issue/archive.go new file mode 100644 index 00000000..1ee7d642 --- /dev/null +++ b/issue/archive.go @@ -0,0 +1,277 @@ +package issue + +import ( + "archive/zip" + "bytes" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" +) + +// buildIssueArchive creates a zip archive containing all .log files found in +// logDir plus additional attachment files. The primary log (lantern.log) is +// given truncation priority; secondary log files and attachments are included +// greedily if space permits. The total compressed archive size will not exceed +// maxSize bytes. +func buildIssueArchive(logDir string, additionalFiles []string, maxSize int64) ([]byte, error) { + logFiles := globLogFiles(logDir) + + var primaryLogData []byte + var secondaryLogs []extraFile + + for _, lf := range logFiles { + data, err := snapshotLogFile(lf, maxSize) + if err != nil { + slog.Warn("unable to snapshot log file", "path", lf, "error", err) + continue + } + if len(data) == 0 { + continue + } + if filepath.Base(lf) == logArchiveName { + primaryLogData = data + } else { + secondaryLogs = append(secondaryLogs, extraFile{ + name: filepath.Base(lf), + data: data, + }) + } + } + + attachments := readExtraFiles(additionalFiles) + + return fitArchive(primaryLogData, secondaryLogs, attachments, maxSize) +} + +// globLogFiles returns all .log files in dir, sorted by filepath.Glob order. +func globLogFiles(dir string) []string { + matches, err := filepath.Glob(filepath.Join(dir, "*.log")) + if err != nil { + slog.Warn("unable to glob log files", "dir", dir, "error", err) + return nil + } + return matches +} + +// snapshotLogFile opens the log file, records its current size, and reads the tail +// up to a reasonable cap. +func snapshotLogFile(logPath string, maxCompressed int64) ([]byte, error) { + f, err := os.Open(logPath) + if err != nil { + return nil, err + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + return nil, err + } + + size := fi.Size() + if size == 0 { + return nil, nil + } + + // Cap the amount we read: even with poor compression, we'd never need more + // than maxCompressed * 20 bytes of uncompressed log to fill the archive. + maxRead := maxCompressed * 20 + readSize := size + if readSize > maxRead { + readSize = maxRead + } + + // Seek to read only the tail (most recent logs). + if size > readSize { + if _, err := f.Seek(size-readSize, io.SeekStart); err != nil { + return nil, err + } + } + + data := make([]byte, readSize) + n, err := io.ReadFull(f, data) + if err != nil && err != io.ErrUnexpectedEOF { + return nil, fmt.Errorf("reading log file: %w", err) + } + return data[:n], nil +} + +type extraFile struct { + name string + data []byte +} + +func readExtraFiles(paths []string) []extraFile { + var files []extraFile + for _, p := range paths { + data, err := os.ReadFile(p) + if err != nil { + slog.Warn("unable to read additional file", "path", p, "error", err) + continue + } + files = append(files, extraFile{ + name: filepath.Base(p), + data: data, + }) + } + return files +} + +// fitArchive builds a zip archive that fits within maxSize. The primary log +// (lantern.log) is given truncation priority, followed by secondary log files, +// then attachments. +func fitArchive(primaryLog []byte, secondaryLogs []extraFile, attachments []extraFile, maxSize int64) ([]byte, error) { + allLogs := logsFromPrimary(primaryLog, secondaryLogs) + + if len(allLogs) == 0 && len(attachments) == 0 { + return nil, nil + } + + // Try everything. + buf, err := writeArchive(allLogs, attachments) + if err != nil { + return nil, err + } + if int64(buf.Len()) <= maxSize { + return buf.Bytes(), nil + } + + // Try primary log only. + primaryLogs := logsFromPrimary(primaryLog, nil) + if len(primaryLog) > 0 { + buf, err = writeArchive(primaryLogs, nil) + if err != nil { + return nil, err + } + if int64(buf.Len()) <= maxSize { + // Full primary fits — greedily add secondary logs, then attachments. + return addExtrasGreedily(primaryLogs, secondaryLogs, attachments, maxSize) + } + + // Full primary doesn't fit — binary search for the maximum tail. + tailSize := searchMaxLogTail(primaryLog, maxSize) + tail := primaryLog[len(primaryLog)-tailSize:] + trimmedPrimary := logsFromPrimary(tail, nil) + return addExtrasGreedily(trimmedPrimary, secondaryLogs, attachments, maxSize) + } + + // No primary log — greedily add secondary logs and attachments. + return addExtrasGreedily(nil, secondaryLogs, attachments, maxSize) +} + +// logsFromPrimary builds a combined log entry list with the primary log first. +func logsFromPrimary(primaryLog []byte, secondaryLogs []extraFile) []extraFile { + var logs []extraFile + if len(primaryLog) > 0 { + logs = append(logs, extraFile{name: logArchiveName, data: primaryLog}) + } + logs = append(logs, secondaryLogs...) + return logs +} + +const logArchiveName = "lantern.log" + +func writeArchive(logs []extraFile, attachments []extraFile) (*bytes.Buffer, error) { + buf := new(bytes.Buffer) + w := zip.NewWriter(buf) + + for _, l := range logs { + fw, err := w.Create(l.name) + if err != nil { + return nil, err + } + if _, err := fw.Write(l.data); err != nil { + return nil, err + } + } + + for _, f := range attachments { + fw, err := w.Create("attachments/" + f.name) + if err != nil { + return nil, err + } + if _, err := fw.Write(f.data); err != nil { + return nil, err + } + } + + if err := w.Close(); err != nil { + return nil, err + } + return buf, nil +} + +// searchMaxLogTail binary-searches for the largest tail of logData (in 256KB chunks) +// that compresses into a zip archive not exceeding maxSize. +func searchMaxLogTail(logData []byte, maxSize int64) int { + const chunkSize = 256 * 1024 + n := len(logData) + lo, hi := 1, (n+chunkSize-1)/chunkSize + best := 0 + + for lo <= hi { + mid := lo + (hi-lo)/2 + tailBytes := mid * chunkSize + if tailBytes > n { + tailBytes = n + } + + logs := []extraFile{{name: logArchiveName, data: logData[n-tailBytes:]}} + buf, err := writeArchive(logs, nil) + if err != nil { + hi = mid - 1 + continue + } + if int64(buf.Len()) <= maxSize { + best = tailBytes + lo = mid + 1 + } else { + hi = mid - 1 + } + } + return best +} + +// addExtrasGreedily starts from the given base logs and greedily adds secondary +// log files then attachment files, keeping each only if the archive still fits +// within maxSize. +func addExtrasGreedily(baseLogs []extraFile, secondaryLogs []extraFile, attachments []extraFile, maxSize int64) ([]byte, error) { + currentLogs := make([]extraFile, len(baseLogs)) + copy(currentLogs, baseLogs) + var currentAttachments []extraFile + + buf, err := writeArchive(currentLogs, nil) + if err != nil { + return nil, err + } + lastGood := buf.Bytes() + + // Greedily add secondary log files. + for _, sl := range secondaryLogs { + trial := append(currentLogs[:len(currentLogs):len(currentLogs)], sl) + buf, err := writeArchive(trial, currentAttachments) + if err != nil { + continue + } + if int64(buf.Len()) <= maxSize { + currentLogs = trial + lastGood = buf.Bytes() + } + } + + // Greedily add attachment files. + for _, a := range attachments { + trial := append(currentAttachments[:len(currentAttachments):len(currentAttachments)], a) + buf, err := writeArchive(currentLogs, trial) + if err != nil { + continue + } + if int64(buf.Len()) <= maxSize { + currentAttachments = trial + lastGood = buf.Bytes() + } + } + + return lastGood, nil +} diff --git a/issue/archive_test.go b/issue/archive_test.go new file mode 100644 index 00000000..b2d6be9e --- /dev/null +++ b/issue/archive_test.go @@ -0,0 +1,457 @@ +package issue + +import ( + "archive/zip" + "bytes" + "crypto/rand" + "io" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSnapshotLogFile(t *testing.T) { + t.Run("reads full file when small", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "test.log") + content := "line1\nline2\nline3\n" + require.NoError(t, os.WriteFile(logPath, []byte(content), 0644)) + + data, err := snapshotLogFile(logPath, 1024*1024) + require.NoError(t, err) + assert.Equal(t, content, string(data)) + }) + + t.Run("reads only tail when file exceeds cap", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "test.log") + + // maxCompressed=100 → maxRead = 100*20 = 2000 + // Write 5000 bytes so the file exceeds the cap. + full := bytes.Repeat([]byte("X"), 5000) + require.NoError(t, os.WriteFile(logPath, full, 0644)) + + data, err := snapshotLogFile(logPath, 100) + require.NoError(t, err) + assert.Equal(t, 2000, len(data)) + // Should be the tail of the file. + assert.Equal(t, string(full[3000:]), string(data)) + }) + + t.Run("returns nil for empty file", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "empty.log") + require.NoError(t, os.WriteFile(logPath, nil, 0644)) + + data, err := snapshotLogFile(logPath, 1024*1024) + require.NoError(t, err) + assert.Nil(t, data) + }) + + t.Run("returns error for missing file", func(t *testing.T) { + _, err := snapshotLogFile("/nonexistent/path.log", 1024*1024) + assert.Error(t, err) + }) + + t.Run("snapshot is stable after file rotation", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "test.log") + original := "original log content\n" + require.NoError(t, os.WriteFile(logPath, []byte(original), 0644)) + + // Open and snapshot size (simulating what snapshotLogFile does internally). + f, err := os.Open(logPath) + require.NoError(t, err) + defer f.Close() + + fi, err := f.Stat() + require.NoError(t, err) + size := fi.Size() + + // Simulate rotation: rename the file and create a new one. + require.NoError(t, os.Rename(logPath, logPath+".1")) + require.NoError(t, os.WriteFile(logPath, []byte("new log content\n"), 0644)) + + // The original fd should still read the original data. + data := make([]byte, size) + n, err := f.Read(data) + require.NoError(t, err) + assert.Equal(t, original, string(data[:n])) + }) +} + +func TestGlobLogFiles(t *testing.T) { + t.Run("finds all log files", func(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "lantern.log"), []byte("main"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "lantern-crash.log"), []byte("crash"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "other.txt"), []byte("not a log"), 0644)) + + files := globLogFiles(dir) + require.Len(t, files, 2) + bases := make([]string, len(files)) + for i, f := range files { + bases[i] = filepath.Base(f) + } + assert.Contains(t, bases, "lantern.log") + assert.Contains(t, bases, "lantern-crash.log") + }) + + t.Run("returns nil for empty dir", func(t *testing.T) { + dir := t.TempDir() + files := globLogFiles(dir) + assert.Nil(t, files) + }) + + t.Run("returns nil for nonexistent dir", func(t *testing.T) { + files := globLogFiles("/nonexistent/dir") + assert.Nil(t, files) + }) +} + +func TestReadExtraFiles(t *testing.T) { + t.Run("reads existing files", func(t *testing.T) { + dir := t.TempDir() + f1 := filepath.Join(dir, "a.txt") + f2 := filepath.Join(dir, "b.txt") + require.NoError(t, os.WriteFile(f1, []byte("aaa"), 0644)) + require.NoError(t, os.WriteFile(f2, []byte("bbb"), 0644)) + + files := readExtraFiles([]string{f1, f2}) + require.Len(t, files, 2) + assert.Equal(t, "a.txt", files[0].name) + assert.Equal(t, "aaa", string(files[0].data)) + assert.Equal(t, "b.txt", files[1].name) + assert.Equal(t, "bbb", string(files[1].data)) + }) + + t.Run("skips missing files", func(t *testing.T) { + dir := t.TempDir() + existing := filepath.Join(dir, "exists.txt") + require.NoError(t, os.WriteFile(existing, []byte("data"), 0644)) + + files := readExtraFiles([]string{"/no/such/file", existing}) + require.Len(t, files, 1) + assert.Equal(t, "exists.txt", files[0].name) + }) + + t.Run("nil input returns nil", func(t *testing.T) { + files := readExtraFiles(nil) + assert.Nil(t, files) + }) +} + +func TestWriteArchive(t *testing.T) { + t.Run("log only", func(t *testing.T) { + logs := []extraFile{{name: logArchiveName, data: []byte("some log content")}} + buf, err := writeArchive(logs, nil) + require.NoError(t, err) + + entries := readZipEntries(t, buf.Bytes()) + require.Len(t, entries, 1) + assert.Equal(t, logArchiveName, entries[0].name) + assert.Equal(t, "some log content", entries[0].content) + }) + + t.Run("multiple logs with attachments", func(t *testing.T) { + logs := []extraFile{ + {name: "lantern.log", data: []byte("main log")}, + {name: "lantern-crash.log", data: []byte("crash log")}, + } + attachments := []extraFile{ + {name: "config.json", data: []byte(`{"key":"val"}`)}, + } + buf, err := writeArchive(logs, attachments) + require.NoError(t, err) + + entries := readZipEntries(t, buf.Bytes()) + require.Len(t, entries, 3) + assert.Equal(t, "lantern.log", entries[0].name) + assert.Equal(t, "lantern-crash.log", entries[1].name) + assert.Equal(t, "attachments/config.json", entries[2].name) + }) + + t.Run("attachments only", func(t *testing.T) { + attachments := []extraFile{{name: "file.txt", data: []byte("hello")}} + buf, err := writeArchive(nil, attachments) + require.NoError(t, err) + + entries := readZipEntries(t, buf.Bytes()) + require.Len(t, entries, 1) + assert.Equal(t, "attachments/file.txt", entries[0].name) + }) + + t.Run("empty inputs", func(t *testing.T) { + buf, err := writeArchive(nil, nil) + require.NoError(t, err) + // Should produce a valid but empty zip. + entries := readZipEntries(t, buf.Bytes()) + assert.Empty(t, entries) + }) +} + +func TestFitArchive(t *testing.T) { + t.Run("everything fits", func(t *testing.T) { + logData := []byte("small log") + secondary := []extraFile{{name: "crash.log", data: []byte("crash")}} + attachments := []extraFile{{name: "a.txt", data: []byte("small")}} + result, err := fitArchive(logData, secondary, attachments, 1024*1024) + require.NoError(t, err) + require.NotNil(t, result) + + entries := readZipEntries(t, result) + assert.Len(t, entries, 3) + }) + + t.Run("nil log and nil extras returns nil", func(t *testing.T) { + result, err := fitArchive(nil, nil, nil, 1024*1024) + require.NoError(t, err) + assert.Nil(t, result) + }) + + t.Run("attachments dropped when too large", func(t *testing.T) { + logData := []byte("log data") + // Make an attachment that's big enough to push past a small maxSize. + bigAttachment := extraFile{name: "big.bin", data: bytes.Repeat([]byte{0xFF}, 50*1024)} + + // Find the compressed size of just the log. + logs := []extraFile{{name: logArchiveName, data: logData}} + logOnly, err := writeArchive(logs, nil) + require.NoError(t, err) + maxSize := int64(logOnly.Len()) + 100 // just barely enough for log, not the extra + + result, err := fitArchive(logData, nil, []extraFile{bigAttachment}, maxSize) + require.NoError(t, err) + + entries := readZipEntries(t, result) + require.Len(t, entries, 1) + assert.Equal(t, logArchiveName, entries[0].name) + assert.Equal(t, "log data", entries[0].content) + }) + + t.Run("log truncated to tail when too large", func(t *testing.T) { + // Use incompressible random data (2MB) with a budget that fits ~1-2 + // chunks (256KB each) but not the full log. + logData := make([]byte, 2*1024*1024) // 2MB + _, err := rand.Read(logData) + require.NoError(t, err) + + maxSize := int64(512 * 1024) // 512KB + + result, err := fitArchive(logData, nil, nil, maxSize) + require.NoError(t, err) + assert.LessOrEqual(t, int64(len(result)), maxSize) + + entries := readZipEntries(t, result) + require.Len(t, entries, 1) + assert.Equal(t, logArchiveName, entries[0].name) + + // The included content should be a tail of the original. + content := entries[0].content + assert.True(t, len(content) < len(logData), "log should be truncated") + assert.Equal(t, string(logData[len(logData)-len(content):]), content, + "included content should be the tail of the original log") + }) + + t.Run("secondary logs and attachments only when no primary", func(t *testing.T) { + secondary := []extraFile{{name: "crash.log", data: []byte("crash")}} + attachments := []extraFile{{name: "a.txt", data: []byte("aaa")}} + result, err := fitArchive(nil, secondary, attachments, 1024*1024) + require.NoError(t, err) + + entries := readZipEntries(t, result) + assert.Len(t, entries, 2) + }) +} + +func TestSearchMaxLogTail(t *testing.T) { + t.Run("all fits", func(t *testing.T) { + logData := []byte("small log data") + tailSize := searchMaxLogTail(logData, 1024*1024) + assert.Equal(t, len(logData), tailSize) + }) + + t.Run("truncates incompressible data", func(t *testing.T) { + logData := make([]byte, 1024*1024) // 1MB random + _, err := rand.Read(logData) + require.NoError(t, err) + + maxSize := int64(300 * 1024) // 300KB + tailSize := searchMaxLogTail(logData, maxSize) + assert.Greater(t, tailSize, 0) + assert.Less(t, tailSize, len(logData)) + + // Verify the result actually fits. + logs := []extraFile{{name: logArchiveName, data: logData[len(logData)-tailSize:]}} + buf, err := writeArchive(logs, nil) + require.NoError(t, err) + assert.LessOrEqual(t, int64(buf.Len()), maxSize) + }) +} + +func TestAddExtrasGreedily(t *testing.T) { + t.Run("adds all when they fit", func(t *testing.T) { + baseLogs := []extraFile{{name: logArchiveName, data: []byte("log")}} + secondary := []extraFile{{name: "crash.log", data: []byte("crash")}} + attachments := []extraFile{{name: "a.txt", data: []byte("aaa")}} + result, err := addExtrasGreedily(baseLogs, secondary, attachments, 1024*1024) + require.NoError(t, err) + + entries := readZipEntries(t, result) + assert.Len(t, entries, 3) + }) + + t.Run("skips extras that would exceed limit", func(t *testing.T) { + baseLogs := []extraFile{{name: logArchiveName, data: []byte("log")}} + small := extraFile{name: "small.txt", data: []byte("s")} + big := extraFile{name: "big.bin", data: bytes.Repeat([]byte{0xFF}, 50*1024)} + + // Budget enough for log + small, but not big. + bufWithSmall, err := writeArchive(baseLogs, []extraFile{small}) + require.NoError(t, err) + maxSize := int64(bufWithSmall.Len()) + 50 // tight budget + + result, err := addExtrasGreedily(baseLogs, nil, []extraFile{small, big}, maxSize) + require.NoError(t, err) + + entries := readZipEntries(t, result) + names := make([]string, len(entries)) + for i, e := range entries { + names[i] = e.name + } + assert.Contains(t, names, logArchiveName) + assert.Contains(t, names, "attachments/small.txt") + assert.NotContains(t, names, "attachments/big.bin") + }) + + t.Run("no extras returns log only", func(t *testing.T) { + baseLogs := []extraFile{{name: logArchiveName, data: []byte("log content")}} + result, err := addExtrasGreedily(baseLogs, nil, nil, 1024*1024) + require.NoError(t, err) + + entries := readZipEntries(t, result) + require.Len(t, entries, 1) + assert.Equal(t, logArchiveName, entries[0].name) + }) +} + +func TestBuildIssueArchive(t *testing.T) { + t.Run("end to end with log and extras", func(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "lantern.log"), []byte("log line 1\nlog line 2\n"), 0644)) + + extra := filepath.Join(dir, "extra.txt") + require.NoError(t, os.WriteFile(extra, []byte("extra content"), 0644)) + + result, err := buildIssueArchive(dir, []string{extra}, 1024*1024) + require.NoError(t, err) + require.NotNil(t, result) + + entries := readZipEntries(t, result) + require.Len(t, entries, 2) + assert.Equal(t, logArchiveName, entries[0].name) + assert.Equal(t, "log line 1\nlog line 2\n", entries[0].content) + assert.Equal(t, "attachments/extra.txt", entries[1].name) + }) + + t.Run("includes all log files in directory", func(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "lantern.log"), []byte("main log"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "lantern-crash.log"), []byte("crash log"), 0644)) + + result, err := buildIssueArchive(dir, nil, 1024*1024) + require.NoError(t, err) + require.NotNil(t, result) + + entries := readZipEntries(t, result) + require.Len(t, entries, 2) + names := make([]string, len(entries)) + for i, e := range entries { + names[i] = e.name + } + assert.Contains(t, names, "lantern.log") + assert.Contains(t, names, "lantern-crash.log") + }) + + t.Run("missing log dir still includes extras", func(t *testing.T) { + dir := t.TempDir() + extra := filepath.Join(dir, "extra.txt") + require.NoError(t, os.WriteFile(extra, []byte("data"), 0644)) + + result, err := buildIssueArchive(filepath.Join(dir, "nonexistent"), []string{extra}, 1024*1024) + require.NoError(t, err) + require.NotNil(t, result) + + entries := readZipEntries(t, result) + require.Len(t, entries, 1) + assert.Equal(t, "attachments/extra.txt", entries[0].name) + }) + + t.Run("archive respects maxSize", func(t *testing.T) { + dir := t.TempDir() + // Write incompressible data (2MB). + logContent := make([]byte, 2*1024*1024) + _, err := rand.Read(logContent) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(dir, "lantern.log"), logContent, 0644)) + + maxSize := int64(512 * 1024) + result, err := buildIssueArchive(dir, nil, maxSize) + require.NoError(t, err) + assert.LessOrEqual(t, int64(len(result)), maxSize) + + // Verify it contains the tail. + entries := readZipEntries(t, result) + require.Len(t, entries, 1) + content := entries[0].content + assert.Equal(t, string(logContent[len(logContent)-len(content):]), content) + }) + + t.Run("snapshot excludes data written after call", func(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "lantern.log") + original := "before snapshot\n" + require.NoError(t, os.WriteFile(logPath, []byte(original), 0644)) + + // Snapshot the file. + data, err := snapshotLogFile(logPath, 1024*1024) + require.NoError(t, err) + + // Append after snapshot. + f, err := os.OpenFile(logPath, os.O_APPEND|os.O_WRONLY, 0644) + require.NoError(t, err) + _, err = f.WriteString("after snapshot\n") + require.NoError(t, err) + f.Close() + + // Snapshot should only contain original content. + assert.Equal(t, original, string(data)) + }) +} + +// --- test helpers --- + +type zipEntry struct { + name string + content string +} + +func readZipEntries(t *testing.T, data []byte) []zipEntry { + t.Helper() + r, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + require.NoError(t, err) + + var entries []zipEntry + for _, f := range r.File { + rc, err := f.Open() + require.NoError(t, err) + body, err := io.ReadAll(rc) + require.NoError(t, err) + rc.Close() + entries = append(entries, zipEntry{name: f.Name, content: string(body)}) + } + return entries +} diff --git a/issue/issue.go b/issue/issue.go index 711f61f5..12e67c8d 100644 --- a/issue/issue.go +++ b/issue/issue.go @@ -4,140 +4,129 @@ import ( "bytes" "context" "fmt" + "io" "log/slog" - "math/rand" + "math/rand/v2" "net/http" "net/http/httputil" - "strconv" + "runtime" "time" "github.com/getlantern/osversion" + "github.com/getlantern/timezone" "go.opentelemetry.io/otel" - "github.com/getlantern/radiance/backend" "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/kindling" "github.com/getlantern/radiance/traces" "google.golang.org/protobuf/proto" ) const ( - maxUncompressedLogSize = 50 * 1024 * 1024 // 50 MB - tracerName = "github.com/getlantern/radiance/issue" + maxCompressedSize = 20 * 1024 * 1024 // 20 MB + tracerName = "github.com/getlantern/radiance/issue" ) -// IssueReporter is used to send issue reports to backend -type IssueReporter struct{} +// IssueReporter is used to send issue reports to backend. +type IssueReporter struct { + httpClient *http.Client +} // NewIssueReporter creates a new IssueReporter that can be used to send issue reports // to the backend. -func NewIssueReporter() *IssueReporter { - return &IssueReporter{} +func NewIssueReporter(httpClient *http.Client) *IssueReporter { + return &IssueReporter{httpClient: httpClient} } -func randStr(n int) string { - r := rand.New(rand.NewSource(time.Now().UnixNano())) - var hexStr string - for i := 0; i < n; i++ { - hexStr += fmt.Sprintf("%x", r.Intn(16)) - } - return hexStr -} +type IssueType int -// Attachment is a file attachment -type Attachment struct { - Name string - Data []byte -} +const ( + CannotCompletePurchase IssueType = iota + CannotSignIn + SpinnerLoadsEndlessly + CannotAccessBlockedSites + Slow + CannotLinkDevice + ApplicationCrashes + Other IssueType = iota + 2 + UpdateFails +) + +// // issue text to type mapping +// var issueTypeMap = map[string]IssueType{ +// "Cannot complete purchase": CannotCompletePurchase, +// "Cannot sign in": CannotSignIn, +// "Spinner loads endlessly": SpinnerLoadsEndlessly, +// "Cannot access blocked sites": CannotAccessBlockedSites, +// "Slow": Slow, +// "Cannot link device": CannotLinkDevice, +// "Application crashes": ApplicationCrashes, +// "Other": Other, +// "Update fails": UpdateFails, +// } type IssueReport struct { // Type is one of the predefined issue type strings - Type string - // Issue description + Type IssueType Description string - // Attachment is a list of issue attachments - Attachments []*Attachment + Email string + CountryCode string // device common name - Device string + Device string + DeviceID string + UserID string + SubscriptionLevel string + Locale string // device alphanumeric name Model string -} - -// issue text to type mapping -var issueTypeMap = map[string]int{ - "Cannot complete purchase": 0, - "Cannot sign in": 1, - "Spinner loads endlessly": 2, - "Cannot access blocked sites": 3, - "Slow": 4, - "Cannot link device": 5, - "Application crashes": 6, - "Other": 9, - "Update fails": 10, + // AdditionalAttachments is a list of additional files to be attached. The log file will be + // automatically included. + AdditionalAttachments []string } // Report sends an issue report to lantern-cloud/issue, which is then forwarded to ticket system via API -func (ir *IssueReporter) Report(ctx context.Context, report IssueReport, userEmail, country string) error { +func (ir *IssueReporter) Report(ctx context.Context, report IssueReport) error { ctx, span := otel.Tracer(tracerName).Start(ctx, "Report") defer span.End() // set a random email if it's empty - if userEmail == "" { - userEmail = "support+" + randStr(8) + "@getlantern.org" + if report.Email == "" { + report.Email = "support+" + randStr(8) + "@getlantern.org" } - userStatus := settings.GetString(settings.UserLevelKey) + // userStatus := settings.GetString(settings.UserLevelKey) osVersion, err := osversion.GetHumanReadable() if err != nil { slog.Error("Unable to get OS version", "error", err) + osVersion = runtime.GOOS + " " + runtime.GOARCH } - // get issue type as integer - iType, ok := issueTypeMap[report.Type] - if !ok { - slog.Error("Unknown issue type, setting to 'Other'", "type", report.Type) - iType = 9 - } - r := &ReportIssueRequest{ - Type: ReportIssueRequest_ISSUE_TYPE(iType), - CountryCode: country, + Type: ReportIssueRequest_ISSUE_TYPE(report.Type), AppVersion: common.Version, - SubscriptionLevel: userStatus, Platform: common.Platform, + CountryCode: report.CountryCode, + SubscriptionLevel: report.SubscriptionLevel, Description: report.Description, - UserEmail: userEmail, - DeviceId: settings.GetString(settings.DeviceIDKey), - UserId: strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10), + UserEmail: report.Email, + DeviceId: report.DeviceID, + UserId: report.UserID, Device: report.Device, Model: report.Model, + Language: report.Locale, OsVersion: osVersion, - Language: settings.GetString(settings.LocaleKey), - } - - for _, attachment := range report.Attachments { - r.Attachments = append(r.Attachments, &ReportIssueRequest_Attachment{ - Type: "application/zip", - Name: attachment.Name, - Content: attachment.Data, - }) } - // Zip logs - slog.Debug("zipping log files for issue report") - buf := &bytes.Buffer{} - // zip * under folder common.LogDir logDir := settings.GetString(settings.LogPathKey) - slog.Debug("zipping log files", "logDir", logDir, "maxSize", maxUncompressedLogSize) - if _, zipErr := zipLogFiles(buf, logDir, maxUncompressedLogSize, int64(maxUncompressedLogSize)); zipErr == nil { - r.Attachments = append(r.Attachments, &ReportIssueRequest_Attachment{ + archive, err := buildIssueArchive(logDir, report.AdditionalAttachments, maxCompressedSize) + if err != nil { + slog.Error("failed to build issue archive", "error", err) + } + if len(archive) > 0 { + r.Attachments = []*ReportIssueRequest_Attachment{{ Type: "application/zip", Name: "logs.zip", - Content: buf.Bytes(), - }) - slog.Debug("log files zipped for issue report", "size", len(buf.Bytes())) - } else { - slog.Error("unable to zip log files", "error", err, "logDir", logDir, "maxSize", maxUncompressedLogSize) + Content: archive, + }} } // send message to lantern-cloud @@ -148,7 +137,7 @@ func (ir *IssueReporter) Report(ctx context.Context, report IssueReport, userEma } issueURL := common.GetBaseURL() + "/issue" - req, err := backend.NewIssueRequest( + req, err := newIssueRequest( ctx, http.MethodPost, issueURL, @@ -159,7 +148,7 @@ func (ir *IssueReporter) Report(ctx context.Context, report IssueReport, userEma return traces.RecordError(ctx, err) } - resp, err := kindling.HTTPClient().Do(req) + resp, err := ir.httpClient.Do(req) if err != nil { slog.Error("failed to send issue report", "error", err, "requestURL", issueURL) return traces.RecordError(ctx, err) @@ -178,3 +167,27 @@ func (ir *IssueReporter) Report(ctx context.Context, report IssueReport, userEma slog.Debug("issue report sent") return nil } + +// newIssueRequest creates a new HTTP request with the required headers for issue reporting. +func newIssueRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + req, err := common.NewRequestWithHeaders(ctx, method, url, body) + if err != nil { + return nil, err + } + + req.Header.Set("content-type", "application/x-protobuf") + req.Header.Set(common.SupportedDataCapsHeader, "monthly,weekly,daily") + if tz, err := timezone.IANANameForTime(time.Now()); err == nil { + req.Header.Set(common.TimeZoneHeader, tz) + } + + return req, nil +} + +func randStr(n int) string { + var hexStr string + for range n { + hexStr += fmt.Sprintf("%x", rand.IntN(16)) + } + return hexStr +} diff --git a/issue/issue_test.go b/issue/issue_test.go index 7e6b4634..58609fa3 100644 --- a/issue/issue_test.go +++ b/issue/issue_test.go @@ -1,12 +1,15 @@ package issue import ( + "archive/zip" + "bytes" "context" "io" "net/http" "net/http/httptest" "net/url" - "strconv" + "os" + "path/filepath" "testing" "github.com/getlantern/osversion" @@ -16,7 +19,6 @@ import ( "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/kindling" ) func TestSendReport(t *testing.T) { @@ -26,7 +28,13 @@ func TestSendReport(t *testing.T) { osVer, err := osversion.GetHumanReadable() require.NoError(t, err) - // Build expected report + // Create a temp file to use as an additional attachment + tmpDir := t.TempDir() + attachPath := filepath.Join(tmpDir, "Hello.txt") + err = os.WriteFile(attachPath, []byte("Hello World"), 0644) + require.NoError(t, err) + + // Build expected report (without attachments — we verify those separately) want := &ReportIssueRequest{ Type: ReportIssueRequest_NO_ACCESS, CountryCode: "US", @@ -36,53 +44,40 @@ func TestSendReport(t *testing.T) { Description: "Description placeholder-test only", UserEmail: "radiancetest@getlantern.org", DeviceId: settings.GetString(settings.DeviceIDKey), - UserId: strconv.FormatInt(settings.GetInt64(settings.UserIDKey), 10), + UserId: settings.GetString(settings.UserIDKey), Device: "Samsung Galaxy S10", Model: "SM-G973F", OsVersion: osVer, Language: settings.GetString(settings.LocaleKey), - Attachments: []*ReportIssueRequest_Attachment{ - { - Type: "application/zip", - Name: "Hello.txt", - Content: []byte("Hello World"), - }, - }, } srv := newTestServer(t, want) defer srv.Close() - reporter := &IssueReporter{} - kindling.SetKindling(&mockKindling{newTestClient(t, srv.URL)}) - report := IssueReport{ - Type: "Cannot access blocked sites", - Description: "Description placeholder-test only", - Attachments: []*Attachment{ - { - Name: "Hello.txt", - Data: []byte("Hello World"), - }, - }, - Device: "Samsung Galaxy S10", - Model: "SM-G973F", - } - - err = reporter.Report(context.Background(), report, "radiancetest@getlantern.org", "US") - require.NoError(t, err) -} - -func newTestClient(t *testing.T, testURL string) *http.Client { - return &http.Client{ + reporter := NewIssueReporter(&http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - parsedURL, err := url.Parse(testURL) - if err != nil { - t.Fatalf("failed to parse testURL: %v", err) - } + parsedURL, err := url.Parse(srv.URL) + require.NoError(t, err, "failed to parse test server URL") req.URL = parsedURL return http.DefaultTransport.RoundTrip(req) }), + }) + report := IssueReport{ + Type: CannotAccessBlockedSites, + Description: "Description placeholder-test only", + Email: "radiancetest@getlantern.org", + CountryCode: "US", + SubscriptionLevel: "free", + DeviceID: settings.GetString(settings.DeviceIDKey), + UserID: settings.GetString(settings.UserIDKey), + Locale: settings.GetString(settings.LocaleKey), + Device: "Samsung Galaxy S10", + Model: "SM-G973F", + AdditionalAttachments: []string{attachPath}, } + + err = reporter.Report(context.Background(), report) + require.NoError(t, err) } // roundTripperFunc allows using a function as http.RoundTripper @@ -109,18 +104,29 @@ func newTestServer(t *testing.T, want *ReportIssueRequest) *testServer { err = proto.Unmarshal(body, &got) require.NoError(t, err, "should unmarshal protobuf request") - // Filter got.Attachments to only include the ones we're testing - // (exclude logs.zip and other dynamic attachments) - filteredAttachments := make([]*ReportIssueRequest_Attachment, 0) - for _, gotAtt := range got.Attachments { - for _, wantAtt := range ts.want.Attachments { - if gotAtt.Name == wantAtt.Name { - filteredAttachments = append(filteredAttachments, gotAtt) - break + // Verify logs.zip attachment contains the additional file + var foundHello bool + for _, att := range got.Attachments { + if att.Name == "logs.zip" { + zr, err := zip.NewReader(bytes.NewReader(att.Content), int64(len(att.Content))) + require.NoError(t, err, "should open logs.zip") + for _, f := range zr.File { + if f.Name == "attachments/Hello.txt" { + rc, err := f.Open() + require.NoError(t, err) + data, err := io.ReadAll(rc) + require.NoError(t, err) + rc.Close() + assert.Equal(t, "Hello World", string(data)) + foundHello = true + } } } } - got.Attachments = filteredAttachments + assert.True(t, foundHello, "logs.zip should contain attachments/Hello.txt") + + // Clear attachments for field-level comparison + got.Attachments = nil // Compare received report with expected report using proto.Equal if assert.True(t, proto.Equal(ts.want, &got), "received report should match expected report") { @@ -131,17 +137,3 @@ func newTestServer(t *testing.T, want *ReportIssueRequest) *testServer { })) return ts } - -type mockKindling struct { - c *http.Client -} - -// NewHTTPClient returns a new HTTP client that is configured to use kindling. -func (m *mockKindling) NewHTTPClient() *http.Client { - return m.c -} - -// ReplaceTransport replaces an existing transport RoundTripper generator with the provided one. -func (m *mockKindling) ReplaceTransport(name string, rt func(ctx context.Context, addr string) (http.RoundTripper, error)) error { - panic("not implemented") // TODO: Implement -} diff --git a/issue/logzipper.go b/issue/logzipper.go deleted file mode 100644 index 693a1a87..00000000 --- a/issue/logzipper.go +++ /dev/null @@ -1,111 +0,0 @@ -package issue - -// copied from flashlight/logging/logging.go - -import ( - "io" - "log/slog" - "os" - "path/filepath" - "sort" -) - -type fileInfo struct { - file string - size int64 - modTime int64 -} -type byDate []*fileInfo - -func (a byDate) Len() int { return len(a) } -func (a byDate) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a byDate) Less(i, j int) bool { return a[i].modTime > a[j].modTime } - -// zipLogFiles zips the Lantern log files to the writer. All files will be -// placed under the folder in the archieve. It will stop and return if the -// newly added file would make the extracted files exceed maxBytes in total. -// -// It also returns up to maxTextBytes of plain text from the end of the most recent log file. -func zipLogFiles(w io.Writer, logDir string, maxBytes int64, maxTextBytes int64) (string, error) { - return zipLogFilesFrom(w, maxBytes, maxTextBytes, map[string]string{"logs": logDir}) -} - -// zipLogFilesFrom zips the log files from the given dirs to the writer. It will -// stop and return if the newly added file would make the extracted files exceed -// maxBytes in total. -// -// It also returns up to maxTextBytes of plain text from the end of the most recent log file. -func zipLogFilesFrom(w io.Writer, maxBytes int64, maxTextBytes int64, dirs map[string]string) (string, error) { - globs := make(map[string]string, len(dirs)) - for baseDir, dir := range dirs { - globs[baseDir] = filepath.Join(dir, "*") - } - err := zipFiles(w, zipOptions{ - Globs: globs, - MaxBytes: maxBytes, - }) - if err != nil { - return "", err - } - - if maxTextBytes <= 0 { - return "", nil - } - - // Get info for all log files - allFiles := make(byDate, 0) - for _, glob := range globs { - matched, err := filepath.Glob(glob) - if err != nil { - slog.Error("Unable to glob log files", "glob", glob, "error", err) - continue - } - for _, file := range matched { - fi, err := os.Stat(file) - if err != nil { - slog.Error("Unable to stat log file", "file", file, "error", err) - continue - } - allFiles = append(allFiles, &fileInfo{ - file: file, - size: fi.Size(), - modTime: fi.ModTime().Unix(), - }) - } - } - - if len(allFiles) > 0 { - // Sort by recency - sort.Sort(allFiles) - - mostRecent := allFiles[0] - slog.Debug("Grabbing log tail", "file", mostRecent.file) - - mostRecentFile, err := os.Open(mostRecent.file) - if err != nil { - slog.Error("Unable to open most recent log file", "file", mostRecent.file, "error", err) - return "", nil - } - defer mostRecentFile.Close() - - seekTo := mostRecent.size - maxTextBytes - if seekTo > 0 { - slog.Debug("Seeking to tail of log file", "file", mostRecent.file, "seekTo", seekTo) - _, err = mostRecentFile.Seek(seekTo, io.SeekCurrent) - if err != nil { - slog.Error("Unable to seek to tail of log file", "file", mostRecent.file, "error", err) - return "", nil - } - } - tail, err := io.ReadAll(mostRecentFile) - if err != nil { - slog.Error("Unable to read tail of log file", "file", mostRecent.file, "error", err) - return "", nil - } - - slog.Debug("Returning log tail", "file", mostRecent.file, "tailSize", len(tail)) - return string(tail), nil - } - - return "", nil -} diff --git a/issue/zip.go b/issue/zip.go deleted file mode 100644 index 28731eb0..00000000 --- a/issue/zip.go +++ /dev/null @@ -1,118 +0,0 @@ -package issue - -import ( - "archive/zip" - "fmt" - "io" - "math" - "os" - "path/filepath" -) - -// zipOptions is a set of options for zipFiles. -type zipOptions struct { - // The search patterns for the files / directories to be zipped, keyed to the - // directory prefix used for storing the associated files in the ZIP, - // The search pattern is described at the comments of path/filepath.Match. - // As a special note, "**/*" doesn't match files not under a subdirectory. - Globs map[string]string - // The limit of total bytes of all the files in the archive. - // All remaining files will be ignored if the limit would be hit. - MaxBytes int64 -} - -// zipFiles creates a zip archive per the options and writes to the writer. -func zipFiles(writer io.Writer, opts zipOptions) (err error) { - w := zip.NewWriter(writer) - defer func() { - if e := w.Close(); e != nil { - err = e - } - }() - - maxBytes := opts.MaxBytes - if maxBytes == 0 { - maxBytes = math.MaxInt64 - } - - var totalBytes int64 - for baseDir, glob := range opts.Globs { - matched, e := filepath.Glob(glob) - if e != nil { - return e - } - for _, source := range matched { - nextTotal, e := zipFile(w, baseDir, source, maxBytes, totalBytes) - if e != nil || nextTotal > maxBytes { - return e - } - totalBytes = nextTotal - } - } - return -} - -func zipFile(w *zip.Writer, baseDir string, source string, limit int64, prevBytes int64) (newBytes int64, err error) { - _, e := os.Stat(source) - if e != nil { - return prevBytes, fmt.Errorf("%s: stat: %v", source, e) - } - - walkErr := filepath.Walk(source, func(fpath string, info os.FileInfo, err error) error { - if err != nil { - return fmt.Errorf("walking to %s: %v", fpath, err) - } - - newBytes = prevBytes + info.Size() - if newBytes > limit { - return filepath.SkipDir - } - header, err := zip.FileInfoHeader(info) - if err != nil { - return fmt.Errorf("%s: getting header: %v", fpath, err) - } - - dir, filename := filepath.Split(fpath) - if baseDir != "" { - dir = baseDir - } else { - dir = dir[:len(dir)-1] // strip trailing slash - } - if info.IsDir() { - header.Name = fmt.Sprintf("%v/", dir) - header.Method = zip.Store - } else { - header.Name = fmt.Sprintf("%v/%v", dir, filename) - header.Method = zip.Deflate - } - - writer, err := w.CreateHeader(header) - if err != nil { - return fmt.Errorf("%s: making header: %v", fpath, err) - } - - if info.IsDir() { - return nil - } - - if !header.Mode().IsRegular() { - return nil - } - file, err := os.Open(fpath) - if err != nil { - return fmt.Errorf("%s: opening: %v", fpath, err) - } - defer file.Close() - - _, err = io.Copy(writer, file) - if err != nil && err != io.EOF { - return fmt.Errorf("%s: copying contents: %v", fpath, err) - } - return nil - }) - - if walkErr != filepath.SkipDir { - return newBytes, walkErr - } - return newBytes, nil -} diff --git a/issue/zip_test.go b/issue/zip_test.go deleted file mode 100644 index 76a21238..00000000 --- a/issue/zip_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package issue - -import ( - "archive/zip" - "bytes" - "io" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestZipFilesWithoutPath(t *testing.T) { - var buf bytes.Buffer - err := zipFiles(&buf, zipOptions{Globs: map[string]string{"": "**/*.txt*"}}) - if !assert.NoError(t, err) { - return - } - expectedFiles := []string{ - "test_data/hello.txt", - "test_data/hello.txt.1", - "test_data/large.txt", - "test_data/zzzz.txt.2", - } - testZipFiles(t, buf.Bytes(), expectedFiles) -} - -func TestZipFilesWithMaxBytes(t *testing.T) { - var buf bytes.Buffer - err := zipFiles(&buf, - zipOptions{ - Globs: map[string]string{"": "test_data/*.txt*"}, - MaxBytes: 1024, // 1KB - }, - ) - if !assert.NoError(t, err) { - return - } - expectedFiles := []string{ - "test_data/hello.txt", - "test_data/hello.txt.1", - } - testZipFiles(t, buf.Bytes(), expectedFiles) -} - -func TestZipFilesWithNewRoot(t *testing.T) { - var buf bytes.Buffer - err := zipFiles(&buf, zipOptions{Globs: map[string]string{"new_root": "**/*.txt*"}}) - if !assert.NoError(t, err) { - return - } - expectedFiles := []string{ - "new_root/hello.txt", - "new_root/hello.txt.1", - "new_root/large.txt", - "new_root/zzzz.txt.2", - } - testZipFiles(t, buf.Bytes(), expectedFiles) -} - -func testZipFiles(t *testing.T, zipped []byte, expectedFiles []string) { - reader, eread := zip.NewReader(bytes.NewReader(zipped), int64(len(zipped))) - if !assert.NoError(t, eread) { - return - } - if !assert.Equal(t, len(expectedFiles), len(reader.File), "should not include extra files and files that would exceed MaxBytes") { - return - } - for idx, file := range reader.File { - t.Log(file.Name) - assert.Equal(t, expectedFiles[idx], file.Name) - if !strings.Contains(file.Name, "hello.txt") { - continue - } - fileReader, err := file.Open() - if !assert.NoError(t, err) { - return - } - defer fileReader.Close() - actual, _ := io.ReadAll(fileReader) - assert.Equal(t, []byte("world\n"), actual) - } -} diff --git a/kindling/client.go b/kindling/client.go index c234719f..d2474067 100644 --- a/kindling/client.go +++ b/kindling/client.go @@ -1,27 +1,35 @@ +// Package kindling provides a wrapper around the kindling library to create an HTTP client with +// various transports (domain fronting, AMP, DNS tunneling, proxyless) from a shared kindling instance. package kindling import ( "context" + "fmt" "log/slog" + "net" "net/http" "path/filepath" + "strings" "sync" "github.com/getlantern/kindling" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/net/proxy" + "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/common/env" "github.com/getlantern/radiance/common/reporting" "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/kindling/dnstt" "github.com/getlantern/radiance/kindling/fronted" "github.com/getlantern/radiance/traces" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" ) var ( k kindling.Kindling - kindlingMutex sync.Mutex + initOnce sync.Once stopUpdater func() closeTransports []func() error // EnabledTransports is used for testing purposes for enabling/disabling kindling transports @@ -31,36 +39,92 @@ var ( "proxyless": true, "fronted": true, } + defaultTransportClone = http.DefaultTransport.(*http.Transport).Clone() + + // transport is the shared http.RoundTripper set once by initOnce. + transport http.RoundTripper ) -// HTTPClient returns a http client with kindling transport. -// Thread-safe: uses kindlingMutex to guard lazy initialization. -func HTTPClient() *http.Client { - kindlingMutex.Lock() - if k == nil { - newK, err := NewKindling() +// initKindling initializes the package-level kindling instance and shared +// transport. +func initKindling() { + // Censorship-circumvention QA path: when OutboundSocksAddress is set, + // every outbound HTTP dial goes through that SOCKS5 server. Kindling's + // stacked transports (fronted/AMP/dnstt/proxyless) are skipped — the + // SOCKS5 is providing egress, and kindling's per-transport internal + // dialers don't expose an override hook today. As a result, when this + // var is set we are testing "does the bandit/tunnel path work given a + // reachable API channel" rather than the full anti-censorship stack. + if addr, ok := env.Get(env.OutboundSocksAddress); ok && addr != "" { + t, err := socksOnlyTransport(addr) if err != nil { - slog.Error("failed to create kindling client", slog.Any("error", err)) - } - if newK != nil { - k = newK + slog.Error("invalid RADIANCE_OUTBOUND_SOCKS_ADDRESS, falling back to default transport", slog.Any("error", err)) + transport = traces.NewRoundTripper(traces.NewHeaderAnnotatingRoundTripper(defaultTransportClone)) + return } + slog.Info("RADIANCE_OUTBOUND_SOCKS_ADDRESS set — routing all radiance HTTP through upstream SOCKS5", slog.String("addr", addr)) + transport = traces.NewRoundTripper(traces.NewHeaderAnnotatingRoundTripper(t)) + return } - localK := k - kindlingMutex.Unlock() + newK, err := NewKindling(settings.GetString(settings.DataPathKey)) + if err != nil { + slog.Error("failed to create kindling client", slog.Any("error", err)) + } + if newK != nil { + k = newK + transport = traces.NewRoundTripper(traces.NewHeaderAnnotatingRoundTripper(newK.NewHTTPClient().Transport)) + } else { + slog.Warn("kindling unavailable, using default transport clone") + transport = traces.NewRoundTripper(traces.NewHeaderAnnotatingRoundTripper(defaultTransportClone)) + } +} - if localK == nil { - slog.Warn("kindling unavailable, returning bare HTTP client") - return &http.Client{Timeout: common.DefaultHTTPTimeout} +// socksOnlyTransport returns an http.Transport that dials through the given +// SOCKS5 server for every connection. +func socksOnlyTransport(socksAddr string) (*http.Transport, error) { + d, err := proxy.SOCKS5("tcp", socksAddr, nil, proxy.Direct) + if err != nil { + return nil, fmt.Errorf("building SOCKS5 dialer for %s: %w", socksAddr, err) } - httpClient := localK.NewHTTPClient() - httpClient.Timeout = common.DefaultHTTPTimeout - httpClient.Transport = traces.NewRoundTripper(traces.NewHeaderAnnotatingRoundTripper(httpClient.Transport)) - return httpClient + ctxDialer, ok := d.(proxy.ContextDialer) + if !ok { + return nil, fmt.Errorf("SOCKS5 dialer does not support context") + } + t := defaultTransportClone.Clone() + t.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + return ctxDialer.DialContext(ctx, network, address) + } + // Disable HTTP_PROXY env-based proxying — we route via DialContext instead. + // (x/net/proxy's SOCKS5 sends the hostname to the upstream as ATYP=domain, + // so DNS resolution also happens at the SOCKS5 server, no local leak.) + t.Proxy = nil + return t, nil +} + +func Init() { + go initOnce.Do(initKindling) +} + +// HTTPClient returns an HTTP client whose transport blocks on first use +// until kindling is initialized. +func HTTPClient() *http.Client { + return &http.Client{ + Timeout: common.DefaultHTTPTimeout, + Transport: readyTransport{}, + } +} + +// readyTransport blocks until initOnce has completed, then delegates to the +// shared transport. +type readyTransport struct{} + +func (readyTransport) RoundTrip(req *http.Request) (*http.Response, error) { + initOnce.Do(initKindling) + return transport.RoundTrip(req) } -// Close stop all concurrent config fetches that can be happening in background -func Close(_ context.Context) error { +// Close stops any in-flight config fetches and releases kindling transports. +func Close() error { if stopUpdater != nil { stopUpdater() } @@ -72,19 +136,24 @@ func Close(_ context.Context) error { return nil } -// SetKindling sets the kindling method used for building the HTTP client -// This function is useful for testing purposes. +// SetKindling installs a kindling instance for tests, bypassing the normal +// initialization path. Call it before any HTTPClient usage; otherwise +// initOnce will have already run and this call becomes a no-op. func SetKindling(a kindling.Kindling) { - kindlingMutex.Lock() - defer kindlingMutex.Unlock() - k = a + initOnce.Do(func() { + k = a + if a != nil { + transport = traces.NewRoundTripper(traces.NewHeaderAnnotatingRoundTripper(a.NewHTTPClient().Transport)) + } else { + transport = traces.NewRoundTripper(traces.NewHeaderAnnotatingRoundTripper(defaultTransportClone)) + } + }) } const tracerName = "github.com/getlantern/radiance/kindling" // NewKindling build a kindling client and bootstrap this package -func NewKindling() (kindling.Kindling, error) { - dataDir := settings.GetString(settings.DataPathKey) +func NewKindling(dataDir string) (kindling.Kindling, error) { logger := &slogWriter{Logger: slog.Default()} ctx, span := otel.Tracer(tracerName).Start( @@ -163,6 +232,8 @@ type slogWriter struct { func (w *slogWriter) Write(p []byte) (n int, err error) { // Convert the byte slice to a string and log it - w.Info(string(p)) + s := string(p) + s = strings.TrimSpace(s) + w.Info(s) return len(p), nil } diff --git a/kindling/client_test.go b/kindling/client_test.go index 0dd2b172..735b589f 100644 --- a/kindling/client_test.go +++ b/kindling/client_test.go @@ -1,30 +1,21 @@ package kindling import ( - "context" - "log/slog" "net/http" - "os" "testing" - "github.com/getlantern/radiance/common/settings" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewClient(t *testing.T) { - slog.SetDefault(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - AddSource: true, - Level: slog.LevelDebug, - }))) - settings.Set(settings.DataPathKey, t.TempDir()) - newK, err := NewKindling() + newK, err := NewKindling(t.TempDir()) require.NoError(t, err) require.NotNil(t, newK) SetKindling(newK) t.Cleanup(func() { - Close(context.Background()) + Close() k = nil }) diff --git a/kindling/dnstt/parser.go b/kindling/dnstt/parser.go index f9577344..e9690a0c 100644 --- a/kindling/dnstt/parser.go +++ b/kindling/dnstt/parser.go @@ -20,11 +20,14 @@ import ( "github.com/alitto/pond" "github.com/getlantern/dnstt" "github.com/getlantern/keepcurrent" + "github.com/goccy/go-yaml" + "go.opentelemetry.io/otel" + + "github.com/getlantern/radiance/common/atomicfile" + "github.com/getlantern/radiance/common/fileperm" "github.com/getlantern/radiance/events" "github.com/getlantern/radiance/kindling/smart" "github.com/getlantern/radiance/traces" - "github.com/goccy/go-yaml" - "go.opentelemetry.io/otel" ) type dnsttConfig struct { @@ -195,7 +198,7 @@ func onNewDNSTTConfig(configFilepath string, gzippedYML []byte) error { localConfigMutex.Lock() defer localConfigMutex.Unlock() - return os.WriteFile(configFilepath, gzippedYML, 0644) + return atomicfile.WriteFile(configFilepath, gzippedYML, fileperm.File) } func newDNSTT(cfg dnsttConfig) (dnstt.DNSTT, error) { diff --git a/kindling/dnstt/parser_test.go b/kindling/dnstt/parser_test.go index efa0f1e4..99d1dc0a 100644 --- a/kindling/dnstt/parser_test.go +++ b/kindling/dnstt/parser_test.go @@ -5,16 +5,16 @@ import ( "compress/gzip" "context" "io" - "log/slog" "net/http" "os" "path/filepath" "testing" "time" - "github.com/getlantern/radiance/events" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/getlantern/radiance/events" ) type roundTripperFunc func(*http.Request) (*http.Response, error) @@ -125,10 +125,6 @@ dnsttConfigs: func TestDNSTTOptions(t *testing.T) { logger := bytes.NewBuffer(nil) - slog.SetDefault(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - AddSource: true, - Level: slog.LevelDebug, - }))) waitFor = 15 * time.Second t.Run("embedded config only", func(t *testing.T) { dnst, err := DNSTTOptions(context.Background(), "", logger) diff --git a/log/log.go b/log/log.go new file mode 100644 index 00000000..0350fbbd --- /dev/null +++ b/log/log.go @@ -0,0 +1,247 @@ +package log + +import ( + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + "gopkg.in/natefinch/lumberjack.v2" + + "github.com/getlantern/radiance/common/env" + "github.com/getlantern/radiance/common/fileperm" + "github.com/getlantern/radiance/common/settings" +) + +const ( + // slog does not define trace and fatal levels, so we define them here. + LevelTrace = slog.LevelDebug - 4 + LevelDebug = slog.LevelDebug + LevelInfo = slog.LevelInfo + LevelWarn = slog.LevelWarn + LevelError = slog.LevelError + LevelFatal = slog.LevelError + 4 + LevelPanic = slog.LevelError + 8 + + Disable = slog.LevelInfo + 1000 // A level that disables logging, used for testing or no-op logger. +) + +type Config struct { + // LogPath is the full path to the log file. + LogPath string + // Level is the log level string (e.g., "info", "debug"). + Level string + // Prod indicates whether the application is running in production mode. + Prod bool + // DisablePublisher indicates whether to disable the log publisher which is used for real-time + // log streaming. + DisablePublisher bool +} + +// NewLogger creates and returns a configured *slog.Logger that writes to a rotating log file +// and optionally to stdout. +// Returns noop logger if log level is set to disable. +func NewLogger(cfg Config) *slog.Logger { + if env.GetBool(env.Testing) { + return NoOpLogger() + } + level := settings.GetString(settings.LogLevelKey) + if level == "" { + level = env.GetString(env.LogLevel) + } + if level == "" && cfg.Level != "" { + level = cfg.Level + } + slevel, err := ParseLogLevel(level) + if err != nil { + slog.Warn("Failed to parse log level", "error", err) + } + slog.SetLogLoggerLevel(slevel) + leveler := settingsLeveler{fallback: slevel} + + // lumberjack creates the log file with 0600 if it does not exist, otherwise it carries over + // the existing permissions. Pre-create with [fileperm.File] so the platform-appropriate mode is + // applied. + f, err := os.OpenFile(cfg.LogPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, fileperm.File) + if err != nil { + slog.Warn("Failed to pre-create log file", "error", err, "path", cfg.LogPath) + } else { + f.Close() + } + + logRotator := &lumberjack.Logger{ + Filename: cfg.LogPath, // Log file path + MaxSize: 25, // Rotate log when it reaches 25 MB + MaxBackups: 2, // Keep up to 2 rotated log files + MaxAge: 30, // Retain old log files for up to 30 days + Compress: cfg.Prod, // Compress rotated log files + } + + isWindows := runtime.GOOS == "windows" + isWindowsProd := isWindows && cfg.Prod + + loggingToStdOut := true + var logWriter io.Writer + if env.GetBool(env.DisableStdout) { + logWriter = logRotator + loggingToStdOut = false + } else if isWindowsProd { + // For some reason, logging to both stdout and a file on Windows + // causes issues with some Windows services where the logs + // do not get written to the file. So in prod mode on Windows, + // we log to file only. See: + // https://www.reddit.com/r/golang/comments/1fpo3cg/golang_windows_service_cannot_write_log_files/ + logWriter = logRotator + loggingToStdOut = false + } else { + logWriter = io.MultiWriter(os.Stdout, logRotator) + } + if !cfg.DisablePublisher { + logWriter = io.MultiWriter(logWriter, Publisher()) + } + var handler slog.Handler = slog.NewTextHandler(logWriter, &slog.HandlerOptions{ + AddSource: true, + Level: leveler, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + switch a.Key { + case slog.TimeKey: + if t, ok := a.Value.Any().(time.Time); ok { + a.Value = slog.StringValue(t.UTC().Format("2006-01-02 15:04:05.000 UTC")) + } + return a + case slog.SourceKey: + source, ok := a.Value.Any().(*slog.Source) + if !ok { + return a + } + // remove github.com/ to get pkg name + var pkg, fn string + fields := strings.SplitN(source.Function, "/", 4) + switch len(fields) { + case 0, 1, 2: + file := filepath.Base(source.File) + a.Value = slog.StringValue(fmt.Sprintf("%s:%d", file, source.Line)) + return a + case 3: + pf := strings.SplitN(fields[2], ".", 2) + pkg, fn = pf[0], pf[1] + default: + pkg = fields[2] + fn = strings.SplitN(fields[3], ".", 2)[1] + } + + _, file, fnd := strings.Cut(source.File, pkg+"/") + if !fnd { + file = filepath.Base(source.File) + } + src := slog.GroupValue( + slog.String("func", fn), + slog.String("file", fmt.Sprintf("%s:%d", file, source.Line)), + ) + a.Value = slog.GroupValue( + slog.String("pkg", pkg), + slog.Any("source", src), + ) + a.Key = "" + case slog.LevelKey: + // format the log level to account for the custom levels defined in internal/util.go, i.e. trace + // otherwise, slog will print as "DEBUG-4" (trace) or similar + level := a.Value.Any().(slog.Level) + a.Value = slog.StringValue(FormatLogLevel(level)) + } + return a + }, + }) + handler = &Handler{Handler: handler, w: logWriter} + logger := slog.New(handler) + if !loggingToStdOut { + if isWindows { + fmt.Printf("Logging to file only on Windows prod -- run with RADIANCE_ENV=dev to enable stdout path: %s, level: %s\n", cfg.LogPath, FormatLogLevel(slevel)) + } else { + fmt.Printf("Logging to file only -- RADIANCE_DISABLE_STDOUT_LOG is set path: %s, level: %s\n", cfg.LogPath, FormatLogLevel(slevel)) + } + } else { + fmt.Printf("Logging to file and stdout path: %s, level: %s\n", cfg.LogPath, FormatLogLevel(slevel)) + } + return logger +} + +type Handler struct { + slog.Handler + w io.Writer +} + +func (h *Handler) Writer() io.Writer { + return h.w +} + +// settingsLeveler reads the current log level from settings on each call, +// so changes to settings.LogLevelKey take effect without rebuilding the logger. +type settingsLeveler struct { + fallback slog.Level +} + +func (s settingsLeveler) Level() slog.Level { + if v := settings.GetString(settings.LogLevelKey); v != "" { + if lvl, err := ParseLogLevel(v); err == nil { + return lvl + } + } + return s.fallback +} + +// ParseLogLevel parses a string representation of a log level and returns the corresponding slog.Level. +// If the level is not recognized, it returns LevelInfo. +func ParseLogLevel(level string) (slog.Level, error) { + switch strings.ToLower(level) { + case "trace": + return LevelTrace, nil + case "debug": + return LevelDebug, nil + case "info": + return LevelInfo, nil + case "warn", "warning": + return LevelWarn, nil + case "error": + return LevelError, nil + case "fatal": + return LevelFatal, nil + case "panic": + return LevelPanic, nil + case "disable", "none", "off": + return Disable, nil + default: + return LevelInfo, fmt.Errorf("unknown log level: %s", level) + } +} + +func FormatLogLevel(level slog.Level) string { + switch { + case level < LevelDebug: + return "TRACE" + case level < LevelInfo: + return "DEBUG" + case level < LevelWarn: + return "INFO" + case level < LevelError: + return "WARN" + case level < LevelFatal: + return "ERROR" + case level < LevelPanic: + return "FATAL" + default: + return "PANIC" + } +} + +// NoOpLogger returns a no-op logger that does not log anything. +func NoOpLogger() *slog.Logger { + // Create a no-op logger that does nothing. + return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ + Level: Disable, + })) +} diff --git a/log/publish_handler.go b/log/publish_handler.go new file mode 100644 index 00000000..1e1b7f96 --- /dev/null +++ b/log/publish_handler.go @@ -0,0 +1,86 @@ +package log + +import ( + "sync" +) + +// LogEntry is a formatted log line streamed to clients. +type LogEntry = string + +// Subscribe returns a channel that receives log entries from the default logger +// and an unsubscribe function. Recent entries from the ring buffer are sent +// immediately. +func Subscribe() (chan LogEntry, func()) { + return defaultPublisher.subscribe() +} + +var defaultPublisher = newPublisher(200) + +// Publisher returns the default log publisher. Include it in the handler's +// writer chain so published entries share the same format. +func Publisher() *publisher { + return defaultPublisher +} + +// publisher fans out log lines to connected SSE clients. It implements io.Writer +// so it can be included in the handler's writer chain. It maintains a ring buffer +// of recent entries so new subscribers get immediate context. +type publisher struct { + clients map[chan LogEntry]struct{} + ring []LogEntry + ringSize int + ringIdx int + mu sync.RWMutex +} + +func newPublisher(ringSize int) *publisher { + return &publisher{ + clients: make(map[chan LogEntry]struct{}), + ring: make([]LogEntry, ringSize), + ringSize: ringSize, + } +} + +// Write implements io.Writer. Each call is treated as a single log line. +func (lb *publisher) Write(b []byte) (int, error) { + entry := string(b) + lb.publish(entry) + return len(b), nil +} + +func (lb *publisher) publish(entry LogEntry) { + lb.mu.Lock() + lb.ring[lb.ringIdx%lb.ringSize] = entry + lb.ringIdx++ + lb.mu.Unlock() + + lb.mu.RLock() + defer lb.mu.RUnlock() + for ch := range lb.clients { + select { + case ch <- entry: + default: // drop if client is slow + } + } +} + +func (lb *publisher) subscribe() (chan LogEntry, func()) { + ch := make(chan LogEntry, lb.ringSize) + lb.mu.Lock() + start := max(0, lb.ringIdx-lb.ringSize) + for i := start; i < lb.ringIdx; i++ { + entry := lb.ring[i%lb.ringSize] + if entry != "" { + ch <- entry + } + } + lb.clients[ch] = struct{}{} + lb.mu.Unlock() + + unsub := func() { + lb.mu.Lock() + delete(lb.clients, ch) + lb.mu.Unlock() + } + return ch, unsub +} diff --git a/log/publish_test.go b/log/publish_test.go new file mode 100644 index 00000000..ad74e916 --- /dev/null +++ b/log/publish_test.go @@ -0,0 +1,135 @@ +package log + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPublisher(t *testing.T) { + p := newPublisher(10) + + ch, unsub := p.subscribe() + defer unsub() + + entry := "time=2025-01-01T00:00:00.000Z level=INFO msg=hello\n" + p.publish(entry) + + select { + case got := <-ch: + assert.Equal(t, entry, got) + case <-time.After(time.Second): + t.Fatal("timed out waiting for broadcast") + } +} + +func TestPublisherWrite(t *testing.T) { + p := newPublisher(10) + + ch, unsub := p.subscribe() + defer unsub() + + line := "time=2025-01-01T00:00:00.000Z level=INFO msg=hello\n" + n, err := p.Write([]byte(line)) + require.NoError(t, err) + assert.Equal(t, len(line), n) + + select { + case got := <-ch: + assert.Equal(t, line, got) + case <-time.After(time.Second): + t.Fatal("timed out waiting for broadcast") + } +} + +func TestMultipleSubscribers(t *testing.T) { + p := newPublisher(10) + + ch1, unsub1 := p.subscribe() + defer unsub1() + ch2, unsub2 := p.subscribe() + defer unsub2() + + entry := "time=2025-01-01T00:00:00.000Z level=DEBUG msg=multi\n" + p.publish(entry) + + for _, ch := range []chan LogEntry{ch1, ch2} { + select { + case got := <-ch: + assert.Equal(t, entry, got) + case <-time.After(time.Second): + t.Fatal("timed out waiting for broadcast") + } + } +} + +func TestUnsubscribe(t *testing.T) { + p := newPublisher(10) + + ch, unsub := p.subscribe() + unsub() + + p.publish("time=2025-01-01T00:00:00.000Z level=INFO msg=\"after unsub\"\n") + + select { + case <-ch: + t.Fatal("should not receive after unsubscribe") + case <-time.After(50 * time.Millisecond): + // expected + } +} + +func TestRingBuffer(t *testing.T) { + p := newPublisher(3) + + // Fill the ring buffer with 5 entries, so only the last 3 should be available. + for i := range 5 { + p.publish(string(rune('a'+i)) + "\n") + } + + ch, unsub := p.subscribe() + defer unsub() + + // New subscriber should get the 3 ring buffer entries. + var msgs []string + for range 3 { + select { + case e := <-ch: + msgs = append(msgs, e) + case <-time.After(time.Second): + t.Fatal("timed out reading ring buffer entries") + } + } + assert.Equal(t, []string{"c\n", "d\n", "e\n"}, msgs) +} + +func TestConcurrentBroadcast(t *testing.T) { + p := newPublisher(100) + ch, unsub := p.subscribe() + defer unsub() + + var wg sync.WaitGroup + n := 50 + wg.Add(n) + for i := range n { + go func(i int) { + defer wg.Done() + p.publish("msg\n") + }(i) + } + wg.Wait() + + received := 0 + for { + select { + case <-ch: + received++ + default: + require.Equal(t, n, received) + return + } + } +} diff --git a/option/algeneva.go b/option/algeneva.go deleted file mode 100644 index dd638ee7..00000000 --- a/option/algeneva.go +++ /dev/null @@ -1,12 +0,0 @@ -package option - -import "github.com/sagernet/sing-box/option" - -type ALGenevaInboundOptions struct { - option.HTTPMixedInboundOptions -} - -type ALGenevaOutboundOptions struct { - option.HTTPOutboundOptions - Strategy string `json:"strategy,omitempty"` -} diff --git a/option/amnezia.go b/option/amnezia.go deleted file mode 100644 index 2cc4c95f..00000000 --- a/option/amnezia.go +++ /dev/null @@ -1,85 +0,0 @@ -package option - -import ( - O "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common/json/badoption" - "net/netip" -) - -/************* ADDED FOR AMNEZIA *************/ -/* -WireGuardAdvancedSecurityOptions provides advanced security options for WireGuard required to activate AmneziaWG. - -In AmneziaWG, random bytes are appended to every auth packet to alter their size. - -Thus, "init and response handshake packets" have added "junk" at the beginning of their data, the size of which -is determined by the values S1 and S2. - -By default, the initiating handshake packet has a fixed size (148 bytes). After adding the junk, its size becomes 148 bytes + S1. -AmneziaWG also incorporates another trick for more reliable masking. Before initiating a session, Amnezia sends a - -certain number of "junk" packets to thoroughly confuse DPI systems. The number of these packets and their -minimum and maximum byte sizes can also be adjusted in the settings, using parameters Jc, Jmin, and Jmax. - -*/ - -type WireGuardAdvancedSecurityOptions struct { - JunkPacketCount int `json:"junk_packet_count,omitempty"` // jc - JunkPacketMinSize int `json:"junk_packet_min_size,omitempty"` // jmin - JunkPacketMaxSize int `json:"junk_packet_max_size,omitempty"` // jmax - InitPacketJunkSize int `json:"init_packet_junk_size,omitempty"` // s1 - ResponsePacketJunkSize int `json:"response_packet_junk_size,omitempty"` // s2 - InitPacketMagicHeader uint32 `json:"init_packet_magic_header,omitempty"` // h1 - ResponsePacketMagicHeader uint32 `json:"response_packet_magic_header,omitempty"` // h2 - UnderloadPacketMagicHeader uint32 `json:"underload_packet_magic_header,omitempty"` // h3 - TransportPacketMagicHeader uint32 `json:"transport_packet_magic_header,omitempty"` // h4 -} -/******************** END ********************/ -type WireGuardEndpointOptions struct { - System bool `json:"system,omitempty"` - Name string `json:"name,omitempty"` - MTU uint32 `json:"mtu,omitempty"` - Address badoption.Listable[netip.Prefix] `json:"address"` - PrivateKey string `json:"private_key"` - ListenPort uint16 `json:"listen_port,omitempty"` - Peers []WireGuardPeer `json:"peers,omitempty"` - UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"` - Workers int `json:"workers,omitempty"` - WireGuardAdvancedSecurityOptions /** ADDED FOR AMNEZIA **/ - O.DialerOptions -} - -type WireGuardPeer struct { - Address string `json:"address,omitempty"` - Port uint16 `json:"port,omitempty"` - PublicKey string `json:"public_key,omitempty"` - PreSharedKey string `json:"pre_shared_key,omitempty"` - AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"` - PersistentKeepaliveInterval uint16 `json:"persistent_keepalive_interval,omitempty"` - Reserved []uint8 `json:"reserved,omitempty"` -} - -type LegacyWireGuardOutboundOptions struct { - O.DialerOptions - SystemInterface bool `json:"system_interface,omitempty"` - GSO bool `json:"gso,omitempty"` - InterfaceName string `json:"interface_name,omitempty"` - LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"` - PrivateKey string `json:"private_key"` - Peers []LegacyWireGuardPeer `json:"peers,omitempty"` - O.ServerOptions - PeerPublicKey string `json:"peer_public_key"` - PreSharedKey string `json:"pre_shared_key,omitempty"` - Reserved []uint8 `json:"reserved,omitempty"` - Workers int `json:"workers,omitempty"` - MTU uint32 `json:"mtu,omitempty"` - Network O.NetworkList `json:"network,omitempty"` -} - -type LegacyWireGuardPeer struct { - O.ServerOptions - PublicKey string `json:"public_key,omitempty"` - PreSharedKey string `json:"pre_shared_key,omitempty"` - AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"` - Reserved []uint8 `json:"reserved,omitempty"` -} diff --git a/option/outline.go b/option/outline.go deleted file mode 100644 index a5fb93aa..00000000 --- a/option/outline.go +++ /dev/null @@ -1,55 +0,0 @@ -package option - -import O "github.com/sagernet/sing-box/option" - -// OutboundOutlineOptions set the outbound options used by the outline-sdk -// smart dialer. You can find more details about the parameters by looking -// through the implementation: https://github.com/Jigsaw-Code/outline-sdk/blob/v0.0.18/x/smart/stream_dialer.go#L65-L100 -// Or check the documentation README: https://github.com/Jigsaw-Code/outline-sdk/tree/v0.0.18/x/smart -type OutboundOutlineOptions struct { - O.DialerOptions - DNSResolvers []DNSEntryConfig `json:"dns,omitempty" yaml:"dns,omitempty"` - TLS []string `json:"tls,omitempty" yaml:"tls,omitempty"` - TestTimeout string `json:"test_timeout" yaml:"-"` - Domains []string `json:"domains" yaml:"-"` -} - -// DNSEntryConfig specifies a list of resolvers to test and they can be one of -// the attributes (system, https, tls, udp or tcp) -type DNSEntryConfig struct { - // System is used for using the system as a resolver, if you want to use it - // provide an empty object. - System *struct{} `json:"system,omitempty"` - // HTTPS use an encrypted DNS over HTTPS (DoH) resolver. - HTTPS *HTTPSEntryConfig `json:"https,omitempty"` - // TLS use an encrypted DNS over TLS (DoT) resolver. - TLS *TLSEntryConfig `json:"tls,omitempty"` - // UDP use a UDP resolver - UDP *UDPEntryConfig `json:"udp,omitempty"` - // TCP use a TCP resolver - TCP *TCPEntryConfig `json:"tcp,omitempty"` -} - -type HTTPSEntryConfig struct { - // Domain name of the host. - Name string `json:"name,omitempty"` - // Host:port. Defaults to Name:443. - Address string `json:"address,omitempty"` -} - -type TLSEntryConfig struct { - // Domain name of the host. - Name string `json:"name,omitempty"` - // Host:port. Defaults to Name:853. - Address string `json:"address,omitempty"` -} - -type UDPEntryConfig struct { - // Host:port. - Address string `json:"address,omitempty"` -} - -type TCPEntryConfig struct { - // Host:port. - Address string `json:"address,omitempty"` -} diff --git a/radiance.go b/radiance.go deleted file mode 100644 index 5bb2666d..00000000 --- a/radiance.go +++ /dev/null @@ -1,359 +0,0 @@ -// Package radiance provides a local server that proxies all requests to a remote proxy server using different -// protocols meant to circumvent censorship. Radiance uses a [transport.StreamDialer] to dial the target server -// over the desired protocol. The [config.Config] is used to configure the dialer for a proxy server. -package radiance - -import ( - "context" - "fmt" - "log/slog" - "sync" - "sync/atomic" - "time" - - "github.com/Xuanwo/go-locale" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/metric/noop" - "go.opentelemetry.io/otel/trace" - traceNoop "go.opentelemetry.io/otel/trace/noop" - - lcommon "github.com/getlantern/common" - "github.com/getlantern/publicip" - - "github.com/getlantern/radiance/api" - "github.com/getlantern/radiance/backend" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/deviceid" - "github.com/getlantern/radiance/common/env" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/config" - "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/issue" - "github.com/getlantern/radiance/kindling" - "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/telemetry" - "github.com/getlantern/radiance/traces" - "github.com/getlantern/radiance/vpn" -) - -const configPollInterval = 10 * time.Minute -const tracerName = "github.com/getlantern/radiance" - -//go:generate mockgen -destination=radiance_mock_test.go -package=radiance github.com/getlantern/radiance configHandler - -// configHandler is an interface that abstracts the config.ConfigHandler struct for easier testing. -type configHandler interface { - // Stop stops the config handler from fetching new configurations. - Stop() - // SetPreferredServerLocation sets the preferred server location. If not set - it's auto selected by the API - SetPreferredServerLocation(country, city string) - // GetConfig returns the current configuration. - // It returns an error if the configuration is not yet available. - GetConfig() (*config.Config, error) -} - -type issueReporter interface { - Report(ctx context.Context, report issue.IssueReport, userEmail, country string) error -} - -// Radiance is a local server that proxies all requests to a remote proxy server over a transport.StreamDialer. -type Radiance struct { - confHandler configHandler - issueReporter issueReporter - apiHandler *api.APIClient - srvManager *servers.Manager - shutdownFuncs []func(context.Context) error - closeOnce sync.Once - stopChan chan struct{} - telemetryConsent atomic.Bool -} - -type Options struct { - DataDir string - LogDir string - Locale string - DeviceID string - LogLevel string - // User choice for telemetry consent - TelemetryConsent bool -} - -// NewRadiance creates a new Radiance VPN client. opts includes the platform interface used to -// interact with the underlying platform on iOS, Android, and MacOS. On other platforms, it is -// ignored and can be nil. -func NewRadiance(opts Options) (*Radiance, error) { - if opts.Locale == "" { - // It is preferable to use the locale from the frontend, as locale is a requirement for lots - // of frontend code and therefore is more reliably supported there. - // However, if the frontend locale is not available, we can use the system locale as a fallback. - if tag, err := locale.Detect(); err != nil { - opts.Locale = "en-US" - } else { - opts.Locale = tag.String() - } - } - - var platformDeviceID string - switch common.Platform { - case "ios", "android": - platformDeviceID = opts.DeviceID - default: - platformDeviceID = deviceid.Get() - } - - shutdownFuncs := []func(context.Context) error{} - if err := common.Init(opts.DataDir, opts.LogDir, opts.LogLevel); err != nil { - return nil, fmt.Errorf("failed to initialize: %w", err) - } - settings.Set(settings.LocaleKey, opts.Locale) - - dataDir := settings.GetString(settings.DataPathKey) - newK, err := kindling.NewKindling() - if err != nil { - slog.Error("failed to initialize kindling", slog.Any("error", err)) - } - if newK != nil { - kindling.SetKindling(newK) - } - setUserConfig(platformDeviceID, dataDir, opts.Locale) - - // Detect public IP before the first config fetch so the server knows our - // real IP even when requests arrive via domain fronting. Uses STUN, DNS, - // and HTTP in parallel — typically completes in 25-200ms. The 2-second - // timeout ensures startup isn't significantly delayed on constrained networks. - ipCtx, ipCancel := context.WithTimeout(context.Background(), 2*time.Second) - result, err := publicip.Detect(ipCtx, &publicip.Config{ - Timeout: 2 * time.Second, - MinConsensus: 1, // accept the first result to minimize delay - }) - ipCancel() - if err != nil { - slog.Warn("Failed to detect public IP", "error", err) - } else { - backend.SetClientIP(result.IP.String()) - slog.Debug("Detected public IP", "confidence", result.Confidence, "sources", result.Sources) - } - - apiHandler := api.NewAPIClient(dataDir) - issueReporter := issue.NewIssueReporter() - - svrMgr, err := servers.NewManager(dataDir) - if err != nil { - return nil, fmt.Errorf("failed to create server manager: %w", err) - } - cOpts := config.Options{ - PollInterval: configPollInterval, - SvrManager: svrMgr, - DataDir: dataDir, - Locale: opts.Locale, - APIHandler: apiHandler, - } - if disableFetch, ok := env.Get[bool](env.DisableFetch); ok && disableFetch { - cOpts.PollInterval = -1 - slog.Info("Disabling config fetch") - } - r := &Radiance{ - issueReporter: issueReporter, - apiHandler: apiHandler, - srvManager: svrMgr, - shutdownFuncs: shutdownFuncs, - stopChan: make(chan struct{}), - closeOnce: sync.Once{}, - } - r.telemetryConsent.Store(opts.TelemetryConsent) - events.Subscribe(func(evt config.NewConfigEvent) { - if r.telemetryConsent.Load() { - slog.Info("Telemetry consent given; handling new config for telemetry") - if err := telemetry.OnNewConfig(evt.Old, evt.New, platformDeviceID); err != nil { - slog.Error("Failed to handle new config for telemetry", "error", err) - } - } else { - slog.Info("Telemetry consent not given; skipping telemetry initialization") - } - }) - r.confHandler = config.NewConfigHandler(cOpts) - // Register AFTER NewConfigHandler so the disk-load event is already - // consumed. Runs whenever a new config is applied to provide continuous - // bandit callback data even when the VPN tunnel is not active. - sub := events.Subscribe(func(evt config.NewConfigEvent) { - vpn.RunURLTests(dataDir) - }) - r.addShutdownFunc(telemetry.Close, kindling.Close, func(_ context.Context) error { - sub.Unsubscribe() - return nil - }) - return r, nil -} - -// addShutdownFunc adds a shutdown function(s) to the Radiance instance. -// This function is called when the Radiance instance is closed to ensure that all -// resources are cleaned up properly. -func (r *Radiance) addShutdownFunc(fns ...func(context.Context) error) { - for _, fn := range fns { - if fn != nil { - r.shutdownFuncs = append(r.shutdownFuncs, fn) - } - } -} - -func (r *Radiance) Close() { - r.closeOnce.Do(func() { - slog.Debug("Closing Radiance") - r.confHandler.Stop() - close(r.stopChan) - for _, shutdown := range r.shutdownFuncs { - if err := shutdown(context.Background()); err != nil { - slog.Error("Failed to shutdown", "error", err) - } - } - }) - <-r.stopChan -} - -// APIHandler returns the API handler for the Radiance client. -func (r *Radiance) APIHandler() *api.APIClient { - return r.apiHandler -} - -// SetPreferredServer sets the preferred server location for the VPN connection. -// pass empty strings to auto select the server location -func (r *Radiance) SetPreferredServer(ctx context.Context, country, city string) { - r.confHandler.SetPreferredServerLocation(country, city) -} - -// ServerManager returns the server manager for the Radiance client. -func (r *Radiance) ServerManager() *servers.Manager { - return r.srvManager -} - -type IssueReport = issue.IssueReport - -// ReportIssue submits an issue report to the back-end with an optional user email -func (r *Radiance) ReportIssue(email string, report IssueReport) error { - ctx, span := otel.Tracer(tracerName).Start(context.Background(), "report_issue") - defer span.End() - if report.Type == "" && report.Description == "" { - return fmt.Errorf("issue report should contain at least type or description") - } - var country string - // get country from the config returned by the backend - cfg, err := r.confHandler.GetConfig() - if err != nil { - slog.Warn("Failed to get config", "error", err) - } else { - country = cfg.ConfigResponse.Country - } - - err = r.issueReporter.Report(ctx, report, email, country) - if err != nil { - slog.Error("Failed to report issue", "error", err) - return traces.RecordError(ctx, fmt.Errorf("failed to report issue: %w", err)) - } - slog.Info("Issue reported successfully") - return nil -} - -// Features returns the features available in the current configuration, returned from the server in the -// config response. -func (r *Radiance) Features() map[string]bool { - _, span := otel.Tracer(tracerName).Start(context.Background(), "features") - defer span.End() - cfg, err := r.confHandler.GetConfig() - if err != nil { - slog.Info("Failed to get config for features", "error", err) - return map[string]bool{} - } - if cfg == nil { - slog.Info("No config available for features, returning empty map") - return map[string]bool{} - } - slog.Debug("Returning features from config", "features", cfg.ConfigResponse.Features) - // Return the features from the config - if cfg.ConfigResponse.Features == nil { - slog.Info("No features available in config, returning empty map") - return map[string]bool{} - } - return cfg.ConfigResponse.Features -} - -// EnableTelemetry enable OpenTelemetry instrumentation for the Radiance client. -// After enabling it, it should initialize telemetry again once a new config -// is available -func (r *Radiance) EnableTelemetry() { - slog.Info("Enabling telemetry") - r.telemetryConsent.Store(true) - // If a config is already available, initialize telemetry immediately instead of - // waiting for the next config change event. - cfg, err := r.confHandler.GetConfig() - if err != nil { - slog.Warn("Failed to get config while enabling telemetry; telemetry will be initialized on next config update", "error", err) - return - } - if cfg == nil { - slog.Info("No config available while enabling telemetry; telemetry will be initialized on next config update") - return - } - cErr := telemetry.OnNewConfig(nil, cfg, settings.GetString(settings.DeviceIDKey)) - if cErr != nil { - slog.Warn("Failed to initialize telemetry on enabling", "error", cErr) - } -} - -// DisableTelemetry disables OpenTelemetry instrumentation for the Radiance client. -func (r *Radiance) DisableTelemetry() { - slog.Info("Disabling telemetry") - r.telemetryConsent.Store(false) - otel.SetTracerProvider(traceNoop.NewTracerProvider()) - otel.SetMeterProvider(noop.NewMeterProvider()) -} - -// ServerLocations returns the list of server locations where the user can connect to proxies. -func (r *Radiance) ServerLocations() ([]lcommon.ServerLocation, error) { - ctx, span := otel.Tracer(tracerName).Start(context.Background(), "server_locations") - defer span.End() - cfg, err := r.confHandler.GetConfig() - if err != nil { - slog.Error("Failed to get config for server locations", "error", err) - traces.RecordError(ctx, err, trace.WithStackTrace(true)) - return nil, fmt.Errorf("failed to get config: %w", err) - } - if cfg == nil { - slog.Info("No config available for server locations, returning error") - traces.RecordError(ctx, err, trace.WithStackTrace(true)) - return nil, fmt.Errorf("no config available") - } - slog.Debug("Returning server locations from config", "locations", cfg.ConfigResponse.Servers) - return cfg.ConfigResponse.Servers, nil -} - -type slogWriter struct { - *slog.Logger -} - -func (w *slogWriter) Write(p []byte) (n int, err error) { - // Convert the byte slice to a string and log it - w.Info(string(p)) - return len(p), nil -} - -// setUserConfig creates a new UserInfo object -func setUserConfig(deviceID, dataDir, locale string) { - if err := settings.Set(settings.DeviceIDKey, deviceID); err != nil { - slog.Error("failed to set device ID in settings", "error", err) - } - if err := settings.Set(settings.DataPathKey, dataDir); err != nil { - slog.Error("failed to set data path in settings", "error", err) - } - if err := settings.Set(settings.LocaleKey, locale); err != nil { - slog.Error("failed to set locale in settings", "error", err) - } - - events.SubscribeOnce(func(evt config.NewConfigEvent) { - if evt.New != nil && evt.New.ConfigResponse.Country != "" { - if err := settings.Set(settings.CountryCodeKey, evt.New.ConfigResponse.Country); err != nil { - slog.Error("failed to set country code in settings", "error", err) - } - slog.Info("Set country code from config response", "country_code", evt.New.ConfigResponse.Country) - } - }) -} diff --git a/radiance_mock_test.go b/radiance_mock_test.go deleted file mode 100644 index 1f455bae..00000000 --- a/radiance_mock_test.go +++ /dev/null @@ -1,96 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/getlantern/radiance (interfaces: configHandler) -// -// Generated by this command: -// -// mockgen -destination=radiance_mock_test.go -package=radiance github.com/getlantern/radiance configHandler -// - -// Package radiance is a generated GoMock package. -package radiance - -import ( - context "context" - reflect "reflect" - - config "github.com/getlantern/common" - gomock "go.uber.org/mock/gomock" -) - -// MockconfigHandler is a mock of configHandler interface. -type MockconfigHandler struct { - ctrl *gomock.Controller - recorder *MockconfigHandlerMockRecorder - isgomock struct{} -} - -// MockconfigHandlerMockRecorder is the mock recorder for MockconfigHandler. -type MockconfigHandlerMockRecorder struct { - mock *MockconfigHandler -} - -// NewMockconfigHandler creates a new mock instance. -func NewMockconfigHandler(ctrl *gomock.Controller) *MockconfigHandler { - mock := &MockconfigHandler{ctrl: ctrl} - mock.recorder = &MockconfigHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockconfigHandler) EXPECT() *MockconfigHandlerMockRecorder { - return m.recorder -} - -// GetConfig mocks base method. -func (m *MockconfigHandler) GetConfig(ctx context.Context) (*config.ConfigResponse, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetConfig", ctx) - ret0, _ := ret[0].(*config.ConfigResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetConfig indicates an expected call of GetConfig. -func (mr *MockconfigHandlerMockRecorder) GetConfig(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfig", reflect.TypeOf((*MockconfigHandler)(nil).GetConfig), ctx) -} - -// ListAvailableServers mocks base method. -func (m *MockconfigHandler) ListAvailableServers(ctx context.Context) ([]config.ServerLocation, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListAvailableServers", ctx) - ret0, _ := ret[0].([]config.ServerLocation) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListAvailableServers indicates an expected call of ListAvailableServers. -func (mr *MockconfigHandlerMockRecorder) ListAvailableServers(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAvailableServers", reflect.TypeOf((*MockconfigHandler)(nil).ListAvailableServers), ctx) -} - -// SetPreferredServerLocation mocks base method. -func (m *MockconfigHandler) SetPreferredServerLocation(country, city string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetPreferredServerLocation", country, city) -} - -// SetPreferredServerLocation indicates an expected call of SetPreferredServerLocation. -func (mr *MockconfigHandlerMockRecorder) SetPreferredServerLocation(country, city any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPreferredServerLocation", reflect.TypeOf((*MockconfigHandler)(nil).SetPreferredServerLocation), country, city) -} - -// Stop mocks base method. -func (m *MockconfigHandler) Stop() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Stop") -} - -// Stop indicates an expected call of Stop. -func (mr *MockconfigHandlerMockRecorder) Stop() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockconfigHandler)(nil).Stop)) -} diff --git a/radiance_test.go b/radiance_test.go deleted file mode 100644 index a153e9f3..00000000 --- a/radiance_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package radiance - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/getlantern/radiance/config" -) - -func TestNewRadiance(t *testing.T) { - t.Run("it should create a new Radiance instance successfully", func(t *testing.T) { - dir := t.TempDir() - r, err := NewRadiance(Options{ - DataDir: dir, - Locale: "en-US", - }) - assert.NoError(t, err) - r.Close() - - assert.NotNil(t, r) - assert.NotNil(t, r.confHandler) - assert.NotNil(t, r.stopChan) - assert.NotNil(t, r.issueReporter) - }) -} - -func TestReportIssue(t *testing.T) { - var tests = []struct { - name string - email string - report IssueReport - assert func(*testing.T, error) - }{ - { - name: "return error when missing type and description", - email: "", - report: IssueReport{}, - assert: func(t *testing.T, err error) { - assert.Error(t, err) - }, - }, - { - name: "return nil when issue report is valid", - email: "radiancetest@getlantern.org", - report: IssueReport{ - Type: "Application crashes", - Description: "internal test only", - Device: "test device", - Model: "a123", - }, - assert: func(t *testing.T, err error) { - assert.NoError(t, err) - }, - }, - { - name: "return nil when issue report is valid with empty email", - email: "", - report: IssueReport{ - Type: "Cannot sign in", - Description: "internal test only", - Device: "test device 2", - Model: "b456", - }, - assert: func(t *testing.T, err error) { - assert.NoError(t, err) - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := &Radiance{ - issueReporter: &mockIssueReporter{}, - confHandler: &mockConfigHandler{}, - } - err := r.ReportIssue(tt.email, tt.report) - tt.assert(t, err) - }) - } -} - -type mockIssueReporter struct{} - -func (m *mockIssueReporter) Report(_ context.Context, _ IssueReport, _, _ string) error { return nil } - -type mockConfigHandler struct{} - -func (m *mockConfigHandler) Stop() {} - -func (m *mockConfigHandler) SetPreferredServerLocation(country string, city string) {} - -func (m *mockConfigHandler) GetConfig() (*config.Config, error) { - return &config.Config{}, nil -} - -func (m *mockConfigHandler) AddConfigListener(listener config.ListenerFunc) { - listener(&config.Config{}, &config.Config{}) -} diff --git a/servers/manager.go b/servers/manager.go index 404d83e2..0c6302f6 100644 --- a/servers/manager.go +++ b/servers/manager.go @@ -4,20 +4,20 @@ package servers import ( + "bytes" "context" "errors" "fmt" "io" "log/slog" - "maps" "net" "net/http" "net/url" "os" "path/filepath" "runtime" - "slices" "strconv" + "strings" "sync" "time" @@ -28,10 +28,10 @@ import ( C "github.com/getlantern/common" "github.com/getlantern/radiance/bypass" - "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/atomicfile" - "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/common/fileperm" "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/traces" "github.com/getlantern/pluriconfig" @@ -59,7 +59,7 @@ const ( saveCriticalThreshold = 15 * time.Second // readerWaitThreshold: log a WARN with a goroutine stack dump if a - // reader (ServersJSON / GetServerByTagJSON) waits longer than this to + // reader (AllServers / GetServerByTag) waits longer than this to // acquire the RLock. Direct evidence of reader starvation. readerWaitThreshold = 1 * time.Second ) @@ -75,14 +75,7 @@ func dumpAllGoroutines() string { return string(buf[:n]) } -type ServerGroup = string - -const ( - SGLantern ServerGroup = "lantern" - SGUser ServerGroup = "user" - - tracerName = "github.com/getlantern/radiance/servers" -) +const tracerName = "github.com/getlantern/radiance/servers" // ServerCredentials holds the access token and invite status for a private server. type ServerCredentials struct { @@ -91,89 +84,137 @@ type ServerCredentials struct { IsJoined bool `json:"is_joined,omitempty"` // whether the user has joined the server (i.e. accepted the invite) } -type Options struct { - Outbounds []option.Outbound `json:"outbounds,omitempty"` - Endpoints []option.Endpoint `json:"endpoints,omitempty"` - Locations map[string]C.ServerLocation `json:"locations,omitempty"` - URLOverrides map[string]string `json:"url_overrides,omitempty"` - Credentials map[string]ServerCredentials `json:"credentials,omitempty"` +type Server struct { + Tag string `json:"tag"` + Type string `json:"type"` + IsLantern bool `json:"isLantern"` + Options any `json:"options"` + Location C.ServerLocation `json:"location,omitempty"` + Credentials *ServerCredentials `json:"credentials,omitempty"` + URLTestResult *URLTestResult `json:"urlTestResult,omitempty"` } -// MarshalJSON encodes Options using the sing-box context so that type-specific outbound/endpoint -// options (server, port, password, etc.) are included in the output. -func (o Options) MarshalJSON() ([]byte, error) { - type Alias Options - return json.MarshalContext(box.BaseContext(), Alias(o)) +// serverJSON is the on-wire representation of a Server. The Options field is split into +// explicit Outbound/Endpoint fields so that the sing-box context-aware JSON marshaler can +// properly serialize/deserialize the typed options (e.g. SamizdatOutboundOptions). +type serverJSON struct { + Tag string `json:"tag"` + Type string `json:"type"` + IsLantern bool `json:"isLantern"` + Outbound *option.Outbound `json:"outbound,omitempty"` + Endpoint *option.Endpoint `json:"endpoint,omitempty"` + Location C.ServerLocation `json:"location,omitempty"` + Credentials *ServerCredentials `json:"credentials,omitempty"` + URLTestResult *URLTestResult `json:"urlTestResult,omitempty"` } -// AllTags returns a slice of all tags from both endpoints and outbounds in the Options. -func (o Options) AllTags() []string { - tags := make([]string, 0, len(o.Outbounds)+len(o.Endpoints)) - for _, ep := range o.Endpoints { - tags = append(tags, ep.Tag) +func (s Server) MarshalJSON() ([]byte, error) { + sj := serverJSON{ + Tag: s.Tag, + Type: s.Type, + IsLantern: s.IsLantern, + Location: s.Location, + Credentials: s.Credentials, + URLTestResult: s.URLTestResult, } - for _, out := range o.Outbounds { - tags = append(tags, out.Tag) + switch opts := s.Options.(type) { + case option.Outbound: + sj.Outbound = &opts + case option.Endpoint: + sj.Endpoint = &opts + } + return json.MarshalContext(box.BaseContext(), sj) +} + +func (s *Server) UnmarshalJSON(data []byte) error { + sj, err := json.UnmarshalExtendedContext[serverJSON](box.BaseContext(), data) + if err != nil { + return err + } + s.Tag = sj.Tag + s.Type = sj.Type + s.IsLantern = sj.IsLantern + s.Location = sj.Location + s.Credentials = sj.Credentials + s.URLTestResult = sj.URLTestResult + if sj.Outbound != nil { + s.Options = *sj.Outbound + } else if sj.Endpoint != nil { + s.Options = *sj.Endpoint + } + return nil +} + +// ServerList is a batch of servers with optional URL overrides for bulk operations. +type ServerList struct { + Servers []*Server `json:"servers"` + URLOverrides map[string]string `json:"url_overrides,omitempty"` +} + +func (sl ServerList) Tags() []string { + tags := make([]string, 0, len(sl.Servers)) + for _, s := range sl.Servers { + tags = append(tags, s.Tag) } return tags } -type Servers map[ServerGroup]Options +func (sl ServerList) Outbounds() []option.Outbound { + var out []option.Outbound + for _, s := range sl.Servers { + if o, ok := s.Options.(option.Outbound); ok { + out = append(out, o) + } + } + return out +} + +func (sl ServerList) Endpoints() []option.Endpoint { + var eps []option.Endpoint + for _, s := range sl.Servers { + if e, ok := s.Options.(option.Endpoint); ok { + eps = append(eps, e) + } + } + return eps +} // Manager manages server configurations, including endpoints and outbounds. type Manager struct { - access sync.RWMutex - servers Servers - optsMaps map[ServerGroup]map[string]any // map of tag to option for quick access + access sync.RWMutex + servers map[string]*Server // tag -> Server // saveMu serializes disk writes in saveServers. This is separate from access - // so that readers (e.g. ServersJSON) aren't blocked during disk I/O — only + // so that readers (e.g. AllServers) aren't blocked during disk I/O — only // during the brief JSON marshalling step. saveMu sync.Mutex + logger *slog.Logger serversFile string httpClient *http.Client } // NewManager creates a new Manager instance, loading server options from disk. -func NewManager(dataPath string) (*Manager, error) { +func NewManager(dataPath string, logger *slog.Logger) (*Manager, error) { mgr := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - Credentials: make(map[string]ServerCredentials), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - Credentials: make(map[string]ServerCredentials), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: make(map[string]any), - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - access: sync.RWMutex{}, - + servers: make(map[string]*Server), + serversFile: filepath.Join(dataPath, internal.ServersFileName), + logger: logger, // Use the bypass proxy dialer to route requests outside the VPN tunnel. // This client is only used to access private servers the user has created. - httpClient: retryableHTTPClient().StandardClient(), + httpClient: retryableHTTPClient(logger).StandardClient(), } - slog.Debug("Loading servers", "file", mgr.serversFile) + mgr.logger.Debug("Loading servers", "file", mgr.serversFile) if err := mgr.loadServers(); err != nil { - slog.Error("Failed to load servers", "file", mgr.serversFile, "error", err) + mgr.logger.Error("Failed to load servers", "file", mgr.serversFile, "error", err) return nil, fmt.Errorf("failed to load servers from file: %w", err) } - slog.Log(nil, internal.LevelTrace, "Loaded servers", "servers", mgr.servers) + mgr.logger.Log(nil, log.LevelTrace, "Loaded servers") return mgr, nil } -func retryableHTTPClient() *retryablehttp.Client { +func retryableHTTPClient(logger *slog.Logger) *retryablehttp.Client { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: bypass.DialContext, @@ -192,95 +233,61 @@ func retryableHTTPClient() *retryablehttp.Client { client.RetryMax = 10 client.RetryWaitMin = 1 * time.Second client.RetryWaitMax = 10 * time.Second + client.Logger = logger return client } -// Servers returns the current server configurations for both groups ([SGLantern] and [SGUser]). -func (m *Manager) Servers() Servers { +// AllServers returns a deep-copied slice of all servers. +func (m *Manager) AllServers() []*Server { + start := time.Now() m.access.RLock() + wait := time.Since(start) defer m.access.RUnlock() - - result := make(Servers, len(m.servers)) - for group, opts := range m.servers { - result[group] = Options{ - Outbounds: append([]option.Outbound{}, opts.Outbounds...), - Endpoints: append([]option.Endpoint{}, opts.Endpoints...), - Locations: maps.Clone(opts.Locations), - URLOverrides: maps.Clone(opts.URLOverrides), - Credentials: maps.Clone(opts.Credentials), - } + warnIfReaderStarved("AllServers", wait) + result := make([]*Server, 0, len(m.servers)) + for _, srv := range m.servers { + cp := *srv + result = append(result, &cp) } return result } -type Server struct { - Group ServerGroup - Tag string - Type string - Options any // will be either [option.Endpoint] or [option.Outbound] - Location C.ServerLocation -} - -// GetServerByTag returns the server configuration for a given tag and a boolean indicating whether -// the server was found. -func (m *Manager) GetServerByTag(tag string) (Server, bool) { - m.access.RLock() - defer m.access.RUnlock() - return m.getServerByTagLocked(tag) +// URLTestResult holds the result of a single URL test. +type URLTestResult struct { + Delay uint16 `json:"delay"` + Time time.Time `json:"time"` } -// getServerByTagLocked performs the tag lookup. Caller must hold access.RLock. -func (m *Manager) getServerByTagLocked(tag string) (Server, bool) { - group := SGLantern - opts, ok := m.optsMaps[SGLantern][tag] - if !ok { - if opts, ok = m.optsMaps[SGUser][tag]; !ok { - return Server{}, false +// UpdateURLTestResults updates the URL test results for servers matching the +// provided tags and persists the change to disk. +func (m *Manager) UpdateURLTestResults(results map[string]URLTestResult) error { + func() { + m.access.Lock() + defer m.access.Unlock() + for tag, result := range results { + if srv, exists := m.servers[tag]; exists { + r := result + srv.URLTestResult = &r + } } - group = SGUser - } - s := Server{ - Group: group, - Tag: tag, - Options: opts, - Location: m.servers[group].Locations[tag], - } - switch v := opts.(type) { - case option.Endpoint: - s.Type = v.Type - case option.Outbound: - s.Type = v.Type - } - return s, true -} - -// ServersJSON returns the current server configurations as pre-marshalled JSON. -func (m *Manager) ServersJSON() ([]byte, error) { - start := time.Now() - m.access.RLock() - wait := time.Since(start) - defer m.access.RUnlock() - warnIfReaderStarved("ServersJSON", wait) - return json.MarshalContext(box.BaseContext(), m.servers) + }() + return m.saveServers() } -// GetServerByTagJSON returns the server configuration for a given tag as pre-marshalled JSON. -func (m *Manager) GetServerByTagJSON(tag string) ([]byte, bool, error) { +// GetServerByTag returns the server configuration for a given tag and a boolean indicating whether +// the server was found. +func (m *Manager) GetServerByTag(tag string) (*Server, bool) { start := time.Now() m.access.RLock() wait := time.Since(start) defer m.access.RUnlock() - warnIfReaderStarved("GetServerByTagJSON", wait) - - s, ok := m.getServerByTagLocked(tag) - if !ok { - return nil, false, nil - } - b, err := json.MarshalContext(box.BaseContext(), s) - if err != nil { - return nil, false, fmt.Errorf("marshal server %q: %w", tag, err) + warnIfReaderStarved("GetServerByTag", wait) + s, exists := m.servers[tag] + if !exists { + return nil, false } - return b, true, nil + cp := *s + return &cp, true } // warnIfReaderStarved logs a WARN with a goroutine stack dump when a reader @@ -297,209 +304,89 @@ func warnIfReaderStarved(caller string, wait time.Duration) { ) } -type ServersUpdatedEvent struct { - events.Event - Group ServerGroup - Options *Options -} - -type ServersAddedEvent struct { - events.Event - Group ServerGroup - Options *Options -} - -type ServersRemovedEvent struct { - events.Event - Group ServerGroup - Tag string -} - -// SetServers sets the server options for a specific group. -// Important: this will overwrite any existing servers for that group. To add new servers without -// overwriting existing ones, use [AddServers] instead. -func (m *Manager) SetServers(group ServerGroup, options Options) error { - if err := m.setServers(group, options); err != nil { - return fmt.Errorf("set servers: %w", err) - } +// SetServers sets the server options for servers with a matching IsLantern value. +// Important: this will overwrite any existing servers with the same IsLantern value. To add new +// servers without overwriting existing ones, use [AddServers] instead. +func (m *Manager) SetServers(list ServerList, isLantern bool) error { + func() { + m.access.Lock() + defer m.access.Unlock() + // Remove existing with matching IsLantern + for tag, srv := range m.servers { + if srv.IsLantern == isLantern { + delete(m.servers, tag) + } + } + // Add new + for _, srv := range list.Servers { + srv.IsLantern = isLantern + m.servers[srv.Tag] = srv + } + }() // saveServers acquires its own locks; don't hold the write lock across it. - if err := m.saveServers(); err != nil { - return fmt.Errorf("failed to save servers: %w", err) - } - events.Emit(ServersUpdatedEvent{ - Group: group, - Options: &options, - }) - return nil + return m.saveServers() } -func (m *Manager) setServers(group ServerGroup, options Options) error { - switch group { - case SGLantern, SGUser: - default: - return fmt.Errorf("invalid server group: %s", group) - } - - m.access.Lock() - defer m.access.Unlock() - - slog.Log(nil, internal.LevelTrace, "Setting servers", "group", group, "options", options) - opts := Options{ - Outbounds: append([]option.Outbound{}, options.Outbounds...), - Endpoints: append([]option.Endpoint{}, options.Endpoints...), - Locations: make(map[string]C.ServerLocation, len(options.Locations)), - URLOverrides: maps.Clone(options.URLOverrides), - Credentials: make(map[string]ServerCredentials, len(options.Credentials)), - } - maps.Copy(opts.Locations, options.Locations) - maps.Copy(opts.Credentials, options.Credentials) - - m.servers[group] = opts - oMap := make(map[string]any, len(options.Endpoints)+len(options.Outbounds)) - for _, ep := range options.Endpoints { - oMap[ep.Tag] = ep - } - for _, out := range options.Outbounds { - oMap[out.Tag] = out - } - m.optsMaps[group] = oMap - return nil -} - -// AddServers adds new servers to the specified group. If a server with the same tag already exists, -// it will be skipped. -func (m *Manager) AddServers(group ServerGroup, opts Options) error { - switch group { - case SGLantern, SGUser: - default: - return fmt.Errorf("invalid server group: %s", group) +// AddServers adds new servers. If force is true, it will overwrite any +// existing servers with the same tags. If force is false, it returns an error +// if any of the tags already exist. +func (m *Manager) AddServers(list ServerList, force bool) error { + if len(list.Servers) == 0 { + return nil } // Perform the in-memory mutation under the write lock, then release it // before saving to disk (saveServers acquires its own locks). Scoped // in a closure so defer Unlock is robust against future early returns. - existingTags := func() []string { + if err := func() error { m.access.Lock() defer m.access.Unlock() - slog.Log(nil, internal.LevelTrace, "Adding servers", "group", group, "options", opts) - return m.merge(group, opts) - }() - - if len(existingTags) > 0 { - slog.Warn("Some servers were not added because they already exist", "tags", existingTags) - } - if err := m.saveServers(); err != nil { - return fmt.Errorf("failed to save servers: %w", err) - } - if len(existingTags) > 0 { - slog.Warn("Tried to add some servers that already exist", "tags", existingTags) - return fmt.Errorf("some servers were not added because they already exist: %v", existingTags) - } - slog.Debug("Server configs added", "group", group, "newCount", len(opts.AllTags())) - events.Emit(ServersAddedEvent{ - Group: group, - Options: &opts, - }) - return nil -} - -// merge adds new endpoints and outbounds to the specified group, skipping any that already exist. -// It returns the tags that were skipped. -func (m *Manager) merge(group ServerGroup, options Options) []string { - if len(options.Endpoints) == 0 && len(options.Outbounds) == 0 { - return nil - } - var existingTags []string - opts := m.optsMaps[group] - servers := m.servers[group] - for _, ep := range options.Endpoints { - if _, exists := opts[ep.Tag]; exists { - existingTags = append(existingTags, ep.Tag) - continue - } - opts[ep.Tag] = ep - servers.Endpoints = append(servers.Endpoints, ep) - servers.Locations[ep.Tag] = options.Locations[ep.Tag] - if creds, ok := options.Credentials[ep.Tag]; ok { - servers.Credentials[ep.Tag] = creds - } - } - for _, out := range options.Outbounds { - if _, exists := opts[out.Tag]; exists { - existingTags = append(existingTags, out.Tag) - continue - } - opts[out.Tag] = out - servers.Outbounds = append(servers.Outbounds, out) - servers.Locations[out.Tag] = options.Locations[out.Tag] - if creds, ok := options.Credentials[out.Tag]; ok { - servers.Credentials[out.Tag] = creds + if !force { + for _, srv := range list.Servers { + if _, exists := m.servers[srv.Tag]; exists { + return fmt.Errorf("server %q already exists", srv.Tag) + } + } } - } - for k, v := range options.URLOverrides { - if servers.URLOverrides == nil { - servers.URLOverrides = make(map[string]string) + for _, srv := range list.Servers { + m.servers[srv.Tag] = srv } - servers.URLOverrides[k] = v + return nil + }(); err != nil { + return err } - m.servers[group] = servers - return existingTags + // saveServers acquires its own locks; don't hold the write lock across it. + return m.saveServers() } // RemoveServer removes a server config by its tag. func (m *Manager) RemoveServer(tag string) error { + _, err := m.RemoveServers([]string{tag}) + return err +} + +// RemoveServers removes multiple server configs by their tags and returns the removed servers. +func (m *Manager) RemoveServers(tags []string) ([]*Server, error) { // Perform the in-memory mutation under the write lock, then release it // before saving to disk (saveServers acquires its own locks). Scoped in // a closure so defer Unlock is robust against future early returns. - group, err := func() (ServerGroup, error) { + removed := func() []*Server { m.access.Lock() defer m.access.Unlock() - slog.Log(nil, internal.LevelTrace, "Removing server", "tag", tag) - // check which group the server belongs to - g := SGLantern - if _, exists := m.optsMaps[g][tag]; !exists { - g = SGUser - if _, exists := m.optsMaps[g][tag]; !exists { - return "", fmt.Errorf("server with tag %q not found", tag) + r := make([]*Server, 0, len(tags)) + for _, tag := range tags { + if srv, exists := m.servers[tag]; exists { + r = append(r, srv) + delete(m.servers, tag) } } - // remove the server from the optsMaps and servers - servers := m.servers[g] - switch v := m.optsMaps[g][tag].(type) { - case option.Endpoint: - servers.Endpoints = remove(servers.Endpoints, v) - case option.Outbound: - servers.Outbounds = remove(servers.Outbounds, v) - } - delete(m.optsMaps[g], tag) - delete(servers.Locations, tag) - delete(servers.Credentials, tag) - m.servers[g] = servers - return g, nil + return r }() - if err != nil { - slog.Warn("Tried to remove non-existent server", "tag", tag) - return err - } - + // saveServers acquires its own locks; don't hold the write lock across it. if err := m.saveServers(); err != nil { - return fmt.Errorf("failed to save servers after removing %q: %w", tag, err) + return nil, fmt.Errorf("failed to save servers: %w", err) } - slog.Debug("Server config removed", "group", group, "tag", tag) - events.Emit(ServersRemovedEvent{ - Group: group, - Tag: tag, - }) - return nil -} - -func remove[T comparable](slice []T, item T) []T { - i := slices.Index(slice, item) - if i == -1 { - return slice - } - slice[i] = slice[len(slice)-1] - return slice[:len(slice)-1] + return removed, nil } // saveServers marshals the current server state to JSON and writes it to disk. @@ -507,7 +394,7 @@ func remove[T comparable](slice []T, item T) []T { // The access write lock is NOT held across this function; only a brief RLock // around marshalling. saveMu serializes the full marshal+write sequence so // concurrent callers can't reorder and overwrite a newer snapshot with an -// older one. Readers (e.g. ServersJSON) are not blocked by the disk write — +// older one. Readers (e.g. AllServers) are not blocked by the disk write — // only by the brief marshal window (see getlantern/engineering#3176). // // Each phase (saveMu wait, RLock+marshal, disk write) is timed so we can @@ -526,8 +413,11 @@ func (m *Manager) saveServers() error { marshalStart := time.Now() m.access.RLock() rlockWait := time.Since(marshalStart) - ctx := box.BaseContext() - buf, err := json.MarshalContext(ctx, m.servers) + servers := make([]*Server, 0, len(m.servers)) + for _, srv := range m.servers { + servers = append(servers, srv) + } + buf, err := json.MarshalContext(box.BaseContext(), servers) m.access.RUnlock() marshalDur := time.Since(marshalStart) - rlockWait if err != nil { @@ -535,11 +425,11 @@ func (m *Manager) saveServers() error { } writeStart := time.Now() - werr := atomicfile.WriteFile(m.serversFile, buf, 0644) + werr := atomicfile.WriteFile(m.serversFile, buf, fileperm.File) writeDur := time.Since(writeStart) total := time.Since(start) - slog.Log(nil, internal.LevelTrace, "saveServers timing", + slog.Log(nil, log.LevelTrace, "saveServers timing", "file", m.serversFile, "size", len(buf), "total_ms", total.Milliseconds(), @@ -573,6 +463,11 @@ func (m *Manager) saveServers() error { return werr } +const ( + modeLantern = "lantern" + modeUser = "user" +) + func (m *Manager) loadServers() error { buf, err := atomicfile.ReadFile(m.serversFile) if errors.Is(err, os.ErrNotExist) { @@ -581,19 +476,62 @@ func (m *Manager) loadServers() error { if err != nil { return fmt.Errorf("read server file %q: %w", m.serversFile, err) } - servers, err := json.UnmarshalExtendedContext[Servers](box.BaseContext(), buf) + buf = bytes.TrimSpace(buf) + ctx := box.BaseContext() + + if len(buf) > 0 && buf[0] == '[' { + loaded, err := json.UnmarshalExtendedContext[[]*Server](ctx, buf) + if err != nil { + return fmt.Errorf("unmarshal servers: %w", err) + } + for _, srv := range loaded { + m.servers[srv.Tag] = srv + } + return nil + } + + // Fall back to old format: map[string]Options and mirgrate to new format on save. + type oldOptions struct { + Outbounds []option.Outbound `json:"outbounds,omitempty"` + Endpoints []option.Endpoint `json:"endpoints,omitempty"` + Locations map[string]C.ServerLocation `json:"locations,omitempty"` + Credentials map[string]ServerCredentials `json:"credentials,omitempty"` + } + old, err := json.UnmarshalExtendedContext[map[string]oldOptions](ctx, buf) if err != nil { return fmt.Errorf("unmarshal server options: %w", err) } - m.setServers(SGLantern, servers[SGLantern]) - m.setServers(SGUser, servers[SGUser]) - return nil + for group, opts := range old { + isLantern := group == modeLantern + for _, out := range opts.Outbounds { + srv := &Server{ + Tag: out.Tag, Type: out.Type, IsLantern: isLantern, + Options: out, Location: opts.Locations[out.Tag], + } + if creds, ok := opts.Credentials[out.Tag]; ok { + srv.Credentials = &creds + } + m.servers[out.Tag] = srv + } + for _, ep := range opts.Endpoints { + srv := &Server{ + Tag: ep.Tag, Type: ep.Type, IsLantern: isLantern, + Options: ep, Location: opts.Locations[ep.Tag], + } + if creds, ok := opts.Credentials[ep.Tag]; ok { + srv.Credentials = &creds + } + m.servers[ep.Tag] = srv + } + } + // Re-save in new format + return m.saveServers() } // Lantern Server Manager Integration // AddPrivateServer fetches VPN connection info from a remote server manager and adds it as a server. -func (m *Manager) AddPrivateServer(tag string, ip string, port int, accessToken string, serverLocation *C.ServerLocation, isJoined bool) error { +func (m *Manager) AddPrivateServer(tag, ip string, port int, accessToken string, loc C.ServerLocation, joined bool) error { u := &url.URL{ Scheme: "https", Host: net.JoinHostPort(ip, strconv.Itoa(port)), @@ -615,28 +553,34 @@ func (m *Manager) AddPrivateServer(tag string, ip string, port int, accessToken return fmt.Errorf("failed to read response body: %w", err) } + type remoteConfig struct { + Outbounds []option.Outbound `json:"outbounds,omitempty"` + Endpoints []option.Endpoint `json:"endpoints,omitempty"` + } ctx := box.BaseContext() - servers, err := json.UnmarshalExtendedContext[Options](ctx, body) + cfg, err := json.UnmarshalExtendedContext[remoteConfig](ctx, body) if err != nil { return fmt.Errorf("decode response: %w", err) } - if len(servers.Endpoints) == 0 && len(servers.Outbounds) == 0 { + if len(cfg.Endpoints) == 0 && len(cfg.Outbounds) == 0 { return fmt.Errorf("no endpoints or outbounds in response") } - servers.Outbounds[0].Tag = tag - // If the server location is provided, set it for the server's tag. - if serverLocation != nil { - servers.Locations = map[string]C.ServerLocation{ - tag: *serverLocation, - } - } - // Store the credentials for the server's tag. - servers.Credentials = map[string]ServerCredentials{ - tag: {AccessToken: accessToken, Port: port, IsJoined: isJoined}, + // TODO: update when we support endpoints + cfg.Outbounds[0].Tag = tag + srv := &Server{ + Tag: tag, + Type: cfg.Outbounds[0].Type, + IsLantern: false, + Options: cfg.Outbounds[0], + Location: loc, + Credentials: &ServerCredentials{ + AccessToken: accessToken, Port: port, IsJoined: joined, + }, } - slog.Info("Adding private server from remote manager", "tag", tag, "ip", ip, "port", port, "location", serverLocation, "is_joined", isJoined) - return m.AddServers(SGUser, servers) + slog.Info("Adding private server from remote manager", "tag", tag, "ip", ip, "port", port, "location", loc, "is_joined", joined) + list := ServerList{Servers: []*Server{srv}} + return m.AddServers(list, false) } // InviteToPrivateServer invites another user to the server manager instance and returns a connection @@ -680,55 +624,60 @@ func (m *Manager) RevokePrivateServerInvite(ip string, port int, accessToken str return nil } -// AddServerWithSingboxJSON parse a value that can be a JSON sing-box config. -// It parses the config into a sing-box config and add it to the user managed group. -func (m *Manager) AddServerWithSingboxJSON(ctx context.Context, value []byte) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "Manager.AddServerWithSingboxJSON") +// AddServersByJSON adds any outbounds and endpoints defined in the provided sing-box JSON config. +func (m *Manager) AddServersByJSON(ctx context.Context, config []byte) ([]string, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "Manager.AddServerBySingboxJSON") defer span.End() - var opts Options - if err := json.UnmarshalContext(box.BaseContext(), value, &opts); err != nil { - return traces.RecordError(ctx, fmt.Errorf("failed to parse config: %w", err)) + type singboxConfig struct { + Outbounds []option.Outbound `json:"outbounds,omitempty"` + Endpoints []option.Endpoint `json:"endpoints,omitempty"` } - if len(opts.Endpoints) == 0 && len(opts.Outbounds) == 0 { - return traces.RecordError(ctx, fmt.Errorf("no endpoints or outbounds found in the provided configuration")) + cfg, err := json.UnmarshalExtendedContext[singboxConfig](box.BaseContext(), config) + if err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("failed to parse config: %w", err)) } - if err := m.AddServers(SGUser, opts); err != nil { - return traces.RecordError(ctx, fmt.Errorf("failed to add servers: %w", err)) + if len(cfg.Endpoints) == 0 && len(cfg.Outbounds) == 0 { + return nil, traces.RecordError(ctx, fmt.Errorf("no endpoints or outbounds found in the provided configuration")) } - return nil + servers := make([]*Server, 0, len(cfg.Outbounds)+len(cfg.Endpoints)) + tags := make([]string, 0, len(cfg.Outbounds)+len(cfg.Endpoints)) + for _, out := range cfg.Outbounds { + if out.Tag == "" { + return nil, traces.RecordError(ctx, fmt.Errorf("outbound missing tag")) + } + servers = append(servers, &Server{Tag: out.Tag, Type: out.Type, Options: out}) + tags = append(tags, out.Tag) + } + for _, ep := range cfg.Endpoints { + if ep.Tag == "" { + return nil, traces.RecordError(ctx, fmt.Errorf("endpoint missing tag")) + } + servers = append(servers, &Server{Tag: ep.Tag, Type: ep.Type, Options: ep}) + tags = append(tags, ep.Tag) + } + if err := m.AddServers(ServerList{Servers: servers}, false); err != nil { + return nil, traces.RecordError(ctx, fmt.Errorf("failed to add servers: %w", err)) + } + return tags, nil } -// AddServerBasedOnURLs adds a server(s) based on the provided URL string. -// The URL can be a comma-separated list of URLs, URLs separated by new lines, or a single URL. -// Note that the UI allows the user to specify a server name. If there is only one URL, the server name overrides -// the tag typically included in the URL. If there are multiple URLs, the server name is ignored. -func (m *Manager) AddServerBasedOnURLs(ctx context.Context, urls string, skipCertVerification bool, serverName string) error { - ctx, span := otel.Tracer(tracerName).Start(ctx, "Manager.AddServerBasedOnURLs") +// AddServersByURL adds a server(s) by downloading and parsing the config from a list of URLs. +func (m *Manager) AddServersByURL(ctx context.Context, urls []string, skipCertVerification bool) ([]string, error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "Manager.AddServerByURLs") defer span.End() urlProvider, loaded := pluriconfig.GetProvider(string(model.ProviderURL)) if !loaded { - return traces.RecordError(ctx, fmt.Errorf("URL config provider not loaded")) + return nil, traces.RecordError(ctx, fmt.Errorf("URL config provider not loaded")) } - cfg, err := urlProvider.Parse(ctx, []byte(urls)) + cfg, err := urlProvider.Parse(ctx, []byte(strings.Join(urls, "\n"))) if err != nil { - return traces.RecordError(ctx, fmt.Errorf("failed to parse URLs: %w", err)) + return nil, traces.RecordError(ctx, fmt.Errorf("failed to parse URLs: %w", err)) } cfgURLs, ok := cfg.Options.([]url.URL) if !ok || len(cfgURLs) == 0 { - return traces.RecordError(ctx, fmt.Errorf("no valid URLs found in the provided configuration")) + return nil, traces.RecordError(ctx, fmt.Errorf("no valid URLs found in the provided configuration")) } - // If we only have a single URL, and the server name is specified, use that - // to override the tag specified in the anchor hash fragment. - if len(cfgURLs) == 1 && serverName != "" { - // override the tag, which is specified in the anchor hash fragment or - // in the tag query parameter. - q := cfgURLs[0].Query() - q.Del("tag") - cfgURLs[0].Fragment = serverName - cfgURLs[0].RawQuery = q.Encode() - cfg.Options = cfgURLs - } if skipCertVerification { urlsWithCustomOptions := make([]url.URL, 0, len(cfgURLs)) for _, v := range cfgURLs { @@ -742,12 +691,12 @@ func (m *Manager) AddServerBasedOnURLs(ctx context.Context, urls string, skipCer singBoxProvider, loaded := pluriconfig.GetProvider(string(model.ProviderSingBox)) if !loaded { - return traces.RecordError(ctx, fmt.Errorf("singbox config provider not loaded")) + return nil, traces.RecordError(ctx, fmt.Errorf("singbox config provider not loaded")) } singBoxCfg, err := singBoxProvider.Serialize(ctx, cfg) if err != nil { - return traces.RecordError(ctx, fmt.Errorf("failed to serialize sing-box config: %w", err)) + return nil, traces.RecordError(ctx, fmt.Errorf("failed to serialize sing-box config: %w", err)) } - slog.Info("Adding servers based on URLs", "serverCount", len(cfgURLs), "skipCertVerification", skipCertVerification, "serverName", serverName) - return m.AddServerWithSingboxJSON(ctx, singBoxCfg) + m.logger.Info("Added servers based on URLs", "serverCount", len(cfgURLs), "skipCertVerification", skipCertVerification) + return m.AddServersByJSON(ctx, singBoxCfg) } diff --git a/servers/manager_test.go b/servers/manager_test.go index b94bd36b..6a6ce0e6 100644 --- a/servers/manager_test.go +++ b/servers/manager_test.go @@ -1,150 +1,35 @@ package servers import ( - "context" "crypto/tls" - "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" - "os" "path/filepath" "strconv" "strings" "testing" C "github.com/getlantern/common" + box "github.com/getlantern/lantern-box" + + _ "github.com/getlantern/radiance/common" + "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/getlantern/radiance/common" ) -func newTestManager(t *testing.T) *Manager { - t.Helper() - dataPath := t.TempDir() - mgr := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: []option.Outbound{ - {Tag: "ss-denver", Type: "shadowsocks", Options: &option.ShadowsocksOutboundOptions{ - ServerOptions: option.ServerOptions{ - Server: "1.2.3.4", - ServerPort: 1080, - }, - Method: "chacha20-ietf-poly1305", - Password: "testpass", - }}, - }, - Endpoints: make([]option.Endpoint, 0), - Locations: map[string]C.ServerLocation{ - "ss-denver": {Country: "US", City: "Denver", CountryCode: "US"}, - }, - Credentials: make(map[string]ServerCredentials), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - Credentials: make(map[string]ServerCredentials), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: {"ss-denver": option.Outbound{Tag: "ss-denver", Type: "shadowsocks", Options: &option.ShadowsocksOutboundOptions{ - ServerOptions: option.ServerOptions{Server: "1.2.3.4", ServerPort: 1080}, - Method: "chacha20-ietf-poly1305", - Password: "testpass", - }}}, - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - } - return mgr -} - -func TestServersJSON(t *testing.T) { - mgr := newTestManager(t) - - b, err := mgr.ServersJSON() - require.NoError(t, err) - require.NotEmpty(t, b) - - // Must be valid JSON - var raw map[string]json.RawMessage - require.NoError(t, json.Unmarshal(b, &raw), "ServersJSON must return valid JSON") - assert.Contains(t, raw, "lantern") - assert.Contains(t, raw, "user") - - // Lantern group must include the sing-box type-specific fields - lanternJSON := string(raw["lantern"]) - assert.Contains(t, lanternJSON, "shadowsocks", "should contain outbound type") - assert.Contains(t, lanternJSON, "1.2.3.4", "should contain server address") - assert.Contains(t, lanternJSON, "1080", "should contain server port") - assert.Contains(t, lanternJSON, "chacha20-ietf-poly1305", "should contain method") -} - -func TestGetServerByTagJSON(t *testing.T) { - mgr := newTestManager(t) - - t.Run("existing tag", func(t *testing.T) { - b, ok, err := mgr.GetServerByTagJSON("ss-denver") - require.NoError(t, err) - require.True(t, ok) - require.NotEmpty(t, b) - - // Must be valid JSON - var raw map[string]json.RawMessage - require.NoError(t, json.Unmarshal(b, &raw), "GetServerByTagJSON must return valid JSON") - assert.Contains(t, raw, "Tag") - assert.Contains(t, raw, "Type") - assert.Contains(t, raw, "Options") - assert.Contains(t, raw, "Location") - - // Verify the correct tag and type - fullJSON := string(b) - assert.Contains(t, fullJSON, "ss-denver") - assert.Contains(t, fullJSON, "shadowsocks") - assert.Contains(t, fullJSON, "Denver") - }) - - t.Run("missing tag", func(t *testing.T) { - b, ok, err := mgr.GetServerByTagJSON("nonexistent") - assert.NoError(t, err) - assert.False(t, ok) - assert.Nil(t, b) - }) -} - func TestPrivateServerIntegration(t *testing.T) { - dataPath := t.TempDir() - manager := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - Credentials: make(map[string]ServerCredentials), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - Credentials: make(map[string]ServerCredentials), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: make(map[string]any), - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - httpClient: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, + manager := testManager(t) + manager.httpClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, }, }, } @@ -155,24 +40,27 @@ func TestPrivateServerIntegration(t *testing.T) { port, _ := strconv.Atoi(parsedURL.Port()) t.Run("convert a token into a custom server", func(t *testing.T) { - require.NoError(t, manager.AddPrivateServer("s1", parsedURL.Hostname(), port, "rootToken", nil, false)) - require.Contains(t, manager.optsMaps[SGUser], "s1", "server should be added to the manager") + require.NoError(t, manager.AddPrivateServer("s1", parsedURL.Hostname(), port, "rootToken", C.ServerLocation{}, false)) + _, exists := manager.servers["s1"] + require.True(t, exists, "server should be added to the manager") }) - t.Run("invite a user", func(t *testing.T) { + t.Run("invite user", func(t *testing.T) { inviteToken, err := manager.InviteToPrivateServer(parsedURL.Hostname(), port, "rootToken", "invite1") assert.NoError(t, err) assert.NotEmpty(t, inviteToken) - require.NoError(t, manager.AddPrivateServer("s2", parsedURL.Hostname(), port, inviteToken, nil, true)) - require.Contains(t, manager.optsMaps[SGUser], "s2", "server should be added for the invited user") + require.NoError(t, manager.AddPrivateServer("s2", parsedURL.Hostname(), port, inviteToken, C.ServerLocation{}, true)) + _, exists := manager.servers["s2"] + require.True(t, exists, "server should be added for the invited user") t.Run("revoke user access", func(t *testing.T) { - delete(manager.optsMaps[SGUser], "s1") + delete(manager.servers, "s1") require.NoError(t, manager.RevokePrivateServerInvite(parsedURL.Hostname(), port, "rootToken", "invite1")) // trying to access again with the same token should fail - assert.Error(t, manager.AddPrivateServer("s1", parsedURL.Hostname(), port, inviteToken, nil, true)) - assert.NotContains(t, manager.optsMaps[SGUser], "s1", "server should not be added after revoking invite") + assert.Error(t, manager.AddPrivateServer("s1", parsedURL.Hostname(), port, inviteToken, C.ServerLocation{}, true)) + _, exists := manager.servers["s1"] + assert.False(t, exists, "server should not be added after revoking invite") }) }) @@ -186,8 +74,6 @@ type lanternServerManagerMock struct { func newLanternServerManagerMock() *httptest.Server { testConfig := ` { - "inbounds": [ - ], "outbounds": [ { "tag": "testing-out", @@ -244,223 +130,83 @@ func (s *lanternServerManagerMock) ServeHTTP(w http.ResponseWriter, r *http.Requ w.WriteHeader(http.StatusNotFound) } -func TestAddServerWithSingBoxJSON(t *testing.T) { - dataPath := t.TempDir() - manager := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: make(map[string]any), - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - } - - ctx := context.Background() - jsonConfig := ` - { - "outbounds": [ - { - "type": "shadowsocks", - "tag": "ss-out", - "server": "127.0.0.1", - "server_port": 8388, - "method": "chacha20-ietf-poly1305", - "password": "randompasswordwith24char", - "network": "tcp" - } - ] - }` - - t.Run("adding server with a sing-box json config should work", func(t *testing.T) { - require.NoError(t, manager.AddServerWithSingboxJSON(ctx, []byte(jsonConfig))) - }) - t.Run("using a empty config should return an error", func(t *testing.T) { - require.Error(t, manager.AddServerWithSingboxJSON(ctx, []byte{})) +func TestAddServersByJSON(t *testing.T) { + t.Run("valid config", func(t *testing.T) { + testConfig := []byte(` +{ + "outbounds": [ + { + "tag": "out", + "type": "shadowsocks", + "server": "127.0.0.1", + "server_port": 1080, + "method": "chacha20-ietf-poly1305", + "password": "", + } + ] +}`) + type singboxConfig struct { + Outbounds []option.Outbound `json:"outbounds,omitempty"` + } + cfg, err := json.UnmarshalExtendedContext[singboxConfig](box.BaseContext(), testConfig) + require.NoError(t, err, "failed to unmarshal test config") + want := &Server{ + Tag: "out", + Type: "shadowsocks", + IsLantern: false, + Options: cfg.Outbounds[0], + } + m := testManager(t) + tags, err := m.AddServersByJSON(t.Context(), testConfig) + require.NoError(t, err) + assert.Equal(t, []string{"out"}, tags) + got, exists := m.GetServerByTag("out") + assert.True(t, exists, "server was not added") + assert.Equal(t, want.Tag, got.Tag) + assert.Equal(t, want.Type, got.Type) + assert.Equal(t, want.IsLantern, got.IsLantern) }) - t.Run("providing a json that doesn't have any endpoints or outbounds should return a error", func(t *testing.T) { - require.Error(t, manager.AddServerWithSingboxJSON(ctx, json.RawMessage("{}"))) + t.Run("empty config", func(t *testing.T) { + m := testManager(t) + _, err := m.AddServersByJSON(t.Context(), []byte("{}")) + assert.Error(t, err) + assert.Empty(t, m.servers, "no servers should have been added") }) } -func TestAddServerBasedOnURLs(t *testing.T) { - dataPath := t.TempDir() - manager := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: make(map[string]any), - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - } - ctx := context.Background() - after := func() { - manager.RemoveServer("VLESS+over+WS+with+TLS") - manager.RemoveServer("Trojan+with+TLS") - manager.RemoveServer("SpecialName") - } - - urls := strings.Join([]string{ +func TestAddServersByURL(t *testing.T) { + urls := []string{ "vless://uuid@host:443?encryption=none&security=tls&type=ws&host=example.com&path=/vless#VLESS+over+WS+with+TLS", "trojan://password@host:443?security=tls&sni=example.com#Trojan+with+TLS", - }, "\n") - t.Run("adding server based on URLs should work", func(t *testing.T) { - require.NoError(t, manager.AddServerBasedOnURLs(ctx, urls, false, "")) - assert.Contains(t, manager.optsMaps[SGUser], "VLESS+over+WS+with+TLS") - assert.Contains(t, manager.optsMaps[SGUser], "Trojan+with+TLS") - after() - }) - - t.Run("using empty URLs should return an error", func(t *testing.T) { - require.Error(t, manager.AddServerBasedOnURLs(ctx, "", false, "")) - }) - - t.Run("skip certificate verification option works", func(t *testing.T) { - require.NoError(t, manager.AddServerBasedOnURLs(ctx, urls, true, "")) - opts, isOutbound := manager.optsMaps[SGUser]["Trojan+with+TLS"].(option.Outbound) - require.True(t, isOutbound) - trojanSettings, ok := opts.Options.(*option.TrojanOutboundOptions) - require.True(t, ok) - require.NotNil(t, trojanSettings) - require.NotNil(t, trojanSettings.TLS) - assert.True(t, trojanSettings.OutboundTLSOptionsContainer.TLS.Insecure, trojanSettings.OutboundTLSOptionsContainer.TLS) - after() - }) - - url := "vless://uuid@host:443?encryption=none&security=tls&type=ws&host=example.com&path=/vless#VLESS+over+WS+with+TLS" - t.Run("adding single URL should work", func(t *testing.T) { - require.NoError(t, manager.AddServerBasedOnURLs(ctx, url, false, "SpecialName")) - assert.Contains(t, manager.optsMaps[SGUser], "SpecialName") - assert.NotContains(t, manager.optsMaps[SGUser], "VLESS+over+WS+with+TLS") - - require.NoError(t, manager.AddServerBasedOnURLs(ctx, url, false, "")) - assert.Contains(t, manager.optsMaps[SGUser], "VLESS+over+WS+with+TLS") - assert.Contains(t, manager.optsMaps[SGUser], "SpecialName") - after() - }) -} -func TestServers(t *testing.T) { - dataPath := t.TempDir() - manager := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: []option.Outbound{ - {Tag: "lantern-out", Type: "shadowsocks"}, - }, - Endpoints: []option.Endpoint{ - {Tag: "lantern-ep", Type: "shadowsocks"}, - }, - Locations: map[string]C.ServerLocation{ - "lantern-out": {City: "New York", Country: "US"}, - }, - }, - SGUser: Options{ - Outbounds: []option.Outbound{ - {Tag: "user-out", Type: "trojan"}, - }, - Endpoints: []option.Endpoint{ - {Tag: "user-ep", Type: "vless"}, - }, - Locations: map[string]C.ServerLocation{ - "user-out": {City: "London", Country: "GB"}, - }, - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: { - "lantern-out": option.Outbound{Tag: "lantern-out", Type: "shadowsocks"}, - "lantern-ep": option.Endpoint{Tag: "lantern-ep", Type: "shadowsocks"}, - }, - SGUser: { - "user-out": option.Outbound{Tag: "user-out", Type: "trojan"}, - "user-ep": option.Endpoint{Tag: "user-ep", Type: "vless"}, - }, - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), } - - t.Run("returns copy of servers", func(t *testing.T) { - servers := manager.Servers() - - require.NotNil(t, servers) - require.Contains(t, servers, SGLantern) - require.Contains(t, servers, SGUser) - - assert.Len(t, servers[SGLantern].Outbounds, 1) - assert.Len(t, servers[SGLantern].Endpoints, 1) - assert.Equal(t, "lantern-out", servers[SGLantern].Outbounds[0].Tag) - assert.Equal(t, "lantern-ep", servers[SGLantern].Endpoints[0].Tag) - - assert.Len(t, servers[SGUser].Outbounds, 1) - assert.Len(t, servers[SGUser].Endpoints, 1) - assert.Equal(t, "user-out", servers[SGUser].Outbounds[0].Tag) - assert.Equal(t, "user-ep", servers[SGUser].Endpoints[0].Tag) - - assert.Equal(t, "New York", servers[SGLantern].Locations["lantern-out"].City) - assert.Equal(t, "London", servers[SGUser].Locations["user-out"].City) + t.Run("valid urls", func(t *testing.T) { + m := testManager(t) + tags, err := m.AddServersByURL(t.Context(), urls, false) + require.NoError(t, err) + assert.Len(t, tags, 2) + _, exists := m.GetServerByTag("VLESS+over+WS+with+TLS") + assert.True(t, exists, "VLESS server should be added") + _, exists = m.GetServerByTag("Trojan+with+TLS") + assert.True(t, exists, "Trojan server should be added") }) - - t.Run("modifications to returned copy don't affect original", func(t *testing.T) { - servers := manager.Servers() - assert.Len(t, servers[SGLantern].Outbounds, 1) - assert.Len(t, servers[SGUser].Endpoints, 1) - - // Modify the copy - servers[SGLantern].Outbounds[0].Tag = "modified-out" - - // Original should remain unchanged - originalServers := manager.Servers() - assert.NotEqual(t, originalServers[SGLantern].Outbounds[0].Tag, "modified-out") + t.Run("skip certificate", func(t *testing.T) { + m := testManager(t) + _, err := m.AddServersByURL(t.Context(), urls, true) + require.NoError(t, err) + server, exists := m.GetServerByTag("Trojan+with+TLS") + require.True(t, exists, "Trojan server should be added") + + options := server.Options.(option.Outbound).Options + require.IsType(t, &option.TrojanOutboundOptions{}, options) + trojanOpts := options.(*option.TrojanOutboundOptions) + require.NotNil(t, trojanOpts.TLS) + assert.True(t, trojanOpts.TLS.Insecure, "TLS.Insecure should be true") }) - - t.Run("handles empty servers", func(t *testing.T) { - emptyManager := &Manager{ - servers: Servers{ - SGLantern: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - SGUser: Options{ - Outbounds: make([]option.Outbound, 0), - Endpoints: make([]option.Endpoint, 0), - Locations: make(map[string]C.ServerLocation), - }, - }, - optsMaps: map[ServerGroup]map[string]any{ - SGLantern: make(map[string]any), - SGUser: make(map[string]any), - }, - serversFile: filepath.Join(dataPath, common.ServersFileName), - } - - servers := emptyManager.Servers() - require.NotNil(t, servers) - assert.Len(t, servers[SGLantern].Outbounds, 0) - assert.Len(t, servers[SGLantern].Endpoints, 0) - assert.Len(t, servers[SGUser].Outbounds, 0) - assert.Len(t, servers[SGUser].Endpoints, 0) + t.Run("empty urls", func(t *testing.T) { + m := testManager(t) + _, err := m.AddServersByURL(t.Context(), []string{}, false) + assert.Error(t, err) + assert.Empty(t, m.servers, "no servers should have been added") }) } @@ -469,7 +215,7 @@ func TestServers(t *testing.T) { // previous implementation, two concurrent saveServers calls could reorder // their marshal/write sequence and leave the older snapshot on disk. func TestSaveServersConcurrent(t *testing.T) { - mgr := newTestManager(t) + mgr := testManager(t) // Run many concurrent mutations. const concurrency = 20 @@ -480,19 +226,23 @@ func TestSaveServersConcurrent(t *testing.T) { defer func() { done <- struct{}{} }() for j := 0; j < opsPerGoroutine; j++ { tag := fmt.Sprintf("concurrent-%d-%d", id, j) - opts := Options{ - Outbounds: []option.Outbound{{Tag: tag, Type: "shadowsocks", Options: &option.ShadowsocksOutboundOptions{ - ServerOptions: option.ServerOptions{Server: "9.9.9.9", ServerPort: 443}, - Method: "chacha20-ietf-poly1305", - Password: "pw", - }}}, - Endpoints: []option.Endpoint{}, - Locations: map[string]C.ServerLocation{ - tag: {Country: "US", City: "X", CountryCode: "US"}, - }, - Credentials: map[string]ServerCredentials{}, + list := ServerList{ + Servers: []*Server{{ + Tag: tag, + Type: "shadowsocks", + Options: option.Outbound{ + Tag: tag, + Type: "shadowsocks", + Options: &option.ShadowsocksOutboundOptions{ + ServerOptions: option.ServerOptions{Server: "9.9.9.9", ServerPort: 443}, + Method: "chacha20-ietf-poly1305", + Password: "pw", + }, + }, + Location: C.ServerLocation{Country: "US", City: "X", CountryCode: "US"}, + }}, } - _ = mgr.AddServers(SGUser, opts) + _ = mgr.AddServers(list, true) } }(i) } @@ -500,24 +250,36 @@ func TestSaveServersConcurrent(t *testing.T) { <-done } - // After all concurrent operations, the on-disk file must match the - // current in-memory state. If saves can reorder, the file would lag. - inMem, err := mgr.ServersJSON() - require.NoError(t, err) - - // Force a final save so we compare against the last committed snapshot. + // After all concurrent operations, force a save and reload into a fresh + // manager. The reloaded state must have exactly the same servers as the + // original — if saves can reorder, the file would lag behind. require.NoError(t, mgr.saveServers()) - onDiskBytes, err := os.ReadFile(mgr.serversFile) - require.NoError(t, err) - assert.JSONEq(t, string(inMem), string(onDiskBytes), "disk file must match in-memory state after concurrent saves") + mgr2 := testManager(t) + mgr2.serversFile = mgr.serversFile + require.NoError(t, mgr2.loadServers()) + assert.Equal(t, len(mgr.AllServers()), len(mgr2.AllServers()), + "reloaded server count must match in-memory count") + + for _, srv := range mgr.AllServers() { + _, ok := mgr2.GetServerByTag(srv.Tag) + assert.True(t, ok, "server %q must survive save/reload", srv.Tag) + } } func TestRetryableHTTPClient(t *testing.T) { - cli := retryableHTTPClient().StandardClient() + cli := retryableHTTPClient(log.NoOpLogger()).StandardClient() request, err := http.NewRequest(http.MethodGet, "https://www.gstatic.com/generate_204", http.NoBody) require.NoError(t, err) resp, err := cli.Do(request) require.NoError(t, err) assert.Equal(t, http.StatusNoContent, resp.StatusCode) } + +func testManager(t *testing.T) *Manager { + return &Manager{ + servers: make(map[string]*Server), + serversFile: filepath.Join(t.TempDir(), internal.ServersFileName), + logger: log.NoOpLogger(), + } +} diff --git a/telemetry/connections.go b/telemetry/connections.go index acc55d1e..9029f15f 100644 --- a/telemetry/connections.go +++ b/telemetry/connections.go @@ -9,32 +9,40 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" - "github.com/getlantern/radiance/vpn/ipc" + "github.com/getlantern/radiance/vpn" ) -// harvestConnectionMetrics periodically polls the number of active connections and their total -// upload and download bytes, setting the corresponding OpenTelemetry metrics. It returns a function -// that can be called to stop the polling. -func harvestConnectionMetrics(pollInterval time.Duration) func() { +// ConnectionSource provides access to the current VPN connections for metrics collection. +type ConnectionSource interface { + Connections() ([]vpn.Connection, error) +} + +// StartConnectionMetrics periodically polls the number of active connections and their total +// upload and download bytes, setting the corresponding OpenTelemetry metrics. It runs until the +// provided context is canceled. +func StartConnectionMetrics(ctx context.Context, src ConnectionSource, pollInterval time.Duration) { ticker := time.NewTicker(pollInterval) meter := otel.Meter("github.com/getlantern/radiance/metrics") currentActiveConnections, err := meter.Int64Counter("current_active_connections", metric.WithDescription("Current number of active connections")) if err != nil { slog.Warn("failed to create current_active_connections metric", slog.Any("error", err)) + return } connectionDuration, err := meter.Float64Histogram("connection_duration_seconds", metric.WithDescription("Duration of connections in seconds"), metric.WithUnit("s")) if err != nil { slog.Warn("failed to create connection_duration_seconds metric", slog.Any("error", err)) + return } downlinkBytes, err := meter.Int64Counter("downlink_bytes", metric.WithDescription("Total downlink bytes across all connections"), metric.WithUnit("By")) if err != nil { slog.Warn("failed to create downlink_bytes metric", slog.Any("error", err)) + return } uplinkBytes, err := meter.Int64Counter("uplink_bytes", metric.WithDescription("Total uplink bytes across all connections"), metric.WithUnit("By")) if err != nil { slog.Warn("failed to create uplink_bytes metric", slog.Any("error", err)) + return } - ctx, cancel := context.WithCancel(context.Background()) go func() { seenConnections := make(map[string]bool) for { @@ -44,20 +52,16 @@ func harvestConnectionMetrics(pollInterval time.Duration) func() { return case <-ticker.C: slog.Debug("polling connections for metrics", slog.Int("seen_connections", len(seenConnections)), slog.Duration("poll_interval", pollInterval)) - vpnStatus, err := ipc.GetStatus(ctx) + conns, err := src.Connections() if err != nil { - slog.Warn("failed to get service status", "error", err) - } - if vpnStatus != ipc.Connected { - continue - } - conns, err := ipc.GetConnections(ctx) - if err != nil { - slog.Warn("failed to retrieve connections", slog.Any("error", err)) + slog.Debug("failed to retrieve connections for metrics", slog.Any("error", err)) continue } + // Track which connections are still reported so we can prune stale entries. + currentIDs := make(map[string]struct{}, len(conns)) for _, c := range conns { + currentIDs[c.ID] = struct{}{} attributes := attribute.NewSet( attribute.String("from_outbound", c.FromOutbound), attribute.String("outbound_name", c.Outbound), @@ -92,8 +96,14 @@ func harvestConnectionMetrics(pollInterval time.Duration) func() { downlinkBytes.Add(ctx, c.Downlink, metric.WithAttributeSet(attributes)) uplinkBytes.Add(ctx, c.Uplink, metric.WithAttributeSet(attributes)) } + + // Remove entries for connections no longer reported by the source. + for id := range seenConnections { + if _, ok := currentIDs[id]; !ok { + delete(seenConnections, id) + } + } } } }() - return cancel } diff --git a/telemetry/otel.go b/telemetry/otel.go index 0114006f..d3de6082 100644 --- a/telemetry/otel.go +++ b/telemetry/otel.go @@ -18,10 +18,12 @@ import ( "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + metricNoop "go.opentelemetry.io/otel/metric/noop" "go.opentelemetry.io/otel/propagation" sdkmetric "go.opentelemetry.io/otel/sdk/metric" "go.opentelemetry.io/otel/sdk/resource" sdktrace "go.opentelemetry.io/otel/sdk/trace" + traceNoop "go.opentelemetry.io/otel/trace/noop" "google.golang.org/grpc/credentials" semconv "github.com/getlantern/semconv" @@ -33,10 +35,8 @@ import ( ) var ( - initMutex sync.Mutex - shutdownOTEL func(context.Context) error - harvestConnections sync.Once - harvestConnectionTickerStop func() + initMutex sync.Mutex + shutdownOTEL func(context.Context) error ) type Attributes struct { @@ -58,18 +58,18 @@ type Attributes struct { // OnNewConfig handles OpenTelemetry re-initialization when the configuration changes. func OnNewConfig(oldConfig, newConfig *config.Config, deviceID string) error { // Check if the old OTEL configuration is the same as the new one. - if oldConfig != nil && reflect.DeepEqual(oldConfig.ConfigResponse.OTEL, newConfig.ConfigResponse.OTEL) { + if oldConfig != nil && reflect.DeepEqual(oldConfig.OTEL, newConfig.OTEL) { slog.Debug("OpenTelemetry configuration has not changed, skipping initialization") return nil } - if err := initialize(deviceID, newConfig.ConfigResponse, settings.IsPro()); err != nil { + if err := Initialize(deviceID, *newConfig, settings.IsPro()); err != nil { slog.Error("Failed to initialize OpenTelemetry", "error", err) - return fmt.Errorf("Failed to initialize OpenTelemetry: %w", err) + return fmt.Errorf("failed to initialize OpenTelemetry: %w", err) } return nil } -func initialize(deviceID string, configResponse common.ConfigResponse, pro bool) error { +func Initialize(deviceID string, configResponse config.Config, pro bool) error { initMutex.Lock() defer initMutex.Unlock() @@ -78,6 +78,15 @@ func initialize(deviceID string, configResponse common.ConfigResponse, pro bool) return nil } + // QA: when env.OutboundSocksAddress is set, the OTLP gRPC exporters do + // NOT honor the radiance dialer override and would phone home directly, + // leaking the test process's real IP and bypassing the SOCKS5 egress. + // Skip telemetry init in that mode. + if addr, ok := env.Get(env.OutboundSocksAddress); ok && addr != "" { + slog.Info("RADIANCE_OUTBOUND_SOCKS_ADDRESS set — skipping OpenTelemetry init (gRPC exporters cannot be routed via SOCKS5)", "addr", addr) + return nil + } + if shutdownOTEL != nil { slog.Info("Shutting down existing OpenTelemetry SDK") if err := shutdownOTEL(context.Background()); err != nil { @@ -90,7 +99,7 @@ func initialize(deviceID string, configResponse common.ConfigResponse, pro bool) attrs := Attributes{ App: "radiance", DeviceID: deviceID, - AppVersion: rcommon.Version, + AppVersion: rcommon.GetVersion(), Platform: rcommon.Platform, GoVersion: runtime.Version(), OSName: runtime.GOOS, @@ -109,24 +118,19 @@ func initialize(deviceID string, configResponse common.ConfigResponse, pro bool) } shutdownOTEL = shutdown - - harvestConnections.Do(func() { - harvestConnectionTickerStop = harvestConnectionMetrics(1 * time.Minute) - }) return nil } -func Close(ctx context.Context) error { +func Close() error { + return CloseContext(context.Background()) +} + +func CloseContext(ctx context.Context) error { initMutex.Lock() defer initMutex.Unlock() var errs error - // stop collecting connection metrics - if harvestConnectionTickerStop != nil { - harvestConnectionTickerStop() - } - if shutdownOTEL != nil { slog.Info("Shutting down existing OpenTelemetry SDK") if err := shutdownOTEL(ctx); err != nil { @@ -135,12 +139,14 @@ func Close(ctx context.Context) error { } shutdownOTEL = nil } + otel.SetTracerProvider(traceNoop.NewTracerProvider()) + otel.SetMeterProvider(metricNoop.NewMeterProvider()) return errs } func buildResources(serviceName string, a Attributes) []attribute.KeyValue { e := "prod" - if v, ok := env.Get[string](env.ENV); ok { + if v := env.GetString(env.ENV); v != "" { e = v } return []attribute.KeyValue{ @@ -164,7 +170,7 @@ func buildResources(serviceName string, a Attributes) []attribute.KeyValue { // setupOTelSDK bootstraps the OpenTelemetry pipeline. // If it does not return an error, make sure to call shutdown for proper cleanup. -func setupOTelSDK(ctx context.Context, attributes Attributes, cfg common.ConfigResponse) (func(context.Context) error, error) { +func setupOTelSDK(ctx context.Context, attributes Attributes, cfg config.Config) (func(context.Context) error, error) { if cfg.Features == nil { cfg.Features = make(map[string]bool) } diff --git a/tester/main.go b/tester/main.go index 10120f8c..7a0d7470 100644 --- a/tester/main.go +++ b/tester/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "log/slog" "os" @@ -8,10 +9,11 @@ import ( "strconv" "time" - "github.com/getlantern/radiance" + "github.com/getlantern/radiance/backend" "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/config" "github.com/getlantern/radiance/events" + "github.com/getlantern/radiance/ipc" "github.com/getlantern/radiance/vpn" ) @@ -20,7 +22,7 @@ func performLanternPing(urlToHit string, runId string, deviceId string, userId i os.RemoveAll(dataDir) } os.MkdirAll(dataDir, 0o755) - r, err := radiance.NewRadiance(radiance.Options{ + be, err := backend.NewLocalBackend(context.Background(), backend.Options{ DataDir: dataDir, LogDir: dataDir, Locale: "en-US", @@ -28,7 +30,7 @@ func performLanternPing(urlToHit string, runId string, deviceId string, userId i if err != nil { return fmt.Errorf("failed to create radiance instance: %w", err) } - defer r.Close() + defer be.Close() settings.Set(settings.UserIDKey, userId) settings.Set(settings.TokenKey, token) settings.Set(settings.UserLevelKey, "") @@ -40,14 +42,16 @@ func performLanternPing(urlToHit string, runId string, deviceId string, userId i }, }) - ipcServer, err := vpn.InitIPC(dataDir, "", "trace", nil) + be.Start() + + ipcServer := ipc.NewServer(be, false) + err = ipcServer.Start() if err != nil { return fmt.Errorf("failed to initialize IPC server: %w", err) } exit := func() { - status, _ := vpn.GetStatus() - if status.TunnelOpen { - vpn.Disconnect() + if be.VPNStatus() != vpn.Disconnected { + be.DisconnectVPN() } ipcServer.Close() } @@ -70,7 +74,7 @@ func performLanternPing(urlToHit string, runId string, deviceId string, userId i } } t1 := time.Now() - if err = vpn.QuickConnect("all", nil); err != nil { + if err = be.ConnectVPN(vpn.AutoSelectTag); err != nil { return fmt.Errorf("quick connect failed: %w", err) } fmt.Println("Quick connect successful") @@ -79,7 +83,7 @@ func performLanternPing(urlToHit string, runId string, deviceId string, userId i proxyAddr := os.Getenv("RADIANCE_SOCKS_ADDRESS") if proxyAddr == "" { - proxyAddr = "127.0.0.1:6666" + proxyAddr = "127.0.0.1:6666" } cmd := exec.Command("curl", "-v", "-x", proxyAddr, "-s", urlToHit) diff --git a/traces/errors.go b/traces/errors.go index cccef67e..6ed6c319 100644 --- a/traces/errors.go +++ b/traces/errors.go @@ -7,6 +7,7 @@ import ( "go.opentelemetry.io/otel/trace" ) +// RecordError records the given error in the current span. If error is nil, it is noop. func RecordError(ctx context.Context, err error, options ...trace.EventOption) error { if err == nil { return nil diff --git a/vpn/boxoptions.go b/vpn/boxoptions.go index 06918e6a..3a742b8f 100644 --- a/vpn/boxoptions.go +++ b/vpn/boxoptions.go @@ -7,20 +7,23 @@ import ( "errors" "fmt" "log/slog" + "net" "net/netip" "path/filepath" + "slices" "strconv" "strings" "time" - lcommon "github.com/getlantern/common" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + lcommon "github.com/getlantern/common" box "github.com/getlantern/lantern-box" lbC "github.com/getlantern/lantern-box/constant" lbO "github.com/getlantern/lantern-box/option" + "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" O "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/json" @@ -30,17 +33,14 @@ import ( "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/atomicfile" "github.com/getlantern/radiance/common/env" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/config" + "github.com/getlantern/radiance/common/fileperm" "github.com/getlantern/radiance/internal" - "github.com/getlantern/radiance/servers" + "github.com/getlantern/radiance/log" ) const ( - autoAllTag = "auto" - - autoLanternTag = "auto-lantern" - autoUserTag = "auto-user" + AutoSelectTag = "auto" + ManualSelectTag = "manual" urlTestInterval = 3 * time.Minute // must be less than urlTestIdleTimeout urlTestIdleTimeout = 15 * time.Minute @@ -54,11 +54,43 @@ const ( minAndroidSystemStackKernel = "5.10" ) -// this is the base options that is need for everything to work correctly. this should not be -// changed unless you know what you're doing. +var reservedTags = []string{AutoSelectTag, ManualSelectTag, "direct", "block"} + +func ReservedTags() []string { + return slices.Clone(reservedTags) +} + +type BoxOptions struct { + BasePath string `json:"base_path,omitempty"` + // Options contains the main options that are merged into the base options with the exception of + // DNS, which overrides the base DNS options entirely instead of being merged. Options should + // contain all servers (both lantern and user). + Options O.Options `json:"options"` + // SmartRouting contains smart routing rules to merge into the final options. + SmartRouting lcommon.SmartRoutingRules `json:"smart_routing,omitempty"` + // AdBlock contains ad block rules to merge into the final options. + AdBlock lcommon.AdBlockRules `json:"ad_block,omitempty"` + // InitialServer chooses the outbound selected when the tunnel starts. + // Empty or AutoSelectTag puts the tunnel in auto mode; any other tag + // must match an outbound or endpoint and forces manual selection. + InitialServer string `json:"initial_server,omitempty"` + // BanditURLOverrides maps outbound tags to per-proxy callback URLs for + // the bandit Thompson sampling system. When set, these override the + // default MutableURLTest URL for each specific outbound, allowing the + // server to detect which proxies successfully connected. + BanditURLOverrides map[string]string `json:"bandit_url_overrides,omitempty"` + BanditThroughputURL string `json:"bandit_throughput_url,omitempty"` + // URLTestSeed seeds the tunnel's URL test history storage at startup so + // prior latency results survive across tunnel close/open. Keyed by + // outbound/endpoint tag. + URLTestSeed map[string]adapter.URLTestHistory `json:"-"` +} + +// baseOpts returns the minimum sing-box options required for the tunnel to +// function. Do not modify without understanding the downstream effects. func baseOpts(basePath string) O.Options { splitTunnelPath := filepath.Join(basePath, splitTunnelFile) - + cacheFile := filepath.Join(basePath, cacheFileName) loopbackAddr := badoption.Addr(netip.MustParseAddr("127.0.0.1")) return O.Options{ Log: &O.LogOptions{ @@ -99,16 +131,6 @@ func baseOpts(basePath string) O.Options { }, }, }, - { - Type: C.TypeMixed, - Tag: bypass.TunnelInboundTag, - Options: &O.HTTPMixedInboundOptions{ - ListenOptions: O.ListenOptions{ - Listen: &loopbackAddr, - ListenPort: bypass.TunnelProxyPort, - }, - }, - }, }, Outbounds: []O.Outbound{ { @@ -138,13 +160,13 @@ func baseOpts(basePath string) O.Options { }, Experimental: &O.ExperimentalOptions{ ClashAPI: &O.ClashAPIOptions{ - DefaultMode: autoAllTag, - ModeList: []string{servers.SGLantern, servers.SGUser, autoAllTag}, + DefaultMode: AutoSelectTag, + ModeList: []string{ManualSelectTag, AutoSelectTag}, ExternalController: "", // intentionally left empty }, CacheFile: &O.CacheFileOptions{ Enabled: true, - Path: cacheFileName, + Path: cacheFile, CacheID: cacheID, }, }, @@ -245,25 +267,21 @@ func baseRoutingRules() []O.Rule { } // buildOptions builds the box options using the config options and user servers. -func buildOptions(ctx context.Context, path string) (O.Options, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "buildOptions") +func buildOptions(bOptions BoxOptions) (O.Options, error) { + _, span := otel.Tracer(tracerName).Start(context.Background(), "buildOptions") defer span.End() - slog.Log(nil, internal.LevelTrace, "Starting buildOptions", "path", path) - - opts := baseOpts(path) - slog.Debug("Base options initialized") + if len(bOptions.Options.Outbounds) == 0 && len(bOptions.Options.Endpoints) == 0 { + return O.Options{}, errors.New("no outbounds or endpoints found in config or user servers") + } - // update default options and paths - opts.Experimental.CacheFile.Path = filepath.Join(path, cacheFileName) + slog.Log(nil, log.LevelTrace, "Starting buildOptions", "path", bOptions.BasePath) - slog.Log(nil, internal.LevelTrace, "Updated default options and paths", - "cacheFilePath", opts.Experimental.CacheFile.Path, - "clashAPIDefaultMode", opts.Experimental.ClashAPI.DefaultMode, - ) + opts := baseOpts(bOptions.BasePath) + slog.Debug("Base options initialized") - if _, useSocks := env.Get[bool](env.UseSocks); useSocks { - socksAddr, _ := env.Get[string](env.SocksAddress) + if env.GetBool(env.UseSocks) { + socksAddr, _ := env.Get(env.SocksAddress) slog.Info("Using SOCKS proxy for inbound as per environment variable", "socksAddr", socksAddr) addrPort, err := netip.ParseAddrPort(socksAddr) if err != nil { @@ -282,7 +300,6 @@ func buildOptions(ctx context.Context, path string) (O.Options, error) { } opts.Inbounds = []O.Inbound{socksIn} } else { - switch common.Platform { case "android": opts.Route.OverrideAndroidVPN = true @@ -301,82 +318,64 @@ func buildOptions(ctx context.Context, path string) (O.Options, error) { } } - // Load config file - confPath := filepath.Join(path, common.ConfigFileName) - slog.Debug("Loading config file", "confPath", confPath) - cfg, err := loadConfig(confPath) - if err != nil { - slog.Error("Failed to load config options", "error", err) - return O.Options{}, err - } - // add smart routing and ad block rules - if settings.GetBool(settings.SmartRoutingKey) && len(cfg.SmartRouting) > 0 { - slog.Debug("Adding smart-routing rules") - smartRoutingRules := normalizeSmartRoutingRules(cfg.SmartRouting) + smartRoutingRules := normalizeSmartRoutingRules(bOptions.SmartRouting) + if len(smartRoutingRules) > 0 { + slog.Info("Adding smart-routing rules") outbounds, rules, rulesets := smartRoutingRules.ToOptions(urlTestInterval, urlTestIdleTimeout) opts.Outbounds = append(opts.Outbounds, outbounds...) opts.Route.Rules = append(opts.Route.Rules, rules...) opts.Route.RuleSet = append(opts.Route.RuleSet, rulesets...) + } else if len(bOptions.SmartRouting) > 0 && len(smartRoutingRules) == 0 { + slog.Warn("No valid smart-routing rules found after normalization, skipping smart-routing configuration") } - if settings.GetBool(settings.AdBlockKey) && len(cfg.AdBlock) > 0 { - slog.Debug("Adding ad-block rules") - rule, rulesets := cfg.AdBlock.ToOptions() + adBlockRules := normalizeAdBlockRules(bOptions.AdBlock) + if len(adBlockRules) > 0 { + slog.Info("Adding ad-block rules") + rule, rulesets := bOptions.AdBlock.ToOptions() opts.Route.Rules = append(opts.Route.Rules, rule) opts.Route.RuleSet = append(opts.Route.RuleSet, rulesets...) + } else if len(bOptions.AdBlock) > 0 && len(adBlockRules) == 0 { + slog.Warn("No valid ad-block rules found after normalization, skipping ad-block configuration") } - var lanternTags []string - configOpts := cfg.Options - if len(configOpts.Outbounds) == 0 && len(configOpts.Endpoints) == 0 { - slog.Warn("Config loaded but no outbounds or endpoints found") - } - lanternTags = mergeAndCollectTags(&opts, &configOpts) - slog.Debug("Merged config options", "tags", lanternTags) - - appendGroupOutbounds(&opts, servers.SGLantern, autoLanternTag, lanternTags, cfg.BanditURLOverrides) - - // Load user servers - slog.Debug("Loading user servers") - userOpts, err := loadUserOptions(path) - if err != nil { - slog.Error("Failed to load user servers", "error", err) - return O.Options{}, err - } - var userTags []string - if len(userOpts.Outbounds) == 0 && len(userOpts.Endpoints) == 0 { - slog.Info("No user servers found") + tags := mergeAndCollectTags(&opts, &bOptions.Options) + initial := bOptions.InitialServer + if initial == "" || initial == AutoSelectTag { + opts.Experimental.ClashAPI.DefaultMode = AutoSelectTag } else { - userTags = mergeAndCollectTags(&opts, &userOpts) - slog.Debug("Merged user server options", "tags", userTags) + // The manual selector defaults to its first tag, so place initial at index 0. + i := slices.Index(tags, initial) + if i == -1 { + return O.Options{}, fmt.Errorf("initial server tag %q not found in outbounds or endpoints", initial) + } + tags[0], tags[i] = tags[i], tags[0] + opts.Experimental.ClashAPI.DefaultMode = ManualSelectTag } - appendGroupOutbounds(&opts, servers.SGUser, autoUserTag, userTags, nil) - if len(lanternTags) == 0 && len(userTags) == 0 { - return O.Options{}, errors.New("no outbounds or endpoints found in config or user servers") + // QA: route every leaf outbound through an upstream SOCKS5 (e.g. one that + // egresses through a residential proxy in the country we want to simulate) + // before reaching its real destination. See env.OutboundSocksAddress. + if err := applyOutboundSocksDetour(&opts); err != nil { + return O.Options{}, err } - // Add auto all outbound - opts.Outbounds = append(opts.Outbounds, urlTestOutbound(autoAllTag, []string{autoLanternTag, autoUserTag}, nil)) - - // Add routing rules for the groups - opts.Route.Rules = append(opts.Route.Rules, groupRule(autoAllTag)) - opts.Route.Rules = append(opts.Route.Rules, groupRule(servers.SGLantern)) - opts.Route.Rules = append(opts.Route.Rules, groupRule(servers.SGUser)) + // add mode selector outbounds and rules + opts.Outbounds = append(opts.Outbounds, urlTestOutbound(AutoSelectTag, tags, bOptions.BanditURLOverrides)) + opts.Outbounds = append(opts.Outbounds, selectorOutbound(ManualSelectTag, tags)) + opts.Route.Rules = append(opts.Route.Rules, selectModeRule(AutoSelectTag)) + opts.Route.Rules = append(opts.Route.Rules, selectModeRule(ManualSelectTag)) // catch-all rule to ensure no fallthrough opts.Route.Rules = append(opts.Route.Rules, catchAllBlockerRule()) - slog.Debug("Finished building options", slog.String("env", common.Env())) + slog.Debug("Finished building options", "env", common.Env()) span.AddEvent("finished building options", trace.WithAttributes( - attribute.String("options", string(writeBoxOptions(path, opts))), - attribute.String("env", common.Env()), + attribute.String("options", string(writeBoxOptions(bOptions.BasePath, opts))), )) return opts, nil } -const debugLanternBoxOptionsFilename = "debug-lantern-box-options.json" - // writeBoxOptions marshals the options as JSON and stores them in a file so we can debug them // we can ignore the errors here since the tunnel will error out anyway if something is wrong func writeBoxOptions(path string, opts O.Options) []byte { @@ -391,35 +390,72 @@ func writeBoxOptions(path string, opts O.Options) []byte { slog.Warn("failed to indent marshaled options while writing debug box options", slog.Any("error", err)) return buf } - if err := atomicfile.WriteFile(filepath.Join(path, debugLanternBoxOptionsFilename), b.Bytes(), 0644); err != nil { + if err := atomicfile.WriteFile(filepath.Join(path, internal.DebugBoxOptionsFileName), b.Bytes(), fileperm.File); err != nil { slog.Warn("failed to write options file", slog.Any("error", err)) return buf } return b.Bytes() } -/////////////////////// +////////////////////// // Helper functions // ////////////////////// -func loadConfig(path string) (lcommon.ConfigResponse, error) { - cfg, err := config.Load(path) +// devOutboundSocksTag is the tag of the synthetic SOCKS5 outbound injected +// when env.OutboundSocksAddress is set. Other outbounds get DialerOptions.Detour +// pointing at this tag, so every real dial is wrapped in a SOCKS5 connection. +const devOutboundSocksTag = "_dev_outbound_socks" + +// applyOutboundSocksDetour appends a SOCKS5 outbound to opts and rewrites every +// other leaf outbound to dial through it, when env.OutboundSocksAddress is set. +// Selector / urltest / block / dns outbounds are skipped — they don't dial +// directly. No-op when the env var is unset. +func applyOutboundSocksDetour(opts *O.Options) error { + addr, ok := env.Get(env.OutboundSocksAddress) + if !ok || addr == "" { + return nil + } + host, portStr, err := net.SplitHostPort(addr) if err != nil { - return lcommon.ConfigResponse{}, fmt.Errorf("load config: %w", err) + return fmt.Errorf("invalid RADIANCE_OUTBOUND_SOCKS_ADDRESS %q: %w", addr, err) } - if cfg == nil { - return lcommon.ConfigResponse{}, nil + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return fmt.Errorf("invalid RADIANCE_OUTBOUND_SOCKS_ADDRESS port %q: %w", portStr, err) } - return cfg.ConfigResponse, nil -} -func loadUserOptions(path string) (O.Options, error) { - mgr, err := servers.NewManager(path) - if err != nil { - return O.Options{}, fmt.Errorf("server manager: %w", err) + for i := range opts.Outbounds { + out := &opts.Outbounds[i] + switch out.Type { + case C.TypeSelector, C.TypeURLTest, C.TypeBlock, C.TypeDNS, C.TypeDirect: + // selector/urltest wrap others; block/dns/direct don't dial + // real upstream proxies. `direct` in particular rejects Detour + // at runtime ("detour is not supported in direct context"). + continue + } + if w, ok := out.Options.(O.DialerOptionsWrapper); ok { + d := w.TakeDialerOptions() + d.Detour = devOutboundSocksTag + w.ReplaceDialerOptions(d) + } } - u := mgr.Servers()[servers.SGUser] - return O.Options{Outbounds: u.Outbounds, Endpoints: u.Endpoints}, nil + + opts.Outbounds = append(opts.Outbounds, O.Outbound{ + Type: C.TypeSOCKS, + Tag: devOutboundSocksTag, + Options: &O.SOCKSOutboundOptions{ + ServerOptions: O.ServerOptions{ + Server: host, + ServerPort: uint16(port), + }, + Version: "5", + }, + }) + + slog.Info("RADIANCE_OUTBOUND_SOCKS_ADDRESS set — every sing-box outbound will dial via this SOCKS5", + slog.String("addr", addr), + slog.Int("rewritten_outbounds", len(opts.Outbounds)-1)) + return nil } // mergeAndCollectTags merges src into dst and returns all outbound/endpoint tags from src. @@ -470,38 +506,18 @@ func normalizeSmartRoutingRules(rules lcommon.SmartRoutingRules) lcommon.SmartRo return normalized } -func useIfNotZero[T comparable](newVal, oldVal T) T { - var zero T - if newVal != zero { - return newVal - } - return oldVal -} - -func appendGroupOutbounds(opts *O.Options, serverGroup, autoTag string, tags []string, urlOverrides map[string]string) { - // All outbounds go in the URL test group — the server now sends callback - // URLs for every outbound, and the dependency's worker pool bounds memory. - opts.Outbounds = append(opts.Outbounds, urlTestOutbound(autoTag, tags, urlOverrides)) - opts.Outbounds = append(opts.Outbounds, selectorOutbound(serverGroup, append([]string{autoTag}, tags...))) - slog.Log( - nil, internal.LevelTrace, "Added group outbounds", - "serverGroup", serverGroup, - "tags", tags, - "outbounds", opts.Outbounds[len(opts.Outbounds)-2:], - ) -} - -func groupAutoTag(group string) string { - switch group { - case servers.SGLantern: - return autoLanternTag - case servers.SGUser: - return autoUserTag - case "all", "": - return autoAllTag - default: - return "" +func normalizeAdBlockRules(rules lcommon.AdBlockRules) lcommon.AdBlockRules { + normalized := make(lcommon.AdBlockRules, 0, len(rules)) + for _, rule := range rules { + tag := strings.TrimSpace(rule.Tag) + if tag == "" { + slog.Warn("Skipping ad-block rule with empty tag") + continue + } + rule.Tag = tag + normalized = append(normalized, rule) } + return normalized } func urlTestOutbound(tag string, outbounds []string, urlOverrides map[string]string) O.Outbound { @@ -518,27 +534,27 @@ func urlTestOutbound(tag string, outbounds []string, urlOverrides map[string]str } } -func selectorOutbound(group string, outbounds []string) O.Outbound { +func selectorOutbound(tag string, outbounds []string) O.Outbound { return O.Outbound{ Type: lbC.TypeMutableSelector, - Tag: group, + Tag: tag, Options: &lbO.MutableSelectorOutboundOptions{ Outbounds: outbounds, }, } } -func groupRule(group string) O.Rule { +func selectModeRule(mode string) O.Rule { return O.Rule{ Type: C.RuleTypeDefault, DefaultOptions: O.DefaultRule{ RawDefaultRule: O.RawDefaultRule{ - ClashMode: group, + ClashMode: mode, }, RuleAction: O.RuleAction{ Action: C.RuleActionTypeRoute, RouteOptions: O.RouteActionOptions{ - Outbound: group, + Outbound: mode, }, }, }, @@ -637,4 +653,3 @@ func newDNSServerOptions(typ, tag, server, domainResolver string) O.DNSServerOpt Options: serverOpts, } } - diff --git a/vpn/boxoptions_test.go b/vpn/boxoptions_test.go index 1f60b6d6..6b4ef59c 100644 --- a/vpn/boxoptions_test.go +++ b/vpn/boxoptions_test.go @@ -1,118 +1,56 @@ package vpn import ( - "context" - "fmt" "os" - "path/filepath" "slices" "testing" - "github.com/sagernet/sing-box/constant" O "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/json" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - LC "github.com/getlantern/common" box "github.com/getlantern/lantern-box" lbO "github.com/getlantern/lantern-box/option" - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/config" - "github.com/getlantern/radiance/servers" ) func TestBuildOptions(t *testing.T) { - testOpts, _, err := testBoxOptions("") - require.NoError(t, err, "get test box options") - lanternTags, lanternOuts := filterOutbounds(*testOpts, constant.TypeHTTP) - userTags, userOuts := filterOutbounds(*testOpts, constant.TypeSOCKS) - cfg := config.Config{ - ConfigResponse: LC.ConfigResponse{ - Options: O.Options{ - Outbounds: lanternOuts, - }, - }, - } - svrs := servers.Servers{ - servers.SGUser: servers.Options{ - Outbounds: userOuts, - }, - } + options, tags := testBoxOptions(t) tests := []struct { name string - lanternTags []string - userTags []string + boxOptions BoxOptions shouldError bool }{ { - name: "config without user servers", - lanternTags: lanternTags, - }, - { - name: "user servers without config", - userTags: userTags, - }, - { - name: "config and user servers", - lanternTags: lanternTags, - userTags: userTags, + name: "success", + boxOptions: BoxOptions{ + BasePath: t.TempDir(), + Options: options, + }, }, { - name: "neither config nor user servers", + name: "no servers available", + boxOptions: BoxOptions{ + BasePath: t.TempDir(), + }, shouldError: true, }, } - hasGroupWithTags := func(t *testing.T, outs []O.Outbound, group string, tags []string) { - out := findOutbound(outs, group) - if !assert.NotNilf(t, out, "group %s not found", group) { - return - } - switch opts := out.Options.(type) { - case *lbO.MutableSelectorOutboundOptions: - assert.ElementsMatchf(t, tags, opts.Outbounds, "group %s does not have correct outbounds", group) - case *O.SelectorOutboundOptions: - assert.ElementsMatchf(t, tags, opts.Outbounds, "group %s does not have correct outbounds", group) - case *lbO.MutableURLTestOutboundOptions: - assert.ElementsMatchf(t, tags, opts.Outbounds, "group %s does not have correct outbounds", group) - case *O.URLTestOutboundOptions: - assert.ElementsMatchf(t, tags, opts.Outbounds, "group %s does not have correct outbounds", group) - default: - assert.Failf(t, fmt.Sprintf("%s[%T] is not a group outbound", group, opts), "") - } - } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - path := t.TempDir() - if len(tt.lanternTags) > 0 { - testOptsToFile(t, cfg, filepath.Join(path, common.ConfigFileName)) - } - if len(tt.userTags) > 0 { - testOptsToFile(t, svrs, filepath.Join(path, common.ServersFileName)) - } - opts, err := buildOptions(context.Background(), path) + opts, err := buildOptions(tt.boxOptions) if tt.shouldError { require.Error(t, err, "expected error but got none") return } require.NoError(t, err) - gotOutbounds := opts.Outbounds - require.NotEmpty(t, gotOutbounds, "no outbounds in built options") - - assert.NotNil(t, findOutbound(gotOutbounds, constant.TypeDirect), "direct outbound not found") - assert.NotNil(t, findOutbound(gotOutbounds, constant.TypeBlock), "block outbound not found") - - hasGroupWithTags(t, gotOutbounds, servers.SGLantern, append(tt.lanternTags, autoLanternTag)) - hasGroupWithTags(t, gotOutbounds, servers.SGUser, append(tt.userTags, autoUserTag)) - - hasGroupWithTags(t, gotOutbounds, autoLanternTag, tt.lanternTags) - hasGroupWithTags(t, gotOutbounds, autoUserTag, tt.userTags) - hasGroupWithTags(t, gotOutbounds, autoAllTag, []string{autoLanternTag, autoUserTag}) - - assert.FileExists(t, filepath.Join(path, debugLanternBoxOptionsFilename), "debug option file must be written") + urlTest := urlTestOutbound(AutoSelectTag, tags, nil) + assert.Contains(t, opts.Outbounds, urlTest, "options should contain auto-select URL test outbound") + selector := selectorOutbound(ManualSelectTag, tags) + assert.Contains(t, opts.Outbounds, selector, "options should contain manual-select selector outbound") }) } } @@ -125,13 +63,17 @@ func TestBuildOptions_Rulesets(t *testing.T) { "type": "urltest", "tag": "sr-openai", "outbounds": ["http1-out", "socks1-out"], - "url": "https://google.com/generate_204", + "url": "https://google.com/generate_204", "interval": "3m0s", "idle_timeout": "15m" } ], "route": { "rules": [ + { + "rule_set": "sr-direct", + "outbound": "direct" + }, { "rule_set": "openai", "outbound": "sr-openai" @@ -140,39 +82,15 @@ func TestBuildOptions_Rulesets(t *testing.T) { "rule_set": [ { "type": "remote", - "tag": "openai", - "url": "https://ruleset.com/openai.srs", - "download_detour": "direct", - "update_interval": "24h0m0s" - } - ] - } - } - ` - adBlockJSON := ` - { - "route": { - "rules": [ - { - "rule_set": [ - "adblock-1", - "adblock-2" - ], - "action": "reject" - } - ], - "rule_set": [ - { - "type": "remote", - "tag": "adblock-1", - "url": "https://ruleset.com/adblock-1.srs", + "tag": "sr-direct", + "url": "https://ruleset.com/direct.srs", "download_detour": "direct", "update_interval": "24h0m0s" }, { "type": "remote", - "tag": "adblock-2", - "url": "https://ruleset.com/adblock-2.srs", + "tag": "openai", + "url": "https://ruleset.com/openai.srs", "download_detour": "direct", "update_interval": "24h0m0s" } @@ -182,112 +100,73 @@ func TestBuildOptions_Rulesets(t *testing.T) { ` wantSmartRoutingOpts, err := json.UnmarshalExtendedContext[O.Options](box.BaseContext(), []byte(smartRouteJSON)) require.NoError(t, err) - wantAdBlockOpts, err := json.UnmarshalExtendedContext[O.Options](box.BaseContext(), []byte(adBlockJSON)) - require.NoError(t, err) - - buf, err := os.ReadFile("testdata/config.json") - require.NoError(t, err, "read test config file") t.Run("with smart routing", func(t *testing.T) { - tmp := t.TempDir() - require.NoError(t, os.WriteFile(filepath.Join(tmp, common.ConfigFileName), buf, 0644), "write test config file to temp dir") - - require.NoError(t, settings.InitSettings(tmp)) - t.Cleanup(settings.Reset) - - settings.Set(settings.SmartRoutingKey, true) - options, err := buildOptions(context.Background(), tmp) + cfg := testConfig(t) + boxOptions := BoxOptions{ + BasePath: t.TempDir(), + Options: cfg.Options, + SmartRouting: cfg.SmartRouting, + } + options, err := buildOptions(boxOptions) require.NoError(t, err) // check rules, rulesets, and outbounds are correctly built into options - assert.True(t, contains(t, options.Route.Rules, wantSmartRoutingOpts.Route.Rules[0]), "missing smart routing rule") - assert.True(t, contains(t, options.Route.RuleSet, wantSmartRoutingOpts.Route.RuleSet[0]), "missing smart routing ruleset") - assert.True(t, contains(t, options.Outbounds, wantSmartRoutingOpts.Outbounds[0]), "missing smart routing outbound") + assert.Subset(t, options.Route.Rules, wantSmartRoutingOpts.Route.Rules, "missing smart routing rule") + assert.Subset(t, options.Route.RuleSet, wantSmartRoutingOpts.Route.RuleSet, "missing smart routing ruleset") + assert.Subset(t, options.Outbounds, wantSmartRoutingOpts.Outbounds, "missing smart routing outbound") }) t.Run("with smart routing and missing outbounds", func(t *testing.T) { - tmp := t.TempDir() - cfg, err := json.UnmarshalExtendedContext[config.Config](box.BaseContext(), buf) - require.NoError(t, err, "parse test config") - require.NotEmpty(t, cfg.ConfigResponse.SmartRouting, "test config missing smart routing rules") - - cfg.ConfigResponse.SmartRouting[0].Outbounds = nil - cfgBuf, err := json.Marshal(cfg) - require.NoError(t, err, "marshal modified test config") - require.NoError( - t, - os.WriteFile(filepath.Join(tmp, common.ConfigFileName), cfgBuf, 0o644), - "write modified config to temp dir", - ) - - require.NoError(t, settings.InitSettings(tmp)) - t.Cleanup(settings.Reset) - - settings.Set(settings.SmartRoutingKey, true) - options, err := buildOptions(context.Background(), tmp) + cfg := testConfig(t) + cfg.SmartRouting[1].Outbounds = nil + boxOptions := BoxOptions{ + BasePath: t.TempDir(), + Options: cfg.Options, + SmartRouting: cfg.SmartRouting, + } + options, err := buildOptions(boxOptions) require.NoError(t, err) - - assert.Nil( - t, - findOutbound(options.Outbounds, "sr-openai"), - "should not create smart-routing outbound when rule outbounds are missing", - ) - - assert.False( - t, - contains(t, options.Route.RuleSet, wantSmartRoutingOpts.Route.RuleSet[0]), - "should skip smart-routing ruleset when rule outbounds are missing", - ) - hasSmartRoutingRule := slices.ContainsFunc(options.Route.Rules, func(rule O.Rule) bool { - return len(rule.DefaultOptions.RuleSet) == 1 && rule.DefaultOptions.RuleSet[0] == "openai" - }) - assert.False(t, hasSmartRoutingRule, "should skip smart-routing rule when rule outbounds are missing") + // sr-direct rule and ruleset should still be present (category still has outbounds) + assert.Contains(t, options.Route.Rules, wantSmartRoutingOpts.Route.Rules[0], "missing sr-direct rule") + assert.Contains(t, options.Route.RuleSet, wantSmartRoutingOpts.Route.RuleSet[0], "missing sr-direct ruleset") + // openai rule/ruleset and sr-openai outbound should be dropped (outbounds were nilled) + assert.NotContains(t, options.Route.Rules, wantSmartRoutingOpts.Route.Rules[1], "unexpected openai rule") + assert.NotContains(t, options.Route.RuleSet, wantSmartRoutingOpts.Route.RuleSet[1], "unexpected openai ruleset") + assert.NotContains(t, options.Outbounds, wantSmartRoutingOpts.Outbounds[0], "unexpected sr-openai outbound") }) t.Run("with ad block", func(t *testing.T) { - tmp := t.TempDir() - require.NoError(t, os.WriteFile(filepath.Join(tmp, common.ConfigFileName), buf, 0644), "write test config file to temp dir") - - require.NoError(t, settings.InitSettings(tmp)) - t.Cleanup(settings.Reset) - - settings.Set(settings.AdBlockKey, true) - options, err := buildOptions(context.Background(), tmp) + cfg := testConfig(t) + boxOptions := BoxOptions{ + BasePath: t.TempDir(), + Options: cfg.Options, + AdBlock: cfg.AdBlock, + } + wantRule, wantRulesets := cfg.AdBlock.ToOptions() + options, err := buildOptions(boxOptions) require.NoError(t, err) // check reject rule and rulesets are correctly built into options - for _, rs := range wantAdBlockOpts.Route.RuleSet { - assert.True(t, contains(t, options.Route.RuleSet, rs), "missing ad block ruleset") - } - - adRule := wantAdBlockOpts.Route.Rules[0] - assert.True(t, contains(t, options.Route.Rules, adRule), "missing ad block rule") + assert.Contains(t, options.Route.Rules, wantRule, "missing ad block rule") + assert.Subset(t, options.Route.RuleSet, wantRulesets, "missing ad block ruleset") }) } func TestBuildOptions_BanditURLOverrides(t *testing.T) { - testOpts, _, err := testBoxOptions("") - require.NoError(t, err) - lanternTags, lanternOuts := filterOutbounds(*testOpts, constant.TypeHTTP) - require.NotEmpty(t, lanternTags, "need at least one HTTP outbound for test") - + cfg := testConfig(t) overrides := map[string]string{ - lanternTags[0]: "https://example.com/callback?token=abc", + cfg.Options.Outbounds[0].Tag: "https://example.com/callback?token=abc", } - cfg := config.Config{ - ConfigResponse: LC.ConfigResponse{ - Options: O.Options{Outbounds: lanternOuts}, - BanditURLOverrides: overrides, - }, + boxOptions := BoxOptions{ + BasePath: t.TempDir(), + Options: cfg.Options, + BanditURLOverrides: overrides, } - - path := t.TempDir() - testOptsToFile(t, cfg, filepath.Join(path, common.ConfigFileName)) - - opts, err := buildOptions(context.Background(), path) + opts, err := buildOptions(boxOptions) require.NoError(t, err) - out := findOutbound(opts.Outbounds, autoLanternTag) - require.NotNil(t, out, "auto-lantern outbound not found") + out := findOutbound(opts.Outbounds, AutoSelectTag) + require.NotNil(t, out, "missing auto-select outbound") - mutOpts, ok := out.Options.(*lbO.MutableURLTestOutboundOptions) - require.True(t, ok, "auto-lantern outbound should be MutableURLTestOutboundOptions") + require.IsType(t, &lbO.MutableURLTestOutboundOptions{}, out.Options, "auto-select outbound options should be MutableURLTestOutboundOptions") + mutOpts := out.Options.(*lbO.MutableURLTestOutboundOptions) assert.Equal(t, overrides, mutOpts.URLOverrides, "URLOverrides should be wired from config") } @@ -330,34 +209,33 @@ func findOutbound(outs []O.Outbound, tag string) *O.Outbound { return &outs[idx] } -func testOptsToFile[T any](t *testing.T, opts T, path string) { - buf, err := json.Marshal(opts) - require.NoError(t, err, "marshal options") - require.NoError(t, os.WriteFile(path, buf, 0644), "write options to file") +func testConfig(t *testing.T) config.Config { + buf, err := os.ReadFile("testdata/config.json") + require.NoError(t, err, "read test config file") + + cfg, err := json.UnmarshalExtendedContext[config.Config](box.BaseContext(), buf) + require.NoError(t, err, "unmarshal test config") + return cfg } -func testBoxOptions(tmpPath string) (*O.Options, string, error) { - content, err := os.ReadFile("testdata/boxopts.json") - if err != nil { - return nil, "", err +func testBoxOptions(t *testing.T) (O.Options, []string) { + cfg := testConfig(t) + var tags []string + for _, o := range cfg.Options.Outbounds { + tags = append(tags, o.Tag) } - opts, err := json.UnmarshalExtendedContext[O.Options](box.BaseContext(), content) - if err != nil { - return nil, "", err + for _, ep := range cfg.Options.Endpoints { + tags = append(tags, ep.Tag) } - - opts.Experimental.CacheFile.Path = filepath.Join(tmpPath, cacheFileName) - opts.Experimental.CacheFile.CacheID = cacheID - buf, _ := json.Marshal(opts) - return &opts, string(buf), nil + return cfg.Options, tags } func TestKernelBelow(t *testing.T) { tests := []struct { - name string - v string - min string - want bool + name string + v string + min string + want bool }{ {"below major", "4.19.0", "5.10", true}, {"below minor", "5.4.0", "5.10", true}, @@ -378,4 +256,3 @@ func TestKernelBelow(t *testing.T) { }) } } - diff --git a/vpn/clash.go b/vpn/clash.go new file mode 100644 index 00000000..ff10b9b0 --- /dev/null +++ b/vpn/clash.go @@ -0,0 +1,126 @@ +package vpn + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "slices" + "strings" + "sync" + + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/experimental/clashapi/trafficontrol" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" +) + +var _ adapter.ClashServer = (*clashServer)(nil) + +// clashServer is a stub adapter.ClashServer: it exposes the traffic manager +// and URL-test history hook the rest of the tunnel depends on, but does not +// run the Clash HTTP API. Start and Close are no-ops because there are no +// owned resources beyond what's wired in via the sing-box service context. +type clashServer struct { + ctx context.Context + dnsRouter adapter.DNSRouter + outbound adapter.OutboundManager + endpoint adapter.EndpointManager + + urlTestHistory adapter.URLTestHistoryStorage + trafficManager *trafficontrol.Manager + + mode string + modeList []string + + mu sync.RWMutex +} + +func newClashServer(ctx context.Context, _ log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { + modeList := options.ModeList + initial := options.DefaultMode + if len(modeList) == 0 { + return nil, errors.New("mode list is empty") + } + if initial == "" { + initial = modeList[0] + } else if !slices.Contains(modeList, initial) { + return nil, fmt.Errorf("initial mode %q is not in mode list", initial) + } + + return &clashServer{ + dnsRouter: service.FromContext[adapter.DNSRouter](ctx), + outbound: service.FromContext[adapter.OutboundManager](ctx), + endpoint: service.FromContext[adapter.EndpointManager](ctx), + urlTestHistory: service.FromContext[adapter.URLTestHistoryStorage](ctx), + trafficManager: trafficontrol.NewManager(), + modeList: modeList, + mode: initial, + }, nil +} + +func (s *clashServer) SetMode(mode string) error { + s.mu.Lock() + defer s.mu.Unlock() + i := slices.IndexFunc(s.modeList, func(m string) bool { + return strings.EqualFold(m, mode) + }) + if i == -1 { + return fmt.Errorf("mode %q is not in mode list", mode) + } + mode = s.modeList[i] + if s.mode != mode { + slog.Info("Switching mode", "from", s.mode, "to", mode) + s.mode = mode + s.dnsRouter.ClearCache() + } + return nil +} + +func (s *clashServer) Mode() string { + s.mu.RLock() + defer s.mu.RUnlock() + return s.mode +} + +func (s *clashServer) ModeList() []string { + s.mu.RLock() + defer s.mu.RUnlock() + return s.modeList +} + +func (s *clashServer) Start(stage adapter.StartStage) error { + return nil +} + +func (s *clashServer) Close() error { + return nil +} + +func (s *clashServer) HistoryStorage() adapter.URLTestHistoryStorage { + s.mu.RLock() + defer s.mu.RUnlock() + return s.urlTestHistory +} + +func (s *clashServer) TrafficManager() *trafficontrol.Manager { + s.mu.RLock() + defer s.mu.RUnlock() + return s.trafficManager +} + +func (s *clashServer) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn { + return trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.outbound, matchedRule, matchOutbound) +} + +func (s *clashServer) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn { + return trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.outbound, matchedRule, matchOutbound) +} + +func (s *clashServer) Name() string { + return "clash" +} diff --git a/vpn/dnsoptions_test.go b/vpn/dnsoptions_test.go index 06b49f1b..9f5866b8 100644 --- a/vpn/dnsoptions_test.go +++ b/vpn/dnsoptions_test.go @@ -3,6 +3,8 @@ package vpn import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/getlantern/radiance/common/settings" ) @@ -62,9 +64,7 @@ func TestNormalizeLocale(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := normalizeLocale(tt.locale) - if result != tt.expected { - t.Errorf("normalizeLocale(%q) = %q, expected %q", tt.locale, result, tt.expected) - } + assert.Equalf(t, tt.expected, result, "normalizeLocale(%q) should return %q", tt.locale, tt.expected) }) } } @@ -138,43 +138,7 @@ func TestLocalDNSIP(t *testing.T) { settings.Set(settings.LocaleKey, tt.locale) result := localDNSIP() - if result != tt.expected { - t.Errorf("localDNSIP() with locale %q = %q, expected %q", tt.locale, result, tt.expected) - } + assert.Equalf(t, tt.expected, result, "localDNSIP() with locale %q should return %q", tt.locale, tt.expected) }) } } -func TestBuildDNSRules(t *testing.T) { - rules := buildDNSRules() - - if len(rules) != 1 { - t.Fatalf("expected 1 DNS rule, got %d", len(rules)) - } - - rule := rules[0] - - if rule.Type != "default" { - t.Errorf("expected rule type 'default', got %q", rule.Type) - } - - if rule.DefaultOptions.DNSRuleAction.Action != "route" { - t.Errorf("expected action 'route', got %q", rule.DefaultOptions.DNSRuleAction.Action) - } - - if rule.DefaultOptions.DNSRuleAction.RouteOptions.Server != "dns_fakeip" { - t.Errorf("expected server 'dns_fakeip', got %q", rule.DefaultOptions.DNSRuleAction.RouteOptions.Server) - } - - queryTypes := rule.DefaultOptions.RawDefaultDNSRule.QueryType - if len(queryTypes) != 2 { - t.Fatalf("expected 2 query types, got %d", len(queryTypes)) - } - - if queryTypes[0] != 1 { // dns.TypeA - t.Errorf("expected first query type to be TypeA (1), got %d", queryTypes[0]) - } - - if queryTypes[1] != 28 { // dns.TypeAAAA - t.Errorf("expected second query type to be TypeAAAA (28), got %d", queryTypes[1]) - } -} diff --git a/vpn/ipc.go b/vpn/ipc.go deleted file mode 100644 index 795cbd88..00000000 --- a/vpn/ipc.go +++ /dev/null @@ -1,45 +0,0 @@ -package vpn - -import ( - "context" - "fmt" - "log/slog" - - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/traces" - "github.com/getlantern/radiance/vpn/ipc" - "github.com/getlantern/radiance/vpn/rvpn" -) - -// InitIPC initializes and returns a started IPC server. -func InitIPC(dataPath, logPath, logLevel string, platformIfce rvpn.PlatformInterface) (*ipc.Server, error) { - ctx, span := otel.Tracer(tracerName).Start( - context.Background(), - "initIPC", - trace.WithAttributes(attribute.String("dataPath", dataPath)), - ) - defer span.End() - - span.AddEvent("initializing IPC server") - - if err := common.InitReadOnly(dataPath, logPath, logLevel); err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("init common ro: %w", err)) - } - if path := settings.GetString(settings.DataPathKey); path != "" && path != dataPath { - dataPath = path - } - - server := ipc.NewServer(NewTunnelService(dataPath, slog.Default().With("service", "ipc"), platformIfce)) - slog.Debug("starting IPC server") - if err := server.Start(); err != nil { - slog.Error("failed to start IPC server", "error", err) - return nil, traces.RecordError(ctx, fmt.Errorf("start IPC server: %w", err)) - } - - return server, nil -} diff --git a/vpn/ipc/clash_mode.go b/vpn/ipc/clash_mode.go deleted file mode 100644 index ec0f9e97..00000000 --- a/vpn/ipc/clash_mode.go +++ /dev/null @@ -1,64 +0,0 @@ -package ipc - -import ( - "context" - "encoding/json" - "log/slog" - "net/http" - - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - - "github.com/getlantern/radiance/internal" -) - -type m struct { - Mode string `json:"mode"` -} - -// GetClashMode retrieves the current mode from the Clash server. -func GetClashMode(ctx context.Context) (string, error) { - res, err := sendRequest[m](ctx, "GET", clashModeEndpoint, nil) - if err != nil { - return "", err - } - return res.Mode, nil -} - -// SetClashMode sets the mode of the Clash server. -func SetClashMode(ctx context.Context, mode string) error { - _, err := sendRequest[empty](ctx, "POST", clashModeEndpoint, m{Mode: mode}) - return err -} - -// clashModeHandler handles HTTP requests for getting or setting the Clash server mode. -func (s *Server) clashModeHandler(w http.ResponseWriter, req *http.Request) { - span := trace.SpanFromContext(req.Context()) - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - cs := s.service.ClashServer() - switch req.Method { - case "GET": - mode := cs.Mode() - span.SetAttributes(attribute.String("mode", mode)) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(m{Mode: mode}); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - case "POST": - var mode m - if err := json.NewDecoder(req.Body).Decode(&mode); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - span.SetAttributes(attribute.String("mode", mode.Mode)) - slog.Log(nil, internal.LevelTrace, "Setting clash mode", "mode", mode.Mode) - cs.SetMode(mode.Mode) - w.WriteHeader(http.StatusOK) - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} diff --git a/vpn/ipc/connections.go b/vpn/ipc/connections.go deleted file mode 100644 index 125c8017..00000000 --- a/vpn/ipc/connections.go +++ /dev/null @@ -1,126 +0,0 @@ -package ipc - -import ( - "context" - "encoding/json" - "net/http" - runtimeDebug "runtime/debug" - "time" - - "github.com/gofrs/uuid/v5" - "github.com/sagernet/sing-box/common/conntrack" - "github.com/sagernet/sing-box/experimental/clashapi/trafficontrol" -) - -// CloseConnections closes connections by their IDs. If connIDs is empty, all connections will be closed. -func CloseConnections(ctx context.Context, connIDs []string) error { - _, err := sendRequest[empty](ctx, "POST", closeConnectionsEndpoint, connIDs) - return err -} - -func (s *Server) closeConnectionHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - var cids []string - err := json.NewDecoder(r.Body).Decode(&cids) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if len(cids) > 0 { - tm := s.service.ClashServer().TrafficManager() - for _, cid := range cids { - targetConn := tm.Connection(uuid.FromStringOrNil(cid)) - if targetConn == nil { - continue - } - targetConn.Close() - } - } else { - conntrack.Close() - } - go func() { - time.Sleep(time.Second) - runtimeDebug.FreeOSMemory() - }() - w.WriteHeader(http.StatusOK) -} - -// GetConnections retrieves the list of current and recently closed connections. -func GetConnections(ctx context.Context) ([]Connection, error) { - return sendRequest[[]Connection](ctx, "GET", connectionsEndpoint, nil) -} - -func (s *Server) connectionsHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - w.Header().Set("Content-Type", "application/json") - tm := s.service.ClashServer().TrafficManager() - activeConns := tm.Connections() - closedConns := tm.ClosedConnections() - connections := make([]Connection, 0, len(activeConns)+len(closedConns)) - for _, connection := range activeConns { - connections = append(connections, newConnection(connection)) - } - for _, connection := range closedConns { - connections = append(connections, newConnection(connection)) - } - if err := json.NewEncoder(w).Encode(connections); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} - -// Connection represents a network connection with relevant metadata. -type Connection struct { - ID string - Inbound string - IPVersion int - Network string - Source string - Destination string - Domain string - Protocol string - FromOutbound string - CreatedAt int64 - ClosedAt int64 - Uplink int64 - Downlink int64 - Rule string - Outbound string - ChainList []string -} - -func newConnection(metadata trafficontrol.TrackerMetadata) Connection { - var rule string - if metadata.Rule != nil { - rule = metadata.Rule.String() + " => " + metadata.Rule.Action().String() - } - var closedAt int64 - if !metadata.ClosedAt.IsZero() { - closedAt = metadata.ClosedAt.UnixMilli() - } - md := metadata.Metadata - return Connection{ - ID: metadata.ID.String(), - Inbound: md.InboundType + "/" + md.Inbound, - IPVersion: int(md.IPVersion), - Network: md.Network, - Source: md.Source.String(), - Destination: md.Destination.String(), - Domain: md.Domain, - Protocol: md.Protocol, - FromOutbound: md.Outbound, - CreatedAt: metadata.CreatedAt.UnixMilli(), - ClosedAt: closedAt, - Uplink: metadata.Upload.Load(), - Downlink: metadata.Download.Load(), - Rule: rule, - Outbound: metadata.OutboundType + "/" + metadata.Outbound, - ChainList: metadata.Chain, - } -} diff --git a/vpn/ipc/endpoints.go b/vpn/ipc/endpoints.go deleted file mode 100644 index b55c43d2..00000000 --- a/vpn/ipc/endpoints.go +++ /dev/null @@ -1,19 +0,0 @@ -package ipc - -const ( - statusEndpoint = "/status" - metricsEndpoint = "/metrics" - startServiceEndpoint = "/service/start" - stopServiceEndpoint = "/service/stop" - restartServiceEndpoint = "/service/restart" - groupsEndpoint = "/groups" - selectEndpoint = "/outbound/select" - activeEndpoint = "/outbound/active" - updateOutboundsEndpoint = "/outbound/update" - addOutboundsEndpoint = "/outbound/add" - removeOutboundsEndpoint = "/outbound/remove" - clashModeEndpoint = "/clash/mode" - connectionsEndpoint = "/connections" - closeConnectionsEndpoint = "/connections/close" - setSettingsPathEndpoint = "/set" -) diff --git a/vpn/ipc/group.go b/vpn/ipc/group.go deleted file mode 100644 index 48ede66a..00000000 --- a/vpn/ipc/group.go +++ /dev/null @@ -1,83 +0,0 @@ -package ipc - -import ( - "context" - "encoding/json" - "errors" - "net/http" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing/service" -) - -// GetGroups retrieves the list of group outbounds. -func GetGroups(ctx context.Context) ([]OutboundGroup, error) { - return sendRequest[[]OutboundGroup](ctx, "GET", groupsEndpoint, nil) -} - -func (s *Server) groupHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - groups, err := getGroups(s.service.Ctx()) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(groups); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} - -// OutboundGroup represents a group of outbounds. -type OutboundGroup struct { - Tag string - Type string - Selected string - Outbounds []Outbounds -} - -// Outbounds represents outbounds within a group. -type Outbounds struct { - Tag string - Type string -} - -func getGroups(ctx context.Context) ([]OutboundGroup, error) { - outboundMgr := service.FromContext[adapter.OutboundManager](ctx) - if outboundMgr == nil { - return nil, errors.New("outbound manager not found") - } - outbounds := outboundMgr.Outbounds() - var iGroups []adapter.OutboundGroup - for _, it := range outbounds { - if group, isGroup := it.(adapter.OutboundGroup); isGroup { - iGroups = append(iGroups, group) - } - } - var groups []OutboundGroup - for _, iGroup := range iGroups { - group := OutboundGroup{ - Tag: iGroup.Tag(), - Type: iGroup.Type(), - Selected: iGroup.Now(), - } - for _, itemTag := range iGroup.All() { - itemOutbound, isLoaded := outboundMgr.Outbound(itemTag) - if !isLoaded { - continue - } - - item := Outbounds{ - Tag: itemTag, - Type: itemOutbound.Type(), - } - group.Outbounds = append(group.Outbounds, item) - } - groups = append(groups, group) - } - return groups, nil -} diff --git a/vpn/ipc/http.go b/vpn/ipc/http.go deleted file mode 100644 index fb91b4ec..00000000 --- a/vpn/ipc/http.go +++ /dev/null @@ -1,79 +0,0 @@ -package ipc - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "os" - - box "github.com/getlantern/lantern-box" - singjson "github.com/sagernet/sing/common/json" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - - "github.com/getlantern/radiance/traces" -) - -const tracerName = "github.com/getlantern/radiance/vpn/ipc" - -// empty is a placeholder type for requests that do not expect a response body. -type empty struct{} - -// sendRequest sends an HTTP request to the specified endpoint with the given method and data. -func sendRequest[T any](ctx context.Context, method, endpoint string, data any) (T, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "vpn.ipc", - trace.WithAttributes(attribute.String("endpoint", endpoint)), - ) - defer span.End() - - // Use sing-box's context-aware JSON so custom outbound types - // (e.g., samizdat Options) are serialized with their type-specific - // fields. Standard json.Marshal loses typed Options on any interface. - buf, err := singjson.MarshalContext(box.BaseContext(), data) - var res T - if err != nil { - return res, traces.RecordError(ctx, fmt.Errorf("failed to marshal payload: %w", err)) - } - req, err := http.NewRequestWithContext(ctx, method, apiURL+endpoint, bytes.NewReader(buf)) - if err != nil { - return res, err - } - client := &http.Client{ - Transport: &http.Transport{ - DialContext: dialContext, - }, - } - resp, err := client.Do(req) - if errors.Is(err, os.ErrNotExist) { - err = ErrIPCNotRunning - } - if err != nil { - return res, traces.RecordError(ctx, fmt.Errorf("request failed: %w", err)) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { - return res, traces.RecordError(ctx, readErrorResponse(resp)) - } - if _, ok := any(&res).(*empty); ok { - return res, nil - } - - err = json.NewDecoder(resp.Body).Decode(&res) - if err != nil { - return res, traces.RecordError(ctx, fmt.Errorf("failed to decode response: %w", err)) - } - return res, nil -} - -func readErrorResponse(resp *http.Response) error { - buf, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read error response body: %w, status: %s", err, resp.Status) - } - return fmt.Errorf("%s: %s", resp.Status, buf) -} diff --git a/vpn/ipc/outbound.go b/vpn/ipc/outbound.go deleted file mode 100644 index d9eb947c..00000000 --- a/vpn/ipc/outbound.go +++ /dev/null @@ -1,291 +0,0 @@ -package ipc - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "net/http" - runtimeDebug "runtime/debug" - "time" - - box "github.com/getlantern/lantern-box" - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/conntrack" - singjson "github.com/sagernet/sing/common/json" - "github.com/sagernet/sing/service" - - "github.com/getlantern/radiance/internal" - "github.com/getlantern/radiance/servers" -) - -const maxIPCBodySize = 10 << 20 // 10 MB - -type selection struct { - GroupTag string `json:"groupTag"` - OutboundTag string `json:"outboundTag"` -} - -// SelectOutbound selects an outbound within a group. -func SelectOutbound(ctx context.Context, groupTag, outboundTag string) error { - _, err := sendRequest[empty](ctx, "POST", selectEndpoint, selection{groupTag, outboundTag}) - return err -} - -func (s *Server) selectHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - var p selection - err := json.NewDecoder(r.Body).Decode(&p) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - defer func() { - if r := recover(); r != nil { - http.Error(w, fmt.Sprint(r), http.StatusInternalServerError) - } - }() - slog.Log(nil, internal.LevelTrace, "selecting outbound", "group", p.GroupTag, "outbound", p.OutboundTag) - outbound, err := getGroupOutbound(s.service.Ctx(), p.GroupTag) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - selector, isSelector := outbound.(Selector) - if !isSelector { - http.Error(w, fmt.Sprintf("outbound %q is not a selector", p.GroupTag), http.StatusBadRequest) - return - } - slog.Log(nil, internal.LevelTrace, "setting outbound", "outbound", p.OutboundTag) - if !selector.SelectOutbound(p.OutboundTag) { - http.Error(w, fmt.Sprintf("outbound %q not found in group", p.OutboundTag), http.StatusBadRequest) - return - } - cs := s.service.ClashServer() - if mode := cs.Mode(); mode != p.GroupTag { - slog.Log(nil, internal.LevelDebug, "changing clash mode", "new", p.GroupTag, "old", mode) - s.service.ClashServer().SetMode(p.GroupTag) - conntrack.Close() - go func() { - time.Sleep(time.Second) - runtimeDebug.FreeOSMemory() - }() - } - w.WriteHeader(http.StatusOK) -} - -// Selector is helper interface to check if an outbound is a selector or wrapper of selector. -type Selector interface { - adapter.OutboundGroup - SelectOutbound(tag string) bool -} - -// GetSelected retrieves the currently selected outbound and its group. -func GetSelected(ctx context.Context) (group, tag string, err error) { - res, err := sendRequest[selection](ctx, "GET", selectEndpoint, nil) - if err != nil { - return "", "", err - } - return res.GroupTag, res.OutboundTag, nil -} - -func (s *Server) selectedHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - cs := s.service.ClashServer() - mode := cs.Mode() - selector, err := getGroupOutbound(s.service.Ctx(), mode) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - res := selection{ - GroupTag: mode, - OutboundTag: selector.Now(), - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(res); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} - -// GetActiveOutbound retrieves the outbound that is actively being used, resolving nested groups -// if necessary. -func GetActiveOutbound(ctx context.Context) (group, tag string, err error) { - res, err := sendRequest[selection](ctx, "GET", activeEndpoint, nil) - if err != nil { - return "", "", err - } - return res.GroupTag, res.OutboundTag, nil -} - -func (s *Server) activeOutboundHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - cs := s.service.ClashServer() - mode := cs.Mode() - group, err := getGroupOutbound(s.service.Ctx(), mode) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - tag := group.Now() - // if the selected outbound is also a group, retrieve its selected outbound - // continue until we reach a non-group outbound - for { - group, err = getGroupOutbound(s.service.Ctx(), tag) - if err != nil { - break - } - tag = group.Now() - } - if tag == "" { - tag = "unavailable" - } - res := selection{ - GroupTag: mode, - OutboundTag: tag, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(res); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} - -func getGroupOutbound(ctx context.Context, tag string) (adapter.OutboundGroup, error) { - outboundMgr := service.FromContext[adapter.OutboundManager](ctx) - if outboundMgr == nil { - return nil, errors.New("outbound manager not found") - } - - outbound, loaded := outboundMgr.Outbound(tag) - if !loaded { - return nil, fmt.Errorf("group not found: %s", tag) - } - group, isGroup := outbound.(adapter.OutboundGroup) - if !isGroup { - return nil, fmt.Errorf("outbound is not a group: %s", tag) - } - return group, nil -} - -func UpdateOutbounds(ctx context.Context, servers servers.Servers) error { - _, err := sendRequest[empty](ctx, "POST", updateOutboundsEndpoint, servers) - return err -} - -func (s *Server) updateOutboundsHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - // Use sing-box's context-aware JSON decoder so custom outbound types - // (e.g., samizdat) are deserialized into their typed Options structs - // instead of generic map[string]any. Without this, fields like - // public_key are lost during the IPC round-trip. - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxIPCBodySize)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - data, err := singjson.UnmarshalExtendedContext[servers.Servers](box.BaseContext(), body) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - slog.Debug("Updating outbounds") - w.WriteHeader(http.StatusAccepted) - go func() { - defer func() { - if r := recover(); r != nil { - slog.Error("panic in UpdateOutbounds", "recover", r, "stack", string(runtimeDebug.Stack())) - } - }() - if err := s.service.UpdateOutbounds(data); err != nil { - slog.Error("Failed to update outbounds", "error", err) - } - }() -} - -type newOutbounds struct { - Group string `json:"group"` - Servers servers.Options `json:"servers"` -} - -func AddOutbounds(ctx context.Context, group string, servers servers.Options) error { - _, err := sendRequest[empty](ctx, "POST", addOutboundsEndpoint, newOutbounds{Group: group, Servers: servers}) - return err -} - -func (s *Server) addOutboundsHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxIPCBodySize)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - data, err := singjson.UnmarshalExtendedContext[newOutbounds](box.BaseContext(), body) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - slog.Debug("Adding outbounds", "group", data.Group) - w.WriteHeader(http.StatusAccepted) - go func() { - defer func() { - if r := recover(); r != nil { - slog.Error("panic in AddOutbounds", "recover", r, "stack", string(runtimeDebug.Stack())) - } - }() - if err := s.service.AddOutbounds(data.Group, data.Servers); err != nil { - slog.Error("Failed to add outbounds", "error", err) - } - }() -} - -type outboundsToRemove struct { - Group string `json:"group"` - Tags []string `json:"tags"` -} - -func RemoveOutbounds(ctx context.Context, group string, tags []string) error { - _, err := sendRequest[empty](ctx, "POST", removeOutboundsEndpoint, outboundsToRemove{Group: group, Tags: tags}) - return err -} - -func (s *Server) removeOutboundsHandler(w http.ResponseWriter, r *http.Request) { - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusServiceUnavailable) - return - } - var data outboundsToRemove - if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - w.WriteHeader(http.StatusAccepted) - go func() { - defer func() { - if r := recover(); r != nil { - slog.Error("panic in RemoveOutbounds", "recover", r, "stack", string(runtimeDebug.Stack())) - } - }() - if err := s.service.RemoveOutbounds(data.Group, data.Tags); err != nil { - slog.Error("Failed to remove outbounds", "error", err) - } - }() -} diff --git a/vpn/ipc/outbound_test.go b/vpn/ipc/outbound_test.go deleted file mode 100644 index 175f3396..00000000 --- a/vpn/ipc/outbound_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package ipc - -import ( - "encoding/json" - "testing" - - box "github.com/getlantern/lantern-box" - LO "github.com/getlantern/lantern-box/option" - O "github.com/sagernet/sing-box/option" - singjson "github.com/sagernet/sing/common/json" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/getlantern/radiance/servers" -) - -// TestSamizdatOptionsRoundTrip verifies that samizdat outbound options -// (specifically public_key) survive JSON serialization/deserialization -// through the IPC path. This was the root cause of the "public_key must -// be 64 hex characters (32 bytes), got len=0" bug — standard encoding/json -// doesn't preserve typed Options on option.Outbound's any interface. -func TestSamizdatOptionsRoundTrip(t *testing.T) { - const testPubKey = "20ebb18d5fdf9bff27fe32ef9501035d8f0bb8dfb481a0a2363181560e0e8115" - const testShortID = "3b1a8fc7f1edf914" - - original := servers.Servers{ - "lantern": servers.Options{ - Outbounds: []O.Outbound{ - { - Type: "samizdat", - Tag: "samizdat-out-test-route", - Options: &LO.SamizdatOutboundOptions{ - ServerOptions: O.ServerOptions{ - Server: "1.2.3.4", - ServerPort: 443, - }, - PublicKey: testPubKey, - ShortID: testShortID, - ServerName: "example.com", - }, - }, - }, - }, - } - - // Demonstrate the bug: standard json.Marshal/Unmarshal loses the typed Options - t.Run("standard_json_loses_public_key", func(t *testing.T) { - buf, err := json.Marshal(original) - require.NoError(t, err) - - var decoded servers.Servers - err = json.Unmarshal(buf, &decoded) - require.NoError(t, err) - - outbounds := decoded["lantern"].Outbounds - require.Len(t, outbounds, 1) - - // Standard json deserializes Options as map[string]any, not *SamizdatOutboundOptions - _, ok := outbounds[0].Options.(*LO.SamizdatOutboundOptions) - assert.False(t, ok, "standard json should NOT preserve typed Options") - }) - - // Verify the fix: sing-box context-aware JSON preserves typed Options - t.Run("singbox_json_preserves_public_key", func(t *testing.T) { - ctx := box.BaseContext() - - buf, err := singjson.MarshalContext(ctx, original) - require.NoError(t, err) - - // Verify public_key is in the serialized JSON - assert.Contains(t, string(buf), testPubKey, "serialized JSON should contain public_key") - - decoded, err := singjson.UnmarshalExtendedContext[servers.Servers](ctx, buf) - require.NoError(t, err) - - outbounds := decoded["lantern"].Outbounds - require.Len(t, outbounds, 1) - - samOpts, ok := outbounds[0].Options.(*LO.SamizdatOutboundOptions) - require.True(t, ok, "sing-box json should preserve typed Options") - assert.Equal(t, testPubKey, samOpts.PublicKey, "public_key should survive round-trip") - assert.Equal(t, testShortID, samOpts.ShortID, "short_id should survive round-trip") - assert.Equal(t, "example.com", samOpts.ServerName, "server_name should survive round-trip") - }) -} diff --git a/vpn/ipc/server.go b/vpn/ipc/server.go deleted file mode 100644 index 187aa57d..00000000 --- a/vpn/ipc/server.go +++ /dev/null @@ -1,263 +0,0 @@ -// Package ipc implements the IPC server for communicating between the client and the VPN service. -// It provides HTTP endpoints for retrieving statistics, managing groups, selecting outbounds, -// changing modes, and closing connections. -package ipc - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "net" - "net/http" - "sync/atomic" - "time" - - "github.com/go-chi/chi/v5" - "github.com/sagernet/sing-box/experimental/clashapi" - "go.opentelemetry.io/otel" - - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/servers" -) - -var ( - ErrServiceIsNotReady = errors.New("service is not ready") - ErrIPCNotRunning = errors.New("IPC not running") -) - -// Service defines the interface that the IPC server uses to interact with the underlying VPN service. -type Service interface { - Ctx() context.Context - Status() VPNStatus - Start(ctx context.Context, options string) error - Restart(ctx context.Context, options string) error - Close() error - ClashServer() *clashapi.Server - UpdateOutbounds(options servers.Servers) error - AddOutbounds(group string, options servers.Options) error - RemoveOutbounds(group string, tags []string) error -} - -// Server represents the IPC server that communicates over a Unix domain socket for Unix-like -// systems, and a named pipe for Windows. -type Server struct { - svr *http.Server - service Service - router chi.Router - vpnStatus atomic.Value // string - closed atomic.Bool -} - -// StatusUpdateEvent is emitted when the VPN status changes. -type StatusUpdateEvent struct { - events.Event - Status VPNStatus - Error error -} - -type VPNStatus string - -// Possible VPN statuses -const ( - Connected VPNStatus = "connected" - Disconnected VPNStatus = "disconnected" - Connecting VPNStatus = "connecting" - Disconnecting VPNStatus = "disconnecting" - ErrorStatus VPNStatus = "error" -) - -func (vpn *VPNStatus) String() string { - return string(*vpn) -} - -// NewServer creates a new Server instance with the provided Service. -func NewServer(service Service) *Server { - s := &Server{ - service: service, - router: chi.NewMux(), - } - s.vpnStatus.Store(Disconnected) - s.router.Use(log, tracer) - - // Only add auth middleware if not running on mobile, since mobile platforms have their own - // sandboxing and permission models. - addAuth := !common.IsMobile() && !_testing - if addAuth { - s.router.Use(authPeer) - } - - s.router.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - s.router.Get(statusEndpoint, s.statusHandler) - s.router.Get(metricsEndpoint, s.metricsHandler) - s.router.Get(groupsEndpoint, s.groupHandler) - s.router.Get(connectionsEndpoint, s.connectionsHandler) - s.router.Get(selectEndpoint, s.selectedHandler) - s.router.Get(activeEndpoint, s.activeOutboundHandler) - s.router.Post(selectEndpoint, s.selectHandler) - s.router.Get(clashModeEndpoint, s.clashModeHandler) - s.router.Post(clashModeEndpoint, s.clashModeHandler) - s.router.Post(startServiceEndpoint, s.startServiceHandler) - s.router.Post(stopServiceEndpoint, s.stopServiceHandler) - s.router.Post(restartServiceEndpoint, s.restartServiceHandler) - s.router.Post(updateOutboundsEndpoint, s.updateOutboundsHandler) - s.router.Post(addOutboundsEndpoint, s.addOutboundsHandler) - s.router.Post(removeOutboundsEndpoint, s.removeOutboundsHandler) - s.router.Post(closeConnectionsEndpoint, s.closeConnectionHandler) - - svr := &http.Server{ - Handler: s.router, - ReadTimeout: time.Second * 5, - WriteTimeout: time.Second * 5, - } - if addAuth { - svr.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - peer, err := getConnPeer(c) - if err != nil { - slog.Error("Failed to get peer credentials", "error", err) - } - return contextWithUsr(ctx, peer) - } - } - s.svr = svr - return s -} - -// Start begins listening for incoming IPC requests. -func (s *Server) Start() error { - if s.closed.Load() { - return errors.New("IPC server is closed") - } - l, err := listen() - if err != nil { - return fmt.Errorf("IPC server: listen: %w", err) - } - go func() { - slog.Info("IPC server started", "address", l.Addr().String()) - err := s.svr.Serve(l) - if err != nil && err != http.ErrServerClosed { - slog.Error("IPC server", "error", err) - } - s.closed.Store(true) - if s.service.Status() != Disconnected { - slog.Warn("IPC server stopped unexpectedly, closing service") - s.service.Close() - s.setVPNStatus(ErrorStatus, errors.New("IPC server stopped unexpectedly")) - } - }() - - return nil -} - -// Close shuts down the IPC server. -func (s *Server) Close() error { - if s.closed.Swap(true) { - return nil - } - defer s.service.Close() - - slog.Info("Closing IPC server") - return s.svr.Close() -} - -func (s *Server) IsClosed() bool { - return s.closed.Load() -} - -type opts struct { - Options string `json:"options"` -} - -// StartService sends a request to start the service -func StartService(ctx context.Context, options string) error { - _, err := sendRequest[empty](ctx, "POST", startServiceEndpoint, opts{Options: options}) - return err -} - -func (s *Server) startServiceHandler(w http.ResponseWriter, r *http.Request) { - ctx, span := otel.Tracer(tracerName).Start(r.Context(), "ipc.Server.StartService") - defer span.End() - switch s.service.Status() { - case Disconnected: - // proceed to start - case Connected: - w.WriteHeader(http.StatusOK) - return - case Disconnecting: - http.Error(w, "service is disconnecting, please wait", http.StatusConflict) - return - default: - http.Error(w, "service is already starting", http.StatusConflict) - return - } - var p opts - if err := json.NewDecoder(r.Body).Decode(&p); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - s.setVPNStatus(Connecting, nil) - if err := s.service.Start(ctx, p.Options); err != nil { - s.setVPNStatus(ErrorStatus, err) - http.Error(w, err.Error(), http.StatusServiceUnavailable) - return - } - s.setVPNStatus(Connected, nil) - w.WriteHeader(http.StatusOK) -} - -// StopService sends a request to stop the service (IPC server stays up) -func StopService(ctx context.Context) error { - _, err := sendRequest[empty](ctx, "POST", stopServiceEndpoint, nil) - return err -} - -func (s *Server) stopServiceHandler(w http.ResponseWriter, r *http.Request) { - slog.Debug("Received request to stop service via IPC") - s.setVPNStatus(Disconnecting, nil) - if err := s.service.Close(); err != nil { - s.setVPNStatus(ErrorStatus, err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - s.setVPNStatus(Disconnected, nil) - w.WriteHeader(http.StatusOK) -} - -func RestartService(ctx context.Context, options string) error { - _, err := sendRequest[empty](ctx, "POST", restartServiceEndpoint, opts{Options: options}) - return err -} - -func (s *Server) restartServiceHandler(w http.ResponseWriter, r *http.Request) { - ctx, span := otel.Tracer(tracerName).Start(r.Context(), "ipc.Server.restartServiceHandler") - defer span.End() - - if s.service.Status() != Connected { - http.Error(w, ErrServiceIsNotReady.Error(), http.StatusInternalServerError) - return - } - var p opts - if err := json.NewDecoder(r.Body).Decode(&p); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - s.setVPNStatus(Disconnected, nil) - if err := s.service.Restart(ctx, p.Options); err != nil { - s.setVPNStatus(ErrorStatus, err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - s.setVPNStatus(Connected, nil) - w.WriteHeader(http.StatusOK) -} - -func (s *Server) setVPNStatus(status VPNStatus, err error) { - s.vpnStatus.Store(status) - events.Emit(StatusUpdateEvent{Status: status, Error: err}) -} diff --git a/vpn/ipc/socket.go b/vpn/ipc/socket.go deleted file mode 100644 index e04ae70a..00000000 --- a/vpn/ipc/socket.go +++ /dev/null @@ -1,48 +0,0 @@ -//go:build !android && !ios && !windows - -package ipc - -import ( - "fmt" - "os" - "os/user" - "runtime" - "strconv" -) - -// use a var so it can be overridden in tests -var _socketPath = "/var/run/lantern/lanternd.sock" - -// setSocketPathForTesting is only used for testing. -func setSocketPathForTesting(path string) { - _socketPath = path -} - -func socketPath() string { - return _socketPath -} - -func setPermissions() error { - path := socketPath() - if runtime.GOOS == "linux" { - // we'll check if user is sudoer to restrict access - return os.Chmod(socketPath(), 0666) - } - - // chown admin group and let the OS restrict access - group, err := user.LookupGroup("admin") - if err != nil { - return fmt.Errorf("lookup admin group: %w", err) - } - gid, err := strconv.Atoi(group.Gid) - if err != nil { - return fmt.Errorf("convert admin gid %s: %w", group.Gid, err) - } - if err := os.Chown(path, 0, gid); err != nil { - return fmt.Errorf("chown %s: %w", path, err) - } - if err := os.Chmod(path, 0660); err != nil { - return fmt.Errorf("chmod %s: %w", path, err) - } - return nil -} diff --git a/vpn/ipc/status.go b/vpn/ipc/status.go deleted file mode 100644 index aef35029..00000000 --- a/vpn/ipc/status.go +++ /dev/null @@ -1,99 +0,0 @@ -package ipc - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "os" - "runtime" - - "github.com/sagernet/sing-box/common/conntrack" - "github.com/sagernet/sing/common/memory" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -// Metrics represents the runtime metrics of the service. -type Metrics struct { - Memory uint64 - Goroutines int - Connections int - - // UplinkTotal and DownlinkTotal are only available when the service is running and there are - // active connections. - // In bytes. - UplinkTotal int64 - // In bytes. - DownlinkTotal int64 -} - -// GetMetrics retrieves the current runtime metrics of the service. -func GetMetrics(ctx context.Context) (Metrics, error) { - return sendRequest[Metrics](ctx, "GET", metricsEndpoint, nil) -} - -func (s *Server) metricsHandler(w http.ResponseWriter, r *http.Request) { - _, span := otel.Tracer(tracerName).Start(r.Context(), "server.metricsHandler") - defer span.End() - stats := Metrics{ - Memory: memory.Inuse(), - Goroutines: runtime.NumGoroutine(), - Connections: conntrack.Count(), - } - if s.service.Status() == Connected { - up, down := s.service.ClashServer().TrafficManager().Total() - stats.UplinkTotal, stats.DownlinkTotal = up, down - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(stats); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} - -type state struct { - State VPNStatus `json:"state"` -} - -// GetStatus retrieves the current status of the service. -func GetStatus(ctx context.Context) (VPNStatus, error) { - // try to dial first to check if IPC server is even running and avoid waiting for timeout - if canDial, err := tryDial(ctx); !canDial { - return Disconnected, err - } - - res, err := sendRequest[state](ctx, "GET", statusEndpoint, nil) - if errors.Is(err, ErrIPCNotRunning) || errors.Is(err, ErrServiceIsNotReady) { - return Disconnected, nil - } - if err != nil { - return "", fmt.Errorf("error getting status: %w", err) - } - return res.State, nil -} - -func tryDial(ctx context.Context) (bool, error) { - conn, err := dialContext(ctx, "", "") - if err == nil { - conn.Close() - return true, nil - } - if errors.Is(err, os.ErrNotExist) { - return false, nil // IPC server is not running so don't treat it as an error - } - return false, err -} - -func (s *Server) statusHandler(w http.ResponseWriter, r *http.Request) { - span := trace.SpanFromContext(r.Context()) - status := s.service.Status() - span.SetAttributes(attribute.String("status", string(status))) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(state{status}); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} diff --git a/vpn/rvpn/platform.go b/vpn/rvpn/platform.go deleted file mode 100644 index 72275218..00000000 --- a/vpn/rvpn/platform.go +++ /dev/null @@ -1,9 +0,0 @@ -package rvpn - -import "github.com/sagernet/sing-box/experimental/libbox" - -type PlatformInterface interface { - libbox.PlatformInterface - RestartService() error - PostServiceClose() -} diff --git a/vpn/service.go b/vpn/service.go deleted file mode 100644 index 4df2f133..00000000 --- a/vpn/service.go +++ /dev/null @@ -1,223 +0,0 @@ -package vpn - -import ( - "context" - "errors" - "fmt" - "io" - "log/slog" - "os" - "path/filepath" - "runtime" - "sync" - - "github.com/sagernet/sing-box/experimental/clashapi" - - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal" - "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/vpn/ipc" - "github.com/getlantern/radiance/vpn/rvpn" -) - -var _ ipc.Service = (*TunnelService)(nil) - -// TunnelService manages the lifecycle of the VPN tunnel. -type TunnelService struct { - tunnel *tunnel - - platformIfce rvpn.PlatformInterface - logger *slog.Logger - - mu sync.Mutex -} - -// NewTunnelService creates a new TunnelService instance with the provided configuration paths, log -// level, and platform interface. -func NewTunnelService(dataPath string, logger *slog.Logger, platformIfce rvpn.PlatformInterface) *TunnelService { - if logger == nil { - logger = slog.Default() - } - switch logger.Handler().(type) { - case *slog.TextHandler, *slog.JSONHandler: - default: - os.MkdirAll(dataPath, 0o755) - path := filepath.Join(dataPath, "radiance_vpn.log") - var writer io.Writer - f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - slog.Error("Failed to open log file", "error", err) - writer = os.Stdout - } else { - writer = f - } - logger = slog.New(slog.NewTextHandler(writer, &slog.HandlerOptions{AddSource: true, Level: internal.LevelTrace})) - runtime.AddCleanup(logger, func(file *os.File) { - file.Close() - }, f) - } - return &TunnelService{ - platformIfce: platformIfce, - logger: logger, - } -} - -// Start initializes and starts the tunnel with the specified options. Returns an error if the -// tunnel is already running or initialization fails. -func (s *TunnelService) Start(ctx context.Context, options string) error { - s.mu.Lock() - defer s.mu.Unlock() - if s.tunnel != nil { - s.logger.Warn("tunnel already started") - return errors.New("tunnel already started") - } - s.logger.Debug("Starting tunnel", "options", options) - if err := s.start(ctx, options); err != nil { - return err - } - return nil -} - -func (s *TunnelService) start(ctx context.Context, options string) error { - path := settings.GetString(settings.DataPathKey) - t := tunnel{ - dataPath: path, - } - if err := t.start(options, s.platformIfce); err != nil { - return fmt.Errorf("failed to start tunnel: %w", err) - } - s.tunnel = &t - return nil -} - -// Close shuts down the currently running tunnel, if any. Returns an error if closing the tunnel fails. -func (s *TunnelService) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.tunnel == nil { - return nil - } - if err := s.close(); err != nil { - return err - } - if s.platformIfce != nil { - s.platformIfce.PostServiceClose() - } - return nil -} - -func (s *TunnelService) close() error { - t := s.tunnel - s.tunnel = nil - - s.logger.Info("Closing tunnel") - if err := t.close(); err != nil { - return err - } - s.logger.Debug("Tunnel closed") - runtime.GC() - return nil -} - -// Restart closes and restarts the tunnel if it is currently running. Returns an error if the tunnel -// is not running or restart fails. -func (s *TunnelService) Restart(ctx context.Context, options string) error { - s.mu.Lock() - if s.tunnel == nil { - s.mu.Unlock() - return errors.New("tunnel not started") - } - if s.tunnel.Status() != ipc.Connected { - s.mu.Unlock() - return errors.New("tunnel not running") - } - - s.logger.Info("Restarting tunnel") - if s.platformIfce != nil { - s.mu.Unlock() - if err := s.platformIfce.RestartService(); err != nil { - s.logger.Error("Failed to restart tunnel via platform interface", "error", err) - return fmt.Errorf("platform interface restart failed: %w", err) - } - return nil - } - - defer s.mu.Unlock() - if err := s.close(); err != nil { - return fmt.Errorf("closing tunnel: %w", err) - } - if err := s.start(ctx, options); err != nil { - s.logger.Error("starting tunnel", "error", err) - return fmt.Errorf("starting tunnel: %w", err) - } - s.logger.Info("Tunnel restarted successfully") - return nil -} - -// Status returns the current status of the tunnel (e.g., running, closed). -func (s *TunnelService) Status() ipc.VPNStatus { - s.mu.Lock() - defer s.mu.Unlock() - if s.tunnel == nil { - return ipc.Disconnected - } - return s.tunnel.Status() -} - -// Ctx returns the context associated with the tunnel, or nil if no tunnel is running. -func (s *TunnelService) Ctx() context.Context { - s.mu.Lock() - defer s.mu.Unlock() - if s.tunnel == nil { - return nil - } - return s.tunnel.ctx -} - -// ClashServer returns the Clash server instance associated with the tunnel, or nil if no tunnel is -// running. -func (s *TunnelService) ClashServer() *clashapi.Server { - s.mu.Lock() - defer s.mu.Unlock() - if s.tunnel == nil { - return nil - } - return s.tunnel.clashServer -} - -var errTunnelNotStarted = errors.New("tunnel not started") - -// activeTunnel returns the running tunnel or errTunnelNotStarted. -func (s *TunnelService) activeTunnel() (*tunnel, error) { - s.mu.Lock() - t := s.tunnel - s.mu.Unlock() - if t == nil { - return nil, errTunnelNotStarted - } - return t, nil -} - -func (s *TunnelService) UpdateOutbounds(newOpts servers.Servers) error { - t, err := s.activeTunnel() - if err != nil { - return err - } - return t.updateOutbounds(newOpts) -} - -func (s *TunnelService) AddOutbounds(group string, options servers.Options) error { - t, err := s.activeTunnel() - if err != nil { - return err - } - return t.addOutbounds(group, options) -} - -func (s *TunnelService) RemoveOutbounds(group string, tags []string) error { - t, err := s.activeTunnel() - if err != nil { - return err - } - return t.removeOutbounds(group, tags) -} diff --git a/vpn/split_tunnel.go b/vpn/split_tunnel.go index fa4e0ffd..2f8334e5 100644 --- a/vpn/split_tunnel.go +++ b/vpn/split_tunnel.go @@ -2,7 +2,6 @@ package vpn import ( "context" - "encoding/json" "errors" "fmt" "io/fs" @@ -16,17 +15,17 @@ import ( C "github.com/sagernet/sing-box/constant" O "github.com/sagernet/sing-box/option" - singjson "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json" - "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/atomicfile" - "github.com/getlantern/radiance/common/settings" + "github.com/getlantern/radiance/common/fileperm" "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" ) const ( splitTunnelTag = "split-tunnel" - splitTunnelFile = splitTunnelTag + ".json" + splitTunnelFile = internal.SplitTunnelFileName TypeDomain = "domain" TypeDomainSuffix = "domainSuffix" @@ -47,17 +46,18 @@ type SplitTunnel struct { ruleMap map[string]*O.DefaultHeadlessRule enabled *atomic.Bool access sync.Mutex + logger *slog.Logger } -func NewSplitTunnelHandler() (*SplitTunnel, error) { - s := newSplitTunnel(settings.GetString(settings.DataPathKey)) +func NewSplitTunnelHandler(dataPath string, logger *slog.Logger) (*SplitTunnel, error) { + s := newSplitTunnel(dataPath, logger) if err := s.loadRule(); err != nil { return nil, fmt.Errorf("loading split tunnel rule file %s: %w", s.ruleFile, err) } return s, nil } -func newSplitTunnel(path string) *SplitTunnel { +func newSplitTunnel(path string, logger *slog.Logger) *SplitTunnel { rule := defaultRule() s := &SplitTunnel{ rule: rule, @@ -65,24 +65,17 @@ func newSplitTunnel(path string) *SplitTunnel { activeFilter: &(rule.Rules[1].LogicalOptions), ruleMap: make(map[string]*O.DefaultHeadlessRule), enabled: &atomic.Bool{}, + logger: logger, } s.initRuleMap() if _, err := os.Stat(s.ruleFile); errors.Is(err, fs.ErrNotExist) { - slog.Debug("Creating initial split tunnel rule file", "file", s.ruleFile) + logger.Debug("Creating initial split tunnel rule file", "file", s.ruleFile) s.saveToFile() } return s } -func (s *SplitTunnel) Enable() error { - return s.setEnabled(true) -} - -func (s *SplitTunnel) Disable() error { - return s.setEnabled(false) -} - -func (s *SplitTunnel) setEnabled(enabled bool) error { +func (s *SplitTunnel) SetEnabled(enabled bool) error { if s.enabled.Load() == enabled { return nil } @@ -97,7 +90,7 @@ func (s *SplitTunnel) setEnabled(enabled bool) error { return fmt.Errorf("writing rule to %s: %w", s.ruleFile, err) } s.enabled.Store(enabled) - slog.Log(context.Background(), internal.LevelTrace, "Updated split-tunneling", "enabled", enabled) + s.logger.Log(context.Background(), log.LevelTrace, "Updated split-tunneling", "enabled", enabled) return nil } @@ -105,10 +98,10 @@ func (s *SplitTunnel) IsEnabled() bool { return s.enabled.Load() } -func (s *SplitTunnel) Filters() Filter { +func (s *SplitTunnel) Filters() SplitTunnelFilter { s.access.Lock() defer s.access.Unlock() - return Filter{ + return SplitTunnelFilter{ Domain: slices.Clone(s.ruleMap[TypeDomain].Domain), DomainSuffix: slices.Clone(s.ruleMap[TypeDomainSuffix].DomainSuffix), DomainKeyword: slices.Clone(s.ruleMap[TypeDomainKeyword].DomainKeyword), @@ -120,95 +113,12 @@ func (s *SplitTunnel) Filters() Filter { } } -// ItemsJSON returns the items for the given filter type as a JSON-encoded []string. -func (s *SplitTunnel) ItemsJSON(filterType string) (string, error) { - items, err := s.Filters().Items(filterType) - if err != nil { - return "", err - } - if items == nil { - items = []string{} - } - b, err := json.Marshal(items) - if err != nil { - return "", err - } - return string(b), nil -} - -// EnabledAppsJSON returns all enabled app/process identifiers from the split -// tunnel configuration as a JSON-encoded []string. It first extracts values -// from the parsed rule set (current sing-box format with snake_case keys), -// then falls back to scanning the raw file for legacy camelCase keys. -func (s *SplitTunnel) EnabledAppsJSON() (string, error) { - seen := map[string]struct{}{} - out := make([]string, 0, 16) - isWindows := common.IsWindows() - - addString := func(str string) { - str = strings.TrimSpace(str) - if str == "" { - return - } - key := str - if isWindows { - key = strings.ToLower(str) - } - if _, exists := seen[key]; exists { - return - } - seen[key] = struct{}{} - out = append(out, str) - } - - addSlice := func(items []string) { - for _, str := range items { - addString(str) - } - } - - // Extract from the parsed rule set (current format). - f := s.Filters() - addSlice(f.ProcessPath) - addSlice(f.ProcessPathRegex) - addSlice(f.ProcessName) - addSlice(f.PackageName) - - // Fall back to legacy camelCase top-level keys in the raw file. - b, err := atomicfile.ReadFile(s.ruleFile) - if err == nil && len(b) > 0 { - m, parseErr := singjson.UnmarshalExtended[map[string]any](b) - if parseErr == nil { - legacyKeys := []string{ - "processPathRegex", "processPath", "packageName", - } - for _, k := range legacyKeys { - arr, ok := m[k].([]any) - if !ok { - continue - } - for _, it := range arr { - if str, ok := it.(string); ok { - addString(str) - } - } - } - } - } - - encoded, err := json.Marshal(out) - if err != nil { - return "", err - } - return string(encoded), nil -} - // AddItem adds a new item to the filter of the given type. func (s *SplitTunnel) AddItem(filterType, item string) error { if err := s.updateFilter(filterType, item, merge); err != nil { return err } - slog.Debug("added item to filter", "filterType", filterType, "item", item) + s.logger.Debug("added item to filter", "filterType", filterType, "item", item) if err := s.saveToFile(); err != nil { return fmt.Errorf("writing rule to %s: %w", s.ruleFile, err) } @@ -220,7 +130,7 @@ func (s *SplitTunnel) RemoveItem(filterType, item string) error { if err := s.updateFilter(filterType, item, remove); err != nil { return err } - slog.Debug("removed item from filter", "filterType", filterType, "item", item) + s.logger.Debug("removed item from filter", "filterType", filterType, "item", item) if err := s.saveToFile(); err != nil { return fmt.Errorf("writing rule to %s: %w", s.ruleFile, err) } @@ -228,20 +138,20 @@ func (s *SplitTunnel) RemoveItem(filterType, item string) error { } // AddItems adds multiple items to the filter. -func (s *SplitTunnel) AddItems(items Filter) error { +func (s *SplitTunnel) AddItems(items SplitTunnelFilter) error { s.updateFilters(items, merge) - slog.Debug("added items to filter", "items", items.String()) + s.logger.Debug("added items to filter", "items", items.String()) return s.saveToFile() } // RemoveItems removes multiple items from the filter. -func (s *SplitTunnel) RemoveItems(items Filter) error { +func (s *SplitTunnel) RemoveItems(items SplitTunnelFilter) error { s.updateFilters(items, remove) - slog.Debug("removed items from filter", "items", items.String()) + s.logger.Debug("removed items from filter", "items", items.String()) return s.saveToFile() } -type Filter struct { +type SplitTunnelFilter struct { Domain []string DomainSuffix []string DomainKeyword []string @@ -252,31 +162,7 @@ type Filter struct { PackageName []string } -// Items returns the items for the given filter type. -func (f Filter) Items(filterType string) ([]string, error) { - switch filterType { - case TypeDomain: - return f.Domain, nil - case TypeDomainSuffix: - return f.DomainSuffix, nil - case TypeDomainKeyword: - return f.DomainKeyword, nil - case TypeDomainRegex: - return f.DomainRegex, nil - case TypeProcessName: - return f.ProcessName, nil - case TypeProcessPath: - return f.ProcessPath, nil - case TypeProcessPathRegex: - return f.ProcessPathRegex, nil - case TypePackageName: - return f.PackageName, nil - default: - return nil, fmt.Errorf("unsupported filter type: %s", filterType) - } -} - -func (f Filter) String() string { +func (f SplitTunnelFilter) String() string { var str []string if len(f.Domain) > 0 { str = append(str, fmt.Sprintf("domain: %v", f.Domain)) @@ -337,7 +223,7 @@ func (s *SplitTunnel) updateFilter(filterType string, item string, fn actionFn) return nil } -func (s *SplitTunnel) updateFilters(diff Filter, fn actionFn) { +func (s *SplitTunnel) updateFilters(diff SplitTunnelFilter, fn actionFn) { s.access.Lock() defer s.access.Unlock() @@ -409,11 +295,11 @@ func (s *SplitTunnel) saveToFile() error { }, }, } - buf, err := singjson.Marshal(rs) + buf, err := json.Marshal(rs) if err != nil { return fmt.Errorf("marshalling rule set: %w", err) } - if err := atomicfile.WriteFile(s.ruleFile, buf, 0644); err != nil { + if err := atomicfile.WriteFile(s.ruleFile, buf, fileperm.File); err != nil { return fmt.Errorf("writing rule file %s: %w", s.ruleFile, err) } return nil @@ -432,13 +318,13 @@ func (s *SplitTunnel) loadRule() error { if err != nil { return fmt.Errorf("reading rule file %s: %w", s.ruleFile, err) } - ruleSet, err := singjson.UnmarshalExtended[O.PlainRuleSetCompat](content) + ruleSet, err := json.UnmarshalExtended[O.PlainRuleSetCompat](content) if err != nil { return fmt.Errorf("unmarshalling rule file %s: %w", s.ruleFile, err) } rules := ruleSet.Options.Rules if len(rules) == 0 { - slog.Warn("split tunnel rule file format is invalid, using empty rule") + s.logger.Warn("split tunnel rule file format is invalid, using empty rule") return nil } @@ -454,7 +340,7 @@ func (s *SplitTunnel) loadRule() error { } else if len(s.rule.Rules) > 1 && s.rule.Rules[1].Type == C.RuleTypeDefault { // Migrate legacy format: wrap DefaultOptions into LogicalOptions // TODO(2/10): remove in future commit - slog.Debug("Migrating legacy split tunnel rule format") + s.logger.Debug("Migrating legacy split tunnel rule format") legacyRule := s.rule.Rules[1].DefaultOptions s.rule.Rules[1] = O.HeadlessRule{ Type: C.RuleTypeLogical, @@ -514,7 +400,7 @@ func (s *SplitTunnel) loadRule() error { s.initRuleMap() s.enabled.Store(s.rule.Mode == C.LogicalTypeOr) - slog.Log(context.Background(), internal.LevelTrace, "loaded split tunnel rules", + s.logger.Log(context.Background(), log.LevelTrace, "loaded split tunnel rules", "file", s.ruleFile, "filters", s.Filters().String(), "enabled", s.IsEnabled(), ) return nil @@ -604,7 +490,6 @@ func (s *SplitTunnel) initRuleMap() { for i := range s.activeFilter.Rules { rule := &s.activeFilter.Rules[i].DefaultOptions matched := false - if len(rule.Domain) > 0 || len(rule.DomainSuffix) > 0 || len(rule.DomainKeyword) > 0 || len(rule.DomainRegex) > 0 { s.ruleMap[TypeDomain] = rule diff --git a/vpn/split_tunnel_test.go b/vpn/split_tunnel_test.go index b849fad2..13b63f3b 100644 --- a/vpn/split_tunnel_test.go +++ b/vpn/split_tunnel_test.go @@ -2,7 +2,6 @@ package vpn import ( "context" - stdjson "encoding/json" "testing" "time" @@ -17,29 +16,22 @@ import ( "github.com/stretchr/testify/require" "github.com/getlantern/radiance/common/atomicfile" - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal/testutil" + rlog "github.com/getlantern/radiance/log" ) -func setupTestSplitTunnel(t *testing.T) *SplitTunnel { - testutil.SetPathsForTesting(t) - s := newSplitTunnel(settings.GetString(settings.DataPathKey)) - return s -} - func TestEnableDisableIsEnabled(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) - if assert.NoError(t, st.Disable()) { + if assert.NoError(t, st.SetEnabled(false)) { assert.False(t, st.IsEnabled(), "split tunnel should be disabled") } - if assert.NoError(t, st.Enable()) { + if assert.NoError(t, st.SetEnabled(true)) { assert.True(t, st.IsEnabled(), "split tunnel should be enabled") } } func TestAddRemoveItem(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) domain := "example.com" domain2 := "example2.com" @@ -72,18 +64,18 @@ func TestAddRemoveItem(t *testing.T) { } func TestRemoveItems(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) - require.NoError(t, st.RemoveItems(Filter{Domain: []string{"a.com"}, ProcessName: []string{"proc"}})) + require.NoError(t, st.RemoveItems(SplitTunnelFilter{Domain: []string{"a.com"}, ProcessName: []string{"proc"}})) f := st.Filters() assert.Empty(t, f.Domain) assert.Empty(t, f.ProcessName) } func TestAddRemoveItems(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) - items := Filter{ + items := SplitTunnelFilter{ Domain: []string{"a.com", "b.com"}, DomainSuffix: []string{"suffix"}, ProcessName: []string{"proc"}, @@ -97,7 +89,7 @@ func TestAddRemoveItems(t *testing.T) { assert.Equal(t, []string{"proc"}, f.ProcessName) assert.Equal(t, []string{"pkg"}, f.PackageName) - err = st.RemoveItems(Filter{Domain: []string{"a.com"}, ProcessName: []string{"proc"}}) + err = st.RemoveItems(SplitTunnelFilter{Domain: []string{"a.com"}, ProcessName: []string{"proc"}}) require.NoError(t, err) f = st.Filters() assert.Equal(t, []string{"b.com"}, f.Domain) @@ -105,22 +97,39 @@ func TestAddRemoveItems(t *testing.T) { } func TestFilterPersistence(t *testing.T) { - testutil.SetPathsForTesting(t) - st, err := NewSplitTunnelHandler() + tmpDir := t.TempDir() + st, err := NewSplitTunnelHandler(tmpDir, rlog.NoOpLogger()) require.NoError(t, err) require.NoError(t, st.AddItem("domain", "example.com")) f := st.Filters() assert.Equal(t, []string{"example.com"}, f.Domain) - st, err = NewSplitTunnelHandler() + st, err = NewSplitTunnelHandler(tmpDir, rlog.NoOpLogger()) require.NoError(t, err) f = st.Filters() assert.Equal(t, []string{"example.com"}, f.Domain, "expected filters to persist after reloading from file") } +func TestFilterPersistenceAfterLoad(t *testing.T) { + tmpDir := t.TempDir() + // Simulate the daemon path: NewSplitTunnelHandler (newSplitTunnel + loadRule), then AddItems + st, err := NewSplitTunnelHandler(tmpDir, rlog.NoOpLogger()) + require.NoError(t, err) + + require.NoError(t, st.AddItems(SplitTunnelFilter{Domain: []string{"example.com"}})) + f := st.Filters() + assert.Equal(t, []string{"example.com"}, f.Domain, "filter should be set in memory after AddItems") + + // Reload from disk to verify persistence + st2, err := NewSplitTunnelHandler(tmpDir, rlog.NoOpLogger()) + require.NoError(t, err) + f = st2.Filters() + assert.Equal(t, []string{"example.com"}, f.Domain, "filter should persist to disk after AddItems") +} + func TestUpdateFilterUnsupportedType(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) err := st.AddItem("unsupported", "foo") assert.Error(t, err) } @@ -145,7 +154,7 @@ func TestRemoveEdgeCases(t *testing.T) { } func TestMatch(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) require.NoError(t, st.AddItem("domain", "example.com")) ruleOpts := O.Rule{ @@ -193,7 +202,7 @@ func TestMatch(t *testing.T) { metadata := &adapter.InboundContext{Domain: "example.com"} rsStr := ruleSet.String() - require.NoError(t, st.Enable()) + require.NoError(t, st.SetEnabled(true)) require.Eventually(t, func() bool { return ruleSet.String() != rsStr }, time.Second, 50*time.Millisecond, "timed out waiting for rule reload") @@ -201,7 +210,7 @@ func TestMatch(t *testing.T) { assert.True(t, rule.Match(metadata), "rule should match when split tunnel is enabled") rsStr = ruleSet.String() - require.NoError(t, st.Disable()) + require.NoError(t, st.SetEnabled(false)) require.Eventually(t, func() bool { return ruleSet.String() != rsStr }, time.Second, 50*time.Millisecond, "timed out waiting for rule reload") @@ -219,7 +228,7 @@ func (r *mockRouter) RuleSet(tag string) (adapter.RuleSet, bool) { } func TestMigration(t *testing.T) { - st := setupTestSplitTunnel(t) + st := newSplitTunnel(t.TempDir(), rlog.NoOpLogger()) // Create a legacy format rule file legacyRule := O.LogicalHeadlessRule{ @@ -324,99 +333,3 @@ func TestMigration(t *testing.T) { rule, _ := json.UnmarshalExtended[O.LogicalHeadlessRule]([]byte(want)) assert.Equal(t, rule, st.rule) } - -// unmarshalItems is a test helper that unmarshals a JSON string into []string. -func unmarshalItems(t *testing.T, jsonStr string) []string { - t.Helper() - var items []string - require.NoError(t, stdjson.Unmarshal([]byte(jsonStr), &items)) - return items -} - -func TestItemsJSON(t *testing.T) { - st := setupTestSplitTunnel(t) - - t.Run("returns items for valid filter type", func(t *testing.T) { - require.NoError(t, st.AddItem(TypeDomain, "example.com")) - require.NoError(t, st.AddItem(TypeDomain, "test.org")) - - result, err := st.ItemsJSON(TypeDomain) - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Equal(t, []string{"example.com", "test.org"}, items) - }) - - t.Run("returns empty array when no items", func(t *testing.T) { - result, err := st.ItemsJSON(TypeDomainKeyword) - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Empty(t, items) - }) - - t.Run("returns error for unsupported filter type", func(t *testing.T) { - _, err := st.ItemsJSON("unsupported") - assert.Error(t, err) - assert.Contains(t, err.Error(), "unsupported filter type") - }) - - t.Run("returns items for package names", func(t *testing.T) { - require.NoError(t, st.AddItem(TypePackageName, "com.example.app")) - result, err := st.ItemsJSON(TypePackageName) - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Equal(t, []string{"com.example.app"}, items) - }) -} - -func TestEnabledAppsJSON(t *testing.T) { - st := setupTestSplitTunnel(t) - - t.Run("returns empty array when no apps configured", func(t *testing.T) { - result, err := st.EnabledAppsJSON() - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Empty(t, items) - }) - - t.Run("returns apps from current format", func(t *testing.T) { - require.NoError(t, st.AddItem(TypePackageName, "com.example.app")) - require.NoError(t, st.AddItem(TypeProcessPath, "/usr/bin/firefox")) - - result, err := st.EnabledAppsJSON() - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Contains(t, items, "com.example.app") - assert.Contains(t, items, "/usr/bin/firefox") - }) - - t.Run("picks up legacy camelCase keys from raw file", func(t *testing.T) { - st2 := setupTestSplitTunnel(t) - require.NoError(t, st2.AddItem(TypePackageName, "com.current.app")) - - // Patch the file with legacy camelCase keys alongside current format - b, err := atomicfile.ReadFile(st2.ruleFile) - require.NoError(t, err) - var raw map[string]any - require.NoError(t, stdjson.Unmarshal(b, &raw)) - raw["packageName"] = []string{"com.legacy.app"} - raw["processPath"] = []string{"/opt/legacy"} - patched, err := stdjson.Marshal(raw) - require.NoError(t, err) - require.NoError(t, atomicfile.WriteFile(st2.ruleFile, patched, 0644)) - - result, err := st2.EnabledAppsJSON() - require.NoError(t, err) - items := unmarshalItems(t, result) - assert.Contains(t, items, "com.current.app") - assert.Contains(t, items, "com.legacy.app") - assert.Contains(t, items, "/opt/legacy") - // Deduplication: com.current.app should appear exactly once - count := 0 - for _, app := range items { - if app == "com.current.app" { - count++ - } - } - assert.Equal(t, 1, count, "com.current.app should appear exactly once") - }) -} diff --git a/vpn/testdata/boxopts.json b/vpn/testdata/boxopts.json index fae1b2e1..cc593d72 100644 --- a/vpn/testdata/boxopts.json +++ b/vpn/testdata/boxopts.json @@ -14,90 +14,44 @@ "type": "direct", "tag": "direct" }, - { - "type": "block", - "tag": "block" - }, { "type": "http", - "tag": "http1-out", + "tag": "http-out", "server": "127.0.0.1", "server_port": 4080 }, - { - "type": "http", - "tag": "http2-out", - "server": "127.0.0.1", - "server_port": 4443 - }, { "type": "socks", - "tag": "socks1-out", + "tag": "socks-out", "server": "127.0.0.1", "server_port": 5080 }, - { - "type": "socks", - "tag": "socks2-out", - "server": "127.0.0.1", - "server_port": 5443 - }, { "type": "mutableurltest", - "tag": "auto-http", + "tag": "auto", "outbounds": [ - "http1-out", - "http2-out" - ] - }, - { - "type": "mutableurltest", - "tag": "auto-socks", - "outbounds": [ - "socks1-out", - "socks2-out" + "http-out", + "socks-out" ] }, { "type": "mutableselector", - "tag": "http", + "tag": "manual", "outbounds": [ - "auto-http", - "http1-out", - "http2-out" - ] - }, - { - "type": "mutableselector", - "tag": "socks", - "outbounds": [ - "auto-socks", - "socks1-out", - "socks2-out" - ] - }, - { - "type": "mutableurltest", - "tag": "auto-all", - "outbounds": [ - "auto-http", - "auto-socks" + "http-out", + "socks-out" ] } ], "route": { "rules": [ { - "clash_mode": "direct", - "outbound": "direct" - }, - { - "clash_mode": "http", - "outbound": "http" + "clash_mode": "auto", + "outbound": "auto" }, { - "clash_mode": "socks", - "outbound": "socks" + "clash_mode": "manual", + "outbound": "manual" } ] }, @@ -107,7 +61,7 @@ "cache_id": "test_cache" }, "clash_api": { - "default_mode": "Rule" + "default_mode": "auto" } } } diff --git a/vpn/testdata/config.json b/vpn/testdata/config.json index 8519af94..097768d1 100644 --- a/vpn/testdata/config.json +++ b/vpn/testdata/config.json @@ -1,58 +1,67 @@ { - "ConfigResponse": { - "smart_routing": [ - { - "category": "openai", - "rule_sets": [ - { - "tag": "openai", - "url": "https://ruleset.com/openai.srs" - } - ], - "outbounds": [ - "http1-out", - "socks1-out" - ] - } - ], - "ad_block": [ - { - "tag": "adblock-1", - "url": "https://ruleset.com/adblock-1.srs" - }, - { - "tag": "adblock-2", - "url": "https://ruleset.com/adblock-2.srs" - } - ], - "options": { - "outbounds": [ - { - "type": "http", - "tag": "http1-out", - "server": "127.0.0.1", - "server_port": 4080 - }, - { - "type": "http", - "tag": "http2-out", - "server": "127.0.0.1", - "server_port": 4443 - }, + "smart_routing": [ + { + "category": "direct", + "rule_sets": [ { - "type": "socks", - "tag": "socks1-out", - "server": "127.0.0.1", - "server_port": 5080 - }, + "tag": "sr-direct", + "url": "https://ruleset.com/direct.srs" + } + ], + "outbounds": [ + "direct" + ] + }, + { + "category": "openai", + "rule_sets": [ { - "type": "socks", - "tag": "socks2-out", - "server": "127.0.0.1", - "server_port": 5443 + "tag": "openai", + "url": "https://ruleset.com/openai.srs" } + ], + "outbounds": [ + "http1-out", + "socks1-out" ] } - }, - "PreferredLocation": {} + ], + "ad_block": [ + { + "tag": "adblock-1", + "url": "https://ruleset.com/adblock-1.srs" + }, + { + "tag": "adblock-2", + "url": "https://ruleset.com/adblock-2.srs" + } + ], + "options": { + "outbounds": [ + { + "type": "http", + "tag": "http1-out", + "server": "127.0.0.1", + "server_port": 4080 + }, + { + "type": "http", + "tag": "http2-out", + "server": "127.0.0.1", + "server_port": 4443 + }, + { + "type": "socks", + "tag": "socks1-out", + "server": "127.0.0.1", + "server_port": 5080 + }, + { + "type": "socks", + "tag": "socks2-out", + "server": "127.0.0.1", + "server_port": 5443 + } + ] + } } diff --git a/vpn/tunnel.go b/vpn/tunnel.go index bd577885..274d0731 100644 --- a/vpn/tunnel.go +++ b/vpn/tunnel.go @@ -7,15 +7,28 @@ import ( "fmt" "io" "log/slog" + "net/http" "path/filepath" + runtimeDebug "runtime/debug" "slices" - "sync/atomic" "time" - lcommon "github.com/getlantern/common" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/conntrack" + "github.com/sagernet/sing-box/common/urltest" + "github.com/sagernet/sing-box/experimental" + "github.com/sagernet/sing-box/experimental/libbox" + sblog "github.com/sagernet/sing-box/log" + O "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/service" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + lsync "github.com/getlantern/common/sync" box "github.com/getlantern/lantern-box" - lbA "github.com/getlantern/lantern-box/adapter" "github.com/getlantern/lantern-box/adapter/groups" lblog "github.com/getlantern/lantern-box/log" @@ -23,62 +36,87 @@ import ( "github.com/getlantern/radiance/common" "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/kindling" + rlog "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/vpn/ipc" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/urltest" - "github.com/sagernet/sing-box/experimental/clashapi" - "github.com/sagernet/sing-box/experimental/libbox" - sblog "github.com/sagernet/sing-box/log" - O "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common/json" - "github.com/sagernet/sing/service" ) type tunnel struct { - ctx context.Context - lbService *libbox.BoxService - clashServer *clashapi.Server - logFactory sblog.ObservableFactory + ctx context.Context + lbService *libbox.BoxService + clashServer *clashServer + urltestHistory adapter.URLTestHistoryStorage + urlTestSeed map[string]adapter.URLTestHistory + logFactory sblog.ObservableFactory dataPath string // optsMap is a map of current outbound/endpoint options JSON, used to deduplicate when adding // outbounds/endpoints - optsMap *lsync.TypedMap[string, []byte] - mutGrpMgr *groups.MutableGroupManager + optsMap *lsync.TypedMap[string, []byte] + mutGrpMgr *groups.MutableGroupManager + outboundMgr adapter.OutboundManager clientContextTracker *clientcontext.ClientContextInjector - status atomic.Value cancel context.CancelFunc closers []io.Closer } -func (t *tunnel) start(options string, platformIfce libbox.PlatformInterface) error { - t.status.Store(ipc.Connecting) - t.ctx, t.cancel = context.WithCancel(box.BaseContext()) - - if err := t.init(options, platformIfce); err != nil { +func (t *tunnel) start(ctx context.Context, options string, platformIfce libbox.PlatformInterface, isRestart bool) error { + ctx, span := otel.Tracer(tracerName).Start(ctx, "tunnel.start", + trace.WithAttributes( + attribute.Int("options_size", len(options)), + attribute.String("platform", common.Platform), + attribute.Bool("is_restart", isRestart), + )) + defer span.End() + + // Unbounded signaling must dial freddie outside the VPN tunnel or it + // recursively re-enters itself. streamingRoundTripper forces kindling to + // skip AMP (non-streamable) so freddie's long-poll genesis stream works. + baseCtx := lbA.ContextWithDirectTransport(box.BaseContext(), streamingRoundTripper{inner: kindling.HTTPClient().Transport}) + t.ctx, t.cancel = context.WithCancel(baseCtx) + + if err := t.init(ctx, options, platformIfce); err != nil { t.close() slog.Error("Failed to initialize tunnel", "error", err) return fmt.Errorf("initializing tunnel: %w", err) } - if err := t.connect(); err != nil { + if err := t.connect(ctx); err != nil { t.close() slog.Error("Failed to connect tunnel", "error", err) return fmt.Errorf("connecting tunnel: %w", err) } - t.status.Store(ipc.Connected) t.optsMap = makeOutboundOptsMap(t.ctx, options) return nil } -func (t *tunnel) init(options string, platformIfce libbox.PlatformInterface) error { - slog.Log(nil, internal.LevelTrace, "Initializing tunnel") +// traceSpan wraps fn in a child span of the caller's context and records any +// error on the child span so failures show up per-phase in the trace. +func traceSpan(ctx context.Context, name string, fn func() error) error { + _, span := otel.Tracer(tracerName).Start(ctx, name) + defer span.End() + err := fn() + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } + return err +} + +func (t *tunnel) init(ctx context.Context, options string, platformIfce libbox.PlatformInterface) (err error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "tunnel.init") + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } + span.End() + }() + + slog.Log(nil, rlog.LevelTrace, "Initializing tunnel") // setup libbox service dataPath := t.dataPath @@ -93,19 +131,36 @@ func (t *tunnel) init(options string, platformIfce libbox.PlatformInterface) err setupOpts.FixAndroidStack = true } - slog.Log(nil, internal.LevelTrace, "Setting up libbox", "setup_options", setupOpts) - if err := libbox.Setup(setupOpts); err != nil { + slog.Log(nil, rlog.LevelTrace, "Setting up libbox", "setup_options", setupOpts) + if err := traceSpan(ctx, "libbox.Setup", func() error { + return libbox.Setup(setupOpts) + }); err != nil { return fmt.Errorf("setup libbox: %w", err) } t.logFactory = lblog.NewFactory(slog.Default().Handler()) service.MustRegister[sblog.Factory](t.ctx, t.logFactory) - slog.Log(nil, internal.LevelTrace, "Creating libbox service") - lb, err := libbox.NewServiceWithContext(t.ctx, options, platformIfce) - if err != nil { + experimental.RegisterClashServerConstructor(newClashServer) + + t.urltestHistory = urltest.NewHistoryStorage() + for tag, h := range t.urlTestSeed { + t.urltestHistory.StoreURLTestHistory(tag, &h) + } + service.MustRegister[adapter.URLTestHistoryStorage](t.ctx, t.urltestHistory) + t.closers = append(t.closers, t.urltestHistory) + + slog.Log(nil, rlog.LevelTrace, "Creating libbox service") + var lb *libbox.BoxService + if err := traceSpan(ctx, "libbox.NewServiceWithContext", func() error { + var err error + lb, err = libbox.NewServiceWithContext(t.ctx, options, platformIfce) + return err + }); err != nil { return fmt.Errorf("create libbox service: %w", err) } + cacheFile := service.FromContext[adapter.CacheFile](t.ctx) + service.MustRegister[adapter.CacheFile](t.ctx, &cacheFileWrapper{CacheFile: cacheFile}) // setup client info tracker outboundMgr := service.FromContext[adapter.OutboundManager](t.ctx) @@ -118,11 +173,6 @@ func (t *tunnel) init(options string, platformIfce libbox.PlatformInterface) err t.closers = append(t.closers, lb) t.lbService = lb - history := service.PtrFromContext[urltest.HistoryStorage](t.ctx) - if err := loadURLTestHistory(history, filepath.Join(dataPath, urlTestHistoryFileName)); err != nil { - return fmt.Errorf("load urltest history: %w", err) - } - // set memory limit for Android and iOS switch common.Platform { case "android", "ios": @@ -143,19 +193,15 @@ func newClientContextInjector(outboundMgr adapter.OutboundManager, dataPath stri Platform: common.Platform, IsPro: settings.IsPro(), CountryCode: settings.GetString(settings.CountryCodeKey), - Version: common.Version, + Version: common.GetVersion(), } } + // Outbound match bounds start empty and are populated when lantern servers are added via + // addOutbounds. Only lantern servers support client context tracking. matchBounds := clientcontext.MatchBounds{ Inbound: []string{"any"}, Outbound: []string{}, } - if outbound, exists := outboundMgr.Outbound(servers.SGLantern); exists { - // Note: this should only contain lantern outbounds with servers that support client context - // tracking. otherwise, the connections will fail. - tags := outbound.(adapter.OutboundGroup).All() - matchBounds.Outbound = append(tags, servers.SGLantern, groupAutoTag(servers.SGLantern)) - } return clientcontext.NewClientContextInjector(infoFn, matchBounds) } @@ -179,8 +225,17 @@ func newMutableGroupManager( return groups.NewMutableGroupManager(logger, oMgr, epMgr, connMgr, mutGroups), nil } -func (t *tunnel) connect() (err error) { - slog.Log(nil, internal.LevelTrace, "Starting libbox service") +func (t *tunnel) connect(ctx context.Context) (err error) { + ctx, span := otel.Tracer(tracerName).Start(ctx, "tunnel.connect") + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } + span.End() + }() + + slog.Log(nil, rlog.LevelTrace, "Starting libbox service") defer func() { if r := recover(); r != nil { @@ -188,41 +243,65 @@ func (t *tunnel) connect() (err error) { err = fmt.Errorf("panic starting libbox service: %v", r) } }() - if err := t.lbService.Start(); err != nil { + if err := traceSpan(ctx, "libbox.BoxService.Start", func() error { + return t.lbService.Start() + }); err != nil { slog.Error("Failed to start libbox service", "error", err) return fmt.Errorf("starting libbox service: %w", err) } slog.Debug("Libbox service started") - t.clashServer = service.FromContext[adapter.ClashServer](t.ctx).(*clashapi.Server) + t.clashServer = service.FromContext[adapter.ClashServer](t.ctx).(*clashServer) + t.outboundMgr = service.FromContext[adapter.OutboundManager](t.ctx) - mutGrpMgr, err := newMutableGroupManager( - t.ctx, t.logFactory.NewLogger("groupsManager"), t.clashServer.TrafficManager(), - ) - if err != nil { + var mutGrpMgr *groups.MutableGroupManager + if err := traceSpan(ctx, "newMutableGroupManager", func() error { + var err error + mutGrpMgr, err = newMutableGroupManager( + t.ctx, t.logFactory.NewLogger("groupsManager"), t.clashServer.TrafficManager(), + ) + return err + }); err != nil { + t.close() return fmt.Errorf("creating mutable group manager: %w", err) } t.mutGrpMgr = mutGrpMgr + // Prepend: mgm's removalQueue reads from libbox-managed state, so close it first. + t.closers = append([]io.Closer{ + closerFunc(func() error { mutGrpMgr.Close(); return nil }), + }, t.closers...) slog.Info("Tunnel connection established") return nil } -func (t *tunnel) selectOutbound(group, tag string) error { - if status := t.Status(); status != ipc.Connected { - return fmt.Errorf("tunnel not running: status %v", status) +func (t *tunnel) selectMode(mode string) error { + if t.lbService == nil { + return fmt.Errorf("tunnel not running") } - t.clashServer.SetMode(group) - if tag == "" { - return nil + if t.clashServer.Mode() != mode { + t.clashServer.SetMode(mode) + conntrack.Close() + go func() { + time.Sleep(time.Second) + runtimeDebug.FreeOSMemory() + }() } + return nil +} + +func (t *tunnel) selectOutbound(tag string) error { + if err := t.selectMode(ManualSelectTag); err != nil { + return err + } + outboundMgr := service.FromContext[adapter.OutboundManager](t.ctx) - outbound, loaded := outboundMgr.Outbound(group) + outbound, loaded := outboundMgr.Outbound(ManualSelectTag) if !loaded { - return fmt.Errorf("selector group not found: %s", group) + return fmt.Errorf("manual select group not found") } - outbound.(ipc.Selector).SelectOutbound(tag) + outbound.(Selector).SelectOutbound(tag) return nil } @@ -235,7 +314,7 @@ func (t *tunnel) close() error { go func() { var errs []error for _, closer := range t.closers { - slog.Log(nil, internal.LevelTrace, "Closing tunnel resource", "type", fmt.Sprintf("%T", closer)) + slog.Log(nil, rlog.LevelTrace, "Closing tunnel resource", "type", fmt.Sprintf("%T", closer)) errs = append(errs, closer.Close()) } done <- errors.Join(errs...) @@ -249,61 +328,69 @@ func (t *tunnel) close() error { t.closers = nil t.lbService = nil - t.status.Store(ipc.Disconnected) return err } -func (t *tunnel) Status() ipc.VPNStatus { - return t.status.Load().(ipc.VPNStatus) -} - var errLibboxClosed = errors.New("libbox closed") -func (t *tunnel) addOutbounds(group string, options servers.Options) (err error) { - if len(options.Outbounds) == 0 && len(options.Endpoints) == 0 { - slog.Debug("No outbounds or endpoints to add", "group", group) +func (t *tunnel) addOutbounds(list servers.ServerList) (err error) { + outbounds := list.Outbounds() + endpoints := list.Endpoints() + if len(outbounds) == 0 && len(endpoints) == 0 { + slog.Debug("No outbounds or endpoints to add") return nil } - slog.Info("Adding servers to group", "group", group, "tags", options.AllTags()) + slog.Info("Adding servers", "tags", list.Tags()) // remove duplicates from newOpts before adding to avoid unnecessary reloads - newOptions := removeDuplicates(t.ctx, t.optsMap, options, group) + newList := removeDuplicates(t.ctx, t.optsMap, list) + newOutbounds := newList.Outbounds() + newEndpoints := newList.Endpoints() ctx := t.ctx router := service.FromContext[adapter.Router](ctx) var errs []error - if group == servers.SGLantern && t.clientContextTracker != nil { + if t.clientContextTracker != nil { // preemptively merge the new lantern tags into the clientContextInjector match bounds to // capture any new connections before finished adding the servers. - if tags := options.AllTags(); len(tags) > 0 { - slog.Log(nil, internal.LevelTrace, "Temporarily merging new lantern tags into ClientContextInjector") + lanternTags := make([]string, 0, len(newList.Servers)) + for _, srv := range newList.Servers { + if srv.IsLantern { + lanternTags = append(lanternTags, srv.Tag) + } + } + if len(lanternTags) > 0 { + slog.Log(nil, rlog.LevelTrace, "Temporarily merging new lantern tags into ClientContextInjector") matchBounds := t.clientContextTracker.MatchBounds() - matchBounds.Outbound = append(matchBounds.Outbound, tags...) + matchBounds.Outbound = append(matchBounds.Outbound, lanternTags...) t.clientContextTracker.SetBounds(matchBounds) } defer func() { - if !errors.Is(err, errLibboxClosed) { - t.updateClientContextTracker() + if errors.Is(err, errLibboxClosed) { + return } + // Remove any lantern tags that failed to load from the match bounds. + mb := t.clientContextTracker.MatchBounds() + mb.Outbound = slices.DeleteFunc(mb.Outbound, func(tag string) bool { + _, loaded := t.optsMap.Load(tag) + return slices.Contains(lanternTags, tag) && !loaded + }) + t.clientContextTracker.SetBounds(mb) }() } var ( mutGrpMgr = t.mutGrpMgr - autoTag = groupAutoTag(group) added = 0 ) - // for each outbound/endpoint in new add to group. - // All outbounds go in the URL test group — the server now sends callback - // URLs for every outbound, and the bounded worker pool helps limit memory usage. - for _, outbound := range newOptions.Outbounds { + for _, outbound := range newOutbounds { logger := t.logFactory.NewLogger("outbound/" + outbound.Tag + "[" + outbound.Type + "]") err := mutGrpMgr.CreateOutboundForGroup( - ctx, router, logger, group, outbound.Tag, outbound.Type, outbound.Options, + ctx, router, logger, ManualSelectTag, outbound.Tag, outbound.Type, outbound.Options, ) if err == nil { - err = mutGrpMgr.AddToGroup(autoTag, outbound.Tag) + err = mutGrpMgr.AddToGroup(AutoSelectTag, outbound.Tag) } if errors.Is(err, groups.ErrIsClosed) { return errLibboxClosed @@ -312,7 +399,6 @@ func (t *tunnel) addOutbounds(group string, options servers.Options) (err error) slog.Warn("Failed to load outbound", "tag", outbound.Tag, "type", outbound.Type, - "group", group, "error", err, ) errs = append(errs, err) @@ -327,13 +413,13 @@ func (t *tunnel) addOutbounds(group string, options servers.Options) (err error) return ctx.Err() } - for _, endpoint := range newOptions.Endpoints { + for _, endpoint := range newEndpoints { logger := t.logFactory.NewLogger("endpoint/" + endpoint.Tag + "[" + endpoint.Type + "]") err := mutGrpMgr.CreateEndpointForGroup( - ctx, router, logger, group, endpoint.Tag, endpoint.Type, endpoint.Options, + ctx, router, logger, ManualSelectTag, endpoint.Tag, endpoint.Type, endpoint.Options, ) if err == nil { - err = mutGrpMgr.AddToGroup(autoTag, endpoint.Tag) + err = mutGrpMgr.AddToGroup(AutoSelectTag, endpoint.Tag) } if errors.Is(err, groups.ErrIsClosed) { return errLibboxClosed @@ -342,7 +428,6 @@ func (t *tunnel) addOutbounds(group string, options servers.Options) (err error) slog.Warn("Failed to load endpoint", "tag", endpoint.Tag, "type", endpoint.Type, - "group", group, "error", err, ) errs = append(errs, err) @@ -353,34 +438,32 @@ func (t *tunnel) addOutbounds(group string, options servers.Options) (err error) } } - if len(options.URLOverrides) > 0 { + if len(list.URLOverrides) > 0 { slog.Info("Applying bandit URL overrides to URL test group", - "group", autoTag, - "override_count", len(options.URLOverrides), + "override_count", len(list.URLOverrides), ) } - if err := t.mutGrpMgr.SetURLOverrides(autoTag, options.URLOverrides); err != nil { - slog.Warn("Failed to set URL overrides", "group", autoTag, "error", err) - } else if len(options.URLOverrides) > 0 { + if err := t.mutGrpMgr.SetURLOverrides(AutoSelectTag, list.URLOverrides); err != nil { + slog.Warn("Failed to set URL overrides", "error", err) + } else if len(list.URLOverrides) > 0 { // Trigger an immediate URL test cycle when we have bandit overrides so // callback probes are hit within seconds of config receipt rather than // waiting for the next scheduled interval (3 min). - if err := t.mutGrpMgr.CheckOutbounds(autoTag); err != nil { - slog.Warn("Failed to trigger immediate URL test after bandit overrides", "group", autoTag, "error", err) + if err := t.mutGrpMgr.CheckOutbounds(AutoSelectTag); err != nil { + slog.Warn("Failed to trigger immediate URL test after bandit overrides", "error", err) } else { - slog.Info("Triggered immediate URL test for bandit callbacks", "group", autoTag) + slog.Info("Triggered immediate URL test for bandit callbacks") } } - slog.Debug("Added servers to group", "group", group, "added", added) + slog.Debug("Added servers", "added", added) return errors.Join(errs...) } -func (t *tunnel) removeOutbounds(group string, tags []string) error { +func (t *tunnel) removeOutbounds(tags []string) error { var ( mutGrpMgr = t.mutGrpMgr - autoTag = groupAutoTag(group) - removed = 0 + removed []string errs []error ) for _, tag := range tags { @@ -389,157 +472,116 @@ func (t *tunnel) removeOutbounds(group string, tags []string) error { continue // skip nested urltests } } - err := mutGrpMgr.RemoveFromGroup(group, tag) + err := mutGrpMgr.RemoveFromGroup(ManualSelectTag, tag) + if err == nil { + // remove from urltest + err = mutGrpMgr.RemoveFromGroup(AutoSelectTag, tag) + } if errors.Is(err, groups.ErrIsClosed) { return errLibboxClosed } if err != nil { errs = append(errs, err) - continue - } - // Best-effort removal from URL test group — extra outbounds - // (non-smart) are only in the selector, not the URL test group, - // so this removal is expected to fail for them. - if utErr := mutGrpMgr.RemoveFromGroup(autoTag, tag); utErr != nil { - if errors.Is(utErr, groups.ErrIsClosed) { - return errLibboxClosed - } - slog.Debug("Failed best-effort removal from URL test group", - "tag", tag, "group", autoTag, "error", utErr) + } else { + t.optsMap.Delete(tag) + removed = append(removed, tag) } - t.optsMap.Delete(tag) - removed++ } - if t.clientContextTracker != nil { - t.updateClientContextTracker() + if t.clientContextTracker != nil && len(removed) > 0 { + mb := t.clientContextTracker.MatchBounds() + mb.Outbound = slices.DeleteFunc(mb.Outbound, func(s string) bool { + return slices.Contains(removed, s) + }) + t.clientContextTracker.SetBounds(mb) } - slog.Debug("Removed servers from group", "group", group, "removed", removed) + slog.Debug("Removed servers", "removed", len(removed)) return errors.Join(errs...) } -func (t *tunnel) updateClientContextTracker() { - outboundMgr := service.FromContext[adapter.OutboundManager](t.ctx) - outbound, exists := outboundMgr.Outbound(servers.SGLantern) - if !exists { - return +func (t *tunnel) updateOutbounds(list servers.ServerList) error { + var errs []error + outbounds := list.Outbounds() + endpoints := list.Endpoints() + if len(outbounds) == 0 && len(endpoints) == 0 && len(list.URLOverrides) == 0 { + slog.Debug("No outbounds, endpoints, or bandit overrides to update, skipping") + return nil } - outGroup := outbound.(adapter.OutboundGroup) - slog.Debug("Setting updated lantern tags into ClientContextInjector") - t.clientContextTracker.SetBounds(clientcontext.MatchBounds{ - Inbound: []string{"any"}, - Outbound: append(outGroup.All(), servers.SGLantern, groupAutoTag(servers.SGLantern)), - }) -} + slog.Log(nil, rlog.LevelTrace, "Updating servers") -func (t *tunnel) updateOutbounds(new servers.Servers) error { - var errs []error - for _, group := range []string{servers.SGLantern, servers.SGUser} { - newOpts := new[group] - if len(newOpts.Outbounds) == 0 && len(newOpts.Endpoints) == 0 && len(newOpts.URLOverrides) == 0 { - slog.Debug("No outbounds, endpoints, or URL overrides to update, skipping", "group", group) - continue - } - slog.Log(nil, internal.LevelTrace, "Updating servers", "group", group) - - autoTag := groupAutoTag(group) - selector, selectorExists := t.mutGrpMgr.OutboundGroup(group) - _, urltestExists := t.mutGrpMgr.OutboundGroup(autoTag) - if !selectorExists || !urltestExists { - // Yes, panic. And, yes, it's intentional. Both selector and URLtest should always exist - // if the tunnel is running, so this is a "world no longer makes sense" situation. This - // should be caught during testing and will not panic in release builds. - slog.Log( - nil, internal.LevelPanic, "selector or urltest group missing", "group", group, - "selector_exists", selectorExists, "urltest_exists", urltestExists, - ) - panic(fmt.Errorf( - "selector or urltest group missing for %q. selector_exists=%v, urltest_exists=%v", - group, selectorExists, urltestExists, - )) - } + selector, selectorExists := t.mutGrpMgr.OutboundGroup(ManualSelectTag) + _, urltestExists := t.mutGrpMgr.OutboundGroup(AutoSelectTag) + if !selectorExists || !urltestExists { + slog.Error("Selector or URL test group not found when updating outbounds") + return errors.New("selector or url test group not found") + } - if contextDone(t.ctx) { - return t.ctx.Err() - } + if contextDone(t.ctx) { + return t.ctx.Err() + } - // collect tags present in the current group but absent from the new config - newTags := newOpts.AllTags() - var toRemove []string - for _, tag := range selector.All() { - if !slices.Contains(newTags, tag) { - toRemove = append(toRemove, tag) - } + // collect tags present in the current group but absent from the new config + newTags := list.Tags() + var toRemove []string + for _, tag := range selector.All() { + if !slices.Contains(newTags, tag) { + toRemove = append(toRemove, tag) } + } - // Add new outbounds first, before removing old ones. If all new - // outbounds fail to load (e.g. invalid config), we keep the old - // working outbounds to maintain connectivity. - addErr := t.addOutbounds(group, newOpts) - if errors.Is(addErr, errLibboxClosed) { - return addErr - } - if addErr != nil { - errs = append(errs, addErr) - } + // Add new outbounds first, before removing old ones. If all new + // outbounds fail to load (e.g. invalid config), we keep the old + // working outbounds to maintain connectivity. + addErr := t.addOutbounds(list) + if errors.Is(addErr, errLibboxClosed) { + return addErr + } + if addErr != nil { + errs = append(errs, addErr) + } - // Check if any new outbound actually loaded into the group. - hasNewOutbound := false - for _, tag := range newTags { - if slices.Contains(selector.All(), tag) { - hasNewOutbound = true - break - } + // Check if any new outbound actually loaded + hasNewOutbound := false + for _, tag := range newTags { + if slices.Contains(selector.All(), tag) { + hasNewOutbound = true + break } + } - if hasNewOutbound { - if err := t.removeOutbounds(group, toRemove); errors.Is(err, errLibboxClosed) { - return err - } else if err != nil { - errs = append(errs, err) - } - } else { - slog.Warn("All new outbounds failed to load, keeping old outbounds", - "group", group, "failed_tags", newTags, "would_remove_tags", toRemove) + if hasNewOutbound { + if err := t.removeOutbounds(toRemove); errors.Is(err, errLibboxClosed) { + return err + } else if err != nil { + errs = append(errs, err) } + } else { + slog.Warn("All new outbounds failed to load, keeping old outbounds", + "failed_tags", newTags, "would_remove_tags", toRemove) } return errors.Join(errs...) } -func removeDuplicates(ctx context.Context, curr *lsync.TypedMap[string, []byte], new servers.Options, group string) servers.Options { - slog.Log(nil, internal.LevelTrace, "Removing duplicate outbounds/endpoints", "group", group) - deduped := servers.Options{ - Outbounds: []O.Outbound{}, - Endpoints: []O.Endpoint{}, - Locations: map[string]lcommon.ServerLocation{}, - URLOverrides: new.URLOverrides, - Credentials: new.Credentials, - } +func removeDuplicates(ctx context.Context, curr *lsync.TypedMap[string, []byte], list servers.ServerList) servers.ServerList { + slog.Log(nil, rlog.LevelTrace, "Removing duplicate outbounds/endpoints") + var deduped []*servers.Server var dropped []string - for _, out := range new.Outbounds { - if currOpts, exists := curr.Load(out.Tag); exists { - if outBytes, _ := json.MarshalContext(ctx, out); bytes.Equal(currOpts, outBytes) { - dropped = append(dropped, out.Tag) + for _, srv := range list.Servers { + if currOpts, exists := curr.Load(srv.Tag); exists { + if srvBytes, _ := json.MarshalContext(ctx, srv.Options); bytes.Equal(currOpts, srvBytes) { + dropped = append(dropped, srv.Tag) continue } } - deduped.Outbounds = append(deduped.Outbounds, out) - deduped.Locations[out.Tag] = new.Locations[out.Tag] - } - for _, ep := range new.Endpoints { - if currOpts, exists := curr.Load(ep.Tag); exists { - if epBytes, _ := json.MarshalContext(ctx, ep); bytes.Equal(currOpts, epBytes) { - dropped = append(dropped, ep.Tag) - continue - } - } - deduped.Endpoints = append(deduped.Endpoints, ep) - deduped.Locations[ep.Tag] = new.Locations[ep.Tag] + deduped = append(deduped, srv) } if len(dropped) > 0 { - slog.Log(nil, internal.LevelDebug, "Dropped duplicate outbounds/endpoints", "group", group, "tags", dropped) + slog.Debug("Dropped duplicate outbounds/endpoints", "tags", dropped) + } + return servers.ServerList{ + Servers: deduped, + URLOverrides: list.URLOverrides, } - return deduped } func makeOutboundOptsMap(ctx context.Context, options string) *lsync.TypedMap[string, []byte] { @@ -558,6 +600,10 @@ func makeOutboundOptsMap(ctx context.Context, options string) *lsync.TypedMap[st return &optsMap } +type closerFunc func() error + +func (f closerFunc) Close() error { return f() } + func contextDone(ctx context.Context) bool { select { case <-ctx.Done(): @@ -566,3 +612,40 @@ func contextDone(ctx context.Context) bool { return false } } + +// streamingRoundTripper defaults Accept to text/event-stream so kindling's +// race pipeline drops non-streamable transports (AMP) that would otherwise +// buffer freddie's long-poll body and break broflake's genesis subscription. +type streamingRoundTripper struct { + inner http.RoundTripper +} + +func (s streamingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if req.Header.Get("Accept") == "" { + req = req.Clone(req.Context()) + req.Header.Set("Accept", "text/event-stream") + } + resp, err := s.inner.RoundTrip(req) + if err != nil { + slog.Error("unbounded signaling RoundTrip error", + slog.String("url", req.URL.String()), + slog.Any("error", err)) + return nil, err + } + return resp, nil +} + +// cacheFileWrapper suppresses libbox's persistence of the selected outbound +// so BoxOptions.InitialServer controls the selection on each connect rather +// than a stale value from disk. +type cacheFileWrapper struct { + adapter.CacheFile +} + +func (c *cacheFileWrapper) LoadSelected(_ string) string { + return "" +} + +func (c *cacheFileWrapper) StoreSelected(_, _ string) error { + return nil +} diff --git a/vpn/tunnel_test.go b/vpn/tunnel_test.go index 03dd03b2..6165e338 100644 --- a/vpn/tunnel_test.go +++ b/vpn/tunnel_test.go @@ -1,163 +1,111 @@ package vpn import ( - "path/filepath" + "context" "testing" - "time" - sbA "github.com/sagernet/sing-box/adapter" - sbC "github.com/sagernet/sing-box/constant" - sbO "github.com/sagernet/sing-box/option" + lsync "github.com/getlantern/common/sync" + box "github.com/getlantern/lantern-box" + O "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/json" - "github.com/sagernet/sing/service" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/getlantern/lantern-box/adapter" - - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal/testutil" "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/vpn/ipc" ) -func TestConnection(t *testing.T) { - testutil.SetPathsForTesting(t) - opts, optsStr, err := testBoxOptions(settings.GetString(settings.DataPathKey)) - require.NoError(t, err, "failed to get test box options") +type errCloser struct{ err error } - tmp := settings.GetString(settings.DataPathKey) +func (c errCloser) Close() error { return c.err } - opts.Route.RuleSet = baseOpts(settings.GetString(settings.DataPathKey)).Route.RuleSet - opts.Route.RuleSet[0].LocalOptions.Path = filepath.Join(tmp, splitTunnelFile) - opts.Route.Rules = append([]sbO.Rule{baseOpts(settings.GetString(settings.DataPathKey)).Route.Rules[2]}, opts.Route.Rules...) - newSplitTunnel(tmp) +func TestTunnelClose(t *testing.T) { + t.Run("no resources", func(t *testing.T) { + tun := &tunnel{} + err := tun.close() + assert.NoError(t, err) + assert.Nil(t, tun.closers) + assert.Nil(t, tun.lbService) + }) - tun := &tunnel{ - dataPath: tmp, - } + t.Run("cancels context", func(t *testing.T) { + tun := &tunnel{} + ctx, cancel := context.WithCancel(context.Background()) + tun.cancel = cancel - require.NoError(t, tun.start(optsStr, nil), "failed to establish connection") - t.Cleanup(func() { - tun.close() + err := tun.close() + assert.NoError(t, err) + assert.Error(t, ctx.Err(), "context should be cancelled after close") }) - require.Equal(t, ipc.Connected, tun.Status(), "tunnel should be running") + t.Run("propagates closer errors", func(t *testing.T) { + tun := &tunnel{} + tun.closers = append(tun.closers, errCloser{err: assert.AnError}) - assert.NoError(t, tun.selectOutbound("http", "http1-out"), "failed to select http outbound") - assert.NoError(t, tun.close(), "failed to close lbService") - assert.Equal(t, ipc.Disconnected, tun.Status(), "tun should be closed") + err := tun.close() + assert.ErrorIs(t, err, assert.AnError) + }) } -func TestUpdateServers(t *testing.T) { - testutil.SetPathsForTesting(t) - testOpts, _, err := testBoxOptions(settings.GetString(settings.DataPathKey)) - require.NoError(t, err, "failed to get test box options") - - baseOuts := baseOpts(settings.GetString(settings.DataPathKey)).Outbounds - allOutbounds := map[string]sbO.Outbound{ - "direct": baseOuts[0], - "block": baseOuts[1], - } - for _, out := range testOpts.Outbounds { - switch out.Type { - case sbC.TypeHTTP, sbC.TypeSOCKS: - allOutbounds[out.Tag] = out - default: - } - } - - lanternTags := []string{"http1-out", "http2-out", "socks1-out"} - userTags := []string{} - outs := []sbO.Outbound{ - allOutbounds["direct"], allOutbounds["block"], - allOutbounds["http1-out"], allOutbounds["http2-out"], allOutbounds["socks1-out"], - urlTestOutbound(autoLanternTag, lanternTags, nil), urlTestOutbound(autoUserTag, userTags, nil), - selectorOutbound(servers.SGLantern, append(lanternTags, autoLanternTag)), - selectorOutbound(servers.SGUser, append(userTags, autoUserTag)), - urlTestOutbound(autoAllTag, []string{autoLanternTag, autoUserTag}, nil), - } - - testOpts.Outbounds = outs - tun := testConnection(t, *testOpts) - defer func() { - tun.close() - }() - - time.Sleep(500 * time.Millisecond) - - err = tun.removeOutbounds(servers.SGLantern, []string{"http2-out", "socks1-out"}) - require.NoError(t, err, "failed to remove servers from lantern") - - newOpts := servers.Options{ - Outbounds: []sbO.Outbound{ - allOutbounds["http1-out"], allOutbounds["socks2-out"], - }, - } - err = tun.addOutbounds(servers.SGLantern, newOpts) - require.NoError(t, err, "failed to update servers for lantern") - - time.Sleep(250 * time.Millisecond) - - outboundMgr := service.FromContext[sbA.OutboundManager](tun.ctx) - require.NotNil(t, outboundMgr, "outbound manager should not be nil") - - groups := tun.mutGrpMgr.OutboundGroups() - - want := map[string][]string{ - autoAllTag: {autoLanternTag, autoUserTag}, - servers.SGLantern: {autoLanternTag, "http1-out", "socks2-out"}, - autoLanternTag: {"http1-out", "socks2-out"}, - servers.SGUser: {autoUserTag}, - autoUserTag: {}, - } - got := make(map[string][]string) - allTags := []string{"direct", "block", autoAllTag, autoLanternTag, autoUserTag, servers.SGLantern, servers.SGUser} - for _, g := range groups { - tags := g.All() - got[g.Tag()] = tags - allTags = append(allTags, tags...) - } - for _, tag := range allTags { - if _, found := outboundMgr.Outbound(tag); !found { - assert.Failf(t, "outbound missing from outbound manager", "outbound %s not found", tag) - } - } - for group, tags := range want { - assert.ElementsMatchf(t, tags, got[group], "group %s does not have correct outbounds", group) - } +func TestSelectMode_NotConnected(t *testing.T) { + // A tunnel without an active libbox service is not running. + tun := &tunnel{} + err := tun.selectMode(AutoSelectTag) + require.Error(t, err) + assert.Contains(t, err.Error(), "tunnel not running") } -func getGroups(outboundMgr sbA.OutboundManager) []adapter.MutableOutboundGroup { - outbounds := outboundMgr.Outbounds() - var iGroups []adapter.MutableOutboundGroup - for _, it := range outbounds { - if group, isGroup := it.(adapter.MutableOutboundGroup); isGroup { - iGroups = append(iGroups, group) +func TestRemoveDuplicates(t *testing.T) { + ctx := box.BaseContext() + out1 := O.Outbound{Type: "http", Tag: "http-1", Options: &O.HTTPOutboundOptions{}} + out2 := O.Outbound{Type: "http", Tag: "http-2", Options: &O.HTTPOutboundOptions{}} + socks := O.Outbound{Type: "socks", Tag: "socks-1", Options: &O.SOCKSOutboundOptions{}} + ep1 := O.Endpoint{Type: "wireguard", Tag: "wg-1", Options: &O.WireGuardEndpointOptions{}} + + t.Run("drops duplicates against current map", func(t *testing.T) { + var curr lsync.TypedMap[string, []byte] + b1, _ := json.MarshalContext(ctx, out1) + curr.Store(out1.Tag, b1) + bEp1, _ := json.MarshalContext(ctx, ep1) + curr.Store(ep1.Tag, bEp1) + + list := servers.ServerList{ + Servers: []*servers.Server{ + {Tag: out1.Tag, Type: out1.Type, Options: out1}, + {Tag: out2.Tag, Type: out2.Type, Options: out2}, + {Tag: ep1.Tag, Type: ep1.Type, Options: ep1}, + }, } - } - return iGroups -} -func testConnection(t *testing.T, opts sbO.Options) *tunnel { - tmp := settings.GetString(settings.DataPathKey) + result := removeDuplicates(ctx, &curr, list) + assert.Len(t, result.Servers, 1) + assert.Equal(t, "http-2", result.Servers[0].Tag) + }) - opts.Route.RuleSet = baseOpts(settings.GetString(settings.DataPathKey)).Route.RuleSet - opts.Route.RuleSet[0].LocalOptions.Path = filepath.Join(tmp, splitTunnelFile) - opts.Route.Rules = append([]sbO.Rule{baseOpts(settings.GetString(settings.DataPathKey)).Route.Rules[2]}, opts.Route.Rules...) - newSplitTunnel(tmp) + t.Run("keeps all servers when none are duplicates", func(t *testing.T) { + var curr lsync.TypedMap[string, []byte] + list := servers.ServerList{ + Servers: []*servers.Server{ + {Tag: out1.Tag, Type: out1.Type, Options: out1}, + {Tag: socks.Tag, Type: socks.Type, Options: socks}, + }, + } - tun := &tunnel{ - dataPath: tmp, - } + result := removeDuplicates(ctx, &curr, list) + assert.Len(t, result.Servers, 2) + }) - options, _ := json.Marshal(opts) - err := tun.start(string(options), nil) - require.NoError(t, err, "failed to establish connection") - t.Cleanup(func() { - tun.close() + t.Run("empty list yields empty result", func(t *testing.T) { + var curr lsync.TypedMap[string, []byte] + result := removeDuplicates(ctx, &curr, servers.ServerList{}) + assert.Empty(t, result.Servers) }) +} + +func TestContextDone(t *testing.T) { + ctx := context.Background() + assert.False(t, contextDone(ctx)) - assert.Equal(t, ipc.Connected, tun.Status(), "tunnel should be running") - return tun + ctx, cancel := context.WithCancel(context.Background()) + cancel() + assert.True(t, contextDone(ctx)) } diff --git a/vpn/types.go b/vpn/types.go new file mode 100644 index 00000000..fccbdd6a --- /dev/null +++ b/vpn/types.go @@ -0,0 +1,86 @@ +package vpn + +import ( + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/experimental/clashapi/trafficontrol" + + "github.com/getlantern/radiance/events" +) + +// URLTestHistoryStorage is an alias for the sing-box adapter interface. +type URLTestHistoryStorage = adapter.URLTestHistoryStorage + +// StatusUpdateEvent is emitted when the VPN status changes. +type StatusUpdateEvent struct { + events.Event + Status VPNStatus `json:"status"` + Error string `json:"error,omitempty"` +} + +// Selector is helper interface to check if an outbound is a selector or wrapper of selector. +type Selector interface { + adapter.OutboundGroup + SelectOutbound(tag string) bool +} + +type OutboundGroup struct { + Tag string + Type string + Selected string + Outbounds []Outbounds +} + +type Outbounds struct { + Tag string + Type string +} + +type Connection struct { + ID string + Inbound string + IPVersion int + Network string + Source string + Destination string + Domain string + Protocol string + FromOutbound string + CreatedAt int64 + ClosedAt int64 + Uplink int64 + Downlink int64 + Rule string + Outbound string + ChainList []string +} + +// NewConnection creates a Connection from tracker metadata. +func newConnection(metadata trafficontrol.TrackerMetadata) Connection { + var rule string + if metadata.Rule != nil { + rule = metadata.Rule.String() + " => " + metadata.Rule.Action().String() + } + var closedAt int64 + if !metadata.ClosedAt.IsZero() { + closedAt = metadata.ClosedAt.UnixMilli() + } + md := metadata.Metadata + return Connection{ + ID: metadata.ID.String(), + Inbound: md.InboundType + "/" + md.Inbound, + IPVersion: int(md.IPVersion), + Network: md.Network, + Source: md.Source.String(), + Destination: md.Destination.String(), + Domain: md.Domain, + Protocol: md.Protocol, + FromOutbound: md.Outbound, + CreatedAt: metadata.CreatedAt.UnixMilli(), + ClosedAt: closedAt, + Uplink: metadata.Upload.Load(), + Downlink: metadata.Download.Load(), + Rule: rule, + Outbound: metadata.OutboundType + "/" + metadata.Outbound, + ChainList: metadata.Chain, + } +} diff --git a/vpn/vpn.go b/vpn/vpn.go index 35a2bb38..b2f7a628 100644 --- a/vpn/vpn.go +++ b/vpn/vpn.go @@ -5,20 +5,20 @@ package vpn import ( "context" - "encoding/json" "errors" "fmt" "log/slog" - "os" "path/filepath" - "slices" + "runtime" "strings" "sync" + "sync/atomic" "time" sbox "github.com/sagernet/sing-box" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/urltest" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/experimental/libbox" "github.com/sagernet/sing-box/option" sbjson "github.com/sagernet/sing/common/json" @@ -29,357 +29,366 @@ import ( "go.opentelemetry.io/otel/trace" box "github.com/getlantern/lantern-box" - - "github.com/getlantern/radiance/common" - "github.com/getlantern/radiance/common/atomicfile" - "github.com/getlantern/radiance/common/settings" "github.com/getlantern/radiance/events" - "github.com/getlantern/radiance/internal" + "github.com/getlantern/radiance/log" "github.com/getlantern/radiance/servers" "github.com/getlantern/radiance/traces" - "github.com/getlantern/radiance/vpn/ipc" ) const ( tracerName = "github.com/getlantern/radiance/vpn" ) -func init() { - forwardToTunnel := func(action func(ctx context.Context) error, desc string) { - ctx := context.Background() - status, err := ipc.GetStatus(ctx) - if err != nil { - slog.Warn("Event received but failed to get tunnel status", "event", desc, "error", err) - return - } - if status != ipc.Connected { - return - } - if err := action(ctx); err != nil { - slog.Error("Failed to forward event to tunnel", "event", desc, "error", err) - } - } +var ( + ErrTunnelNotConnected = errors.New("tunnel not connected") + ErrTunnelAlreadyConnected = errors.New("tunnel already connected") +) - events.Subscribe(func(e servers.ServersUpdatedEvent) { - forwardToTunnel(func(ctx context.Context) error { - svrs := map[string]servers.Options{e.Group: *e.Options} - return ipc.UpdateOutbounds(ctx, svrs) - }, "servers-updated") - }) - events.Subscribe(func(e servers.ServersAddedEvent) { - forwardToTunnel(func(ctx context.Context) error { - return ipc.AddOutbounds(ctx, e.Group, *e.Options) - }, "servers-added") - }) - events.Subscribe(func(e servers.ServersRemovedEvent) { - forwardToTunnel(func(ctx context.Context) error { - return ipc.RemoveOutbounds(ctx, e.Group, []string{e.Tag}) - }, "servers-removed") - }) -} +type VPNStatus string + +// Possible VPN statuses +const ( + Connecting VPNStatus = "connecting" + Connected VPNStatus = "connected" + Disconnecting VPNStatus = "disconnecting" + Disconnected VPNStatus = "disconnected" + Restarting VPNStatus = "restarting" + ErrorStatus VPNStatus = "error" +) -// Deprecated: Use AutoConnect instead with the desired group. -func QuickConnect(group string, _ libbox.PlatformInterface) (err error) { - return AutoConnect(group) +func (s *VPNStatus) String() string { + return string(*s) } -// AutoConnect automatically connects to the best available server in the specified group. Valid -// groups are [servers.ServerGroupLantern], [servers.ServerGroupUser], "all", or the empty string. -// Using "all" or the empty string will connect to the best available server across all groups. -func AutoConnect(group string) error { - ctx, span := otel.Tracer(tracerName).Start( - context.Background(), - "quick_connect", - trace.WithAttributes(attribute.String("group", group))) - defer span.End() +// VPNClient manages the lifecycle of the VPN tunnel. +type VPNClient struct { + tunnel *tunnel - switch group { - case servers.SGLantern: - return traces.RecordError(ctx, ConnectToServer(servers.SGLantern, autoLanternTag, nil)) - case servers.SGUser: - return traces.RecordError(ctx, ConnectToServer(servers.SGUser, autoUserTag, nil)) - case autoAllTag, "all", "": - if isOpen(ctx) { - if err := ipc.SetClashMode(ctx, autoAllTag); err != nil { - return fmt.Errorf("failed to set auto mode: %w", err) - } - return nil - } - return traces.RecordError(ctx, connect(autoAllTag, "")) - default: - return traces.RecordError(ctx, fmt.Errorf("invalid group: %s", group)) - } + platformIfce PlatformInterface + logger *slog.Logger + + offlineTestCancel context.CancelFunc + offlineTestDone chan struct{} + + status atomic.Value // VPNStatus + + mu sync.RWMutex +} + +// PlatformInterface defines the methods to interact with platform-specific services +type PlatformInterface interface { + libbox.PlatformInterface + // RestartService is called when the VPNClient wants to restart the tunnel instead of direct + // disconnect/reconnect. This allows platforms to perform any necessary extra steps to restart + // the tunnel. RestartService should block until the tunnel has been restarted and is ready for + // use, or return an error if restart fails. + RestartService() error + // PostServiceClose is called after the tunnel has been closed. This allows platforms to perform + // any necessary cleanup. + PostServiceClose() } -// Deprecated: Use Connect instead with the desired group and tag. -func ConnectToServer(group, tag string, _ libbox.PlatformInterface) error { - return Connect(group, tag) +// NewVPNClient creates a new VPNClient instance with the provided configuration paths, log +// level, and platform interface. +func NewVPNClient(dataPath string, logger *slog.Logger, platformIfce PlatformInterface) *VPNClient { + if logger == nil { + logger = slog.Default() + } + _ = newSplitTunnel(dataPath, logger) + done := make(chan struct{}) + close(done) + c := &VPNClient{ + platformIfce: platformIfce, + logger: logger, + offlineTestCancel: func() {}, + offlineTestDone: done, + } + c.status.Store(Disconnected) + return c } -// Connect connects to a specific server identified by the group and tag. Valid groups are -// [servers.SGLantern] and [servers.SGUser]. -func Connect(group, tag string) error { +func (c *VPNClient) Connect(boxOptions BoxOptions) error { ctx, span := otel.Tracer(tracerName).Start( context.Background(), - "connect_to_server", - trace.WithAttributes( - attribute.String("group", group), - attribute.String("tag", tag))) + "connect", + ) defer span.End() - switch group { - case servers.SGLantern, servers.SGUser: - default: - return traces.RecordError(ctx, fmt.Errorf("invalid group: %s", group)) + c.mu.Lock() + // Cancel any running offline tests and wait for them to finish. + c.offlineTestCancel() + done := c.offlineTestDone + c.mu.Unlock() + <-done + + c.mu.Lock() + defer c.mu.Unlock() + if c.tunnel != nil { + switch status := c.Status(); status { + case Connected: + return ErrTunnelAlreadyConnected + case Restarting, Connecting, Disconnecting: + return fmt.Errorf("tunnel is currently %s", status) + case Disconnected, ErrorStatus: + // Clean up the stale tunnel so we can reconnect. + c.tunnel = nil + default: + return fmt.Errorf("tunnel is in unexpected state: %s", status) + } + } + + options, err := buildOptions(boxOptions) + if err != nil { + return traces.RecordError(ctx, fmt.Errorf("failed to build options: %w", err)) } - if tag == "" { - return traces.RecordError(ctx, errors.New("tag must be specified")) + opts, err := sbjson.Marshal(options) + if err != nil { + return traces.RecordError(ctx, fmt.Errorf("failed to marshal options: %w", err)) } - return traces.RecordError(ctx, connect(group, tag)) + return traces.RecordError(ctx, c.start(ctx, boxOptions.BasePath, string(opts), false, boxOptions.URLTestSeed)) } -func connect(group, tag string) error { - ctx := context.Background() - if isOpen(ctx) { - return SelectServer(ctx, group, tag) +// Disconnect closes the tunnel and all active connections. +func (c *VPNClient) Disconnect() error { + ctx, span := otel.Tracer(tracerName).Start(context.Background(), "disconnect") + defer span.End() + c.mu.Lock() + defer c.mu.Unlock() + if c.tunnel == nil { + return nil } - dataPath := settings.GetString(settings.DataPathKey) - _ = newSplitTunnel(dataPath) - options, err := getOptions() - if err != nil { + c.logger.Info("Disconnecting VPN") + return traces.RecordError(ctx, c.close()) +} + +func (c *VPNClient) start(ctx context.Context, path, options string, isRestart bool, urlTestSeed map[string]adapter.URLTestHistory) error { + c.logger.Debug("Starting tunnel", "options", options) + c.setStatus(Connecting, nil) + t := tunnel{dataPath: path, urlTestSeed: urlTestSeed} + if err := t.start(ctx, options, c.platformIfce, isRestart); err != nil { + c.setStatus(ErrorStatus, err) return err } - if err := ipc.StartService(ctx, options); err != nil { + c.tunnel = &t + c.setStatus(Connected, nil) + return nil +} + +func (c *VPNClient) close() error { + t := c.tunnel + c.tunnel = nil + + c.logger.Info("Closing tunnel") + c.setStatus(Disconnecting, nil) + if err := t.close(); err != nil { + c.setStatus(ErrorStatus, err) return err } - return SelectServer(ctx, group, tag) + c.setStatus(Disconnected, nil) + if c.platformIfce != nil { + c.platformIfce.PostServiceClose() + } + c.logger.Debug("Tunnel closed") + runtime.GC() + return nil } -// Restart restarts the tunnel by reconnecting to the currently selected server. -func Restart() error { - ctx, span := otel.Tracer(tracerName).Start(context.Background(), "restart") +// Restart closes and restarts the tunnel if it is currently running. Returns an error if the tunnel +// is not running or restart fails. +func (c *VPNClient) Restart(boxOptions BoxOptions) error { + ctx, span := otel.Tracer(tracerName).Start(context.Background(), "VPNClient.Restart") defer span.End() - options, err := getOptions() - if err != nil { - return err + c.mu.Lock() + if c.tunnel == nil || c.Status() != Connected { + c.mu.Unlock() + return ErrTunnelNotConnected } - return traces.RecordError(ctx, ipc.RestartService(ctx, options)) -} -func getOptions() (string, error) { - dataPath := settings.GetString(settings.DataPathKey) - options, err := buildOptions(context.Background(), dataPath) + c.setStatus(Restarting, nil) + c.logger.Info("Restarting tunnel") + + if c.platformIfce != nil { + span.SetAttributes(attribute.String("path", "platform_ifce")) + c.mu.Unlock() + if err := c.platformIfce.RestartService(); err != nil { + c.logger.Error("Failed to restart tunnel via platform interface", "error", err) + err = fmt.Errorf("platform interface restart failed: %w", err) + c.setStatus(ErrorStatus, err) + return traces.RecordError(ctx, err) + } + c.logger.Info("Tunnel restarted successfully") + return nil + } + span.SetAttributes(attribute.String("path", "direct")) + + defer c.mu.Unlock() + if err := c.close(); err != nil { + return traces.RecordError(ctx, fmt.Errorf("closing tunnel: %w", err)) + } + options, err := buildOptions(boxOptions) if err != nil { - return "", fmt.Errorf("failed to build options: %w", err) + c.setStatus(ErrorStatus, err) + return traces.RecordError(ctx, fmt.Errorf("failed to build options: %w", err)) } opts, err := sbjson.Marshal(options) if err != nil { - return "", fmt.Errorf("failed to marshal options: %w", err) + c.setStatus(ErrorStatus, err) + return traces.RecordError(ctx, fmt.Errorf("failed to marshal options: %w", err)) } - return string(opts), nil + if err := c.start(ctx, boxOptions.BasePath, string(opts), true, boxOptions.URLTestSeed); err != nil { + c.logger.Error("starting tunnel", "error", err) + // c.start already set ErrorStatus; the guard lets Restarting→ErrorStatus through. + return traces.RecordError(ctx, fmt.Errorf("starting tunnel: %w", err)) + } + c.logger.Info("Tunnel restarted successfully") + return nil } // isOpen returns true if the tunnel is open, false otherwise. // Note, this does not check if the tunnel can connect to a server. -func isOpen(ctx context.Context) bool { - state, err := ipc.GetStatus(ctx) - if err != nil { - slog.Error("Failed to get tunnel state", "error", err) - } - return state == ipc.Connected +func (c *VPNClient) isOpen() bool { + return c.Status() == Connected } -// Disconnect closes the tunnel and all active connections. -func Disconnect() error { - ctx, span := otel.Tracer(tracerName).Start(context.Background(), "disconnect") - defer span.End() - slog.Info("Disconnecting VPN") - return traces.RecordError(ctx, ipc.StopService(ctx)) +func (c *VPNClient) Status() VPNStatus { + s, _ := c.status.Load().(VPNStatus) + return s } -// SelectServer selects the specified server for the tunnel. The tunnel must already be open. -func SelectServer(ctx context.Context, group, tag string) error { - if !isOpen(ctx) { - return errors.New("tunnel is not open") - } - if group == autoAllTag { - slog.Info("Switching to auto mode", "group", group) - if err := ipc.SetClashMode(ctx, group); err != nil { - slog.Error("Failed to set auto mode", "group", group, "error", err) - return fmt.Errorf("failed to set auto mode: %w", err) - } - return nil +// setStatus stores and emits a status event. If the current status is Restarting, only allow +// transitions to Connected or ErrorStatus to avoid emitting intermediate states during a restart. +func (c *VPNClient) setStatus(s VPNStatus, err error) { + if cur, _ := c.status.Load().(VPNStatus); cur == Restarting && s != Connected && s != ErrorStatus { + return } - slog.Info("Selecting server", "group", group, "tag", tag) - if err := ipc.SelectOutbound(ctx, group, tag); err != nil { - slog.Error("Failed to select server", "group", group, "tag", tag, "error", err) - return fmt.Errorf("failed to select server %s/%s: %w", group, tag, err) + c.status.Store(s) + evt := StatusUpdateEvent{Status: s} + if err != nil { + evt.Error = err.Error() } - return nil + events.Emit(evt) } -// Status represents the current status of the tunnel, including whether it is open, the selected -// server, and the active server. Active is only set if the tunnel is open. -type Status struct { - TunnelOpen bool - // SelectedServer is the server that is currently selected for the tunnel. - SelectedServer string - // ActiveServer is the server that is currently active for the tunnel. This will differ from - // SelectedServer if using auto-select mode. - ActiveServer string +// HistoryStorage returns the tunnel's URL test history storage or nil if the tunnel is not connected. +func (c *VPNClient) HistoryStorage() adapter.URLTestHistoryStorage { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return nil + } + return c.tunnel.urltestHistory } -func GetStatus() (Status, error) { - ctx, span := otel.Tracer(tracerName).Start(context.Background(), "get_status") - defer span.End() - slog.Debug("Retrieving tunnel status") - s := Status{ - TunnelOpen: isOpen(ctx), +// SelectServer changes the currently selected server to the one specified by tag. If tag is +// AutoSelectTag or the empty string, the tunnel will switch to auto-select mode and automatically +// choose the best server. +func (c *VPNClient) SelectServer(tag string) error { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil || c.Status() != Connected { + return ErrTunnelNotConnected } - if !s.TunnelOpen { - return s, nil + t := c.tunnel + if tag == AutoSelectTag || tag == "" { + return c.tunnel.selectMode(AutoSelectTag) } - slog.Log(nil, internal.LevelTrace, "Tunnel is open, retrieving selected and active servers") - group, tag, err := ipc.GetSelected(ctx) - if err != nil { - return s, fmt.Errorf("failed to get selected server: %w", err) - } - if group == autoAllTag { - s.SelectedServer = autoAllTag - } else { - s.SelectedServer = tag + c.logger.Info("Selecting server", "tag", tag) + if err := t.selectOutbound(tag); err != nil { + c.logger.Error("Failed to select server", "tag", tag, "error", err) + return fmt.Errorf("failed to select server %s: %w", tag, err) } + return nil +} - _, active, err := ipc.GetActiveOutbound(ctx) - if err != nil { - return s, fmt.Errorf("failed to get active server: %w", err) +func (c *VPNClient) UpdateOutbounds(list servers.ServerList) error { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return ErrTunnelNotConnected } - s.ActiveServer = active - slog.Log(nil, internal.LevelTrace, "retrieved tunnel status", "tunnelOpen", s.TunnelOpen, "selectedServer", s.SelectedServer, "activeServer", s.ActiveServer) - return s, nil + return c.tunnel.updateOutbounds(list) } -func ActiveServer(ctx context.Context) (group, tag string, err error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "active_server") - defer span.End() - slog.Log(nil, internal.LevelTrace, "Retrieving active server") - group, tag, err = ipc.GetActiveOutbound(ctx) - if err != nil { - return "", "", fmt.Errorf("failed to get active server: %w", err) +func (c *VPNClient) AddOutbounds(list servers.ServerList) error { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return ErrTunnelNotConnected } - return group, tag, nil + return c.tunnel.addOutbounds(list) } -// ActiveConnections returns a list of currently active connections, ordered from newest to oldest. -// A non-nil error is only returned if there was an error retrieving the connections, or if the -// tunnel is closed. If there are no active connections and the tunnel is open, an empty slice is -// returned without an error. -func ActiveConnections(ctx context.Context) ([]ipc.Connection, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "active_connections") - defer span.End() - connections, err := Connections(ctx) - if err != nil { - return nil, traces.RecordError(ctx, fmt.Errorf("failed to get active connections: %w", err)) +func (c *VPNClient) RemoveOutbounds(tags []string) error { + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return ErrTunnelNotConnected } - - connections = slices.DeleteFunc(connections, func(c ipc.Connection) bool { - return c.ClosedAt != 0 - }) - slices.SortFunc(connections, func(a, b ipc.Connection) int { - return int(b.CreatedAt - a.CreatedAt) - }) - return connections, nil + return c.tunnel.removeOutbounds(tags) } // Connections returns a list of all connections, both active and recently closed. A non-nil error // is only returned if there was an error retrieving the connections, or if the tunnel is closed. // If there are no connections and the tunnel is open, an empty slice is returned without an error. -func Connections(ctx context.Context) ([]ipc.Connection, error) { - ctx, span := otel.Tracer(tracerName).Start(ctx, "connections") +func (c *VPNClient) Connections() ([]Connection, error) { + _, span := otel.Tracer(tracerName).Start(context.Background(), "connections") defer span.End() - connections, err := ipc.GetConnections(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get connections: %w", err) + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return nil, fmt.Errorf("failed to get connections: %w", ErrTunnelNotConnected) + } + tm := c.tunnel.clashServer.TrafficManager() + activeConns := tm.Connections() + closedConns := tm.ClosedConnections() + connections := make([]Connection, 0, len(activeConns)+len(closedConns)) + for _, conn := range activeConns { + connections = append(connections, newConnection(conn)) + } + for _, conn := range closedConns { + connections = append(connections, newConnection(conn)) } return connections, nil } -// AutoSelections represents the currently active servers for each auto server group. -type AutoSelections struct { - Lantern string - User string - AutoAll string -} - -// AutoSelectionsEvent is emitted when server location changes for any auto server group. -type AutoSelectionsEvent struct { +// AutoSelectedEvent is emitted when the auto-selected server changes. +type AutoSelectedEvent struct { events.Event - Selections AutoSelections + Selected string `json:"selected"` } -// SelectionUnavailable is the sentinel value returned for an auto-selection -// group that has no active server (tunnel not running, group not found, etc.). -const SelectionUnavailable = "Unavailable" - -// AutoServerSelections returns the currently active server for each auto server group. If the group -// is not found or has no active server, SelectionUnavailable is returned for that group. -func AutoServerSelections() (AutoSelections, error) { - as := AutoSelections{ - Lantern: SelectionUnavailable, - User: SelectionUnavailable, - AutoAll: SelectionUnavailable, - } - ctx := context.Background() - if !isOpen(ctx) { - slog.Log(ctx, internal.LevelTrace, "Tunnel not running, cannot get auto selections") - return as, nil +func (c *VPNClient) CurrentAutoSelectedServer() (string, error) { + if !c.isOpen() { + c.logger.Log(nil, log.LevelTrace, "Tunnel not running, cannot get auto selections") + return "", nil } - groups, err := ipc.GetGroups(ctx) - if err != nil { - return as, fmt.Errorf("failed to get groups: %w", err) - } - slog.Log(ctx, internal.LevelTrace, "Retrieved groups", "groups", groups) - selected := func(tag string) string { - idx := slices.IndexFunc(groups, func(g ipc.OutboundGroup) bool { - return g.Tag == tag - }) - if idx < 0 || groups[idx].Selected == "" { - slog.Log(ctx, internal.LevelTrace, "Group not found or has no selection", "tag", tag) - return SelectionUnavailable - } - return groups[idx].Selected - } - auto := AutoSelections{ - Lantern: selected(autoLanternTag), - User: selected(autoUserTag), + c.mu.RLock() + defer c.mu.RUnlock() + if c.tunnel == nil { + return "", ErrTunnelNotConnected } - - switch all := selected(autoAllTag); all { - case autoLanternTag: - auto.AutoAll = auto.Lantern - case autoUserTag: - auto.AutoAll = auto.User - default: - auto.AutoAll = all + outbound, loaded := c.tunnel.outboundMgr.Outbound(AutoSelectTag) + if !loaded { + return "", fmt.Errorf("auto select group not found") } - return auto, nil + return outbound.(adapter.OutboundGroup).Now(), nil } const ( - rapidPollInterval = 500 * time.Millisecond - rapidPollWindow = 15 * time.Second + rapidPollInterval = 500 * time.Millisecond + rapidPollWindow = 15 * time.Second steadyPollInterval = 10 * time.Second ) -// AutoSelectionsChangeListener polls for auto-selection changes and emits an -// AutoSelectionsEvent whenever the selection differs from the previous value. +// AutoSelectedChangeListener polls for auto-selection changes and emits an +// AutoSelectedEvent whenever the selection differs from the previous value. // It performs an initial rapid poll to catch the first selection soon after // tunnel connect, then settles into a slower steady-state interval. -func AutoSelectionsChangeListener(ctx context.Context) { +func (c *VPNClient) AutoSelectedChangeListener(ctx context.Context) { go func() { - var prev AutoSelections + var prev string // Rapid initial poll to emit the first selection promptly after connect. initialDeadline := time.NewTimer(rapidPollWindow) @@ -394,17 +403,15 @@ func AutoSelectionsChangeListener(ctx context.Context) { case <-initialDeadline.C: break initial case <-tick.C: - curr, err := AutoServerSelections() + curr, err := c.CurrentAutoSelectedServer() if err != nil { tick.Reset(rapidPollInterval) continue } if curr != prev { prev = curr - events.Emit(AutoSelectionsEvent{Selections: curr}) - if curr.Lantern != SelectionUnavailable || - curr.User != SelectionUnavailable || - curr.AutoAll != SelectionUnavailable { + events.Emit(AutoSelectedEvent{Selected: curr}) + if curr != "" { break initial } } @@ -427,14 +434,14 @@ func AutoSelectionsChangeListener(ctx context.Context) { case <-ctx.Done(): return case <-tick.C: - curr, err := AutoServerSelections() + curr, err := c.CurrentAutoSelectedServer() if err != nil { tick.Reset(steadyPollInterval) continue } if curr != prev { prev = curr - events.Emit(AutoSelectionsEvent{Selections: curr}) + events.Emit(AutoSelectedEvent{Selected: curr}) } tick.Reset(steadyPollInterval) } @@ -442,217 +449,146 @@ func AutoSelectionsChangeListener(ctx context.Context) { }() } -const urlTestHistoryFileName = "url_test_history.json" - -var urlTestMu sync.Mutex - -// RunURLTests performs URL tests for all outbounds defined in configs. It is intended to run in -// response to configuration updates to provide continuous bandit callback data even when the VPN -// tunnel is not active. When the tunnel IS active, its own CheckOutbounds handles URL testing, so -// this is skipped. -func RunURLTests(path string) { - // Skip if the tunnel is handling URL tests - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if isOpen(ctx) { - slog.Debug("Tunnel is active, skipping standalone URL tests") - return - } - - // Prevent overlapping runs - if !urlTestMu.TryLock() { - return - } - defer urlTestMu.Unlock() - - results, traceCtx, hasTrace, err := preTest(path) - if err != nil { - slog.Error("URL test failed", "error", err) - if len(results) == 0 { - return - } - // Tests ran but a non-critical step (e.g. saving history) failed. - // Continue to emit the span and log the results we do have. - } - - // Record URL test results in a span linked to the bandit's trace. - if hasTrace { - _, span := otel.Tracer(tracerName).Start(traceCtx, "radiance.url_tests_complete", - trace.WithAttributes( - attribute.Int("bandit.test_count", len(results)), - ), - ) - for tag, delay := range results { - span.AddEvent("url_test_result", trace.WithAttributes( - attribute.String("outbound", tag), - attribute.Int("latency_ms", int(delay)), - )) - } - span.End() - } - - var formattedResults []string - for tag, delay := range results { - formattedResults = append(formattedResults, fmt.Sprintf("%s: [%dms]", tag, delay)) - } - slog.Log(nil, internal.LevelTrace, "URL test complete", "results", strings.Join(formattedResults, "; ")) -} - -func preTest(path string) (map[string]uint16, context.Context, bool, error) { - slog.Info("Performing pre-start URL tests") - - confPath := filepath.Join(path, common.ConfigFileName) - slog.Debug("Loading config file", "confPath", confPath) - cfg, err := loadConfig(confPath) - if err != nil { - return nil, context.Background(), false, fmt.Errorf("failed to load config: %w", err) +// RunOfflineURLTests will run URL tests for all outbounds if the tunnel is not currently connected. +// This can improve initial connection times by pre-determining reachability and latency to servers. +// +// If [VPNClient.Connect] is called while RunOfflineURLTests is running, the tests will be cancelled and +// any results will be discarded. +func (c *VPNClient) RunOfflineURLTests(basePath string, outbounds []option.Outbound, banditURLs map[string]string) (map[string]uint16, error) { + c.mu.Lock() + if c.tunnel != nil { + c.mu.Unlock() + return nil, ErrTunnelAlreadyConnected + } + select { + case <-c.offlineTestDone: + // no tests currently running, safe to start new tests + default: + c.mu.Unlock() + return nil, errors.New("offline tests already running") } + ctx, cancel := context.WithCancel(box.BaseContext()) + c.offlineTestCancel = cancel + done := make(chan struct{}) + c.offlineTestDone = done + c.mu.Unlock() + defer close(done) // Extract bandit trace context for distributed tracing - traceCtx, hasTrace := traces.ExtractBanditTraceContext(cfg.BanditURLOverrides) - - cfgOpts := cfg.Options - - slog.Debug("Loading user servers") - userOpts, err := loadUserOptions(path) - if err != nil { - return nil, context.Background(), false, fmt.Errorf("failed to load user options: %w", err) - } + traceCtx, hasTrace := traces.ExtractBanditTraceContext(banditURLs) - // since we are only doing URL tests, we only need the outbounds from both configs; we skip - // endpoints as most/all require elevated privileges to use. just using outbounds is sufficient - // to improve initial connect times. - outbounds := append(cfgOpts.Outbounds, userOpts.Outbounds...) + c.logger.Info("Performing offline URL tests") tags := make([]string, 0, len(outbounds)) for _, ob := range outbounds { tags = append(tags, ob.Tag) } - // All outbounds get URL-tested — the server now sends callback - // URLs for every outbound, and the dependency's worker pool bounds memory. - outbounds = append(outbounds, urlTestOutbound("preTest", tags, cfg.BanditURLOverrides)) + outbounds = append(outbounds, urlTestOutbound("offline-test", tags, banditURLs)) options := option.Options{ Log: &option.LogOptions{Disabled: true}, Outbounds: outbounds, + Experimental: &option.ExperimentalOptions{ + CacheFile: &option.CacheFileOptions{ + Enabled: true, + Path: filepath.Join(basePath, cacheFileName), + CacheID: cacheID, + }, + }, } - // create pre-started box instance. we just use the standard box since we don't need a + // create offlineed box instance. we just use the standard box since we don't need a // platform interface for testing. - ctx := box.BaseContext() ctx = service.ContextWith[filemanager.Manager](ctx, nil) urlTestHistoryStorage := urltest.NewHistoryStorage() ctx = service.ContextWithPtr(ctx, urlTestHistoryStorage) service.MustRegister[adapter.URLTestHistoryStorage](ctx, urlTestHistoryStorage) // for good measure - ctx, cancel := context.WithTimeout(ctx, 15*time.Second) // enough time for bandit callback tests through proxies + ctx, cancel = context.WithTimeout(ctx, 5*time.Second) // enough time for tests to complete or fail defer cancel() instance, err := sbox.New(sbox.Options{ Context: ctx, Options: options, }) if err != nil { - return nil, context.Background(), false, fmt.Errorf("failed to create sing-box instance: %w", err) + return nil, fmt.Errorf("failed to create sing-box instance: %w", err) } defer instance.Close() - if err := instance.PreStart(); err != nil { - return nil, context.Background(), false, fmt.Errorf("failed to start sing-box instance: %w", err) - } - outbound, ok := instance.Outbound().Outbound("preTest") - if !ok { - return nil, context.Background(), false, errors.New("preTest outbound not found") + // connect may have been called while we were setting up, so check if we should abort before + // starting the instance. + select { + case <-ctx.Done(): + return nil, fmt.Errorf("offline tests cancelled: %w", ctx.Err()) + default: } - tester, ok := outbound.(adapter.URLTestGroup) - if !ok { - return nil, context.Background(), false, errors.New("preTest outbound is not a URLTestGroup") + if err := instance.PreStart(); err != nil { + return nil, fmt.Errorf("failed to start sing-box instance: %w", err) } + outbound, _ := instance.Outbound().Outbound("offline-test") + tester, _ := outbound.(adapter.URLTestGroup) // run URL tests results, err := tester.URLTest(ctx) if err != nil { - return nil, context.Background(), false, fmt.Errorf("failed to perform URL tests: %w", err) - } - - historyPath := filepath.Join(path, urlTestHistoryFileName) - if err := saveURLTestResults(urlTestHistoryStorage, historyPath, results); err != nil { - return results, traceCtx, hasTrace, fmt.Errorf("failed to save URL test results: %w", err) - } - return results, traceCtx, hasTrace, nil -} - - -func saveURLTestResults(storage *urltest.HistoryStorage, path string, results map[string]uint16) error { - slog.Debug("Saving URL test history", "path", path) - history := make(map[string]*adapter.URLTestHistory, len(results)) - for tag := range results { - history[tag] = storage.LoadURLTestHistory(tag) - } - buf, err := json.Marshal(history) - if err != nil { - return fmt.Errorf("failed to marshal URL test history: %w", err) - } - return atomicfile.WriteFile(path, buf, 0o644) -} - -func loadURLTestHistory(storage *urltest.HistoryStorage, path string) error { - slog.Debug("Loading URL test history", "path", path) - buf, err := atomicfile.ReadFile(path) - if errors.Is(err, os.ErrNotExist) { - return nil - } - if err != nil { - return fmt.Errorf("failed to read URL test history file: %w", err) - } - - history := make(map[string]*adapter.URLTestHistory) - if err := json.Unmarshal(buf, &history); err != nil { - return fmt.Errorf("failed to unmarshal URL test history: %w", err) - } - for tag, result := range history { - storage.StoreURLTestHistory(tag, result) + c.logger.Error("offline URL test failed", "error", err) + return nil, fmt.Errorf("offline URL test failed: %w", err) } - return nil -} - -func SmartRoutingEnabled() bool { - return settings.GetBool(settings.SmartRoutingKey) -} -func SetSmartRouting(enable bool) error { - if SmartRoutingEnabled() == enable { - return nil - } - if err := settings.Set(settings.SmartRoutingKey, enable); err != nil { - return err + // Record URL test results in a span linked to the bandit's trace. + if hasTrace { + _, span := otel.Tracer(tracerName).Start(traceCtx, "url_tests_complete", + trace.WithAttributes( + attribute.Int("bandit.test_count", len(results)), + ), + ) + for tag, delay := range results { + span.AddEvent("url_test_result", trace.WithAttributes( + attribute.String("outbound", tag), + attribute.Int("latency_ms", int(delay)), + )) + } + span.End() } - slog.Info("Updated Smart-Routing", "enabled", enable) - return restartTunnel() -} - -func AdBlockEnabled() bool { - return settings.GetBool(settings.AdBlockKey) -} -func SetAdBlock(enable bool) error { - if AdBlockEnabled() == enable { - return nil - } - if err := settings.Set(settings.AdBlockKey, enable); err != nil { - return err + var fmttedResults []string + for tag, delay := range results { + fmttedResults = append(fmttedResults, fmt.Sprintf("%s: [%dms]", tag, delay)) } - slog.Info("Updated Ad-Block", "enabled", enable) - return restartTunnel() + c.logger.Info("offline URL test complete") + c.logger.Log(nil, log.LevelTrace, "offline URL test results", "results", strings.Join(fmttedResults, "; ")) + return results, nil } -func restartTunnel() error { - ctx := context.Background() - if !isOpen(ctx) { - return nil - } - slog.Info("Restarting tunnel") - options, err := getOptions() +// AttemptFixNetState attempts to clear any error state left by a previous unclean shutdown, such +// as from a crash. No errors are returned and this fails silently. +func AttemptFixNetState() { + options := baseOpts("") + options = option.Options{ + DNS: options.DNS, + Inbounds: options.Inbounds, + Route: &option.RouteOptions{ + AutoDetectInterface: true, + Rules: []option.Rule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultRule{ + RawDefaultRule: option.RawDefaultRule{ + Protocol: []string{"dns"}, + }, + RuleAction: option.RuleAction{ + Action: C.RuleActionTypeHijackDNS, + }, + }, + }, + }, + }, + } + ctx, cancel := context.WithCancel(box.BaseContext()) + defer cancel() + b, err := sbox.New(sbox.Options{ + Context: ctx, + Options: options, + }) if err != nil { - return err - } - if err := ipc.RestartService(ctx, options); err != nil { - return fmt.Errorf("failed to restart tunnel: %w", err) + return } - return nil + defer b.Close() + b.Start() } diff --git a/vpn/vpn_test.go b/vpn/vpn_test.go index a3b2c8fc..eca0c09d 100644 --- a/vpn/vpn_test.go +++ b/vpn/vpn_test.go @@ -1,235 +1,244 @@ package vpn import ( - "context" - "slices" + "errors" + "log/slog" + "sync" "testing" - box "github.com/getlantern/lantern-box" - - "github.com/getlantern/radiance/common/settings" - "github.com/getlantern/radiance/internal/testutil" - "github.com/getlantern/radiance/servers" - "github.com/getlantern/radiance/vpn/ipc" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/experimental/cachefile" - "github.com/sagernet/sing-box/experimental/clashapi" "github.com/sagernet/sing-box/experimental/libbox" - "github.com/sagernet/sing/service" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" -) -func TestSelectServer(t *testing.T) { - var tests = []struct { - name string - initialGroup string - wantGroup string - wantTag string - }{ - { - name: "select in same group", - initialGroup: "socks", - wantGroup: "socks", - wantTag: "socks2-out", - }, - { - name: "select in different group", - initialGroup: "socks", - wantGroup: "http", - wantTag: "http2-out", - }, - } + rlog "github.com/getlantern/radiance/log" + "github.com/getlantern/radiance/servers" +) - testutil.SetPathsForTesting(t) - mservice := setupVpnTest(t) +// stubPlatform implements PlatformInterface for testing without real VPN operations. +type stubPlatform struct { + libbox.PlatformInterface - ctx := mservice.Ctx() - clashServer := service.FromContext[adapter.ClashServer](ctx).(*clashapi.Server) - outboundMgr := service.FromContext[adapter.OutboundManager](ctx) + restartErr error + restartCalled bool + postCloseCalled bool + restartFn func() error // optional hook invoked inside RestartService + mu sync.Mutex +} - type _selector interface { - adapter.OutboundGroup - Start() error - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // set initial group - clashServer.SetMode(tt.initialGroup) - - // start the selector - outbound, ok := outboundMgr.Outbound(tt.wantGroup) - require.True(t, ok, tt.wantGroup+" selector should exist") - selector := outbound.(_selector) - require.NoError(t, selector.Start(), "failed to start selector") - - mservice.status = ipc.Connected - require.NoError(t, SelectServer(context.Background(), tt.wantGroup, tt.wantTag)) - assert.Equal(t, tt.wantTag, selector.Now(), tt.wantTag+" should be selected") - assert.Equal(t, tt.wantGroup, clashServer.Mode(), "clash mode should be "+tt.wantGroup) - }) +func (s *stubPlatform) RestartService() error { + s.mu.Lock() + s.restartCalled = true + fn := s.restartFn + errRet := s.restartErr + s.mu.Unlock() + if fn != nil { + return fn() } + return errRet } -func TestSelectedServer(t *testing.T) { - wantGroup := "socks" - wantTag := "socks2-out" - - testutil.SetPathsForTesting(t) - opts, _, err := testBoxOptions(settings.GetString(settings.DataPathKey)) - require.NoError(t, err, "failed to load test box options") - cacheFile := cachefile.New(context.Background(), *opts.Experimental.CacheFile) - require.NoError(t, cacheFile.Start(adapter.StartStateInitialize)) - - require.NoError(t, cacheFile.StoreMode(wantGroup)) - require.NoError(t, cacheFile.StoreSelected(wantGroup, wantTag)) - _ = cacheFile.Close() - - t.Run("with tunnel open", func(t *testing.T) { - mservice := setupVpnTest(t) - outboundMgr := service.FromContext[adapter.OutboundManager](mservice.Ctx()) - require.NoError(t, outboundMgr.Start(adapter.StartStateStart), "failed to start outbound manager") - - group, tag, err := ipc.GetSelected(context.Background()) - require.NoError(t, err, "should not error when getting selected server") - assert.Equal(t, wantGroup, group, "group should match") - assert.Equal(t, wantTag, tag, "tag should match") - }) +func (s *stubPlatform) PostServiceClose() { + s.mu.Lock() + defer s.mu.Unlock() + s.postCloseCalled = true } -func TestAutoServerSelections(t *testing.T) { - testutil.SetPathsForTesting(t) - mgr := &mockOutMgr{ - outbounds: []adapter.Outbound{ - &mockOutbound{tag: "socks1-out"}, - &mockOutbound{tag: "socks2-out"}, - &mockOutbound{tag: "http1-out"}, - &mockOutbound{tag: "http2-out"}, - &mockOutboundGroup{ - mockOutbound: mockOutbound{tag: autoLanternTag}, - now: "socks1-out", - all: []string{"socks1-out", "socks2-out"}, - }, - &mockOutboundGroup{ - mockOutbound: mockOutbound{tag: autoUserTag}, - now: "http2-out", - all: []string{"http1-out", "http2-out"}, - }, - &mockOutboundGroup{ - mockOutbound: mockOutbound{tag: autoAllTag}, - now: autoLanternTag, - all: []string{autoLanternTag, autoUserTag}, - }, - }, - } - want := AutoSelections{ - Lantern: "socks1-out", - User: "http2-out", - AutoAll: "socks1-out", - } - ctx := box.BaseContext() - service.MustRegister[adapter.OutboundManager](ctx, mgr) - m := &mockService{ - ctx: ctx, - status: ipc.Connected, - } - ipcServer := ipc.NewServer(m) - require.NoError(t, ipcServer.Start()) +func TestNewVPNClient(t *testing.T) { + t.Run("nil logger defaults to slog.Default", func(t *testing.T) { + c := NewVPNClient(t.TempDir(), nil, nil) + require.NotNil(t, c) + assert.Equal(t, slog.Default(), c.logger) + assert.Equal(t, Disconnected, c.Status()) + }) - got, err := AutoServerSelections() - require.NoError(t, err, "should not error when getting auto server selections") - require.Equal(t, want, got, "selections should match") + t.Run("custom logger is retained", func(t *testing.T) { + logger := rlog.NoOpLogger() + c := NewVPNClient(t.TempDir(), logger, nil) + require.NotNil(t, c) + assert.Equal(t, logger, c.logger) + }) } -type mockOutMgr struct { - adapter.OutboundManager - outbounds []adapter.Outbound -} +func TestStatus(t *testing.T) { + t.Run("disconnected when no tunnel", func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + assert.Equal(t, Disconnected, c.Status()) + assert.False(t, c.isOpen()) + }) -func (o *mockOutMgr) Outbounds() []adapter.Outbound { - return o.outbounds + t.Run("concurrent reads", func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = c.Status() + }() + } + wg.Wait() + }) } -func (o *mockOutMgr) Outbound(tag string) (adapter.Outbound, bool) { - idx := slices.IndexFunc(o.outbounds, func(ob adapter.Outbound) bool { - return ob.Tag() == tag +func TestConnect(t *testing.T) { + t.Run("already connected", func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + c.status.Store(Connected) + c.tunnel = &tunnel{} + + err := c.Connect(BoxOptions{}) + assert.ErrorIs(t, err, ErrTunnelAlreadyConnected) + }) + + t.Run("transient state refused", func(t *testing.T) { + for _, status := range []VPNStatus{Restarting, Connecting, Disconnecting} { + t.Run(string(status), func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + c.status.Store(status) + c.tunnel = &tunnel{} + + err := c.Connect(BoxOptions{}) + require.Error(t, err) + assert.Contains(t, err.Error(), string(status)) + }) + } + }) + + t.Run("cleans up stale tunnel", func(t *testing.T) { + for _, status := range []VPNStatus{Disconnected, ErrorStatus} { + t.Run(string(status), func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + c.status.Store(status) + c.tunnel = &tunnel{} + + // Connect fails because BoxOptions has no outbounds, but the stale + // tunnel should be cleared first so the error comes from buildOptions. + err := c.Connect(BoxOptions{BasePath: t.TempDir()}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no outbounds") + }) + } }) - if idx == -1 { - return nil, false - } - return o.outbounds[idx], true } -type mockOutbound struct { - adapter.Outbound - tag string +func TestDisconnect_NoTunnel(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + assert.NoError(t, c.Disconnect()) } -func (o *mockOutbound) Tag() string { return o.tag } -func (o *mockOutbound) Type() string { return "mock" } +func TestRestart(t *testing.T) { + t.Run("no tunnel", func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + err := c.Restart(BoxOptions{}) + assert.ErrorIs(t, err, ErrTunnelNotConnected) + }) + + t.Run("tunnel not connected", func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + c.status.Store(Disconnected) + c.tunnel = &tunnel{} + + err := c.Restart(BoxOptions{}) + assert.ErrorIs(t, err, ErrTunnelNotConnected) + }) + + t.Run("platform interface success", func(t *testing.T) { + // While RestartService is in flight, VPNClient.Status() must report + // Restarting — bridging the window where the old tunnel is torn down and + // the new one has not yet reached Connected. Once RestartService returns + // successfully, status reflects the new tunnel's Connected state — which a + // real platform drives by calling VPNClient.Disconnect + Connect + // internally. The stub simulates that via a direct setStatus(Connected). + entered := make(chan struct{}) + release := make(chan struct{}) + p := &stubPlatform{} + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), p) + c.status.Store(Connected) + c.tunnel = &tunnel{} + + p.restartFn = func() error { + close(entered) + <-release + c.setStatus(Connected, nil) + return nil + } + + done := make(chan error, 1) + go func() { done <- c.Restart(BoxOptions{}) }() + + <-entered + assert.Equal(t, Restarting, c.Status(), "status should report Restarting while RestartService runs") + close(release) + + require.NoError(t, <-done) + + p.mu.Lock() + assert.True(t, p.restartCalled) + p.mu.Unlock() + assert.Equal(t, Connected, c.Status(), "status should reflect the new tunnel after restart completes") + }) -type mockOutboundGroup struct { - mockOutbound - now string - all []string + t.Run("platform interface error", func(t *testing.T) { + restartErr := errors.New("restart failed") + p := &stubPlatform{restartErr: restartErr} + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), p) + c.status.Store(Connected) + c.tunnel = &tunnel{} + + err := c.Restart(BoxOptions{}) + require.Error(t, err) + assert.ErrorIs(t, err, restartErr) + assert.Equal(t, ErrorStatus, c.Status()) + }) } -func (o *mockOutboundGroup) Now() string { return o.now } -func (o *mockOutboundGroup) All() []string { return o.all } +func TestSelectServer(t *testing.T) { + t.Run("no tunnel", func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + err := c.SelectServer("some-tag") + assert.ErrorIs(t, err, ErrTunnelNotConnected) + }) -var _ ipc.Service = (*mockService)(nil) + t.Run("tunnel disconnected", func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + c.status.Store(Disconnected) + c.tunnel = &tunnel{} -type mockService struct { - ctx context.Context - status ipc.VPNStatus - clash *clashapi.Server + err := c.SelectServer("some-tag") + assert.ErrorIs(t, err, ErrTunnelNotConnected) + }) } -func (m *mockService) Ctx() context.Context { return m.ctx } -func (m *mockService) Status() ipc.VPNStatus { return m.status } -func (m *mockService) ClashServer() *clashapi.Server { return m.clash } -func (m *mockService) Close() error { return nil } -func (m *mockService) Start(context.Context, string) error { return nil } -func (m *mockService) Restart(context.Context, string) error { return nil } -func (m *mockService) UpdateOutbounds(options servers.Servers) error { return nil } -func (m *mockService) AddOutbounds(group string, options servers.Options) error { return nil } -func (m *mockService) RemoveOutbounds(group string, tags []string) error { return nil } - -func setupVpnTest(t *testing.T) *mockService { - path := settings.GetString(settings.DataPathKey) - setupOpts := libbox.SetupOptions{ - BasePath: path, - WorkingPath: path, - TempPath: path, +func TestNoTunnelOperations(t *testing.T) { + ops := map[string]func(*VPNClient) error{ + "UpdateOutbounds": func(c *VPNClient) error { return c.UpdateOutbounds(servers.ServerList{}) }, + "AddOutbounds": func(c *VPNClient) error { return c.AddOutbounds(servers.ServerList{}) }, + "RemoveOutbounds": func(c *VPNClient) error { return c.RemoveOutbounds([]string{"tag1"}) }, + "Connections": func(c *VPNClient) error { + _, err := c.Connections() + return err + }, } - require.NoError(t, libbox.Setup(&setupOpts)) - - _, boxOpts, err := testBoxOptions(path) - require.NoError(t, err, "failed to load test box options") + for name, op := range ops { + t.Run(name, func(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + assert.ErrorIs(t, op(c), ErrTunnelNotConnected) + }) + } +} - ctx := box.BaseContext() +func TestCurrentAutoSelectedServer_NotOpen(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + selected, err := c.CurrentAutoSelectedServer() + assert.NoError(t, err) + assert.Empty(t, selected) +} - lb, err := libbox.NewServiceWithContext(ctx, boxOpts, nil) - require.NoError(t, err) - clashServer := service.FromContext[adapter.ClashServer](ctx) - cacheFile := service.FromContext[adapter.CacheFile](ctx) +func TestRunOfflineURLTests_AlreadyConnected(t *testing.T) { + c := NewVPNClient(t.TempDir(), rlog.NoOpLogger(), nil) + c.status.Store(Connected) + c.tunnel = &tunnel{} - m := &mockService{ - ctx: ctx, - status: ipc.Connected, - clash: clashServer.(*clashapi.Server), - } - ipcServer := ipc.NewServer(m) - require.NoError(t, ipcServer.Start()) - - t.Cleanup(func() { - lb.Close() - ipcServer.Close() - cacheFile.Close() - clashServer.Close() - }) - require.NoError(t, cacheFile.Start(adapter.StartStateInitialize)) - require.NoError(t, clashServer.Start(adapter.StartStateStart)) - return m + _, err := c.RunOfflineURLTests("", nil, nil) + assert.ErrorIs(t, err, ErrTunnelAlreadyConnected) }