diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 773ee4cbe..31f092bfc 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -2,6 +2,7 @@ .. currentmodule:: pytato.transform.lower_to_index_lambda .. autofunction:: to_index_lambda +.. autoclass:: MapAsIndexLambdaMixin """ from __future__ import annotations @@ -28,8 +29,9 @@ THE SOFTWARE. """ +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from constantdict import constantdict from typing_extensions import Never @@ -65,6 +67,8 @@ from pytato.tags import AssumeNonNegative from pytato.transform import ( Mapper, + P, + ResultT, _verify_is_array, ) from pytato.utils import normalized_slice_does_not_change_axis @@ -790,4 +794,64 @@ def to_index_lambda(expr: Array) -> IndexLambda: assert isinstance(res, IndexLambda) return res + +class MapAsIndexLambdaMixin(ABC, Generic[ResultT, P]): + """ + Mixin that, where possible, lowers arrays to :class:`~pytato.array.IndexLambda` + and calls :meth:`map_as_index_lambda` on them. + + .. automethod:: map_as_index_lambda + """ + @abstractmethod + def map_as_index_lambda( + self, expr: Array, idx_lambda: IndexLambda, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + """ + Map *expr* via its :class:`~pytato.array.IndexLambda` representation + *idx_lambda*. + """ + + def map_index_lambda( + self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, expr, *args, **kwargs) + + def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_axis_permutation( + self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_basic_index( + self, expr: BasicIndex, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_contiguous_advanced_index( + self, expr: AdvancedIndexInContiguousAxes, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_non_contiguous_advanced_index( + self, expr: AdvancedIndexInNoncontiguousAxes, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_concatenate( + self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_csr_matmul( + self, expr: CSRMatmul, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + # vim:fdm=marker diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 36aab2391..957c3ce35 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -54,7 +54,7 @@ ) from bidict import bidict -from typing_extensions import Never +from typing_extensions import Never, override import pymbolic.primitives as prim from pymbolic.typing import Expression @@ -63,12 +63,7 @@ from pytato.array import ( AbstractResultWithNamedArrays, - AdvancedIndexInContiguousAxes, Array, - AxisPermutation, - BasicIndex, - Concatenate, - CSRMatmul, DictOfNamedArrays, Einsum, EinsumReductionAxis, @@ -76,7 +71,6 @@ InputArgumentBase, NamedArray, Reshape, - Stack, ) from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import NamedCallResult @@ -91,7 +85,7 @@ Mapper, TransformMapperCache, ) -from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.transform.lower_to_index_lambda import MapAsIndexLambdaMixin logger = logging.getLogger(__name__) @@ -156,7 +150,9 @@ def map_constant(self, expr: object) -> dict[BindingName, # {{{ AxesTagsEquationCollector -class AxesTagsEquationCollector(Mapper[None, Never, []]): +class AxesTagsEquationCollector( + Mapper[None, Never, []], + MapAsIndexLambdaMixin[None, []]): r""" Records equations arising from operand/output axes equivalence for an array operation. An equation is recorded for "straight-through" axes in expressions, @@ -183,15 +179,12 @@ class AxesTagsEquationCollector(Mapper[None, Never, []]): iaxis)`` to the :class:`str` by which it will be referenced in :attr:`equations`. - .. automethod:: map_index_lambda .. automethod:: map_placeholder .. automethod:: map_data_wrapper .. automethod:: map_size_param .. automethod:: map_reshape - .. automethod:: map_basic_index - .. automethod:: map_contiguous_advanced_index - .. automethod:: map_stack - .. automethod:: map_concatenate + + .. automethod:: map_as_index_lambda .. note:: @@ -281,20 +274,16 @@ def _map_input_base(self, expr: InputArgumentBase) -> None: map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_index_lambda(self, expr: IndexLambda) -> None: - for bnd in expr.bindings.values(): - self.rec(bnd) - - self.add_equations_using_index_lambda_version_of_expr(expr) - - def add_equations_using_index_lambda_version_of_expr(self, expr: Array) -> None: + @override + def map_as_index_lambda(self, expr: Array, idx_lambda: IndexLambda) -> None: """ Equations are added between an axis of the bindings of *expr* and an axis of *expr* if the binding's axis is indexed by by a :class:`~pymbolic.Variable` which has a name that follows the reserved iname format, "_[0-9]+", and the axis of the output specified by the iname. """ - idx_lambda = expr if isinstance(expr, IndexLambda) else to_index_lambda(expr) + for bnd in idx_lambda.bindings.values(): + self.rec(bnd) index_expr_used = BindingSubscriptsCollector()(idx_lambda.expr) @@ -327,32 +316,7 @@ def add_equations_using_index_lambda_version_of_expr(self, expr: Array) -> None: # Other cases are considered "complicated" and we won't # handle them here. - def map_stack(self, expr: Stack) -> None: - for ary in expr.arrays: - self.rec(ary) - - self.add_equations_using_index_lambda_version_of_expr(expr) - - def map_concatenate(self, expr: Concatenate) -> None: - for ary in expr.arrays: - self.rec(ary) - self.add_equations_using_index_lambda_version_of_expr(expr) - - def map_axis_permutation(self, expr: AxisPermutation - ) -> None: - self.rec(expr.array) - self.add_equations_using_index_lambda_version_of_expr(expr) - - def map_basic_index(self, expr: BasicIndex) -> None: - self.rec(expr.array) - self.add_equations_using_index_lambda_version_of_expr(expr) - - def map_contiguous_advanced_index(self, - expr: AdvancedIndexInContiguousAxes - ) -> None: - self.rec(expr.array) - self.add_equations_using_index_lambda_version_of_expr(expr) - + @override def map_reshape(self, expr: Reshape) -> None: """ Reshaping generally does not preserve the axis between its input and @@ -386,20 +350,6 @@ def map_reshape(self, expr: Reshape) -> None: assert i_in_axis == expr.array.ndim - def map_einsum(self, expr: Einsum) -> None: - for arg in expr.args: - self.rec(arg) - self.add_equations_using_index_lambda_version_of_expr(expr) - - def map_csr_matmul(self, expr: CSRMatmul) -> None: - for ary in ( - expr.matrix.elem_values, - expr.matrix.elem_col_indices, - expr.matrix.row_starts, - expr.array): - self.rec(ary) - self.add_equations_using_index_lambda_version_of_expr(expr) - def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: for _, subexpr in sorted(expr._data.items()): self.rec(subexpr)