Skip to content
45 changes: 44 additions & 1 deletion src/psyclone/psyGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,11 @@ class Transformation(metaclass=abc.ABCMeta):
"User guide for more details."
)

#: List of transformations called inside this one that need to be
#: considered by the split_kwargs infrastructure when propagating
#: transformation options
_SUB_TRANSFORMATIONS = []

@property
def name(self):
'''
Expand All @@ -2353,6 +2358,36 @@ def name(self):
'''
return type(self).__name__

def split_kwargs(
self,
kwargs: dict[str, Any],
) -> tuple[dict[str, Any]]:
'''
:param kwargs: the list of kwargs to split.

:returns: a tuple of the kwargs dictionaries that are valid for this
transformation and every other transformation listed in the
_SUB_TRANSFORMATIONS list. The first kwargs (the ones for itself)
will also include any key that is not valid in any of the other
transformation (e.g. unsupported options).
Comment thread
LonelyCat124 marked this conversation as resolved.
Outdated
'''
# The first kwargs starts with all the items
first_dict = dict(kwargs)
# The following kwargs start empty
other_dicts = [{} for _ in self._SUB_TRANSFORMATIONS]

# Now copy each valid item into the transformation-specific kwargs
# and delete them from the first one if they are valid somewhere
# else but not in the self options
for key in kwargs:
for idx, trans in enumerate(self._SUB_TRANSFORMATIONS):
if key in trans.get_valid_options():
other_dicts[idx][key] = kwargs[key]
if key not in type(self).get_valid_options():
del first_dict[key]

return tuple([first_dict] + other_dicts)
Comment thread
LonelyCat124 marked this conversation as resolved.
Outdated

@abc.abstractmethod
def apply(self, node, options=None, **kwargs):
'''Abstract method that applies the transformation. This function
Expand Down Expand Up @@ -2509,10 +2544,18 @@ def validate_options(self, **kwargs):
for invalid in invalid_options:
invalid_options_detail.append(f"'{invalid}'")
invalid_options_list = ", ".join(invalid_options_detail)
extra_options = ""
if self._SUB_TRANSFORMATIONS:
sub_trans_names = [
tr.__name__ for tr in self._SUB_TRANSFORMATIONS
]
extra_options = (
f" or any other options supported in {sub_trans_names}"
)
raise ValueError(f"'{type(self).__name__}' received invalid "
f"options [{invalid_options_list}]. "
f"Valid options are "
f"'{list(valid_options.keys())}.")
f"{list(valid_options.keys())}{extra_options}.")
if len(wrong_types.keys()) > 0:
wrong_types_detail = []
for name in wrong_types.keys():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class MaximalOMPParallelRegionTrans(MaximalRegionTrans):
the discussion on #3205 for more detail.'''
# The type of parallel transformation to be applied to the input region.
_transformation = OMPParallelTrans
_SUB_TRANSFORMATIONS = [OMPParallelTrans]
# Tuple of statement nodes allowed inside the _transformation
_allowed_contiguous_statements = (
OMPTaskwaitDirective,
Expand Down
24 changes: 15 additions & 9 deletions src/psyclone/psyir/transformations/maximal_region_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# -----------------------------------------------------------------------------
# Authors A. B. G. Chalk, STFC Daresbury Lab
# Author: A. B. G. Chalk, STFC Daresbury Lab
# Modified: S. Siso, STFC Daresbury Lab

'''This module contains the MaximalRegionTrans.'''

import abc
from typing import Union
from typing import Union, Any

from psyclone.psyir.nodes import (
Node,
Expand Down Expand Up @@ -148,13 +149,15 @@ def _can_be_in_region(self, node: Node) -> bool:
def _compute_transformable_sections(
self, node_list: list[Node],
trans: Transformation,
trans_kwargs: dict[str, Any]
) -> list[list[Node]]:
'''
Computes the sections of the input node_list to apply the
transformation to.

:param node_list: The node_list passed into this Transformation.
:param trans: The transformation applied to the regions found.
:param trans_kwargs: The kwargs applied to the transformation.
:returns: The list of node_lists to apply this class'
_transformation class to.
'''
Expand All @@ -168,7 +171,7 @@ def _compute_transformable_sections(
# Check that validation still succeeds if we add this child
# to the current block.
try:
trans.validate(current_block + [child])
trans.validate(current_block + [child], **trans_kwargs)
current_block.append(child)
except TransformationError:
# If validation now fails, then don't add this to the
Expand All @@ -189,17 +192,17 @@ def _compute_transformable_sections(
# Need to recurse on some node types
if isinstance(child, IfBlock):
if_blocks = self._compute_transformable_sections(
child.if_body, trans
child.if_body, trans, trans_kwargs
)
all_blocks.extend(if_blocks)
if child.else_body:
else_blocks = self._compute_transformable_sections(
child.else_body, trans
child.else_body, trans, trans_kwargs
)
all_blocks.extend(else_blocks)
if isinstance(child, (Loop, WhileLoop)):
loop_blocks = self._compute_transformable_sections(
child.loop_body, trans
child.loop_body, trans, trans_kwargs
)
all_blocks.extend(loop_blocks)
# If any nodes are left in the current block at the end of the
Expand All @@ -220,7 +223,8 @@ def validate(self, nodes: Union[Node, Schedule, list[Node]], **kwargs):
same parent and aren't consecutive.
'''

self.validate_options(**kwargs)
self_kwargs, _ = self.split_kwargs(kwargs)
self.validate_options(**self_kwargs)
node_list = self.get_node_list(nodes)

node_parent = node_list[0].parent
Expand All @@ -247,11 +251,13 @@ def apply(self, nodes: Union[Node, Schedule, list[Node]], **kwargs):

# Call validate.
self.validate(nodes, **kwargs)
_, tr_kwargs = self.split_kwargs(kwargs)

par_trans = self._transformation()

all_blocks = self._compute_transformable_sections(node_list, par_trans)
all_blocks = self._compute_transformable_sections(
node_list, par_trans, tr_kwargs)

# Apply the transformation to all of the blocks found.
for block in all_blocks:
par_trans.apply(block)
par_trans.apply(block, **tr_kwargs)
101 changes: 100 additions & 1 deletion src/psyclone/tests/psyGen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,105 @@ def apply(self, node, valid: bool = True):
"Valid options are '['valid']." in str(excinfo.value))


def test_transformation_split_kwargs():
''' Test that the kwargs can be split when they can be propagated to
multiple sub-transformations. '''

class Called1Trans(Transformation):
''' Transformation Example'''
def apply(
self,
node,
test_option: bool = True,
common_option: bool = True,
**kwargs
):
self.validate(
node,
test_option=test_option,
common_option=common_option,
**kwargs
)
# Asserts to prove that the True value was propagated until here
assert test_option
assert common_option

class Called2Trans(Transformation):
''' Transformation Example'''
def apply(
self,
node,
test2_option: bool = False,
common_option: bool = False,
**kwargs
):
self.validate(
node,
test2_option=test2_option,
common_option=common_option,
**kwargs
)
# Asserts to prove that the True value was propagated until here
assert test2_option
assert common_option

class TestMetaTrans(Transformation):
''' MetaTrans Example'''
_trans1 = Called1Trans
_trans2 = Called2Trans
_SUB_TRANSFORMATIONS = [Called1Trans, Called2Trans]

def validate(self, node, **kwargs):
self_kwargs, tr1_kwargs, tr2_kwargs = self.split_kwargs(kwargs)
self._trans1().validate(node, tr1_kwargs)
self._trans2().validate(node, tr2_kwargs)
self.validate_options(**self_kwargs)

super().validate(node, **self_kwargs)

def apply(
self,
node,
meta_option: bool = True,
common_option: bool = True,
**kwargs
):
# If we want a keyword argument to not be exclusively consumed by
# this transformation and propagate it, put it back into kwargs
kwargs['common_option'] = common_option
# If we want to consume it use it by name
self.validate(
node,
meta_option=meta_option,
**kwargs)

self._trans1().apply(node, **kwargs)
Comment thread
LonelyCat124 marked this conversation as resolved.
Outdated
self._trans2().apply(node, **kwargs)

# Asserts to prove that the True value was propagated until here
assert meta_option
assert common_option

test = TestMetaTrans()
test.apply(Node(), meta_option=True, common_option=True,
test_option=True, test2_option=True)
test.validate(Node(), meta_option=True, common_option=True,
test_option=True, test2_option=True)

with pytest.raises(ValueError) as err:
test.apply(Node(), invalid=True)
assert ("'TestMetaTrans' received invalid options ['invalid']. Valid "
"options are ['meta_option', 'common_option'] or any other "
"options supported in ['Called1Trans', 'Called2Trans']."
== str(err.value))
with pytest.raises(ValueError) as err:
test.validate(Node(), invalid=True)
assert ("'TestMetaTrans' received invalid options ['invalid']. Valid "
"options are ['meta_option', 'common_option'] or any other "
"options supported in ['Called1Trans', 'Called2Trans']."
== str(err.value))


def test_transformation_apply_deprecation_message(capsys):
'''Test that passing the options dict to the Transformation.apply
function gets the expected deprecation message.'''
Expand Down Expand Up @@ -313,7 +412,7 @@ def apply(self, node, valid: bool = True, options=None):
with pytest.raises(ValueError) as excinfo:
instance.validate_options(not_valid=True)
assert ("'TestTrans' received invalid options ['not_valid']. "
"Valid options are '['valid', 'options']." in str(excinfo.value))
"Valid options are ['valid', 'options']." in str(excinfo.value))


# TransInfo class unit tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class MaxParTrans(MaximalRegionTrans):
''' Dummy class to test MaxParallelRegionTrans' functionality. '''
# The apply function will do OMPParallelTrans around allowed regions.
_transformation = OMPParallelTrans
_SUB_TRANSFORMATIONS = [OMPParallelTrans]
# We're only allowing assignment because its straightforward to test with.
_allowed_contiguous_statements = (Assignment, )
# Should parallelise any found region that contains an assignment.
Expand Down Expand Up @@ -265,6 +266,7 @@ def apply(self, node, **kwargs):
class OneParTrans(MaximalRegionTrans):
'''Dummy MaximalRegionTrans that uses our FakeTrans'''
_transformation = Faketrans
_SUB_TRANSFORMATIONS = [Faketrans]
_allowed_contiguous_statements = (Assignment, )
_required_nodes = (Assignment, )

Expand Down Expand Up @@ -322,6 +324,7 @@ def apply(self, node: Assignment, **kwargs):
class OneParTrans(MaximalRegionTrans):
'''Dummy MaximalRegionTrans that uses our FakeTrans'''
_transformation = Faketrans
_SUB_TRANSFORMATIONS = [Faketrans]
_allowed_contiguous_statements = (Assignment, )
_required_nodes = (Assignment, )

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def test_omploop_trans_new_options(sample_psyir):
with pytest.raises(ValueError) as excinfo:
omplooptrans.apply(tree.walk(Loop)[0], fakeoption1=1, fakeoption2=2)
assert ("'OMPLoopTrans' received invalid options ['fakeoption1', "
"'fakeoption2']. Valid options are '['node_type_check', "
"'fakeoption2']. Valid options are ['node_type_check', "
"'verbose', 'collapse', 'force', 'ignore_dependencies_for', "
"'privatise_arrays', 'sequential', 'nowait', 'reduction_ops', "
"'force_private', 'options', 'reprod', 'enable_reductions']."
Expand Down
Loading