Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion command.go
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,30 @@ func (c *Command) ValidateArgs(args []string) error {
return c.Args(c, args)
}

// RequiredFlagError represents a failure to validate required flags.
type RequiredFlagError struct {
Err error
}

// Error satisfies the error interface.
func (r *RequiredFlagError) Error() string {
return r.Err.Error()
}

// Is satisfies the Is error interface.
func (r *RequiredFlagError) Is(target error) bool {
err, ok := target.(*RequiredFlagError)
if !ok {
return false
}
return r.Err == err
}

// Unwrap satisfies Unwrap error interface.
func (r *RequiredFlagError) Unwrap() error {
return r.Err
}
Comment on lines +1179 to +1201
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be overkill.

I'm doubtful about the implementation.

I'm not sure you need a structure, especially while the type is a simple wrapper here.

If you had stored the fields name, I would have understood, but here you don't.

Also I will have double check about the need of Is method. Usually, adding Unwrap is enough.

I will come back to you later with feedbacks and maybe another implementation


// ValidateRequiredFlags validates all required flags are present and returns an error otherwise
func (c *Command) ValidateRequiredFlags() error {
if c.DisableFlagParsing {
Expand All @@ -1195,7 +1219,9 @@ func (c *Command) ValidateRequiredFlags() error {
})

if len(missingFlagNames) > 0 {
return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`))
return &RequiredFlagError{
Err: fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`)),
}
}
return nil
}
Expand Down
12 changes: 12 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package cobra
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -2952,3 +2953,14 @@ func TestHelpFuncExecuted(t *testing.T) {

checkStringContains(t, output, helpText)
}

func TestValidateRequiredFlags(t *testing.T) {
c := &Command{Use: "c", Run: emptyRun}
c.Flags().BoolP("boola", "a", false, "a boolean flag")
c.MarkFlagRequired("boola")
if err := c.ValidateRequiredFlags(); !errors.Is(err, &RequiredFlagError{
Err: errors.New("required flag(s) \"boola\" not set"),
}) {
t.Fatalf("Expected error: %v, got: %v", "boola", err)
}
}