diff --git a/go.mod b/go.mod index 904ec3f..0496185 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,13 @@ module github.com/containeroo/never -go 1.24.2 +go 1.24.4 toolchain go1.24.5 require ( - github.com/containeroo/dynflags v0.1.1 github.com/containeroo/httputils v0.0.1 github.com/containeroo/resolver v0.1.0 - github.com/spf13/pflag v1.0.7 + github.com/containeroo/tinyflags v0.0.26 github.com/stretchr/testify v1.10.0 golang.org/x/net v0.42.0 golang.org/x/sync v0.16.0 diff --git a/go.sum b/go.sum index d3e5479..c3728ac 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,15 @@ -github.com/containeroo/dynflags v0.1.1 h1:p6039AhBuKjfeylNwQwlQVsCEgufJDgFncrtCQmrjKE= -github.com/containeroo/dynflags v0.1.1/go.mod h1:YmSfpL9trViFWdXOz+RUXQMTFiUBkcOeVajzoh+rMVw= github.com/containeroo/httputils v0.0.1 h1:W9SbW6nbmnGgaEOXRH5nY9ZwLarBo3+FLUCx6EtS2mc= github.com/containeroo/httputils v0.0.1/go.mod h1:TrmDptapH6SBe6CeHFWY2kxy63O/9ZWBWrLKpc2SFjI= github.com/containeroo/resolver v0.1.0 h1:2QGuMeY9H1T6GE53zPCnM3Zgpp/4XtGjN0eBSR8epSs= github.com/containeroo/resolver v0.1.0/go.mod h1:98QSZmbWP7A2YR2EU2SJZe34iFPKh2SDI9zJr2vO+Z8= +github.com/containeroo/tinyflags v0.0.26 h1:iEDu1sl5I13zgqYK6E6eKDVfNR7uldps38lpBqdTjD4= +github.com/containeroo/tinyflags v0.0.26/go.mod h1:eiIeqRE1uTvO0n4oVkA27oQuKVBj46Rm+TiuNv5lT60= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/spf13/pflag v1.0.7 h1:vN6T9TfwStFPFM5XzjsvmzZkLuaLX+HS+0SeFLRgU6M= -github.com/spf13/pflag v1.0.7/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= diff --git a/internal/app/run.go b/internal/app/run.go index 829a767..f9cd9f6 100644 --- a/internal/app/run.go +++ b/internal/app/run.go @@ -9,10 +9,11 @@ import ( "os/signal" "syscall" - "github.com/containeroo/never/internal/config" "github.com/containeroo/never/internal/factory" + "github.com/containeroo/never/internal/flag" "github.com/containeroo/never/internal/logging" "github.com/containeroo/never/internal/wait" + "github.com/containeroo/tinyflags" "golang.org/x/sync/errgroup" ) @@ -23,16 +24,18 @@ func Run(ctx context.Context, version string, args []string, output io.Writer) e defer cancel() // Parse command-line flags - parsedFlags, err := config.ParseFlags(args, version) + flags, err := flag.ParseFlags(args, version) if err != nil { - if config.IsHelpRequested(err, output) { - return nil + if tinyflags.IsHelpRequested(err) || tinyflags.IsVersionRequested(err) { + fmt.Fprint(output, err.Error()) // nolint:errcheck + os.Exit(0) } + return fmt.Errorf("configuration error: %w", err) } // Initialize target checkers - checkers, err := factory.BuildCheckers(parsedFlags.DynFlags, parsedFlags.DefaultCheckInterval) + checkers, err := factory.BuildCheckers(flags.DynamicGroups, flags.DefaultCheckInterval) if err != nil { return fmt.Errorf("failed to initialize target checkers: %w", err) } diff --git a/internal/app/run_test.go b/internal/app/run_test.go index bb0ad32..0b4bac1 100644 --- a/internal/app/run_test.go +++ b/internal/app/run_test.go @@ -108,7 +108,7 @@ func TestRunConfigErrorUnsupportedCheckType(t *testing.T) { err := Run(ctx, version, args, &output) assert.Error(t, err) - assert.EqualError(t, err, "configuration error: flag parsing error: unknown flag: --target.unsupported.name") + assert.EqualError(t, err, "configuration error: unknown dynamic group: target") } func TestRunConfigErrorInvalidHeaders(t *testing.T) { @@ -149,24 +149,5 @@ func TestRunParseError(t *testing.T) { err := Run(ctx, version, args, &output) assert.Error(t, err) - assert.EqualError(t, err, "configuration error: flag parsing error: unknown flag: --invalid") -} - -func TestRunShowVersion(t *testing.T) { - t.Parallel() - - args := []string{ - "--http.invalidheaders.name=TestService", - "--http.invalidheaders.address=http://localhost:8080", - "--version", - } - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - var output bytes.Buffer - - err := Run(ctx, version, args, &output) - - assert.NoError(t, err) + assert.EqualError(t, err, "configuration error: unknown flag: --invalid") } diff --git a/internal/config/config.go b/internal/config/config.go deleted file mode 100644 index 0241a65..0000000 --- a/internal/config/config.go +++ /dev/null @@ -1,187 +0,0 @@ -package config - -import ( - "bytes" - "errors" - "fmt" - "io" - "strings" - "time" - - "github.com/containeroo/dynflags" - - flag "github.com/spf13/pflag" -) - -const ( - paramDefaultInterval string = "default-interval" - defaultCheckInterval time.Duration = 2 * time.Second - defaultHTTPAllowDuplicateHeaders bool = false - defaultHTTPSkipTLSVerify bool = false -) - -type HelpRequested struct { - Message string -} - -func (e *HelpRequested) Error() string { - return e.Message -} - -// ParsedFlags holds the parsed command-line flags. -type ParsedFlags struct { - ShowHelp bool - ShowVersion bool - Version string - DefaultCheckInterval time.Duration - DynFlags *dynflags.DynFlags -} - -// ParseFlags parses command-line arguments and returns the parsed flags. -func ParseFlags(args []string, version string) (*ParsedFlags, error) { - // Create global fs and dynamic flags - fs := setupGlobalFlags() - df := setupDynamicFlags() - - // Set up custom usage function - setupUsage(fs, df) - - // Parse unknown arguments with dynamic flags - if err := df.Parse(args); err != nil { - return nil, fmt.Errorf("error parsing dynamic flags: %w", err) - } - - unknownArgs := df.UnknownArgs() - - // Parse known flags - if err := fs.Parse(unknownArgs); err != nil { - return nil, fmt.Errorf("flag parsing error: %s", err.Error()) - } - - // Handle special flags (e.g., --help or --version) - if err := handleSpecialFlags(fs, version); err != nil { - return nil, err - } - - // Retrieve the default interval value - defaultInterval, err := getDurationFlag(fs, paramDefaultInterval, defaultCheckInterval) - if err != nil { - return nil, err - } - - return &ParsedFlags{ - DefaultCheckInterval: defaultInterval, - DynFlags: df, - }, nil -} - -// setupGlobalFlags sets up global application flags. -func setupGlobalFlags() *flag.FlagSet { - fs := flag.NewFlagSet("never", flag.ContinueOnError) - fs.SortFlags = false - - fs.Duration(paramDefaultInterval, defaultCheckInterval, "Default interval between checks. Can be overridden for each target.") - fs.Bool("version", false, "Show version and exit.") - fs.BoolP("help", "h", false, "Show help.") - - return fs -} - -// setupDynamicFlags sets up dynamic flags for HTTP, TCP, ICMP. -func setupDynamicFlags() *dynflags.DynFlags { - df := dynflags.New(dynflags.ContinueOnError) - df.Epilog("For more information, see https://github.com/containeroo/never") - df.SortGroups = true - df.SortFlags = true - - // HTTP flags - http := df.Group("http") - http.String("name", "", "Name of the HTTP checker") - http.String("method", "GET", "HTTP method to use") - http.String("address", "", "HTTP target URL") - http.Duration("interval", 1*time.Second, "Time between HTTP requests. Can be overwritten with --default-interval.") - http.StringSlices("header", nil, "HTTP headers to send") - http.Bool("allow-duplicate-headers", defaultHTTPAllowDuplicateHeaders, "Allow duplicate HTTP headers") - http.String("expected-status-codes", "200", "Expected HTTP status codes") - http.Bool("skip-tls-verify", defaultHTTPSkipTLSVerify, "Skip TLS verification") - http.Duration("timeout", 2*time.Second, "Timeout in seconds") - - // ICMP flags - icmp := df.Group("icmp") - icmp.String("name", "", "Name of the ICMP checker") - icmp.String("address", "", "ICMP target address") - icmp.Duration("interval", 1*time.Second, "Time between ICMP requests. Can be overwritten with --default-interval.") - icmp.Duration("read-timeout", 2*time.Second, "Timeout for ICMP read") - icmp.Duration("write-timeout", 2*time.Second, "Timeout for ICMP write") - - // TCP flags - tcp := df.Group("tcp") - tcp.String("name", "", "Name of the TCP checker") - tcp.String("address", "", "TCP target address") - tcp.Duration("timeout", 2*time.Second, "Timeout for TCP connection") - tcp.Duration("interval", 1*time.Second, "Time between TCP requests. Can be overwritten with --default-interval.") - - return df -} - -// setupUsage sets the custom usage function. -func setupUsage(fs *flag.FlagSet, df *dynflags.DynFlags) { - fs.Usage = func() { - out := fs.Output() // capture writer ONCE - - fmt.Fprintf(out, "Usage: %s [FLAGS] [DYNAMIC FLAGS..]\n", strings.ToLower(fs.Name())) // nolint:errcheck - - fmt.Fprintln(out, "\nGlobal Flags:") // nolint:errcheck - fs.SetOutput(out) - fs.PrintDefaults() - - fmt.Fprintln(out, "\nDynamic Flags:") // nolint:errcheck - df.SetOutput(out) - df.PrintDefaults() - } -} - -// handleSpecialFlags handles help and version flags. -func handleSpecialFlags(fs *flag.FlagSet, version string) error { - help := fs.Lookup("help") - if help != nil && help.Value.String() == "true" { - // create a buffer to capture the output to pass to the HelpRequested error message - buffer := &bytes.Buffer{} - fs.SetOutput(buffer) - fs.Usage() - return &HelpRequested{Message: buffer.String()} - } - - versionFlag := fs.Lookup("version") - if versionFlag != nil && versionFlag.Value.String() == "true" { - return &HelpRequested{Message: fmt.Sprintf("%s version %s\n", fs.Name(), version)} - } - - return nil -} - -// Example of getting a flag value as a time.Duration -func getDurationFlag(flagSet *flag.FlagSet, name string, defaultValue time.Duration) (time.Duration, error) { - flag := flagSet.Lookup(name) - if flag == nil { - return defaultValue, nil - } - - // Parse the flag val from string to time.Duration - val, err := time.ParseDuration(flag.Value.String()) - if err != nil { - return defaultValue, fmt.Errorf("invalid duration for flag '%s'", flag.Value.String()) - } - - return val, nil -} - -// IsHelpRequested checks if the error is a HelpRequested sentinel and prints it. -func IsHelpRequested(err error, w io.Writer) bool { - var helpErr *HelpRequested - if errors.As(err, &helpErr) { - fmt.Fprint(w, helpErr.Error()) // nolint:errcheck - return true - } - return false -} diff --git a/internal/config/config_test.go b/internal/config/config_test.go deleted file mode 100644 index c1245ed..0000000 --- a/internal/config/config_test.go +++ /dev/null @@ -1,223 +0,0 @@ -package config - -import ( - "bytes" - "errors" - "fmt" - "strings" - "testing" - "time" - - "github.com/spf13/pflag" - "github.com/stretchr/testify/assert" -) - -func TestParseFlags(t *testing.T) { - t.Parallel() - - t.Run("Successful Parsing", func(t *testing.T) { - t.Parallel() - - args := []string{"--default-interval=5s"} - - parsedFlags, err := ParseFlags(args, "1.0.0") - assert.NoError(t, err) - assert.Equal(t, 5*time.Second, parsedFlags.DefaultCheckInterval) - }) - - t.Run("Handle Help Flag", func(t *testing.T) { - t.Parallel() - - var output strings.Builder - flagSet := setupGlobalFlags() - flagSet.SetOutput(&output) // Ensure output is properly set - _ = flagSet.Parse([]string{"--help"}) - - flagSet.Usage = func() { - fmt.Fprintln(&output, "Usage: never [FLAGS] [DYNAMIC FLAGS..]") - } - - err := handleSpecialFlags(flagSet, "1.0.0") - assert.Error(t, err) - assert.IsType(t, &HelpRequested{}, err) - assert.Contains(t, output.String(), "Usage: never [FLAGS] [DYNAMIC FLAGS..]") - }) - - t.Run("Show Version Flag", func(t *testing.T) { - t.Parallel() - - args := []string{"--version"} - - _, err := ParseFlags(args, "1.0.0") - assert.Error(t, err) - assert.IsType(t, &HelpRequested{}, err) - assert.Contains(t, err.Error(), "never version 1.0.0") - }) - - t.Run("Invalid Duration Flag", func(t *testing.T) { - t.Parallel() - - args := []string{"--default-interval=invalid"} - - _, err := ParseFlags(args, "1.0.0") - assert.Error(t, err) - - assert.EqualError(t, err, "flag parsing error: invalid argument \"invalid\" for \"--default-interval\" flag: time: invalid duration \"invalid\"") - }) -} - -func TestSetupGlobalFlags(t *testing.T) { - t.Parallel() - - flagSet := setupGlobalFlags() - assert.NotNil(t, flagSet.Lookup("default-interval")) - assert.NotNil(t, flagSet.Lookup("version")) - assert.NotNil(t, flagSet.Lookup("help")) -} - -func TestSetupDynamicFlags(t *testing.T) { - t.Parallel() - - dynFlags := setupDynamicFlags() - assert.NotNil(t, dynFlags.Group("http")) - assert.NotNil(t, dynFlags.Group("tcp")) - assert.NotNil(t, dynFlags.Group("icmp")) - - httpGroup := dynFlags.Group("http") - assert.NotNil(t, httpGroup.Lookup("name")) - assert.NotNil(t, httpGroup.Lookup("method")) - assert.NotNil(t, httpGroup.Lookup("address")) -} - -func TestSetupUsage(t *testing.T) { - t.Parallel() - - var output strings.Builder - flagSet := setupGlobalFlags() - flagSet.SetOutput(&output) - - dynFlags := setupDynamicFlags() - dynFlags.SetOutput(&output) - - setupUsage(flagSet, dynFlags) - flagSet.Usage() - - usageOutput := output.String() - assert.Contains(t, usageOutput, "Usage: never [FLAGS] [DYNAMIC FLAGS..]") - assert.Contains(t, usageOutput, "Global Flags:") - assert.Contains(t, usageOutput, "--default-interval") - assert.Contains(t, usageOutput, "Dynamic Flags:") - assert.Contains(t, usageOutput, "http") -} - -func TestHandleSpecialFlags(t *testing.T) { - t.Parallel() - - t.Run("Handle Help Flag", func(t *testing.T) { - t.Parallel() - - var output strings.Builder - flagSet := setupGlobalFlags() - flagSet.SetOutput(&output) - - flagSet.Usage = func() { - fmt.Fprintln(&output, "Usage: never [FLAGS] [DYNAMIC FLAGS..]") - } - - args := []string{"--help"} - err := flagSet.Parse(args) - assert.NoError(t, err) - - err = handleSpecialFlags(flagSet, "1.0.0") - assert.Error(t, err) - }) - - t.Run("Handle Version Flag", func(t *testing.T) { - t.Parallel() - - flagSet := setupGlobalFlags() - _ = flagSet.Parse([]string{"--version"}) - - err := handleSpecialFlags(flagSet, "1.0.0") - assert.Error(t, err) - assert.IsType(t, &HelpRequested{}, err) - assert.Contains(t, err.Error(), "never version 1.0.0") - }) - - t.Run("No Special Flags", func(t *testing.T) { - t.Parallel() - - flagSet := setupGlobalFlags() - _ = flagSet.Parse([]string{}) - - err := handleSpecialFlags(flagSet, "1.0.0") - assert.NoError(t, err) - }) -} - -func TestGetDurationFlag(t *testing.T) { - t.Parallel() - - t.Run("Valid Duration Flag", func(t *testing.T) { - t.Parallel() - - flagSet := setupGlobalFlags() - _ = flagSet.Set("default-interval", "10s") - - duration, err := getDurationFlag(flagSet, "default-interval", time.Second) - assert.NoError(t, err) - assert.Equal(t, 10*time.Second, duration) - }) - - t.Run("Invalid Duration", func(t *testing.T) { - t.Parallel() - - flagSet := pflag.NewFlagSet("never", pflag.ContinueOnError) - flagSet.String("invalid-flag", "invalid", "Invalid flag") - err := flagSet.Set("invalid-flag", "invalid") - assert.NoError(t, err) - - _, err = getDurationFlag(flagSet, "invalid-flag", time.Second) - assert.Error(t, err) - assert.EqualError(t, err, "invalid duration for flag 'invalid'") - }) - - t.Run("Missing Duration Flag", func(t *testing.T) { - t.Parallel() - - flagSet := setupGlobalFlags() - - duration, err := getDurationFlag(flagSet, "non-existent-flag", time.Second) - assert.NoError(t, err) - assert.Equal(t, time.Second, duration) - }) -} - -func TestIsHelpRequested(t *testing.T) { - t.Parallel() - - t.Run("returns true and writes message for HelpRequested error", func(t *testing.T) { - t.Parallel() - - buf := &bytes.Buffer{} - helpMsg := "this is the help message\n" - err := &HelpRequested{Message: helpMsg} - - ok := IsHelpRequested(err, buf) - - assert.True(t, ok) - assert.Equal(t, helpMsg, buf.String()) - }) - - t.Run("returns false and writes nothing for unrelated error", func(t *testing.T) { - t.Parallel() - - buf := &bytes.Buffer{} - err := errors.New("some other error") - - ok := IsHelpRequested(err, buf) - - assert.False(t, ok) - assert.Equal(t, "", buf.String()) - }) -} diff --git a/internal/factory/factory.go b/internal/factory/factory.go index 66859f7..fef3791 100644 --- a/internal/factory/factory.go +++ b/internal/factory/factory.go @@ -5,10 +5,10 @@ import ( "strings" "time" - "github.com/containeroo/dynflags" "github.com/containeroo/httputils" "github.com/containeroo/never/internal/checker" "github.com/containeroo/resolver" + "github.com/containeroo/tinyflags" ) // CheckerWithInterval represents a checker with its interval. @@ -18,94 +18,77 @@ type CheckerWithInterval struct { } // BuildCheckers creates a list of CheckerWithInterval from the parsed dynflags configuration. -func BuildCheckers(dynFlags *dynflags.DynFlags, defaultInterval time.Duration) ([]CheckerWithInterval, error) { +func BuildCheckers(dynamicGroups []*tinyflags.DynamicGroup, defaultInterval time.Duration) ([]CheckerWithInterval, error) { var checkers []CheckerWithInterval - // Iterate over all parsed groups - for parentName, childGroups := range dynFlags.Parsed().Groups() { - checkType, err := checker.ParseCheckType(parentName) + for _, group := range dynamicGroups { + checkType, err := checker.ParseCheckType(group.Name()) if err != nil { - return nil, fmt.Errorf("invalid check type '%s': %w", parentName, err) + return nil, err } - // Process each parsed group (child) under the parent group - for _, group := range childGroups { - address, err := group.GetString("address") + for _, id := range group.Instances() { + address := tinyflags.GetOrDefaultDynamic[string](group, id, "address") + resolvedAddr, err := resolver.ResolveVariable(address) if err != nil { - return nil, fmt.Errorf("missing address for %s checker: %w", parentName, err) + return nil, fmt.Errorf("invalid variable in address: %w", err) } - resolvedAddress, err := resolver.ResolveVariable(address) - if err != nil { - return nil, fmt.Errorf("failed to resolve variable in address: %w", err) - } - - // Default interval for the checker interval := defaultInterval - if customInterval, err := group.GetDuration("interval"); err == nil { - interval = customInterval + if v, _ := tinyflags.GetDynamic[time.Duration](group, id, "interval"); v != 0 { + interval = v } - // Prepare options based on the checker type var opts []checker.Option switch checkType { case checker.HTTP: - if method, err := group.GetString("method"); err == nil { - opts = append(opts, checker.WithHTTPMethod(method)) + method := tinyflags.GetOrDefaultDynamic[string](group, id, "method") + opts = append(opts, checker.WithHTTPMethod(method)) + + headers := tinyflags.GetOrDefaultDynamic[[]string](group, id, "header") + allowDup := tinyflags.GetOrDefaultDynamic[bool](group, id, "allow-duplicate-headers") + headersMap, err := createHTTPHeadersMap(headers, allowDup) + if err != nil { + return nil, fmt.Errorf("invalid \"--%s.%s.header\": %w", group.Name(), id, err) } + opts = append(opts, checker.WithHTTPHeaders(headersMap)) - allowDuplicateHeaders, _ := group.GetBool("allow-duplicate-headers") // Type is checked when parsing - if headers, err := group.GetStringSlices("header"); err == nil { - headersMap, err := createHTTPHeadersMap(headers, allowDuplicateHeaders) - if err != nil { - return nil, fmt.Errorf("invalid \"--%s.%s.header\": %w", parentName, group.Name, err) - } - opts = append(opts, checker.WithHTTPHeaders(headersMap)) - } - - if allowedStatusCodes, err := group.GetString("expected-status-codes"); err == nil { - statusCodes, err := httputils.ParseStatusCodes(allowedStatusCodes) - if err != nil { - return nil, fmt.Errorf("invalid \"--%s.%s.expected-status-codes\": %w", parentName, group.Name, err) - } - - opts = append(opts, checker.WithExpectedStatusCodes(statusCodes)) + codeStr := tinyflags.GetOrDefaultDynamic[string](group, id, "expected-status-codes") + codes, err := httputils.ParseStatusCodes(codeStr) + if err != nil { + return nil, fmt.Errorf("invalid --%s.%s.expected-status-codes: %w", group.Name(), id, err) } + opts = append(opts, checker.WithExpectedStatusCodes(codes)) - if skipTLS, err := group.GetBool("skip-tls-verify"); err == nil { - opts = append(opts, checker.WithHTTPSkipTLSVerify(skipTLS)) - } + skipTLS := tinyflags.GetOrDefaultDynamic[bool](group, id, "skip-tls-verify") + opts = append(opts, checker.WithHTTPSkipTLSVerify(skipTLS)) - if timeout, err := group.GetDuration("timeout"); err == nil { - opts = append(opts, checker.WithHTTPTimeout(timeout)) - } + timeout := tinyflags.GetOrDefaultDynamic[time.Duration](group, id, "timeout") + opts = append(opts, checker.WithHTTPTimeout(timeout)) case checker.TCP: - if timeout, err := group.GetDuration("timeout"); err == nil { - opts = append(opts, checker.WithHTTPTimeout(timeout)) // Could have a TCP-specific timeout option - } + timeout := tinyflags.GetOrDefaultDynamic[time.Duration](group, id, "timeout") + opts = append(opts, checker.WithHTTPTimeout(timeout)) case checker.ICMP: - if readTimeout, err := group.GetDuration("read-timeout"); err == nil { - opts = append(opts, checker.WithICMPReadTimeout(readTimeout)) - } - if writeTimeout, err := group.GetDuration("write-timeout"); err == nil { - opts = append(opts, checker.WithICMPWriteTimeout(writeTimeout)) - } + rt := tinyflags.GetOrDefaultDynamic[time.Duration](group, id, "read-timeout") + opts = append(opts, checker.WithICMPReadTimeout(rt)) + + wt := tinyflags.GetOrDefaultDynamic[time.Duration](group, id, "write-timeout") + opts = append(opts, checker.WithICMPWriteTimeout(wt)) } - name, _ := group.GetString("name") - if name == "" { - name = group.Name + name := id + if n := tinyflags.GetOrDefaultDynamic[string](group, id, "name"); n != "" { + name = n } - instance, err := checker.NewChecker(checkType, name, resolvedAddress, opts...) + instance, err := checker.NewChecker(checkType, name, resolvedAddr, opts...) if err != nil { - return nil, fmt.Errorf("failed to create %s checker: %w", parentName, err) + return nil, fmt.Errorf("failed to create %s checker: %w", checkType, err) } - // Wrap the checker with its interval and add to the list checkers = append(checkers, CheckerWithInterval{ Interval: interval, Checker: instance, diff --git a/internal/factory/factory_test.go b/internal/factory/factory_test.go index 02bc189..6dae773 100644 --- a/internal/factory/factory_test.go +++ b/internal/factory/factory_test.go @@ -4,9 +4,10 @@ import ( "testing" "time" - "github.com/containeroo/dynflags" "github.com/containeroo/never/internal/factory" + "github.com/containeroo/tinyflags" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBuildCheckers(t *testing.T) { @@ -15,14 +16,17 @@ func TestBuildCheckers(t *testing.T) { t.Run("Valid HTTP Checker", func(t *testing.T) { t.Parallel() - df := dynflags.New(dynflags.ContinueOnError) - httpGroup := df.Group("http") - httpGroup.String("address", "http://example.com", "HTTP target address") - httpGroup.String("method", "GET", "HTTP method") - httpGroup.Duration("interval", 5*time.Second, "Request interval") - httpGroup.StringSlices("header", nil, "HTTP header") - httpGroup.Bool("skip-tls-verify", false, "Skip TLS verification") - httpGroup.Duration("timeout", 2*time.Second, "Timeout") + tf := tinyflags.NewFlagSet("test.exe", tinyflags.ContinueOnError) + http := tf.DynamicGroup("http") + http.String("name", "", "Name of the HTTP checker") + http.String("method", "GET", "HTTP method to use") + http.String("address", "", "HTTP target URL") + http.Duration("interval", 1*time.Second, "Time between HTTP requests. Can be overwritten with --default-interval.") + http.StringSlice("header", nil, "HTTP headers to send") + http.Bool("allow-duplicate-headers", true, "Allow duplicate HTTP headers") + http.String("expected-status-codes", "200", "Expected HTTP status codes") + http.Bool("skip-tls-verify", true, "Skip TLS verification") + http.Duration("timeout", 22*time.Second, "Timeout in seconds") args := []string{ "--http.mygroup.address=http://example.com", @@ -30,66 +34,52 @@ func TestBuildCheckers(t *testing.T) { "--http.mygroup.interval=5s", "--http.mygroup.header=Content-Type=application/json", "--http.mygroup.skip-tls-verify=true", - "--http.mygroup.timeout=2s", + "--http.mygroup.timeout=33s", } - err := df.Parse(args) + err := tf.Parse(args) assert.NoError(t, err) - checkers, err := factory.BuildCheckers(df, 2*time.Second) + checkers, err := factory.BuildCheckers(tf.DynamicGroups(), 9*time.Second) assert.NoError(t, err) assert.Len(t, checkers, 1) assert.Equal(t, "http://example.com", checkers[0].Checker.Address()) assert.Equal(t, 5*time.Second, checkers[0].Interval) }) - t.Run("Missing Address", func(t *testing.T) { - t.Parallel() - - df := dynflags.New(dynflags.ContinueOnError) - httpGroup := df.Group("http") - httpGroup.String("method", "GET", "HTTP method") - - args := []string{"--http.mygroup.method=GET"} - err := df.Parse(args) - assert.NoError(t, err) - - checkers, err := factory.BuildCheckers(df, 2*time.Second) - assert.Nil(t, checkers) - assert.ErrorContains(t, err, "missing address for http checker") - }) - t.Run("Invalid Check Type", func(t *testing.T) { t.Parallel() - df := dynflags.New(dynflags.ContinueOnError) - invalidGroup := df.Group("invalid") + tf := tinyflags.NewFlagSet("test.exe", tinyflags.ContinueOnError) + invalidGroup := tf.DynamicGroup("invalid") invalidGroup.String("address", "invalid-address", "Invalid target address") args := []string{"--invalid.mygroup.address=invalid-address"} - err := df.Parse(args) + err := tf.Parse(args) assert.NoError(t, err) - checkers, err := factory.BuildCheckers(df, 2*time.Second) + checkers, err := factory.BuildCheckers(tf.DynamicGroups(), 2*time.Second) assert.Nil(t, checkers) - assert.ErrorContains(t, err, "invalid check type 'invalid'") + assert.EqualError(t, err, "unsupported check type: invalid") }) t.Run("Invalid Header Parsing", func(t *testing.T) { t.Parallel() - df := dynflags.New(dynflags.ContinueOnError) - httpGroup := df.Group("http") + tf := tinyflags.NewFlagSet("test.exe", tinyflags.ContinueOnError) + httpGroup := tf.DynamicGroup("http") httpGroup.String("address", "http://example.com", "HTTP target address") - httpGroup.StringSlices("header", []string{}, "HTTP headers") + httpGroup.StringSlice("header", []string{}, "HTTP headers") + httpGroup.String("method", "GET", "HTTP method to use") + httpGroup.Bool("allow-duplicate-headers", true, "Allow duplicate HTTP headers") args := []string{ "--http.mygroup.address=http://example.com", "--http.mygroup.header=InvalidHeaderFormat", } - err := df.Parse(args) + err := tf.Parse(args) assert.NoError(t, err) - checkers, err := factory.BuildCheckers(df, 2*time.Second) + checkers, err := factory.BuildCheckers(tf.DynamicGroups(), 2*time.Second) assert.Error(t, err) assert.EqualError(t, err, "invalid \"--http.mygroup.header\": invalid header format: \"InvalidHeaderFormat\"") @@ -100,21 +90,26 @@ func TestBuildCheckers(t *testing.T) { t.Run("Inalid HTTP Status codes", func(t *testing.T) { t.Parallel() - df := dynflags.New(dynflags.ContinueOnError) - httpGroup := df.Group("http") - httpGroup.String("address", "http://example.com", "HTTP target address") - httpGroup.String("expected-status-codes", "400,401", "HTTP expected status codes") + tf := tinyflags.NewFlagSet("test.exe", tinyflags.ContinueOnError) + http := tf.DynamicGroup("http") + http.String("name", "", "Name of the HTTP checker") + http.String("method", "GET", "HTTP method to use") + http.String("address", "", "HTTP target URL") + http.Duration("interval", 1*time.Second, "Time between HTTP requests. Can be overwritten with --default-interval.") + http.StringSlice("header", nil, "HTTP headers to send") + http.Bool("allow-duplicate-headers", true, "Allow duplicate HTTP headers") + http.String("expected-status-codes", "200", "Expected HTTP status codes") + http.Bool("skip-tls-verify", true, "Skip TLS verification") + http.Duration("timeout", 2*time.Second, "Timeout in seconds") args := []string{ - "--http.mygroup.address=http://example.com", - "--http.mygroup.expected-status-codes=201-200", + "--http.myid.address=http://example.com", + "--http.myid.expected-status-codes=201-200", } - err := df.Parse(args) + err := tf.Parse(args) assert.NoError(t, err) - res := httpGroup.Lookup("expected-status-codes").GetValue() - assert.Equal(t, "201-200", res) - checkers, err := factory.BuildCheckers(df, 2*time.Second) + checkers, err := factory.BuildCheckers(tf.DynamicGroups(), 2*time.Second) assert.Error(t, err) assert.Len(t, checkers, 0) }) @@ -122,19 +117,25 @@ func TestBuildCheckers(t *testing.T) { t.Run("Valid HTTP Status codes", func(t *testing.T) { t.Parallel() - df := dynflags.New(dynflags.ContinueOnError) - httpGroup := df.Group("http") + tf := tinyflags.NewFlagSet("test.exe", tinyflags.ContinueOnError) + httpGroup := tf.DynamicGroup("http") + httpGroup.String("name", "", "Name of the HTTP checker. Defaults to .") httpGroup.String("address", "http://example.com", "HTTP target address") httpGroup.String("expected-status-codes", "200,201", "HTTP expected status codes") + httpGroup.StringSlice("header", []string{}, "HTTP headers") + httpGroup.String("method", "GET", "HTTP method to use") + httpGroup.Bool("allow-duplicate-headers", true, "Allow duplicate HTTP headers") + httpGroup.Bool("skip-tls-verify", true, "Skip TLS verification") + httpGroup.Duration("timeout", 2*time.Second, "Timeout in seconds") args := []string{ "--http.mygroup.address=http://example.com", "--http.mygroup.expected-status-codes=200,201", } - err := df.Parse(args) - assert.NoError(t, err) + err := tf.Parse(args) + require.NoError(t, err) - checkers, err := factory.BuildCheckers(df, 2*time.Second) + checkers, err := factory.BuildCheckers(tf.DynamicGroups(), 2*time.Second) assert.NoError(t, err) assert.Len(t, checkers, 1) }) @@ -142,8 +143,9 @@ func TestBuildCheckers(t *testing.T) { t.Run("Valid TCP Checker", func(t *testing.T) { t.Parallel() - df := dynflags.New(dynflags.ContinueOnError) - tcpGroup := df.Group("tcp") + tf := tinyflags.NewFlagSet("test.exe", tinyflags.ContinueOnError) + tcpGroup := tf.DynamicGroup("tcp") + tcpGroup.String("name", "", "Name of the HTTP checker. Defaults to .") tcpGroup.String("address", "127.0.0.1:8080", "TCP target address") tcpGroup.Duration("timeout", 3*time.Second, "Timeout") @@ -151,10 +153,10 @@ func TestBuildCheckers(t *testing.T) { "--tcp.mygroup.address=127.0.0.1:8080", "--tcp.mygroup.timeout=3s", } - err := df.Parse(args) + err := tf.Parse(args) assert.NoError(t, err) - checkers, err := factory.BuildCheckers(df, 2*time.Second) + checkers, err := factory.BuildCheckers(tf.DynamicGroups(), 2*time.Second) assert.NoError(t, err) assert.Len(t, checkers, 1) assert.Equal(t, "127.0.0.1:8080", checkers[0].Checker.Address()) @@ -163,8 +165,9 @@ func TestBuildCheckers(t *testing.T) { t.Run("Valid ICMP Checker", func(t *testing.T) { t.Parallel() - df := dynflags.New(dynflags.ContinueOnError) - icmpGroup := df.Group("icmp") + tf := tinyflags.NewFlagSet("test.exe", tinyflags.ContinueOnError) + icmpGroup := tf.DynamicGroup("icmp") + icmpGroup.String("name", "", "Name of the ICMP checker. Defaults to .") icmpGroup.String("address", "8.8.8.8", "ICMP target address") icmpGroup.Duration("read-timeout", 2*time.Second, "Read timeout") icmpGroup.Duration("write-timeout", 2*time.Second, "Write timeout") @@ -174,10 +177,10 @@ func TestBuildCheckers(t *testing.T) { "--icmp.mygroup.read-timeout=2s", "--icmp.mygroup.write-timeout=2s", } - err := df.Parse(args) + err := tf.Parse(args) assert.NoError(t, err) - checkers, err := factory.BuildCheckers(df, 2*time.Second) + checkers, err := factory.BuildCheckers(tf.DynamicGroups(), 2*time.Second) assert.NoError(t, err) assert.Len(t, checkers, 1) assert.Equal(t, "8.8.8.8", checkers[0].Checker.Address()) @@ -186,35 +189,22 @@ func TestBuildCheckers(t *testing.T) { t.Run("Invalid ICMP Checker", func(t *testing.T) { t.Parallel() - df := dynflags.New(dynflags.ContinueOnError) - icmpGroup := df.Group("icmp") + tf := tinyflags.NewFlagSet("test.exe", tinyflags.ContinueOnError) + icmpGroup := tf.DynamicGroup("icmp") + icmpGroup.String("name", "", "Name of the TCP checker. Defaults to .") icmpGroup.String("address", "8.8.8.8", "ICMP target address") + icmpGroup.Duration("read-timeout", 2*time.Second, "Read timeout") + icmpGroup.Duration("write-timeout", 2*time.Second, "Write timeout") args := []string{ "--icmp.mygroup.address=://invalid-url", } - err := df.Parse(args) + err := tf.Parse(args) assert.NoError(t, err) - checker, err := factory.BuildCheckers(df, 2*time.Second) + checker, err := factory.BuildCheckers(tf.DynamicGroups(), 2*time.Second) assert.Nil(t, checker) assert.Error(t, err) }) - - t.Run("Checker Creation Failure", func(t *testing.T) { - t.Parallel() - - df := dynflags.New(dynflags.ContinueOnError) - httpGroup := df.Group("http") - httpGroup.String("address", "", "HTTP target address") - - args := []string{"--http.mygroup.address="} - err := df.Parse(args) - assert.NoError(t, err) - - checkers, err := factory.BuildCheckers(df, 2*time.Second) - assert.NotNil(t, checkers) - assert.NoError(t, err) - }) } diff --git a/internal/flag/flags.go b/internal/flag/flags.go new file mode 100644 index 0000000..1b48a2c --- /dev/null +++ b/internal/flag/flags.go @@ -0,0 +1,122 @@ +package flag + +import ( + "errors" + "fmt" + "net" + "net/url" + "time" + + "github.com/containeroo/tinyflags" +) + +const ( + paramDefaultInterval string = "default-interval" + defaultCheckInterval time.Duration = 2 * time.Second + defaultHTTPAllowDuplicateHeaders bool = false + defaultHTTPSkipTLSVerify bool = false +) + +// ParsedFlags holds the parsed command-line flags. +type ParsedFlags struct { + ShowHelp bool + ShowVersion bool + Version string + DefaultCheckInterval time.Duration + DynamicGroups []*tinyflags.DynamicGroup +} + +// ParseFlags parses command-line arguments and returns the parsed flags. +func ParseFlags(args []string, version string) (*ParsedFlags, error) { + tf := tinyflags.NewFlagSet("never", tinyflags.ContinueOnError) + tf.Version(version) + tf.SortedFlags() + tf.SortedGroups() + + interval := tf.Duration( + paramDefaultInterval, + 15*time.Second, + "Default interval between checks. Can be overridden for each target.", + ). + Placeholder("DURATION"). + Value() + + tf.Note("\nFor more information, see https://github.com/containeroo/never") + + // HTTP flags + http := tf.DynamicGroup("http").Title("HTTP") + http.String("name", "", "Name of the HTTP checker. Defaults to .") + http.String("method", "GET", "HTTP method to use") + http.String("address", "", "HTTP target URL"). + Validate(func(s string) error { + u, err := url.Parse(s) + if err != nil || u.Host == "" { + return fmt.Errorf("invalid URL: %q", s) + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("unsupported scheme: %q", u.Scheme) + } + return nil + }). + Required() + http.Duration("interval", 1*time.Second, "Time between HTTP requests. Can be overwritten with --default-interval."). + Placeholder("DURATION") + http.StringSlice("header", []string{}, "HTTP headers to send") + http.Bool("allow-duplicate-headers", defaultHTTPAllowDuplicateHeaders, "Allow duplicate HTTP headers") + http.String("expected-status-codes", "200", "Expected HTTP status codes"). + Placeholder("CODE") + http.Bool("skip-tls-verify", defaultHTTPSkipTLSVerify, "Skip TLS verification") + http.Duration("timeout", 2*time.Second, "Timeout in seconds"). + Placeholder("DURATION") + + // ICMP flags + icmp := tf.DynamicGroup("icmp").Title("ICMP") + icmp.String("name", "", "Name of the ICMP checker. Defaults to .") + icmp.String("address", "", "ICMP target address"). + Validate(func(s string) error { + if ip := net.ParseIP(s); ip != nil { + return nil + } + u, err := url.Parse(s) + if err != nil || u.Host == "" { + return fmt.Errorf("invalid URL: %q", s) + } + if u.Scheme != "" { + return errors.New("ICMP check cannot have a scheme") + } + return nil + }). + Required() + icmp.Duration("interval", 1*time.Second, "Time between ICMP requests. Can be overwritten with --default-interval."). + Placeholder("DURATION") + icmp.Duration("read-timeout", 2*time.Second, "Timeout for ICMP read"). + Placeholder("DURATION") + icmp.Duration("write-timeout", 2*time.Second, "Timeout for ICMP write"). + Placeholder("DURATION") + + // TCP flags + tcp := tf.DynamicGroup("tcp").Title("TCP") + tcp.String("name", "", "Name of the TCP checker. Defaults to .") + tcp.String("address", "", "TCP target address"). + Validate(func(s string) error { + if _, _, err := net.SplitHostPort(s); err != nil { + return fmt.Errorf("TCP address must be host:port (e.g. 127.0.0.1:80): %w", err) + } + return nil + }). + Required() + tcp.Duration("timeout", 2*time.Second, "Timeout for TCP connection"). + Placeholder("DURATION") + tcp.Duration("interval", 1*time.Second, "Time between TCP requests. Can be overwritten with --default-interval."). + Placeholder("DURATION") + + // Parse unknown arguments with dynamic flags + if err := tf.Parse(args); err != nil { + return nil, err + } + + return &ParsedFlags{ + DefaultCheckInterval: *interval, + DynamicGroups: tf.DynamicGroups(), + }, nil +} diff --git a/internal/flag/flags_test.go b/internal/flag/flags_test.go new file mode 100644 index 0000000..62800a8 --- /dev/null +++ b/internal/flag/flags_test.go @@ -0,0 +1,54 @@ +package flag + +import ( + "errors" + "testing" + "time" + + "github.com/containeroo/tinyflags" + "github.com/stretchr/testify/assert" +) + +func TestParseFlags(t *testing.T) { + t.Parallel() + + t.Run("Successful Parsing", func(t *testing.T) { + t.Parallel() + + args := []string{"--default-interval=5s"} + + parsedFlags, err := ParseFlags(args, "1.0.0") + assert.NoError(t, err) + assert.Equal(t, 5*time.Second, parsedFlags.DefaultCheckInterval) + }) + + t.Run("Handle Help Flag", func(t *testing.T) { + t.Parallel() + + _, err := ParseFlags([]string{"--help"}, "1.0.0") + assert.Error(t, err) + }) + + t.Run("Show Version Flag", func(t *testing.T) { + t.Parallel() + + args := []string{"--version"} + + _, err := ParseFlags(args, "1.0.0") + assert.Error(t, err) + var verr *tinyflags.VersionRequested + assert.True(t, errors.As(err, &verr), "expected VersionRequested error") + assert.EqualError(t, err, "1.0.0") + }) + + t.Run("Invalid Duration Flag", func(t *testing.T) { + t.Parallel() + + args := []string{"--default-interval=invalid"} + + _, err := ParseFlags(args, "1.0.0") + assert.Error(t, err) + + assert.EqualError(t, err, "invalid value for flag --default-interval: time: invalid duration \"invalid\".") + }) +}