diff --git a/olmo/config.py b/olmo/config.py index 6da7dc03d..811809ffc 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -1200,6 +1200,11 @@ class TrainConfig(BaseConfig): Weights & Biases configuration. """ + tensorboard_path: Optional[str] = None + """ + Path to tensorbard log output directory. + """ + speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig) """ Speed monitor configuration. diff --git a/olmo/tensorbard_logger.py b/olmo/tensorbard_logger.py new file mode 100644 index 000000000..d6708932c --- /dev/null +++ b/olmo/tensorbard_logger.py @@ -0,0 +1,43 @@ +import logging +import os.path as osp + +logger = logging.getLogger(__name__) + +try: + from torch.utils.tensorboard import SummaryWriter + HAS_TENSORBOARD = True +except ImportError: + HAS_TENSORBOARD = False + + +# create a new class inheriting from SummaryWriter +class TBNewSummaryWriter(SummaryWriter): + + def __init__(self, log_dir=None, comment="", **kwargs): + super().__init__(log_dir, comment, **kwargs) + + + # create a new function that will take dictionary as input + # and uses built-in add_scalar() function + # that function combines all plots into one subgroup by a tag + def add_scalar_dict(self, dictionary, global_step, tag=None): + for name, val in dictionary.items(): + if tag is not None: + name = osp.join(tag, name) + self.add_scalar(name, val, global_step) + + + def log(self, dictionary, global_step, tag=None): + self.add_scalar_dict(dictionary, global_step, tag) + + + def write_args_to_tensorboard(self, args, iteration, prefix=""): + """Write arguments to tensorboard.""" + if prefix: + prefix = f"{prefix}." + for arg in args.keys(): + arg_text = f"{prefix}{arg}" + if isinstance(args[arg], dict): + self.write_args_to_tensorboard(args[arg], iteration, prefix=arg_text) + else: + self.add_text(arg_text, str(args[arg]), global_step=iteration) diff --git a/olmo/train.py b/olmo/train.py index a4f9919c8..5891f7bee 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -58,6 +58,7 @@ synchronize_value, ) from .util import upload +from olmo.tensorbard_logger import TBNewSummaryWriter, HAS_TENSORBOARD __all__ = ["SpeedMonitor", "LRMonitor", "Trainer"] @@ -240,6 +241,15 @@ def __post_init__(self): self.loss_fn = fused_loss_fn else: raise NameError("`fused_loss_fn` is not defined. Please ensure that `flash_attn` is installed.") + self.logger = None + if self.cfg.tensorboard_path: + if HAS_TENSORBOARD: + log_dir = Path(self.cfg.tensorboard_path) + log_dir.mkdir(parents=True, exist_ok=True) + self.logger = TBNewSummaryWriter(log_dir=log_dir) + self.logger.write_args_to_tensorboard(self.cfg.asdict()) + else: + logger.warn("Failed to import tensorbard writer, will not write tensorbard logs.") @property def dataset(self) -> IterableDataset: @@ -1126,6 +1136,8 @@ def fit(self): eval_metrics = self.eval() if wandb.run is not None: wandb.log(eval_metrics, step=self.global_step) + if self.logger is not None: + self.logger.log(eval_metrics, global_step=self.global_step) # Set model to 'train' mode. self.dist_model.train() @@ -1141,6 +1153,8 @@ def fit(self): self.log_metrics_to_console("Pre-train system metrics", sys_metrics) if wandb.run is not None: wandb.log(sys_metrics, step=0) + if self.logger is not None: + self.logger.log(sys_metrics, global_step=0) # Python Profiler stuff if self.cfg.python_profiling: @@ -1251,6 +1265,8 @@ def on_trace_ready(p): and self.global_step % self.cfg.wandb.log_interval == 0 ): wandb.log(metrics, step=self.global_step) + if self.logger is not None: + self.logger.log(metrics, global_step=self.global_step) # Check if/when run should be canceled. if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0: @@ -1317,6 +1333,8 @@ def on_trace_ready(p): # Log metrics to W&B. if wandb.run is not None: wandb.log(eval_metrics, step=self.global_step) + if self.logger is not None: + self.logger.log(eval_metrics, global_step=self.global_step) # Reset speed monitor so that we don't count the time taken to run evaluations. speed_monitor.reset() @@ -1387,6 +1405,8 @@ def close(self, exit_code: int = 0) -> None: gc.disable() if wandb.run is not None: wandb.finish(exit_code=exit_code, quiet=True) + if self.logger is not None: + self.logger.close() def __enter__(self) -> Trainer: return self