Skip to content

Add AMAX, AVG, NORM1, NORM2, MUL, MUL_NO_ZEROS reduction modes#325

Open
rsuderman wants to merge 3 commits intoiree-org:mainfrom
rsuderman:reduction_rest
Open

Add AMAX, AVG, NORM1, NORM2, MUL, MUL_NO_ZEROS reduction modes#325
rsuderman wants to merge 3 commits intoiree-org:mainfrom
rsuderman:reduction_rest

Conversation

@rsuderman
Copy link
Copy Markdown
Contributor

Enable the remaining cuDNN reduction modes in ReductionAttr and add the corresponding MLIR schemas to the asm emitter:

  • NORM1 lowers to abs + sum.dim_IntList.
  • AMAX lowers to abs + amax.
  • AVG lowers to mean.dim (float dtypes only — torch.aten.mean.dim is not defined on integer tensors, so the sample skips int32 for AVG).
  • NORM2 lowers to mul + sum.dim_IntList + sqrt.
  • MUL lowers directly to torch.prims.prod.
  • MUL_NO_ZEROS uses aten.ne.Scalar to build an i1 mask, then aten.where.ScalarOther to substitute 1 for zero entries before feeding the result to torch.prims.prod, so zero inputs are excluded from the product.

Extend samples/reduction/reduction_ops.cpp to exercise every new mode. Input data is built by a per-mode generateReductionInputData helper so MUL/MUL_NO_ZEROS get a non-trivial pattern (mostly 1s with a 2 and a 3, plus injected zeros for MUL_NO_ZEROS) that stays in range for fp16/int32, and the expected value is computed by the existing reference reduction loop rather than hardcoded.

Add lit tests for each new mode under tests/lit/ and register them in tests/CMakeLists.txt.

Enable the remaining cuDNN reduction modes in ReductionAttr and add
the corresponding MLIR schemas to the asm emitter:

- NORM1 lowers to abs + sum.dim_IntList.
- AMAX lowers to abs + amax.
- AVG lowers to mean.dim (float dtypes only — torch.aten.mean.dim is
  not defined on integer tensors, so the sample skips int32 for AVG).
- NORM2 lowers to mul + sum.dim_IntList + sqrt.
- MUL lowers directly to torch.prims.prod.
- MUL_NO_ZEROS uses aten.ne.Scalar to build an i1 mask, then
  aten.where.ScalarOther to substitute 1 for zero entries before
  feeding the result to torch.prims.prod, so zero inputs are
  excluded from the product.

Extend samples/reduction/reduction_ops.cpp to exercise every new
mode. Input data is built by a per-mode generateReductionInputData
helper so MUL/MUL_NO_ZEROS get a non-trivial pattern (mostly 1s with
a 2 and a 3, plus injected zeros for MUL_NO_ZEROS) that stays in
range for fp16/int32, and the expected value is computed by the
existing reference reduction loop rather than hardcoded.

Add lit tests for each new mode under tests/lit/ and register them
in tests/CMakeLists.txt.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>

# Conflicts:
#	include/fusilli/support/asm_emitter.h
#	samples/reduction/reduction_ops.cpp
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant