diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 9031cf4298..400d0d198d 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -186,7 +186,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly. float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax -float32_weight_sum: true # whether to use full fp32 precision to sum expert weights for numerical stability +float32_weight_sum: false # whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device float32_gate_logits: false # whether to cast inputs to fp32 to compute MoE gate logits for numerical stability # multi-token prediction configs @@ -320,6 +320,7 @@ scan_pipeline_repeats: false scan_layers_per_stage: false set_remat_policy_on_pipeline_iterations: true set_remat_policy_on_layers_per_stage: false +pipeline_save_decoder_layer_input: true # set to false to reduce pipeline tmem at cost of recomputing decoder layer inputs in backward pass # Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp', diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 1eca954e53..2b44a395cc 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -759,8 +759,8 @@ class MoEGeneral(BaseModel): description="Enable top-k probability normalization for router weights (Qwen3-specific).", ) float32_weight_sum: bool = Field( - True, - description="Whether to use full fp32 precision to sum expert weights for numerical stability.", + False, + description="Whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device.", ) float32_gate_logits: bool = Field( False, @@ -1011,6 +1011,14 @@ class PipelineParallelism(BaseModel): scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.") set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.") set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.") + pipeline_save_decoder_layer_input: bool = Field( + True, + description=( + "Whether to save 'decoder_layer_input' activations in the pipeline remat policy. " + "Setting to False reduces temporary memory (tmem) during pipeline execution at the cost " + "of recomputing decoder layer inputs in the backward pass." + ), + ) class RematAndOffload(BaseModel): @@ -2850,7 +2858,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de # For AOT compilation and correctness, always prioritize the 'stage' axis for sharding when pipelining. for rule in self.logical_axis_rules: if rule and rule[0] == "activation_embed_and_logits_batch": - rule[1] = ["stage", "data", "fsdp", "fsdp_transpose", "expert"] + rule[1] = [ax for ax in ["stage", "data", "fsdp", "fsdp_transpose", "expert"] if ax in self.mesh_axes] break if "stage" in self.mesh_axes: diff --git a/src/maxtext/kernels/gather_reduce_sc.py b/src/maxtext/kernels/gather_reduce_sc.py index 5b3b8e7597..c858b45bf5 100644 --- a/src/maxtext/kernels/gather_reduce_sc.py +++ b/src/maxtext/kernels/gather_reduce_sc.py @@ -55,6 +55,7 @@ def __getitem__(self, shape): _BF16 = VectorTypeHelper(ir.BF16Type.get) +# fmt: off @jax.jit( static_argnames=[ "reduce_group_size", @@ -69,6 +70,7 @@ def __getitem__(self, shape): "topk_wgt_zero_nan", ], ) +# fmt: on def sc_gather_reduce( op: jax.Array, idx: jax.Array, diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index b3c3f296f4..2c937385d2 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1624,13 +1624,22 @@ def _sequence_descriptor(segment_ids): dummy_attn_mask = None mask_type = "causal" else: - # Default case: no packing, no context parallelism - dummy_attn_mask = jnp.zeros( - (1, 1, 1, self.max_target_length, self.max_target_length), - dtype=jnp.uint8, - ) - attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) + # Default case: no packing, no context parallelism. + # For synthetic data, segment IDs are always all-ones (one segment per sequence), so + # the segment mask is all-True and the combined mask reduces to pure causal masking. + # Use mask_type="causal" directly to avoid materializing f32/s32[seq,seq] tensors that + # XLA loop_broadcast_fusion hoists into the pipeline scan carry (+5 GiB temp memory). + if self.config.dataset_type == "synthetic": + attn_mask = None + dummy_attn_mask = None + mask_type = "causal" + else: + dummy_attn_mask = jnp.zeros( + (1, 1, 1, self.max_target_length, self.max_target_length), + dtype=jnp.uint8, + ) + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) dpa_layer = DotProductAttention( head_dim=head_dim, @@ -1643,12 +1652,10 @@ def _sequence_descriptor(segment_ids): dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout=qkv_layout, - scale_factor=1.0, transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, context_parallel_axis=self.config.context_sharding, - context_parallel_strategy=self.config.context_parallel_strategy, max_segments_per_seq=max_segments_per_seq, ) diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 679c891360..062fcda34f 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -553,6 +553,7 @@ def __init__( mesh=mesh, shard_mode=config.shard_mode, debug_sharding=config.debug_sharding, + skip_trivial_specs=True, ) def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 86b6723bd5..77b590a527 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -23,6 +23,7 @@ import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding +from flax import linen as nn from flax import nnx from maxtext.common.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType @@ -157,30 +158,36 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: self.dtype, ) - output_axis_names = ( - ( - "activation_embed_and_logits_batch", - "prefill_activation_length", - "activation_embed", - ) - if model_mode == MODEL_MODE_PREFILL - else ( - "activation_embed_and_logits_batch", - "activation_length", - "activation_embed", - ) - ) - out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None)) + output_prefill_axis_names = ("activation_embed_and_logits_batch", "prefill_activation_length", "activation_embed") + output_default_axis_names = ("activation_embed_and_logits_batch", "activation_length", "activation_embed") - out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None + if self.config.shard_mode == ShardMode.EXPLICIT: + output_axis_names = output_prefill_axis_names if model_mode == MODEL_MODE_PREFILL else output_default_axis_names + out_pspec = logical_to_mesh_axes( + output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None) + ) + out_sharding = NamedSharding(self.mesh, out_pspec) + else: + out_sharding = None - if cfg.use_iota_embed: + one_hot_elements = 1 + for d in inputs.shape: + one_hot_elements *= d + one_hot_elements *= self.num_embeddings + one_hot_bytes = one_hot_elements * jnp.dtype(self.dtype).itemsize + use_iota = cfg.use_iota_embed and one_hot_bytes <= 2 * 1024**3 + + if use_iota: iota = lax.iota(jnp.int32, self.num_embeddings) one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) output = jnp.dot(one_hot, embedding, out_sharding=out_sharding) else: output = embedding.at[inputs].get(out_sharding=out_sharding) + if model_mode == MODEL_MODE_PREFILL: + output = nn.with_logical_constraint(output, output_prefill_axis_names) + else: + output = nn.with_logical_constraint(output, output_default_axis_names) return output def attend(self, query: Array, out_sharding: NamedSharding | None = None) -> Array: diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index e98977c60c..0b9d9ef6fc 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -22,7 +22,7 @@ import jax from jax import lax import jax.numpy as jnp -from jax.sharding import NamedSharding +from jax.sharding import NamedSharding, reshard from maxtext.common.common_types import Array, DType, ShardMode from maxtext.layers import nnx_wrappers from maxtext.layers.initializers import Initializer, variable_to_logically_partitioned @@ -78,7 +78,10 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> if not self.with_scale: if out_sharding is not None: - y = jax.lax.with_sharding_constraint(y, out_sharding) + if self.shard_mode == ShardMode.EXPLICIT: + y = reshard(y, out_sharding) + else: + y = jax.lax.with_sharding_constraint(y, out_sharding) return y scale = self.scale.get_value() @@ -88,8 +91,14 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> scale = jax.device_put(scale, max_utils.device_space()) scale = jnp.asarray(scale, self.dtype) - effective_scale = scale + self.scale_offset - return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding) + effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale + y = y * effective_scale + if out_sharding is not None: + if self.shard_mode == ShardMode.EXPLICIT: + y = reshard(y, out_sharding) + else: + y = jax.lax.with_sharding_constraint(y, out_sharding) + return y class GlobalRMSNorm(RMSNorm): diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index 07083a824e..762fb28151 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -117,6 +117,7 @@ def _maybe_shard_with_logical(self, inputs, logical_axes): rules=self.config.logical_axis_rules, debug_sharding=self.config.debug_sharding, extra_stack_level=1, + skip_trivial_specs=True, ) def _maybe_shard_with_name(self, inputs, sharding_name): @@ -138,7 +139,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) state_io_batch_idx = loop_iteration % self.microbatches_per_stage state_io_slice = state_io[:, state_io_batch_idx] - shift = self._maybe_shard_with_logical(shift, self.stages_in_logical) if self.use_circ_storage: # Setup potential input from circ_storage, which also has a rotating index for microbatch, @@ -153,7 +153,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. # from circ_storage). first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) - first_stage_in = self._maybe_shard_with_logical(first_stage_in, self.stages_in_logical) # Note that first_stage_in may correspond to bubble computation during the last few iterations. # However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are @@ -163,11 +162,7 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): def select_state_or_input(first_stage_in, shift): # Selects input for stage 0, shift for other stages - return jnp.where( - jax.lax.broadcasted_iota("int32", shift.shape, 0, out_sharding=self.stages_in_sharding) == 0, - first_stage_in, - shift, - ) + return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift) # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) stages_in = select_state_or_input(first_stage_in, shift) @@ -178,7 +173,6 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): non-circular""" # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) - microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage"))) microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches return microbatch_ids, repeat_ids @@ -187,10 +181,133 @@ def get_pipeline_remat_policy(self): """Returns the pipeline remat policy for this pipeline.""" if self.config.remat_policy == "custom": return self.remat_policy - save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") + + names_to_save = ["iteration_input"] + if self.config.pipeline_save_decoder_layer_input: + names_to_save.append("decoder_layer_input") + save_input_policy = jax.checkpoint_policies.save_only_these_names(*names_to_save) if self.remat_policy is not None: - return jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) - return save_input_policy + remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) + else: + remat_policy = save_input_policy + return remat_policy + + def get_weight_sharding(self, *init_args): + """get weight sharding function for this pipeline.""" + key = jax.random.PRNGKey(0) + keys = {"params": key, "dropout": key, "aqt": key} + weights = self.init(keys, *init_args) + + def get_partition_spec(pytree): + def _is_leaf(x): + return isinstance(x, nn.spmd.LogicallyPartitioned) + + def get_partition_spec_leaf(leaf): + return leaf.get_partition_spec() + + return jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf) + + partition_spec_with_extra_layer = get_partition_spec(weights) + logical_partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]} + return logical_partition_spec + + def get_vmap_func_for_init(self): + """This vmap func is used to initialize the weights only on init.""" + + def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): + return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + + vmap_func = nn.vmap( + func_to_vmap, + in_axes=(0, 0, 0, None, None), + spmd_axis_name=self.spmd_axis_name, + variable_axes={"params": 0, "_overwrite_with_gradient": 0}, + split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + return vmap_func + + def get_main_vmap_func_for_iterations(self): + """ + Returns main stage function vmapped by number of stages. + This becomes a vmap over a single layer instance if body_instance is a single layer, + else a set of layers if body_instance is a set of layers. + """ + + def func_to_vmap( + body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode + ): + return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + + vmap_func = nn.vmap( + func_to_vmap, + in_axes=(0, 0, 0, 0, None, None), + spmd_axis_name=self.spmd_axis_name, + variable_axes={"params": 0}, + split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + return vmap_func + + def _run_weight_initialization( + self, example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode + ): + """Runs the initialization sequence mapping layers appropriately based on pipeline settings.""" + vmap_func = self.get_vmap_func_for_init() + + if self.config.num_pipeline_repeats > 1: + vmap_func = nn.vmap( + vmap_func, + in_axes=(0, segment_idx, position_idx, None, None), + variable_axes={"params": 0, "_overwrite_with_gradient": 0, "non_trainable": 0, "hyper_params": 0}, + split_rngs={"params": True, "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": True, + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + }, + ) + example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) + example_segmentation = ( + jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) + if example_segmentation is not None + else None + ) + example_position = ( + jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) + if example_position is not None + else None + ) + + example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None)) + stage_outputs = vmap_func( + self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode + ) + if self.config.scan_layers: + stage_outputs = stage_outputs[0] + if self.config.num_pipeline_repeats > 1: + stage_outputs = stage_outputs[0] + broadcasted_stage_outpus = jax.lax.broadcast( + stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size] + ) + + return jnp.reshape( + broadcasted_stage_outpus, + [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim], + out_sharding=self.output_sharding, + ) @staticmethod def _remove_fsdp_from_physical_partition_spec(physical_partition_spec): @@ -355,9 +472,7 @@ def vmap_gather(self, xs, ids, ids_dim): ndim = xs.ndim def _gather_one(x, i): - idx = tuple(i if d == ids_dim else slice(None) for d in range(ndim)) - replicated_sharding = NamedSharding(self.mesh, P()) - return x.at[idx].get(out_sharding=replicated_sharding) + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None) outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) @@ -381,20 +496,16 @@ def get_new_loop_state(self, output, loop_state): loop_iteration = loop_state["loop_iteration"] old_prev_outputs = loop_state["prev_outputs"] - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _rotate_right(arr): - # we use +1 for right shifting - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - return jax.lax.ppermute(arr, axis_name="stage", perm=perm) + # Use lax.slice to avoid generating a gather. + last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0) + except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0) + return jnp.concatenate([last, except_last], axis=0) - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _shift_right(arr): - stage_idx = jax.lax.axis_index("stage") - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr) + padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) + # Use lax.slice to guarantee the gradient is a pad. + return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) # Shift either rotates or shifts depending on if the last stage immediately must send to first or not # For non-circular pipelines, the last stage does not need to send to first @@ -437,29 +548,17 @@ def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): stream_buf_idx = loop_iteration % self.microbatches_per_stage stream_slice = old_state_io[:, stream_buf_idx] - def _rotate_left(arr, stage_size): - # we use -1 for left shifting - perm = [(i, (i - 1) % stage_size) for i in range(stage_size)] - return jax.lax.ppermute(arr, axis_name="stage", perm=perm) - - def _shift_left(arr, stage_size, output): - stage_idx = jax.lax.axis_index("stage") - arr = _rotate_left(arr, stage_size) - return jnp.where(stage_idx == stage_size - 1, output, arr) - - @jax.shard_map( - mesh=self.mesh, - in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()), - out_specs=self.state_io_spec, - ) - def _update_state_io(state_in, stream_slice, output, stream_buf_idx): + def _update_state_io(state_in, stream_slice, output): # Shift the current slice to the left, then fill the last stage with the final output. - stage_size = jax.lax.axis_size("stage") - stream_slice = _shift_left(stream_slice, stage_size, output) + padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) + stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) + stream_slice = jnp.where( + jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice + ) stream_slice = jnp.expand_dims(stream_slice, 1) return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1) - new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx) + new_state = _update_state_io(old_state_io, stream_slice, output) return { "state_io": new_state, diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 27e1a6f7ad..c17be0c7c5 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -42,6 +42,7 @@ from maxtext.utils import max_utils from maxtext.utils.sharding import create_sharding from maxtext.utils.sharding import maybe_shard_with_logical +from maxtext.utils.sharding import remove_size_one_mesh_axis import transformers @@ -483,15 +484,14 @@ def __call__( return outputs, None # bf16 and fp8 code path for pure-JAX batch-split. - # fp8 code path supports both manual quantization and qwix - # quantization. - input_sharding = jax.typeof(inputs).sharding - activation_pspec = jax.sharding.PartitionSpec( - ("data", "fsdp", "expert"), - None, - None, + activation_pspec = remove_size_one_mesh_axis( + jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert", "context"), + None, + None, + ), + self.mesh, ) - inputs = jax.reshard(inputs, jax.sharding.NamedSharding(self.mesh, activation_pspec)) yarn_freqs = deepseek_batchsplit.initialize_yarn_freqs( decoder_positions, embedding_dims=self.config.qk_rope_head_dim, @@ -563,7 +563,6 @@ def extract_fn(x): in_specs=([activation_pspec] * self.config.batch_split_factor,), out_specs=activation_pspec, )(outputs) - outputs = jax.reshard(outputs, input_sharding) return outputs, None x = self.with_logical_constraint(inputs) diff --git a/src/maxtext/models/mixtral.py b/src/maxtext/models/mixtral.py index faf69273c6..cf97ec93ad 100644 --- a/src/maxtext/models/mixtral.py +++ b/src/maxtext/models/mixtral.py @@ -31,6 +31,7 @@ from maxtext.layers.normalizations import RMSNorm from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.utils import max_utils +from maxtext.utils.sharding import maybe_shard_with_logical # ----------------------------------------- # The Decoder Layer for Mixtral @@ -136,14 +137,26 @@ def __call__( kv_cache=None, attention_metadata=None, ): + cfg = self.config + + def shard(x): + return maybe_shard_with_logical( + x, + self.activation_axis_names, + mesh=self.mesh, + shard_mode=cfg.shard_mode, + rules=cfg.logical_axis_rules, + skip_trivial_specs=True, + ) + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] - inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) + inputs = shard(inputs) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx = self.pre_self_attention_layer_norm(inputs) - lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) + lnx = shard(lnx) attention_lnx, kv_cache = self.self_attention( lnx, @@ -157,28 +170,28 @@ def __call__( attention_metadata=attention_metadata, ) - attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) + attention_lnx = shard(attention_lnx) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) + hidden_states = shard(hidden_states) load_balance_loss = None # NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints. # The `name` represents the weight name in JAX/checkpoints and so the class name # is just for readability. mlp_lnx, load_balance_loss, _ = self.MoeBlock_0(hidden_states) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) + mlp_lnx = shard(mlp_lnx) layer_output = mlp_lnx + intermediate_inputs layer_output = self.dropout(layer_output, deterministic=deterministic) - layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) + layer_output = shard(layer_output) - if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + if cfg.load_balance_loss_weight > 0.0 and load_balance_loss is not None: self.sow("intermediates", "moe_lb_loss", load_balance_loss) - if self.config.record_internal_nn_metrics: + if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( @@ -187,7 +200,7 @@ def __call__( jnp.sum(layer_output == 0) / jnp.size(layer_output), ) - if self.config.scan_layers: + if cfg.scan_layers: return layer_output, None else: return layer_output, kv_cache diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 047ddb97a8..b6b4d20ec5 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -36,6 +36,13 @@ import jax.numpy as jnp from jax.sharding import NamedSharding + +import flax + +try: + flax.config.update("flax_always_shard_variable", False) +except LookupError: + pass from flax import linen as nn, nnx from flax.linen import partitioning as nn_partitioning from flax.nnx import variablelib @@ -394,10 +401,11 @@ def diff_wrapper(curr_params, custom_params, rest, config, data): (loss, (aux, new_rest)), (raw_grads, custom_grads) = grad_func(curr_params, custom_params, rest, config, data) nnx.update(state.model, nnx.State.merge(custom_grads, new_rest)) - raw_grads = jax.tree_util.tree_map( - lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, - raw_grads, - ) + if config.grad_dtype != jnp.float32: + raw_grads = jax.tree_util.tree_map( + lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, + raw_grads, + ) if config.parameter_memory_host_offload: raw_grads = jax.device_put( raw_grads, diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 836c425f09..cca5311a89 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -231,7 +231,11 @@ def jit_and_compile( def save_compiled(compiled, save_name): """Serialize and save the compiled function.""" - serialized, _, _ = serialize(compiled) + result = serialize(compiled) + # jax.experimental.serialize_executable.serialize() changed its return type: + # older JAX: (bytes, in_tree, out_tree) + # newer JAX: bytes + serialized = result[0] if isinstance(result, tuple) else result with open(save_name, "wb") as f: f.write(serialized) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 4a500e2fe1..0902717928 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -132,7 +132,15 @@ def maybe_shard_with_pspec( def maybe_shard_with_logical( - inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc="" + inputs, + logical_axes, + mesh, + shard_mode, + rules=None, + debug_sharding=False, + extra_stack_level=0, + sharding_desc="", + skip_trivial_specs=False, ): """ A wrapper of maybe_shard_with_name when logical axes are inputs @@ -147,6 +155,9 @@ def maybe_shard_with_logical( named_sharding = create_sharding(mesh, logical_axes, rules=rules) + if skip_trivial_specs and all(ax is None or ax == () for ax in named_sharding.spec): + return inputs + return maybe_shard_with_name( inputs, named_sharding, diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 1975ad1abf..d2103951d4 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -950,6 +950,7 @@ def test_circular_pipeline_ag_per_repeat_ep_ds(self): "use_random_routing=true", "max_target_length=128", "pipeline_fsdp_ag_per_repeat=true", + "pipeline_save_decoder_layer_input=false", ) )