diff --git a/brax/training/checkpoint.py b/brax/training/checkpoint.py index cc64d02a..be332d25 100644 --- a/brax/training/checkpoint.py +++ b/brax/training/checkpoint.py @@ -225,6 +225,8 @@ def load_config( if init_fn_name not in loaded_dict['network_factory_kwargs']: continue init_fn_name_ = loaded_dict['network_factory_kwargs'][init_fn_name] + if init_fn_name_ is None: + continue loaded_dict['network_factory_kwargs'][init_fn_name] = ( networks.KERNEL_INITIALIZER[init_fn_name_] )