Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions scripts/convert_olmo_to_hf_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
"""
Sample usage:
```
python src/transformers/models/olmo/convert_olmo_weights_to_hf.py \
--input_dir /path/to/downloaded/olmo/weights --model_size 7B --output_dir /output/path
python scripts/convert_olmo_to_hf_new.py \
--input_dir /path/to/downloaded/olmo/weights --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
Expand Down Expand Up @@ -100,8 +100,8 @@ def write_model(
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
# Unsharded
# TODO: Layernorm stuff
# TODO: multi query attention
# Multi-query attention is handled by setting num_key_value_heads appropriately
# and splitting the attention projection weights accordingly
fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads]
q_proj_weight, k_proj_weight, v_proj_weight = torch.split(
loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0
Expand All @@ -121,6 +121,28 @@ def write_model(
f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight,
}

# Add LayerNorm weights
if f"transformer.blocks.{layer_i}.attn_norm.weight" in loaded:
state_dict[f"model.layers.{layer_i}.input_layernorm.weight"] = loaded[
f"transformer.blocks.{layer_i}.attn_norm.weight"
]

if f"transformer.blocks.{layer_i}.ff_norm.weight" in loaded:
state_dict[f"model.layers.{layer_i}.post_attention_layernorm.weight"] = loaded[
f"transformer.blocks.{layer_i}.ff_norm.weight"
]

# Handle Q and K norm weights for attention layer norm if present
if f"transformer.blocks.{layer_i}.q_norm.weight" in loaded:
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = loaded[
f"transformer.blocks.{layer_i}.q_norm.weight"
]

if f"transformer.blocks.{layer_i}.k_norm.weight" in loaded:
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = loaded[
f"transformer.blocks.{layer_i}.k_norm.weight"
]

state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq

for k, v in state_dict.items():
Expand All @@ -131,13 +153,22 @@ def write_model(
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"

# Unsharded
# TODO: Deal with weight-tying
# Weight-tying is handled by using the appropriate weights for lm_head based on the weight_tying configuration
state_dict = {
"model.embed_tokens.weight": loaded["transformer.wte.weight"],
"lm_head.weight": loaded["transformer.ff_out.weight"]
if "transformer.ff_out.weight" in loaded
else loaded["transformer.wte.weight"],
}

# Add final layer norm if present
if "transformer.ln_f.weight" in loaded:
state_dict["model.norm.weight"] = loaded["transformer.ln_f.weight"]

# Handle weight-tying (lm_head weight may come from embedding or a separate layer)
if "transformer.ff_out.weight" in loaded and not olmo_config.get("weight_tying", True):
# Not using weight tying, use the ff_out weight
state_dict["lm_head.weight"] = loaded["transformer.ff_out.weight"]
else:
# Using weight tying or ff_out not present, use embedding weight
state_dict["lm_head.weight"] = loaded["transformer.wte.weight"]

for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
Expand Down