Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ class ReductionDescriptor(Taggable):
Records information about a reduction dimension in an
:class:`~pytato.Array`'.
"""
tags: frozenset[Tag]
tags: frozenset[Tag] = frozenset()

@override
def _with_new_tags(self, tags: frozenset[Tag]) -> ReductionDescriptor:
Expand Down Expand Up @@ -1766,7 +1766,7 @@ def einsum(subscripts: str, *operands: Array,
for descr in index_to_descr.values():
if (isinstance(descr, EinsumReductionAxis)
and descr not in redn_axis_to_redn_descr):
redn_axis_to_redn_descr[descr] = ReductionDescriptor(frozenset())
redn_axis_to_redn_descr[descr] = ReductionDescriptor()

# }}}

Expand Down Expand Up @@ -2339,9 +2339,21 @@ class CSRMatmul(SparseMatmul):
.. attribute:: array

The :class:`Array` to which the sparse matrix is being applied.

.. attribute:: reduction_var

The index variable for the per-row reduction.

.. attribute:: reduction_descr

The :class:`ReductionDescriptor` for the per-row reduction.
"""
matrix: CSRMatrix
array: Array
reduction_var: str = "_r0"
reduction_descr: ReductionDescriptor = dataclasses.field(
kw_only=True,
default_factory=ReductionDescriptor)

@property
@override
Expand All @@ -2353,6 +2365,24 @@ def _matrix(self) -> SparseMatrix:
def _array(self) -> Array:
return self.array

def with_tagged_reduction(self, tags: Tag | Iterable[Tag]) -> CSRMatmul:
"""
Returns a copy of *self* with its :class:`ReductionDescriptor`
tagged with *tag*.
"""
new_redn_descr = self.reduction_descr.tagged(tags)
if new_redn_descr is not self.reduction_descr:
return type(self)(
matrix=self.matrix,
array=self.array,
axes=self.axes,
reduction_descr=new_redn_descr,
tags=self.tags,
non_equality_tags=self.non_equality_tags)
else:
return self


# }}}


Expand Down Expand Up @@ -3271,7 +3301,7 @@ def make_index_lambda(

for redn_var in redn_vars:
redn_descr = var_to_reduction_descr.get(redn_var,
ReductionDescriptor(frozenset()))
ReductionDescriptor())
if not isinstance(redn_descr, ReductionDescriptor):
raise TypeError(f"reduction_dim for {redn_var} expected to be"
f" of type ReductionDescriptor, got {type(redn_descr)}.")
Expand Down
2 changes: 1 addition & 1 deletion pytato/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _get_var_to_redn_descr(
idx = f"_r{n_redn_dims}"
redn_descr = axis_to_reduction_descr.get(
idim,
ReductionDescriptor(frozenset()))
ReductionDescriptor())
if not isinstance(redn_descr, ReductionDescriptor):
raise TypeError(f"'axis_to_reduction_descr[{idim}]': "
"expected an instance of ReductionDescriptor, "
Expand Down
2 changes: 1 addition & 1 deletion pytato/stringifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _map_generic_array(self, expr: Array, depth: int) -> str:
fields = tuple(field for field in fields if field != "axes")

if (isinstance(expr, IndexLambda)
and all(redn_descr == ReductionDescriptor(frozenset())
and all(redn_descr == ReductionDescriptor()
for redn_descr in expr.var_to_reduction_descr.values())):
# prettify: if trivial 'expr.var_to_reduction_descr' => don't print.
fields = tuple(field
Expand Down
11 changes: 6 additions & 5 deletions pytato/transform/lower_to_index_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
IndexExpr,
IndexLambda,
NormalizedSlice,
ReductionDescriptor,
Reshape,
Roll,
ShapeComponent,
Expand Down Expand Up @@ -732,6 +731,8 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda:
non_equality_tags=expr.non_equality_tags)

def map_csr_matmul(self, expr: CSRMatmul) -> IndexLambda:
redn_var = expr.reduction_var

rec_matrix_elem_values = self.rec(expr.matrix.elem_values)
rec_matrix_elem_col_indices = self.rec(expr.matrix.elem_col_indices)
rec_matrix_row_starts = self.rec(expr.matrix.row_starts)
Expand All @@ -740,15 +741,15 @@ def map_csr_matmul(self, expr: CSRMatmul) -> IndexLambda:
from pytato.reductions import SumReductionOperation
from pytato.scalar_expr import Reduce
index_expr = Reduce(
prim.Variable("_in0")[prim.Variable("_r0"),]
prim.Variable("_in0")[prim.Variable(redn_var),]
* prim.Variable("_in3")[(
prim.Variable("_in1")[prim.Variable("_r0"),],
prim.Variable("_in1")[prim.Variable(redn_var),],
*(
prim.Variable(f"_{idim}")
for idim in range(1, rec_array.ndim)))],
SumReductionOperation(),
constantdict({
"_r0": (
redn_var: (
prim.Variable("_in2")[prim.Variable("_0"),],
prim.Variable("_in2")[prim.Variable("_0") + 1,])}))

Expand All @@ -762,7 +763,7 @@ def map_csr_matmul(self, expr: CSRMatmul) -> IndexLambda:
"_in3": rec_array}),
axes=expr.axes,
var_to_reduction_descr=constantdict({
"_r0": ReductionDescriptor(tags=frozenset())}),
redn_var: expr.reduction_descr}),
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

Expand Down
5 changes: 5 additions & 0 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,11 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array:
self.axis_to_tags.get((expr, redn_var), [])
)

if isinstance(expr, CSRMatmul):
assert isinstance(result, CSRMatmul)
tags = self.axis_to_tags.get((expr, expr.reduction_var), [])
result = result.with_tagged_reduction(tags)

# }}}

return result
Expand Down
Loading