diff --git a/src/tyro/_cli.py b/src/tyro/_cli.py index 472223c9..8f48f7bb 100644 --- a/src/tyro/_cli.py +++ b/src/tyro/_cli.py @@ -6,6 +6,7 @@ import shutil import sys import warnings +from contextlib import nullcontext from typing import Callable, Literal, Sequence, TypeVar, cast, overload from typing_extensions import Annotated, assert_never, deprecated @@ -502,23 +503,10 @@ def _cli_impl( if write_completion: completion_target_path = pathlib.Path(args[2]) - # Map a callable to the relevant CLI arguments + subparsers. - with _settings.timing_context("Generate parser specification"): - if registry is not None: - with registry: - parser_spec = _parsers.ParserSpecification.from_callable_or_type( - f, - markers=set(), - description=description, - parent_classes=set(), # Used for recursive calls. - default_instance=default_instance, # Overrides for default values. - intern_prefix="", # Used for recursive calls. - extern_prefix="", # Used for recursive calls. - subcommand_prefix="", - support_single_arg_types=False, - prog_suffix="", - ) - else: + registry_context = registry if registry is not None else nullcontext() + with registry_context: + # Map a callable to the relevant CLI arguments + subparsers. + with _settings.timing_context("Generate parser specification"): parser_spec = _parsers.ParserSpecification.from_callable_or_type( f, markers=set(), @@ -532,153 +520,161 @@ def _cli_impl( prog_suffix="", ) - # Initialize backend. - if backend_name == "argparse": - from ._backends._argparse_backend import ArgparseBackend + # Initialize backend. + if backend_name == "argparse": + from ._backends._argparse_backend import ArgparseBackend - backend = ArgparseBackend() - elif backend_name == "tyro": - from ._backends._tyro_backend import TyroBackend + backend = ArgparseBackend() + elif backend_name == "tyro": + from ._backends._tyro_backend import TyroBackend - backend = TyroBackend() - else: - assert_never(backend_name) + backend = TyroBackend() + else: + assert_never(backend_name) + + # Handle shell completion. + if print_completion or write_completion: + assert completion_shell in ( + "bash", + "zsh", + "tcsh", + ), ( + f"Shell should be one `bash`, `zsh`, or `tcsh`, but got {completion_shell}" + ) - # Handle shell completion. - if print_completion or write_completion: - assert completion_shell in ( - "bash", - "zsh", - "tcsh", - ), f"Shell should be one `bash`, `zsh`, or `tcsh`, but got {completion_shell}" + # Determine program name for completion script. + if prog is None: + prog = sys.argv[0] - # Determine program name for completion script. - if prog is None: - prog = sys.argv[0] + # Sanitize prog for use in function/variable names by replacing + # non-alphanumeric characters with underscores. + safe_prog = "".join(c if c.isalnum() or c == "_" else "_" for c in prog) - # Sanitize prog for use in function/variable names by replacing - # non-alphanumeric characters with underscores. - safe_prog = "".join(c if c.isalnum() or c == "_" else "_" for c in prog) + # Generate completion script using the backend's method. + completion_script = backend.generate_completion( + parser_spec, + prog=prog, + shell=completion_shell, # type: ignore + root_prefix=f"tyro_{safe_prog}", + ) - # Generate completion script using the backend's method. - completion_script = backend.generate_completion( - parser_spec, - prog=prog, - shell=completion_shell, # type: ignore - root_prefix=f"tyro_{safe_prog}", - ) + if write_completion and completion_target_path != pathlib.Path("-"): + assert completion_target_path is not None + completion_target_path.write_text(completion_script) + else: + print(completion_script) + sys.exit() - if write_completion and completion_target_path != pathlib.Path("-"): - assert completion_target_path is not None - completion_target_path.write_text(completion_script) - else: - print(completion_script) - sys.exit() + # For backwards compatibility with get_parser(). + if return_parser: + return backend.get_parser_for_completion( + parser_spec, prog=prog, add_help=add_help + ) - # For backwards compatibility with get_parser(). - if return_parser: - return backend.get_parser_for_completion( - parser_spec, prog=prog, add_help=add_help - ) + # Parse arguments using the backend. + if prog is None: + prog = sys.argv[0] - # Parse arguments using the backend. - if prog is None: - prog = sys.argv[0] - - with _settings.timing_context("Parsing arguments"): - value_from_prefixed_field_name, unknown_args = backend.parse_args( - parser_spec=parser_spec, - args=args, - prog=prog, - return_unknown_args=return_unknown_args, - console_outputs=console_outputs, - add_help=add_help, - compact_help=compact_help, - ) + with _settings.timing_context("Parsing arguments"): + value_from_prefixed_field_name, unknown_args = backend.parse_args( + parser_spec=parser_spec, + args=args, + prog=prog, + return_unknown_args=return_unknown_args, + console_outputs=console_outputs, + add_help=add_help, + compact_help=compact_help, + ) - try: - # Attempt to call `f` using whatever was passed in. - get_out, consumed_keywords = _calling.callable_with_args( - f, - parser_spec, - default_instance, - value_from_prefixed_field_name, - field_name_prefix="", - ) - except _calling.InstantiationError as e: - # Print prettier errors. - # This doesn't catch errors raised directly by get_out(), since that's - # called later! This is intentional, because we do less error handling - # for the root callable. Relevant: the `field_name_prefix == ""` - # condition in `callable_with_args()`! - - # Emulate argparse's error behavior when invalid arguments are passed in. - error_box_rows: list[str | fmt.Element] = [] - if isinstance(e.arg, _arguments.ArgumentDefinition): - display_name = ( - str(e.arg.lowered.metavar) - if e.arg.is_positional() - else "/".join(e.arg.lowered.name_or_flags) + try: + # Attempt to call `f` using whatever was passed in. + get_out, consumed_keywords = _calling.callable_with_args( + f, + parser_spec, + default_instance, + value_from_prefixed_field_name, + field_name_prefix="", ) - error_box_rows.extend( - [ + except _calling.InstantiationError as e: + # Print prettier errors. + # This doesn't catch errors raised directly by get_out(), since that's + # called later! This is intentional, because we do less error handling + # for the root callable. Relevant: the `field_name_prefix == ""` + # condition in `callable_with_args()`! + + # Emulate argparse's error behavior when invalid arguments are passed in. + error_box_rows: list[str | fmt.Element] = [] + if isinstance(e.arg, _arguments.ArgumentDefinition): + display_name = ( + str(e.arg.lowered.metavar) + if e.arg.is_positional() + else "/".join(e.arg.lowered.name_or_flags) + ) + error_box_rows.extend( + [ + fmt.text( + fmt.text["bright_red", "bold"]( + f"Error parsing {display_name}:" + ), + " ", + e.message, + ), + fmt.hr["red"](), + "Argument helptext:", + fmt.cols( + ("", 4), + fmt.rows( + e.arg.get_invocation_text()[1], + _arguments.generate_argument_helptext( + e.arg, e.arg.lowered + ), + ), + ), + ] + ) + else: + error_box_rows.append( fmt.text( fmt.text["bright_red", "bold"]( - f"Error parsing {display_name}:" + f"Error parsing {e.arg}:", ), " ", e.message, - ), - fmt.hr["red"](), - "Argument helptext:", - fmt.cols( - ("", 4), - fmt.rows( - e.arg.get_invocation_text()[1], - _arguments.generate_argument_helptext(e.arg, e.arg.lowered), + ) + ) + + if add_help: + error_box_rows.extend( + [ + fmt.hr["red"](), + fmt.text( + "For full helptext, see ", + fmt.text["bold"](f"{prog} --help"), ), - ), - ] - ) - else: - error_box_rows.append( - fmt.text( - fmt.text["bright_red", "bold"]( - f"Error parsing {e.arg}:", - ), - " ", - e.message, + ] + ) + print( + fmt.box["red"]( + fmt.text["red"]("Value error"), fmt.rows(*error_box_rows) ), + file=sys.stderr, + flush=True, ) + sys.exit(2) - if add_help: - error_box_rows.extend( - [ - fmt.hr["red"](), - fmt.text( - "For full helptext, see ", - fmt.text["bold"](f"{prog} --help"), - ), - ] - ) - print( - fmt.box["red"](fmt.text["red"]("Value error"), fmt.rows(*error_box_rows)), - file=sys.stderr, - flush=True, + assert len(value_from_prefixed_field_name.keys() - consumed_keywords) == 0, ( + f"Parsed {value_from_prefixed_field_name.keys()}, but only consumed" + f" {consumed_keywords}" ) - sys.exit(2) - - assert len(value_from_prefixed_field_name.keys() - consumed_keywords) == 0, ( - f"Parsed {value_from_prefixed_field_name.keys()}, but only consumed" - f" {consumed_keywords}" - ) - if return_unknown_args: - assert unknown_args is not None, "Should have parsed with `parse_known_args()`" - # If we're parsed unknown args, we should return the original args, not - # the fixed ones. - if modified_args is not None: - unknown_args = [modified_args.get(arg, arg) for arg in unknown_args] - return get_out, unknown_args # type: ignore - else: - assert unknown_args is None, "Should have parsed with `parse_args()`" - return get_out # type: ignore + if return_unknown_args: + assert unknown_args is not None, ( + "Should have parsed with `parse_known_args()`" + ) + # If we're parsed unknown args, we should return the original args, not + # the fixed ones. + if modified_args is not None: + unknown_args = [modified_args.get(arg, arg) for arg in unknown_args] + return get_out, unknown_args # type: ignore + else: + assert unknown_args is None, "Should have parsed with `parse_args()`" + return get_out # type: ignore diff --git a/tests/test_custom_constructors.py b/tests/test_custom_constructors.py index f5213b00..9a67f229 100644 --- a/tests/test_custom_constructors.py +++ b/tests/test_custom_constructors.py @@ -236,6 +236,33 @@ def main_with_default(x: Dict[str, Any] = {"hello": 5}) -> Dict[str, Any]: } +def test_registry_parameter_subcommand_cli_from_dict() -> None: + """Test that registry parameter works with subcommand_cli_from_dict().""" + registry = tyro.constructors.ConstructorRegistry() + + @registry.primitive_rule + def json_dict_spec( + type_info: tyro.constructors.PrimitiveTypeInfo, + ) -> tyro.constructors.PrimitiveConstructorSpec | None: + if not ( + type_info.type_origin is dict and get_args(type_info.type) == (str, Any) + ): + return None + return json_constructor_spec + + def main( + x: Dict[str, Any], + y: Dict[str, Any] = {"hello": 5}, + ) -> tuple[Dict[str, Any], Dict[str, Any]]: + return x, y + + assert tyro.extras.subcommand_cli_from_dict( + {"main": main}, + args=["main", "--x", '{"a": 1}'], + registry=registry, + ) == ({"a": 1}, {"hello": 5}) + + # Define a custom dataclass @dataclass class CustomDataWithPrefix: diff --git a/tests/test_py311_generated/test_custom_constructors_generated.py b/tests/test_py311_generated/test_custom_constructors_generated.py index 99306469..0fb5bd85 100644 --- a/tests/test_py311_generated/test_custom_constructors_generated.py +++ b/tests/test_py311_generated/test_custom_constructors_generated.py @@ -235,6 +235,33 @@ def main_with_default(x: Dict[str, Any] = {"hello": 5}) -> Dict[str, Any]: } +def test_registry_parameter_subcommand_cli_from_dict() -> None: + """Test that registry parameter works with subcommand_cli_from_dict().""" + registry = tyro.constructors.ConstructorRegistry() + + @registry.primitive_rule + def json_dict_spec( + type_info: tyro.constructors.PrimitiveTypeInfo, + ) -> tyro.constructors.PrimitiveConstructorSpec | None: + if not ( + type_info.type_origin is dict and get_args(type_info.type) == (str, Any) + ): + return None + return json_constructor_spec + + def main( + x: Dict[str, Any], + y: Dict[str, Any] = {"hello": 5}, + ) -> tuple[Dict[str, Any], Dict[str, Any]]: + return x, y + + assert tyro.extras.subcommand_cli_from_dict( + {"main": main}, + args=["main", "--x", '{"a": 1}'], + registry=registry, + ) == ({"a": 1}, {"hello": 5}) + + # Define a custom dataclass @dataclass class CustomDataWithPrefix: