Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions tools/convert_torch_dist_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,25 @@ def set_up_planner(
super().set_up_planner(state_dict, metadata, is_coordinator)


def get_expert_param(args, name, param):
# Models whose HF format uses fused 3D expert tensors.
# These models pass the [num_experts, ...] tensor through as-is.
# All other models split into per-expert 2D tensors (existing behavior).
# MTP experts are excluded — Qwen3.5 MTP uses per-expert split format.
_FUSED_EXPERT_MODELS = ["qwen3_5moe"]


def _use_fused_experts(model_name, key_name):
"""Whether this key should use fused 3D expert passthrough."""
if model_name is None:
return False
if not any(m in model_name for m in _FUSED_EXPERT_MODELS):
return False
if "mtp" in key_name:
return False
return True


def get_expert_param(args, name, param, model_name=None):
if ".experts." not in name:
yield name, param
return
Expand All @@ -72,15 +90,23 @@ def get_expert_param(args, name, param):
match = re.search(r"mlp.experts\.(.+)\.weight(\d+)", name)
if not match:
assert param.shape[0] == num_experts
for expert_id in range(num_experts):
expert_name = name.replace(".experts.experts.", ".experts.") + str(expert_id)
expert_param = param[expert_id]
yield expert_name, expert_param
if _use_fused_experts(model_name, name):
fixed_name = name.replace(".experts.experts.", ".experts.")
if fixed_name.endswith(".weight"):
fixed_name = fixed_name[: -len(".weight")]
yield fixed_name, param
else:
for expert_id in range(num_experts):
expert_name = name.replace(".experts.experts.", ".experts.") + str(expert_id)
expert_param = param[expert_id]
yield expert_name, expert_param
else:
if _use_fused_experts(model_name, name):
return
yield name, param


def get_layer_param(args, name, param):
def get_layer_param(args, name, param, model_name=None):
if ".layers." not in name:
yield name, param
return
Expand All @@ -92,15 +118,15 @@ def get_layer_param(args, name, param):
for layer_id in range(num_layers):
layer_name = name.replace(".layers.", f".layers.{layer_id}.")
layer_param = param[layer_id]
yield from get_expert_param(args, layer_name, layer_param)
yield from get_expert_param(args, layer_name, layer_param, model_name)
else:
yield from get_expert_param(args, name, param)
yield from get_expert_param(args, name, param, model_name)


def get_named_params(args, state_dict):
def get_named_params(args, state_dict, model_name=None):
for name, param in state_dict.items():
name = f"module.module.{name}"
yield from get_layer_param(args, name, param)
yield from get_layer_param(args, name, param, model_name)


def save_tensors(args, model_name, state_dict, output_dir, chunk_size, vocab_size=None, origin_hf_dir=None):
Expand All @@ -111,7 +137,7 @@ def save_tensors(args, model_name, state_dict, output_dir, chunk_size, vocab_siz
total_size = 0
modeltensors = [{}]
converted_names = set()
for name, param in get_named_params(args, state_dict):
for name, param in get_named_params(args, state_dict, model_name):
if vocab_size:
param = remove_padding(name, param, vocab_size)
converted_named_tensors = convert_to_hf(args, model_name, name, param)
Expand Down Expand Up @@ -218,6 +244,8 @@ def copy_assets(origin_hf_dir, output_dir):
hf_config = AutoConfig.from_pretrained(args.origin_hf_dir, trust_remote_code=True)
args.model_name = type(hf_config).__name__.lower()

print(f"model_name={args.model_name}, fused_experts={_use_fused_experts(args.model_name, '')}")

state_dict = {}
print(f"loading model from {args.input_dir}")
t = time.time()
Expand Down