diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index b43b7d63e..befc707b9 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -125,6 +125,7 @@ def fit( log_every_n_steps: int = 1, gradient_clip_val: Optional[float] = None, distribution_strategy: Optional[str] = "ddp", + fast_dev_run: Optional[int] = None, **trainer_kwargs: Any, ) -> None: """Fit `StandardModel` using `pytorch_lightning.Trainer`.""" @@ -162,6 +163,7 @@ def fit( log_every_n_steps=log_every_n_steps, gradient_clip_val=gradient_clip_val, distribution_strategy=distribution_strategy, + fast_dev_run=fast_dev_run, **trainer_kwargs, ) @@ -174,16 +176,17 @@ def fit( pass # Load weights from best-fit model after training if possible - if has_early_stopping & has_model_checkpoint: - for callback in callbacks: - if isinstance(callback, ModelCheckpoint): - checkpoint_callback = callback - self.load_state_dict( - torch.load( - checkpoint_callback.best_model_path, weights_only=False - )["state_dict"] - ) - self.info("Best-fit weights from EarlyStopping loaded.") + if fast_dev_run is None: + if has_early_stopping & has_model_checkpoint: + for callback in callbacks: + if isinstance(callback, ModelCheckpoint): + checkpoint_callback = callback + self.load_state_dict( + torch.load( + checkpoint_callback.best_model_path, weights_only=False + )["state_dict"] + ) + self.info("Best-fit weights from EarlyStopping loaded.") def _print_callbacks(self, callbacks: List[Callback]) -> None: callback_names = []