diff --git a/src/defopt.py b/src/defopt.py index 8855c05..16f5254 100644 --- a/src/defopt.py +++ b/src/defopt.py @@ -243,6 +243,7 @@ def run( argparse_kwargs: dict = {}, intermixed: bool = False, argv: Optional[List[str]] = None, + defaults: Optional[Dict[str, Any]] = None ): """ Process command-line arguments and run the given functions. @@ -302,6 +303,10 @@ def run( listed in the `argparse` documentation. :param argv: Command line arguments to parse (default: ``sys.argv[1:]``). + :param defaults: + Mapping for argument defaults passed to + `~argparse.ArgumentParser.set_defaults`. Key must be the command name, + value is a mapping of argument name to default value. :return: The value returned by the function that was run. """ @@ -309,13 +314,15 @@ def run( funcs, parsers=parsers, short=short, cli_options=cli_options, show_defaults=show_defaults, show_types=show_types, no_negated_flags=no_negated_flags, version=version, - argparse_kwargs=argparse_kwargs, intermixed=intermixed, argv=argv)() + argparse_kwargs=argparse_kwargs, intermixed=intermixed, argv=argv, + defaults=defaults)() _DefoptOptions = namedtuple( '_DefoptOptions', ['parsers', 'short', 'cli_options', 'show_defaults', 'show_types', - 'no_negated_flags', 'version', 'argparse_kwargs', 'intermixed', 'argv']) + 'no_negated_flags', 'version', 'argparse_kwargs', 'intermixed', 'argv', + 'defaults']) def _options(**kwargs): @@ -359,12 +366,17 @@ def _create_parser(funcs, opts): **opts.argparse_kwargs}) version_sources = [] if callable(funcs): - _populate_parser(funcs, parser, opts) + _populate_parser(funcs, parser, opts, opts.defaults) version_sources.append(funcs) else: subparsers = parser.add_subparsers() for func, subparser in _recurse_functions(funcs, subparsers): - _populate_parser(func, subparser, opts) + # FIXME: need fully qualitified name for nested commands... + defaults = None if opts.defaults is None else opts.defaults.get( + func.__name__, + opts.defaults.get(func.__name__.replace('_', '-'), None) + ) + _populate_parser(func, subparser, opts, defaults) version_sources.append(func) if isinstance(opts.version, str): version_string = opts.version @@ -563,7 +575,7 @@ def _get_type_from_hint(hint): return hint -def _populate_parser(func, parser, opts): +def _populate_parser(func, parser, opts, defaults): sig = signature(func) parser.description = sig.doc @@ -605,6 +617,11 @@ def _populate_parser(func, parser, opts): raise ValueError(f'no type found for parameter {name}') hasdefault = param.default is not param.empty default = param.default if hasdefault else SUPPRESS + + if not hasdefault and defaults is not None and name in defaults: + hasdefault = True + default = defaults[name] + required = not hasdefault and param.kind != param.VAR_POSITIONAL positional = name in positionals