sync.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. import asyncio
  2. import asyncio.coroutines
  3. import contextvars
  4. import functools
  5. import inspect
  6. import os
  7. import sys
  8. import threading
  9. import warnings
  10. import weakref
  11. from concurrent.futures import Future, ThreadPoolExecutor
  12. from typing import (
  13. TYPE_CHECKING,
  14. Any,
  15. Awaitable,
  16. Callable,
  17. Coroutine,
  18. Dict,
  19. Generic,
  20. List,
  21. Optional,
  22. TypeVar,
  23. Union,
  24. overload,
  25. )
  26. from .current_thread_executor import CurrentThreadExecutor
  27. from .local import Local
  28. if sys.version_info >= (3, 10):
  29. from typing import ParamSpec
  30. else:
  31. from typing_extensions import ParamSpec
  32. if TYPE_CHECKING:
  33. # This is not available to import at runtime
  34. from _typeshed import OptExcInfo
  35. _F = TypeVar("_F", bound=Callable[..., Any])
  36. _P = ParamSpec("_P")
  37. _R = TypeVar("_R")
  38. def _restore_context(context: contextvars.Context) -> None:
  39. # Check for changes in contextvars, and set them to the current
  40. # context for downstream consumers
  41. for cvar in context:
  42. cvalue = context.get(cvar)
  43. try:
  44. if cvar.get() != cvalue:
  45. cvar.set(cvalue)
  46. except LookupError:
  47. cvar.set(cvalue)
  48. # Python 3.12 deprecates asyncio.iscoroutinefunction() as an alias for
  49. # inspect.iscoroutinefunction(), whilst also removing the _is_coroutine marker.
  50. # The latter is replaced with the inspect.markcoroutinefunction decorator.
  51. # Until 3.12 is the minimum supported Python version, provide a shim.
  52. if hasattr(inspect, "markcoroutinefunction"):
  53. iscoroutinefunction = inspect.iscoroutinefunction
  54. markcoroutinefunction: Callable[[_F], _F] = inspect.markcoroutinefunction
  55. else:
  56. iscoroutinefunction = asyncio.iscoroutinefunction # type: ignore[assignment]
  57. def markcoroutinefunction(func: _F) -> _F:
  58. func._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore
  59. return func
  60. class ThreadSensitiveContext:
  61. """Async context manager to manage context for thread sensitive mode
  62. This context manager controls which thread pool executor is used when in
  63. thread sensitive mode. By default, a single thread pool executor is shared
  64. within a process.
  65. The ThreadSensitiveContext() context manager may be used to specify a
  66. thread pool per context.
  67. This context manager is re-entrant, so only the outer-most call to
  68. ThreadSensitiveContext will set the context.
  69. Usage:
  70. >>> import time
  71. >>> async with ThreadSensitiveContext():
  72. ... await sync_to_async(time.sleep, 1)()
  73. """
  74. def __init__(self):
  75. self.token = None
  76. async def __aenter__(self):
  77. try:
  78. SyncToAsync.thread_sensitive_context.get()
  79. except LookupError:
  80. self.token = SyncToAsync.thread_sensitive_context.set(self)
  81. return self
  82. async def __aexit__(self, exc, value, tb):
  83. if not self.token:
  84. return
  85. executor = SyncToAsync.context_to_thread_executor.pop(self, None)
  86. if executor:
  87. executor.shutdown()
  88. SyncToAsync.thread_sensitive_context.reset(self.token)
  89. class AsyncToSync(Generic[_P, _R]):
  90. """
  91. Utility class which turns an awaitable that only works on the thread with
  92. the event loop into a synchronous callable that works in a subthread.
  93. If the call stack contains an async loop, the code runs there.
  94. Otherwise, the code runs in a new loop in a new thread.
  95. Either way, this thread then pauses and waits to run any thread_sensitive
  96. code called from further down the call stack using SyncToAsync, before
  97. finally exiting once the async task returns.
  98. """
  99. # Keeps a reference to the CurrentThreadExecutor in local context, so that
  100. # any sync_to_async inside the wrapped code can find it.
  101. executors: "Local" = Local()
  102. # When we can't find a CurrentThreadExecutor from the context, such as
  103. # inside create_task, we'll look it up here from the running event loop.
  104. loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {}
  105. def __init__(
  106. self,
  107. awaitable: Union[
  108. Callable[_P, Coroutine[Any, Any, _R]],
  109. Callable[_P, Awaitable[_R]],
  110. ],
  111. force_new_loop: bool = False,
  112. ):
  113. if not callable(awaitable) or (
  114. not iscoroutinefunction(awaitable)
  115. and not iscoroutinefunction(getattr(awaitable, "__call__", awaitable))
  116. ):
  117. # Python does not have very reliable detection of async functions
  118. # (lots of false negatives) so this is just a warning.
  119. warnings.warn(
  120. "async_to_sync was passed a non-async-marked callable", stacklevel=2
  121. )
  122. self.awaitable = awaitable
  123. try:
  124. self.__self__ = self.awaitable.__self__ # type: ignore[union-attr]
  125. except AttributeError:
  126. pass
  127. self.force_new_loop = force_new_loop
  128. self.main_event_loop = None
  129. try:
  130. self.main_event_loop = asyncio.get_running_loop()
  131. except RuntimeError:
  132. # There's no event loop in this thread.
  133. pass
  134. def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
  135. __traceback_hide__ = True # noqa: F841
  136. if not self.force_new_loop and not self.main_event_loop:
  137. # There's no event loop in this thread. Look for the threadlocal if
  138. # we're inside SyncToAsync
  139. main_event_loop_pid = getattr(
  140. SyncToAsync.threadlocal, "main_event_loop_pid", None
  141. )
  142. # We make sure the parent loop is from the same process - if
  143. # they've forked, this is not going to be valid any more (#194)
  144. if main_event_loop_pid and main_event_loop_pid == os.getpid():
  145. self.main_event_loop = getattr(
  146. SyncToAsync.threadlocal, "main_event_loop", None
  147. )
  148. # You can't call AsyncToSync from a thread with a running event loop
  149. try:
  150. event_loop = asyncio.get_running_loop()
  151. except RuntimeError:
  152. pass
  153. else:
  154. if event_loop.is_running():
  155. raise RuntimeError(
  156. "You cannot use AsyncToSync in the same thread as an async event loop - "
  157. "just await the async function directly."
  158. )
  159. # Make a future for the return information
  160. call_result: "Future[_R]" = Future()
  161. # Make a CurrentThreadExecutor we'll use to idle in this thread - we
  162. # need one for every sync frame, even if there's one above us in the
  163. # same thread.
  164. old_executor = getattr(self.executors, "current", None)
  165. current_executor = CurrentThreadExecutor()
  166. self.executors.current = current_executor
  167. # Wrapping context in list so it can be reassigned from within
  168. # `main_wrap`.
  169. context = [contextvars.copy_context()]
  170. # Get task context so that parent task knows which task to propagate
  171. # an asyncio.CancelledError to.
  172. task_context = getattr(SyncToAsync.threadlocal, "task_context", None)
  173. loop = None
  174. # Use call_soon_threadsafe to schedule a synchronous callback on the
  175. # main event loop's thread if it's there, otherwise make a new loop
  176. # in this thread.
  177. try:
  178. awaitable = self.main_wrap(
  179. call_result,
  180. sys.exc_info(),
  181. task_context,
  182. context,
  183. *args,
  184. **kwargs,
  185. )
  186. if not (self.main_event_loop and self.main_event_loop.is_running()):
  187. # Make our own event loop - in a new thread - and run inside that.
  188. loop = asyncio.new_event_loop()
  189. self.loop_thread_executors[loop] = current_executor
  190. loop_executor = ThreadPoolExecutor(max_workers=1)
  191. loop_future = loop_executor.submit(
  192. self._run_event_loop, loop, awaitable
  193. )
  194. if current_executor:
  195. # Run the CurrentThreadExecutor until the future is done
  196. current_executor.run_until_future(loop_future)
  197. # Wait for future and/or allow for exception propagation
  198. loop_future.result()
  199. else:
  200. # Call it inside the existing loop
  201. self.main_event_loop.call_soon_threadsafe(
  202. self.main_event_loop.create_task, awaitable
  203. )
  204. if current_executor:
  205. # Run the CurrentThreadExecutor until the future is done
  206. current_executor.run_until_future(call_result)
  207. finally:
  208. # Clean up any executor we were running
  209. if loop is not None:
  210. del self.loop_thread_executors[loop]
  211. _restore_context(context[0])
  212. # Restore old current thread executor state
  213. self.executors.current = old_executor
  214. # Wait for results from the future.
  215. return call_result.result()
  216. def _run_event_loop(self, loop, coro):
  217. """
  218. Runs the given event loop (designed to be called in a thread).
  219. """
  220. asyncio.set_event_loop(loop)
  221. try:
  222. loop.run_until_complete(coro)
  223. finally:
  224. try:
  225. # mimic asyncio.run() behavior
  226. # cancel unexhausted async generators
  227. tasks = asyncio.all_tasks(loop)
  228. for task in tasks:
  229. task.cancel()
  230. async def gather():
  231. await asyncio.gather(*tasks, return_exceptions=True)
  232. loop.run_until_complete(gather())
  233. for task in tasks:
  234. if task.cancelled():
  235. continue
  236. if task.exception() is not None:
  237. loop.call_exception_handler(
  238. {
  239. "message": "unhandled exception during loop shutdown",
  240. "exception": task.exception(),
  241. "task": task,
  242. }
  243. )
  244. if hasattr(loop, "shutdown_asyncgens"):
  245. loop.run_until_complete(loop.shutdown_asyncgens())
  246. finally:
  247. loop.close()
  248. asyncio.set_event_loop(self.main_event_loop)
  249. def __get__(self, parent: Any, objtype: Any) -> Callable[_P, _R]:
  250. """
  251. Include self for methods
  252. """
  253. func = functools.partial(self.__call__, parent)
  254. return functools.update_wrapper(func, self.awaitable)
  255. async def main_wrap(
  256. self,
  257. call_result: "Future[_R]",
  258. exc_info: "OptExcInfo",
  259. task_context: "Optional[List[asyncio.Task[Any]]]",
  260. context: List[contextvars.Context],
  261. *args: _P.args,
  262. **kwargs: _P.kwargs,
  263. ) -> None:
  264. """
  265. Wraps the awaitable with something that puts the result into the
  266. result/exception future.
  267. """
  268. __traceback_hide__ = True # noqa: F841
  269. if context is not None:
  270. _restore_context(context[0])
  271. current_task = asyncio.current_task()
  272. if current_task is not None and task_context is not None:
  273. task_context.append(current_task)
  274. try:
  275. # If we have an exception, run the function inside the except block
  276. # after raising it so exc_info is correctly populated.
  277. if exc_info[1]:
  278. try:
  279. raise exc_info[1]
  280. except BaseException:
  281. result = await self.awaitable(*args, **kwargs)
  282. else:
  283. result = await self.awaitable(*args, **kwargs)
  284. except BaseException as e:
  285. call_result.set_exception(e)
  286. else:
  287. call_result.set_result(result)
  288. finally:
  289. if current_task is not None and task_context is not None:
  290. task_context.remove(current_task)
  291. context[0] = contextvars.copy_context()
  292. class SyncToAsync(Generic[_P, _R]):
  293. """
  294. Utility class which turns a synchronous callable into an awaitable that
  295. runs in a threadpool. It also sets a threadlocal inside the thread so
  296. calls to AsyncToSync can escape it.
  297. If thread_sensitive is passed, the code will run in the same thread as any
  298. outer code. This is needed for underlying Python code that is not
  299. threadsafe (for example, code which handles SQLite database connections).
  300. If the outermost program is async (i.e. SyncToAsync is outermost), then
  301. this will be a dedicated single sub-thread that all sync code runs in,
  302. one after the other. If the outermost program is sync (i.e. AsyncToSync is
  303. outermost), this will just be the main thread. This is achieved by idling
  304. with a CurrentThreadExecutor while AsyncToSync is blocking its sync parent,
  305. rather than just blocking.
  306. If executor is passed in, that will be used instead of the loop's default executor.
  307. In order to pass in an executor, thread_sensitive must be set to False, otherwise
  308. a TypeError will be raised.
  309. """
  310. # Storage for main event loop references
  311. threadlocal = threading.local()
  312. # Single-thread executor for thread-sensitive code
  313. single_thread_executor = ThreadPoolExecutor(max_workers=1)
  314. # Maintain a contextvar for the current execution context. Optionally used
  315. # for thread sensitive mode.
  316. thread_sensitive_context: "contextvars.ContextVar[ThreadSensitiveContext]" = (
  317. contextvars.ContextVar("thread_sensitive_context")
  318. )
  319. # Contextvar that is used to detect if the single thread executor
  320. # would be awaited on while already being used in the same context
  321. deadlock_context: "contextvars.ContextVar[bool]" = contextvars.ContextVar(
  322. "deadlock_context"
  323. )
  324. # Maintaining a weak reference to the context ensures that thread pools are
  325. # erased once the context goes out of scope. This terminates the thread pool.
  326. context_to_thread_executor: "weakref.WeakKeyDictionary[ThreadSensitiveContext, ThreadPoolExecutor]" = (
  327. weakref.WeakKeyDictionary()
  328. )
  329. def __init__(
  330. self,
  331. func: Callable[_P, _R],
  332. thread_sensitive: bool = True,
  333. executor: Optional["ThreadPoolExecutor"] = None,
  334. ) -> None:
  335. if (
  336. not callable(func)
  337. or iscoroutinefunction(func)
  338. or iscoroutinefunction(getattr(func, "__call__", func))
  339. ):
  340. raise TypeError("sync_to_async can only be applied to sync functions.")
  341. self.func = func
  342. functools.update_wrapper(self, func)
  343. self._thread_sensitive = thread_sensitive
  344. markcoroutinefunction(self)
  345. if thread_sensitive and executor is not None:
  346. raise TypeError("executor must not be set when thread_sensitive is True")
  347. self._executor = executor
  348. try:
  349. self.__self__ = func.__self__ # type: ignore
  350. except AttributeError:
  351. pass
  352. async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
  353. __traceback_hide__ = True # noqa: F841
  354. loop = asyncio.get_running_loop()
  355. # Work out what thread to run the code in
  356. if self._thread_sensitive:
  357. current_thread_executor = getattr(AsyncToSync.executors, "current", None)
  358. if current_thread_executor:
  359. # If we have a parent sync thread above somewhere, use that
  360. executor = current_thread_executor
  361. elif self.thread_sensitive_context.get(None):
  362. # If we have a way of retrieving the current context, attempt
  363. # to use a per-context thread pool executor
  364. thread_sensitive_context = self.thread_sensitive_context.get()
  365. if thread_sensitive_context in self.context_to_thread_executor:
  366. # Re-use thread executor in current context
  367. executor = self.context_to_thread_executor[thread_sensitive_context]
  368. else:
  369. # Create new thread executor in current context
  370. executor = ThreadPoolExecutor(max_workers=1)
  371. self.context_to_thread_executor[thread_sensitive_context] = executor
  372. elif loop in AsyncToSync.loop_thread_executors:
  373. # Re-use thread executor for running loop
  374. executor = AsyncToSync.loop_thread_executors[loop]
  375. elif self.deadlock_context.get(False):
  376. raise RuntimeError(
  377. "Single thread executor already being used, would deadlock"
  378. )
  379. else:
  380. # Otherwise, we run it in a fixed single thread
  381. executor = self.single_thread_executor
  382. self.deadlock_context.set(True)
  383. else:
  384. # Use the passed in executor, or the loop's default if it is None
  385. executor = self._executor
  386. context = contextvars.copy_context()
  387. child = functools.partial(self.func, *args, **kwargs)
  388. func = context.run
  389. task_context: List[asyncio.Task[Any]] = []
  390. # Run the code in the right thread
  391. exec_coro = loop.run_in_executor(
  392. executor,
  393. functools.partial(
  394. self.thread_handler,
  395. loop,
  396. sys.exc_info(),
  397. task_context,
  398. func,
  399. child,
  400. ),
  401. )
  402. ret: _R
  403. try:
  404. ret = await asyncio.shield(exec_coro)
  405. except asyncio.CancelledError:
  406. cancel_parent = True
  407. try:
  408. task = task_context[0]
  409. task.cancel()
  410. try:
  411. await task
  412. cancel_parent = False
  413. except asyncio.CancelledError:
  414. pass
  415. except IndexError:
  416. pass
  417. if exec_coro.done():
  418. raise
  419. if cancel_parent:
  420. exec_coro.cancel()
  421. ret = await exec_coro
  422. finally:
  423. _restore_context(context)
  424. self.deadlock_context.set(False)
  425. return ret
  426. def __get__(
  427. self, parent: Any, objtype: Any
  428. ) -> Callable[_P, Coroutine[Any, Any, _R]]:
  429. """
  430. Include self for methods
  431. """
  432. func = functools.partial(self.__call__, parent)
  433. return functools.update_wrapper(func, self.func)
  434. def thread_handler(self, loop, exc_info, task_context, func, *args, **kwargs):
  435. """
  436. Wraps the sync application with exception handling.
  437. """
  438. __traceback_hide__ = True # noqa: F841
  439. # Set the threadlocal for AsyncToSync
  440. self.threadlocal.main_event_loop = loop
  441. self.threadlocal.main_event_loop_pid = os.getpid()
  442. self.threadlocal.task_context = task_context
  443. # Run the function
  444. # If we have an exception, run the function inside the except block
  445. # after raising it so exc_info is correctly populated.
  446. if exc_info[1]:
  447. try:
  448. raise exc_info[1]
  449. except BaseException:
  450. return func(*args, **kwargs)
  451. else:
  452. return func(*args, **kwargs)
  453. @overload
  454. def async_to_sync(
  455. *,
  456. force_new_loop: bool = False,
  457. ) -> Callable[
  458. [Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]],
  459. Callable[_P, _R],
  460. ]:
  461. ...
  462. @overload
  463. def async_to_sync(
  464. awaitable: Union[
  465. Callable[_P, Coroutine[Any, Any, _R]],
  466. Callable[_P, Awaitable[_R]],
  467. ],
  468. *,
  469. force_new_loop: bool = False,
  470. ) -> Callable[_P, _R]:
  471. ...
  472. def async_to_sync(
  473. awaitable: Optional[
  474. Union[
  475. Callable[_P, Coroutine[Any, Any, _R]],
  476. Callable[_P, Awaitable[_R]],
  477. ]
  478. ] = None,
  479. *,
  480. force_new_loop: bool = False,
  481. ) -> Union[
  482. Callable[
  483. [Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]],
  484. Callable[_P, _R],
  485. ],
  486. Callable[_P, _R],
  487. ]:
  488. if awaitable is None:
  489. return lambda f: AsyncToSync(
  490. f,
  491. force_new_loop=force_new_loop,
  492. )
  493. return AsyncToSync(
  494. awaitable,
  495. force_new_loop=force_new_loop,
  496. )
  497. @overload
  498. def sync_to_async(
  499. *,
  500. thread_sensitive: bool = True,
  501. executor: Optional["ThreadPoolExecutor"] = None,
  502. ) -> Callable[[Callable[_P, _R]], Callable[_P, Coroutine[Any, Any, _R]]]:
  503. ...
  504. @overload
  505. def sync_to_async(
  506. func: Callable[_P, _R],
  507. *,
  508. thread_sensitive: bool = True,
  509. executor: Optional["ThreadPoolExecutor"] = None,
  510. ) -> Callable[_P, Coroutine[Any, Any, _R]]:
  511. ...
  512. def sync_to_async(
  513. func: Optional[Callable[_P, _R]] = None,
  514. *,
  515. thread_sensitive: bool = True,
  516. executor: Optional["ThreadPoolExecutor"] = None,
  517. ) -> Union[
  518. Callable[[Callable[_P, _R]], Callable[_P, Coroutine[Any, Any, _R]]],
  519. Callable[_P, Coroutine[Any, Any, _R]],
  520. ]:
  521. if func is None:
  522. return lambda f: SyncToAsync(
  523. f,
  524. thread_sensitive=thread_sensitive,
  525. executor=executor,
  526. )
  527. return SyncToAsync(
  528. func,
  529. thread_sensitive=thread_sensitive,
  530. executor=executor,
  531. )