Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion tests/test_continue_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,19 @@ def test_continue_train():
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth")))

command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path}"
# Continue training from the best model
command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path} --coqpit.run_eval_steps=1"
run_cli(command_continue)

assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth")))

# Continue training from the last checkpoint
for best in glob.glob(os.path.join(continue_path, "best_model*")):
os.remove(best)
run_cli(command_continue)

# Continue training from a specific checkpoint
restore_path = os.path.join(continue_path, "checkpoint_5.pth")
command_continue = f"python tests/utils/train_mnist.py --restore_path {restore_path}"
run_cli(command_continue)
shutil.rmtree(continue_path)
5 changes: 4 additions & 1 deletion trainer/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def save_best_model(
save_func=None,
**kwargs,
):
if current_loss < best_loss:
use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None
if (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or (
not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"]
):
best_model_name = f"best_model_{current_step}.pth"
checkpoint_path = os.path.join(out_path, best_model_name)
logger.info(" > BEST MODEL : %s", checkpoint_path)
Expand Down
15 changes: 11 additions & 4 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def __init__( # pylint: disable=dangerous-default-value
self.epochs_done = 0
self.restore_step = 0
self.restore_epoch = 0
self.best_loss = float("inf")
self.best_loss = {"train_loss": float("inf"), "eval_loss": float("inf") if self.config.run_eval else None}
self.train_loader = None
self.test_loader = None
self.eval_loader = None
Expand Down Expand Up @@ -1724,8 +1724,15 @@ def _restore_best_loss(self):
logger.info(" > Restoring best loss from %s ...", os.path.basename(self.args.best_path))
ch = load_fsspec(self.args.restore_path, map_location="cpu")
if "model_loss" in ch:
self.best_loss = ch["model_loss"]
logger.info(" > Starting with loaded last best loss %f", self.best_loss)
if isinstance(ch["model_loss"], dict):
self.best_loss = ch["model_loss"]
# For backwards-compatibility:
elif isinstance(ch["model_loss"], float):
if self.config.run_eval:
self.best_loss = {"train_loss": None, "eval_loss": ch["model_loss"]}
else:
self.best_loss = {"train_loss": ch["model_loss"], "eval_loss": None}
logger.info(" > Starting with loaded last best loss %s", self.best_loss)

def test(self, model=None, test_samples=None) -> None:
"""Run evaluation steps on the test data split. You can either provide the model and the test samples
Expand Down Expand Up @@ -1907,7 +1914,7 @@ def save_best_model(self) -> None:

# save the model and update the best_loss
self.best_loss = save_best_model(
eval_loss if eval_loss else train_loss,
{"train_loss": train_loss, "eval_loss": eval_loss},
self.best_loss,
self.config,
self.model,
Expand Down