From 904e78a7ec399906c481b5e10c2837e3662a4f81 Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Thu, 11 Jun 2026 15:14:07 +0000 Subject: [PATCH 1/2] feat(cli): expand test coverage, unify quota delegation, improve docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add behavioral tests across all CLI commands covering service calls, parameter handling, auth gates, and error paths - Add delegation error-path and profiling tests for admin service - Unify quota admin methods with generic helpers (with2/with3/with0/with2i), reducing admin_service.go boilerplate - Extract commandGetter interface and shared mockCommand for consistent test patterns across handlers - Extract prompt interfaces (Prompt/Confirm/Select) from handlers into testable abstractions - Add Makefile with git-based ldflags (version, commit, branch, build time, platform) replacing hardcoded build info - Reorder install docs: releases → go install → go build - Fix README alias framing: "not deprecated" → "first-class shortcuts" --- AGENTS.md | 11 +- Makefile | 28 + README.md | 35 +- pkg/cli/account.go | 8 +- pkg/cli/account_api_keys.go | 46 +- pkg/cli/account_api_keys_service_test.go | 37 + pkg/cli/account_api_keys_test.go | 271 ++++ pkg/cli/account_test.go | 144 +- pkg/cli/admin.go | 32 +- pkg/cli/admin_billing_credits.go | 73 +- pkg/cli/admin_billing_credits_test.go | 473 +++++- pkg/cli/admin_billing_overview.go | 7 +- pkg/cli/admin_billing_overview_test.go | 227 +++ pkg/cli/admin_billing_price_lines.go | 89 +- pkg/cli/admin_billing_price_lines_test.go | 283 +--- pkg/cli/admin_billing_pricing_plans.go | 125 +- pkg/cli/admin_billing_pricing_plans_test.go | 590 +++---- pkg/cli/admin_billing_subscribers.go | 80 +- pkg/cli/admin_billing_subscribers_test.go | 165 +- pkg/cli/admin_pprof.go | 36 +- pkg/cli/admin_pprof_test.go | 267 ++++ pkg/cli/admin_quota.go | 131 +- pkg/cli/admin_quota_actions_test.go | 517 +++++++ pkg/cli/admin_quota_create_update_test.go | 321 +--- pkg/cli/admin_quota_test.go | 227 ++- pkg/cli/admin_service.go | 111 +- pkg/cli/admin_service_delegates_test.go | 473 ++++++ pkg/cli/admin_service_test.go | 16 +- pkg/cli/admin_service_unauth_test.go | 193 +++ pkg/cli/admin_test.go | 38 +- pkg/cli/admin_token_provider_test.go | 79 + pkg/cli/admin_websites.go | 16 +- pkg/cli/admin_websites_test.go | 4 +- pkg/cli/auth.go | 14 +- pkg/cli/auth_service_test.go | 263 +++- pkg/cli/auth_status_test.go | 189 ++- pkg/cli/bench.go | 48 +- pkg/cli/bench_service.go | 2 + pkg/cli/bench_test.go | 583 +++++-- pkg/cli/command_docs_test.go | 121 ++ pkg/cli/command_getter.go | 86 ++ pkg/cli/command_helper.go | 42 + pkg/cli/command_helper_test.go | 87 ++ pkg/cli/command_registration_test.go | 803 ++++++++++ pkg/cli/command_wrapper.go | 13 + pkg/cli/command_wrapper_test.go | 93 ++ pkg/cli/config.go | 12 +- pkg/cli/config_test.go | 299 ++++ pkg/cli/confirm_email.go | 4 +- pkg/cli/confirm_email_test.go | 126 ++ pkg/cli/dns.go | 260 +--- pkg/cli/dns_helpers_test.go | 148 ++ pkg/cli/dns_service.go | 91 +- pkg/cli/dns_service_crud_test.go | 169 ++ pkg/cli/dns_service_test.go | 115 ++ pkg/cli/dns_test.go | 875 +++++++++++ pkg/cli/doctor.go | 4 +- pkg/cli/doctor_test.go | 127 +- pkg/cli/download.go | 71 +- pkg/cli/download_client_test.go | 444 ++++++ pkg/cli/download_test.go | 175 +-- pkg/cli/error_formatter_test.go | 111 ++ pkg/cli/flags_test.go | 93 ++ pkg/cli/ipfs_service_base.go | 35 + pkg/cli/ipns.go | 132 +- pkg/cli/ipns_service.go | 90 +- pkg/cli/ipns_service_crud_test.go | 430 ++++++ pkg/cli/ipns_service_test.go | 120 +- pkg/cli/ipns_test.go | 1376 ++++------------- pkg/cli/list.go | 35 +- pkg/cli/list_test.go | 104 +- pkg/cli/metadata_removed_test.go | 29 + pkg/cli/operations.go | 11 +- pkg/cli/operations_test.go | 446 ++++-- pkg/cli/output_test.go | 105 +- pkg/cli/pin.go | 34 +- pkg/cli/pin_test.go | 129 +- pkg/cli/pinning_client_batch_test.go | 522 +++++++ pkg/cli/pinning_client_test.go | 66 + pkg/cli/pinning_service_test.go | 72 +- pkg/cli/pins_add.go | 24 +- pkg/cli/pins_add_test.go | 170 ++ pkg/cli/pins_ls.go | 10 +- pkg/cli/pins_rm.go | 10 +- pkg/cli/pins_status.go | 7 +- pkg/cli/pins_update.go | 38 +- pkg/cli/pins_update_test.go | 220 +-- pkg/cli/progress_test.go | 30 + pkg/cli/prompt_interfaces.go | 34 + pkg/cli/prompt_mock.go | 60 + pkg/cli/prompt_pterm.go | 95 ++ pkg/cli/register.go | 4 +- pkg/cli/register_test.go | 179 +++ pkg/cli/restructure_integration_test.go | 30 +- pkg/cli/root_test.go | 73 + pkg/cli/setup.go | 4 +- pkg/cli/setup_mock.go | 13 + pkg/cli/setup_pterm.go | 76 +- pkg/cli/setup_test.go | 32 + pkg/cli/setup_ui.go | 2 + pkg/cli/sources_test.go | 28 + pkg/cli/status.go | 31 +- pkg/cli/status_service_test.go | 18 +- pkg/cli/status_test.go | 193 ++- pkg/cli/testutils.go | 292 ++++ pkg/cli/testutils_test.go | 31 + pkg/cli/unpin.go | 33 +- pkg/cli/unpin_all.go | 51 +- pkg/cli/unpin_all_test.go | 215 ++- pkg/cli/unpin_test.go | 106 +- pkg/cli/upload.go | 35 +- pkg/cli/upload_client_test.go | 191 ++- pkg/cli/upload_client_tus_integration_test.go | 2 +- pkg/cli/upload_test.go | 95 +- pkg/cli/utils_test.go | 43 +- pkg/cli/version.go | 4 + pkg/cli/version_test.go | 80 + pkg/cli/websites.go | 136 +- pkg/cli/websites_handler_test.go | 791 ++++++++++ pkg/cli/websites_required_records_test.go | 145 ++ pkg/cli/websites_service.go | 91 +- pkg/cli/websites_service_crud_test.go | 124 ++ pkg/cli/websites_service_test.go | 71 + pkg/cli/websites_ssl.go | 29 +- pkg/cli/websites_test.go | 487 +++--- pkg/cli/websites_wizard.go | 8 +- pkg/cli/websites_wizard_mock.go | 42 +- pkg/cli/websites_wizard_pterm.go | 48 +- pkg/cli/websites_wizard_test.go | 34 +- pkg/cli/websites_wizard_ui.go | 3 + 130 files changed, 13943 insertions(+), 4982 deletions(-) create mode 100644 Makefile create mode 100644 pkg/cli/account_api_keys_service_test.go create mode 100644 pkg/cli/admin_billing_overview_test.go create mode 100644 pkg/cli/admin_pprof_test.go create mode 100644 pkg/cli/admin_quota_actions_test.go create mode 100644 pkg/cli/admin_service_delegates_test.go create mode 100644 pkg/cli/admin_service_unauth_test.go create mode 100644 pkg/cli/admin_token_provider_test.go create mode 100644 pkg/cli/command_docs_test.go create mode 100644 pkg/cli/command_getter.go create mode 100644 pkg/cli/command_helper_test.go create mode 100644 pkg/cli/command_registration_test.go create mode 100644 pkg/cli/command_wrapper_test.go create mode 100644 pkg/cli/confirm_email_test.go create mode 100644 pkg/cli/dns_helpers_test.go create mode 100644 pkg/cli/dns_service_crud_test.go create mode 100644 pkg/cli/dns_service_test.go create mode 100644 pkg/cli/dns_test.go create mode 100644 pkg/cli/download_client_test.go create mode 100644 pkg/cli/ipfs_service_base.go create mode 100644 pkg/cli/ipns_service_crud_test.go create mode 100644 pkg/cli/metadata_removed_test.go create mode 100644 pkg/cli/pinning_client_batch_test.go create mode 100644 pkg/cli/pinning_client_test.go create mode 100644 pkg/cli/pins_add_test.go create mode 100644 pkg/cli/prompt_interfaces.go create mode 100644 pkg/cli/prompt_mock.go create mode 100644 pkg/cli/prompt_pterm.go create mode 100644 pkg/cli/register_test.go create mode 100644 pkg/cli/root_test.go create mode 100644 pkg/cli/sources_test.go create mode 100644 pkg/cli/testutils_test.go create mode 100644 pkg/cli/version_test.go create mode 100644 pkg/cli/websites_handler_test.go create mode 100644 pkg/cli/websites_required_records_test.go create mode 100644 pkg/cli/websites_service_crud_test.go diff --git a/AGENTS.md b/AGENTS.md index 4629dfe..8edcd44 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -5,16 +5,19 @@ This file provides guidance to various AI agents when working with code in this ### Building ```bash -# Build for current platform +# Build with version info (recommended) +make build + +# Or install to $GOPATH/bin +make install + +# Build without make go build -o pinner ./cmd/pinner # Cross-compile for different platforms GOOS=linux GOARCH=amd64 go build -o pinner-linux-amd64 ./cmd/pinner GOOS=darwin GOARCH=arm64 go build -o pinner-darwin-arm64 ./cmd/pinner GOOS=windows GOARCH=amd64 go build -o pinner-windows-amd64.exe ./cmd/pinner - -# Build with version info -go build -ldflags="-X 'build.Version=1.0.0' -X 'build.GitCommit=abc123'" -o pinner ./cmd/pinner ``` ### Testing diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a154cb7 --- /dev/null +++ b/Makefile @@ -0,0 +1,28 @@ +.PHONY: build install clean + +VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo dev) +GIT_COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown) +GIT_BRANCH := $(shell git rev-parse --abbrev-ref HEAD 2>/dev/null || echo unknown) +BUILD_TIME := $(shell date -u +%Y-%m-%dT%H:%M:%SZ) +GO_VERSION := $(shell go version | sed 's/go version //') +PLATFORM := $(shell go env GOOS) +ARCH := $(shell go env GOARCH) + +PKG := go.lumeweb.com/pinner-cli/build + +LDFLAGS := -X '$(PKG).Version=$(VERSION)' \ + -X '$(PKG).GitCommit=$(GIT_COMMIT)' \ + -X '$(PKG).GitBranch=$(GIT_BRANCH)' \ + -X '$(PKG).BuildTime=$(BUILD_TIME)' \ + -X '$(PKG).GoVersion=$(GO_VERSION)' \ + -X '$(PKG).Platform=$(PLATFORM)' \ + -X '$(PKG).Architecture=$(ARCH)' + +build: + go build -ldflags="$(LDFLAGS)" -o pinner ./cmd/pinner + +install: + go install -ldflags="$(LDFLAGS)" ./cmd/pinner + +clean: + rm -f pinner diff --git a/README.md b/README.md index 7a19e77..eea78ca 100644 --- a/README.md +++ b/README.md @@ -14,24 +14,35 @@ A developer-focused CLI for pinning content to IPFS, managing websites, DNS, and ## Installation +### From Releases + +Download the latest binary from the [releases page](https://github.com/lumeweb/pinner-cli/releases). + +### From go install + +```bash +go install github.com/lumeweb/pinner-cli/cmd/pinner@latest +``` + ### From Source ```bash -# Build for current platform +# Build with version info from git (recommended) +make build + +# Or install to $GOPATH/bin +make install + +# Build without make go build -o pinner ./cmd/pinner # Cross-compile for different platforms GOOS=linux GOARCH=amd64 go build -o pinner-linux-amd64 ./cmd/pinner GOOS=darwin GOARCH=arm64 go build -o pinner-darwin-arm64 ./cmd/pinner GOOS=windows GOARCH=amd64 go build -o pinner-windows-amd64.exe ./cmd/pinner - -# Build with version info -go build -ldflags="-X 'build.Version=1.0.0' -X 'build.GitCommit=abc123'" -o pinner ./cmd/pinner ``` -### From Pre-built Binaries - -Download the latest release from the [releases page](https://github.com/lumeweb/pinner-cli/releases). +The Makefile injects git commit, branch, version, build time, and platform info via ldflags. ## Quick Start @@ -277,7 +288,7 @@ pinner pins rm --all --status failed --force pinner pins add bafybeig... --dry-run ``` -**Note**: `pin`, `list`, `status`, and `unpin` are aliases for `pins add`, `pins ls`, `pins status`, and `pins rm` respectively. They work identically and are not deprecated. +**Note**: `pin`, `list`, `status`, and `unpin` are first-class shortcuts for `pins add`, `pins ls`, `pins status`, and `pins rm` respectively — use whichever you prefer. ### Pin @@ -837,16 +848,16 @@ mockery --name=AuthService ### Building ```bash -# Build for current platform +# Build with version info (recommended) +make build + +# Build without make go build -o pinner ./cmd/pinner # Cross-compile for different platforms GOOS=linux GOARCH=amd64 go build -o pinner-linux-amd64 ./cmd/pinner GOOS=darwin GOARCH=arm64 go build -o pinner-darwin-arm64 ./cmd/pinner GOOS=windows GOARCH=amd64 go build -o pinner-windows-amd64.exe ./cmd/pinner - -# Build with version info -go build -ldflags="-X 'build.Version=1.0.0' -X 'build.GitCommit=abc123'" -o pinner ./cmd/pinner ``` ### Running the CLI diff --git a/pkg/cli/account.go b/pkg/cli/account.go index 3f381da..7b33b53 100644 --- a/pkg/cli/account.go +++ b/pkg/cli/account.go @@ -54,7 +54,7 @@ After successful verification, 2FA will be required for all future logins.`, }, Action: func(ctx context.Context, cmd *cli.Command) error { output := setupOutput(cmd) - return accountOTPEnable(ctx, cmd, output, defaultConfigManagerFactory, defaultAuthServiceFactory) + return accountOTPEnable(ctx, newCLICommandWrapper(cmd), output, defaultConfigManagerFactory, defaultAuthServiceFactory) }, }, { @@ -78,14 +78,14 @@ WARNING: This reduces your account security. Consider re-enabling 2FA.`, }, Action: func(ctx context.Context, cmd *cli.Command) error { output := setupOutput(cmd) - return accountOTPDisable(ctx, cmd, output, defaultConfigManagerFactory, defaultAuthServiceFactory) + return accountOTPDisable(ctx, newCLICommandWrapper(cmd), output, defaultConfigManagerFactory, defaultAuthServiceFactory) }, }, }, } } -func accountOTPEnable(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory) error { +func accountOTPEnable(ctx context.Context, cmd flagGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory) error { cfgMgr, err := cfgMgrFactory() if err != nil { return fmt.Errorf("failed to initialize config manager: %w", err) @@ -99,7 +99,7 @@ func accountOTPEnable(ctx context.Context, cmd *cli.Command, output Output, cfgM return authService.EnableOTP(ctx, otpCode) } -func accountOTPDisable(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory) error { +func accountOTPDisable(ctx context.Context, cmd flagGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory) error { cfgMgr, err := cfgMgrFactory() if err != nil { return fmt.Errorf("failed to initialize config manager: %w", err) diff --git a/pkg/cli/account_api_keys.go b/pkg/cli/account_api_keys.go index d1e663e..55c60b5 100644 --- a/pkg/cli/account_api_keys.go +++ b/pkg/cli/account_api_keys.go @@ -6,6 +6,7 @@ import ( "github.com/urfave/cli/v3" portalsdk "go.lumeweb.com/portal-sdk" + "go.lumeweb.com/pinner-cli/pkg/config" ) func newAccountAPIKeysCommand() *cli.Command { @@ -41,7 +42,12 @@ Use --search to filter keys by name.`, }, Action: func(ctx context.Context, cmd *cli.Command) error { output := setupOutput(cmd) - return accountAPIKeysList(ctx, cmd, output, defaultConfigManagerFactory, defaultAuthServiceFactory, defaultAPIKeyServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(cmd, cfgMgr) + return accountAPIKeysList(ctx, newCLICommandWrapper(cmd), output, cfgMgr, authToken, defaultAuthServiceFactory, defaultAPIKeyServiceFactory) }, }, { @@ -58,7 +64,12 @@ This key can be used with: PINNER_AUTH_TOKEN= pinner `, Action: func(ctx context.Context, cmd *cli.Command) error { output := setupOutput(cmd) - return accountAPIKeysCreate(ctx, cmd, output, defaultConfigManagerFactory, defaultAuthServiceFactory, defaultAPIKeyServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(cmd, cfgMgr) + return accountAPIKeysCreate(ctx, newCLICommandWrapper(cmd), output, cfgMgr, authToken, defaultAuthServiceFactory, defaultAPIKeyServiceFactory) }, }, { @@ -79,21 +90,20 @@ current key, you must re-authenticate with 'pinner auth'.`, }, Action: func(ctx context.Context, cmd *cli.Command) error { output := setupOutput(cmd) - return accountAPIKeysDelete(ctx, cmd, output, defaultConfigManagerFactory, defaultAuthServiceFactory, defaultAPIKeyServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(cmd, cfgMgr) + return accountAPIKeysDelete(ctx, newCLICommandWrapper(cmd), output, cfgMgr, authToken, defaultAuthServiceFactory, defaultAPIKeyServiceFactory) }, }, }, } } -func accountAPIKeysList(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, svcFactory APIKeyServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return fmt.Errorf("failed to initialize config manager: %w", err) - } - +func accountAPIKeysList(ctx context.Context, cmd flagGetter, output Output, cfgMgr config.Manager, authToken string, authServiceFactory AuthServiceFactory, svcFactory APIKeyServiceFactory) error { apiEndpoint := cfgMgr.Config().GetAPIEndpoint() - authToken := GetAuthToken(cmd, cfgMgr) authService := authServiceFactory(cfgMgr, output, apiEndpoint) svc := svcFactory(authService, authToken) @@ -136,14 +146,8 @@ func accountAPIKeysList(ctx context.Context, cmd *cli.Command, output Output, cf return nil } -func accountAPIKeysCreate(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, svcFactory APIKeyServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return fmt.Errorf("failed to initialize config manager: %w", err) - } - +func accountAPIKeysCreate(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, authServiceFactory AuthServiceFactory, svcFactory APIKeyServiceFactory) error { apiEndpoint := cfgMgr.Config().GetAPIEndpoint() - authToken := GetAuthToken(cmd, cfgMgr) authService := authServiceFactory(cfgMgr, output, apiEndpoint) svc := svcFactory(authService, authToken) @@ -177,14 +181,8 @@ func accountAPIKeysCreate(ctx context.Context, cmd *cli.Command, output Output, return nil } -func accountAPIKeysDelete(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, svcFactory APIKeyServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return fmt.Errorf("failed to initialize config manager: %w", err) - } - +func accountAPIKeysDelete(ctx context.Context, cmd argsFlagGetterWithBool, output Output, cfgMgr config.Manager, authToken string, authServiceFactory AuthServiceFactory, svcFactory APIKeyServiceFactory) error { apiEndpoint := cfgMgr.Config().GetAPIEndpoint() - authToken := GetAuthToken(cmd, cfgMgr) authService := authServiceFactory(cfgMgr, output, apiEndpoint) svc := svcFactory(authService, authToken) diff --git a/pkg/cli/account_api_keys_service_test.go b/pkg/cli/account_api_keys_service_test.go new file mode 100644 index 0000000..729ca85 --- /dev/null +++ b/pkg/cli/account_api_keys_service_test.go @@ -0,0 +1,37 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAPIKeyService(t *testing.T) { + mockAuth := NewMockAuthService(t) + svc := NewAPIKeyService(mockAuth, "test-token") + require.NotNil(t, svc) +} + +func TestDefaultAPIKeyServiceFactory(t *testing.T) { + mockAuth := NewMockAuthService(t) + svc := defaultAPIKeyServiceFactory(mockAuth, "test-token") + require.NotNil(t, svc) +} + +func TestAPIKeyServiceRequireAuthenticated(t *testing.T) { + t.Run("authenticated with token", func(t *testing.T) { + mockAuth := NewMockAuthService(t) + svc := NewAPIKeyService(mockAuth, "test-token") + err := svc.RequireAuthenticated() + assert.NoError(t, err) + }) + + t.Run("not authenticated without token", func(t *testing.T) { + mockAuth := NewMockAuthService(t) + svc := NewAPIKeyService(mockAuth, "") + err := svc.RequireAuthenticated() + assert.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + }) +} diff --git a/pkg/cli/account_api_keys_test.go b/pkg/cli/account_api_keys_test.go index 0a4271e..a4bd8b2 100644 --- a/pkg/cli/account_api_keys_test.go +++ b/pkg/cli/account_api_keys_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" portalsdk "go.lumeweb.com/portal-sdk" portalsdkmocks "go.lumeweb.com/portal-sdk/mocks" + "go.lumeweb.com/pinner-cli/pkg/config" ) func newTestAPIKey(name, uuidStr string) *portalsdk.APIKey { @@ -326,6 +327,276 @@ func TestIsUUIDString(t *testing.T) { } } +func TestNewAccountAPIKeysCommand(t *testing.T) { + cmd := newAccountAPIKeysCommand() + require.Equal(t, "api-keys", cmd.Name) + require.Contains(t, cmd.Aliases, "apikey") + require.Contains(t, cmd.Aliases, "api-key") + require.Len(t, cmd.Commands, 3) +} + +type mockAPIKeyServiceForCLI struct { + listFunc func(ctx context.Context, search string) ([]*portalsdk.APIKey, int, error) + createFunc func(ctx context.Context, name string) (*portalsdk.APIKey, error) + deleteFunc func(ctx context.Context, idOrName string, force bool) error + currentUUIDFunc func() string + requireAuthErr error +} + +func (m *mockAPIKeyServiceForCLI) ListAPIKeys(ctx context.Context, search string) ([]*portalsdk.APIKey, int, error) { + if m.listFunc != nil { + return m.listFunc(ctx, search) + } + return nil, 0, nil +} + +func (m *mockAPIKeyServiceForCLI) CreateAPIKey(ctx context.Context, name string) (*portalsdk.APIKey, error) { + if m.createFunc != nil { + return m.createFunc(ctx, name) + } + return nil, nil +} + +func (m *mockAPIKeyServiceForCLI) DeleteAPIKey(ctx context.Context, idOrName string, force bool) error { + if m.deleteFunc != nil { + return m.deleteFunc(ctx, idOrName, force) + } + return nil +} + +func (m *mockAPIKeyServiceForCLI) GetCurrentAPIKeyUUID() string { + if m.currentUUIDFunc != nil { + return m.currentUUIDFunc() + } + return "" +} + +func (m *mockAPIKeyServiceForCLI) RequireAuthenticated() error { + return m.requireAuthErr +} + +func setupAPIKeyHandlerTest(t *testing.T) (*mockAPIKeyServiceForCLI, config.Manager) { + t.Helper() + mockSvc := &mockAPIKeyServiceForCLI{} + cfgMgr := newTestConfigMgr(t) + return mockSvc, cfgMgr +} + +func TestAccountAPIKeysList_Success(t *testing.T) { + mockSvc, cfgMgr := setupAPIKeyHandlerTest(t) + mockSvc.listFunc = func(ctx context.Context, search string) ([]*portalsdk.APIKey, int, error) { + return []*portalsdk.APIKey{ + newTestAPIKey("my-key", "00000000-0000-0000-0000-000000000001"), + }, 1, nil + } + + output := newTestOutput() + cmd := newMockCommand() + err := accountAPIKeysList(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return mockSvc + }, + ) + require.NoError(t, err) +} + +func TestAccountAPIKeysList_Empty(t *testing.T) { + mockSvc, cfgMgr := setupAPIKeyHandlerTest(t) + mockSvc.listFunc = func(ctx context.Context, search string) ([]*portalsdk.APIKey, int, error) { + return []*portalsdk.APIKey{}, 0, nil + } + + output := newTestOutput() + cmd := newMockCommand() + err := accountAPIKeysList(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return mockSvc + }, + ) + require.NoError(t, err) +} + +func TestAccountAPIKeysList_WithSearch(t *testing.T) { + mockSvc, cfgMgr := setupAPIKeyHandlerTest(t) + mockSvc.listFunc = func(ctx context.Context, search string) ([]*portalsdk.APIKey, int, error) { + require.Equal(t, "my-key", search) + return []*portalsdk.APIKey{}, 0, nil + } + + output := newTestOutput() + cmd := newMockCommand().withString(FlagSearch, "my-key") + err := accountAPIKeysList(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return mockSvc + }, + ) + require.NoError(t, err) +} + +func TestAccountAPIKeysList_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupAPIKeyHandlerTest(t) + mockSvc.listFunc = func(ctx context.Context, search string) ([]*portalsdk.APIKey, int, error) { + return nil, 0, fmt.Errorf("server error") + } + + output := newTestOutput() + cmd := newMockCommand() + err := accountAPIKeysList(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return mockSvc + }, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to list API keys") +} + +func TestAccountAPIKeysCreate_Success(t *testing.T) { + mockSvc, cfgMgr := setupAPIKeyHandlerTest(t) + mockSvc.createFunc = func(ctx context.Context, name string) (*portalsdk.APIKey, error) { + require.Equal(t, "my-new-key", name) + return portalsdk.NewAPIKey("my-new-key", "generated-token"), nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("my-new-key") + err := accountAPIKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return mockSvc + }, + ) + require.NoError(t, err) +} + +func TestAccountAPIKeysCreate_MissingName(t *testing.T) { + _, cfgMgr := setupAPIKeyHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := accountAPIKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return &mockAPIKeyServiceForCLI{} + }, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "API key name is required") +} + +func TestAccountAPIKeysCreate_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupAPIKeyHandlerTest(t) + mockSvc.createFunc = func(ctx context.Context, name string) (*portalsdk.APIKey, error) { + return nil, fmt.Errorf("duplicate key name") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("dup-key") + err := accountAPIKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return mockSvc + }, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to create API key") +} + +func TestAccountAPIKeysDelete_Success(t *testing.T) { + mockSvc, cfgMgr := setupAPIKeyHandlerTest(t) + mockSvc.deleteFunc = func(ctx context.Context, idOrName string, force bool) error { + require.Equal(t, "00000000-0000-0000-0000-000000000001", idOrName) + require.False(t, force) + return nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("00000000-0000-0000-0000-000000000001") + err := accountAPIKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return mockSvc + }, + ) + require.NoError(t, err) +} + +func TestAccountAPIKeysDelete_MissingArg(t *testing.T) { + _, cfgMgr := setupAPIKeyHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := accountAPIKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return &mockAPIKeyServiceForCLI{} + }, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "API key UUID or name is required") +} + +func TestAccountAPIKeysDelete_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupAPIKeyHandlerTest(t) + mockSvc.deleteFunc = func(ctx context.Context, idOrName string, force bool) error { + return fmt.Errorf("not found") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("nonexistent").withBool(FlagForce, true) + err := accountAPIKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return mockSvc + }, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") +} + +func TestAccountAPIKeysDelete_WithForce(t *testing.T) { + mockSvc, cfgMgr := setupAPIKeyHandlerTest(t) + mockSvc.deleteFunc = func(ctx context.Context, idOrName string, force bool) error { + require.True(t, force) + return nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("my-key").withBool(FlagForce, true) + err := accountAPIKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token", + func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return NewMockAuthService(t) + }, + func(authService AuthService, authToken string) APIKeyService { + return mockSvc + }, + ) + require.NoError(t, err) +} + // makeAPIKeyJWT creates a minimal JWT string with the given subject and audience. func makeAPIKeyJWT(sub, aud string) string { header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) diff --git a/pkg/cli/account_test.go b/pkg/cli/account_test.go index 3c334e4..64956b6 100644 --- a/pkg/cli/account_test.go +++ b/pkg/cli/account_test.go @@ -54,7 +54,7 @@ func TestAccountOTPEnable(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) authService := NewMockAuthService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() var cfgMgrFactory ConfigManagerFactory if tt.name == "config manager factory fails" { @@ -85,7 +85,7 @@ func TestAccountOTPEnable(t *testing.T) { } } - err := accountOTPEnable(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory) + err := accountOTPEnable(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory, authServiceFactory) if tt.wantErr { require.Error(t, err) @@ -139,7 +139,7 @@ func TestAccountOTPDisable(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) authService := NewMockAuthService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() var cfgMgrFactory ConfigManagerFactory if tt.name == "config manager factory fails" { @@ -170,7 +170,7 @@ func TestAccountOTPDisable(t *testing.T) { } } - err := accountOTPDisable(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory) + err := accountOTPDisable(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory, authServiceFactory) if tt.wantErr { require.Error(t, err) @@ -229,7 +229,7 @@ func TestAuthService_EnableOTP(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) acc := portalsdkmocks.NewMockAccountAPI(t) prompter := NewMockAuthPrompter(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() // Mock Config() to return a config with a login JWT cfg := config.NewConfig() @@ -294,7 +294,7 @@ func TestAuthService_EnableOTP_Interactive(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) acc := portalsdkmocks.NewMockAccountAPI(t) prompter := NewMockAuthPrompter(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() // Mock Config() to return a config with a login JWT cfg := config.NewConfig() @@ -359,7 +359,7 @@ func TestAuthService_DisableOTP(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) acc := portalsdkmocks.NewMockAccountAPI(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() // Mock Config() to return a config with a login JWT cfg := config.NewConfig() @@ -391,6 +391,134 @@ func TestAuthService_DisableOTP(t *testing.T) { } } +func TestAccountOTPEnable_MockCommand_Success(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().EnableOTP(context.Background(), "123456").Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + cmd := newMockCommand().withString(FlagOTP, "123456") + err := accountOTPEnable(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory) + require.NoError(t, err) +} + +func TestAccountOTPEnable_MockCommand_NoOTP(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().EnableOTP(context.Background(), "").Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + cmd := newMockCommand() + err := accountOTPEnable(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory) + require.NoError(t, err) +} + +func TestAccountOTPEnable_MockCommand_ServiceError(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().EnableOTP(context.Background(), "000000"). + Return(errors.New("invalid OTP code")) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + cmd := newMockCommand().withString(FlagOTP, "000000") + err := accountOTPEnable(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid OTP code") +} + +func TestAccountOTPEnable_MockCommand_ConfigError(t *testing.T) { + output := newTestOutput() + + cmd := newMockCommand().withString(FlagOTP, "123456") + err := accountOTPEnable(context.Background(), cmd, output, failingConfigMgrFactory(), func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return nil + }) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to initialize config manager") +} + +func TestAccountOTPDisable_MockCommand_Success(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().DisableOTP(context.Background(), "mypassword").Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + cmd := newMockCommand().withString("password", "mypassword") + err := accountOTPDisable(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory) + require.NoError(t, err) +} + +func TestAccountOTPDisable_MockCommand_EmptyPassword(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().DisableOTP(context.Background(), "").Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + cmd := newMockCommand() + err := accountOTPDisable(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory) + require.NoError(t, err) +} + +func TestAccountOTPDisable_MockCommand_ServiceError(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().DisableOTP(context.Background(), "wrong"). + Return(errors.New("invalid password")) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + cmd := newMockCommand().withString("password", "wrong") + err := accountOTPDisable(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid password") +} + +func TestAccountOTPDisable_MockCommand_ConfigError(t *testing.T) { + output := newTestOutput() + + cmd := newMockCommand().withString("password", "test") + err := accountOTPDisable(context.Background(), cmd, output, failingConfigMgrFactory(), func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return nil + }) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to initialize config manager") +} + func TestAuthService_DisableOTP_Interactive(t *testing.T) { tests := []struct { name string @@ -421,7 +549,7 @@ func TestAuthService_DisableOTP_Interactive(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) acc := portalsdkmocks.NewMockAccountAPI(t) prompter := NewMockAuthPrompter(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() // Mock Config() to return a config with a login JWT (optional, for early exit cases) cfg := config.NewConfig() diff --git a/pkg/cli/admin.go b/pkg/cli/admin.go index 85fc017..a87f7e7 100644 --- a/pkg/cli/admin.go +++ b/pkg/cli/admin.go @@ -129,7 +129,7 @@ Examples: if err != nil { return err } - return quotaPlansListAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaPlansListAction(ctx, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, { @@ -146,7 +146,7 @@ Examples: if err != nil { return err } - return quotaPlansGetAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaPlansGetAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, { @@ -199,7 +199,7 @@ Examples: if err != nil { return err } - return quotaPlansCreateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaPlansCreateAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, { @@ -250,7 +250,7 @@ Examples: if err != nil { return err } - return quotaPlansUpdateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaPlansUpdateAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, { @@ -266,7 +266,7 @@ Examples: if err != nil { return err } - return quotaPlansDeleteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaPlansDeleteAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, { @@ -282,7 +282,7 @@ Examples: if err != nil { return err } - return quotaPlansSetDefaultAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaPlansSetDefaultAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, }, @@ -312,7 +312,7 @@ Examples: if err != nil { return err } - return quotaAllowancesListAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaAllowancesListAction(ctx, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, { @@ -357,7 +357,7 @@ Examples: if err != nil { return err } - return quotaAllowancesCreateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaAllowancesCreateAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, { @@ -403,7 +403,7 @@ Examples: if err != nil { return err } - return quotaAllowancesUpdateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaAllowancesUpdateAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, { @@ -419,7 +419,7 @@ Examples: if err != nil { return err } - return quotaAllowancesDeleteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaAllowancesDeleteAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, }, @@ -450,7 +450,7 @@ Examples: if err != nil { return err } - return quotaUserConfigsListAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaUserConfigsListAction(ctx, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, { @@ -522,7 +522,7 @@ Examples: if err != nil { return err } - return quotaUserConfigsUpdateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaUserConfigsUpdateAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, { @@ -538,7 +538,7 @@ Examples: if err != nil { return err } - return quotaUserConfigsResetAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaUserConfigsResetAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, }, }, @@ -559,7 +559,7 @@ Examples: if err != nil { return err } - return quotaStatsAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaStatsAction(ctx, output, cfgMgr, defaultQuotaAdminServiceFactory) }, } } @@ -584,7 +584,7 @@ Examples: if err != nil { return err } - return quotaReconcileAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaReconcileAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, } } @@ -609,7 +609,7 @@ Examples: if err != nil { return err } - return quotaCleanupAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultQuotaAdminServiceFactory) + return quotaCleanupAction(ctx, cmd, output, cfgMgr, defaultQuotaAdminServiceFactory) }, } } diff --git a/pkg/cli/admin_billing_credits.go b/pkg/cli/admin_billing_credits.go index cf61b08..4eac0c1 100644 --- a/pkg/cli/admin_billing_credits.go +++ b/pkg/cli/admin_billing_credits.go @@ -69,17 +69,12 @@ Examples: if err != nil { return err } - return billingCreditsListAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingCreditsListAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingCreditsListCmdGetter defines the interface for getting list command flags. -type billingCreditsListCmdGetter interface { - String(name string) string -} - -func billingCreditsListAction(ctx context.Context, cmd billingCreditsListCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingCreditsListAction(ctx context.Context, cmd flagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -147,17 +142,12 @@ Examples: if err != nil { return err } - return billingCreditsGetAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingCreditsGetAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingCreditsGetCmdGetter defines the interface for getting get command args. -type billingCreditsGetCmdGetter interface { - Args() cli.Args -} - -func billingCreditsGetAction(ctx context.Context, cmd billingCreditsGetCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingCreditsGetAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("credit ID is required") } @@ -248,17 +238,12 @@ Examples: if err != nil { return err } - return billingCreditsCreateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingCreditsCreateAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingCreditsCreateCmdGetter defines the interface for getting create command flags. -type billingCreditsCreateCmdGetter interface { - String(name string) string -} - -func billingCreditsCreateAction(ctx context.Context, cmd billingCreditsCreateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingCreditsCreateAction(ctx context.Context, cmd flagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -323,17 +308,12 @@ Examples: if err != nil { return err } - return billingCreditsDeleteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingCreditsDeleteAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingCreditsDeleteCmdGetter defines the interface for getting delete command args. -type billingCreditsDeleteCmdGetter interface { - Args() cli.Args -} - -func billingCreditsDeleteAction(ctx context.Context, cmd billingCreditsDeleteCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingCreditsDeleteAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("credit ID is required") } @@ -375,17 +355,12 @@ Examples: if err != nil { return err } - return billingCreditsRestoreAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingCreditsRestoreAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingCreditsRestoreCmdGetter defines the interface for getting restore command args. -type billingCreditsRestoreCmdGetter interface { - Args() cli.Args -} - -func billingCreditsRestoreAction(ctx context.Context, cmd billingCreditsRestoreCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingCreditsRestoreAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("credit ID is required") } @@ -432,17 +407,12 @@ Examples: if err != nil { return err } - return billingCreditsPurgeAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingCreditsPurgeAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingCreditsPurgeCmdGetter defines the interface for getting purge command flags. -type billingCreditsPurgeCmdGetter interface { - String(name string) string -} - -func billingCreditsPurgeAction(ctx context.Context, cmd billingCreditsPurgeCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingCreditsPurgeAction(ctx context.Context, cmd flagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -483,17 +453,12 @@ Examples: if err != nil { return err } - return billingCreditsUserBalanceAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingCreditsUserBalanceAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingCreditsUserBalanceCmdGetter defines the interface for getting user-balance command args. -type billingCreditsUserBalanceCmdGetter interface { - Args() cli.Args -} - -func billingCreditsUserBalanceAction(ctx context.Context, cmd billingCreditsUserBalanceCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingCreditsUserBalanceAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("user ID is required") } @@ -548,18 +513,12 @@ Examples: if err != nil { return err } - return billingCreditsUserDeletedCreditsAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingCreditsUserDeletedCreditsAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingCreditsUserDeletedCreditsCmdGetter defines the interface for getting user-deleted-credits command args and flags. -type billingCreditsUserDeletedCreditsCmdGetter interface { - Args() cli.Args - String(name string) string -} - -func billingCreditsUserDeletedCreditsAction(ctx context.Context, cmd billingCreditsUserDeletedCreditsCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingCreditsUserDeletedCreditsAction(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("user ID is required") } diff --git a/pkg/cli/admin_billing_credits_test.go b/pkg/cli/admin_billing_credits_test.go index b1e732c..9147f5f 100644 --- a/pkg/cli/admin_billing_credits_test.go +++ b/pkg/cli/admin_billing_credits_test.go @@ -8,31 +8,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" "go.lumeweb.com/portal-sdk/admin" ) -// Mock command getters for billing credits tests -type mockBillingCreditsListCmd struct { - userID string - direction string - creditType string -} - -func (m *mockBillingCreditsListCmd) String(name string) string { - switch name { - case FlagUserID: - return m.userID - case FlagDirection: - return m.direction - case FlagType: - return m.creditType - } - return "" -} - func unmarshalCreditItemJSON(data string) *admin.CreditItem { var item admin.CreditItem if err := json.Unmarshal([]byte(data), &item); err != nil { @@ -97,7 +77,7 @@ func TestBillingCreditsList(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -107,7 +87,7 @@ func TestBillingCreditsList(t *testing.T) { return service } - cmd := &mockBillingCreditsListCmd{} + cmd := newMockCommand() err := billingCreditsListAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -123,15 +103,6 @@ func TestBillingCreditsList(t *testing.T) { } } -// mockBillingCreditsGetCmd implements billingCreditsGetCmdGetter -type mockBillingCreditsGetCmd struct { - args cli.Args -} - -func (m *mockBillingCreditsGetCmd) Args() cli.Args { - return m.args -} - func TestBillingCreditsGet(t *testing.T) { tests := []struct { name string @@ -174,17 +145,16 @@ func TestBillingCreditsGet(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.creditID != "" { - args.args = []string{tt.creditID} + cmd = cmd.withArgs(tt.creditID) } - cmd := &mockBillingCreditsGetCmd{args: args} serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -204,6 +174,408 @@ func TestBillingCreditsGet(t *testing.T) { } } +func TestBillingCreditsDelete(t *testing.T) { + tests := []struct { + name string + creditID string + setupMocks func(*configmocks.MockManager, *MockBillingAdminService) + wantErr bool + errContains string + }{ + { + name: "successful delete", + creditID: "123e4567-e89b-12d3-a456-426614174000", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().DeleteCredit(context.Background(), "123e4567-e89b-12d3-a456-426614174000").Return(nil) + }, + wantErr: false, + }, + { + name: "returns error when credit ID is missing", + creditID: "", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) {}, + wantErr: true, + errContains: "credit ID is required", + }, + { + name: "returns error when not authenticated", + creditID: "123e4567-e89b-12d3-a456-426614174000", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "returns error when service fails", + creditID: "123e4567-e89b-12d3-a456-426614174000", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().DeleteCredit(context.Background(), "123e4567-e89b-12d3-a456-426614174000").Return(errors.New("service error")) + }, + wantErr: true, + errContains: "service error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockBillingAdminService(t) + output := newTestOutput() + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, service) + } + + serviceFactory := func(cm config.Manager, out Output) BillingAdminService { + return service + } + + cmd := newMockCommand() + if tt.creditID != "" { + cmd = cmd.withArgs(tt.creditID) + } + + err := billingCreditsDeleteAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestBillingCreditsRestore(t *testing.T) { + tests := []struct { + name string + creditID string + setupMocks func(*configmocks.MockManager, *MockBillingAdminService) + wantErr bool + errContains string + }{ + { + name: "successful restore", + creditID: "123e4567-e89b-12d3-a456-426614174000", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().RestoreCredit(context.Background(), "123e4567-e89b-12d3-a456-426614174000").Return( + unmarshalCreditJSON(`{"id":"123e4567-e89b-12d3-a456-426614174000","user_id":123,"amount":"100.00","type":"manual","direction":"credit"}`), + nil, + ) + }, + wantErr: false, + }, + { + name: "returns error when credit ID is missing", + creditID: "", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) {}, + wantErr: true, + errContains: "credit ID is required", + }, + { + name: "returns error when not authenticated", + creditID: "123e4567-e89b-12d3-a456-426614174000", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "returns error when service fails", + creditID: "123e4567-e89b-12d3-a456-426614174000", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().RestoreCredit(context.Background(), "123e4567-e89b-12d3-a456-426614174000").Return( + nil, + errors.New("service error"), + ) + }, + wantErr: true, + errContains: "service error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockBillingAdminService(t) + output := newTestOutput() + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, service) + } + + serviceFactory := func(cm config.Manager, out Output) BillingAdminService { + return service + } + + cmd := newMockCommand() + if tt.creditID != "" { + cmd = cmd.withArgs(tt.creditID) + } + + err := billingCreditsRestoreAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestBillingCreditsPurge(t *testing.T) { + tests := []struct { + name string + setupMocks func(*configmocks.MockManager, *MockBillingAdminService) + wantErr bool + errContains string + }{ + { + name: "successful purge", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().PurgeCredits(context.Background(), &admin.CreditPurgeRequest{ + OlderThan: "30d", + }).Return(5, nil) + }, + wantErr: false, + }, + { + name: "returns error when not authenticated", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "returns error when service fails", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().PurgeCredits(context.Background(), &admin.CreditPurgeRequest{ + OlderThan: "30d", + }).Return(0, errors.New("service error")) + }, + wantErr: true, + errContains: "service error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockBillingAdminService(t) + output := newTestOutput() + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, service) + } + + serviceFactory := func(cm config.Manager, out Output) BillingAdminService { + return service + } + + cmd := newMockCommand().withString(FlagOlderThan, "30d") + + err := billingCreditsPurgeAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestBillingCreditsUserBalance(t *testing.T) { + tests := []struct { + name string + userID string + setupMocks func(*configmocks.MockManager, *MockBillingAdminService) + wantErr bool + errContains string + }{ + { + name: "successful user balance", + userID: "123", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetUserBalance(context.Background(), "123").Return( + &admin.UserBalance{}, + nil, + ) + }, + wantErr: false, + }, + { + name: "returns error when user ID is missing", + userID: "", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) {}, + wantErr: true, + errContains: "user ID is required", + }, + { + name: "returns error when not authenticated", + userID: "123", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "returns error when service fails", + userID: "123", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetUserBalance(context.Background(), "123").Return( + nil, + errors.New("service error"), + ) + }, + wantErr: true, + errContains: "service error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockBillingAdminService(t) + output := newTestOutput() + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, service) + } + + serviceFactory := func(cm config.Manager, out Output) BillingAdminService { + return service + } + + cmd := newMockCommand() + if tt.userID != "" { + cmd = cmd.withArgs(tt.userID) + } + + err := billingCreditsUserBalanceAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestBillingCreditsUserDeletedCredits(t *testing.T) { + tests := []struct { + name string + userID string + setupMocks func(*configmocks.MockManager, *MockBillingAdminService) + wantErr bool + errContains string + }{ + { + name: "successful user deleted credits", + userID: "123", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetUserDeletedCredits(context.Background(), "123", &admin.GetApiBillingUsersUserIdDeletedCreditsParams{}).Return( + []*admin.CreditItem{ + unmarshalCreditItemJSON(`{"id":"123e4567-e89b-12d3-a456-426614174000","user_id":123,"amount":"100.00","type":"manual","direction":"credit"}`), + }, + 1, + nil, + ) + }, + wantErr: false, + }, + { + name: "returns error when user ID is missing", + userID: "", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) {}, + wantErr: true, + errContains: "user ID is required", + }, + { + name: "returns error when not authenticated", + userID: "123", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "returns error when service fails", + userID: "123", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetUserDeletedCredits(context.Background(), "123", &admin.GetApiBillingUsersUserIdDeletedCreditsParams{}).Return( + nil, + 0, + errors.New("service error"), + ) + }, + wantErr: true, + errContains: "service error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockBillingAdminService(t) + output := newTestOutput() + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, service) + } + + serviceFactory := func(cm config.Manager, out Output) BillingAdminService { + return service + } + + cmd := newMockCommand() + if tt.userID != "" { + cmd = cmd.withArgs(tt.userID) + } + + err := billingCreditsUserDeletedCreditsAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + func TestBillingCreditsCreate(t *testing.T) { tests := []struct { name string @@ -241,7 +613,7 @@ func TestBillingCreditsCreate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -251,7 +623,11 @@ func TestBillingCreditsCreate(t *testing.T) { return service } - cmd := &mockBillingCreditsCreateCmd{} + cmd := newMockCommand(). + withString(FlagUserID, "123"). + withString(FlagAmount, "100.00"). + withString(FlagType, "manual"). + withString(FlagDirection, "credit") err := billingCreditsCreateAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -266,26 +642,3 @@ func TestBillingCreditsCreate(t *testing.T) { }) } } - -type mockBillingCreditsCreateCmd struct {} - -func (m *mockBillingCreditsCreateCmd) String(name string) string { - switch name { - case FlagUserID: - return "123" - case FlagAmount: - return "100.00" - case FlagType: - return "manual" - case FlagDirection: - return "credit" - } - return "" -} - -// Ensure mockBillingCreditsListCmd implements the interface -var _ billingCreditsListCmdGetter = (*mockBillingCreditsListCmd)(nil) -// Ensure mockBillingCreditsCreateCmd implements the interface -var _ billingCreditsCreateCmdGetter = (*mockBillingCreditsCreateCmd)(nil) -// Ensure mockBillingCreditsGetCmd implements the interface -var _ billingCreditsGetCmdGetter = (*mockBillingCreditsGetCmd)(nil) diff --git a/pkg/cli/admin_billing_overview.go b/pkg/cli/admin_billing_overview.go index 3dbbe91..a6edb3b 100644 --- a/pkg/cli/admin_billing_overview.go +++ b/pkg/cli/admin_billing_overview.go @@ -24,15 +24,12 @@ Examples: if err != nil { return err } - return billingOverviewAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory, defaultQuotaAdminServiceFactory) + return billingOverviewAction(ctx, output, cfgMgr, defaultBillingAdminServiceFactory, defaultQuotaAdminServiceFactory) }, } } -// billingOverviewCmdGetter defines the interface for the overview command. -type billingOverviewCmdGetter interface{} - -func billingOverviewAction(ctx context.Context, cmd billingOverviewCmdGetter, output Output, cfgMgr config.Manager, billingFactory BillingAdminServiceFactory, quotaFactory QuotaAdminServiceFactory) error { +func billingOverviewAction(ctx context.Context, output Output, cfgMgr config.Manager, billingFactory BillingAdminServiceFactory, quotaFactory QuotaAdminServiceFactory) error { billingService := billingFactory(cfgMgr, output) if err := billingService.RequireAuthenticated(); err != nil { return err diff --git a/pkg/cli/admin_billing_overview_test.go b/pkg/cli/admin_billing_overview_test.go new file mode 100644 index 0000000..e49050e --- /dev/null +++ b/pkg/cli/admin_billing_overview_test.go @@ -0,0 +1,227 @@ +package cli + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" + "go.lumeweb.com/portal-sdk/admin" +) + +func unmarshalQuotaPlanJSON(data string) *admin.QuotaPlan { + var item admin.QuotaPlan + if err := json.Unmarshal([]byte(data), &item); err != nil { + panic(err) + } + return &item +} + +func TestBillingOverview(t *testing.T) { + tests := []struct { + name string + jsonOutput bool + setupMocks func(*configmocks.MockManager, *MockBillingAdminService, *MockQuotaAdminService) + wantErr bool + errContains string + }{ + { + name: "successful overview", + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, billingSvc *MockBillingAdminService, quotaSvc *MockQuotaAdminService) { + billingSvc.EXPECT().RequireAuthenticated().Return(nil) + quotaSvc.EXPECT().ListPlans(context.Background()).Return( + []*admin.QuotaPlan{ + unmarshalQuotaPlanJSON(`{"id":1,"name":"Free","is_active":true}`), + unmarshalQuotaPlanJSON(`{"id":2,"name":"Pro","is_active":false}`), + }, + 2, + nil, + ) + billingSvc.EXPECT().ListPriceLines(context.Background()).Return( + []*admin.PriceLine{ + unmarshalPriceLineJSON(`{"id":1,"name":"Storage","is_active":true}`), + }, + 1, + nil, + ) + billingSvc.EXPECT().ListPricingPlans(context.Background()).Return( + []*admin.PricingPlanItem{ + unmarshalPricingPlanItemJSON(`{"id":1,"name":"Monthly","is_active":true}`), + }, + 1, + nil, + ) + billingSvc.EXPECT().ListPricingPlanPeriods(context.Background()).Return( + []*admin.PricingPlanPeriod{ + unmarshalPricingPlanPeriodJSON(`{"id":1,"pricing_plan_id":1,"quota_plan_id":1}`), + }, + 1, + nil, + ) + }, + wantErr: false, + }, + { + name: "successful overview json output", + jsonOutput: true, + setupMocks: func(cfgMgr *configmocks.MockManager, billingSvc *MockBillingAdminService, quotaSvc *MockQuotaAdminService) { + billingSvc.EXPECT().RequireAuthenticated().Return(nil) + quotaSvc.EXPECT().ListPlans(context.Background()).Return( + []*admin.QuotaPlan{}, + 0, + nil, + ) + billingSvc.EXPECT().ListPriceLines(context.Background()).Return( + []*admin.PriceLine{}, + 0, + nil, + ) + billingSvc.EXPECT().ListPricingPlans(context.Background()).Return( + []*admin.PricingPlanItem{}, + 0, + nil, + ) + billingSvc.EXPECT().ListPricingPlanPeriods(context.Background()).Return( + []*admin.PricingPlanPeriod{}, + 0, + nil, + ) + }, + wantErr: false, + }, + { + name: "returns error when billing not authenticated", + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, billingSvc *MockBillingAdminService, quotaSvc *MockQuotaAdminService) { + billingSvc.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "returns error when quota service fails", + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, billingSvc *MockBillingAdminService, quotaSvc *MockQuotaAdminService) { + billingSvc.EXPECT().RequireAuthenticated().Return(nil) + quotaSvc.EXPECT().ListPlans(context.Background()).Return( + nil, + 0, + errors.New("quota api error"), + ) + }, + wantErr: true, + errContains: "failed to list quota plans", + }, + { + name: "returns error when billing list price lines fails", + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, billingSvc *MockBillingAdminService, quotaSvc *MockQuotaAdminService) { + billingSvc.EXPECT().RequireAuthenticated().Return(nil) + quotaSvc.EXPECT().ListPlans(context.Background()).Return( + []*admin.QuotaPlan{}, + 0, + nil, + ) + billingSvc.EXPECT().ListPriceLines(context.Background()).Return( + nil, + 0, + errors.New("price lines api error"), + ) + }, + wantErr: true, + errContains: "failed to list price lines", + }, + { + name: "returns error when billing list pricing plans fails", + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, billingSvc *MockBillingAdminService, quotaSvc *MockQuotaAdminService) { + billingSvc.EXPECT().RequireAuthenticated().Return(nil) + quotaSvc.EXPECT().ListPlans(context.Background()).Return( + []*admin.QuotaPlan{}, + 0, + nil, + ) + billingSvc.EXPECT().ListPriceLines(context.Background()).Return( + []*admin.PriceLine{}, + 0, + nil, + ) + billingSvc.EXPECT().ListPricingPlans(context.Background()).Return( + nil, + 0, + errors.New("pricing plans api error"), + ) + }, + wantErr: true, + errContains: "failed to list pricing plans", + }, + { + name: "returns error when billing list pricing plan periods fails", + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, billingSvc *MockBillingAdminService, quotaSvc *MockQuotaAdminService) { + billingSvc.EXPECT().RequireAuthenticated().Return(nil) + quotaSvc.EXPECT().ListPlans(context.Background()).Return( + []*admin.QuotaPlan{}, + 0, + nil, + ) + billingSvc.EXPECT().ListPriceLines(context.Background()).Return( + []*admin.PriceLine{}, + 0, + nil, + ) + billingSvc.EXPECT().ListPricingPlans(context.Background()).Return( + []*admin.PricingPlanItem{}, + 0, + nil, + ) + billingSvc.EXPECT().ListPricingPlanPeriods(context.Background()).Return( + nil, + 0, + errors.New("periods api error"), + ) + }, + wantErr: true, + errContains: "failed to list pricing plan periods", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + billingSvc := NewMockBillingAdminService(t) + quotaSvc := NewMockQuotaAdminService(t) + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, billingSvc, quotaSvc) + } + + billingFactory := func(cm config.Manager, out Output) BillingAdminService { + return billingSvc + } + quotaFactory := func(cm config.Manager, out Output) QuotaAdminService { + return quotaSvc + } + + err := billingOverviewAction(context.Background(), output, cfgMgr, billingFactory, quotaFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/cli/admin_billing_price_lines.go b/pkg/cli/admin_billing_price_lines.go index e4a82f5..15b15f5 100644 --- a/pkg/cli/admin_billing_price_lines.go +++ b/pkg/cli/admin_billing_price_lines.go @@ -52,17 +52,12 @@ Examples: if err != nil { return err } - return billingPriceLinesListAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPriceLinesListAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPriceLinesListCmdGetter defines the interface for getting list command args and flags. -type billingPriceLinesListCmdGetter interface { - Args() cli.Args -} - -func billingPriceLinesListAction(ctx context.Context, cmd billingPriceLinesListCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPriceLinesListAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -117,17 +112,12 @@ Examples: if err != nil { return err } - return billingPriceLinesGetAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPriceLinesGetAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPriceLinesGetCmdGetter defines the interface for getting get command args and flags. -type billingPriceLinesGetCmdGetter interface { - Args() cli.Args -} - -func billingPriceLinesGetAction(ctx context.Context, cmd billingPriceLinesGetCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPriceLinesGetAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("price line ID is required") } @@ -205,19 +195,12 @@ Examples: if err != nil { return err } - return billingPriceLinesCreateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPriceLinesCreateAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPriceLinesCreateCmdGetter defines the interface for getting create command args and flags. -type billingPriceLinesCreateCmdGetter interface { - Args() cli.Args - String(name string) string - Bool(name string) bool -} - -func billingPriceLinesCreateAction(ctx context.Context, cmd billingPriceLinesCreateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPriceLinesCreateAction(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -292,20 +275,15 @@ Examples: if err != nil { return err } - return billingPriceLinesUpdateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPriceLinesUpdateAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPriceLinesUpdateCmdGetter defines the interface for getting update command args and flags. -type billingPriceLinesUpdateCmdGetter interface { - Args() cli.Args - String(name string) string - Bool(name string) bool - IsSet(name string) bool -} - -func billingPriceLinesUpdateAction(ctx context.Context, cmd billingPriceLinesUpdateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPriceLinesUpdateAction(ctx context.Context, cmd interface { + argsGetter + flagGetterWithIsSet +}, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("price line ID is required") } @@ -376,17 +354,12 @@ Examples: if err != nil { return err } - return billingPriceLinesDeleteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPriceLinesDeleteAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPriceLinesDeleteCmdGetter defines the interface for getting delete command args and flags. -type billingPriceLinesDeleteCmdGetter interface { - Args() cli.Args -} - -func billingPriceLinesDeleteAction(ctx context.Context, cmd billingPriceLinesDeleteCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPriceLinesDeleteAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("price line ID is required") } @@ -439,20 +412,15 @@ Examples: if err != nil { return err } - return billingPriceLinesAddPlanAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPriceLinesAddPlanAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPriceLinesAddPlanCmdGetter defines the interface for getting add-plan command args and flags. -type billingPriceLinesAddPlanCmdGetter interface { - Args() cli.Args - String(name string) string - Int(name string) int - IsSet(name string) bool -} - -func billingPriceLinesAddPlanAction(ctx context.Context, cmd billingPriceLinesAddPlanCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPriceLinesAddPlanAction(ctx context.Context, cmd interface { + argsGetter + flagGetterWithIsSet +}, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("price line ID is required") } @@ -536,18 +504,12 @@ Examples: if err != nil { return err } - return billingPriceLinesDeletePlanAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPriceLinesDeletePlanAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPriceLinesDeletePlanCmdGetter defines the interface for getting delete-plan command args and flags. -type billingPriceLinesDeletePlanCmdGetter interface { - Args() cli.Args - String(name string) string -} - -func billingPriceLinesDeletePlanAction(ctx context.Context, cmd billingPriceLinesDeletePlanCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPriceLinesDeletePlanAction(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("price line ID is required") } @@ -599,19 +561,12 @@ Examples: if err != nil { return err } - return billingPriceLinesUpdatePlanPositionAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPriceLinesUpdatePlanPositionAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPriceLinesUpdatePlanPositionCmdGetter defines the interface for getting update-plan-position command args and flags. -type billingPriceLinesUpdatePlanPositionCmdGetter interface { - Args() cli.Args - String(name string) string - Int(name string) int -} - -func billingPriceLinesUpdatePlanPositionAction(ctx context.Context, cmd billingPriceLinesUpdatePlanPositionCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPriceLinesUpdatePlanPositionAction(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("price line ID is required") } diff --git a/pkg/cli/admin_billing_price_lines_test.go b/pkg/cli/admin_billing_price_lines_test.go index fa0320d..493f4d7 100644 --- a/pkg/cli/admin_billing_price_lines_test.go +++ b/pkg/cli/admin_billing_price_lines_test.go @@ -5,13 +5,11 @@ import ( "encoding/json" "errors" "fmt" - "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" "go.lumeweb.com/portal-sdk/admin" @@ -73,7 +71,7 @@ func TestBillingPriceLinesList(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -83,7 +81,7 @@ func TestBillingPriceLinesList(t *testing.T) { return service } - var cmd billingPriceLinesListCmdGetter + cmd := newMockCommand() err := billingPriceLinesListAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -165,21 +163,20 @@ func TestAddPlan_AutoPosition(t *testing.T) { tt.setupMocks(cfgMgr, service) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service } - args := &mockArgs{} + cmd := newMockCommand() if tt.priceLineID != "" { - args.args = []string{tt.priceLineID} + cmd = cmd.withArgs(tt.priceLineID) } - cmd := &billingPriceLinesAddPlanArgs{ - args: args, - planID: tt.planID, - position: tt.position, - isSetPosition: tt.isSetPosition, + cmd = cmd.withString("plan-id", tt.planID) + if tt.isSetPosition { + cmd = cmd.withInt("position", tt.position) + cmd = cmd.withIsSet("position", true) } err := billingPriceLinesAddPlanAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -196,22 +193,13 @@ func TestAddPlan_AutoPosition(t *testing.T) { } } -// billingPriceLinesGetArgs implements billingPriceLinesGetCmdGetter -type billingPriceLinesGetArgs struct { - args cli.Args -} - -func (m *billingPriceLinesGetArgs) Args() cli.Args { - return m.args -} - func TestBillingPriceLinesGet(t *testing.T) { tests := []struct { - name string - priceLineID string - setupMocks func(*configmocks.MockManager, *MockBillingAdminService) - wantErr bool - errContains string + name string + priceLineID string + setupMocks func(*configmocks.MockManager, *MockBillingAdminService) + wantErr bool + errContains string }{ { name: "successful get", @@ -238,17 +226,16 @@ func TestBillingPriceLinesGet(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.priceLineID != "" { - args.args = []string{tt.priceLineID} + cmd = cmd.withArgs(tt.priceLineID) } - cmd := &billingPriceLinesGetArgs{args: args} serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -268,39 +255,6 @@ func TestBillingPriceLinesGet(t *testing.T) { } } -// billingPriceLinesCreateCmd implements billingPriceLinesCreateCmdGetter -type billingPriceLinesCreateCmd struct { - args cli.Args - name string - description string - isActive bool - isDefault bool -} - -func (m *billingPriceLinesCreateCmd) Args() cli.Args { - return m.args -} - -func (m *billingPriceLinesCreateCmd) String(name string) string { - switch name { - case "name": - return m.name - case "description": - return m.description - } - return "" -} - -func (m *billingPriceLinesCreateCmd) Bool(name string) bool { - switch name { - case "is-active": - return m.isActive - case "is-default": - return m.isDefault - } - return false -} - func TestBillingPriceLinesCreate(t *testing.T) { tests := []struct { name string @@ -338,7 +292,7 @@ func TestBillingPriceLinesCreate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -348,12 +302,11 @@ func TestBillingPriceLinesCreate(t *testing.T) { return service } - cmd := &billingPriceLinesCreateCmd{ - name: "Storage", - description: "Storage pricing", - isActive: true, - isDefault: false, - } + cmd := newMockCommand(). + withString("name", "Storage"). + withString("description", "Storage pricing"). + withBool("is-active", true). + withBool("is-default", false) err := billingPriceLinesCreateAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -369,47 +322,6 @@ func TestBillingPriceLinesCreate(t *testing.T) { } } -// billingPriceLinesUpdateCmd implements billingPriceLinesUpdateCmdGetter -type billingPriceLinesUpdateCmd struct { - args cli.Args - name string - description string - isActive bool - isDefault bool - isSet map[string]bool -} - -func (m *billingPriceLinesUpdateCmd) Args() cli.Args { - return m.args -} - -func (m *billingPriceLinesUpdateCmd) String(name string) string { - switch name { - case "name": - return m.name - case "description": - return m.description - } - return "" -} - -func (m *billingPriceLinesUpdateCmd) Bool(name string) bool { - switch name { - case "is-active": - return m.isActive - case "is-default": - return m.isDefault - } - return false -} - -func (m *billingPriceLinesUpdateCmd) IsSet(name string) bool { - if m.isSet == nil { - return false - } - return m.isSet[name] -} - func TestBillingPriceLinesUpdate(t *testing.T) { tests := []struct { name string @@ -446,25 +358,21 @@ func TestBillingPriceLinesUpdate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.priceLineID != "" { - args.args = []string{tt.priceLineID} - } - cmd := &billingPriceLinesUpdateCmd{ - args: args, - name: "Updated Storage", - description: "Updated description", - isSet: map[string]bool{ - "name": true, - "description": true, - }, + cmd = cmd.withArgs(tt.priceLineID) } + cmd = cmd. + withString("name", "Updated Storage"). + withString("description", "Updated description"). + withIsSet("name", true). + withIsSet("description", true) serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -484,15 +392,6 @@ func TestBillingPriceLinesUpdate(t *testing.T) { } } -// billingPriceLinesDeleteArgs implements billingPriceLinesDeleteCmdGetter -type billingPriceLinesDeleteArgs struct { - args cli.Args -} - -func (m *billingPriceLinesDeleteArgs) Args() cli.Args { - return m.args -} - func TestBillingPriceLinesDelete(t *testing.T) { tests := []struct { name string @@ -523,17 +422,16 @@ func TestBillingPriceLinesDelete(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.priceLineID != "" { - args.args = []string{tt.priceLineID} + cmd = cmd.withArgs(tt.priceLineID) } - cmd := &billingPriceLinesDeleteArgs{args: args} serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -553,48 +451,6 @@ func TestBillingPriceLinesDelete(t *testing.T) { } } -// billingPriceLinesAddPlanArgs implements billingPriceLinesAddPlanCmdGetter -type billingPriceLinesAddPlanArgs struct { - args cli.Args - planID string - position int - isSetPosition bool -} - -func (m *billingPriceLinesAddPlanArgs) Args() cli.Args { - return m.args -} - -func (m *billingPriceLinesAddPlanArgs) String(name string) string { - if name == "plan-id" { - return m.planID - } - return "" -} - -func (m *billingPriceLinesAddPlanArgs) Int(name string) int { - switch name { - case "plan-id": - if m.planID != "" { - v, _ := strconv.Atoi(m.planID) - return v - } - return 0 - case "position": - return m.position - } - return 0 -} - -func (m *billingPriceLinesAddPlanArgs) IsSet(name string) bool { - switch name { - case "position": - return m.isSetPosition - default: - return false - } -} - func TestBillingPriceLinesAddPlan(t *testing.T) { tests := []struct { name string @@ -631,17 +487,20 @@ func TestBillingPriceLinesAddPlan(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.priceLineID != "" { - args.args = []string{tt.priceLineID} + cmd = cmd.withArgs(tt.priceLineID) } - cmd := &billingPriceLinesAddPlanArgs{args: args, planID: "1", position: 1, isSetPosition: true} + cmd = cmd. + withString("plan-id", "1"). + withInt("position", 1). + withIsSet("position", true) serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -661,22 +520,6 @@ func TestBillingPriceLinesAddPlan(t *testing.T) { } } -// billingPriceLinesDeletePlanArgs implements billingPriceLinesDeletePlanCmdGetter -type billingPriceLinesDeletePlanArgs struct { - args cli.Args -} - -func (m *billingPriceLinesDeletePlanArgs) Args() cli.Args { - return m.args -} - -func (m *billingPriceLinesDeletePlanArgs) String(name string) string { - if name == "plan-id" { - return "1" - } - return "" -} - func TestBillingPriceLinesDeletePlan(t *testing.T) { tests := []struct { name string @@ -707,17 +550,17 @@ func TestBillingPriceLinesDeletePlan(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.priceLineID != "" { - args.args = []string{tt.priceLineID} + cmd = cmd.withArgs(tt.priceLineID) } - cmd := &billingPriceLinesDeletePlanArgs{args: args} + cmd = cmd.withString("plan-id", "1") serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -737,32 +580,6 @@ func TestBillingPriceLinesDeletePlan(t *testing.T) { } } -// billingPriceLinesUpdatePlanPositionArgs implements billingPriceLinesUpdatePlanPositionCmdGetter -type billingPriceLinesUpdatePlanPositionArgs struct { - args cli.Args -} - -func (m *billingPriceLinesUpdatePlanPositionArgs) Args() cli.Args { - return m.args -} - -func (m *billingPriceLinesUpdatePlanPositionArgs) String(name string) string { - if name == "plan-id" { - return "1" - } - return "" -} - -func (m *billingPriceLinesUpdatePlanPositionArgs) Int(name string) int { - switch name { - case "plan-id": - return 1 - case "position": - return 2 - } - return 0 -} - func TestBillingPriceLinesUpdatePlanPosition(t *testing.T) { tests := []struct { name string @@ -795,17 +612,19 @@ func TestBillingPriceLinesUpdatePlanPosition(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.priceLineID != "" { - args.args = []string{tt.priceLineID} + cmd = cmd.withArgs(tt.priceLineID) } - cmd := &billingPriceLinesUpdatePlanPositionArgs{args: args} + cmd = cmd. + withString("plan-id", "1"). + withInt("position", 2) serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service diff --git a/pkg/cli/admin_billing_pricing_plans.go b/pkg/cli/admin_billing_pricing_plans.go index a0b1b99..5e26a01 100644 --- a/pkg/cli/admin_billing_pricing_plans.go +++ b/pkg/cli/admin_billing_pricing_plans.go @@ -24,15 +24,12 @@ Examples: if err != nil { return err } - return billingPricingPlansListAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPricingPlansListAction(ctx, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPricingPlansListCmdGetter is an empty interface for list command (no args/flags needed) -type billingPricingPlansListCmdGetter interface{} - -func billingPricingPlansListAction(ctx context.Context, cmd billingPricingPlansListCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPricingPlansListAction(ctx context.Context, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -91,17 +88,12 @@ Examples: if err != nil { return err } - return billingPricingPlansGetAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPricingPlansGetAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPricingPlansGetCmdGetter defines the interface for getting get command args. -type billingPricingPlansGetCmdGetter interface { - Args() cli.Args -} - -func billingPricingPlansGetAction(ctx context.Context, cmd billingPricingPlansGetCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPricingPlansGetAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("plan ID is required") } @@ -238,21 +230,15 @@ Examples: if err != nil { return err } - return billingPricingPlansCreateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPricingPlansCreateAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPricingPlansCreateCmdGetter defines the interface for getting create command flags. -type billingPricingPlansCreateCmdGetter interface { - String(name string) string - Bool(name string) bool - Int(name string) int +func billingPricingPlansCreateAction(ctx context.Context, cmd interface { + flagGetterWithIsSet Float(name string) float64 - IsSet(name string) bool -} - -func billingPricingPlansCreateAction(ctx context.Context, cmd billingPricingPlansCreateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +}, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -382,20 +368,15 @@ Examples: if err != nil { return err } - return billingPricingPlansUpdateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPricingPlansUpdateAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPricingPlansUpdateCmdGetter defines the interface for getting update command args and flags. -type billingPricingPlansUpdateCmdGetter interface { - Args() cli.Args - String(name string) string - Bool(name string) bool - IsSet(name string) bool -} - -func billingPricingPlansUpdateAction(ctx context.Context, cmd billingPricingPlansUpdateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPricingPlansUpdateAction(ctx context.Context, cmd interface { + argsGetter + flagGetterWithIsSet +}, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("pricing plan ID is required") } @@ -466,17 +447,12 @@ Examples: if err != nil { return err } - return billingPricingPlansDeleteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPricingPlansDeleteAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPricingPlansDeleteCmdGetter defines the interface for getting delete command args. -type billingPricingPlansDeleteCmdGetter interface { - Args() cli.Args -} - -func billingPricingPlansDeleteAction(ctx context.Context, cmd billingPricingPlansDeleteCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPricingPlansDeleteAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("pricing plan ID is required") } @@ -517,15 +493,12 @@ Examples: if err != nil { return err } - return billingPricingPlanPeriodsListAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPricingPlanPeriodsListAction(ctx, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPricingPlanPeriodsListCmdGetter is an empty interface for list command (no args/flags needed) -type billingPricingPlanPeriodsListCmdGetter interface{} - -func billingPricingPlanPeriodsListAction(ctx context.Context, cmd billingPricingPlanPeriodsListCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPricingPlanPeriodsListAction(ctx context.Context, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -577,17 +550,12 @@ Examples: if err != nil { return err } - return billingPricingPlanPeriodsGetAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPricingPlanPeriodsGetAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPricingPlanPeriodsGetCmdGetter defines the interface for getting get command args. -type billingPricingPlanPeriodsGetCmdGetter interface { - Args() cli.Args -} - -func billingPricingPlanPeriodsGetAction(ctx context.Context, cmd billingPricingPlanPeriodsGetCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPricingPlanPeriodsGetAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("period ID is required") } @@ -668,21 +636,15 @@ Examples: if err != nil { return err } - return billingPricingPlanPeriodsCreateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPricingPlanPeriodsCreateAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPricingPlanPeriodsCreateCmdGetter defines the interface for getting create command flags. -type billingPricingPlanPeriodsCreateCmdGetter interface { - Int(name string) int +func billingPricingPlanPeriodsCreateAction(ctx context.Context, cmd interface { + flagGetterWithIsSet Float(name string) float64 - String(name string) string - Bool(name string) bool - IsSet(name string) bool -} - -func billingPricingPlanPeriodsCreateAction(ctx context.Context, cmd billingPricingPlanPeriodsCreateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +}, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -769,22 +731,16 @@ Examples: if err != nil { return err } - return billingPricingPlanPeriodsUpdateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPricingPlanPeriodsUpdateAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPricingPlanPeriodsUpdateCmdGetter defines the interface for getting update command args and flags. -type billingPricingPlanPeriodsUpdateCmdGetter interface { - Args() cli.Args +func billingPricingPlanPeriodsUpdateAction(ctx context.Context, cmd interface { + argsGetter + flagGetterWithIsSet Float(name string) float64 - String(name string) string - Int(name string) int - Bool(name string) bool - IsSet(name string) bool -} - -func billingPricingPlanPeriodsUpdateAction(ctx context.Context, cmd billingPricingPlanPeriodsUpdateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +}, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("period ID is required") } @@ -856,17 +812,12 @@ Examples: if err != nil { return err } - return billingPricingPlanPeriodsDeleteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingPricingPlanPeriodsDeleteAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingPricingPlanPeriodsDeleteCmdGetter defines the interface for getting delete command args. -type billingPricingPlanPeriodsDeleteCmdGetter interface { - Args() cli.Args -} - -func billingPricingPlanPeriodsDeleteAction(ctx context.Context, cmd billingPricingPlanPeriodsDeleteCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingPricingPlanPeriodsDeleteAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("period ID is required") } @@ -913,17 +864,12 @@ Examples: if err != nil { return err } - return billingSyncPricingPlanAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSyncPricingPlanAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSyncPricingPlanCmdGetter defines the interface for getting sync command args. -type billingSyncPricingPlanCmdGetter interface { - Args() cli.Args -} - -func billingSyncPricingPlanAction(ctx context.Context, cmd billingSyncPricingPlanCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSyncPricingPlanAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("plan ID is required") } @@ -967,15 +913,12 @@ Examples: if err != nil { return err } - return billingSyncAllPricingPlansAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSyncAllPricingPlansAction(ctx, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSyncAllPricingPlansCmdGetter is an empty interface for sync-all command (no args needed). -type billingSyncAllPricingPlansCmdGetter interface{} - -func billingSyncAllPricingPlansAction(ctx context.Context, cmd billingSyncAllPricingPlansCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSyncAllPricingPlansAction(ctx context.Context, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err diff --git a/pkg/cli/admin_billing_pricing_plans_test.go b/pkg/cli/admin_billing_pricing_plans_test.go index e48b790..03e65d3 100644 --- a/pkg/cli/admin_billing_pricing_plans_test.go +++ b/pkg/cli/admin_billing_pricing_plans_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" "go.lumeweb.com/portal-sdk/admin" @@ -39,6 +38,123 @@ func unmarshalPricingPlanPeriodJSON(data string) *admin.PricingPlanPeriod { return &item } +func TestBillingPricingPlansGet(t *testing.T) { + tests := []struct { + name string + planID string + jsonOutput bool + setupMocks func(*configmocks.MockManager, *MockBillingAdminService) + wantErr bool + errContains string + }{ + { + name: "successful get", + planID: "1", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetPricingPlan(context.Background(), "1").Return( + unmarshalPricingPlanJSON(`{"id":1,"name":"Pro Plan","currency":"USD","is_active":true,"is_public":true,"description":"A test plan"}`), + nil, + ) + }, + wantErr: false, + }, + { + name: "successful get with json output", + planID: "1", + jsonOutput: true, + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetPricingPlan(context.Background(), "1").Return( + unmarshalPricingPlanJSON(`{"id":1,"name":"Pro Plan","currency":"USD","is_active":true,"is_public":true}`), + nil, + ) + }, + wantErr: false, + }, + { + name: "successful get with pricing periods", + planID: "1", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetPricingPlan(context.Background(), "1").Return( + unmarshalPricingPlanJSON(`{"id":1,"name":"Pro Plan","currency":"USD","is_active":true,"is_public":true,"pricing_periods":[{"id":1,"pricing_plan_id":1,"price_usd":9.99,"cadence":"monthly","quota_plan_id":1,"is_active":true}]}`), + nil, + ) + }, + wantErr: false, + }, + { + name: "returns error when plan ID is missing", + planID: "", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + }, + wantErr: true, + errContains: "plan ID is required", + }, + { + name: "returns error when not authenticated", + planID: "1", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "returns error when service fails", + planID: "1", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetPricingPlan(context.Background(), "1").Return( + nil, + errors.New("service error"), + ) + }, + wantErr: true, + errContains: "service error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockBillingAdminService(t) + + var output Output + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } else { + output = newTestOutput() + } + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, service) + } + + cmd := newMockCommand() + if tt.planID != "" { + cmd = cmd.withArgs(tt.planID) + } + + serviceFactory := func(cm config.Manager, out Output) BillingAdminService { + return service + } + + err := billingPricingPlansGetAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + func TestBillingPricingPlansList(t *testing.T) { tests := []struct { name string @@ -87,7 +203,7 @@ func TestBillingPricingPlansList(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -97,9 +213,7 @@ func TestBillingPricingPlansList(t *testing.T) { return service } - var cmd billingPricingPlansListCmdGetter - - err := billingPricingPlansListAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + err := billingPricingPlansListAction(context.Background(), output, cfgMgr, serviceFactory) if tt.wantErr { require.Error(t, err) @@ -116,20 +230,19 @@ func TestBillingPricingPlansList(t *testing.T) { func TestPricingPlanPeriodsCreate_PriceValidation(t *testing.T) { tests := []struct { name string - cmd *billingPricingPlanPeriodsCreateCmd + cmd *mockCommand setupMocks func(*configmocks.MockManager, *MockBillingAdminService) wantErr bool errContains string }{ { name: "zero price without allow-free rejected", - cmd: &billingPricingPlanPeriodsCreateCmd{ - planID: 1, - price: 0, - cadence: "monthly", - quotaPlanID: 1, - allowFree: false, - }, + cmd: newMockCommand(). + withInt(FlagPlanID, 1). + withFloat(FlagPrice, 0). + withString(FlagCadence, "monthly"). + withInt(FlagQuotaPlanID, 1). + withBool(FlagAllowFree, false), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockBillingAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) }, @@ -138,13 +251,12 @@ func TestPricingPlanPeriodsCreate_PriceValidation(t *testing.T) { }, { name: "zero price with allow-free accepted", - cmd: &billingPricingPlanPeriodsCreateCmd{ - planID: 1, - price: 0, - cadence: "monthly", - quotaPlanID: 1, - allowFree: true, - }, + cmd: newMockCommand(). + withInt(FlagPlanID, 1). + withFloat(FlagPrice, 0). + withString(FlagCadence, "monthly"). + withInt(FlagQuotaPlanID, 1). + withBool(FlagAllowFree, true), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockBillingAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) svc.EXPECT().CreatePricingPlanPeriod(mock.Anything, mock.AnythingOfType("*admin.PricingPlanPeriodCreateRequest")).Return(&admin.PricingPlanPeriod{}, nil) @@ -153,13 +265,12 @@ func TestPricingPlanPeriodsCreate_PriceValidation(t *testing.T) { }, { name: "positive price works without allow-free", - cmd: &billingPricingPlanPeriodsCreateCmd{ - planID: 1, - price: 9.99, - cadence: "monthly", - quotaPlanID: 1, - allowFree: false, - }, + cmd: newMockCommand(). + withInt(FlagPlanID, 1). + withFloat(FlagPrice, 9.99). + withString(FlagCadence, "monthly"). + withInt(FlagQuotaPlanID, 1). + withBool(FlagAllowFree, false), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockBillingAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) svc.EXPECT().CreatePricingPlanPeriod(mock.Anything, mock.AnythingOfType("*admin.PricingPlanPeriodCreateRequest")).Return(&admin.PricingPlanPeriod{}, nil) @@ -168,13 +279,12 @@ func TestPricingPlanPeriodsCreate_PriceValidation(t *testing.T) { }, { name: "negative price rejected", - cmd: &billingPricingPlanPeriodsCreateCmd{ - planID: 1, - price: -5.0, - cadence: "monthly", - quotaPlanID: 1, - allowFree: false, - }, + cmd: newMockCommand(). + withInt(FlagPlanID, 1). + withFloat(FlagPrice, -5.0). + withString(FlagCadence, "monthly"). + withInt(FlagQuotaPlanID, 1). + withBool(FlagAllowFree, false), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockBillingAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) }, @@ -190,7 +300,7 @@ func TestPricingPlanPeriodsCreate_PriceValidation(t *testing.T) { tt.setupMocks(cfgMgr, service) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -210,119 +320,20 @@ func TestPricingPlanPeriodsCreate_PriceValidation(t *testing.T) { } } -// billingPricingPlansCreateCmd implements billingPricingPlansCreateCmdGetter -type billingPricingPlansCreateCmd struct { - name string - currency string - description string - isActive bool - isPublic bool - pricelineID int - // Period creation fields - price float64 - cadence string - quotaPlanID int - rollingDays int - allowFree bool - isSet map[string]bool -} - -func (m *billingPricingPlansCreateCmd) String(name string) string { - switch name { - case FlagName: - return m.name - case FlagCurrency: - return m.currency - case FlagDescription: - return m.description - case FlagCadence: - return m.cadence - } - return "" -} - -func (m *billingPricingPlansCreateCmd) Bool(name string) bool { - switch name { - case FlagIsActive: - return m.isActive - case FlagIsPublic: - return m.isPublic - case FlagAllowFree: - return m.allowFree - } - return false -} - -func (m *billingPricingPlansCreateCmd) Int(name string) int { - switch name { - case FlagPricelineID: - return m.pricelineID - case FlagQuotaPlanID: - return m.quotaPlanID - case FlagRollingDays: - return m.rollingDays - } - return 0 -} - -func (m *billingPricingPlansCreateCmd) Float(name string) float64 { - if name == FlagPrice { - return m.price - } - return 0 -} - -func (m *billingPricingPlansCreateCmd) IsSet(name string) bool { - if m.isSet == nil { - return false - } - return m.isSet[name] -} - -func (m *billingPricingPlansUpdateCmd) String(name string) string { - switch name { - case FlagName: - return m.name - case FlagCurrency: - return m.currency - case FlagDescription: - return m.description - } - return "" -} - -func (m *billingPricingPlansUpdateCmd) Bool(name string) bool { - switch name { - case FlagIsActive: - return m.isActive - case FlagIsPublic: - return m.isPublic - } - return false -} - -func (m *billingPricingPlansUpdateCmd) IsSet(name string) bool { - if m.isSet == nil { - return false - } - return m.isSet[name] -} - func TestBillingPricingPlansCreate(t *testing.T) { tests := []struct { name string - cmd *billingPricingPlansCreateCmd + cmd *mockCommand setupMocks func(*configmocks.MockManager, *MockBillingAdminService) wantErr bool errContains string }{ { name: "successful create", - cmd: &billingPricingPlansCreateCmd{ - name: "Pro Plan", - currency: "USD", - isActive: true, - }, + cmd: newMockCommand(). + withString(FlagName, "Pro Plan"). + withString(FlagCurrency, "USD"). + withBool(FlagIsActive, true), setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().CreatePricingPlan(context.Background(), mock.MatchedBy(func(req *admin.PricingPlanCreateRequest) bool { @@ -336,11 +347,10 @@ func TestBillingPricingPlansCreate(t *testing.T) { }, { name: "returns error when not authenticated", - cmd: &billingPricingPlansCreateCmd{ - name: "Pro Plan", - currency: "USD", - isActive: true, - }, + cmd: newMockCommand(). + withString(FlagName, "Pro Plan"). + withString(FlagCurrency, "USD"). + withBool(FlagIsActive, true), setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) }, @@ -349,19 +359,16 @@ func TestBillingPricingPlansCreate(t *testing.T) { }, { name: "successful create with period", - cmd: &billingPricingPlansCreateCmd{ - name: "Starter Plan", - currency: "USD", - isActive: true, - quotaPlanID: 1, - price: 9.99, - cadence: "monthly", - isSet: map[string]bool{ - FlagQuotaPlanID: true, - FlagPrice: true, - FlagCadence: true, - }, - }, + cmd: newMockCommand(). + withString(FlagName, "Starter Plan"). + withString(FlagCurrency, "USD"). + withBool(FlagIsActive, true). + withInt(FlagQuotaPlanID, 1). + withFloat(FlagPrice, 9.99). + withString(FlagCadence, "monthly"). + withIsSet(FlagQuotaPlanID, true). + withIsSet(FlagPrice, true). + withIsSet(FlagCadence, true), setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().CreatePricingPlan(context.Background(), mock.MatchedBy(func(req *admin.PricingPlanCreateRequest) bool { @@ -381,15 +388,12 @@ func TestBillingPricingPlansCreate(t *testing.T) { }, { name: "create plan only when quota-plan-id set but price missing", - cmd: &billingPricingPlansCreateCmd{ - name: "Partial Plan", - currency: "USD", - isActive: true, - quotaPlanID: 1, - isSet: map[string]bool{ - FlagQuotaPlanID: true, - }, - }, + cmd: newMockCommand(). + withString(FlagName, "Partial Plan"). + withString(FlagCurrency, "USD"). + withBool(FlagIsActive, true). + withInt(FlagQuotaPlanID, 1). + withIsSet(FlagQuotaPlanID, true), setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().CreatePricingPlan(context.Background(), mock.MatchedBy(func(req *admin.PricingPlanCreateRequest) bool { @@ -398,7 +402,6 @@ func TestBillingPricingPlansCreate(t *testing.T) { unmarshalPricingPlanJSON(`{"id":1,"name":"Partial Plan","currency":"USD","is_active":true}`), nil, ) - // Period creation skipped because price/cadence not set }, wantErr: false, }, @@ -408,7 +411,7 @@ func TestBillingPricingPlansCreate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -432,21 +435,6 @@ func TestBillingPricingPlansCreate(t *testing.T) { } } -// billingPricingPlansUpdateCmd implements billingPricingPlansUpdateCmdGetter -type billingPricingPlansUpdateCmd struct { - args cli.Args - name string - currency string - description string - isActive bool - isPublic bool - isSet map[string]bool -} - -func (m *billingPricingPlansUpdateCmd) Args() cli.Args { - return m.args -} - func TestBillingPricingPlansUpdate(t *testing.T) { tests := []struct { name string @@ -482,23 +470,19 @@ func TestBillingPricingPlansUpdate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.planID != "" { - args.args = []string{tt.planID} - } - cmd := &billingPricingPlansUpdateCmd{ - args: args, - name: "Updated Plan", - isSet: map[string]bool{ - "name": true, - }, + cmd = cmd.withArgs(tt.planID) } + cmd = cmd. + withString(FlagName, "Updated Plan"). + withIsSet(FlagName, true) serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -518,15 +502,6 @@ func TestBillingPricingPlansUpdate(t *testing.T) { } } -// billingPricingPlansDeleteArgs implements billingPricingPlansDeleteCmdGetter -type billingPricingPlansDeleteArgs struct { - args cli.Args -} - -func (m *billingPricingPlansDeleteArgs) Args() cli.Args { - return m.args -} - func TestBillingPricingPlansDelete(t *testing.T) { tests := []struct { name string @@ -557,17 +532,16 @@ func TestBillingPricingPlansDelete(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.planID != "" { - args.args = []string{tt.planID} + cmd = cmd.withArgs(tt.planID) } - cmd := &billingPricingPlansDeleteArgs{args: args} serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -635,7 +609,7 @@ func TestBillingPricingPlanPeriodsList(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -645,9 +619,7 @@ func TestBillingPricingPlanPeriodsList(t *testing.T) { return service } - var cmd billingPricingPlanPeriodsListCmdGetter - - err := billingPricingPlanPeriodsListAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + err := billingPricingPlanPeriodsListAction(context.Background(), output, cfgMgr, serviceFactory) if tt.wantErr { require.Error(t, err) @@ -661,15 +633,6 @@ func TestBillingPricingPlanPeriodsList(t *testing.T) { } } -// billingPricingPlanPeriodsGetArgs implements billingPricingPlanPeriodsGetCmdGetter -type billingPricingPlanPeriodsGetArgs struct { - args cli.Args -} - -func (m *billingPricingPlanPeriodsGetArgs) Args() cli.Args { - return m.args -} - func TestBillingPricingPlanPeriodsGet(t *testing.T) { tests := []struct { name string @@ -703,17 +666,16 @@ func TestBillingPricingPlanPeriodsGet(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.periodID != "" { - args.args = []string{tt.periodID} + cmd = cmd.withArgs(tt.periodID) } - cmd := &billingPricingPlanPeriodsGetArgs{args: args} serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -733,57 +695,6 @@ func TestBillingPricingPlanPeriodsGet(t *testing.T) { } } -// billingPricingPlanPeriodsCreateCmd implements billingPricingPlanPeriodsCreateCmdGetter -type billingPricingPlanPeriodsCreateCmd struct { - planID int - price float64 - cadence string - quotaPlanID int - rollingDays int - allowFree bool - isSet map[string]bool -} - -func (m *billingPricingPlanPeriodsCreateCmd) Int(name string) int { - switch name { - case FlagPlanID: - return m.planID - case FlagQuotaPlanID: - return m.quotaPlanID - case FlagRollingDays: - return m.rollingDays - } - return 0 -} - -func (m *billingPricingPlanPeriodsCreateCmd) Float(name string) float64 { - if name == FlagPrice { - return m.price - } - return 0 -} - -func (m *billingPricingPlanPeriodsCreateCmd) String(name string) string { - if name == FlagCadence { - return m.cadence - } - return "" -} - -func (m *billingPricingPlanPeriodsCreateCmd) Bool(name string) bool { - if name == FlagAllowFree { - return m.allowFree - } - return false -} - -func (m *billingPricingPlanPeriodsCreateCmd) IsSet(name string) bool { - if m.isSet == nil { - return false - } - return m.isSet[name] -} - func TestBillingPricingPlanPeriodsCreate(t *testing.T) { tests := []struct { name string @@ -821,7 +732,7 @@ func TestBillingPricingPlanPeriodsCreate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -831,12 +742,11 @@ func TestBillingPricingPlanPeriodsCreate(t *testing.T) { return service } - cmd := &billingPricingPlanPeriodsCreateCmd{ - planID: 1, - price: 9.99, - cadence: "monthly", - quotaPlanID: 1, - } + cmd := newMockCommand(). + withInt(FlagPlanID, 1). + withFloat(FlagPrice, 9.99). + withString(FlagCadence, "monthly"). + withInt(FlagQuotaPlanID, 1) err := billingPricingPlanPeriodsCreateAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -852,59 +762,6 @@ func TestBillingPricingPlanPeriodsCreate(t *testing.T) { } } -// billingPricingPlanPeriodsUpdateCmd implements billingPricingPlanPeriodsUpdateCmdGetter -type billingPricingPlanPeriodsUpdateCmd struct { - args cli.Args - price float64 - cadence string - quotaPlanID int - rollingDays int - allowFree bool - isSet map[string]bool -} - -func (m *billingPricingPlanPeriodsUpdateCmd) Args() cli.Args { - return m.args -} - -func (m *billingPricingPlanPeriodsUpdateCmd) Float(name string) float64 { - if name == FlagPrice { - return m.price - } - return 0 -} - -func (m *billingPricingPlanPeriodsUpdateCmd) String(name string) string { - if name == FlagCadence { - return m.cadence - } - return "" -} - -func (m *billingPricingPlanPeriodsUpdateCmd) Int(name string) int { - switch name { - case FlagQuotaPlanID: - return m.quotaPlanID - case FlagRollingDays: - return m.rollingDays - } - return 0 -} - -func (m *billingPricingPlanPeriodsUpdateCmd) Bool(name string) bool { - if name == FlagAllowFree { - return m.allowFree - } - return false -} - -func (m *billingPricingPlanPeriodsUpdateCmd) IsSet(name string) bool { - if m.isSet == nil { - return false - } - return m.isSet[name] -} - func TestBillingPricingPlanPeriodsUpdate(t *testing.T) { tests := []struct { name string @@ -933,11 +790,11 @@ func TestBillingPricingPlanPeriodsUpdate(t *testing.T) { wantErr: false, }, { - name: "allow-free sets AllowFree on request", - periodID: "1", - price: 0, + name: "allow-free sets AllowFree on request", + periodID: "1", + price: 0, allowFree: true, - isSet: map[string]bool{"allow-free": true}, + isSet: map[string]bool{"allow-free": true}, setupMocks: func(cfgMgr *configmocks.MockManager, service *MockBillingAdminService) { service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().UpdatePricingPlanPeriod(context.Background(), "1", mock.MatchedBy(func(req *admin.PricingPlanPeriodUpdateRequest) bool { @@ -962,21 +819,21 @@ func TestBillingPricingPlanPeriodsUpdate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.periodID != "" { - args.args = []string{tt.periodID} + cmd = cmd.withArgs(tt.periodID) } - cmd := &billingPricingPlanPeriodsUpdateCmd{ - args: args, - price: tt.price, - allowFree: tt.allowFree, - isSet: tt.isSet, + cmd = cmd. + withFloat(FlagPrice, tt.price). + withBool(FlagAllowFree, tt.allowFree) + for k, v := range tt.isSet { + cmd = cmd.withIsSet(k, v) } serviceFactory := func(cm config.Manager, out Output) BillingAdminService { @@ -997,15 +854,6 @@ func TestBillingPricingPlanPeriodsUpdate(t *testing.T) { } } -// billingPricingPlanPeriodsDeleteArgs implements billingPricingPlanPeriodsDeleteCmdGetter -type billingPricingPlanPeriodsDeleteArgs struct { - args cli.Args -} - -func (m *billingPricingPlanPeriodsDeleteArgs) Args() cli.Args { - return m.args -} - func TestBillingPricingPlanPeriodsDelete(t *testing.T) { tests := []struct { name string @@ -1036,17 +884,16 @@ func TestBillingPricingPlanPeriodsDelete(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.periodID != "" { - args.args = []string{tt.periodID} + cmd = cmd.withArgs(tt.periodID) } - cmd := &billingPricingPlanPeriodsDeleteArgs{args: args} serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -1066,18 +913,6 @@ func TestBillingPricingPlanPeriodsDelete(t *testing.T) { } } -// Mock command getters for sync commands -type mockBillingSyncCmd struct { - args cli.Args -} - -func (m *mockBillingSyncCmd) Args() cli.Args { - return m.args -} - -// mockBillingSyncAllCmd is an empty struct for sync-all command (no args needed) -type mockBillingSyncAllCmd struct{} - func TestBillingSyncPricingPlan(t *testing.T) { tests := []struct { name string @@ -1128,7 +963,7 @@ func TestBillingSyncPricingPlan(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -1138,11 +973,10 @@ func TestBillingSyncPricingPlan(t *testing.T) { return service } - args := &mockArgs{} + cmd := newMockCommand() if tt.planID != "" { - args.args = []string{tt.planID} + cmd = cmd.withArgs(tt.planID) } - cmd := &mockBillingSyncCmd{args: args} err := billingSyncPricingPlanAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -1196,7 +1030,7 @@ func TestBillingSyncAllPricingPlans(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -1206,9 +1040,7 @@ func TestBillingSyncAllPricingPlans(t *testing.T) { return service } - cmd := &mockBillingSyncAllCmd{} - - err := billingSyncAllPricingPlansAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + err := billingSyncAllPricingPlansAction(context.Background(), output, cfgMgr, serviceFactory) if tt.wantErr { require.Error(t, err) diff --git a/pkg/cli/admin_billing_subscribers.go b/pkg/cli/admin_billing_subscribers.go index 29468ba..7e7d1b3 100644 --- a/pkg/cli/admin_billing_subscribers.go +++ b/pkg/cli/admin_billing_subscribers.go @@ -24,15 +24,12 @@ Examples: if err != nil { return err } - return billingSubscribersListAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSubscribersListAction(ctx, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSubscribersListCmdGetter is an empty interface for list command (no args/flags needed) -type billingSubscribersListCmdGetter interface{} - -func billingSubscribersListAction(ctx context.Context, cmd billingSubscribersListCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSubscribersListAction(ctx context.Context, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -90,17 +87,12 @@ Examples: if err != nil { return err } - return billingSubscribersGetAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSubscribersGetAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSubscribersGetCmdGetter defines the interface for getting get command args. -type billingSubscribersGetCmdGetter interface { - Args() cli.Args -} - -func billingSubscribersGetAction(ctx context.Context, cmd billingSubscribersGetCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSubscribersGetAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("subscriber ID is required") } @@ -175,17 +167,12 @@ Examples: if err != nil { return err } - return billingSubscribersListGatewayAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSubscribersListGatewayAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSubscribersListGatewayCmdGetter defines the interface for getting list-gateway command args. -type billingSubscribersListGatewayCmdGetter interface { - Args() cli.Args -} - -func billingSubscribersListGatewayAction(ctx context.Context, cmd billingSubscribersListGatewayCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSubscribersListGatewayAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("gateway ID is required") } @@ -249,17 +236,12 @@ Examples: if err != nil { return err } - return billingSubscribersListUserAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSubscribersListUserAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSubscribersListUserCmdGetter defines the interface for getting list-user command args. -type billingSubscribersListUserCmdGetter interface { - Args() cli.Args -} - -func billingSubscribersListUserAction(ctx context.Context, cmd billingSubscribersListUserCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSubscribersListUserAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("user ID is required") } @@ -335,17 +317,12 @@ Examples: if err != nil { return err } - return billingSubscribersCancelAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSubscribersCancelAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSubscribersCancelCmdGetter defines the interface for getting cancel command flags. -type billingSubscribersCancelCmdGetter interface { - String(name string) string -} - -func billingSubscribersCancelAction(ctx context.Context, cmd billingSubscribersCancelCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSubscribersCancelAction(ctx context.Context, cmd flagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -392,17 +369,12 @@ Examples: if err != nil { return err } - return billingSubscribersAbortCancelAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSubscribersAbortCancelAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSubscribersAbortCancelCmdGetter defines the interface for getting abort-cancel command flags. -type billingSubscribersAbortCancelCmdGetter interface { - String(name string) string -} - -func billingSubscribersAbortCancelAction(ctx context.Context, cmd billingSubscribersAbortCancelCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSubscribersAbortCancelAction(ctx context.Context, cmd flagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -449,18 +421,12 @@ Examples: if err != nil { return err } - return billingSubscribersChangePlanAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSubscribersChangePlanAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSubscribersChangePlanCmdGetter defines the interface for getting change-plan command flags. -type billingSubscribersChangePlanCmdGetter interface { - String(name string) string - Int(name string) int -} - -func billingSubscribersChangePlanAction(ctx context.Context, cmd billingSubscribersChangePlanCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSubscribersChangePlanAction(ctx context.Context, cmd flagGetterWithInt, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -519,17 +485,12 @@ Examples: if err != nil { return err } - return billingSubscribersPauseAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSubscribersPauseAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSubscribersPauseCmdGetter defines the interface for getting pause command flags. -type billingSubscribersPauseCmdGetter interface { - String(name string) string -} - -func billingSubscribersPauseAction(ctx context.Context, cmd billingSubscribersPauseCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSubscribersPauseAction(ctx context.Context, cmd flagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -571,17 +532,12 @@ Examples: if err != nil { return err } - return billingSubscribersResumeAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultBillingAdminServiceFactory) + return billingSubscribersResumeAction(ctx, cmd, output, cfgMgr, defaultBillingAdminServiceFactory) }, } } -// billingSubscribersResumeCmdGetter defines the interface for getting resume command flags. -type billingSubscribersResumeCmdGetter interface { - String(name string) string -} - -func billingSubscribersResumeAction(ctx context.Context, cmd billingSubscribersResumeCmdGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { +func billingSubscribersResumeAction(ctx context.Context, cmd flagGetter, output Output, cfgMgr config.Manager, serviceFactory BillingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err diff --git a/pkg/cli/admin_billing_subscribers_test.go b/pkg/cli/admin_billing_subscribers_test.go index e26bbe8..ad3f1ce 100644 --- a/pkg/cli/admin_billing_subscribers_test.go +++ b/pkg/cli/admin_billing_subscribers_test.go @@ -9,7 +9,6 @@ import ( "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" "go.lumeweb.com/portal-sdk/admin" @@ -23,8 +22,6 @@ func unmarshalSubscriberJSON(data string) *admin.Subscriber { return &item } - - func TestBillingSubscribersList(t *testing.T) { tests := []struct { name string @@ -80,7 +77,7 @@ func TestBillingSubscribersList(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -90,9 +87,7 @@ func TestBillingSubscribersList(t *testing.T) { return service } - var cmd billingSubscribersListCmdGetter - - err := billingSubscribersListAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + err := billingSubscribersListAction(context.Background(), output, cfgMgr, serviceFactory) if tt.wantErr { require.Error(t, err) @@ -106,15 +101,6 @@ func TestBillingSubscribersList(t *testing.T) { } } -// billingSubscribersGetArgs implements billingSubscribersGetCmdGetter -type billingSubscribersGetArgs struct { - args cli.Args -} - -func (m *billingSubscribersGetArgs) Args() cli.Args { - return m.args -} - func TestBillingSubscribersGet(t *testing.T) { tests := []struct { name string @@ -148,17 +134,16 @@ func TestBillingSubscribersGet(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.subscriberID != "" { - args.args = []string{tt.subscriberID} + cmd = cmd.withArgs(tt.subscriberID) } - cmd := &billingSubscribersGetArgs{args: args} serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -178,15 +163,6 @@ func TestBillingSubscribersGet(t *testing.T) { } } -// billingSubscribersListGatewayArgs implements billingSubscribersListGatewayCmdGetter -type billingSubscribersListGatewayArgs struct { - args cli.Args -} - -func (m *billingSubscribersListGatewayArgs) Args() cli.Args { - return m.args -} - func TestBillingSubscribersListGateway(t *testing.T) { tests := []struct { name string @@ -223,17 +199,16 @@ func TestBillingSubscribersListGateway(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.gatewayID != "" { - args.args = []string{tt.gatewayID} + cmd = cmd.withArgs(tt.gatewayID) } - cmd := &billingSubscribersListGatewayArgs{args: args} serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -253,15 +228,6 @@ func TestBillingSubscribersListGateway(t *testing.T) { } } -// billingSubscribersListUserArgs implements billingSubscribersListUserCmdGetter -type billingSubscribersListUserArgs struct { - args cli.Args -} - -func (m *billingSubscribersListUserArgs) Args() cli.Args { - return m.args -} - func TestBillingSubscribersListUser(t *testing.T) { tests := []struct { name string @@ -298,17 +264,16 @@ func TestBillingSubscribersListUser(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - args := &mockArgs{} + cmd := newMockCommand() if tt.userID != "" { - args.args = []string{tt.userID} + cmd = cmd.withArgs(tt.userID) } - cmd := &billingSubscribersListUserArgs{args: args} serviceFactory := func(cm config.Manager, out Output) BillingAdminService { return service @@ -328,22 +293,6 @@ func TestBillingSubscribersListUser(t *testing.T) { } } -// billingSubscribersCancelCmd implements billingSubscribersCancelCmdGetter -type billingSubscribersCancelCmd struct { - userID string - mode string -} - -func (m *billingSubscribersCancelCmd) String(name string) string { - switch name { - case FlagUserID: - return m.userID - case FlagMode: - return m.mode - } - return "" -} - func TestBillingSubscribersCancel(t *testing.T) { tests := []struct { name string @@ -381,7 +330,7 @@ func TestBillingSubscribersCancel(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -391,10 +340,9 @@ func TestBillingSubscribersCancel(t *testing.T) { return service } - cmd := &billingSubscribersCancelCmd{ - userID: "123", - mode: "end_of_billing_period", - } + cmd := newMockCommand(). + withString(FlagUserID, "123"). + withString(FlagMode, "end_of_billing_period") err := billingSubscribersCancelAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -410,18 +358,6 @@ func TestBillingSubscribersCancel(t *testing.T) { } } -// billingSubscribersAbortCancelCmd implements billingSubscribersAbortCancelCmdGetter -type billingSubscribersAbortCancelCmd struct { - userID string -} - -func (m *billingSubscribersAbortCancelCmd) String(name string) string { - if name == FlagUserID { - return m.userID - } - return "" -} - func TestBillingSubscribersAbortCancel(t *testing.T) { tests := []struct { name string @@ -456,7 +392,7 @@ func TestBillingSubscribersAbortCancel(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -466,9 +402,7 @@ func TestBillingSubscribersAbortCancel(t *testing.T) { return service } - cmd := &billingSubscribersAbortCancelCmd{ - userID: "123", - } + cmd := newMockCommand().withString(FlagUserID, "123") err := billingSubscribersAbortCancelAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -484,26 +418,6 @@ func TestBillingSubscribersAbortCancel(t *testing.T) { } } -// billingSubscribersChangePlanCmd implements billingSubscribersChangePlanCmdGetter -type billingSubscribersChangePlanCmd struct { - userID string - periodID int -} - -func (m *billingSubscribersChangePlanCmd) String(name string) string { - if name == FlagUserID { - return m.userID - } - return "" -} - -func (m *billingSubscribersChangePlanCmd) Int(name string) int { - if name == FlagPlanID { - return m.periodID - } - return 0 -} - func TestBillingSubscribersChangePlan(t *testing.T) { tests := []struct { name string @@ -542,7 +456,7 @@ func TestBillingSubscribersChangePlan(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -552,10 +466,9 @@ func TestBillingSubscribersChangePlan(t *testing.T) { return service } - cmd := &billingSubscribersChangePlanCmd{ - userID: "123", - periodID: 1, - } + cmd := newMockCommand(). + withString(FlagUserID, "123"). + withInt(FlagPlanID, 1) err := billingSubscribersChangePlanAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -571,18 +484,6 @@ func TestBillingSubscribersChangePlan(t *testing.T) { } } -// billingSubscribersPauseCmd implements billingSubscribersPauseCmdGetter -type billingSubscribersPauseCmd struct { - userID string -} - -func (m *billingSubscribersPauseCmd) String(name string) string { - if name == FlagUserID { - return m.userID - } - return "" -} - func TestBillingSubscribersPause(t *testing.T) { tests := []struct { name string @@ -617,7 +518,7 @@ func TestBillingSubscribersPause(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -627,9 +528,7 @@ func TestBillingSubscribersPause(t *testing.T) { return service } - cmd := &billingSubscribersPauseCmd{ - userID: "123", - } + cmd := newMockCommand().withString(FlagUserID, "123") err := billingSubscribersPauseAction(context.Background(), cmd, output, cfgMgr, serviceFactory) @@ -645,18 +544,6 @@ func TestBillingSubscribersPause(t *testing.T) { } } -// billingSubscribersResumeCmd implements billingSubscribersResumeCmdGetter -type billingSubscribersResumeCmd struct { - userID string -} - -func (m *billingSubscribersResumeCmd) String(name string) string { - if name == FlagUserID { - return m.userID - } - return "" -} - func TestBillingSubscribersResume(t *testing.T) { tests := []struct { name string @@ -691,7 +578,7 @@ func TestBillingSubscribersResume(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockBillingAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -701,9 +588,7 @@ func TestBillingSubscribersResume(t *testing.T) { return service } - cmd := &billingSubscribersResumeCmd{ - userID: "123", - } + cmd := newMockCommand().withString(FlagUserID, "123") err := billingSubscribersResumeAction(context.Background(), cmd, output, cfgMgr, serviceFactory) diff --git a/pkg/cli/admin_pprof.go b/pkg/cli/admin_pprof.go index a347584..2f6e4e5 100644 --- a/pkg/cli/admin_pprof.go +++ b/pkg/cli/admin_pprof.go @@ -59,7 +59,7 @@ Examples: if err != nil { return err } - return adminPprofByteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, + return adminPprofByteAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { return svc.GetProfileIndex(ctx) }) @@ -80,7 +80,7 @@ Examples: if err != nil { return err } - return adminPprofByteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, + return adminPprofByteAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { return svc.GetBlockProfile(ctx) }) @@ -103,7 +103,7 @@ Examples: if err != nil { return err } - return adminPprofSetRateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, "block profile rate", + return adminPprofSetRateAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, "block profile rate", func(svc ProfilingAdminService, ctx context.Context, rate int) error { return svc.SetBlockProfileRate(ctx, rate) }) @@ -124,7 +124,7 @@ Examples: if err != nil { return err } - return adminPprofByteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, + return adminPprofByteAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { return svc.GetCmdline(ctx) }) @@ -145,7 +145,7 @@ Examples: if err != nil { return err } - return adminPprofByteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, + return adminPprofByteAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { return svc.GetGoroutineProfile(ctx) }) @@ -166,7 +166,7 @@ Examples: if err != nil { return err } - return adminPprofByteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, + return adminPprofByteAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { return svc.GetHeapProfile(ctx) }) @@ -187,7 +187,7 @@ Examples: if err != nil { return err } - return adminPprofByteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, + return adminPprofByteAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { return svc.GetMutexProfile(ctx) }) @@ -210,7 +210,7 @@ Examples: if err != nil { return err } - return adminPprofSetRateAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, "mutex profile fraction", + return adminPprofSetRateAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, "mutex profile fraction", func(svc ProfilingAdminService, ctx context.Context, rate int) error { return svc.SetMutexProfileFraction(ctx, rate) }) @@ -231,7 +231,7 @@ Examples: if err != nil { return err } - return adminPprofByteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, + return adminPprofByteAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { return svc.GetCPUProfile(ctx) }) @@ -253,7 +253,7 @@ Examples: if err != nil { return err } - return adminPprofStatusAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory) + return adminPprofStatusAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory) }, } } @@ -271,7 +271,7 @@ Examples: if err != nil { return err } - return adminPprofByteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, + return adminPprofByteAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { return svc.GetSymbol(ctx) }) @@ -292,7 +292,7 @@ Examples: if err != nil { return err } - return adminPprofByteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, + return adminPprofByteAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { return svc.GetThreadcreate(ctx) }) @@ -313,7 +313,7 @@ Examples: if err != nil { return err } - return adminPprofByteAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultProfilingAdminServiceFactory, + return adminPprofByteAction(ctx, cmd, output, cfgMgr, defaultProfilingAdminServiceFactory, func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { return svc.GetTrace(ctx) }) @@ -321,11 +321,7 @@ Examples: } } -type adminPprofCmdGetter interface { - Args() cli.Args -} - -func adminPprofByteAction(ctx context.Context, cmd adminPprofCmdGetter, output Output, cfgMgr config.Manager, serviceFactory ProfilingAdminServiceFactory, fn func(ProfilingAdminService, context.Context) ([]byte, error)) error { +func adminPprofByteAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory ProfilingAdminServiceFactory, fn func(ProfilingAdminService, context.Context) ([]byte, error)) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err @@ -346,7 +342,7 @@ func adminPprofByteAction(ctx context.Context, cmd adminPprofCmdGetter, output O return err } -func adminPprofSetRateAction(ctx context.Context, cmd adminPprofCmdGetter, output Output, cfgMgr config.Manager, serviceFactory ProfilingAdminServiceFactory, label string, fn func(ProfilingAdminService, context.Context, int) error) error { +func adminPprofSetRateAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory ProfilingAdminServiceFactory, label string, fn func(ProfilingAdminService, context.Context, int) error) error { if cmd.Args().Len() < 1 { return fmt.Errorf("%s value is required", label) } @@ -374,7 +370,7 @@ func adminPprofSetRateAction(ctx context.Context, cmd adminPprofCmdGetter, outpu return nil } -func adminPprofStatusAction(ctx context.Context, cmd adminPprofCmdGetter, output Output, cfgMgr config.Manager, serviceFactory ProfilingAdminServiceFactory) error { +func adminPprofStatusAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory ProfilingAdminServiceFactory) error { service := serviceFactory(cfgMgr, output) if err := service.RequireAuthenticated(); err != nil { return err diff --git a/pkg/cli/admin_pprof_test.go b/pkg/cli/admin_pprof_test.go new file mode 100644 index 0000000..efcfb7f --- /dev/null +++ b/pkg/cli/admin_pprof_test.go @@ -0,0 +1,267 @@ +package cli + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" + "go.lumeweb.com/portal-sdk/admin" +) + +func TestAdminPprofByteAction(t *testing.T) { + tests := []struct { + name string + setupMocks func(*configmocks.MockManager, *MockProfilingAdminService) + fn func(ProfilingAdminService, context.Context) ([]byte, error) + wantErr bool + errContains string + }{ + { + name: "successful byte profile", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetHeapProfile(context.Background()).Return([]byte("heap-data"), nil) + }, + fn: func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { + return svc.GetHeapProfile(ctx) + }, + wantErr: false, + }, + { + name: "returns error when not authenticated", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) { + service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) + }, + fn: func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { + return svc.GetHeapProfile(ctx) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "returns error when service fails", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetHeapProfile(context.Background()).Return(nil, errors.New("profile fetch failed")) + }, + fn: func(svc ProfilingAdminService, ctx context.Context) ([]byte, error) { + return svc.GetHeapProfile(ctx) + }, + wantErr: true, + errContains: "profile fetch failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockProfilingAdminService(t) + output := newTestOutput() + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, service) + } + + serviceFactory := func(cm config.Manager, out Output) ProfilingAdminService { + return service + } + + cmd := newMockCommand() + + err := adminPprofByteAction(context.Background(), cmd, output, cfgMgr, serviceFactory, tt.fn) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestAdminPprofSetRateAction(t *testing.T) { + tests := []struct { + name string + args []string + label string + setupMocks func(*configmocks.MockManager, *MockProfilingAdminService) + fn func(ProfilingAdminService, context.Context, int) error + wantErr bool + errContains string + }{ + { + name: "successful set block rate", + args: []string{"1"}, + label: "block profile rate", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().SetBlockProfileRate(context.Background(), 1).Return(nil) + }, + fn: func(svc ProfilingAdminService, ctx context.Context, rate int) error { + return svc.SetBlockProfileRate(ctx, rate) + }, + wantErr: false, + }, + { + name: "successful set mutex fraction", + args: []string{"100"}, + label: "mutex profile fraction", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().SetMutexProfileFraction(context.Background(), 100).Return(nil) + }, + fn: func(svc ProfilingAdminService, ctx context.Context, rate int) error { + return svc.SetMutexProfileFraction(ctx, rate) + }, + wantErr: false, + }, + { + name: "returns error when rate arg is missing", + args: nil, + label: "block profile rate", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) {}, + fn: func(svc ProfilingAdminService, ctx context.Context, rate int) error { return nil }, + wantErr: true, + errContains: "block profile rate value is required", + }, + { + name: "returns error when rate arg is not a number", + args: []string{"abc"}, + label: "block profile rate", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) {}, + fn: func(svc ProfilingAdminService, ctx context.Context, rate int) error { return nil }, + wantErr: true, + errContains: "invalid block profile rate value: abc", + }, + { + name: "returns error when not authenticated", + args: []string{"1"}, + label: "block profile rate", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) { + service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) + }, + fn: func(svc ProfilingAdminService, ctx context.Context, rate int) error { + return svc.SetBlockProfileRate(ctx, rate) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "returns error when service fails", + args: []string{"1"}, + label: "block profile rate", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().SetBlockProfileRate(context.Background(), 1).Return(errors.New("set rate failed")) + }, + fn: func(svc ProfilingAdminService, ctx context.Context, rate int) error { + return svc.SetBlockProfileRate(ctx, rate) + }, + wantErr: true, + errContains: "set rate failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockProfilingAdminService(t) + output := newTestOutput() + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, service) + } + + serviceFactory := func(cm config.Manager, out Output) ProfilingAdminService { + return service + } + + cmd := newMockCommand() + if tt.args != nil { + cmd = cmd.withArgs(tt.args...) + } + + err := adminPprofSetRateAction(context.Background(), cmd, output, cfgMgr, serviceFactory, tt.label, tt.fn) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestAdminPprofStatusAction(t *testing.T) { + tests := []struct { + name string + setupMocks func(*configmocks.MockManager, *MockProfilingAdminService) + wantErr bool + errContains string + }{ + { + name: "successful status", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetStatus(context.Background()).Return(&admin.ProfilingStatus{}, nil) + }, + wantErr: false, + }, + { + name: "returns error when not authenticated", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) { + service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "returns error when service fails", + setupMocks: func(cfgMgr *configmocks.MockManager, service *MockProfilingAdminService) { + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().GetStatus(context.Background()).Return(nil, errors.New("status fetch failed")) + }, + wantErr: true, + errContains: "status fetch failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockProfilingAdminService(t) + output := newTestOutput() + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, service) + } + + serviceFactory := func(cm config.Manager, out Output) ProfilingAdminService { + return service + } + + cmd := newMockCommand() + + err := adminPprofStatusAction(context.Background(), cmd, output, cfgMgr, serviceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/cli/admin_quota.go b/pkg/cli/admin_quota.go index 115220f..c328f12 100644 --- a/pkg/cli/admin_quota.go +++ b/pkg/cli/admin_quota.go @@ -8,19 +8,13 @@ import ( "strings" "time" - "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" "go.lumeweb.com/portal-sdk/admin" ) var quotaAdminServiceFactory QuotaAdminServiceFactory = defaultQuotaAdminServiceFactory -// quotaPlansListCmdGetter interface for quota plans list command -type quotaPlansListCmdGetter interface { -} - -// quotaPlansListAction lists all quota plans -func quotaPlansListAction(ctx context.Context, cmd quotaPlansListCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaPlansListAction(ctx context.Context, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -70,13 +64,8 @@ func quotaPlansListAction(ctx context.Context, cmd quotaPlansListCmdGetter, outp return nil } -// quotaPlansGetCmdGetter interface for quota plans get command -type quotaPlansGetCmdGetter interface { - Args() cli.Args -} - // quotaPlansGetAction gets a quota plan by ID -func quotaPlansGetAction(ctx context.Context, cmd quotaPlansGetCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaPlansGetAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -126,16 +115,8 @@ func quotaPlansGetAction(ctx context.Context, cmd quotaPlansGetCmdGetter, output return nil } -// quotaPlansCreateCmdGetter interface for quota plans create command -type quotaPlansCreateCmdGetter interface { - String(string) string - Int(string) int - Bool(string) bool - IsSet(string) bool -} - // quotaPlansCreateAction creates a new quota plan -func quotaPlansCreateAction(ctx context.Context, cmd quotaPlansCreateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaPlansCreateAction(ctx context.Context, cmd flagGetterWithIsSet, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -203,17 +184,11 @@ func quotaPlansCreateAction(ctx context.Context, cmd quotaPlansCreateCmdGetter, return nil } -// quotaPlansUpdateCmdGetter interface for quota plans update command -type quotaPlansUpdateCmdGetter interface { - Args() cli.Args - String(string) string - Int(string) int - Bool(string) bool - IsSet(string) bool -} - // quotaPlansUpdateAction updates a quota plan -func quotaPlansUpdateAction(ctx context.Context, cmd quotaPlansUpdateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaPlansUpdateAction(ctx context.Context, cmd interface { + argsGetter + flagGetterWithIsSet +}, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -321,13 +296,8 @@ func quotaPlansUpdateAction(ctx context.Context, cmd quotaPlansUpdateCmdGetter, return nil } -// quotaPlansDeleteCmdGetter interface for quota plans delete command -type quotaPlansDeleteCmdGetter interface { - Args() cli.Args -} - // quotaPlansDeleteAction deletes a quota plan -func quotaPlansDeleteAction(ctx context.Context, cmd quotaPlansDeleteCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaPlansDeleteAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -359,13 +329,8 @@ func quotaPlansDeleteAction(ctx context.Context, cmd quotaPlansDeleteCmdGetter, return nil } -// quotaPlansSetDefaultCmdGetter interface for quota plans set-default command -type quotaPlansSetDefaultCmdGetter interface { - Args() cli.Args -} - // quotaPlansSetDefaultAction sets a quota plan as default -func quotaPlansSetDefaultAction(ctx context.Context, cmd quotaPlansSetDefaultCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaPlansSetDefaultAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -400,12 +365,8 @@ func quotaPlansSetDefaultAction(ctx context.Context, cmd quotaPlansSetDefaultCmd return nil } -// quotaAllowancesListCmdGetter interface for quota allowances list command -type quotaAllowancesListCmdGetter interface { -} - // quotaAllowancesListAction lists all quota allowances -func quotaAllowancesListAction(ctx context.Context, cmd quotaAllowancesListCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaAllowancesListAction(ctx context.Context, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -457,15 +418,8 @@ func quotaAllowancesListAction(ctx context.Context, cmd quotaAllowancesListCmdGe return nil } -// quotaAllowancesCreateCmdGetter interface for quota allowances create command -type quotaAllowancesCreateCmdGetter interface { - Int(string) int - String(string) string - IsSet(string) bool -} - // quotaAllowancesCreateAction creates a quota allowance -func quotaAllowancesCreateAction(ctx context.Context, cmd quotaAllowancesCreateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaAllowancesCreateAction(ctx context.Context, cmd flagGetterWithIsSet, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -525,16 +479,11 @@ func quotaAllowancesCreateAction(ctx context.Context, cmd quotaAllowancesCreateC return nil } -// quotaAllowancesUpdateCmdGetter interface for quota allowances update command -type quotaAllowancesUpdateCmdGetter interface { - Args() cli.Args - Int(string) int - String(string) string - IsSet(string) bool -} - // quotaAllowancesUpdateAction updates a quota allowance -func quotaAllowancesUpdateAction(ctx context.Context, cmd quotaAllowancesUpdateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaAllowancesUpdateAction(ctx context.Context, cmd interface { + argsGetter + flagGetterWithIsSet +}, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -623,13 +572,8 @@ func quotaAllowancesUpdateAction(ctx context.Context, cmd quotaAllowancesUpdateC return nil } -// quotaAllowancesDeleteCmdGetter interface for quota allowances delete command -type quotaAllowancesDeleteCmdGetter interface { - Args() cli.Args -} - // quotaAllowancesDeleteAction deletes a quota allowance -func quotaAllowancesDeleteAction(ctx context.Context, cmd quotaAllowancesDeleteCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaAllowancesDeleteAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -661,12 +605,8 @@ func quotaAllowancesDeleteAction(ctx context.Context, cmd quotaAllowancesDeleteC return nil } -// quotaStatsCmdGetter interface for quota stats command -type quotaStatsCmdGetter interface { -} - // quotaStatsAction gets quota system statistics -func quotaStatsAction(ctx context.Context, cmd quotaStatsCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaStatsAction(ctx context.Context, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -701,14 +641,8 @@ func quotaStatsAction(ctx context.Context, cmd quotaStatsCmdGetter, output Outpu return nil } -// quotaReconcileCmdGetter interface for quota reconcile command -type quotaReconcileCmdGetter interface { - Int(string) int - IsSet(string) bool -} - // quotaReconcileAction reconciles quota data -func quotaReconcileAction(ctx context.Context, cmd quotaReconcileCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaReconcileAction(ctx context.Context, cmd flagGetterWithIsSet, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() @@ -742,13 +676,8 @@ func quotaReconcileAction(ctx context.Context, cmd quotaReconcileCmdGetter, outp return nil } -// quotaCleanupCmdGetter interface for quota cleanup command -type quotaCleanupCmdGetter interface { - Int(string) int -} - // quotaCleanupAction cleans up expired quota data -func quotaCleanupAction(ctx context.Context, cmd quotaCleanupCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaCleanupAction(ctx context.Context, cmd flagGetterWithInt, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() @@ -777,12 +706,8 @@ func quotaCleanupAction(ctx context.Context, cmd quotaCleanupCmdGetter, output O return nil } -// quotaUserConfigsListCmdGetter interface for quota user configs list command -type quotaUserConfigsListCmdGetter interface { -} - // quotaUserConfigsListAction lists all user quota configs -func quotaUserConfigsListAction(ctx context.Context, cmd quotaUserConfigsListCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaUserConfigsListAction(ctx context.Context, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -844,13 +769,8 @@ func quotaUserConfigsListAction(ctx context.Context, cmd quotaUserConfigsListCmd return nil } -// quotaUserConfigsResetCmdGetter interface for quota user configs reset command -type quotaUserConfigsResetCmdGetter interface { - Args() cli.Args -} - // quotaUserConfigsResetAction resets a user's quota config to default -func quotaUserConfigsResetAction(ctx context.Context, cmd quotaUserConfigsResetCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaUserConfigsResetAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -887,15 +807,8 @@ func quotaUserConfigsResetAction(ctx context.Context, cmd quotaUserConfigsResetC return nil } -// quotaUserConfigsUpdateCmdGetter interface for quota user configs update command -type quotaUserConfigsUpdateCmdGetter interface { - Int(name string) int - String(name string) string - IsSet(name string) bool -} - // quotaUserConfigsUpdateAction updates a user's quota config -func quotaUserConfigsUpdateAction(ctx context.Context, cmd quotaUserConfigsUpdateCmdGetter, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { +func quotaUserConfigsUpdateAction(ctx context.Context, cmd flagGetterWithIsSet, output Output, cfgMgr config.Manager, serviceFactory QuotaAdminServiceFactory) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() diff --git a/pkg/cli/admin_quota_actions_test.go b/pkg/cli/admin_quota_actions_test.go new file mode 100644 index 0000000..0cc7d3d --- /dev/null +++ b/pkg/cli/admin_quota_actions_test.go @@ -0,0 +1,517 @@ +package cli + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" + "go.lumeweb.com/portal-sdk/admin" +) + +func TestQuotaAllowancesCreate(t *testing.T) { + tests := []struct { + name string + cmd *mockCommand + jsonOutput bool + setupMocks func(*configmocks.MockManager, *MockQuotaAdminService) + wantErr bool + errContains string + }{ + { + name: "success with all flags", + cmd: newMockCommand(). + withInt(FlagUserID, 42). + withIsSet(FlagUserID, true). + withString(FlagSource, "admin"). + withString(FlagQuotaType, "monthly"). + withInt(FlagUploadLimit, 1048576), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().CreateAllowance( + mock.Anything, 42, "admin", "monthly", 1048576, 1048576, 0, time.Time{}, + ).Return(&admin.QuotaAllowance{}, nil) + }, + wantErr: false, + }, + { + name: "success with expiry flag", + cmd: newMockCommand(). + withInt(FlagUserID, 10). + withIsSet(FlagUserID, true). + withString(FlagSource, "grant"). + withString(FlagQuotaType, "one-time"). + withInt(FlagUploadLimit, 2048). + withInt(FlagExpiry, 30). + withIsSet(FlagExpiry, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().CreateAllowance( + mock.Anything, 10, "grant", "one-time", 2048, 2048, 0, mock.AnythingOfType("time.Time"), + ).Return(&admin.QuotaAllowance{}, nil) + }, + wantErr: false, + }, + { + name: "success json output", + cmd: newMockCommand(). + withInt(FlagUserID, 5). + withIsSet(FlagUserID, true). + withString(FlagSource, "promo"). + withString(FlagQuotaType, "annual"). + withInt(FlagUploadLimit, 512), + jsonOutput: true, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().CreateAllowance( + mock.Anything, 5, "promo", "annual", 512, 512, 0, time.Time{}, + ).Return(&admin.QuotaAllowance{}, nil) + }, + wantErr: false, + }, + { + name: "missing user-id flag", + cmd: newMockCommand(), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + }, + wantErr: true, + errContains: "--user-id is required", + }, + { + name: "not authenticated", + cmd: newMockCommand(). + withInt(FlagUserID, 1). + withIsSet(FlagUserID, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(errors.New("authentication required")) + }, + wantErr: true, + errContains: "authentication required", + }, + { + name: "service error", + cmd: newMockCommand(). + withInt(FlagUserID, 1). + withIsSet(FlagUserID, true). + withString(FlagSource, "admin"). + withString(FlagQuotaType, "monthly"). + withInt(FlagUploadLimit, 1024), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().CreateAllowance( + mock.Anything, 1, "admin", "monthly", 1024, 1024, 0, time.Time{}, + ).Return(nil, errors.New("api error")) + }, + wantErr: true, + errContains: "api error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockQuotaAdminService(t) + + tt.setupMocks(cfgMgr, service) + + savedFactory := quotaAdminServiceFactory + quotaAdminServiceFactory = func(cm config.Manager, out Output) QuotaAdminService { + return service + } + defer func() { quotaAdminServiceFactory = savedFactory }() + + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + err := quotaAllowancesCreateAction(context.Background(), tt.cmd, output, cfgMgr, quotaAdminServiceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestQuotaAllowancesUpdate(t *testing.T) { + tests := []struct { + name string + cmd *mockCommand + jsonOutput bool + setupMocks func(*configmocks.MockManager, *MockQuotaAdminService) + wantErr bool + errContains string + }{ + { + name: "success with upload-limit", + cmd: newMockCommand(). + withArgs("grant-123"). + withInt(FlagUploadLimit, 2048). + withIsSet(FlagUploadLimit, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().UpdateAllowance( + mock.Anything, "grant-123", 0, "", "", 2048, 2048, 0, time.Time{}, + ).Return(&admin.QuotaAllowance{}, nil) + }, + wantErr: false, + }, + { + name: "success with multiple fields", + cmd: newMockCommand(). + withArgs("grant-456"). + withInt(FlagUserID, 7). + withIsSet(FlagUserID, true). + withString(FlagSource, "admin"). + withIsSet(FlagSource, true). + withString(FlagQuotaType, "monthly"). + withIsSet(FlagQuotaType, true). + withInt(FlagUploadLimit, 4096). + withIsSet(FlagUploadLimit, true). + withInt(FlagDownloadLimit, 8192). + withIsSet(FlagDownloadLimit, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().UpdateAllowance( + mock.Anything, "grant-456", 7, "admin", "monthly", 4096, 8192, 0, time.Time{}, + ).Return(&admin.QuotaAllowance{}, nil) + }, + wantErr: false, + }, + { + name: "missing grant ID", + cmd: newMockCommand(), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + }, + wantErr: true, + errContains: "grant ID is required", + }, + { + name: "no update fields provided", + cmd: newMockCommand(). + withArgs("grant-789"), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + }, + wantErr: true, + errContains: "at least one field must be provided for update", + }, + { + name: "not authenticated", + cmd: newMockCommand(). + withArgs("grant-1"). + withInt(FlagUploadLimit, 1024). + withIsSet(FlagUploadLimit, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(errors.New("not authenticated")) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "service error", + cmd: newMockCommand(). + withArgs("grant-1"). + withInt(FlagUploadLimit, 1024). + withIsSet(FlagUploadLimit, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().UpdateAllowance( + mock.Anything, "grant-1", 0, "", "", 1024, 1024, 0, time.Time{}, + ).Return(nil, errors.New("allowance not found")) + }, + wantErr: true, + errContains: "allowance not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockQuotaAdminService(t) + + tt.setupMocks(cfgMgr, service) + + savedFactory := quotaAdminServiceFactory + quotaAdminServiceFactory = func(cm config.Manager, out Output) QuotaAdminService { + return service + } + defer func() { quotaAdminServiceFactory = savedFactory }() + + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + err := quotaAllowancesUpdateAction(context.Background(), tt.cmd, output, cfgMgr, quotaAdminServiceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestQuotaReconcile(t *testing.T) { + tests := []struct { + name string + cmd *mockCommand + jsonOutput bool + setupMocks func(*configmocks.MockManager, *MockQuotaAdminService) + wantErr bool + errContains string + }{ + { + name: "success without user-id", + cmd: newMockCommand(), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().Reconcile(mock.Anything, (*int)(nil)).Return("all quotas reconciled", 5, nil) + }, + wantErr: false, + }, + { + name: "success with user-id", + cmd: newMockCommand(). + withInt(FlagUserID, 42). + withIsSet(FlagUserID, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().Reconcile(mock.Anything, mock.AnythingOfType("*int")).Return("user quotas reconciled", 2, nil) + }, + wantErr: false, + }, + { + name: "success json output", + cmd: newMockCommand(), + jsonOutput: true, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().Reconcile(mock.Anything, (*int)(nil)).Return("reconciled", 3, nil) + }, + wantErr: false, + }, + { + name: "not authenticated", + cmd: newMockCommand(), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(errors.New("authentication required")) + }, + wantErr: true, + errContains: "authentication required", + }, + { + name: "service error", + cmd: newMockCommand(), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().Reconcile(mock.Anything, (*int)(nil)).Return("", 0, errors.New("reconciliation failed")) + }, + wantErr: true, + errContains: "reconciliation failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockQuotaAdminService(t) + + tt.setupMocks(cfgMgr, service) + + savedFactory := quotaAdminServiceFactory + quotaAdminServiceFactory = func(cm config.Manager, out Output) QuotaAdminService { + return service + } + defer func() { quotaAdminServiceFactory = savedFactory }() + + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + err := quotaReconcileAction(context.Background(), tt.cmd, output, cfgMgr, quotaAdminServiceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestQuotaUserConfigsUpdate(t *testing.T) { + tests := []struct { + name string + cmd *mockCommand + jsonOutput bool + setupMocks func(*configmocks.MockManager, *MockQuotaAdminService) + wantErr bool + errContains string + }{ + { + name: "success with plan-id", + cmd: newMockCommand(). + withInt(FlagUserID, 10). + withIsSet(FlagUserID, true). + withInt(FlagPlanID, 3). + withIsSet(FlagPlanID, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().UpdateUserConfig( + mock.Anything, 10, mock.Anything, + ).Return(&admin.UserQuotaConfig{}, nil) + }, + wantErr: false, + }, + { + name: "success with enforcement-policy", + cmd: newMockCommand(). + withInt(FlagUserID, 5). + withIsSet(FlagUserID, true). + withString(FlagEnforcementPolicy, "HARD_LIMITS"). + withIsSet(FlagEnforcementPolicy, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().UpdateUserConfig( + mock.Anything, 5, mock.Anything, + ).Return(&admin.UserQuotaConfig{}, nil) + }, + wantErr: false, + }, + { + name: "missing user-id", + cmd: newMockCommand(), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) {}, + wantErr: true, + errContains: "--user-id is required", + }, + { + name: "no update fields provided", + cmd: newMockCommand(). + withInt(FlagUserID, 1). + withIsSet(FlagUserID, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + }, + wantErr: true, + errContains: "at least one field must be provided for update", + }, + { + name: "invalid enforcement-policy value", + cmd: newMockCommand(). + withInt(FlagUserID, 1). + withIsSet(FlagUserID, true). + withString(FlagEnforcementPolicy, "INVALID"). + withIsSet(FlagEnforcementPolicy, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + }, + wantErr: true, + errContains: "invalid --enforcement-policy value", + }, + { + name: "not authenticated", + cmd: newMockCommand(). + withInt(FlagUserID, 1). + withIsSet(FlagUserID, true). + withInt(FlagPlanID, 2). + withIsSet(FlagPlanID, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(errors.New("not authenticated")) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "service error", + cmd: newMockCommand(). + withInt(FlagUserID, 1). + withIsSet(FlagUserID, true). + withInt(FlagPlanID, 2). + withIsSet(FlagPlanID, true), + jsonOutput: false, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().UpdateUserConfig( + mock.Anything, 1, mock.Anything, + ).Return(nil, errors.New("update failed")) + }, + wantErr: true, + errContains: "update failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockQuotaAdminService(t) + + tt.setupMocks(cfgMgr, service) + + savedFactory := quotaAdminServiceFactory + quotaAdminServiceFactory = func(cm config.Manager, out Output) QuotaAdminService { + return service + } + defer func() { quotaAdminServiceFactory = savedFactory }() + + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + err := quotaUserConfigsUpdateAction(context.Background(), tt.cmd, output, cfgMgr, quotaAdminServiceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/cli/admin_quota_create_update_test.go b/pkg/cli/admin_quota_create_update_test.go index f59585d..6a5d2a4 100644 --- a/pkg/cli/admin_quota_create_update_test.go +++ b/pkg/cli/admin_quota_create_update_test.go @@ -9,159 +9,25 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" "go.lumeweb.com/portal-sdk/admin" ) -// mockQuotaPlansCreateCmd implements quotaPlansCreateCmdGetter -type mockQuotaPlansCreateCmd struct { - name string - description string - upload int - download int - storage int - windowType string - isActive bool - isDefault bool - isSetName bool -} - -func (m *mockQuotaPlansCreateCmd) String(s string) string { - switch s { - case FlagName: - return m.name - case FlagDescription: - return m.description - case FlagWindowType: - return m.windowType - default: - return "" - } -} - -func (m *mockQuotaPlansCreateCmd) Int(s string) int { - switch s { - case FlagUploadLimit: - return m.upload - case FlagDownloadLimit: - return m.download - case FlagStorageLimit: - return m.storage - default: - return 0 - } -} - -func (m *mockQuotaPlansCreateCmd) Bool(s string) bool { - switch s { - case FlagIsActive: - return m.isActive - case FlagIsDefault: - return m.isDefault - default: - return false - } -} - -func (m *mockQuotaPlansCreateCmd) IsSet(s string) bool { - switch s { - case FlagName: - return m.isSetName - default: - return false - } -} - -// mockQuotaPlansUpdateCmd implements quotaPlansUpdateCmdGetter -type mockQuotaPlansUpdateCmd struct { - args *mockArgs - name string - description string - upload int - download int - storage int - windowType string - isActive bool - isDefault bool - isSetActive bool - isSetDefault bool - isSetWindowType bool -} - -func (m *mockQuotaPlansUpdateCmd) Args() cli.Args { - return m.args -} - -func (m *mockQuotaPlansUpdateCmd) String(s string) string { - switch s { - case FlagName: - return m.name - case FlagDescription: - return m.description - case FlagWindowType: - return m.windowType - default: - return "" - } -} - -func (m *mockQuotaPlansUpdateCmd) Int(s string) int { - switch s { - case FlagUploadLimit: - return m.upload - case FlagDownloadLimit: - return m.download - case FlagStorageLimit: - return m.storage - default: - return 0 - } -} - -func (m *mockQuotaPlansUpdateCmd) Bool(s string) bool { - switch s { - case FlagIsActive: - return m.isActive - case FlagIsDefault: - return m.isDefault - default: - return false - } -} - -func (m *mockQuotaPlansUpdateCmd) IsSet(s string) bool { - switch s { - case FlagIsActive: - return m.isSetActive - case FlagIsDefault: - return m.isSetDefault - case FlagWindowType: - return m.isSetWindowType - case FlagUploadLimit, FlagDownloadLimit, FlagStorageLimit: - return true - default: - return false - } -} - func TestQuotaPlansCreate(t *testing.T) { tests := []struct { name string - cmd *mockQuotaPlansCreateCmd - jsonOutput bool + cmd *mockCommand setupMocks func(*configmocks.MockManager, *MockQuotaAdminService) wantErr bool errContains string }{ { name: "success with is-active flag", - cmd: &mockQuotaPlansCreateCmd{ - name: "Free", - isSetName: true, - isActive: true, - }, + cmd: newMockCommand(). + withString(FlagName, "Free"). + withIsSet(FlagName, true). + withBool(FlagIsActive, true), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) svc.EXPECT().CreatePlan(mock.Anything, mock.AnythingOfType("*admin.QuotaPlan")).Return(&admin.QuotaPlan{}, nil) @@ -170,12 +36,11 @@ func TestQuotaPlansCreate(t *testing.T) { }, { name: "success with is-active and is-default flags", - cmd: &mockQuotaPlansCreateCmd{ - name: "Free", - isSetName: true, - isActive: true, - isDefault: true, - }, + cmd: newMockCommand(). + withString(FlagName, "Free"). + withIsSet(FlagName, true). + withBool(FlagIsActive, true). + withBool(FlagIsDefault, true), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) svc.EXPECT().CreatePlan(mock.Anything, mock.AnythingOfType("*admin.QuotaPlan")).Return(&admin.QuotaPlan{}, nil) @@ -185,12 +50,11 @@ func TestQuotaPlansCreate(t *testing.T) { }, { name: "is-default fails but plan still created", - cmd: &mockQuotaPlansCreateCmd{ - name: "Free", - isSetName: true, - isActive: true, - isDefault: true, - }, + cmd: newMockCommand(). + withString(FlagName, "Free"). + withIsSet(FlagName, true). + withBool(FlagIsActive, true). + withBool(FlagIsDefault, true), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) svc.EXPECT().CreatePlan(mock.Anything, mock.AnythingOfType("*admin.QuotaPlan")).Return(&admin.QuotaPlan{}, nil) @@ -201,12 +65,11 @@ func TestQuotaPlansCreate(t *testing.T) { }, { name: "success with description", - cmd: &mockQuotaPlansCreateCmd{ - name: "Free", - isSetName: true, - description: "Free tier plan", - isActive: true, - }, + cmd: newMockCommand(). + withString(FlagName, "Free"). + withIsSet(FlagName, true). + withString(FlagDescription, "Free tier plan"). + withBool(FlagIsActive, true), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) svc.EXPECT().CreatePlan(mock.Anything, mock.AnythingOfType("*admin.QuotaPlan")).Return(&admin.QuotaPlan{}, nil) @@ -215,10 +78,9 @@ func TestQuotaPlansCreate(t *testing.T) { }, { name: "not authenticated", - cmd: &mockQuotaPlansCreateCmd{ - name: "Free", - isSetName: true, - }, + cmd: newMockCommand(). + withString(FlagName, "Free"). + withIsSet(FlagName, true), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(errors.New("authentication required")) }, @@ -227,11 +89,10 @@ func TestQuotaPlansCreate(t *testing.T) { }, { name: "service error", - cmd: &mockQuotaPlansCreateCmd{ - name: "Free", - isSetName: true, - isActive: true, - }, + cmd: newMockCommand(). + withString(FlagName, "Free"). + withIsSet(FlagName, true). + withBool(FlagIsActive, true), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) svc.EXPECT().CreatePlan(mock.Anything, mock.AnythingOfType("*admin.QuotaPlan")).Return(nil, errors.New("api error")) @@ -241,7 +102,7 @@ func TestQuotaPlansCreate(t *testing.T) { }, { name: "returns error when no fields provided for create", - cmd: &mockQuotaPlansCreateCmd{}, + cmd: newMockCommand(), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) }, @@ -257,7 +118,7 @@ func TestQuotaPlansCreate(t *testing.T) { tt.setupMocks(cfgMgr, service) - output := NewOutputFormatter(tt.jsonOutput, false, false, false) + output := newTestOutput() serviceFactory := func(cm config.Manager, out Output) QuotaAdminService { return service @@ -277,89 +138,21 @@ func TestQuotaPlansCreate(t *testing.T) { } } -func TestQuotaPlansSetDefault_Enhanced(t *testing.T) { - tests := []struct { - name string - args []string - setupMocks func(*configmocks.MockManager, *MockQuotaAdminService) - wantErr bool - errContains string - }{ - { - name: "success", - args: []string{"2"}, - setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { - svc.EXPECT().RequireAuthenticated().Return(nil) - svc.EXPECT().SetDefaultPlan(mock.Anything, "2").Return(nil) - }, - wantErr: false, - }, - { - name: "plan not found with helpful error", - args: []string{"3"}, - setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { - svc.EXPECT().RequireAuthenticated().Return(nil) - svc.EXPECT().SetDefaultPlan(mock.Anything, "3").Return(fmt.Errorf("%w: plan not found", admin.ErrNotFound)) - }, - wantErr: true, - errContains: "ensure the plan is active", - }, - { - name: "other error passes through", - args: []string{"2"}, - setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { - svc.EXPECT().RequireAuthenticated().Return(nil) - svc.EXPECT().SetDefaultPlan(mock.Anything, "2").Return(errors.New("server error")) - }, - wantErr: true, - errContains: "server error", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cfgMgr := configmocks.NewMockManager(t) - service := NewMockQuotaAdminService(t) - - tt.setupMocks(cfgMgr, service) - - output := NewOutputFormatter(false, false, false, false) - - serviceFactory := func(cm config.Manager, out Output) QuotaAdminService { - return service - } - - cmd := &mockQuotaPlansGetCmd{args: &mockArgs{args: tt.args}} - err := quotaPlansSetDefaultAction(context.Background(), cmd, output, cfgMgr, serviceFactory) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - assert.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) - } -} - func TestQuotaPlansUpdate(t *testing.T) { tests := []struct { name string - cmd *mockQuotaPlansUpdateCmd - jsonOutput bool + planID string + cmd *mockCommand setupMocks func(*configmocks.MockManager, *MockQuotaAdminService) wantErr bool errContains string }{ { - name: "success with is-active flag", - cmd: &mockQuotaPlansUpdateCmd{ - args: &mockArgs{args: []string{"2"}}, - isActive: true, - isSetActive: true, - }, + name: "success with is-active flag", + planID: "2", + cmd: newMockCommand(). + withBool(FlagIsActive, true). + withIsSet(FlagIsActive, true), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) svc.EXPECT().GetPlan(mock.Anything, "2").Return(&admin.QuotaPlan{}, nil) @@ -368,14 +161,13 @@ func TestQuotaPlansUpdate(t *testing.T) { wantErr: false, }, { - name: "success with is-active and is-default flags", - cmd: &mockQuotaPlansUpdateCmd{ - args: &mockArgs{args: []string{"2"}}, - isActive: true, - isDefault: true, - isSetActive: true, - isSetDefault: true, - }, + name: "success with is-active and is-default flags", + planID: "2", + cmd: newMockCommand(). + withBool(FlagIsActive, true). + withIsSet(FlagIsActive, true). + withBool(FlagIsDefault, true). + withIsSet(FlagIsDefault, true), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) svc.EXPECT().GetPlan(mock.Anything, "2").Return(&admin.QuotaPlan{}, nil) @@ -385,10 +177,9 @@ func TestQuotaPlansUpdate(t *testing.T) { wantErr: false, }, { - name: "missing plan ID", - cmd: &mockQuotaPlansUpdateCmd{ - args: &mockArgs{args: []string{}}, - }, + name: "missing plan ID", + planID: "", + cmd: newMockCommand(), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) }, @@ -396,12 +187,11 @@ func TestQuotaPlansUpdate(t *testing.T) { errContains: "plan ID is required", }, { - name: "is-default fails but update succeeds", - cmd: &mockQuotaPlansUpdateCmd{ - args: &mockArgs{args: []string{"2"}}, - isDefault: true, - isSetDefault: true, - }, + name: "is-default fails but update succeeds", + planID: "2", + cmd: newMockCommand(). + withBool(FlagIsDefault, true). + withIsSet(FlagIsDefault, true), setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { svc.EXPECT().RequireAuthenticated().Return(nil) svc.EXPECT().GetPlan(mock.Anything, "2").Return(&admin.QuotaPlan{}, nil) @@ -420,13 +210,18 @@ func TestQuotaPlansUpdate(t *testing.T) { tt.setupMocks(cfgMgr, service) - output := NewOutputFormatter(tt.jsonOutput, false, false, false) + output := newTestOutput() serviceFactory := func(cm config.Manager, out Output) QuotaAdminService { return service } - err := quotaPlansUpdateAction(context.Background(), tt.cmd, output, cfgMgr, serviceFactory) + cmd := tt.cmd + if tt.planID != "" { + cmd = cmd.withArgs(tt.planID) + } + + err := quotaPlansUpdateAction(context.Background(), cmd, output, cfgMgr, serviceFactory) if tt.wantErr { require.Error(t, err) diff --git a/pkg/cli/admin_quota_test.go b/pkg/cli/admin_quota_test.go index d8ef4c7..a219948 100644 --- a/pkg/cli/admin_quota_test.go +++ b/pkg/cli/admin_quota_test.go @@ -3,26 +3,17 @@ package cli import ( "context" "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" "go.lumeweb.com/portal-sdk/admin" ) -// Mock command getters for quota tests -type mockQuotaPlansGetCmd struct { - args cli.Args -} - -func (m *mockQuotaPlansGetCmd) Args() cli.Args { - return m.args -} - func TestQuotaPlansList(t *testing.T) { tests := []struct { name string @@ -92,8 +83,13 @@ func TestQuotaPlansList(t *testing.T) { } defer func() { quotaAdminServiceFactory = savedFactory }() - // Use an empty struct that implements quotaPlansListCmdGetter - err := quotaPlansListAction(context.Background(), struct{}{}, NewOutputFormatter(tt.jsonOutput, false, false, false), cfgMgr, quotaAdminServiceFactory) + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + + } + + err := quotaPlansListAction(context.Background(), output, cfgMgr, quotaAdminServiceFactory) if tt.wantErr { require.Error(t, err) @@ -180,13 +176,17 @@ func TestQuotaPlansGet(t *testing.T) { } defer func() { quotaAdminServiceFactory = savedFactory }() - args := &mockArgs{} + cmd := newMockCommand() if len(tt.args) > 0 { - args.args = tt.args + cmd = cmd.withArgs(tt.args...) + } + + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) } - cmd := &mockQuotaPlansGetCmd{args: args} - err := quotaPlansGetAction(context.Background(), cmd, NewOutputFormatter(tt.jsonOutput, false, false, false), cfgMgr, quotaAdminServiceFactory) + err := quotaPlansGetAction(context.Background(), cmd, output, cfgMgr, quotaAdminServiceFactory) if tt.wantErr { require.Error(t, err) @@ -253,13 +253,17 @@ func TestQuotaPlansDelete(t *testing.T) { } defer func() { quotaAdminServiceFactory = savedFactory }() - args := &mockArgs{} + cmd := newMockCommand() if len(tt.args) > 0 { - args.args = tt.args + cmd = cmd.withArgs(tt.args...) } - cmd := &mockQuotaPlansGetCmd{args: args} - err := quotaPlansDeleteAction(context.Background(), cmd, NewOutputFormatter(tt.jsonOutput, false, false, false), cfgMgr, quotaAdminServiceFactory) + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + err := quotaPlansDeleteAction(context.Background(), cmd, output, cfgMgr, quotaAdminServiceFactory) if tt.wantErr { require.Error(t, err) @@ -326,13 +330,17 @@ func TestQuotaPlansSetDefault(t *testing.T) { } defer func() { quotaAdminServiceFactory = savedFactory }() - args := &mockArgs{} + cmd := newMockCommand() if len(tt.args) > 0 { - args.args = tt.args + cmd = cmd.withArgs(tt.args...) + } + + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) } - cmd := &mockQuotaPlansGetCmd{args: args} - err := quotaPlansSetDefaultAction(context.Background(), cmd, NewOutputFormatter(tt.jsonOutput, false, false, false), cfgMgr, quotaAdminServiceFactory) + err := quotaPlansSetDefaultAction(context.Background(), cmd, output, cfgMgr, quotaAdminServiceFactory) if tt.wantErr { require.Error(t, err) @@ -397,7 +405,12 @@ func TestQuotaAllowancesList(t *testing.T) { } defer func() { quotaAdminServiceFactory = savedFactory }() - err := quotaAllowancesListAction(context.Background(), struct{}{}, NewOutputFormatter(tt.jsonOutput, false, false, false), cfgMgr, quotaAdminServiceFactory) + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + err := quotaAllowancesListAction(context.Background(), output, cfgMgr, quotaAdminServiceFactory) if tt.wantErr { require.Error(t, err) @@ -464,13 +477,17 @@ func TestQuotaAllowancesDelete(t *testing.T) { } defer func() { quotaAdminServiceFactory = savedFactory }() - args := &mockArgs{} + cmd := newMockCommand() if len(tt.args) > 0 { - args.args = tt.args + cmd = cmd.withArgs(tt.args...) } - cmd := &mockQuotaPlansGetCmd{args: args} - err := quotaAllowancesDeleteAction(context.Background(), cmd, NewOutputFormatter(tt.jsonOutput, false, false, false), cfgMgr, quotaAdminServiceFactory) + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + err := quotaAllowancesDeleteAction(context.Background(), cmd, output, cfgMgr, quotaAdminServiceFactory) if tt.wantErr { require.Error(t, err) @@ -535,7 +552,12 @@ func TestQuotaStats(t *testing.T) { } defer func() { quotaAdminServiceFactory = savedFactory }() - err := quotaStatsAction(context.Background(), struct{}{}, NewOutputFormatter(tt.jsonOutput, false, false, false), cfgMgr, quotaAdminServiceFactory) + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + err := quotaStatsAction(context.Background(), output, cfgMgr, quotaAdminServiceFactory) if tt.wantErr { require.Error(t, err) @@ -551,12 +573,12 @@ func TestQuotaStats(t *testing.T) { func TestQuotaCleanup(t *testing.T) { tests := []struct { - name string - retentionDays int64 - jsonOutput bool - setupMocks func(*configmocks.MockManager, *MockQuotaAdminService) - wantErr bool - errContains string + name string + retentionDays int64 + jsonOutput bool + setupMocks func(*configmocks.MockManager, *MockQuotaAdminService) + wantErr bool + errContains string }{ { name: "success", @@ -594,9 +616,14 @@ func TestQuotaCleanup(t *testing.T) { } defer func() { quotaAdminServiceFactory = savedFactory }() - cmdWrapper := &cleanupCmdMock{retentionDays: int(tt.retentionDays)} + cmd := newMockCommand().withInt("retention-days", int(tt.retentionDays)) + + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } - err := quotaCleanupAction(context.Background(), cmdWrapper, NewOutputFormatter(tt.jsonOutput, false, false, false), cfgMgr, quotaAdminServiceFactory) + err := quotaCleanupAction(context.Background(), cmd, output, cfgMgr, quotaAdminServiceFactory) if tt.wantErr { require.Error(t, err) @@ -610,17 +637,6 @@ func TestQuotaCleanup(t *testing.T) { } } -type cleanupCmdMock struct { - retentionDays int -} - -func (c *cleanupCmdMock) Int(name string) int { - if name == "retention-days" { - return c.retentionDays - } - return 0 -} - func TestQuotaUserConfigsList(t *testing.T) { tests := []struct { name string @@ -672,7 +688,12 @@ func TestQuotaUserConfigsList(t *testing.T) { } defer func() { quotaAdminServiceFactory = savedFactory }() - err := quotaUserConfigsListAction(context.Background(), struct{}{}, NewOutputFormatter(tt.jsonOutput, false, false, false), cfgMgr, quotaAdminServiceFactory) + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + err := quotaUserConfigsListAction(context.Background(), output, cfgMgr, quotaAdminServiceFactory) if tt.wantErr { require.Error(t, err) @@ -747,13 +768,88 @@ func TestQuotaUserConfigsReset(t *testing.T) { } defer func() { quotaAdminServiceFactory = savedFactory }() - args := &mockArgs{} + cmd := newMockCommand() + if len(tt.args) > 0 { + cmd = cmd.withArgs(tt.args...) + } + + output := newTestOutput() + if tt.jsonOutput { + output = NewOutputFormatter(true, false, false, false) + } + + err := quotaUserConfigsResetAction(context.Background(), cmd, output, cfgMgr, quotaAdminServiceFactory) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestQuotaPlansSetDefault_Enhanced(t *testing.T) { + tests := []struct { + name string + args []string + setupMocks func(*configmocks.MockManager, *MockQuotaAdminService) + wantErr bool + errContains string + }{ + { + name: "success", + args: []string{"2"}, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().SetDefaultPlan(mock.Anything, "2").Return(nil) + }, + wantErr: false, + }, + { + name: "plan not found with helpful error", + args: []string{"3"}, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().SetDefaultPlan(mock.Anything, "3").Return(fmt.Errorf("%w: plan not found", admin.ErrNotFound)) + }, + wantErr: true, + errContains: "ensure the plan is active", + }, + { + name: "other error passes through", + args: []string{"2"}, + setupMocks: func(cfgMgr *configmocks.MockManager, svc *MockQuotaAdminService) { + svc.EXPECT().RequireAuthenticated().Return(nil) + svc.EXPECT().SetDefaultPlan(mock.Anything, "2").Return(errors.New("server error")) + }, + wantErr: true, + errContains: "server error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockQuotaAdminService(t) + + tt.setupMocks(cfgMgr, service) + + output := newTestOutput() + + serviceFactory := func(cm config.Manager, out Output) QuotaAdminService { + return service + } + + cmd := newMockCommand() if len(tt.args) > 0 { - args.args = tt.args + cmd = cmd.withArgs(tt.args...) } - cmd := &mockQuotaPlansGetCmd{args: args} - err := quotaUserConfigsResetAction(context.Background(), cmd, NewOutputFormatter(tt.jsonOutput, false, false, false), cfgMgr, quotaAdminServiceFactory) + err := quotaPlansSetDefaultAction(context.Background(), cmd, output, cfgMgr, serviceFactory) if tt.wantErr { require.Error(t, err) @@ -766,3 +862,26 @@ func TestQuotaUserConfigsReset(t *testing.T) { }) } } + +func TestFormatBytes(t *testing.T) { + tests := []struct { + input int + expected string + }{ + {-1, "unlimited"}, + {0, "0 B"}, + {512, "512 B"}, + {1024, "1.00 KB"}, + {1536, "1.50 KB"}, + {1048576, "1.00 MB"}, + {1073741824, "1.00 GB"}, + {1099511627776, "1.00 TB"}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%d_bytes", tt.input), func(t *testing.T) { + result := formatBytes(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/cli/admin_service.go b/pkg/cli/admin_service.go index ecefa71..7e3e308 100644 --- a/pkg/cli/admin_service.go +++ b/pkg/cli/admin_service.go @@ -130,6 +130,17 @@ func with0[S any](svc authedService[S], ctx context.Context, fn func(S) error) e return fn(s) } +func with2i[S any](svc authedService[S], ctx context.Context, fn func(S) (int, error)) (int, error) { + if err := svc.RequireAuthenticated(); err != nil { + return 0, err + } + s, err := svc.getService(ctx) + if err != nil { + return 0, err + } + return fn(s) +} + // NewQuotaAdminService creates a new QuotaAdminService instance. func NewQuotaAdminService(cfgMgr config.Manager, output Output, apiEndpoint string) QuotaAdminService { return "aAdminService{ @@ -343,65 +354,41 @@ func (s *quotaAdminService) SetDefaultPlan(ctx context.Context, planID string) e // ListAllowances lists all quota allowances. func (s *quotaAdminService) ListAllowances(ctx context.Context) ([]*admin.QuotaAllowance, int, error) { - if err := s.RequireAuthenticated(); err != nil { - return nil, 0, err - } - svc, err := s.getService(ctx) - if err != nil { - return nil, 0, err - } - return svc.ListAllowances(ctx) + return with3(s, ctx, func(svc *admin.QuotaService) ([]*admin.QuotaAllowance, int, error) { + return svc.ListAllowances(ctx) + }) } // CreateAllowance creates a new quota allowance for a user. func (s *quotaAdminService) CreateAllowance(ctx context.Context, userID int, source, allowanceType string, upload, download, storage int, expiryDate time.Time) (*admin.QuotaAllowance, error) { - if err := s.RequireAuthenticated(); err != nil { - return nil, err - } - svc, err := s.getService(ctx) - if err != nil { - return nil, err - } - return svc.CreateAllowance(ctx, userID, source, allowanceType, upload, download, storage, expiryDate) + return with2(s, ctx, func(svc *admin.QuotaService) (*admin.QuotaAllowance, error) { + return svc.CreateAllowance(ctx, userID, source, allowanceType, upload, download, storage, expiryDate) + }) } // UpdateAllowance updates an existing quota allowance. func (s *quotaAdminService) UpdateAllowance(ctx context.Context, grantID string, userID int, source, allowanceType string, upload, download, storage int, expiryDate time.Time) (*admin.QuotaAllowance, error) { - if err := s.RequireAuthenticated(); err != nil { - return nil, err - } - svc, err := s.getService(ctx) - if err != nil { - return nil, err - } - return svc.UpdateAllowance(ctx, grantID, userID, source, allowanceType, upload, download, storage, expiryDate) + return with2(s, ctx, func(svc *admin.QuotaService) (*admin.QuotaAllowance, error) { + return svc.UpdateAllowance(ctx, grantID, userID, source, allowanceType, upload, download, storage, expiryDate) + }) } // DeleteAllowance deletes a quota allowance. func (s *quotaAdminService) DeleteAllowance(ctx context.Context, grantID string) error { - if err := s.RequireAuthenticated(); err != nil { - return err - } - svc, err := s.getService(ctx) - if err != nil { - return err - } - return svc.DeleteAllowance(ctx, grantID) + return with0(s, ctx, func(svc *admin.QuotaService) error { + return svc.DeleteAllowance(ctx, grantID) + }) } // GetStats retrieves system-wide quota statistics. func (s *quotaAdminService) GetStats(ctx context.Context) (*admin.SystemStats, error) { - if err := s.RequireAuthenticated(); err != nil { - return nil, err - } - svc, err := s.getService(ctx) - if err != nil { - return nil, err - } - return svc.GetStats(ctx) + return with2(s, ctx, func(svc *admin.QuotaService) (*admin.SystemStats, error) { + return svc.GetStats(ctx) + }) } // Reconcile performs quota reconciliation for users. +// NOTE: Reconcile returns (string, int, error) which doesn't fit with0/with2/with3/with2i. func (s *quotaAdminService) Reconcile(ctx context.Context, userID *int) (string, int, error) { if err := s.RequireAuthenticated(); err != nil { return "", 0, err @@ -415,50 +402,30 @@ func (s *quotaAdminService) Reconcile(ctx context.Context, userID *int) (string, // Cleanup performs quota cleanup based on retention policy. func (s *quotaAdminService) Cleanup(ctx context.Context, retentionDays int) (int, error) { - if err := s.RequireAuthenticated(); err != nil { - return 0, err - } - svc, err := s.getService(ctx) - if err != nil { - return 0, err - } - return svc.Cleanup(ctx, retentionDays) + return with2i(s, ctx, func(svc *admin.QuotaService) (int, error) { + return svc.Cleanup(ctx, retentionDays) + }) } // ListUserConfigs lists all user quota configurations with pagination. func (s *quotaAdminService) ListUserConfigs(ctx context.Context) ([]*admin.UserQuotaConfig, int, error) { - if err := s.RequireAuthenticated(); err != nil { - return nil, 0, err - } - svc, err := s.getService(ctx) - if err != nil { - return nil, 0, err - } - return svc.ListUserConfigs(ctx) + return with3(s, ctx, func(svc *admin.QuotaService) ([]*admin.UserQuotaConfig, int, error) { + return svc.ListUserConfigs(ctx) + }) } // UpdateUserConfig updates a user's quota configuration. func (s *quotaAdminService) UpdateUserConfig(ctx context.Context, userID int, config *admin.UserQuotaConfigUpdate) (*admin.UserQuotaConfig, error) { - if err := s.RequireAuthenticated(); err != nil { - return nil, err - } - svc, err := s.getService(ctx) - if err != nil { - return nil, err - } - return svc.UpdateUserConfig(ctx, userID, config) + return with2(s, ctx, func(svc *admin.QuotaService) (*admin.UserQuotaConfig, error) { + return svc.UpdateUserConfig(ctx, userID, config) + }) } // ResetUserPlan removes a user's assigned quota plan (sets to NULL). func (s *quotaAdminService) ResetUserPlan(ctx context.Context, userID int) error { - if err := s.RequireAuthenticated(); err != nil { - return err - } - svc, err := s.getService(ctx) - if err != nil { - return err - } - return svc.ResetUserPlan(ctx, userID) + return with0(s, ctx, func(svc *admin.QuotaService) error { + return svc.ResetUserPlan(ctx, userID) + }) } // getService returns the billing service, lazily initializing with token exchange if needed. diff --git a/pkg/cli/admin_service_delegates_test.go b/pkg/cli/admin_service_delegates_test.go new file mode 100644 index 0000000..4d843c3 --- /dev/null +++ b/pkg/cli/admin_service_delegates_test.go @@ -0,0 +1,473 @@ +package cli + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + +type mockAuthedService[S any] struct { + authErr error + getServiceFn func(ctx context.Context) (S, error) +} + +func (m *mockAuthedService[S]) RequireAuthenticated() error { + return m.authErr +} + +func (m *mockAuthedService[S]) getService(ctx context.Context) (S, error) { + return m.getServiceFn(ctx) +} + +func TestWith2(t *testing.T) { + t.Run("auth failure returns zero value and error", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: errors.New("not authenticated"), + } + + result, err := with2(mockSvc, context.Background(), func(s string) (int, error) { + return 42, nil + }) + + require.Error(t, err) + assert.Equal(t, "not authenticated", err.Error()) + assert.Equal(t, 0, result) + }) + + t.Run("getService failure returns zero value and error", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: nil, + getServiceFn: func(ctx context.Context) (string, error) { + return "", errors.New("service unavailable") + }, + } + + result, err := with2(mockSvc, context.Background(), func(s string) (int, error) { + return 42, nil + }) + + require.Error(t, err) + assert.Equal(t, "service unavailable", err.Error()) + assert.Equal(t, 0, result) + }) + + t.Run("success calls fn and returns result", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: nil, + getServiceFn: func(ctx context.Context) (string, error) { + return "service", nil + }, + } + + result, err := with2(mockSvc, context.Background(), func(s string) (int, error) { + assert.Equal(t, "service", s) + return 42, nil + }) + + require.NoError(t, err) + assert.Equal(t, 42, result) + }) +} + +func TestWith3(t *testing.T) { + t.Run("auth failure returns nil, 0, and error", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: errors.New("not authenticated"), + } + + result, count, err := with3(mockSvc, context.Background(), func(s string) ([]string, int, error) { + return []string{"a"}, 1, nil + }) + + require.Error(t, err) + assert.Equal(t, "not authenticated", err.Error()) + assert.Nil(t, result) + assert.Equal(t, 0, count) + }) + + t.Run("getService failure returns nil, 0, and error", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: nil, + getServiceFn: func(ctx context.Context) (string, error) { + return "", errors.New("service unavailable") + }, + } + + result, count, err := with3(mockSvc, context.Background(), func(s string) ([]string, int, error) { + return []string{"a"}, 1, nil + }) + + require.Error(t, err) + assert.Equal(t, "service unavailable", err.Error()) + assert.Nil(t, result) + assert.Equal(t, 0, count) + }) + + t.Run("success calls fn and returns result", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: nil, + getServiceFn: func(ctx context.Context) (string, error) { + return "service", nil + }, + } + + result, count, err := with3(mockSvc, context.Background(), func(s string) ([]string, int, error) { + assert.Equal(t, "service", s) + return []string{"a", "b"}, 2, nil + }) + + require.NoError(t, err) + assert.Equal(t, []string{"a", "b"}, result) + assert.Equal(t, 2, count) + }) +} + +func TestWith0(t *testing.T) { + t.Run("auth failure returns error", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: errors.New("not authenticated"), + } + + err := with0(mockSvc, context.Background(), func(s string) error { + return nil + }) + + require.Error(t, err) + assert.Equal(t, "not authenticated", err.Error()) + }) + + t.Run("getService failure returns error", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: nil, + getServiceFn: func(ctx context.Context) (string, error) { + return "", errors.New("service unavailable") + }, + } + + err := with0(mockSvc, context.Background(), func(s string) error { + return nil + }) + + require.Error(t, err) + assert.Equal(t, "service unavailable", err.Error()) + }) + + t.Run("success calls fn and returns nil", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: nil, + getServiceFn: func(ctx context.Context) (string, error) { + return "service", nil + }, + } + + called := false + err := with0(mockSvc, context.Background(), func(s string) error { + assert.Equal(t, "service", s) + called = true + return nil + }) + + require.NoError(t, err) + assert.True(t, called) + }) +} + +func TestWith2i(t *testing.T) { + t.Run("auth failure returns 0 and error", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: errors.New("not authenticated"), + } + + result, err := with2i(mockSvc, context.Background(), func(s string) (int, error) { + return 42, nil + }) + + require.Error(t, err) + assert.Equal(t, "not authenticated", err.Error()) + assert.Equal(t, 0, result) + }) + + t.Run("getService failure returns 0 and error", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: nil, + getServiceFn: func(ctx context.Context) (string, error) { + return "", errors.New("service unavailable") + }, + } + + result, err := with2i(mockSvc, context.Background(), func(s string) (int, error) { + return 42, nil + }) + + require.Error(t, err) + assert.Equal(t, "service unavailable", err.Error()) + assert.Equal(t, 0, result) + }) + + t.Run("success calls fn and returns result", func(t *testing.T) { + mockSvc := &mockAuthedService[string]{ + authErr: nil, + getServiceFn: func(ctx context.Context) (string, error) { + return "service", nil + }, + } + + result, err := with2i(mockSvc, context.Background(), func(s string) (int, error) { + assert.Equal(t, "service", s) + return 42, nil + }) + + require.NoError(t, err) + assert.Equal(t, 42, result) + }) +} + +func TestQuotaAdminService_Delegation(t *testing.T) { + ctx := context.Background() + + apiPurposeToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJhcGkifQ." + + newQuotaServiceWithGetServiceError := func(t *testing.T) *quotaAdminService { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: apiPurposeToken, + BaseEndpoint: "http://127.0.0.1:1", + }).Maybe() + + svc := NewQuotaAdminService(cfgMgr, nil, "http://127.0.0.1:1").(*quotaAdminService) + return svc + } + + t.Run("ListAllowances getService error", func(t *testing.T) { + svc := newQuotaServiceWithGetServiceError(t) + + result, count, err := svc.ListAllowances(ctx) + require.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, 0, count) + }) + + t.Run("CreateAllowance getService error", func(t *testing.T) { + svc := newQuotaServiceWithGetServiceError(t) + + result, err := svc.CreateAllowance(ctx, 1, "src", "type", 100, 100, 100, zeroTime()) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("UpdateAllowance getService error", func(t *testing.T) { + svc := newQuotaServiceWithGetServiceError(t) + + result, err := svc.UpdateAllowance(ctx, "grant-1", 1, "src", "type", 100, 100, 100, zeroTime()) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("DeleteAllowance getService error", func(t *testing.T) { + svc := newQuotaServiceWithGetServiceError(t) + + err := svc.DeleteAllowance(ctx, "grant-1") + require.Error(t, err) + }) + + t.Run("GetStats getService error", func(t *testing.T) { + svc := newQuotaServiceWithGetServiceError(t) + + result, err := svc.GetStats(ctx) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("Cleanup getService error", func(t *testing.T) { + svc := newQuotaServiceWithGetServiceError(t) + + result, err := svc.Cleanup(ctx, 30) + require.Error(t, err) + assert.Equal(t, 0, result) + }) + + t.Run("ListUserConfigs getService error", func(t *testing.T) { + svc := newQuotaServiceWithGetServiceError(t) + + result, count, err := svc.ListUserConfigs(ctx) + require.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, 0, count) + }) + + t.Run("UpdateUserConfig getService error", func(t *testing.T) { + svc := newQuotaServiceWithGetServiceError(t) + + result, err := svc.UpdateUserConfig(ctx, 1, nil) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("ResetUserPlan getService error", func(t *testing.T) { + svc := newQuotaServiceWithGetServiceError(t) + + err := svc.ResetUserPlan(ctx, 1) + require.Error(t, err) + }) + + t.Run("Reconcile not authenticated", func(t *testing.T) { + svc := newUnauthQuotaAdminService() + result, count, err := svc.Reconcile(ctx, nil) + require.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + assert.Equal(t, "", result) + assert.Equal(t, 0, count) + }) + + t.Run("Reconcile getService error", func(t *testing.T) { + svc := newQuotaServiceWithGetServiceError(t) + + result, count, err := svc.Reconcile(ctx, nil) + require.Error(t, err) + assert.Equal(t, "", result) + assert.Equal(t, 0, count) + }) +} + +func zeroTime() time.Time { + return time.Time{} +} + +func TestProfilingAdminService_Delegation(t *testing.T) { + ctx := context.Background() + + apiPurposeToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJhcGkifQ." + + newProfilingServiceWithGetServiceError := func(t *testing.T) *profilingAdminService { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: apiPurposeToken, + BaseEndpoint: "http://127.0.0.1:1", + }).Maybe() + + svc := NewProfilingAdminService(cfgMgr, nil, "http://127.0.0.1:1").(*profilingAdminService) + return svc + } + + t.Run("GetProfileIndex not authenticated", func(t *testing.T) { + svc := newUnauthProfilingAdminService() + result, err := svc.GetProfileIndex(ctx) + require.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + assert.Nil(t, result) + }) + + t.Run("GetProfileIndex getService error", func(t *testing.T) { + svc := newProfilingServiceWithGetServiceError(t) + result, err := svc.GetProfileIndex(ctx) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("GetBlockProfile not authenticated", func(t *testing.T) { + svc := newUnauthProfilingAdminService() + result, err := svc.GetBlockProfile(ctx) + require.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + assert.Nil(t, result) + }) + + t.Run("GetBlockProfile getService error", func(t *testing.T) { + svc := newProfilingServiceWithGetServiceError(t) + result, err := svc.GetBlockProfile(ctx) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("SetBlockProfileRate not authenticated", func(t *testing.T) { + svc := newUnauthProfilingAdminService() + err := svc.SetBlockProfileRate(ctx, 1) + require.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + }) + + t.Run("SetBlockProfileRate getService error", func(t *testing.T) { + svc := newProfilingServiceWithGetServiceError(t) + err := svc.SetBlockProfileRate(ctx, 1) + require.Error(t, err) + }) + + t.Run("GetCmdline not authenticated", func(t *testing.T) { + svc := newUnauthProfilingAdminService() + result, err := svc.GetCmdline(ctx) + require.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + assert.Nil(t, result) + }) + + t.Run("GetCmdline getService error", func(t *testing.T) { + svc := newProfilingServiceWithGetServiceError(t) + result, err := svc.GetCmdline(ctx) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("GetGoroutineProfile not authenticated", func(t *testing.T) { + svc := newUnauthProfilingAdminService() + result, err := svc.GetGoroutineProfile(ctx) + require.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + assert.Nil(t, result) + }) + + t.Run("GetGoroutineProfile getService error", func(t *testing.T) { + svc := newProfilingServiceWithGetServiceError(t) + result, err := svc.GetGoroutineProfile(ctx) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("GetHeapProfile not authenticated", func(t *testing.T) { + svc := newUnauthProfilingAdminService() + result, err := svc.GetHeapProfile(ctx) + require.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + assert.Nil(t, result) + }) + + t.Run("GetHeapProfile getService error", func(t *testing.T) { + svc := newProfilingServiceWithGetServiceError(t) + result, err := svc.GetHeapProfile(ctx) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("GetMutexProfile not authenticated", func(t *testing.T) { + svc := newUnauthProfilingAdminService() + result, err := svc.GetMutexProfile(ctx) + require.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + assert.Nil(t, result) + }) + + t.Run("GetMutexProfile getService error", func(t *testing.T) { + svc := newProfilingServiceWithGetServiceError(t) + result, err := svc.GetMutexProfile(ctx) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("SetMutexProfileFraction not authenticated", func(t *testing.T) { + svc := newUnauthProfilingAdminService() + err := svc.SetMutexProfileFraction(ctx, 1) + require.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + }) + + t.Run("SetMutexProfileFraction getService error", func(t *testing.T) { + svc := newProfilingServiceWithGetServiceError(t) + err := svc.SetMutexProfileFraction(ctx, 1) + require.Error(t, err) + }) +} diff --git a/pkg/cli/admin_service_test.go b/pkg/cli/admin_service_test.go index f63eb80..aa23c9b 100644 --- a/pkg/cli/admin_service_test.go +++ b/pkg/cli/admin_service_test.go @@ -50,7 +50,7 @@ func TestDefaultQuotaAdminServiceFactory(t *testing.T) { tt.setupMocks(cfgMgr) } - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := defaultQuotaAdminServiceFactory(cfgMgr, output) assert.NotNil(t, service) @@ -97,7 +97,7 @@ func TestDefaultBillingAdminServiceFactory(t *testing.T) { tt.setupMocks(cfgMgr) } - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := defaultBillingAdminServiceFactory(cfgMgr, output) assert.NotNil(t, service) @@ -133,7 +133,7 @@ func TestNewQuotaAdminService(t *testing.T) { AuthToken: tt.authToken, }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewQuotaAdminService(cfgMgr, output, tt.apiEndpoint) assert.NotNil(t, service) @@ -176,7 +176,7 @@ func TestNewBillingAdminService(t *testing.T) { AuthToken: tt.authToken, }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewBillingAdminService(cfgMgr, output, tt.apiEndpoint) assert.NotNil(t, service) @@ -220,7 +220,7 @@ func TestQuotaAdminService_RequireAuthenticated(t *testing.T) { Secure: true, }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewQuotaAdminService(cfgMgr, output, "https://api.test.com") err := service.RequireAuthenticated() @@ -266,7 +266,7 @@ func TestBillingAdminService_RequireAuthenticated(t *testing.T) { Secure: true, }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewBillingAdminService(cfgMgr, output, "https://api.test.com") err := service.RequireAuthenticated() @@ -291,7 +291,7 @@ func TestQuotaAdminService_HasTokenProvider(t *testing.T) { Secure: true, }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewQuotaAdminService(cfgMgr, output, "https://api.test.com") qs := service.(*quotaAdminService) @@ -306,7 +306,7 @@ func TestBillingAdminService_HasTokenProvider(t *testing.T) { Secure: true, }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewBillingAdminService(cfgMgr, output, "https://api.test.com") bs := service.(*billingAdminService) diff --git a/pkg/cli/admin_service_unauth_test.go b/pkg/cli/admin_service_unauth_test.go new file mode 100644 index 0000000..cfa221e --- /dev/null +++ b/pkg/cli/admin_service_unauth_test.go @@ -0,0 +1,193 @@ +package cli + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.lumeweb.com/portal-sdk/admin" +) + +func newUnauthQuotaAdminService() *quotaAdminService { + return "aAdminService{ + adminServiceBase: &adminServiceBase{authenticated: false}, + } +} + +func newUnauthBillingAdminService() *billingAdminService { + return &billingAdminService{ + adminServiceBase: &adminServiceBase{authenticated: false}, + } +} + +func newUnauthWebsiteAdminService() *websiteAdminService { + return &websiteAdminService{ + adminServiceBase: &adminServiceBase{authenticated: false}, + } +} + +func newUnauthProfilingAdminService() *profilingAdminService { + return &profilingAdminService{ + adminServiceBase: &adminServiceBase{authenticated: false}, + } +} + +func TestQuotaAdminService_Unauthenticated(t *testing.T) { + svc := newUnauthQuotaAdminService() + ctx := context.Background() + + _, _, err := svc.ListPlans(ctx) + require.Error(t, err) + _, err = svc.CreatePlan(ctx, &admin.QuotaPlan{}) + require.Error(t, err) + _, err = svc.GetPlan(ctx, "p1") + require.Error(t, err) + _, err = svc.UpdatePlan(ctx, "p1", &admin.QuotaPlan{}) + require.Error(t, err) + err = svc.DeletePlan(ctx, "p1") + require.Error(t, err) + err = svc.SetDefaultPlan(ctx, "p1") + require.Error(t, err) + _, _, err = svc.ListAllowances(ctx) + require.Error(t, err) + _, err = svc.CreateAllowance(ctx, 1, "src", "type", 100, 100, 100, time.Now()) + require.Error(t, err) + _, err = svc.UpdateAllowance(ctx, "g1", 1, "src", "type", 100, 100, 100, time.Now()) + require.Error(t, err) + err = svc.DeleteAllowance(ctx, "g1") + require.Error(t, err) + _, err = svc.GetStats(ctx) + require.Error(t, err) + _, _, err = svc.Reconcile(ctx, nil) + require.Error(t, err) + _, err = svc.Cleanup(ctx, 30) + require.Error(t, err) + _, err = svc.UpdateUserConfig(ctx, 1, &admin.UserQuotaConfigUpdate{}) + require.Error(t, err) + _, _, err = svc.ListUserConfigs(ctx) + require.Error(t, err) + err = svc.ResetUserPlan(ctx, 1) + require.Error(t, err) +} + +func TestBillingAdminService_Unauthenticated(t *testing.T) { + svc := newUnauthBillingAdminService() + ctx := context.Background() + + _, _, err := svc.ListCredits(ctx, nil) + require.Error(t, err) + _, err = svc.CreateCredit(ctx, &admin.CreditCreateRequest{}) + require.Error(t, err) + _, err = svc.GetCredit(ctx, "c1") + require.Error(t, err) + err = svc.DeleteCredit(ctx, "c1") + require.Error(t, err) + _, err = svc.RestoreCredit(ctx, "c1") + require.Error(t, err) + _, err = svc.PurgeCredits(ctx, &admin.CreditPurgeRequest{}) + require.Error(t, err) + _, err = svc.GetUserBalance(ctx, "1") + require.Error(t, err) + _, _, err = svc.GetUserDeletedCredits(ctx, "1", nil) + require.Error(t, err) + _, _, err = svc.ListPriceLines(ctx) + require.Error(t, err) + _, err = svc.CreatePriceLine(ctx, &admin.PriceLineCreateRequest{}) + require.Error(t, err) + _, err = svc.GetPriceLine(ctx, "pl1") + require.Error(t, err) + _, err = svc.UpdatePriceLine(ctx, "pl1", &admin.PriceLineUpdateRequest{}) + require.Error(t, err) + err = svc.DeletePriceLine(ctx, "pl1") + require.Error(t, err) + _, _, err = svc.ListPricingPlans(ctx) + require.Error(t, err) + _, err = svc.GetPricingPlan(ctx, "pp1") + require.Error(t, err) + _, err = svc.CreatePricingPlan(ctx, &admin.PricingPlanCreateRequest{}) + require.Error(t, err) + _, err = svc.UpdatePricingPlan(ctx, "pp1", &admin.PricingPlanUpdateRequest{}) + require.Error(t, err) + err = svc.DeletePricingPlan(ctx, "pp1") + require.Error(t, err) + _, _, err = svc.ListPricingPlanPeriods(ctx) + require.Error(t, err) + _, err = svc.CreatePricingPlanPeriod(ctx, &admin.PricingPlanPeriodCreateRequest{}) + require.Error(t, err) + _, err = svc.GetPricingPlanPeriod(ctx, "per1") + require.Error(t, err) + _, err = svc.UpdatePricingPlanPeriod(ctx, "per1", &admin.PricingPlanPeriodUpdateRequest{}) + require.Error(t, err) + err = svc.DeletePricingPlanPeriod(ctx, "per1") + require.Error(t, err) + _, _, err = svc.ListSubscribers(ctx) + require.Error(t, err) + _, err = svc.GetSubscriber(ctx, "sub1") + require.Error(t, err) + _, _, err = svc.ListGatewaySubscribers(ctx, "gw1") + require.Error(t, err) + _, _, err = svc.GetUserSubscribers(ctx, "1") + require.Error(t, err) + _, err = svc.CancelUserSubscription(ctx, "1", &admin.CancelSubscriptionRequest{}) + require.Error(t, err) + _, err = svc.AbortUserSubscriptionCancellation(ctx, "1") + require.Error(t, err) + _, err = svc.ChangeUserPlan(ctx, "1", &admin.ChangePlanRequest{}) + require.Error(t, err) +} + +func TestWebsiteAdminService_Unauthenticated(t *testing.T) { + svc := newUnauthWebsiteAdminService() + ctx := context.Background() + _, err := svc.BlockWebsite(ctx, "example.com") + require.Error(t, err) + _, err = svc.UnblockWebsite(ctx, "example.com") + require.Error(t, err) +} + +func TestProfilingAdminService_Unauthenticated(t *testing.T) { + svc := newUnauthProfilingAdminService() + ctx := context.Background() + _, err := svc.GetProfileIndex(ctx) + require.Error(t, err) + _, err = svc.GetBlockProfile(ctx) + require.Error(t, err) + err = svc.SetBlockProfileRate(ctx, 1) + require.Error(t, err) + _, err = svc.GetCmdline(ctx) + require.Error(t, err) + _, err = svc.GetGoroutineProfile(ctx) + require.Error(t, err) + _, err = svc.GetHeapProfile(ctx) + require.Error(t, err) + _, err = svc.GetMutexProfile(ctx) + require.Error(t, err) + err = svc.SetMutexProfileFraction(ctx, 1) + require.Error(t, err) + _, err = svc.GetCPUProfile(ctx) + require.Error(t, err) + _, err = svc.GetStatus(ctx) + require.Error(t, err) + _, err = svc.GetSymbol(ctx) + require.Error(t, err) + _, err = svc.GetThreadcreate(ctx) + require.Error(t, err) + _, err = svc.GetTrace(ctx) + require.Error(t, err) +} + +func TestAdminServiceBase_RequireAuthenticated(t *testing.T) { + t.Run("not authenticated", func(t *testing.T) { + base := &adminServiceBase{authenticated: false} + err := base.RequireAuthenticated() + require.Error(t, err) + assert.Equal(t, ErrNotAuthenticated, err) + }) + t.Run("authenticated", func(t *testing.T) { + base := &adminServiceBase{authenticated: true} + err := base.RequireAuthenticated() + require.NoError(t, err) + }) +} diff --git a/pkg/cli/admin_test.go b/pkg/cli/admin_test.go index 1fc836b..336eacb 100644 --- a/pkg/cli/admin_test.go +++ b/pkg/cli/admin_test.go @@ -24,7 +24,7 @@ func TestNewAdminCommand(t *testing.T) { require.NotNil(t, cmd.Commands) assert.Len(t, cmd.Commands, 4) - subcommandNames := getSubcommandNames(cmd.Commands) + subcommandNames := getSubcommandNames(cmd) assert.Contains(t, subcommandNames, "quota") assert.Contains(t, subcommandNames, "billing") assert.Contains(t, subcommandNames, "websites") @@ -48,7 +48,7 @@ func TestNewQuotaCommand(t *testing.T) { require.NotNil(t, cmd.Commands) assert.Len(t, cmd.Commands, 6) - subcommandNames := getSubcommandNames(cmd.Commands) + subcommandNames := getSubcommandNames(cmd) assert.Contains(t, subcommandNames, "plans") assert.Contains(t, subcommandNames, "allowances") assert.Contains(t, subcommandNames, "user-configs") @@ -74,7 +74,7 @@ func TestNewBillingCommand(t *testing.T) { require.NotNil(t, cmd.Commands) assert.Len(t, cmd.Commands, 6) - subcommandNames := getSubcommandNames(cmd.Commands) + subcommandNames := getSubcommandNames(cmd) assert.Contains(t, subcommandNames, "overview") assert.Contains(t, subcommandNames, "credits") assert.Contains(t, subcommandNames, "price-lines") @@ -99,7 +99,7 @@ func TestNewQuotaPlansCommand(t *testing.T) { require.NotNil(t, cmd.Commands) assert.Len(t, cmd.Commands, 6) - subcommandNames := getSubcommandNames(cmd.Commands) + subcommandNames := getSubcommandNames(cmd) assert.Contains(t, subcommandNames, "list") assert.Contains(t, subcommandNames, "get") assert.Contains(t, subcommandNames, "create") @@ -116,7 +116,7 @@ func TestNewQuotaPlansCommand(t *testing.T) { require.NotNil(t, createCmd.Flags) assert.Len(t, createCmd.Flags, 8) - flagNames := getFlagNames(createCmd.Flags) + flagNames := getFlagNames(createCmd) assert.Contains(t, flagNames, "name") assert.Contains(t, flagNames, "description") assert.Contains(t, flagNames, "upload-limit") @@ -143,7 +143,7 @@ func TestNewQuotaAllowancesCommand(t *testing.T) { require.NotNil(t, cmd.Commands) assert.Len(t, cmd.Commands, 4) - subcommandNames := getSubcommandNames(cmd.Commands) + subcommandNames := getSubcommandNames(cmd) assert.Contains(t, subcommandNames, "list") assert.Contains(t, subcommandNames, "create") assert.Contains(t, subcommandNames, "update") @@ -158,7 +158,7 @@ func TestNewQuotaAllowancesCommand(t *testing.T) { require.NotNil(t, createCmd.Flags) assert.Len(t, createCmd.Flags, 7) - flagNames := getFlagNames(createCmd.Flags) + flagNames := getFlagNames(createCmd) assert.Contains(t, flagNames, "user-id") assert.Contains(t, flagNames, "source") assert.Contains(t, flagNames, "quota-type") @@ -184,7 +184,7 @@ func TestNewQuotaUserConfigsCommand(t *testing.T) { require.NotNil(t, cmd.Commands) assert.Len(t, cmd.Commands, 3) - subcommandNames := getSubcommandNames(cmd.Commands) + subcommandNames := getSubcommandNames(cmd) assert.Contains(t, subcommandNames, "list") assert.Contains(t, subcommandNames, "update") assert.Contains(t, subcommandNames, "reset") @@ -262,7 +262,7 @@ func TestNewBillingPricingPlansCommand(t *testing.T) { require.NotNil(t, cmd.Commands) assert.Len(t, cmd.Commands, 7) - subcommandNames := getSubcommandNames(cmd.Commands) + subcommandNames := getSubcommandNames(cmd) assert.Contains(t, subcommandNames, "list") assert.Contains(t, subcommandNames, "get") assert.Contains(t, subcommandNames, "create") @@ -288,7 +288,7 @@ func TestNewBillingPricingPlanPeriodsCommand(t *testing.T) { require.NotNil(t, cmd.Commands) assert.Len(t, cmd.Commands, 5) - subcommandNames := getSubcommandNames(cmd.Commands) + subcommandNames := getSubcommandNames(cmd) assert.Contains(t, subcommandNames, "list") assert.Contains(t, subcommandNames, "get") assert.Contains(t, subcommandNames, "create") @@ -312,7 +312,7 @@ func TestNewBillingSubscribersCommand(t *testing.T) { require.NotNil(t, cmd.Commands) assert.Len(t, cmd.Commands, 9) - subcommandNames := getSubcommandNames(cmd.Commands) + subcommandNames := getSubcommandNames(cmd) assert.Contains(t, subcommandNames, "list") assert.Contains(t, subcommandNames, "get") assert.Contains(t, subcommandNames, "list-gateway") @@ -327,22 +327,6 @@ func TestNewBillingSubscribersCommand(t *testing.T) { // Helper functions -func getSubcommandNames(commands []*cli.Command) []string { - names := make([]string, len(commands)) - for i, cmd := range commands { - names[i] = cmd.Name - } - return names -} - -func getFlagNames(flags []cli.Flag) []string { - names := make([]string, len(flags)) - for i, flag := range flags { - names[i] = flag.Names()[0] - } - return names -} - func findSubcommand(commands []*cli.Command, name string) *cli.Command { for _, cmd := range commands { if cmd.Name == name { diff --git a/pkg/cli/admin_token_provider_test.go b/pkg/cli/admin_token_provider_test.go new file mode 100644 index 0000000..3147c76 --- /dev/null +++ b/pkg/cli/admin_token_provider_test.go @@ -0,0 +1,79 @@ +package cli + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + +func TestNewAdminTokenProvider(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "test-token", + }).Maybe() + + provider := NewAdminTokenProvider(cfgMgr) + require.NotNil(t, provider) + assert.Equal(t, "test-token", provider.cfgMgr.Config().AuthToken) +} + +func TestAdminTokenProvider_Invalidate(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: "token"}).Maybe() + + provider := &AdminTokenProvider{ + cfgMgr: cfgMgr, + baseToken: "old-token", + loginToken: "old-login", + } + + provider.Invalidate() + assert.Equal(t, "", provider.loginToken) + assert.Equal(t, "", provider.baseToken) +} + +func TestAdminTokenProvider_GetLoginToken_NonAPIKeyJWT(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: "token"}).Maybe() + + provider := &AdminTokenProvider{ + cfgMgr: cfgMgr, + apiEndpoint: "http://localhost:8080", + } + + token, err := provider.GetLoginToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "token", token) +} + +func TestAdminTokenProvider_GetLoginToken_Cached(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: "same-token"}).Maybe() + + provider := &AdminTokenProvider{ + cfgMgr: cfgMgr, + baseToken: "same-token", + loginToken: "cached-login", + } + + token, err := provider.GetLoginToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "cached-login", token) +} + +func TestAdminTokenProvider_GetLoginToken_EmptyToken(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: ""}).Maybe() + + provider := &AdminTokenProvider{ + cfgMgr: cfgMgr, + } + + token, err := provider.GetLoginToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "", token) +} diff --git a/pkg/cli/admin_websites.go b/pkg/cli/admin_websites.go index ac23a65..451b67d 100644 --- a/pkg/cli/admin_websites.go +++ b/pkg/cli/admin_websites.go @@ -39,16 +39,12 @@ Examples: if err != nil { return err } - return adminWebsitesBlockAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultWebsiteAdminServiceFactory) + return adminWebsitesBlockAction(ctx, cmd, output, cfgMgr, defaultWebsiteAdminServiceFactory) }, } } -type adminWebsitesBlockCmdGetter interface { - Args() cli.Args -} - -func adminWebsitesBlockAction(ctx context.Context, cmd adminWebsitesBlockCmdGetter, output Output, cfgMgr config.Manager, serviceFactory WebsiteAdminServiceFactory) error { +func adminWebsitesBlockAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory WebsiteAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("website ID is required") } @@ -88,16 +84,12 @@ Examples: if err != nil { return err } - return adminWebsitesUnblockAction(ctx, newCLICommandWrapper(cmd), output, cfgMgr, defaultWebsiteAdminServiceFactory) + return adminWebsitesUnblockAction(ctx, cmd, output, cfgMgr, defaultWebsiteAdminServiceFactory) }, } } -type adminWebsitesUnblockCmdGetter interface { - Args() cli.Args -} - -func adminWebsitesUnblockAction(ctx context.Context, cmd adminWebsitesUnblockCmdGetter, output Output, cfgMgr config.Manager, serviceFactory WebsiteAdminServiceFactory) error { +func adminWebsitesUnblockAction(ctx context.Context, cmd argsGetter, output Output, cfgMgr config.Manager, serviceFactory WebsiteAdminServiceFactory) error { if cmd.Args().Len() < 1 { return fmt.Errorf("website ID is required") } diff --git a/pkg/cli/admin_websites_test.go b/pkg/cli/admin_websites_test.go index 24fb158..2d44fc8 100644 --- a/pkg/cli/admin_websites_test.go +++ b/pkg/cli/admin_websites_test.go @@ -62,7 +62,7 @@ func TestAdminWebsitesBlock(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockWebsiteAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) @@ -149,7 +149,7 @@ func TestAdminWebsitesUnblock(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockWebsiteAdminService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) diff --git a/pkg/cli/auth.go b/pkg/cli/auth.go index 3f5f5f1..ddf7691 100644 --- a/pkg/cli/auth.go +++ b/pkg/cli/auth.go @@ -42,12 +42,6 @@ func runPrompt(fn func() (string, error)) (string, error) { return result, nil } -// commandGetter defines the interface for getting command flags and arguments. -type commandGetter interface { - String(name string) string - Bool(name string) bool -} - // AuthPrompter defines the interface for interactive user input. type AuthPrompter interface { // PromptEmail prompts for and validates an email address. @@ -258,13 +252,13 @@ type ConfigManagerFactory func() (config.Manager, error) type AuthServiceFactory func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService // authLogin handles authentication with interactive, semi-interactive, and non-interactive modes. -func authLogin(ctx context.Context, cmd commandGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory) error { +func authLogin(ctx context.Context, cmd flagGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory) error { return authLoginWithFactories(ctx, cmd, output, cfgMgrFactory, authServiceFactory, nil) } // authLoginWithFactories is the testable implementation of authLogin with prompter injection. // The factories and prompter allow dependency injection for testing. -func authLoginWithFactories(ctx context.Context, cmd commandGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, prompter AuthPrompter) error { +func authLoginWithFactories(ctx context.Context, cmd flagGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, prompter AuthPrompter) error { email := cmd.String("email") password := cmd.String("password") otpCode := cmd.String("otp-code") @@ -391,13 +385,13 @@ Examples: pinner auth status --verbose`, Action: func(ctx context.Context, cmd *cli.Command) error { output := setupOutput(cmd) - return authStatus(ctx, cmd, output, defaultConfigManagerFactory, defaultAuthServiceFactory) + return authStatus(ctx, output, defaultConfigManagerFactory, defaultAuthServiceFactory) }, } } // authStatus checks if the user is authenticated. -func authStatus(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory) error { +func authStatus(ctx context.Context, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory) error { cfgMgr, err := cfgMgrFactory() if err != nil { return fmt.Errorf("failed to initialize config manager: %w", err) diff --git a/pkg/cli/auth_service_test.go b/pkg/cli/auth_service_test.go index 1f962c6..65a8848 100644 --- a/pkg/cli/auth_service_test.go +++ b/pkg/cli/auth_service_test.go @@ -6,7 +6,9 @@ import ( "errors" "testing" + "github.com/golang-jwt/jwt/v5" mock "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.lumeweb.com/pinner-cli/pkg/config" configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" @@ -70,7 +72,7 @@ func TestAuthService_LoginCheck(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) acc := portalsdkmocks.NewMockAccountAPI(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(acc) @@ -205,7 +207,7 @@ func TestAuthService_CompleteLogin(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) acc := portalsdkmocks.NewMockAccountAPI(t) authAcc := portalsdkmocks.NewMockAccountAPI(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, acc, authAcc) @@ -271,7 +273,7 @@ func TestAuthService_SaveToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr) @@ -294,7 +296,7 @@ func TestAuthService_SaveToken(t *testing.T) { func TestAuthService_GetAPIEndpoint(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() authService := NewAuthService(cfgMgr, output, "https://api.test.com") require.Equal(t, "https://api.test.com", authService.GetAPIEndpoint()) @@ -341,7 +343,7 @@ func TestAuthService_CompleteLogin_JSONOutput(t *testing.T) { func TestNewAuthService(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() authService := NewAuthService(cfgMgr, output, "https://api.test.com") @@ -431,7 +433,7 @@ func TestInteractiveLogin(t *testing.T) { tt.setupMocks(prompter, authService) } - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() err := interactiveLogin(context.Background(), authService, output, tt.keyName, tt.noCreateKey, tt.force, prompter) if tt.wantErr { @@ -538,16 +540,15 @@ func TestAuthLogin(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) authService := NewMockAuthService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() // Create a mock cli.Command - cmd := &mockCommand{ - email: tt.email, - password: tt.password, - keyName: tt.keyName, - noCreateKey: tt.noCreateKey, - force: tt.force, - } + cmd := newMockCommand(). + withString(FlagEmail, tt.email). + withString(FlagPassword, tt.password). + withString(FlagKeyName, tt.keyName). + withBool(FlagNoCreateKey, tt.noCreateKey). + withBool(FlagForce, tt.force) // Setup config manager factory var cfgMgrFactory ConfigManagerFactory @@ -600,41 +601,7 @@ func TestAuthLogin(t *testing.T) { } } -// mockCommand is a mock implementation of commandGetter for testing. -type mockCommand struct { - email string - password string - otpCode string - keyName string - noCreateKey bool - force bool -} -func (m *mockCommand) String(name string) string { - switch name { - case FlagEmail: - return m.email - case FlagPassword: - return m.password - case FlagOTPCode: - return m.otpCode - case FlagKeyName: - return m.keyName - default: - return "" - } -} - -func (m *mockCommand) Bool(name string) bool { - switch name { - case FlagNoCreateKey: - return m.noCreateKey - case FlagForce: - return m.force - default: - return false - } -} func TestAuthService_LoginWithOTP(t *testing.T) { tests := []struct { @@ -716,7 +683,7 @@ func TestAuthService_LoginWithOTP(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) acc := portalsdkmocks.NewMockAccountAPI(t) authAcc := portalsdkmocks.NewMockAccountAPI(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, acc, authAcc) @@ -789,7 +756,7 @@ func TestSaveAuthToken(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) authService := NewMockAuthService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() // Setup config manager factory var cfgMgrFactory ConfigManagerFactory @@ -825,3 +792,199 @@ func TestSaveAuthToken(t *testing.T) { }) } } + +func TestAuthServiceDefault_Register(t *testing.T) { + tests := []struct { + name string + email string + firstName string + lastName string + password string + setupMocks func(*portalsdkmocks.MockAccountAPI) + wantErr bool + errContains string + }{ + { + name: "successful registration", + email: "test@example.com", + firstName: "John", + lastName: "Doe", + password: "password123", + setupMocks: func(acc *portalsdkmocks.MockAccountAPI) { + acc.EXPECT().Register(context.Background(), "test@example.com", "John", "Doe", "password123"). + Return(nil) + }, + wantErr: false, + }, + { + name: "registration fails with service error", + email: "test@example.com", + firstName: "John", + lastName: "Doe", + password: "password123", + setupMocks: func(acc *portalsdkmocks.MockAccountAPI) { + acc.EXPECT().Register(context.Background(), "test@example.com", "John", "Doe", "password123"). + Return(portalsdk.ErrUnauthorized) + }, + wantErr: true, + errContains: "registration failed", + }, + { + name: "registration fails with network error", + email: "test@example.com", + firstName: "Jane", + lastName: "Smith", + password: "secret", + setupMocks: func(acc *portalsdkmocks.MockAccountAPI) { + acc.EXPECT().Register(context.Background(), "test@example.com", "Jane", "Smith", "secret"). + Return(errors.New("connection refused")) + }, + wantErr: true, + errContains: "registration failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + acc := portalsdkmocks.NewMockAccountAPI(t) + output := newTestOutput() + + if tt.setupMocks != nil { + tt.setupMocks(acc) + } + + authService := NewAuthService(cfgMgr, output, "https://api.test.com", + WithAuthAccountClient(acc), + ) + + err := authService.Register(context.Background(), tt.email, tt.firstName, tt.lastName, tt.password) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + require.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestAuthServiceDefault_GetLoginToken(t *testing.T) { + // Helper to create a signed JWT with a given audience (purpose) + makeJWT := func(audience string) string { + claims := jwt.RegisteredClaims{ + Audience: []string{audience}, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := token.SignedString([]byte("test-secret")) + require.NoError(t, err) + return signed + } + + tests := []struct { + name string + setupMocks func(*configmocks.MockManager, *portalsdkmocks.MockAccountAPI) + wantErr bool + errContains string + wantToken string + }{ + { + name: "login JWT returned directly", + setupMocks: func(cfgMgr *configmocks.MockManager, acc *portalsdkmocks.MockAccountAPI) { + loginJWT := makeJWT("login") + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: loginJWT}) + }, + wantErr: false, + wantToken: makeJWT("login"), + }, + { + name: "API key JWT exchanged for login token", + setupMocks: func(cfgMgr *configmocks.MockManager, acc *portalsdkmocks.MockAccountAPI) { + apiKeyJWT := makeJWT("api") + loginJWT := makeJWT("login") + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: apiKeyJWT}) + acc.EXPECT().LoginWithAPIKey(context.Background(), apiKeyJWT). + Return(loginJWT, nil) + }, + wantErr: false, + wantToken: makeJWT("login"), + }, + { + name: "empty token returns not authenticated", + setupMocks: func(cfgMgr *configmocks.MockManager, acc *portalsdkmocks.MockAccountAPI) { + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: ""}) + }, + wantErr: true, + errContains: "not authenticated", + }, + { + name: "invalid JWT treated as login token", + setupMocks: func(cfgMgr *configmocks.MockManager, acc *portalsdkmocks.MockAccountAPI) { + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: "not-a-valid-jwt"}) + }, + wantErr: false, + wantToken: "not-a-valid-jwt", + }, + { + name: "API key exchange fails", + setupMocks: func(cfgMgr *configmocks.MockManager, acc *portalsdkmocks.MockAccountAPI) { + apiKeyJWT := makeJWT("api") + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: apiKeyJWT}) + acc.EXPECT().LoginWithAPIKey(context.Background(), apiKeyJWT). + Return("", errors.New("API key expired")) + }, + wantErr: true, + errContains: "failed to authenticate with API key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + acc := portalsdkmocks.NewMockAccountAPI(t) + output := newTestOutput() + + if tt.setupMocks != nil { + tt.setupMocks(cfgMgr, acc) + } + + authService := NewAuthService(cfgMgr, output, "https://api.test.com", + WithAuthAccountClient(acc), + ) + + token, err := authService.GetLoginToken(context.Background()) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + require.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + require.Equal(t, tt.wantToken, token) + } + }) + } +} + +func TestGetJWTPurpose(t *testing.T) { + t.Run("invalid token returns error", func(t *testing.T) { + _, err := GetJWTPurpose("not-a-jwt") + require.Error(t, err) + }) + + t.Run("token with no audience returns empty", func(t *testing.T) { + claims := jwt.RegisteredClaims{ + Subject: "test", + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := token.SignedString([]byte("secret")) + require.NoError(t, err) + purpose, err := GetJWTPurpose(signed) + require.NoError(t, err) + assert.Equal(t, "", purpose) + }) +} diff --git a/pkg/cli/auth_status_test.go b/pkg/cli/auth_status_test.go index ba2f4e6..ecba3be 100644 --- a/pkg/cli/auth_status_test.go +++ b/pkg/cli/auth_status_test.go @@ -5,8 +5,8 @@ import ( "errors" "testing" + "github.com/manifoldco/promptui" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" portalsdk "go.lumeweb.com/portal-sdk" @@ -60,7 +60,7 @@ func TestAuthStatus(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) authService := NewMockAuthService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() var cfgMgrFactory ConfigManagerFactory if tt.name == "config manager factory fails" { @@ -81,9 +81,7 @@ func TestAuthStatus(t *testing.T) { tt.setupMocks(cfgMgr, authService) } - cmd := &cli.Command{} - - err := authStatus(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory) + err := authStatus(context.Background(), output, cfgMgrFactory, authServiceFactory) if tt.wantErr { require.Error(t, err) @@ -135,7 +133,7 @@ func TestAuthService_Status(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) acc := portalsdkmocks.NewMockAccountAPI(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() // Mock Config() to return a config with a login JWT and portal URL cfg := config.NewConfig() @@ -297,9 +295,7 @@ func TestAuthStatusCommand(t *testing.T) { tt.setupMocks(cfgMgr, authService) } - cmd := &cli.Command{} - - err := authStatus(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory) + err := authStatus(context.Background(), output, cfgMgrFactory, authServiceFactory) if tt.wantErr { require.Error(t, err) @@ -312,3 +308,178 @@ func TestAuthStatusCommand(t *testing.T) { }) } } + +func TestHandleInterrupt(t *testing.T) { + t.Run("returns cancelled error for interrupt", func(t *testing.T) { + err := handleInterrupt(promptui.ErrInterrupt) + require.Error(t, err) + require.Contains(t, err.Error(), "cancelled") + }) + + t.Run("returns original error for non-interrupt", func(t *testing.T) { + origErr := errors.New("some error") + err := handleInterrupt(origErr) + require.Error(t, err) + require.Equal(t, origErr, err) + }) + + t.Run("returns nil for nil", func(t *testing.T) { + err := handleInterrupt(nil) + require.NoError(t, err) + }) +} + +func TestAuthLogin_MockCommand_WithEmailPassword(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().LoginCheck(context.Background(), "user@example.com", "secret"). + Return(&portalsdk.LoginResult{Token: "jwt-token", OTPRequired: false}, nil) + authService.EXPECT().CompleteLogin(context.Background(), "jwt-token", "cli-generated", false).Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + cmd := newMockCommand(). + withString("email", "user@example.com"). + withString("password", "secret"). + withString("key-name", "cli-generated"). + withBool("no-create-key", false). + withBool("force", false) + + err := authLoginWithFactories(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory, nil) + require.NoError(t, err) +} + +func TestAuthLogin_MockCommand_WithOTPFlow(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().LoginCheck(context.Background(), "user@example.com", "secret"). + Return(&portalsdk.LoginResult{IntermediateJWT: "intermediate-jwt", OTPRequired: true}, nil) + authService.EXPECT().LoginWithOTP(context.Background(), "intermediate-jwt", "123456", "cli-generated", false).Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + cmd := newMockCommand(). + withString("email", "user@example.com"). + withString("password", "secret"). + withString("otp-code", "123456"). + withString("key-name", "cli-generated"). + withBool("no-create-key", false). + withBool("force", false) + + err := authLoginWithFactories(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory, nil) + require.NoError(t, err) +} + +func TestAuthLogin_MockCommand_LoginCheckError(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().LoginCheck(context.Background(), "user@example.com", "wrong"). + Return(nil, errors.New("invalid credentials")) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + cmd := newMockCommand(). + withString("email", "user@example.com"). + withString("password", "wrong"). + withString("key-name", "cli-generated"). + withBool("no-create-key", false). + withBool("force", false) + + err := authLoginWithFactories(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory, nil) + require.Error(t, err) +} + +func TestAuthLogin_MockCommand_ConfigError(t *testing.T) { + output := newTestOutput() + + cmd := newMockCommand(). + withString("email", "user@example.com"). + withString("password", "secret") + + err := authLoginWithFactories(context.Background(), cmd, output, failingConfigMgrFactory(), + func(cm config.Manager, out Output, apiEndpoint string) AuthService { return nil }, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to initialize config manager") +} + +func TestAuthLogin_MockCommand_NoCreateKey(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().LoginCheck(context.Background(), "user@example.com", "secret"). + Return(&portalsdk.LoginResult{Token: "jwt-token", OTPRequired: false}, nil) + authService.EXPECT().CompleteLogin(context.Background(), "jwt-token", "cli-generated", true).Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + cmd := newMockCommand(). + withString("email", "user@example.com"). + withString("password", "secret"). + withString("key-name", "cli-generated"). + withBool("no-create-key", true). + withBool("force", false) + + err := authLoginWithFactories(context.Background(), cmd, output, cfgMgrFactory, authServiceFactory, nil) + require.NoError(t, err) +} + +func TestSaveAuthTokenWithFactories_Success(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().SaveToken("my-jwt-token").Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + err := saveAuthTokenWithFactories(output, "my-jwt-token", cfgMgrFactory, authServiceFactory) + require.NoError(t, err) +} + +func TestSaveAuthTokenWithFactories_ConfigError(t *testing.T) { + output := newTestOutput() + + err := saveAuthTokenWithFactories(output, "my-jwt-token", failingConfigMgrFactory(), + func(cm config.Manager, out Output, apiEndpoint string) AuthService { return nil }) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to initialize config manager") +} + +func TestSaveAuthTokenWithFactories_SaveError(t *testing.T) { + authService := NewMockAuthService(t) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() + + authService.EXPECT().SaveToken("bad-token").Return(errors.New("invalid token format")) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + authServiceFactory := func(cm config.Manager, out Output, apiEndpoint string) AuthService { + return authService + } + + err := saveAuthTokenWithFactories(output, "bad-token", cfgMgrFactory, authServiceFactory) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid token format") +} diff --git a/pkg/cli/bench.go b/pkg/cli/bench.go index 141a805..7a97024 100644 --- a/pkg/cli/bench.go +++ b/pkg/cli/bench.go @@ -7,6 +7,7 @@ import ( "github.com/docker/go-units" "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" ) // Bench flag constants @@ -71,7 +72,13 @@ min/max/avg/median statistics across all iterations.`, }, Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return bench(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + secure := GetSecureSetting(c, cfgMgr) + return bench(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure) }, } } @@ -129,23 +136,18 @@ func BenchPollIntervalFlag() *cli.DurationFlag { } } -// benchCommandGetter defines the interface for getting bench command flags. -type benchCommandGetter interface { - String(name string) string - Int(name string) int +type benchAuthServiceFactoryFunc func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService + +var benchAuthServiceFactory benchAuthServiceFactoryFunc = func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService { + return NewAuthService(cfgMgr, output, apiEndpoint) +} + +func bench(ctx context.Context, cmd interface { + argsFlagGetter Int64(name string) int64 - Bool(name string) bool Uint64(name string) uint64 Duration(name string) time.Duration - Args() cli.Args -} - -func bench(ctx context.Context, cmd benchCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - +}, output Output, cfgMgr config.Manager, authToken string, secure bool) error { // Parse size flag sizeStr := cmd.String(FlagBenchSize) sizeBytes, err := units.RAMInBytes(sizeStr) @@ -161,19 +163,13 @@ func bench(ctx context.Context, cmd benchCommandGetter, output Output, cfgMgrFac // Create services var pinningService PinningService - if c, ok := cmd.(*cliCommandWrapper); ok { - secure := GetSecureSetting(c.Command, cfgMgr) - authToken := GetAuthToken(c.Command, cfgMgr) - if authToken != "" { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) - } else { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } + if authToken != "" { + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointSecure()) + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) } - authService := NewAuthService(cfgMgr, output, cfgMgr.Config().GetAccountEndpointSecure()) + authService := benchAuthServiceFactory(cfgMgr, output, cfgMgr.Config().GetAccountEndpointSecure()) var svcOpts []UploadServiceOption svcOpts = append(svcOpts, WithMemoryLimit(memoryLimit), WithUploadAuthService(authService)) @@ -271,7 +267,7 @@ func bench(ctx context.Context, cmd benchCommandGetter, output Output, cfgMgrFac return nil } - benchService := NewBenchService(cfgMgr, output, uploadService, pinningService, accountClient) + benchService := defaultBenchServiceFactory(cfgMgr, output, uploadService, pinningService, accountClient) result, err := benchService.Run(ctx, opts) if err != nil { diff --git a/pkg/cli/bench_service.go b/pkg/cli/bench_service.go index 157a60a..bf8440a 100644 --- a/pkg/cli/bench_service.go +++ b/pkg/cli/bench_service.go @@ -126,6 +126,8 @@ type BenchService interface { // BenchServiceFactory creates a BenchService with dependencies. type BenchServiceFactory func(cfgMgr config.Manager, output Output, uploadService UploadService, pinningService PinningService, accountClient portalsdk.AccountAPI) BenchService +var defaultBenchServiceFactory BenchServiceFactory = NewBenchService + // BenchServiceDefault provides benchmark operations. type BenchServiceDefault struct { configMgr config.Manager diff --git a/pkg/cli/bench_test.go b/pkg/cli/bench_test.go index d622891..e6a758a 100644 --- a/pkg/cli/bench_test.go +++ b/pkg/cli/bench_test.go @@ -2,13 +2,15 @@ package cli import ( "context" + "errors" "fmt" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" + portalsdk "go.lumeweb.com/portal-sdk" portalsdkmocks "go.lumeweb.com/portal-sdk/mocks" "go.lumeweb.com/pinner-cli/pkg/config" configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" @@ -60,6 +62,240 @@ func TestBenchFilePath(t *testing.T) { } } +type mockBenchServiceForCLI struct { + runFunc func(ctx context.Context, opts BenchOptions) (*BenchResult, error) + requireAuthenticatedFn func() error +} + +func (m *mockBenchServiceForCLI) Run(ctx context.Context, opts BenchOptions) (*BenchResult, error) { + if m.runFunc != nil { + return m.runFunc(ctx, opts) + } + return &BenchResult{}, nil +} + +func (m *mockBenchServiceForCLI) RequireAuthenticated() error { + if m.requireAuthenticatedFn != nil { + return m.requireAuthenticatedFn() + } + return nil +} + +func setupBenchHandlerTest(t *testing.T) (*mockBenchServiceForCLI, *configmocks.MockManager) { + t.Helper() + + mockSvc := &mockBenchServiceForCLI{} + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Maybe().Return(&config.Config{ + BaseEndpoint: "pinner.xyz", + Secure: true, + AuthToken: "test-token", + MemoryLimit: 100, + }).Maybe() + + origBenchFactory := defaultBenchServiceFactory + origAuthFactory := benchAuthServiceFactory + t.Cleanup(func() { + defaultBenchServiceFactory = origBenchFactory + benchAuthServiceFactory = origAuthFactory + }) + + mockAuthSvc := NewMockAuthService(t) + mockAuthSvc.EXPECT().GetAuthenticatedClient(mock.Anything).Maybe().Return(portalsdkmocks.NewMockAccountAPI(t), nil) + + benchAuthServiceFactory = func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService { + return mockAuthSvc + } + + defaultBenchServiceFactory = func(cfgMgr config.Manager, output Output, uploadService UploadService, pinningService PinningService, accountClient portalsdk.AccountAPI) BenchService { + return mockSvc + } + + return mockSvc, cfgMgr +} + +func defaultBenchCmd() *mockCommand { + return newMockCommand(). + withString(FlagBenchSize, "1MB"). + withInt(FlagBenchFiles, 1). + withInt(FlagBenchDepth, 0). + withInt(FlagBenchIterations, 1). + withInt(FlagParallel, 1). + withBool(FlagBenchNoCleanup, false). + withDuration(FlagBenchPollInterval, 500*time.Millisecond). + withUint64(FlagMemoryLimit, 100). + withBool(FlagDryRun, false) +} + +func TestBenchHandler_Success(t *testing.T) { + mockSvc, cfgMgr := setupBenchHandlerTest(t) + mockSvc.runFunc = func(ctx context.Context, opts BenchOptions) (*BenchResult, error) { + assert.Equal(t, int64(1048576), opts.SizeBytes) + assert.Equal(t, 1, opts.Files) + assert.Equal(t, 1, opts.Iterations) + return &BenchResult{ + Input: BenchInput{Type: "random", Size: 1048576, Files: 1}, + Iterations: []BenchIteration{{Number: 1, CID: "QmTest", Size: 1048576, Total: 100 * time.Millisecond}}, + Summary: BenchSummary{TotalDuration: 100 * time.Millisecond, UploadDuration: 100 * time.Millisecond}, + }, nil + } + + output := newTestOutput() + cmd := defaultBenchCmd() + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.NoError(t, err) +} + +func TestBenchHandler_InvalidSize(t *testing.T) { + _, cfgMgr := setupBenchHandlerTest(t) + + output := newTestOutput() + cmd := defaultBenchCmd().withString(FlagBenchSize, "notasize") + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid size") +} + +func TestBenchHandler_ZeroSizeNoPath(t *testing.T) { + _, cfgMgr := setupBenchHandlerTest(t) + + output := newTestOutput() + cmd := defaultBenchCmd().withString(FlagBenchSize, "0B") + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.Error(t, err) + assert.Contains(t, err.Error(), "size must be positive or provide a path") +} + +func TestBenchHandler_ZeroFiles(t *testing.T) { + _, cfgMgr := setupBenchHandlerTest(t) + + output := newTestOutput() + cmd := defaultBenchCmd().withInt(FlagBenchFiles, 0) + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.Error(t, err) + assert.Contains(t, err.Error(), "files must be at least 1") +} + +func TestBenchHandler_ZeroIterations(t *testing.T) { + _, cfgMgr := setupBenchHandlerTest(t) + + output := newTestOutput() + cmd := defaultBenchCmd().withInt(FlagBenchIterations, 0) + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.Error(t, err) + assert.Contains(t, err.Error(), "iterations must be at least 1") +} + +func TestBenchHandler_NegativeDepth(t *testing.T) { + _, cfgMgr := setupBenchHandlerTest(t) + + output := newTestOutput() + cmd := defaultBenchCmd().withInt(FlagBenchDepth, -1) + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.Error(t, err) + assert.Contains(t, err.Error(), "depth must be non-negative") +} + +func TestBenchHandler_ZeroPollInterval(t *testing.T) { + _, cfgMgr := setupBenchHandlerTest(t) + + output := newTestOutput() + cmd := defaultBenchCmd().withDuration(FlagBenchPollInterval, 0) + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.Error(t, err) + assert.Contains(t, err.Error(), "poll-interval must be positive") +} + +func TestBenchHandler_DryRun(t *testing.T) { + _, cfgMgr := setupBenchHandlerTest(t) + + output := newTestOutput() + cmd := defaultBenchCmd().withBool(FlagDryRun, true) + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.NoError(t, err) +} + +func TestBenchHandler_DryRunWithPath(t *testing.T) { + _, cfgMgr := setupBenchHandlerTest(t) + + output := newTestOutput() + cmd := defaultBenchCmd().withBool(FlagDryRun, true).withArgs("./testdata") + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.NoError(t, err) +} + +func TestBenchHandler_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupBenchHandlerTest(t) + mockSvc.runFunc = func(ctx context.Context, opts BenchOptions) (*BenchResult, error) { + return nil, errors.New("benchmark failed") + } + + output := newTestOutput() + cmd := defaultBenchCmd() + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.Error(t, err) + assert.Contains(t, err.Error(), "benchmark failed") +} + +func TestBenchHandler_AuthError(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Maybe().Return(&config.Config{ + BaseEndpoint: "pinner.xyz", + Secure: true, + AuthToken: "test-token", + MemoryLimit: 100, + }).Maybe() + + origAuthFactory := benchAuthServiceFactory + t.Cleanup(func() { benchAuthServiceFactory = origAuthFactory }) + + mockAuthSvc := NewMockAuthService(t) + mockAuthSvc.EXPECT().GetAuthenticatedClient(mock.Anything).Return(nil, errors.New("auth failed")) + + benchAuthServiceFactory = func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService { + return mockAuthSvc + } + + output := newTestOutput() + cmd := defaultBenchCmd() + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to authenticate for operation polling") +} + +func TestBenchHandler_JSONOutput(t *testing.T) { + mockSvc, cfgMgr := setupBenchHandlerTest(t) + mockSvc.runFunc = func(ctx context.Context, opts BenchOptions) (*BenchResult, error) { + return &BenchResult{ + Input: BenchInput{Type: "random", Size: 1048576, Files: 1}, + Iterations: []BenchIteration{{Number: 1, CID: "QmTest", Size: 1048576, Total: 100 * time.Millisecond}}, + Summary: BenchSummary{TotalDuration: 100 * time.Millisecond, UploadDuration: 100 * time.Millisecond}, + }, nil + } + + output := NewOutputFormatter(true, false, false, false) + cmd := defaultBenchCmd() + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.NoError(t, err) +} + +func TestBenchHandler_WithPath(t *testing.T) { + mockSvc, cfgMgr := setupBenchHandlerTest(t) + mockSvc.runFunc = func(ctx context.Context, opts BenchOptions) (*BenchResult, error) { + assert.Equal(t, "./testdata", opts.Path) + return &BenchResult{ + Input: BenchInput{Type: "path", Path: "./testdata"}, + Iterations: []BenchIteration{{Number: 1, CID: "QmTest", Size: 1024, Total: 50 * time.Millisecond}}, + Summary: BenchSummary{TotalDuration: 50 * time.Millisecond, UploadDuration: 50 * time.Millisecond}, + }, nil + } + + output := newTestOutput() + cmd := defaultBenchCmd().withString(FlagBenchSize, "0B").withArgs("./testdata") + err := bench(context.Background(), cmd, output, cfgMgr, "test-token", true) + require.NoError(t, err) +} + func TestGenerateRandomData(t *testing.T) { t.Run("single file", func(t *testing.T) { opts := BenchOptions{ @@ -94,18 +330,15 @@ func TestGenerateRandomData(t *testing.T) { }) t.Run("disk fallback for large data", func(t *testing.T) { - // Use a size larger than available memory to force disk path - avail := availableMemory() - if avail <= 0 { - t.Skip("cannot detect available memory") - } + // Test the disk fallback path directly with a small size to avoid + // filling the disk on systems with large amounts of RAM. opts := BenchOptions{ - SizeBytes: avail + 1, // just over available memory + SizeBytes: 1024, // 1KB — small size, but we're testing the disk path directly Files: 1, Depth: 0, } - fsys, name, cleanup, err := generateRandomData(opts) + fsys, name, cleanup, err := generateRandomDataDisk(opts) require.NoError(t, err) assert.Equal(t, "bench", name) assert.NotNil(t, fsys) @@ -242,7 +475,7 @@ func TestBenchService_RequireAuthenticated(t *testing.T) { AuthToken: "", }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningService := NewPinningService(cfgMgr, output, "https://api.test.com") uploadService := NewUploadService(cfgMgr, output) @@ -259,7 +492,7 @@ func TestBenchService_Run_NotAuthenticated(t *testing.T) { AuthToken: "", }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningService := NewPinningService(cfgMgr, output, "https://api.test.com") uploadService := NewUploadService(cfgMgr, output) @@ -273,95 +506,7 @@ func TestBenchService_Run_NotAuthenticated(t *testing.T) { }) } -// mockBenchCommand is a test mock for benchCommandGetter. -type mockBenchCommand struct { - size string - files int - depth int - iterations int - parallel int - noCleanup bool - pollInterval time.Duration - memoryLimit uint64 - dryRun bool - path string - chunkSize int64 - chunkerStrategy string - maxLinks int -} - -func (m *mockBenchCommand) String(name string) string { - switch name { - case FlagBenchSize: - return m.size - case FlagChunker: - return m.chunkerStrategy - default: - return "" - } -} - -func (m *mockBenchCommand) Int(name string) int { - switch name { - case FlagBenchFiles: - return m.files - case FlagBenchDepth: - return m.depth - case FlagBenchIterations: - return m.iterations - case FlagParallel: - return m.parallel - case FlagMaxLinks: - return m.maxLinks - default: - return 0 - } -} -func (m *mockBenchCommand) Int64(name string) int64 { - switch name { - case FlagChunkSize: - return m.chunkSize - default: - return 0 - } -} - -func (m *mockBenchCommand) Bool(name string) bool { - switch name { - case FlagBenchNoCleanup: - return m.noCleanup - case FlagDryRun: - return m.dryRun - default: - return false - } -} - -func (m *mockBenchCommand) Uint64(name string) uint64 { - switch name { - case FlagMemoryLimit: - return m.memoryLimit - default: - return 0 - } -} - -func (m *mockBenchCommand) Duration(name string) time.Duration { - switch name { - case FlagBenchPollInterval: - return m.pollInterval - default: - return 0 - } -} - -func (m *mockBenchCommand) Args() cli.Args { - return &mockArgs{args: []string{m.path}} -} - -// Ensure mock types satisfy interfaces at compile time. -var _ benchCommandGetter = (*mockBenchCommand)(nil) func TestFormatBenchResult(t *testing.T) { t.Run("single iteration with random data", func(t *testing.T) { @@ -399,7 +544,7 @@ func TestFormatBenchResult(t *testing.T) { } // Just verify it doesn't panic and produces output - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() formatBenchResult(output, result) }) @@ -438,7 +583,7 @@ func TestFormatBenchResult(t *testing.T) { }, } - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() formatBenchResult(output, result) }) } @@ -487,3 +632,241 @@ func TestNewBenchError(t *testing.T) { assert.Equal(t, "internal server error", benchErr.Detail["body"]) }) } + +func TestIsUnrecoverableError(t *testing.T) { + t.Run("ErrUnauthorized is unrecoverable", func(t *testing.T) { + assert.True(t, isUnrecoverableError(portalsdk.ErrUnauthorized)) + }) + + t.Run("ErrForbidden is unrecoverable", func(t *testing.T) { + assert.True(t, isUnrecoverableError(portalsdk.ErrForbidden)) + }) + + t.Run("wrapped ErrUnauthorized is unrecoverable", func(t *testing.T) { + err := fmt.Errorf("request failed: %w", portalsdk.ErrUnauthorized) + assert.True(t, isUnrecoverableError(err)) + }) + + t.Run("wrapped ErrForbidden is unrecoverable", func(t *testing.T) { + err := fmt.Errorf("access denied: %w", portalsdk.ErrForbidden) + assert.True(t, isUnrecoverableError(err)) + }) + + t.Run("generic error is recoverable", func(t *testing.T) { + assert.False(t, isUnrecoverableError(errors.New("something went wrong"))) + }) + + t.Run("HTTPError 500 is recoverable", func(t *testing.T) { + assert.False(t, isUnrecoverableError(NewHTTPError(500, "internal server error"))) + }) + + t.Run("HTTPError 401 is recoverable (not a sentinel error)", func(t *testing.T) { + assert.False(t, isUnrecoverableError(NewHTTPError(401, "unauthorized"))) + }) + + t.Run("nil error is recoverable", func(t *testing.T) { + assert.False(t, isUnrecoverableError(nil)) + }) +} + +func TestBenchError_String(t *testing.T) { + err := &BenchError{Message: "test error message"} + assert.Equal(t, "test error message", err.String()) +} + +func TestBenchServiceDefault_RunIteration(t *testing.T) { + ctx := context.Background() + + newCompletedOp := func(id int) *portalsdk.Operation { + op := &portalsdk.Operation{} + op.Id = id + op.Status = string(portalsdk.OperationStatusCompleted) + op.Operation = "Pin" + op.Protocol = "IPFS" + op.ProgressPercent = 100 + op.StartedAt = time.Now() + op.UpdatedAt = time.Now() + return op + } + + t.Run("successful upload and pin", func(t *testing.T) { + uploadSvc := NewMockUploadService(t) + pinningSvc := NewMockPinningService(t) + accountAPI := portalsdkmocks.NewMockAccountAPI(t) + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + uploadSvc.EXPECT().Upload(mock.Anything, mock.Anything, mock.Anything, false). + Return(&UploadResult{CID: "QmTestCID", Size: 1024}, nil) + + accountAPI.EXPECT().ListOperations(mock.Anything, mock.Anything). + Return([]*portalsdk.Operation{newCompletedOp(1)}, nil) + accountAPI.EXPECT().GetOperation(mock.Anything, int64(1)). + Return(newCompletedOp(1), nil) + + svc := &BenchServiceDefault{ + configMgr: cfgMgr, + output: output, + uploadService: uploadSvc, + pinningService: pinningSvc, + accountClient: accountAPI, + } + + opts := BenchOptions{ + SizeBytes: 1024, + Files: 1, + Iterations: 1, + PollInterval: 100 * time.Millisecond, + } + + iter := svc.runIteration(ctx, opts, 0) + assert.Equal(t, 1, iter.Number) + assert.Equal(t, "QmTestCID", iter.CID) + assert.Equal(t, int64(1024), iter.Size) + assert.Nil(t, iter.Error) + assert.Greater(t, iter.Total, time.Duration(0)) + + stageNames := make([]string, len(iter.Stages)) + for i, s := range iter.Stages { + stageNames[i] = s.Name + } + assert.Contains(t, stageNames, "generate") + assert.Contains(t, stageNames, "upload") + }) + + t.Run("upload failure returns error", func(t *testing.T) { + uploadSvc := NewMockUploadService(t) + pinningSvc := NewMockPinningService(t) + accountAPI := portalsdkmocks.NewMockAccountAPI(t) + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + uploadSvc.EXPECT().Upload(mock.Anything, mock.Anything, mock.Anything, false). + Return(nil, errors.New("network error")) + + svc := &BenchServiceDefault{ + configMgr: cfgMgr, + output: output, + uploadService: uploadSvc, + pinningService: pinningSvc, + accountClient: accountAPI, + } + + opts := BenchOptions{ + SizeBytes: 1024, + Files: 1, + Iterations: 1, + PollInterval: 100 * time.Millisecond, + } + + iter := svc.runIteration(ctx, opts, 0) + assert.Equal(t, 1, iter.Number) + assert.Empty(t, iter.CID) + assert.NotNil(t, iter.Error) + assert.Equal(t, "network error", iter.Error.Message) + assert.NotNil(t, iter.err) + + stageNames := make([]string, len(iter.Stages)) + for i, s := range iter.Stages { + stageNames[i] = s.Name + } + assert.Contains(t, stageNames, "generate") + assert.Contains(t, stageNames, "upload") + }) + + t.Run("upload failure with HTTPError has structured detail", func(t *testing.T) { + uploadSvc := NewMockUploadService(t) + pinningSvc := NewMockPinningService(t) + accountAPI := portalsdkmocks.NewMockAccountAPI(t) + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + httpErr := NewHTTPError(429, "rate limit exceeded") + uploadSvc.EXPECT().Upload(mock.Anything, mock.Anything, mock.Anything, false). + Return(nil, httpErr) + + svc := &BenchServiceDefault{ + configMgr: cfgMgr, + output: output, + uploadService: uploadSvc, + pinningService: pinningSvc, + accountClient: accountAPI, + } + + opts := BenchOptions{ + SizeBytes: 1024, + Files: 1, + Iterations: 1, + PollInterval: 100 * time.Millisecond, + } + + iter := svc.runIteration(ctx, opts, 0) + assert.NotNil(t, iter.Error) + assert.Equal(t, "upload failed", iter.Error.Message) + assert.Equal(t, 429, iter.Error.Detail["http_status"]) + }) + + t.Run("upload failure with auth error is unrecoverable", func(t *testing.T) { + uploadSvc := NewMockUploadService(t) + pinningSvc := NewMockPinningService(t) + accountAPI := portalsdkmocks.NewMockAccountAPI(t) + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + uploadSvc.EXPECT().Upload(mock.Anything, mock.Anything, mock.Anything, false). + Return(nil, portalsdk.ErrUnauthorized) + + svc := &BenchServiceDefault{ + configMgr: cfgMgr, + output: output, + uploadService: uploadSvc, + pinningService: pinningSvc, + accountClient: accountAPI, + } + + opts := BenchOptions{ + SizeBytes: 1024, + Files: 1, + Iterations: 1, + PollInterval: 100 * time.Millisecond, + } + + iter := svc.runIteration(ctx, opts, 0) + assert.NotNil(t, iter.Error) + assert.True(t, isUnrecoverableError(iter.err)) + }) + + t.Run("iteration number is 1-indexed", func(t *testing.T) { + uploadSvc := NewMockUploadService(t) + pinningSvc := NewMockPinningService(t) + accountAPI := portalsdkmocks.NewMockAccountAPI(t) + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + uploadSvc.EXPECT().Upload(mock.Anything, mock.Anything, mock.Anything, false). + Return(&UploadResult{CID: "QmTestCID", Size: 1024}, nil) + + accountAPI.EXPECT().ListOperations(mock.Anything, mock.Anything). + Return([]*portalsdk.Operation{newCompletedOp(1)}, nil) + accountAPI.EXPECT().GetOperation(mock.Anything, int64(1)). + Return(newCompletedOp(1), nil) + + svc := &BenchServiceDefault{ + configMgr: cfgMgr, + output: output, + uploadService: uploadSvc, + pinningService: pinningSvc, + accountClient: accountAPI, + } + + opts := BenchOptions{ + SizeBytes: 1024, + Files: 1, + Iterations: 1, + PollInterval: 100 * time.Millisecond, + } + + iter := svc.runIteration(ctx, opts, 4) + assert.Equal(t, 5, iter.Number) + }) +} diff --git a/pkg/cli/command_docs_test.go b/pkg/cli/command_docs_test.go new file mode 100644 index 0000000..57f0891 --- /dev/null +++ b/pkg/cli/command_docs_test.go @@ -0,0 +1,121 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v3" +) + +func TestTutorialCommandsEmpty(t *testing.T) { + root := &cli.Command{} + cmds := TutorialCommands(root) + assert.Empty(t, cmds) +} + +func TestTutorialCommandsFromRoot(t *testing.T) { + root := &cli.Command{ + Commands: []*cli.Command{ + {Name: "upload", Metadata: WithTutorial(2, "Upload a file", "pinner upload file.txt")}, + {Name: "auth", Metadata: WithTutorial(1, "Authenticate", "pinner auth")}, + {Name: "list", Metadata: WithTutorial(3, "List pins", "pinner list")}, + {Name: "config"}, + }, + } + + cmds := TutorialCommands(root) + require.Len(t, cmds, 3) + assert.Equal(t, "auth", cmds[0].Name) + assert.Equal(t, "upload", cmds[1].Name) + assert.Equal(t, "list", cmds[2].Name) +} + +func TestBuildTutorialCommandsTable(t *testing.T) { + root := &cli.Command{ + Commands: []*cli.Command{ + {Name: "upload", Usage: "Upload files", Description: "Upload desc", Metadata: WithTutorial(1, "Upload a file", "pinner upload file.txt")}, + }, + } + + headers, rows := BuildTutorialCommandsTable(root) + assert.Equal(t, []string{"Command", "Usage", "Description"}, headers) + require.Len(t, rows, 1) + assert.Equal(t, "upload", rows[0][0]) + assert.Equal(t, "Upload a file", rows[0][2]) +} + +func TestBuildTutorialCommandsTableFallbackDescription(t *testing.T) { + root := &cli.Command{ + Commands: []*cli.Command{ + {Name: "upload", Usage: "Upload files", Description: "Fallback desc", Metadata: WithTutorial(1, "", "pinner upload file.txt")}, + }, + } + + _, rows := BuildTutorialCommandsTable(root) + require.Len(t, rows, 1) + assert.Equal(t, "Fallback desc", rows[0][2]) +} + +func TestBuildTutorialExamplesTable(t *testing.T) { + root := &cli.Command{ + Commands: []*cli.Command{ + {Name: "upload", Metadata: WithTutorial(1, "Upload a file", "pinner upload myfile.txt")}, + }, + } + + headers, rows := BuildTutorialExamplesTable(root) + assert.Equal(t, []string{"Example"}, headers) + require.Len(t, rows, 1) + assert.Equal(t, "pinner upload myfile.txt", rows[0][0]) +} + +func TestBuildTutorialExamplesTableFallback(t *testing.T) { + root := &cli.Command{ + Commands: []*cli.Command{ + {Name: "upload", Metadata: WithTutorial(1, "Upload a file", "")}, + }, + } + + _, rows := BuildTutorialExamplesTable(root) + require.Len(t, rows, 1) + assert.Equal(t, "pinner upload", rows[0][0]) +} + +func TestWithTutorial(t *testing.T) { + meta := WithTutorial(1, "desc", "example") + require.Contains(t, meta, "tutorial") + tm, ok := meta["tutorial"].(*TutorialMetadata) + require.True(t, ok) + assert.Equal(t, 1, tm.Priority) + assert.Equal(t, "desc", tm.Description) + assert.Equal(t, "example", tm.Example) +} + +func TestAbbreviateCID(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"short CID", "bafybeie7", "bafybeie7"}, + {"long CID", "bafybeie7m2fsbt6sjtn7tymyb6sim7iiyz6szl4ethtn7anzx4frzfzipu", "bafybeie7m..."}, + {"empty", "", ""}, + {"exactly 10", "1234567890", "1234567890"}, + {"11 chars", "12345678901", "1234567890..."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, abbreviateCID(tt.input)) + }) + } +} + +func TestDocumentationURL(t *testing.T) { + assert.Equal(t, "https://docs.pinner.xyz", DocumentationURL) +} + +func TestTutorialCID(t *testing.T) { + assert.NotEmpty(t, TutorialCID) +} diff --git a/pkg/cli/command_getter.go b/pkg/cli/command_getter.go new file mode 100644 index 0000000..0d9bd65 --- /dev/null +++ b/pkg/cli/command_getter.go @@ -0,0 +1,86 @@ +package cli + +import ( + "time" + + "github.com/urfave/cli/v3" +) + +type flagGetter interface { + String(name string) string + Bool(name string) bool +} + +type flagGetterWithInt interface { + flagGetter + Int(name string) int +} + +type flagGetterWithIsSet interface { + flagGetterWithInt + IsSet(name string) bool +} + +type flagGetterWithUint interface { + flagGetterWithInt + Uint(name string) uint +} + +type flagGetterWithDuration interface { + flagGetterWithUint + Duration(name string) time.Duration +} + +// commandGetter is the broadest interface satisfied by cliCommandWrapper. +// It encompasses all flag, arg, and CID access methods used across handlers. +type commandGetter interface { + flagGetterWithIsSet + argsGetter + cidGetter + Uint(name string) uint + Duration(name string) time.Duration +} + +type argsGetter interface { + Args() cli.Args +} + +type cidGetter interface { + GetCID() string +} + +type argsFlagGetter interface { + argsGetter + flagGetterWithInt +} + +type cidFlagGetter interface { + cidGetter + flagGetterWithInt +} + +// argsFlagGetterWithBool combines argsGetter and flagGetter for commands +// that need both positional args and Bool flag access (e.g., setConfig). +type argsFlagGetterWithBool interface { + argsGetter + flagGetter +} + +// dnsCommandGetter combines all interfaces needed by DNS handlers. +type dnsCommandGetter interface { + flagGetterWithIsSet + argsGetter + Uint(name string) uint +} + +// benchCommandGetter combines all interfaces needed by bench handlers. +type benchCommandGetter interface { + flagGetterWithDuration + argsGetter +} + +// websitesCommandGetter combines all interfaces needed by websites handlers. +type websitesCommandGetter interface { + flagGetterWithIsSet + argsGetter +} diff --git a/pkg/cli/command_helper.go b/pkg/cli/command_helper.go index 2e5285b..c01d7c2 100644 --- a/pkg/cli/command_helper.go +++ b/pkg/cli/command_helper.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "strings" @@ -71,3 +72,44 @@ func requireSetInt(cmd intFlagChecker, name string) (int, error) { } return v, nil } + +// commandContext carries all resolved dependencies for a command handler. +// It replaces the repeated setupCommandContext + GetAuthToken + GetSecureSetting + newCLICommandWrapper pattern. +type commandContext struct { + Cmd commandGetter + Output Output + CfgMgr config.Manager + AuthToken string + Secure bool +} + +// newCommandContext creates a commandContext from a *cli.Command. +// It resolves output, config, auth token, and secure settings in one call. +func newCommandContext(c *cli.Command) (*commandContext, error) { + cfgMgr, output, err := setupCommandContext(c) + if err != nil { + return nil, err + } + authToken := GetAuthToken(c, cfgMgr) + secure := GetSecureSetting(c, cfgMgr) + cmd := newCLICommandWrapper(c) + return &commandContext{ + Cmd: cmd, + Output: output, + CfgMgr: cfgMgr, + AuthToken: authToken, + Secure: secure, + }, nil +} + +// withContext wraps a handler that takes *commandContext into a cli.ActionFunc. +// This is the DRY mechanism for Action closures — replaces 5-line boilerplate with 2-3 lines. +func withContext(handler func(ctx context.Context, cc *commandContext) error) cli.ActionFunc { + return func(ctx context.Context, c *cli.Command) error { + cc, err := newCommandContext(c) + if err != nil { + return err + } + return handler(ctx, cc) + } +} diff --git a/pkg/cli/command_helper_test.go b/pkg/cli/command_helper_test.go new file mode 100644 index 0000000..0b96101 --- /dev/null +++ b/pkg/cli/command_helper_test.go @@ -0,0 +1,87 @@ +package cli + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v3" + + "go.lumeweb.com/pinner-cli/pkg/config" +) + +func TestSetupCommandContext(t *testing.T) { + orig := configManagerFactory + defer func() { configManagerFactory = orig }() + + configManagerFactory = func() (config.Manager, error) { return newTestConfigMgr(t), nil } + + cmd := &cli.Command{} + cfgMgr, output, err := setupCommandContext(cmd) + require.NoError(t, err) + require.NotNil(t, cfgMgr) + require.NotNil(t, output) +} + +func TestSetupCommandContextError(t *testing.T) { + orig := configManagerFactory + defer func() { configManagerFactory = orig }() + + configManagerFactory = func() (config.Manager, error) { + return nil, errors.New("config error") + } + + cmd := &cli.Command{} + cfgMgr, output, err := setupCommandContext(cmd) + require.Error(t, err) + assert.Nil(t, cfgMgr) + assert.Nil(t, output) +} + +func TestSetupOutput(t *testing.T) { + orig := configManagerFactory + defer func() { configManagerFactory = orig }() + + configManagerFactory = func() (config.Manager, error) { return newTestConfigMgr(t), nil } + + cmd := &cli.Command{} + output := setupOutput(cmd) + require.NotNil(t, output) +} + +func TestRequireUpdateFieldsSet(t *testing.T) { + cmd := newMockCommand().withIsSet("name", true) + err := requireUpdateFields(cmd, "name", "email") + require.NoError(t, err) +} + +func TestRequireUpdateFieldsNotSet(t *testing.T) { + cmd := newMockCommand() + err := requireUpdateFields(cmd, "name", "email") + require.Error(t, err) + assert.Contains(t, err.Error(), "at least one field must be provided") + assert.Contains(t, err.Error(), "name") + assert.Contains(t, err.Error(), "email") +} + +func TestRequireSetInt(t *testing.T) { + cmd := newMockCommand().withIsSet("limit", true).withInt("limit", 10) + v, err := requireSetInt(cmd, "limit") + require.NoError(t, err) + assert.Equal(t, 10, v) +} + +func TestRequireSetIntNotSet(t *testing.T) { + cmd := newMockCommand() + _, err := requireSetInt(cmd, "limit") + require.Error(t, err) + assert.Contains(t, err.Error(), "--limit is required") +} + +func TestRequireSetIntZero(t *testing.T) { + cmd := newMockCommand().withIsSet("limit", true).withInt("limit", 0) + _, err := requireSetInt(cmd, "limit") + require.Error(t, err) + assert.Contains(t, err.Error(), "--limit must be greater than zero") +} diff --git a/pkg/cli/command_registration_test.go b/pkg/cli/command_registration_test.go new file mode 100644 index 0000000..e11b713 --- /dev/null +++ b/pkg/cli/command_registration_test.go @@ -0,0 +1,803 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v3" +) + +// findCommand finds a command by name in a list of commands. +func findCommand(cmds []*cli.Command, name string) *cli.Command { + for _, cmd := range cmds { + if cmd.Name == name { + return cmd + } + } + return nil +} + +// commandNames extracts all command names from a slice. +func commandNames(cmds []*cli.Command) []string { + names := make([]string, len(cmds)) + for i, cmd := range cmds { + names[i] = cmd.Name + } + return names +} + +// collectAllCommandNames walks the entire command tree and collects all command names. +func collectAllCommandNames(cmd *cli.Command) []string { + var names []string + walkCommands(cmd, func(c *cli.Command) { + names = append(names, c.Name) + }) + return names +} + +// walkCommands recursively walks the command tree, calling fn for each command (including root). +func walkCommands(cmd *cli.Command, fn func(*cli.Command)) { + fn(cmd) + for _, sub := range cmd.Commands { + walkCommands(sub, fn) + } +} + +// countCommands recursively counts all commands in the tree (including root). +func countCommands(cmd *cli.Command) int { + count := 1 + for _, sub := range cmd.Commands { + count += countCommands(sub) + } + return count +} + +func TestCommandRegistration_RootSubcommands(t *testing.T) { + root := NewRootCommand() + + expectedRootSubcommands := []string{ + "setup", + "auth", + "register", + "confirm-email", + "account", + "upload", + "download", + "cat", + "ls", + "pin", + "pins", + "list", + "status", + "unpin", + "metadata", + "operations", + "config", + "doctor", + "bench", + "dns", + "ipns", + "websites", + "admin", + } + + names := commandNames(root.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedRootSubcommands { + assert.True(t, nameSet[expected], "root should have subcommand %q", expected) + } + + // Verify exact count — no unexpected commands + assert.Len(t, root.Commands, len(expectedRootSubcommands), + "root should have exactly %d subcommands, got %d: %v", + len(expectedRootSubcommands), len(root.Commands), names) +} + +func TestCommandRegistration_Categories(t *testing.T) { + root := NewRootCommand() + + expectedCategories := map[string]string{ + "setup": "Setup", + "auth": "Setup", + "register": "Setup", + "confirm-email": "Setup", + "account": "Setup", + "upload": "Content", + "download": "Content", + "cat": "Content", + "ls": "Content", + "pin": "Pinning", + "pins": "Pinning", + "list": "Pinning", + "status": "Pinning", + "unpin": "Pinning", + "metadata": "Pinning", + "operations": "Management", + "dns": "Management", + "ipns": "Management", + "websites": "Management", + "config": "System", + "doctor": "System", + "bench": "System", + "admin": "Admin", + } + + for name, expectedCat := range expectedCategories { + cmd := findCommand(root.Commands, name) + require.NotNil(t, cmd, "command %q should exist", name) + assert.Equal(t, expectedCat, cmd.Category, + "command %q should have category %q, got %q", name, expectedCat, cmd.Category) + } +} + +func TestCommandRegistration_PinsSubcommands(t *testing.T) { + root := NewRootCommand() + pins := findCommand(root.Commands, "pins") + require.NotNil(t, pins, "pins command should exist") + + expectedPinsSubs := []string{"add", "rm", "ls", "status", "update"} + names := commandNames(pins.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedPinsSubs { + assert.True(t, nameSet[expected], "pins should have subcommand %q", expected) + } + assert.Len(t, pins.Commands, len(expectedPinsSubs), + "pins should have exactly %d subcommands, got %d: %v", + len(expectedPinsSubs), len(pins.Commands), names) +} + +func TestCommandRegistration_AuthSubcommands(t *testing.T) { + root := NewRootCommand() + auth := findCommand(root.Commands, "auth") + require.NotNil(t, auth, "auth command should exist") + + expectedAuthSubs := []string{"status"} + names := commandNames(auth.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedAuthSubs { + assert.True(t, nameSet[expected], "auth should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AccountSubcommands(t *testing.T) { + root := NewRootCommand() + account := findCommand(root.Commands, "account") + require.NotNil(t, account, "account command should exist") + + expectedAccountSubs := []string{"otp", "api-keys"} + names := commandNames(account.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedAccountSubs { + assert.True(t, nameSet[expected], "account should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AccountOTPSubcommands(t *testing.T) { + root := NewRootCommand() + account := findCommand(root.Commands, "account") + require.NotNil(t, account, "account command should exist") + + otp := findCommand(account.Commands, "otp") + require.NotNil(t, otp, "otp command should exist") + + expectedOTPSubs := []string{"enable", "disable"} + names := commandNames(otp.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedOTPSubs { + assert.True(t, nameSet[expected], "otp should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AccountAPIKeysSubcommands(t *testing.T) { + root := NewRootCommand() + account := findCommand(root.Commands, "account") + require.NotNil(t, account, "account command should exist") + + apiKeys := findCommand(account.Commands, "api-keys") + require.NotNil(t, apiKeys, "api-keys command should exist") + + expectedSubs := []string{"list", "create", "delete"} + names := commandNames(apiKeys.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "api-keys should have subcommand %q", expected) + } +} + +func TestCommandRegistration_DNSSubcommands(t *testing.T) { + root := NewRootCommand() + dns := findCommand(root.Commands, "dns") + require.NotNil(t, dns, "dns command should exist") + + expectedDNSSubs := []string{"zones", "records"} + names := commandNames(dns.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedDNSSubs { + assert.True(t, nameSet[expected], "dns should have subcommand %q", expected) + } +} + +func TestCommandRegistration_DNSZonesSubcommands(t *testing.T) { + root := NewRootCommand() + dns := findCommand(root.Commands, "dns") + require.NotNil(t, dns, "dns command should exist") + + zones := findCommand(dns.Commands, "zones") + require.NotNil(t, zones, "dns zones command should exist") + + expectedSubs := []string{"list", "create", "get", "delete", "validate"} + names := commandNames(zones.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "dns zones should have subcommand %q", expected) + } +} + +func TestCommandRegistration_DNSRecordsSubcommands(t *testing.T) { + root := NewRootCommand() + dns := findCommand(root.Commands, "dns") + require.NotNil(t, dns, "dns command should exist") + + records := findCommand(dns.Commands, "records") + require.NotNil(t, records, "dns records command should exist") + + expectedSubs := []string{"list", "create", "get", "update", "delete"} + names := commandNames(records.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "dns records should have subcommand %q", expected) + } +} + +func TestCommandRegistration_IPNSSubcommands(t *testing.T) { + root := NewRootCommand() + ipns := findCommand(root.Commands, "ipns") + require.NotNil(t, ipns, "ipns command should exist") + + expectedIPNSSubs := []string{"keys", "publish", "republish", "resolve"} + names := commandNames(ipns.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedIPNSSubs { + assert.True(t, nameSet[expected], "ipns should have subcommand %q", expected) + } +} + +func TestCommandRegistration_IPNSKeysSubcommands(t *testing.T) { + root := NewRootCommand() + ipns := findCommand(root.Commands, "ipns") + require.NotNil(t, ipns, "ipns command should exist") + + keys := findCommand(ipns.Commands, "keys") + require.NotNil(t, keys, "ipns keys command should exist") + + expectedSubs := []string{"list", "create", "get", "delete"} + names := commandNames(keys.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "ipns keys should have subcommand %q", expected) + } +} + +func TestCommandRegistration_WebsitesSubcommands(t *testing.T) { + root := NewRootCommand() + websites := findCommand(root.Commands, "websites") + require.NotNil(t, websites, "websites command should exist") + + expectedSubs := []string{ + "list", "create", "get", "update", "enable-ipns", + "delete", "validate", "ssl", "config", "wizard", + } + names := commandNames(websites.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "websites should have subcommand %q", expected) + } +} + +func TestCommandRegistration_WebsitesSSLSubcommands(t *testing.T) { + root := NewRootCommand() + websites := findCommand(root.Commands, "websites") + require.NotNil(t, websites, "websites command should exist") + + ssl := findCommand(websites.Commands, "ssl") + require.NotNil(t, ssl, "websites ssl command should exist") + + expectedSubs := []string{"status"} + names := commandNames(ssl.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "websites ssl should have subcommand %q", expected) + } +} + +func TestCommandRegistration_OperationsSubcommands(t *testing.T) { + root := NewRootCommand() + ops := findCommand(root.Commands, "operations") + require.NotNil(t, ops, "operations command should exist") + + expectedSubs := []string{"list", "get"} + names := commandNames(ops.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "operations should have subcommand %q", expected) + } +} + +func TestCommandRegistration_UnpinSubcommands(t *testing.T) { + root := NewRootCommand() + unpin := findCommand(root.Commands, "unpin") + require.NotNil(t, unpin, "unpin command should exist") + + expectedSubs := []string{"all"} + names := commandNames(unpin.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "unpin should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + expectedAdminSubs := []string{"quota", "billing", "websites", "pprof"} + names := commandNames(admin.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedAdminSubs { + assert.True(t, nameSet[expected], "admin should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminQuotaSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + quota := findCommand(admin.Commands, "quota") + require.NotNil(t, quota, "admin quota command should exist") + + expectedSubs := []string{"plans", "allowances", "user-configs", "stats", "reconcile", "cleanup"} + names := commandNames(quota.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin quota should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminBillingSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + billing := findCommand(admin.Commands, "billing") + require.NotNil(t, billing, "admin billing command should exist") + + expectedSubs := []string{ + "overview", "credits", "price-lines", + "pricing-plans", "pricing-plan-periods", "subscribers", + } + names := commandNames(billing.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin billing should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminWebsitesSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + websites := findCommand(admin.Commands, "websites") + require.NotNil(t, websites, "admin websites command should exist") + + expectedSubs := []string{"block", "unblock"} + names := commandNames(websites.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin websites should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminPprofSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + pprof := findCommand(admin.Commands, "pprof") + require.NotNil(t, pprof, "admin pprof command should exist") + + expectedSubs := []string{ + "index", "block", "set-block-rate", "cmdline", + "goroutine", "heap", "mutex", "set-mutex-fraction", + "cpu", "status", "symbol", "threadcreate", "trace", + } + names := commandNames(pprof.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin pprof should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminQuotaPlansSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + quota := findCommand(admin.Commands, "quota") + require.NotNil(t, quota, "admin quota command should exist") + + plans := findCommand(quota.Commands, "plans") + require.NotNil(t, plans, "admin quota plans command should exist") + + expectedSubs := []string{"list", "get", "create", "update", "delete", "set-default"} + names := commandNames(plans.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin quota plans should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminBillingCreditsSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + billing := findCommand(admin.Commands, "billing") + require.NotNil(t, billing, "admin billing command should exist") + + credits := findCommand(billing.Commands, "credits") + require.NotNil(t, credits, "admin billing credits command should exist") + + expectedSubs := []string{ + "list", "get", "create", "delete", "restore", "purge", + "user-balance", "user-deleted-credits", + } + names := commandNames(credits.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin billing credits should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminBillingSubscribersSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + billing := findCommand(admin.Commands, "billing") + require.NotNil(t, billing, "admin billing command should exist") + + subscribers := findCommand(billing.Commands, "subscribers") + require.NotNil(t, subscribers, "admin billing subscribers command should exist") + + expectedSubs := []string{ + "list", "get", "list-gateway", "list-user", + "cancel", "abort-cancel", "change-plan", "pause", "resume", + } + names := commandNames(subscribers.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin billing subscribers should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminBillingPricingPlansSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + billing := findCommand(admin.Commands, "billing") + require.NotNil(t, billing, "admin billing command should exist") + + plans := findCommand(billing.Commands, "pricing-plans") + require.NotNil(t, plans, "admin billing pricing-plans command should exist") + + expectedSubs := []string{ + "list", "get", "create", "update", "delete", "sync", "sync-all", + } + names := commandNames(plans.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin billing pricing-plans should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminBillingPriceLinesSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + billing := findCommand(admin.Commands, "billing") + require.NotNil(t, billing, "admin billing command should exist") + + priceLines := findCommand(billing.Commands, "price-lines") + require.NotNil(t, priceLines, "admin billing price-lines command should exist") + + expectedSubs := []string{ + "list", "get", "create", "update", "delete", + "add-plan", "delete-plan", "update-plan-position", + } + names := commandNames(priceLines.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin billing price-lines should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminBillingPricingPlanPeriodsSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + billing := findCommand(admin.Commands, "billing") + require.NotNil(t, billing, "admin billing command should exist") + + periods := findCommand(billing.Commands, "pricing-plan-periods") + require.NotNil(t, periods, "admin billing pricing-plan-periods command should exist") + + expectedSubs := []string{"list", "get", "create", "update", "delete"} + names := commandNames(periods.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin billing pricing-plan-periods should have subcommand %q", expected) + } +} + +func TestCommandRegistration_Aliases(t *testing.T) { + root := NewRootCommand() + + aliasTests := map[string][]string{ + "websites": {"website"}, + } + + for cmdName, expectedAliases := range aliasTests { + cmd := findCommand(root.Commands, cmdName) + require.NotNil(t, cmd, "command %q should exist", cmdName) + assert.Equal(t, expectedAliases, cmd.Aliases, + "command %q should have aliases %v, got %v", cmdName, expectedAliases, cmd.Aliases) + } +} + +func TestCommandRegistration_AccountAPIKeysAliases(t *testing.T) { + root := NewRootCommand() + account := findCommand(root.Commands, "account") + require.NotNil(t, account, "account command should exist") + + apiKeys := findCommand(account.Commands, "api-keys") + require.NotNil(t, apiKeys, "api-keys command should exist") + + expectedAliases := []string{"apikey", "api-key"} + assert.Equal(t, expectedAliases, apiKeys.Aliases, + "api-keys should have aliases %v, got %v", expectedAliases, apiKeys.Aliases) +} + +func TestCommandRegistration_MetadataHidden(t *testing.T) { + root := NewRootCommand() + metadata := findCommand(root.Commands, "metadata") + require.NotNil(t, metadata, "metadata command should exist") + + assert.True(t, metadata.Hidden, "metadata command should be hidden") +} + +func TestCommandRegistration_AllCommandsHaveUsage(t *testing.T) { + root := NewRootCommand() + + walkCommands(root, func(cmd *cli.Command) { + // Skip root — it has Usage but we only care about subcommands + if cmd.Name == "pinner" { + return + } + assert.NotEmpty(t, cmd.Usage, "command %q should have non-empty Usage", cmd.Name) + }) +} + +func TestCommandRegistration_AllCommandsHaveActionOrSubcommands(t *testing.T) { + root := NewRootCommand() + + walkCommands(root, func(cmd *cli.Command) { + if cmd.Name == "pinner" { + return + } + hasAction := cmd.Action != nil + hasSubcommands := len(cmd.Commands) > 0 + assert.True(t, hasAction || hasSubcommands, + "command %q should have an Action or subcommands", cmd.Name) + }) +} + +func TestCommandRegistration_NoDuplicateNames(t *testing.T) { + root := NewRootCommand() + + // Check each level for duplicate command names + walkCommands(root, func(cmd *cli.Command) { + seen := make(map[string]bool) + for _, sub := range cmd.Commands { + assert.False(t, seen[sub.Name], + "duplicate command name %q under %q", sub.Name, cmd.Name) + seen[sub.Name] = true + } + }) +} + +func TestCommandRegistration_CommandTreeDepth(t *testing.T) { + root := NewRootCommand() + + // The command tree should have a reasonable depth. + // Root → admin → billing → credits → list = depth 4 (max expected) + totalCommands := countCommands(root) + allNames := collectAllCommandNames(root) + + // Verify we have a substantial number of commands (the tree is well-populated) + assert.Greater(t, totalCommands, 50, + "command tree should have more than 50 commands total, got %d", totalCommands) + + // Verify no empty-named commands + for _, name := range allNames { + assert.NotEmpty(t, name, "all commands should have non-empty names") + } +} + +func TestCommandRegistration_DownloadCatLsCategories(t *testing.T) { + root := NewRootCommand() + + contentCmds := []string{"download", "cat", "ls"} + for _, name := range contentCmds { + cmd := findCommand(root.Commands, name) + require.NotNil(t, cmd, "command %q should exist", name) + assert.Equal(t, "Content", cmd.Category, + "command %q should have category Content, got %q", name, cmd.Category) + } +} + +func TestCommandRegistration_AdminQuotaAllowancesSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + quota := findCommand(admin.Commands, "quota") + require.NotNil(t, quota, "admin quota command should exist") + + allowances := findCommand(quota.Commands, "allowances") + require.NotNil(t, allowances, "admin quota allowances command should exist") + + expectedSubs := []string{"list", "create", "update", "delete"} + names := commandNames(allowances.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin quota allowances should have subcommand %q", expected) + } +} + +func TestCommandRegistration_AdminQuotaUserConfigsSubcommands(t *testing.T) { + root := NewRootCommand() + admin := findCommand(root.Commands, "admin") + require.NotNil(t, admin, "admin command should exist") + + quota := findCommand(admin.Commands, "quota") + require.NotNil(t, quota, "admin quota command should exist") + + userConfigs := findCommand(quota.Commands, "user-configs") + require.NotNil(t, userConfigs, "admin quota user-configs command should exist") + + expectedSubs := []string{"list", "update", "reset"} + names := commandNames(userConfigs.Commands) + nameSet := make(map[string]bool, len(names)) + for _, n := range names { + nameSet[n] = true + } + + for _, expected := range expectedSubs { + assert.True(t, nameSet[expected], "admin quota user-configs should have subcommand %q", expected) + } +} diff --git a/pkg/cli/command_wrapper.go b/pkg/cli/command_wrapper.go index 647aa82..ba4285f 100644 --- a/pkg/cli/command_wrapper.go +++ b/pkg/cli/command_wrapper.go @@ -1,6 +1,8 @@ package cli import ( + "time" + "github.com/urfave/cli/v3" ) @@ -33,6 +35,14 @@ func (w *cliCommandWrapper) Uint64(name string) uint64 { return w.Command.Uint64(name) } +func (w *cliCommandWrapper) Uint(name string) uint { + return w.Command.Uint(name) +} + +func (w *cliCommandWrapper) Duration(name string) time.Duration { + return w.Command.Duration(name) +} + func (w *cliCommandWrapper) Args() cli.Args { return w.Command.Args() } @@ -41,3 +51,6 @@ func (w *cliCommandWrapper) Args() cli.Args { func newCLICommandWrapper(c *cli.Command) *cliCommandWrapper { return &cliCommandWrapper{c} } + +// Compile-time interface satisfaction check. +var _ commandGetter = (*cliCommandWrapper)(nil) diff --git a/pkg/cli/command_wrapper_test.go b/pkg/cli/command_wrapper_test.go new file mode 100644 index 0000000..bc7e7dc --- /dev/null +++ b/pkg/cli/command_wrapper_test.go @@ -0,0 +1,93 @@ +package cli + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/urfave/cli/v3" +) + +func TestCLICommandWrapperString(t *testing.T) { + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.StringFlag{Name: "test", Value: "hello"}, + }, + } + w := newCLICommandWrapper(cmd) + assert.Equal(t, "hello", w.String("test")) +} + +func TestCLICommandWrapperInt(t *testing.T) { + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.IntFlag{Name: "count", Value: 42}, + }, + } + w := newCLICommandWrapper(cmd) + assert.Equal(t, 42, w.Int("count")) +} + +func TestCLICommandWrapperBool(t *testing.T) { + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.BoolFlag{Name: "flag", Value: true}, + }, + } + w := newCLICommandWrapper(cmd) + assert.True(t, w.Bool("flag")) +} + +func TestCLICommandWrapperUint64(t *testing.T) { + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.Uint64Flag{Name: "size", Value: 1024}, + }, + } + w := newCLICommandWrapper(cmd) + assert.Equal(t, uint64(1024), w.Uint64("size")) +} + +func TestCLICommandWrapperStringSlice(t *testing.T) { + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.StringSliceFlag{Name: "tags", Value: []string{"a", "b"}}, + }, + } + w := newCLICommandWrapper(cmd) + assert.Equal(t, []string{"a", "b"}, w.StringSlice("tags")) +} + +func TestCLICommandWrapperGetCID(t *testing.T) { + cmd := &cli.Command{ + Action: func(ctx context.Context, cmd *cli.Command) error { return nil }, + } + _ = cmd.Run(context.Background(), []string{"test", "bafybeig123"}) + w := newCLICommandWrapper(cmd) + assert.Equal(t, "bafybeig123", w.GetCID()) +} + +func TestCLICommandWrapperArgs(t *testing.T) { + cmd := &cli.Command{ + Action: func(ctx context.Context, cmd *cli.Command) error { return nil }, + } + _ = cmd.Run(context.Background(), []string{"test", "arg1", "arg2"}) + w := newCLICommandWrapper(cmd) + args := w.Args() + assert.Equal(t, 2, args.Len()) + assert.Equal(t, "arg1", args.Get(0)) + assert.Equal(t, "arg2", args.Get(1)) +} + +func TestCLICommandWrapperImplementsInterfaces(t *testing.T) { + cmd := &cli.Command{} + w := newCLICommandWrapper(cmd) + + var _ flagGetter = w + var _ flagGetterWithInt = w + var _ flagGetterWithIsSet = w + var _ argsGetter = w + var _ cidGetter = w + var _ argsFlagGetter = w + var _ cidFlagGetter = w +} diff --git a/pkg/cli/config.go b/pkg/cli/config.go index 78391b7..5616c44 100644 --- a/pkg/cli/config.go +++ b/pkg/cli/config.go @@ -52,14 +52,14 @@ Common keys: auth_token - Authentication token (managed by 'pinner auth')`, ArgsUsage: "[get | set ]", Flags: append(GlobalFlags(), DryRunFlag()), - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return configAction(ctx, cmd, output, defaultConfigManagerFactory) + Action: func(ctx context.Context, c *cli.Command) error { + output := setupOutput(c) + return configAction(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory) }, } } -func configAction(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory) error { +func configAction(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgrFactory ConfigManagerFactory) error { args := cmd.Args() if args.Len() == 0 { @@ -125,7 +125,7 @@ func showAllConfig(output Output, cfgMgrFactory ConfigManagerFactory) error { return nil } -func getConfig(cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory) error { +func getConfig(cmd argsGetter, output Output, cfgMgrFactory ConfigManagerFactory) error { cfgMgr, err := cfgMgrFactory() if err != nil { return fmt.Errorf("failed to initialize config manager: %w", err) @@ -161,7 +161,7 @@ func getConfig(cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFacto return nil } -func setConfig(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory) error { +func setConfig(ctx context.Context, cmd argsFlagGetterWithBool, output Output, cfgMgrFactory ConfigManagerFactory) error { cfgMgr, err := cfgMgrFactory() if err != nil { return fmt.Errorf("failed to initialize config manager: %w", err) diff --git a/pkg/cli/config_test.go b/pkg/cli/config_test.go index 85de21b..bfc3204 100644 --- a/pkg/cli/config_test.go +++ b/pkg/cli/config_test.go @@ -1,9 +1,14 @@ package cli import ( + "context" + "errors" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" ) func TestConfigKeyToEnvVar(t *testing.T) { @@ -28,3 +33,297 @@ func TestConfigKeyToEnvVar(t *testing.T) { }) } } + +func TestNewConfigCommand(t *testing.T) { + cmd := newConfigCommand() + + assert.Equal(t, "config", cmd.Name) + assert.Equal(t, "System", cmd.Category) + assert.NotEmpty(t, cmd.Usage) + assert.NotNil(t, cmd.Action) +} + +func TestShowAllConfigError(t *testing.T) { + output := newTestOutput() + cfgMgrFactory := func() (config.Manager, error) { + return nil, errors.New("config error") + } + + err := showAllConfig(output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to initialize config manager") +} + +func TestShowAllConfigWithBoolValue(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().GetAllDescriptions().Return(map[string]string{"secure": "Use HTTPS"}) + cfgMgr.EXPECT().All().Return(map[string]any{"secure": true}) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + + err := showAllConfig(output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestShowAllConfigWithIntValue(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().GetAllDescriptions().Return(map[string]string{"max_retries": "Max retries"}) + cfgMgr.EXPECT().All().Return(map[string]any{"max_retries": 3}) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + + err := showAllConfig(output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestShowAllConfigWithEmptyDescription(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().GetAllDescriptions().Return(map[string]string{"custom_key": ""}) + cfgMgr.EXPECT().All().Return(map[string]any{"custom_key": "value"}) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + + err := showAllConfig(output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestShowAllConfigWithNotSetValue(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().GetAllDescriptions().Return(map[string]string{"auth_token": "Auth token"}) + cfgMgr.EXPECT().All().Return(map[string]any{}) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + + err := showAllConfig(output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestGetConfigError(t *testing.T) { + output := newTestOutput() + cfgMgrFactory := func() (config.Manager, error) { + return nil, errors.New("config error") + } + + cmd := newMockCommand().withArgs("get", "auth_token") + err := getConfig(cmd, output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to initialize config manager") +} + +func TestSetConfigError(t *testing.T) { + output := newTestOutput() + cfgMgrFactory := func() (config.Manager, error) { + return nil, errors.New("config error") + } + + cmd := newMockCommand().withArgs("set", "base_endpoint", "test.com") + err := setConfig(context.Background(), cmd, output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to initialize config manager") +} + +func TestConfigActionNoArgs(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().GetAllDescriptions().Return(map[string]string{"secure": "Use HTTPS"}) + cfgMgr.EXPECT().All().Return(map[string]any{"secure": true}) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand() + + err := configAction(context.Background(), cmd, output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestConfigActionGetSubcommand(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().Get("auth_token").Return("my-token", true, nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("get", "auth_token") + + err := configAction(context.Background(), cmd, output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestConfigActionSetSubcommand(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().Exists("max_retries").Return(true) + cfgMgr.EXPECT().Get("max_retries").Return(3, true, nil) + cfgMgr.EXPECT().Set(context.Background(), "max_retries", int64(5)).Return(nil) + cfgMgr.EXPECT().Persist().Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("set", "max_retries", "5") + + err := configAction(context.Background(), cmd, output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestConfigActionInvalidAction(t *testing.T) { + output := newTestOutput() + cfgMgrFactory := func() (config.Manager, error) { + return nil, errors.New("should not be called") + } + cmd := newMockCommand().withArgs("delete") + + err := configAction(context.Background(), cmd, output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid action: delete") +} + +func TestGetConfigMissingKey(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("get") + + err := getConfig(cmd, output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "key is required") +} + +func TestGetConfigSuccess(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().Get("base_endpoint").Return("pinner.xyz", true, nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("get", "base_endpoint") + + err := getConfig(cmd, output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestGetConfigGetFails(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().Get("unknown_key").Return(nil, false, errors.New("not found")) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("get", "unknown_key") + + err := getConfig(cmd, output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get config key") +} + +func TestSetConfigMissingKeyOrValue(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("set", "key") + + err := setConfig(context.Background(), cmd, output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "key and value are required") +} + +func TestSetConfigNewKey(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().Exists("custom_key").Return(false) + cfgMgr.EXPECT().Set(context.Background(), "custom_key", "custom_value").Return(nil) + cfgMgr.EXPECT().Persist().Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("set", "custom_key", "custom_value") + + err := setConfig(context.Background(), cmd, output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestSetConfigBoolValue(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().Exists("secure").Return(true) + cfgMgr.EXPECT().Get("secure").Return(true, true, nil) + cfgMgr.EXPECT().Set(context.Background(), "secure", false).Return(nil) + cfgMgr.EXPECT().Persist().Return(nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("set", "secure", "false") + + err := setConfig(context.Background(), cmd, output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestSetConfigInvalidBoolValue(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().Exists("secure").Return(true) + cfgMgr.EXPECT().Get("secure").Return(true, true, nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("set", "secure", "notabool") + + err := setConfig(context.Background(), cmd, output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be true or false") +} + +func TestSetConfigInvalidIntValue(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().Exists("max_retries").Return(true) + cfgMgr.EXPECT().Get("max_retries").Return(3, true, nil) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("set", "max_retries", "notanint") + + err := setConfig(context.Background(), cmd, output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be an integer") +} + +func TestSetConfigDryRun(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().Exists("max_retries").Return(true) + cfgMgr.EXPECT().Get("max_retries").Return(3, true, nil) + cfgMgr.EXPECT().GetDescription("max_retries").Return("Max retries") + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("set", "max_retries", "5").withBool(FlagDryRun, true) + + err := setConfig(context.Background(), cmd, output, cfgMgrFactory) + require.NoError(t, err) +} + +func TestSetConfigPersistFails(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + cfgMgr.EXPECT().Exists("custom_key").Return(false) + cfgMgr.EXPECT().Set(context.Background(), "custom_key", "value").Return(nil) + cfgMgr.EXPECT().Persist().Return(errors.New("disk full")) + + cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } + cmd := newMockCommand().withArgs("set", "custom_key", "value") + + err := setConfig(context.Background(), cmd, output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to save config") +} diff --git a/pkg/cli/confirm_email.go b/pkg/cli/confirm_email.go index e4020ab..b25a8d8 100644 --- a/pkg/cli/confirm_email.go +++ b/pkg/cli/confirm_email.go @@ -40,12 +40,12 @@ After confirmation, authenticate with: }, Action: func(ctx context.Context, cmd *cli.Command) error { output := setupOutput(cmd) - return confirmEmail(ctx, cmd, output, defaultConfigManagerFactory) + return confirmEmail(ctx, newCLICommandWrapper(cmd), output, defaultConfigManagerFactory) }, } } -func confirmEmail(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory) error { +func confirmEmail(ctx context.Context, cmd flagGetter, output Output, cfgMgrFactory ConfigManagerFactory) error { email := cmd.String(FlagEmail) token := cmd.String(FlagToken) diff --git a/pkg/cli/confirm_email_test.go b/pkg/cli/confirm_email_test.go new file mode 100644 index 0000000..ceb523e --- /dev/null +++ b/pkg/cli/confirm_email_test.go @@ -0,0 +1,126 @@ +package cli + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v3" + + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + +func TestNewConfirmEmailCommand(t *testing.T) { + cmd := newConfirmEmailCommand() + + assert.Equal(t, "confirm-email", cmd.Name) + assert.Equal(t, "Setup", cmd.Category) + assert.NotEmpty(t, cmd.Usage) + assert.NotEmpty(t, cmd.Description) + assert.NotNil(t, cmd.Action) + + flagNames := getFlagNames(cmd) + nameSet := make(map[string]bool) + for _, n := range flagNames { + nameSet[n] = true + } + + expectedFlags := []string{FlagEmail, FlagToken} + for _, f := range expectedFlags { + assert.True(t, nameSet[f], "confirm-email command should have flag --%s", f) + } +} + +func TestConfirmEmailConfigManagerError(t *testing.T) { + output := newTestOutput() + + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.StringFlag{Name: FlagEmail, Value: "user@example.com"}, + &cli.StringFlag{Name: FlagToken, Value: "abc123"}, + }, + } + + cfgMgrFactory := func() (config.Manager, error) { + return nil, errors.New("config error") + } + + err := confirmEmail(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to create config manager") +} + +func TestConfirmEmailFlagAliases(t *testing.T) { + cmd := newConfirmEmailCommand() + + emailFlag := findFlag(cmd, FlagEmail) + require.NotNil(t, emailFlag) + assert.Contains(t, emailFlag.Names(), "e", "email flag should have -e alias") + + tokenFlag := findFlag(cmd, FlagToken) + require.NotNil(t, tokenFlag) + assert.Contains(t, tokenFlag.Names(), "t", "token flag should have -t alias") +} + +func TestConfirmEmail_MockCommand_ConfigError(t *testing.T) { + output := newTestOutput() + + cmd := newMockCommand(). + withString(FlagEmail, "user@example.com"). + withString(FlagToken, "abc123") + + err := confirmEmail(context.Background(), cmd, output, failingConfigMgrFactory()) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to create config manager") +} + +func TestConfirmEmail_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/account/verify-email", r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + var buf bytes.Buffer + output := newTestOutput() + output.SetWriter(&buf) + + cmd := newMockCommand(). + withString(FlagEmail, "user@example.com"). + withString(FlagToken, "abc123") + + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + BaseEndpoint: server.URL, + Secure: false, + }).Maybe() + + cfgMgrFactory := func() (config.Manager, error) { + return cfgMgr, nil + } + + err := confirmEmail(context.Background(), cmd, output, cfgMgrFactory) + require.NoError(t, err) + + result := buf.String() + assert.Contains(t, result, "Email verified successfully!") + assert.Contains(t, result, "pinner auth --email user@example.com") +} + +func findFlag(cmd *cli.Command, name string) cli.Flag { + for _, f := range cmd.Flags { + for _, n := range f.Names() { + if n == name { + return f + } + } + } + return nil +} diff --git a/pkg/cli/dns.go b/pkg/cli/dns.go index a300153..7ec397c 100644 --- a/pkg/cli/dns.go +++ b/pkg/cli/dns.go @@ -83,13 +83,9 @@ func newDNSZonesListCommand() *cli.Command { Examples: pinner dns zones list pinner dns zones list --json`, - Action: func(ctx context.Context, cmd *cli.Command) error { - cfgMgr, output, err := setupCommandContext(cmd) - if err != nil { - return err - } - return dnsZonesList(ctx, cmd, output, cfgMgr) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return dnsZonesList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -107,13 +103,9 @@ Examples: RequiredDomainFlag(), NameserversFlag(), }, - Action: func(ctx context.Context, cmd *cli.Command) error { - cfgMgr, output, err := setupCommandContext(cmd) - if err != nil { - return err - } - return dnsZonesCreate(ctx, cmd, output, cfgMgr) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return dnsZonesCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -127,13 +119,9 @@ Examples: pinner dns zones get example.com pinner dns zones get example.com --json`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - cfgMgr, output, err := setupCommandContext(cmd) - if err != nil { - return err - } - return dnsZonesGet(ctx, cmd, output, cfgMgr) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return dnsZonesGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -146,13 +134,9 @@ func newDNSZonesDeleteCommand() *cli.Command { Examples: pinner dns zones delete example.com`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - cfgMgr, output, err := setupCommandContext(cmd) - if err != nil { - return err - } - return dnsZonesDelete(ctx, cmd, output, cfgMgr) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return dnsZonesDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -167,13 +151,9 @@ Examples: pinner dns zones validate example.com pinner dns zones validate example.com --json`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - cfgMgr, output, err := setupCommandContext(cmd) - if err != nil { - return err - } - return dnsZonesValidate(ctx, cmd, output, cfgMgr) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return dnsZonesValidate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -211,13 +191,9 @@ Examples: pinner dns records list example.com pinner dns records list example.com --json`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - cfgMgr, output, err := setupCommandContext(cmd) - if err != nil { - return err - } - return dnsRecordsList(ctx, cmd, output, cfgMgr) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return dnsRecordsList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -239,13 +215,9 @@ Examples: TTLFlag(), DisabledFlag(), }, - Action: func(ctx context.Context, cmd *cli.Command) error { - cfgMgr, output, err := setupCommandContext(cmd) - if err != nil { - return err - } - return dnsRecordsCreate(ctx, cmd, output, cfgMgr) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return dnsRecordsCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -263,13 +235,9 @@ Examples: RequiredNameFlag("DNS record name"), RequiredTypeFlag(), }, - Action: func(ctx context.Context, cmd *cli.Command) error { - cfgMgr, output, err := setupCommandContext(cmd) - if err != nil { - return err - } - return dnsRecordsGet(ctx, cmd, output, cfgMgr) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return dnsRecordsGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -291,13 +259,9 @@ Examples: TTLFlag(), DisabledFlag(), }, - Action: func(ctx context.Context, cmd *cli.Command) error { - cfgMgr, output, err := setupCommandContext(cmd) - if err != nil { - return err - } - return dnsRecordsUpdate(ctx, cmd, output, cfgMgr) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return dnsRecordsUpdate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -314,31 +278,17 @@ Examples: RequiredNameFlag("DNS record name"), RequiredTypeFlag(), }, - Action: func(ctx context.Context, cmd *cli.Command) error { - cfgMgr, output, err := setupCommandContext(cmd) - if err != nil { - return err - } - return dnsRecordsDelete(ctx, cmd, output, cfgMgr) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return dnsRecordsDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } // ===== HANDLERS ===== -func dnsZonesList(ctx context.Context, cmd *cli.Command, output Output, cfgMgr config.Manager) error { - var dnsService DNSService - - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - - if authToken != "" { - dnsService = NewDNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - dnsService = defaultDNSServiceFactory(cfgMgr, output) - } - - if err := dnsService.RequireAuthenticated(); err != nil { +func dnsZonesList(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + if err != nil { return err } @@ -375,7 +325,7 @@ func dnsZonesList(ctx context.Context, cmd *cli.Command, output Output, cfgMgr c return nil } -func dnsZonesCreate(ctx context.Context, cmd *cli.Command, output Output, cfgMgr config.Manager) error { +func dnsZonesCreate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { domain := cmd.String(FlagDomain) if err := validateDomain(domain); err != nil { @@ -388,18 +338,8 @@ func dnsZonesCreate(ctx context.Context, cmd *cli.Command, output Output, cfgMgr nameservers = parseCommaSeparated(nameserversStr) } - var dnsService DNSService - - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - - if authToken != "" { - dnsService = NewDNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - dnsService = defaultDNSServiceFactory(cfgMgr, output) - } - - if err := dnsService.RequireAuthenticated(); err != nil { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + if err != nil { return err } @@ -427,7 +367,7 @@ func dnsZonesCreate(ctx context.Context, cmd *cli.Command, output Output, cfgMgr return nil } -func dnsZonesGet(ctx context.Context, cmd *cli.Command, output Output, cfgMgr config.Manager) error { +func dnsZonesGet(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -435,18 +375,8 @@ func dnsZonesGet(ctx context.Context, cmd *cli.Command, output Output, cfgMgr co arg := args.First() - var dnsService DNSService - - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - - if authToken != "" { - dnsService = NewDNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - dnsService = defaultDNSServiceFactory(cfgMgr, output) - } - - if err := dnsService.RequireAuthenticated(); err != nil { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + if err != nil { return err } @@ -478,7 +408,7 @@ func dnsZonesGet(ctx context.Context, cmd *cli.Command, output Output, cfgMgr co return nil } -func dnsZonesDelete(ctx context.Context, cmd *cli.Command, output Output, cfgMgr config.Manager) error { +func dnsZonesDelete(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -486,18 +416,8 @@ func dnsZonesDelete(ctx context.Context, cmd *cli.Command, output Output, cfgMgr arg := args.First() - var dnsService DNSService - - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - - if authToken != "" { - dnsService = NewDNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - dnsService = defaultDNSServiceFactory(cfgMgr, output) - } - - if err := dnsService.RequireAuthenticated(); err != nil { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + if err != nil { return err } @@ -515,7 +435,7 @@ func dnsZonesDelete(ctx context.Context, cmd *cli.Command, output Output, cfgMgr return nil } -func dnsZonesValidate(ctx context.Context, cmd *cli.Command, output Output, cfgMgr config.Manager) error { +func dnsZonesValidate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -523,18 +443,8 @@ func dnsZonesValidate(ctx context.Context, cmd *cli.Command, output Output, cfgM arg := args.First() - var dnsService DNSService - - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - - if authToken != "" { - dnsService = NewDNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - dnsService = defaultDNSServiceFactory(cfgMgr, output) - } - - if err := dnsService.RequireAuthenticated(); err != nil { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + if err != nil { return err } @@ -584,7 +494,7 @@ func dnsZonesValidate(ctx context.Context, cmd *cli.Command, output Output, cfgM return nil } -func dnsRecordsList(ctx context.Context, cmd *cli.Command, output Output, cfgMgr config.Manager) error { +func dnsRecordsList(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -592,18 +502,8 @@ func dnsRecordsList(ctx context.Context, cmd *cli.Command, output Output, cfgMgr arg := args.First() - var dnsService DNSService - - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - - if authToken != "" { - dnsService = NewDNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - dnsService = defaultDNSServiceFactory(cfgMgr, output) - } - - if err := dnsService.RequireAuthenticated(); err != nil { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + if err != nil { return err } @@ -646,7 +546,7 @@ func dnsRecordsList(ctx context.Context, cmd *cli.Command, output Output, cfgMgr return nil } -func dnsRecordsCreate(ctx context.Context, cmd *cli.Command, output Output, cfgMgr config.Manager) error { +func dnsRecordsCreate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -676,18 +576,8 @@ func dnsRecordsCreate(ctx context.Context, cmd *cli.Command, output Output, cfgM Disabled: &disabled, } - var dnsService DNSService - - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - - if authToken != "" { - dnsService = NewDNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - dnsService = defaultDNSServiceFactory(cfgMgr, output) - } - - if err := dnsService.RequireAuthenticated(); err != nil { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + if err != nil { return err } @@ -721,7 +611,7 @@ func dnsRecordsCreate(ctx context.Context, cmd *cli.Command, output Output, cfgM return nil } -func dnsRecordsGet(ctx context.Context, cmd *cli.Command, output Output, cfgMgr config.Manager) error { +func dnsRecordsGet(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -731,18 +621,8 @@ func dnsRecordsGet(ctx context.Context, cmd *cli.Command, output Output, cfgMgr name := cmd.String(FlagName) recordType := cmd.String(FlagType) - var dnsService DNSService - - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - - if authToken != "" { - dnsService = NewDNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - dnsService = defaultDNSServiceFactory(cfgMgr, output) - } - - if err := dnsService.RequireAuthenticated(); err != nil { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + if err != nil { return err } @@ -777,7 +657,7 @@ func dnsRecordsGet(ctx context.Context, cmd *cli.Command, output Output, cfgMgr return nil } -func dnsRecordsUpdate(ctx context.Context, cmd *cli.Command, output Output, cfgMgr config.Manager) error { +func dnsRecordsUpdate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -807,18 +687,8 @@ func dnsRecordsUpdate(ctx context.Context, cmd *cli.Command, output Output, cfgM Disabled: &disabled, } - var dnsService DNSService - - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - - if authToken != "" { - dnsService = NewDNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - dnsService = defaultDNSServiceFactory(cfgMgr, output) - } - - if err := dnsService.RequireAuthenticated(); err != nil { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + if err != nil { return err } @@ -852,7 +722,7 @@ func dnsRecordsUpdate(ctx context.Context, cmd *cli.Command, output Output, cfgM return nil } -func dnsRecordsDelete(ctx context.Context, cmd *cli.Command, output Output, cfgMgr config.Manager) error { +func dnsRecordsDelete(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -862,18 +732,8 @@ func dnsRecordsDelete(ctx context.Context, cmd *cli.Command, output Output, cfgM name := cmd.String(FlagName) recordType := cmd.String(FlagType) - var dnsService DNSService - - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - - if authToken != "" { - dnsService = NewDNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - dnsService = defaultDNSServiceFactory(cfgMgr, output) - } - - if err := dnsService.RequireAuthenticated(); err != nil { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + if err != nil { return err } diff --git a/pkg/cli/dns_helpers_test.go b/pkg/cli/dns_helpers_test.go new file mode 100644 index 0000000..d6de595 --- /dev/null +++ b/pkg/cli/dns_helpers_test.go @@ -0,0 +1,148 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsValidIPv4(t *testing.T) { + tests := []struct { + ip string + expected bool + }{ + {"192.168.1.1", true}, + {"0.0.0.0", true}, + {"255.255.255.255", true}, + {"10.0.0.1", true}, + {"::1", false}, + {"2001:db8::1", false}, + {"not-an-ip", false}, + {"", false}, + {"256.1.1.1", false}, + } + + for _, tt := range tests { + t.Run(tt.ip, func(t *testing.T) { + assert.Equal(t, tt.expected, isValidIPv4(tt.ip)) + }) + } +} + +func TestIsValidIPv6(t *testing.T) { + tests := []struct { + ip string + expected bool + }{ + {"::1", true}, + {"2001:db8::1", true}, + {"fe80::1", true}, + {"192.168.1.1", false}, + {"not-an-ip", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.ip, func(t *testing.T) { + assert.Equal(t, tt.expected, isValidIPv6(tt.ip)) + }) + } +} + +func TestIsValidDomain(t *testing.T) { + tests := []struct { + domain string + expected bool + }{ + {"example.com", true}, + {"sub.example.com", true}, + {"a.b.c.d", true}, + {"", false}, + {"single", false}, + {".com", false}, + {"example.", false}, + } + + for _, tt := range tests { + t.Run(tt.domain, func(t *testing.T) { + assert.Equal(t, tt.expected, isValidDomain(tt.domain)) + }) + } +} + +func TestValidateDomain(t *testing.T) { + t.Run("valid domain", func(t *testing.T) { + err := validateDomain("example.com") + require.NoError(t, err) + }) + + t.Run("empty domain", func(t *testing.T) { + err := validateDomain("") + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot be empty") + }) +} + +func TestValidateDNSRecord(t *testing.T) { + tests := []struct { + name string + rtype string + content string + wantErr bool + errMatch string + }{ + {"valid A record", "A", "1.2.3.4", false, ""}, + {"invalid A record", "A", "not-an-ip", true, "invalid IPv4"}, + {"valid AAAA record", "AAAA", "::1", false, ""}, + {"invalid AAAA record", "AAAA", "1.2.3.4", true, "invalid IPv6"}, + {"valid CNAME record", "CNAME", "example.com", false, ""}, + {"invalid CNAME record", "CNAME", "single", true, "invalid domain"}, + {"valid MX record", "MX", "mail.example.com", false, ""}, + {"valid NS record", "NS", "ns1.example.com", false, ""}, + {"valid TXT record", "TXT", "some text", false, ""}, + {"TXT too long", "TXT", string(make([]byte, 256)), true, "too long"}, + {"unsupported type", "SRV", "whatever", true, "unsupported record type"}, + {"lowercase type", "a", "1.2.3.4", false, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateDNSRecord(tt.rtype, tt.content) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestParseCommaSeparated(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + {"empty string", "", nil}, + {"single value", "ns1.example.com", []string{"ns1.example.com"}}, + {"two values", "ns1.example.com,ns2.example.com", []string{"ns1.example.com", "ns2.example.com"}}, + {"with spaces", " ns1.example.com , ns2.example.com ", []string{"ns1.example.com", "ns2.example.com"}}, + {"trailing comma", "a,b,", []string{"a", "b"}}, + {"only commas", ",,,", []string{}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseCommaSeparated(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestResolveZoneID_NumericArg(t *testing.T) { + id, err := resolveZoneID(nil, nil, "42") + require.NoError(t, err) + assert.Equal(t, "42", id) +} diff --git a/pkg/cli/dns_service.go b/pkg/cli/dns_service.go index ba484fd..acc5bdc 100644 --- a/pkg/cli/dns_service.go +++ b/pkg/cli/dns_service.go @@ -28,44 +28,56 @@ type DNSService interface { // dnsServiceCLI wraps the SDK DNS service with CLI-specific functionality. type dnsServiceCLI struct { - service ipfs.DNSService - cfgMgr config.Manager - output Output - authToken string - authenticated bool + ipfsServiceBase + service ipfs.DNSService + output Output + client *ipfs.Client // injected client (nil = create default) } -// NewDNSService creates a new DNSService instance with the provided configuration. -func NewDNSService(cfgMgr config.Manager, output Output, apiEndpoint string) DNSService { - authToken := cfgMgr.Config().AuthToken +// DNSServiceOption is a function that configures a dnsServiceCLI. +type DNSServiceOption func(*dnsServiceCLI) - client, err := ipfs.NewClient(apiEndpoint, authToken) - if err != nil { - output.PrintError(err) - return &dnsServiceCLI{ - service: nil, - cfgMgr: cfgMgr, - output: output, - authToken: authToken, - authenticated: false, - } +// WithDNSAuthToken sets an auth token override that takes precedence over config. +func WithDNSAuthToken(token string) DNSServiceOption { + return func(s *dnsServiceCLI) { + withAuthToken(token)(&s.ipfsServiceBase) } +} - return &dnsServiceCLI{ - service: client.DNS(), - cfgMgr: cfgMgr, - output: output, - authToken: authToken, - authenticated: authToken != "", +// WithDNSClient sets a pre-configured ipfs.Client, bypassing the default ipfs.NewClient() call. +func WithDNSClient(client *ipfs.Client) DNSServiceOption { + return func(s *dnsServiceCLI) { + s.client = client } } -// RequireAuthenticated checks if the user is authenticated. -func (s *dnsServiceCLI) RequireAuthenticated() error { - if !s.authenticated { - return ErrNotAuthenticated +// NewDNSService creates a new DNSService instance with the provided configuration. +func NewDNSService(cfgMgr config.Manager, output Output, apiEndpoint string, opts ...DNSServiceOption) DNSService { + authToken := cfgMgr.Config().AuthToken + + s := &dnsServiceCLI{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: authToken, + }, + output: output, + } + for _, opt := range opts { + opt(s) + } + + if s.client != nil { + s.service = s.client.DNS() + } else { + client, err := ipfs.NewClient(apiEndpoint, authToken) + if err != nil { + output.PrintError(err) + s.service = nil + return s + } + s.service = client.DNS() } - return nil + return s } // CreateZone creates a new DNS zone. @@ -178,8 +190,23 @@ func (s *dnsServiceCLI) DeleteRecord(ctx context.Context, id string, name string return s.service.DeleteRecord(ctx, id, name, recordType) } -// defaultDNSServiceFactory creates a default DNS service instance. -func defaultDNSServiceFactory(cfgMgr config.Manager, output Output) DNSService { +type dnsServiceFactoryFunc func(cfgMgr config.Manager, output Output, opts ...DNSServiceOption) DNSService + +var dnsServiceFactory dnsServiceFactoryFunc = defaultDNSServiceFactory + +func defaultDNSServiceFactory(cfgMgr config.Manager, output Output, opts ...DNSServiceOption) DNSService { apiEndpoint := cfgMgr.Config().GetIPFSEndpointSecure() - return NewDNSService(cfgMgr, output, apiEndpoint) + return NewDNSService(cfgMgr, output, apiEndpoint, opts...) +} + +func newAuthenticatedDNSService(cfgMgr config.Manager, output Output, authToken string) (DNSService, error) { + var svcOpts []DNSServiceOption + if authToken != "" { + svcOpts = append(svcOpts, WithDNSAuthToken(authToken)) + } + dnsService := dnsServiceFactory(cfgMgr, output, svcOpts...) + if err := dnsService.RequireAuthenticated(); err != nil { + return nil, err + } + return dnsService, nil } diff --git a/pkg/cli/dns_service_crud_test.go b/pkg/cli/dns_service_crud_test.go new file mode 100644 index 0000000..2b52e7e --- /dev/null +++ b/pkg/cli/dns_service_crud_test.go @@ -0,0 +1,169 @@ +package cli + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + ipfs "go.lumeweb.com/ipfs-sdk" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + +func newUnauthDNSService(t *testing.T) *dnsServiceCLI { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: ""}).Maybe() + return &dnsServiceCLI{ + ipfsServiceBase: ipfsServiceBase{cfgMgr: cfgMgr, authToken: ""}, + } +} + +func newAuthedNilDNSService(t *testing.T) *dnsServiceCLI { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: "token"}).Maybe() + return &dnsServiceCLI{ + ipfsServiceBase: ipfsServiceBase{cfgMgr: cfgMgr, authToken: "token"}, + service: nil, + } +} + +func TestDNSService_CreateZone_Unauthenticated(t *testing.T) { + svc := newUnauthDNSService(t) + _, err := svc.CreateZone(context.Background(), "example.com", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestDNSService_CreateZone_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilDNSService(t) + _, err := svc.CreateZone(context.Background(), "example.com", nil) + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestDNSService_ListZones_Unauthenticated(t *testing.T) { + svc := newUnauthDNSService(t) + _, err := svc.ListZones(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestDNSService_ListZones_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilDNSService(t) + _, err := svc.ListZones(context.Background()) + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestDNSService_GetZone_Unauthenticated(t *testing.T) { + svc := newUnauthDNSService(t) + _, err := svc.GetZone(context.Background(), "123") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestDNSService_GetZone_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilDNSService(t) + _, err := svc.GetZone(context.Background(), "123") + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestDNSService_DeleteZone_Unauthenticated(t *testing.T) { + svc := newUnauthDNSService(t) + err := svc.DeleteZone(context.Background(), "123") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestDNSService_DeleteZone_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilDNSService(t) + err := svc.DeleteZone(context.Background(), "123") + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestDNSService_ValidateZone_Unauthenticated(t *testing.T) { + svc := newUnauthDNSService(t) + _, err := svc.ValidateZone(context.Background(), "123") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestDNSService_ValidateZone_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilDNSService(t) + _, err := svc.ValidateZone(context.Background(), "123") + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestDNSService_CreateRecord_Unauthenticated(t *testing.T) { + svc := newUnauthDNSService(t) + _, err := svc.CreateRecord(context.Background(), "123", ipfs.RecordRequest{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestDNSService_CreateRecord_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilDNSService(t) + _, err := svc.CreateRecord(context.Background(), "123", ipfs.RecordRequest{}) + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestDNSService_ListRecords_Unauthenticated(t *testing.T) { + svc := newUnauthDNSService(t) + _, err := svc.ListRecords(context.Background(), "123") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestDNSService_ListRecords_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilDNSService(t) + _, err := svc.ListRecords(context.Background(), "123") + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestDNSService_GetRecord_Unauthenticated(t *testing.T) { + svc := newUnauthDNSService(t) + _, err := svc.GetRecord(context.Background(), "123", "www", "A") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestDNSService_GetRecord_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilDNSService(t) + _, err := svc.GetRecord(context.Background(), "123", "www", "A") + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestDNSService_UpdateRecord_Unauthenticated(t *testing.T) { + svc := newUnauthDNSService(t) + _, err := svc.UpdateRecord(context.Background(), "123", "www", "A", ipfs.RecordRequest{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestDNSService_UpdateRecord_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilDNSService(t) + _, err := svc.UpdateRecord(context.Background(), "123", "www", "A", ipfs.RecordRequest{}) + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestDNSService_DeleteRecord_Unauthenticated(t *testing.T) { + svc := newUnauthDNSService(t) + err := svc.DeleteRecord(context.Background(), "123", "www", "A") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestDNSService_DeleteRecord_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilDNSService(t) + err := svc.DeleteRecord(context.Background(), "123", "www", "A") + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} diff --git a/pkg/cli/dns_service_test.go b/pkg/cli/dns_service_test.go new file mode 100644 index 0000000..26ae9d5 --- /dev/null +++ b/pkg/cli/dns_service_test.go @@ -0,0 +1,115 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/require" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" + "go.lumeweb.com/pinner-cli/pkg/config" +) + +func TestDNSService_RequireAuthenticated(t *testing.T) { + t.Run("authenticated with override token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "", + }).Maybe() + + svc := &dnsServiceCLI{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "test-token", + }, + } + + err := svc.RequireAuthenticated() + require.NoError(t, err) + }) + + t.Run("not authenticated when no token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "", + }).Maybe() + + svc := &dnsServiceCLI{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "", + }, + } + + err := svc.RequireAuthenticated() + require.Error(t, err) + require.Contains(t, err.Error(), "not authenticated") + }) +} + +func TestDNSService_AuthTokenOverride(t *testing.T) { + t.Run("override token takes precedence over empty config token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "", + }).Maybe() + + svc := &dnsServiceCLI{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "override-token", + }, + } + + err := svc.RequireAuthenticated() + require.NoError(t, err) + }) + + t.Run("override token takes precedence over config token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "config-token", + }).Maybe() + + svc := &dnsServiceCLI{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "override-token", + }, + } + + require.Equal(t, "override-token", svc.getAuthToken()) + }) + + t.Run("falls back to config token when override is empty", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "config-token", + }).Maybe() + + svc := &dnsServiceCLI{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "", + }, + } + + require.Equal(t, "config-token", svc.getAuthToken()) + }) + + t.Run("WithDNSAuthToken functional option sets override", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "", + }).Maybe() + + svc := &dnsServiceCLI{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + }, + } + WithDNSAuthToken("override-token")(svc) + + require.Equal(t, "override-token", svc.getAuthToken()) + err := svc.RequireAuthenticated() + require.NoError(t, err) + }) +} diff --git a/pkg/cli/dns_test.go b/pkg/cli/dns_test.go new file mode 100644 index 0000000..50b3b8b --- /dev/null +++ b/pkg/cli/dns_test.go @@ -0,0 +1,875 @@ +package cli + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + ipfs "go.lumeweb.com/ipfs-sdk" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + +type mockDNSServiceForCLI struct { + requireAuthenticatedErr error + listZonesFunc func(ctx context.Context) ([]ipfs.ZoneListResponse, error) + createZoneFunc func(ctx context.Context, domain string, nameservers []string) (*ipfs.ZoneResponse, error) + getZoneFunc func(ctx context.Context, id string) (*ipfs.ZoneResponse, error) + deleteZoneFunc func(ctx context.Context, id string) error + validateZoneFunc func(ctx context.Context, id string) (*ipfs.ValidationResponse, error) + createRecordFunc func(ctx context.Context, id string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) + listRecordsFunc func(ctx context.Context, id string) ([]ipfs.RecordResponse, error) + getRecordFunc func(ctx context.Context, id string, name string, recordType string) (*ipfs.RecordResponse, error) + updateRecordFunc func(ctx context.Context, id string, name string, recordType string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) + deleteRecordFunc func(ctx context.Context, id string, name string, recordType string) error +} + +func (m *mockDNSServiceForCLI) RequireAuthenticated() error { + return m.requireAuthenticatedErr +} + +func (m *mockDNSServiceForCLI) ListZones(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + if m.listZonesFunc != nil { + return m.listZonesFunc(ctx) + } + return nil, nil +} + +func (m *mockDNSServiceForCLI) CreateZone(ctx context.Context, domain string, nameservers []string) (*ipfs.ZoneResponse, error) { + if m.createZoneFunc != nil { + return m.createZoneFunc(ctx, domain, nameservers) + } + return nil, nil +} + +func (m *mockDNSServiceForCLI) GetZone(ctx context.Context, id string) (*ipfs.ZoneResponse, error) { + if m.getZoneFunc != nil { + return m.getZoneFunc(ctx, id) + } + return nil, nil +} + +func (m *mockDNSServiceForCLI) DeleteZone(ctx context.Context, id string) error { + if m.deleteZoneFunc != nil { + return m.deleteZoneFunc(ctx, id) + } + return nil +} + +func (m *mockDNSServiceForCLI) ValidateZone(ctx context.Context, id string) (*ipfs.ValidationResponse, error) { + if m.validateZoneFunc != nil { + return m.validateZoneFunc(ctx, id) + } + return nil, nil +} + +func (m *mockDNSServiceForCLI) CreateRecord(ctx context.Context, id string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) { + if m.createRecordFunc != nil { + return m.createRecordFunc(ctx, id, record) + } + return nil, nil +} + +func (m *mockDNSServiceForCLI) ListRecords(ctx context.Context, id string) ([]ipfs.RecordResponse, error) { + if m.listRecordsFunc != nil { + return m.listRecordsFunc(ctx, id) + } + return nil, nil +} + +func (m *mockDNSServiceForCLI) GetRecord(ctx context.Context, id string, name string, recordType string) (*ipfs.RecordResponse, error) { + if m.getRecordFunc != nil { + return m.getRecordFunc(ctx, id, name, recordType) + } + return nil, nil +} + +func (m *mockDNSServiceForCLI) UpdateRecord(ctx context.Context, id string, name string, recordType string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) { + if m.updateRecordFunc != nil { + return m.updateRecordFunc(ctx, id, name, recordType, record) + } + return nil, nil +} + +func (m *mockDNSServiceForCLI) DeleteRecord(ctx context.Context, id string, name string, recordType string) error { + if m.deleteRecordFunc != nil { + return m.deleteRecordFunc(ctx, id, name, recordType) + } + return nil +} + +func setupDNSHandlerTest(t *testing.T) (*mockDNSServiceForCLI, *configmocks.MockManager) { + t.Helper() + mockSvc := &mockDNSServiceForCLI{} + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + BaseEndpoint: "pinner.xyz", + Secure: true, + AuthToken: "test-token", + }).Maybe() + + origFactory := dnsServiceFactory + t.Cleanup(func() { dnsServiceFactory = origFactory }) + dnsServiceFactory = func(config.Manager, Output, ...DNSServiceOption) DNSService { + return mockSvc + } + + return mockSvc, cfgMgr +} + +// ===== dnsZonesList ===== + +func TestDnsZonesList_Success(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + now := time.Now() + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return []ipfs.ZoneListResponse{ + {Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}, + {Id: 2, Domain: "test.org", Status: "pending", CreatedAt: now, UpdatedAt: now}, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand() + err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsZonesList_Empty(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return []ipfs.ZoneListResponse{}, nil + } + + output := newTestOutput() + cmd := newMockCommand() + err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsZonesList_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return nil, errors.New("server error") + } + + output := newTestOutput() + cmd := newMockCommand() + err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to list zones") +} + +func TestDnsZonesList_Unauthenticated(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.requireAuthenticatedErr = ErrNotAuthenticated + + output := newTestOutput() + cmd := newMockCommand() + err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrNotAuthenticated)) +} + +// ===== dnsZonesCreate ===== + +func TestDnsZonesCreate_Success(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + now := time.Now() + mockSvc.createZoneFunc = func(ctx context.Context, domain string, nameservers []string) (*ipfs.ZoneResponse, error) { + assert.Equal(t, "example.com", domain) + assert.Nil(t, nameservers) + return &ipfs.ZoneResponse{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withString(FlagDomain, "example.com") + err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsZonesCreate_WithNameservers(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + now := time.Now() + mockSvc.createZoneFunc = func(ctx context.Context, domain string, nameservers []string) (*ipfs.ZoneResponse, error) { + assert.Equal(t, "example.com", domain) + assert.Equal(t, []string{"ns1.example.com", "ns2.example.com"}, nameservers) + return &ipfs.ZoneResponse{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}, nil + } + + output := newTestOutput() + cmd := newMockCommand(). + withString(FlagDomain, "example.com"). + withString(FlagNameservers, "ns1.example.com,ns2.example.com") + err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsZonesCreate_EmptyDomain(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand().withString(FlagDomain, "") + err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain cannot be empty") +} + +func TestDnsZonesCreate_InvalidDomain(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand().withString(FlagDomain, "a..b") + err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid domain format") +} + +func TestDnsZonesCreate_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.createZoneFunc = func(ctx context.Context, domain string, nameservers []string) (*ipfs.ZoneResponse, error) { + return nil, errors.New("conflict") + } + + output := newTestOutput() + cmd := newMockCommand().withString(FlagDomain, "example.com") + err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to create zone") +} + +// ===== dnsZonesGet ===== + +func TestDnsZonesGet_Success(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + now := time.Now() + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return []ipfs.ZoneListResponse{{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}}, nil + } + mockSvc.getZoneFunc = func(ctx context.Context, id string) (*ipfs.ZoneResponse, error) { + assert.Equal(t, "1", id) + return &ipfs.ZoneResponse{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsZonesGet_NumericID(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + now := time.Now() + mockSvc.getZoneFunc = func(ctx context.Context, id string) (*ipfs.ZoneResponse, error) { + assert.Equal(t, "42", id) + return &ipfs.ZoneResponse{Id: 42, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("42") + err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsZonesGet_MissingArg(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain or zone ID is required") +} + +func TestDnsZonesGet_ZoneNotFound(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return []ipfs.ZoneListResponse{}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("nonexistent.com") + err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "zone not found") +} + +func TestDnsZonesGet_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.getZoneFunc = func(ctx context.Context, id string) (*ipfs.ZoneResponse, error) { + return nil, errors.New("server error") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "server error") +} + +// ===== dnsZonesDelete ===== + +func TestDnsZonesDelete_Success(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.deleteZoneFunc = func(ctx context.Context, id string) error { + assert.Equal(t, "1", id) + return nil + } + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + now := time.Now() + return []ipfs.ZoneListResponse{{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsZonesDelete_NumericID(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.deleteZoneFunc = func(ctx context.Context, id string) error { + assert.Equal(t, "42", id) + return nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("42") + err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsZonesDelete_MissingArg(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain or zone ID is required") +} + +func TestDnsZonesDelete_ZoneNotFound(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return []ipfs.ZoneListResponse{}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("nonexistent.com") + err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "zone not found") +} + +func TestDnsZonesDelete_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.deleteZoneFunc = func(ctx context.Context, id string) error { + return errors.New("server error") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to delete zone") +} + +// ===== dnsZonesValidate ===== + +func TestDnsZonesValidate_Success(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + now := time.Now() + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return []ipfs.ZoneListResponse{{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}}, nil + } + mockSvc.getZoneFunc = func(ctx context.Context, id string) (*ipfs.ZoneResponse, error) { + return &ipfs.ZoneResponse{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}, nil + } + mockSvc.validateZoneFunc = func(ctx context.Context, id string) (*ipfs.ValidationResponse, error) { + assert.Equal(t, "1", id) + return &ipfs.ValidationResponse{Valid: true, Message: "Nameservers are properly delegated", CheckedAt: now}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsZonesValidate_ValidationFailure(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + now := time.Now() + ns := []string{"ns1.pinner.xyz", "ns2.pinner.xyz"} + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return []ipfs.ZoneListResponse{{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}}, nil + } + mockSvc.getZoneFunc = func(ctx context.Context, id string) (*ipfs.ZoneResponse, error) { + return &ipfs.ZoneResponse{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}, nil + } + mockSvc.validateZoneFunc = func(ctx context.Context, id string) (*ipfs.ValidationResponse, error) { + return &ipfs.ValidationResponse{Valid: false, Message: "Nameservers not delegated", Nameservers: &ns, CheckedAt: now}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) // validation failure is not an error, it's a result +} + +func TestDnsZonesValidate_MissingArg(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain or zone ID is required") +} + +func TestDnsZonesValidate_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + now := time.Now() + mockSvc.getZoneFunc = func(ctx context.Context, id string) (*ipfs.ZoneResponse, error) { + return &ipfs.ZoneResponse{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}, nil + } + mockSvc.validateZoneFunc = func(ctx context.Context, id string) (*ipfs.ValidationResponse, error) { + return nil, errors.New("server error") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to validate zone") +} + +// ===== dnsRecordsList ===== + +func TestDnsRecordsList_Success(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.listRecordsFunc = func(ctx context.Context, id string) ([]ipfs.RecordResponse, error) { + assert.Equal(t, "1", id) + return []ipfs.RecordResponse{ + {ZoneId: 1, Name: "www", Type: "CNAME", Content: "example.com", Ttl: 3600, Disabled: false}, + {ZoneId: 1, Name: "@", Type: "A", Content: "1.2.3.4", Ttl: 3600, Disabled: false}, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsRecordsList_Empty(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.listRecordsFunc = func(ctx context.Context, id string) ([]ipfs.RecordResponse, error) { + return []ipfs.RecordResponse{}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsRecordsList_MissingArg(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain or zone ID is required") +} + +func TestDnsRecordsList_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.listRecordsFunc = func(ctx context.Context, id string) ([]ipfs.RecordResponse, error) { + return nil, errors.New("server error") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to list records") +} + +func TestDnsRecordsList_DomainArg(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + now := time.Now() + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return []ipfs.ZoneListResponse{{Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}}, nil + } + mockSvc.listRecordsFunc = func(ctx context.Context, id string) ([]ipfs.RecordResponse, error) { + assert.Equal(t, "1", id) + return []ipfs.RecordResponse{}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +// ===== dnsRecordsCreate ===== + +func TestDnsRecordsCreate_Success(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.createRecordFunc = func(ctx context.Context, id string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) { + assert.Equal(t, "1", id) + assert.Equal(t, "www", record.Name) + assert.Equal(t, "CNAME", record.Type) + assert.Equal(t, "example.com", record.Content) + return &ipfs.RecordResponse{ZoneId: 1, Name: "www", Type: "CNAME", Content: "example.com", Ttl: 3600, Disabled: false}, nil + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "CNAME"). + withString(FlagContent, "example.com") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsRecordsCreate_ARecord(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.createRecordFunc = func(ctx context.Context, id string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) { + assert.Equal(t, "A", record.Type) + assert.Equal(t, "1.2.3.4", record.Content) + return &ipfs.RecordResponse{ZoneId: 1, Name: "@", Type: "A", Content: "1.2.3.4", Ttl: 3600, Disabled: false}, nil + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "@"). + withString(FlagType, "A"). + withString(FlagContent, "1.2.3.4") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsRecordsCreate_MissingArg(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand(). + withString(FlagName, "www"). + withString(FlagType, "CNAME"). + withString(FlagContent, "example.com") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain or zone ID is required") +} + +func TestDnsRecordsCreate_InvalidRecordType(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "INVALID"). + withString(FlagContent, "example.com") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported record type") +} + +func TestDnsRecordsCreate_InvalidARecordContent(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "A"). + withString(FlagContent, "not-an-ip") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid IPv4 address") +} + +func TestDnsRecordsCreate_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.createRecordFunc = func(ctx context.Context, id string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) { + return nil, errors.New("server error") + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "CNAME"). + withString(FlagContent, "example.com") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to create record") +} + +func TestDnsRecordsCreate_DefaultTTL(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.createRecordFunc = func(ctx context.Context, id string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) { + assert.NotNil(t, record.Ttl) + assert.Equal(t, 3600, *record.Ttl) // default TTL + return &ipfs.RecordResponse{ZoneId: 1, Name: "www", Type: "CNAME", Content: "example.com", Ttl: 3600, Disabled: false}, nil + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "CNAME"). + withString(FlagContent, "example.com") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsRecordsCreate_CustomTTL(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.createRecordFunc = func(ctx context.Context, id string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) { + assert.NotNil(t, record.Ttl) + assert.Equal(t, 7200, *record.Ttl) + return &ipfs.RecordResponse{ZoneId: 1, Name: "www", Type: "CNAME", Content: "example.com", Ttl: 7200, Disabled: false}, nil + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "CNAME"). + withString(FlagContent, "example.com"). + withUint(FlagTTL, 7200) + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +// ===== dnsRecordsGet ===== + +func TestDnsRecordsGet_Success(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.getRecordFunc = func(ctx context.Context, id string, name string, recordType string) (*ipfs.RecordResponse, error) { + assert.Equal(t, "1", id) + assert.Equal(t, "www", name) + assert.Equal(t, "CNAME", recordType) + return &ipfs.RecordResponse{ZoneId: 1, Name: "www", Type: "CNAME", Content: "example.com", Ttl: 3600, Disabled: false}, nil + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "CNAME") + err := dnsRecordsGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsRecordsGet_MissingArg(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand(). + withString(FlagName, "www"). + withString(FlagType, "CNAME") + err := dnsRecordsGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain or zone ID is required") +} + +func TestDnsRecordsGet_NotFound(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.getRecordFunc = func(ctx context.Context, id string, name string, recordType string) (*ipfs.RecordResponse, error) { + return nil, errors.New("record not found") + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "nonexistent"). + withString(FlagType, "A") + err := dnsRecordsGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get record") +} + +// ===== dnsRecordsUpdate ===== + +func TestDnsRecordsUpdate_Success(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.updateRecordFunc = func(ctx context.Context, id string, name string, recordType string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) { + assert.Equal(t, "1", id) + assert.Equal(t, "www", name) + assert.Equal(t, "CNAME", recordType) + assert.Equal(t, "new.example.com", record.Content) + return &ipfs.RecordResponse{ZoneId: 1, Name: "www", Type: "CNAME", Content: "new.example.com", Ttl: 3600, Disabled: false}, nil + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "CNAME"). + withString(FlagContent, "new.example.com") + err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsRecordsUpdate_MissingArg(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand(). + withString(FlagName, "www"). + withString(FlagType, "CNAME"). + withString(FlagContent, "new.example.com") + err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain or zone ID is required") +} + +func TestDnsRecordsUpdate_InvalidRecordType(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "BOGUS"). + withString(FlagContent, "example.com") + err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported record type") +} + +func TestDnsRecordsUpdate_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.updateRecordFunc = func(ctx context.Context, id string, name string, recordType string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) { + return nil, errors.New("server error") + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "CNAME"). + withString(FlagContent, "new.example.com") + err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to update record") +} + +func TestDnsRecordsUpdate_ARecordWithValidIP(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.updateRecordFunc = func(ctx context.Context, id string, name string, recordType string, record ipfs.RecordRequest) (*ipfs.RecordResponse, error) { + assert.Equal(t, "5.6.7.8", record.Content) + return &ipfs.RecordResponse{ZoneId: 1, Name: "@", Type: "A", Content: "5.6.7.8", Ttl: 3600, Disabled: false}, nil + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "@"). + withString(FlagType, "A"). + withString(FlagContent, "5.6.7.8") + err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +// ===== dnsRecordsDelete ===== + +func TestDnsRecordsDelete_Success(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.deleteRecordFunc = func(ctx context.Context, id string, name string, recordType string) error { + assert.Equal(t, "1", id) + assert.Equal(t, "www", name) + assert.Equal(t, "CNAME", recordType) + return nil + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "CNAME") + err := dnsRecordsDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestDnsRecordsDelete_MissingArg(t *testing.T) { + _, cfgMgr := setupDNSHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand(). + withString(FlagName, "www"). + withString(FlagType, "CNAME") + err := dnsRecordsDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain or zone ID is required") +} + +func TestDnsRecordsDelete_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupDNSHandlerTest(t) + mockSvc.deleteRecordFunc = func(ctx context.Context, id string, name string, recordType string) error { + return errors.New("server error") + } + + output := newTestOutput() + cmd := newMockCommand(). + withArgs("1"). + withString(FlagName, "www"). + withString(FlagType, "CNAME") + err := dnsRecordsDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to delete record") +} + +// ===== resolveZoneID (handler-level integration) ===== + +func TestDnsResolveZoneID_DomainArg(t *testing.T) { + mockSvc := &mockDNSServiceForCLI{} + now := time.Now() + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return []ipfs.ZoneListResponse{ + {Id: 1, Domain: "example.com", Status: "active", CreatedAt: now, UpdatedAt: now}, + {Id: 2, Domain: "other.com", Status: "active", CreatedAt: now, UpdatedAt: now}, + }, nil + } + id, err := resolveZoneID(context.Background(), mockSvc, "example.com") + require.NoError(t, err) + assert.Equal(t, "1", id) +} + +func TestDnsResolveZoneID_NumericArg(t *testing.T) { + mockSvc := &mockDNSServiceForCLI{} + id, err := resolveZoneID(context.Background(), mockSvc, "42") + require.NoError(t, err) + assert.Equal(t, "42", id) +} + +func TestDnsResolveZoneID_NotFound(t *testing.T) { + mockSvc := &mockDNSServiceForCLI{} + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return []ipfs.ZoneListResponse{}, nil + } + _, err := resolveZoneID(context.Background(), mockSvc, "nonexistent.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "zone not found") +} + +func TestDnsResolveZoneID_ListZonesError(t *testing.T) { + mockSvc := &mockDNSServiceForCLI{} + mockSvc.listZonesFunc = func(ctx context.Context) ([]ipfs.ZoneListResponse, error) { + return nil, errors.New("server error") + } + _, err := resolveZoneID(context.Background(), mockSvc, "example.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to look up zone by domain") +} diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go index 0f802e4..86744d8 100644 --- a/pkg/cli/doctor.go +++ b/pkg/cli/doctor.go @@ -79,12 +79,12 @@ Use this command when: Metadata: WithTutorial(6, "Show diagnostic info", "pinner doctor"), Action: func(ctx context.Context, cmd *cli.Command) error { output := NewOutputFormatter(cmd.Bool("json"), false, false, false) - return doctor(ctx, cmd, output, defaultConfigManagerFactory) + return doctor(ctx, newCLICommandWrapper(cmd), output, defaultConfigManagerFactory) }, } } -func doctor(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory) error { +func doctor(ctx context.Context, cmd flagGetter, output Output, cfgMgrFactory ConfigManagerFactory) error { cfgMgr, err := cfgMgrFactory() if err != nil { return fmt.Errorf("failed to create config manager: %w", err) diff --git a/pkg/cli/doctor_test.go b/pkg/cli/doctor_test.go index 9479441..264e37c 100644 --- a/pkg/cli/doctor_test.go +++ b/pkg/cli/doctor_test.go @@ -8,6 +8,7 @@ import ( "runtime" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" @@ -103,7 +104,7 @@ func TestDoctor(t *testing.T) { } } - err := doctor(context.Background(), cmd, output, cfgMgrFactory) + err := doctor(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory) if tt.wantErr { require.Error(t, err) @@ -273,7 +274,6 @@ func TestBashCompletionDetector(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Skip shell tests that don't match the current OS if tt.skipIfNotUnix && runtime.GOOS == "windows" { t.Skip("Skipping Unix-specific test on Windows") } @@ -291,6 +291,13 @@ func TestBashCompletionDetector(t *testing.T) { } } +func TestBashCompletionDetector_Accessors(t *testing.T) { + d := &BashCompletionDetector{homeDir: "/home/user"} + require.Equal(t, "bash", d.Name()) + require.Equal(t, "source <(pinner completion bash)", d.InstallCommand()) + require.Equal(t, "/home/user/.bashrc", d.ConfigPath()) +} + func TestZshCompletionDetector(t *testing.T) { tests := []struct { name string @@ -320,7 +327,6 @@ func TestZshCompletionDetector(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Skip shell tests that don't match the current OS if tt.skipIfNotUnix && runtime.GOOS == "windows" { t.Skip("Skipping Unix-specific test on Windows") } @@ -338,6 +344,13 @@ func TestZshCompletionDetector(t *testing.T) { } } +func TestZshCompletionDetector_Accessors(t *testing.T) { + d := &ZshCompletionDetector{homeDir: "/home/user"} + require.Equal(t, "zsh", d.Name()) + require.Equal(t, "source <(pinner completion zsh)", d.InstallCommand()) + require.Equal(t, "/home/user/.zshrc", d.ConfigPath()) +} + func TestFishCompletionDetector(t *testing.T) { tests := []struct { name string @@ -366,7 +379,6 @@ func TestFishCompletionDetector(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Skip shell tests that don't match the current OS if tt.skipIfNotUnix && runtime.GOOS == "windows" { t.Skip("Skipping Unix-specific test on Windows") } @@ -390,6 +402,20 @@ func TestFishCompletionDetector(t *testing.T) { } } +func TestFishCompletionDetector_Accessors(t *testing.T) { + d := &FishCompletionDetector{homeDir: "/home/user"} + require.Equal(t, "fish", d.Name()) + require.Contains(t, d.InstallCommand(), "pinner completion fish") + require.Contains(t, d.ConfigPath(), "pinner.fish") +} + +func TestPowerShellCompletionDetector_Accessors(t *testing.T) { + d := &PowerShellCompletionDetector{} + require.Equal(t, "pwsh", d.Name()) + require.Contains(t, d.InstallCommand(), "pinner completion pwsh") + require.Equal(t, "$PROFILE", d.ConfigPath()) +} + func TestPowerShellCompletionDetector(t *testing.T) { tests := []struct { name string @@ -547,3 +573,96 @@ func TestCheckCompletion(t *testing.T) { require.Contains(t, info.Configured, "zsh") }) } + +func TestPowerShellIsConfigured(t *testing.T) { + t.Run("returns false on non-windows", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping on windows") + } + d := &PowerShellCompletionDetector{} + configured, err := d.IsConfigured() + require.NoError(t, err) + assert.False(t, configured) + }) +} + +func TestDoctor_MockCommand_Success(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + Secure: true, + BaseEndpoint: "pinner.xyz", + AuthToken: "test-token", + MaxRetries: 3, + MemoryLimit: 256, + }) + output := newTestOutput() + + cmd := newMockCommand() + err := doctor(context.Background(), cmd, output, func() (config.Manager, error) { + return cfgMgr, nil + }) + require.NoError(t, err) +} + +func TestDoctor_MockCommand_JSONOutput(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + Secure: false, + BaseEndpoint: "api.example.com", + AuthToken: "", + MaxRetries: 5, + MemoryLimit: 512, + }) + output := NewOutputFormatter(true, false, false, false) + + cmd := newMockCommand().withBool("json", true) + err := doctor(context.Background(), cmd, output, func() (config.Manager, error) { + return cfgMgr, nil + }) + require.NoError(t, err) +} + +func TestDoctor_MockCommand_ConfigError(t *testing.T) { + output := newTestOutput() + + cmd := newMockCommand() + err := doctor(context.Background(), cmd, output, failingConfigMgrFactory()) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to create config manager") +} + +func TestDoctor_MockCommand_DefaultEndpoint(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + Secure: true, + BaseEndpoint: "", + AuthToken: "", + MaxRetries: 3, + MemoryLimit: 0, + }) + output := newTestOutput() + + cmd := newMockCommand() + err := doctor(context.Background(), cmd, output, func() (config.Manager, error) { + return cfgMgr, nil + }) + require.NoError(t, err) +} + +func TestDoctor_MockCommand_Authenticated(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + Secure: true, + BaseEndpoint: "pinner.xyz", + AuthToken: "my-jwt-token", + MaxRetries: 3, + MemoryLimit: 100, + }) + output := newTestOutput() + + cmd := newMockCommand() + err := doctor(context.Background(), cmd, output, func() (config.Manager, error) { + return cfgMgr, nil + }) + require.NoError(t, err) +} diff --git a/pkg/cli/download.go b/pkg/cli/download.go index daf21ec..2f5887d 100644 --- a/pkg/cli/download.go +++ b/pkg/cli/download.go @@ -8,6 +8,7 @@ import ( "time" "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" ) const ( @@ -55,7 +56,12 @@ The output includes: Metadata: WithTutorial(7, "Download pinned content", fmt.Sprintf("pinner download %s", abbreviateCID(TutorialCID))), Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return handleDownload(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultDownloadServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + return handleDownload(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultDownloadServiceFactory) }, } } @@ -81,7 +87,12 @@ Use --verbose or redirect stderr for progress info.`, Flags: []cli.Flag{}, Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return handleCat(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultDownloadServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + return handleCat(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultDownloadServiceFactory) }, } } @@ -108,34 +119,24 @@ Examples: }, Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return handleLs(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultDownloadServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + return handleLs(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultDownloadServiceFactory) }, } } -type downloadCommandGetter interface { - String(name string) string - Int(name string) int - Bool(name string) bool - Args() cli.Args -} - -func handleDownload(ctx context.Context, cmd downloadCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, downloadServiceFactory DownloadServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - +func handleDownload(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, downloadServiceFactory DownloadServiceFactory) error { authService := NewAuthService(cfgMgr, output, cfgMgr.Config().GetAccountEndpointSecure()) var svcOpts []DownloadServiceOption svcOpts = append(svcOpts, WithDownloadAuthService(authService)) - if c, ok := cmd.(*cliCommandWrapper); ok { - authToken := GetAuthToken(c.Command, cfgMgr) - if authToken != "" { - svcOpts = append(svcOpts, WithDownloadAuthToken(authToken)) - } + if authToken != "" { + svcOpts = append(svcOpts, WithDownloadAuthToken(authToken)) } cidStr := cmd.Args().First() @@ -207,22 +208,14 @@ func handleDownload(ctx context.Context, cmd downloadCommandGetter, output Outpu return nil } -func handleCat(ctx context.Context, cmd downloadCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, downloadServiceFactory DownloadServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - +func handleCat(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, downloadServiceFactory DownloadServiceFactory) error { authService := NewAuthService(cfgMgr, output, cfgMgr.Config().GetAccountEndpointSecure()) var svcOpts []DownloadServiceOption svcOpts = append(svcOpts, WithDownloadAuthService(authService)) - if c, ok := cmd.(*cliCommandWrapper); ok { - authToken := GetAuthToken(c.Command, cfgMgr) - if authToken != "" { - svcOpts = append(svcOpts, WithDownloadAuthToken(authToken)) - } + if authToken != "" { + svcOpts = append(svcOpts, WithDownloadAuthToken(authToken)) } cidStr := cmd.Args().First() @@ -246,22 +239,14 @@ func handleCat(ctx context.Context, cmd downloadCommandGetter, output Output, cf return err } -func handleLs(ctx context.Context, cmd downloadCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, downloadServiceFactory DownloadServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - +func handleLs(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, downloadServiceFactory DownloadServiceFactory) error { authService := NewAuthService(cfgMgr, output, cfgMgr.Config().GetAccountEndpointSecure()) var svcOpts []DownloadServiceOption svcOpts = append(svcOpts, WithDownloadAuthService(authService)) - if c, ok := cmd.(*cliCommandWrapper); ok { - authToken := GetAuthToken(c.Command, cfgMgr) - if authToken != "" { - svcOpts = append(svcOpts, WithDownloadAuthToken(authToken)) - } + if authToken != "" { + svcOpts = append(svcOpts, WithDownloadAuthToken(authToken)) } cidStr := cmd.Args().First() diff --git a/pkg/cli/download_client_test.go b/pkg/cli/download_client_test.go new file mode 100644 index 0000000..bbcfdc9 --- /dev/null +++ b/pkg/cli/download_client_test.go @@ -0,0 +1,444 @@ +package cli + +import ( + "context" + "io" + "os" + "path/filepath" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" + portalsdk "go.lumeweb.com/portal-sdk" + portalsdkmocks "go.lumeweb.com/portal-sdk/mocks" +) + +func TestDownloadService_RequireAuthenticated(t *testing.T) { + t.Run("not authenticated when no token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: ""}).Maybe() + svc := &DownloadServiceDefault{configMgr: cfgMgr} + err := svc.RequireAuthenticated() + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") + }) + + t.Run("authenticated with override token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: ""}).Maybe() + svc := &DownloadServiceDefault{configMgr: cfgMgr, authToken: "test-token"} + err := svc.RequireAuthenticated() + require.NoError(t, err) + }) + + t.Run("authenticated with config token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: "config-token"}).Maybe() + svc := &DownloadServiceDefault{configMgr: cfgMgr} + err := svc.RequireAuthenticated() + require.NoError(t, err) + }) +} + +func TestDownloadService_getAuthToken(t *testing.T) { + t.Run("override takes precedence", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: "config-token"}).Maybe() + svc := &DownloadServiceDefault{configMgr: cfgMgr, authToken: "override-token"} + assert.Equal(t, "override-token", svc.getAuthToken()) + }) + + t.Run("falls back to config", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: "config-token"}).Maybe() + svc := &DownloadServiceDefault{configMgr: cfgMgr} + assert.Equal(t, "config-token", svc.getAuthToken()) + }) +} + +func TestParseIPFSPath(t *testing.T) { + t.Run("CID only", func(t *testing.T) { + p, err := parseIPFSPath("QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG") + require.NoError(t, err) + assert.Equal(t, "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG", p.cid.String()) + assert.Equal(t, "", p.path) + }) + + t.Run("CID with path", func(t *testing.T) { + p, err := parseIPFSPath("QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG/subdir/file.txt") + require.NoError(t, err) + assert.Equal(t, "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG", p.cid.String()) + assert.Equal(t, "subdir/file.txt", p.path) + }) + + t.Run("CID with trailing slash path", func(t *testing.T) { + p, err := parseIPFSPath("QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG/subdir/") + require.NoError(t, err) + assert.Equal(t, "subdir", p.path) + }) + + t.Run("invalid CID", func(t *testing.T) { + _, err := parseIPFSPath("not-a-cid") + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidCID) + }) +} + +func TestIsNotDirectoryError(t *testing.T) { + assert.False(t, isNotDirectoryError(nil)) + assert.True(t, isNotDirectoryError(newErrorWithString("path is not a directory"))) + assert.True(t, isNotDirectoryError(newErrorWithString("CID is not a directory"))) + assert.False(t, isNotDirectoryError(newErrorWithString("some other error"))) +} + +type stringError struct { + msg string +} + +func (e *stringError) Error() string { return e.msg } + +func newErrorWithString(msg string) error { return &stringError{msg: msg} } + +func TestWrapDownloadError(t *testing.T) { + t.Run("nil error returns nil", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{}).Maybe() + svc := &DownloadServiceDefault{configMgr: cfgMgr} + assert.Nil(t, svc.wrapDownloadError(nil)) + }) +} + +func TestWithDownloadAuthToken(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: ""}).Maybe() + svc := &DownloadServiceDefault{configMgr: cfgMgr} + WithDownloadAuthToken("test-token")(svc) + assert.Equal(t, "test-token", svc.authToken) +} + +func TestWithDownloadAuthService(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{}).Maybe() + svc := &DownloadServiceDefault{configMgr: cfgMgr} + mockAuth := NewMockAuthService(t) + WithDownloadAuthService(mockAuth)(svc) + assert.NotNil(t, svc.authService) +} + +func makeJWTWithAudience(audience string) string { + claims := jwt.RegisteredClaims{ + Audience: []string{audience}, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, _ := token.SignedString([]byte("test-secret")) + return signed +} + +func newDownloadSvc(t *testing.T, authToken string, opts ...DownloadServiceOption) *DownloadServiceDefault { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: authToken, + BaseEndpoint: "pinner.xyz", + Secure: true, + }).Maybe() + svc := &DownloadServiceDefault{ + configMgr: cfgMgr, + output: newTestOutput(), + ipfsEndpoint: "https://ipfs.pinner.xyz", + authToken: authToken, + } + for _, opt := range opts { + opt(svc) + } + return svc +} + +func TestDownloadService_resolveAuthToken(t *testing.T) { + t.Run("returns token from config when no auth service", func(t *testing.T) { + svc := newDownloadSvc(t, "config-token") + token, err := svc.resolveAuthToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "config-token", token) + }) + + t.Run("returns override token when no auth service", func(t *testing.T) { + svc := newDownloadSvc(t, "") + svc.authToken = "override-token" + token, err := svc.resolveAuthToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "override-token", token) + }) + + t.Run("exchanges API key JWT for login token", func(t *testing.T) { + apiKeyJWT := makeJWTWithAudience("api") + mockAccountAPI := portalsdkmocks.NewMockAccountAPI(t) + mockAccountAPI.EXPECT().LoginWithAPIKey(context.Background(), apiKeyJWT).Return("login-jwt", nil) + + svc := newDownloadSvc(t, apiKeyJWT, WithDownloadAuthService(NewMockAuthService(t))) + svc.accountClient = mockAccountAPI + + token, err := svc.resolveAuthToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "login-jwt", token) + }) + + t.Run("returns raw token when JWT decode fails", func(t *testing.T) { + svc := newDownloadSvc(t, "not-a-jwt", WithDownloadAuthService(NewMockAuthService(t))) + + token, err := svc.resolveAuthToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "not-a-jwt", token) + }) + + t.Run("returns login token as-is when purpose is login", func(t *testing.T) { + loginJWT := makeJWTWithAudience("login") + svc := newDownloadSvc(t, loginJWT, WithDownloadAuthService(NewMockAuthService(t))) + + token, err := svc.resolveAuthToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, loginJWT, token) + }) + + t.Run("returns error when API key exchange fails", func(t *testing.T) { + apiKeyJWT := makeJWTWithAudience("api") + mockAccountAPI := portalsdkmocks.NewMockAccountAPI(t) + mockAccountAPI.EXPECT().LoginWithAPIKey(context.Background(), apiKeyJWT).Return("", portalsdk.ErrUnauthorized) + + svc := newDownloadSvc(t, apiKeyJWT, WithDownloadAuthService(NewMockAuthService(t))) + svc.accountClient = mockAccountAPI + + _, err := svc.resolveAuthToken(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to exchange API key for download") + }) +} + +func TestDownloadService_Cat(t *testing.T) { + t.Run("error when not authenticated", func(t *testing.T) { + svc := newDownloadSvc(t, "") + _, err := svc.Cat(context.Background(), "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") + }) + + t.Run("error on invalid CID", func(t *testing.T) { + svc := newDownloadSvc(t, "test-token") + _, err := svc.Cat(context.Background(), "not-a-cid") + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidCID) + }) + + t.Run("service error from unreachable endpoint", func(t *testing.T) { + svc := newDownloadSvc(t, "test-token") + svc.ipfsEndpoint = "http://127.0.0.1:1" + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, err := svc.Cat(ctx, "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG") + require.Error(t, err) + }) +} + +func TestDownloadService_Download(t *testing.T) { + t.Run("error when not authenticated", func(t *testing.T) { + svc := newDownloadSvc(t, "") + _, err := svc.Download(context.Background(), "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG", "", false) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") + }) + + t.Run("error on invalid CID", func(t *testing.T) { + svc := newDownloadSvc(t, "test-token") + _, err := svc.Download(context.Background(), "not-a-cid", "", false) + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidCID) + }) + + t.Run("error when file already exists without force", func(t *testing.T) { + tmpDir := t.TempDir() + existingFile := filepath.Join(tmpDir, "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG") + err := os.WriteFile(existingFile, []byte("existing"), 0644) + require.NoError(t, err) + + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "test-token", + BaseEndpoint: "pinner.xyz", + Secure: true, + }).Maybe() + svc := &DownloadServiceDefault{ + configMgr: cfgMgr, + output: newTestOutput(), + ipfsEndpoint: "https://ipfs.pinner.xyz", + authToken: "test-token", + } + + _, err = svc.Download(context.Background(), "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG", existingFile, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "file already exists") + }) + + t.Run("service error from unreachable endpoint", func(t *testing.T) { + svc := newDownloadSvc(t, "test-token") + svc.ipfsEndpoint = "http://127.0.0.1:1" + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, err := svc.Download(ctx, "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG", filepath.Join(t.TempDir(), "out"), true) + require.Error(t, err) + }) +} + +func TestDownloadService_FileSize(t *testing.T) { + t.Run("error when not authenticated", func(t *testing.T) { + svc := newDownloadSvc(t, "") + _, err := svc.FileSize(context.Background(), "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") + }) + + t.Run("error on invalid CID", func(t *testing.T) { + svc := newDownloadSvc(t, "test-token") + _, err := svc.FileSize(context.Background(), "not-a-cid") + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidCID) + }) + + t.Run("service error from unreachable endpoint", func(t *testing.T) { + svc := newDownloadSvc(t, "test-token") + svc.ipfsEndpoint = "http://127.0.0.1:1" + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, err := svc.FileSize(ctx, "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG") + require.Error(t, err) + }) +} + +func TestDownloadService_ListDirectory(t *testing.T) { + t.Run("error when not authenticated", func(t *testing.T) { + svc := newDownloadSvc(t, "") + _, err := svc.ListDirectory(context.Background(), "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") + }) + + t.Run("error on invalid CID", func(t *testing.T) { + svc := newDownloadSvc(t, "test-token") + _, err := svc.ListDirectory(context.Background(), "not-a-cid") + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidCID) + }) + + t.Run("service error from unreachable endpoint", func(t *testing.T) { + svc := newDownloadSvc(t, "test-token") + svc.ipfsEndpoint = "http://127.0.0.1:1" + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, err := svc.ListDirectory(ctx, "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG") + require.Error(t, err) + }) +} + +func TestDownloadService_listFileEntry(t *testing.T) { + t.Run("CID as name when no path", func(t *testing.T) { + p, err := parseIPFSPath("QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG") + require.NoError(t, err) + + name := p.cid.String() + if p.path != "" { + name = filepath.Base(p.path) + } + assert.Equal(t, "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG", name) + }) + + t.Run("path basename as name when path present", func(t *testing.T) { + p, err := parseIPFSPath("QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG/subdir/file.txt") + require.NoError(t, err) + + name := p.cid.String() + if p.path != "" { + name = filepath.Base(p.path) + } + assert.Equal(t, "file.txt", name) + }) +} + +func TestDownloadService_newSDKDownloadService(t *testing.T) { + t.Run("error when not authenticated", func(t *testing.T) { + svc := newDownloadSvc(t, "") + _, err := svc.newSDKDownloadService(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") + }) + + t.Run("error when resolveAuthToken fails", func(t *testing.T) { + apiKeyJWT := makeJWTWithAudience("api") + mockAccountAPI := portalsdkmocks.NewMockAccountAPI(t) + mockAccountAPI.EXPECT().LoginWithAPIKey(context.Background(), apiKeyJWT).Return("", portalsdk.ErrUnauthorized) + + svc := newDownloadSvc(t, apiKeyJWT, WithDownloadAuthService(NewMockAuthService(t))) + svc.accountClient = mockAccountAPI + + _, err := svc.newSDKDownloadService(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to exchange API key for download") + }) + + t.Run("creates SDK service with valid endpoint", func(t *testing.T) { + svc := newDownloadSvc(t, "test-token") + dlService, err := svc.newSDKDownloadService(context.Background()) + require.NoError(t, err) + require.NotNil(t, dlService) + assert.Equal(t, "test-token", dlService.AuthToken()) + }) +} + +func TestDownloadService_Download_defaultOutputPath(t *testing.T) { + t.Run("CID string as default filename when no path", func(t *testing.T) { + p, err := parseIPFSPath("QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG") + require.NoError(t, err) + + outputPath := "" + if outputPath == "" { + if p.path != "" { + outputPath = filepath.Base(p.path) + } else { + outputPath = p.cid.String() + } + } + assert.Equal(t, "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG", outputPath) + }) + + t.Run("path basename as default filename when path present", func(t *testing.T) { + p, err := parseIPFSPath("QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG/subdir/myfile.txt") + require.NoError(t, err) + + outputPath := "" + if outputPath == "" { + if p.path != "" { + outputPath = filepath.Base(p.path) + } else { + outputPath = p.cid.String() + } + } + assert.Equal(t, "myfile.txt", outputPath) + }) +} + +type mockReadCloser struct { + io.Reader +} + +func (m *mockReadCloser) Close() error { return nil } + +func TestDownloadService_Cat_withPath(t *testing.T) { + t.Run("CID with subpath triggers GetFile branch", func(t *testing.T) { + p, err := parseIPFSPath("QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG/subdir/file.txt") + require.NoError(t, err) + assert.Equal(t, "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG", p.cid.String()) + assert.Equal(t, "subdir/file.txt", p.path) + assert.NotEmpty(t, p.path) + }) +} diff --git a/pkg/cli/download_test.go b/pkg/cli/download_test.go index 473c2f8..f5bbefa 100644 --- a/pkg/cli/download_test.go +++ b/pkg/cli/download_test.go @@ -10,56 +10,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" ) -type mockDownloadCommand struct { - cid string - output string - force bool - dryRun bool - limit int - args []string -} - -func (m *mockDownloadCommand) Args() cli.Args { - if m.args == nil { - m.args = []string{m.cid} - } - return &mockArgs{m.args} -} - -func (m *mockDownloadCommand) String(name string) string { - switch name { - case FlagOutput: - return m.output - default: - return "" - } -} - -func (m *mockDownloadCommand) Int(name string) int { - switch name { - case FlagLimit: - return m.limit - default: - return 0 - } -} - -func (m *mockDownloadCommand) Bool(name string) bool { - switch name { - case FlagForce: - return m.force - case FlagDryRun: - return m.dryRun - default: - return false - } -} - func testDownloadConfigMgr(t *testing.T) *configmocks.MockManager { t.Helper() cfgMgr := configmocks.NewMockManager(t) @@ -108,24 +62,20 @@ func TestHandleDownload_DryRun(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() - cmd := &mockDownloadCommand{ - cid: tt.cid, - output: tt.output, - force: tt.force, - dryRun: true, - } + cmd := newMockCommand(). + withArgs(tt.cid). + withString(FlagOutput, tt.output). + withBool(FlagForce, tt.force). + withBool(FlagDryRun, true) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return NewMockDownloadService(t) } - err := handleDownload(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) if tt.wantErr { require.Error(t, err) @@ -138,26 +88,23 @@ func TestHandleDownload_DryRun(t *testing.T) { func TestHandleDownload_RequiresCID(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() - cmd := &mockDownloadCommand{cid: ""} + cmd := newMockCommand() - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return NewMockDownloadService(t) } - err := handleDownload(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.Error(t, err) assert.True(t, errors.Is(err, ErrCIDRequired)) } func TestHandleDownload_Success(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewMockDownloadService(t) service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().Download(context.Background(), "QmXxx", "", false).Return(&DownloadResult{ @@ -167,65 +114,56 @@ func TestHandleDownload_Success(t *testing.T) { Duration: 100 * time.Millisecond, }, nil) - cmd := &mockDownloadCommand{cid: "QmXxx"} + cmd := newMockCommand().withArgs("QmXxx") - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return service } - err := handleDownload(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.NoError(t, err) } func TestHandleDownload_NotAuthenticated(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewMockDownloadService(t) service.EXPECT().RequireAuthenticated().Return(errors.New("not authenticated")) - cmd := &mockDownloadCommand{cid: "QmXxx"} + cmd := newMockCommand().withArgs("QmXxx") - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return service } - err := handleDownload(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.Error(t, err) } func TestHandleDownload_FileExists_NoForce(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewMockDownloadService(t) service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().Download(context.Background(), "QmXxx", "", false).Return(nil, errors.New("file already exists")) - cmd := &mockDownloadCommand{cid: "QmXxx"} + cmd := newMockCommand().withArgs("QmXxx") - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return service } - err := handleDownload(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.Error(t, err) assert.Contains(t, err.Error(), "file already exists") } func TestHandleDownload_WithForce(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewMockDownloadService(t) service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().Download(context.Background(), "QmXxx", "existing.txt", true).Return(&DownloadResult{ @@ -235,17 +173,17 @@ func TestHandleDownload_WithForce(t *testing.T) { Duration: 50 * time.Millisecond, }, nil) - cmd := &mockDownloadCommand{cid: "QmXxx", output: "existing.txt", force: true} + cmd := newMockCommand(). + withArgs("QmXxx"). + withString(FlagOutput, "existing.txt"). + withBool(FlagForce, true) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return service } - err := handleDownload(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.NoError(t, err) } @@ -256,62 +194,53 @@ func TestHandleCat_Success(t *testing.T) { service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().Cat(context.Background(), "QmXxx").Return(io.NopCloser(strings.NewReader("hello world")), nil) - cmd := &mockDownloadCommand{cid: "QmXxx"} + cmd := newMockCommand().withArgs("QmXxx") - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return service } - err := handleCat(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleCat(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.NoError(t, err) } func TestHandleCat_RequiresCID(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() - cmd := &mockDownloadCommand{cid: ""} + cmd := newMockCommand() - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return NewMockDownloadService(t) } - err := handleCat(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleCat(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.Error(t, err) assert.True(t, errors.Is(err, ErrCIDRequired)) } func TestHandleCat_NotAuthenticated(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewMockDownloadService(t) service.EXPECT().RequireAuthenticated().Return(errors.New("not authenticated")) - cmd := &mockDownloadCommand{cid: "QmXxx"} + cmd := newMockCommand().withArgs("QmXxx") - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return service } - err := handleCat(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleCat(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.Error(t, err) } func TestHandleLs_Success(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewMockDownloadService(t) service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().ListDirectory(context.Background(), "QmXxx").Return([]DirEntry{ @@ -319,63 +248,54 @@ func TestHandleLs_Success(t *testing.T) { {Name: "subdir", Size: -1, Type: "directory"}, }, nil) - cmd := &mockDownloadCommand{cid: "QmXxx", limit: 10} + cmd := newMockCommand().withArgs("QmXxx").withInt(FlagLimit, 10) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return service } - err := handleLs(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.NoError(t, err) } func TestHandleLs_EmptyDirectory(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewMockDownloadService(t) service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().ListDirectory(context.Background(), "QmXxx").Return([]DirEntry{}, nil) - cmd := &mockDownloadCommand{cid: "QmXxx", limit: 10} + cmd := newMockCommand().withArgs("QmXxx").withInt(FlagLimit, 10) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return service } - err := handleLs(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.NoError(t, err) } func TestHandleLs_RequiresCID(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() - cmd := &mockDownloadCommand{cid: "", limit: 10} + cmd := newMockCommand().withInt(FlagLimit, 10) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return NewMockDownloadService(t) } - err := handleLs(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.Error(t, err) assert.True(t, errors.Is(err, ErrCIDRequired)) } func TestHandleLs_WithLimit(t *testing.T) { cfgMgr := testDownloadConfigMgr(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewMockDownloadService(t) service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().ListDirectory(context.Background(), "QmXxx").Return([]DirEntry{ @@ -384,17 +304,14 @@ func TestHandleLs_WithLimit(t *testing.T) { {Name: "file3.txt", Size: 300, Type: "file"}, }, nil) - cmd := &mockDownloadCommand{cid: "QmXxx", limit: 2} + cmd := newMockCommand().withArgs("QmXxx").withInt(FlagLimit, 2) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } downloadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return service } - err := handleLs(context.Background(), cmd, output, cfgMgrFactory, downloadServiceFactory) + err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) require.NoError(t, err) } @@ -452,7 +369,7 @@ func TestDefaultDownloadServiceFactory(t *testing.T) { Secure: true, }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := defaultDownloadServiceFactory(cfgMgr, output) diff --git a/pkg/cli/error_formatter_test.go b/pkg/cli/error_formatter_test.go index 686f1aa..eda8167 100644 --- a/pkg/cli/error_formatter_test.go +++ b/pkg/cli/error_formatter_test.go @@ -1,9 +1,16 @@ package cli import ( + "context" + "errors" + "fmt" + "net" + "os" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + portalsdk "go.lumeweb.com/portal-sdk" ) func TestExtractErrorMessage(t *testing.T) { @@ -67,3 +74,107 @@ func TestHTTPError(t *testing.T) { assert.Equal(t, "HTTP 429: quota exceeded", err.Error()) }) } + +func TestFormatError(t *testing.T) { + tests := []struct { + name string + err error + verbose bool + contains string + }{ + {"nil error", nil, false, ""}, + {"known error", ErrNotAuthenticated, false, "Not authenticated"}, + {"known error verbose", ErrNotAuthenticated, true, "Details:"}, + {"wrapped known error", fmt.Errorf("wrap: %w", ErrPinningFailed), false, "Pinning operation failed"}, + {"unknown error", errors.New("something weird"), false, "something weird"}, + {"unknown error verbose", errors.New("something weird"), true, "Details:"}, + {"context canceled", context.Canceled, false, "cancelled"}, + {"context deadline", context.DeadlineExceeded, false, "timed out"}, + {"sdk unauthorized", portalsdk.ErrUnauthorized, false, "re-authenticate"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FormatError(tt.err, tt.verbose) + if tt.err == nil { + assert.Empty(t, result) + } else { + assert.Contains(t, result, tt.contains) + } + }) + } +} + +type testNetErr struct { + timeout bool +} + +func (e *testNetErr) Error() string { return "network error" } +func (e *testNetErr) Timeout() bool { return e.timeout } +func (e *testNetErr) Temporary() bool { return e.timeout } + +func TestIsNetworkError(t *testing.T) { + t.Run("timeout net.Error", func(t *testing.T) { + var netErr net.Error = &testNetErr{timeout: true} + assert.True(t, isNetworkError(netErr)) + }) + + t.Run("non-timeout net.Error", func(t *testing.T) { + var netErr net.Error = &testNetErr{timeout: false} + assert.False(t, isNetworkError(netErr)) + }) + + t.Run("non-network error", func(t *testing.T) { + assert.False(t, isNetworkError(errors.New("not network"))) + }) +} + +func TestFormatError_NetworkError(t *testing.T) { + var netErr net.Error = &testNetErr{timeout: true} + result := FormatError(netErr, false) + assert.Contains(t, result, "Connection error") +} + +func TestWrapAuthError(t *testing.T) { + t.Run("nil error", func(t *testing.T) { + assert.NoError(t, WrapAuthError("test", nil)) + }) + + t.Run("non-auth error", func(t *testing.T) { + err := WrapAuthError("upload", errors.New("disk full")) + require.Error(t, err) + assert.Contains(t, err.Error(), "upload failed") + assert.NotContains(t, err.Error(), "authentication") + }) + + t.Run("sdk unauthorized error", func(t *testing.T) { + err := WrapAuthError("pin", portalsdk.ErrUnauthorized) + require.Error(t, err) + assert.Contains(t, err.Error(), "authentication expired") + assert.True(t, errors.Is(err, ErrNotAuthenticated)) + }) +} + +func TestWrapFileError(t *testing.T) { + t.Run("nil error", func(t *testing.T) { + assert.NoError(t, WrapFileError("read", "file.txt", nil)) + }) + + t.Run("file not found", func(t *testing.T) { + err := WrapFileError("read", "missing.txt", os.ErrNotExist) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrFileNotFound)) + }) + + t.Run("permission denied", func(t *testing.T) { + err := WrapFileError("write", "protected.txt", os.ErrPermission) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrPermissionDenied)) + }) + + t.Run("other error", func(t *testing.T) { + err := WrapFileError("read", "file.txt", errors.New("io error")) + require.Error(t, err) + assert.Contains(t, err.Error(), "read failed") + }) +} diff --git a/pkg/cli/flags_test.go b/pkg/cli/flags_test.go index d6138f7..59c756e 100644 --- a/pkg/cli/flags_test.go +++ b/pkg/cli/flags_test.go @@ -603,3 +603,96 @@ func TestCommandTree(t *testing.T) { assert.True(t, found, "otp subcommand %s should exist", subcommandName) } } + +func TestGetAuthToken(t *testing.T) { + t.Run("returns flag value when set", func(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + cmd := &cli.Command{} + cmd.Flags = []cli.Flag{&cli.StringFlag{Name: FlagAuthToken}} + _ = cmd.Set(FlagAuthToken, "flag-token") + result := GetAuthToken(cmd, cfgMgr) + assert.Equal(t, "flag-token", result) + }) + + t.Run("returns config value when flag not set", func(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + cmd := &cli.Command{} + cmd.Flags = []cli.Flag{&cli.StringFlag{Name: FlagAuthToken}} + result := GetAuthToken(cmd, cfgMgr) + assert.Equal(t, "test-token", result) + }) + + t.Run("returns config value when cmd is nil", func(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + result := GetAuthToken(nil, cfgMgr) + assert.Equal(t, "test-token", result) + }) +} + +func TestGetSecureSetting(t *testing.T) { + t.Run("returns config value when cmd is nil", func(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + result := GetSecureSetting(nil, cfgMgr) + assert.True(t, result) + }) + + t.Run("returns flag value when explicitly set", func(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + cmd := &cli.Command{} + cmd.Flags = []cli.Flag{&cli.BoolFlag{Name: FlagSecure}} + _ = cmd.Set(FlagSecure, "false") + result := GetSecureSetting(cmd, cfgMgr) + assert.False(t, result) + }) + + t.Run("returns config default when flag not set", func(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + cmd := &cli.Command{} + cmd.Flags = []cli.Flag{&cli.BoolFlag{Name: FlagSecure}} + result := GetSecureSetting(cmd, cfgMgr) + assert.True(t, result) + }) +} + +func TestApplySecureConfig(t *testing.T) { + t.Run("returns config value when cmd is nil", func(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + result := ApplySecureConfig(nil, cfgMgr) + assert.True(t, result) + }) + + t.Run("applies and returns flag value when set", func(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + cfgMgr.EXPECT().SetSecure(false).Return(nil) + cmd := &cli.Command{} + cmd.Flags = []cli.Flag{&cli.BoolFlag{Name: FlagSecure}} + _ = cmd.Set(FlagSecure, "false") + result := ApplySecureConfig(cmd, cfgMgr) + assert.False(t, result) + }) + + t.Run("returns config default when flag not set", func(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + cmd := &cli.Command{} + cmd.Flags = []cli.Flag{&cli.BoolFlag{Name: FlagSecure}} + result := ApplySecureConfig(cmd, cfgMgr) + assert.True(t, result) + }) +} + +func TestTargetHashFlag(t *testing.T) { + flag := TargetHashFlag() + assert.Equal(t, FlagTargetHash, flag.Name) +} + +func TestRequiredTargetHashFlag(t *testing.T) { + flag := RequiredTargetHashFlag() + assert.Equal(t, FlagTargetHash, flag.Name) + assert.True(t, flag.Required) +} + +func TestRequiredContentFlag(t *testing.T) { + flag := RequiredContentFlag() + assert.Equal(t, FlagContent, flag.Name) + assert.True(t, flag.Required) +} diff --git a/pkg/cli/ipfs_service_base.go b/pkg/cli/ipfs_service_base.go new file mode 100644 index 0000000..bee7173 --- /dev/null +++ b/pkg/cli/ipfs_service_base.go @@ -0,0 +1,35 @@ +package cli + +import ( + "go.lumeweb.com/pinner-cli/pkg/config" +) + +// ipfsServiceBase provides the shared auth/config pattern used by DNS, IPNS, and Websites services. +type ipfsServiceBase struct { + cfgMgr config.Manager + authToken string +} + +// getAuthToken returns the auth token to use, with override taking precedence over config. +func (b *ipfsServiceBase) getAuthToken() string { + if b.authToken != "" { + return b.authToken + } + return b.cfgMgr.Config().AuthToken +} + +// RequireAuthenticated checks if the user is authenticated. +func (b *ipfsServiceBase) RequireAuthenticated() error { + if b.getAuthToken() == "" { + return ErrNotAuthenticated + } + return nil +} + +// ipfsServiceOption applies a functional option to ipfsServiceBase. +type ipfsServiceOption func(*ipfsServiceBase) + +// withAuthToken returns an ipfsServiceOption that sets the auth token override. +func withAuthToken(token string) ipfsServiceOption { + return func(b *ipfsServiceBase) { b.authToken = token } +} diff --git a/pkg/cli/ipns.go b/pkg/cli/ipns.go index e56e5b9..9149bdb 100644 --- a/pkg/cli/ipns.go +++ b/pkg/cli/ipns.go @@ -6,6 +6,7 @@ import ( "time" "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" ) func newIPNSCommand() *cli.Command { @@ -76,10 +77,9 @@ func newIPNSKeysListCommand() *cli.Command { Examples: pinner ipns keys list pinner ipns keys list --json`, - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return ipnsKeysList(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return ipnsKeysList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -101,10 +101,9 @@ Examples: Usage: "Private key to import (optional, generates a new key if not provided)", }, }, - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return ipnsKeysCreate(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return ipnsKeysCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -119,10 +118,9 @@ Examples: pinner ipns keys get 1 pinner ipns keys get my-key --json`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return ipnsKeysGet(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return ipnsKeysGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -136,10 +134,9 @@ Examples: pinner ipns keys delete my-key pinner ipns keys delete 1`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return ipnsKeysDelete(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return ipnsKeysDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -171,10 +168,9 @@ Examples: }, WaitFlag(), }, - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return ipnsPublish(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return ipnsPublish(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -192,10 +188,9 @@ Examples: pinner ipns republish 1 pinner ipns republish my-key --json`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return ipnsRepublish(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return ipnsRepublish(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -210,49 +205,24 @@ Examples: pinner ipns resolve k51qzi5uqu5djx... pinner ipns resolve k51qzi5uqu5djx... --json`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return ipnsResolve(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return ipnsResolve(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } -func initIPNSService(ctx context.Context, cmd *cli.Command, output Output) (context.Context, context.CancelFunc, IPNSService, error) { - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - - cfgMgr, err := defaultConfigManagerFactory() - if err != nil { - cancel() - return ctx, func() {}, nil, err - } - - var ipnsService IPNSService - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - if authToken != "" { - ipnsService = NewIPNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - ipnsService = defaultIPNSServiceFactory(cfgMgr, output) - } - - if err := ipnsService.RequireAuthenticated(); err != nil { - cancel() - return ctx, func() {}, nil, err - } - - return ctx, cancel, ipnsService, nil -} - func resolveIPNSKeyArg(ctx context.Context, ipnsService IPNSService, arg string) (string, error) { return resolveIPNSKeyIDToString(ctx, ipnsService, arg) } -func ipnsKeysList(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, ipnsService, err := initIPNSService(ctx, cmd, output) +func ipnsKeysList(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() keys, err := ipnsService.ListKeys(ctx) if err != nil { @@ -295,12 +265,14 @@ func ipnsKeysList(ctx context.Context, cmd *cli.Command, output Output) error { return nil } -func ipnsKeysCreate(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, ipnsService, err := initIPNSService(ctx, cmd, output) +func ipnsKeysCreate(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() name := cmd.String(FlagName) if name == "" { @@ -345,12 +317,14 @@ func ipnsKeysCreate(ctx context.Context, cmd *cli.Command, output Output) error return nil } -func ipnsKeysGet(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, ipnsService, err := initIPNSService(ctx, cmd, output) +func ipnsKeysGet(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() args := cmd.Args() if args.Len() == 0 { @@ -395,12 +369,14 @@ func ipnsKeysGet(ctx context.Context, cmd *cli.Command, output Output) error { return nil } -func ipnsKeysDelete(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, ipnsService, err := initIPNSService(ctx, cmd, output) +func ipnsKeysDelete(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() args := cmd.Args() if args.Len() == 0 { @@ -430,12 +406,14 @@ func ipnsKeysDelete(ctx context.Context, cmd *cli.Command, output Output) error return nil } -func ipnsPublish(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, ipnsService, err := initIPNSService(ctx, cmd, output) +func ipnsPublish(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() args := cmd.Args() if args.Len() == 0 { @@ -481,12 +459,14 @@ func ipnsPublish(ctx context.Context, cmd *cli.Command, output Output) error { return nil } -func ipnsRepublish(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, ipnsService, err := initIPNSService(ctx, cmd, output) +func ipnsRepublish(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() args := cmd.Args() if args.Len() == 0 { @@ -509,12 +489,14 @@ func ipnsRepublish(ctx context.Context, cmd *cli.Command, output Output) error { return nil } -func ipnsResolve(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, ipnsService, err := initIPNSService(ctx, cmd, output) +func ipnsResolve(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() args := cmd.Args() if args.Len() == 0 { diff --git a/pkg/cli/ipns_service.go b/pkg/cli/ipns_service.go index 41a1bb7..8c7b757 100644 --- a/pkg/cli/ipns_service.go +++ b/pkg/cli/ipns_service.go @@ -21,38 +21,77 @@ type IPNSService interface { } type ipnsService struct { - service ipfs.IPNSService - cfgMgr config.Manager - authToken string - authenticated bool + ipfsServiceBase + service ipfs.IPNSService + client *ipfs.Client } -type IPNSServiceFactory func(cfgMgr config.Manager, output Output) IPNSService +// IPNSServiceOption is a function that configures an ipnsService. +type IPNSServiceOption func(*ipnsService) -func defaultIPNSServiceFactory(cfgMgr config.Manager, output Output) IPNSService { - return NewIPNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointSecure()) +// WithIPNSAuthToken sets an auth token override that takes precedence over config. +func WithIPNSAuthToken(token string) IPNSServiceOption { + return func(s *ipnsService) { + withAuthToken(token)(&s.ipfsServiceBase) + } +} + +// WithIPNSClient sets a pre-configured ipfs.Client, bypassing the default ipfs.NewClient() call. +func WithIPNSClient(client *ipfs.Client) IPNSServiceOption { + return func(s *ipnsService) { + s.client = client + } } -func NewIPNSService(cfgMgr config.Manager, output Output, apiEndpoint string) IPNSService { +type IPNSServiceFactory func(cfgMgr config.Manager, output Output, opts ...IPNSServiceOption) IPNSService + +func defaultIPNSServiceFactory(cfgMgr config.Manager, output Output, opts ...IPNSServiceOption) IPNSService { + return NewIPNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointSecure(), opts...) +} + +type ipnsServiceFactoryFunc func(cfgMgr config.Manager, output Output, opts ...IPNSServiceOption) IPNSService + +var ipnsServiceFactory ipnsServiceFactoryFunc = defaultIPNSServiceFactory + +// newAuthenticatedIPNSService creates an IPNSService with authentication. +// It returns an error if the user is not authenticated. +func newAuthenticatedIPNSService(cfgMgr config.Manager, output Output, authToken string) (IPNSService, error) { + var svcOpts []IPNSServiceOption + if authToken != "" { + svcOpts = append(svcOpts, WithIPNSAuthToken(authToken)) + } + ipnsService := ipnsServiceFactory(cfgMgr, output, svcOpts...) + if err := ipnsService.RequireAuthenticated(); err != nil { + return nil, err + } + return ipnsService, nil +} + +func NewIPNSService(cfgMgr config.Manager, output Output, apiEndpoint string, opts ...IPNSServiceOption) IPNSService { authToken := cfgMgr.Config().AuthToken - client, err := ipfs.NewClient(apiEndpoint, authToken) - if err != nil { - output.PrintError(err) - return &ipnsService{ - service: nil, - cfgMgr: cfgMgr, - authToken: authToken, - authenticated: false, - } + s := &ipnsService{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: authToken, + }, + } + for _, opt := range opts { + opt(s) } - return &ipnsService{ - service: client.IPNS(), - cfgMgr: cfgMgr, - authToken: authToken, - authenticated: authToken != "", + if s.client != nil { + s.service = s.client.IPNS() + } else { + client, err := ipfs.NewClient(apiEndpoint, authToken) + if err != nil { + output.PrintError(err) + s.service = nil + return s + } + s.service = client.IPNS() } + return s } func (s *ipnsService) ListKeys(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { @@ -122,13 +161,6 @@ func (s *ipnsService) Resolve(ctx context.Context, name string) (*ipfs.IPNSResol return s.service.Resolve(ctx, name) } -func (s *ipnsService) RequireAuthenticated() error { - if !s.authenticated { - return ErrNotAuthenticated - } - return nil -} - func resolveIPNSKeyID(ctx context.Context, svc IPNSService, arg string) (int, error) { if id, err := strconv.Atoi(arg); err == nil { return id, nil diff --git a/pkg/cli/ipns_service_crud_test.go b/pkg/cli/ipns_service_crud_test.go new file mode 100644 index 0000000..b12c9e4 --- /dev/null +++ b/pkg/cli/ipns_service_crud_test.go @@ -0,0 +1,430 @@ +package cli + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + ipfs "go.lumeweb.com/ipfs-sdk" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + +// mockIPNSSDKService implements ipfs.IPNSService for testing the ipnsService wrapper. +type mockIPNSSDKService struct { + listKeysFunc func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) + getKeyFunc func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) + createKeyFunc func(ctx context.Context, name string, opts ...ipfs.CreateKeyOption) (*ipfs.IPNSKeyResponse, error) + deleteKeyFunc func(ctx context.Context, id string) error + publishFunc func(ctx context.Context, keyID int, cid string, opts ...ipfs.PublishOption) (*ipfs.IPNSPublishResponse, error) + republishFunc func(ctx context.Context, id string) (*ipfs.IPNSRepublishResponse, error) + resolveFunc func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) + waitResolveFunc func(ctx context.Context, name string, expectedCID string, opts ...ipfs.PollOption) (*ipfs.IPNSResolveResponse, error) +} + +func (m *mockIPNSSDKService) ListKeys(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + if m.listKeysFunc != nil { + return m.listKeysFunc(ctx) + } + return nil, nil +} + +func (m *mockIPNSSDKService) GetKey(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) { + if m.getKeyFunc != nil { + return m.getKeyFunc(ctx, id) + } + return nil, nil +} + +func (m *mockIPNSSDKService) CreateKey(ctx context.Context, name string, opts ...ipfs.CreateKeyOption) (*ipfs.IPNSKeyResponse, error) { + if m.createKeyFunc != nil { + return m.createKeyFunc(ctx, name, opts...) + } + return nil, nil +} + +func (m *mockIPNSSDKService) DeleteKey(ctx context.Context, id string) error { + if m.deleteKeyFunc != nil { + return m.deleteKeyFunc(ctx, id) + } + return nil +} + +func (m *mockIPNSSDKService) Publish(ctx context.Context, keyID int, cid string, opts ...ipfs.PublishOption) (*ipfs.IPNSPublishResponse, error) { + if m.publishFunc != nil { + return m.publishFunc(ctx, keyID, cid, opts...) + } + return nil, nil +} + +func (m *mockIPNSSDKService) Republish(ctx context.Context, id string) (*ipfs.IPNSRepublishResponse, error) { + if m.republishFunc != nil { + return m.republishFunc(ctx, id) + } + return nil, nil +} + +func (m *mockIPNSSDKService) Resolve(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) { + if m.resolveFunc != nil { + return m.resolveFunc(ctx, name) + } + return nil, nil +} + +func (m *mockIPNSSDKService) WaitForIPNSResolution(ctx context.Context, name string, expectedCID string, opts ...ipfs.PollOption) (*ipfs.IPNSResolveResponse, error) { + if m.waitResolveFunc != nil { + return m.waitResolveFunc(ctx, name, expectedCID, opts...) + } + return nil, nil +} + +func newAuthedIPNSService(t *testing.T, sdkSvc ipfs.IPNSService) *ipnsService { + t.Helper() + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: "test-token"}).Maybe() + return &ipnsService{ + ipfsServiceBase: ipfsServiceBase{cfgMgr: cfgMgr, authToken: "test-token"}, + service: sdkSvc, + } +} + +func newUnauthIPNSService(t *testing.T) *ipnsService { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: ""}).Maybe() + return &ipnsService{ + ipfsServiceBase: ipfsServiceBase{cfgMgr: cfgMgr, authToken: ""}, + } +} + +func TestIPNSService_ListKeys_Unauthenticated(t *testing.T) { + svc := newUnauthIPNSService(t) + _, err := svc.ListKeys(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestIPNSService_CreateKey_Unauthenticated(t *testing.T) { + svc := newUnauthIPNSService(t) + _, err := svc.CreateKey(context.Background(), "my-key", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestIPNSService_CreateKey_WithKey_Unauthenticated(t *testing.T) { + svc := newUnauthIPNSService(t) + key := "base64key" + _, err := svc.CreateKey(context.Background(), "my-key", &key) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestIPNSService_GetKey_Unauthenticated(t *testing.T) { + svc := newUnauthIPNSService(t) + _, err := svc.GetKey(context.Background(), "1") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestIPNSService_DeleteKey_Unauthenticated(t *testing.T) { + svc := newUnauthIPNSService(t) + err := svc.DeleteKey(context.Background(), "1") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestIPNSService_Publish_Unauthenticated(t *testing.T) { + svc := newUnauthIPNSService(t) + _, err := svc.Publish(context.Background(), "QmHash", "my-key", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestIPNSService_Publish_WithTTL_Unauthenticated(t *testing.T) { + svc := newUnauthIPNSService(t) + ttl := "1h" + _, err := svc.Publish(context.Background(), "QmHash", "my-key", &ttl) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestIPNSService_Republish_Unauthenticated(t *testing.T) { + svc := newUnauthIPNSService(t) + _, err := svc.Republish(context.Background(), "my-key") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestIPNSService_Resolve_Unauthenticated(t *testing.T) { + svc := newUnauthIPNSService(t) + _, err := svc.Resolve(context.Background(), "k51qzi5uqu5dg4vh") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestIPNSService_WithIPNSAuthToken(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: ""}).Maybe() + + svc := &ipnsService{ + ipfsServiceBase: ipfsServiceBase{cfgMgr: cfgMgr}, + } + WithIPNSAuthToken("override-token")(svc) + assert.Equal(t, "override-token", svc.getAuthToken()) +} + +func TestResolveIPNSKeyID_NumericArg(t *testing.T) { + id, err := resolveIPNSKeyID(context.Background(), nil, "42") + require.NoError(t, err) + assert.Equal(t, 42, id) +} + +func TestResolveIPNSKeyID_NumericString(t *testing.T) { + id, err := resolveIPNSKeyID(context.Background(), nil, "0") + require.NoError(t, err) + assert.Equal(t, 0, id) +} + +// ===== Behavioral tests for ipnsService CRUD methods ===== + +func TestIPNSService_CreateKey_Success(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + createKeyFunc: func(ctx context.Context, name string, opts ...ipfs.CreateKeyOption) (*ipfs.IPNSKeyResponse, error) { + assert.Equal(t, "my-key", name) + assert.Empty(t, opts) + return &ipfs.IPNSKeyResponse{Id: 1, Name: "my-key"}, nil + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + result, err := svc.CreateKey(context.Background(), "my-key", nil) + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, 1, result.Id) + assert.Equal(t, "my-key", result.Name) +} + +func TestIPNSService_CreateKey_WithKey_Success(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + createKeyFunc: func(ctx context.Context, name string, opts ...ipfs.CreateKeyOption) (*ipfs.IPNSKeyResponse, error) { + assert.Equal(t, "imported-key", name) + require.Len(t, opts, 1) + return &ipfs.IPNSKeyResponse{Id: 2, Name: "imported-key"}, nil + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + key := "base64key" + result, err := svc.CreateKey(context.Background(), "imported-key", &key) + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, 2, result.Id) +} + +func TestIPNSService_CreateKey_ServiceError(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + createKeyFunc: func(ctx context.Context, name string, opts ...ipfs.CreateKeyOption) (*ipfs.IPNSKeyResponse, error) { + return nil, errors.New("conflict: key already exists") + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + _, err := svc.CreateKey(context.Background(), "my-key", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "conflict") +} + +func TestIPNSService_GetKey_Success(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + getKeyFunc: func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) { + assert.Equal(t, "1", id) + return &ipfs.IPNSKeyResponse{Id: 1, Name: "my-key", IpnsName: "k51qzi5uqu5djx123"}, nil + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + result, err := svc.GetKey(context.Background(), "1") + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, 1, result.Id) + assert.Equal(t, "my-key", result.Name) +} + +func TestIPNSService_GetKey_ServiceError(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + getKeyFunc: func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) { + return nil, errors.New("key not found") + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + _, err := svc.GetKey(context.Background(), "999") + require.Error(t, err) + assert.Contains(t, err.Error(), "key not found") +} + +func TestIPNSService_DeleteKey_Success(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + deleteKeyFunc: func(ctx context.Context, id string) error { + assert.Equal(t, "1", id) + return nil + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + err := svc.DeleteKey(context.Background(), "1") + require.NoError(t, err) +} + +func TestIPNSService_DeleteKey_ServiceError(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + deleteKeyFunc: func(ctx context.Context, id string) error { + return errors.New("key not found") + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + err := svc.DeleteKey(context.Background(), "999") + require.Error(t, err) + assert.Contains(t, err.Error(), "key not found") +} + +func TestIPNSService_Publish_Success(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + publishFunc: func(ctx context.Context, keyID int, cid string, opts ...ipfs.PublishOption) (*ipfs.IPNSPublishResponse, error) { + assert.Equal(t, 1, keyID) + assert.Equal(t, "QmXxx", cid) + assert.Empty(t, opts) + return &ipfs.IPNSPublishResponse{Name: "k51qzi5uqu5djx123", Value: "QmXxx"}, nil + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + result, err := svc.Publish(context.Background(), "QmXxx", "1", nil) + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "QmXxx", result.Value) +} + +func TestIPNSService_Publish_WithTTL_Success(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + publishFunc: func(ctx context.Context, keyID int, cid string, opts ...ipfs.PublishOption) (*ipfs.IPNSPublishResponse, error) { + assert.Equal(t, 1, keyID) + assert.Equal(t, "QmXxx", cid) + require.Len(t, opts, 1) + return &ipfs.IPNSPublishResponse{Name: "k51qzi5uqu5djx123", Value: "QmXxx"}, nil + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + ttl := "1h" + result, err := svc.Publish(context.Background(), "QmXxx", "1", &ttl) + require.NoError(t, err) + require.NotNil(t, result) +} + +func TestIPNSService_Publish_KeyResolutionError(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + listKeysFunc: func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return nil, errors.New("service unavailable") + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + _, err := svc.Publish(context.Background(), "QmXxx", "my-key", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to resolve key") + assert.Contains(t, err.Error(), "my-key") +} + +func TestIPNSService_Publish_ServiceError(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + publishFunc: func(ctx context.Context, keyID int, cid string, opts ...ipfs.PublishOption) (*ipfs.IPNSPublishResponse, error) { + return nil, errors.New("invalid CID format") + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + _, err := svc.Publish(context.Background(), "invalid", "1", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid CID format") +} + +func TestIPNSService_Republish_Success(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + republishFunc: func(ctx context.Context, id string) (*ipfs.IPNSRepublishResponse, error) { + assert.Equal(t, "1", id) + return &ipfs.IPNSRepublishResponse{Count: 1, Message: "republished successfully"}, nil + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + result, err := svc.Republish(context.Background(), "1") + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, 1, result.Count) +} + +func TestIPNSService_Republish_KeyResolutionError(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + listKeysFunc: func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return nil, errors.New("service unavailable") + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + _, err := svc.Republish(context.Background(), "my-key") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to resolve key") + assert.Contains(t, err.Error(), "my-key") +} + +func TestIPNSService_Republish_ServiceError(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + republishFunc: func(ctx context.Context, id string) (*ipfs.IPNSRepublishResponse, error) { + return nil, errors.New("republish failed") + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + _, err := svc.Republish(context.Background(), "1") + require.Error(t, err) + assert.Contains(t, err.Error(), "republish failed") +} + +func TestIPNSService_Resolve_Success(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + resolveFunc: func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) { + assert.Equal(t, "k51qzi5uqu5djx123", name) + return &ipfs.IPNSResolveResponse{ + Name: "k51qzi5uqu5djx123", + Value: "QmXxx", + Sequence: 1, + Expired: false, + Expires: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC), + }, nil + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + result, err := svc.Resolve(context.Background(), "k51qzi5uqu5djx123") + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "QmXxx", result.Value) + assert.Equal(t, 1, result.Sequence) +} + +func TestIPNSService_Resolve_ServiceError(t *testing.T) { + sdkMock := &mockIPNSSDKService{ + resolveFunc: func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) { + return nil, errors.New("IPNS name not found") + }, + } + svc := newAuthedIPNSService(t, sdkMock) + + _, err := svc.Resolve(context.Background(), "k51qzi5uqu5djx999") + require.Error(t, err) + assert.Contains(t, err.Error(), "IPNS name not found") +} diff --git a/pkg/cli/ipns_service_test.go b/pkg/cli/ipns_service_test.go index 0344d5f..9440d0d 100644 --- a/pkg/cli/ipns_service_test.go +++ b/pkg/cli/ipns_service_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/stretchr/testify/require" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" + "go.lumeweb.com/pinner-cli/pkg/config" ipfs "go.lumeweb.com/ipfs-sdk" ) @@ -85,13 +87,14 @@ func TestIPNSService_ListKeys(t *testing.T) { } type mockIPNSServiceForCLI struct { - listKeysFunc func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) - createKeyFunc func(ctx context.Context, name string, key *string) (*ipfs.IPNSKeyResponse, error) - getKeyFunc func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) - deleteKeyFunc func(ctx context.Context, id string) error - publishFunc func(ctx context.Context, cid string, keyName string, ttl *string) (*ipfs.IPNSPublishResponse, error) - republishFunc func(ctx context.Context, keyName string) (*ipfs.IPNSRepublishResponse, error) - resolveFunc func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) + requireAuthenticatedErr error + listKeysFunc func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) + createKeyFunc func(ctx context.Context, name string, key *string) (*ipfs.IPNSKeyResponse, error) + getKeyFunc func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) + deleteKeyFunc func(ctx context.Context, id string) error + publishFunc func(ctx context.Context, cid string, keyName string, ttl *string) (*ipfs.IPNSPublishResponse, error) + republishFunc func(ctx context.Context, keyName string) (*ipfs.IPNSRepublishResponse, error) + resolveFunc func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) } func (m *mockIPNSServiceForCLI) ListKeys(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { @@ -173,7 +176,7 @@ func (m *mockIPNSServiceForCLI) Resolve(ctx context.Context, name string) (*ipfs } func (m *mockIPNSServiceForCLI) RequireAuthenticated() error { - return nil + return m.requireAuthenticatedErr } type unauthenticatedIPNSService struct { @@ -235,28 +238,36 @@ func (u *unauthenticatedIPNSService) Resolve(ctx context.Context, name string) ( func TestIPNSService_RequireAuthenticated(t *testing.T) { tests := []struct { - name string - authenticated bool - wantErr bool - errContains string + name string + authToken string + wantErr bool + errContains string }{ { - name: "authenticated", - authenticated: true, - wantErr: false, + name: "authenticated", + authToken: "test-token", + wantErr: false, }, { - name: "not authenticated", - authenticated: false, - wantErr: true, - errContains: "not authenticated", + name: "not authenticated", + authToken: "", + wantErr: true, + errContains: "not authenticated", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "", + }).Maybe() + svc := &ipnsService{ - authenticated: tt.authenticated, + ipfsServiceBase: ipfsServiceBase{ + authToken: tt.authToken, + cfgMgr: cfgMgr, + }, } err := svc.RequireAuthenticated() @@ -272,3 +283,72 @@ func TestIPNSService_RequireAuthenticated(t *testing.T) { }) } } + +func TestIPNSService_AuthTokenOverride(t *testing.T) { + t.Run("override token takes precedence over empty config token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "", + }).Maybe() + + svc := &ipnsService{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "override-token", + }, + } + + err := svc.RequireAuthenticated() + require.NoError(t, err) + }) + + t.Run("override token takes precedence over config token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "config-token", + }).Maybe() + + svc := &ipnsService{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "override-token", + }, + } + + require.Equal(t, "override-token", svc.getAuthToken()) + }) + + t.Run("falls back to config token when override is empty", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "config-token", + }).Maybe() + + svc := &ipnsService{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "", + }, + } + + require.Equal(t, "config-token", svc.getAuthToken()) + }) + + t.Run("WithIPNSAuthToken functional option sets override", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "", + }).Maybe() + + svc := &ipnsService{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + }, + } + WithIPNSAuthToken("override-token")(svc) + + require.Equal(t, "override-token", svc.getAuthToken()) + err := svc.RequireAuthenticated() + require.NoError(t, err) + }) +} diff --git a/pkg/cli/ipns_test.go b/pkg/cli/ipns_test.go index 9e04183..54f1c43 100644 --- a/pkg/cli/ipns_test.go +++ b/pkg/cli/ipns_test.go @@ -3,1174 +3,454 @@ package cli import ( "context" "errors" - "fmt" "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" ipfs "go.lumeweb.com/ipfs-sdk" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" ) -func TestIPNSKeysList(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - wantErr bool - errContains string - }{ - { - name: "successful list keys", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { - return []ipfs.IPNSKeyResponse{ - { - Id: 1, - Name: "my-key", - IpnsName: "k51qzi5uqu5djx123", - PeerId: "12D3KooWABC123", - Created: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), - }, - { - Id: 2, - Name: "another-key", - IpnsName: "k51qzi5uqu5djx456", - PeerId: "12D3KooWDEF456", - Created: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC), - }, - }, nil - } - }, - wantErr: false, - }, - { - name: "no keys found", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { - return []ipfs.IPNSKeyResponse{}, nil - } - }, - wantErr: false, - }, - { - name: "service error", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { - return nil, errors.New("failed to list keys") - } - }, - wantErr: true, - errContains: "failed to list keys", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } +func setupIPNSHandlerTest(t *testing.T) (*mockIPNSServiceForCLI, *configmocks.MockManager) { + t.Helper() + mockSvc := &mockIPNSServiceForCLI{} + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + BaseEndpoint: "pinner.xyz", + Secure: true, + AuthToken: "test-token", + }).Maybe() - cmd := &cli.Command{} - - err := ipnsKeysListWithService(context.Background(), cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) + origFactory := ipnsServiceFactory + t.Cleanup(func() { ipnsServiceFactory = origFactory }) + ipnsServiceFactory = func(config.Manager, Output, ...IPNSServiceOption) IPNSService { + return mockSvc } -} -func ipnsKeysListWithService(ctx context.Context, cmd *cli.Command, output Output, ipnsService IPNSService) error { - if err := ipnsService.RequireAuthenticated(); err != nil { - return err - } - - keys, err := ipnsService.ListKeys(ctx) - if err != nil { - return err - } + return mockSvc, cfgMgr +} - if len(keys) == 0 { - output.Printf("No IPNS keys found") - return nil - } +// ===== ipnsKeysList ===== - output.Printf("Found %d IPNS key(s)", len(keys)) - - headers := []string{"ID", "NAME", "IPNS NAME", "PEER ID", "CREATED"} - rows := make([][]string, len(keys)) - for i, key := range keys { - rows[i] = []string{ - fmt.Sprintf("%d", key.Id), - key.Name, - key.IpnsName, - key.PeerId, - key.Created.Format("2006-01-02 15:04:05"), - } +func TestIpnsKeysList_Success(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + now := time.Now() + mockSvc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return []ipfs.IPNSKeyResponse{ + {Id: 1, Name: "my-key", IpnsName: "k51qzi5uqu5djx123", PeerId: "12D3KooWABC123", Created: now}, + {Id: 2, Name: "another-key", IpnsName: "k51qzi5uqu5djx456", PeerId: "12D3KooWDEF456", Created: now}, + }, nil } - output.PrintTable(headers, rows) - return nil + output := newTestOutput() + cmd := newMockCommand() + err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) } -func TestIPNSResolve(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - cmd *mockIPNSResolveCommand - wantErr bool - errContains string - }{ - { - name: "successful resolve", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.resolveFunc = func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) { - return &ipfs.IPNSResolveResponse{ - Name: "k51qzi5uqu5djx123", - Value: "QmXxx", - Sequence: 1, - Expired: false, - Expires: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC), - }, nil - } - }, - cmd: &mockIPNSResolveCommand{ipnsName: "k51qzi5uqu5djx123"}, - wantErr: false, - }, - { - name: "successful resolve with expired record", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.resolveFunc = func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) { - return &ipfs.IPNSResolveResponse{ - Name: "k51qzi5uqu5djx456", - Value: "QmYyy", - Sequence: 2, - Expired: true, - Expires: time.Date(2023, 12, 31, 23, 59, 59, 0, time.UTC), - }, nil - } - }, - cmd: &mockIPNSResolveCommand{ipnsName: "k51qzi5uqu5djx456"}, - wantErr: false, - }, - { - name: "missing IPNS name", - cmd: &mockIPNSResolveCommand{ipnsName: ""}, - wantErr: true, - errContains: "IPNS name is required", - }, - { - name: "service error - invalid IPNS name", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.resolveFunc = func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) { - return nil, errors.New("invalid IPNS name format") - } - }, - cmd: &mockIPNSResolveCommand{ipnsName: "invalid"}, - wantErr: true, - errContains: "invalid IPNS name format", - }, - { - name: "service error - IPNS name not found", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.resolveFunc = func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) { - return nil, errors.New("IPNS name not found") - } - }, - cmd: &mockIPNSResolveCommand{ipnsName: "k51qzi5uqu5djx999"}, - wantErr: true, - errContains: "IPNS name not found", - }, +func TestIpnsKeysList_Empty(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return []ipfs.IPNSKeyResponse{}, nil } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } - - err := ipnsResolveWithService(context.Background(), tt.cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) - } + output := newTestOutput() + cmd := newMockCommand() + err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) } -func TestIPNSResolveJSON(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - cmd *mockIPNSResolveCommand - wantErr bool - errContains string - }{ - { - name: "successful resolve JSON output", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.resolveFunc = func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) { - return &ipfs.IPNSResolveResponse{ - Name: "k51qzi5uqu5djx123", - Value: "QmXxx", - Sequence: 1, - Expired: false, - Expires: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC), - }, nil - } - }, - cmd: &mockIPNSResolveCommand{ipnsName: "k51qzi5uqu5djx123"}, - wantErr: false, - }, +func TestIpnsKeysList_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return nil, errors.New("server error") } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(true, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } - - err := ipnsResolveWithService(context.Background(), tt.cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) - } + output := newTestOutput() + cmd := newMockCommand() + err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "server error") } -type mockIPNSResolveCommand struct { - ipnsName string -} +func TestIpnsKeysList_Unauthenticated(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.requireAuthenticatedErr = ErrNotAuthenticated -func (m *mockIPNSResolveCommand) Args() cli.Args { - if m.ipnsName == "" { - return &mockArgs{} - } - return &mockArgs{[]string{m.ipnsName}} + output := newTestOutput() + cmd := newMockCommand() + err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrNotAuthenticated)) } -func ipnsResolveWithService(ctx context.Context, cmd interface{ Args() cli.Args }, output Output, ipnsService IPNSService) error { - args := cmd.Args() - if args.Len() == 0 { - return fmt.Errorf("IPNS name is required") - } - - ipnsName := args.First() - if ipnsName == "" { - return fmt.Errorf("IPNS name is required") - } - - if err := ipnsService.RequireAuthenticated(); err != nil { - return err - } - - response, err := ipnsService.Resolve(ctx, ipnsName) - if err != nil { - return err - } - - if output.IsJSON() { - return output.PrintJSON(response) - } +// ===== ipnsKeysCreate ===== - output.Printf("IPNS name %s resolves to CID %s", response.Name, response.Value) - - headers := []string{"NAME", "CID", "SEQUENCE", "EXPIRED", "EXPIRES"} - rows := [][]string{ - { - response.Name, - response.Value, - fmt.Sprintf("%d", response.Sequence), - fmt.Sprintf("%t", response.Expired), - response.Expires.Format("2006-01-02 15:04:05"), - }, +func TestIpnsKeysCreate_Success(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + now := time.Now() + mockSvc.createKeyFunc = func(ctx context.Context, name string, key *string) (*ipfs.IPNSKeyResponse, error) { + assert.Equal(t, "my-key", name) + assert.Nil(t, key) + return &ipfs.IPNSKeyResponse{Id: 1, Name: "my-key", IpnsName: "k51qzi5uqu5djx123", PeerId: "12D3KooWABC123", Created: now}, nil } - output.PrintTable(headers, rows) - return nil + output := newTestOutput() + cmd := newMockCommand().withString(FlagName, "my-key") + err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) } -func TestIPNSKeysCreate(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - cmd *mockIPNSCreateCommand - wantErr bool - errContains string - }{ - { - name: "successful create key", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.createKeyFunc = func(ctx context.Context, name string, key *string) (*ipfs.IPNSKeyResponse, error) { - return &ipfs.IPNSKeyResponse{ - Id: 1, - Name: name, - IpnsName: "k51qzi5uqu5djx123", - PeerId: "12D3KooWABC123", - Created: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), - }, nil - } - }, - wantErr: false, - }, - { - name: "successful create key with import", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.createKeyFunc = func(ctx context.Context, name string, key *string) (*ipfs.IPNSKeyResponse, error) { - return &ipfs.IPNSKeyResponse{ - Id: 2, - Name: name, - IpnsName: "k51qzi5uqu5djx456", - PeerId: "12D3KooWDEF456", - Created: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC), - }, nil - } - }, - wantErr: false, - }, - { - name: "missing name", - cmd: &mockIPNSCreateCommand{name: ""}, - wantErr: true, - errContains: "name is required", - }, - { - name: "service error", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.createKeyFunc = func(ctx context.Context, name string, key *string) (*ipfs.IPNSKeyResponse, error) { - return nil, errors.New("failed to create key") - } - }, - cmd: &mockIPNSCreateCommand{name: "my-key"}, - wantErr: true, - errContains: "failed to create key", - }, +func TestIpnsKeysCreate_WithKeyImport(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + now := time.Now() + mockSvc.createKeyFunc = func(ctx context.Context, name string, key *string) (*ipfs.IPNSKeyResponse, error) { + assert.Equal(t, "imported-key", name) + require.NotNil(t, key) + assert.Equal(t, "base64keydata", *key) + return &ipfs.IPNSKeyResponse{Id: 2, Name: "imported-key", IpnsName: "k51qzi5uqu5djx789", PeerId: "12D3KooWGHI789", Created: now}, nil } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } - - cmd := tt.cmd - if cmd == nil { - cmd = &mockIPNSCreateCommand{name: "my-key"} - } - - err := ipnsKeysCreateWithService(context.Background(), cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) - } + output := newTestOutput() + cmd := newMockCommand().withString(FlagName, "imported-key").withString(FlagKey, "base64keydata") + err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) } -type mockIPNSCreateCommand struct { - name string - key string -} +func TestIpnsKeysCreate_MissingName(t *testing.T) { + _, cfgMgr := setupIPNSHandlerTest(t) -func (m *mockIPNSCreateCommand) String(name string) string { - switch name { - case FlagName: - return m.name - case "key": - return m.key - default: - return "" - } + output := newTestOutput() + cmd := newMockCommand().withString(FlagName, "") + err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "name is required") } -func ipnsKeysCreateWithService(ctx context.Context, cmd interface{ String(name string) string }, output Output, ipnsService IPNSService) error { - if err := ipnsService.RequireAuthenticated(); err != nil { - return err +func TestIpnsKeysCreate_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.createKeyFunc = func(ctx context.Context, name string, key *string) (*ipfs.IPNSKeyResponse, error) { + return nil, errors.New("conflict") } - name := cmd.String(FlagName) - if name == "" { - return fmt.Errorf("name is required") - } - - var key *string - keyValue := cmd.String("key") - if keyValue != "" { - key = &keyValue - } - - createdKey, err := ipnsService.CreateKey(ctx, name, key) - if err != nil { - return err - } + output := newTestOutput() + cmd := newMockCommand().withString(FlagName, "my-key") + err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "conflict") +} - if output.IsJSON() { - return output.PrintJSON(createdKey) - } +// ===== ipnsKeysGet ===== - output.Printf("Successfully created IPNS key") - - headers := []string{"ID", "NAME", "IPNS NAME", "PEER ID", "CREATED"} - rows := [][]string{ - { - fmt.Sprintf("%d", createdKey.Id), - createdKey.Name, - createdKey.IpnsName, - createdKey.PeerId, - createdKey.Created.Format("2006-01-02 15:04:05"), - }, +func TestIpnsKeysGet_Success(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + now := time.Now() + mockSvc.getKeyFunc = func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) { + assert.Equal(t, "1", id) + return &ipfs.IPNSKeyResponse{Id: 1, Name: "my-key", IpnsName: "k51qzi5uqu5djx123", PeerId: "12D3KooWABC123", Created: now}, nil } - output.PrintTable(headers, rows) - return nil + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) } -func TestIPNSKeysGet(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - cmd *mockIPNSGetCommand - wantErr bool - errContains string - }{ - { - name: "successful get key by numeric ID", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.getKeyFunc = func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) { - return &ipfs.IPNSKeyResponse{ - Id: 1, - Name: "my-key", - IpnsName: "k51qzi5uqu5djx123", - PeerId: "12D3KooWABC123", - Created: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), - }, nil - } - }, - cmd: &mockIPNSGetCommand{keyArg: "1"}, - wantErr: false, - }, - { - name: "successful get key by name", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.getKeyFunc = func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) { - return &ipfs.IPNSKeyResponse{ - Id: 2, - Name: "another-key", - IpnsName: "k51qzi5uqu5djx456", - PeerId: "12D3KooWDEF456", - Created: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC), - }, nil - } - svc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { - return []ipfs.IPNSKeyResponse{ - {Id: 2, Name: "another-key", IpnsName: "k51qzi5uqu5djx456", PeerId: "12D3KooWDEF456", Created: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC)}, - }, nil - } - }, - cmd: &mockIPNSGetCommand{keyArg: "another-key"}, - wantErr: false, - }, - { - name: "missing key arg", - cmd: &mockIPNSGetCommand{keyArg: ""}, - wantErr: true, - errContains: "key name or ID is required", - }, - { - name: "service error", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.getKeyFunc = func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) { - return nil, errors.New("failed to get key") - } - }, - cmd: &mockIPNSGetCommand{keyArg: "1"}, - wantErr: true, - errContains: "failed to get key", - }, +func TestIpnsKeysGet_ByName(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + now := time.Now() + mockSvc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return []ipfs.IPNSKeyResponse{{Id: 2, Name: "my-key", IpnsName: "k51qzi5uqu5djx456", PeerId: "12D3KooWDEF456", Created: now}}, nil } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } - - err := ipnsKeysGetWithService(context.Background(), tt.cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) + mockSvc.getKeyFunc = func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) { + assert.Equal(t, "2", id) + return &ipfs.IPNSKeyResponse{Id: 2, Name: "my-key", IpnsName: "k51qzi5uqu5djx456", PeerId: "12D3KooWDEF456", Created: now}, nil } -} -type mockIPNSGetCommand struct { - keyArg string + output := newTestOutput() + cmd := newMockCommand().withArgs("my-key") + err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) } -func (m *mockIPNSGetCommand) String(name string) string { - return "" -} +func TestIpnsKeysGet_MissingArg(t *testing.T) { + _, cfgMgr := setupIPNSHandlerTest(t) -func (m *mockIPNSGetCommand) Args() cli.Args { - if m.keyArg == "" { - return &mockArgs{} - } - return &mockArgs{[]string{m.keyArg}} + output := newTestOutput() + cmd := newMockCommand() + err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "key name or ID is required") } -func ipnsKeysGetWithService(ctx context.Context, cmd interface{ Args() cli.Args }, output Output, ipnsService IPNSService) error { - args := cmd.Args() - if args.Len() == 0 { - return fmt.Errorf("key name or ID is required") +func TestIpnsKeysGet_NotFound(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.getKeyFunc = func(ctx context.Context, id string) (*ipfs.IPNSKeyResponse, error) { + return nil, errors.New("key not found") } - keyArg := args.First() - if keyArg == "" { - return fmt.Errorf("key name or ID is required") - } + output := newTestOutput() + cmd := newMockCommand().withArgs("999") + err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "key not found") +} - if err := ipnsService.RequireAuthenticated(); err != nil { - return err - } +// ===== ipnsKeysDelete ===== - keyID, err := resolveIPNSKeyIDToString(ctx, ipnsService, keyArg) - if err != nil { - return err +func TestIpnsKeysDelete_Success(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.deleteKeyFunc = func(ctx context.Context, id string) error { + assert.Equal(t, "1", id) + return nil } - key, err := ipnsService.GetKey(ctx, keyID) - if err != nil { - return err - } + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} - if output.IsJSON() { - return output.PrintJSON(key) +func TestIpnsKeysDelete_ByName(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + now := time.Now() + mockSvc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return []ipfs.IPNSKeyResponse{{Id: 3, Name: "my-key", IpnsName: "k51qzi5uqu5djx456", PeerId: "12D3KooWDEF456", Created: now}}, nil } - - output.Printf("IPNS Key Details") - - headers := []string{"ID", "NAME", "IPNS NAME", "PEER ID", "CREATED"} - rows := [][]string{ - { - fmt.Sprintf("%d", key.Id), - key.Name, - key.IpnsName, - key.PeerId, - key.Created.Format("2006-01-02 15:04:05"), - }, + mockSvc.deleteKeyFunc = func(ctx context.Context, id string) error { + assert.Equal(t, "3", id) + return nil } - output.PrintTable(headers, rows) - return nil + output := newTestOutput() + cmd := newMockCommand().withArgs("my-key") + err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) } -func TestIPNSKeysDelete(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - cmd *mockIPNSGetCommand - wantErr bool - errContains string - }{ - { - name: "successful delete key by ID", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.deleteKeyFunc = func(ctx context.Context, id string) error { - return nil - } - }, - cmd: &mockIPNSGetCommand{keyArg: "1"}, - wantErr: false, - }, - { - name: "successful delete key by name", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.deleteKeyFunc = func(ctx context.Context, id string) error { - return nil - } - svc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { - return []ipfs.IPNSKeyResponse{ - {Id: 2, Name: "my-key", IpnsName: "k51qzi5uqu5djx456", PeerId: "12D3KooWDEF456", Created: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC)}, - }, nil - } - }, - cmd: &mockIPNSGetCommand{keyArg: "my-key"}, - wantErr: false, - }, - { - name: "missing key arg", - cmd: &mockIPNSGetCommand{keyArg: ""}, - wantErr: true, - errContains: "key name or ID is required", - }, - { - name: "service error", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.deleteKeyFunc = func(ctx context.Context, id string) error { - return errors.New("failed to delete key") - } - }, - cmd: &mockIPNSGetCommand{keyArg: "1"}, - wantErr: true, - errContains: "failed to delete key", - }, - } +func TestIpnsKeysDelete_MissingArg(t *testing.T) { + _, cfgMgr := setupIPNSHandlerTest(t) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } - - err := ipnsKeysDeleteWithService(context.Background(), tt.cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) - } + output := newTestOutput() + cmd := newMockCommand() + err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "key name or ID is required") } -func TestIPNSKeysDeleteJSON(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - cmd *mockIPNSGetCommand - wantErr bool - errContains string - }{ - { - name: "successful delete key JSON output", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.deleteKeyFunc = func(ctx context.Context, id string) error { - return nil - } - }, - cmd: &mockIPNSGetCommand{keyArg: "1"}, - wantErr: false, - }, +func TestIpnsKeysDelete_NotFound(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.deleteKeyFunc = func(ctx context.Context, id string) error { + return errors.New("key not found") } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(true, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } - - err := ipnsKeysDeleteWithService(context.Background(), tt.cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) - } + output := newTestOutput() + cmd := newMockCommand().withArgs("999") + err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "key not found") } -func ipnsKeysDeleteWithService(ctx context.Context, cmd interface{ Args() cli.Args }, output Output, ipnsService IPNSService) error { - if err := ipnsService.RequireAuthenticated(); err != nil { - return err - } - - args := cmd.Args() - if args.Len() == 0 { - return fmt.Errorf("key name or ID is required") - } - - keyArg := args.First() - - keyID, err := resolveIPNSKeyIDToString(ctx, ipnsService, keyArg) - if err != nil { - return err - } - - if err := ipnsService.DeleteKey(ctx, keyID); err != nil { - return err - } +// ===== ipnsPublish ===== - if output.IsJSON() { - result := map[string]any{ - "success": true, - "message": fmt.Sprintf("IPNS key %s deleted successfully", keyArg), - } - return output.PrintJSON(result) +func TestIpnsPublish_Success(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + now := time.Now() + mockSvc.publishFunc = func(ctx context.Context, cid string, keyName string, ttl *string) (*ipfs.IPNSPublishResponse, error) { + assert.Equal(t, "QmXxx", cid) + assert.Equal(t, "1", keyName) + assert.Nil(t, ttl) + return &ipfs.IPNSPublishResponse{Name: "k51qzi5uqu5djx123", Value: "QmXxx", Published: now, Sequence: 1, Validity: now.Add(24 * time.Hour)}, nil } - output.Printf("IPNS key %s deleted successfully", keyArg) - - return nil + output := newTestOutput() + cmd := newMockCommand().withArgs("QmXxx").withString("key-name", "1") + err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) } -func TestIPNSPublish(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - cmd *mockIPNSPublishCommand - wantErr bool - errContains string - }{ - { - name: "successful publish with key name", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.publishFunc = func(ctx context.Context, cid string, keyName string, ttl *string) (*ipfs.IPNSPublishResponse, error) { - return &ipfs.IPNSPublishResponse{ - Name: "k51qzi5uqu5djx123", - Value: "QmXxx", - Published: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), - Sequence: 1, - Validity: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC), - }, nil - } - }, - cmd: &mockIPNSPublishCommand{ - cid: "QmXxx", - keyName: "my-key", - ttl: "", - }, - wantErr: false, - }, - { - name: "successful publish with TTL", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.publishFunc = func(ctx context.Context, cid string, keyName string, ttl *string) (*ipfs.IPNSPublishResponse, error) { - return &ipfs.IPNSPublishResponse{ - Name: "k51qzi5uqu5djx456", - Value: "QmYyy", - Published: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), - Sequence: 2, - Validity: time.Date(2024, 1, 8, 12, 0, 0, 0, time.UTC), - }, nil - } - }, - cmd: &mockIPNSPublishCommand{ - cid: "QmYyy", - keyName: "another-key", - ttl: "24h", - }, - wantErr: false, - }, - { - name: "missing CID", - cmd: &mockIPNSPublishCommand{ - cid: "", - keyName: "my-key", - }, - wantErr: true, - errContains: "CID is required", - }, - { - name: "missing key name", - cmd: &mockIPNSPublishCommand{ - cid: "QmXxx", - keyName: "", - }, - wantErr: true, - errContains: "key-name is required", - }, - { - name: "service error - invalid CID", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.publishFunc = func(ctx context.Context, cid string, keyName string, ttl *string) (*ipfs.IPNSPublishResponse, error) { - return nil, errors.New("invalid CID format") - } - }, - cmd: &mockIPNSPublishCommand{ - cid: "invalid", - keyName: "my-key", - }, - wantErr: true, - errContains: "invalid CID format", - }, - { - name: "service error - key not found", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.publishFunc = func(ctx context.Context, cid string, keyName string, ttl *string) (*ipfs.IPNSPublishResponse, error) { - return nil, errors.New("key not found") - } - }, - cmd: &mockIPNSPublishCommand{ - cid: "QmXxx", - keyName: "nonexistent", - }, - wantErr: true, - errContains: "key not found", - }, +func TestIpnsPublish_WithTTL(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + now := time.Now() + mockSvc.publishFunc = func(ctx context.Context, cid string, keyName string, ttl *string) (*ipfs.IPNSPublishResponse, error) { + require.NotNil(t, ttl) + assert.Equal(t, "24h", *ttl) + return &ipfs.IPNSPublishResponse{Name: "k51qzi5uqu5djx123", Value: "QmYyy", Published: now, Sequence: 2, Validity: now.Add(24 * time.Hour)}, nil } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } - - err := ipnsPublishWithService(context.Background(), tt.cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) - } + output := newTestOutput() + cmd := newMockCommand().withArgs("QmYyy").withString("key-name", "1").withString("ttl", "24h") + err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) } -func TestIPNSPublishJSON(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - cmd *mockIPNSPublishCommand - wantErr bool - errContains string - }{ - { - name: "successful publish JSON output", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.publishFunc = func(ctx context.Context, cid string, keyName string, ttl *string) (*ipfs.IPNSPublishResponse, error) { - return &ipfs.IPNSPublishResponse{ - Name: "k51qzi5uqu5djx123", - Value: "QmXxx", - Published: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), - Sequence: 1, - Validity: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC), - }, nil - } - }, - cmd: &mockIPNSPublishCommand{ - cid: "QmXxx", - keyName: "my-key", - }, - wantErr: false, - }, - } +func TestIpnsPublish_MissingCID(t *testing.T) { + _, cfgMgr := setupIPNSHandlerTest(t) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(true, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } - - err := ipnsPublishWithService(context.Background(), tt.cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) - } + output := newTestOutput() + cmd := newMockCommand().withString("key-name", "my-key") + err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "CID is required") } -type mockIPNSPublishCommand struct { - cid string - keyName string - ttl string -} +func TestIpnsPublish_MissingKeyName(t *testing.T) { + _, cfgMgr := setupIPNSHandlerTest(t) -func (m *mockIPNSPublishCommand) Int(name string) int { - return 0 + output := newTestOutput() + cmd := newMockCommand().withArgs("QmXxx").withString("key-name", "") + err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "key-name is required") } -func (m *mockIPNSPublishCommand) String(name string) string { - switch name { - case "key-name": - return m.keyName - case "ttl": - return m.ttl - default: - return "" +func TestIpnsPublish_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.publishFunc = func(ctx context.Context, cid string, keyName string, ttl *string) (*ipfs.IPNSPublishResponse, error) { + return nil, errors.New("invalid CID format") } -} -func (m *mockIPNSPublishCommand) Args() cli.Args { - if m.cid == "" { - return &mockArgs{} - } - return &mockArgs{[]string{m.cid}} + output := newTestOutput() + cmd := newMockCommand().withArgs("invalid").withString("key-name", "1") + err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid CID format") } -func ipnsPublishWithService(ctx context.Context, cmd interface { - Int(name string) int - String(name string) string - Args() cli.Args -}, output Output, ipnsService IPNSService) error { - if err := ipnsService.RequireAuthenticated(); err != nil { - return err - } +// ===== ipnsRepublish ===== - args := cmd.Args() - if args.Len() == 0 { - return fmt.Errorf("CID is required") +func TestIpnsRepublish_Success(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.republishFunc = func(ctx context.Context, keyName string) (*ipfs.IPNSRepublishResponse, error) { + assert.Equal(t, "my-key", keyName) + return &ipfs.IPNSRepublishResponse{Count: 1, Message: "republished successfully"}, nil } - cid := args.First() - if cid == "" { - return fmt.Errorf("CID is required") - } - - keyName := cmd.String("key-name") - if keyName == "" { - return fmt.Errorf("key-name is required") - } + output := newTestOutput() + cmd := newMockCommand().withArgs("my-key") + err := ipnsRepublish(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} - var ttl *string - ttlValue := cmd.String("ttl") - if ttlValue != "" { - ttl = &ttlValue - } +func TestIpnsRepublish_MissingArg(t *testing.T) { + _, cfgMgr := setupIPNSHandlerTest(t) - response, err := ipnsService.Publish(ctx, cid, keyName, ttl) - if err != nil { - return err - } + output := newTestOutput() + cmd := newMockCommand() + err := ipnsRepublish(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "key name or ID is required") +} - if output.IsJSON() { - return output.PrintJSON(response) +func TestIpnsRepublish_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.republishFunc = func(ctx context.Context, keyName string) (*ipfs.IPNSRepublishResponse, error) { + return nil, errors.New("republish failed") } - output.Printf("Published CID %s to IPNS name %s", response.Value, response.Name) - - headers := []string{"NAME", "VALUE", "PUBLISHED", "SEQUENCE", "VALIDITY"} - rows := [][]string{ - { - response.Name, - response.Value, - response.Published.Format("2006-01-02 15:04:05"), - fmt.Sprintf("%d", response.Sequence), - response.Validity.Format("2006-01-02 15:04:05"), - }, - } - output.PrintTable(headers, rows) + output := newTestOutput() + cmd := newMockCommand().withArgs("my-key") + err := ipnsRepublish(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "republish failed") +} - return nil +// ===== ipnsResolve ===== + +func TestIpnsResolve_Success(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.resolveFunc = func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) { + assert.Equal(t, "k51qzi5uqu5djx123", name) + return &ipfs.IPNSResolveResponse{ + Name: "k51qzi5uqu5djx123", + Value: "QmXxx", + Sequence: 1, + Expired: false, + Expires: time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC), + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("k51qzi5uqu5djx123") + err := ipnsResolve(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) } -func TestIPNSRepublish(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - cmd *mockIPNSRepublishCommand - wantErr bool - errContains string - }{ - { - name: "successful republish by name", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.republishFunc = func(ctx context.Context, keyName string) (*ipfs.IPNSRepublishResponse, error) { - return &ipfs.IPNSRepublishResponse{ - Count: 1, - Message: "republished successfully", - }, nil - } - }, - cmd: &mockIPNSRepublishCommand{keyArg: "my-key"}, - wantErr: false, - }, - { - name: "successful republish by ID", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.republishFunc = func(ctx context.Context, keyName string) (*ipfs.IPNSRepublishResponse, error) { - return &ipfs.IPNSRepublishResponse{ - Count: 1, - Message: "republished successfully", - }, nil - } - }, - cmd: &mockIPNSRepublishCommand{keyArg: "1"}, - wantErr: false, - }, - { - name: "missing key arg", - cmd: &mockIPNSRepublishCommand{keyArg: ""}, - wantErr: true, - errContains: "key name or ID is required", - }, - { - name: "service error", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.republishFunc = func(ctx context.Context, keyName string) (*ipfs.IPNSRepublishResponse, error) { - return nil, errors.New("republish failed") - } - }, - cmd: &mockIPNSRepublishCommand{keyArg: "my-key"}, - wantErr: true, - errContains: "republish failed", - }, - } +func TestIpnsResolve_MissingArg(t *testing.T) { + _, cfgMgr := setupIPNSHandlerTest(t) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } - - err := ipnsRepublishWithService(context.Background(), tt.cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) - } + output := newTestOutput() + cmd := newMockCommand() + err := ipnsResolve(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "IPNS name is required") } -func TestIPNSRepublishJSON(t *testing.T) { - tests := []struct { - name string - setupMocks func(*mockIPNSServiceForCLI) - cmd *mockIPNSRepublishCommand - wantErr bool - errContains string - }{ - { - name: "successful republish JSON output", - setupMocks: func(svc *mockIPNSServiceForCLI) { - svc.republishFunc = func(ctx context.Context, keyName string) (*ipfs.IPNSRepublishResponse, error) { - return &ipfs.IPNSRepublishResponse{ - Count: 1, - Message: "republished successfully", - }, nil - } - }, - cmd: &mockIPNSRepublishCommand{keyArg: "my-key"}, - wantErr: false, - }, +func TestIpnsResolve_NotFound(t *testing.T) { + mockSvc, cfgMgr := setupIPNSHandlerTest(t) + mockSvc.resolveFunc = func(ctx context.Context, name string) (*ipfs.IPNSResolveResponse, error) { + return nil, errors.New("IPNS name not found") } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockSvc := &mockIPNSServiceForCLI{} - output := NewOutputFormatter(true, false, false, false) - - if tt.setupMocks != nil { - tt.setupMocks(mockSvc) - } - - err := ipnsRepublishWithService(context.Background(), tt.cmd, output, mockSvc) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - require.Contains(t, err.Error(), tt.errContains) - } - } else { - require.NoError(t, err) - } - }) - } + output := newTestOutput() + cmd := newMockCommand().withArgs("k51qzi5uqu5djx999") + err := ipnsResolve(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "IPNS name not found") } -type mockIPNSRepublishCommand struct { - keyArg string -} +// ===== resolveIPNSKeyID (helper function) ===== -func (m *mockIPNSRepublishCommand) Args() cli.Args { - if m.keyArg == "" { - return &mockArgs{} +func TestResolveIPNSKeyID_ByName(t *testing.T) { + mockSvc := &mockIPNSServiceForCLI{} + mockSvc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return []ipfs.IPNSKeyResponse{ + {Id: 7, Name: "my-key"}, + {Id: 8, Name: "other-key"}, + }, nil } - return &mockArgs{[]string{m.keyArg}} + id, err := resolveIPNSKeyID(context.Background(), mockSvc, "my-key") + require.NoError(t, err) + assert.Equal(t, 7, id) } -func ipnsRepublishWithService(ctx context.Context, cmd interface{ Args() cli.Args }, output Output, ipnsService IPNSService) error { - if err := ipnsService.RequireAuthenticated(); err != nil { - return err +func TestResolveIPNSKeyID_NotFound(t *testing.T) { + mockSvc := &mockIPNSServiceForCLI{} + mockSvc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return []ipfs.IPNSKeyResponse{}, nil } + _, err := resolveIPNSKeyID(context.Background(), mockSvc, "missing-key") + require.Error(t, err) + assert.Contains(t, err.Error(), "IPNS key not found for name") +} - args := cmd.Args() - if args.Len() == 0 { - return fmt.Errorf("key name or ID is required") +func TestResolveIPNSKeyID_ListError(t *testing.T) { + mockSvc := &mockIPNSServiceForCLI{} + mockSvc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return nil, errors.New("service down") } + _, err := resolveIPNSKeyID(context.Background(), mockSvc, "my-key") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to look up IPNS key by name") +} - keyArg := args.First() +// ===== resolveIPNSKeyIDToString (helper function) ===== - response, err := ipnsService.Republish(ctx, keyArg) - if err != nil { - return err - } +func TestResolveIPNSKeyIDToString_NumericID(t *testing.T) { + mockSvc := &mockIPNSServiceForCLI{} + id, err := resolveIPNSKeyIDToString(context.Background(), mockSvc, "42") + require.NoError(t, err) + assert.Equal(t, "42", id) +} - if output.IsJSON() { - return output.PrintJSON(response) +func TestResolveIPNSKeyIDToString_ByName(t *testing.T) { + mockSvc := &mockIPNSServiceForCLI{} + mockSvc.listKeysFunc = func(ctx context.Context) ([]ipfs.IPNSKeyResponse, error) { + return []ipfs.IPNSKeyResponse{{Id: 7, Name: "my-key"}}, nil } - - output.Printf("Republished IPNS key %s: %s (%d record(s))", keyArg, response.Message, response.Count) - - return nil + id, err := resolveIPNSKeyIDToString(context.Background(), mockSvc, "my-key") + require.NoError(t, err) + assert.Equal(t, "7", id) } - - diff --git a/pkg/cli/list.go b/pkg/cli/list.go index 00a8326..66265e4 100644 --- a/pkg/cli/list.go +++ b/pkg/cli/list.go @@ -6,6 +6,7 @@ import ( "time" "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" ) func newListCommand() *cli.Command { @@ -32,34 +33,22 @@ Examples: }, Metadata: WithTutorial(3, "List all pins", "pinner list --name my-pin"), Action: func(ctx context.Context, c *cli.Command) error { - output := NewOutputFormatter(c.Bool(FlagJSON), c.Bool(FlagVerbose), c.Bool(FlagQuiet), c.Bool(FlagUnmask)) - return list(ctx, c, output, defaultConfigManagerFactory, defaultPinningServiceFactory) + output := setupOutput(c) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + secure := GetSecureSetting(c, cfgMgr) + return list(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory) }, } } -// listCommandGetter defines the interface for getting list command flags. -type listCommandGetter interface { - String(name string) string - Int(name string) int - Bool(name string) bool -} - -func list(ctx context.Context, cmd listCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, pinningServiceFactory PinningServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - +func list(ctx context.Context, cmd flagGetterWithInt, output Output, cfgMgr config.Manager, authToken string, secure bool, pinningServiceFactory PinningServiceFactory) error { var pinningService PinningService - if c, ok := cmd.(*cli.Command); ok { - secure := GetSecureSetting(c, cfgMgr) - authToken := GetAuthToken(c, cfgMgr) - if authToken != "" { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) - } else { - pinningService = pinningServiceFactory(cfgMgr, output) - } + if authToken != "" { + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { pinningService = pinningServiceFactory(cfgMgr, output) } diff --git a/pkg/cli/list_test.go b/pkg/cli/list_test.go index be0c3de..93c758c 100644 --- a/pkg/cli/list_test.go +++ b/pkg/cli/list_test.go @@ -14,13 +14,12 @@ import ( func TestList(t *testing.T) { tests := []struct { - name string - nameFilter string - limit int - setupMocks func(*configmocks.MockManager, *MockPinningService) - wantErr bool - errContains string - cfgMgrFactoryErr bool + name string + nameFilter string + limit int + setupMocks func(*configmocks.MockManager, *MockPinningService) + wantErr bool + errContains string }{ { name: "successful list operation", @@ -80,15 +79,6 @@ func TestList(t *testing.T) { }, wantErr: false, }, - { - name: "returns error when config manager factory fails", - nameFilter: "", - limit: 0, - setupMocks: func(cfgMgr *configmocks.MockManager, service *MockPinningService) {}, - wantErr: true, - errContains: "config error", - cfgMgrFactoryErr: true, - }, { name: "returns error when list fails", nameFilter: "", @@ -122,35 +112,23 @@ func TestList(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfgMgr := configmocks.NewMockManager(t) + cfgMgr := newTestConfigMgr(t) service := NewMockPinningService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - cmd := &mockListCommand{ - nameFilter: tt.nameFilter, - limit: tt.limit, - } - - var cfgMgrFactory ConfigManagerFactory - if tt.cfgMgrFactoryErr { - cfgMgrFactory = func() (config.Manager, error) { - return nil, errors.New("config error") - } - } else { - cfgMgrFactory = func() (config.Manager, error) { - return cfgMgr, nil - } - } + cmd := newMockCommand(). + withString(FlagName, tt.nameFilter). + withInt(FlagLimit, tt.limit) pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := list(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := list(context.Background(), cmd, output, cfgMgr, "", false, pinningServiceFactory) if tt.wantErr { require.Error(t, err) @@ -170,7 +148,6 @@ func TestNewListCommand(t *testing.T) { assert.Equal(t, "list", cmd.Name) - // Check flags flags := cmd.Flags assert.Len(t, flags, 4) @@ -192,30 +169,47 @@ func TestNewListCommand(t *testing.T) { }) } -// mockListCommand is a mock implementation of listCommandGetter for testing. -type mockListCommand struct { - nameFilter string - limit int -} +func TestList_WithStatusFilter(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + service := NewMockPinningService(t) + output := newTestOutput() -func (m *mockListCommand) String(name string) string { - switch name { - case FlagName: - return m.nameFilter - default: - return "" + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().List(context.Background(), "", 10, "pinned").Return( + []Pin{ + {CID: "QmXxx", Name: "test", Status: "pinned", Created: "2024-01-01T00:00:00Z"}, + }, + nil, + ) + + cmd := newMockCommand(). + withInt(FlagLimit, 10). + withString(FlagStatus, "pinned") + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service } + + err := list(context.Background(), cmd, output, cfgMgr, "", false, pinningServiceFactory) + require.NoError(t, err) } -func (m *mockListCommand) Int(name string) int { - switch name { - case FlagLimit: - return m.limit - default: - return 0 +func TestList_RequireAuthFails(t *testing.T) { + cfgMgr := newTestConfigMgr(t) + service := NewMockPinningService(t) + output := newTestOutput() + + service.EXPECT().RequireAuthenticated().Return(errors.New("not authenticated")) + + cmd := newMockCommand() + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service } -} -func (m *mockListCommand) Bool(name string) bool { - return false + err := list(context.Background(), cmd, output, cfgMgr, "", false, pinningServiceFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") } + + diff --git a/pkg/cli/metadata_removed_test.go b/pkg/cli/metadata_removed_test.go new file mode 100644 index 0000000..d89e44c --- /dev/null +++ b/pkg/cli/metadata_removed_test.go @@ -0,0 +1,29 @@ +package cli + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewMetadataRemovedCommand(t *testing.T) { + cmd := newMetadataRemovedCommand() + + assert.Equal(t, "metadata", cmd.Name) + assert.Equal(t, "Pinning", cmd.Category) + assert.Contains(t, cmd.Usage, "REMOVED") + assert.True(t, cmd.Hidden, "metadata command should be hidden") + assert.NotNil(t, cmd.Action) +} + +func TestMetadataRemovedAction(t *testing.T) { + cmd := newMetadataRemovedCommand() + + err := cmd.Action(context.Background(), cmd) + + require.Error(t, err) + assert.Contains(t, err.Error(), "metadata") + assert.Contains(t, err.Error(), "pins update") +} diff --git a/pkg/cli/operations.go b/pkg/cli/operations.go index 235f0f1..3b9fb7c 100644 --- a/pkg/cli/operations.go +++ b/pkg/cli/operations.go @@ -102,18 +102,11 @@ Examples: } } -type operationsCommandGetter interface { - String(name string) string - Int(name string) int - Bool(name string) bool - Args() cli.Args -} - func defaultOperationsServiceFactory(cfgMgr config.Manager, output Output, authService AuthService) OperationsService { return NewOperationsService(cfgMgr, output, authService) } -func operationsList(ctx context.Context, cmd operationsCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, serviceFactory OperationsServiceFactory) error { +func operationsList(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, serviceFactory OperationsServiceFactory) error { cfgMgr, err := cfgMgrFactory() if err != nil { return err @@ -170,7 +163,7 @@ func operationsList(ctx context.Context, cmd operationsCommandGetter, output Out return nil } -func operationsGet(ctx context.Context, cmd operationsCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, serviceFactory OperationsServiceFactory) error { +func operationsGet(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, serviceFactory OperationsServiceFactory) error { cfgMgr, err := cfgMgrFactory() if err != nil { return err diff --git a/pkg/cli/operations_test.go b/pkg/cli/operations_test.go index 9d3eaeb..e5658ff 100644 --- a/pkg/cli/operations_test.go +++ b/pkg/cli/operations_test.go @@ -20,7 +20,7 @@ import ( func TestOperationsServiceDefault_List(t *testing.T) { t.Run("returns operations from account client", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) now := time.Now() @@ -47,7 +47,7 @@ func TestOperationsServiceDefault_List(t *testing.T) { t.Run("returns multiple operations", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) now := time.Now() @@ -75,7 +75,7 @@ func TestOperationsServiceDefault_List(t *testing.T) { t.Run("returns error when not authenticated", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() svc := NewOperationsService(cfgMgr, output, nil) @@ -86,7 +86,7 @@ func TestOperationsServiceDefault_List(t *testing.T) { t.Run("returns error when list operations fails", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) accountClient.EXPECT().ListOperations( @@ -103,7 +103,7 @@ func TestOperationsServiceDefault_List(t *testing.T) { t.Run("populates error field from operation", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) now := time.Now() @@ -129,7 +129,7 @@ func TestOperationsServiceDefault_List(t *testing.T) { t.Run("resolves account client via auth service", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() authSvc := NewMockAuthService(t) accountClient := portalsdkmocks.NewMockAccountAPI(t) @@ -155,7 +155,7 @@ func TestOperationsServiceDefault_List(t *testing.T) { t.Run("reuses account client on subsequent calls", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() authSvc := NewMockAuthService(t) accountClient := portalsdkmocks.NewMockAccountAPI(t) @@ -196,7 +196,7 @@ func TestOperationsServiceDefault_List(t *testing.T) { func TestOperationsServiceDefault_Get(t *testing.T) { t.Run("returns operation detail from account client", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) now := time.Now() @@ -219,7 +219,7 @@ func TestOperationsServiceDefault_Get(t *testing.T) { t.Run("returns operation with error detail", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) now := time.Now() @@ -238,7 +238,7 @@ func TestOperationsServiceDefault_Get(t *testing.T) { t.Run("returns error when get operation fails", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) accountClient.EXPECT().GetOperation(context.Background(), int64(404)).Return( @@ -254,7 +254,7 @@ func TestOperationsServiceDefault_Get(t *testing.T) { t.Run("returns error when not authenticated", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() svc := NewOperationsService(cfgMgr, output, nil) @@ -267,7 +267,7 @@ func TestOperationsServiceDefault_Get(t *testing.T) { func TestOperationsServiceDefault_RequireAuthenticated(t *testing.T) { t.Run("returns error when auth service is nil", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() svc := NewOperationsService(cfgMgr, output, nil) @@ -278,7 +278,7 @@ func TestOperationsServiceDefault_RequireAuthenticated(t *testing.T) { t.Run("returns error when auth client resolution fails", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() authSvc := NewMockAuthService(t) authSvc.EXPECT().GetAuthenticatedClient(context.Background()).Return(nil, errors.New("auth failed")) @@ -291,7 +291,7 @@ func TestOperationsServiceDefault_RequireAuthenticated(t *testing.T) { t.Run("returns nil when authenticated", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) svc := NewOperationsService(cfgMgr, output, nil, WithOperationsAccountClient(accountClient)) @@ -304,7 +304,7 @@ func TestOperationsServiceDefault_RequireAuthenticated(t *testing.T) { func TestOperationsServiceDefault_Watch(t *testing.T) { t.Run("returns settled operation from WaitForOperation", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) now := time.Now() @@ -330,7 +330,7 @@ func TestOperationsServiceDefault_Watch(t *testing.T) { t.Run("returns error when WaitForOperation fails", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) accountClient.EXPECT().WaitForOperation( @@ -373,7 +373,7 @@ func TestFormatOperationStatusWithColor(t *testing.T) { func TestOperationsListOptions_Filters(t *testing.T) { t.Run("applies status filter", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) now := time.Now() @@ -398,7 +398,7 @@ func TestOperationsListOptions_Filters(t *testing.T) { func TestRenderOperationDetail(t *testing.T) { t.Run("renders operation with steps", func(t *testing.T) { - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() currentStep := 2 totalSteps := 5 op := &OperationDetail{ @@ -422,7 +422,7 @@ func TestRenderOperationDetail(t *testing.T) { }) t.Run("renders operation with error and message", func(t *testing.T) { - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() op := &OperationDetail{ ID: 2, CID: "QmErr", @@ -444,7 +444,7 @@ func TestRenderOperationDetail(t *testing.T) { }) t.Run("renders minimal operation", func(t *testing.T) { - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() op := &OperationDetail{ ID: 3, CID: "", @@ -467,7 +467,7 @@ func TestRenderOperationDetail(t *testing.T) { func TestOperationsServiceDefault_CIDPointer(t *testing.T) { t.Run("handles nil CID pointer from operation", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) now := time.Now() @@ -531,68 +531,12 @@ func TestNewOperationsGetCommand(t *testing.T) { }) } -type mockOperationsCommand struct { - status string - operation string - protocol string - cid string - limit int - watch bool - args []string -} - -func (m *mockOperationsCommand) String(name string) string { - switch name { - case FlagStatus: - return m.status - case FlagOperation: - return m.operation - case FlagProtocol: - return m.protocol - case FlagCID: - return m.cid - default: - return "" - } -} - -func (m *mockOperationsCommand) Int(name string) int { - switch name { - case FlagLimit: - return m.limit - default: - return 0 - } -} - -func (m *mockOperationsCommand) Bool(name string) bool { - switch name { - case FlagWatch: - return m.watch - default: - return false - } -} - -func (m *mockOperationsCommand) Args() cli.Args { - return &mockArgs{args: m.args} -} -func setupMockCfgMgr(t *testing.T) *configmocks.MockManager { - t.Helper() - cfgMgr := configmocks.NewMockManager(t) - cfgMgr.EXPECT().Config().Return(&config.Config{ - Secure: true, - BaseEndpoint: "pinner.xyz", - AuthToken: "test-token", - }).Maybe() - return cfgMgr -} func TestOperationsList(t *testing.T) { t.Run("successful list with results", func(t *testing.T) { - cfgMgr := setupMockCfgMgr(t) - output := NewOutputFormatter(false, false, false, false) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() opsSvc := NewMockOperationsService(t) opsSvc.EXPECT().RequireAuthenticated().Return(nil) @@ -609,7 +553,7 @@ func TestOperationsList(t *testing.T) { Total: 1, }, nil) - cmd := &mockOperationsCommand{} + cmd := newMockCommand() cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -620,8 +564,8 @@ func TestOperationsList(t *testing.T) { }) t.Run("successful list with empty results", func(t *testing.T) { - cfgMgr := setupMockCfgMgr(t) - output := NewOutputFormatter(false, false, false, false) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() opsSvc := NewMockOperationsService(t) opsSvc.EXPECT().RequireAuthenticated().Return(nil) @@ -630,7 +574,7 @@ func TestOperationsList(t *testing.T) { Total: 0, }, nil) - cmd := &mockOperationsCommand{} + cmd := newMockCommand() cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -641,13 +585,13 @@ func TestOperationsList(t *testing.T) { }) t.Run("returns error when not authenticated", func(t *testing.T) { - cfgMgr := setupMockCfgMgr(t) - output := NewOutputFormatter(false, false, false, false) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() opsSvc := NewMockOperationsService(t) opsSvc.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) - cmd := &mockOperationsCommand{} + cmd := newMockCommand() cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -659,14 +603,14 @@ func TestOperationsList(t *testing.T) { }) t.Run("returns error when list fails", func(t *testing.T) { - cfgMgr := setupMockCfgMgr(t) - output := NewOutputFormatter(false, false, false, false) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() opsSvc := NewMockOperationsService(t) opsSvc.EXPECT().RequireAuthenticated().Return(nil) opsSvc.EXPECT().List(context.Background(), OperationsListOptions{}).Return(nil, errors.New("server error")) - cmd := &mockOperationsCommand{} + cmd := newMockCommand() cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -678,8 +622,8 @@ func TestOperationsList(t *testing.T) { }) t.Run("passes filters to service", func(t *testing.T) { - cfgMgr := setupMockCfgMgr(t) - output := NewOutputFormatter(false, false, false, false) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() opsSvc := NewMockOperationsService(t) opsSvc.EXPECT().RequireAuthenticated().Return(nil) @@ -694,13 +638,12 @@ func TestOperationsList(t *testing.T) { Total: 0, }, nil) - cmd := &mockOperationsCommand{ - status: "running", - operation: "upload", - protocol: "ipfs", - cid: "QmTest", - limit: 5, - } + cmd := newMockCommand(). + withString(FlagStatus, "running"). + withString(FlagOperation, "upload"). + withString(FlagProtocol, "ipfs"). + withString(FlagCID, "QmTest"). + withInt(FlagLimit, 5) cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -711,8 +654,8 @@ func TestOperationsList(t *testing.T) { }) t.Run("returns error when cfgMgr factory fails", func(t *testing.T) { - output := NewOutputFormatter(false, false, false, false) - cmd := &mockOperationsCommand{} + output := newTestOutput() + cmd := newMockCommand() cfgMgrFactory := func() (config.Manager, error) { return nil, errors.New("config error") } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -726,8 +669,8 @@ func TestOperationsList(t *testing.T) { func TestOperationsGet(t *testing.T) { t.Run("successful get operation", func(t *testing.T) { - cfgMgr := setupMockCfgMgr(t) - output := NewOutputFormatter(false, false, false, false) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() opsSvc := NewMockOperationsService(t) opsSvc.EXPECT().RequireAuthenticated().Return(nil) @@ -745,7 +688,7 @@ func TestOperationsGet(t *testing.T) { UpdatedAt: "2024-01-01T00:00:00Z", }, nil) - cmd := &mockOperationsCommand{args: []string{"42"}} + cmd := newMockCommand().withArgs("42") cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -756,13 +699,13 @@ func TestOperationsGet(t *testing.T) { }) t.Run("returns error when no operation ID provided", func(t *testing.T) { - cfgMgr := setupMockCfgMgr(t) - output := NewOutputFormatter(false, false, false, false) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() opsSvc := NewMockOperationsService(t) opsSvc.EXPECT().RequireAuthenticated().Return(nil) - cmd := &mockOperationsCommand{args: []string{}} + cmd := newMockCommand() cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -774,13 +717,13 @@ func TestOperationsGet(t *testing.T) { }) t.Run("returns error for invalid operation ID", func(t *testing.T) { - cfgMgr := setupMockCfgMgr(t) - output := NewOutputFormatter(false, false, false, false) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() opsSvc := NewMockOperationsService(t) opsSvc.EXPECT().RequireAuthenticated().Return(nil) - cmd := &mockOperationsCommand{args: []string{"not-a-number"}} + cmd := newMockCommand().withArgs("not-a-number") cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -792,14 +735,14 @@ func TestOperationsGet(t *testing.T) { }) t.Run("returns error when get fails", func(t *testing.T) { - cfgMgr := setupMockCfgMgr(t) - output := NewOutputFormatter(false, false, false, false) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() opsSvc := NewMockOperationsService(t) opsSvc.EXPECT().RequireAuthenticated().Return(nil) opsSvc.EXPECT().Get(context.Background(), int64(999)).Return(nil, fmt.Errorf("not found")) - cmd := &mockOperationsCommand{args: []string{"999"}} + cmd := newMockCommand().withArgs("999") cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -811,13 +754,13 @@ func TestOperationsGet(t *testing.T) { }) t.Run("returns error when not authenticated", func(t *testing.T) { - cfgMgr := setupMockCfgMgr(t) - output := NewOutputFormatter(false, false, false, false) + cfgMgr := newTestConfigMgr(t) + output := newTestOutput() opsSvc := NewMockOperationsService(t) opsSvc.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) - cmd := &mockOperationsCommand{args: []string{"1"}} + cmd := newMockCommand().withArgs("1") cfgMgrFactory := func() (config.Manager, error) { return cfgMgr, nil } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -829,8 +772,8 @@ func TestOperationsGet(t *testing.T) { }) t.Run("returns error when cfgMgr factory fails", func(t *testing.T) { - output := NewOutputFormatter(false, false, false, false) - cmd := &mockOperationsCommand{args: []string{"1"}} + output := newTestOutput() + cmd := newMockCommand().withArgs("1") cfgMgrFactory := func() (config.Manager, error) { return nil, errors.New("config error") } authSvcFactory := func(cm config.Manager, out Output, endpoint string) AuthService { return nil } @@ -895,3 +838,272 @@ func TestBuildOperationRows(t *testing.T) { assert.Equal(t, "IPFS", rows[0][2]) }) } + +func TestWatchOperationsList(t *testing.T) { + t.Run("exits cleanly when context cancelled after initial list", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + opsSvc.EXPECT().List(mock.Anything, OperationsListOptions{}).Return(&OperationsListResult{ + Operations: []OperationListItem{ + {ID: 1, CID: "QmTest", Status: "running", Operation: "upload", OperationDisplayName: "Upload", Protocol: "ipfs", ProtocolDisplayName: "IPFS", ProgressPercent: 50, StartedAt: "2024-01-01"}, + }, + Total: 1, + }, nil) + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel context shortly after the initial list call completes, + // so the ticker loop picks up ctx.Done() before the 2s ticker fires. + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + err := watchOperationsList(ctx, opsSvc, output, OperationsListOptions{}) + assert.ErrorIs(t, err, context.Canceled) + }) + + t.Run("returns error when initial list fails", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + opsSvc.EXPECT().List(mock.Anything, OperationsListOptions{}).Return(nil, errors.New("server error")) + + err := watchOperationsList(context.Background(), opsSvc, output, OperationsListOptions{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "server error") + }) + + t.Run("exits when all operations are settled on initial list", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + opsSvc.EXPECT().List(mock.Anything, OperationsListOptions{}).Return(&OperationsListResult{ + Operations: []OperationListItem{ + {ID: 1, CID: "QmTest", Status: "completed", Operation: "pin", OperationDisplayName: "Pin", Protocol: "ipfs", ProtocolDisplayName: "IPFS", ProgressPercent: 100, StartedAt: "2024-01-01"}, + }, + Total: 1, + }, nil) + + err := watchOperationsList(context.Background(), opsSvc, output, OperationsListOptions{}) + require.NoError(t, err) + }) + + t.Run("returns error when list fails during ticker loop", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + callCount := 0 + opsSvc.EXPECT().List(mock.Anything, OperationsListOptions{}).RunAndReturn( + func(ctx context.Context, opts OperationsListOptions) (*OperationsListResult, error) { + callCount++ + if callCount == 1 { + return &OperationsListResult{ + Operations: []OperationListItem{ + {ID: 1, CID: "QmTest", Status: "running", Operation: "upload", OperationDisplayName: "Upload", Protocol: "ipfs", ProtocolDisplayName: "IPFS", ProgressPercent: 50, StartedAt: "2024-01-01"}, + }, + Total: 1, + }, nil + } + return nil, errors.New("network error") + }, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := watchOperationsList(ctx, opsSvc, output, OperationsListOptions{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "network error") + }) + + t.Run("exits when ticker loop finds empty operations", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + callCount := 0 + opsSvc.EXPECT().List(mock.Anything, OperationsListOptions{}).RunAndReturn( + func(ctx context.Context, opts OperationsListOptions) (*OperationsListResult, error) { + callCount++ + if callCount == 1 { + return &OperationsListResult{ + Operations: []OperationListItem{ + {ID: 1, CID: "QmTest", Status: "running", Operation: "upload", OperationDisplayName: "Upload", Protocol: "ipfs", ProtocolDisplayName: "IPFS", ProgressPercent: 50, StartedAt: "2024-01-01"}, + }, + Total: 1, + }, nil + } + return &OperationsListResult{ + Operations: []OperationListItem{}, + Total: 0, + }, nil + }, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := watchOperationsList(ctx, opsSvc, output, OperationsListOptions{}) + require.NoError(t, err) + }) + + t.Run("exits when ticker loop finds all operations settled", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + callCount := 0 + opsSvc.EXPECT().List(mock.Anything, OperationsListOptions{}).RunAndReturn( + func(ctx context.Context, opts OperationsListOptions) (*OperationsListResult, error) { + callCount++ + if callCount == 1 { + return &OperationsListResult{ + Operations: []OperationListItem{ + {ID: 1, CID: "QmTest", Status: "running", Operation: "upload", OperationDisplayName: "Upload", Protocol: "ipfs", ProtocolDisplayName: "IPFS", ProgressPercent: 50, StartedAt: "2024-01-01"}, + }, + Total: 1, + }, nil + } + return &OperationsListResult{ + Operations: []OperationListItem{ + {ID: 1, CID: "QmTest", Status: "completed", Operation: "upload", OperationDisplayName: "Upload", Protocol: "ipfs", ProtocolDisplayName: "IPFS", ProgressPercent: 100, StartedAt: "2024-01-01"}, + }, + Total: 1, + }, nil + }, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := watchOperationsList(ctx, opsSvc, output, OperationsListOptions{}) + require.NoError(t, err) + }) +} + +func TestWatchOperation(t *testing.T) { + t.Run("exits cleanly when context cancelled", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel quickly so the ticker loop exits via ctx.Done() + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + err := watchOperation(ctx, opsSvc, output, 42) + require.NoError(t, err) + }) + + t.Run("exits when operation is complete on first ticker check", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + opsSvc.EXPECT().Get(mock.Anything, int64(7)).Return(&OperationDetail{ + ID: 7, + CID: "QmDone", + Status: "completed", + StatusDisplayName: "Completed", + Operation: "pin", + OperationDisplayName: "Pin", + Protocol: "ipfs", + ProtocolDisplayName: "IPFS", + ProgressPercent: 100, + StartedAt: "2024-01-01T00:00:00Z", + UpdatedAt: "2024-01-01T00:01:00Z", + }, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := watchOperation(ctx, opsSvc, output, 7) + require.NoError(t, err) + }) + + t.Run("returns error when get fails during ticker loop", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + opsSvc.EXPECT().Get(mock.Anything, int64(99)).Return(nil, errors.New("not found")) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := watchOperation(ctx, opsSvc, output, 99) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("exits when operation reaches failed status", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + opsSvc.EXPECT().Get(mock.Anything, int64(5)).Return(&OperationDetail{ + ID: 5, + CID: "QmFail", + Status: "failed", + StatusDisplayName: "Failed", + Operation: "upload", + OperationDisplayName: "Upload", + Protocol: "ipfs", + ProtocolDisplayName: "IPFS", + ProgressPercent: 30, + StartedAt: "2024-01-01T00:00:00Z", + UpdatedAt: "2024-01-01T00:01:00Z", + Error: "disk full", + }, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := watchOperation(ctx, opsSvc, output, 5) + require.NoError(t, err) + }) + + t.Run("detects status change between checks", func(t *testing.T) { + opsSvc := NewMockOperationsService(t) + output := newTestOutput() + + callCount := 0 + opsSvc.EXPECT().Get(mock.Anything, int64(10)).RunAndReturn( + func(ctx context.Context, id int64) (*OperationDetail, error) { + callCount++ + if callCount == 1 { + return &OperationDetail{ + ID: 10, + CID: "QmProgress", + Status: "running", + StatusDisplayName: "Running", + Operation: "upload", + OperationDisplayName: "Upload", + Protocol: "ipfs", + ProtocolDisplayName: "IPFS", + ProgressPercent: 50, + StartedAt: "2024-01-01T00:00:00Z", + UpdatedAt: "2024-01-01T00:01:00Z", + }, nil + } + return &OperationDetail{ + ID: 10, + CID: "QmProgress", + Status: "completed", + StatusDisplayName: "Completed", + Operation: "upload", + OperationDisplayName: "Upload", + Protocol: "ipfs", + ProtocolDisplayName: "IPFS", + ProgressPercent: 100, + StartedAt: "2024-01-01T00:00:00Z", + UpdatedAt: "2024-01-01T00:02:00Z", + }, nil + }, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := watchOperation(ctx, opsSvc, output, 10) + require.NoError(t, err) + }) +} diff --git a/pkg/cli/output_test.go b/pkg/cli/output_test.go index 89787db..11827e6 100644 --- a/pkg/cli/output_test.go +++ b/pkg/cli/output_test.go @@ -7,6 +7,7 @@ import ( "errors" "strings" "testing" + "time" "github.com/ggwhite/go-masker" "github.com/stretchr/testify/assert" @@ -15,7 +16,7 @@ import ( func TestNewOutputFormatter(t *testing.T) { t.Run("returns humanFormatter when json is false", func(t *testing.T) { - formatter := NewOutputFormatter(false, false, false, false) + formatter := newTestOutput() assert.IsType(t, &humanFormatter{}, formatter) }) @@ -907,7 +908,7 @@ func TestNewOutputFormatterCombinations(t *testing.T) { func TestOutputFormatterSetWriter(t *testing.T) { t.Run("set writer for human formatter", func(t *testing.T) { var buf bytes.Buffer - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() output.SetWriter(&buf) output.Print("test") @@ -951,7 +952,7 @@ func TestMaskToken(t *testing.T) { } func TestMaskSensitive(t *testing.T) { - humanFormatter := NewOutputFormatter(false, false, false, false).(*humanFormatter) + humanFormatter := newTestOutput().(*humanFormatter) jsonFormatter := NewOutputFormatter(true, false, false, false).(*jsonFormatter) testCases := []struct { @@ -1093,7 +1094,7 @@ func TestHumanFormatterWatch(t *testing.T) { } var buf bytes.Buffer - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() output.SetWriter(&buf) err := output.Watch(ctx, fetcher, formatter) @@ -1117,7 +1118,7 @@ func TestHumanFormatterWatch(t *testing.T) { } var buf bytes.Buffer - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() output.SetWriter(&buf) err := output.Watch(ctx, fetcher, formatter) @@ -1138,7 +1139,7 @@ func TestHumanFormatterWatch(t *testing.T) { } var buf bytes.Buffer - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() output.SetWriter(&buf) err := output.Watch(ctx, fetcher, formatter) @@ -1159,7 +1160,7 @@ func TestHumanFormatterWatch(t *testing.T) { } var buf bytes.Buffer - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() output.SetWriter(&buf) err := output.Watch(ctx, fetcher, formatter) @@ -1220,3 +1221,93 @@ func TestWrapLine(t *testing.T) { }) } } + +func TestJsonFormatterPrintBatchResult(t *testing.T) { + t.Run("prints batch result as JSON", func(t *testing.T) { + var buf bytes.Buffer + output := NewOutputFormatter(true, false, false, false) + output.SetWriter(&buf) + + result := &BatchResult{ + Total: 3, + Succeeded: []OperationResult{{CID: "QmA", RequestID: "r1", Status: "pinned"}}, + Failed: []OperationError{{CID: "QmB", Error: "timeout"}}, + Skipped: []string{"QmC"}, + Duration: 5 * time.Second, + } + + output.PrintBatchResult(result) + + var parsed map[string]any + err := json.Unmarshal(buf.Bytes(), &parsed) + require.NoError(t, err) + assert.Equal(t, float64(3), parsed["total"]) + assert.Equal(t, float64(1), parsed["succeeded"]) + assert.Equal(t, float64(1), parsed["failed"]) + assert.Equal(t, float64(1), parsed["skipped"]) + }) + + t.Run("quiet mode suppresses output", func(t *testing.T) { + var buf bytes.Buffer + output := NewOutputFormatter(true, false, true, false) + output.SetWriter(&buf) + + result := &BatchResult{Total: 1, Succeeded: []OperationResult{{CID: "QmA"}}} + output.PrintBatchResult(result) + + assert.Empty(t, buf.String()) + }) +} + +func TestHumanFormatterPrintBatchResult(t *testing.T) { + t.Run("prints batch result with failures", func(t *testing.T) { + var buf bytes.Buffer + output := NewOutputFormatter(false, false, false, false) + output.SetWriter(&buf) + + result := &BatchResult{ + Total: 2, + Succeeded: []OperationResult{{CID: "QmA", RequestID: "r1", Status: "pinned"}}, + Failed: []OperationError{{CID: "QmB", Error: "timeout"}}, + Skipped: []string{}, + Duration: 5 * time.Second, + } + + output.PrintBatchResult(result) + + assert.Contains(t, buf.String(), "QmB") + assert.Contains(t, buf.String(), "timeout") + }) + + t.Run("quiet mode suppresses output", func(t *testing.T) { + var buf bytes.Buffer + output := NewOutputFormatter(false, false, true, false) + output.SetWriter(&buf) + + result := &BatchResult{Total: 1, Succeeded: []OperationResult{{CID: "QmA"}}} + output.PrintBatchResult(result) + + assert.Empty(t, buf.String()) + }) +} + +func TestJsonFormatterWatch(t *testing.T) { + t.Run("watch with cancelled context returns immediately", func(t *testing.T) { + var buf bytes.Buffer + output := NewOutputFormatter(true, false, false, false) + output.SetWriter(&buf) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := output.Watch(ctx, + func(ctx context.Context) (any, error) { + return []Pin{{CID: "QmA", Status: "pinned"}}, nil + }, + func(data any) (string, []string, [][]string) { + return "test", []string{"CID"}, [][]string{{"QmA"}} + }, + ) + assert.Error(t, err) + }) +} diff --git a/pkg/cli/pin.go b/pkg/cli/pin.go index f6c3e7b..386ffde 100644 --- a/pkg/cli/pin.go +++ b/pkg/cli/pin.go @@ -39,35 +39,22 @@ Examples: Metadata: WithTutorial(2, "Pin by CID", fmt.Sprintf("pinner pin %s", abbreviateCID(TutorialCID))), Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - _, err := pin(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultPinningServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + secure := GetSecureSetting(c, cfgMgr) + _, err = pin(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory) return err }, } } -// pinCommandGetter defines the interface for getting pin command flags. -type pinCommandGetter interface { - String(name string) string - Int(name string) int - Bool(name string) bool - GetCID() string -} - -func pin(ctx context.Context, cmd pinCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, pinningServiceFactory PinningServiceFactory) ([]string, error) { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return nil, err - } - +func pin(ctx context.Context, cmd cidFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool, pinningServiceFactory PinningServiceFactory) ([]string, error) { var pinningService PinningService - if c, ok := cmd.(*cliCommandWrapper); ok { - secure := GetSecureSetting(c.Command, cfgMgr) - authToken := GetAuthToken(c.Command, cfgMgr) - if authToken != "" { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) - } else { - pinningService = pinningServiceFactory(cfgMgr, output) - } + if authToken != "" { + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { pinningService = pinningServiceFactory(cfgMgr, output) } @@ -84,6 +71,7 @@ func pin(ctx context.Context, cmd pinCommandGetter, output Output, cfgMgrFactory dryRun := cmd.Bool(FlagDryRun) var cids []string + var err error if isStdinPipe() { cids, err = readLinesFromStdin() diff --git a/pkg/cli/pin_test.go b/pkg/cli/pin_test.go index c7dec64..ea071b6 100644 --- a/pkg/cli/pin_test.go +++ b/pkg/cli/pin_test.go @@ -69,37 +69,32 @@ func TestPinDryRun(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockPinningService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - var cmd pinCommandGetter + cmd := newMockCommand() if tt.name == "dry run with options" { - cmd = &mockPinCommand{ - cid: tt.cid, - name: "test-name", - parallel: 5, - continueOn: true, - dryRun: tt.dryRunFlag, - } + cmd = newMockCommand(). + withCID(tt.cid). + withString(FlagName, "test-name"). + withInt(FlagParallel, 5). + withBool(FlagContinue, true). + withBool(FlagDryRun, tt.dryRunFlag) } else { - cmd = &mockPinCommand{ - cid: tt.cid, - dryRun: tt.dryRunFlag, - } + cmd = newMockCommand(). + withCID(tt.cid). + withBool(FlagDryRun, tt.dryRunFlag) } - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } pinningServiceFactory := func(cfgMgr config.Manager, output Output) PinningService { return service } - _, err := pin(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + _, err := pin(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) if tt.wantErr { require.Error(t, err) @@ -122,7 +117,6 @@ func TestPin(t *testing.T) { setupMocks func(*configmocks.MockManager, *MockPinningService) wantErr bool errContains string - cfgMgrFactoryErr bool }{ { name: "successful pin operation", @@ -173,7 +167,6 @@ func TestPin(t *testing.T) { }, wantErr: true, errContains: "cid is required", - cfgMgrFactoryErr: false, }, { name: "returns error when no CIDs provided for batch", @@ -186,16 +179,6 @@ func TestPin(t *testing.T) { wantErr: true, errContains: "cid is required", }, - { - name: "returns error when config manager factory fails", - cid: "QmXxx", - nameFlag: "", - noWaitFlag: false, - setupMocks: func(cfgMgr *configmocks.MockManager, service *MockPinningService) {}, - wantErr: true, - errContains: "config error", - cfgMgrFactoryErr: true, - }, { name: "returns error when pinning fails", cid: "QmXxx", @@ -216,34 +199,22 @@ func TestPin(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockPinningService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - cmd := &mockPinCommand{ - cid: tt.cid, - name: tt.nameFlag, - noWait: tt.noWaitFlag, - } - - var cfgMgrFactory ConfigManagerFactory - if tt.cfgMgrFactoryErr { - cfgMgrFactory = func() (config.Manager, error) { - return nil, errors.New("config error") - } - } else { - cfgMgrFactory = func() (config.Manager, error) { - return cfgMgr, nil - } - } + cmd := newMockCommand(). + withCID(tt.cid). + withString(FlagName, tt.nameFlag). + withBool(FlagNoWait, tt.noWaitFlag) pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - _, err := pin(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + _, err := pin(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) if tt.wantErr { require.Error(t, err) @@ -304,27 +275,23 @@ func TestPinBatch(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockPinningService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - cmd := &mockPinCommand{ - cid: tt.cids, - parallel: tt.parallel, - continueOn: tt.continueOn, - } + cmd := newMockCommand(). + withCID(tt.cids). + withInt(FlagParallel, tt.parallel). + withBool(FlagContinue, tt.continueOn) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - _, err := pin(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + _, err := pin(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) if tt.wantErr { require.Error(t, err) @@ -376,53 +343,7 @@ func TestNewPinCommand(t *testing.T) { }) } -// mockPinCommand is a mock implementation of commandGetter for testing. -type mockPinCommand struct { - cid string - name string - noWait bool - file string - parallel int - continueOn bool - dryRun bool -} - -func (m *mockPinCommand) GetCID() string { - return m.cid -} - -func (m *mockPinCommand) String(name string) string { - switch name { - case FlagName: - return m.name - case FlagFile: - return m.file - default: - return "" - } -} - -func (m *mockPinCommand) Int(name string) int { - switch name { - case FlagParallel: - return m.parallel - default: - return 0 - } -} -func (m *mockPinCommand) Bool(name string) bool { - switch name { - case FlagNoWait: - return m.noWait - case FlagContinue: - return m.continueOn - case FlagDryRun: - return m.dryRun - default: - return false - } -} func TestDefaultPinningServiceFactory(t *testing.T) { t.Run("creates pinning service with correct dependencies", func(t *testing.T) { @@ -433,7 +354,7 @@ func TestDefaultPinningServiceFactory(t *testing.T) { Secure: true, }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := defaultPinningServiceFactory(cfgMgr, output) diff --git a/pkg/cli/pinning_client_batch_test.go b/pkg/cli/pinning_client_batch_test.go new file mode 100644 index 0000000..6e6249e --- /dev/null +++ b/pkg/cli/pinning_client_batch_test.go @@ -0,0 +1,522 @@ +package cli + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "testing" + "time" + + go_pinning_service_http_client "github.com/ipfs/boxo/pinning/remote/client" + "github.com/ipfs/go-cid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + climocks "go.lumeweb.com/pinner-cli/pkg/cli/internal/mocks" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + +var ( + batchCID1, _ = cid.Decode("QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn") + batchCID2, _ = cid.Decode("QmPZ9gcCEpqKTo6aq61g2nXGUhM4iCL3ewB6LDXZCtioEB") + batchCID3, _ = cid.Decode("QmSnuWmxptJZdLJpRarEy8h7u5ZdhbZaHjyspVvX7wVEQv") +) + +func newAuthenticatedBatchService(t *testing.T, client *climocks.MockPinningClient) *PinningServiceDefault { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Maybe().Return(&config.Config{ + AuthToken: testAuthToken, + }) + output := newTestOutput() + service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) + return service.(*PinningServiceDefault) +} + +func TestPinningServiceDefault_PinBatch(t *testing.T) { + t.Run("happy path with 3 CIDs", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + for _, c := range []cid.Cid{batchCID1, batchCID2, batchCID3} { + mockResult := NewMockPinStatusGetter(t, c, "batch-pin", go_pinning_service_http_client.StatusPinned) + client.EXPECT().Add(context.Background(), c, mock.Anything).Return(mockResult, nil) + } + + service := newAuthenticatedBatchService(t, client) + + result, err := service.PinBatch(context.Background(), + []string{batchCID1.String(), batchCID2.String(), batchCID3.String()}, + "batch-pin", + BatchOptions{Parallel: 2}, + ) + + require.NoError(t, err) + assert.Equal(t, 3, result.Total) + assert.Len(t, result.Succeeded, 3) + assert.Empty(t, result.Failed) + assert.Empty(t, result.Skipped) + assert.Greater(t, result.Duration, time.Duration(0)) + }) + + t.Run("empty CID list returns empty result", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + service := newAuthenticatedBatchService(t, client) + + result, err := service.PinBatch(context.Background(), []string{}, "batch-pin", BatchOptions{}) + + require.NoError(t, err) + assert.Equal(t, &BatchResult{}, result) + }) + + t.Run("ContinueOn=true collects failures", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + mockResult1 := NewMockPinStatusGetter(t, batchCID1, "", go_pinning_service_http_client.StatusPinned) + client.EXPECT().Add(context.Background(), batchCID1, mock.Anything).Return(mockResult1, nil) + + client.EXPECT().Add(context.Background(), batchCID2, mock.Anything).Return( + nil, errors.New("pin failed for cid2"), + ) + + mockResult3 := NewMockPinStatusGetter(t, batchCID3, "", go_pinning_service_http_client.StatusPinned) + client.EXPECT().Add(context.Background(), batchCID3, mock.Anything).Return(mockResult3, nil) + + service := newAuthenticatedBatchService(t, client) + + result, err := service.PinBatch(context.Background(), + []string{batchCID1.String(), batchCID2.String(), batchCID3.String()}, + "", + BatchOptions{Parallel: 1, ContinueOn: true}, + ) + + require.NoError(t, err) + assert.Equal(t, 3, result.Total) + assert.Len(t, result.Succeeded, 2) + assert.Len(t, result.Failed, 1) + assert.Equal(t, batchCID2.String(), result.Failed[0].CID) + assert.Contains(t, result.Failed[0].Error, "pin failed for cid2") + }) + + t.Run("ContinueOn=false returns first error", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + mockResult1 := NewMockPinStatusGetter(t, batchCID1, "", go_pinning_service_http_client.StatusPinned) + client.EXPECT().Add(context.Background(), batchCID1, mock.Anything).Return(mockResult1, nil) + + client.EXPECT().Add(context.Background(), batchCID2, mock.Anything).Return( + nil, errors.New("pin failed for cid2"), + ) + + mockResult3 := NewMockPinStatusGetter(t, batchCID3, "", go_pinning_service_http_client.StatusPinned) + client.EXPECT().Add(context.Background(), batchCID3, mock.Anything).Return(mockResult3, nil) + + service := newAuthenticatedBatchService(t, client) + + result, err := service.PinBatch(context.Background(), + []string{batchCID1.String(), batchCID2.String(), batchCID3.String()}, + "", + BatchOptions{Parallel: 1, ContinueOn: false}, + ) + + require.Error(t, err) + assert.Contains(t, err.Error(), "pin failed for cid2") + assert.Equal(t, 3, result.Total) + assert.Empty(t, result.Failed) + }) + + t.Run("returns error when not authenticated", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + service := &PinningServiceDefault{ + pinningClient: nil, + configMgr: cfgMgr, + output: output, + apiEndpoint: "https://api.test.com", + } + + result, err := service.PinBatch(context.Background(), + []string{batchCID1.String()}, + "test", + BatchOptions{}, + ) + + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") + assert.Nil(t, result) + }) + + t.Run("defaults parallel to 1 when zero or negative", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + mockResult := NewMockPinStatusGetter(t, batchCID1, "", go_pinning_service_http_client.StatusPinned) + client.EXPECT().Add(context.Background(), batchCID1, mock.Anything).Return(mockResult, nil) + + service := newAuthenticatedBatchService(t, client) + + result, err := service.PinBatch(context.Background(), + []string{batchCID1.String()}, + "", + BatchOptions{Parallel: 0}, + ) + + require.NoError(t, err) + assert.Len(t, result.Succeeded, 1) + }) + + t.Run("invalid CID in batch with ContinueOn", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + service := newAuthenticatedBatchService(t, client) + + result, err := service.PinBatch(context.Background(), + []string{"invalid-cid"}, + "", + BatchOptions{Parallel: 1, ContinueOn: true}, + ) + + require.NoError(t, err) + assert.Equal(t, 1, result.Total) + assert.Len(t, result.Failed, 1) + assert.Equal(t, "invalid-cid", result.Failed[0].CID) + assert.Contains(t, result.Failed[0].Error, "invalid CID") + }) +} + +func TestPinningServiceDefault_UnpinBatch(t *testing.T) { + t.Run("happy path with 3 CIDs", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + for _, c := range []cid.Cid{batchCID1, batchCID2, batchCID3} { + mockPin := NewMockPin(t, c, "") + mockResult := NewMockPinStatusGetterWithPin(t, mockPin, go_pinning_service_http_client.StatusPinned) + client.EXPECT().LsSync(context.Background(), mock.Anything).Return( + []go_pinning_service_http_client.PinStatusGetter{mockResult}, nil, + ) + client.EXPECT().DeleteByID(context.Background(), mock.Anything).Return(nil) + } + + service := newAuthenticatedBatchService(t, client) + + result, err := service.UnpinBatch(context.Background(), + []string{batchCID1.String(), batchCID2.String(), batchCID3.String()}, + BatchOptions{Parallel: 2}, + ) + + require.NoError(t, err) + assert.Equal(t, 3, result.Total) + assert.Len(t, result.Succeeded, 3) + assert.Empty(t, result.Failed) + }) + + t.Run("empty CID list returns empty result", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + service := newAuthenticatedBatchService(t, client) + + result, err := service.UnpinBatch(context.Background(), []string{}, BatchOptions{}) + + require.NoError(t, err) + assert.Equal(t, &BatchResult{}, result) + }) + + t.Run("ContinueOn=true collects failures", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + mockPin1 := NewMockPin(t, batchCID1, "") + mockResult1 := NewMockPinStatusGetterWithPin(t, mockPin1, go_pinning_service_http_client.StatusPinned) + mockPin3 := NewMockPin(t, batchCID3, "") + mockResult3 := NewMockPinStatusGetterWithPin(t, mockPin3, go_pinning_service_http_client.StatusPinned) + + lsCallCount := atomic.Int32{} + client.EXPECT().LsSync(context.Background(), mock.Anything).RunAndReturn( + func(ctx context.Context, opts ...go_pinning_service_http_client.LsOption) ([]go_pinning_service_http_client.PinStatusGetter, error) { + n := lsCallCount.Add(1) + switch n { + case 1: + return []go_pinning_service_http_client.PinStatusGetter{mockResult1}, nil + case 2: + return []go_pinning_service_http_client.PinStatusGetter{}, nil + case 3: + return []go_pinning_service_http_client.PinStatusGetter{mockResult3}, nil + default: + return nil, nil + } + }, + ).Times(3) + + deleteCallCount := atomic.Int32{} + client.EXPECT().DeleteByID(context.Background(), mock.Anything).RunAndReturn( + func(ctx context.Context, id string) error { + n := deleteCallCount.Add(1) + if n == 1 { + return nil + } + return nil + }, + ).Times(2) + + service := newAuthenticatedBatchService(t, client) + + result, err := service.UnpinBatch(context.Background(), + []string{batchCID1.String(), batchCID2.String(), batchCID3.String()}, + BatchOptions{Parallel: 1, ContinueOn: true}, + ) + + require.NoError(t, err) + assert.Equal(t, 3, result.Total) + assert.Len(t, result.Succeeded, 2) + assert.Len(t, result.Failed, 1) + assert.Equal(t, batchCID2.String(), result.Failed[0].CID) + }) + + t.Run("ContinueOn=false returns first error", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + mockPin1 := NewMockPin(t, batchCID1, "") + mockResult1 := NewMockPinStatusGetterWithPin(t, mockPin1, go_pinning_service_http_client.StatusPinned) + mockPin3 := NewMockPin(t, batchCID3, "") + mockResult3 := NewMockPinStatusGetterWithPin(t, mockPin3, go_pinning_service_http_client.StatusPinned) + + lsCallCount := atomic.Int32{} + client.EXPECT().LsSync(context.Background(), mock.Anything).RunAndReturn( + func(ctx context.Context, opts ...go_pinning_service_http_client.LsOption) ([]go_pinning_service_http_client.PinStatusGetter, error) { + n := lsCallCount.Add(1) + switch n { + case 1: + return []go_pinning_service_http_client.PinStatusGetter{mockResult1}, nil + case 2: + return []go_pinning_service_http_client.PinStatusGetter{}, nil + case 3: + return []go_pinning_service_http_client.PinStatusGetter{mockResult3}, nil + default: + return nil, nil + } + }, + ).Times(3) + + client.EXPECT().DeleteByID(context.Background(), mock.Anything).Return(nil).Times(2) + + service := newAuthenticatedBatchService(t, client) + + result, err := service.UnpinBatch(context.Background(), + []string{batchCID1.String(), batchCID2.String(), batchCID3.String()}, + BatchOptions{Parallel: 1, ContinueOn: false}, + ) + + require.Error(t, err) + assert.Empty(t, result.Failed) + }) + + t.Run("returns error when not authenticated", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + service := &PinningServiceDefault{ + pinningClient: nil, + configMgr: cfgMgr, + output: output, + apiEndpoint: "https://api.test.com", + } + + result, err := service.UnpinBatch(context.Background(), + []string{batchCID1.String()}, + BatchOptions{}, + ) + + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") + assert.Nil(t, result) + }) +} + +func TestPinningServiceDefault_UnpinAll(t *testing.T) { + t.Run("unpins all listed pins", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + pins := make([]go_pinning_service_http_client.PinStatusGetter, 3) + cids := []cid.Cid{batchCID1, batchCID2, batchCID3} + for i, c := range cids { + p := &mockPin{ + cid: c, + name: fmt.Sprintf("pin-%d", i+1), + requestID: fmt.Sprintf("req-%d", i+1), + status: go_pinning_service_http_client.StatusPinned, + meta: map[string]string{}, + created: time.Now(), + origins: []string{}, + } + pins[i] = &mockPinStatusGetter{pin: p} + } + client.EXPECT().LsSync(context.Background(), mock.Anything).Return(pins, nil) + + for range cids { + client.EXPECT().DeleteByID(context.Background(), mock.Anything).Return(nil) + } + + service := newAuthenticatedBatchService(t, client) + + result, err := service.UnpinAll(context.Background(), "", BatchOptions{Parallel: 2}) + + require.NoError(t, err) + assert.Equal(t, 3, result.Total) + assert.Len(t, result.Succeeded, 3) + assert.Empty(t, result.Failed) + }) + + t.Run("empty list returns empty result", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + client.EXPECT().LsSync(context.Background(), mock.Anything).Return( + []go_pinning_service_http_client.PinStatusGetter{}, nil, + ) + + service := newAuthenticatedBatchService(t, client) + + result, err := service.UnpinAll(context.Background(), "", BatchOptions{}) + + require.NoError(t, err) + assert.Equal(t, &BatchResult{}, result) + }) + + t.Run("returns error when list fails", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + client.EXPECT().LsSync(context.Background(), mock.Anything).Return( + nil, errors.New("list service error"), + ) + + service := newAuthenticatedBatchService(t, client) + + result, err := service.UnpinAll(context.Background(), "", BatchOptions{}) + + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("returns error when not authenticated", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + output := newTestOutput() + + service := &PinningServiceDefault{ + pinningClient: nil, + configMgr: cfgMgr, + output: output, + apiEndpoint: "https://api.test.com", + } + + result, err := service.UnpinAll(context.Background(), "", BatchOptions{}) + + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") + assert.Nil(t, result) + }) + + t.Run("ContinueOn=true collects delete failures", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + pins := make([]go_pinning_service_http_client.PinStatusGetter, 2) + for i, c := range []cid.Cid{batchCID1, batchCID2} { + p := &mockPin{ + cid: c, + name: fmt.Sprintf("pin-%d", i+1), + requestID: fmt.Sprintf("req-%d", i+1), + status: go_pinning_service_http_client.StatusPinned, + meta: map[string]string{}, + created: time.Now(), + origins: []string{}, + } + pins[i] = &mockPinStatusGetter{pin: p} + } + client.EXPECT().LsSync(context.Background(), mock.Anything).Return(pins, nil) + + var deleteCallCount atomic.Int32 + client.EXPECT().DeleteByID(context.Background(), mock.Anything).RunAndReturn( + func(ctx context.Context, id string) error { + n := deleteCallCount.Add(1) + if n == 2 { + return errors.New("delete failed") + } + return nil + }, + ).Times(2) + + service := newAuthenticatedBatchService(t, client) + + result, err := service.UnpinAll(context.Background(), "", BatchOptions{ + Parallel: 1, + ContinueOn: true, + }) + + require.NoError(t, err) + assert.Equal(t, 2, result.Total) + assert.Len(t, result.Succeeded, 1) + assert.Len(t, result.Failed, 1) + }) + + t.Run("ContinueOn=false returns first delete error", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + pins := make([]go_pinning_service_http_client.PinStatusGetter, 2) + for i, c := range []cid.Cid{batchCID1, batchCID2} { + p := &mockPin{ + cid: c, + name: fmt.Sprintf("pin-%d", i+1), + requestID: fmt.Sprintf("req-%d", i+1), + status: go_pinning_service_http_client.StatusPinned, + meta: map[string]string{}, + created: time.Now(), + origins: []string{}, + } + pins[i] = &mockPinStatusGetter{pin: p} + } + client.EXPECT().LsSync(context.Background(), mock.Anything).Return(pins, nil) + + var callCount atomic.Int32 + client.EXPECT().DeleteByID(context.Background(), mock.Anything).RunAndReturn( + func(ctx context.Context, id string) error { + n := callCount.Add(1) + if n == 1 { + return errors.New("delete failed") + } + return nil + }, + ).Times(2) + + service := newAuthenticatedBatchService(t, client) + + result, err := service.UnpinAll(context.Background(), "", BatchOptions{ + Parallel: 1, + ContinueOn: false, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "delete failed") + assert.Empty(t, result.Failed) + }) + + t.Run("with status filter passes filter to list", func(t *testing.T) { + client := climocks.NewMockPinningClient(t) + + p := &mockPin{ + cid: batchCID1, + name: "failed-pin", + requestID: "req-failed", + status: go_pinning_service_http_client.StatusFailed, + meta: map[string]string{}, + created: time.Now(), + origins: []string{}, + } + client.EXPECT().LsSync(context.Background(), mock.Anything).Return( + []go_pinning_service_http_client.PinStatusGetter{&mockPinStatusGetter{pin: p}}, nil, + ) + client.EXPECT().DeleteByID(context.Background(), mock.Anything).Return(nil) + + service := newAuthenticatedBatchService(t, client) + + result, err := service.UnpinAll(context.Background(), "failed", BatchOptions{Parallel: 1}) + + require.NoError(t, err) + assert.Equal(t, 1, result.Total) + assert.Len(t, result.Succeeded, 1) + }) +} diff --git a/pkg/cli/pinning_client_test.go b/pkg/cli/pinning_client_test.go new file mode 100644 index 0000000..f5c46c1 --- /dev/null +++ b/pkg/cli/pinning_client_test.go @@ -0,0 +1,66 @@ +package cli + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWrapNetworkError(t *testing.T) { + t.Run("nil error returns nil", func(t *testing.T) { + result := WrapNetworkError("upload", nil) + assert.Nil(t, result) + }) + + t.Run("wraps error with operation context", func(t *testing.T) { + err := errors.New("connection refused") + result := WrapNetworkError("upload", err) + assert.Contains(t, result.Error(), "upload failed") + assert.Contains(t, result.Error(), "connection refused") + assert.Contains(t, result.Error(), "Check your internet connection") + }) +} + +func TestIsBoxoAuthError(t *testing.T) { + t.Run("nil error returns false", func(t *testing.T) { + assert.False(t, isBoxoAuthError(nil)) + }) + + t.Run("401 error matches", func(t *testing.T) { + err := errors.New("remote pinning service returned http error 401: unauthorized") + assert.True(t, isBoxoAuthError(err)) + }) + + t.Run("non-auth error does not match", func(t *testing.T) { + err := errors.New("connection refused") + assert.False(t, isBoxoAuthError(err)) + }) +} + +func TestWrapPinningError(t *testing.T) { + t.Run("nil error returns nil", func(t *testing.T) { + result := wrapPinningError("pin", nil, ErrNotAuthenticated) + assert.Nil(t, result) + }) + + t.Run("auth error wraps with auth message", func(t *testing.T) { + err := errors.New("remote pinning service returned http error 401: unauthorized") + result := wrapPinningError("pin", err, ErrNotAuthenticated) + assert.Contains(t, result.Error(), "authentication expired") + }) + + t.Run("non-auth error wraps with operation", func(t *testing.T) { + err := errors.New("timeout") + result := wrapPinningError("pin", err, ErrNotAuthenticated) + assert.Contains(t, result.Error(), "pin failed") + assert.Contains(t, result.Error(), "timeout") + }) +} + +func TestWithAuthToken(t *testing.T) { + opt := WithAuthToken("my-token") + s := &PinningServiceDefault{} + opt(s) + assert.Equal(t, "my-token", s.authToken) +} diff --git a/pkg/cli/pinning_service_test.go b/pkg/cli/pinning_service_test.go index 839165b..b4217d1 100644 --- a/pkg/cli/pinning_service_test.go +++ b/pkg/cli/pinning_service_test.go @@ -28,7 +28,7 @@ func TestNewPinningService(t *testing.T) { AuthToken: testAuthToken, }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com") assert.IsType(t, &PinningServiceDefault{}, service) @@ -42,7 +42,7 @@ func TestNewPinningService(t *testing.T) { AuthToken: "", }) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com") assert.IsType(t, &PinningServiceDefault{}, service) @@ -63,7 +63,7 @@ func TestPinningService_Pin(t *testing.T) { mockResult := NewMockPinStatusGetter(t, testCID, "", go_pinning_service_http_client.StatusPinned) client.EXPECT().Add(context.Background(), testCID, mock.Anything).Return(mockResult, nil) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) result, err := service.Pin(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", "", false) @@ -83,7 +83,7 @@ func TestPinningService_Pin(t *testing.T) { mockResult := NewMockPinResult(t, "req-123", go_pinning_service_http_client.StatusPinned) client.EXPECT().Add(context.Background(), testCID, mock.Anything).Return(mockResult, nil) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) result, err := service.Pin(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", "test-name", false) @@ -94,7 +94,7 @@ func TestPinningService_Pin(t *testing.T) { t.Run("returns error when not authenticated", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := &PinningServiceDefault{ pinningClient: nil, @@ -115,7 +115,7 @@ func TestPinningService_Pin(t *testing.T) { }) client := climocks.NewMockPinningClient(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) @@ -137,7 +137,7 @@ func TestPinningService_Pin(t *testing.T) { errors.New("pinning service error"), ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) _, err := service.Pin(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", "", false) @@ -158,7 +158,7 @@ func TestPinningService_Pin(t *testing.T) { fmt.Errorf("remote pinning service returned http error 401: unauthorized"), ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) _, err := service.Pin(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", "", false) @@ -183,7 +183,7 @@ func TestPinningService_List(t *testing.T) { nil, ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) pins, err := service.List(context.Background(), "", 0, "") @@ -196,7 +196,7 @@ func TestPinningService_List(t *testing.T) { t.Run("returns error when not authenticated", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := &PinningServiceDefault{ pinningClient: nil, @@ -223,7 +223,7 @@ func TestPinningService_List(t *testing.T) { errors.New("list service error"), ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) _, err := service.List(context.Background(), "", 0, "") @@ -244,7 +244,7 @@ func TestPinningService_List(t *testing.T) { fmt.Errorf("remote pinning service returned http error 401: unauthorized"), ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) _, err := service.List(context.Background(), "", 0, "") @@ -270,7 +270,7 @@ func TestPinningService_Status(t *testing.T) { nil, ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) status, err := service.Status(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", false) @@ -281,7 +281,7 @@ func TestPinningService_Status(t *testing.T) { t.Run("returns error when not authenticated", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := &PinningServiceDefault{ pinningClient: nil, @@ -302,7 +302,7 @@ func TestPinningService_Status(t *testing.T) { }) client := climocks.NewMockPinningClient(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) @@ -324,7 +324,7 @@ func TestPinningService_Status(t *testing.T) { nil, ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) _, err := service.Status(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", false) @@ -345,7 +345,7 @@ func TestPinningService_Status(t *testing.T) { errors.New("status check error"), ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) _, err := service.Status(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", false) @@ -366,7 +366,7 @@ func TestPinningService_Status(t *testing.T) { fmt.Errorf("remote pinning service returned http error 401: unauthorized"), ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) _, err := service.Status(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", false) @@ -393,7 +393,7 @@ func TestPinningService_Unpin(t *testing.T) { ) client.EXPECT().DeleteByID(context.Background(), "req-123").Return(nil) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) result, err := service.Unpin(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", true) @@ -409,7 +409,7 @@ func TestPinningService_Unpin(t *testing.T) { }) client := climocks.NewMockPinningClient(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) @@ -421,7 +421,7 @@ func TestPinningService_Unpin(t *testing.T) { t.Run("returns error when not authenticated", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := &PinningServiceDefault{ pinningClient: nil, @@ -442,7 +442,7 @@ func TestPinningService_Unpin(t *testing.T) { }) client := climocks.NewMockPinningClient(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) @@ -464,7 +464,7 @@ func TestPinningService_Unpin(t *testing.T) { nil, ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) _, err := service.Unpin(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", true) @@ -491,7 +491,7 @@ func TestPinningService_Unpin(t *testing.T) { errors.New("unpin service error"), ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) _, err := service.Unpin(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", true) @@ -521,7 +521,7 @@ func TestPinningService_UpdateMetadata(t *testing.T) { nil, ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) err := service.UpdateMetadata(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", []string{"key", "value"}, false) @@ -530,7 +530,7 @@ func TestPinningService_UpdateMetadata(t *testing.T) { t.Run("returns error when not authenticated", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := &PinningServiceDefault{ pinningClient: nil, @@ -551,7 +551,7 @@ func TestPinningService_UpdateMetadata(t *testing.T) { }) client := climocks.NewMockPinningClient(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) @@ -567,7 +567,7 @@ func TestPinningService_UpdateMetadata(t *testing.T) { }) client := climocks.NewMockPinningClient(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) @@ -589,7 +589,7 @@ func TestPinningService_UpdateMetadata(t *testing.T) { nil, ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) err := service.UpdateMetadata(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", []string{"key", "value"}, false) @@ -617,7 +617,7 @@ func TestPinningService_UpdateMetadata(t *testing.T) { errors.New("update service error"), ) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) err := service.UpdateMetadata(context.Background(), "QmUNLLsPACCz1vLxQVkXqqLX5R1X345qqfHbsf67hvA3Nn", []string{"key", "value"}, false) @@ -630,7 +630,7 @@ func TestPinningService_waitForPinCompletion(t *testing.T) { t.Run("successfully waits for pin completion", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) client := climocks.NewMockPinningClient(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() callCount := 0 @@ -661,7 +661,7 @@ func TestPinningService_waitForPinCompletion(t *testing.T) { t.Run("returns error on context cancellation", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) client := climocks.NewMockPinningClient(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := &PinningServiceDefault{ pinningClient: client, @@ -680,7 +680,7 @@ func TestPinningService_waitForPinCompletion(t *testing.T) { t.Run("returns error when pinning fails", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) client := climocks.NewMockPinningClient(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() client.EXPECT().GetStatusByID(mock.Anything, "req-123").Return( NewMockPinStatusGetter(t, testCID, "", go_pinning_service_http_client.StatusFailed), @@ -702,7 +702,7 @@ func TestPinningService_waitForPinCompletion(t *testing.T) { t.Run("returns error when status check fails", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) client := climocks.NewMockPinningClient(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() client.EXPECT().GetStatusByID(mock.Anything, "req-123").Return( nil, @@ -739,7 +739,7 @@ func TestPinningService_watchPinStatus(t *testing.T) { nil, ).Maybe() - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) ctx, cancel := context.WithCancel(context.Background()) @@ -772,7 +772,7 @@ func TestPinningService_watchPinStatus(t *testing.T) { nil, ).Maybe() - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewPinningService(cfgMgr, output, "https://api.test.com", WithPinningClient(client)) ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/cli/pins_add.go b/pkg/cli/pins_add.go index 2e22755..2f1402b 100644 --- a/pkg/cli/pins_add.go +++ b/pkg/cli/pins_add.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" ) func newPinsAddCommand() *cli.Command { @@ -35,13 +36,23 @@ Examples: }, Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return pinsAdd(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultPinningServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + secure := GetSecureSetting(c, cfgMgr) + return pinsAdd(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory) }, } } -func pinsAdd(ctx context.Context, cmd *cliCommandWrapper, output Output, cfgMgrFactory ConfigManagerFactory, pinningServiceFactory PinningServiceFactory) error { - cids, err := pin(ctx, cmd, output, cfgMgrFactory, pinningServiceFactory) +func pinsAdd(ctx context.Context, cmd interface { + cidGetter + flagGetterWithIsSet + StringSlice(name string) []string +}, output Output, cfgMgr config.Manager, authToken string, secure bool, pinningServiceFactory PinningServiceFactory) error { + cids, err := pin(ctx, cmd, output, cfgMgr, authToken, secure, pinningServiceFactory) if err != nil { return err } @@ -64,14 +75,7 @@ func pinsAdd(ctx context.Context, cmd *cliCommandWrapper, output Output, cfgMgrF return err } - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - var pinningService PinningService - authToken := GetAuthToken(cmd.Command, cfgMgr) - secure := GetSecureSetting(cmd.Command, cfgMgr) if authToken != "" { pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { diff --git a/pkg/cli/pins_add_test.go b/pkg/cli/pins_add_test.go new file mode 100644 index 0000000..f5f903d --- /dev/null +++ b/pkg/cli/pins_add_test.go @@ -0,0 +1,170 @@ +package cli + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + +func TestNewPinsAddCommandProperties(t *testing.T) { + cmd := newPinsAddCommand() + assert.Equal(t, "add", cmd.Name) + assert.NotNil(t, cmd.Action) + assert.NotEmpty(t, cmd.Flags) + assert.Contains(t, cmd.Usage, "Pin existing content") +} + +func TestPinsAddCommand_Flags(t *testing.T) { + cmd := newPinsAddCommand() + flagNames := getFlagNames(cmd) + require.Contains(t, flagNames, "name") + require.Contains(t, flagNames, "no-wait") + require.Contains(t, flagNames, "file") + require.Contains(t, flagNames, "parallel") + require.Contains(t, flagNames, "dry-run") + require.Contains(t, flagNames, "meta") +} + +func TestPinsAdd_DryRun(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockPinningService(t) + output := newTestOutput() + + cfgMgr.EXPECT().Config().Return(&config.Config{ + Secure: true, + BaseEndpoint: "pinner.xyz", + AuthToken: "test-token", + }) + service.EXPECT().RequireAuthenticated().Return(nil) + + cmd := newMockCommand(). + withCID("QmXxx"). + withBool(FlagDryRun, true) + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service + } + + err := pinsAdd(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) + require.NoError(t, err) +} + +func TestPinsAdd_NoMeta(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockPinningService(t) + output := newTestOutput() + + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().Pin(context.Background(), "QmXxx", "", true).Return( + &PinResult{CID: "QmXxx", RequestID: "req-1", Status: "pinned"}, nil, + ) + + cmd := newMockCommand().withCID("QmXxx") + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service + } + + err := pinsAdd(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) + require.NoError(t, err) +} + +func TestPinsAdd_WithMetadata(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockPinningService(t) + output := newTestOutput() + + // pin() calls factory + RequireAuthenticated + Pin + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().Pin(context.Background(), "QmXxx", "", true).Return( + &PinResult{CID: "QmXxx", RequestID: "req-1", Status: "pinned"}, nil, + ) + // pinsAdd metadata path calls factory + RequireAuthenticated + UpdateMetadata + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().UpdateMetadata(context.Background(), "QmXxx", []string{"owner", "alice"}, false).Return(nil) + + cmd := newMockCommand(). + withCID("QmXxx"). + withStringSlice(FlagMeta, []string{"owner=alice"}) + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service + } + + err := pinsAdd(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) + require.NoError(t, err) +} + +func TestPinsAdd_PinError(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockPinningService(t) + output := newTestOutput() + + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().Pin(context.Background(), "QmXxx", "", true).Return( + nil, errors.New("pinning failed"), + ) + + cmd := newMockCommand().withCID("QmXxx") + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service + } + + err := pinsAdd(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "pinning failed") +} + +func TestPinsAdd_MetadataUpdateError(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockPinningService(t) + output := newTestOutput() + + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().Pin(context.Background(), "QmXxx", "", true).Return( + &PinResult{CID: "QmXxx", RequestID: "req-1", Status: "pinned"}, nil, + ) + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().UpdateMetadata(context.Background(), "QmXxx", []string{"owner", "alice"}, false).Return(errors.New("metadata update failed")) + + cmd := newMockCommand(). + withCID("QmXxx"). + withStringSlice(FlagMeta, []string{"owner=alice"}) + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service + } + + err := pinsAdd(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "pin succeeded but metadata update failed") +} + +func TestPinsAdd_InvalidMetaFormat(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockPinningService(t) + output := newTestOutput() + + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().Pin(context.Background(), "QmXxx", "", true).Return( + &PinResult{CID: "QmXxx", RequestID: "req-1", Status: "pinned"}, nil, + ) + + cmd := newMockCommand(). + withCID("QmXxx"). + withStringSlice(FlagMeta, []string{"invalid-no-equals"}) + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service + } + + err := pinsAdd(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid metadata pair") +} diff --git a/pkg/cli/pins_ls.go b/pkg/cli/pins_ls.go index 9f57dc4..3fbb8c7 100644 --- a/pkg/cli/pins_ls.go +++ b/pkg/cli/pins_ls.go @@ -25,8 +25,14 @@ Examples: WatchFlag(), }, Action: func(ctx context.Context, c *cli.Command) error { - output := NewOutputFormatter(c.Bool(FlagJSON), c.Bool(FlagVerbose), c.Bool(FlagQuiet), c.Bool(FlagUnmask)) - return list(ctx, c, output, defaultConfigManagerFactory, defaultPinningServiceFactory) + output := setupOutput(c) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + secure := GetSecureSetting(c, cfgMgr) + return list(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory) }, } } diff --git a/pkg/cli/pins_rm.go b/pkg/cli/pins_rm.go index 9789414..c2a7628 100644 --- a/pkg/cli/pins_rm.go +++ b/pkg/cli/pins_rm.go @@ -36,10 +36,16 @@ Examples: }, Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) if c.Bool(FlagAll) { - return unpinAll(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultPinningServiceFactory) + prompter := &PTermConfirmPrompter{} + return unpinAll(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory, prompter) } - return unpin(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultPinningServiceFactory) + return unpin(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory) }, } } diff --git a/pkg/cli/pins_status.go b/pkg/cli/pins_status.go index 934e8cb..5e5a853 100644 --- a/pkg/cli/pins_status.go +++ b/pkg/cli/pins_status.go @@ -26,7 +26,12 @@ Examples: }, Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return status(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultPinningServiceFactory, defaultStatusServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + return status(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory, defaultStatusServiceFactory) }, } } diff --git a/pkg/cli/pins_update.go b/pkg/cli/pins_update.go index 03170e0..77b290f 100644 --- a/pkg/cli/pins_update.go +++ b/pkg/cli/pins_update.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" ) func newPinsUpdateCommand() *cli.Command { @@ -28,36 +29,25 @@ Examples: }, Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return pinsUpdate(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultPinningServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + secure := GetSecureSetting(c, cfgMgr) + return pinsUpdate(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory) }, } } -// pinsUpdateCommandGetter defines the interface for getting pins update command flags. -type pinsUpdateCommandGetter interface { - String(name string) string +func pinsUpdate(ctx context.Context, cmd interface { + cidGetter + flagGetterWithIsSet StringSlice(name string) []string - Bool(name string) bool - IsSet(name string) bool - GetCID() string -} - -func pinsUpdate(ctx context.Context, cmd pinsUpdateCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, pinningServiceFactory PinningServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - - var secure bool +}, output Output, cfgMgr config.Manager, authToken string, secure bool, pinningServiceFactory PinningServiceFactory) error { var pinningService PinningService - if c, ok := cmd.(*cliCommandWrapper); ok { - authToken := GetAuthToken(c.Command, cfgMgr) - secure = GetSecureSetting(c.Command, cfgMgr) - if authToken != "" { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) - } else { - pinningService = pinningServiceFactory(cfgMgr, output) - } + if authToken != "" { + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { pinningService = pinningServiceFactory(cfgMgr, output) } diff --git a/pkg/cli/pins_update_test.go b/pkg/cli/pins_update_test.go index 3702a2c..b521067 100644 --- a/pkg/cli/pins_update_test.go +++ b/pkg/cli/pins_update_test.go @@ -11,65 +11,21 @@ import ( configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" ) -type mockPinsUpdateCommandGetter struct { - cid string - name string - meta []string - clearMeta bool - dryRun bool - isSet map[string]bool -} - -func (m *mockPinsUpdateCommandGetter) String(name string) string { - if name == FlagName { - return m.name - } - return "" -} - -func (m *mockPinsUpdateCommandGetter) StringSlice(name string) []string { - if name == FlagMeta { - return m.meta - } - return nil -} - -func (m *mockPinsUpdateCommandGetter) Bool(name string) bool { - switch name { - case FlagClearMeta: - return m.clearMeta - case FlagDryRun: - return m.dryRun - } - return false -} - -func (m *mockPinsUpdateCommandGetter) IsSet(name string) bool { - return m.isSet[name] -} - -func (m *mockPinsUpdateCommandGetter) GetCID() string { - return m.cid -} - func TestPinsUpdate(t *testing.T) { t.Run("returns error when cid is missing", func(t *testing.T) { service := NewMockPinningService(t) service.EXPECT().RequireAuthenticated().Return(nil) - cmd := &mockPinsUpdateCommandGetter{ - cid: "", - isSet: map[string]bool{FlagName: true}, - } - output := NewOutputFormatter(false, false, false, false) - cfgMgrFactory := func() (config.Manager, error) { - return configmocks.NewMockManager(t), nil - } + cmd := newMockCommand(). + withIsSet(FlagName, true) + + output := newTestOutput() + cfgMgr := configmocks.NewMockManager(t) pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := pinsUpdate(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := pinsUpdate(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) assert.Error(t, err) assert.Contains(t, err.Error(), "cid is required") }) @@ -78,19 +34,16 @@ func TestPinsUpdate(t *testing.T) { service := NewMockPinningService(t) service.EXPECT().RequireAuthenticated().Return(nil) - cmd := &mockPinsUpdateCommandGetter{ - cid: "QmTest", - isSet: map[string]bool{}, - } - output := NewOutputFormatter(false, false, false, false) - cfgMgrFactory := func() (config.Manager, error) { - return configmocks.NewMockManager(t), nil - } + cmd := newMockCommand(). + withCID("QmTest") + + output := newTestOutput() + cfgMgr := configmocks.NewMockManager(t) pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := pinsUpdate(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := pinsUpdate(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) assert.Error(t, err) assert.Contains(t, err.Error(), "at least one field must be provided for update") }) @@ -99,20 +52,18 @@ func TestPinsUpdate(t *testing.T) { service := NewMockPinningService(t) service.EXPECT().RequireAuthenticated().Return(nil) - cmd := &mockPinsUpdateCommandGetter{ - cid: "QmTest", - meta: []string{"invalid"}, - isSet: map[string]bool{FlagMeta: true}, - } - output := NewOutputFormatter(false, false, false, false) - cfgMgrFactory := func() (config.Manager, error) { - return configmocks.NewMockManager(t), nil - } + cmd := newMockCommand(). + withCID("QmTest"). + withStringSlice(FlagMeta, []string{"invalid"}). + withIsSet(FlagMeta, true) + + output := newTestOutput() + cfgMgr := configmocks.NewMockManager(t) pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := pinsUpdate(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := pinsUpdate(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) assert.Error(t, err) assert.Contains(t, err.Error(), "expected key=value format") }) @@ -124,20 +75,17 @@ func TestPinsUpdate(t *testing.T) { service := NewMockPinningService(t) service.EXPECT().RequireAuthenticated().Return(ErrNotAuthenticated) - cmd := &mockPinsUpdateCommandGetter{ - cid: "QmTest", - name: "renamed", - isSet: map[string]bool{FlagName: true}, - } - output := NewOutputFormatter(false, false, false, false) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } + cmd := newMockCommand(). + withCID("QmTest"). + withString(FlagName, "renamed"). + withIsSet(FlagName, true) + + output := newTestOutput() pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := pinsUpdate(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := pinsUpdate(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) assert.Error(t, err) assert.True(t, errors.Is(err, ErrNotAuthenticated)) }) @@ -152,20 +100,17 @@ func TestPinsUpdate(t *testing.T) { service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().UpdatePin(mock.Anything, "QmTest", "renamed", []string(nil), false).Return(nil) - cmd := &mockPinsUpdateCommandGetter{ - cid: "QmTest", - name: "renamed", - isSet: map[string]bool{FlagName: true}, - } - output := NewOutputFormatter(false, false, false, false) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } + cmd := newMockCommand(). + withCID("QmTest"). + withString(FlagName, "renamed"). + withIsSet(FlagName, true) + + output := newTestOutput() pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := pinsUpdate(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := pinsUpdate(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) assert.NoError(t, err) }) @@ -179,20 +124,17 @@ func TestPinsUpdate(t *testing.T) { service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().UpdatePin(mock.Anything, "QmTest", "", []string{"env", "prod"}, false).Return(nil) - cmd := &mockPinsUpdateCommandGetter{ - cid: "QmTest", - meta: []string{"env=prod"}, - isSet: map[string]bool{FlagMeta: true}, - } - output := NewOutputFormatter(false, false, false, false) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } + cmd := newMockCommand(). + withCID("QmTest"). + withStringSlice(FlagMeta, []string{"env=prod"}). + withIsSet(FlagMeta, true) + + output := newTestOutput() pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := pinsUpdate(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := pinsUpdate(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) assert.NoError(t, err) }) @@ -206,20 +148,17 @@ func TestPinsUpdate(t *testing.T) { service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().UpdatePin(mock.Anything, "QmTest", "", []string(nil), true).Return(nil) - cmd := &mockPinsUpdateCommandGetter{ - cid: "QmTest", - clearMeta: true, - isSet: map[string]bool{FlagClearMeta: true}, - } - output := NewOutputFormatter(false, false, false, false) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } + cmd := newMockCommand(). + withCID("QmTest"). + withBool(FlagClearMeta, true). + withIsSet(FlagClearMeta, true) + + output := newTestOutput() pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := pinsUpdate(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := pinsUpdate(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) assert.NoError(t, err) }) @@ -233,21 +172,19 @@ func TestPinsUpdate(t *testing.T) { service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().UpdatePin(mock.Anything, "QmTest", "", []string{"fresh", "start"}, true).Return(nil) - cmd := &mockPinsUpdateCommandGetter{ - cid: "QmTest", - meta: []string{"fresh=start"}, - clearMeta: true, - isSet: map[string]bool{FlagClearMeta: true, FlagMeta: true}, - } - output := NewOutputFormatter(false, false, false, false) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } + cmd := newMockCommand(). + withCID("QmTest"). + withStringSlice(FlagMeta, []string{"fresh=start"}). + withBool(FlagClearMeta, true). + withIsSet(FlagClearMeta, true). + withIsSet(FlagMeta, true) + + output := newTestOutput() pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := pinsUpdate(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := pinsUpdate(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) assert.NoError(t, err) }) @@ -261,21 +198,19 @@ func TestPinsUpdate(t *testing.T) { service.EXPECT().RequireAuthenticated().Return(nil) service.EXPECT().UpdatePin(mock.Anything, "QmTest", "renamed", []string{"env", "prod"}, false).Return(nil) - cmd := &mockPinsUpdateCommandGetter{ - cid: "QmTest", - name: "renamed", - meta: []string{"env=prod"}, - isSet: map[string]bool{FlagName: true, FlagMeta: true}, - } - output := NewOutputFormatter(false, false, false, false) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } + cmd := newMockCommand(). + withCID("QmTest"). + withString(FlagName, "renamed"). + withStringSlice(FlagMeta, []string{"env=prod"}). + withIsSet(FlagName, true). + withIsSet(FlagMeta, true) + + output := newTestOutput() pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := pinsUpdate(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := pinsUpdate(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) assert.NoError(t, err) }) @@ -290,22 +225,21 @@ func TestPinsUpdate(t *testing.T) { service := NewMockPinningService(t) service.EXPECT().RequireAuthenticated().Return(nil) - cmd := &mockPinsUpdateCommandGetter{ - cid: "QmTest", - name: "renamed", - meta: []string{"env=prod"}, - dryRun: true, - isSet: map[string]bool{FlagName: true, FlagMeta: true, FlagDryRun: true}, - } - output := NewOutputFormatter(false, false, false, false) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } + cmd := newMockCommand(). + withCID("QmTest"). + withString(FlagName, "renamed"). + withStringSlice(FlagMeta, []string{"env=prod"}). + withBool(FlagDryRun, true). + withIsSet(FlagName, true). + withIsSet(FlagMeta, true). + withIsSet(FlagDryRun, true) + + output := newTestOutput() pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := pinsUpdate(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := pinsUpdate(context.Background(), cmd, output, cfgMgr, "", true, pinningServiceFactory) assert.NoError(t, err) }) } diff --git a/pkg/cli/progress_test.go b/pkg/cli/progress_test.go index 00c2635..dc0b8a7 100644 --- a/pkg/cli/progress_test.go +++ b/pkg/cli/progress_test.go @@ -254,3 +254,33 @@ func TestProgressWriter_PipedOutput(t *testing.T) { require.NoError(t, err) assert.Equal(t, len(data), n) } + +func TestProgressWriterStartStop(t *testing.T) { + t.Run("start and stop with disabled progress", func(t *testing.T) { + data := []byte("hello") + pw := NewProgressWriter(bytes.NewReader(data), 0, true, "test") + assert.False(t, pw.enabled) + + err := pw.Start() + assert.NoError(t, err) + + err = pw.Stop() + assert.NoError(t, err) + }) +} + +func TestBatchProgressTrackerStartStop(t *testing.T) { + t.Run("start and stop with disabled progress", func(t *testing.T) { + bt := NewBatchProgressTracker(0, true, "test") + assert.False(t, bt.enabled) + + err := bt.Start() + assert.NoError(t, err) + + bt.Increment() + assert.Equal(t, 1, bt.completed) + + err = bt.Stop() + assert.NoError(t, err) + }) +} diff --git a/pkg/cli/prompt_interfaces.go b/pkg/cli/prompt_interfaces.go new file mode 100644 index 0000000..37e37f9 --- /dev/null +++ b/pkg/cli/prompt_interfaces.go @@ -0,0 +1,34 @@ +package cli + +// SelectPrompter prompts the user to select from a list of items. +type SelectPrompter interface { + // Select displays a selection prompt and returns the index and value of the selected item. + Select(label string, items []string) (int, string, error) +} + +// ContinuePrompter prompts the user to continue. +type ContinuePrompter interface { + // Continue displays a "press enter to continue" prompt. + Continue() error +} + +// Spinner displays a spinner with status messages. +type Spinner interface { + // Start begins the spinner with the given message. + Start(message string) error + // UpdateText updates the spinner message. + UpdateText(message string) + // Success stops the spinner with a success message. + Success(message string) + // Fail stops the spinner with a failure message. + Fail(message string) + // Stop stops the spinner without a success or failure message. + Stop() error +} + +// ConfirmPrompter prompts the user to confirm by typing an expected value. +type ConfirmPrompter interface { + // Confirm displays a prompt requiring the user to type the expected value. + // Returns the user's input. + Confirm(label string, expected string) (string, error) +} diff --git a/pkg/cli/prompt_mock.go b/pkg/cli/prompt_mock.go new file mode 100644 index 0000000..565be70 --- /dev/null +++ b/pkg/cli/prompt_mock.go @@ -0,0 +1,60 @@ +package cli + +// MockSelectPrompter is a mock implementation of SelectPrompter for testing. +type MockSelectPrompter struct { + SelectResult int + SelectString string + SelectErr error +} + +func (m *MockSelectPrompter) Select(label string, items []string) (int, string, error) { + return m.SelectResult, m.SelectString, m.SelectErr +} + +// MockContinuePrompter is a mock implementation of ContinuePrompter for testing. +type MockContinuePrompter struct { + ContinueErr error +} + +func (m *MockContinuePrompter) Continue() error { + return m.ContinueErr +} + +// MockSpinner is a mock implementation of Spinner for testing. +type MockSpinner struct { + StartErr error + StopErr error + Messages []string +} + +func (m *MockSpinner) Start(message string) error { + m.Messages = append(m.Messages, "start:"+message) + return m.StartErr +} + +func (m *MockSpinner) UpdateText(message string) { + m.Messages = append(m.Messages, "update:"+message) +} + +func (m *MockSpinner) Success(message string) { + m.Messages = append(m.Messages, "success:"+message) +} + +func (m *MockSpinner) Fail(message string) { + m.Messages = append(m.Messages, "fail:"+message) +} + +func (m *MockSpinner) Stop() error { + m.Messages = append(m.Messages, "stop") + return m.StopErr +} + +// MockConfirmPrompter is a mock implementation of ConfirmPrompter for testing. +type MockConfirmPrompter struct { + ConfirmResult string + ConfirmErr error +} + +func (m *MockConfirmPrompter) Confirm(label string, expected string) (string, error) { + return m.ConfirmResult, m.ConfirmErr +} diff --git a/pkg/cli/prompt_pterm.go b/pkg/cli/prompt_pterm.go new file mode 100644 index 0000000..2e9fdd5 --- /dev/null +++ b/pkg/cli/prompt_pterm.go @@ -0,0 +1,95 @@ +package cli + +import ( + "fmt" + + "github.com/manifoldco/promptui" + "github.com/pterm/pterm" +) + +// PTermSelectPrompter implements SelectPrompter using promptui.Select. +type PTermSelectPrompter struct{} + +func (p *PTermSelectPrompter) Select(label string, items []string) (int, string, error) { + prompt := promptui.Select{ + Label: label, + Items: items, + } + idx, result, err := prompt.Run() + if err != nil { + return 0, "", handleInterrupt(err) + } + return idx, result, nil +} + +// PTermContinuePrompter implements ContinuePrompter using pterm.DefaultInteractiveContinue. +type PTermContinuePrompter struct{} + +func (p *PTermContinuePrompter) Continue() error { + _, err := pterm.DefaultInteractiveContinue.Show() + return err +} + +// PTermSpinner implements Spinner using pterm.DefaultSpinner. +type PTermSpinner struct { + spinner *pterm.SpinnerPrinter + started bool +} + +func (s *PTermSpinner) Start(message string) error { + spinner, err := pterm.DefaultSpinner.Start(message) + if err != nil { + return fmt.Errorf("failed to start spinner: %w", err) + } + s.spinner = spinner + s.started = true + return nil +} + +func (s *PTermSpinner) UpdateText(message string) { + if s.spinner != nil { + s.spinner.UpdateText(message) + } +} + +func (s *PTermSpinner) Success(message string) { + if s.spinner != nil { + s.spinner.Success(message) + s.started = false + } +} + +func (s *PTermSpinner) Fail(message string) { + if s.spinner != nil { + s.spinner.Fail(message) + s.started = false + } +} + +func (s *PTermSpinner) Stop() error { + if s.spinner != nil && s.started { + s.spinner.Stop() + s.started = false + } + return nil +} + +// PTermConfirmPrompter implements ConfirmPrompter using promptui.Prompt. +type PTermConfirmPrompter struct{} + +func (p *PTermConfirmPrompter) Confirm(label string, expected string) (string, error) { + prompt := promptui.Prompt{ + Label: label, + Validate: func(input string) error { + if input != expected { + return fmt.Errorf("must type %s to confirm", expected) + } + return nil + }, + } + result, err := prompt.Run() + if err != nil { + return "", handleInterrupt(err) + } + return result, nil +} diff --git a/pkg/cli/register.go b/pkg/cli/register.go index 5f2cf72..e360ea9 100644 --- a/pkg/cli/register.go +++ b/pkg/cli/register.go @@ -50,12 +50,12 @@ Examples: }, Action: func(ctx context.Context, cmd *cli.Command) error { output := setupOutput(cmd) - return register(ctx, cmd, output, defaultConfigManagerFactory, defaultAuthServiceFactory) + return register(ctx, newCLICommandWrapper(cmd), output, defaultConfigManagerFactory, defaultAuthServiceFactory) }, } } -func register(ctx context.Context, cmd *cli.Command, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory) error { +func register(ctx context.Context, cmd flagGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory) error { email := cmd.String(FlagEmail) firstName := cmd.String(FlagFirstName) lastName := cmd.String(FlagLastName) diff --git a/pkg/cli/register_test.go b/pkg/cli/register_test.go new file mode 100644 index 0000000..ada5a97 --- /dev/null +++ b/pkg/cli/register_test.go @@ -0,0 +1,179 @@ +package cli + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" +) + +func TestNewRegisterCommand(t *testing.T) { + cmd := newRegisterCommand() + + assert.Equal(t, "register", cmd.Name) + assert.Equal(t, "Setup", cmd.Category) + assert.NotEmpty(t, cmd.Usage) + assert.NotEmpty(t, cmd.Description) + assert.NotNil(t, cmd.Action) + + flagNames := getFlagNames(cmd) + nameSet := make(map[string]bool) + for _, n := range flagNames { + nameSet[n] = true + } + + expectedFlags := []string{FlagEmail, FlagFirstName, FlagLastName, FlagPassword} + for _, f := range expectedFlags { + assert.True(t, nameSet[f], "register command should have flag --%s", f) + } +} + +func TestRegisterAllFlagsProvided(t *testing.T) { + authService := NewMockAuthService(t) + output := newTestOutput() + + authService.EXPECT().Register(context.Background(), "user@example.com", "John", "Doe", "secret123").Return(nil) + + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.StringFlag{Name: FlagEmail, Value: "user@example.com"}, + &cli.StringFlag{Name: FlagFirstName, Value: "John"}, + &cli.StringFlag{Name: FlagLastName, Value: "Doe"}, + &cli.StringFlag{Name: FlagPassword, Value: "secret123"}, + }, + } + + cfgMgrFactory := func() (config.Manager, error) { return newTestConfigMgr(t), nil } + authServiceFactory := func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService { + return authService + } + + err := register(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory, authServiceFactory) + require.NoError(t, err) +} + +func TestRegisterConfigManagerError(t *testing.T) { + output := newTestOutput() + + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.StringFlag{Name: FlagEmail, Value: "user@example.com"}, + &cli.StringFlag{Name: FlagFirstName, Value: "John"}, + &cli.StringFlag{Name: FlagLastName, Value: "Doe"}, + &cli.StringFlag{Name: FlagPassword, Value: "secret123"}, + }, + } + + cfgMgrFactory := func() (config.Manager, error) { + return nil, errors.New("config error") + } + authServiceFactory := func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService { + return nil + } + + err := register(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory, authServiceFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to create config manager") +} + +func TestRegisterAuthServiceError(t *testing.T) { + authService := NewMockAuthService(t) + output := newTestOutput() + + authService.EXPECT().Register(context.Background(), "user@example.com", "John", "Doe", "secret123"). + Return(errors.New("registration failed")) + + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.StringFlag{Name: FlagEmail, Value: "user@example.com"}, + &cli.StringFlag{Name: FlagFirstName, Value: "John"}, + &cli.StringFlag{Name: FlagLastName, Value: "Doe"}, + &cli.StringFlag{Name: FlagPassword, Value: "secret123"}, + }, + } + + cfgMgrFactory := func() (config.Manager, error) { return newTestConfigMgr(t), nil } + authServiceFactory := func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService { + return authService + } + + err := register(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory, authServiceFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "registration failed") +} + +func TestRegisterMissingEmailPrompts(t *testing.T) { + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.StringFlag{Name: FlagFirstName, Value: "John"}, + &cli.StringFlag{Name: FlagLastName, Value: "Doe"}, + &cli.StringFlag{Name: FlagPassword, Value: "secret123"}, + }, + } + + output := newTestOutput() + cfgMgrFactory := func() (config.Manager, error) { return newTestConfigMgr(t), nil } + authServiceFactory := func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService { return nil } + + err := register(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory, authServiceFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to read email") +} + +func TestRegisterMissingFirstNamePrompts(t *testing.T) { + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.StringFlag{Name: FlagEmail, Value: "user@example.com"}, + &cli.StringFlag{Name: FlagLastName, Value: "Doe"}, + &cli.StringFlag{Name: FlagPassword, Value: "secret123"}, + }, + } + + output := newTestOutput() + cfgMgrFactory := func() (config.Manager, error) { return newTestConfigMgr(t), nil } + authServiceFactory := func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService { return nil } + + err := register(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory, authServiceFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to read first name") +} + +func TestRegisterMissingLastNamePrompts(t *testing.T) { + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.StringFlag{Name: FlagEmail, Value: "user@example.com"}, + &cli.StringFlag{Name: FlagFirstName, Value: "John"}, + &cli.StringFlag{Name: FlagPassword, Value: "secret123"}, + }, + } + + output := newTestOutput() + cfgMgrFactory := func() (config.Manager, error) { return newTestConfigMgr(t), nil } + authServiceFactory := func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService { return nil } + + err := register(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory, authServiceFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to read last name") +} + +func TestRegisterMissingPasswordPrompts(t *testing.T) { + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.StringFlag{Name: FlagEmail, Value: "user@example.com"}, + &cli.StringFlag{Name: FlagFirstName, Value: "John"}, + &cli.StringFlag{Name: FlagLastName, Value: "Doe"}, + }, + } + + output := newTestOutput() + cfgMgrFactory := func() (config.Manager, error) { return newTestConfigMgr(t), nil } + authServiceFactory := func(cfgMgr config.Manager, output Output, apiEndpoint string) AuthService { return nil } + + err := register(context.Background(), newCLICommandWrapper(cmd), output, cfgMgrFactory, authServiceFactory) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to read password") +} diff --git a/pkg/cli/restructure_integration_test.go b/pkg/cli/restructure_integration_test.go index 12a0e4e..6fe1c6c 100644 --- a/pkg/cli/restructure_integration_test.go +++ b/pkg/cli/restructure_integration_test.go @@ -40,8 +40,8 @@ func TestIntegration_AliasEquivalence_PinAndPinsAdd(t *testing.T) { pinCmd := newPinCommand() pinsAddCmd := newPinsAddCommand() - pinFlags := getFlagNames(pinCmd.Flags) - pinsAddFlags := getFlagNames(pinsAddCmd.Flags) + pinFlags := getFlagNames(pinCmd) + pinsAddFlags := getFlagNames(pinsAddCmd) // Both must have --name and --no-wait for _, required := range []string{FlagName, FlagNoWait} { @@ -65,8 +65,8 @@ func TestIntegration_AliasEquivalence_UnpinAndPinsRm(t *testing.T) { unpinCmd := newUnpinCommand() pinsRmCmd := newPinsRmCommand() - unpinFlags := getFlagNames(unpinCmd.Flags) - pinsRmFlags := getFlagNames(pinsRmCmd.Flags) + unpinFlags := getFlagNames(unpinCmd) + pinsRmFlags := getFlagNames(pinsRmCmd) // Both must have --force and --confirm for _, required := range []string{FlagForce, FlagConfirm} { @@ -91,8 +91,8 @@ func TestIntegration_AliasEquivalence_ListAndPinsLs(t *testing.T) { listCmd := newListCommand() pinsLsCmd := newPinsLsCommand() - listFlags := getFlagNames(listCmd.Flags) - pinsLsFlags := getFlagNames(pinsLsCmd.Flags) + listFlags := getFlagNames(listCmd) + pinsLsFlags := getFlagNames(pinsLsCmd) // Both must have --name, --limit, --status, --watch for _, required := range []string{FlagName, FlagLimit, FlagStatus, FlagWatch} { @@ -107,8 +107,8 @@ func TestIntegration_AliasEquivalence_StatusAndPinsStatus(t *testing.T) { statusCmd := newStatusCommand() pinsStatusCmd := newPinsStatusCommand() - statusFlags := getFlagNames(statusCmd.Flags) - pinsStatusFlags := getFlagNames(pinsStatusCmd.Flags) + statusFlags := getFlagNames(statusCmd) + pinsStatusFlags := getFlagNames(pinsStatusCmd) // Both must have --watch assert.Contains(t, statusFlags, "watch", "status command should have --watch flag") @@ -175,7 +175,7 @@ func TestIntegration_MetaFlagOnCreationAndUpdate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - flagNames := getFlagNames(tt.cmd.Flags) + flagNames := getFlagNames(tt.cmd) if tt.hasMeta { assert.Contains(t, flagNames, FlagMeta, "%s should have --meta flag", tt.name) } else { @@ -188,7 +188,7 @@ func TestIntegration_MetaFlagOnCreationAndUpdate(t *testing.T) { // TestIntegration_ClearMetaOnUpdate verifies that pins update has --clear-meta flag. func TestIntegration_ClearMetaOnUpdate(t *testing.T) { cmd := newPinsUpdateCommand() - flagNames := getFlagNames(cmd.Flags) + flagNames := getFlagNames(cmd) assert.Contains(t, flagNames, FlagClearMeta, "pins update should have --clear-meta flag") assert.Contains(t, flagNames, FlagName, "pins update should have --name flag") @@ -209,7 +209,7 @@ func TestIntegration_ForceFlagOnDestructiveCommands(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - flagNames := getFlagNames(tt.cmd.Flags) + flagNames := getFlagNames(tt.cmd) assert.Contains(t, flagNames, FlagForce, "%s should have --force flag", tt.name) }) } @@ -284,7 +284,7 @@ func TestIntegration_PinsSubcommands(t *testing.T) { assert.Len(t, cmd.Commands, 5, "pins should have exactly 5 subcommands") expected := []string{"add", "rm", "ls", "status", "update"} - names := getSubcommandNames(cmd.Commands) + names := getSubcommandNames(cmd) for _, name := range expected { assert.Contains(t, names, name, "pins should have %q subcommand", name) @@ -294,7 +294,7 @@ func TestIntegration_PinsSubcommands(t *testing.T) { // TestIntegration_PinsRmFlags verifies that pins rm has --all, --force, --status flags. func TestIntegration_PinsRmFlags(t *testing.T) { cmd := newPinsRmCommand() - flagNames := getFlagNames(cmd.Flags) + flagNames := getFlagNames(cmd) assert.Contains(t, flagNames, FlagAll, "pins rm should have --all flag") assert.Contains(t, flagNames, FlagForce, "pins rm should have --force flag") @@ -308,7 +308,7 @@ func TestIntegration_PinsRmFlags(t *testing.T) { // --clear-meta, and --dry-run flags. func TestIntegration_PinsUpdateFlags(t *testing.T) { cmd := newPinsUpdateCommand() - flagNames := getFlagNames(cmd.Flags) + flagNames := getFlagNames(cmd) assert.Contains(t, flagNames, FlagName, "pins update should have --name flag") assert.Contains(t, flagNames, FlagMeta, "pins update should have --meta flag") @@ -364,7 +364,7 @@ func TestIntegration_ShellCompletion(t *testing.T) { // TestIntegration_UploadHasNoWait verifies that upload command has --no-wait flag. func TestIntegration_UploadHasNoWait(t *testing.T) { cmd := newUploadCommand() - flagNames := getFlagNames(cmd.Flags) + flagNames := getFlagNames(cmd) assert.Contains(t, flagNames, FlagNoWait, "upload should have --no-wait flag") assert.Contains(t, flagNames, FlagMeta, "upload should have --meta flag") diff --git a/pkg/cli/root_test.go b/pkg/cli/root_test.go new file mode 100644 index 0000000..7085af6 --- /dev/null +++ b/pkg/cli/root_test.go @@ -0,0 +1,73 @@ +package cli + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRootCommandProperties(t *testing.T) { + cmd := NewRootCommand() + + assert.Equal(t, "pinner", cmd.Name) + assert.Equal(t, "Simple IPFS Pinning CLI", cmd.Usage) + assert.True(t, cmd.EnableShellCompletion) + assert.NotNil(t, cmd.Action) +} + +func TestRootCommandAction(t *testing.T) { + cmd := NewRootCommand() + require.NotNil(t, cmd.Action, "root command should have an action") +} + +func TestRootCommandAllSubcommands(t *testing.T) { + cmd := NewRootCommand() + + expectedSubcommands := []string{ + "setup", "auth", "register", "confirm-email", "account", + "upload", "download", "cat", "ls", + "pin", "pins", "list", "status", "unpin", + "metadata", "operations", "config", "doctor", "bench", + "dns", "ipns", "websites", "admin", + } + + subcommandNames := getSubcommandNames(cmd) + nameSet := make(map[string]bool) + for _, n := range subcommandNames { + nameSet[n] = true + } + + for _, expected := range expectedSubcommands { + assert.True(t, nameSet[expected], "root command should have subcommand %q", expected) + } +} + +func TestRootCommandHasGlobalFlags(t *testing.T) { + cmd := NewRootCommand() + + flagNames := getFlagNames(cmd) + nameSet := make(map[string]bool) + for _, n := range flagNames { + nameSet[n] = true + } + + requiredFlags := []string{FlagJSON, FlagVerbose, FlagQuiet, FlagUnmask, FlagAuthToken, FlagSecure} + for _, f := range requiredFlags { + assert.True(t, nameSet[f], "root command should have global flag --%s", f) + } +} + +func TestRootCommandDescription(t *testing.T) { + cmd := NewRootCommand() + + require.NotEmpty(t, cmd.Description) + assert.Contains(t, cmd.Description, "pinner setup") + assert.Contains(t, cmd.Description, "pinner upload") +} + +func TestRun(t *testing.T) { + err := Run(context.Background(), []string{"pinner", "--version"}) + require.NoError(t, err, "Run with --version should succeed") +} diff --git a/pkg/cli/setup.go b/pkg/cli/setup.go index 300b3eb..baa3af7 100644 --- a/pkg/cli/setup.go +++ b/pkg/cli/setup.go @@ -53,7 +53,7 @@ Examples: // runSetupCommand is the testable entry point for setup. func runSetupWizard( ctx context.Context, - cmd commandGetter, + cmd flagGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, @@ -64,7 +64,7 @@ func runSetupWizard( // runSetupWizardWithFactories is the testable implementation with dependency injection. func runSetupWizardWithFactories( ctx context.Context, - cmd commandGetter, + cmd flagGetter, output Output, cfgMgrFactory ConfigManagerFactory, authServiceFactory AuthServiceFactory, diff --git a/pkg/cli/setup_mock.go b/pkg/cli/setup_mock.go index 618114f..bdcd22e 100644 --- a/pkg/cli/setup_mock.go +++ b/pkg/cli/setup_mock.go @@ -23,6 +23,11 @@ type MockSetupUI struct { ContinueError error + SelectResult int + SelectString string + SelectErr error + ContinueErr error + AuthExecuted bool ConfigExecuted bool TutorialExecuted bool @@ -174,3 +179,11 @@ func (m *MockSetupUI) ExecuteCompletionStep(_ *SetupWizard) error { return nil } + +func (m *MockSetupUI) Select(label string, items []string) (int, string, error) { + return m.SelectResult, m.SelectString, m.SelectErr +} + +func (m *MockSetupUI) Continue() error { + return m.ContinueErr +} diff --git a/pkg/cli/setup_pterm.go b/pkg/cli/setup_pterm.go index 25738f8..46ed62f 100644 --- a/pkg/cli/setup_pterm.go +++ b/pkg/cli/setup_pterm.go @@ -7,35 +7,27 @@ import ( "runtime" "strings" - "github.com/manifoldco/promptui" "github.com/pterm/pterm" "github.com/pterm/pterm/putils" "go.lumeweb.com/pinner-cli/pkg/cli/wizard" ) -// runSelect executes a select prompt and handles interrupts. -// Returns the index, selected item, or error. -func runSelect(prompt *promptui.Select) (int, string, error) { - idx, result, err := prompt.Run() - if err == promptui.ErrInterrupt { - cleanupTerminal() - return 0, "", fmt.Errorf("setup cancelled") - } - return idx, result, err -} - // PTermSetupUI implements SetupUI using PTerm for display. // This is the production UI layer - tests use mocks. type PTermSetupUI struct { *wizard.PTermUI + *PTermSelectPrompter + *PTermContinuePrompter output Output } // NewPTermSetupUI creates a new PTerm-based UI. func NewPTermSetupUI(output Output) *PTermSetupUI { return &PTermSetupUI{ - PTermUI: wizard.NewPTermUI("", ""), - output: output, + PTermUI: wizard.NewPTermUI("", ""), + PTermSelectPrompter: &PTermSelectPrompter{}, + PTermContinuePrompter: &PTermContinuePrompter{}, + output: output, } } @@ -61,8 +53,7 @@ func (ui *PTermSetupUI) ShowWelcome() error { pterm.Println() - _, err := pterm.DefaultInteractiveContinue.Show() - return err + return ui.Continue() } // ShowCompletion displays the completion message. @@ -76,7 +67,7 @@ func (ui *PTermSetupUI) ShowCompletion() error { " • Run 'pinner pin ' to pin by CID\n" + " • Run 'pinner list' to view your pins\n" + " • Run 'pinner --help' for more commands\n\n" + - "Need help? Visit " + DocumentationURL, + "Need help? visit " + DocumentationURL, ) pterm.DefaultCenter.Println(successBox) return nil @@ -94,12 +85,7 @@ func (ui *PTermSetupUI) ExecuteAuthStep(ctx context.Context, wizard *SetupWizard "Skip (configure later with 'pinner auth')", } - prompt := promptui.Select{ - Label: "What would you like to do?", - Items: choices, - } - - _, result, err := runSelect(&prompt) + _, result, err := ui.Select("What would you like to do?", choices) if err != nil { return fmt.Errorf("prompt failed: %w", err) } @@ -113,7 +99,7 @@ func (ui *PTermSetupUI) ExecuteAuthStep(ctx context.Context, wizard *SetupWizard pterm.Info.Println("After creating your account, we'll help you sign in.") pterm.Println() - if _, err := pterm.DefaultInteractiveContinue.Show(); err != nil { + if err := ui.Continue(); err != nil { return err } @@ -125,8 +111,7 @@ func (ui *PTermSetupUI) ExecuteAuthStep(ctx context.Context, wizard *SetupWizard case choices[2]: // Skip pterm.Warning.Println("Skipping authentication. You can run 'pinner auth' later.") pterm.Println() - _, err := pterm.DefaultInteractiveContinue.Show() - return err + return ui.Continue() } return fmt.Errorf("invalid choice") @@ -151,8 +136,8 @@ func (ui *PTermSetupUI) handleSignIn(ctx context.Context, wizard *SetupWizard) e pterm.Println() - spinner, err := pterm.DefaultSpinner.Start("Authenticating...") - if err != nil { + spinner := &PTermSpinner{} + if err := spinner.Start("Authenticating..."); err != nil { return fmt.Errorf("failed to start spinner: %w", err) } @@ -207,12 +192,7 @@ func (ui *PTermSetupUI) ExecuteConfigStep(ctx context.Context, wizard *SetupWiza "Skip (use defaults)", } - prompt := promptui.Select{ - Label: "What would you like to do?", - Items: choices, - } - - _, result, err := runSelect(&prompt) + _, result, err := ui.Select("What would you like to do?", choices) if err != nil { return fmt.Errorf("prompt failed: %w", err) } @@ -244,11 +224,7 @@ func (ui *PTermSetupUI) handleCustomConfig(wizard *SetupWizard) error { return fmt.Errorf("endpoint prompt failed: %w", err) } - securePrompt := promptui.Select{ - Label: "Use HTTPS?", - Items: []string{"Yes", "No"}, - } - _, secureChoice, err := runSelect(&securePrompt) + _, secureChoice, err := ui.Select("Use HTTPS?", []string{"Yes", "No"}) if err != nil { return fmt.Errorf("secure prompt failed: %w", err) } @@ -287,8 +263,7 @@ func (ui *PTermSetupUI) ExecuteTutorialStep(_ *SetupWizard) error { pterm.Printf("Documentation: %s\n", DocumentationURL) pterm.Println() - _, err := pterm.DefaultInteractiveContinue.Show() - return err + return ui.Continue() } // ExecuteCompletionStep offers to set up shell completion. @@ -301,12 +276,7 @@ func (ui *PTermSetupUI) ExecuteCompletionStep(_ *SetupWizard) error { "Skip (I'll set it up later with 'pinner completion')", } - prompt := promptui.Select{ - Label: "Would you like to enable shell completion?", - Items: choices, - } - - _, result, err := runSelect(&prompt) + _, result, err := ui.Select("Would you like to enable shell completion?", choices) if err != nil { return fmt.Errorf("prompt failed: %w", err) } @@ -323,8 +293,7 @@ func (ui *PTermSetupUI) ExecuteCompletionStep(_ *SetupWizard) error { pterm.Printf("To enable completion later, run: pinner completion \n") pterm.Printf(" Example: pinner completion bash\n") pterm.Println() - _, err := pterm.DefaultInteractiveContinue.Show() - return err + return ui.Continue() } return fmt.Errorf("invalid choice") @@ -339,8 +308,7 @@ func (ui *PTermSetupUI) handleCompletionSetup() error { pterm.Printf("To enable completion, run: pinner completion \n") pterm.Printf(" Example: pinner completion bash\n") pterm.Println() - _, err := pterm.DefaultInteractiveContinue.Show() - return err + return ui.Continue() } shell := detectShell() @@ -357,8 +325,7 @@ func (ui *PTermSetupUI) handleCompletionSetup() error { pterm.Printf("Detected shell: %s\n\n", shell) pterm.Printf("To enable completion, run: pinner completion %s\n", shell) pterm.Println() - _, err := pterm.DefaultInteractiveContinue.Show() - return err + return ui.Continue() } pterm.Printf("Detected shell: %s\n\n", detector.Name()) @@ -368,8 +335,7 @@ func (ui *PTermSetupUI) handleCompletionSetup() error { pterm.Printf(" echo '%s' >> %s\n", detector.InstallCommand(), detector.ConfigPath()) pterm.Println() - _, err = pterm.DefaultInteractiveContinue.Show() - return err + return ui.Continue() } // detectShell attempts to detect the current shell. diff --git a/pkg/cli/setup_test.go b/pkg/cli/setup_test.go index 5bc66ec..75a19f4 100644 --- a/pkg/cli/setup_test.go +++ b/pkg/cli/setup_test.go @@ -471,6 +471,38 @@ func TestSetupWizard_Accessors(t *testing.T) { require.Equal(t, options, wizard.Options()) } +func TestSetupWizard_NonInteractive(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfg := &config.Config{AuthToken: "token"} + cfgMgr.EXPECT().Config().Return(cfg).Maybe() + + output := newTestOutput() + cmd := &mockCommand{boolFields: map[string]bool{"non-interactive": true}} + + err := runSetupWizardWithFactories(context.Background(), cmd, output, func() (config.Manager, error) { + return cfgMgr, nil + }, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "interactive mode") +} + +func TestSetupWizard_ConfigError(t *testing.T) { + output := newTestOutput() + cmd := &mockCommand{boolFields: map[string]bool{}} + + err := runSetupWizardWithFactories(context.Background(), cmd, output, failingConfigMgrFactory(), nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to initialize config manager") +} + +func TestNewSetupCommand(t *testing.T) { + cmd := newSetupCommand() + require.Equal(t, "setup", cmd.Name) + require.Equal(t, "Setup", cmd.Category) + require.NotNil(t, cmd.Action) + require.NotEmpty(t, cmd.Flags) +} + func TestMockSetupUI(t *testing.T) { t.Run("call tracking", func(t *testing.T) { mock := NewMockSetupUI() diff --git a/pkg/cli/setup_ui.go b/pkg/cli/setup_ui.go index 0e1f911..0f87af7 100644 --- a/pkg/cli/setup_ui.go +++ b/pkg/cli/setup_ui.go @@ -10,6 +10,8 @@ import ( // This allows for easy testing by providing mock implementations. type SetupUI interface { wizard.UI + SelectPrompter + ContinuePrompter // Step execution ExecuteAuthStep(ctx context.Context, wizard *SetupWizard) error diff --git a/pkg/cli/sources_test.go b/pkg/cli/sources_test.go new file mode 100644 index 0000000..3fd222a --- /dev/null +++ b/pkg/cli/sources_test.go @@ -0,0 +1,28 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStdinSourceString(t *testing.T) { + s := &StdinSource{} + assert.Equal(t, "stdin", s.String()) +} + +func TestStdinSourceGoString(t *testing.T) { + s := &StdinSource{} + assert.Equal(t, "cli.StdinSource", s.GoString()) +} + +func TestNewStdinSource(t *testing.T) { + s := NewStdinSource() + assert.NotNil(t, s) +} + +func TestStdin(t *testing.T) { + s := Stdin() + assert.NotNil(t, s) + assert.IsType(t, &StdinSource{}, s) +} diff --git a/pkg/cli/status.go b/pkg/cli/status.go index 01eef6a..8f6b42a 100644 --- a/pkg/cli/status.go +++ b/pkg/cli/status.go @@ -45,34 +45,27 @@ Operation status values (shown when pin is not found): Metadata: WithTutorial(4, "Check pin status", fmt.Sprintf("pinner status %s", abbreviateCID(TutorialCID))), Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return status(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultPinningServiceFactory, defaultStatusServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + return status(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory, defaultStatusServiceFactory) }, } } -type statusCommandGetter interface { - Bool(name string) bool - GetCID() string -} - func defaultStatusServiceFactory(cfgMgr config.Manager, output Output, pinningService PinningService, authService AuthService) StatusService { return NewStatusService(cfgMgr, output, pinningService, authService) } -func status(ctx context.Context, cmd statusCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, pinningServiceFactory PinningServiceFactory, statusServiceFactory StatusServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - +func status(ctx context.Context, cmd interface { + cidGetter + Bool(name string) bool +}, output Output, cfgMgr config.Manager, authToken string, pinningServiceFactory PinningServiceFactory, statusServiceFactory StatusServiceFactory) error { var pinningService PinningService - if c, ok := cmd.(*cliCommandWrapper); ok { - authToken := GetAuthToken(c.Command, cfgMgr) - if authToken != "" { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpoint(), WithAuthToken(authToken)) - } else { - pinningService = pinningServiceFactory(cfgMgr, output) - } + if authToken != "" { + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpoint(), WithAuthToken(authToken)) } else { pinningService = pinningServiceFactory(cfgMgr, output) } diff --git a/pkg/cli/status_service_test.go b/pkg/cli/status_service_test.go index 0b0f1e7..ef06162 100644 --- a/pkg/cli/status_service_test.go +++ b/pkg/cli/status_service_test.go @@ -36,7 +36,7 @@ func makeOperation(id int, cid string, status string, opName string, protocol st func TestStatusServiceDefault_Status(t *testing.T) { t.Run("returns pin status when pin is found", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningSvc := NewMockPinningService(t) pinningSvc.EXPECT().RequireAuthenticated().Return(nil) @@ -56,7 +56,7 @@ func TestStatusServiceDefault_Status(t *testing.T) { t.Run("falls back to operation when pin not found", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningSvc := NewMockPinningService(t) authSvc := NewMockAuthService(t) accountClient := portalsdkmocks.NewMockAccountAPI(t) @@ -92,7 +92,7 @@ func TestStatusServiceDefault_Status(t *testing.T) { t.Run("returns pin not found when no operation exists either", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningSvc := NewMockPinningService(t) authSvc := NewMockAuthService(t) accountClient := portalsdkmocks.NewMockAccountAPI(t) @@ -120,7 +120,7 @@ func TestStatusServiceDefault_Status(t *testing.T) { t.Run("returns pin not found when auth service is nil", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningSvc := NewMockPinningService(t) pinningSvc.EXPECT().RequireAuthenticated().Return(nil) @@ -139,7 +139,7 @@ func TestStatusServiceDefault_Status(t *testing.T) { t.Run("returns error when pin status fails with non-ErrPinNotFound error", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningSvc := NewMockPinningService(t) pinningSvc.EXPECT().RequireAuthenticated().Return(nil) @@ -158,7 +158,7 @@ func TestStatusServiceDefault_Status(t *testing.T) { t.Run("populates error field from operation", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningSvc := NewMockPinningService(t) authSvc := NewMockAuthService(t) accountClient := portalsdkmocks.NewMockAccountAPI(t) @@ -195,7 +195,7 @@ func TestStatusServiceDefault_Status(t *testing.T) { t.Run("returns error when auth service fails to get authenticated client", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningSvc := NewMockPinningService(t) authSvc := NewMockAuthService(t) @@ -217,7 +217,7 @@ func TestStatusServiceDefault_Status(t *testing.T) { t.Run("reuses account client on subsequent lookups", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningSvc := NewMockPinningService(t) authSvc := NewMockAuthService(t) accountClient := portalsdkmocks.NewMockAccountAPI(t) @@ -269,7 +269,7 @@ func TestStatusServiceDefault_Status(t *testing.T) { t.Run("uses pre-injected account client when available", func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() pinningSvc := NewMockPinningService(t) accountClient := portalsdkmocks.NewMockAccountAPI(t) diff --git a/pkg/cli/status_test.go b/pkg/cli/status_test.go index 6b39707..7e97d8f 100644 --- a/pkg/cli/status_test.go +++ b/pkg/cli/status_test.go @@ -1,6 +1,7 @@ package cli import ( + "bytes" "context" "errors" "testing" @@ -14,13 +15,12 @@ import ( func TestStatus(t *testing.T) { tests := []struct { - name string - cid string - watchFlag bool - setupMocks func(*configmocks.MockManager, *MockPinningService, *MockStatusService) - wantErr bool - errContains string - cfgMgrFactoryErr bool + name string + cid string + watchFlag bool + setupMocks func(*configmocks.MockManager, *MockPinningService, *MockStatusService) + wantErr bool + errContains string }{ { name: "successful pin status check", @@ -129,15 +129,6 @@ func TestStatus(t *testing.T) { }, wantErr: false, }, - { - name: "returns error when config manager factory fails", - cid: "QmXxx", - watchFlag: false, - setupMocks: func(cfgMgr *configmocks.MockManager, pinSvc *MockPinningService, statusSvc *MockStatusService) {}, - wantErr: true, - errContains: "config error", - cfgMgrFactoryErr: true, - }, { name: "returns error when status check fails", cid: "QmXxx", @@ -185,27 +176,15 @@ func TestStatus(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) pinningSvc := NewMockPinningService(t) statusSvc := NewMockStatusService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, pinningSvc, statusSvc) } - cmd := &mockStatusCommand{ - cid: tt.cid, - watch: tt.watchFlag, - } - - var cfgMgrFactory ConfigManagerFactory - if tt.cfgMgrFactoryErr { - cfgMgrFactory = func() (config.Manager, error) { - return nil, errors.New("config error") - } - } else { - cfgMgrFactory = func() (config.Manager, error) { - return cfgMgr, nil - } - } + cmd := newMockCommand(). + withCID(tt.cid). + withBool(FlagWatch, tt.watchFlag) pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return pinningSvc @@ -215,7 +194,7 @@ func TestStatus(t *testing.T) { return statusSvc } - err := status(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory, statusServiceFactory) + err := status(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory, statusServiceFactory) if tt.wantErr { require.Error(t, err) @@ -245,20 +224,142 @@ func TestNewStatusCommand(t *testing.T) { }) } -type mockStatusCommand struct { - cid string - watch bool -} +func TestRenderPinStatus(t *testing.T) { + t.Run("renders pin status without delegates", func(t *testing.T) { + var buf bytes.Buffer + output := newTestOutput() + output.SetWriter(&buf) + + pinStatus := &PinStatus{ + CID: "QmXxx", + Status: "pinned", + Created: "2024-01-01T00:00:00Z", + } + + err := renderPinStatus(output, pinStatus) + require.NoError(t, err) + + result := buf.String() + assert.Contains(t, result, "QmXxx") + assert.Contains(t, result, "pinned") + assert.Contains(t, result, "2024-01-01T00:00:00Z") + }) + + t.Run("renders pin status with delegates", func(t *testing.T) { + var buf bytes.Buffer + output := newTestOutput() + output.SetWriter(&buf) + + pinStatus := &PinStatus{ + CID: "QmXxx", + Status: "pinned", + Created: "2024-01-01T00:00:00Z", + Delegates: []string{"delegate1", "delegate2"}, + } + + err := renderPinStatus(output, pinStatus) + require.NoError(t, err) + + result := buf.String() + assert.Contains(t, result, "QmXxx") + assert.Contains(t, result, "Delegates:") + assert.Contains(t, result, "delegate1") + assert.Contains(t, result, "delegate2") + }) + + t.Run("renders pin status as JSON", func(t *testing.T) { + var buf bytes.Buffer + output := NewOutputFormatter(true, false, false, false) + output.SetWriter(&buf) -func (m *mockStatusCommand) GetCID() string { - return m.cid + pinStatus := &PinStatus{ + CID: "QmXxx", + Status: "pinned", + Created: "2024-01-01T00:00:00Z", + } + + err := renderPinStatus(output, pinStatus) + require.NoError(t, err) + + result := buf.String() + assert.Contains(t, result, `"CID"`) + assert.Contains(t, result, `"QmXxx"`) + assert.Contains(t, result, `"pinned"`) + }) } -func (m *mockStatusCommand) Bool(name string) bool { - switch name { - case FlagWatch: - return m.watch - default: - return false - } +func TestRenderOperationStatus(t *testing.T) { + t.Run("renders operation status", func(t *testing.T) { + var buf bytes.Buffer + output := newTestOutput() + output.SetWriter(&buf) + + op := &OperationStatusResult{ + CID: "QmYyy", + StatusDisplayName: "Completed", + OperationDisplayName: "Pin", + ProtocolDisplayName: "IPFS", + ProgressPercent: 100, + StartedAt: "2024-01-01T00:00:00Z", + } + + err := renderOperationStatus(output, op) + require.NoError(t, err) + + result := buf.String() + assert.Contains(t, result, "QmYyy") + assert.Contains(t, result, "Completed") + assert.Contains(t, result, "Pin") + assert.Contains(t, result, "IPFS") + assert.Contains(t, result, "100%") + assert.Contains(t, result, "2024-01-01T00:00:00Z") + }) + + t.Run("renders operation status with message and error", func(t *testing.T) { + var buf bytes.Buffer + output := newTestOutput() + output.SetWriter(&buf) + + op := &OperationStatusResult{ + CID: "QmZzz", + StatusDisplayName: "Failed", + OperationDisplayName: "Pin", + ProtocolDisplayName: "IPFS", + ProgressPercent: 50, + StartedAt: "2024-01-01T00:00:00Z", + StatusMessage: "processing stalled", + Error: "upload failed", + } + + err := renderOperationStatus(output, op) + require.NoError(t, err) + + result := buf.String() + assert.Contains(t, result, "processing stalled") + assert.Contains(t, result, "upload failed") + }) + + t.Run("renders operation status as JSON", func(t *testing.T) { + var buf bytes.Buffer + output := NewOutputFormatter(true, false, false, false) + output.SetWriter(&buf) + + op := &OperationStatusResult{ + CID: "QmYyy", + StatusDisplayName: "Completed", + OperationDisplayName: "Pin", + ProtocolDisplayName: "IPFS", + ProgressPercent: 100, + StartedAt: "2024-01-01T00:00:00Z", + } + + err := renderOperationStatus(output, op) + require.NoError(t, err) + + result := buf.String() + assert.Contains(t, result, `"CID"`) + assert.Contains(t, result, `"QmYyy"`) + assert.Contains(t, result, `"Completed"`) + }) } + diff --git a/pkg/cli/testutils.go b/pkg/cli/testutils.go index 2445ab6..2f4a7e3 100644 --- a/pkg/cli/testutils.go +++ b/pkg/cli/testutils.go @@ -1,5 +1,15 @@ package cli +import ( + "errors" + "testing" + "time" + + "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + // mockArgs implements cli.Args interface for testing type mockArgs struct { args []string @@ -37,3 +47,285 @@ func (m *mockArgs) Tail() []string { } return []string{} } + +// newMockCommand creates a new mockCommand ready for builder method chaining. +func newMockCommand() *mockCommand { + return &mockCommand{} +} + +// mockCommand is a map-backed test double for commandGetter interfaces. +// It replaces hand-written mock command structs across test files. +type mockCommand struct { + stringFields map[string]string + intFields map[string]int + int64Fields map[string]int64 + uint64Fields map[string]uint64 + uintFields map[string]uint + durationFields map[string]time.Duration + floatFields map[string]float64 + boolFields map[string]bool + stringSlices map[string][]string + isSetFields map[string]bool + args cli.Args + cid string +} + +func (m *mockCommand) withString(name, value string) *mockCommand { + if m.stringFields == nil { + m.stringFields = make(map[string]string) + } + m.stringFields[name] = value + return m +} + +func (m *mockCommand) withInt(name string, value int) *mockCommand { + if m.intFields == nil { + m.intFields = make(map[string]int) + } + m.intFields[name] = value + return m +} + +func (m *mockCommand) withInt64(name string, value int64) *mockCommand { + if m.int64Fields == nil { + m.int64Fields = make(map[string]int64) + } + m.int64Fields[name] = value + return m +} + +func (m *mockCommand) withUint64(name string, value uint64) *mockCommand { + if m.uint64Fields == nil { + m.uint64Fields = make(map[string]uint64) + } + m.uint64Fields[name] = value + return m +} + +func (m *mockCommand) withUint(name string, value uint) *mockCommand { + if m.uintFields == nil { + m.uintFields = make(map[string]uint) + } + m.uintFields[name] = value + return m +} + +func (m *mockCommand) withDuration(name string, value time.Duration) *mockCommand { + if m.durationFields == nil { + m.durationFields = make(map[string]time.Duration) + } + m.durationFields[name] = value + return m +} + +func (m *mockCommand) withFloat(name string, value float64) *mockCommand { + if m.floatFields == nil { + m.floatFields = make(map[string]float64) + } + m.floatFields[name] = value + return m +} + +func (m *mockCommand) withBool(name string, value bool) *mockCommand { + if m.boolFields == nil { + m.boolFields = make(map[string]bool) + } + m.boolFields[name] = value + return m +} + +func (m *mockCommand) withIsSet(name string, value bool) *mockCommand { + if m.isSetFields == nil { + m.isSetFields = make(map[string]bool) + } + m.isSetFields[name] = value + return m +} + +func (m *mockCommand) withStringSlice(name string, value []string) *mockCommand { + if m.stringSlices == nil { + m.stringSlices = make(map[string][]string) + } + m.stringSlices[name] = value + return m +} + +func (m *mockCommand) withArgs(args ...string) *mockCommand { + m.args = &mockArgs{args: args} + return m +} + +func (m *mockCommand) withCID(cid string) *mockCommand { + m.cid = cid + return m +} + +func (m *mockCommand) String(name string) string { + if m.stringFields != nil { + if v, ok := m.stringFields[name]; ok { + return v + } + } + return "" +} + +func (m *mockCommand) Int(name string) int { + if m.intFields != nil { + if v, ok := m.intFields[name]; ok { + return v + } + } + return 0 +} + +func (m *mockCommand) Int64(name string) int64 { + if m.int64Fields != nil { + if v, ok := m.int64Fields[name]; ok { + return v + } + } + return 0 +} + +func (m *mockCommand) Uint64(name string) uint64 { + if m.uint64Fields != nil { + if v, ok := m.uint64Fields[name]; ok { + return v + } + } + return 0 +} + +func (m *mockCommand) Uint(name string) uint { + if m.uintFields != nil { + if v, ok := m.uintFields[name]; ok { + return v + } + } + return 0 +} + +func (m *mockCommand) Duration(name string) time.Duration { + if m.durationFields != nil { + if v, ok := m.durationFields[name]; ok { + return v + } + } + return 0 +} + +func (m *mockCommand) Float(name string) float64 { + if m.floatFields != nil { + if v, ok := m.floatFields[name]; ok { + return v + } + } + return 0 +} + +func (m *mockCommand) Bool(name string) bool { + if m.boolFields != nil { + if v, ok := m.boolFields[name]; ok { + return v + } + } + return false +} + +func (m *mockCommand) IsSet(name string) bool { + if m.isSetFields != nil { + if v, ok := m.isSetFields[name]; ok { + return v + } + } + return false +} + +func (m *mockCommand) StringSlice(name string) []string { + if m.stringSlices != nil { + if v, ok := m.stringSlices[name]; ok { + return v + } + } + return nil +} + +func (m *mockCommand) Args() cli.Args { + if m.args != nil { + return m.args + } + return &mockArgs{} +} + +func (m *mockCommand) GetCID() string { + return m.cid +} + +// Compile-time interface satisfaction checks +var ( + _ flagGetter = (*mockCommand)(nil) + _ flagGetterWithInt = (*mockCommand)(nil) + _ flagGetterWithIsSet = (*mockCommand)(nil) + _ flagGetterWithUint = (*mockCommand)(nil) + _ flagGetterWithDuration = (*mockCommand)(nil) + _ commandGetter = (*mockCommand)(nil) + _ argsGetter = (*mockCommand)(nil) + _ cidGetter = (*mockCommand)(nil) + _ argsFlagGetter = (*mockCommand)(nil) + _ cidFlagGetter = (*mockCommand)(nil) + _ dnsCommandGetter = (*mockCommand)(nil) + _ benchCommandGetter = (*mockCommand)(nil) + _ websitesCommandGetter = (*mockCommand)(nil) +) + +// newTestOutput creates a human-readable Output for testing. +func newTestOutput() Output { + return NewOutputFormatter(false, false, false, false) +} + +func newTestConfigMgr(t *testing.T) *configmocks.MockManager { + m := configmocks.NewMockManager(t) + m.EXPECT().Config().Return(&config.Config{ + BaseEndpoint: "pinner.xyz", + Secure: true, + AuthToken: "test-token", + }).Maybe() + return m +} + +func getFlagNames(cmd *cli.Command) []string { + names := make([]string, len(cmd.Flags)) + for i, f := range cmd.Flags { + names[i] = f.Names()[0] + } + return names +} + +func getSubcommandNames(cmd *cli.Command) []string { + names := make([]string, len(cmd.Commands)) + for i, c := range cmd.Commands { + names[i] = c.Name + } + return names +} + +// Compile-time interface satisfaction checks for mockCommand. +var _ flagGetter = (*mockCommand)(nil) +var _ flagGetterWithInt = (*mockCommand)(nil) +var _ flagGetterWithIsSet = (*mockCommand)(nil) +var _ flagGetterWithUint = (*mockCommand)(nil) +var _ flagGetterWithDuration = (*mockCommand)(nil) +var _ commandGetter = (*mockCommand)(nil) +var _ argsGetter = (*mockCommand)(nil) +var _ cidGetter = (*mockCommand)(nil) +var _ argsFlagGetter = (*mockCommand)(nil) +var _ cidFlagGetter = (*mockCommand)(nil) +var _ dnsCommandGetter = (*mockCommand)(nil) +var _ benchCommandGetter = (*mockCommand)(nil) +var _ websitesCommandGetter = (*mockCommand)(nil) + +func failingConfigMgrFactory() ConfigManagerFactory { + return func() (config.Manager, error) { + return nil, errors.New("config error") + } +} diff --git a/pkg/cli/testutils_test.go b/pkg/cli/testutils_test.go new file mode 100644 index 0000000..9194a7f --- /dev/null +++ b/pkg/cli/testutils_test.go @@ -0,0 +1,31 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMockArgsPresent(t *testing.T) { + empty := &mockArgs{args: []string{}} + assert.False(t, empty.Present()) + + nonEmpty := &mockArgs{args: []string{"a"}} + assert.True(t, nonEmpty.Present()) +} + +func TestMockArgsTail(t *testing.T) { + empty := &mockArgs{args: []string{}} + assert.Empty(t, empty.Tail()) + + single := &mockArgs{args: []string{"a"}} + assert.Empty(t, single.Tail()) + + multi := &mockArgs{args: []string{"a", "b", "c"}} + assert.Equal(t, []string{"b", "c"}, multi.Tail()) +} + +func TestMockCommandWithInt64(t *testing.T) { + cmd := newMockCommand().withInt64("size", 1024) + assert.Equal(t, int64(1024), cmd.Int64("size")) +} diff --git a/pkg/cli/unpin.go b/pkg/cli/unpin.go index 90e2fde..79f3266 100644 --- a/pkg/cli/unpin.go +++ b/pkg/cli/unpin.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" ) func newUnpinCommand() *cli.Command { @@ -39,33 +40,20 @@ Examples: Metadata: WithTutorial(5, "Unpin content", fmt.Sprintf("pinner unpin %s", abbreviateCID(TutorialCID))), Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return unpin(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultPinningServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + return unpin(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory) }, } } -// unpinCommandGetter defines the interface for getting unpin command flags. -type unpinCommandGetter interface { - String(name string) string - Int(name string) int - Bool(name string) bool - GetCID() string -} - -func unpin(ctx context.Context, cmd unpinCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, pinningServiceFactory PinningServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - +func unpin(ctx context.Context, cmd cidFlagGetter, output Output, cfgMgr config.Manager, authToken string, pinningServiceFactory PinningServiceFactory) error { var pinningService PinningService - if c, ok := cmd.(*cliCommandWrapper); ok { - authToken := GetAuthToken(c.Command, cfgMgr) - if authToken != "" { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpoint(), WithAuthToken(authToken)) - } else { - pinningService = pinningServiceFactory(cfgMgr, output) - } + if authToken != "" { + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpoint(), WithAuthToken(authToken)) } else { pinningService = pinningServiceFactory(cfgMgr, output) } @@ -81,6 +69,7 @@ func unpin(ctx context.Context, cmd unpinCommandGetter, output Output, cfgMgrFac dryRun := cmd.Bool(FlagDryRun) var cids []string + var err error if isStdinPipe() { cids, err = readLinesFromStdin() diff --git a/pkg/cli/unpin_all.go b/pkg/cli/unpin_all.go index b11655c..3d12bdf 100644 --- a/pkg/cli/unpin_all.go +++ b/pkg/cli/unpin_all.go @@ -5,8 +5,8 @@ import ( "fmt" "strconv" - "github.com/manifoldco/promptui" "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" ) func newUnpinAllCommand() *cli.Command { @@ -44,37 +44,27 @@ Examples: }, Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return unpinAll(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultPinningServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + prompter := &PTermConfirmPrompter{} + return unpinAll(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory, prompter) }, } } -type unpinAllCommandGetter interface { - String(name string) string - Int(name string) int - Bool(name string) bool -} - -func unpinAll(ctx context.Context, cmd unpinAllCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, pinningServiceFactory PinningServiceFactory) error { +func unpinAll(ctx context.Context, cmd flagGetterWithInt, output Output, cfgMgr config.Manager, authToken string, pinningServiceFactory PinningServiceFactory, prompter ConfirmPrompter) error { confirm := cmd.Bool(FlagForce) || cmd.Bool(FlagConfirm) if !confirm { output.Printfln("Use --force to unpin all pins. This is a destructive operation.") return nil } - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - var pinningService PinningService - if c, ok := cmd.(*cliCommandWrapper); ok { - authToken := GetAuthToken(c.Command, cfgMgr) - if authToken != "" { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpoint(), WithAuthToken(authToken)) - } else { - pinningService = pinningServiceFactory(cfgMgr, output) - } + if authToken != "" { + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpoint(), WithAuthToken(authToken)) } else { pinningService = pinningServiceFactory(cfgMgr, output) } @@ -134,21 +124,12 @@ func unpinAll(ctx context.Context, cmd unpinAllCommandGetter, output Output, cfg if !yes { expected := strconv.Itoa(len(pins)) - prompt := promptui.Prompt{ - Label: fmt.Sprintf("Type %s to confirm unpinning all %d pins", expected, len(pins)), - Validate: func(input string) error { - if input != expected { - return fmt.Errorf("must type %s to confirm", expected) - } - return nil - }, - } - result, err := prompt.Run() + result, err := prompter.Confirm( + fmt.Sprintf("Type %s to confirm unpinning all %d pins", expected, len(pins)), + expected, + ) if err != nil { - if err == promptui.ErrInterrupt { - return ErrUnpinAllAborted - } - return fmt.Errorf("safety prompt failed: %w", err) + return ErrUnpinAllAborted } if result != expected { return ErrUnpinAllAborted diff --git a/pkg/cli/unpin_all_test.go b/pkg/cli/unpin_all_test.go index 7f6929f..7d5ae90 100644 --- a/pkg/cli/unpin_all_test.go +++ b/pkg/cli/unpin_all_test.go @@ -14,17 +14,16 @@ import ( func TestUnpinAll(t *testing.T) { tests := []struct { - name string - confirm bool - yes bool - dryRun bool - statusFilter string - parallel int - continueOn bool - setupMocks func(*configmocks.MockManager, *MockPinningService) - wantErr bool - errContains string - cfgMgrFactoryErr bool + name string + confirm bool + yes bool + dryRun bool + statusFilter string + parallel int + continueOn bool + setupMocks func(*configmocks.MockManager, *MockPinningService) + wantErr bool + errContains string }{ { name: "requires --confirm flag", @@ -150,15 +149,6 @@ func TestUnpinAll(t *testing.T) { wantErr: true, errContains: "unpin-all failed", }, - { - name: "returns error when config manager factory fails", - confirm: true, - yes: true, - setupMocks: func(cfgMgr *configmocks.MockManager, service *MockPinningService) {}, - wantErr: true, - errContains: "config error", - cfgMgrFactoryErr: true, - }, { name: "returns error when not authenticated", confirm: true, @@ -204,37 +194,27 @@ func TestUnpinAll(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockPinningService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - cmd := &mockUnpinAllCommand{ - confirm: tt.confirm, - yes: tt.yes, - dryRun: tt.dryRun, - statusFilter: tt.statusFilter, - parallel: tt.parallel, - continueOn: tt.continueOn, - } - - var cfgMgrFactory ConfigManagerFactory - if tt.cfgMgrFactoryErr { - cfgMgrFactory = func() (config.Manager, error) { - return nil, errors.New("config error") - } - } else { - cfgMgrFactory = func() (config.Manager, error) { - return cfgMgr, nil - } - } + cmd := newMockCommand(). + withBool(FlagForce, tt.confirm || tt.yes). + withBool(FlagConfirm, tt.confirm). + withBool(FlagYes, tt.yes). + withBool(FlagDryRun, tt.dryRun). + withString(FlagStatus, tt.statusFilter). + withInt(FlagParallel, tt.parallel). + withBool(FlagContinue, tt.continueOn) pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := unpinAll(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + prompter := &MockConfirmPrompter{} + err := unpinAll(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory, prompter) if tt.wantErr { require.Error(t, err) @@ -248,6 +228,117 @@ func TestUnpinAll(t *testing.T) { } } +func TestUnpinAllConfirmPrompt(t *testing.T) { + t.Run("mismatch_aborts", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockPinningService(t) + output := newTestOutput() + + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().List(context.Background(), "", 0, "").Return( + []Pin{ + {CID: "QmXxx1", Name: "test1", Status: "pinned", RequestID: "req-1"}, + {CID: "QmXxx2", Name: "test2", Status: "pinned", RequestID: "req-2"}, + }, + nil, + ) + + cmd := newMockCommand(). + withBool(FlagForce, false). + withBool(FlagConfirm, true). + withBool(FlagYes, false). + withBool(FlagDryRun, false). + withString(FlagStatus, ""). + withInt(FlagParallel, 0). + withBool(FlagContinue, false) + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service + } + + prompter := &MockConfirmPrompter{ConfirmResult: "wrong"} + err := unpinAll(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory, prompter) + + assert.ErrorIs(t, err, ErrUnpinAllAborted) + }) + + t.Run("match_proceeds", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockPinningService(t) + output := newTestOutput() + + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().List(context.Background(), "", 0, "").Return( + []Pin{ + {CID: "QmXxx1", Name: "test1", Status: "pinned", RequestID: "req-1"}, + {CID: "QmXxx2", Name: "test2", Status: "pinned", RequestID: "req-2"}, + }, + nil, + ) + service.EXPECT().UnpinAll(context.Background(), "", BatchOptions{ + Parallel: 0, + ContinueOn: false, + Progress: true, + }).Return(&BatchResult{ + Total: 2, + Succeeded: []OperationResult{{CID: "QmXxx1"}, {CID: "QmXxx2"}}, + Failed: []OperationError{}, + Skipped: []string{}, + }, nil) + + cmd := newMockCommand(). + withBool(FlagForce, false). + withBool(FlagConfirm, true). + withBool(FlagYes, false). + withBool(FlagDryRun, false). + withString(FlagStatus, ""). + withInt(FlagParallel, 0). + withBool(FlagContinue, false) + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service + } + + prompter := &MockConfirmPrompter{ConfirmResult: "2"} + err := unpinAll(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory, prompter) + + assert.NoError(t, err) + }) + + t.Run("interrupt_aborts", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + service := NewMockPinningService(t) + output := newTestOutput() + + service.EXPECT().RequireAuthenticated().Return(nil) + service.EXPECT().List(context.Background(), "", 0, "").Return( + []Pin{ + {CID: "QmXxx1", Name: "test1", Status: "pinned", RequestID: "req-1"}, + {CID: "QmXxx2", Name: "test2", Status: "pinned", RequestID: "req-2"}, + }, + nil, + ) + + cmd := newMockCommand(). + withBool(FlagForce, false). + withBool(FlagConfirm, true). + withBool(FlagYes, false). + withBool(FlagDryRun, false). + withString(FlagStatus, ""). + withInt(FlagParallel, 0). + withBool(FlagContinue, false) + + pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + return service + } + + prompter := &MockConfirmPrompter{ConfirmErr: ErrUnpinAllAborted} + err := unpinAll(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory, prompter) + + assert.ErrorIs(t, err, ErrUnpinAllAborted) + }) +} + func TestNewUnpinAllCommand(t *testing.T) { t.Run("creates unpin all command with correct configuration", func(t *testing.T) { cmd := newUnpinAllCommand() @@ -289,46 +380,4 @@ func TestNewUnpinAllCommand(t *testing.T) { }) } -type mockUnpinAllCommand struct { - confirm bool - yes bool - dryRun bool - statusFilter string - parallel int - continueOn bool -} - -func (m *mockUnpinAllCommand) String(name string) string { - switch name { - case FlagStatus: - return m.statusFilter - default: - return "" - } -} - -func (m *mockUnpinAllCommand) Int(name string) int { - switch name { - case FlagParallel: - return m.parallel - default: - return 0 - } -} -func (m *mockUnpinAllCommand) Bool(name string) bool { - switch name { - case FlagForce: - return m.confirm || m.yes - case FlagConfirm: - return m.confirm - case FlagYes: - return m.yes - case FlagDryRun: - return m.dryRun - case FlagContinue: - return m.continueOn - default: - return false - } -} diff --git a/pkg/cli/unpin_test.go b/pkg/cli/unpin_test.go index a115296..aac1ad3 100644 --- a/pkg/cli/unpin_test.go +++ b/pkg/cli/unpin_test.go @@ -14,13 +14,12 @@ import ( func TestUnpin(t *testing.T) { tests := []struct { - name string - cid string - confirmFlag bool - setupMocks func(*configmocks.MockManager, *MockPinningService) - wantErr bool - errContains string - cfgMgrFactoryErr bool + name string + cid string + confirmFlag bool + setupMocks func(*configmocks.MockManager, *MockPinningService) + wantErr bool + errContains string }{ { name: "successful unpin operation", @@ -46,15 +45,6 @@ func TestUnpin(t *testing.T) { }, wantErr: false, }, - { - name: "returns error when config manager factory fails", - cid: "QmXxx", - confirmFlag: true, - setupMocks: func(cfgMgr *configmocks.MockManager, service *MockPinningService) {}, - wantErr: true, - errContains: "config error", - cfgMgrFactoryErr: true, - }, { name: "returns error when unpin fails", cid: "QmXxx", @@ -100,33 +90,22 @@ func TestUnpin(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockPinningService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - cmd := &mockUnpinCommand{ - cid: tt.cid, - confirm: tt.confirmFlag, - } - - var cfgMgrFactory ConfigManagerFactory - if tt.cfgMgrFactoryErr { - cfgMgrFactory = func() (config.Manager, error) { - return nil, errors.New("config error") - } - } else { - cfgMgrFactory = func() (config.Manager, error) { - return cfgMgr, nil - } - } + cmd := newMockCommand(). + withCID(tt.cid). + withBool(FlagForce, tt.confirmFlag). + withBool(FlagConfirm, tt.confirmFlag) pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := unpin(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := unpin(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory) if tt.wantErr { require.Error(t, err) @@ -188,28 +167,25 @@ func TestUnpinBatch(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfgMgr := configmocks.NewMockManager(t) service := NewMockPinningService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(cfgMgr, service) } - cmd := &mockUnpinCommand{ - cid: tt.cids, - confirm: tt.confirm, - parallel: tt.parallel, - continueOn: tt.continueOn, - } + cmd := newMockCommand(). + withCID(tt.cids). + withBool(FlagForce, tt.confirm). + withBool(FlagConfirm, tt.confirm). + withInt(FlagParallel, tt.parallel). + withBool(FlagContinue, tt.continueOn) - cfgMgrFactory := func() (config.Manager, error) { - return cfgMgr, nil - } pinningServiceFactory := func(cm config.Manager, out Output) PinningService { return service } - err := unpin(context.Background(), cmd, output, cfgMgrFactory, pinningServiceFactory) + err := unpin(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory) if tt.wantErr { require.Error(t, err) @@ -259,46 +235,4 @@ func TestNewUnpinCommand(t *testing.T) { }) } -// mockUnpinCommand is a mock implementation of unpinCommandGetter for testing. -type mockUnpinCommand struct { - cid string - confirm bool - file string - parallel int - continueOn bool -} - -func (m *mockUnpinCommand) GetCID() string { - return m.cid -} - -func (m *mockUnpinCommand) String(name string) string { - switch name { - case FlagFile: - return m.file - default: - return "" - } -} - -func (m *mockUnpinCommand) Int(name string) int { - switch name { - case FlagParallel: - return m.parallel - default: - return 0 - } -} -func (m *mockUnpinCommand) Bool(name string) bool { - switch name { - case FlagForce: - return m.confirm - case FlagConfirm: - return m.confirm - case FlagContinue: - return m.continueOn - default: - return false - } -} diff --git a/pkg/cli/upload.go b/pkg/cli/upload.go index 4e474d0..4c108cb 100644 --- a/pkg/cli/upload.go +++ b/pkg/cli/upload.go @@ -58,7 +58,13 @@ The output includes: Metadata: WithTutorial(1, "Upload and pin a file", "pinner upload myfile.txt"), Action: func(ctx context.Context, c *cli.Command) error { output := setupOutput(c) - return handleUpload(ctx, newCLICommandWrapper(c), output, defaultConfigManagerFactory, defaultUploadServiceFactory, defaultPinningServiceFactory) + cfgMgr, err := defaultConfigManagerFactory() + if err != nil { + return err + } + authToken := GetAuthToken(c, cfgMgr) + secure := GetSecureSetting(c, cfgMgr) + return handleUpload(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultUploadServiceFactory, defaultPinningServiceFactory) }, } } @@ -138,23 +144,12 @@ func detectInputType(path string) string { return "file" } -// uploadCommandGetter defines the interface for getting upload command flags. -type uploadCommandGetter interface { +func handleUpload(ctx context.Context, cmd interface { + argsFlagGetter Uint64(name string) uint64 Int64(name string) int64 - Int(name string) int - String(name string) string - Bool(name string) bool StringSlice(name string) []string - Args() cli.Args -} - -func handleUpload(ctx context.Context, cmd uploadCommandGetter, output Output, cfgMgrFactory ConfigManagerFactory, uploadServiceFactory UploadServiceFactory, pinningServiceFactory PinningServiceFactory) error { - cfgMgr, err := cfgMgrFactory() - if err != nil { - return err - } - +}, output Output, cfgMgr config.Manager, authToken string, secure bool, uploadServiceFactory UploadServiceFactory, pinningServiceFactory PinningServiceFactory) error { // Set memory limit from flag (overrides config if provided, runtime only) memoryLimit := cmd.Uint64(FlagMemoryLimit) if memoryLimit == 0 { @@ -252,14 +247,8 @@ func handleUpload(ctx context.Context, cmd uploadCommandGetter, output Output, c return fmt.Errorf("upload succeeded but metadata flag invalid: %w", err) } var metaPinningService PinningService - if c, ok := cmd.(*cliCommandWrapper); ok { - authToken := GetAuthToken(c.Command, cfgMgr) - secure := GetSecureSetting(c.Command, cfgMgr) - if authToken != "" { - metaPinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) - } else { - metaPinningService = pinningServiceFactory(cfgMgr, output) - } + if authToken != "" { + metaPinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { metaPinningService = pinningServiceFactory(cfgMgr, output) } diff --git a/pkg/cli/upload_client_test.go b/pkg/cli/upload_client_test.go index 5e93870..07b9e54 100644 --- a/pkg/cli/upload_client_test.go +++ b/pkg/cli/upload_client_test.go @@ -13,6 +13,7 @@ import ( "path/filepath" "testing" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -40,7 +41,7 @@ type uploadTestHelpers struct { func newUploadTestHelpers(t *testing.T) *uploadTestHelpers { cfgMgr := configmocks.NewMockManager(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() accountClient := portalsdkmocks.NewMockAccountAPI(t) cfg := &config.Config{ @@ -354,6 +355,194 @@ func TestUploadServiceDefault_Upload_WaitForPin(t *testing.T) { }) } +func TestUploadServiceDefault_RequireAuthenticated(t *testing.T) { + t.Run("returns nil when auth token is available", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.cfg.AuthToken = configToken + h.cfgMgr.EXPECT().Config().Return(h.cfg) + + err := h.service.RequireAuthenticated() + assert.NoError(t, err) + }) + + t.Run("returns error when no auth token", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.cfg.AuthToken = "" + h.cfgMgr.EXPECT().Config().Return(h.cfg) + + err := h.service.RequireAuthenticated() + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") + }) + + t.Run("returns nil when override token is set", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.cfg.AuthToken = "" + h.service.WithAuthToken(overrideToken) + h.cfgMgr.EXPECT().Config().Return(h.cfg) + + err := h.service.RequireAuthenticated() + assert.NoError(t, err) + }) +} + +func TestUploadServiceDefault_getAuthToken(t *testing.T) { + t.Run("returns override token when set", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.cfg.AuthToken = configToken + h.service.WithAuthToken(overrideToken) + h.cfgMgr.EXPECT().Config().Return(h.cfg) + + token := h.service.getAuthToken() + assert.Equal(t, overrideToken, token) + }) + + t.Run("falls back to config token when no override", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.cfg.AuthToken = configToken + h.cfgMgr.EXPECT().Config().Return(h.cfg) + + token := h.service.getAuthToken() + assert.Equal(t, configToken, token) + }) + + t.Run("returns empty when neither available", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.cfg.AuthToken = "" + h.cfgMgr.EXPECT().Config().Return(h.cfg) + + token := h.service.getAuthToken() + assert.Empty(t, token) + }) +} + +func TestUploadServiceDefault_resolveAuthToken(t *testing.T) { + t.Run("returns config token when no auth service", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.cfg.AuthToken = configToken + h.cfgMgr.EXPECT().Config().Return(h.cfg) + + token, err := h.service.resolveAuthToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, configToken, token) + }) + + t.Run("returns raw token when JWT decode fails", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.service.authService = NewMockAuthService(t) + h.cfg.AuthToken = "not-a-jwt" + h.cfgMgr.EXPECT().Config().Return(h.cfg) + + token, err := h.service.resolveAuthToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "not-a-jwt", token) + }) + + t.Run("exchanges API key JWT for login JWT", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.service.authService = NewMockAuthService(t) + + apiKeyJWT := createUploadTestJWT(t, "api") + loginJWT := "login-jwt-token" + + h.cfg.AuthToken = apiKeyJWT + h.cfgMgr.EXPECT().Config().Return(h.cfg) + h.accountClient.EXPECT().LoginWithAPIKey(mock.Anything, apiKeyJWT).Return(loginJWT, nil) + + token, err := h.service.resolveAuthToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, loginJWT, token) + }) + + t.Run("returns error when API key exchange fails", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.service.authService = NewMockAuthService(t) + + apiKeyJWT := createUploadTestJWT(t, "api") + h.cfg.AuthToken = apiKeyJWT + h.cfgMgr.EXPECT().Config().Return(h.cfg) + h.accountClient.EXPECT().LoginWithAPIKey(mock.Anything, apiKeyJWT).Return("", errors.New("exchange failed")) + + token, err := h.service.resolveAuthToken(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to exchange API key") + assert.Empty(t, token) + }) + + t.Run("returns login JWT as-is when purpose is login", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.service.authService = NewMockAuthService(t) + + loginJWT := createUploadTestJWT(t, "login") + h.cfg.AuthToken = loginJWT + h.cfgMgr.EXPECT().Config().Return(h.cfg) + + token, err := h.service.resolveAuthToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, loginJWT, token) + }) +} + +func TestUploadServiceDefault_wrapUploadError(t *testing.T) { + t.Run("returns nil for nil error", func(t *testing.T) { + h := newUploadTestHelpers(t) + result := h.service.wrapUploadError(nil) + assert.Nil(t, result) + }) + + t.Run("wraps error with Upload context", func(t *testing.T) { + h := newUploadTestHelpers(t) + innerErr := errors.New("something went wrong") + result := h.service.wrapUploadError(innerErr) + require.Error(t, result) + assert.Contains(t, result.Error(), "Upload failed") + assert.True(t, errors.Is(result, innerErr)) + }) +} + +func TestUploadServiceDefault_waitForPin(t *testing.T) { + t.Run("returns error when account endpoint unreachable", func(t *testing.T) { + h := newUploadTestHelpers(t) + h.cfgMgr.EXPECT().Config().Return(h.cfg) + + err := h.service.waitForPin(context.Background(), "bafybeigtest", "test-token") + require.Error(t, err) + }) + + t.Run("returns error when no operations found for CID", func(t *testing.T) { + h := newUploadTestHelpers(t) + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[],"total":0}`)) + }) + server := httptest.NewServer(mux) + defer server.Close() + + u, _ := url.Parse(server.URL) + h.service.accountEndpoint = "http://localhost:" + u.Port() + + err := h.service.waitForPin(context.Background(), "bafybeigtest", "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "operation not found") + }) +} + +func createUploadTestJWT(t *testing.T, audience string) string { + t.Helper() + claims := &jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{audience}, + Issuer: "test-issuer", + Subject: "test-subject", + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := token.SignedString([]byte("test-secret")) + require.NoError(t, err) + return signed +} + func TestUploadServiceDefaultIntegration(t *testing.T) { t.Run("handles complex directory structure", func(t *testing.T) { baseEndpoint, server := createUploadMockServer(t, func(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/cli/upload_client_tus_integration_test.go b/pkg/cli/upload_client_tus_integration_test.go index c71a2c4..274e73e 100644 --- a/pkg/cli/upload_client_tus_integration_test.go +++ b/pkg/cli/upload_client_tus_integration_test.go @@ -184,7 +184,7 @@ func setupTUSTest(t *testing.T, uploadLimit int64) *tusTestSetup { accClient := portalsdkmocks.NewMockAccountAPI(t) accClient.EXPECT().UploadLimit(mock.Anything).Return(uploadLimit, nil) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() service := NewUploadService(cfgMgr, output, WithUploadAccountClient(accClient)) return &tusTestSetup{ diff --git a/pkg/cli/upload_test.go b/pkg/cli/upload_test.go index 9ac8b22..59c0b2b 100644 --- a/pkg/cli/upload_test.go +++ b/pkg/cli/upload_test.go @@ -10,9 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v3" "go.lumeweb.com/pinner-cli/pkg/config" - configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" "go.lumeweb.com/pinner-cli/pkg/internal/io" ) @@ -81,65 +79,7 @@ func TestResolveUploadInput_Stdin(t *testing.T) { } } -type mockUploadCommand struct { - path string - name string - noWait bool - dryRun bool - args []string -} - -func (m *mockUploadCommand) Args() cli.Args { - if m.args == nil { - m.args = []string{m.path} - } - return &mockArgs{m.args} -} - -func (m *mockUploadCommand) Uint64(name string) uint64 { - if name == FlagMemoryLimit { - return 100 - } - return 0 -} - -func (m *mockUploadCommand) Int64(name string) int64 { - return 0 -} -func (m *mockUploadCommand) Int(name string) int { - return 0 -} - -func (m *mockUploadCommand) String(name string) string { - switch name { - case FlagName: - return m.name - default: - return "" - } -} - -func (m *mockUploadCommand) Bool(name string) bool { - switch name { - case FlagNoWait: - return m.noWait - case FlagDryRun: - return m.dryRun - case FlagSecure: - return true - default: - return false - } -} - -func (m *mockUploadCommand) StringSlice(name string) []string { - return nil -} - -func (m *mockUploadCommand) IsSet(name string) bool { - return false -} func TestResolveUploadInput_File(t *testing.T) { // Create a temp file @@ -450,38 +390,27 @@ func TestUploadDryRun(t *testing.T) { } service := NewMockUploadService(t) - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() + cfgMgr := newTestConfigMgr(t) if tt.setupMocks != nil { tt.setupMocks(service) } - cmd := &mockUploadCommand{ - path: filepath.Join(tmpDir, tt.path), - name: "", - noWait: false, - dryRun: tt.dryRunFlag, - } + cmd := newMockCommand(). + withArgs(filepath.Join(tmpDir, tt.path)). + withString(FlagName, ""). + withBool(FlagNoWait, false). + withBool(FlagDryRun, tt.dryRunFlag). + withUint64(FlagMemoryLimit, 100). + withBool(FlagSecure, true) if tt.name == "dry run with custom name" { - cmd.name = "custom-name" + cmd = cmd.withString(FlagName, "custom-name") } if tt.name == "dry run with wait flag" { - cmd.noWait = false - } - - cfgMgrFactory := func() (config.Manager, error) { - cfgMgr := configmocks.NewMockManager(t) - cfgMgr.EXPECT().Config().Return(&config.Config{ - MemoryLimit: 100, - Secure: true, - BaseEndpoint: "pinner.xyz", - AuthToken: testAuthToken, - MaxRetries: 3, - GatewayEndpoint: "https://gateway.ipfs.io", - }) - return cfgMgr, nil + cmd = cmd.withBool(FlagNoWait, false) } uploadServiceFactory := func(cfgMgr config.Manager, output Output, opts ...UploadServiceOption) UploadService { @@ -492,7 +421,7 @@ func TestUploadDryRun(t *testing.T) { return NewMockPinningService(t) } - err := handleUpload(context.Background(), cmd, output, cfgMgrFactory, uploadServiceFactory, pinningServiceFactory) + err := handleUpload(context.Background(), cmd, output, cfgMgr, "test-token", true, uploadServiceFactory, pinningServiceFactory) if tt.wantErr { require.Error(t, err) diff --git a/pkg/cli/utils_test.go b/pkg/cli/utils_test.go index a6feea4..b0d0d11 100644 --- a/pkg/cli/utils_test.go +++ b/pkg/cli/utils_test.go @@ -229,10 +229,45 @@ func TestOperationError(t *testing.T) { }) } +func TestFormatStatusWithColor(t *testing.T) { + t.Run("colors pinned status", func(t *testing.T) { + result := formatStatusWithColor("pinned") + assert.NotEmpty(t, result) + }) + + t.Run("colors queued status", func(t *testing.T) { + result := formatStatusWithColor("queued") + assert.NotEmpty(t, result) + }) + + t.Run("colors pinning status", func(t *testing.T) { + result := formatStatusWithColor("pinning") + assert.NotEmpty(t, result) + }) + + t.Run("colors failed status", func(t *testing.T) { + result := formatStatusWithColor("failed") + assert.NotEmpty(t, result) + }) + + t.Run("returns unknown status unchanged", func(t *testing.T) { + result := formatStatusWithColor("unknown") + assert.Equal(t, "unknown", result) + }) +} + +func TestDryRunOption(t *testing.T) { + t.Run("creates option entry", func(t *testing.T) { + opt := DryRunOption("key", "value") + require.Len(t, opt, 1) + assert.Equal(t, "value", opt["key"]) + }) +} + func TestRenderDryRun(t *testing.T) { t.Run("renders dry run with items", func(t *testing.T) { var buf bytes.Buffer - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() output.SetWriter(&buf) RenderDryRun(output, DryRunPreview{ @@ -262,7 +297,7 @@ func TestRenderDryRun(t *testing.T) { t.Run("renders dry run with truncated items", func(t *testing.T) { var buf bytes.Buffer - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() output.SetWriter(&buf) items := make([]string, 15) @@ -284,7 +319,7 @@ func TestRenderDryRun(t *testing.T) { t.Run("renders dry run without items", func(t *testing.T) { var buf bytes.Buffer - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() output.SetWriter(&buf) RenderDryRun(output, DryRunPreview{ @@ -302,7 +337,7 @@ func TestRenderDryRun(t *testing.T) { t.Run("renders dry run with custom max items", func(t *testing.T) { var buf bytes.Buffer - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() output.SetWriter(&buf) RenderDryRun(output, DryRunPreview{ diff --git a/pkg/cli/version.go b/pkg/cli/version.go index 781270b..626cfc2 100644 --- a/pkg/cli/version.go +++ b/pkg/cli/version.go @@ -21,6 +21,10 @@ func init() { } func printVersion(cmd *cli.Command) { + showVersion(newCLICommandWrapper(cmd)) +} + +func showVersion(cmd flagGetter) { info := build.GetInfo() if cmd.Bool(FlagJSON) { diff --git a/pkg/cli/version_test.go b/pkg/cli/version_test.go new file mode 100644 index 0000000..ae7f95f --- /dev/null +++ b/pkg/cli/version_test.go @@ -0,0 +1,80 @@ +package cli + +import ( + "bytes" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/build" +) + +func TestVersionFlagConfiguration(t *testing.T) { + assert.NotNil(t, cli.VersionFlag) + assert.Equal(t, "version", cli.VersionFlag.Names()[0]) + + names := cli.VersionFlag.Names() + assert.Contains(t, names, "V", "version flag should have -V alias") +} + +func TestVersionPrinterSet(t *testing.T) { + assert.NotNil(t, cli.VersionPrinter, "VersionPrinter should be set in init()") +} + +func TestPrintVersionHuman(t *testing.T) { + old := build.Default + defer func() { build.Default = old }() + + build.Default = build.New("1.2.3", "abc12345", "main", "", "go1.26", "linux", "amd64") + + cmd := &cli.Command{} + + var buf bytes.Buffer + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + cli.VersionPrinter(cmd) + + _ = w.Close() + os.Stdout = oldStdout + _, _ = buf.ReadFrom(r) + + output := buf.String() + require.Contains(t, output, "1.2.3") +} + +func TestPrintVersionJSON(t *testing.T) { + old := build.Default + defer func() { build.Default = old }() + + build.Default = build.New("2.0.0", "deadbeef", "main", "", "go1.26", "linux", "amd64") + + cmd := &cli.Command{ + Flags: []cli.Flag{ + &cli.BoolFlag{Name: FlagJSON, Value: true}, + }, + } + + var buf bytes.Buffer + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + cli.VersionPrinter(cmd) + + _ = w.Close() + os.Stdout = oldStdout + _, _ = buf.ReadFrom(r) + + output := buf.String() + require.Contains(t, output, "2.0.0") + require.Contains(t, output, "deadbeef") +} + +func TestNewVersionCommand(t *testing.T) { + root := NewRootCommand() + assert.NotEmpty(t, root.Version, "root command should have a version set") +} diff --git a/pkg/cli/websites.go b/pkg/cli/websites.go index 7a093b1..57a1400 100644 --- a/pkg/cli/websites.go +++ b/pkg/cli/websites.go @@ -5,9 +5,9 @@ import ( "fmt" "strconv" "strings" - "time" "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" ipfs "go.lumeweb.com/ipfs-sdk" ) @@ -69,10 +69,9 @@ func newWebsitesListCommand() *cli.Command { Examples: pinner websites list pinner websites list --json`, - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return websitesList(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return websitesList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -94,10 +93,9 @@ Examples: DNSHostingFlag(), NoDNSHostingFlag(), }, - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return websitesCreate(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return websitesCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -111,10 +109,9 @@ Examples: pinner websites get example.com pinner websites get example.com --json`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return websitesGet(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return websitesGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -148,10 +145,9 @@ Examples: DNSHostingFlag(), NoDNSHostingFlag(), }, - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return websitesUpdate(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return websitesUpdate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -170,33 +166,7 @@ type WebsitesService interface { GetConfig(ctx context.Context) (*ipfs.WebsiteConfigResponse, error) } -func initWebsitesService(ctx context.Context, cmd *cli.Command, output Output) (context.Context, context.CancelFunc, WebsitesService, error) { - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - - cfgMgr, err := defaultConfigManagerFactory() - if err != nil { - cancel() - return ctx, func() {}, nil, err - } - - var websitesService WebsitesService - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - if authToken != "" { - websitesService = NewWebsitesService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - websitesService = defaultWebsitesServiceFactory(cfgMgr, output) - } - - if err := websitesService.RequireAuthenticated(); err != nil { - cancel() - return ctx, func() {}, nil, err - } - - return ctx, cancel, websitesService, nil -} - -func resolveRequiredArg(ctx context.Context, websitesService WebsitesService, cmd *cli.Command) (string, error) { +func resolveRequiredArg(ctx context.Context, websitesService WebsitesService, cmd websitesCommandGetter) (string, error) { args := cmd.Args() if args.Len() == 0 { return "", fmt.Errorf("website ID or domain is required") @@ -238,12 +208,11 @@ func printWebsiteUpdateResult(output Output, website *ipfs.WebsiteItem, message } } -func websitesList(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, websitesService, err := initWebsitesService(ctx, cmd, output) +func websitesList(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() websites, err := websitesService.List(ctx) if err != nil { @@ -333,12 +302,11 @@ func resolveAndGetWebsite(ctx context.Context, websitesService WebsitesService, return websitesService.Get(ctx, id) } -func websitesUpdate(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, websitesService, err := initWebsitesService(ctx, cmd, output) +func websitesUpdate(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() id, err := resolveRequiredArg(ctx, websitesService, cmd) if err != nil { @@ -420,19 +388,17 @@ Examples: Flags: []cli.Flag{ CIDFlag(), }, - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return websitesEnableIPNS(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return websitesEnableIPNS(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } -func websitesEnableIPNS(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, websitesService, err := initWebsitesService(ctx, cmd, output) +func websitesEnableIPNS(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() id, err := resolveRequiredArg(ctx, websitesService, cmd) if err != nil { @@ -463,12 +429,11 @@ func websitesEnableIPNS(ctx context.Context, cmd *cli.Command, output Output) er return nil } -func websitesGet(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, websitesService, err := initWebsitesService(ctx, cmd, output) +func websitesGet(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() id, err := resolveRequiredArg(ctx, websitesService, cmd) if err != nil { @@ -543,12 +508,11 @@ func websitesGet(ctx context.Context, cmd *cli.Command, output Output) error { return nil } -func websitesCreate(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, websitesService, err := initWebsitesService(ctx, cmd, output) +func websitesCreate(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() args := cmd.Args() if args.Len() == 0 { @@ -726,10 +690,9 @@ Examples: pinner websites delete example.com pinner websites delete example.com --json`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return websitesDelete(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return websitesDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } @@ -743,19 +706,17 @@ Examples: pinner websites validate example.com pinner websites validate example.com --json`, ArgsUsage: "", - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return websitesValidate(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return websitesValidate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } -func websitesDelete(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, websitesService, err := initWebsitesService(ctx, cmd, output) +func websitesDelete(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() id, err := resolveRequiredArg(ctx, websitesService, cmd) if err != nil { @@ -779,21 +740,16 @@ func websitesDelete(ctx context.Context, cmd *cli.Command, output Output) error return nil } -func websitesValidate(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, websitesService, err := initWebsitesService(ctx, cmd, output) +func websitesValidate(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() return doWebsitesValidate(ctx, cmd, output, websitesService) } -func doWebsitesValidate(ctx context.Context, cmd interface{ Args() cli.Args }, output Output, websitesService WebsitesService) error { - if err := websitesService.RequireAuthenticated(); err != nil { - return err - } - +func doWebsitesValidate(ctx context.Context, cmd websitesCommandGetter, output Output, websitesService WebsitesService) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("website ID or domain is required") @@ -934,23 +890,17 @@ Use this to find the gateway domain for setting up CNAME records with your DNS p Examples: pinner websites config pinner websites config --json`, - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return websitesConfig(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return websitesConfig(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } -func websitesConfig(ctx context.Context, cmd *cli.Command, output Output) error { - ctx, cancel, websitesService, err := initWebsitesService(ctx, cmd, output) +func websitesConfig(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) if err != nil { return err } - defer cancel() - - if err := websitesService.RequireAuthenticated(); err != nil { - return err - } config, err := websitesService.GetConfig(ctx) if err != nil { diff --git a/pkg/cli/websites_handler_test.go b/pkg/cli/websites_handler_test.go new file mode 100644 index 0000000..610a448 --- /dev/null +++ b/pkg/cli/websites_handler_test.go @@ -0,0 +1,791 @@ +package cli + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + ipfs "go.lumeweb.com/ipfs-sdk" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + +type mockWebsitesHandlerService struct { + requireAuthenticatedErr error + listFunc func(ctx context.Context) ([]ipfs.WebsiteItem, error) + createFunc func(ctx context.Context, domain, cid, targetType string) (*ipfs.WebsiteItem, error) + createWithOptionsFunc func(ctx context.Context, req ipfs.WebsiteRequest) (*ipfs.WebsiteItem, error) + getFunc func(ctx context.Context, id string) (*ipfs.WebsiteItem, error) + updateFunc func(ctx context.Context, id, domain, cid, targetType string) (*ipfs.WebsiteItem, error) + updateWithOptionsFunc func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) + deleteFunc func(ctx context.Context, id string) error + validateFunc func(ctx context.Context, id string) (*ipfs.WebsiteValidateResponse, error) + getSSLStatusFunc func(ctx context.Context, domain string) (*ipfs.WebsiteResponse, error) + getConfigFunc func(ctx context.Context) (*ipfs.WebsiteConfigResponse, error) +} + +func (m *mockWebsitesHandlerService) RequireAuthenticated() error { + return m.requireAuthenticatedErr +} + +func (m *mockWebsitesHandlerService) List(ctx context.Context) ([]ipfs.WebsiteItem, error) { + if m.listFunc != nil { + return m.listFunc(ctx) + } + return nil, nil +} + +func (m *mockWebsitesHandlerService) Create(ctx context.Context, domain, cid, targetType string) (*ipfs.WebsiteItem, error) { + if m.createFunc != nil { + return m.createFunc(ctx, domain, cid, targetType) + } + return nil, nil +} + +func (m *mockWebsitesHandlerService) CreateWithOptions(ctx context.Context, req ipfs.WebsiteRequest) (*ipfs.WebsiteItem, error) { + if m.createWithOptionsFunc != nil { + return m.createWithOptionsFunc(ctx, req) + } + if m.createFunc != nil { + return m.createFunc(ctx, req.Domain, req.TargetHash, req.TargetType) + } + return nil, nil +} + +func (m *mockWebsitesHandlerService) Get(ctx context.Context, id string) (*ipfs.WebsiteItem, error) { + if m.getFunc != nil { + return m.getFunc(ctx, id) + } + return nil, nil +} + +func (m *mockWebsitesHandlerService) Update(ctx context.Context, id, domain, cid, targetType string) (*ipfs.WebsiteItem, error) { + if m.updateFunc != nil { + return m.updateFunc(ctx, id, domain, cid, targetType) + } + return nil, nil +} + +func (m *mockWebsitesHandlerService) UpdateWithOptions(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { + if m.updateWithOptionsFunc != nil { + return m.updateWithOptionsFunc(ctx, id, req) + } + return nil, nil +} + +func (m *mockWebsitesHandlerService) Delete(ctx context.Context, id string) error { + if m.deleteFunc != nil { + return m.deleteFunc(ctx, id) + } + return nil +} + +func (m *mockWebsitesHandlerService) Validate(ctx context.Context, id string) (*ipfs.WebsiteValidateResponse, error) { + if m.validateFunc != nil { + return m.validateFunc(ctx, id) + } + return nil, nil +} + +func (m *mockWebsitesHandlerService) GetSSLStatus(ctx context.Context, domain string) (*ipfs.WebsiteResponse, error) { + if m.getSSLStatusFunc != nil { + return m.getSSLStatusFunc(ctx, domain) + } + return nil, nil +} + +func (m *mockWebsitesHandlerService) GetConfig(ctx context.Context) (*ipfs.WebsiteConfigResponse, error) { + if m.getConfigFunc != nil { + return m.getConfigFunc(ctx) + } + return nil, nil +} + +func setupWebsitesHandlerTest(t *testing.T) (*mockWebsitesHandlerService, *configmocks.MockManager) { + t.Helper() + mockSvc := &mockWebsitesHandlerService{} + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + BaseEndpoint: "pinner.xyz", + Secure: true, + AuthToken: "test-token", + }).Maybe() + + origFactory := websitesServiceFactory + t.Cleanup(func() { websitesServiceFactory = origFactory }) + websitesServiceFactory = func(config.Manager, Output, ...WebsitesServiceOption) WebsitesService { + return mockSvc + } + + return mockSvc, cfgMgr +} + +// ===== websitesList ===== + +func TestWebsitesListHandler_Success(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{ + {Id: 1, Domain: "example.com", TargetHash: "QmXxx", Status: "active", Created: now}, + {Id: 2, Domain: "test.org", TargetHash: "QmYyy", Status: "pending", Created: now}, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand() + err := websitesList(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesListHandler_Empty(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{}, nil + } + + output := newTestOutput() + cmd := newMockCommand() + err := websitesList(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesListHandler_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return nil, errors.New("server error") + } + + output := newTestOutput() + cmd := newMockCommand() + err := websitesList(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "server error") +} + +func TestWebsitesListHandler_Unauthenticated(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.requireAuthenticatedErr = ErrNotAuthenticated + + output := newTestOutput() + cmd := newMockCommand() + err := websitesList(context.Background(), cmd, output, cfgMgr, "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrNotAuthenticated)) +} + +// ===== websitesCreate ===== + +func TestWebsitesCreateHandler_Success(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.createWithOptionsFunc = func(ctx context.Context, req ipfs.WebsiteRequest) (*ipfs.WebsiteItem, error) { + assert.Equal(t, "example.com", req.Domain) + assert.Equal(t, "QmXxx", req.TargetHash) + assert.Equal(t, "ipfs", req.TargetType) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "QmXxx", TargetType: "ipfs", + Status: "active", Created: now, ValidationToken: "lumeweb-verify=abc123", + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com").withString(FlagCID, "QmXxx") + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesCreateHandler_MissingDomain(t *testing.T) { + _, cfgMgr := setupWebsitesHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand().withString(FlagCID, "QmXxx") + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain is required") +} + +func TestWebsitesCreateHandler_WithDNSHosting(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.createWithOptionsFunc = func(ctx context.Context, req ipfs.WebsiteRequest) (*ipfs.WebsiteItem, error) { + assert.NotNil(t, req.DnsHostingEnabled) + assert.True(t, *req.DnsHostingEnabled) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "QmXxx", TargetType: "ipfs", + Status: "active", Created: now, DnsHostingEnabled: true, ValidationToken: "abc123", + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com").withString(FlagCID, "QmXxx"). + withBool(FlagDNSHosting, true).withIsSet(FlagDNSHosting, true) + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesCreateHandler_WithNoDNSHosting(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.createWithOptionsFunc = func(ctx context.Context, req ipfs.WebsiteRequest) (*ipfs.WebsiteItem, error) { + assert.NotNil(t, req.DnsHostingEnabled) + assert.False(t, *req.DnsHostingEnabled) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "QmXxx", TargetType: "ipfs", + Status: "active", Created: now, ValidationToken: "abc123", + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com").withString(FlagCID, "QmXxx"). + withBool(FlagNoDNSHosting, true).withIsSet(FlagNoDNSHosting, true) + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesCreateHandler_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.createWithOptionsFunc = func(ctx context.Context, req ipfs.WebsiteRequest) (*ipfs.WebsiteItem, error) { + return nil, errors.New("conflict") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com").withString(FlagCID, "QmXxx") + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "conflict") +} + +func TestWebsitesCreateHandler_DefaultTargetType(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.createWithOptionsFunc = func(ctx context.Context, req ipfs.WebsiteRequest) (*ipfs.WebsiteItem, error) { + assert.Equal(t, "ipfs", req.TargetType) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "QmXxx", TargetType: "ipfs", + Status: "active", Created: now, ValidationToken: "abc123", + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com").withString(FlagCID, "QmXxx") + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +// ===== websitesGet ===== + +func TestWebsitesGetHandler_Success(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.getFunc = func(ctx context.Context, id string) (*ipfs.WebsiteItem, error) { + assert.Equal(t, "1", id) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "QmXxx", TargetType: "ipfs", + Status: "active", Created: now, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesGetHandler_DomainArg(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{{Id: 1, Domain: "example.com", Status: "active", Created: now}}, nil + } + mockSvc.getFunc = func(ctx context.Context, id string) (*ipfs.WebsiteItem, error) { + assert.Equal(t, "1", id) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "QmXxx", TargetType: "ipfs", + Status: "active", Created: now, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesGetHandler_MissingArg(t *testing.T) { + _, cfgMgr := setupWebsitesHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "website ID or domain is required") +} + +func TestWebsitesGetHandler_NotFound(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.getFunc = func(ctx context.Context, id string) (*ipfs.WebsiteItem, error) { + return nil, errors.New("website not found") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("999") + err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "website not found") +} + +func TestWebsitesGetHandler_DomainNotFound(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("nonexistent.com") + err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "website not found for domain") +} + +// ===== websitesUpdate ===== + +func TestWebsitesUpdateHandler_Success(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.updateWithOptionsFunc = func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { + assert.Equal(t, "1", id) + assert.NotNil(t, req.TargetHash) + assert.Equal(t, "QmNewHash", *req.TargetHash) + assert.NotNil(t, req.TargetType) + assert.Equal(t, "ipfs", *req.TargetType) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "QmNewHash", TargetType: "ipfs", + Status: "active", Created: now, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1"). + withString(FlagCID, "QmNewHash").withIsSet(FlagCID, true). + withString(FlagTargetType, "ipfs").withIsSet(FlagTargetType, true) + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesUpdateHandler_NoUpdateFields(t *testing.T) { + _, cfgMgr := setupWebsitesHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "at least one field must be provided for update") +} + +func TestWebsitesUpdateHandler_CIDWithoutTargetType(t *testing.T) { + _, cfgMgr := setupWebsitesHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand().withArgs("1"). + withString(FlagCID, "QmNewHash").withIsSet(FlagCID, true) + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "--target-type is required when --cid is provided") +} + +func TestWebsitesUpdateHandler_DNSHostingEnabled(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.updateWithOptionsFunc = func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { + assert.NotNil(t, req.DnsHostingEnabled) + assert.True(t, *req.DnsHostingEnabled) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "QmXxx", TargetType: "ipfs", + Status: "active", Created: now, DnsHostingEnabled: true, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1"). + withBool(FlagDNSHosting, true).withIsSet(FlagDNSHosting, true) + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesUpdateHandler_DNSHostingDisabled(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.updateWithOptionsFunc = func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { + assert.NotNil(t, req.DnsHostingEnabled) + assert.False(t, *req.DnsHostingEnabled) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "QmXxx", TargetType: "ipfs", + Status: "active", Created: now, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1"). + withBool(FlagNoDNSHosting, true).withIsSet(FlagNoDNSHosting, true) + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesUpdateHandler_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.updateWithOptionsFunc = func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { + return nil, errors.New("update failed") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1"). + withString(FlagCID, "QmNewHash").withIsSet(FlagCID, true). + withString(FlagTargetType, "ipfs").withIsSet(FlagTargetType, true) + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "update failed") +} + +func TestWebsitesUpdateHandler_MissingArg(t *testing.T) { + _, cfgMgr := setupWebsitesHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "website ID or domain is required") +} + +// ===== websitesEnableIPNS ===== + +func TestWebsitesEnableIPNSHandler_Success(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.updateWithOptionsFunc = func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { + assert.Equal(t, "1", id) + assert.NotNil(t, req.TargetType) + assert.Equal(t, "ipns", *req.TargetType) + assert.Nil(t, req.TargetHash) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "12D3KooWTest", TargetType: "ipns", + Status: "active", Created: now, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesEnableIPNSHandler_WithCID(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.updateWithOptionsFunc = func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { + assert.NotNil(t, req.TargetType) + assert.Equal(t, "ipns", *req.TargetType) + assert.NotNil(t, req.TargetHash) + assert.Equal(t, "QmNewHash", *req.TargetHash) + return &ipfs.WebsiteItem{ + Id: 1, Domain: "example.com", TargetHash: "12D3KooWTest", TargetType: "ipns", + Status: "active", Created: now, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1").withString(FlagCID, "QmNewHash").withIsSet(FlagCID, true) + err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesEnableIPNSHandler_MissingArg(t *testing.T) { + _, cfgMgr := setupWebsitesHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "website ID or domain is required") +} + +func TestWebsitesEnableIPNSHandler_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.updateWithOptionsFunc = func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { + return nil, errors.New("not found") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +// ===== websitesDelete ===== + +func TestWebsitesDeleteHandler_Success(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.deleteFunc = func(ctx context.Context, id string) error { + assert.Equal(t, "1", id) + return nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesDeleteHandler_DomainArg(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{{Id: 1, Domain: "example.com", Status: "active", Created: now}}, nil + } + mockSvc.deleteFunc = func(ctx context.Context, id string) error { + assert.Equal(t, "1", id) + return nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesDeleteHandler_MissingArg(t *testing.T) { + _, cfgMgr := setupWebsitesHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "website ID or domain is required") +} + +func TestWebsitesDeleteHandler_NotFound(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.deleteFunc = func(ctx context.Context, id string) error { + return errors.New("website not found") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("999") + err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "website not found") +} + +func TestWebsitesDeleteHandler_DomainNotFound(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("nonexistent.com") + err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "website not found for domain") +} + +// ===== websitesValidate ===== + +func TestWebsitesValidateHandler_Success(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.validateFunc = func(ctx context.Context, id string) (*ipfs.WebsiteValidateResponse, error) { + assert.Equal(t, "1", id) + return &ipfs.WebsiteValidateResponse{ + Domain: "example.com", Id: 1, Valid: true, Message: "Website is valid", + }, nil + } + mockSvc.getFunc = func(ctx context.Context, id string) (*ipfs.WebsiteItem, error) { + return &ipfs.WebsiteItem{Id: 1, Domain: "example.com", Status: "active", Created: time.Now()}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesValidateHandler_ValidationFailure(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.validateFunc = func(ctx context.Context, id string) (*ipfs.WebsiteValidateResponse, error) { + return &ipfs.WebsiteValidateResponse{ + Domain: "example.com", Id: 1, Valid: false, Message: "DNS record not found", + }, nil + } + mockSvc.getFunc = func(ctx context.Context, id string) (*ipfs.WebsiteItem, error) { + return &ipfs.WebsiteItem{Id: 1, Domain: "example.com", Status: "pending", Created: time.Now()}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("1") + err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesValidateHandler_MissingArg(t *testing.T) { + _, cfgMgr := setupWebsitesHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "website ID or domain is required") +} + +func TestWebsitesValidateHandler_DomainArg(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{{Id: 1, Domain: "example.com", Status: "active", Created: now}}, nil + } + mockSvc.validateFunc = func(ctx context.Context, id string) (*ipfs.WebsiteValidateResponse, error) { + assert.Equal(t, "1", id) + return &ipfs.WebsiteValidateResponse{ + Domain: "example.com", Id: 1, Valid: true, Message: "Valid", + }, nil + } + mockSvc.getFunc = func(ctx context.Context, id string) (*ipfs.WebsiteItem, error) { + return &ipfs.WebsiteItem{Id: 1, Domain: "example.com", Status: "active", Created: now}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +// ===== websitesSSLStatus ===== + +func TestWebsitesSSLStatusHandler_Success(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + now := time.Now() + issuedAt := now.Format(time.RFC3339) + lastUpdated := now.Add(24 * time.Hour).Format(time.RFC3339) + mockSvc.getSSLStatusFunc = func(ctx context.Context, domain string) (*ipfs.WebsiteResponse, error) { + assert.Equal(t, "example.com", domain) + var resp ipfs.WebsiteResponse + if err := json.Unmarshal([]byte(fmt.Sprintf( + `{"domain":"example.com","ssl":{"status":"active","issued_at":"%s","last_updated_at":"%s"}}`, + issuedAt, lastUpdated, + )), &resp); err != nil { + panic(err) + } + return &resp, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesSSLStatusHandler_MissingDomain(t *testing.T) { + _, cfgMgr := setupWebsitesHandlerTest(t) + + output := newTestOutput() + cmd := newMockCommand() + err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "domain is required") +} + +func TestWebsitesSSLStatusHandler_NoSSLInfo(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.getSSLStatusFunc = func(ctx context.Context, domain string) (*ipfs.WebsiteResponse, error) { + return &ipfs.WebsiteResponse{Domain: "example.com", Ssl: nil}, nil + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesSSLStatusHandler_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.getSSLStatusFunc = func(ctx context.Context, domain string) (*ipfs.WebsiteResponse, error) { + return nil, errors.New("API error") + } + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "API error") +} + +func TestWebsitesSSLStatusHandler_Unauthenticated(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.requireAuthenticatedErr = ErrNotAuthenticated + + output := newTestOutput() + cmd := newMockCommand().withArgs("example.com") + err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrNotAuthenticated)) +} + +// ===== websitesConfig ===== + +func TestWebsitesConfigHandler_Success(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + gateway := "gw.pinner.xyz" + ns := []string{"ns1.pinner.xyz", "ns2.pinner.xyz"} + mockSvc.getConfigFunc = func(ctx context.Context) (*ipfs.WebsiteConfigResponse, error) { + return &ipfs.WebsiteConfigResponse{ + GatewayDomain: &gateway, + Nameservers: &ns, + }, nil + } + + output := newTestOutput() + cmd := newMockCommand() + err := websitesConfig(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesConfigHandler_NoSites(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.getConfigFunc = func(ctx context.Context) (*ipfs.WebsiteConfigResponse, error) { + return &ipfs.WebsiteConfigResponse{}, nil + } + + output := newTestOutput() + cmd := newMockCommand() + err := websitesConfig(context.Background(), cmd, output, cfgMgr, "test-token") + require.NoError(t, err) +} + +func TestWebsitesConfigHandler_ServiceError(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.getConfigFunc = func(ctx context.Context) (*ipfs.WebsiteConfigResponse, error) { + return nil, errors.New("failed to get config") + } + + output := newTestOutput() + cmd := newMockCommand() + err := websitesConfig(context.Background(), cmd, output, cfgMgr, "test-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get config") +} + +func TestWebsitesConfigHandler_Unauthenticated(t *testing.T) { + mockSvc, cfgMgr := setupWebsitesHandlerTest(t) + mockSvc.requireAuthenticatedErr = ErrNotAuthenticated + + output := newTestOutput() + cmd := newMockCommand() + err := websitesConfig(context.Background(), cmd, output, cfgMgr, "") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrNotAuthenticated)) +} diff --git a/pkg/cli/websites_required_records_test.go b/pkg/cli/websites_required_records_test.go new file mode 100644 index 0000000..9806fd1 --- /dev/null +++ b/pkg/cli/websites_required_records_test.go @@ -0,0 +1,145 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/require" + ipfs "go.lumeweb.com/ipfs-sdk" +) + +func TestBuildRequiredRecords(t *testing.T) { + tests := []struct { + name string + website *ipfs.WebsiteItem + nameservers []string + want []map[string]string + }{ + { + name: "nil website returns nil", + website: nil, + nameservers: []string{"ns1.example.com"}, + want: nil, + }, + { + name: "DNS hosting enabled with two nameservers", + website: &ipfs.WebsiteItem{ + Domain: "example.com", + DnsHostingEnabled: true, + }, + nameservers: []string{"ns1.pinner.xyz", "ns2.pinner.xyz"}, + want: []map[string]string{ + {"name": "example.com", "type": "NS", "value": "ns1.pinner.xyz"}, + {"name": "example.com", "type": "NS", "value": "ns2.pinner.xyz"}, + }, + }, + { + name: "DNS hosting enabled with empty nameservers", + website: &ipfs.WebsiteItem{ + Domain: "example.com", + DnsHostingEnabled: true, + }, + nameservers: []string{}, + want: []map[string]string{}, + }, + { + name: "self-managed DNS with all fields present", + website: &ipfs.WebsiteItem{ + Domain: "example.com", + DnsHostingEnabled: false, + ValidationToken: "token123", + TargetType: "ipfs", + TargetHash: "QmXxx", + GatewayDomain: strPtr("gateway.pinner.xyz"), + }, + want: []map[string]string{ + {"name": "example.com", "type": "TXT", "value": "token123"}, + {"name": "_dnslink.example.com", "type": "TXT", "value": "dnslink=/ipfs/QmXxx"}, + {"name": "example.com", "type": "CNAME", "value": "gateway.pinner.xyz"}, + }, + }, + { + name: "self-managed DNS without gateway domain", + website: &ipfs.WebsiteItem{ + Domain: "example.com", + DnsHostingEnabled: false, + ValidationToken: "token123", + TargetType: "ipfs", + TargetHash: "QmXxx", + }, + want: []map[string]string{ + {"name": "example.com", "type": "TXT", "value": "token123"}, + {"name": "_dnslink.example.com", "type": "TXT", "value": "dnslink=/ipfs/QmXxx"}, + }, + }, + { + name: "self-managed DNS with custom validation host", + website: &ipfs.WebsiteItem{ + Domain: "example.com", + DnsHostingEnabled: false, + ValidationToken: "token123", + ValidationRecordHost: strPtr("_validation.example.com"), + TargetType: "ipfs", + TargetHash: "QmXxx", + }, + want: []map[string]string{ + {"name": "_validation.example.com", "type": "TXT", "value": "token123"}, + {"name": "_dnslink.example.com", "type": "TXT", "value": "dnslink=/ipfs/QmXxx"}, + }, + }, + { + name: "self-managed DNS with IPNS target type", + website: &ipfs.WebsiteItem{ + Domain: "example.com", + DnsHostingEnabled: false, + ValidationToken: "token123", + TargetType: "ipns", + TargetHash: "k51qzi5uqu5dg4vh...", + }, + want: []map[string]string{ + {"name": "example.com", "type": "TXT", "value": "token123"}, + {"name": "_dnslink.example.com", "type": "TXT", "value": "dnslink=/ipns/k51qzi5uqu5dg4vh..."}, + }, + }, + { + name: "self-managed DNS with empty gateway domain pointer", + website: &ipfs.WebsiteItem{ + Domain: "example.com", + DnsHostingEnabled: false, + ValidationToken: "token123", + TargetType: "ipfs", + TargetHash: "QmXxx", + GatewayDomain: strPtr(""), + }, + want: []map[string]string{ + {"name": "example.com", "type": "TXT", "value": "token123"}, + {"name": "_dnslink.example.com", "type": "TXT", "value": "dnslink=/ipfs/QmXxx"}, + }, + }, + { + name: "self-managed DNS with empty validation record host pointer", + website: &ipfs.WebsiteItem{ + Domain: "example.com", + DnsHostingEnabled: false, + ValidationToken: "token123", + ValidationRecordHost: strPtr(""), + TargetType: "ipfs", + TargetHash: "QmXxx", + }, + want: []map[string]string{ + {"name": "example.com", "type": "TXT", "value": "token123"}, + {"name": "_dnslink.example.com", "type": "TXT", "value": "dnslink=/ipfs/QmXxx"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildRequiredRecords(tt.website, tt.nameservers) + require.Equal(t, tt.want, got) + }) + } +} + +func strPtr(s string) *string { + return &s +} diff --git a/pkg/cli/websites_service.go b/pkg/cli/websites_service.go index 4cee480..03b9fbe 100644 --- a/pkg/cli/websites_service.go +++ b/pkg/cli/websites_service.go @@ -9,41 +9,80 @@ import ( // websitesService implements the WebsitesService interface using the ipfs.WebsitesService. type websitesService struct { - service ipfs.WebsitesService - cfgMgr config.Manager - authToken string - authenticated bool + ipfsServiceBase + service ipfs.WebsitesService + client *ipfs.Client +} + +// WebsitesServiceOption is a function that configures a websitesService. +type WebsitesServiceOption func(*websitesService) + +// WithWebsitesAuthToken sets an auth token override that takes precedence over config. +func WithWebsitesAuthToken(token string) WebsitesServiceOption { + return func(s *websitesService) { + withAuthToken(token)(&s.ipfsServiceBase) + } +} + +// WithWebsitesClient sets a pre-configured ipfs.Client, bypassing the default ipfs.NewClient() call. +func WithWebsitesClient(client *ipfs.Client) WebsitesServiceOption { + return func(s *websitesService) { + s.client = client + } } // WebsitesServiceFactory creates a WebsitesService with dependencies. -type WebsitesServiceFactory func(cfgMgr config.Manager, output Output) WebsitesService +type WebsitesServiceFactory func(cfgMgr config.Manager, output Output, opts ...WebsitesServiceOption) WebsitesService + +// websitesServiceFactory is the factory function used by newAuthenticatedWebsitesService. +// It can be overridden in tests to inject mock services. +var websitesServiceFactory WebsitesServiceFactory = defaultWebsitesServiceFactory // defaultWebsitesServiceFactory creates a default WebsitesService instance. -func defaultWebsitesServiceFactory(cfgMgr config.Manager, output Output) WebsitesService { - return NewWebsitesService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointSecure()) +func defaultWebsitesServiceFactory(cfgMgr config.Manager, output Output, opts ...WebsitesServiceOption) WebsitesService { + return NewWebsitesService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointSecure(), opts...) +} + +// newAuthenticatedWebsitesService creates a WebsitesService with authentication. +// It returns an error if the user is not authenticated. +func newAuthenticatedWebsitesService(cfgMgr config.Manager, output Output, authToken string) (WebsitesService, error) { + var svcOpts []WebsitesServiceOption + if authToken != "" { + svcOpts = append(svcOpts, WithWebsitesAuthToken(authToken)) + } + websitesService := websitesServiceFactory(cfgMgr, output, svcOpts...) + if err := websitesService.RequireAuthenticated(); err != nil { + return nil, err + } + return websitesService, nil } // NewWebsitesService creates a new WebsitesService instance. -func NewWebsitesService(cfgMgr config.Manager, output Output, apiEndpoint string) WebsitesService { +func NewWebsitesService(cfgMgr config.Manager, output Output, apiEndpoint string, opts ...WebsitesServiceOption) WebsitesService { authToken := cfgMgr.Config().AuthToken - client, err := ipfs.NewClient(apiEndpoint, authToken) - if err != nil { - output.PrintError(err) - return &websitesService{ - service: nil, - cfgMgr: cfgMgr, - authToken: authToken, - authenticated: false, - } + s := &websitesService{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: authToken, + }, + } + for _, opt := range opts { + opt(s) } - return &websitesService{ - service: client.Websites(), - cfgMgr: cfgMgr, - authToken: authToken, - authenticated: authToken != "", + if s.client != nil { + s.service = s.client.Websites() + } else { + client, err := ipfs.NewClient(apiEndpoint, authToken) + if err != nil { + output.PrintError(err) + s.service = nil + return s + } + s.service = client.Websites() } + return s } // List retrieves all websites for the authenticated user. @@ -152,10 +191,4 @@ func (s *websitesService) GetConfig(ctx context.Context) (*ipfs.WebsiteConfigRes return s.service.GetConfig(ctx) } -// RequireAuthenticated checks if the service is authenticated. -func (s *websitesService) RequireAuthenticated() error { - if !s.authenticated { - return ErrNotAuthenticated - } - return nil -} + diff --git a/pkg/cli/websites_service_crud_test.go b/pkg/cli/websites_service_crud_test.go new file mode 100644 index 0000000..b0661cc --- /dev/null +++ b/pkg/cli/websites_service_crud_test.go @@ -0,0 +1,124 @@ +package cli + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + ipfs "go.lumeweb.com/ipfs-sdk" + "go.lumeweb.com/pinner-cli/pkg/config" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" +) + +func newUnauthWebsitesService(t *testing.T) *websitesService { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: ""}).Maybe() + return &websitesService{ + ipfsServiceBase: ipfsServiceBase{cfgMgr: cfgMgr, authToken: ""}, + } +} + +func newAuthedNilWebsitesService(t *testing.T) *websitesService { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: "token"}).Maybe() + return &websitesService{ + ipfsServiceBase: ipfsServiceBase{cfgMgr: cfgMgr, authToken: "token"}, + service: nil, + } +} + +func TestWebsitesService_List_Unauthenticated(t *testing.T) { + svc := newUnauthWebsitesService(t) + _, err := svc.List(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestWebsitesService_List_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilWebsitesService(t) + _, err := svc.List(context.Background()) + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestWebsitesService_Create_Unauthenticated(t *testing.T) { + svc := newUnauthWebsitesService(t) + _, err := svc.Create(context.Background(), "example.com", "QmHash", "cid") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestWebsitesService_CreateWithOptions_Unauthenticated(t *testing.T) { + svc := newUnauthWebsitesService(t) + _, err := svc.CreateWithOptions(context.Background(), ipfs.WebsiteRequest{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestWebsitesService_Get_Unauthenticated(t *testing.T) { + svc := newUnauthWebsitesService(t) + _, err := svc.Get(context.Background(), "example.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestWebsitesService_Update_Unauthenticated(t *testing.T) { + svc := newUnauthWebsitesService(t) + _, err := svc.Update(context.Background(), "example.com", "QmHash", "cid", "ipns") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestWebsitesService_UpdateWithOptions_Unauthenticated(t *testing.T) { + svc := newUnauthWebsitesService(t) + _, err := svc.UpdateWithOptions(context.Background(), "example.com", ipfs.WebsiteUpdateRequest{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestWebsitesService_Delete_Unauthenticated(t *testing.T) { + svc := newUnauthWebsitesService(t) + err := svc.Delete(context.Background(), "example.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestWebsitesService_Validate_Unauthenticated(t *testing.T) { + svc := newUnauthWebsitesService(t) + _, err := svc.Validate(context.Background(), "example.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestWebsitesService_GetSSLStatus_Unauthenticated(t *testing.T) { + svc := newUnauthWebsitesService(t) + _, err := svc.GetSSLStatus(context.Background(), "example.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestWebsitesService_GetConfig_Unauthenticated(t *testing.T) { + svc := newUnauthWebsitesService(t) + _, err := svc.GetConfig(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "not authenticated") +} + +func TestWebsitesService_GetConfig_ServiceUnavailable(t *testing.T) { + svc := newAuthedNilWebsitesService(t) + _, err := svc.GetConfig(context.Background()) + require.Error(t, err) + assert.Equal(t, ErrServiceUnavailable, err) +} + +func TestWebsitesService_WithWebsitesAuthToken(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{AuthToken: ""}).Maybe() + + svc := &websitesService{ + ipfsServiceBase: ipfsServiceBase{cfgMgr: cfgMgr}, + } + WithWebsitesAuthToken("override-token")(svc) + assert.Equal(t, "override-token", svc.getAuthToken()) +} diff --git a/pkg/cli/websites_service_test.go b/pkg/cli/websites_service_test.go index 09ffcda..3605711 100644 --- a/pkg/cli/websites_service_test.go +++ b/pkg/cli/websites_service_test.go @@ -5,6 +5,8 @@ import ( "testing" "github.com/stretchr/testify/require" + configmocks "go.lumeweb.com/pinner-cli/pkg/config/mocks" + "go.lumeweb.com/pinner-cli/pkg/config" ipfs "go.lumeweb.com/ipfs-sdk" ) @@ -116,3 +118,72 @@ func TestWebsitesService_RequireAuthenticated(t *testing.T) { }) } } + +func TestWebsitesService_AuthTokenOverride(t *testing.T) { + t.Run("override token takes precedence over empty config token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "", + }).Maybe() + + svc := &websitesService{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "override-token", + }, + } + + err := svc.RequireAuthenticated() + require.NoError(t, err) + }) + + t.Run("override token takes precedence over config token", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "config-token", + }).Maybe() + + svc := &websitesService{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "override-token", + }, + } + + require.Equal(t, "override-token", svc.getAuthToken()) + }) + + t.Run("falls back to config token when override is empty", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "config-token", + }).Maybe() + + svc := &websitesService{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + authToken: "", + }, + } + + require.Equal(t, "config-token", svc.getAuthToken()) + }) + + t.Run("WithWebsitesAuthToken functional option sets override", func(t *testing.T) { + cfgMgr := configmocks.NewMockManager(t) + cfgMgr.EXPECT().Config().Return(&config.Config{ + AuthToken: "", + }).Maybe() + + svc := &websitesService{ + ipfsServiceBase: ipfsServiceBase{ + cfgMgr: cfgMgr, + }, + } + WithWebsitesAuthToken("override-token")(svc) + + require.Equal(t, "override-token", svc.getAuthToken()) + err := svc.RequireAuthenticated() + require.NoError(t, err) + }) +} diff --git a/pkg/cli/websites_ssl.go b/pkg/cli/websites_ssl.go index a43206d..e36edbc 100644 --- a/pkg/cli/websites_ssl.go +++ b/pkg/cli/websites_ssl.go @@ -6,11 +6,11 @@ import ( "time" "github.com/urfave/cli/v3" + "go.lumeweb.com/pinner-cli/pkg/config" ipfs "go.lumeweb.com/ipfs-sdk" ) -// newWebsitesSSLCommand creates the SSL subcommand for websites. func newWebsitesSSLCommand() *cli.Command { return &cli.Command{ Name: "ssl", @@ -31,7 +31,6 @@ Examples: } } -// newWebsitesSSLStatusCommand creates the SSL status command. func newWebsitesSSLStatusCommand() *cli.Command { return &cli.Command{ Name: "status", @@ -55,15 +54,13 @@ Examples: Usage: "Watch for SSL status changes", }, }, - Action: func(ctx context.Context, cmd *cli.Command) error { - output := setupOutput(cmd) - return websitesSSLStatus(ctx, cmd, output) - }, + Action: withContext(func(ctx context.Context, cc *commandContext) error { + return websitesSSLStatus(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + }), } } -// websitesSSLStatus retrieves and displays SSL certificate status for a website. -func websitesSSLStatus(ctx context.Context, cmd *cli.Command, output Output) error { +func websitesSSLStatus(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain is required") @@ -72,21 +69,8 @@ func websitesSSLStatus(ctx context.Context, cmd *cli.Command, output Output) err domain := args.First() watch := cmd.Bool("watch") - cfgMgr, err := defaultConfigManagerFactory() + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) if err != nil { - return fmt.Errorf("failed to get config: %w", err) - } - - var websitesService WebsitesService - authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) - if authToken != "" { - websitesService = NewWebsitesService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - websitesService = defaultWebsitesServiceFactory(cfgMgr, output) - } - - if err := websitesService.RequireAuthenticated(); err != nil { return err } @@ -151,7 +135,6 @@ func websitesSSLStatus(ctx context.Context, cmd *cli.Command, output Output) err return nil } -// formatTimePtr formats a time pointer to a human-readable string. func formatTimePtr(t *time.Time) string { if t == nil { return "N/A" diff --git a/pkg/cli/websites_test.go b/pkg/cli/websites_test.go index ba5bff6..7082192 100644 --- a/pkg/cli/websites_test.go +++ b/pkg/cli/websites_test.go @@ -191,7 +191,7 @@ func TestWebsitesList(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockSvc := &mockWebsitesServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(mockSvc) @@ -403,17 +403,13 @@ func TestWebsitesCreate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockSvc := &mockWebsitesServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(mockSvc) } - cmd := &mockWebsitesCreateCommand{ - domain: tt.domain, - cid: tt.cid, - targetType: tt.targetType, - } + cmd := newMockCommand().withString(FlagDomain, tt.domain).withString(FlagCID, tt.cid).withString(FlagTargetType, tt.targetType) err := websitesCreateWithService(context.Background(), cmd, output, mockSvc) @@ -469,11 +465,7 @@ func TestWebsitesCreateJSON(t *testing.T) { tt.setupMocks(mockSvc) } - cmd := &mockWebsitesCreateCommand{ - domain: tt.domain, - cid: tt.cid, - targetType: tt.targetType, - } + cmd := newMockCommand().withString(FlagDomain, tt.domain).withString(FlagCID, tt.cid).withString(FlagTargetType, tt.targetType) err := websitesCreateWithService(context.Background(), cmd, output, mockSvc) @@ -544,12 +536,11 @@ func websitesGetWithService(ctx context.Context, cmd interface{ Args() cli.Args } args := cmd.Args() - if args.Len() == 0 { + id := args.First() + if id == "" { return fmt.Errorf("website ID or domain is required") } - id := args.First() - website, err := websitesService.Get(ctx, id) if err != nil { return err @@ -591,7 +582,7 @@ func TestWebsitesGet(t *testing.T) { tests := []struct { name string setupMocks func(*mockWebsitesServiceForCLI) - cmd *mockWebsitesGetCommand + cmd *mockCommand wantErr bool errContains string }{ @@ -609,7 +600,7 @@ func TestWebsitesGet(t *testing.T) { }, nil } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: false, }, { @@ -626,12 +617,12 @@ func TestWebsitesGet(t *testing.T) { }, nil } }, - cmd: &mockWebsitesGetCommand{id: "2"}, + cmd: newMockCommand().withArgs("2"), wantErr: false, }, { name: "missing website ID", - cmd: &mockWebsitesGetCommand{id: ""}, + cmd: newMockCommand().withArgs(""), wantErr: true, errContains: "website ID or domain is required", }, @@ -642,7 +633,7 @@ func TestWebsitesGet(t *testing.T) { return nil, errors.New("website not found") } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: true, errContains: "website not found", }, @@ -651,7 +642,7 @@ func TestWebsitesGet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockSvc := &mockWebsitesServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(mockSvc) @@ -675,7 +666,7 @@ func TestWebsitesGetJSON(t *testing.T) { tests := []struct { name string setupMocks func(*mockWebsitesServiceForCLI) - cmd *mockWebsitesGetCommand + cmd *mockCommand wantErr bool errContains string }{ @@ -693,7 +684,7 @@ func TestWebsitesGetJSON(t *testing.T) { }, nil } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: false, }, } @@ -724,19 +715,15 @@ func TestWebsitesGetJSON(t *testing.T) { func TestWebsitesUpdate(t *testing.T) { tests := []struct { name string - cmd *mockWebsitesUpdateCommand + cmd *mockCommand setupMocks func(*mockWebsitesServiceForCLI) wantErr bool errContains string }{ { name: "successful update with all parameters", - cmd: &mockWebsitesUpdateCommand{ - id: "1", - domain: "new-example.com", - cid: "QmNewHash", - targetType: "ipfs", - }, + cmd: newMockCommand().withArgs("1").withString(FlagDomain, "new-example.com").withString(FlagCID, "QmNewHash").withString(FlagTargetType, "ipfs", + ), setupMocks: func(svc *mockWebsitesServiceForCLI) { svc.updateFunc = func(ctx context.Context, id, domain, cid, targetType string) (*ipfs.WebsiteItem, error) { return &ipfs.WebsiteItem{ @@ -753,12 +740,8 @@ func TestWebsitesUpdate(t *testing.T) { }, { name: "successful update with domain only", - cmd: &mockWebsitesUpdateCommand{ - id: "1", - domain: "new-domain.com", - cid: "", - targetType: "", - }, + cmd: newMockCommand().withArgs("1").withString(FlagDomain, "new-domain.com").withString(FlagCID, "").withString(FlagTargetType, "", + ), setupMocks: func(svc *mockWebsitesServiceForCLI) { svc.updateFunc = func(ctx context.Context, id, domain, cid, targetType string) (*ipfs.WebsiteItem, error) { return &ipfs.WebsiteItem{ @@ -775,12 +758,8 @@ func TestWebsitesUpdate(t *testing.T) { }, { name: "successful update with cid only", - cmd: &mockWebsitesUpdateCommand{ - id: "1", - domain: "", - cid: "QmNewHash", - targetType: "", - }, + cmd: newMockCommand().withArgs("1").withString(FlagDomain, "").withString(FlagCID, "QmNewHash").withString(FlagTargetType, "", + ), setupMocks: func(svc *mockWebsitesServiceForCLI) { svc.updateFunc = func(ctx context.Context, id, domain, cid, targetType string) (*ipfs.WebsiteItem, error) { return &ipfs.WebsiteItem{ @@ -797,36 +776,24 @@ func TestWebsitesUpdate(t *testing.T) { }, { name: "missing website ID", - cmd: &mockWebsitesUpdateCommand{ - id: "", - domain: "new-example.com", - cid: "QmNewHash", - targetType: "ipfs", - }, + cmd: newMockCommand().withArgs("").withString(FlagDomain, "new-example.com").withString(FlagCID, "QmNewHash").withString(FlagTargetType, "ipfs", + ), setupMocks: func(svc *mockWebsitesServiceForCLI) {}, wantErr: true, errContains: "website ID or domain is required", }, { name: "missing update fields (all empty)", - cmd: &mockWebsitesUpdateCommand{ - id: "1", - domain: "", - cid: "", - targetType: "", - }, + cmd: newMockCommand().withArgs("1").withString(FlagDomain, "").withString(FlagCID, "").withString(FlagTargetType, "", + ), setupMocks: func(svc *mockWebsitesServiceForCLI) {}, wantErr: true, errContains: "at least one field must be provided for update", }, { name: "service error", - cmd: &mockWebsitesUpdateCommand{ - id: "1", - domain: "new-example.com", - cid: "QmNewHash", - targetType: "ipfs", - }, + cmd: newMockCommand().withArgs("1").withString(FlagDomain, "new-example.com").withString(FlagCID, "QmNewHash").withString(FlagTargetType, "ipfs", + ), setupMocks: func(svc *mockWebsitesServiceForCLI) { svc.updateFunc = func(ctx context.Context, id, domain, cid, targetType string) (*ipfs.WebsiteItem, error) { return nil, errors.New("website not found") @@ -840,7 +807,7 @@ func TestWebsitesUpdate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockSvc := &mockWebsitesServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(mockSvc) @@ -863,19 +830,15 @@ func TestWebsitesUpdate(t *testing.T) { func TestWebsitesUpdateJSON(t *testing.T) { tests := []struct { name string - cmd *mockWebsitesUpdateCommand + cmd *mockCommand setupMocks func(*mockWebsitesServiceForCLI) wantErr bool errContains string }{ { name: "successful update with JSON output", - cmd: &mockWebsitesUpdateCommand{ - id: "1", - domain: "new-example.com", - cid: "QmNewHash", - targetType: "ipfs", - }, + cmd: newMockCommand().withArgs("1").withString(FlagDomain, "new-example.com").withString(FlagCID, "QmNewHash").withString(FlagTargetType, "ipfs", + ), setupMocks: func(svc *mockWebsitesServiceForCLI) { svc.updateFunc = func(ctx context.Context, id, domain, cid, targetType string) (*ipfs.WebsiteItem, error) { return &ipfs.WebsiteItem{ @@ -892,12 +855,8 @@ func TestWebsitesUpdateJSON(t *testing.T) { }, { name: "successful update partial parameters with JSON output", - cmd: &mockWebsitesUpdateCommand{ - id: "1", - domain: "new-domain.com", - cid: "", - targetType: "", - }, + cmd: newMockCommand().withArgs("1").withString(FlagDomain, "new-domain.com").withString(FlagCID, "").withString(FlagTargetType, "", + ), setupMocks: func(svc *mockWebsitesServiceForCLI) { svc.updateFunc = func(ctx context.Context, id, domain, cid, targetType string) (*ipfs.WebsiteItem, error) { return &ipfs.WebsiteItem{ @@ -937,69 +896,6 @@ func TestWebsitesUpdateJSON(t *testing.T) { } } -// mockWebsitesCreateCommand is a mock implementation of commandGetter for testing. -type mockWebsitesCreateCommand struct { - domain string - cid string - targetType string -} - -func (m *mockWebsitesCreateCommand) String(name string) string { - switch name { - case FlagDomain: - return m.domain - case FlagCID: - return m.cid - case FlagTargetType: - return m.targetType - default: - return "" - } -} - -// mockWebsitesGetCommand is a mock implementation of commandGetter for testing. -type mockWebsitesGetCommand struct { - id string -} - -func (m *mockWebsitesGetCommand) String(name string) string { - return "" -} - -func (m *mockWebsitesGetCommand) Args() cli.Args { - if m.id == "" { - return &mockArgs{} - } - return &mockArgs{[]string{m.id}} -} - -// mockWebsitesUpdateCommand is a mock implementation of commandGetter for testing. -type mockWebsitesUpdateCommand struct { - id string - domain string - cid string - targetType string -} - -func (m *mockWebsitesUpdateCommand) String(name string) string { - switch name { - case FlagDomain: - return m.domain - case FlagCID: - return m.cid - case FlagTargetType: - return m.targetType - default: - return "" - } -} - -func (m *mockWebsitesUpdateCommand) Args() cli.Args { - if m.id == "" { - return &mockArgs{} - } - return &mockArgs{[]string{m.id}} -} // websitesUpdateWithService is a test helper that allows injecting a mock WebsitesService func websitesUpdateWithService(ctx context.Context, cmd interface { @@ -1011,12 +907,11 @@ func websitesUpdateWithService(ctx context.Context, cmd interface { } args := cmd.Args() - if args.Len() == 0 { + id := args.First() + if id == "" { return fmt.Errorf("website ID or domain is required") } - id := args.First() - domain := cmd.String(FlagDomain) cid := cmd.String(FlagCID) targetType := cmd.String(FlagTargetType) @@ -1056,7 +951,7 @@ func TestWebsitesDelete(t *testing.T) { tests := []struct { name string setupMocks func(*mockWebsitesServiceForCLI) - cmd *mockWebsitesGetCommand + cmd *mockCommand wantErr bool errContains string }{ @@ -1067,7 +962,7 @@ func TestWebsitesDelete(t *testing.T) { return nil } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: false, }, { @@ -1077,12 +972,12 @@ func TestWebsitesDelete(t *testing.T) { return nil } }, - cmd: &mockWebsitesGetCommand{id: "2"}, + cmd: newMockCommand().withArgs("2"), wantErr: false, }, { name: "missing website ID", - cmd: &mockWebsitesGetCommand{id: ""}, + cmd: newMockCommand().withArgs(""), wantErr: true, errContains: "website ID or domain is required", }, @@ -1093,7 +988,7 @@ func TestWebsitesDelete(t *testing.T) { return errors.New("website not found") } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: true, errContains: "website not found", }, @@ -1102,7 +997,7 @@ func TestWebsitesDelete(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockSvc := &mockWebsitesServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(mockSvc) @@ -1126,7 +1021,7 @@ func TestWebsitesDeleteJSON(t *testing.T) { tests := []struct { name string setupMocks func(*mockWebsitesServiceForCLI) - cmd *mockWebsitesGetCommand + cmd *mockCommand wantErr bool errContains string }{ @@ -1137,7 +1032,7 @@ func TestWebsitesDeleteJSON(t *testing.T) { return nil } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: false, }, } @@ -1172,12 +1067,11 @@ func websitesDeleteWithService(ctx context.Context, cmd interface{ Args() cli.Ar } args := cmd.Args() - if args.Len() == 0 { + id := args.First() + if id == "" { return fmt.Errorf("website ID or domain is required") } - id := args.First() - if err := websitesService.Delete(ctx, id); err != nil { return err } @@ -1199,7 +1093,7 @@ func TestWebsitesValidate(t *testing.T) { tests := []struct { name string setupMocks func(*mockWebsitesServiceForCLI) - cmd *mockWebsitesGetCommand + cmd *mockCommand wantErr bool errContains string }{ @@ -1215,7 +1109,7 @@ func TestWebsitesValidate(t *testing.T) { }, nil } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: false, }, { @@ -1230,12 +1124,12 @@ func TestWebsitesValidate(t *testing.T) { }, nil } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: false, }, { name: "missing website ID", - cmd: &mockWebsitesGetCommand{id: ""}, + cmd: newMockCommand().withArgs(), wantErr: true, errContains: "website ID or domain is required", }, @@ -1246,7 +1140,7 @@ func TestWebsitesValidate(t *testing.T) { return nil, errors.New("website not found") } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: false, }, } @@ -1254,7 +1148,7 @@ func TestWebsitesValidate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockSvc := &mockWebsitesServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(mockSvc) @@ -1278,7 +1172,7 @@ func TestWebsitesValidateJSON(t *testing.T) { tests := []struct { name string setupMocks func(*mockWebsitesServiceForCLI) - cmd *mockWebsitesGetCommand + cmd *mockCommand wantErr bool errContains string }{ @@ -1294,7 +1188,7 @@ func TestWebsitesValidateJSON(t *testing.T) { }, nil } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: false, }, { @@ -1309,7 +1203,7 @@ func TestWebsitesValidateJSON(t *testing.T) { }, nil } }, - cmd: &mockWebsitesGetCommand{id: "1"}, + cmd: newMockCommand().withArgs("1"), wantErr: false, }, } @@ -1338,7 +1232,7 @@ func TestWebsitesValidateJSON(t *testing.T) { } // websitesValidateWithService is a test helper that allows injecting a mock WebsitesService -func websitesValidateWithService(ctx context.Context, cmd interface{ Args() cli.Args }, output Output, websitesService WebsitesService) error { +func websitesValidateWithService(ctx context.Context, cmd websitesCommandGetter, output Output, websitesService WebsitesService) error { return doWebsitesValidate(ctx, cmd, output, websitesService) } @@ -1410,13 +1304,13 @@ func TestWebsitesSSLStatus(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockSvc := &mockWebsitesServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(mockSvc) } - cmd := &mockSSLStatusCommand{domain: "example.com"} + cmd := newMockCommand().withArgs("example.com") err := websitesSSLStatusWithService(context.Background(), cmd, output, mockSvc) if tt.wantErr { @@ -1431,24 +1325,15 @@ func TestWebsitesSSLStatus(t *testing.T) { } } -// mockSSLStatusCommand is a mock command for testing SSL status -type mockSSLStatusCommand struct { - domain string -} - -func (m *mockSSLStatusCommand) Args() cli.Args { - return &mockArgs{[]string{m.domain}} -} // websitesSSLStatusWithService is a test helper that allows injecting a mock WebsitesService -func websitesSSLStatusWithService(ctx context.Context, cmd interface{ Args() cli.Args }, output Output, websitesService WebsitesService) error { +func websitesSSLStatusWithService(ctx context.Context, cmd websitesCommandGetter, output Output, websitesService WebsitesService) error { args := cmd.Args() - if args.Len() == 0 { + domain := args.First() + if domain == "" { return fmt.Errorf("domain is required") } - domain := args.First() - website, err := websitesService.GetSSLStatus(ctx, domain) if err != nil { return err @@ -1524,7 +1409,7 @@ func TestWebsitesConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockSvc := &mockWebsitesServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(mockSvc) @@ -1622,58 +1507,18 @@ func websitesConfigWithService(ctx context.Context, output Output, websitesServi return nil } -// mockEnableIPNSCommand is a mock for testing enable-ipns -type mockEnableIPNSCommand struct { - id string - cid string -} - -func (m *mockEnableIPNSCommand) IsSet(name string) bool { - switch name { - case FlagCID: - return m.cid != "" - default: - return false - } -} - -func (m *mockEnableIPNSCommand) String(name string) string { - switch name { - case FlagCID: - return m.cid - default: - return "" - } -} - -func (m *mockEnableIPNSCommand) Args() cli.Args { - if m.id == "" { - return &mockArgs{} - } - return &mockArgs{[]string{m.id}} -} - -func (m *mockEnableIPNSCommand) Bool(name string) bool { - return false -} - -func (m *mockEnableIPNSCommand) Int(name string) int { - return 0 -} func TestWebsitesEnableIPNS(t *testing.T) { tests := []struct { name string - cmd *mockEnableIPNSCommand + cmd *mockCommand setupMocks func(*mockWebsitesServiceForCLI) wantErr bool errContains string }{ { name: "enable ipns without cid", - cmd: &mockEnableIPNSCommand{ - id: "1", - }, + cmd: newMockCommand().withArgs("1"), setupMocks: func(svc *mockWebsitesServiceForCLI) { svc.updateWithOptionsFunc = func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { require.NotNil(t, req.TargetType) @@ -1693,10 +1538,7 @@ func TestWebsitesEnableIPNS(t *testing.T) { }, { name: "enable ipns with cid", - cmd: &mockEnableIPNSCommand{ - id: "1", - cid: "QmNewHash", - }, + cmd: newMockCommand().withArgs("1").withString(FlagCID, "QmNewHash").withIsSet(FlagCID, true), setupMocks: func(svc *mockWebsitesServiceForCLI) { svc.updateWithOptionsFunc = func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { require.NotNil(t, req.TargetType) @@ -1717,18 +1559,14 @@ func TestWebsitesEnableIPNS(t *testing.T) { }, { name: "missing website id", - cmd: &mockEnableIPNSCommand{ - id: "", - }, + cmd: newMockCommand().withArgs(""), setupMocks: func(svc *mockWebsitesServiceForCLI) {}, wantErr: true, errContains: "website ID or domain is required", }, { name: "service error", - cmd: &mockEnableIPNSCommand{ - id: "1", - }, + cmd: newMockCommand().withArgs("1"), setupMocks: func(svc *mockWebsitesServiceForCLI) { svc.updateWithOptionsFunc = func(ctx context.Context, id string, req ipfs.WebsiteUpdateRequest) (*ipfs.WebsiteItem, error) { return nil, errors.New("website not found") @@ -1742,7 +1580,7 @@ func TestWebsitesEnableIPNS(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockSvc := &mockWebsitesServiceForCLI{} - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() if tt.setupMocks != nil { tt.setupMocks(mockSvc) @@ -1772,11 +1610,12 @@ func websitesEnableIPNSWithService(ctx context.Context, cmd interface { } args := cmd.Args() - if args.Len() == 0 { + idArg := args.First() + if idArg == "" { return fmt.Errorf("website ID or domain is required") } - id, err := resolveWebsiteID(ctx, websitesService, args.First()) + id, err := resolveWebsiteID(ctx, websitesService, idArg) if err != nil { return err } @@ -1804,3 +1643,189 @@ func websitesEnableIPNSWithService(ctx context.Context, cmd interface { return nil } + +func TestStripValidationPrefix(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"lumeweb-verify=abc123", "abc123"}, + {"key=value", "value"}, + {"no-prefix", "no-prefix"}, + {"=", ""}, + {"a=b=c", "b=c"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := stripValidationPrefix(tt.input) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestResolveWebsiteID(t *testing.T) { + t.Run("numeric ID returned as-is", func(t *testing.T) { + mockSvc := &mockWebsitesServiceForCLI{} + id, err := resolveWebsiteID(context.Background(), mockSvc, "42") + require.NoError(t, err) + require.Equal(t, "42", id) + }) + + t.Run("domain resolved via list", func(t *testing.T) { + mockSvc := &mockWebsitesServiceForCLI{} + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{ + {Id: 7, Domain: "example.com"}, + {Id: 8, Domain: "other.com"}, + }, nil + } + id, err := resolveWebsiteID(context.Background(), mockSvc, "example.com") + require.NoError(t, err) + require.Equal(t, "7", id) + }) + + t.Run("domain not found", func(t *testing.T) { + mockSvc := &mockWebsitesServiceForCLI{} + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{}, nil + } + _, err := resolveWebsiteID(context.Background(), mockSvc, "missing.com") + require.Error(t, err) + require.Contains(t, err.Error(), "website not found for domain") + }) + + t.Run("list service error", func(t *testing.T) { + mockSvc := &mockWebsitesServiceForCLI{} + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return nil, errors.New("service down") + } + _, err := resolveWebsiteID(context.Background(), mockSvc, "example.com") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to look up website by domain") + }) +} + +func TestResolveAndGetWebsite(t *testing.T) { + t.Run("numeric ID fetches directly", func(t *testing.T) { + mockSvc := &mockWebsitesServiceForCLI{} + mockSvc.getFunc = func(ctx context.Context, id string) (*ipfs.WebsiteItem, error) { + return &ipfs.WebsiteItem{Id: 42, Domain: "example.com"}, nil + } + website, err := resolveAndGetWebsite(context.Background(), mockSvc, "42") + require.NoError(t, err) + require.Equal(t, 42, website.Id) + }) + + t.Run("domain resolves then fetches", func(t *testing.T) { + mockSvc := &mockWebsitesServiceForCLI{} + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{{Id: 7, Domain: "example.com"}}, nil + } + mockSvc.getFunc = func(ctx context.Context, id string) (*ipfs.WebsiteItem, error) { + return &ipfs.WebsiteItem{Id: 7, Domain: "example.com"}, nil + } + website, err := resolveAndGetWebsite(context.Background(), mockSvc, "example.com") + require.NoError(t, err) + require.Equal(t, 7, website.Id) + }) + + t.Run("domain not found returns error", func(t *testing.T) { + mockSvc := &mockWebsitesServiceForCLI{} + mockSvc.listFunc = func(ctx context.Context) ([]ipfs.WebsiteItem, error) { + return []ipfs.WebsiteItem{}, nil + } + _, err := resolveAndGetWebsite(context.Background(), mockSvc, "missing.com") + require.Error(t, err) + }) +} + +func TestPrintWebsiteUpdateResult(t *testing.T) { + t.Run("active website without gateway", func(t *testing.T) { + output := newTestOutput() + website := &ipfs.WebsiteItem{ + Id: 1, + Domain: "example.com", + TargetHash: "QmXxx", + TargetType: "ipfs", + Status: "active", + Created: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), + } + printWebsiteUpdateResult(output, website, "Website updated") + }) + + t.Run("inactive website shows token expired", func(t *testing.T) { + output := newTestOutput() + website := &ipfs.WebsiteItem{ + Id: 1, + Domain: "example.com", + TargetHash: "QmXxx", + TargetType: "ipfs", + Status: "pending", + Expired: true, + Created: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), + } + printWebsiteUpdateResult(output, website, "Website updated") + }) + + t.Run("website with gateway domain", func(t *testing.T) { + output := newTestOutput() + gateway := "gw.pinner.xyz" + ipnsKeyID := 5 + website := &ipfs.WebsiteItem{ + Id: 1, + Domain: "example.com", + TargetHash: "QmXxx", + TargetType: "ipfs", + Status: "active", + GatewayDomain: &gateway, + IpnsKeyId: &ipnsKeyID, + Created: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), + } + printWebsiteUpdateResult(output, website, "IPNS enabled") + }) +} + +func TestShowDNSRecordInstructions(t *testing.T) { + t.Run("nil website returns early", func(t *testing.T) { + output := newTestOutput() + showDNSRecordInstructions(output, nil, nil) + }) + + t.Run("dns hosting enabled", func(t *testing.T) { + output := newTestOutput() + website := &ipfs.WebsiteItem{ + Domain: "example.com", + DnsHostingEnabled: true, + } + showDNSRecordInstructions(output, website, []string{"ns1.pinner.xyz", "ns2.pinner.xyz"}) + }) + + t.Run("self-managed dns", func(t *testing.T) { + output := newTestOutput() + website := &ipfs.WebsiteItem{ + Domain: "example.com", + TargetHash: "QmXxx", + TargetType: "ipfs", + ValidationToken: "lumeweb-verify=abc123", + } + showDNSRecordInstructions(output, website, nil) + }) +} + +func TestShowConfigDNSRecords(t *testing.T) { + t.Run("with gateway domain", func(t *testing.T) { + output := newTestOutput() + gateway := "gw.pinner.xyz" + config := &ipfs.WebsiteConfigResponse{ + GatewayDomain: &gateway, + } + showConfigDNSRecords(output, config) + }) + + t.Run("without gateway domain", func(t *testing.T) { + output := newTestOutput() + config := &ipfs.WebsiteConfigResponse{} + showConfigDNSRecords(output, config) + }) +} diff --git a/pkg/cli/websites_wizard.go b/pkg/cli/websites_wizard.go index b500ac2..9c7a0a3 100644 --- a/pkg/cli/websites_wizard.go +++ b/pkg/cli/websites_wizard.go @@ -259,12 +259,12 @@ func runWebsitesWizard(ctx context.Context, cmd *cli.Command, output Output) err var websitesService WebsitesService authToken := GetAuthToken(cmd, cfgMgr) - secure := GetSecureSetting(cmd, cfgMgr) + + var svcOpts []WebsitesServiceOption if authToken != "" { - websitesService = NewWebsitesService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) - } else { - websitesService = defaultWebsitesServiceFactory(cfgMgr, output) + svcOpts = append(svcOpts, WithWebsitesAuthToken(authToken)) } + websitesService = defaultWebsitesServiceFactory(cfgMgr, output, svcOpts...) ui := NewPTermWebsitesUI(output) diff --git a/pkg/cli/websites_wizard_mock.go b/pkg/cli/websites_wizard_mock.go index 73fbe62..bb149ef 100644 --- a/pkg/cli/websites_wizard_mock.go +++ b/pkg/cli/websites_wizard_mock.go @@ -14,7 +14,6 @@ type MockWebsitesUI struct { mu sync.Mutex - // Track execution choices ContentChoice ContentSourceChoice DNSChoice DNSModeChoice TargetChoice TargetTypeChoice @@ -22,10 +21,17 @@ type MockWebsitesUI struct { DomainInput string PromptError error - // Control behavior ContinueError error - // Track state + SelectResult int + SelectString string + SelectErr error + ContinueErr error + + StartErr error + StopErr error + Messages []string + AuthCheckExecuted bool ContentSourceExecuted bool TargetTypeExecuted bool @@ -237,3 +243,33 @@ func (m *MockWebsitesUI) ExecuteValidateStep(_ context.Context, w *WebsitesWizar w.SetValidateRetry(false) return nil } + +func (m *MockWebsitesUI) Select(label string, items []string) (int, string, error) { + return m.SelectResult, m.SelectString, m.SelectErr +} + +func (m *MockWebsitesUI) Continue() error { + return m.ContinueErr +} + +func (m *MockWebsitesUI) Start(message string) error { + m.Messages = append(m.Messages, "start:"+message) + return m.StartErr +} + +func (m *MockWebsitesUI) UpdateText(message string) { + m.Messages = append(m.Messages, "update:"+message) +} + +func (m *MockWebsitesUI) Success(message string) { + m.Messages = append(m.Messages, "success:"+message) +} + +func (m *MockWebsitesUI) Fail(message string) { + m.Messages = append(m.Messages, "fail:"+message) +} + +func (m *MockWebsitesUI) Stop() error { + m.Messages = append(m.Messages, "stop") + return m.StopErr +} diff --git a/pkg/cli/websites_wizard_pterm.go b/pkg/cli/websites_wizard_pterm.go index 26679b0..41d40f4 100644 --- a/pkg/cli/websites_wizard_pterm.go +++ b/pkg/cli/websites_wizard_pterm.go @@ -15,6 +15,9 @@ import ( // PTermWebsitesUI implements WebsitesUI using PTerm for display. type PTermWebsitesUI struct { *wizard.PTermUI + *PTermSelectPrompter + *PTermContinuePrompter + *PTermSpinner output Output wizard *WebsitesWizard } @@ -22,8 +25,11 @@ type PTermWebsitesUI struct { // NewPTermWebsitesUI creates a new PTerm-based websites UI. func NewPTermWebsitesUI(output Output) *PTermWebsitesUI { return &PTermWebsitesUI{ - PTermUI: wizard.NewPTermUI("", ""), - output: output, + PTermUI: wizard.NewPTermUI("", ""), + PTermSelectPrompter: &PTermSelectPrompter{}, + PTermContinuePrompter: &PTermContinuePrompter{}, + PTermSpinner: &PTermSpinner{}, + output: output, } } @@ -52,8 +58,7 @@ func (ui *PTermWebsitesUI) ShowWelcome() error { pterm.Println() - _, err := pterm.DefaultInteractiveContinue.Show() - return err + return ui.Continue() } // ShowCompletion displays the completion message. @@ -114,12 +119,7 @@ func (ui *PTermWebsitesUI) ExecuteContentSourceStep(_ context.Context, w *Websit "No, I need to upload content first", } - prompt := promptui.Select{ - Label: "Have you already uploaded content to IPFS?", - Items: choices, - } - - idx, _, err := runSelect(&prompt) + idx, _, err := ui.Select("Have you already uploaded content to IPFS?", choices) if err != nil { return fmt.Errorf("prompt failed: %w", err) } @@ -168,12 +168,7 @@ func (ui *PTermWebsitesUI) ExecuteTargetTypeStep(_ context.Context, w *WebsitesW "IPNS (mutable name, updates automatically)", } - prompt := promptui.Select{ - Label: "What type of content link do you want to use?", - Items: choices, - } - - idx, _, err := runSelect(&prompt) + idx, _, err := ui.Select("What type of content link do you want to use?", choices) if err != nil { return fmt.Errorf("prompt failed: %w", err) } @@ -230,12 +225,7 @@ func (ui *PTermWebsitesUI) ExecuteDNSModeStep(_ context.Context, w *WebsitesWiza "I'll manage DNS myself", } - prompt := promptui.Select{ - Label: "How would you like to manage DNS for this website?", - Items: choices, - } - - idx, _, err := runSelect(&prompt) + idx, _, err := ui.Select("How would you like to manage DNS for this website?", choices) if err != nil { return fmt.Errorf("prompt failed: %w", err) } @@ -259,14 +249,16 @@ func (ui *PTermWebsitesUI) ExecuteDNSModeStep(_ context.Context, w *WebsitesWiza } func (ui *PTermWebsitesUI) ExecuteCreateWebsiteStep(ctx context.Context, w *WebsitesWizard) error { - spinner, _ := pterm.DefaultSpinner.Start("Creating website...") + if err := ui.Start("Creating website..."); err != nil { + return fmt.Errorf("failed to start spinner: %w", err) + } if err := w.executeCreateWebsite(ctx); err != nil { - spinner.Fail("Failed to create website") + ui.Fail("Failed to create website") return err } - spinner.Success("Website created successfully!") + ui.Success("Website created successfully!") return nil } @@ -295,7 +287,9 @@ func (ui *PTermWebsitesUI) executeManagedDNSValidation(ctx context.Context, w *W pterm.Info.Println("Validating website configuration...") pterm.Println() - spinner, _ := pterm.DefaultSpinner.Start("Waiting for DNS records to propagate...") + if err := ui.Start("Waiting for DNS records to propagate..."); err != nil { + return fmt.Errorf("failed to start spinner: %w", err) + } var lastErr error err := retry.Do( @@ -323,7 +317,7 @@ func (ui *PTermWebsitesUI) executeManagedDNSValidation(ctx context.Context, w *W pterm.Debug.Printf("Validation attempt %d failed: %v\n", n+1, err) }), ) - spinner.Stop() + ui.Stop() if err != nil { if lastErr != nil { diff --git a/pkg/cli/websites_wizard_test.go b/pkg/cli/websites_wizard_test.go index 0642dad..9fdb65f 100644 --- a/pkg/cli/websites_wizard_test.go +++ b/pkg/cli/websites_wizard_test.go @@ -28,7 +28,7 @@ func TestWebsitesWizard_Run(t *testing.T) { mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) result, err := w.Run(context.Background()) @@ -59,7 +59,7 @@ func TestWebsitesWizard_Run(t *testing.T) { mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) result, err := w.Run(context.Background()) @@ -88,7 +88,7 @@ func TestWebsitesWizard_Run(t *testing.T) { mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) result, err := w.Run(context.Background()) @@ -115,7 +115,7 @@ func TestWebsitesWizard_Run(t *testing.T) { mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) result, err := w.Run(context.Background()) @@ -141,7 +141,7 @@ func TestWebsitesWizard_Run(t *testing.T) { mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) result, err := w.Run(context.Background()) @@ -163,7 +163,7 @@ func TestWebsitesWizard_Run(t *testing.T) { mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) _, err := w.Run(context.Background()) @@ -184,7 +184,7 @@ func TestWebsitesWizard_Run(t *testing.T) { mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) _, err := w.Run(context.Background()) @@ -208,7 +208,7 @@ func TestWebsitesWizard_Run(t *testing.T) { mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) result, err := w.Run(context.Background()) @@ -238,7 +238,7 @@ func TestWebsitesWizard_Run(t *testing.T) { }, } - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) _, err := w.Run(context.Background()) @@ -271,7 +271,7 @@ func TestWebsitesWizard_Run(t *testing.T) { }, } - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) result, err := w.Run(context.Background()) @@ -307,7 +307,7 @@ func TestWebsitesWizard_Run(t *testing.T) { }, } - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) result, err := w.Run(context.Background()) @@ -337,7 +337,7 @@ func TestWebsitesWizard_Run(t *testing.T) { }, } - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) result, err := w.Run(context.Background()) @@ -359,7 +359,7 @@ func TestWebsitesWizard_Accessors(t *testing.T) { cfgMgr.EXPECT().Config().Return(cfg).Maybe() mockUI := NewMockWebsitesUI() - output := NewOutputFormatter(false, false, false, false) + output := newTestOutput() mockWebsitesSvc := &mockWebsitesServiceForCLI{} w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, output) @@ -380,7 +380,7 @@ func TestWebsitesWizard_Setters(t *testing.T) { mockUI := NewMockWebsitesUI() mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) require.Equal(t, "", w.CID()) require.Equal(t, "", w.Domain()) @@ -423,7 +423,7 @@ func TestWebsitesWizard_StepCalls(t *testing.T) { mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) _, err := w.Run(context.Background()) require.NoError(t, err) @@ -476,7 +476,7 @@ func TestWebsitesWizard_UIError(t *testing.T) { mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mockUI, newTestOutput()) _, err := w.Run(context.Background()) require.Error(t, err) @@ -497,7 +497,7 @@ func TestMockWebsitesUI(t *testing.T) { cfgMgr.EXPECT().Config().Return(cfg).Maybe() mockWebsitesSvc := &mockWebsitesServiceForCLI{} - w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mock, NewOutputFormatter(false, false, false, false)) + w := NewWebsitesWizard(mockWebsitesSvc, cfgMgr, mock, newTestOutput()) _ = mock.ExecuteAuthCheckStep(context.Background(), w) diff --git a/pkg/cli/websites_wizard_ui.go b/pkg/cli/websites_wizard_ui.go index bda10df..c3d83e9 100644 --- a/pkg/cli/websites_wizard_ui.go +++ b/pkg/cli/websites_wizard_ui.go @@ -9,6 +9,9 @@ import ( // WebsitesUI defines the interface for websites wizard UI interactions. type WebsitesUI interface { wizard.UI + SelectPrompter + ContinuePrompter + Spinner // Step execution ExecuteAuthCheckStep(ctx context.Context, w *WebsitesWizard) error From 799604f0d84302120fd85e3c88ec274f30804ed2 Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Thu, 11 Jun 2026 17:01:55 +0000 Subject: [PATCH 2/2] fix(cli): add nil guards and propagate --secure flag to all services - Add s.service nil guards to ipnsService and websitesService methods, preventing panics when ipfs.NewClient() fails (matching dnsServiceCLI) - Propagate --secure flag through newAuthenticated*Service constructors for DNS, Websites, and IPNS services - Pass secure bool to unpin, unpinAll, status, and download handlers - Use GetIPFSEndpointWithSecure(secure) instead of GetIPFSEndpoint() or GetIPFSEndpointSecure() in all handler and factory code - Fix dry-run displays to respect the secure flag - Update PinningServiceFactory and all callers to accept secure bool --- pkg/cli/bench.go | 2 +- pkg/cli/dns.go | 60 ++++++++++---------- pkg/cli/dns_service.go | 10 ++-- pkg/cli/dns_test.go | 96 ++++++++++++++++---------------- pkg/cli/download.go | 23 ++++---- pkg/cli/download_service.go | 6 ++ pkg/cli/download_test.go | 26 ++++----- pkg/cli/ipns.go | 42 +++++++------- pkg/cli/ipns_service.go | 31 +++++++++-- pkg/cli/ipns_test.go | 56 +++++++++---------- pkg/cli/list.go | 2 +- pkg/cli/list_test.go | 6 +- pkg/cli/pin.go | 8 +-- pkg/cli/pin_test.go | 8 +-- pkg/cli/pinning_service.go | 2 +- pkg/cli/pins_add.go | 2 +- pkg/cli/pins_add_test.go | 12 ++-- pkg/cli/pins_rm.go | 5 +- pkg/cli/pins_status.go | 3 +- pkg/cli/pins_update.go | 2 +- pkg/cli/pins_update_test.go | 20 +++---- pkg/cli/status.go | 9 +-- pkg/cli/status_test.go | 4 +- pkg/cli/unpin.go | 11 ++-- pkg/cli/unpin_all.go | 11 ++-- pkg/cli/unpin_all_test.go | 16 +++--- pkg/cli/unpin_test.go | 8 +-- pkg/cli/upload.go | 6 +- pkg/cli/upload_test.go | 2 +- pkg/cli/websites.go | 48 ++++++++-------- pkg/cli/websites_handler_test.go | 90 +++++++++++++++--------------- pkg/cli/websites_service.go | 34 +++++++++-- pkg/cli/websites_ssl.go | 6 +- pkg/cli/websites_wizard.go | 3 +- 34 files changed, 365 insertions(+), 305 deletions(-) diff --git a/pkg/cli/bench.go b/pkg/cli/bench.go index 7a97024..657266c 100644 --- a/pkg/cli/bench.go +++ b/pkg/cli/bench.go @@ -261,7 +261,7 @@ func bench(ctx context.Context, cmd interface { RenderDryRun(output, DryRunPreview{ Operation: "benchmark", - Endpoint: cfgMgr.Config().GetIPFSEndpointSecure(), + Endpoint: cfgMgr.Config().GetIPFSEndpointWithSecure(secure), Options: options, }) return nil diff --git a/pkg/cli/dns.go b/pkg/cli/dns.go index 7ec397c..7170414 100644 --- a/pkg/cli/dns.go +++ b/pkg/cli/dns.go @@ -84,7 +84,7 @@ Examples: pinner dns zones list pinner dns zones list --json`, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return dnsZonesList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return dnsZonesList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -104,7 +104,7 @@ Examples: NameserversFlag(), }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return dnsZonesCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return dnsZonesCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -120,7 +120,7 @@ Examples: pinner dns zones get example.com --json`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return dnsZonesGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return dnsZonesGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -135,7 +135,7 @@ Examples: pinner dns zones delete example.com`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return dnsZonesDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return dnsZonesDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -152,7 +152,7 @@ Examples: pinner dns zones validate example.com --json`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return dnsZonesValidate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return dnsZonesValidate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -192,7 +192,7 @@ Examples: pinner dns records list example.com --json`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return dnsRecordsList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return dnsRecordsList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -216,7 +216,7 @@ Examples: DisabledFlag(), }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return dnsRecordsCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return dnsRecordsCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -236,7 +236,7 @@ Examples: RequiredTypeFlag(), }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return dnsRecordsGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return dnsRecordsGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -260,7 +260,7 @@ Examples: DisabledFlag(), }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return dnsRecordsUpdate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return dnsRecordsUpdate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -279,15 +279,15 @@ Examples: RequiredTypeFlag(), }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return dnsRecordsDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return dnsRecordsDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } // ===== HANDLERS ===== -func dnsZonesList(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { - dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) +func dnsZonesList(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -325,7 +325,7 @@ func dnsZonesList(ctx context.Context, cmd dnsCommandGetter, output Output, cfgM return nil } -func dnsZonesCreate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { +func dnsZonesCreate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { domain := cmd.String(FlagDomain) if err := validateDomain(domain); err != nil { @@ -338,7 +338,7 @@ func dnsZonesCreate(ctx context.Context, cmd dnsCommandGetter, output Output, cf nameservers = parseCommaSeparated(nameserversStr) } - dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -367,7 +367,7 @@ func dnsZonesCreate(ctx context.Context, cmd dnsCommandGetter, output Output, cf return nil } -func dnsZonesGet(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { +func dnsZonesGet(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -375,7 +375,7 @@ func dnsZonesGet(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMg arg := args.First() - dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -408,7 +408,7 @@ func dnsZonesGet(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMg return nil } -func dnsZonesDelete(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { +func dnsZonesDelete(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -416,7 +416,7 @@ func dnsZonesDelete(ctx context.Context, cmd dnsCommandGetter, output Output, cf arg := args.First() - dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -435,7 +435,7 @@ func dnsZonesDelete(ctx context.Context, cmd dnsCommandGetter, output Output, cf return nil } -func dnsZonesValidate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { +func dnsZonesValidate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -443,7 +443,7 @@ func dnsZonesValidate(ctx context.Context, cmd dnsCommandGetter, output Output, arg := args.First() - dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -494,7 +494,7 @@ func dnsZonesValidate(ctx context.Context, cmd dnsCommandGetter, output Output, return nil } -func dnsRecordsList(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { +func dnsRecordsList(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -502,7 +502,7 @@ func dnsRecordsList(ctx context.Context, cmd dnsCommandGetter, output Output, cf arg := args.First() - dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -546,7 +546,7 @@ func dnsRecordsList(ctx context.Context, cmd dnsCommandGetter, output Output, cf return nil } -func dnsRecordsCreate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { +func dnsRecordsCreate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -576,7 +576,7 @@ func dnsRecordsCreate(ctx context.Context, cmd dnsCommandGetter, output Output, Disabled: &disabled, } - dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -611,7 +611,7 @@ func dnsRecordsCreate(ctx context.Context, cmd dnsCommandGetter, output Output, return nil } -func dnsRecordsGet(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { +func dnsRecordsGet(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -621,7 +621,7 @@ func dnsRecordsGet(ctx context.Context, cmd dnsCommandGetter, output Output, cfg name := cmd.String(FlagName) recordType := cmd.String(FlagType) - dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -657,7 +657,7 @@ func dnsRecordsGet(ctx context.Context, cmd dnsCommandGetter, output Output, cfg return nil } -func dnsRecordsUpdate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { +func dnsRecordsUpdate(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -687,7 +687,7 @@ func dnsRecordsUpdate(ctx context.Context, cmd dnsCommandGetter, output Output, Disabled: &disabled, } - dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -722,7 +722,7 @@ func dnsRecordsUpdate(ctx context.Context, cmd dnsCommandGetter, output Output, return nil } -func dnsRecordsDelete(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { +func dnsRecordsDelete(ctx context.Context, cmd dnsCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain or zone ID is required") @@ -732,7 +732,7 @@ func dnsRecordsDelete(ctx context.Context, cmd dnsCommandGetter, output Output, name := cmd.String(FlagName) recordType := cmd.String(FlagType) - dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken) + dnsService, err := newAuthenticatedDNSService(cfgMgr, output, authToken, secure) if err != nil { return err } diff --git a/pkg/cli/dns_service.go b/pkg/cli/dns_service.go index acc5bdc..56bdef4 100644 --- a/pkg/cli/dns_service.go +++ b/pkg/cli/dns_service.go @@ -190,21 +190,21 @@ func (s *dnsServiceCLI) DeleteRecord(ctx context.Context, id string, name string return s.service.DeleteRecord(ctx, id, name, recordType) } -type dnsServiceFactoryFunc func(cfgMgr config.Manager, output Output, opts ...DNSServiceOption) DNSService +type dnsServiceFactoryFunc func(cfgMgr config.Manager, output Output, secure bool, opts ...DNSServiceOption) DNSService var dnsServiceFactory dnsServiceFactoryFunc = defaultDNSServiceFactory -func defaultDNSServiceFactory(cfgMgr config.Manager, output Output, opts ...DNSServiceOption) DNSService { - apiEndpoint := cfgMgr.Config().GetIPFSEndpointSecure() +func defaultDNSServiceFactory(cfgMgr config.Manager, output Output, secure bool, opts ...DNSServiceOption) DNSService { + apiEndpoint := cfgMgr.Config().GetIPFSEndpointWithSecure(secure) return NewDNSService(cfgMgr, output, apiEndpoint, opts...) } -func newAuthenticatedDNSService(cfgMgr config.Manager, output Output, authToken string) (DNSService, error) { +func newAuthenticatedDNSService(cfgMgr config.Manager, output Output, authToken string, secure bool) (DNSService, error) { var svcOpts []DNSServiceOption if authToken != "" { svcOpts = append(svcOpts, WithDNSAuthToken(authToken)) } - dnsService := dnsServiceFactory(cfgMgr, output, svcOpts...) + dnsService := dnsServiceFactory(cfgMgr, output, secure, svcOpts...) if err := dnsService.RequireAuthenticated(); err != nil { return nil, err } diff --git a/pkg/cli/dns_test.go b/pkg/cli/dns_test.go index 50b3b8b..a5b4076 100644 --- a/pkg/cli/dns_test.go +++ b/pkg/cli/dns_test.go @@ -113,7 +113,7 @@ func setupDNSHandlerTest(t *testing.T) (*mockDNSServiceForCLI, *configmocks.Mock origFactory := dnsServiceFactory t.Cleanup(func() { dnsServiceFactory = origFactory }) - dnsServiceFactory = func(config.Manager, Output, ...DNSServiceOption) DNSService { + dnsServiceFactory = func(config.Manager, Output, bool, ...DNSServiceOption) DNSService { return mockSvc } @@ -134,7 +134,7 @@ func TestDnsZonesList_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -146,7 +146,7 @@ func TestDnsZonesList_Empty(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -158,7 +158,7 @@ func TestDnsZonesList_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "failed to list zones") } @@ -169,7 +169,7 @@ func TestDnsZonesList_Unauthenticated(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "") + err := dnsZonesList(context.Background(), cmd, output, cfgMgr, "", true) require.Error(t, err) assert.True(t, errors.Is(err, ErrNotAuthenticated)) } @@ -187,7 +187,7 @@ func TestDnsZonesCreate_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withString(FlagDomain, "example.com") - err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -204,7 +204,7 @@ func TestDnsZonesCreate_WithNameservers(t *testing.T) { cmd := newMockCommand(). withString(FlagDomain, "example.com"). withString(FlagNameservers, "ns1.example.com,ns2.example.com") - err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -213,7 +213,7 @@ func TestDnsZonesCreate_EmptyDomain(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withString(FlagDomain, "") - err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain cannot be empty") } @@ -223,7 +223,7 @@ func TestDnsZonesCreate_InvalidDomain(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withString(FlagDomain, "a..b") - err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "invalid domain format") } @@ -236,7 +236,7 @@ func TestDnsZonesCreate_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withString(FlagDomain, "example.com") - err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "failed to create zone") } @@ -256,7 +256,7 @@ func TestDnsZonesGet_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -270,7 +270,7 @@ func TestDnsZonesGet_NumericID(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("42") - err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -279,7 +279,7 @@ func TestDnsZonesGet_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain or zone ID is required") } @@ -292,7 +292,7 @@ func TestDnsZonesGet_ZoneNotFound(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("nonexistent.com") - err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "zone not found") } @@ -305,7 +305,7 @@ func TestDnsZonesGet_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "server error") } @@ -325,7 +325,7 @@ func TestDnsZonesDelete_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -338,7 +338,7 @@ func TestDnsZonesDelete_NumericID(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("42") - err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -347,7 +347,7 @@ func TestDnsZonesDelete_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain or zone ID is required") } @@ -360,7 +360,7 @@ func TestDnsZonesDelete_ZoneNotFound(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("nonexistent.com") - err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "zone not found") } @@ -373,7 +373,7 @@ func TestDnsZonesDelete_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "failed to delete zone") } @@ -396,7 +396,7 @@ func TestDnsZonesValidate_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -416,7 +416,7 @@ func TestDnsZonesValidate_ValidationFailure(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) // validation failure is not an error, it's a result } @@ -425,7 +425,7 @@ func TestDnsZonesValidate_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain or zone ID is required") } @@ -442,7 +442,7 @@ func TestDnsZonesValidate_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsZonesValidate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "failed to validate zone") } @@ -461,7 +461,7 @@ func TestDnsRecordsList_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -473,7 +473,7 @@ func TestDnsRecordsList_Empty(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -482,7 +482,7 @@ func TestDnsRecordsList_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain or zone ID is required") } @@ -495,7 +495,7 @@ func TestDnsRecordsList_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "failed to list records") } @@ -513,7 +513,7 @@ func TestDnsRecordsList_DomainArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -535,7 +535,7 @@ func TestDnsRecordsCreate_Success(t *testing.T) { withString(FlagName, "www"). withString(FlagType, "CNAME"). withString(FlagContent, "example.com") - err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -553,7 +553,7 @@ func TestDnsRecordsCreate_ARecord(t *testing.T) { withString(FlagName, "@"). withString(FlagType, "A"). withString(FlagContent, "1.2.3.4") - err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -565,7 +565,7 @@ func TestDnsRecordsCreate_MissingArg(t *testing.T) { withString(FlagName, "www"). withString(FlagType, "CNAME"). withString(FlagContent, "example.com") - err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain or zone ID is required") } @@ -579,7 +579,7 @@ func TestDnsRecordsCreate_InvalidRecordType(t *testing.T) { withString(FlagName, "www"). withString(FlagType, "INVALID"). withString(FlagContent, "example.com") - err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "unsupported record type") } @@ -593,7 +593,7 @@ func TestDnsRecordsCreate_InvalidARecordContent(t *testing.T) { withString(FlagName, "www"). withString(FlagType, "A"). withString(FlagContent, "not-an-ip") - err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "invalid IPv4 address") } @@ -610,7 +610,7 @@ func TestDnsRecordsCreate_ServiceError(t *testing.T) { withString(FlagName, "www"). withString(FlagType, "CNAME"). withString(FlagContent, "example.com") - err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "failed to create record") } @@ -629,7 +629,7 @@ func TestDnsRecordsCreate_DefaultTTL(t *testing.T) { withString(FlagName, "www"). withString(FlagType, "CNAME"). withString(FlagContent, "example.com") - err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -648,7 +648,7 @@ func TestDnsRecordsCreate_CustomTTL(t *testing.T) { withString(FlagType, "CNAME"). withString(FlagContent, "example.com"). withUint(FlagTTL, 7200) - err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -668,7 +668,7 @@ func TestDnsRecordsGet_Success(t *testing.T) { withArgs("1"). withString(FlagName, "www"). withString(FlagType, "CNAME") - err := dnsRecordsGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -679,7 +679,7 @@ func TestDnsRecordsGet_MissingArg(t *testing.T) { cmd := newMockCommand(). withString(FlagName, "www"). withString(FlagType, "CNAME") - err := dnsRecordsGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain or zone ID is required") } @@ -695,7 +695,7 @@ func TestDnsRecordsGet_NotFound(t *testing.T) { withArgs("1"). withString(FlagName, "nonexistent"). withString(FlagType, "A") - err := dnsRecordsGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "failed to get record") } @@ -718,7 +718,7 @@ func TestDnsRecordsUpdate_Success(t *testing.T) { withString(FlagName, "www"). withString(FlagType, "CNAME"). withString(FlagContent, "new.example.com") - err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -730,7 +730,7 @@ func TestDnsRecordsUpdate_MissingArg(t *testing.T) { withString(FlagName, "www"). withString(FlagType, "CNAME"). withString(FlagContent, "new.example.com") - err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain or zone ID is required") } @@ -744,7 +744,7 @@ func TestDnsRecordsUpdate_InvalidRecordType(t *testing.T) { withString(FlagName, "www"). withString(FlagType, "BOGUS"). withString(FlagContent, "example.com") - err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "unsupported record type") } @@ -761,7 +761,7 @@ func TestDnsRecordsUpdate_ServiceError(t *testing.T) { withString(FlagName, "www"). withString(FlagType, "CNAME"). withString(FlagContent, "new.example.com") - err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "failed to update record") } @@ -779,7 +779,7 @@ func TestDnsRecordsUpdate_ARecordWithValidIP(t *testing.T) { withString(FlagName, "@"). withString(FlagType, "A"). withString(FlagContent, "5.6.7.8") - err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -799,7 +799,7 @@ func TestDnsRecordsDelete_Success(t *testing.T) { withArgs("1"). withString(FlagName, "www"). withString(FlagType, "CNAME") - err := dnsRecordsDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -810,7 +810,7 @@ func TestDnsRecordsDelete_MissingArg(t *testing.T) { cmd := newMockCommand(). withString(FlagName, "www"). withString(FlagType, "CNAME") - err := dnsRecordsDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain or zone ID is required") } @@ -826,7 +826,7 @@ func TestDnsRecordsDelete_ServiceError(t *testing.T) { withArgs("1"). withString(FlagName, "www"). withString(FlagType, "CNAME") - err := dnsRecordsDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := dnsRecordsDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "failed to delete record") } diff --git a/pkg/cli/download.go b/pkg/cli/download.go index 2f5887d..def00d6 100644 --- a/pkg/cli/download.go +++ b/pkg/cli/download.go @@ -61,7 +61,8 @@ The output includes: return err } authToken := GetAuthToken(c, cfgMgr) - return handleDownload(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultDownloadServiceFactory) + secure := GetSecureSetting(c, cfgMgr) + return handleDownload(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultDownloadServiceFactory) }, } } @@ -92,7 +93,8 @@ Use --verbose or redirect stderr for progress info.`, return err } authToken := GetAuthToken(c, cfgMgr) - return handleCat(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultDownloadServiceFactory) + secure := GetSecureSetting(c, cfgMgr) + return handleCat(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultDownloadServiceFactory) }, } } @@ -124,16 +126,17 @@ Examples: return err } authToken := GetAuthToken(c, cfgMgr) - return handleLs(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultDownloadServiceFactory) + secure := GetSecureSetting(c, cfgMgr) + return handleLs(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultDownloadServiceFactory) }, } } -func handleDownload(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, downloadServiceFactory DownloadServiceFactory) error { +func handleDownload(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool, downloadServiceFactory DownloadServiceFactory) error { authService := NewAuthService(cfgMgr, output, cfgMgr.Config().GetAccountEndpointSecure()) var svcOpts []DownloadServiceOption - svcOpts = append(svcOpts, WithDownloadAuthService(authService)) + svcOpts = append(svcOpts, WithDownloadAuthService(authService), WithDownloadIPFSEndpoint(cfgMgr.Config().GetIPFSEndpointWithSecure(secure))) if authToken != "" { svcOpts = append(svcOpts, WithDownloadAuthToken(authToken)) @@ -160,7 +163,7 @@ func handleDownload(ctx context.Context, cmd argsFlagGetter, output Output, cfgM RenderDryRun(output, DryRunPreview{ Operation: "download operation", - Endpoint: cfgMgr.Config().GetIPFSEndpointSecure(), + Endpoint: cfgMgr.Config().GetIPFSEndpointWithSecure(secure), Options: options, }) return nil @@ -208,11 +211,11 @@ func handleDownload(ctx context.Context, cmd argsFlagGetter, output Output, cfgM return nil } -func handleCat(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, downloadServiceFactory DownloadServiceFactory) error { +func handleCat(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool, downloadServiceFactory DownloadServiceFactory) error { authService := NewAuthService(cfgMgr, output, cfgMgr.Config().GetAccountEndpointSecure()) var svcOpts []DownloadServiceOption - svcOpts = append(svcOpts, WithDownloadAuthService(authService)) + svcOpts = append(svcOpts, WithDownloadAuthService(authService), WithDownloadIPFSEndpoint(cfgMgr.Config().GetIPFSEndpointWithSecure(secure))) if authToken != "" { svcOpts = append(svcOpts, WithDownloadAuthToken(authToken)) @@ -239,11 +242,11 @@ func handleCat(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr co return err } -func handleLs(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, downloadServiceFactory DownloadServiceFactory) error { +func handleLs(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool, downloadServiceFactory DownloadServiceFactory) error { authService := NewAuthService(cfgMgr, output, cfgMgr.Config().GetAccountEndpointSecure()) var svcOpts []DownloadServiceOption - svcOpts = append(svcOpts, WithDownloadAuthService(authService)) + svcOpts = append(svcOpts, WithDownloadAuthService(authService), WithDownloadIPFSEndpoint(cfgMgr.Config().GetIPFSEndpointWithSecure(secure))) if authToken != "" { svcOpts = append(svcOpts, WithDownloadAuthToken(authToken)) diff --git a/pkg/cli/download_service.go b/pkg/cli/download_service.go index 0fe27e0..1468925 100644 --- a/pkg/cli/download_service.go +++ b/pkg/cli/download_service.go @@ -51,6 +51,12 @@ func WithDownloadAuthToken(token string) DownloadServiceOption { } } +func WithDownloadIPFSEndpoint(endpoint string) DownloadServiceOption { + return func(s *DownloadServiceDefault) { + s.ipfsEndpoint = endpoint + } +} + func defaultDownloadServiceFactory(cfgMgr config.Manager, output Output, opts ...DownloadServiceOption) DownloadService { return NewDownloadService(cfgMgr, output, opts...) } diff --git a/pkg/cli/download_test.go b/pkg/cli/download_test.go index f5bbefa..ef27ce2 100644 --- a/pkg/cli/download_test.go +++ b/pkg/cli/download_test.go @@ -75,7 +75,7 @@ func TestHandleDownload_DryRun(t *testing.T) { return NewMockDownloadService(t) } - err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) if tt.wantErr { require.Error(t, err) @@ -97,7 +97,7 @@ func TestHandleDownload_RequiresCID(t *testing.T) { return NewMockDownloadService(t) } - err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.Error(t, err) assert.True(t, errors.Is(err, ErrCIDRequired)) } @@ -121,7 +121,7 @@ func TestHandleDownload_Success(t *testing.T) { return service } - err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.NoError(t, err) } @@ -138,7 +138,7 @@ func TestHandleDownload_NotAuthenticated(t *testing.T) { return service } - err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.Error(t, err) } @@ -156,7 +156,7 @@ func TestHandleDownload_FileExists_NoForce(t *testing.T) { return service } - err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.Error(t, err) assert.Contains(t, err.Error(), "file already exists") } @@ -183,7 +183,7 @@ func TestHandleDownload_WithForce(t *testing.T) { return service } - err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleDownload(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.NoError(t, err) } @@ -201,7 +201,7 @@ func TestHandleCat_Success(t *testing.T) { return service } - err := handleCat(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleCat(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.NoError(t, err) } @@ -216,7 +216,7 @@ func TestHandleCat_RequiresCID(t *testing.T) { return NewMockDownloadService(t) } - err := handleCat(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleCat(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.Error(t, err) assert.True(t, errors.Is(err, ErrCIDRequired)) } @@ -234,7 +234,7 @@ func TestHandleCat_NotAuthenticated(t *testing.T) { return service } - err := handleCat(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleCat(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.Error(t, err) } @@ -255,7 +255,7 @@ func TestHandleLs_Success(t *testing.T) { return service } - err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.NoError(t, err) } @@ -273,7 +273,7 @@ func TestHandleLs_EmptyDirectory(t *testing.T) { return service } - err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.NoError(t, err) } @@ -288,7 +288,7 @@ func TestHandleLs_RequiresCID(t *testing.T) { return NewMockDownloadService(t) } - err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.Error(t, err) assert.True(t, errors.Is(err, ErrCIDRequired)) } @@ -311,7 +311,7 @@ func TestHandleLs_WithLimit(t *testing.T) { return service } - err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", downloadServiceFactory) + err := handleLs(context.Background(), cmd, output, cfgMgr, "test-token", true, DownloadServiceFactory(downloadServiceFactory)) require.NoError(t, err) } diff --git a/pkg/cli/ipns.go b/pkg/cli/ipns.go index 9149bdb..3e7ccbf 100644 --- a/pkg/cli/ipns.go +++ b/pkg/cli/ipns.go @@ -78,7 +78,7 @@ Examples: pinner ipns keys list pinner ipns keys list --json`, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return ipnsKeysList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return ipnsKeysList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -102,7 +102,7 @@ Examples: }, }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return ipnsKeysCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return ipnsKeysCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -119,7 +119,7 @@ Examples: pinner ipns keys get my-key --json`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return ipnsKeysGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return ipnsKeysGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -135,7 +135,7 @@ Examples: pinner ipns keys delete 1`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return ipnsKeysDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return ipnsKeysDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -169,7 +169,7 @@ Examples: WaitFlag(), }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return ipnsPublish(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return ipnsPublish(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -189,7 +189,7 @@ Examples: pinner ipns republish my-key --json`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return ipnsRepublish(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return ipnsRepublish(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -206,7 +206,7 @@ Examples: pinner ipns resolve k51qzi5uqu5djx... --json`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return ipnsResolve(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return ipnsResolve(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -215,11 +215,11 @@ func resolveIPNSKeyArg(ctx context.Context, ipnsService IPNSService, arg string) return resolveIPNSKeyIDToString(ctx, ipnsService, arg) } -func ipnsKeysList(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { +func ipnsKeysList(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -265,11 +265,11 @@ func ipnsKeysList(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr return nil } -func ipnsKeysCreate(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { +func ipnsKeysCreate(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -317,11 +317,11 @@ func ipnsKeysCreate(ctx context.Context, cmd argsFlagGetter, output Output, cfgM return nil } -func ipnsKeysGet(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { +func ipnsKeysGet(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -369,11 +369,11 @@ func ipnsKeysGet(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr return nil } -func ipnsKeysDelete(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { +func ipnsKeysDelete(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -406,11 +406,11 @@ func ipnsKeysDelete(ctx context.Context, cmd argsFlagGetter, output Output, cfgM return nil } -func ipnsPublish(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { +func ipnsPublish(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -459,11 +459,11 @@ func ipnsPublish(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr return nil } -func ipnsRepublish(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { +func ipnsRepublish(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -489,11 +489,11 @@ func ipnsRepublish(ctx context.Context, cmd argsFlagGetter, output Output, cfgMg return nil } -func ipnsResolve(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string) error { +func ipnsResolve(ctx context.Context, cmd argsFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken) + ipnsService, err := newAuthenticatedIPNSService(cfgMgr, output, authToken, secure) if err != nil { return err } diff --git a/pkg/cli/ipns_service.go b/pkg/cli/ipns_service.go index 8c7b757..d47e176 100644 --- a/pkg/cli/ipns_service.go +++ b/pkg/cli/ipns_service.go @@ -45,22 +45,22 @@ func WithIPNSClient(client *ipfs.Client) IPNSServiceOption { type IPNSServiceFactory func(cfgMgr config.Manager, output Output, opts ...IPNSServiceOption) IPNSService -func defaultIPNSServiceFactory(cfgMgr config.Manager, output Output, opts ...IPNSServiceOption) IPNSService { - return NewIPNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointSecure(), opts...) +func defaultIPNSServiceFactory(cfgMgr config.Manager, output Output, secure bool, opts ...IPNSServiceOption) IPNSService { + return NewIPNSService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), opts...) } -type ipnsServiceFactoryFunc func(cfgMgr config.Manager, output Output, opts ...IPNSServiceOption) IPNSService +type ipnsServiceFactoryFunc func(cfgMgr config.Manager, output Output, secure bool, opts ...IPNSServiceOption) IPNSService var ipnsServiceFactory ipnsServiceFactoryFunc = defaultIPNSServiceFactory // newAuthenticatedIPNSService creates an IPNSService with authentication. // It returns an error if the user is not authenticated. -func newAuthenticatedIPNSService(cfgMgr config.Manager, output Output, authToken string) (IPNSService, error) { +func newAuthenticatedIPNSService(cfgMgr config.Manager, output Output, authToken string, secure bool) (IPNSService, error) { var svcOpts []IPNSServiceOption if authToken != "" { svcOpts = append(svcOpts, WithIPNSAuthToken(authToken)) } - ipnsService := ipnsServiceFactory(cfgMgr, output, svcOpts...) + ipnsService := ipnsServiceFactory(cfgMgr, output, secure, svcOpts...) if err := ipnsService.RequireAuthenticated(); err != nil { return nil, err } @@ -98,6 +98,9 @@ func (s *ipnsService) ListKeys(ctx context.Context) ([]ipfs.IPNSKeyResponse, err if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } return s.service.ListKeys(ctx) } @@ -105,6 +108,9 @@ func (s *ipnsService) CreateKey(ctx context.Context, name string, key *string) ( if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } if key != nil { return s.service.CreateKey(ctx, name, ipfs.WithIPNSKey(*key)) } @@ -115,6 +121,9 @@ func (s *ipnsService) GetKey(ctx context.Context, id string) (*ipfs.IPNSKeyRespo if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } return s.service.GetKey(ctx, id) } @@ -122,6 +131,9 @@ func (s *ipnsService) DeleteKey(ctx context.Context, id string) error { if err := s.RequireAuthenticated(); err != nil { return err } + if s.service == nil { + return ErrServiceUnavailable + } return s.service.DeleteKey(ctx, id) } @@ -129,6 +141,9 @@ func (s *ipnsService) Publish(ctx context.Context, cid string, keyName string, t if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } keyID, err := resolveIPNSKeyID(ctx, s, keyName) if err != nil { @@ -145,6 +160,9 @@ func (s *ipnsService) Republish(ctx context.Context, keyName string) (*ipfs.IPNS if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } keyID, err := resolveIPNSKeyID(ctx, s, keyName) if err != nil { @@ -158,6 +176,9 @@ func (s *ipnsService) Resolve(ctx context.Context, name string) (*ipfs.IPNSResol if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } return s.service.Resolve(ctx, name) } diff --git a/pkg/cli/ipns_test.go b/pkg/cli/ipns_test.go index 54f1c43..25a2c4b 100644 --- a/pkg/cli/ipns_test.go +++ b/pkg/cli/ipns_test.go @@ -25,7 +25,7 @@ func setupIPNSHandlerTest(t *testing.T) (*mockIPNSServiceForCLI, *configmocks.Mo origFactory := ipnsServiceFactory t.Cleanup(func() { ipnsServiceFactory = origFactory }) - ipnsServiceFactory = func(config.Manager, Output, ...IPNSServiceOption) IPNSService { + ipnsServiceFactory = func(config.Manager, Output, bool, ...IPNSServiceOption) IPNSService { return mockSvc } @@ -46,7 +46,7 @@ func TestIpnsKeysList_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -58,7 +58,7 @@ func TestIpnsKeysList_Empty(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -70,7 +70,7 @@ func TestIpnsKeysList_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "server error") } @@ -81,7 +81,7 @@ func TestIpnsKeysList_Unauthenticated(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "") + err := ipnsKeysList(context.Background(), cmd, output, cfgMgr, "", true) require.Error(t, err) assert.True(t, errors.Is(err, ErrNotAuthenticated)) } @@ -99,7 +99,7 @@ func TestIpnsKeysCreate_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withString(FlagName, "my-key") - err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -115,7 +115,7 @@ func TestIpnsKeysCreate_WithKeyImport(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withString(FlagName, "imported-key").withString(FlagKey, "base64keydata") - err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -124,7 +124,7 @@ func TestIpnsKeysCreate_MissingName(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withString(FlagName, "") - err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "name is required") } @@ -137,7 +137,7 @@ func TestIpnsKeysCreate_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withString(FlagName, "my-key") - err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "conflict") } @@ -154,7 +154,7 @@ func TestIpnsKeysGet_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -171,7 +171,7 @@ func TestIpnsKeysGet_ByName(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("my-key") - err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -180,7 +180,7 @@ func TestIpnsKeysGet_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "key name or ID is required") } @@ -193,7 +193,7 @@ func TestIpnsKeysGet_NotFound(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("999") - err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "key not found") } @@ -209,7 +209,7 @@ func TestIpnsKeysDelete_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -226,7 +226,7 @@ func TestIpnsKeysDelete_ByName(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("my-key") - err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -235,7 +235,7 @@ func TestIpnsKeysDelete_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "key name or ID is required") } @@ -248,7 +248,7 @@ func TestIpnsKeysDelete_NotFound(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("999") - err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsKeysDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "key not found") } @@ -267,7 +267,7 @@ func TestIpnsPublish_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("QmXxx").withString("key-name", "1") - err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -282,7 +282,7 @@ func TestIpnsPublish_WithTTL(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("QmYyy").withString("key-name", "1").withString("ttl", "24h") - err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -291,7 +291,7 @@ func TestIpnsPublish_MissingCID(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withString("key-name", "my-key") - err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "CID is required") } @@ -301,7 +301,7 @@ func TestIpnsPublish_MissingKeyName(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("QmXxx").withString("key-name", "") - err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "key-name is required") } @@ -314,7 +314,7 @@ func TestIpnsPublish_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("invalid").withString("key-name", "1") - err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsPublish(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "invalid CID format") } @@ -330,7 +330,7 @@ func TestIpnsRepublish_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("my-key") - err := ipnsRepublish(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsRepublish(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -339,7 +339,7 @@ func TestIpnsRepublish_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := ipnsRepublish(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsRepublish(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "key name or ID is required") } @@ -352,7 +352,7 @@ func TestIpnsRepublish_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("my-key") - err := ipnsRepublish(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsRepublish(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "republish failed") } @@ -374,7 +374,7 @@ func TestIpnsResolve_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("k51qzi5uqu5djx123") - err := ipnsResolve(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsResolve(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -383,7 +383,7 @@ func TestIpnsResolve_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := ipnsResolve(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsResolve(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "IPNS name is required") } @@ -396,7 +396,7 @@ func TestIpnsResolve_NotFound(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("k51qzi5uqu5djx999") - err := ipnsResolve(context.Background(), cmd, output, cfgMgr, "test-token") + err := ipnsResolve(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "IPNS name not found") } diff --git a/pkg/cli/list.go b/pkg/cli/list.go index 66265e4..d91dd77 100644 --- a/pkg/cli/list.go +++ b/pkg/cli/list.go @@ -50,7 +50,7 @@ func list(ctx context.Context, cmd flagGetterWithInt, output Output, cfgMgr conf if authToken != "" { pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { - pinningService = pinningServiceFactory(cfgMgr, output) + pinningService = pinningServiceFactory(cfgMgr, output, secure) } if err := pinningService.RequireAuthenticated(); err != nil { diff --git a/pkg/cli/list_test.go b/pkg/cli/list_test.go index 93c758c..7a80a33 100644 --- a/pkg/cli/list_test.go +++ b/pkg/cli/list_test.go @@ -124,7 +124,7 @@ func TestList(t *testing.T) { withString(FlagName, tt.nameFilter). withInt(FlagLimit, tt.limit) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -186,7 +186,7 @@ func TestList_WithStatusFilter(t *testing.T) { withInt(FlagLimit, 10). withString(FlagStatus, "pinned") - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -203,7 +203,7 @@ func TestList_RequireAuthFails(t *testing.T) { cmd := newMockCommand() - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } diff --git a/pkg/cli/pin.go b/pkg/cli/pin.go index 386ffde..e6ac6fa 100644 --- a/pkg/cli/pin.go +++ b/pkg/cli/pin.go @@ -56,7 +56,7 @@ func pin(ctx context.Context, cmd cidFlagGetter, output Output, cfgMgr config.Ma if authToken != "" { pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { - pinningService = pinningServiceFactory(cfgMgr, output) + pinningService = pinningServiceFactory(cfgMgr, output, secure) } if err := pinningService.RequireAuthenticated(); err != nil { @@ -112,7 +112,7 @@ func pin(ctx context.Context, cmd cidFlagGetter, output Output, cfgMgr config.Ma RenderDryRun(output, DryRunPreview{ Operation: "pinning operations", - Endpoint: cfgMgr.Config().GetIPFSEndpointSecure(), + Endpoint: cfgMgr.Config().GetIPFSEndpointWithSecure(secure), Items: cids, ItemLabel: "CIDs to pin", Options: options, @@ -152,6 +152,6 @@ func pin(ctx context.Context, cmd cidFlagGetter, output Output, cfgMgr config.Ma return cids, nil } -func defaultPinningServiceFactory(cfgMgr config.Manager, output Output) PinningService { - return NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointSecure()) +func defaultPinningServiceFactory(cfgMgr config.Manager, output Output, secure bool) PinningService { + return NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure)) } diff --git a/pkg/cli/pin_test.go b/pkg/cli/pin_test.go index ea071b6..9c34917 100644 --- a/pkg/cli/pin_test.go +++ b/pkg/cli/pin_test.go @@ -90,7 +90,7 @@ func TestPinDryRun(t *testing.T) { } - pinningServiceFactory := func(cfgMgr config.Manager, output Output) PinningService { + pinningServiceFactory := func(cfgMgr config.Manager, output Output, _ bool) PinningService { return service } @@ -210,7 +210,7 @@ func TestPin(t *testing.T) { withString(FlagName, tt.nameFlag). withBool(FlagNoWait, tt.noWaitFlag) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -287,7 +287,7 @@ func TestPinBatch(t *testing.T) { withBool(FlagContinue, tt.continueOn) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -356,7 +356,7 @@ func TestDefaultPinningServiceFactory(t *testing.T) { output := newTestOutput() - service := defaultPinningServiceFactory(cfgMgr, output) + service := defaultPinningServiceFactory(cfgMgr, output, true) assert.IsType(t, &PinningServiceDefault{}, service) ps := service.(*PinningServiceDefault) diff --git a/pkg/cli/pinning_service.go b/pkg/cli/pinning_service.go index edf2bdd..662ca96 100644 --- a/pkg/cli/pinning_service.go +++ b/pkg/cli/pinning_service.go @@ -131,4 +131,4 @@ type StatusService interface { } // PinningServiceFactory creates a PinningService with dependencies -type PinningServiceFactory func(cfgMgr config.Manager, output Output) PinningService +type PinningServiceFactory func(cfgMgr config.Manager, output Output, secure bool) PinningService diff --git a/pkg/cli/pins_add.go b/pkg/cli/pins_add.go index 2f1402b..8a3dd9d 100644 --- a/pkg/cli/pins_add.go +++ b/pkg/cli/pins_add.go @@ -79,7 +79,7 @@ func pinsAdd(ctx context.Context, cmd interface { if authToken != "" { pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { - pinningService = pinningServiceFactory(cfgMgr, output) + pinningService = pinningServiceFactory(cfgMgr, output, secure) } if err := pinningService.RequireAuthenticated(); err != nil { diff --git a/pkg/cli/pins_add_test.go b/pkg/cli/pins_add_test.go index f5f903d..f55e468 100644 --- a/pkg/cli/pins_add_test.go +++ b/pkg/cli/pins_add_test.go @@ -46,7 +46,7 @@ func TestPinsAdd_DryRun(t *testing.T) { withCID("QmXxx"). withBool(FlagDryRun, true) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -66,7 +66,7 @@ func TestPinsAdd_NoMeta(t *testing.T) { cmd := newMockCommand().withCID("QmXxx") - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -92,7 +92,7 @@ func TestPinsAdd_WithMetadata(t *testing.T) { withCID("QmXxx"). withStringSlice(FlagMeta, []string{"owner=alice"}) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -112,7 +112,7 @@ func TestPinsAdd_PinError(t *testing.T) { cmd := newMockCommand().withCID("QmXxx") - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -137,7 +137,7 @@ func TestPinsAdd_MetadataUpdateError(t *testing.T) { withCID("QmXxx"). withStringSlice(FlagMeta, []string{"owner=alice"}) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -160,7 +160,7 @@ func TestPinsAdd_InvalidMetaFormat(t *testing.T) { withCID("QmXxx"). withStringSlice(FlagMeta, []string{"invalid-no-equals"}) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } diff --git a/pkg/cli/pins_rm.go b/pkg/cli/pins_rm.go index c2a7628..ef36195 100644 --- a/pkg/cli/pins_rm.go +++ b/pkg/cli/pins_rm.go @@ -41,11 +41,12 @@ Examples: return err } authToken := GetAuthToken(c, cfgMgr) + secure := GetSecureSetting(c, cfgMgr) if c.Bool(FlagAll) { prompter := &PTermConfirmPrompter{} - return unpinAll(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory, prompter) + return unpinAll(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory, prompter) } - return unpin(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory) + return unpin(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory) }, } } diff --git a/pkg/cli/pins_status.go b/pkg/cli/pins_status.go index 5e5a853..b2e5ccb 100644 --- a/pkg/cli/pins_status.go +++ b/pkg/cli/pins_status.go @@ -31,7 +31,8 @@ Examples: return err } authToken := GetAuthToken(c, cfgMgr) - return status(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory, defaultStatusServiceFactory) + secure := GetSecureSetting(c, cfgMgr) + return status(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory, defaultStatusServiceFactory) }, } } diff --git a/pkg/cli/pins_update.go b/pkg/cli/pins_update.go index 77b290f..9972788 100644 --- a/pkg/cli/pins_update.go +++ b/pkg/cli/pins_update.go @@ -49,7 +49,7 @@ func pinsUpdate(ctx context.Context, cmd interface { if authToken != "" { pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { - pinningService = pinningServiceFactory(cfgMgr, output) + pinningService = pinningServiceFactory(cfgMgr, output, secure) } if err := pinningService.RequireAuthenticated(); err != nil { diff --git a/pkg/cli/pins_update_test.go b/pkg/cli/pins_update_test.go index b521067..3a5348e 100644 --- a/pkg/cli/pins_update_test.go +++ b/pkg/cli/pins_update_test.go @@ -21,7 +21,7 @@ func TestPinsUpdate(t *testing.T) { output := newTestOutput() cfgMgr := configmocks.NewMockManager(t) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -39,7 +39,7 @@ func TestPinsUpdate(t *testing.T) { output := newTestOutput() cfgMgr := configmocks.NewMockManager(t) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -59,7 +59,7 @@ func TestPinsUpdate(t *testing.T) { output := newTestOutput() cfgMgr := configmocks.NewMockManager(t) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -81,7 +81,7 @@ func TestPinsUpdate(t *testing.T) { withIsSet(FlagName, true) output := newTestOutput() - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -106,7 +106,7 @@ func TestPinsUpdate(t *testing.T) { withIsSet(FlagName, true) output := newTestOutput() - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -130,7 +130,7 @@ func TestPinsUpdate(t *testing.T) { withIsSet(FlagMeta, true) output := newTestOutput() - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -154,7 +154,7 @@ func TestPinsUpdate(t *testing.T) { withIsSet(FlagClearMeta, true) output := newTestOutput() - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -180,7 +180,7 @@ func TestPinsUpdate(t *testing.T) { withIsSet(FlagMeta, true) output := newTestOutput() - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -206,7 +206,7 @@ func TestPinsUpdate(t *testing.T) { withIsSet(FlagMeta, true) output := newTestOutput() - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } @@ -235,7 +235,7 @@ func TestPinsUpdate(t *testing.T) { withIsSet(FlagDryRun, true) output := newTestOutput() - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } diff --git a/pkg/cli/status.go b/pkg/cli/status.go index 8f6b42a..814f596 100644 --- a/pkg/cli/status.go +++ b/pkg/cli/status.go @@ -50,7 +50,8 @@ Operation status values (shown when pin is not found): return err } authToken := GetAuthToken(c, cfgMgr) - return status(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory, defaultStatusServiceFactory) + secure := GetSecureSetting(c, cfgMgr) + return status(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory, defaultStatusServiceFactory) }, } } @@ -62,12 +63,12 @@ func defaultStatusServiceFactory(cfgMgr config.Manager, output Output, pinningSe func status(ctx context.Context, cmd interface { cidGetter Bool(name string) bool -}, output Output, cfgMgr config.Manager, authToken string, pinningServiceFactory PinningServiceFactory, statusServiceFactory StatusServiceFactory) error { +}, output Output, cfgMgr config.Manager, authToken string, secure bool, pinningServiceFactory PinningServiceFactory, statusServiceFactory StatusServiceFactory) error { var pinningService PinningService if authToken != "" { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpoint(), WithAuthToken(authToken)) + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { - pinningService = pinningServiceFactory(cfgMgr, output) + pinningService = pinningServiceFactory(cfgMgr, output, secure) } if err := pinningService.RequireAuthenticated(); err != nil { diff --git a/pkg/cli/status_test.go b/pkg/cli/status_test.go index 7e97d8f..58b61ab 100644 --- a/pkg/cli/status_test.go +++ b/pkg/cli/status_test.go @@ -186,7 +186,7 @@ func TestStatus(t *testing.T) { withCID(tt.cid). withBool(FlagWatch, tt.watchFlag) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return pinningSvc } @@ -194,7 +194,7 @@ func TestStatus(t *testing.T) { return statusSvc } - err := status(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory, statusServiceFactory) + err := status(context.Background(), cmd, output, cfgMgr, "", true, PinningServiceFactory(pinningServiceFactory), statusServiceFactory) if tt.wantErr { require.Error(t, err) diff --git a/pkg/cli/unpin.go b/pkg/cli/unpin.go index 79f3266..6ac5e81 100644 --- a/pkg/cli/unpin.go +++ b/pkg/cli/unpin.go @@ -45,17 +45,18 @@ Examples: return err } authToken := GetAuthToken(c, cfgMgr) - return unpin(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory) + secure := GetSecureSetting(c, cfgMgr) + return unpin(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory) }, } } -func unpin(ctx context.Context, cmd cidFlagGetter, output Output, cfgMgr config.Manager, authToken string, pinningServiceFactory PinningServiceFactory) error { +func unpin(ctx context.Context, cmd cidFlagGetter, output Output, cfgMgr config.Manager, authToken string, secure bool, pinningServiceFactory PinningServiceFactory) error { var pinningService PinningService if authToken != "" { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpoint(), WithAuthToken(authToken)) + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { - pinningService = pinningServiceFactory(cfgMgr, output) + pinningService = pinningServiceFactory(cfgMgr, output, secure) } if err := pinningService.RequireAuthenticated(); err != nil { @@ -109,7 +110,7 @@ func unpin(ctx context.Context, cmd cidFlagGetter, output Output, cfgMgr config. RenderDryRun(output, DryRunPreview{ Operation: "unpin operations", - Endpoint: cfgMgr.Config().GetIPFSEndpoint(), + Endpoint: cfgMgr.Config().GetIPFSEndpointWithSecure(secure), Items: cids, ItemLabel: "CIDs to unpin", Options: options, diff --git a/pkg/cli/unpin_all.go b/pkg/cli/unpin_all.go index 3d12bdf..5cbff0f 100644 --- a/pkg/cli/unpin_all.go +++ b/pkg/cli/unpin_all.go @@ -49,13 +49,14 @@ Examples: return err } authToken := GetAuthToken(c, cfgMgr) + secure := GetSecureSetting(c, cfgMgr) prompter := &PTermConfirmPrompter{} - return unpinAll(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, defaultPinningServiceFactory, prompter) + return unpinAll(ctx, newCLICommandWrapper(c), output, cfgMgr, authToken, secure, defaultPinningServiceFactory, prompter) }, } } -func unpinAll(ctx context.Context, cmd flagGetterWithInt, output Output, cfgMgr config.Manager, authToken string, pinningServiceFactory PinningServiceFactory, prompter ConfirmPrompter) error { +func unpinAll(ctx context.Context, cmd flagGetterWithInt, output Output, cfgMgr config.Manager, authToken string, secure bool, pinningServiceFactory PinningServiceFactory, prompter ConfirmPrompter) error { confirm := cmd.Bool(FlagForce) || cmd.Bool(FlagConfirm) if !confirm { output.Printfln("Use --force to unpin all pins. This is a destructive operation.") @@ -64,9 +65,9 @@ func unpinAll(ctx context.Context, cmd flagGetterWithInt, output Output, cfgMgr var pinningService PinningService if authToken != "" { - pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpoint(), WithAuthToken(authToken)) + pinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { - pinningService = pinningServiceFactory(cfgMgr, output) + pinningService = pinningServiceFactory(cfgMgr, output, secure) } if err := pinningService.RequireAuthenticated(); err != nil { @@ -114,7 +115,7 @@ func unpinAll(ctx context.Context, cmd flagGetterWithInt, output Output, cfgMgr RenderDryRun(output, DryRunPreview{ Operation: fmt.Sprintf("unpin-all (%d pins)", len(pins)), - Endpoint: cfgMgr.Config().GetIPFSEndpoint(), + Endpoint: cfgMgr.Config().GetIPFSEndpointWithSecure(secure), Items: items, ItemLabel: "Request IDs to unpin", Options: options, diff --git a/pkg/cli/unpin_all_test.go b/pkg/cli/unpin_all_test.go index 7d5ae90..6faf9f1 100644 --- a/pkg/cli/unpin_all_test.go +++ b/pkg/cli/unpin_all_test.go @@ -209,12 +209,12 @@ func TestUnpinAll(t *testing.T) { withInt(FlagParallel, tt.parallel). withBool(FlagContinue, tt.continueOn) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } prompter := &MockConfirmPrompter{} - err := unpinAll(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory, prompter) + err := unpinAll(context.Background(), cmd, output, cfgMgr, "", true, PinningServiceFactory(pinningServiceFactory), prompter) if tt.wantErr { require.Error(t, err) @@ -252,12 +252,12 @@ func TestUnpinAllConfirmPrompt(t *testing.T) { withInt(FlagParallel, 0). withBool(FlagContinue, false) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } prompter := &MockConfirmPrompter{ConfirmResult: "wrong"} - err := unpinAll(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory, prompter) + err := unpinAll(context.Background(), cmd, output, cfgMgr, "", true, PinningServiceFactory(pinningServiceFactory), prompter) assert.ErrorIs(t, err, ErrUnpinAllAborted) }) @@ -295,12 +295,12 @@ func TestUnpinAllConfirmPrompt(t *testing.T) { withInt(FlagParallel, 0). withBool(FlagContinue, false) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } prompter := &MockConfirmPrompter{ConfirmResult: "2"} - err := unpinAll(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory, prompter) + err := unpinAll(context.Background(), cmd, output, cfgMgr, "", true, PinningServiceFactory(pinningServiceFactory), prompter) assert.NoError(t, err) }) @@ -328,12 +328,12 @@ func TestUnpinAllConfirmPrompt(t *testing.T) { withInt(FlagParallel, 0). withBool(FlagContinue, false) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } prompter := &MockConfirmPrompter{ConfirmErr: ErrUnpinAllAborted} - err := unpinAll(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory, prompter) + err := unpinAll(context.Background(), cmd, output, cfgMgr, "", true, PinningServiceFactory(pinningServiceFactory), prompter) assert.ErrorIs(t, err, ErrUnpinAllAborted) }) diff --git a/pkg/cli/unpin_test.go b/pkg/cli/unpin_test.go index aac1ad3..f059b8a 100644 --- a/pkg/cli/unpin_test.go +++ b/pkg/cli/unpin_test.go @@ -101,11 +101,11 @@ func TestUnpin(t *testing.T) { withBool(FlagForce, tt.confirmFlag). withBool(FlagConfirm, tt.confirmFlag) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } - err := unpin(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory) + err := unpin(context.Background(), cmd, output, cfgMgr, "", true, PinningServiceFactory(pinningServiceFactory)) if tt.wantErr { require.Error(t, err) @@ -181,11 +181,11 @@ func TestUnpinBatch(t *testing.T) { withBool(FlagContinue, tt.continueOn) - pinningServiceFactory := func(cm config.Manager, out Output) PinningService { + pinningServiceFactory := func(cm config.Manager, out Output, _ bool) PinningService { return service } - err := unpin(context.Background(), cmd, output, cfgMgr, "", pinningServiceFactory) + err := unpin(context.Background(), cmd, output, cfgMgr, "", true, PinningServiceFactory(pinningServiceFactory)) if tt.wantErr { require.Error(t, err) diff --git a/pkg/cli/upload.go b/pkg/cli/upload.go index 4c108cb..6185d09 100644 --- a/pkg/cli/upload.go +++ b/pkg/cli/upload.go @@ -178,7 +178,7 @@ func handleUpload(ctx context.Context, cmd interface { wait := !cmd.Bool(FlagNoWait) if wait { - svcOpts = append(svcOpts, WithUploadPinningService(pinningServiceFactory(cfgMgr, output))) + svcOpts = append(svcOpts, WithUploadPinningService(pinningServiceFactory(cfgMgr, output, secure))) } uploadService := uploadServiceFactory(cfgMgr, output, svcOpts...) @@ -219,7 +219,7 @@ func handleUpload(ctx context.Context, cmd interface { RenderDryRun(output, DryRunPreview{ Operation: "upload operation", - Endpoint: cfgMgr.Config().GetIPFSEndpointSecure(), + Endpoint: cfgMgr.Config().GetIPFSEndpointWithSecure(secure), Options: options, }) return nil @@ -250,7 +250,7 @@ func handleUpload(ctx context.Context, cmd interface { if authToken != "" { metaPinningService = NewPinningService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), WithAuthToken(authToken)) } else { - metaPinningService = pinningServiceFactory(cfgMgr, output) + metaPinningService = pinningServiceFactory(cfgMgr, output, secure) } slice := metaMapToSlice(meta) if err := metaPinningService.UpdateMetadata(ctx, result.CID, slice, false); err != nil { diff --git a/pkg/cli/upload_test.go b/pkg/cli/upload_test.go index 59c0b2b..de504f5 100644 --- a/pkg/cli/upload_test.go +++ b/pkg/cli/upload_test.go @@ -417,7 +417,7 @@ func TestUploadDryRun(t *testing.T) { return service } - pinningServiceFactory := func(cfgMgr config.Manager, output Output) PinningService { + pinningServiceFactory := func(cfgMgr config.Manager, output Output, _ bool) PinningService { return NewMockPinningService(t) } diff --git a/pkg/cli/websites.go b/pkg/cli/websites.go index 57a1400..f782fca 100644 --- a/pkg/cli/websites.go +++ b/pkg/cli/websites.go @@ -70,7 +70,7 @@ Examples: pinner websites list pinner websites list --json`, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return websitesList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return websitesList(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -94,7 +94,7 @@ Examples: NoDNSHostingFlag(), }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return websitesCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return websitesCreate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -110,7 +110,7 @@ Examples: pinner websites get example.com --json`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return websitesGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return websitesGet(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -146,7 +146,7 @@ Examples: NoDNSHostingFlag(), }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return websitesUpdate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return websitesUpdate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -208,8 +208,8 @@ func printWebsiteUpdateResult(output Output, website *ipfs.WebsiteItem, message } } -func websitesList(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { - websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) +func websitesList(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -302,8 +302,8 @@ func resolveAndGetWebsite(ctx context.Context, websitesService WebsitesService, return websitesService.Get(ctx, id) } -func websitesUpdate(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { - websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) +func websitesUpdate(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -389,13 +389,13 @@ Examples: CIDFlag(), }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return websitesEnableIPNS(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return websitesEnableIPNS(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } -func websitesEnableIPNS(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { - websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) +func websitesEnableIPNS(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -429,8 +429,8 @@ func websitesEnableIPNS(ctx context.Context, cmd websitesCommandGetter, output O return nil } -func websitesGet(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { - websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) +func websitesGet(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -508,8 +508,8 @@ func websitesGet(ctx context.Context, cmd websitesCommandGetter, output Output, return nil } -func websitesCreate(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { - websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) +func websitesCreate(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -691,7 +691,7 @@ Examples: pinner websites delete example.com --json`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return websitesDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return websitesDelete(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } @@ -707,13 +707,13 @@ Examples: pinner websites validate example.com --json`, ArgsUsage: "", Action: withContext(func(ctx context.Context, cc *commandContext) error { - return websitesValidate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return websitesValidate(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } -func websitesDelete(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { - websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) +func websitesDelete(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -740,8 +740,8 @@ func websitesDelete(ctx context.Context, cmd websitesCommandGetter, output Outpu return nil } -func websitesValidate(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { - websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) +func websitesValidate(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken, secure) if err != nil { return err } @@ -891,13 +891,13 @@ Examples: pinner websites config pinner websites config --json`, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return websitesConfig(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return websitesConfig(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } -func websitesConfig(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { - websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) +func websitesConfig(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken, secure) if err != nil { return err } diff --git a/pkg/cli/websites_handler_test.go b/pkg/cli/websites_handler_test.go index 610a448..11e751b 100644 --- a/pkg/cli/websites_handler_test.go +++ b/pkg/cli/websites_handler_test.go @@ -118,7 +118,7 @@ func setupWebsitesHandlerTest(t *testing.T) (*mockWebsitesHandlerService, *confi origFactory := websitesServiceFactory t.Cleanup(func() { websitesServiceFactory = origFactory }) - websitesServiceFactory = func(config.Manager, Output, ...WebsitesServiceOption) WebsitesService { + websitesServiceFactory = func(config.Manager, Output, bool, ...WebsitesServiceOption) WebsitesService { return mockSvc } @@ -139,7 +139,7 @@ func TestWebsitesListHandler_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesList(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -151,7 +151,7 @@ func TestWebsitesListHandler_Empty(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesList(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -163,7 +163,7 @@ func TestWebsitesListHandler_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesList(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesList(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "server error") } @@ -174,7 +174,7 @@ func TestWebsitesListHandler_Unauthenticated(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesList(context.Background(), cmd, output, cfgMgr, "") + err := websitesList(context.Background(), cmd, output, cfgMgr, "", true) require.Error(t, err) assert.True(t, errors.Is(err, ErrNotAuthenticated)) } @@ -196,7 +196,7 @@ func TestWebsitesCreateHandler_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com").withString(FlagCID, "QmXxx") - err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -205,7 +205,7 @@ func TestWebsitesCreateHandler_MissingDomain(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withString(FlagCID, "QmXxx") - err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain is required") } @@ -225,7 +225,7 @@ func TestWebsitesCreateHandler_WithDNSHosting(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com").withString(FlagCID, "QmXxx"). withBool(FlagDNSHosting, true).withIsSet(FlagDNSHosting, true) - err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -244,7 +244,7 @@ func TestWebsitesCreateHandler_WithNoDNSHosting(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com").withString(FlagCID, "QmXxx"). withBool(FlagNoDNSHosting, true).withIsSet(FlagNoDNSHosting, true) - err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -256,7 +256,7 @@ func TestWebsitesCreateHandler_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com").withString(FlagCID, "QmXxx") - err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "conflict") } @@ -274,7 +274,7 @@ func TestWebsitesCreateHandler_DefaultTargetType(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com").withString(FlagCID, "QmXxx") - err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesCreate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -293,7 +293,7 @@ func TestWebsitesGetHandler_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -313,7 +313,7 @@ func TestWebsitesGetHandler_DomainArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -322,7 +322,7 @@ func TestWebsitesGetHandler_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "website ID or domain is required") } @@ -335,7 +335,7 @@ func TestWebsitesGetHandler_NotFound(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("999") - err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "website not found") } @@ -348,7 +348,7 @@ func TestWebsitesGetHandler_DomainNotFound(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("nonexistent.com") - err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesGet(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "website not found for domain") } @@ -374,7 +374,7 @@ func TestWebsitesUpdateHandler_Success(t *testing.T) { cmd := newMockCommand().withArgs("1"). withString(FlagCID, "QmNewHash").withIsSet(FlagCID, true). withString(FlagTargetType, "ipfs").withIsSet(FlagTargetType, true) - err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -383,7 +383,7 @@ func TestWebsitesUpdateHandler_NoUpdateFields(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "at least one field must be provided for update") } @@ -394,7 +394,7 @@ func TestWebsitesUpdateHandler_CIDWithoutTargetType(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1"). withString(FlagCID, "QmNewHash").withIsSet(FlagCID, true) - err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "--target-type is required when --cid is provided") } @@ -414,7 +414,7 @@ func TestWebsitesUpdateHandler_DNSHostingEnabled(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1"). withBool(FlagDNSHosting, true).withIsSet(FlagDNSHosting, true) - err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -433,7 +433,7 @@ func TestWebsitesUpdateHandler_DNSHostingDisabled(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1"). withBool(FlagNoDNSHosting, true).withIsSet(FlagNoDNSHosting, true) - err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -447,7 +447,7 @@ func TestWebsitesUpdateHandler_ServiceError(t *testing.T) { cmd := newMockCommand().withArgs("1"). withString(FlagCID, "QmNewHash").withIsSet(FlagCID, true). withString(FlagTargetType, "ipfs").withIsSet(FlagTargetType, true) - err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "update failed") } @@ -457,7 +457,7 @@ func TestWebsitesUpdateHandler_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesUpdate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "website ID or domain is required") } @@ -480,7 +480,7 @@ func TestWebsitesEnableIPNSHandler_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -500,7 +500,7 @@ func TestWebsitesEnableIPNSHandler_WithCID(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1").withString(FlagCID, "QmNewHash").withIsSet(FlagCID, true) - err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -509,7 +509,7 @@ func TestWebsitesEnableIPNSHandler_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "website ID or domain is required") } @@ -522,7 +522,7 @@ func TestWebsitesEnableIPNSHandler_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesEnableIPNS(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "not found") } @@ -538,7 +538,7 @@ func TestWebsitesDeleteHandler_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -555,7 +555,7 @@ func TestWebsitesDeleteHandler_DomainArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -564,7 +564,7 @@ func TestWebsitesDeleteHandler_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "website ID or domain is required") } @@ -577,7 +577,7 @@ func TestWebsitesDeleteHandler_NotFound(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("999") - err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "website not found") } @@ -590,7 +590,7 @@ func TestWebsitesDeleteHandler_DomainNotFound(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("nonexistent.com") - err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesDelete(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "website not found for domain") } @@ -611,7 +611,7 @@ func TestWebsitesValidateHandler_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -628,7 +628,7 @@ func TestWebsitesValidateHandler_ValidationFailure(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("1") - err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -637,7 +637,7 @@ func TestWebsitesValidateHandler_MissingArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "website ID or domain is required") } @@ -660,7 +660,7 @@ func TestWebsitesValidateHandler_DomainArg(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesValidate(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -685,7 +685,7 @@ func TestWebsitesSSLStatusHandler_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -694,7 +694,7 @@ func TestWebsitesSSLStatusHandler_MissingDomain(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "domain is required") } @@ -707,7 +707,7 @@ func TestWebsitesSSLStatusHandler_NoSSLInfo(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -719,7 +719,7 @@ func TestWebsitesSSLStatusHandler_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "API error") } @@ -730,7 +730,7 @@ func TestWebsitesSSLStatusHandler_Unauthenticated(t *testing.T) { output := newTestOutput() cmd := newMockCommand().withArgs("example.com") - err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "") + err := websitesSSLStatus(context.Background(), cmd, output, cfgMgr, "", true) require.Error(t, err) assert.True(t, errors.Is(err, ErrNotAuthenticated)) } @@ -750,7 +750,7 @@ func TestWebsitesConfigHandler_Success(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesConfig(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesConfig(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -762,7 +762,7 @@ func TestWebsitesConfigHandler_NoSites(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesConfig(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesConfig(context.Background(), cmd, output, cfgMgr, "test-token", true) require.NoError(t, err) } @@ -774,7 +774,7 @@ func TestWebsitesConfigHandler_ServiceError(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesConfig(context.Background(), cmd, output, cfgMgr, "test-token") + err := websitesConfig(context.Background(), cmd, output, cfgMgr, "test-token", true) require.Error(t, err) assert.Contains(t, err.Error(), "failed to get config") } @@ -785,7 +785,7 @@ func TestWebsitesConfigHandler_Unauthenticated(t *testing.T) { output := newTestOutput() cmd := newMockCommand() - err := websitesConfig(context.Background(), cmd, output, cfgMgr, "") + err := websitesConfig(context.Background(), cmd, output, cfgMgr, "", true) require.Error(t, err) assert.True(t, errors.Is(err, ErrNotAuthenticated)) } diff --git a/pkg/cli/websites_service.go b/pkg/cli/websites_service.go index 03b9fbe..a42d4c3 100644 --- a/pkg/cli/websites_service.go +++ b/pkg/cli/websites_service.go @@ -32,25 +32,25 @@ func WithWebsitesClient(client *ipfs.Client) WebsitesServiceOption { } // WebsitesServiceFactory creates a WebsitesService with dependencies. -type WebsitesServiceFactory func(cfgMgr config.Manager, output Output, opts ...WebsitesServiceOption) WebsitesService +type WebsitesServiceFactory func(cfgMgr config.Manager, output Output, secure bool, opts ...WebsitesServiceOption) WebsitesService // websitesServiceFactory is the factory function used by newAuthenticatedWebsitesService. // It can be overridden in tests to inject mock services. var websitesServiceFactory WebsitesServiceFactory = defaultWebsitesServiceFactory // defaultWebsitesServiceFactory creates a default WebsitesService instance. -func defaultWebsitesServiceFactory(cfgMgr config.Manager, output Output, opts ...WebsitesServiceOption) WebsitesService { - return NewWebsitesService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointSecure(), opts...) +func defaultWebsitesServiceFactory(cfgMgr config.Manager, output Output, secure bool, opts ...WebsitesServiceOption) WebsitesService { + return NewWebsitesService(cfgMgr, output, cfgMgr.Config().GetIPFSEndpointWithSecure(secure), opts...) } // newAuthenticatedWebsitesService creates a WebsitesService with authentication. // It returns an error if the user is not authenticated. -func newAuthenticatedWebsitesService(cfgMgr config.Manager, output Output, authToken string) (WebsitesService, error) { +func newAuthenticatedWebsitesService(cfgMgr config.Manager, output Output, authToken string, secure bool) (WebsitesService, error) { var svcOpts []WebsitesServiceOption if authToken != "" { svcOpts = append(svcOpts, WithWebsitesAuthToken(authToken)) } - websitesService := websitesServiceFactory(cfgMgr, output, svcOpts...) + websitesService := websitesServiceFactory(cfgMgr, output, secure, svcOpts...) if err := websitesService.RequireAuthenticated(); err != nil { return nil, err } @@ -101,6 +101,9 @@ func (s *websitesService) Create(ctx context.Context, domain, targetHash, target if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } response, err := s.service.Create(ctx, domain, targetHash, targetType) if err != nil { return nil, err @@ -113,6 +116,9 @@ func (s *websitesService) CreateWithOptions(ctx context.Context, req ipfs.Websit if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } response, err := s.service.CreateWithOptions(ctx, req) if err != nil { return nil, err @@ -125,6 +131,9 @@ func (s *websitesService) Get(ctx context.Context, id string) (*ipfs.WebsiteItem if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } response, err := s.service.Get(ctx, id) if err != nil { return nil, err @@ -137,6 +146,9 @@ func (s *websitesService) Update(ctx context.Context, id, domain, targetHash, ta if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } response, err := s.service.Update(ctx, id, domain, targetHash, targetType) if err != nil { return nil, err @@ -149,6 +161,9 @@ func (s *websitesService) UpdateWithOptions(ctx context.Context, id string, req if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } response, err := s.service.UpdateWithOptions(ctx, id, req) if err != nil { return nil, err @@ -161,6 +176,9 @@ func (s *websitesService) Delete(ctx context.Context, id string) error { if err := s.RequireAuthenticated(); err != nil { return err } + if s.service == nil { + return ErrServiceUnavailable + } return s.service.Delete(ctx, id) } @@ -169,6 +187,9 @@ func (s *websitesService) Validate(ctx context.Context, id string) (*ipfs.Websit if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } return s.service.ValidateDNS(ctx, id) } @@ -177,6 +198,9 @@ func (s *websitesService) GetSSLStatus(ctx context.Context, domain string) (*ipf if err := s.RequireAuthenticated(); err != nil { return nil, err } + if s.service == nil { + return nil, ErrServiceUnavailable + } return s.service.GetSSLStatus(ctx, domain) } diff --git a/pkg/cli/websites_ssl.go b/pkg/cli/websites_ssl.go index e36edbc..efeaad0 100644 --- a/pkg/cli/websites_ssl.go +++ b/pkg/cli/websites_ssl.go @@ -55,12 +55,12 @@ Examples: }, }, Action: withContext(func(ctx context.Context, cc *commandContext) error { - return websitesSSLStatus(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken) + return websitesSSLStatus(ctx, cc.Cmd, cc.Output, cc.CfgMgr, cc.AuthToken, cc.Secure) }), } } -func websitesSSLStatus(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string) error { +func websitesSSLStatus(ctx context.Context, cmd websitesCommandGetter, output Output, cfgMgr config.Manager, authToken string, secure bool) error { args := cmd.Args() if args.Len() == 0 { return fmt.Errorf("domain is required") @@ -69,7 +69,7 @@ func websitesSSLStatus(ctx context.Context, cmd websitesCommandGetter, output Ou domain := args.First() watch := cmd.Bool("watch") - websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken) + websitesService, err := newAuthenticatedWebsitesService(cfgMgr, output, authToken, secure) if err != nil { return err } diff --git a/pkg/cli/websites_wizard.go b/pkg/cli/websites_wizard.go index 9c7a0a3..5d932d6 100644 --- a/pkg/cli/websites_wizard.go +++ b/pkg/cli/websites_wizard.go @@ -259,12 +259,13 @@ func runWebsitesWizard(ctx context.Context, cmd *cli.Command, output Output) err var websitesService WebsitesService authToken := GetAuthToken(cmd, cfgMgr) + secure := GetSecureSetting(cmd, cfgMgr) var svcOpts []WebsitesServiceOption if authToken != "" { svcOpts = append(svcOpts, WithWebsitesAuthToken(authToken)) } - websitesService = defaultWebsitesServiceFactory(cfgMgr, output, svcOpts...) + websitesService = defaultWebsitesServiceFactory(cfgMgr, output, secure, svcOpts...) ui := NewPTermWebsitesUI(output)