123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import asyncio
- import contextlib
- import contextvars
- import threading
- from typing import Any, Dict, Union
- class _CVar:
- """Storage utility for Local."""
- def __init__(self) -> None:
- self._data: "contextvars.ContextVar[Dict[str, Any]]" = contextvars.ContextVar(
- "asgiref.local"
- )
- def __getattr__(self, key):
- storage_object = self._data.get({})
- try:
- return storage_object[key]
- except KeyError:
- raise AttributeError(f"{self!r} object has no attribute {key!r}")
- def __setattr__(self, key: str, value: Any) -> None:
- if key == "_data":
- return super().__setattr__(key, value)
- storage_object = self._data.get({})
- storage_object[key] = value
- self._data.set(storage_object)
- def __delattr__(self, key: str) -> None:
- storage_object = self._data.get({})
- if key in storage_object:
- del storage_object[key]
- self._data.set(storage_object)
- else:
- raise AttributeError(f"{self!r} object has no attribute {key!r}")
- class Local:
- """Local storage for async tasks.
- This is a namespace object (similar to `threading.local`) where data is
- also local to the current async task (if there is one).
- In async threads, local means in the same sense as the `contextvars`
- module - i.e. a value set in an async frame will be visible:
- - to other async code `await`-ed from this frame.
- - to tasks spawned using `asyncio` utilities (`create_task`, `wait_for`,
- `gather` and probably others).
- - to code scheduled in a sync thread using `sync_to_async`
- In "sync" threads (a thread with no async event loop running), the
- data is thread-local, but additionally shared with async code executed
- via the `async_to_sync` utility, which schedules async code in a new thread
- and copies context across to that thread.
- If `thread_critical` is True, then the local will only be visible per-thread,
- behaving exactly like `threading.local` if the thread is sync, and as
- `contextvars` if the thread is async. This allows genuinely thread-sensitive
- code (such as DB handles) to be kept stricly to their initial thread and
- disable the sharing across `sync_to_async` and `async_to_sync` wrapped calls.
- Unlike plain `contextvars` objects, this utility is threadsafe.
- """
- def __init__(self, thread_critical: bool = False) -> None:
- self._thread_critical = thread_critical
- self._thread_lock = threading.RLock()
- self._storage: "Union[threading.local, _CVar]"
- if thread_critical:
- # Thread-local storage
- self._storage = threading.local()
- else:
- # Contextvar storage
- self._storage = _CVar()
- @contextlib.contextmanager
- def _lock_storage(self):
- # Thread safe access to storage
- if self._thread_critical:
- try:
- # this is a test for are we in a async or sync
- # thread - will raise RuntimeError if there is
- # no current loop
- asyncio.get_running_loop()
- except RuntimeError:
- # We are in a sync thread, the storage is
- # just the plain thread local (i.e, "global within
- # this thread" - it doesn't matter where you are
- # in a call stack you see the same storage)
- yield self._storage
- else:
- # We are in an async thread - storage is still
- # local to this thread, but additionally should
- # behave like a context var (is only visible with
- # the same async call stack)
- # Ensure context exists in the current thread
- if not hasattr(self._storage, "cvar"):
- self._storage.cvar = _CVar()
- # self._storage is a thread local, so the members
- # can't be accessed in another thread (we don't
- # need any locks)
- yield self._storage.cvar
- else:
- # Lock for thread_critical=False as other threads
- # can access the exact same storage object
- with self._thread_lock:
- yield self._storage
- def __getattr__(self, key):
- with self._lock_storage() as storage:
- return getattr(storage, key)
- def __setattr__(self, key, value):
- if key in ("_local", "_storage", "_thread_critical", "_thread_lock"):
- return super().__setattr__(key, value)
- with self._lock_storage() as storage:
- setattr(storage, key, value)
- def __delattr__(self, key):
- with self._lock_storage() as storage:
- delattr(storage, key)
|