Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion pytato/transform/lower_to_index_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.. currentmodule:: pytato.transform.lower_to_index_lambda

.. autofunction:: to_index_lambda
.. autoclass:: MapAsIndexLambdaMixin
"""
from __future__ import annotations

Expand All @@ -28,8 +29,9 @@
THE SOFTWARE.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast

from constantdict import constantdict
from typing_extensions import Never
Expand Down Expand Up @@ -65,6 +67,8 @@
from pytato.tags import AssumeNonNegative
from pytato.transform import (
Mapper,
P,
ResultT,
_verify_is_array,
)
from pytato.utils import normalized_slice_does_not_change_axis
Expand Down Expand Up @@ -790,4 +794,64 @@ def to_index_lambda(expr: Array) -> IndexLambda:
assert isinstance(res, IndexLambda)
return res


class MapAsIndexLambdaMixin(ABC, Generic[ResultT, P]):
"""
Mixin that, where possible, lowers arrays to :class:`~pytato.array.IndexLambda`
and calls :meth:`map_as_index_lambda` on them.

.. automethod:: map_as_index_lambda
"""
@abstractmethod
def map_as_index_lambda(
self, expr: Array, idx_lambda: IndexLambda,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
"""
Map *expr* via its :class:`~pytato.array.IndexLambda` representation
*idx_lambda*.
"""

def map_index_lambda(
self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, expr, *args, **kwargs)

def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs)

def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs)

def map_axis_permutation(
self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs)

def map_basic_index(
self, expr: BasicIndex, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs)

def map_contiguous_advanced_index(
self, expr: AdvancedIndexInContiguousAxes,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs)

def map_non_contiguous_advanced_index(
self, expr: AdvancedIndexInNoncontiguousAxes,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs)

def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs)

def map_concatenate(
self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs)

def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs)

def map_csr_matmul(
self, expr: CSRMatmul, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs)


# vim:fdm=marker
74 changes: 12 additions & 62 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)

from bidict import bidict
from typing_extensions import Never
from typing_extensions import Never, override

import pymbolic.primitives as prim
from pymbolic.typing import Expression
Expand All @@ -63,20 +63,14 @@

from pytato.array import (
AbstractResultWithNamedArrays,
AdvancedIndexInContiguousAxes,
Array,
AxisPermutation,
BasicIndex,
Concatenate,
CSRMatmul,
DictOfNamedArrays,
Einsum,
EinsumReductionAxis,
IndexLambda,
InputArgumentBase,
NamedArray,
Reshape,
Stack,
)
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
from pytato.function import NamedCallResult
Expand All @@ -91,7 +85,7 @@
Mapper,
TransformMapperCache,
)
from pytato.transform.lower_to_index_lambda import to_index_lambda
from pytato.transform.lower_to_index_lambda import MapAsIndexLambdaMixin


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -156,7 +150,9 @@ def map_constant(self, expr: object) -> dict[BindingName,
# {{{ AxesTagsEquationCollector


class AxesTagsEquationCollector(Mapper[None, Never, []]):
class AxesTagsEquationCollector(
Mapper[None, Never, []],
MapAsIndexLambdaMixin[None, []]):
r"""
Records equations arising from operand/output axes equivalence for an array
operation. An equation is recorded for "straight-through" axes in expressions,
Expand All @@ -183,15 +179,12 @@ class AxesTagsEquationCollector(Mapper[None, Never, []]):
iaxis)`` to the :class:`str` by which it will be referenced in
:attr:`equations`.

.. automethod:: map_index_lambda
.. automethod:: map_placeholder
.. automethod:: map_data_wrapper
.. automethod:: map_size_param
.. automethod:: map_reshape
.. automethod:: map_basic_index
.. automethod:: map_contiguous_advanced_index
.. automethod:: map_stack
.. automethod:: map_concatenate

.. automethod:: map_as_index_lambda

.. note::

Expand Down Expand Up @@ -281,20 +274,16 @@ def _map_input_base(self, expr: InputArgumentBase) -> None:
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_index_lambda(self, expr: IndexLambda) -> None:
for bnd in expr.bindings.values():
self.rec(bnd)

self.add_equations_using_index_lambda_version_of_expr(expr)

def add_equations_using_index_lambda_version_of_expr(self, expr: Array) -> None:
@override
def map_as_index_lambda(self, expr: Array, idx_lambda: IndexLambda) -> None:
"""
Equations are added between an axis of the bindings of *expr* and an axis of
*expr* if the binding's axis is indexed by by a :class:`~pymbolic.Variable`
which has a name that follows the reserved iname format, "_[0-9]+", and the axis
of the output specified by the iname.
"""
idx_lambda = expr if isinstance(expr, IndexLambda) else to_index_lambda(expr)
for bnd in idx_lambda.bindings.values():
self.rec(bnd)

index_expr_used = BindingSubscriptsCollector()(idx_lambda.expr)

Expand Down Expand Up @@ -327,32 +316,7 @@ def add_equations_using_index_lambda_version_of_expr(self, expr: Array) -> None:
# Other cases are considered "complicated" and we won't
# handle them here.

def map_stack(self, expr: Stack) -> None:
for ary in expr.arrays:
self.rec(ary)

self.add_equations_using_index_lambda_version_of_expr(expr)

def map_concatenate(self, expr: Concatenate) -> None:
for ary in expr.arrays:
self.rec(ary)
self.add_equations_using_index_lambda_version_of_expr(expr)

def map_axis_permutation(self, expr: AxisPermutation
) -> None:
self.rec(expr.array)
self.add_equations_using_index_lambda_version_of_expr(expr)

def map_basic_index(self, expr: BasicIndex) -> None:
self.rec(expr.array)
self.add_equations_using_index_lambda_version_of_expr(expr)

def map_contiguous_advanced_index(self,
expr: AdvancedIndexInContiguousAxes
) -> None:
self.rec(expr.array)
self.add_equations_using_index_lambda_version_of_expr(expr)

@override
def map_reshape(self, expr: Reshape) -> None:
"""
Reshaping generally does not preserve the axis between its input and
Expand Down Expand Up @@ -386,20 +350,6 @@ def map_reshape(self, expr: Reshape) -> None:

assert i_in_axis == expr.array.ndim

def map_einsum(self, expr: Einsum) -> None:
for arg in expr.args:
self.rec(arg)
self.add_equations_using_index_lambda_version_of_expr(expr)

def map_csr_matmul(self, expr: CSRMatmul) -> None:
for ary in (
expr.matrix.elem_values,
expr.matrix.elem_col_indices,
expr.matrix.row_starts,
expr.array):
self.rec(ary)
self.add_equations_using_index_lambda_version_of_expr(expr)

def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
for _, subexpr in sorted(expr._data.items()):
self.rec(subexpr)
Expand Down
Loading