diff --git a/tests/worker/tpu_worker_test.py b/tests/worker/tpu_worker_test.py index 77cd174b15..09012230f8 100644 --- a/tests/worker/tpu_worker_test.py +++ b/tests/worker/tpu_worker_test.py @@ -186,6 +186,82 @@ def test_init_with_profiler_on_other_ranks(self, mock_vllm_config): distributed_init_method="test_method") assert worker.profile_dir is None + def test_init_with_gcs_profiler_on_rank_zero(self, mock_vllm_config): + """Tests that a GCS profile dir is set directly without os.makedirs.""" + mock_vllm_config.profiler_config.profiler = "torch" + mock_vllm_config.profiler_config.torch_profiler_dir = ( + "gs://my-bucket/profiles" + ) + with patch("tpu_inference.worker.tpu_worker.os.makedirs") as mock_mkdirs: + worker = TPUWorker( + vllm_config=mock_vllm_config, + local_rank=0, + rank=0, + distributed_init_method="test_method", + ) + assert worker.profile_dir == "gs://my-bucket/profiles" + mock_mkdirs.assert_not_called() + + @patch("tpu_inference.worker.tpu_worker.os.makedirs") + @patch("tpu_inference.tpu_info.get_num_cores_per_chip", return_value=2) + def test_init_pp_with_gcs_path_skips_makedirs( + self, mock_get_cores, mock_mkdirs, mock_vllm_config + ): + """Tests that os.makedirs is NOT called for GCS paths in PP mode.""" + mock_vllm_config.profiler_config.profiler = "torch" + mock_vllm_config.profiler_config.torch_profiler_dir = ( + "gs://my-bucket/profiles" + ) + mock_vllm_config.parallel_config.pipeline_parallel_size = 2 + worker = TPUWorker( + vllm_config=mock_vllm_config, + local_rank=0, + rank=0, + distributed_init_method="test_method", + ) + assert worker.profile_dir == ( + "gs://my-bucket/profiles/pprank_0_ppworldsize_2" + ) + mock_mkdirs.assert_not_called() + + @patch("tpu_inference.worker.tpu_worker.os.makedirs") + @patch("tpu_inference.tpu_info.get_num_cores_per_chip", return_value=2) + def test_init_pp_with_local_path_calls_makedirs( + self, mock_get_cores, mock_mkdirs, mock_vllm_config + ): + """Tests that os.makedirs IS called for local paths in PP mode.""" + mock_vllm_config.profiler_config.profiler = "torch" + mock_vllm_config.profiler_config.torch_profiler_dir = "/tmp/profiles" + mock_vllm_config.parallel_config.pipeline_parallel_size = 2 + worker = TPUWorker( + vllm_config=mock_vllm_config, + local_rank=0, + rank=0, + distributed_init_method="test_method", + ) + assert worker.profile_dir == "/tmp/profiles/pprank_0_ppworldsize_2" + mock_mkdirs.assert_called_once_with( + "/tmp/profiles/pprank_0_ppworldsize_2", exist_ok=True + ) + + @patch("tpu_inference.worker.tpu_worker.jax") + @patch.dict("os.environ", {"PYTHON_TRACER_LEVEL": "1"}, clear=True) + def test_profile_start_with_gcs_path(self, mock_jax, mock_vllm_config): + """Tests that a GCS path is passed directly to jax.profiler.start_trace.""" + worker = TPUWorker( + vllm_config=mock_vllm_config, + local_rank=0, + rank=0, + distributed_init_method="test", + ) + worker.profile_dir = "gs://my-bucket/profiles" + + worker.profile(is_start=True) + + mock_jax.profiler.start_trace.assert_called_once() + args, _ = mock_jax.profiler.start_trace.call_args + assert args[0] == "gs://my-bucket/profiles" + # # --- Device and Cache Initialization Tests --- # diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 53228ba546..d8f5eca00d 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -157,7 +157,8 @@ def __init__( profiler_config.torch_profiler_dir, f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}" ) - os.makedirs(self.profile_dir, exist_ok=True) + if not self.profile_dir.startswith("gs://"): + os.makedirs(self.profile_dir, exist_ok=True) use_jax_profiler_server = envs.USE_JAX_PROFILER_SERVER # Only one instance of profiler is allowed