Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
Event,
)
from .reactive import bind, rx
from .typed import ParamField, ParamModel
from ._utils import (
descendents,
concrete_descendents,
Expand Down Expand Up @@ -187,6 +188,8 @@
'Number',
'NumericTuple',
'ObjectSelector',
'ParamField',
'ParamModel',
'ParamOverrides',
'Parameter',
'Parameterized',
Expand Down
269 changes: 269 additions & 0 deletions param/typed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
from __future__ import annotations

import copy
import importlib
import sys
import types
import typing as t

from collections.abc import Callable, Mapping
from typing import Any
from typing_extensions import dataclass_transform

from .parameterized import (
Parameter,
Parameterized,
ParameterizedMetaclass,
String,
Undefined,
)

FT = t.TypeVar("FT")


class _ParamFieldSpec:

__slots__ = ("default", "default_factory", "parameter", "kwargs")

def __init__(
self,
*,
default: Any = Undefined,
default_factory: Callable[..., Any] | Any = Undefined,
parameter: type[Parameter] | Callable[..., Parameter] | Parameter | None = None,
**kwargs: Any,
):
self.default = default
self.default_factory = default_factory
self.parameter = parameter
self.kwargs = kwargs


@t.overload
def ParamField(
*,
default: FT,
default_factory: Callable[..., Any] | Any = Undefined,
parameter: type[Parameter] | Callable[..., Parameter] | Parameter | None = None,
**kwargs: Any,
) -> FT:
...


@t.overload
def ParamField(
*,
default: Any = Undefined,
default_factory: Callable[[], FT],
parameter: type[Parameter] | Callable[..., Parameter] | Parameter | None = None,
**kwargs: Any,
) -> FT:
...


@t.overload
def ParamField(
*,
default: Any = Undefined,
default_factory: Callable[..., Any] | Any = Undefined,
parameter: type[Parameter] | Callable[..., Parameter] | Parameter | None = None,
**kwargs: Any,
) -> Any:
...


def ParamField(
*,
default: Any = Undefined,
default_factory: Callable[..., Any] | Any = Undefined,
parameter: type[Parameter] | Callable[..., Parameter] | Parameter | None = None,
**kwargs: Any,
) -> Any:
"""ParamField specifier for ParamModel attributes."""
return t.cast(
"Any",
_ParamFieldSpec(
default=default,
default_factory=default_factory,
parameter=parameter,
**kwargs,
),
)


def _annotation_parameter_factory(annotation: Any) -> tuple[type[Parameter], dict[str, Any]]:
from .parameters import (
Boolean,
ClassSelector,
Dict,
Integer,
List,
Number,
Selector,
Tuple,
)

kwargs: dict[str, Any] = {}
ann = annotation
origin = t.get_origin(ann)
if origin is t.Annotated:
annotated_args = list(t.get_args(ann))
ann = annotated_args[0] if annotated_args else ann
for meta in annotated_args[1:]:
if isinstance(meta, Mapping):
kwargs.update(dict(meta))
origin = t.get_origin(ann)

if origin in (t.Union, types.UnionType):
union_args = list(t.get_args(ann))
non_none = [a for a in union_args if a is not type(None)]
if len(non_none) < len(union_args):
kwargs["allow_None"] = True
ann = non_none[0] if len(non_none) == 1 else ann
origin = t.get_origin(ann)

if origin is t.Literal:
kwargs["objects"] = list(t.get_args(ann))
return Selector, kwargs

if ann is bool:
return Boolean, kwargs
if ann is int:
return Integer, kwargs
if ann is float:
return Number, kwargs
if ann is str:
return String, kwargs

if origin in (list, t.List):
list_args = t.get_args(ann)
if list_args and isinstance(list_args[0], type):
kwargs["item_type"] = list_args[0]
return List, kwargs

if origin in (tuple, t.Tuple):
tuple_args = t.get_args(ann)
if tuple_args and tuple_args[-1] is not Ellipsis:
kwargs["length"] = len(tuple_args)
return Tuple, kwargs

if origin in (dict, t.Dict):
return Dict, kwargs

if origin in (set, t.Set):
kwargs["class_"] = set
return ClassSelector, kwargs

if ann in (Any, object):
return Parameter, kwargs

return Parameter, kwargs


def _build_parameter_from_field(
annotation: Any,
*,
field_spec: _ParamFieldSpec | None,
explicit_value: Any = Undefined,
has_explicit_value: bool = False,
) -> Parameter:
factory_kwargs: dict[str, Any] = {}
if field_spec is not None and field_spec.parameter is not None:
factory: type[Parameter] | Callable[..., Parameter] | Parameter | None = field_spec.parameter
else:
factory, inferred = _annotation_parameter_factory(annotation)
factory_kwargs.update(inferred)

if field_spec is not None:
factory_kwargs.update(field_spec.kwargs)
if field_spec.default is not Undefined:
factory_kwargs["default"] = field_spec.default
if field_spec.default_factory is not Undefined:
factory_kwargs["default_factory"] = field_spec.default_factory

if has_explicit_value:
factory_kwargs["default"] = explicit_value
elif "default" not in factory_kwargs and "default_factory" not in factory_kwargs:
# Checker-facing required semantics for annotation-only declarations.
factory_kwargs["default"] = Undefined

if isinstance(factory, Parameter):
pobj = copy.copy(factory)
for key, value in factory_kwargs.items():
setattr(pobj, key, value)
return pobj
if factory is None:
return Parameter(**factory_kwargs)
return factory(**factory_kwargs)


def _extract_namespace_annotations(namespace: dict[str, Any]) -> dict[str, Any]:
annotations = dict(namespace.get("__annotations__", {}))
if annotations:
return annotations

# Python 3.14 may defer class annotation materialization to __annotate_func__.
annotate_func = namespace.get("__annotate_func__")
if not callable(annotate_func):
return {}

try:
annotationlib = importlib.import_module("annotationlib")
format_value = getattr(getattr(annotationlib, "Format", None), "VALUE", 1)
evaluated = annotate_func(format_value)
except Exception:
try:
# Fallback for runtimes where annotationlib is unavailable.
evaluated = annotate_func(1)
except Exception:
return {}

return dict(evaluated) if isinstance(evaluated, Mapping) else {}


@dataclass_transform(field_specifiers=(ParamField,))
class ParamModelMetaclass(ParameterizedMetaclass):

def __new__(
mcs, name: str, bases: tuple[type, ...], dict_: dict[str, Any]
) -> ParamModelMetaclass:
namespace = dict_
annotations = _extract_namespace_annotations(namespace)
module_name = namespace.get("__module__", "")
module_globals = getattr(sys.modules.get(module_name), "__dict__", {})

for attr, annotation in annotations.items():
if isinstance(annotation, str):
try:
annotation = eval(annotation, module_globals, namespace)
except Exception:
pass
if attr.startswith("_"):
continue
origin = t.get_origin(annotation)
if origin is t.ClassVar:
continue

existing = namespace.get(attr, Undefined)
if isinstance(existing, Parameter):
continue

field_spec = existing if isinstance(existing, _ParamFieldSpec) else None
has_explicit_value = (
attr in namespace and not isinstance(existing, _ParamFieldSpec)
)
explicit_value = existing if has_explicit_value else Undefined
namespace[attr] = _build_parameter_from_field(
annotation,
field_spec=field_spec,
explicit_value=explicit_value,
has_explicit_value=has_explicit_value,
)

return t.cast(
"ParamModelMetaclass", super().__new__(mcs, name, bases, namespace)
)


class ParamModel(Parameterized, metaclass=ParamModelMetaclass):
"""A Parameterized subclass that synthesizes Parameters from type annotations."""
101 changes: 101 additions & 0 deletions tests/testparammodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import typing as t

import param
import pytest


def test_literal_annotation_infers_selector():
class P(param.ParamModel):
mode: t.Literal["read", "write"]

assert isinstance(P.param.mode, param.Selector)
assert P.param.mode.objects == ["read", "write"]
assert P.param.mode.default == "read"

p = P()
assert p.mode == "read"
p.mode = "write"

with pytest.raises(ValueError):
p.mode = "delete"


def test_literal_annotation_supports_explicit_default_value():
class P(param.ParamModel):
mode: t.Literal["read", "write"] = "write"

assert isinstance(P.param.mode, param.Selector)
assert P.param.mode.objects == ["read", "write"]
assert P.param.mode.default == "write"
assert P().mode == "write"


def test_literal_field_specification_supports_default_and_optional():
class P(param.ParamModel):
mode: t.Literal["light", "dark"] = param.ParamField(default="dark")
optional_mode: t.Literal["auto", "manual"] | None = param.ParamField(default=None)

assert isinstance(P.param.mode, param.Selector)
assert P.param.mode.objects == ["light", "dark"]
assert P.param.mode.default == "dark"

assert isinstance(P.param.optional_mode, param.Selector)
assert P.param.optional_mode.objects == ["auto", "manual"]
assert P.param.optional_mode.allow_None is True
assert P.param.optional_mode.default is None
assert P().optional_mode is None


def test_classvar_annotation_is_not_parameterized():
class P(param.ParamModel):
shared: t.ClassVar[int] = 7
value: int = 1

assert "shared" not in P.param
assert "value" in P.param
assert P.shared == 7
assert P().value == 1


def test_annotated_metadata_sets_doc_and_parameter_attributes():
class P(param.ParamModel):
title: t.Annotated[str, {"doc": "Title text", "constant": True}] = "hello"

assert isinstance(P.param.title, param.String)
assert P.param.title.doc == "Title text"
assert P.param.title.constant is True
assert P().title == "hello"


def test_annotated_metadata_supports_inferred_parameter_kwargs():
class P(param.ParamModel):
value: t.Annotated[int, {"bounds": (0, 10)}] = 4

assert isinstance(P.param.value, param.Integer)
assert P.param.value.bounds == (0, 10)
assert P().value == 4


def test_field_parameter_allows_overriding_inferred_parameter_class():
class P(param.ParamModel):
value: int = param.ParamField(default=1.5, parameter=param.Number, bounds=(0, None))

assert isinstance(P.param.value, param.Number)

p = P()
assert p.value == 1.5
p.value = 2.25
with pytest.raises(ValueError):
p.value = "not-a-number"


def test_field_parameter_override_can_replace_literal_selector_behavior():
class P(param.ParamModel):
mode: t.Literal["light", "dark"] = param.ParamField(
default="sepia", parameter=param.String
)

assert isinstance(P.param.mode, param.String)
p = P()
assert p.mode == "sepia"
p.mode = "custom-theme"