diff --git a/src/art/megatron/routing_replay.py b/src/art/megatron/routing_replay.py index 2e4db5c73..acc8ff2aa 100644 --- a/src/art/megatron/routing_replay.py +++ b/src/art/megatron/routing_replay.py @@ -744,16 +744,28 @@ def install_router_patches(self, model_chunks: list[Any]) -> None: topk = int(getattr(module, "topk")) original_routing = module.routing + def _prepare_native_target_for_bound_router( + _controller: MoeRoutingReplayController = self, + _router_key: str = router_key, + ) -> None: + _controller._prepare_native_target_for_router(_router_key) + + prepare_native_target = torch.compiler.disable( + _prepare_native_target_for_bound_router + ) + def _routing_with_replay_target( router_module: Any, *args: Any, - _controller: MoeRoutingReplayController = self, - _router_key: str = router_key, _original_routing: Any = original_routing, + _prepare_native_target: Any = prepare_native_target, **kwargs: Any, ) -> Any: del router_module - _controller._prepare_native_target_for_router(_router_key) + # Target selection mutates Python replay cursors and Megatron's + # RouterReplay state; keep it out of Dynamo while preserving + # compiled routing compute below. + _prepare_native_target() return _original_routing(*args, **kwargs) module.routing = types.MethodType(_routing_with_replay_target, module)