diff --git a/tools/convert_torch_dist_to_hf.py b/tools/convert_torch_dist_to_hf.py index 8049d77437..89095eb40f 100644 --- a/tools/convert_torch_dist_to_hf.py +++ b/tools/convert_torch_dist_to_hf.py @@ -63,7 +63,25 @@ def set_up_planner( super().set_up_planner(state_dict, metadata, is_coordinator) -def get_expert_param(args, name, param): +# Models whose HF format uses fused 3D expert tensors. +# These models pass the [num_experts, ...] tensor through as-is. +# All other models split into per-expert 2D tensors (existing behavior). +# MTP experts are excluded — Qwen3.5 MTP uses per-expert split format. +_FUSED_EXPERT_MODELS = ["qwen3_5moe"] + + +def _use_fused_experts(model_name, key_name): + """Whether this key should use fused 3D expert passthrough.""" + if model_name is None: + return False + if not any(m in model_name for m in _FUSED_EXPERT_MODELS): + return False + if "mtp" in key_name: + return False + return True + + +def get_expert_param(args, name, param, model_name=None): if ".experts." not in name: yield name, param return @@ -72,15 +90,23 @@ def get_expert_param(args, name, param): match = re.search(r"mlp.experts\.(.+)\.weight(\d+)", name) if not match: assert param.shape[0] == num_experts - for expert_id in range(num_experts): - expert_name = name.replace(".experts.experts.", ".experts.") + str(expert_id) - expert_param = param[expert_id] - yield expert_name, expert_param + if _use_fused_experts(model_name, name): + fixed_name = name.replace(".experts.experts.", ".experts.") + if fixed_name.endswith(".weight"): + fixed_name = fixed_name[: -len(".weight")] + yield fixed_name, param + else: + for expert_id in range(num_experts): + expert_name = name.replace(".experts.experts.", ".experts.") + str(expert_id) + expert_param = param[expert_id] + yield expert_name, expert_param else: + if _use_fused_experts(model_name, name): + return yield name, param -def get_layer_param(args, name, param): +def get_layer_param(args, name, param, model_name=None): if ".layers." not in name: yield name, param return @@ -92,15 +118,15 @@ def get_layer_param(args, name, param): for layer_id in range(num_layers): layer_name = name.replace(".layers.", f".layers.{layer_id}.") layer_param = param[layer_id] - yield from get_expert_param(args, layer_name, layer_param) + yield from get_expert_param(args, layer_name, layer_param, model_name) else: - yield from get_expert_param(args, name, param) + yield from get_expert_param(args, name, param, model_name) -def get_named_params(args, state_dict): +def get_named_params(args, state_dict, model_name=None): for name, param in state_dict.items(): name = f"module.module.{name}" - yield from get_layer_param(args, name, param) + yield from get_layer_param(args, name, param, model_name) def save_tensors(args, model_name, state_dict, output_dir, chunk_size, vocab_size=None, origin_hf_dir=None): @@ -111,7 +137,7 @@ def save_tensors(args, model_name, state_dict, output_dir, chunk_size, vocab_siz total_size = 0 modeltensors = [{}] converted_names = set() - for name, param in get_named_params(args, state_dict): + for name, param in get_named_params(args, state_dict, model_name): if vocab_size: param = remove_padding(name, param, vocab_size) converted_named_tensors = convert_to_hf(args, model_name, name, param) @@ -218,6 +244,8 @@ def copy_assets(origin_hf_dir, output_dir): hf_config = AutoConfig.from_pretrained(args.origin_hf_dir, trust_remote_code=True) args.model_name = type(hf_config).__name__.lower() + print(f"model_name={args.model_name}, fused_experts={_use_fused_experts(args.model_name, '')}") + state_dict = {} print(f"loading model from {args.input_dir}") t = time.time()