--- a +++ b/HomoAug/utils/ray_tools.py @@ -0,0 +1,90 @@ +from asyncio import Event +from typing import Tuple + +import ray +from ray.actor import ActorHandle +from tqdm import tqdm + + +def split_list(_list, n): + chunk_size = (len(_list) - 1) // n + 1 + chunks = [_list[i * chunk_size : (i + 1) * chunk_size] for i in range(n)] + return chunks + + +@ray.remote +class ProgressBarActor: + counter: int + delta: int + event: Event + + def __init__(self) -> None: + self.counter = 0 + self.delta = 0 + self.event = Event() + + def update(self, num_items_completed: int) -> None: + """Updates the ProgressBar with the incremental + number of items that were just completed. + """ + self.counter += num_items_completed + self.delta += num_items_completed + self.event.set() + + async def wait_for_update(self) -> Tuple[int, int]: + """Blocking call. + + Waits until somebody calls `update`, then returns a tuple of + the number of updates since the last call to + `wait_for_update`, and the total number of completed items. + """ + await self.event.wait() + self.event.clear() + saved_delta = self.delta + self.delta = 0 + return saved_delta, self.counter + + def get_counter(self) -> int: + """ + Returns the total number of complete items. + """ + return self.counter + + +class ProgressBar: + progress_actor: ActorHandle + total: int + description: str + pbar: tqdm + + def __init__(self, total: int, description: str = ""): + # Ray actors don't seem to play nice with mypy, generating + # a spurious warning for the following line, + # which we need to suppress. The code is fine. + self.progress_actor = ProgressBarActor.remote() # type: ignore + self.total = total + print("Total:", total) + self.description = description + + @property + def actor(self) -> ActorHandle: + """Returns a reference to the remote `ProgressBarActor`. + + When you complete tasks, call `update` on the actor. + """ + return self.progress_actor + + def print_until_done(self) -> None: + """Blocking call. + + Do this after starting a series of remote Ray tasks, to which you've + passed the actor handle. Each of them calls `update` on the actor. + When the progress meter reaches 100%, this method returns. + """ + pbar = tqdm(desc=self.description, total=self.total, ncols=80) + while True: + delta, counter = ray.get(self.actor.wait_for_update.remote()) + pbar.update(delta) + if counter >= self.total: + pbar.close() + return \ No newline at end of file