aidial_sdk/utils/_cancel_scope.py (48 lines of code) (raw):

import asyncio from asyncio import exceptions from typing import Optional, Set class CancelScope: """ Async context manager that enforces cancellation of all tasks created within its scope when either: 1. the parent task has been cancelled or has thrown an exception or 2. any of the tasks created within the scope has thrown an exception. """ def __init__(self): self._tasks: Set[asyncio.Task] = set() self._on_completed_fut: Optional[asyncio.Future] = None self._cancelling: bool = False async def __aenter__(self): return self async def __aexit__(self, exc_type, exc, tb): cancelled_error = ( exc if isinstance(exc, exceptions.CancelledError) else None ) # If the parent task has thrown an exception, cancel all the tasks if exc_type is not None: self._cancel_tasks() while self._tasks: if self._on_completed_fut is None: self._on_completed_fut = asyncio.Future() # If the parent task was cancelled, cancel all the tasks try: await self._on_completed_fut except exceptions.CancelledError as ex: cancelled_error = ex self._cancel_tasks() self._on_completed_fut = None if cancelled_error: raise cancelled_error def create_task(self, coro): task = asyncio.create_task(coro) task.add_done_callback(self._on_task_done) self._tasks.add(task) return task def _cancel_tasks(self): if not self._cancelling: self._cancelling = True for t in self._tasks: if not t.done(): t.cancel() def _on_task_done(self, task): self._tasks.discard(task) if ( self._on_completed_fut is not None and not self._on_completed_fut.done() and not self._tasks ): self._on_completed_fut.set_result(True) # If any of the tasks was cancelled, cancel all the tasks if task.exception() is not None: self._cancel_tasks()