Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class DecoderBlockType(enum.Enum):
SIMPLE_MLP = "simple_mlp"
LLAMA4 = "llama4"
OLMO3 = "olmo3"
DEEPSEEK4 = "deepseek4"


class AttentionType(enum.Enum):
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/common/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _log_training_metrics(self, metrics, step):

if self.config.num_experts > 1:
moe_lb_loss = scalars.get("learning/moe_lb_loss", 0.0)
log_parts.append(f"moe_lb_loss: {moe_lb_loss:.3f}")
log_parts.append(f"moe_lb_loss: {moe_lb_loss:.6f}")

if self.config.mtp_num_layers > 0:
mtp_loss = scalars.get("learning/mtp_loss", 0.0)
Expand Down
68 changes: 68 additions & 0 deletions src/maxtext/configs/models/deepseek4-284b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Model config for DeepSeek-V4-Flash 284B (https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash)

base_emb_dim: 4096
base_num_query_heads: 64
base_num_kv_heads: 1
base_num_decoder_layers: 43
base_mlp_dim: 2048
base_moe_mlp_dim: 2048
vocab_size: 129280
head_dim: 512

# --- Standard Defaults ---
enable_dropout: false
logits_via_embedding: false
normalization_layer_epsilon: 1.0e-6

# --- V4 Specific Architectural Keys ---
decoder_block: "deepseek4"
mhc_expansion_rate: 4
first_num_hash_layers: 3
indexer_head_dim: 128
indexer_n_heads: 64
indexer_topk: 512

# Note: Layers (0, 1, 2) are prefix layers.
# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now.
# This leaves exactly 43 layers: 3 prefix [0,0,4] + 40 scanned.
compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4]

# --- MoE configuration ---
mlp_activations: ["silu", "linear"]
num_experts: 256
num_experts_per_tok: 6
mlp_activations_limit: 10
shared_experts: 1
routed_score_func: "sqrtsoftplus"
routed_bias: true
routed_bias_update_rate: 0.001
load_balance_loss_weight: 0.0001
adamw_mask: [".*gate.*bias.*"]

# --- Attention configuration ---
attention_type: 'compressed'
q_lora_rank: 1024
o_groups: 8
o_lora_rank: 1024
sliding_window_size: 128

# --- RoPE ---

rope_type: "default"
rope_max_timescale: 10000 # Main RoPE theta
compressed_rope_max_timescale: 160000 # Compressed RoPE theta
max_position_embeddings: 1048576
69 changes: 69 additions & 0 deletions src/maxtext/configs/models/deepseek4-tiny.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny model config for DeepSeek V4 for CPU execution and testing

base_emb_dim: 64
base_num_query_heads: 4
base_num_kv_heads: 1
base_num_decoder_layers: 43
base_mlp_dim: 64
base_moe_mlp_dim: 64
vocab_size: 129280
head_dim: 32
qk_rope_head_dim: 32

# --- Standard Defaults ---
enable_dropout: false
logits_via_embedding: false
normalization_layer_epsilon: 1.0e-6

# --- V4 Specific Architectural Keys ---
decoder_block: "deepseek4"
mhc_expansion_rate: 4
first_num_hash_layers: 3
indexer_head_dim: 32
indexer_n_heads: 4
indexer_topk: 16

# Note: Layers (0,1) are not compressed.
# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now.
# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4].
compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4]

# --- MoE configuration ---
mlp_activations: ["silu", "linear"]
num_experts: 16
num_experts_per_tok: 4
shared_experts: 1
routed_score_func: "sqrtsoftplus"
routed_bias: true
routed_bias_update_rate: 0.001
load_balance_loss_weight: 0.0001
adamw_mask: [".*gate.*bias.*"]

# --- Attention configuration ---
attention: 'dot_product'
attention_type: 'compressed'
q_lora_rank: 16
o_groups: 4
o_lora_rank: 16
sliding_window_size: 32

# --- RoPE ---

rope_type: "default"
rope_max_timescale: 10000 # Main RoPE theta
compressed_rope_max_timescale: 160000 # Compressed RoPE theta
max_position_embeddings: 4096
9 changes: 6 additions & 3 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ class ProfilerType(str, Enum):
"deepseek3-test",
"deepseek3-tiny",
"deepseek3.2-671b",
"deepseek4",
"deepseek4-284b",
"deepseek4-tiny",
"deepseek-custom",
"kimi-k2-1t",
"gemma-7b",
Expand Down Expand Up @@ -553,7 +554,7 @@ class Attention(BaseModel):
"autoselected",
description="The attention algorithm to use (dot_product, flash, etc).",
)
attention_type: Literal["global", "local_sliding", "chunk", "mla", "full"] = Field(
attention_type: Literal["global", "local_sliding", "chunk", "mla", "full", "compressed"] = Field(
"global", description="The variant of attention to use."
)
share_kv_projections: bool = Field(
Expand Down Expand Up @@ -2925,6 +2926,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
raise ValueError("`local_checkpoint_period` must be > 0 for emergency checkpointing.")
if self.moba and self.attention not in ("dot_product"):
raise ValueError("MoBA is only supported with dot_product attention.")
if self.decoder_block == DecoderBlockType.DEEPSEEK4 and self.attention != "dot_product":
raise ValueError("DeepSeek4 decoder block currently only supports dot_product attention.")
if self.use_indexer:
if self.q_lora_rank == 0:
raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.")
Expand Down Expand Up @@ -2986,7 +2989,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
)
if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1:
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK:
if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block not in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK4):
raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.")
if self.model_name.startswith("deepseek4") and self.first_num_hash_layers > 0 and self.use_ring_of_experts:
raise ValueError("DeepSeek V4 hash routing is currently not supported with ring of experts.")
Expand Down
37 changes: 19 additions & 18 deletions src/maxtext/layers/attention_compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,24 +680,23 @@ def __init__(
rngs: Optional[nnx.Rngs] = None,
**kwargs,
):
"""Initializes the CompressedAttention layer.
"""Inherits all standard Attention hyperparameters and selectively instantiates
an underlying HCA or CSA compressor based on the provided `compress_ratio`.

Inherits all standard Attention hyperparameters and selectively instantiates
an underlying HCA or CSA compressor based on the provided `layer_type`.
Highlights of DeepSeek-V4 attention integration:
- Shared-KV: The layer supports decoupling Q and KV heads for heavy compression.
- MQA: Multi-Query Attention used alongside heavy KV compression.
- 3 Different Attention Modes: Sliding Window (prefix), HCA (128x), and CSA (4x).
- Dual RoPE Theta: Uses 10000 for standard uncompressed tokens and 160000 for compressed.

Args:
(See maxtext.layers.attentions.Attention for standard attention arguments)
q_lora_rank: The rank for the LoRA projection in the compressed query.
compress_ratio: The compression ratio for the compressor.
compress_ratio: The compression ratio (0, 4, or 128) for the compressor.
"""
"""Initializes the Compressed Attention module."""
self.q_lora_rank = q_lora_rank
self.compress_ratio = compress_ratio

# Determine the correct underlying attention type based on the compress_ratio
if self.compress_ratio == 0:
attention_type = AttentionType.LOCAL_SLIDING

super().__init__(
config=config,
num_query_heads=num_query_heads,
Expand Down Expand Up @@ -809,20 +808,22 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
rngs=self.rngs,
)

# DeepSeek-V4 uses a separate RoPE theta (160000) for compressed tokens.
# We must instantiate a dedicated rotary embedding for the compressors
self.compress_rotary_embedding = DeepSeekV4RotaryEmbedding(
# Override the base rotary embedding with the correct theta for this layer.
# CSA / HCA layers use compressed_rope_max_timescale (160000).
# Sliding window prefix layers use rope_max_timescale (10000).
rope_theta = self.config.compressed_rope_max_timescale if self.compress_ratio > 0 else self.config.rope_max_timescale
self.rotary_embedding = DeepSeekV4RotaryEmbedding(
head_dim=self.config.head_dim,
partial_rotary_factor=1.0,
rope_theta=self.config.compressed_rope_max_timescale,
dtype=self.dtype,
partial_rotary_factor=self.config.qk_rope_head_dim / self.config.head_dim,
rope_theta=rope_theta,
fprop_dtype=self.dtype,
)

if self.compress_ratio > 4:
self.hca_compressor = DeepseekV4HCACompressor(
config=self.config,
compress_ratio=self.compress_ratio,
rotary_embedding=self.compress_rotary_embedding,
rotary_embedding=self.rotary_embedding,
kernel_init=self.kernel_init,
quant=self.quant,
model_mode=self.model_mode,
Expand All @@ -832,7 +833,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
self.csa_compressor = DeepseekV4CSACompressor(
config=self.config,
compress_ratio=self.compress_ratio,
rotary_embedding=self.compress_rotary_embedding,
rotary_embedding=self.rotary_embedding,
kernel_init=self.kernel_init,
quant=self.quant,
model_mode=self.model_mode,
Expand Down Expand Up @@ -1047,7 +1048,7 @@ def __call__(
# -> [batch, q_length, emb_dim]
final_out = self.o_b_proj(grouped_flat)

return final_out
return final_out, None


def compressed_attention(
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ def init_rotary_embedding(self):
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)

elif self.is_qwen3_hybrid:
rotary_embedding = PartialRotaryEmbedding(
min_timescale=self.config.rope_min_timescale,
Expand Down
Loading