From aca97de52ba0a46fb07d076f40c8bbfd9e7b5857 Mon Sep 17 00:00:00 2001 From: Ilya Voronin Date: Tue, 20 Jan 2026 22:28:50 +0200 Subject: [PATCH] Replace internal/argsieve with external ivoronin/argsieve library --- go.mod | 1 + go.sum | 2 + internal/app/list.go | 2 +- internal/app/scp.go | 2 +- internal/app/sftp.go | 2 +- internal/app/ssh.go | 2 +- internal/app/ssm.go | 2 +- internal/app/tunnel_session.go | 2 +- internal/argsieve/argsieve.go | 353 --------------- internal/argsieve/argsieve_test.go | 675 ----------------------------- 10 files changed, 9 insertions(+), 1034 deletions(-) delete mode 100644 internal/argsieve/argsieve.go delete mode 100644 internal/argsieve/argsieve_test.go diff --git a/go.mod b/go.mod index e804ddc..bebb659 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/hashicorp/hc-install v0.9.2 github.com/hashicorp/terraform-exec v0.24.0 + github.com/ivoronin/argsieve v0.0.2 github.com/mmmorris1975/ssm-session-client v0.403.0 github.com/rogpeppe/go-internal v1.14.1 github.com/stretchr/testify v1.11.1 diff --git a/go.sum b/go.sum index 40d8d7a..53fc0e8 100644 --- a/go.sum +++ b/go.sum @@ -89,6 +89,8 @@ github.com/hashicorp/terraform-exec v0.24.0 h1:mL0xlk9H5g2bn0pPF6JQZk5YlByqSqrO5 github.com/hashicorp/terraform-exec v0.24.0/go.mod h1:lluc/rDYfAhYdslLJQg3J0oDqo88oGQAdHR+wDqFvo4= github.com/hashicorp/terraform-json v0.27.1 h1:zWhEracxJW6lcjt/JvximOYyc12pS/gaKSy/wzzE7nY= github.com/hashicorp/terraform-json v0.27.1/go.mod h1:GzPLJ1PLdUG5xL6xn1OXWIjteQRT2CNT9o/6A9mi9hE= +github.com/ivoronin/argsieve v0.0.2 h1:7kIMMuNo00Y+rfLwYBQnA0ibajuUs6m1trurQ+DuHj8= +github.com/ivoronin/argsieve v0.0.2/go.mod h1:ZNinyee+AUAUdYSMDpKakjKTRu3orpH9dLAHbpdv+oU= github.com/ivoronin/ssm-session-client v0.0.0-20251210165256-7a67290e8efb h1:KgxBdscAIlTYg2ixIaN7T6ZxlxkxWPLq+XgFcqPiuhE= github.com/ivoronin/ssm-session-client v0.0.0-20251210165256-7a67290e8efb/go.mod h1:GUMBk2MQJbA5COsPY9LAK4v6g1XQFwUZ40g60rj393A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= diff --git a/internal/app/list.go b/internal/app/list.go index 2a38fd1..a9b71b2 100644 --- a/internal/app/list.go +++ b/internal/app/list.go @@ -10,7 +10,7 @@ import ( "text/tabwriter" "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/ivoronin/ec2ssh/internal/argsieve" + "github.com/ivoronin/argsieve" "github.com/ivoronin/ec2ssh/internal/awsclient" "github.com/ivoronin/ec2ssh/internal/ec2client" ) diff --git a/internal/app/scp.go b/internal/app/scp.go index 0504e82..1345e23 100644 --- a/internal/app/scp.go +++ b/internal/app/scp.go @@ -3,7 +3,7 @@ package app import ( "fmt" - "github.com/ivoronin/ec2ssh/internal/argsieve" + "github.com/ivoronin/argsieve" "github.com/ivoronin/ec2ssh/internal/ssh" ) diff --git a/internal/app/sftp.go b/internal/app/sftp.go index e7bf228..6f0962c 100644 --- a/internal/app/sftp.go +++ b/internal/app/sftp.go @@ -3,7 +3,7 @@ package app import ( "fmt" - "github.com/ivoronin/ec2ssh/internal/argsieve" + "github.com/ivoronin/argsieve" "github.com/ivoronin/ec2ssh/internal/ssh" ) diff --git a/internal/app/ssh.go b/internal/app/ssh.go index 9158b42..91af3a5 100644 --- a/internal/app/ssh.go +++ b/internal/app/ssh.go @@ -3,7 +3,7 @@ package app import ( "fmt" - "github.com/ivoronin/ec2ssh/internal/argsieve" + "github.com/ivoronin/argsieve" "github.com/ivoronin/ec2ssh/internal/ssh" ) diff --git a/internal/app/ssm.go b/internal/app/ssm.go index c9d5dbe..9dbb3c1 100644 --- a/internal/app/ssm.go +++ b/internal/app/ssm.go @@ -8,7 +8,7 @@ import ( "os" "time" - "github.com/ivoronin/ec2ssh/internal/argsieve" + "github.com/ivoronin/argsieve" "github.com/ivoronin/ec2ssh/internal/awsclient" "github.com/ivoronin/ec2ssh/internal/ec2client" "github.com/ivoronin/ec2ssh/internal/ssh" diff --git a/internal/app/tunnel_session.go b/internal/app/tunnel_session.go index ae4fe2f..b7f71cf 100644 --- a/internal/app/tunnel_session.go +++ b/internal/app/tunnel_session.go @@ -8,7 +8,7 @@ import ( "os" "strconv" - "github.com/ivoronin/ec2ssh/internal/argsieve" + "github.com/ivoronin/argsieve" "github.com/ivoronin/ec2ssh/internal/tunnel" "github.com/mmmorris1975/ssm-session-client/ssmclient" ) diff --git a/internal/argsieve/argsieve.go b/internal/argsieve/argsieve.go deleted file mode 100644 index 6f10309..0000000 --- a/internal/argsieve/argsieve.go +++ /dev/null @@ -1,353 +0,0 @@ -// Package argsieve provides argument parsing with two modes: -// - Sift: extracts known flags, passes unknown flags through (for CLI wrappers) -// - Parse: strict parsing that errors on unknown flags -package argsieve - -import ( - "encoding" - "errors" - "fmt" - "iter" - "reflect" - "slices" - "strings" -) - -// ErrParse indicates a parsing error (e.g., missing value or unknown option). -var ErrParse = errors.New("argument parsing error") - -// textUnmarshalerType is used to check if a type implements encoding.TextUnmarshaler. -var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() - -// fieldInfo holds a reference to a struct field and whether it needs an argument. -type fieldInfo struct { - field reflect.Value - needsArg bool - isPtr bool // true if field is a pointer to TextUnmarshaler -} - -// sieve separates known flags from unknown flags and positional arguments. -type sieve struct { - fields map[string]fieldInfo // flag name → field info - passthrough map[string]struct{} - remaining []string - positional []string - strict bool -} - -// Sift extracts known flags into target, returning unknown flags and positional args. -// passthroughWithArg lists unknown flags that take values (e.g., []string{"-o", "-L"}). -// -// Panics if target is not a pointer to struct or if any tagged field -// has a type other than string or bool. -func Sift(target any, args []string, passthroughWithArg []string) (remaining, positional []string, err error) { - s := &sieve{ - fields: make(map[string]fieldInfo), - passthrough: make(map[string]struct{}), - } - - s.extractFields(target) - - for _, p := range passthroughWithArg { - s.passthrough[p] = struct{}{} - } - - return s.parse(args) -} - -// Parse parses args into target, returning positional args. -// Returns error if unknown flags are encountered. -// -// Panics if target is not a pointer to struct or if any tagged field -// has a type other than string or bool. -func Parse(target any, args []string) (positional []string, err error) { - s := &sieve{ - fields: make(map[string]fieldInfo), - passthrough: make(map[string]struct{}), - strict: true, - } - - s.extractFields(target) - - _, positional, err = s.parse(args) - - return positional, err -} - -// Helper methods for cleaner append patterns. -func (s *sieve) addRemaining(args ...string) { s.remaining = append(s.remaining, args...) } -func (s *sieve) addPositional(args ...string) { s.positional = append(s.positional, args...) } - -// extractFields reads struct tags and stores field references. -// Panics if target is not a pointer to a struct. -func (s *sieve) extractFields(target any) { - v := reflect.ValueOf(target) - if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { - panic(fmt.Sprintf("argsieve: target must be a pointer to struct, got %T", target)) - } - - s.extractFieldsFromValue(v.Elem()) -} - -// extractFieldsFromValue recursively extracts fields from a struct value, -// including fields from embedded structs. -func (s *sieve) extractFieldsFromValue(v reflect.Value) { - t := v.Type() - - for i := 0; i < t.NumField(); i++ { - fieldType := t.Field(i) - fieldValue := v.Field(i) - - // Recursively process embedded structs - if fieldType.Anonymous && fieldType.Type.Kind() == reflect.Struct { - s.extractFieldsFromValue(fieldValue) - continue - } - - short := fieldType.Tag.Get("short") - long := fieldType.Tag.Get("long") - - // Skip fields without tags - if short == "" && long == "" { - continue - } - - // Determine field type and whether it needs an argument - kind := fieldType.Type.Kind() - var info fieldInfo - - switch { - case kind == reflect.Bool: - info = fieldInfo{field: fieldValue, needsArg: false} - case kind == reflect.String: - info = fieldInfo{field: fieldValue, needsArg: true} - case kind == reflect.Ptr: - // Pointer to TextUnmarshaler - nil when flag absent, allocated when present - elemType := fieldType.Type.Elem() - if reflect.PointerTo(elemType).Implements(textUnmarshalerType) { - info = fieldInfo{field: fieldValue, needsArg: true, isPtr: true} - } else { - panic(fmt.Sprintf("argsieve: pointer field %s must point to type implementing encoding.TextUnmarshaler", - fieldType.Name)) - } - case fieldValue.CanAddr() && reflect.PointerTo(fieldType.Type).Implements(textUnmarshalerType): - // Field's pointer type implements encoding.TextUnmarshaler - info = fieldInfo{field: fieldValue, needsArg: true} - default: - panic(fmt.Sprintf("argsieve: field %s has unsupported type %s (must be string, bool, or implement encoding.TextUnmarshaler)", - fieldType.Name, fieldType.Type)) - } - - if short != "" { - s.fields[short] = info - } - - if long != "" { - s.fields[long] = info - } - } -} - -// setField assigns a value to a field based on its type. -// Returns an error if TextUnmarshaler.UnmarshalText fails. -func (s *sieve) setField(info fieldInfo, value string) error { - // Handle pointer fields - allocate and set - if info.isPtr { - elemType := info.field.Type().Elem() - newVal := reflect.New(elemType) - if tu, ok := newVal.Interface().(encoding.TextUnmarshaler); ok { - if err := tu.UnmarshalText([]byte(value)); err != nil { - return err - } - info.field.Set(newVal) - return nil - } - } - - // Try TextUnmarshaler for value types - if info.field.CanAddr() { - if tu, ok := info.field.Addr().Interface().(encoding.TextUnmarshaler); ok { - return tu.UnmarshalText([]byte(value)) - } - } - - // Fall back to built-in types - if info.needsArg { - info.field.SetString(value) - } else { - info.field.SetBool(true) - } - - return nil -} - -// handleLong processes --name or --name=value arguments. -func (s *sieve) handleLong(arg string, next func() (string, bool)) error { - name, eqValue, hasEquals := strings.Cut(arg[2:], "=") - - info, known := s.fields[name] - - // Unknown flag - reject in strict mode or check passthrough list - if !known { - if s.strict { - return fmt.Errorf("%w: unknown option --%s", ErrParse, name) - } - - _, isPassthrough := s.passthrough["--"+name] - - if isPassthrough && !hasEquals { - if value, ok := next(); ok { - s.addRemaining(arg, value) - - return nil - } - } - - s.addRemaining(arg) - - return nil - } - - // Known bool flag - if !info.needsArg { - return s.setField(info, "") - } - - // Known string flag with equals - if hasEquals { - if err := s.setField(info, eqValue); err != nil { - return fmt.Errorf("%w: invalid value for --%s: %v", ErrParse, name, err) - } - - return nil - } - - // Known string flag - needs argument from next arg - value, ok := next() - if !ok { - return fmt.Errorf("%w: missing value for --%s", ErrParse, name) - } - - if err := s.setField(info, value); err != nil { - return fmt.Errorf("%w: invalid value for --%s: %v", ErrParse, name, err) - } - - return nil -} - -// handleShort processes -x, -xvalue, or -xyz combined arguments. -func (s *sieve) handleShort(arg string, next func() (string, bool)) error { - flags := arg[1:] - - for j := 0; j < len(flags); j++ { - flag := string(flags[j]) - tail := flags[j+1:] - - info, known := s.fields[flag] - - // Handle unknown flag first (guard clause) - if !known { - if err := s.handleUnknownShort(flag, tail, next); err != nil { - return err - } - - if len(tail) > 0 { - return nil // tail consumed by passthrough - } - - continue - } - - // Known bool flag - if !info.needsArg { - if err := s.setField(info, ""); err != nil { - return err - } - - continue - } - - // Known string flag - value attached - if len(tail) > 0 { - if err := s.setField(info, tail); err != nil { - return fmt.Errorf("%w: invalid value for -%s: %v", ErrParse, flag, err) - } - - return nil - } - - // Known string flag - value in next arg - value, ok := next() - if !ok { - return fmt.Errorf("%w: missing value for -%s", ErrParse, flag) - } - - if err := s.setField(info, value); err != nil { - return fmt.Errorf("%w: invalid value for -%s: %v", ErrParse, flag, err) - } - - return nil - } - - return nil -} - -// handleUnknownShort handles unknown short flags, checking passthrough list. -func (s *sieve) handleUnknownShort(flag, tail string, next func() (string, bool)) error { - if s.strict { - return fmt.Errorf("%w: unknown option -%s", ErrParse, flag) - } - - prefixedFlag := "-" + flag - _, isPassthrough := s.passthrough[prefixedFlag] - - if isPassthrough { - if len(tail) > 0 { - s.addRemaining("-" + flag + tail) - - return nil - } - - if value, ok := next(); ok { - s.addRemaining(prefixedFlag, value) - - return nil - } - } - - s.addRemaining(prefixedFlag) - - return nil -} - -// parse separates args into known flags (bound to target), unknown flags, and positionals. -// Arguments after "--" are treated as positional (the "--" itself is not included). -func (s *sieve) parse(args []string) (remaining, positional []string, err error) { - next, stop := iter.Pull(slices.Values(args)) - defer stop() - - for arg, ok := next(); ok; arg, ok = next() { - switch { - case arg == "--": - // Drain remaining args as positional (don't pass "--" through) - for arg, ok := next(); ok; arg, ok = next() { - s.addPositional(arg) - } - - case strings.HasPrefix(arg, "--"): - if err := s.handleLong(arg, next); err != nil { - return nil, nil, err - } - - case strings.HasPrefix(arg, "-") && len(arg) > 1: - if err := s.handleShort(arg, next); err != nil { - return nil, nil, err - } - - default: - s.addPositional(arg) - } - } - - return s.remaining, s.positional, nil -} diff --git a/internal/argsieve/argsieve_test.go b/internal/argsieve/argsieve_test.go deleted file mode 100644 index 0288921..0000000 --- a/internal/argsieve/argsieve_test.go +++ /dev/null @@ -1,675 +0,0 @@ -package argsieve - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// testFlags is a test struct covering all supported field types. -type testFlags struct { - Region string `short:"r" long:"region"` - Profile string `short:"p" long:"profile"` - Verbose bool `short:"v" long:"verbose"` - Debug bool `short:"d" long:"debug"` -} - -// testEmbeddedBase is embedded in testEmbedded. -type testEmbeddedBase struct { - Region string `short:"r" long:"region"` -} - -// testEmbedded tests embedded struct field extraction. -type testEmbedded struct { - testEmbeddedBase - Profile string `short:"p" long:"profile"` -} - -func TestSift(t *testing.T) { - t.Parallel() - - tests := map[string]struct { - args []string - passthroughWithArg []string - wantRemaining []string - wantPositional []string - wantRegion string - wantProfile string - wantVerbose bool - wantDebug bool - wantErr bool - }{ - // Short flags with separate value - "short flag with separate value": { - args: []string{"-r", "us-west-2"}, - wantRegion: "us-west-2", - }, - // Short flag with attached value - "short flag with attached value": { - args: []string{"-rus-west-2"}, - wantRegion: "us-west-2", - }, - // Short bool flag - "short bool flag": { - args: []string{"-v"}, - wantVerbose: true, - }, - // Short flag chaining bools - "short flag chaining bools": { - args: []string{"-vd"}, - wantVerbose: true, - wantDebug: true, - }, - // Short flag chain with value at end - "short flag chain with value at end": { - args: []string{"-vdrus-west-2"}, - wantVerbose: true, - wantDebug: true, - wantRegion: "us-west-2", - }, - // Long flag with separate value - "long flag with separate value": { - args: []string{"--region", "us-west-2"}, - wantRegion: "us-west-2", - }, - // Long flag with equals value - "long flag with equals value": { - args: []string{"--region=us-west-2"}, - wantRegion: "us-west-2", - }, - // Long bool flag - "long bool flag": { - args: []string{"--verbose"}, - wantVerbose: true, - }, - // Unknown short flag passed through - "unknown short flag passed through": { - args: []string{"-x", "foo"}, - wantRemaining: []string{"-x"}, - wantPositional: []string{"foo"}, - }, - // Unknown long flag passed through - "unknown long flag passed through": { - args: []string{"--unknown", "foo"}, - wantRemaining: []string{"--unknown"}, - wantPositional: []string{"foo"}, - }, - // Passthrough flag with value - "passthrough flag with value": { - args: []string{"-o", "StrictHostKeyChecking=no"}, - passthroughWithArg: []string{"-o"}, - wantRemaining: []string{"-o", "StrictHostKeyChecking=no"}, - }, - // Passthrough flag with attached value - "passthrough flag with attached value": { - args: []string{"-oStrictHostKeyChecking=no"}, - passthroughWithArg: []string{"-o"}, - wantRemaining: []string{"-oStrictHostKeyChecking=no"}, - }, - // Passthrough long flag with value - "passthrough long flag with value": { - args: []string{"--option", "value"}, - passthroughWithArg: []string{"--option"}, - wantRemaining: []string{"--option", "value"}, - }, - // Positional only - "positional only": { - args: []string{"host1", "host2"}, - wantPositional: []string{"host1", "host2"}, - }, - // Mixed flags and positional - "mixed flags and positional": { - args: []string{"-r", "us-west-2", "host"}, - wantRegion: "us-west-2", - wantPositional: []string{"host"}, - }, - // Double dash terminator - "double dash terminator": { - args: []string{"-v", "--", "-r", "us-west-2"}, - wantVerbose: true, - wantPositional: []string{"-r", "us-west-2"}, - }, - // Empty args - "empty args": { - args: []string{}, - }, - // Single dash is positional - "single dash is positional": { - args: []string{"-"}, - wantPositional: []string{"-"}, - }, - // Multiple known and unknown mixed - "multiple known and unknown mixed": { - args: []string{"-v", "-x", "--region", "us-east-1", "--unknown", "host"}, - wantVerbose: true, - wantRegion: "us-east-1", - wantRemaining: []string{"-x", "--unknown"}, - wantPositional: []string{"host"}, - }, - // Missing value for short flag - "missing value for short flag": { - args: []string{"-r"}, - wantErr: true, - }, - // Missing value for long flag - "missing value for long flag": { - args: []string{"--region"}, - wantErr: true, - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - t.Parallel() - - var flags testFlags - remaining, positional, err := Sift(&flags, tc.args, tc.passthroughWithArg) - - if tc.wantErr { - require.Error(t, err) - assert.ErrorIs(t, err, ErrParse) - return - } - - require.NoError(t, err) - assert.Equal(t, tc.wantRemaining, remaining, "remaining") - assert.Equal(t, tc.wantPositional, positional, "positional") - assert.Equal(t, tc.wantRegion, flags.Region, "region") - assert.Equal(t, tc.wantProfile, flags.Profile, "profile") - assert.Equal(t, tc.wantVerbose, flags.Verbose, "verbose") - assert.Equal(t, tc.wantDebug, flags.Debug, "debug") - }) - } -} - -func TestParse(t *testing.T) { - t.Parallel() - - tests := map[string]struct { - args []string - wantPositional []string - wantRegion string - wantProfile string - wantVerbose bool - wantErr bool - errContains string - }{ - "valid flags": { - args: []string{"--region", "us-west-2", "host"}, - wantRegion: "us-west-2", - wantPositional: []string{"host"}, - }, - "all flag types": { - args: []string{"-v", "-r", "us-west-2", "--profile", "myprofile", "host"}, - wantVerbose: true, - wantRegion: "us-west-2", - wantProfile: "myprofile", - wantPositional: []string{"host"}, - }, - "unknown short flag rejected": { - args: []string{"-x"}, - wantErr: true, - errContains: "unknown option -x", - }, - "unknown long flag rejected": { - args: []string{"--unknown"}, - wantErr: true, - errContains: "unknown option --unknown", - }, - "empty args": { - args: []string{}, - wantPositional: nil, - }, - "positional only": { - args: []string{"host1", "host2"}, - wantPositional: []string{"host1", "host2"}, - }, - "missing value rejected": { - args: []string{"--region"}, - wantErr: true, - errContains: "missing value for --region", - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - t.Parallel() - - var flags testFlags - positional, err := Parse(&flags, tc.args) - - if tc.wantErr { - require.Error(t, err) - assert.ErrorIs(t, err, ErrParse) - if tc.errContains != "" { - assert.Contains(t, err.Error(), tc.errContains) - } - return - } - - require.NoError(t, err) - assert.Equal(t, tc.wantPositional, positional) - assert.Equal(t, tc.wantRegion, flags.Region) - assert.Equal(t, tc.wantProfile, flags.Profile) - assert.Equal(t, tc.wantVerbose, flags.Verbose) - }) - } -} - -func TestSift_EmbeddedStruct(t *testing.T) { - t.Parallel() - - var flags testEmbedded - _, _, err := Sift(&flags, []string{"-r", "us-west-2", "-p", "myprofile"}, nil) - - require.NoError(t, err) - assert.Equal(t, "us-west-2", flags.Region) - assert.Equal(t, "myprofile", flags.Profile) -} - -func TestSift_PanicsOnInvalidTarget(t *testing.T) { - t.Parallel() - - tests := map[string]struct { - target any - }{ - "nil target": {target: nil}, - "non-pointer": {target: testFlags{}}, - "pointer to string": {target: new(string)}, - "pointer to int": {target: new(int)}, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - t.Parallel() - assert.Panics(t, func() { - _, _, _ = Sift(tc.target, []string{}, nil) - }) - }) - } -} - -func TestSift_PanicsOnUnsupportedFieldType(t *testing.T) { - t.Parallel() - - type badStruct struct { - Count int `short:"c"` - } - - assert.Panics(t, func() { - var flags badStruct - _, _, _ = Sift(&flags, []string{}, nil) - }) -} - -func TestParse_PanicsOnInvalidTarget(t *testing.T) { - t.Parallel() - - tests := map[string]struct { - target any - }{ - "nil target": {target: nil}, - "non-pointer": {target: testFlags{}}, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - t.Parallel() - assert.Panics(t, func() { - _, _ = Parse(tc.target, []string{}) - }) - }) - } -} - -func TestSift_LongFlagEqualsEmptyValue(t *testing.T) { - t.Parallel() - - var flags testFlags - _, _, err := Sift(&flags, []string{"--region="}, nil) - - require.NoError(t, err) - assert.Equal(t, "", flags.Region) -} - -func TestSift_ComplexPassthrough(t *testing.T) { - t.Parallel() - - var flags testFlags - remaining, positional, err := Sift(&flags, - []string{"-v", "-o", "opt1", "-L", "8080:localhost:80", "--region", "us-west-2", "host"}, - []string{"-o", "-L"}, - ) - - require.NoError(t, err) - assert.Equal(t, []string{"-o", "opt1", "-L", "8080:localhost:80"}, remaining) - assert.Equal(t, []string{"host"}, positional) - assert.True(t, flags.Verbose) - assert.Equal(t, "us-west-2", flags.Region) -} - -// logLevel implements encoding.TextUnmarshaler for testing custom type support. -type logLevel int - -const ( - logLevelInfo logLevel = iota - logLevelDebug - logLevelError -) - -func (l *logLevel) UnmarshalText(text []byte) error { - switch string(text) { - case "", "info": - *l = logLevelInfo - case "debug": - *l = logLevelDebug - case "error": - *l = logLevelError - default: - return assert.AnError - } - return nil -} - -// strictLevel is like logLevel but rejects empty strings (matches DstType/AddrType behavior) -type strictLevel int - -const ( - strictLevelLow strictLevel = iota - strictLevelHigh -) - -func (s *strictLevel) UnmarshalText(text []byte) error { - switch string(text) { - case "low": - *s = strictLevelLow - case "high": - *s = strictLevelHigh - default: - return fmt.Errorf("unknown strict level: %q", text) - } - return nil -} - -func TestSift_TextUnmarshaler(t *testing.T) { - t.Parallel() - - type customFlags struct { - Level logLevel `long:"level"` - Verbose bool `short:"v"` - } - - tests := map[string]struct { - args []string - wantLevel logLevel - wantErr bool - }{ - "valid debug": { - args: []string{"--level", "debug"}, - wantLevel: logLevelDebug, - }, - "valid error": { - args: []string{"--level", "error"}, - wantLevel: logLevelError, - }, - "valid info": { - args: []string{"--level", "info"}, - wantLevel: logLevelInfo, - }, - "empty defaults to info": { - args: []string{"--level", ""}, - wantLevel: logLevelInfo, - }, - "invalid value": { - args: []string{"--level", "invalid"}, - wantErr: true, - }, - "with equals": { - args: []string{"--level=debug"}, - wantLevel: logLevelDebug, - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - t.Parallel() - - var flags customFlags - _, _, err := Sift(&flags, tc.args, nil) - - if tc.wantErr { - require.Error(t, err) - return - } - - require.NoError(t, err) - assert.Equal(t, tc.wantLevel, flags.Level) - }) - } -} - -func TestParse_TextUnmarshaler(t *testing.T) { - t.Parallel() - - type customFlags struct { - Level logLevel `short:"l" long:"level"` - } - - tests := map[string]struct { - args []string - wantLevel logLevel - wantErr bool - }{ - "short flag with value": { - args: []string{"-l", "debug"}, - wantLevel: logLevelDebug, - }, - "short flag attached value": { - args: []string{"-ldebug"}, - wantLevel: logLevelDebug, - }, - "invalid short": { - args: []string{"-l", "invalid"}, - wantErr: true, - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - t.Parallel() - - var flags customFlags - _, err := Parse(&flags, tc.args) - - if tc.wantErr { - require.Error(t, err) - return - } - - require.NoError(t, err) - assert.Equal(t, tc.wantLevel, flags.Level) - }) - } -} - -func TestSift_PointerToTextUnmarshaler(t *testing.T) { - t.Parallel() - - type pointerFlags struct { - Level *logLevel `long:"level"` - Name string `long:"name"` - } - - tests := map[string]struct { - args []string - wantLevel *logLevel - wantNil bool - }{ - "flag absent - nil": { - args: []string{"--name", "test"}, - wantNil: true, - }, - "flag present - allocated": { - args: []string{"--level", "debug"}, - wantLevel: ptrTo(logLevelDebug), - }, - "flag with equals": { - args: []string{"--level=error"}, - wantLevel: ptrTo(logLevelError), - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - t.Parallel() - - var flags pointerFlags - _, _, err := Sift(&flags, tc.args, nil) - require.NoError(t, err) - - if tc.wantNil { - assert.Nil(t, flags.Level) - } else { - require.NotNil(t, flags.Level) - assert.Equal(t, *tc.wantLevel, *flags.Level) - } - }) - } -} - -func TestParse_PointerToTextUnmarshaler(t *testing.T) { - t.Parallel() - - type pointerFlags struct { - Level *logLevel `short:"l" long:"level"` - } - - tests := map[string]struct { - args []string - wantLevel *logLevel - wantNil bool - wantErr bool - }{ - "short flag with value": { - args: []string{"-l", "debug"}, - wantLevel: ptrTo(logLevelDebug), - }, - "short flag attached": { - args: []string{"-lerror"}, - wantLevel: ptrTo(logLevelError), - }, - "absent is nil": { - args: []string{}, - wantNil: true, - }, - "invalid value": { - args: []string{"--level", "invalid"}, - wantErr: true, - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - t.Parallel() - - var flags pointerFlags - _, err := Parse(&flags, tc.args) - - if tc.wantErr { - require.Error(t, err) - return - } - - require.NoError(t, err) - if tc.wantNil { - assert.Nil(t, flags.Level) - } else { - require.NotNil(t, flags.Level) - assert.Equal(t, *tc.wantLevel, *flags.Level) - } - }) - } -} - -// ptrTo returns a pointer to the value (generic helper for tests). -func ptrTo[T any](v T) *T { - return &v -} - -func TestSift_PanicsOnUnsupportedPointerType(t *testing.T) { - t.Parallel() - - // *string does not implement encoding.TextUnmarshaler - type badStruct struct { - Value *string `long:"value"` - } - - assert.Panics(t, func() { - var flags badStruct - _, _, _ = Sift(&flags, []string{}, nil) - }) -} - -func TestParse_PanicsOnUnsupportedPointerType(t *testing.T) { - t.Parallel() - - // *int does not implement encoding.TextUnmarshaler - type badStruct struct { - Count *int `long:"count"` - } - - assert.Panics(t, func() { - var flags badStruct - _, _ = Parse(&flags, []string{}) - }) -} - -// TestParse_PointerEmptyStringRejection tests that types rejecting empty strings -// (like DstType and AddrType) correctly propagate errors through pointer fields. -func TestParse_PointerEmptyStringRejection(t *testing.T) { - t.Parallel() - - type strictFlags struct { - Level *strictLevel `long:"level"` - } - - tests := map[string]struct { - args []string - wantErr bool - }{ - "valid value": { - args: []string{"--level", "high"}, - wantErr: false, - }, - "empty string with equals": { - args: []string{"--level="}, - wantErr: true, - }, - "empty string separate": { - args: []string{"--level", ""}, - wantErr: true, - }, - "absent is nil not error": { - args: []string{}, - wantErr: false, - }, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - t.Parallel() - - var flags strictFlags - _, err := Parse(&flags, tc.args) - - if tc.wantErr { - require.Error(t, err) - assert.Contains(t, err.Error(), "unknown strict level") - } else { - require.NoError(t, err) - } - }) - } -}