Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,11 @@ class TrainConfig(BaseConfig):
Weights & Biases configuration.
"""

log_to_tensorbard: Optional[str] = None
Comment thread
lyuwen marked this conversation as resolved.
Outdated
"""
Path to tensorbard log output directory.
"""

speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
"""
Speed monitor configuration.
Expand Down
56 changes: 56 additions & 0 deletions olmo/tensorbard_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os.path as osp
from torch.utils.tensorboard import SummaryWriter


# create a new class inheriting from SummaryWriter
class NewSummaryWriter(SummaryWriter):

def __init__(self, log_dir=None, comment="", **kwargs):
super().__init__(log_dir, comment, **kwargs)


# create a new function that will take dictionary as input
# and uses built-in add_scalar() function
# that function combines all plots into one subgroup by a tag
def add_scalar_dict(self, dictionary, global_step, tag=None):
for name, val in dictionary.items():
if tag is not None:
name = osp.join(tag, name)
self.add_scalar(name, val, global_step)


writer = None
Comment thread
lyuwen marked this conversation as resolved.
Outdated


def init(log_dir=None):
global writer
writer = NewSummaryWriter(log_dir=log_dir)


def log(dictionary, global_step, tag=None):
global writer
if writer is None:
return
writer.add_scalar_dict(dictionary, global_step, tag)


def write_args_to_tensorboard(args, iteration, prefix=""):
"""Write arguments to tensorboard."""
global writer
if writer:
if prefix:
prefix = f"{prefix}."
for arg in args.keys():
arg_text = f"{prefix}{arg}"
if isinstance(args[arg], dict):
write_args_to_tensorboard(args[arg], iteration, prefix=arg_text)
else:
writer.add_text(arg_text, str(args[arg]), global_step=iteration)


def finish():
global writer
if writer is None:
return
writer.close()

9 changes: 9 additions & 0 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
synchronize_value,
)
from .util import upload
import olmo.tensorbard_logger as tblogger

__all__ = ["SpeedMonitor", "LRMonitor", "Trainer"]

Expand Down Expand Up @@ -1126,6 +1127,8 @@ def fit(self):
eval_metrics = self.eval()
if wandb.run is not None:
wandb.log(eval_metrics, step=self.global_step)
if tblogger.writer is not None:
tblogger.log(eval_metrics, global_step=self.global_step)

# Set model to 'train' mode.
self.dist_model.train()
Expand All @@ -1141,6 +1144,8 @@ def fit(self):
self.log_metrics_to_console("Pre-train system metrics", sys_metrics)
if wandb.run is not None:
wandb.log(sys_metrics, step=0)
if tblogger.writer is not None:
tblogger.log(sys_metrics, global_step=0)

# Python Profiler stuff
if self.cfg.python_profiling:
Expand Down Expand Up @@ -1251,6 +1256,8 @@ def on_trace_ready(p):
and self.global_step % self.cfg.wandb.log_interval == 0
):
wandb.log(metrics, step=self.global_step)
if tblogger.writer is not None:
tblogger.log(metrics, global_step=self.global_step)

# Check if/when run should be canceled.
if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
Expand Down Expand Up @@ -1317,6 +1324,8 @@ def on_trace_ready(p):
# Log metrics to W&B.
if wandb.run is not None:
wandb.log(eval_metrics, step=self.global_step)
if tblogger.writer is not None:
tblogger.log(eval_metrics, global_step=self.global_step)

# Reset speed monitor so that we don't count the time taken to run evaluations.
speed_monitor.reset()
Expand Down
6 changes: 6 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
log_extra_field,
prepare_cli_environment,
)
import olmo.tensorbard_logger as tblogger

log = logging.getLogger("train")

Expand Down Expand Up @@ -122,6 +123,11 @@ def main(cfg: TrainConfig) -> None:
tags=cfg.wandb.tags,
config=cfg.asdict(exclude=["wandb"]),
)
if cfg.log_to_tensorbard:
log_dir = Path(cfg.log_to_tensorbard)
log_dir.mkdir(parents=True, exist_ok=True)
tblogger.init(log_dir=log_dir)
tblogger.write_args_to_tensorboard(cfg.asdict())

barrier()

Expand Down