diff --git a/flag_groups.go b/flag_groups.go index 560612fd3..ebfd60757 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -83,31 +83,46 @@ func (c *Command) ValidateFlagGroups() error { return nil } - flags := c.Flags() - - // groupStatus format is the list of flags as a unique ID, - // then a map of each flag name and whether it is set or not. - groupStatus := map[string]map[string]bool{} - oneRequiredGroupStatus := map[string]map[string]bool{} - mutuallyExclusiveGroupStatus := map[string]map[string]bool{} - flags.VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) - processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) - processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) - }) + statuses := c.getFlagGroupStatuses() - if err := validateRequiredFlagGroups(groupStatus); err != nil { + if err := validateRequiredFlagGroups(statuses.Required); err != nil { return err } - if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil { + if err := validateOneRequiredFlagGroups(statuses.OneRequired); err != nil { return err } - if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { + if err := validateExclusiveFlagGroups(statuses.MutuallyExclusive); err != nil { return err } return nil } +type flagGroupStatuses struct { + Required map[string]map[string]bool + OneRequired map[string]map[string]bool + MutuallyExclusive map[string]map[string]bool +} + +// getFlagGroupStatuses collects the status of all flags belonging to any flag group. +func (c *Command) getFlagGroupStatuses() flagGroupStatuses { + flags := c.Flags() + required := map[string]map[string]bool{} + oneRequired := map[string]map[string]bool{} + mutuallyExclusive := map[string]map[string]bool{} + + flags.VisitAll(func(pflag *flag.Flag) { + processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, required) + processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequired) + processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusive) + }) + + return flagGroupStatuses{ + Required: required, + OneRequired: oneRequired, + MutuallyExclusive: mutuallyExclusive, + } +} + func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool { for _, fname := range flagnames { f := fs.Lookup(fname) @@ -227,19 +242,11 @@ func (c *Command) enforceFlagGroupsForCompletion() { return } - flags := c.Flags() - groupStatus := map[string]map[string]bool{} - oneRequiredGroupStatus := map[string]map[string]bool{} - mutuallyExclusiveGroupStatus := map[string]map[string]bool{} - c.Flags().VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) - processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) - processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) - }) + statuses := c.getFlagGroupStatuses() // If a flag that is part of a group is present, we make all the other flags // of that group required so that the shell completion suggests them automatically - for flagList, flagnameAndStatus := range groupStatus { + for flagList, flagnameAndStatus := range statuses.Required { for _, isSet := range flagnameAndStatus { if isSet { // One of the flags of the group is set, mark the other ones as required @@ -252,7 +259,7 @@ func (c *Command) enforceFlagGroupsForCompletion() { // If none of the flags of a one-required group are present, we make all the flags // of that group required so that the shell completion suggests them automatically - for flagList, flagnameAndStatus := range oneRequiredGroupStatus { + for flagList, flagnameAndStatus := range statuses.OneRequired { isSet := false for _, isSet = range flagnameAndStatus { @@ -272,7 +279,7 @@ func (c *Command) enforceFlagGroupsForCompletion() { // If a flag that is mutually exclusive to others is present, we hide the other // flags of that group so the shell completion does not suggest them - for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus { + for flagList, flagnameAndStatus := range statuses.MutuallyExclusive { for flagName, isSet := range flagnameAndStatus { if isSet { // One of the flags of the mutually exclusive group is set, mark the other ones as hidden