diff --git a/gemma/gm/ckpts/_checkpoint.py b/gemma/gm/ckpts/_checkpoint.py index 9449b9a6..ce71643a 100644 --- a/gemma/gm/ckpts/_checkpoint.py +++ b/gemma/gm/ckpts/_checkpoint.py @@ -227,9 +227,7 @@ def load_params( if sharding is not None and params is not None: raise ValueError('`sharding` and `params` are mutually exclusive.') - ckpt = ocp.StandardCheckpointer() - - metadata, path = _get_metadata_and_path(ckpt, path) + metadata, path, ckpt = _get_metadata_and_path(path) metadata = _CheckpointTree.shape_dtype_struct_like(tree=metadata) @@ -488,14 +486,20 @@ def _release_memory(x): def _get_metadata_and_path( - ckpt: ocp.StandardCheckpointer, path: epath.PathLike, ): """Returns the metadata of the checkpoint.""" path = epath.Path(path) + ckpt = ocp.StandardCheckpointer() metadata = ckpt.metadata(path) + if metadata.item_metadata is None: + pytree_ckpt = ocp.PyTreeCheckpointer() + pytree_metadata = pytree_ckpt.metadata(path) + if pytree_metadata.item_metadata is not None: + return pytree_metadata.item_metadata.tree, path, pytree_ckpt + # Kauldron checkpoints structure is different, so the params are contained # in a sub-directory if ( @@ -509,7 +513,7 @@ def _get_metadata_and_path( raise ValueError(f'No item metadata found in {path}') metadata = metadata.item_metadata.tree # Normalize metadata - return metadata, path + return metadata, path, ckpt def _as_shape_dtype_struct(tree): diff --git a/gemma/gm/utils/_dtype_params.py b/gemma/gm/utils/_dtype_params.py index e8b71565..44ffad19 100644 --- a/gemma/gm/utils/_dtype_params.py +++ b/gemma/gm/utils/_dtype_params.py @@ -72,16 +72,18 @@ def decorated( self: nn.Module, name: str, init_fn, # : Callable[..., Any], - shape: tuple[int, ...], - dtype: _DType | None = None, + *init_args, **kwargs, ): if _should_replace_dtype(module=self, stack=_dtypes_stack): - del dtype # The dtype is overwritten by the contextmanager state = _dtypes_stack.stack[-1] - return param(self, name, init_fn, shape, **kwargs, dtype=state.dtype) + if len(init_args) >= 2: + init_args = (init_args[0], state.dtype) + init_args[2:] + if 'dtype' in kwargs or len(init_args) < 2: + kwargs['dtype'] = state.dtype + return param(self, name, init_fn, *init_args, **kwargs) else: - return param(self, name, init_fn, shape, dtype, **kwargs) + return param(self, name, init_fn, *init_args, **kwargs) nn.Module.param = decorated