Skip to content

Add broadcast-compatible scale support for RmsNorm#332

Open
rsuderman wants to merge 1 commit intoiree-org:mainfrom
rsuderman:rmsnorm_broadcast
Open

Add broadcast-compatible scale support for RmsNorm#332
rsuderman wants to merge 1 commit intoiree-org:mainfrom
rsuderman:rmsnorm_broadcast

Conversation

@rsuderman
Copy link
Copy Markdown
Contributor

Allow RmsNorm scale tensors to have broadcast-compatible shapes (e.g. {1,c,1,1} instead of {1,c,h,w}), letting torch.aten.rms_norm handle the broadcasting internally. Also update getScalarConstantAsm to emit the actual tensor shape instead of hardcoding tensor<1x...>.

Allow RmsNorm scale tensors to have broadcast-compatible shapes
(e.g. {1,c,1,1} instead of {1,c,h,w}), letting torch.aten.rms_norm
handle the broadcasting internally. Also update getScalarConstantAsm
to emit the actual tensor shape instead of hardcoding tensor<1x...>.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
Copy link
Copy Markdown
Member

@sjain-stanford sjain-stanford left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In isolation this change is fine but I'm not sure this is a level of flexibility we should allow. Both PT and cudnn consistently use per element scale for RMS norm (i.e. of normalized shape - matching input shape modulo batch dims). Why would we want to differ / be more general than that? If this change was prompted by allowing hipdnn -> fusilli to bridge, then its a placebo and the real issue lies in hipdnn scale not matching normalized shapes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants