diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index d4b52207fc..71dbc105d4 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -113,6 +113,7 @@ class DecoderBlockType(enum.Enum): SIMPLE_MLP = "simple_mlp" LLAMA4 = "llama4" OLMO3 = "olmo3" + DEEPSEEK4 = "deepseek4" class AttentionType(enum.Enum): diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 44771ecb05..2f1a564c6d 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -197,7 +197,7 @@ def _log_training_metrics(self, metrics, step): if self.config.num_experts > 1: moe_lb_loss = scalars.get("learning/moe_lb_loss", 0.0) - log_parts.append(f"moe_lb_loss: {moe_lb_loss:.3f}") + log_parts.append(f"moe_lb_loss: {moe_lb_loss:.6f}") if self.config.mtp_num_layers > 0: mtp_loss = scalars.get("learning/mtp_loss", 0.0) diff --git a/src/maxtext/configs/models/deepseek4-284b.yml b/src/maxtext/configs/models/deepseek4-284b.yml new file mode 100644 index 0000000000..5cded0be59 --- /dev/null +++ b/src/maxtext/configs/models/deepseek4-284b.yml @@ -0,0 +1,68 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Model config for DeepSeek-V4-Flash 284B (https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) + +base_emb_dim: 4096 +base_num_query_heads: 64 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 2048 +base_moe_mlp_dim: 2048 +vocab_size: 129280 +head_dim: 512 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 128 +indexer_n_heads: 64 +indexer_topk: 512 + +# Note: Layers (0, 1, 2) are prefix layers. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 3 prefix [0,0,4] + 40 scanned. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 256 +num_experts_per_tok: 6 +mlp_activations_limit: 10 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] + +# --- Attention configuration --- +attention_type: 'compressed' +q_lora_rank: 1024 +o_groups: 8 +o_lora_rank: 1024 +sliding_window_size: 128 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 1048576 diff --git a/src/maxtext/configs/models/deepseek4-tiny.yml b/src/maxtext/configs/models/deepseek4-tiny.yml new file mode 100644 index 0000000000..881043777b --- /dev/null +++ b/src/maxtext/configs/models/deepseek4-tiny.yml @@ -0,0 +1,69 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny model config for DeepSeek V4 for CPU execution and testing + +base_emb_dim: 64 +base_num_query_heads: 4 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 64 +base_moe_mlp_dim: 64 +vocab_size: 129280 +head_dim: 32 +qk_rope_head_dim: 32 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 32 +indexer_n_heads: 4 +indexer_topk: 16 + +# Note: Layers (0,1) are not compressed. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4]. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 16 +num_experts_per_tok: 4 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] + +# --- Attention configuration --- +attention: 'dot_product' +attention_type: 'compressed' +q_lora_rank: 16 +o_groups: 4 +o_lora_rank: 16 +sliding_window_size: 32 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 4096 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e43f34f247..0cbf5d751a 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -227,7 +227,8 @@ class ProfilerType(str, Enum): "deepseek3-test", "deepseek3-tiny", "deepseek3.2-671b", - "deepseek4", + "deepseek4-284b", + "deepseek4-tiny", "deepseek-custom", "kimi-k2-1t", "gemma-7b", @@ -553,7 +554,7 @@ class Attention(BaseModel): "autoselected", description="The attention algorithm to use (dot_product, flash, etc).", ) - attention_type: Literal["global", "local_sliding", "chunk", "mla", "full"] = Field( + attention_type: Literal["global", "local_sliding", "chunk", "mla", "full", "compressed"] = Field( "global", description="The variant of attention to use." ) share_kv_projections: bool = Field( @@ -2925,6 +2926,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de raise ValueError("`local_checkpoint_period` must be > 0 for emergency checkpointing.") if self.moba and self.attention not in ("dot_product"): raise ValueError("MoBA is only supported with dot_product attention.") + if self.decoder_block == DecoderBlockType.DEEPSEEK4 and self.attention != "dot_product": + raise ValueError("DeepSeek4 decoder block currently only supports dot_product attention.") if self.use_indexer: if self.q_lora_rank == 0: raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.") @@ -2986,7 +2989,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1: raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.") - if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK: + if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block not in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK4): raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.") if self.model_name.startswith("deepseek4") and self.first_num_hash_layers > 0 and self.use_ring_of_experts: raise ValueError("DeepSeek V4 hash routing is currently not supported with ring of experts.") diff --git a/src/maxtext/layers/attention_compressed.py b/src/maxtext/layers/attention_compressed.py index e9a25f46b5..391ec6cedd 100644 --- a/src/maxtext/layers/attention_compressed.py +++ b/src/maxtext/layers/attention_compressed.py @@ -680,24 +680,23 @@ def __init__( rngs: Optional[nnx.Rngs] = None, **kwargs, ): - """Initializes the CompressedAttention layer. + """Inherits all standard Attention hyperparameters and selectively instantiates + an underlying HCA or CSA compressor based on the provided `compress_ratio`. - Inherits all standard Attention hyperparameters and selectively instantiates - an underlying HCA or CSA compressor based on the provided `layer_type`. + Highlights of DeepSeek-V4 attention integration: + - Shared-KV: The layer supports decoupling Q and KV heads for heavy compression. + - MQA: Multi-Query Attention used alongside heavy KV compression. + - 3 Different Attention Modes: Sliding Window (prefix), HCA (128x), and CSA (4x). + - Dual RoPE Theta: Uses 10000 for standard uncompressed tokens and 160000 for compressed. Args: (See maxtext.layers.attentions.Attention for standard attention arguments) q_lora_rank: The rank for the LoRA projection in the compressed query. - compress_ratio: The compression ratio for the compressor. + compress_ratio: The compression ratio (0, 4, or 128) for the compressor. """ - """Initializes the Compressed Attention module.""" self.q_lora_rank = q_lora_rank self.compress_ratio = compress_ratio - # Determine the correct underlying attention type based on the compress_ratio - if self.compress_ratio == 0: - attention_type = AttentionType.LOCAL_SLIDING - super().__init__( config=config, num_query_heads=num_query_heads, @@ -809,20 +808,22 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No rngs=self.rngs, ) - # DeepSeek-V4 uses a separate RoPE theta (160000) for compressed tokens. - # We must instantiate a dedicated rotary embedding for the compressors - self.compress_rotary_embedding = DeepSeekV4RotaryEmbedding( + # Override the base rotary embedding with the correct theta for this layer. + # CSA / HCA layers use compressed_rope_max_timescale (160000). + # Sliding window prefix layers use rope_max_timescale (10000). + rope_theta = self.config.compressed_rope_max_timescale if self.compress_ratio > 0 else self.config.rope_max_timescale + self.rotary_embedding = DeepSeekV4RotaryEmbedding( head_dim=self.config.head_dim, - partial_rotary_factor=1.0, - rope_theta=self.config.compressed_rope_max_timescale, - dtype=self.dtype, + partial_rotary_factor=self.config.qk_rope_head_dim / self.config.head_dim, + rope_theta=rope_theta, + fprop_dtype=self.dtype, ) if self.compress_ratio > 4: self.hca_compressor = DeepseekV4HCACompressor( config=self.config, compress_ratio=self.compress_ratio, - rotary_embedding=self.compress_rotary_embedding, + rotary_embedding=self.rotary_embedding, kernel_init=self.kernel_init, quant=self.quant, model_mode=self.model_mode, @@ -832,7 +833,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No self.csa_compressor = DeepseekV4CSACompressor( config=self.config, compress_ratio=self.compress_ratio, - rotary_embedding=self.compress_rotary_embedding, + rotary_embedding=self.rotary_embedding, kernel_init=self.kernel_init, quant=self.quant, model_mode=self.model_mode, @@ -1047,7 +1048,7 @@ def __call__( # -> [batch, q_length, emb_dim] final_out = self.o_b_proj(grouped_flat) - return final_out + return final_out, None def compressed_attention( diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 679c891360..ab7673d1d4 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -850,6 +850,7 @@ def init_rotary_embedding(self): shard_mode=self.config.shard_mode, rngs=self.rngs, ) + elif self.is_qwen3_hybrid: rotary_embedding = PartialRotaryEmbedding( min_timescale=self.config.rope_min_timescale, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index b28b6dcb7a..0150c7b401 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -41,6 +41,7 @@ from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.models import ( deepseek, + deepseek4, deepseek_batchsplit, deepseek_batchsplit_fp8, gemma, @@ -467,6 +468,10 @@ def get_decoder_layers(self): deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen, ] + case DecoderBlockType.DEEPSEEK4: + return ( + [deepseek4.DeepSeek4ScannableBlockToLinen] if self.config.scan_layers else [deepseek4.DeepSeek4LayerToLinen] + ) case DecoderBlockType.GEMMA: return [gemma.GemmaDecoderLayerToLinen] case DecoderBlockType.GEMMA2: @@ -632,6 +637,7 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.MISTRAL, DecoderBlockType.MIXTRAL, DecoderBlockType.DEEPSEEK, + DecoderBlockType.DEEPSEEK4, DecoderBlockType.GEMMA, DecoderBlockType.GEMMA2, DecoderBlockType.GEMMA3, @@ -1061,6 +1067,17 @@ def __call__( previous_chunk, slot, ) + elif cfg.decoder_block == DecoderBlockType.DEEPSEEK4: + y = self._apply_deepseek4_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + slot, + decoder_input_tokens, + ) else: RemattedBlockLayer = RemattedBlockLayers[0] scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) @@ -1195,7 +1212,7 @@ def __call__( "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), } - if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): + if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5, DecoderBlockType.DEEPSEEK4): layer_kwargs = {"layer_idx": lyr} kv_cache = None if kv_caches is not None: @@ -1423,6 +1440,97 @@ def _apply_gemma4_scanned_blocks( return y + def _apply_deepseek4_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + slot, + decoder_input_tokens, + ): + """Applies DeepSeek V4 scanned decoder blocks. + + DeepSeek V4 has some number of prefix layers (defined by `first_num_hash_layers`) + that use static Hash Routing. The remaining layers alternate `compress_ratio=128` (HCA) + and `compress_ratio=4` (CSA) and are evaluated in a single `nn.scan` block. + + For DeepSeek4-Flash (43 hidden layers total): + - 3 Prefix layers (Indices 0, 1, 2) + - 40 Scanned layers: 20 perfectly repeating chunks of [128, 4] + """ + + cfg = self.config + mesh = self.mesh + + broadcast_args = ( + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot, + previous_chunk, + ) + + layer_call_kwargs = { + "previous_chunk": previous_chunk, + "slot": slot, + "decoder_input_tokens": decoder_input_tokens, + } + + # 1. Prefix Unrolling + # These layers use Hash Routing. + num_hash_layers = cfg.first_num_hash_layers + for layer_idx in range(num_hash_layers): + prefix_layer = deepseek4.DeepSeek4LayerToLinen( + config=cfg, + mesh=mesh, + name=f"layers_{layer_idx}", + quant=self.quant, + model_mode=self.model_mode, + layer_idx=layer_idx, + ) + y, _ = prefix_layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + **layer_call_kwargs, + ) + + # 2. Chunked Scanning + # The remaining layers perfectly alternate HCA (128) and CSA (4). + num_remaining_layers = cfg.num_decoder_layers - num_hash_layers + num_full_blocks = num_remaining_layers // 2 + + if num_full_blocks > 0: + ScannableBlockToLinen = deepseek4.DeepSeek4ScannableBlockToLinen + policy = self.get_remat_policy() + RemattedDeepSeek4Block = self.set_remat_policy([ScannableBlockToLinen], policy)[0] + + y, _ = nn.scan( + RemattedDeepSeek4Block, + variable_axes={ + "params": cfg.param_scan_axis, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True, "dropout": cfg.enable_dropout}, + in_axes=(nn.broadcast,) * len(broadcast_args), + length=num_full_blocks, + metadata_params={ + nn.PARTITION_NAME: "layers", + "abstract_init": False, + }, + )(config=cfg, mesh=mesh, quant=self.quant, model_mode=model_mode, name="scanned_blocks",)(y, *broadcast_args) + + return y + def _apply_gemma4_small_layers( self, y, diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 86b6723bd5..ad6b171f2f 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -1803,7 +1803,7 @@ def qwen3_omni_mrope_embedding_as_linen( ) -class DeepSeekV4RotaryEmbedding(nnx.Module): +class DeepSeekV4RotaryEmbedding(RotaryEmbedding): """DeepSeek-V4 partial rotary embedding with interleaved frequencies. DeepSeek-V4 uses an interleaved positional encoding where consecutive channels @@ -1822,12 +1822,23 @@ def __init__( head_dim: int, partial_rotary_factor: float = 64.0 / 512.0, rope_theta: float = 10000.0, - dtype: Any = jnp.float32, + fprop_dtype: Any = jnp.float32, + min_timescale: int = 10000, + max_timescale: int = 10000, + mesh: Any = None, + **kwargs, ): + super().__init__( + min_timescale=min_timescale, + max_timescale=max_timescale, + mesh=mesh, + fprop_dtype=fprop_dtype, + **kwargs, + ) self.head_dim = head_dim self.partial_rotary_factor = partial_rotary_factor self.rope_theta = rope_theta - self.dtype = dtype + self.fprop_dtype = fprop_dtype # Compute the partial rotary dimension (rope_head_dim) self.dim = int(head_dim * partial_rotary_factor) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 020956098c..61b4b6db75 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -208,6 +208,10 @@ def calculate_load_balance_updates(top_k_indices, num_experts, rate): return output +class Tid2EidVar(nnx.Variable): + """Custom variable to hold tid2eid without trainable param overhead.""" + + class GateLogit(nnx.Module): """A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing.""" @@ -344,8 +348,11 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. pre_bias_logits = output if self.use_bias: + # Architectural Note: Bias is an nnx.Param rather than nnx.Variable due to Linen/NNX state + # management transitions otherwise we will have to manage the overhead. We use jax.lax.stop_gradient + # here to mathematically enforce the Auxiliary-Loss-Free constraint, isolating it from sequence-wise loss leaks. bias = jnp.asarray(self.bias[...], self.dtype) - output += bias + output += jax.lax.stop_gradient(bias) return output, pre_bias_logits @@ -399,8 +406,11 @@ def __init__( # DeepSeek V4 Hash Routing if self.is_hash_routing: # Token-ID to Expert-ID lookup table for static routing - self.tid2eid = nnx.Variable( - jnp.zeros((self.config.vocab_size, self.num_experts_per_tok), dtype=jnp.int32), + # Must be stored as float32 because MaxText passes the entire variable tree + # through jax.value_and_grad, which strictly requires all leaves to be inexact types + # (even if they receive no gradients). We cast to int32 dynamically during routing. + self.tid2eid = Tid2EidVar( + jnp.zeros((self.config.vocab_size, self.num_experts_per_tok), dtype=jnp.float32), out_sharding=None, # Replicated across shards for local lookup ) else: @@ -665,7 +675,13 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None, input_ids=None): return top_k_weights, top_k_indices if self.is_hash_routing: - top_k_indices = self.tid2eid[input_ids] + if input_ids is None: + raise ValueError("input_ids cannot be None when is_hash_routing is True") + # Access the static routing table + tid2eid_int = self.tid2eid.value + # Cast the float32 array to int32 (JAX automatically assigns 0.0 gradients to integer casts) + tid2eid_int = tid2eid_int.astype(jnp.int32) + top_k_indices = tid2eid_int[input_ids] top_k_weights = jnp.take_along_axis(pre_bias_logits, top_k_indices, axis=-1) # NOTE: deepseek2 has a different pattern elif self.config.model_name.startswith(("deepseek3", "deepseek4")): @@ -2150,7 +2166,6 @@ def dense_matmul( lb_loss = ( self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None ) - # TODO(dipakg-lang, b/521990776): Add sequence-wise balance loss * 0.0001 else: lb_loss = None diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 95bd79eb9f..86b61c7480 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -38,7 +38,7 @@ import qwix from qwix._src.core import dot_general_qt from qwix._src.core import sparsity -from qwix._src.utils import flax_util +from qwix._src import flax_util import qwix.pallas as qpl # Params used to define mixed precision quantization configs diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 27e1a6f7ad..d3a72b31bf 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -25,7 +25,7 @@ import jax.numpy as jnp from jax.sharding import Mesh from maxtext.common.common_types import Config -from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL +from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL, DecoderBlockType from maxtext.layers import attention_mla from maxtext.layers import initializers from maxtext.layers import linears @@ -138,37 +138,39 @@ def __init__( self.engram_layer_norm = None self.engram = None - self.self_attention = attention_mla.MLA( - config=self.config, - num_query_heads=self.config.num_query_heads, - num_kv_heads=self.config.num_kv_heads, - head_dim=self.config.head_dim, - max_target_length=self.config.max_target_length, - max_prefill_predict_length=self.config.max_prefill_predict_length, - attention_kernel=self.config.attention, - attention_type=self.config.attention_type, - inputs_q_shape=self.dummy_inputs_shape, - inputs_kv_shape=self.dummy_inputs_shape, - mesh=mesh, - dtype=self.config.dtype, - weight_dtype=self.config.weight_dtype, - dropout_rate=self.config.dropout_rate, - name="self_attention", - quant=quant, - kv_quant=quantizations.configure_kv_quant(config), - q_lora_rank=self.config.q_lora_rank, - kv_lora_rank=self.config.kv_lora_rank, - qk_nope_head_dim=self.config.qk_nope_head_dim, - qk_rope_head_dim=self.config.qk_rope_head_dim, - v_head_dim=self.config.v_head_dim, - max_position_embeddings=self.config.max_position_embeddings, - original_max_position_embeddings=self.config.original_max_position_embeddings, - mscale=self.config.mscale, - rope_factor=self.config.rope_factor, - model_mode=model_mode, - rngs=rngs, - attn_logits_soft_cap=self.config.attn_logits_soft_cap, - ) + # DeepSeek V4 natively overrides this block with CompressedAttention. + if self.config.decoder_block != DecoderBlockType.DEEPSEEK4: + self.self_attention = attention_mla.MLA( + config=self.config, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=self.config.attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + name="self_attention", + quant=quant, + kv_quant=quantizations.configure_kv_quant(self.config), + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + original_max_position_embeddings=self.config.original_max_position_embeddings, + mscale=self.config.mscale, + rope_factor=self.config.rope_factor, + model_mode=model_mode, + rngs=rngs, + attn_logits_soft_cap=self.config.attn_logits_soft_cap, + ) self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) if self.is_mhc_enabled: @@ -333,7 +335,7 @@ def __init__( rngs=self.rngs, ) - def mlp_op(self, x, deterministic): + def mlp_op(self, x, deterministic, *args, **kwargs): mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding) return self.with_logical_constraint(mlp) diff --git a/src/maxtext/models/deepseek4.py b/src/maxtext/models/deepseek4.py new file mode 100644 index 0000000000..12b0b83823 --- /dev/null +++ b/src/maxtext/models/deepseek4.py @@ -0,0 +1,274 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepSeek-V4 model definition.""" + +from typing import Optional + +from flax import nnx +import flax.linen as nn +from jax.sharding import Mesh + +from maxtext.common.common_types import Config, AttentionType +from maxtext.common.common_types import HyperConnectionType +from maxtext.layers import attention_compressed +from maxtext.layers import initializers +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.models import deepseek +from jax.ad_checkpoint import checkpoint_name + + +class DeepSeek4DecoderLayer(deepseek.DeepSeekGenericLayer): + """DeepSeek-V4 specific decoder layer. + + Note: V4 does not utilize purely dense layers in the initial transformer blocks. + Every layer is a Sparse MoE layer (which internally contains shared dense experts). + + Args: + config: Configuration for the model. + model_mode: The mode of the model (e.g. 'train', 'inference'). + mesh: JAX sharding mesh. + rngs: NNX Rngs. + quant: Optional AQT quantization config. + layer_idx: The index of the layer. + compress_ratio: DeepSeek V4 specific parameter defining the KV cache compression + ratio. Expected values are 0 (no compression, sliding window), 4 (CSA), or 128 (HCA). + is_hash_routing: DeepSeek V4 specific parameter defining if this layer uses + static deterministic hash routing (used in prefix layers). + """ + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + compress_ratio: Optional[int] = None, + is_hash_routing: Optional[bool] = None, + ) -> None: + super().__init__( + config=config, + model_mode=model_mode, + mesh=mesh, + rngs=rngs, + quant=quant, + layer_idx=layer_idx, + ) + + # DeepSeek V4 applies Hash Routing to the first `config.first_num_hash_layers` layers. + # For the unscannable prefix layers, we can safely determine this using `layer_idx`. + # However, for layers inside `nn.scan` blocks, `layer_idx` is a dynamic JAX tracer + # and cannot be evaluated as a boolean condition. Since all scannable layers occur + # after the hash-routed prefix, the scannable block explicitly passes + # `is_hash_routing=False` to safely bypass this check. + if is_hash_routing is None: + is_hash_routing = layer_idx < config.first_num_hash_layers + self.mlp = moe.RoutedAndSharedMoE( + config=self.config, + mesh=self.mesh, + kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + quant=quant, + is_hash_routing=is_hash_routing, + rngs=rngs, + ) + + if compress_ratio is None: + compress_ratio = config.compress_ratios[layer_idx] + + # Route to LOCAL_SLIDING if compression is disabled for this layer, + # otherwise default to the globally configured attention type (e.g., COMPRESSED). + layer_attention_type = ( + AttentionType.LOCAL_SLIDING if compress_ratio == 0 else AttentionType(self.config.attention_type) + ) + + self.self_attention = attention_compressed.CompressedAttention( + config=self.config, + compress_ratio=compress_ratio, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=layer_attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=self.mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + sliding_window_size=self.config.sliding_window_size, + q_lora_rank=self.config.q_lora_rank, + name=f"compressed_attention_layer_{layer_idx}", + quant=quant, + kv_quant=quantizations.configure_kv_quant(config), + model_mode=model_mode, + rngs=rngs, + ) + + # pylint: disable=arguments-differ + def mlp_op(self, inputs, deterministic, *args, **kwargs): + input_ids = kwargs.get("input_ids") + mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp( + inputs=inputs, + input_ids=input_ids, + ) + return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + kv_cache=None, + attention_metadata=None, + decoder_input_tokens=None, + ): + if isinstance(inputs, tuple): + inputs = inputs[0] + + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + _, intermediate_inputs = self.self_attention_with_norm_op( + x, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + slot, + ) + + layer_output, metadata = self.mhc_mlp( + self.post_attention_norm_op, + self.mlp_op, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_MOE, + deterministic=deterministic, + input_ids=decoder_input_tokens, + ) + load_balance_loss = metadata.get("load_balance_loss", None) + moe_bias_updates = metadata.get("moe_bias_updates", None) + + layer_output = self.dropout_op(layer_output, deterministic=deterministic) + return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache) + + +class DeepSeek4ScannableBlock(nnx.Module): + """A scannable block containing exactly two DeepSeek V4 layers (HCA and CSA). + + DeepSeek V4 layers alternate `compress_ratio=128` (HCA) and `compress_ratio=4` (CSA) + throughout the middle of the network. This block encapsulates one full `[128, 4]` + cycle so it can be perfectly scanned using JAX `nn.scan`. + """ + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: None | quantizations.AqtQuantization = None, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + # Layer 0 in the block: HCA (compress_ratio=128) with Standard MoE (is_hash_routing=False) + self.layers_0 = DeepSeek4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + compress_ratio=128, + is_hash_routing=False, + ) + + # Layer 1 in the block: CSA (compress_ratio=4) with Standard MoE (is_hash_routing=False) + self.layers_1 = DeepSeek4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + compress_ratio=4, + is_hash_routing=False, + ) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot=None, + previous_chunk=None, + attention_metadata=None, + kv_cache=None, + ): + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = checkpoint_name(inputs, "decoder_layer_input") + y = inputs + + y, _ = self.layers_0( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + y, _ = self.layers_1( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + return y, None + + +DeepSeek4LayerToLinen = nnx_wrappers.to_linen_class( + DeepSeek4DecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + +DeepSeek4ScannableBlockToLinen = nnx_wrappers.to_linen_class( + DeepSeek4ScannableBlock, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 9992d7674f..4200504927 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -238,6 +238,21 @@ def get_optimizer(config, learning_rate_schedule, model=None): lambda params: jax.tree_util.tree_map(lambda x: "frozen" if x else "trainable", freeze_mask_fn(params)), ) + if getattr(config, "routed_bias", False): + import re + from flax import traverse_util + bias_regex = re.compile(".*gate.*bias.*") + # Architectural Note: Optax's Muon implementation correctly routes 2D+ matrices to the + # Newton-Schulz algorithm, but its fallback logic for 1D vectors (like our GateLogit bias) + # routes them to a standard AdamW optimizer *without* exposing a weight decay mask. + # To prevent the Muon optimizer from decaying our auxiliary-loss-free bias to zero, + # we apply a global optax.set_to_zero() mask here. + def bias_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + mask = {k: bool(bias_regex.match("/".join(map(str, k)))) for k in flat_params.keys()} + return traverse_util.unflatten_dict(mask) + base_opt = optax.chain(base_opt, optax.masked(optax.set_to_zero(), bias_mask_fn)) + return base_opt diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index fd2cc7b56c..c0a0be3a04 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -36,7 +36,7 @@ import jax.numpy as jnp from jax.sharding import NamedSharding -from flax import linen as nn, nnx +from flax import linen as nn, nnx, traverse_util from flax.linen import partitioning as nn_partitioning from flax.nnx import variablelib @@ -278,12 +278,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr else: max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.") - # get MoE routed bias term updates - moe_bias_updates = None - if config.routed_bias and config.routed_bias_update_rate > 0.0: - nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates") - moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None) - # Add the model's primary output to the intermediates dict so it can be used # by the acceptance rate calculation in eval_step. intermediate_outputs["logits"] = logits @@ -295,7 +289,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "indexer_loss": indexer_loss, - "moe_bias_updates": moe_bias_updates, "mtp_loss": mtp_loss, "batch_stats": (intermediate_outputs.get("batch_stats", None) if hasattr(intermediate_outputs, "get") else None), } @@ -421,9 +414,9 @@ def diff_wrapper(curr_params, custom_params, rest, config, data): moe_lb_loss = aux["moe_lb_loss"] indexer_loss = aux.get("indexer_loss", 0.0) z_loss = aux.get("z_loss", 0.0) - moe_bias_updates = aux.get("moe_bias_updates") mtp_loss = aux.get("mtp_loss", 0.0) new_opt_state = None + bias_metrics = {} if isinstance(model, nn.Module): if config.gradient_clipping_threshold > 0: @@ -480,12 +473,39 @@ def move(path, value): else: new_state = state.apply_gradients(grads=full_grads) - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for the DeepSeek family. + # We dynamically traverse the PyTree to apply updates because the topology varies drastically: + # 1. DeepSeek V3 mixes dense layers (no bias updates) with MoE layers. + # 2. DeepSeek V4 introduces Hash Routing in early layers (which lack a learnable bias entirely). + # 3. DeepSeek V4 groups alternating attention topologies into nested `ScannableBlocks`. + # Dynamic traversal ensures we only target the correct `gate.bias` parameters without hardcoded, brittle paths. + if config.routed_bias and config.routed_bias_update_rate > 0.0: + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + flat_params = traverse_util.flatten_dict(new_state.params) + new_flat_params = dict(flat_params) + + for path, update in flat_intermediates.items(): + if path[-1] != "moe_bias_updates": + continue + prefix = path[1:-1] if path[0] == "intermediates" else path[:-1] + for param_path in flat_params: + param_prefix = param_path[1:] if param_path[0] == "params" else param_path + if len(param_prefix) >= len(prefix) and param_prefix[:len(prefix)] == prefix and param_path[-2:] == ("gate", "bias"): + update_val = update[0] if isinstance(update, (tuple, list)) else update + name_prefix = "-".join(map(str, param_path)) + + old_val = new_flat_params[param_path].value if hasattr(new_flat_params[param_path], "value") else new_flat_params[param_path] + bias_metrics[f"learning/moe_bias_before_norm_{name_prefix}"] = jnp.linalg.norm(old_val) + + new_val = old_val + jnp.array(update_val).transpose() + if hasattr(new_flat_params[param_path], "value"): + new_flat_params[param_path] = new_flat_params[param_path].replace(value=new_val) + else: + new_flat_params[param_path] = new_val + + bias_metrics[f"learning/moe_bias_update_norm_{name_prefix}"] = jnp.linalg.norm(jnp.array(update_val)) + + new_state = new_state.replace(params=traverse_util.unflatten_dict(new_flat_params)) else: if config.gradient_clipping_threshold > 0: grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) @@ -506,9 +526,31 @@ def move(path, value): new_state = state # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias - target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() + if config.routed_bias and config.routed_bias_update_rate > 0.0: + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + print("DEBUG_FLAT_INTERMEDIATES:", list(flat_intermediates.keys())) + for path, update in flat_intermediates.items(): + if path[-1] != "moe_bias_updates": + continue + target = new_state.model + prefix = path[1:-1] if path[0] == "intermediates" else path[:-1] + for key in prefix: + if hasattr(target, key): + target = getattr(target, key) + elif isinstance(target, dict) and key in target: + target = target[key] + else: + target = None + break + if target is None: + continue + for _, node in nnx.iter_graph(target): + if type(node).__name__ == "GateLogit" and hasattr(node, "bias") and node.bias is not None: + update_val = update[0] if isinstance(update, (tuple, list)) else update + name_prefix = "-".join(map(str, prefix)) + bias_metrics[f"learning/moe_bias_before_norm_{name_prefix}"] = jnp.linalg.norm(node.bias.value) + node.bias.value = node.bias.value + jnp.array(update_val).transpose() + bias_metrics[f"learning/moe_bias_update_norm_{name_prefix}"] = jnp.linalg.norm(jnp.array(update_val)) lm_loss = xent_sum / (total_weights + EPS) scalar_metrics = { @@ -521,6 +563,7 @@ def move(path, value): "learning/mtp_loss": mtp_loss, "learning/total_weights": total_weights, } + scalar_metrics.update(bias_metrics) if config.use_qk_clip: if isinstance(model, nn.Module): new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index e3b3aadf2d..48caa91ef1 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -75,6 +75,7 @@ "deepseek2-16b": "deepseek-ai/DeepSeek-V2-Lite", "deepseek3-671b": "deepseek-ai/DeepSeek-V3", "deepseek3.2-671b": "deepseek-ai/DeepSeek-V3.2", + "deepseek4": "deepseek-ai/DeepSeek-V4-Flash", "gpt-oss-20b": "openai/gpt-oss-20b", "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", diff --git a/tests/unit/deepseek_routed_bias_test.py b/tests/unit/deepseek_routed_bias_test.py new file mode 100644 index 0000000000..9237f89f8b --- /dev/null +++ b/tests/unit/deepseek_routed_bias_test.py @@ -0,0 +1,112 @@ +import unittest +import jax +import jax.numpy as jnp +import optax +from flax.training import train_state +from maxtext.configs import pyconfig +from maxtext.models import models +from maxtext.trainers.pre_train import train as pre_train +class DeepSeekRoutedBiasTest(unittest.TestCase): + def setUp(self): + self.mesh = jax.sharding.Mesh(jax.devices(), ('data',)) + def _make_dummy_data(self, batch=1, seq=16): + return { + "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), + "inputs_position": jnp.broadcast_to(jnp.arange(seq), (batch, seq)), + "inputs_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + "targets": jnp.zeros((batch, seq), dtype=jnp.int32), + "targets_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + } + def _create_and_run_train_step(self, config_args): + config = pyconfig.initialize(config_args) + rngs = jax.nnx.Rngs(0) if hasattr(jax, 'nnx') else __import__('flax.nnx', fromlist=['Rngs']).Rngs(0) + import flax.nnx as nnx + from maxtext.common import train_state_nnx + rngs = nnx.Rngs(0) + model = models.Transformer(config, self.mesh, quant=None, rngs=rngs) + data = self._make_dummy_data(batch=config.micro_batch_size_to_train_on, seq=config.max_target_length) + optimizer = nnx.Optimizer(model, optax.sgd(0.01), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + state_graphdef, state_pure = nnx.split(ts) + new_state, metrics = pre_train.train_step( + state_graphdef, config, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + return new_state, metrics + def test_deepseek_v3_dense_routed_bias_success(self): + """Proves that a DeepSeek V3 model with dense layers (no moe_layers attribute) + successfully traverses the state tree and updates routed bias without crashing. + """ + config_args = [ + "", + "src/maxtext/configs/base.yml", + "model_name=deepseek3-tiny", + "decoder_block=deepseek", + "num_decoder_layers=2", + "per_device_batch_size=1", + "max_target_length=16", + "routed_bias=True", + "routed_bias_update_rate=0.001", + "skip_jax_distributed_system=True", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "num_experts=2", + "num_experts_per_tok=2", + "first_num_dense_layers=1", + "sparse_matmul=False", + "override_model_config=True", + ] + new_state, metrics = self._create_and_run_train_step(config_args) + self.assertIsNotNone(new_state) + self.assertIn("learning/loss", metrics["scalar"]) + + def _create_and_run_linen_train_step(self, config_args): + config = pyconfig.initialize(config_args) + model = models.transformer_as_linen(config, self.mesh, quant=None) + data = self._make_dummy_data(batch=config.micro_batch_size_to_train_on, seq=config.max_target_length) + rng = jax.random.PRNGKey(0) + variables = model.init(rng, data["inputs"], data["inputs_position"], data["inputs_segmentation"]) + ts = train_state.TrainState.create( + apply_fn=model.apply, + params=variables["params"], + tx=optax.sgd(0.01) + ) + new_state, metrics = pre_train.train_step( + model, config, state_mesh_shardings=None, params_shardings=None, state=ts, data=data, dropout_rng=jax.random.PRNGKey(0) + ) + return new_state, metrics + + def test_deepseek_v3_moe_routed_bias_linen(self): + """Proves that a DeepSeek V3 model with MoE layers successfully traverses the + Linen state tree and updates routed bias. + """ + config_args = [ + "", + "src/maxtext/configs/base.yml", + "model_name=deepseek3-tiny", + "decoder_block=deepseek", + "num_decoder_layers=2", + "per_device_batch_size=1", + "max_target_length=16", + "routed_bias=True", + "routed_bias_update_rate=0.001", + "skip_jax_distributed_system=True", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "num_experts=2", + "num_experts_per_tok=2", + "first_num_dense_layers=0", + "sparse_matmul=False", + "override_model_config=True", + ] + new_state, metrics = self._create_and_run_linen_train_step(config_args) + self.assertIsNotNone(new_state) + self.assertTrue(any(key.startswith("learning/moe_bias_before_norm") for key in metrics["scalar"])) + self.assertTrue(any(key.startswith("learning/moe_bias_update_norm") for key in metrics["scalar"])) +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py index 1da95a184e..0b75aa9ff4 100644 --- a/tests/unit/deepseek_v4_vs_reference_test.py +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -57,13 +57,13 @@ # Tests # ============================================================================== -# HuggingFace reference: https://huggingface.co/deepseek-ai/DeepSeek-V4/blob/main/modeling_deepseek_v4.py +# HuggingFace reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py # pylint: disable=line-too-long from jax.experimental import mesh_utils from jax.sharding import Mesh from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.configs import pyconfig from maxtext.layers.attention_compressed import CompressedAttention -from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding as MTRope + from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.deepseek_v4.modeling_deepseek_v4 import DeepseekV4Attention from transformers.models.deepseek_v4.modeling_deepseek_v4 import DeepseekV4RotaryEmbedding as PTRope @@ -75,7 +75,7 @@ class DeepSeekV4RotaryEmbeddingTest(unittest.TestCase): def setUp(self): self.batch_size = 2 - self.seq_len = 16 + self.seq_len = 4096 self.head_dim = 128 self.num_heads = 4 self.main_rope_theta = 10000.0 @@ -408,6 +408,8 @@ def setUp(self): self.q_lora_rank = 32 self.o_groups = 2 self.o_lora_rank = 64 + self.qk_rope_head_dim = 64 + self.partial_rotary_factor = self.qk_rope_head_dim / self.head_dim self.rngs = nnx.Rngs(0) @@ -431,8 +433,12 @@ def setUp(self): layer_types=["sliding_attention"], num_hidden_layers=1, rope_parameters={ - "main": {"rope_type": "default", "rope_theta": 10000.0, "partial_rotary_factor": 1.0}, - "compress": {"rope_type": "default", "rope_theta": 160000.0, "partial_rotary_factor": 1.0}, + "main": {"rope_type": "default", "rope_theta": 10000.0, "partial_rotary_factor": self.partial_rotary_factor}, + "compress": { + "rope_type": "default", + "rope_theta": 160000.0, + "partial_rotary_factor": self.partial_rotary_factor, + }, }, sliding_window=2048, attention_dropout=0.0, @@ -524,9 +530,13 @@ def _run_e2e_test(self, layer_type, is_packed=False): "compressed_sparse_attention": self.pt_config.compress_rates["compressed_sparse_attention"], "heavily_compressed_attention": self.pt_config.compress_rates["heavily_compressed_attention"], } + compress_ratio = compress_ratio_map[layer_type] + layer_attention_type = AttentionType.LOCAL_SLIDING if compress_ratio == 0 else AttentionType.COMPRESSED + mt_attn = CompressedAttention( config=mt_config, - compress_ratio=compress_ratio_map[layer_type], + compress_ratio=compress_ratio, + attention_type=layer_attention_type, num_query_heads=self.num_heads, num_kv_heads=1, head_dim=self.head_dim, @@ -540,14 +550,6 @@ def _run_e2e_test(self, layer_type, is_packed=False): rngs=self.rngs, ) self.mt_attn = mt_attn - if layer_type == "sliding_attention": - rope_factor = self.pt_config.rope_parameters["main"]["partial_rotary_factor"] - mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=10000.0) - else: - rope_factor = self.pt_config.rope_parameters["compress"]["partial_rotary_factor"] - mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=160000.0) - - mt_attn.rotary_embedding = mt_rope # 3. Copy Weights self._copy_linear(mt_attn.wq_a, ref_attn.q_a_proj) @@ -652,8 +654,7 @@ def _run_e2e_test(self, layer_type, is_packed=False): print(f"top_k_indices mismatches: {num_mismatches}") # 6. Execute MaxText - - mt_out = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) + mt_out, _ = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) # 7. Asserts if not is_packed: @@ -771,7 +772,7 @@ def setUp(self): "vocab_size": self.vocab_size, "first_num_hash_layers": 3, "decoder_block": "deepseek", - "model_name": "deepseek4", + "model_name": "deepseek4-284b", "attention": "dot_product", "base_mlp_dim": 256, "base_moe_mlp_dim": 256, @@ -809,7 +810,7 @@ def test_hash_router(self): ) # Sync weights - mx_moe.tid2eid.value = jnp.array(pt_router.tid2eid.numpy()) + mx_moe.tid2eid.value = jnp.array(pt_router.tid2eid.numpy(), dtype=jnp.float32) mx_moe.gate.kernel.value = jnp.array(pt_router.weight.detach().numpy()).T hidden_states = torch.randn(self.batch_size, self.seq_len, self.hidden_dim) @@ -910,7 +911,7 @@ def test_swiglu_clamp(self): "topk_routing_group": 1, "mlp_activations_limit": limit, "decoder_block": "deepseek", - "model_name": "deepseek4", + "model_name": "deepseek4-284b", "attention": "dot_product", "base_mlp_dim": 256, "base_moe_mlp_dim": 256, diff --git a/tests/unit/metric_logger_test_coverage.py b/tests/unit/metric_logger_test_coverage.py new file mode 100644 index 0000000000..5567a7492c --- /dev/null +++ b/tests/unit/metric_logger_test_coverage.py @@ -0,0 +1,29 @@ +import unittest +from maxtext.common.metric_logger import MetricLogger +from maxtext.configs import pyconfig +from unittest import mock + +class MetricLoggerTest(unittest.TestCase): + def test_log_train_metrics_moe_lb_loss(self): + config = pyconfig.initialize(["", "src/maxtext/configs/base.yml", "run_name=test_run", "base_output_directory=/tmp/maxtext_output", "num_experts=2", "mtp_num_layers=0", "base_moe_mlp_dim=64", "base_mlp_dim=64"]) + + logger = MetricLogger(config, None) + metrics = { + "scalar": { + "learning/loss": 1.0, + "learning/lm_loss": 1.0, + "learning/total_weights": 1000, + "learning/moe_lb_loss": 0.000403, + "perf/step_time_seconds": 1.0, + "perf/per_device_tflops_per_sec": 1.0, + "perf/per_device_tokens_per_sec": 1.0, + } + } + with mock.patch("maxtext.common.metric_logger.max_logging.log") as mock_log: + logger._log_training_metrics(metrics, 1) + mock_log.assert_called() + called_args = mock_log.call_args[0][0] + self.assertIn("moe_lb_loss: 0.000403", called_args) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index b8eab1061e..4b9fe305eb 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -622,5 +622,44 @@ def __init__(self, rngs: nnx.Rngs): self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) +class TestGetOptimizerGlobalMask(unittest.TestCase): + """Tests that the global optimizer cleanly masks out the routed bias.""" + def test_routed_bias_global_mask(self): + config = pyconfig.initialize(["", "src/maxtext/configs/base.yml", "routed_bias=True", "opt_type=sgd"]) + # We define a dummy params dict containing a routed bias and a regular weight. + # The routed bias must be completely ignored by the optimizer. + params = { + "decoder": { + "moe_layers": { + "MoeBlock_0": { + "gate": { + "bias": jnp.array([1.0]), + "kernel": jnp.array([1.0]) + } + } + } + } + } + grads = { + "decoder": { + "moe_layers": { + "MoeBlock_0": { + "gate": { + "bias": jnp.array([0.5]), + "kernel": jnp.array([0.5]) + } + } + } + } + } + # We use sgd because it's simple to test updates, but the mask logic applies + # cleanly to any base optimizer returned by get_optimizer. + opt = optimizers.get_optimizer(config, learning_rate_schedule=0.1) + opt_state = opt.init(params) + updates, _ = opt.update(grads, opt_state, params) + # The routed bias update should be exactly 0.0 (masked by set_to_zero) + self.assertEqual(updates["decoder"]["moe_layers"]["MoeBlock_0"]["gate"]["bias"].item(), 0.0) + # The kernel should receive the SGD gradient update (-0.1 * 0.5) + self.assertTrue(updates["decoder"]["moe_layers"]["MoeBlock_0"]["gate"]["kernel"].item() < 0.0) if __name__ == "__main__": unittest.main() diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 1975ad1abf..41557c8c3c 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -804,6 +804,26 @@ def test_deepseek32(self): ) ) + def test_deepseek4(self): + # test deepseek4 compile + compiled_trainstep_file = "/tmp/test_deepseek4.pickle" + train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=deepseek4-284b", + "per_device_batch_size=1", + "max_target_length=1024", + "attention=dot_product", + "dtype=bfloat16", + "weight_dtype=bfloat16", + ) + ) + @pytest.mark.cpu_only def test_indexer_dense_warmup(self): # test deepseek3.2 with sparse attention diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index ebeededbd7..b31bc4a5dc 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -61,8 +61,12 @@ class _Cfg: shard_mode: int = 0 # ShardMode.AUTO weight_sparsity_n: int = 0 weight_sparsity_m: int = 0 + decoder_block: str = "default" +class _DummyDecoder(nnx.Module): + pass + class _TinyDecoder(nnx.Module): """Mimics NNXDecoder.__call__ enough for loss_fn to run end-to-end. @@ -73,6 +77,7 @@ class _TinyDecoder(nnx.Module): def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): self.embed = nnx.Embed(vocab_size, hidden, rngs=rngs) self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) + self.decoder = _DummyDecoder() def __call__( self, @@ -125,7 +130,6 @@ def test_returns_loss_and_full_aux_dict(self): "total_weights", "moe_lb_loss", "indexer_loss", - "moe_bias_updates", "mtp_loss", ): self.assertIn(key, aux) @@ -194,6 +198,18 @@ def test_train_step_with_gradient_clipping(self): self.assertIsInstance(new_state, nnx.State) self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + def test_train_step_deepseek_aux_loss(self): + cfg, ts = _build_state() + cfg.routed_bias = True + cfg.routed_bias_update_rate = 0.001 + cfg.decoder_block = "deepseek" + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + # The robust trainer logic will correctly traverse and NOT crash, ignoring the hardcoded path + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertIsInstance(new_state, nnx.State) class TestEvalStepNNX(unittest.TestCase): """Cover the NNX branch of eval_step (lines 568-570)."""