Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions gemma/gm/ckpts/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,7 @@ def load_params(
if sharding is not None and params is not None:
raise ValueError('`sharding` and `params` are mutually exclusive.')

ckpt = ocp.StandardCheckpointer()

metadata, path = _get_metadata_and_path(ckpt, path)
metadata, path, ckpt = _get_metadata_and_path(path)

metadata = _CheckpointTree.shape_dtype_struct_like(tree=metadata)

Expand Down Expand Up @@ -488,14 +486,20 @@ def _release_memory(x):


def _get_metadata_and_path(
ckpt: ocp.StandardCheckpointer,
path: epath.PathLike,
):
"""Returns the metadata of the checkpoint."""
path = epath.Path(path)

ckpt = ocp.StandardCheckpointer()
metadata = ckpt.metadata(path)

if metadata.item_metadata is None:
pytree_ckpt = ocp.PyTreeCheckpointer()
pytree_metadata = pytree_ckpt.metadata(path)
if pytree_metadata.item_metadata is not None:
return pytree_metadata.item_metadata.tree, path, pytree_ckpt

# Kauldron checkpoints structure is different, so the params are contained
# in a sub-directory
if (
Expand All @@ -509,7 +513,7 @@ def _get_metadata_and_path(
raise ValueError(f'No item metadata found in {path}')

metadata = metadata.item_metadata.tree # Normalize metadata
return metadata, path
return metadata, path, ckpt


def _as_shape_dtype_struct(tree):
Expand Down
12 changes: 7 additions & 5 deletions gemma/gm/utils/_dtype_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,18 @@ def decorated(
self: nn.Module,
name: str,
init_fn, # : Callable[..., Any],
shape: tuple[int, ...],
dtype: _DType | None = None,
*init_args,
**kwargs,
):
if _should_replace_dtype(module=self, stack=_dtypes_stack):
del dtype # The dtype is overwritten by the contextmanager
state = _dtypes_stack.stack[-1]
return param(self, name, init_fn, shape, **kwargs, dtype=state.dtype)
if len(init_args) >= 2:
init_args = (init_args[0], state.dtype) + init_args[2:]
if 'dtype' in kwargs or len(init_args) < 2:
kwargs['dtype'] = state.dtype
return param(self, name, init_fn, *init_args, **kwargs)
else:
return param(self, name, init_fn, shape, dtype, **kwargs)
return param(self, name, init_fn, *init_args, **kwargs)

nn.Module.param = decorated

Expand Down
Loading