Skip to content

[ROCm]: fix: reduce MoE temp memory — embedding cap, weight sum default, skip trivial specs (PR3)#4193

Open
cj401-amd wants to merge 4 commits into
AI-Hypercomputer:mainfrom
cj401-amd:cj/tmem-fixes-clean-3-moe-tmem
Open

[ROCm]: fix: reduce MoE temp memory — embedding cap, weight sum default, skip trivial specs (PR3)#4193
cj401-amd wants to merge 4 commits into
AI-Hypercomputer:mainfrom
cj401-amd:cj/tmem-fixes-clean-3-moe-tmem

Conversation

@cj401-amd

Copy link
Copy Markdown
Collaborator

Summary

  • Embeddings: cap use_iota_embed to ≤2 GiB one-hot size to prevent OOM on large
    vocabularies; add explicit nn.with_logical_constraint after embedding lookup
  • MoE config: change float32_weight_sum default from true to false — the f32
    upcast adds ~2 GB temp per device with minimal numerical benefit for most configs
  • DeepSeek: fix activation PartitionSpec to include fsdp_transpose and context
    axes; use remove_size_one_mesh_axis helper; remove redundant jax.reshard calls
  • Mixtral: replace nn.with_logical_constraint with maybe_shard_with_logical(..., skip_trivial_specs=True) throughout MixtralDecoderLayer to avoid no-op sharding
    constraints that add XLA overhead

Test plan

  • python3 -m pytest tests/unit/train_compile_test.py -v -k "moe or deepseek or mixtral"
  • Smoke-test MoE model (e.g. mixtral-8x7b or deepseek3-test config)

@cj401-amd cj401-amd force-pushed the cj/tmem-fixes-clean-3-moe-tmem branch from 666bf09 to 0ed140e Compare June 18, 2026 22:42
@codecov

codecov Bot commented Jun 18, 2026

Copy link
Copy Markdown

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant