diff --git a/.golangci.yaml b/.golangci.yaml index 333b1851..d3fab5d1 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -50,11 +50,11 @@ linters: - $test allow: - $gostd - - github.com/golang/mock/gomock - github.com/openfga/api/proto - github.com/openfga/cli - github.com/openfga/go-sdk - github.com/openfga/openfga + - github.com/spf13/cobra - github.com/stretchr - go.uber.org/mock/gomock funlen: diff --git a/cmd/model/get.go b/cmd/model/get.go index a080e2ba..94262290 100644 --- a/cmd/model/get.go +++ b/cmd/model/get.go @@ -24,6 +24,7 @@ import ( "github.com/openfga/cli/internal/authorizationmodel" "github.com/openfga/cli/internal/cmdutils" + "github.com/openfga/cli/internal/flags" "github.com/openfga/cli/internal/output" ) @@ -77,8 +78,8 @@ func init() { getCmd.Flags().StringArray("field", []string{"model"}, "Fields to display, choices are: id, created_at and model") //nolint:lll getCmd.Flags().Var(&getOutputFormat, "format", `Authorization model output format. Can be "fga" or "json"`) - if err := getCmd.MarkFlagRequired("store-id"); err != nil { - fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/get", err) + if err := flags.SetFlagRequired(getCmd, "store-id", "cmd/models/get", false); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/model/list.go b/cmd/model/list.go index 3481ffb7..66114847 100644 --- a/cmd/model/list.go +++ b/cmd/model/list.go @@ -27,6 +27,7 @@ import ( "github.com/openfga/cli/internal/authorizationmodel" "github.com/openfga/cli/internal/cmdutils" + "github.com/openfga/cli/internal/flags" "github.com/openfga/cli/internal/output" ) @@ -112,8 +113,8 @@ func init() { listCmd.Flags().String("store-id", "", "Store ID") listCmd.Flags().StringArray("field", []string{"id", "created_at"}, "Fields to display, choices are: id, created_at and model") //nolint:lll - if err := listCmd.MarkFlagRequired("store-id"); err != nil { - fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/list", err) + if err := flags.SetFlagRequired(listCmd, "store-id", "cmd/models/list", false); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/model/test.go b/cmd/model/test.go index a3d97bc4..9d7f35ba 100644 --- a/cmd/model/test.go +++ b/cmd/model/test.go @@ -24,6 +24,7 @@ import ( "github.com/spf13/cobra" "github.com/openfga/cli/internal/cmdutils" + "github.com/openfga/cli/internal/flags" "github.com/openfga/cli/internal/output" "github.com/openfga/cli/internal/storetest" ) @@ -99,8 +100,8 @@ func init() { testCmd.Flags().Bool("verbose", false, "Print verbose JSON output") testCmd.Flags().Bool("suppress-summary", false, "Suppress the plain text summary output") - if err := testCmd.MarkFlagRequired("tests"); err != nil { - fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/test", err) + if err := flags.SetFlagRequired(testCmd, "tests", "cmd/models/test", false); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/model/write.go b/cmd/model/write.go index 23e7e682..b21b4fa2 100644 --- a/cmd/model/write.go +++ b/cmd/model/write.go @@ -27,6 +27,7 @@ import ( "github.com/openfga/cli/internal/authorizationmodel" "github.com/openfga/cli/internal/cmdutils" + "github.com/openfga/cli/internal/flags" "github.com/openfga/cli/internal/output" "github.com/openfga/cli/internal/utils" ) @@ -106,8 +107,8 @@ func init() { writeCmd.Flags().String("file", "", "File Name. The file should have the model in the JSON or DSL format") writeCmd.Flags().Var(&writeInputFormat, "format", `Authorization model input format. Can be "fga", "json", or "modular"`) //nolint:lll - if err := writeCmd.MarkFlagRequired("store-id"); err != nil { - fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/write", err) + if err := flags.SetFlagRequired(writeCmd, "store-id", "cmd/model/write", false); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/query/list-users.go b/cmd/query/list-users.go index a13284d2..55e0cef1 100644 --- a/cmd/query/list-users.go +++ b/cmd/query/list-users.go @@ -27,6 +27,7 @@ import ( "github.com/spf13/cobra" "github.com/openfga/cli/internal/cmdutils" + "github.com/openfga/cli/internal/flags" "github.com/openfga/cli/internal/output" ) @@ -138,18 +139,11 @@ func init() { listUsersCmd.Flags().String("relation", "", "Relation to evaluate on") listUsersCmd.Flags().String("user-filter", "", "Filter the responses can be in the formats (to filter objects and typed public bound access) or # (to filter usersets)") //nolint:lll - if err := listUsersCmd.MarkFlagRequired("object"); err != nil { - fmt.Printf("error setting flag as required - %v: %v\n", "cmd/query/list-users", err) - os.Exit(1) - } - - if err := listUsersCmd.MarkFlagRequired("relation"); err != nil { - fmt.Printf("error setting flag as required - %v: %v\n", "cmd/query/list-users", err) - os.Exit(1) - } - - if err := listUsersCmd.MarkFlagRequired("user-filter"); err != nil { - fmt.Printf("error setting flag as required - %v: %v\n", "cmd/query/list-users", err) + if err := flags.SetFlagsRequired( + listUsersCmd, + []string{"object", "relation", "user-filter"}, + "cmd/query/list-users", false); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/query/query.go b/cmd/query/query.go index 7df4834b..38987a52 100644 --- a/cmd/query/query.go +++ b/cmd/query/query.go @@ -22,6 +22,8 @@ import ( "os" "github.com/spf13/cobra" + + "github.com/openfga/cli/internal/flags" ) // QueryCmd represents the query command. @@ -48,9 +50,8 @@ func init() { "Consistency preference for the request. Valid options are HIGHER_CONSISTENCY and MINIMIZE_LATENCY.", ) - err := QueryCmd.MarkPersistentFlagRequired("store-id") - if err != nil { - fmt.Print(err) + if err := flags.SetFlagRequired(QueryCmd, "store-id", "cmd/query/query", true); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/store/delete.go b/cmd/store/delete.go index 40b0885d..3580f5ed 100644 --- a/cmd/store/delete.go +++ b/cmd/store/delete.go @@ -25,6 +25,7 @@ import ( "github.com/openfga/cli/internal/cmdutils" "github.com/openfga/cli/internal/confirmation" + "github.com/openfga/cli/internal/flags" "github.com/openfga/cli/internal/output" ) @@ -72,9 +73,8 @@ func init() { deleteCmd.Flags().String("store-id", "", "Store ID") deleteCmd.Flags().Bool("force", false, "Force delete without confirmation") - err := deleteCmd.MarkFlagRequired("store-id") - if err != nil { - fmt.Print(err) + if err := flags.SetFlagRequired(deleteCmd, "store-id", "cmd/store/delete", false); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/store/export.go b/cmd/store/export.go index 9217bf1b..201b3e81 100644 --- a/cmd/store/export.go +++ b/cmd/store/export.go @@ -30,6 +30,7 @@ import ( "github.com/openfga/cli/internal/cmdutils" "github.com/openfga/cli/internal/confirmation" "github.com/openfga/cli/internal/fga" + "github.com/openfga/cli/internal/flags" "github.com/openfga/cli/internal/output" "github.com/openfga/cli/internal/storetest" "github.com/openfga/cli/internal/tuple" @@ -193,9 +194,8 @@ func init() { exportCmd.Flags().String("model-id", "", "Authorization Model ID") exportCmd.Flags().Uint("max-tuples", defaultMaxTupleCount, "max number of tuples to return in the output") - err := exportCmd.MarkFlagRequired("store-id") - if err != nil { - fmt.Print(err) + if err := flags.SetFlagRequired(exportCmd, "store-id", "cmd/store/export", false); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/store/get.go b/cmd/store/get.go index dbe3a89e..aa4ecbcc 100644 --- a/cmd/store/get.go +++ b/cmd/store/get.go @@ -26,6 +26,7 @@ import ( "github.com/openfga/cli/internal/cmdutils" "github.com/openfga/cli/internal/fga" + "github.com/openfga/cli/internal/flags" "github.com/openfga/cli/internal/output" ) @@ -65,9 +66,8 @@ var getCmd = &cobra.Command{ func init() { getCmd.Flags().String("store-id", "", "Store ID") - err := getCmd.MarkFlagRequired("store-id") - if err != nil { - fmt.Print(err) + if err := flags.SetFlagRequired(getCmd, "store-id", "cmd/store/get", false); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/store/import.go b/cmd/store/import.go index b334c725..675dc2ac 100644 --- a/cmd/store/import.go +++ b/cmd/store/import.go @@ -34,6 +34,7 @@ import ( "github.com/openfga/cli/internal/authorizationmodel" "github.com/openfga/cli/internal/cmdutils" "github.com/openfga/cli/internal/fga" + "github.com/openfga/cli/internal/flags" "github.com/openfga/cli/internal/output" "github.com/openfga/cli/internal/storetest" "github.com/openfga/cli/internal/tuple" @@ -339,8 +340,8 @@ func init() { importCmd.Flags().Int("max-tuples-per-write", tuple.MaxTuplesPerWrite, "Max tuples per write chunk.") importCmd.Flags().Int("max-parallel-requests", tuple.MaxParallelRequests, "Max number of requests to issue to the server in parallel.") //nolint:lll - if err := importCmd.MarkFlagRequired("file"); err != nil { - fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/write", err) + if err := flags.SetFlagRequired(importCmd, "file", "cmd/store/import", false); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/tuple/tuple.go b/cmd/tuple/tuple.go index d5af2254..358a6783 100644 --- a/cmd/tuple/tuple.go +++ b/cmd/tuple/tuple.go @@ -22,6 +22,8 @@ import ( "os" "github.com/spf13/cobra" + + "github.com/openfga/cli/internal/flags" ) // TupleCmd represents the tuple command. @@ -39,9 +41,8 @@ func init() { TupleCmd.PersistentFlags().String("store-id", "", "Store ID") - err := TupleCmd.MarkPersistentFlagRequired("store-id") - if err != nil { //nolint:wsl - fmt.Print(err) + if err := flags.SetFlagRequired(TupleCmd, "store-id", "cmd/tuple/tuple", true); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/internal/flags/flags.go b/internal/flags/flags.go new file mode 100644 index 00000000..28fff562 --- /dev/null +++ b/internal/flags/flags.go @@ -0,0 +1,101 @@ +// Package flags provides utility functions for working with cobra command flags. +// It simplifies the process of marking flags as required and handling related errors. +package flags + +import ( + "errors" + "fmt" + "strings" + + "github.com/spf13/cobra" +) + +var ( + // ErrFlagRequired is returned when a flag cannot be marked as required. + ErrFlagRequired = errors.New("error setting flag as required") + + // ErrInvalidInput is returned when invalid input is provided. + ErrInvalidInput = errors.New("invalid input") +) + +// buildFlagRequiredError creates a consistent error message for flag requirement failures. +// It wraps the original error with context about which flag and location failed. +func buildFlagRequiredError(flag, location string, err error) error { + if err == nil { + return nil + } + + return fmt.Errorf("%w - (flag: %s, file: %s): %v", ErrFlagRequired, flag, location, err) +} + +// SetFlagRequired marks a single flag as required for a cobra command. +// +// Parameters: +// - cmd: The cobra command to modify +// - flag: The name of the flag to mark as required +// - location: A string identifying the calling location (for error context) +// - isPersistent: If true, marks the persistent flag as required; otherwise marks the regular flag +// +// Returns an error if: +// - cmd is nil +// - flag is empty +// - the flag cannot be marked as required (e.g., flag doesn't exist) +func SetFlagRequired(cmd *cobra.Command, flag string, location string, isPersistent bool) error { + if cmd == nil { + return fmt.Errorf("%w: command cannot be nil", ErrInvalidInput) + } + + if strings.TrimSpace(flag) == "" { + return fmt.Errorf("%w: flag name cannot be empty", ErrInvalidInput) + } + + if isPersistent { + if err := cmd.MarkPersistentFlagRequired(flag); err != nil { + return buildFlagRequiredError(flag, location, err) + } + } else { + if err := cmd.MarkFlagRequired(flag); err != nil { + return buildFlagRequiredError(flag, location, err) + } + } + + return nil +} + +// SetFlagsRequired marks multiple flags as required for a cobra command. +// +// Parameters: +// - cmd: The cobra command to modify +// - flags: A slice of flag names to mark as required +// - location: A string identifying the calling location (for error context) +// - isPersistent: If true, marks the persistent flags as required; otherwise marks the regular flags +// +// Returns a joined error containing all individual flag requirement failures. +// If no flags are provided or all succeed, returns nil. +// +// Note: This function continues processing all flags even if some fail, +// allowing you to see all failures at once rather than stopping at the first error. +func SetFlagsRequired(cmd *cobra.Command, flags []string, location string, isPersistent bool) error { + if cmd == nil { + return fmt.Errorf("%w: command cannot be nil", ErrInvalidInput) + } + + if len(flags) == 0 { + return nil + } + + // Pre-allocate slice with exact capacity needed + flagErrors := make([]error, 0, len(flags)) + + for _, flag := range flags { + if err := SetFlagRequired(cmd, flag, location, isPersistent); err != nil { + flagErrors = append(flagErrors, err) + } + } + + if len(flagErrors) > 0 { + return errors.Join(flagErrors...) + } + + return nil +} diff --git a/internal/flags/flags_test.go b/internal/flags/flags_test.go new file mode 100644 index 00000000..41410015 --- /dev/null +++ b/internal/flags/flags_test.go @@ -0,0 +1,226 @@ +package flags + +import ( + "fmt" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetFlagRequired_NonPersistent_Success(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("foo", "", "foo flag") + + err := SetFlagRequired(cmd, "foo", "TestLocation", false) + + assert.NoError(t, err) +} + +func TestSetFlagRequired_Persistent_Success(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + cmd.PersistentFlags().String("bar", "", "bar flag") + + err := SetFlagRequired(cmd, "bar", "TestLocation", true) + + assert.NoError(t, err) +} + +func TestSetFlagRequired_NonPersistent_FlagNotFound(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + + err := SetFlagRequired(cmd, "missing", "TestLocation", false) + + require.Error(t, err) + assert.Contains(t, err.Error(), "error setting flag as required - (flag: missing, file: TestLocation):") +} + +func TestSetFlagRequired_Persistent_FlagNotFound(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + + err := SetFlagRequired(cmd, "missing", "TestLocation", true) + + require.Error(t, err) + assert.Contains(t, err.Error(), "error setting flag as required - (flag: missing, file: TestLocation):") +} + +func TestSetFlagRequired_ErrorMessageFormat(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + flagName := "nonexistent" + location := "SomeFunction" + + err := SetFlagRequired(cmd, flagName, location, false) + + require.Error(t, err) + + expectedPrefix := fmt.Sprintf("error setting flag as required - (flag: %s, file: %s):", flagName, location) + assert.Contains(t, err.Error(), expectedPrefix) +} + +func TestSetFlagsRequired_AllSuccess(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("foo", "", "foo flag") + cmd.Flags().String("bar", "", "bar flag") + cmd.Flags().String("baz", "", "baz flag") + + flags := []string{"foo", "bar", "baz"} + err := SetFlagsRequired(cmd, flags, "TestLocation", false) + + assert.NoError(t, err) +} + +func TestSetFlagsRequired_PersistentAllSuccess(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + cmd.PersistentFlags().String("foo", "", "foo flag") + cmd.PersistentFlags().String("bar", "", "bar flag") + + flags := []string{"foo", "bar"} + err := SetFlagsRequired(cmd, flags, "TestLocation", true) + + assert.NoError(t, err) +} + +func TestSetFlagsRequired_SomeSuccess_SomeFail(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("foo", "", "foo flag") + // "missing" flag is not defined + + flags := []string{"foo", "missing"} + err := SetFlagsRequired(cmd, flags, "TestLocation", false) + + require.Error(t, err) + assert.Contains(t, err.Error(), "error setting flag as required - (flag: missing, file: TestLocation):") + // The error should not contain "foo" since that one succeeded + assert.NotContains(t, err.Error(), "error setting flag as required - (flag: foo, file: TestLocation):") +} + +func TestSetFlagsRequired_AllFail(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + + flags := []string{"missing1", "missing2"} + err := SetFlagsRequired(cmd, flags, "TestLocation", false) + + require.Error(t, err) + assert.Contains(t, err.Error(), "error setting flag as required - (flag: missing1, file: TestLocation):") + assert.Contains(t, err.Error(), "error setting flag as required - (flag: missing2, file: TestLocation):") +} + +func TestSetFlagsRequired_EmptySlice(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + + err := SetFlagsRequired(cmd, []string{}, "TestLocation", false) + + assert.NoError(t, err) +} + +func TestSetFlagsRequired_NilSlice(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + + err := SetFlagsRequired(cmd, nil, "TestLocation", false) + + assert.NoError(t, err) +} + +func TestSetFlagsRequired_ErrorJoining(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + + flags := []string{"missing1", "missing2", "missing3"} + err := SetFlagsRequired(cmd, flags, "TestLocation", false) + + require.Error(t, err) + + for _, flag := range flags { + expectedError := fmt.Sprintf("error setting flag as required - (flag: %s, file: TestLocation):", flag) + assert.Contains(t, err.Error(), expectedError) + } +} + +func TestSetFlagsRequired_MixedPersistentAndNonPersistent(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("regular", "", "regular flag") + cmd.PersistentFlags().String("persistent", "", "persistent flag") + + flags := []string{"regular", "persistent"} + err := SetFlagsRequired(cmd, flags, "TestLocation", false) + + // The regular flag should succeed, but persistent flag should fail + require.Error(t, err) + assert.Contains(t, err.Error(), "error setting flag as required - (flag: persistent, file: TestLocation):") + assert.NotContains(t, err.Error(), "error setting flag as required - (flag: regular, file: TestLocation):") +} + +func TestSetFlagRequired_NilCommand(t *testing.T) { + t.Parallel() + + err := SetFlagRequired(nil, "foo", "TestLocation", false) + + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid input: command cannot be nil") +} + +func TestSetFlagRequired_EmptyFlag(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + err := SetFlagRequired(cmd, "", "TestLocation", false) + + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid input: flag name cannot be empty") +} + +func TestSetFlagRequired_WhitespaceFlag(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + err := SetFlagRequired(cmd, " ", "TestLocation", false) + + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid input: flag name cannot be empty") +} + +func TestSetFlagsRequired_NilCommand(t *testing.T) { + t.Parallel() + + err := SetFlagsRequired(nil, []string{"foo"}, "TestLocation", false) + + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid input: command cannot be nil") +} + +func TestSetFlagsRequired_EmptyFlagInSlice(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "test"} + flags := []string{"valid", "", "also-valid"} + err := SetFlagsRequired(cmd, flags, "TestLocation", false) + + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid input: flag name cannot be empty") +}