Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 142 additions & 146 deletions src/tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shutil
import sys
import warnings
from contextlib import nullcontext
from typing import Callable, Literal, Sequence, TypeVar, cast, overload

from typing_extensions import Annotated, assert_never, deprecated
Expand Down Expand Up @@ -502,23 +503,10 @@ def _cli_impl(
if write_completion:
completion_target_path = pathlib.Path(args[2])

# Map a callable to the relevant CLI arguments + subparsers.
with _settings.timing_context("Generate parser specification"):
if registry is not None:
with registry:
parser_spec = _parsers.ParserSpecification.from_callable_or_type(
f,
markers=set(),
description=description,
parent_classes=set(), # Used for recursive calls.
default_instance=default_instance, # Overrides for default values.
intern_prefix="", # Used for recursive calls.
extern_prefix="", # Used for recursive calls.
subcommand_prefix="",
support_single_arg_types=False,
prog_suffix="",
)
else:
registry_context = registry if registry is not None else nullcontext()
with registry_context:
# Map a callable to the relevant CLI arguments + subparsers.
with _settings.timing_context("Generate parser specification"):
parser_spec = _parsers.ParserSpecification.from_callable_or_type(
f,
markers=set(),
Expand All @@ -532,153 +520,161 @@ def _cli_impl(
prog_suffix="",
)

# Initialize backend.
if backend_name == "argparse":
from ._backends._argparse_backend import ArgparseBackend
# Initialize backend.
if backend_name == "argparse":
from ._backends._argparse_backend import ArgparseBackend

backend = ArgparseBackend()
elif backend_name == "tyro":
from ._backends._tyro_backend import TyroBackend
backend = ArgparseBackend()
elif backend_name == "tyro":
from ._backends._tyro_backend import TyroBackend

backend = TyroBackend()
else:
assert_never(backend_name)
backend = TyroBackend()
else:
assert_never(backend_name)

# Handle shell completion.
if print_completion or write_completion:
assert completion_shell in (
"bash",
"zsh",
"tcsh",
), (
f"Shell should be one `bash`, `zsh`, or `tcsh`, but got {completion_shell}"
)

# Handle shell completion.
if print_completion or write_completion:
assert completion_shell in (
"bash",
"zsh",
"tcsh",
), f"Shell should be one `bash`, `zsh`, or `tcsh`, but got {completion_shell}"
# Determine program name for completion script.
if prog is None:
prog = sys.argv[0]

# Determine program name for completion script.
if prog is None:
prog = sys.argv[0]
# Sanitize prog for use in function/variable names by replacing
# non-alphanumeric characters with underscores.
safe_prog = "".join(c if c.isalnum() or c == "_" else "_" for c in prog)

# Sanitize prog for use in function/variable names by replacing
# non-alphanumeric characters with underscores.
safe_prog = "".join(c if c.isalnum() or c == "_" else "_" for c in prog)
# Generate completion script using the backend's method.
completion_script = backend.generate_completion(
parser_spec,
prog=prog,
shell=completion_shell, # type: ignore
root_prefix=f"tyro_{safe_prog}",
)

# Generate completion script using the backend's method.
completion_script = backend.generate_completion(
parser_spec,
prog=prog,
shell=completion_shell, # type: ignore
root_prefix=f"tyro_{safe_prog}",
)
if write_completion and completion_target_path != pathlib.Path("-"):
assert completion_target_path is not None
completion_target_path.write_text(completion_script)
else:
print(completion_script)
sys.exit()

if write_completion and completion_target_path != pathlib.Path("-"):
assert completion_target_path is not None
completion_target_path.write_text(completion_script)
else:
print(completion_script)
sys.exit()
# For backwards compatibility with get_parser().
if return_parser:
return backend.get_parser_for_completion(
parser_spec, prog=prog, add_help=add_help
)

# For backwards compatibility with get_parser().
if return_parser:
return backend.get_parser_for_completion(
parser_spec, prog=prog, add_help=add_help
)
# Parse arguments using the backend.
if prog is None:
prog = sys.argv[0]

# Parse arguments using the backend.
if prog is None:
prog = sys.argv[0]

with _settings.timing_context("Parsing arguments"):
value_from_prefixed_field_name, unknown_args = backend.parse_args(
parser_spec=parser_spec,
args=args,
prog=prog,
return_unknown_args=return_unknown_args,
console_outputs=console_outputs,
add_help=add_help,
compact_help=compact_help,
)
with _settings.timing_context("Parsing arguments"):
value_from_prefixed_field_name, unknown_args = backend.parse_args(
parser_spec=parser_spec,
args=args,
prog=prog,
return_unknown_args=return_unknown_args,
console_outputs=console_outputs,
add_help=add_help,
compact_help=compact_help,
)

try:
# Attempt to call `f` using whatever was passed in.
get_out, consumed_keywords = _calling.callable_with_args(
f,
parser_spec,
default_instance,
value_from_prefixed_field_name,
field_name_prefix="",
)
except _calling.InstantiationError as e:
# Print prettier errors.
# This doesn't catch errors raised directly by get_out(), since that's
# called later! This is intentional, because we do less error handling
# for the root callable. Relevant: the `field_name_prefix == ""`
# condition in `callable_with_args()`!

# Emulate argparse's error behavior when invalid arguments are passed in.
error_box_rows: list[str | fmt.Element] = []
if isinstance(e.arg, _arguments.ArgumentDefinition):
display_name = (
str(e.arg.lowered.metavar)
if e.arg.is_positional()
else "/".join(e.arg.lowered.name_or_flags)
try:
# Attempt to call `f` using whatever was passed in.
get_out, consumed_keywords = _calling.callable_with_args(
f,
parser_spec,
default_instance,
value_from_prefixed_field_name,
field_name_prefix="",
)
error_box_rows.extend(
[
except _calling.InstantiationError as e:
# Print prettier errors.
# This doesn't catch errors raised directly by get_out(), since that's
# called later! This is intentional, because we do less error handling
# for the root callable. Relevant: the `field_name_prefix == ""`
# condition in `callable_with_args()`!

# Emulate argparse's error behavior when invalid arguments are passed in.
error_box_rows: list[str | fmt.Element] = []
if isinstance(e.arg, _arguments.ArgumentDefinition):
display_name = (
str(e.arg.lowered.metavar)
if e.arg.is_positional()
else "/".join(e.arg.lowered.name_or_flags)
)
error_box_rows.extend(
[
fmt.text(
fmt.text["bright_red", "bold"](
f"Error parsing {display_name}:"
),
" ",
e.message,
),
fmt.hr["red"](),
"Argument helptext:",
fmt.cols(
("", 4),
fmt.rows(
e.arg.get_invocation_text()[1],
_arguments.generate_argument_helptext(
e.arg, e.arg.lowered
),
),
),
]
)
else:
error_box_rows.append(
fmt.text(
fmt.text["bright_red", "bold"](
f"Error parsing {display_name}:"
f"Error parsing {e.arg}:",
),
" ",
e.message,
),
fmt.hr["red"](),
"Argument helptext:",
fmt.cols(
("", 4),
fmt.rows(
e.arg.get_invocation_text()[1],
_arguments.generate_argument_helptext(e.arg, e.arg.lowered),
)
)

if add_help:
error_box_rows.extend(
[
fmt.hr["red"](),
fmt.text(
"For full helptext, see ",
fmt.text["bold"](f"{prog} --help"),
),
),
]
)
else:
error_box_rows.append(
fmt.text(
fmt.text["bright_red", "bold"](
f"Error parsing {e.arg}:",
),
" ",
e.message,
]
)
print(
fmt.box["red"](
fmt.text["red"]("Value error"), fmt.rows(*error_box_rows)
),
file=sys.stderr,
flush=True,
)
sys.exit(2)

if add_help:
error_box_rows.extend(
[
fmt.hr["red"](),
fmt.text(
"For full helptext, see ",
fmt.text["bold"](f"{prog} --help"),
),
]
)
print(
fmt.box["red"](fmt.text["red"]("Value error"), fmt.rows(*error_box_rows)),
file=sys.stderr,
flush=True,
assert len(value_from_prefixed_field_name.keys() - consumed_keywords) == 0, (
f"Parsed {value_from_prefixed_field_name.keys()}, but only consumed"
f" {consumed_keywords}"
)
sys.exit(2)

assert len(value_from_prefixed_field_name.keys() - consumed_keywords) == 0, (
f"Parsed {value_from_prefixed_field_name.keys()}, but only consumed"
f" {consumed_keywords}"
)
if return_unknown_args:
assert unknown_args is not None, "Should have parsed with `parse_known_args()`"
# If we're parsed unknown args, we should return the original args, not
# the fixed ones.
if modified_args is not None:
unknown_args = [modified_args.get(arg, arg) for arg in unknown_args]
return get_out, unknown_args # type: ignore
else:
assert unknown_args is None, "Should have parsed with `parse_args()`"
return get_out # type: ignore
if return_unknown_args:
assert unknown_args is not None, (
"Should have parsed with `parse_known_args()`"
)
# If we're parsed unknown args, we should return the original args, not
# the fixed ones.
if modified_args is not None:
unknown_args = [modified_args.get(arg, arg) for arg in unknown_args]
return get_out, unknown_args # type: ignore
else:
assert unknown_args is None, "Should have parsed with `parse_args()`"
return get_out # type: ignore
27 changes: 27 additions & 0 deletions tests/test_custom_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,33 @@ def main_with_default(x: Dict[str, Any] = {"hello": 5}) -> Dict[str, Any]:
}


def test_registry_parameter_subcommand_cli_from_dict() -> None:
"""Test that registry parameter works with subcommand_cli_from_dict()."""
registry = tyro.constructors.ConstructorRegistry()

@registry.primitive_rule
def json_dict_spec(
type_info: tyro.constructors.PrimitiveTypeInfo,
) -> tyro.constructors.PrimitiveConstructorSpec | None:
if not (
type_info.type_origin is dict and get_args(type_info.type) == (str, Any)
):
return None
return json_constructor_spec

def main(
x: Dict[str, Any],
y: Dict[str, Any] = {"hello": 5},
) -> tuple[Dict[str, Any], Dict[str, Any]]:
return x, y

assert tyro.extras.subcommand_cli_from_dict(
{"main": main},
args=["main", "--x", '{"a": 1}'],
registry=registry,
) == ({"a": 1}, {"hello": 5})


# Define a custom dataclass
@dataclass
class CustomDataWithPrefix:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_py311_generated/test_custom_constructors_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,33 @@ def main_with_default(x: Dict[str, Any] = {"hello": 5}) -> Dict[str, Any]:
}


def test_registry_parameter_subcommand_cli_from_dict() -> None:
"""Test that registry parameter works with subcommand_cli_from_dict()."""
registry = tyro.constructors.ConstructorRegistry()

@registry.primitive_rule
def json_dict_spec(
type_info: tyro.constructors.PrimitiveTypeInfo,
) -> tyro.constructors.PrimitiveConstructorSpec | None:
if not (
type_info.type_origin is dict and get_args(type_info.type) == (str, Any)
):
return None
return json_constructor_spec

def main(
x: Dict[str, Any],
y: Dict[str, Any] = {"hello": 5},
) -> tuple[Dict[str, Any], Dict[str, Any]]:
return x, y

assert tyro.extras.subcommand_cli_from_dict(
{"main": main},
args=["main", "--x", '{"a": 1}'],
registry=registry,
) == ({"a": 1}, {"hello": 5})


# Define a custom dataclass
@dataclass
class CustomDataWithPrefix:
Expand Down
Loading