diff --git a/trphysx/config/arg_parser.py b/trphysx/config/arg_parser.py index 7f26dc6..bdf79a6 100755 --- a/trphysx/config/arg_parser.py +++ b/trphysx/config/arg_parser.py @@ -8,143 +8,330 @@ ===== """ import dataclasses -from dataclasses import dataclass +import json import sys -from argparse import ArgumentParser +import types +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError +from copy import copy from enum import Enum +from inspect import isclass from pathlib import Path -from typing import Any, Iterable, List, Dict, Optional, Tuple, Union -from typing_extensions import Protocol +from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints +import yaml -class DataClass(Protocol): - # checking for this attribute is currently - # the most reliable way to ascertain that something is a dataclass - __dataclass_fields__: Dict + +DataClass = NewType("DataClass", Any) +DataClassType = NewType("DataClassType", Any) + + +# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse +def string_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise ArgumentTypeError( + f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." + ) + + +def make_choice_type_function(choices: list) -> Callable[[str], Any]: + """ + Creates a mapping function from each choices string representation to the actual value. Used to support multiple + value types for a single argument. + + Args: + choices (list): List of choices. + + Returns: + Callable[[str], Any]: Mapping function from string representation to actual value for each choice. + """ + str_to_choice = {str(choice): choice for choice in choices} + return lambda arg: str_to_choice.get(arg, arg) + + +def HfArg( + *, + aliases: Union[str, List[str]] = None, + help: str = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[[], Any] = dataclasses.MISSING, + metadata: dict = None, + **kwargs, +) -> dataclasses.Field: + """Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`. + + Example comparing the use of `HfArg` and `dataclasses.field`: + ``` + @dataclass + class Args: + regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"}) + hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!") + ``` + + Args: + aliases (Union[str, List[str]], optional): + Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`. + Defaults to None. + help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None. + default (Any, optional): + Default value for the argument. If not default or default_factory is specified, the argument is required. + Defaults to dataclasses.MISSING. + default_factory (Callable[[], Any], optional): + The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide + default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`. + Defaults to dataclasses.MISSING. + metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None. + + Returns: + Field: A `dataclasses.Field` with the desired properties. + """ + if metadata is None: + # Important, don't use as default param in function signature because dict is mutable and shared across function calls + metadata = {} + if aliases is not None: + metadata["aliases"] = aliases + if help is not None: + metadata["help"] = help + + return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs) class HfArgumentParser(ArgumentParser): """ This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. + The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) arguments to the parser after initialization and you'll get the output back after parsing as an additional - namespace. + namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass. - Code originally from the Huggingface Transformers repository: - https://github.com/huggingface/transformers/blob/master/src/transformers/hf_argparser.py + Code copied from the Huggingface Transformers repository on 04/08/23, + https://github.com/huggingface/transformers/blob/master/src/transformers/hf_argparser.py + version v4.31.0 commit hash 2642d8d04b14c18199ebe7b35f976da02df61752 - Args: - dataclass_types (Union[DataClass, Iterable[DataClass]]): - Dataclass type, or list of dataclass types for which - we will "fill" instances with the parsed args. - kwargs (optional): Passed to `argparse.ArgumentParser()` in the regular way. """ + + dataclass_types: Iterable[DataClassType] - dataclass_types: Iterable[DataClass] - - def __init__(self, dataclass_types: Union[DataClass, Iterable[DataClass]], **kwargs): - """Constructor + def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs): """ + Args: + dataclass_types: + Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args. + kwargs (`Dict[str, Any]`, *optional*): + Passed to `argparse.ArgumentParser()` in the regular way. + """ + # To make the default appear when using --help + if "formatter_class" not in kwargs: + kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter super().__init__(**kwargs) if dataclasses.is_dataclass(dataclass_types): dataclass_types = [dataclass_types] - self.dataclass_types = dataclass_types + self.dataclass_types = list(dataclass_types) for dtype in self.dataclass_types: self._add_dataclass_arguments(dtype) - def _add_dataclass_arguments(self, dtype: DataClass): - for field in dataclasses.fields(dtype): - field_name = f"--{field.name}" - kwargs = field.metadata.copy() - # field.metadata is not used at all by Data Classes, - # it is provided as a third-party extension mechanism. - if isinstance(field.type, str): - raise ImportError( - "This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563)," - "which can be opted in from Python 3.7 with `from __future__ import annotations`." - "We will add compatibility when Python 3.9 is released." + @staticmethod + def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): + field_name = f"--{field.name}" + kwargs = field.metadata.copy() + # field.metadata is not used at all by Data Classes, + # it is provided as a third-party extension mechanism. + if isinstance(field.type, str): + raise RuntimeError( + "Unresolved type detected, which should have been done with the help of " + "`typing.get_type_hints` method by default" + ) + + aliases = kwargs.pop("aliases", []) + if isinstance(aliases, str): + aliases = [aliases] + + origin_type = getattr(field.type, "__origin__", field.type) + if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)): + if str not in field.type.__args__ and ( + len(field.type.__args__) != 2 or type(None) not in field.type.__args__ + ): + raise ValueError( + "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because" + " the argument parser only supports one type per argument." + f" Problem encountered in field '{field.name}'." + ) + if type(None) not in field.type.__args__: + # filter `str` in Union + field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1] + origin_type = getattr(field.type, "__origin__", field.type) + elif bool not in field.type.__args__: + # filter `NoneType` in Union (except for `Union[bool, NoneType]`) + field.type = ( + field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1] ) - typestring = str(field.type) - for prim_type in (int, float, str): - for collection in (List,): - if typestring == f"typing.Union[{collection[prim_type]}, NoneType]": - field.type = collection[prim_type] - if typestring == f"typing.Union[{prim_type.__name__}, NoneType]": - field.type = prim_type - - if isinstance(field.type, type) and issubclass(field.type, Enum): - kwargs["choices"] = list(field.type) - kwargs["type"] = field.type - if field.default is not dataclasses.MISSING: - kwargs["default"] = field.default - elif field.type is bool: - if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): - kwargs["action"] = "store_false" if field.default is True else "store_true" - if field.default is True: - field_name = f"--no_{field.name}" - kwargs["dest"] = field.name - # Python 3.9 fix - elif hasattr(field.type, "__origin__") and "List" in str(field.type): - kwargs["nargs"] = "+" - kwargs["type"] = field.type.__args__[0] - assert all( - x == kwargs["type"] for x in field.type.__args__ - ), "{} cannot be a List of mixed types".format(field.name) - if field.default_factory is not dataclasses.MISSING: - kwargs["default"] = field.default_factory() + origin_type = getattr(field.type, "__origin__", field.type) + + # A variable to store kwargs for a boolean field, if needed + # so that we can init a `no_*` complement argument (see below) + bool_kwargs = {} + if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)): + if origin_type is Literal: + kwargs["choices"] = field.type.__args__ + else: + kwargs["choices"] = [x.value for x in field.type] + + kwargs["type"] = make_choice_type_function(kwargs["choices"]) + + if field.default is not dataclasses.MISSING: + kwargs["default"] = field.default else: - kwargs["type"] = field.type - if field.default is not dataclasses.MISSING: - kwargs["default"] = field.default - elif field.default_factory is not dataclasses.MISSING: - kwargs["default"] = field.default_factory() - else: - kwargs["required"] = True - self.add_argument(field_name, **kwargs) + kwargs["required"] = True + elif field.type is bool or field.type == Optional[bool]: + # Copy the currect kwargs to use to instantiate a `no_*` complement argument below. + # We do not initialize it here because the `no_*` alternative must be instantiated after the real argument + bool_kwargs = copy(kwargs) + + # Hack because type=bool in argparse does not behave as we want. + kwargs["type"] = string_to_bool + if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): + # Default value is False if we have no default when of type bool. + default = False if field.default is dataclasses.MISSING else field.default + # This is the value that will get picked if we don't include --field_name in any way + kwargs["default"] = default + # This tells argparse we accept 0 or 1 value after --field_name + kwargs["nargs"] = "?" + # This is the value that will get picked if we do --field_name (without value) + kwargs["const"] = True + elif isclass(origin_type) and issubclass(origin_type, list): + kwargs["type"] = field.type.__args__[0] + kwargs["nargs"] = "+" + if field.default_factory is not dataclasses.MISSING: + kwargs["default"] = field.default_factory() + elif field.default is dataclasses.MISSING: + kwargs["required"] = True + else: + kwargs["type"] = field.type + if field.default is not dataclasses.MISSING: + kwargs["default"] = field.default + elif field.default_factory is not dataclasses.MISSING: + kwargs["default"] = field.default_factory() + else: + kwargs["required"] = True + parser.add_argument(field_name, *aliases, **kwargs) + + # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added. + # Order is important for arguments with the same destination! + # We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down + # here and we do not need those changes/additional keys. + if field.default is True and (field.type is bool or field.type == Optional[bool]): + bool_kwargs["default"] = False + parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs) + + def _add_dataclass_arguments(self, dtype: DataClassType): + if hasattr(dtype, "_argument_group_name"): + parser = self.add_argument_group(dtype._argument_group_name) + else: + parser = self + + try: + type_hints: Dict[str, type] = get_type_hints(dtype) + except NameError: + raise RuntimeError( + f"Type resolution failed for {dtype}. Try declaring the class in global scope or " + "removing line of `from __future__ import annotations` which opts in Postponed " + "Evaluation of Annotations (PEP 563)" + ) + except TypeError as ex: + # Remove this block when we drop Python 3.9 support + if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex): + python_version = ".".join(map(str, sys.version_info[:3])) + raise RuntimeError( + f"Type resolution failed for {dtype} on Python {python_version}. Try removing " + "line of `from __future__ import annotations` which opts in union types as " + "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To " + "support Python versions that lower than 3.10, you need to use " + "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of " + "`X | None`." + ) from ex + raise + + for field in dataclasses.fields(dtype): + if not field.init: + continue + field.type = type_hints[field.name] + self._parse_dataclass_field(parser, field) def parse_args_into_dataclasses( self, - args: Iterable[str] = None, - return_remaining_strings: bool = False, - look_for_args_file: bool = True, - args_filename: str = None - ) -> Tuple[DataClass]: + args=None, + return_remaining_strings=False, + look_for_args_file=True, + args_filename=None, + args_file_flag=None, + ) -> Tuple[DataClass, ...]: """ Parse command-line args into instances of the specified dataclass types. + This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at: + docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args + Args: - args (Iterable[str]): + args: List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser) - return_remaining_strings (bool): + return_remaining_strings: If true, also return a list of remaining argument strings. - look_for_args_file (bool): + look_for_args_file: If true, will look for a ".args" file with the same base name as the entry point script for this process, and will append its potential content to the command line args. - args_filename (str): + args_filename: If not None, will uses this file instead of the ".args" file specified in the previous argument. + args_file_flag: + If not None, will look for a file in the command-line args specified with this flag. The flag can be + specified multiple times and precedence is determined by the order (last one wins). + Returns: - (Tuple[DataClass]): + Tuple consisting of: + - the dataclass instances in the same order as they were passed to the initializer.abspath - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser after initialization. - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) """ - if args_filename or (look_for_args_file and len(sys.argv)): + + if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)): + args_files = [] + if args_filename: - args_file = Path(args_filename) - else: - args_file = Path(sys.argv[0]).with_suffix(".args") - - if args_file.exists(): - fargs = args_file.read_text().split() - args = fargs + \ - args if args is not None else fargs + sys.argv[1:] - # in case of duplicate arguments the first one has precedence - # so we append rather than prepend. + args_files.append(Path(args_filename)) + elif look_for_args_file and len(sys.argv): + args_files.append(Path(sys.argv[0]).with_suffix(".args")) + + # args files specified via command line flag should overwrite default args files so we add them last + if args_file_flag: + # Create special parser just to extract the args_file_flag values + args_file_parser = ArgumentParser() + args_file_parser.add_argument(args_file_flag, type=str, action="append") + + # Use only remaining args for further parsing (remove the args_file_flag) + cfg, args = args_file_parser.parse_known_args(args=args) + cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None) + + if cmd_args_file_paths: + args_files.extend([Path(p) for p in cmd_args_file_paths]) + + file_args = [] + for args_file in args_files: + if args_file.exists(): + file_args += args_file.read_text().split() + + # in case of duplicate arguments the last one has precedence + # args specified via the command line should overwrite args from files, so we add them last + args = file_args + args if args is not None else file_args + sys.argv[1:] namespace, remaining_args = self.parse_known_args(args=args) outputs = [] for dtype in self.dataclass_types: - keys = {f.name for f in dataclasses.fields(dtype)} + keys = {f.name for f in dataclasses.fields(dtype) if f.init} inputs = {k: v for k, v in vars(namespace).items() if k in keys} for k in keys: delattr(namespace, k) @@ -157,7 +344,76 @@ def parse_args_into_dataclasses( return (*outputs, remaining_args) else: if remaining_args: - raise ValueError( - f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") + raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") return (*outputs,) + + def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass + types. + + Args: + args (`dict`): + dict containing config values + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. + """ + unused_keys = set(args.keys()) + outputs = [] + for dtype in self.dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype) if f.init} + inputs = {k: v for k, v in args.items() if k in keys} + unused_keys.difference_update(inputs.keys()) + obj = dtype(**inputs) + outputs.append(obj) + if not allow_extra_keys and unused_keys: + raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") + return tuple(outputs) + + def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the + dataclass types. + + Args: + json_file (`str` or `os.PathLike`): + File name of the json file to parse + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the json file contains keys that are not + parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. + """ + with open(Path(json_file), encoding="utf-8") as open_json_file: + data = json.loads(open_json_file.read()) + outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys) + return tuple(outputs) + + def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the + dataclass types. + + Args: + yaml_file (`str` or `os.PathLike`): + File name of the yaml file to parse + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the json file contains keys that are not + parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. + """ + outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys) + return tuple(outputs)