Skip to content

[ROCm]: fix: reduce pipeline temp memory — replace ppermute collectives with lax.slice/pad (PR2) #4192

Open
cj401-amd wants to merge 3 commits into
AI-Hypercomputer:mainfrom
cj401-amd:cj/tmem-fixes-clean-2-pipeline-tmem
Open

[ROCm]: fix: reduce pipeline temp memory — replace ppermute collectives with lax.slice/pad (PR2) #4192
cj401-amd wants to merge 3 commits into
AI-Hypercomputer:mainfrom
cj401-amd:cj/tmem-fixes-clean-2-pipeline-tmem

Conversation

@cj401-amd

Copy link
Copy Markdown
Collaborator

Summary

Replace shard_map + ppermute collective operations in the pipeline with pure
lax.slice/jnp.pad/jnp.concatenate equivalents. This eliminates the shard_map
overhead and removes stage-axis sharding constraints that caused temp memory bloat
and shape-divisibility errors.

Changes in PipelineBase.get_new_loop_state:

  • _rotate_right: shard_map + ppermutelax.slice_in_dim + concatenate
  • _shift_right: shard_map + ppermute + wherepad + lax.slice
  • _update_state_io: shard_map + _rotate_left/_shift_leftpad +
    slice_in_dim + where (also removes extra stream_buf_idx arg)

Changes in PipelineBase.get_iteration_inputs:

  • Remove redundant _maybe_shard_with_logical calls on shift and first_stage_in
  • Remove out_sharding from broadcasted_iota

Other:

  • Remove _maybe_shard_with_name on microbatches_processed in get_microbatch_and_repeat_ids
  • Make get_pipeline_remat_policy conditional on pipeline_save_decoder_layer_input:
    when False, omit decoder_layer_input from saved names to reduce remat temp memory

Test plan

  • python3 -m pytest tests/unit/train_compile_test.py -v -k "pipeline"
  • python3 -m pytest tests/integration/pipeline_parallelism_test.py -v
  • Smoke-test pp=8 config with pipeline_save_decoder_layer_input=false

@cj401-amd cj401-amd force-pushed the cj/tmem-fixes-clean-2-pipeline-tmem branch from face36a to 62907cf Compare June 18, 2026 22:41
@codecov

codecov Bot commented Jun 18, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 40.69767% with 51 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/pipeline.py 38.59% 35 Missing ⚠️
src/maxtext/layers/attention_op.py 0.00% 7 Missing ⚠️
src/maxtext/layers/normalizations.py 54.54% 4 Missing and 1 partial ⚠️
src/maxtext/trainers/pre_train/train.py 42.85% 3 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant