Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Added support for safetensors in `hf_olmo` conversion script.

## [v0.5.1](https://github.com/allenai/OLMo/releases/tag/v0.5.1) - 2024-10-17

### Added
Expand Down Expand Up @@ -45,7 +47,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Swapped in correct flan data mix.
- Fix bug where the attention norm, when applied before the attention block, was modifying the residual stream.
- Fixed `OLMo.from_checkpoint()` so that it correctly loads `olmo_core` and `torch_new` style checkpoints.
- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout
- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout


## [v0.4.0](https://github.com/allenai/OLMo/releases/tag/v0.4.0) - 2024-07-11
Expand Down
15 changes: 11 additions & 4 deletions hf_olmo/convert_olmo_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast
from olmo import ModelConfig, Tokenizer, TrainConfig
from olmo.checkpoint import build_sharded_checkpointer
from olmo.safetensors_util import safetensors_file_to_state_dict
from olmo.util import _get_s3_client

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,10 +68,14 @@ def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
# For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly.
# So, we explicitly store the model with the expected prefix.

old_model_path = os.path.join(checkpoint_dir, "model.pt")
new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
if os.path.exists(old_model_path := os.path.join(checkpoint_dir, "model.pt")):
state_dict = torch.load(old_model_path, map_location="cpu")
elif os.path.exists(old_model_path := os.path.join(checkpoint_dir, "model.safetensors")):
state_dict = safetensors_file_to_state_dict(old_model_path, map_location="cpu")
else:
raise ValueError(f"No model found in {checkpoint_dir}")

state_dict = torch.load(old_model_path, map_location="cpu")
new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin")

# this takes care of the case where the model was saved with a different prefix,
# typically due to unsharding.
Expand Down Expand Up @@ -233,7 +238,9 @@ def upload_local_checkpoint(local_checkpoint_dir: str, destination_dir: str):


def maybe_unshard(checkpoint_dir: str):
if os.path.exists(os.path.join(checkpoint_dir, "model.pt")):
if os.path.exists(os.path.join(checkpoint_dir, "model.pt")) or os.path.exists(
os.path.join(checkpoint_dir, "model.safetensors")
):
return

print(f"Unsharding {checkpoint_dir}...")
Expand Down