diff --git a/param/__init__.py b/param/__init__.py index 4aa051ee..b2373ba3 100644 --- a/param/__init__.py +++ b/param/__init__.py @@ -100,6 +100,7 @@ Event, ) from .reactive import bind, rx +from .typed import ParamField, ParamModel from ._utils import ( descendents, concrete_descendents, @@ -187,6 +188,8 @@ 'Number', 'NumericTuple', 'ObjectSelector', + 'ParamField', + 'ParamModel', 'ParamOverrides', 'Parameter', 'Parameterized', diff --git a/param/typed.py b/param/typed.py new file mode 100644 index 00000000..16c0298f --- /dev/null +++ b/param/typed.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import copy +import importlib +import sys +import types +import typing as t + +from collections.abc import Callable, Mapping +from typing import Any +from typing_extensions import dataclass_transform + +from .parameterized import ( + Parameter, + Parameterized, + ParameterizedMetaclass, + String, + Undefined, +) + +FT = t.TypeVar("FT") + + +class _ParamFieldSpec: + + __slots__ = ("default", "default_factory", "parameter", "kwargs") + + def __init__( + self, + *, + default: Any = Undefined, + default_factory: Callable[..., Any] | Any = Undefined, + parameter: type[Parameter] | Callable[..., Parameter] | Parameter | None = None, + **kwargs: Any, + ): + self.default = default + self.default_factory = default_factory + self.parameter = parameter + self.kwargs = kwargs + + +@t.overload +def ParamField( + *, + default: FT, + default_factory: Callable[..., Any] | Any = Undefined, + parameter: type[Parameter] | Callable[..., Parameter] | Parameter | None = None, + **kwargs: Any, +) -> FT: + ... + + +@t.overload +def ParamField( + *, + default: Any = Undefined, + default_factory: Callable[[], FT], + parameter: type[Parameter] | Callable[..., Parameter] | Parameter | None = None, + **kwargs: Any, +) -> FT: + ... + + +@t.overload +def ParamField( + *, + default: Any = Undefined, + default_factory: Callable[..., Any] | Any = Undefined, + parameter: type[Parameter] | Callable[..., Parameter] | Parameter | None = None, + **kwargs: Any, +) -> Any: + ... + + +def ParamField( + *, + default: Any = Undefined, + default_factory: Callable[..., Any] | Any = Undefined, + parameter: type[Parameter] | Callable[..., Parameter] | Parameter | None = None, + **kwargs: Any, +) -> Any: + """ParamField specifier for ParamModel attributes.""" + return t.cast( + "Any", + _ParamFieldSpec( + default=default, + default_factory=default_factory, + parameter=parameter, + **kwargs, + ), + ) + + +def _annotation_parameter_factory(annotation: Any) -> tuple[type[Parameter], dict[str, Any]]: + from .parameters import ( + Boolean, + ClassSelector, + Dict, + Integer, + List, + Number, + Selector, + Tuple, + ) + + kwargs: dict[str, Any] = {} + ann = annotation + origin = t.get_origin(ann) + if origin is t.Annotated: + annotated_args = list(t.get_args(ann)) + ann = annotated_args[0] if annotated_args else ann + for meta in annotated_args[1:]: + if isinstance(meta, Mapping): + kwargs.update(dict(meta)) + origin = t.get_origin(ann) + + if origin in (t.Union, types.UnionType): + union_args = list(t.get_args(ann)) + non_none = [a for a in union_args if a is not type(None)] + if len(non_none) < len(union_args): + kwargs["allow_None"] = True + ann = non_none[0] if len(non_none) == 1 else ann + origin = t.get_origin(ann) + + if origin is t.Literal: + kwargs["objects"] = list(t.get_args(ann)) + return Selector, kwargs + + if ann is bool: + return Boolean, kwargs + if ann is int: + return Integer, kwargs + if ann is float: + return Number, kwargs + if ann is str: + return String, kwargs + + if origin in (list, t.List): + list_args = t.get_args(ann) + if list_args and isinstance(list_args[0], type): + kwargs["item_type"] = list_args[0] + return List, kwargs + + if origin in (tuple, t.Tuple): + tuple_args = t.get_args(ann) + if tuple_args and tuple_args[-1] is not Ellipsis: + kwargs["length"] = len(tuple_args) + return Tuple, kwargs + + if origin in (dict, t.Dict): + return Dict, kwargs + + if origin in (set, t.Set): + kwargs["class_"] = set + return ClassSelector, kwargs + + if ann in (Any, object): + return Parameter, kwargs + + return Parameter, kwargs + + +def _build_parameter_from_field( + annotation: Any, + *, + field_spec: _ParamFieldSpec | None, + explicit_value: Any = Undefined, + has_explicit_value: bool = False, +) -> Parameter: + factory_kwargs: dict[str, Any] = {} + if field_spec is not None and field_spec.parameter is not None: + factory: type[Parameter] | Callable[..., Parameter] | Parameter | None = field_spec.parameter + else: + factory, inferred = _annotation_parameter_factory(annotation) + factory_kwargs.update(inferred) + + if field_spec is not None: + factory_kwargs.update(field_spec.kwargs) + if field_spec.default is not Undefined: + factory_kwargs["default"] = field_spec.default + if field_spec.default_factory is not Undefined: + factory_kwargs["default_factory"] = field_spec.default_factory + + if has_explicit_value: + factory_kwargs["default"] = explicit_value + elif "default" not in factory_kwargs and "default_factory" not in factory_kwargs: + # Checker-facing required semantics for annotation-only declarations. + factory_kwargs["default"] = Undefined + + if isinstance(factory, Parameter): + pobj = copy.copy(factory) + for key, value in factory_kwargs.items(): + setattr(pobj, key, value) + return pobj + if factory is None: + return Parameter(**factory_kwargs) + return factory(**factory_kwargs) + + +def _extract_namespace_annotations(namespace: dict[str, Any]) -> dict[str, Any]: + annotations = dict(namespace.get("__annotations__", {})) + if annotations: + return annotations + + # Python 3.14 may defer class annotation materialization to __annotate_func__. + annotate_func = namespace.get("__annotate_func__") + if not callable(annotate_func): + return {} + + try: + annotationlib = importlib.import_module("annotationlib") + format_value = getattr(getattr(annotationlib, "Format", None), "VALUE", 1) + evaluated = annotate_func(format_value) + except Exception: + try: + # Fallback for runtimes where annotationlib is unavailable. + evaluated = annotate_func(1) + except Exception: + return {} + + return dict(evaluated) if isinstance(evaluated, Mapping) else {} + + +@dataclass_transform(field_specifiers=(ParamField,)) +class ParamModelMetaclass(ParameterizedMetaclass): + + def __new__( + mcs, name: str, bases: tuple[type, ...], dict_: dict[str, Any] + ) -> ParamModelMetaclass: + namespace = dict_ + annotations = _extract_namespace_annotations(namespace) + module_name = namespace.get("__module__", "") + module_globals = getattr(sys.modules.get(module_name), "__dict__", {}) + + for attr, annotation in annotations.items(): + if isinstance(annotation, str): + try: + annotation = eval(annotation, module_globals, namespace) + except Exception: + pass + if attr.startswith("_"): + continue + origin = t.get_origin(annotation) + if origin is t.ClassVar: + continue + + existing = namespace.get(attr, Undefined) + if isinstance(existing, Parameter): + continue + + field_spec = existing if isinstance(existing, _ParamFieldSpec) else None + has_explicit_value = ( + attr in namespace and not isinstance(existing, _ParamFieldSpec) + ) + explicit_value = existing if has_explicit_value else Undefined + namespace[attr] = _build_parameter_from_field( + annotation, + field_spec=field_spec, + explicit_value=explicit_value, + has_explicit_value=has_explicit_value, + ) + + return t.cast( + "ParamModelMetaclass", super().__new__(mcs, name, bases, namespace) + ) + + +class ParamModel(Parameterized, metaclass=ParamModelMetaclass): + """A Parameterized subclass that synthesizes Parameters from type annotations.""" diff --git a/tests/testparammodel.py b/tests/testparammodel.py new file mode 100644 index 00000000..f6697154 --- /dev/null +++ b/tests/testparammodel.py @@ -0,0 +1,101 @@ +import typing as t + +import param +import pytest + + +def test_literal_annotation_infers_selector(): + class P(param.ParamModel): + mode: t.Literal["read", "write"] + + assert isinstance(P.param.mode, param.Selector) + assert P.param.mode.objects == ["read", "write"] + assert P.param.mode.default == "read" + + p = P() + assert p.mode == "read" + p.mode = "write" + + with pytest.raises(ValueError): + p.mode = "delete" + + +def test_literal_annotation_supports_explicit_default_value(): + class P(param.ParamModel): + mode: t.Literal["read", "write"] = "write" + + assert isinstance(P.param.mode, param.Selector) + assert P.param.mode.objects == ["read", "write"] + assert P.param.mode.default == "write" + assert P().mode == "write" + + +def test_literal_field_specification_supports_default_and_optional(): + class P(param.ParamModel): + mode: t.Literal["light", "dark"] = param.ParamField(default="dark") + optional_mode: t.Literal["auto", "manual"] | None = param.ParamField(default=None) + + assert isinstance(P.param.mode, param.Selector) + assert P.param.mode.objects == ["light", "dark"] + assert P.param.mode.default == "dark" + + assert isinstance(P.param.optional_mode, param.Selector) + assert P.param.optional_mode.objects == ["auto", "manual"] + assert P.param.optional_mode.allow_None is True + assert P.param.optional_mode.default is None + assert P().optional_mode is None + + +def test_classvar_annotation_is_not_parameterized(): + class P(param.ParamModel): + shared: t.ClassVar[int] = 7 + value: int = 1 + + assert "shared" not in P.param + assert "value" in P.param + assert P.shared == 7 + assert P().value == 1 + + +def test_annotated_metadata_sets_doc_and_parameter_attributes(): + class P(param.ParamModel): + title: t.Annotated[str, {"doc": "Title text", "constant": True}] = "hello" + + assert isinstance(P.param.title, param.String) + assert P.param.title.doc == "Title text" + assert P.param.title.constant is True + assert P().title == "hello" + + +def test_annotated_metadata_supports_inferred_parameter_kwargs(): + class P(param.ParamModel): + value: t.Annotated[int, {"bounds": (0, 10)}] = 4 + + assert isinstance(P.param.value, param.Integer) + assert P.param.value.bounds == (0, 10) + assert P().value == 4 + + +def test_field_parameter_allows_overriding_inferred_parameter_class(): + class P(param.ParamModel): + value: int = param.ParamField(default=1.5, parameter=param.Number, bounds=(0, None)) + + assert isinstance(P.param.value, param.Number) + + p = P() + assert p.value == 1.5 + p.value = 2.25 + with pytest.raises(ValueError): + p.value = "not-a-number" + + +def test_field_parameter_override_can_replace_literal_selector_behavior(): + class P(param.ParamModel): + mode: t.Literal["light", "dark"] = param.ParamField( + default="sepia", parameter=param.String + ) + + assert isinstance(P.param.mode, param.String) + p = P() + assert p.mode == "sepia" + p.mode = "custom-theme"