diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index bafc69ed..1bbdd31e 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -32,6 +32,7 @@ Please keep the lists sorted alphabetically. * Bikram Pandit * Eric Vollenweider * Fabian Jenelten +* Lingshang Kong * Lorenzo Terenzi * Marko Bjelonic * Matthijs van der Boon diff --git a/README.md b/README.md index c8bae884..ecdd63b3 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ The package supports the following logging frameworks which can be configured th * Tensorboard: https://www.tensorflow.org/tensorboard/ * Weights & Biases: https://wandb.ai/site * Neptune: https://docs.neptune.ai/ +* Swanlab: https://docs.swanlab.cn/en/ For a demo configuration of PPO, please check the [example_config.yaml](config/example_config.yaml) file. diff --git a/rsl_rl/utils/logger.py b/rsl_rl/utils/logger.py index 2dc29231..a1b77e09 100644 --- a/rsl_rl/utils/logger.py +++ b/rsl_rl/utils/logger.py @@ -64,7 +64,7 @@ def __init__( self._store_code_state() # Log configuration - if self.writer and not self.disable_logs and self.logger_type in ["wandb", "neptune"]: + if self.writer and not self.disable_logs and self.logger_type in ["wandb", "neptune", "swanlab"]: self.writer.store_config(env_cfg, self.cfg) def process_env_step( @@ -251,6 +251,10 @@ def _prepare_logging_writer(self) -> None: from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) + elif self.logger_type == "swanlab": + from rsl_rl.utils.swanlab_utils import SwanlabSummaryWriter + + self.writer = SwanlabSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) else: raise ValueError("Logger type not found. Please choose 'wandb', 'neptune', or 'tensorboard'.") else: diff --git a/rsl_rl/utils/swanlab_utils.py b/rsl_rl/utils/swanlab_utils.py new file mode 100644 index 00000000..00046f82 --- /dev/null +++ b/rsl_rl/utils/swanlab_utils.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import os +from dataclasses import asdict +from torch.utils.tensorboard import SummaryWriter + +try: + import swanlab +except ModuleNotFoundError: + raise ModuleNotFoundError("swanlab package is required to log to Swanlab.") from None + + +class SwanlabSummaryWriter(SummaryWriter): + """Summary writer for Swanlab.""" + + def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: + super().__init__(log_dir, flush_secs) + + # Get the run name + run_name = os.path.split(log_dir)[-1] + + # Get swanlab project and entity + try: + project = cfg["swanlab_project"] + except KeyError: + raise KeyError("Please specify swanlab_project in the runner config, e.g. legged_gym.") from None + try: + entity = os.environ["SWANLAB_USERNAME"] + except KeyError: + entity = None + + # Initialize swanlab + swanlab.init(project=project, entity=entity, name=run_name) + swanlab.config.update({"log_dir": log_dir}) + + def store_config(self, env_cfg: dict | object, train_cfg: dict) -> None: + swanlab.config.update({"runner_cfg": train_cfg}) + swanlab.config.update({"policy_cfg": train_cfg["policy"]}) + swanlab.config.update({"alg_cfg": train_cfg["algorithm"]}) + try: + swanlab.config.update({"env_cfg": env_cfg.to_dict()}) + except Exception: + swanlab.config.update({"env_cfg": asdict(env_cfg)}) + + def add_scalar( + self, + tag: str, + scalar_value: float, + global_step: int | None = None, + walltime: float | None = None, + new_style: bool = False, + ) -> None: + super().add_scalar( + tag, + scalar_value, + global_step=global_step, + walltime=walltime, + new_style=new_style, + ) + swanlab.log({tag: scalar_value}, step=global_step) + + def stop(self) -> None: + swanlab.finish() + + def save_model(self, model_path: str, it: int) -> None: + raise NotImplementedError("Swanlab does not support saving files currently.") + + def save_file(self, path: str) -> None: + raise NotImplementedError("Swanlab does not support saving files currently.")