Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed
cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax
float32_weight_sum: true # whether to use full fp32 precision to sum expert weights for numerical stability
float32_weight_sum: false # whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device
float32_gate_logits: false # whether to cast inputs to fp32 to compute MoE gate logits for numerical stability

# multi-token prediction configs
Expand Down Expand Up @@ -320,6 +320,7 @@ scan_pipeline_repeats: false
scan_layers_per_stage: false
set_remat_policy_on_pipeline_iterations: true
set_remat_policy_on_layers_per_stage: false
pipeline_save_decoder_layer_input: true # set to false to reduce pipeline tmem at cost of recomputing decoder layer inputs in backward pass


# Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp',
Expand Down
14 changes: 11 additions & 3 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,8 @@ class MoEGeneral(BaseModel):
description="Enable top-k probability normalization for router weights (Qwen3-specific).",
)
float32_weight_sum: bool = Field(
True,
description="Whether to use full fp32 precision to sum expert weights for numerical stability.",
False,
description="Whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device.",
)
float32_gate_logits: bool = Field(
False,
Expand Down Expand Up @@ -1011,6 +1011,14 @@ class PipelineParallelism(BaseModel):
scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.")
set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.")
set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.")
pipeline_save_decoder_layer_input: bool = Field(
True,
description=(
"Whether to save 'decoder_layer_input' activations in the pipeline remat policy. "
"Setting to False reduces temporary memory (tmem) during pipeline execution at the cost "
"of recomputing decoder layer inputs in the backward pass."
),
)


class RematAndOffload(BaseModel):
Expand Down Expand Up @@ -2850,7 +2858,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
# For AOT compilation and correctness, always prioritize the 'stage' axis for sharding when pipelining.
for rule in self.logical_axis_rules:
if rule and rule[0] == "activation_embed_and_logits_batch":
rule[1] = ["stage", "data", "fsdp", "fsdp_transpose", "expert"]
rule[1] = [ax for ax in ["stage", "data", "fsdp", "fsdp_transpose", "expert"] if ax in self.mesh_axes]
break

if "stage" in self.mesh_axes:
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/kernels/gather_reduce_sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __getitem__(self, shape):
_BF16 = VectorTypeHelper(ir.BF16Type.get)


# fmt: off
@jax.jit(
static_argnames=[
"reduce_group_size",
Expand All @@ -69,6 +70,7 @@ def __getitem__(self, shape):
"topk_wgt_zero_nan",
],
)
# fmt: on
def sc_gather_reduce(
op: jax.Array,
idx: jax.Array,
Expand Down
25 changes: 16 additions & 9 deletions src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,13 +1624,22 @@ def _sequence_descriptor(segment_ids):
dummy_attn_mask = None
mask_type = "causal"
else:
# Default case: no packing, no context parallelism
dummy_attn_mask = jnp.zeros(
(1, 1, 1, self.max_target_length, self.max_target_length),
dtype=jnp.uint8,
)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8)
# Default case: no packing, no context parallelism.
# For synthetic data, segment IDs are always all-ones (one segment per sequence), so
# the segment mask is all-True and the combined mask reduces to pure causal masking.
# Use mask_type="causal" directly to avoid materializing f32/s32[seq,seq] tensors that
# XLA loop_broadcast_fusion hoists into the pipeline scan carry (+5 GiB temp memory).
if self.config.dataset_type == "synthetic":
attn_mask = None
dummy_attn_mask = None
mask_type = "causal"
else:
dummy_attn_mask = jnp.zeros(
(1, 1, 1, self.max_target_length, self.max_target_length),
dtype=jnp.uint8,
)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8)

dpa_layer = DotProductAttention(
head_dim=head_dim,
Expand All @@ -1643,12 +1652,10 @@ def _sequence_descriptor(segment_ids):
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout=qkv_layout,
scale_factor=1.0,
transpose_batch_sequence=False,
window_size=sliding_window_size,
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
context_parallel_axis=self.config.context_sharding,
context_parallel_strategy=self.config.context_parallel_strategy,
max_segments_per_seq=max_segments_per_seq,
)

Expand Down
1 change: 1 addition & 0 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def __init__(
mesh=mesh,
shard_mode=config.shard_mode,
debug_sharding=config.debug_sharding,
skip_trivial_specs=True,
)

def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
Expand Down
39 changes: 23 additions & 16 deletions src/maxtext/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding

from flax import linen as nn
from flax import nnx

from maxtext.common.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType
Expand Down Expand Up @@ -157,30 +158,36 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:
self.dtype,
)

output_axis_names = (
(
"activation_embed_and_logits_batch",
"prefill_activation_length",
"activation_embed",
)
if model_mode == MODEL_MODE_PREFILL
else (
"activation_embed_and_logits_batch",
"activation_length",
"activation_embed",
)
)
out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None))
output_prefill_axis_names = ("activation_embed_and_logits_batch", "prefill_activation_length", "activation_embed")
output_default_axis_names = ("activation_embed_and_logits_batch", "activation_length", "activation_embed")

out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None
if self.config.shard_mode == ShardMode.EXPLICIT:
output_axis_names = output_prefill_axis_names if model_mode == MODEL_MODE_PREFILL else output_default_axis_names
out_pspec = logical_to_mesh_axes(
output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None)
)
out_sharding = NamedSharding(self.mesh, out_pspec)
else:
out_sharding = None

if cfg.use_iota_embed:
one_hot_elements = 1
for d in inputs.shape:
one_hot_elements *= d
one_hot_elements *= self.num_embeddings
one_hot_bytes = one_hot_elements * jnp.dtype(self.dtype).itemsize
use_iota = cfg.use_iota_embed and one_hot_bytes <= 2 * 1024**3

if use_iota:
iota = lax.iota(jnp.int32, self.num_embeddings)
one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
output = jnp.dot(one_hot, embedding, out_sharding=out_sharding)
else:
output = embedding.at[inputs].get(out_sharding=out_sharding)

if model_mode == MODEL_MODE_PREFILL:
output = nn.with_logical_constraint(output, output_prefill_axis_names)
else:
output = nn.with_logical_constraint(output, output_default_axis_names)
return output

def attend(self, query: Array, out_sharding: NamedSharding | None = None) -> Array:
Expand Down
17 changes: 13 additions & 4 deletions src/maxtext/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax
from jax import lax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import NamedSharding, reshard
from maxtext.common.common_types import Array, DType, ShardMode
from maxtext.layers import nnx_wrappers
from maxtext.layers.initializers import Initializer, variable_to_logically_partitioned
Expand Down Expand Up @@ -78,7 +78,10 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->

if not self.with_scale:
if out_sharding is not None:
y = jax.lax.with_sharding_constraint(y, out_sharding)
if self.shard_mode == ShardMode.EXPLICIT:
y = reshard(y, out_sharding)
else:
y = jax.lax.with_sharding_constraint(y, out_sharding)
return y

scale = self.scale.get_value()
Expand All @@ -88,8 +91,14 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
scale = jax.device_put(scale, max_utils.device_space())

scale = jnp.asarray(scale, self.dtype)
effective_scale = scale + self.scale_offset
return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding)
effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale
y = y * effective_scale
if out_sharding is not None:
if self.shard_mode == ShardMode.EXPLICIT:
y = reshard(y, out_sharding)
else:
y = jax.lax.with_sharding_constraint(y, out_sharding)
return y


class GlobalRMSNorm(RMSNorm):
Expand Down
Loading
Loading