Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ optional arguments:
we can't generate Argparser for them.
- You can't have a function argument named `__command`.
- If you don't like the generated parser, you can modify it using `override` function.
- If you'd like to customize parser generation process:
- Make your own `ArgumentParser` generator by subclassing `ArgparserGenerator`
- Activate it by calling `set_default_generator(MyArgparserGenerator)`


## Alternatives
Expand Down
202 changes: 157 additions & 45 deletions func_argparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@
import typing
from argparse import ArgumentParser
from types import FunctionType, ModuleType
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Union
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)

AnyCallable = Callable[..., Any]
_GenericAlias = type(Union[int, str])
Expand Down Expand Up @@ -246,69 +258,161 @@ def _get_parser(t: Parser, flags: List[str]) -> Parser:
return t


def func_argparser(
fn: AnyCallable, parser: Optional[ArgumentParser] = None
) -> ArgumentParser:
"""Creates an ArgumentParser for the given function."""
if not parser:
parser = ArgumentParser(description=get_fn_description(fn))
parser.set_defaults(**{COMMAND_KEY: fn})
class ArgumentSpec:
"""A class defining an argument to be added to an ArgumentParser."""

def __init__(self, *flags: str, **kwargs: Any):
self.flags = flags
self.kwargs = kwargs

def add_to_parser(self, parser: ArgumentParser) -> None:
parser.add_argument(*self.flags, **self.kwargs)


class ArgparserGenerator:
"""A class for generating an ArgumentParser from a function."""

ArgParser = ArgumentParser

def generate_argparser(
self, fn: AnyCallable, parser: Optional[ArgumentParser] = None
) -> ArgumentParser:
"""Creates an ArgumentParser for the given function."""
arg_specs = self._get_arg_specs(fn)
if not parser:
parser = self._create_parser(fn)
parser.set_defaults(**{COMMAND_KEY: fn})
self._add_arguments_to_parser(parser, arg_specs)
return parser

def _get_arg_specs(self, fn: AnyCallable) -> List[ArgumentSpec]:
spec = inspect.getfullargspec(fn)
args = self._get_arg_names(fn)

if spec.defaults:
defaults = dict(zip(reversed(args), reversed(spec.defaults)))
else:
defaults = {}
args_desc = _get_arguments_description(fn, spec, defaults)

prefixes = self._resolve_prefixes(fn)

# Get all arguments to add
arg_specs = []
for a, t in self._get_args_and_annotations(fn):
doc = args_desc.get(a)
for arg_spec in self._gen_param_arguments(
a, t, doc, defaults.get(a), a in defaults, prefixes.get(a)
):
arg_specs.append(arg_spec)

return arg_specs

def _gen_param_arguments(
self,
arg_name: str,
arg_type: Type[Any],
doc: Optional[str],
default: Any,
has_default: bool,
prefix: Optional[str],
) -> Iterator[ArgumentSpec]:

a = arg_name
t = arg_type

spec = inspect.getfullargspec(fn)
args = spec.args
if isinstance(fn, type):
# Ignore `self` from `__init__` method.
args = args[1:]
for a in args:
assert a in spec.annotations, f"Need a type annotation for argument {a} of {fn}"

if spec.defaults:
defaults = dict(zip(reversed(args), reversed(spec.defaults)))
else:
defaults = {}
args_desc = _get_arguments_description(fn, spec, defaults)

# One letter arguments are given the short flags.
prefixes: Set[str] = set(a for a in args if len(a) == 1)
# -h is always for help.
prefixes.add("h")
for a, t in spec.annotations.items():
if a == "return":
continue
doc = args_desc.get(a)
flags = [f"--{a}"]
if len(a) == 1 or a[0] not in prefixes:
flags.insert(0, f"-{a[0]}")
prefixes.add(a[0])
if prefix is not None:
flags.insert(0, f"-{prefix}")

if t is bool:
d = defaults.get(a, False)
parser.add_argument(*flags, default=d, action="store_true", help=doc)
if default is None:
default = False
yield ArgumentSpec(*flags, default=default, action="store_true", help=doc)
# The --no flags are hidden
parser.add_argument(
yield ArgumentSpec(
f"--no-{a}", dest=a, action="store_false", help=argparse.SUPPRESS
)
continue
return

if _is_option_type(t):
if a not in defaults:
defaults[a] = None
required = not has_default
if required and _is_option_type(t):
required = False
if not has_default:
default = None
has_default = True

action = "store"
t_contained = _get_list_contained_type(t)
if t_contained is not None:
action = "append"
t = t_contained

parser.add_argument(
*flags,
kwargs = dict(
type=_get_parser(t, flags),
action=action,
default=defaults.get(a),
required=a not in defaults,
default=default,
required=required,
help=doc,
)
return parser
yield ArgumentSpec(*flags, **kwargs)

def _create_parser(self, fn: AnyCallable) -> ArgumentParser:
return self.ArgParser(description=get_fn_description(fn))

def _add_arguments_to_parser(
self, parser: ArgumentParser, arg_specs: List[ArgumentSpec]
) -> None:
for arg_spec in arg_specs:
arg_spec.add_to_parser(parser)

def _get_arg_names(self, fn: AnyCallable) -> List[str]:
spec = inspect.getfullargspec(fn)
args = spec.args
if isinstance(fn, type):
# Ignore `self` from `__init__` method.
args = args[1:]
return args

def _get_args_and_annotations(self, fn: AnyCallable) -> List[Tuple[str, Type[Any]]]:
spec = inspect.getfullargspec(fn)
args = spec.args
if isinstance(fn, type):
# Ignore `self` from `__init__` method.
args = args[1:]
for a in args:
assert (
a in spec.annotations
), f"Need a type annotation for argument {a} of {fn}"
arg_specs = []
for a, t in spec.annotations.items():
if a == "return":
continue
arg_specs.append((a, t))
return arg_specs

def _resolve_prefixes(self, fn: AnyCallable) -> Dict[str, str]:
# One letter arguments are given the short flags.
prefixes: Set[str] = set(a for a in self._get_arg_names(fn) if len(a) == 1)
# -h is always for help.
prefixes.add("h")
extra_prefixes = {}
for a in self._get_arg_names(fn):
if len(a) == 1 or a[0] not in prefixes:
extra_prefixes[a] = a[0]
prefixes.add(a[0])
return extra_prefixes


def func_argparser(
fn: AnyCallable,
parser: Optional[ArgumentParser] = None,
argparser_generator: Optional[ArgparserGenerator] = None,
) -> ArgumentParser:
"""Creates an ArgumentParser for the given function."""
if argparser_generator is None:
argparser_generator = _DEFAULT_GENERATOR_CLASS()
return argparser_generator.generate_argparser(fn, parser)


def override(
Expand Down Expand Up @@ -358,3 +462,11 @@ def override(
action.help = None
if metavar is not None:
action.metavar = metavar


_DEFAULT_GENERATOR_CLASS = ArgparserGenerator


def set_default_generator(generator_class: Type[Any]) -> None:
global _DEFAULT_GENERATOR_CLASS
_DEFAULT_GENERATOR_CLASS = generator_class