Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions rdagent/log/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,12 @@ def add_duration(self, duration: timedelta) -> None:
def is_timeout(self) -> bool:
if self.started and self.target_time is not None:
self.update_remain_time()
if datetime.now() > self.target_time:
return True
return self._remain_time_duration == timedelta(0)
return False

def update_remain_time(self) -> None:
if self.started and self.target_time is not None:
self._remain_time_duration = self.target_time - datetime.now()
self._remain_time_duration = max(self.target_time - datetime.now(), timedelta(0))
return None

def remain_time(self) -> timedelta | None:
Expand Down
4 changes: 2 additions & 2 deletions rdagent/scenarios/data_science/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,10 @@ def record(self, prev_out: dict[str, Any]):
) # backup when upper code line is killed when running
self.timer.add_duration(datetime.now() - start_archive_datetime)

def _check_exit_conditions_on_step(self, loop_id: Optional[int] = None, step_id: Optional[int] = None):
async def _check_exit_conditions_on_step(self, loop_id: Optional[int] = None, step_id: Optional[int] = None):
if step_id not in [self.steps.index("running"), self.steps.index("feedback")]:
# pass the check for running and feedbacks since they are very likely to be finished soon.
super()._check_exit_conditions_on_step(loop_id=loop_id, step_id=step_id)
await super()._check_exit_conditions_on_step(loop_id=loop_id, step_id=step_id)

@classmethod
def load(
Expand Down
19 changes: 11 additions & 8 deletions rdagent/utils/workflow/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(self) -> None:
self.step_n: Optional[int] = None # remain step count

self.semaphores: dict[str, asyncio.Semaphore] = {}
self._step_n_lock: asyncio.Lock = asyncio.Lock()

def get_unfinished_loop_cnt(self, next_loop: int) -> int:
n = 0
Expand Down Expand Up @@ -168,19 +169,20 @@ def close_pbar(self) -> None:
self._pbar.close()
del self._pbar

def _check_exit_conditions_on_step(self, loop_id: Optional[int] = None, step_id: Optional[int] = None) -> None:
async def _check_exit_conditions_on_step(self, loop_id: Optional[int] = None, step_id: Optional[int] = None) -> None:
"""Check if the loop should continue or terminate.

Raises
------
LoopTerminationException
When conditions indicate that the loop should terminate
"""
# Check step count limitation
if self.step_n is not None:
if self.step_n <= 0:
raise self.LoopTerminationError("Step count reached")
self.step_n -= 1
# Check step count limitation — guarded by lock to prevent race under parallel steps
async with self._step_n_lock:
if self.step_n is not None:
if self.step_n <= 0:
raise self.LoopTerminationError("Step count reached")
self.step_n -= 1

# Check timer timeout
if self.timer.started:
Expand Down Expand Up @@ -305,7 +307,7 @@ async def _run_step(self, li: int, force_subproc: bool = False) -> None:
# it has been executed successfully
self.dump(self.session_folder / f"{li}" / f"{si}_{name}")

self._check_exit_conditions_on_step(loop_id=li, step_id=si)
await self._check_exit_conditions_on_step(loop_id=li, step_id=si)
else:
logger.warning(f"Step forward {si} of loop {li} is skipped.")

Expand Down Expand Up @@ -528,14 +530,15 @@ def load(
def __getstate__(self) -> dict[str, Any]:
res = {}
for k, v in self.__dict__.items():
if k not in ["queue", "semaphores", "_pbar"]:
if k not in ["queue", "semaphores", "_pbar", "_step_n_lock"]:
res[k] = v
return res

def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)
self.queue = asyncio.Queue()
self.semaphores = {}
self._step_n_lock = asyncio.Lock()


def kill_subprocesses() -> None:
Expand Down
4 changes: 2 additions & 2 deletions rdagent/utils/workflow/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def log_workflow_state(self) -> None:
if self.loop_base.timer.started:
remain_time = self.loop_base.timer.remain_time()
assert remain_time is not None
mlflow.log_metric("remain_time", remain_time.total_seconds())
mlflow.log_metric("remain_time", max(remain_time.total_seconds(), 0.0))
mlflow.log_metric(
"remain_percent",
remain_time / self.loop_base.timer.all_duration * 100,
max(remain_time / self.loop_base.timer.all_duration * 100, 0.0),
)

# Keep only the log_workflow_state method as it's the primary entry point now
Expand Down
Loading