Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions amp_rsl_rl/utils/wandb_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,45 @@
import os
import warnings
import wandb
from rsl_rl.utils.wandb_utils import WandbSummaryWriter as RslWandbSummaryWriter
from torch.utils.tensorboard import SummaryWriter

class WandbSummaryWriter(RslWandbSummaryWriter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None:
SummaryWriter.__init__(self, log_dir, flush_secs)

# Get the run name
run_name = os.path.split(log_dir)[-1]

# Thanks to https://github.com/leggedrobotics/rsl_rl/pull/80/
try:
project = cfg['wandb_kwargs']["project"]
except KeyError:
raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.") from None

try:
entity = cfg['wandb_kwargs']["entity"]
except KeyError:
entity = None
warnings.warn("wandb_entity not specified in the runner config.")

try:
group = cfg['wandb_kwargs']["group"]
except KeyError:
warnings.warn("wandb_group not specified in the runner config. Using default group.")

# Initialize wandb
wandb.init(
project=project,
entity=entity,
name=run_name,
group=group,
notes=cfg['wandb_kwargs']['notes'],
)

# Add log directory to wandb
wandb.config.update({"log_dir": log_dir})

self.video_files = []

# To save video files to wandb explicitly
Expand Down