Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dev.commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
c72c4599012297cfbd1d57e006b544478b6bbf78
c0c4fdc45e9f0f8047e29f2cb8669613169752c3
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM
Submodule Megatron-LM updated 207 files
10 changes: 9 additions & 1 deletion src/megatron/bridge/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import datetime
import inspect
import os
import time
import warnings
Expand Down Expand Up @@ -676,14 +677,20 @@ def _initialize_distributed(
if parallel_state.model_parallel_is_initialized():
print("model parallel is already initialized")
else:
# Guard for main/dev branch submodule compat: hybrid_context_parallel was added in the dev branch.
# TODO: remove guard once the addition lands in main and Bridge pins the new main commit.
_init_mp_params = set(inspect.signature(parallel_state.initialize_model_parallel).parameters)
_optional_kwargs = {}
if "hybrid_context_parallel" in _init_mp_params:
_optional_kwargs["hybrid_context_parallel"] = model_config.hybrid_context_parallel

parallel_state.initialize_model_parallel(
tensor_model_parallel_size=model_config.tensor_model_parallel_size,
pipeline_model_parallel_size=model_config.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=model_config.virtual_pipeline_model_parallel_size,
pipeline_model_parallel_comm_backend=model_config.pipeline_model_parallel_comm_backend,
context_parallel_size=model_config.context_parallel_size,
hierarchical_context_parallel_sizes=model_config.hierarchical_context_parallel_sizes,
hybrid_context_parallel=model_config.hybrid_context_parallel,
expert_model_parallel_size=model_config.expert_model_parallel_size,
num_distributed_optimizer_instances=num_distributed_optimizer_instances,
expert_tensor_parallel_size=model_config.expert_tensor_parallel_size,
Expand All @@ -696,6 +703,7 @@ def _initialize_distributed(
use_sharp=dist_config.use_sharp,
high_priority_stream_groups=dist_config.high_priority_stream_groups,
sharp_enabled_group=dist_config.sharp_enabled_group,
**_optional_kwargs,
)
if get_rank_safe() == 0:
print(
Expand Down
14 changes: 9 additions & 5 deletions src/megatron/bridge/training/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@
from typing import Any, Optional

import torch
from megatron.core.energy_monitor import EnergyMonitor
from megatron.core.timers import Timers
from megatron.core.utils import StragglerDetector
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.tensorboard.writer import SummaryWriter


# TODO: Remove try/except once `get_async_strategy` lands in mcore dev.
# The function was added to mcore main but has not yet been merged into dev.
# TODO: Remove try/except guards once these land in mcore dev.
try:
from megatron.core.dist_checkpointing.strategies.torch import get_async_strategy
except ImportError:
get_async_strategy = None # type: ignore[assignment]

try:
from megatron.core.energy_monitor import EnergyMonitor
except ImportError:
EnergyMonitor = None # type: ignore[assignment]

from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.nvrx_straggler import NVRxStragglerDetectionManager
from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer
Expand Down Expand Up @@ -144,7 +147,7 @@ def __init__(self) -> None:
self._async_calls_queue: Optional[Any] = None
self._nvrx_straggler_manager: Optional[NVRxStragglerDetectionManager] = None
self._nvrx_straggler_created: bool = False
self._energy_monitor: Optional[EnergyMonitor] = None
self._energy_monitor: Optional[Any] = None
self._energy_monitor_created: bool = False

@property
Expand Down Expand Up @@ -440,13 +443,14 @@ def nvrx_straggler_manager(self) -> Optional[NVRxStragglerDetectionManager]:
return self._nvrx_straggler_manager

@property
def energy_monitor(self) -> Optional[EnergyMonitor]:
def energy_monitor(self) -> Optional[Any]:
"""The EnergyMonitor instance for tracking energy consumption."""
if (
not self._energy_monitor_created
and self._energy_monitor is None
and self.cfg is not None
and self.cfg.logger.log_energy
and EnergyMonitor is not None
):
self._energy_monitor = EnergyMonitor()
self._energy_monitor_created = True
Expand Down
Loading
Loading