diff --git a/README.md b/README.md index 463fe62..dc3c8a5 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ pip install sqlalchemy-dlock ```python from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy_dlock.asyncio import create_async_sadlock + from sqlalchemy_dlock import create_async_sadlock key = 'user/001' diff --git a/db.docker-compose.yml b/docker-compose.database.yml similarity index 100% rename from db.docker-compose.yml rename to docker-compose.database.yml diff --git a/docs/README.rst b/docs/README.rst index 35033d6..e3f9d09 100644 --- a/docs/README.rst +++ b/docs/README.rst @@ -19,7 +19,7 @@ How to build docs .. code:: sh - sphinx-apidoc -o docs/apidocs -eMTf src + rm -fr docs/apidocs/* && sphinx-apidoc -o docs/apidocs -eMTf src #. Build HTML documentation: @@ -27,6 +27,12 @@ How to build docs make -C docs html + or rebuild it: + + .. code:: sh + + make -C docs clean html + The built static web site is output to ``docs/_build/html``, we can serve it: .. code:: sh diff --git a/docs/apidocs/sqlalchemy_dlock.asyncio.factory.rst b/docs/apidocs/sqlalchemy_dlock.asyncio.factory.rst deleted file mode 100644 index 5314dd8..0000000 --- a/docs/apidocs/sqlalchemy_dlock.asyncio.factory.rst +++ /dev/null @@ -1,7 +0,0 @@ -sqlalchemy\_dlock.asyncio.factory module -======================================== - -.. automodule:: sqlalchemy_dlock.asyncio.factory - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/apidocs/sqlalchemy_dlock.asyncio.lock.base.rst b/docs/apidocs/sqlalchemy_dlock.asyncio.lock.base.rst deleted file mode 100644 index c72ea98..0000000 --- a/docs/apidocs/sqlalchemy_dlock.asyncio.lock.base.rst +++ /dev/null @@ -1,7 +0,0 @@ -sqlalchemy\_dlock.asyncio.lock.base module -========================================== - -.. automodule:: sqlalchemy_dlock.asyncio.lock.base - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/apidocs/sqlalchemy_dlock.asyncio.lock.mysql.rst b/docs/apidocs/sqlalchemy_dlock.asyncio.lock.mysql.rst deleted file mode 100644 index b3cf279..0000000 --- a/docs/apidocs/sqlalchemy_dlock.asyncio.lock.mysql.rst +++ /dev/null @@ -1,7 +0,0 @@ -sqlalchemy\_dlock.asyncio.lock.mysql module -=========================================== - -.. automodule:: sqlalchemy_dlock.asyncio.lock.mysql - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/apidocs/sqlalchemy_dlock.asyncio.lock.postgresql.rst b/docs/apidocs/sqlalchemy_dlock.asyncio.lock.postgresql.rst deleted file mode 100644 index e1d61b6..0000000 --- a/docs/apidocs/sqlalchemy_dlock.asyncio.lock.postgresql.rst +++ /dev/null @@ -1,7 +0,0 @@ -sqlalchemy\_dlock.asyncio.lock.postgresql module -================================================ - -.. automodule:: sqlalchemy_dlock.asyncio.lock.postgresql - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/apidocs/sqlalchemy_dlock.asyncio.lock.rst b/docs/apidocs/sqlalchemy_dlock.asyncio.lock.rst deleted file mode 100644 index 6003c30..0000000 --- a/docs/apidocs/sqlalchemy_dlock.asyncio.lock.rst +++ /dev/null @@ -1,17 +0,0 @@ -sqlalchemy\_dlock.asyncio.lock package -====================================== - -.. automodule:: sqlalchemy_dlock.asyncio.lock - :members: - :undoc-members: - :show-inheritance: - -Submodules ----------- - -.. toctree:: - :maxdepth: 4 - - sqlalchemy_dlock.asyncio.lock.base - sqlalchemy_dlock.asyncio.lock.mysql - sqlalchemy_dlock.asyncio.lock.postgresql diff --git a/docs/apidocs/sqlalchemy_dlock.asyncio.rst b/docs/apidocs/sqlalchemy_dlock.asyncio.rst deleted file mode 100644 index 3fcfe09..0000000 --- a/docs/apidocs/sqlalchemy_dlock.asyncio.rst +++ /dev/null @@ -1,24 +0,0 @@ -sqlalchemy\_dlock.asyncio package -================================= - -.. automodule:: sqlalchemy_dlock.asyncio - :members: - :undoc-members: - :show-inheritance: - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - sqlalchemy_dlock.asyncio.lock - -Submodules ----------- - -.. toctree:: - :maxdepth: 4 - - sqlalchemy_dlock.asyncio.factory - sqlalchemy_dlock.asyncio.types diff --git a/docs/apidocs/sqlalchemy_dlock.asyncio.types.rst b/docs/apidocs/sqlalchemy_dlock.asyncio.types.rst deleted file mode 100644 index 61d0976..0000000 --- a/docs/apidocs/sqlalchemy_dlock.asyncio.types.rst +++ /dev/null @@ -1,7 +0,0 @@ -sqlalchemy\_dlock.asyncio.types module -====================================== - -.. automodule:: sqlalchemy_dlock.asyncio.types - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/apidocs/sqlalchemy_dlock.registry.rst b/docs/apidocs/sqlalchemy_dlock.registry.rst new file mode 100644 index 0000000..b03e6a2 --- /dev/null +++ b/docs/apidocs/sqlalchemy_dlock.registry.rst @@ -0,0 +1,7 @@ +sqlalchemy\_dlock.registry module +================================= + +.. automodule:: sqlalchemy_dlock.registry + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/apidocs/sqlalchemy_dlock.rst b/docs/apidocs/sqlalchemy_dlock.rst index e43ddb6..1c1cb53 100644 --- a/docs/apidocs/sqlalchemy_dlock.rst +++ b/docs/apidocs/sqlalchemy_dlock.rst @@ -12,7 +12,6 @@ Subpackages .. toctree:: :maxdepth: 4 - sqlalchemy_dlock.asyncio sqlalchemy_dlock.lock sqlalchemy_dlock.statement @@ -24,5 +23,6 @@ Submodules sqlalchemy_dlock.exceptions sqlalchemy_dlock.factory + sqlalchemy_dlock.registry sqlalchemy_dlock.types sqlalchemy_dlock.utils diff --git a/src/sqlalchemy_dlock/__init__.py b/src/sqlalchemy_dlock/__init__.py index 7b18b36..a198278 100644 --- a/src/sqlalchemy_dlock/__init__.py +++ b/src/sqlalchemy_dlock/__init__.py @@ -4,4 +4,5 @@ from ._version import __version__, __version_tuple__ from .exceptions import SqlAlchemyDLockBaseException, SqlAlchemyDLockDatabaseError -from .factory import create_sadlock +from .factory import create_async_sadlock, create_sadlock +from .lock import BaseAsyncSadLock, BaseSadLock diff --git a/src/sqlalchemy_dlock/asyncio/__init__.py b/src/sqlalchemy_dlock/asyncio/__init__.py deleted file mode 100644 index 16c3db4..0000000 --- a/src/sqlalchemy_dlock/asyncio/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .factory import * # noqa: F403 diff --git a/src/sqlalchemy_dlock/asyncio/factory.py b/src/sqlalchemy_dlock/asyncio/factory.py deleted file mode 100644 index 67ca04f..0000000 --- a/src/sqlalchemy_dlock/asyncio/factory.py +++ /dev/null @@ -1,31 +0,0 @@ -from importlib import import_module -from string import Template -from typing import Any, Mapping, Type, Union - -from sqlalchemy.engine import Connection -from sqlalchemy.ext.asyncio import AsyncConnection - -from .lock.base import BaseAsyncSadLock -from .types import TAsyncConnectionOrSession - -__all__ = ["create_async_sadlock"] - - -def create_async_sadlock( - connection_or_session: TAsyncConnectionOrSession, key, contextual_timeout: Union[float, int, None] = None, **kwargs -) -> BaseAsyncSadLock: - if isinstance(connection_or_session, AsyncConnection): - engine = connection_or_session.sync_engine - else: - bind = connection_or_session.get_bind() - if isinstance(bind, Connection): - engine = bind.engine - else: - engine = bind - conf: Mapping[str, Any] = getattr(import_module(".registry", __package__), "REGISTRY")[engine.name] - package: Union[str, None] = conf.get("package") - if package: - package = Template(package).safe_substitute(package=__package__) - mod = import_module(conf["module"], package) - clz: Type[BaseAsyncSadLock] = getattr(mod, conf["class"]) - return clz(connection_or_session, key, contextual_timeout=contextual_timeout, **kwargs) diff --git a/src/sqlalchemy_dlock/asyncio/lock/__init__.py b/src/sqlalchemy_dlock/asyncio/lock/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/sqlalchemy_dlock/asyncio/lock/base.py b/src/sqlalchemy_dlock/asyncio/lock/base.py deleted file mode 100644 index 677fac8..0000000 --- a/src/sqlalchemy_dlock/asyncio/lock/base.py +++ /dev/null @@ -1,69 +0,0 @@ -import sys -from typing import Generic, TypeVar, Union - -if sys.version_info >= (3, 11): # pragma: no cover - from typing import Self -else: # pragma: no cover - from typing_extensions import Self - -from ..types import TAsyncConnectionOrSession - -KT = TypeVar("KT") - - -class BaseAsyncSadLock(Generic[KT]): - def __init__( - self, - connection_or_session: TAsyncConnectionOrSession, - key: KT, - /, - contextual_timeout: Union[float, int, None] = None, - **kwargs, - ): - self._acquired = False - self._connection_or_session = connection_or_session - self._key = key - self._contextual_timeout = contextual_timeout - - async def __aenter__(self) -> Self: - if self._contextual_timeout is None: - await self.acquire() - elif not await self.acquire(timeout=self._contextual_timeout): - # the timeout period has elapsed and not acquired - raise TimeoutError() - return self - - async def __aexit__(self, exc_type, exc_value, exc_tb): - await self.close() - - def __str__(self): - return "<{} {} key={} at 0x{:x}>".format( - "locked" if self._acquired else "unlocked", - self.__class__.__name__, - self._key, - id(self), - ) - - @property - def connection_or_session(self) -> TAsyncConnectionOrSession: - return self._connection_or_session - - @property - def key(self) -> KT: - return self._key - - @property - def locked(self) -> bool: - return self._acquired - - async def acquire( - self, block: bool = True, timeout: Union[float, int, None] = None, *args, **kwargs - ) -> bool: # pragma: no cover - raise NotImplementedError() - - async def release(self, *args, **kwargs) -> None: # pragma: no cover - raise NotImplementedError() - - async def close(self, *args, **kwargs) -> None: - if self._acquired: - await self.release(*args, **kwargs) diff --git a/src/sqlalchemy_dlock/asyncio/lock/mysql.py b/src/sqlalchemy_dlock/asyncio/lock/mysql.py deleted file mode 100644 index dd17135..0000000 --- a/src/sqlalchemy_dlock/asyncio/lock/mysql.py +++ /dev/null @@ -1,68 +0,0 @@ -import sys -from typing import Union - -if sys.version_info < (3, 12): # pragma: no cover - from typing_extensions import override -else: # pragma: no cover - from typing import override - -from ...exceptions import SqlAlchemyDLockDatabaseError -from ...lock.mysql import MysqlSadLockMixin -from ...statement.mysql import LOCK, UNLOCK -from ..types import TAsyncConnectionOrSession -from .base import BaseAsyncSadLock - - -class MysqlAsyncSadLock(MysqlSadLockMixin, BaseAsyncSadLock[str]): - @override - def __init__(self, connection_or_session: TAsyncConnectionOrSession, key, **kwargs): - MysqlSadLockMixin.__init__(self, key=key, **kwargs) - BaseAsyncSadLock.__init__(self, connection_or_session, self._actual_key, **kwargs) - - @override - async def acquire(self, block: bool = True, timeout: Union[float, int, None] = None, *args, **kwargs) -> bool: - if self._acquired: - raise ValueError("invoked on a locked lock") - if block: - # None: set the timeout period to infinite. - if timeout is None: - timeout = -1 - # negative value for `timeout` are equivalent to a `timeout` of zero - elif timeout < 0: - timeout = 0 - else: - timeout = 0 - stmt = LOCK.params(str=self.key, timeout=timeout) - ret_val = (await self.connection_or_session.execute(stmt)).scalar_one() - if ret_val == 1: - self._acquired = True - elif ret_val == 0: - pass # 直到超时也没有成功锁定 - elif ret_val is None: # pragma: no cover - raise SqlAlchemyDLockDatabaseError(f"An error occurred while attempting to obtain the lock {self.key!r}") - else: # pragma: no cover - raise SqlAlchemyDLockDatabaseError(f"GET_LOCK({self.key!r}, {timeout}) returns {ret_val}") - return self._acquired - - @override - async def release(self): - if not self._acquired: - raise ValueError("invoked on an unlocked lock") - stmt = UNLOCK.params(str=self.key) - ret_val = (await self.connection_or_session.execute(stmt)).scalar_one() - if ret_val == 1: - self._acquired = False - elif ret_val == 0: # pragma: no cover - self._acquired = False - raise SqlAlchemyDLockDatabaseError( - f"The named lock {self.key!r} was not established by this thread, " "and the lock is not released." - ) - elif ret_val is None: # pragma: no cover - self._acquired = False - raise SqlAlchemyDLockDatabaseError( - f"The named lock {self.key!r} did not exist, " - "was never obtained by a call to GET_LOCK(), " - "or has previously been released." - ) - else: # pragma: no cover - raise SqlAlchemyDLockDatabaseError(f"RELEASE_LOCK({self.key!r}) returns {ret_val}") diff --git a/src/sqlalchemy_dlock/asyncio/lock/postgresql.py b/src/sqlalchemy_dlock/asyncio/lock/postgresql.py deleted file mode 100644 index 4580624..0000000 --- a/src/sqlalchemy_dlock/asyncio/lock/postgresql.py +++ /dev/null @@ -1,91 +0,0 @@ -import asyncio -import sys -from time import time -from typing import Union -from warnings import catch_warnings, warn - -if sys.version_info < (3, 12): # pragma: no cover - from typing_extensions import override -else: # pragma: no cover - from typing import override - -from ...exceptions import SqlAlchemyDLockDatabaseError -from ...lock.postgresql import PostgresqlSadLockMixin -from ...statement.postgresql import SLEEP_INTERVAL_DEFAULT, SLEEP_INTERVAL_MIN -from ..types import TAsyncConnectionOrSession -from .base import BaseAsyncSadLock - - -class PostgresqlAsyncSadLock(PostgresqlSadLockMixin, BaseAsyncSadLock[int]): - @override - def __init__(self, connection_or_session: TAsyncConnectionOrSession, key, **kwargs): - PostgresqlSadLockMixin.__init__(self, key=key, **kwargs) - BaseAsyncSadLock.__init__(self, connection_or_session, self._actual_key, **kwargs) - - @override - async def acquire( - self, - block: bool = True, - timeout: Union[float, int, None] = None, - interval: Union[float, int, None] = None, - *args, - **kwargs, - ) -> bool: - if self._acquired: - raise ValueError("invoked on a locked lock") - if block: - if timeout is None: - # None: set the timeout period to infinite. - _ = (await self.connection_or_session.execute(self._stmt_lock)).all() - self._acquired = True - else: - # negative value for `timeout` are equivalent to a `timeout` of zero. - if timeout < 0: - timeout = 0 - interval = SLEEP_INTERVAL_DEFAULT if interval is None else interval - if interval < SLEEP_INTERVAL_MIN: # pragma: no cover - raise ValueError("interval too small") - ts_begin = time() - while True: - ret_val = (await self.connection_or_session.execute(self._stmt_try_lock)).scalar_one() - if ret_val: # succeed - self._acquired = True - break - if time() - ts_begin > timeout: # expired - break - await asyncio.sleep(interval) - else: - # This will either obtain the lock immediately and return true, - # or return false without waiting if the lock cannot be acquired immediately. - ret_val = (await self.connection_or_session.execute(self._stmt_try_lock)).scalar_one() - self._acquired = bool(ret_val) - # - return self._acquired - - @override - async def release(self): - if not self._acquired: - raise ValueError("invoked on an unlocked lock") - if self._stmt_unlock is None: - warn( - "PostgreSQL transaction level advisory locks are held until the current transaction ends; " - "there is no provision for manual release.", - RuntimeWarning, - ) - return - ret_val = (await self.connection_or_session.execute(self._stmt_unlock)).scalar_one() - if ret_val: - self._acquired = False - else: # pragma: no cover - self._acquired = False - raise SqlAlchemyDLockDatabaseError(f"The advisory lock {self.key!r} was not held.") - - @override - async def close(self): - if self._acquired: - if sys.version_info < (3, 11): - with catch_warnings(): - return await self.release() - else: - with catch_warnings(category=RuntimeWarning): - return await self.release() diff --git a/src/sqlalchemy_dlock/asyncio/registry.py b/src/sqlalchemy_dlock/asyncio/registry.py deleted file mode 100644 index b3e71e6..0000000 --- a/src/sqlalchemy_dlock/asyncio/registry.py +++ /dev/null @@ -1,12 +0,0 @@ -REGISTRY = { - "mysql": { - "module": ".lock.mysql", - "package": "${package}", # module name relative to the package - "class": "MysqlAsyncSadLock", - }, - "postgresql": { - "module": ".lock.postgresql", - "package": "${package}", # module name relative to the package - "class": "PostgresqlAsyncSadLock", - }, -} diff --git a/src/sqlalchemy_dlock/asyncio/types.py b/src/sqlalchemy_dlock/asyncio/types.py deleted file mode 100644 index d92c0ee..0000000 --- a/src/sqlalchemy_dlock/asyncio/types.py +++ /dev/null @@ -1,13 +0,0 @@ -import sys -from typing import Union - -if sys.version_info < (3, 10): # pragma: no cover - from typing_extensions import TypeAlias -else: # pragma: no cover - from typing import TypeAlias - -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_scoped_session - -__all__ = ["TAsyncConnectionOrSession"] - -TAsyncConnectionOrSession: TypeAlias = Union[AsyncConnection, AsyncSession, async_scoped_session] diff --git a/src/sqlalchemy_dlock/factory.py b/src/sqlalchemy_dlock/factory.py index 5d9503d..af6a8d5 100644 --- a/src/sqlalchemy_dlock/factory.py +++ b/src/sqlalchemy_dlock/factory.py @@ -1,17 +1,17 @@ -from importlib import import_module -from string import Template -from typing import Any, Mapping, Type, Union +from typing import Union from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import AsyncConnection -from .lock.base import BaseSadLock -from .types import TConnectionOrSession +from .lock.base import BaseAsyncSadLock, BaseSadLock +from .types import AsyncConnectionOrSessionT, ConnectionOrSessionT +from .utils import find_lock_class -__all__ = ["create_sadlock"] +__all__ = ["create_sadlock", "create_async_sadlock"] def create_sadlock( - connection_or_session: TConnectionOrSession, key, /, contextual_timeout: Union[float, int, None] = None, **kwargs + connection_or_session: ConnectionOrSessionT, key, /, contextual_timeout: Union[float, int, None] = None, **kwargs ) -> BaseSadLock: """Create a database distributed lock object @@ -48,10 +48,22 @@ def create_sadlock( else: engine = bind - conf: Mapping[str, Any] = getattr(import_module(".registry", __package__), "REGISTRY")[engine.name] - package: Union[str, None] = conf.get("package") - if package: - package = Template(package).safe_substitute(package=__package__) - mod = import_module(conf["module"], package) - clz: Type[BaseSadLock] = getattr(mod, conf["class"]) - return clz(connection_or_session, key, contextual_timeout=contextual_timeout, **kwargs) + class_ = find_lock_class(engine.name) + return class_(connection_or_session, key, contextual_timeout=contextual_timeout, **kwargs) + + +def create_async_sadlock( + connection_or_session: AsyncConnectionOrSessionT, key, /, contextual_timeout: Union[float, int, None] = None, **kwargs +) -> BaseAsyncSadLock: + """AsyncIO version of :func:`create_sadlock`""" + if isinstance(connection_or_session, AsyncConnection): + engine = connection_or_session.engine + else: + bind = connection_or_session.get_bind() + if isinstance(bind, Connection): + engine = bind.engine + else: + engine = bind + + class_ = find_lock_class(engine.name, True) + return class_(connection_or_session, key, contextual_timeout=contextual_timeout, **kwargs) diff --git a/src/sqlalchemy_dlock/lock/__init__.py b/src/sqlalchemy_dlock/lock/__init__.py index e69de29..69ac59d 100644 --- a/src/sqlalchemy_dlock/lock/__init__.py +++ b/src/sqlalchemy_dlock/lock/__init__.py @@ -0,0 +1 @@ +from .base import BaseAsyncSadLock, BaseSadLock diff --git a/src/sqlalchemy_dlock/lock/base.py b/src/sqlalchemy_dlock/lock/base.py index e237cfc..0947158 100644 --- a/src/sqlalchemy_dlock/lock/base.py +++ b/src/sqlalchemy_dlock/lock/base.py @@ -7,7 +7,7 @@ else: # pragma: no cover from typing_extensions import Self -from ..types import TConnectionOrSession +from ..types import AsyncConnectionOrSessionT, ConnectionOrSessionT KT = TypeVar("KT") @@ -42,7 +42,7 @@ class BaseSadLock(Generic[KT], local): def __init__( self, - connection_or_session: TConnectionOrSession, + connection_or_session: ConnectionOrSessionT, key: KT, /, contextual_timeout: Union[float, int, None] = None, @@ -100,7 +100,7 @@ def __str__(self) -> str: ) @property - def connection_or_session(self) -> TConnectionOrSession: + def connection_or_session(self) -> ConnectionOrSessionT: """Connection or Session object SQL locking functions will be invoked on it It returns ``connection_or_session`` parameter of the class's constructor. @@ -190,3 +190,61 @@ def close(self, *args, **kwargs) -> None: """ # noqa: E501 if self._acquired: self.release(*args, **kwargs) + + +class BaseAsyncSadLock(Generic[KT], local): + def __init__( + self, + connection_or_session: AsyncConnectionOrSessionT, + key: KT, + /, + contextual_timeout: Union[float, int, None] = None, + **kwargs, + ): + self._acquired = False + self._connection_or_session = connection_or_session + self._key = key + self._contextual_timeout = contextual_timeout + + async def __aenter__(self) -> Self: + if self._contextual_timeout is None: + await self.acquire() + elif not await self.acquire(timeout=self._contextual_timeout): + # the timeout period has elapsed and not acquired + raise TimeoutError() + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + await self.close() + + def __str__(self): + return "<{} {} key={} at 0x{:x}>".format( + "locked" if self._acquired else "unlocked", + self.__class__.__name__, + self._key, + id(self), + ) + + @property + def connection_or_session(self) -> AsyncConnectionOrSessionT: + return self._connection_or_session + + @property + def key(self) -> KT: + return self._key + + @property + def locked(self) -> bool: + return self._acquired + + async def acquire( + self, block: bool = True, timeout: Union[float, int, None] = None, *args, **kwargs + ) -> bool: # pragma: no cover + raise NotImplementedError() + + async def release(self, *args, **kwargs) -> None: # pragma: no cover + raise NotImplementedError() + + async def close(self, *args, **kwargs) -> None: + if self._acquired: + await self.release(*args, **kwargs) diff --git a/src/sqlalchemy_dlock/lock/mysql.py b/src/sqlalchemy_dlock/lock/mysql.py index f5e51fe..a41392e 100644 --- a/src/sqlalchemy_dlock/lock/mysql.py +++ b/src/sqlalchemy_dlock/lock/mysql.py @@ -8,20 +8,20 @@ from ..exceptions import SqlAlchemyDLockDatabaseError from ..statement.mysql import LOCK, UNLOCK -from ..types import TConnectionOrSession +from ..types import AsyncConnectionOrSessionT, ConnectionOrSessionT from ..utils import to_str_key -from .base import BaseSadLock +from .base import BaseAsyncSadLock, BaseSadLock MYSQL_LOCK_NAME_MAX_LENGTH = 64 -TKey = TypeVar("TKey", bound=Any) +KT = TypeVar("KT", bound=Any) class MysqlSadLockMixin: """A Mix-in class for MySQL named lock""" - def __init__(self, *, key: TKey, convert: Optional[Callable[[TKey], str]] = None, **kwargs): + def __init__(self, *, key: KT, convert: Optional[Callable[[KT], str]] = None, **kwargs): """ Args: key: MySQL named lock requires the key given by string. @@ -77,7 +77,7 @@ class MysqlSadLock(MysqlSadLockMixin, BaseSadLock[str]): """ # noqa: E501 @override - def __init__(self, connection_or_session: TConnectionOrSession, key, **kwargs): + def __init__(self, connection_or_session: ConnectionOrSessionT, key, **kwargs): """ Args: connection_or_session: :attr:`.BaseSadLock.connection_or_session` @@ -85,7 +85,7 @@ def __init__(self, connection_or_session: TConnectionOrSession, key, **kwargs): **kwargs: other named parameters pass to :class:`.BaseSadLock` and :class:`.MysqlSadLockMixin` """ MysqlSadLockMixin.__init__(self, key=key, **kwargs) - BaseSadLock.__init__(self, connection_or_session, self._actual_key, **kwargs) + BaseSadLock.__init__(self, connection_or_session, self.actual_key, **kwargs) @override def acquire(self, block: bool = True, timeout: Union[float, int, None] = None, *args, **kwargs) -> bool: @@ -134,3 +134,58 @@ def release(self): ) else: # pragma: no cover raise SqlAlchemyDLockDatabaseError(f"RELEASE_LOCK({self.key!r}) returns {ret_val}") + + +class MysqlAsyncSadLock(MysqlSadLockMixin, BaseAsyncSadLock[str]): + @override + def __init__(self, connection_or_session: AsyncConnectionOrSessionT, key, **kwargs): + MysqlSadLockMixin.__init__(self, key=key, **kwargs) + BaseAsyncSadLock.__init__(self, connection_or_session, self.actual_key, **kwargs) + + @override + async def acquire(self, block: bool = True, timeout: Union[float, int, None] = None, *args, **kwargs) -> bool: + if self._acquired: + raise ValueError("invoked on a locked lock") + if block: + # None: set the timeout period to infinite. + if timeout is None: + timeout = -1 + # negative value for `timeout` are equivalent to a `timeout` of zero + elif timeout < 0: + timeout = 0 + else: + timeout = 0 + stmt = LOCK.params(str=self.key, timeout=timeout) + ret_val = (await self.connection_or_session.execute(stmt)).scalar_one() + if ret_val == 1: + self._acquired = True + elif ret_val == 0: + pass # 直到超时也没有成功锁定 + elif ret_val is None: # pragma: no cover + raise SqlAlchemyDLockDatabaseError(f"An error occurred while attempting to obtain the lock {self.key!r}") + else: # pragma: no cover + raise SqlAlchemyDLockDatabaseError(f"GET_LOCK({self.key!r}, {timeout}) returns {ret_val}") + return self._acquired + + @override + async def release(self): + if not self._acquired: + raise ValueError("invoked on an unlocked lock") + stmt = UNLOCK.params(str=self.key) + ret_val = (await self.connection_or_session.execute(stmt)).scalar_one() + if ret_val == 1: + self._acquired = False + elif ret_val == 0: # pragma: no cover + self._acquired = False + raise SqlAlchemyDLockDatabaseError( + f"The named lock {self.key!r} was not established by this thread, " "and the lock is not released." + ) + elif ret_val is None: # pragma: no cover + self._acquired = False + raise SqlAlchemyDLockDatabaseError( + f"The named lock {self.key!r} did not exist, " + "was never obtained by a call to GET_LOCK(), " + "or has previously been released." + ) + else: # pragma: no cover + raise SqlAlchemyDLockDatabaseError(f"RELEASE_LOCK({self.key!r}) returns {ret_val}") diff --git a/src/sqlalchemy_dlock/lock/postgresql.py b/src/sqlalchemy_dlock/lock/postgresql.py index fea12c7..9019519 100644 --- a/src/sqlalchemy_dlock/lock/postgresql.py +++ b/src/sqlalchemy_dlock/lock/postgresql.py @@ -1,3 +1,4 @@ +import asyncio import sys from time import sleep, time from typing import Any, Callable, Optional, TypeVar, Union @@ -23,18 +24,18 @@ UNLOCK, UNLOCK_SHARED, ) -from ..types import TConnectionOrSession +from ..types import AsyncConnectionOrSessionT, ConnectionOrSessionT from ..utils import ensure_int64, to_int64_key -from .base import BaseSadLock +from .base import BaseAsyncSadLock, BaseSadLock -TKey = TypeVar("TKey", bound=Any) +KT = TypeVar("KT", bound=Any) class PostgresqlSadLockMixin: """A Mix-in class for PostgreSQL advisory lock""" def __init__( - self, *, key: TKey, shared: bool = False, xact: bool = False, convert: Optional[Callable[[TKey], int]] = None, **kwargs + self, *, key: KT, shared: bool = False, xact: bool = False, convert: Optional[Callable[[KT], int]] = None, **kwargs ): """ Args: @@ -108,7 +109,7 @@ class PostgresqlSadLock(PostgresqlSadLockMixin, BaseSadLock[int]): """ @override - def __init__(self, connection_or_session: TConnectionOrSession, key, **kwargs): + def __init__(self, connection_or_session: ConnectionOrSessionT, key, **kwargs): """ Args: connection_or_session: see :attr:`.BaseSadLock.connection_or_session` @@ -202,3 +203,78 @@ def close(self): else: with catch_warnings(category=RuntimeWarning): return self.release() + + +class PostgresqlAsyncSadLock(PostgresqlSadLockMixin, BaseAsyncSadLock[int]): + @override + def __init__(self, connection_or_session: AsyncConnectionOrSessionT, key, **kwargs): + PostgresqlSadLockMixin.__init__(self, key=key, **kwargs) + BaseAsyncSadLock.__init__(self, connection_or_session, self.actual_key, **kwargs) + + @override + async def acquire( + self, + block: bool = True, + timeout: Union[float, int, None] = None, + interval: Union[float, int, None] = None, + *args, + **kwargs, + ) -> bool: + if self._acquired: + raise ValueError("invoked on a locked lock") + if block: + if timeout is None: + # None: set the timeout period to infinite. + _ = (await self.connection_or_session.execute(self._stmt_lock)).all() + self._acquired = True + else: + # negative value for `timeout` are equivalent to a `timeout` of zero. + if timeout < 0: + timeout = 0 + interval = SLEEP_INTERVAL_DEFAULT if interval is None else interval + if interval < SLEEP_INTERVAL_MIN: # pragma: no cover + raise ValueError("interval too small") + ts_begin = time() + while True: + ret_val = (await self.connection_or_session.execute(self._stmt_try_lock)).scalar_one() + if ret_val: # succeed + self._acquired = True + break + if time() - ts_begin > timeout: # expired + break + await asyncio.sleep(interval) + else: + # This will either obtain the lock immediately and return true, + # or return false without waiting if the lock cannot be acquired immediately. + ret_val = (await self.connection_or_session.execute(self._stmt_try_lock)).scalar_one() + self._acquired = bool(ret_val) + # + return self._acquired + + @override + async def release(self): + if not self._acquired: + raise ValueError("invoked on an unlocked lock") + if self._stmt_unlock is None: + warn( + "PostgreSQL transaction level advisory locks are held until the current transaction ends; " + "there is no provision for manual release.", + RuntimeWarning, + ) + return + ret_val = (await self.connection_or_session.execute(self._stmt_unlock)).scalar_one() + if ret_val: + self._acquired = False + else: # pragma: no cover + self._acquired = False + raise SqlAlchemyDLockDatabaseError(f"The advisory lock {self.key!r} was not held.") + + @override + async def close(self): + if self._acquired: + if sys.version_info < (3, 11): + with catch_warnings(): + return await self.release() + else: + with catch_warnings(category=RuntimeWarning): + return await self.release() diff --git a/src/sqlalchemy_dlock/registry.py b/src/sqlalchemy_dlock/registry.py index 8e85830..70dc492 100644 --- a/src/sqlalchemy_dlock/registry.py +++ b/src/sqlalchemy_dlock/registry.py @@ -10,3 +10,16 @@ "class": "PostgresqlSadLock", }, } + +ASYNCIO_REGISTRY = { + "mysql": { + "module": ".lock.mysql", + "package": "${package}", # module name relative to the package + "class": "MysqlAsyncSadLock", + }, + "postgresql": { + "module": ".lock.postgresql", + "package": "${package}", # module name relative to the package + "class": "PostgresqlAsyncSadLock", + }, +} diff --git a/src/sqlalchemy_dlock/types.py b/src/sqlalchemy_dlock/types.py index 0a6ad22..94e39c7 100644 --- a/src/sqlalchemy_dlock/types.py +++ b/src/sqlalchemy_dlock/types.py @@ -7,8 +7,10 @@ from typing import TypeAlias from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_scoped_session from sqlalchemy.orm import Session, scoped_session -__all__ = ["TConnectionOrSession"] +__all__ = ["ConnectionOrSessionT", "AsyncConnectionOrSessionT"] -TConnectionOrSession: TypeAlias = Union[Connection, Session, scoped_session] +ConnectionOrSessionT: TypeAlias = Union[Connection, Session, scoped_session] +AsyncConnectionOrSessionT: TypeAlias = Union[AsyncConnection, AsyncSession, async_scoped_session] diff --git a/src/sqlalchemy_dlock/utils.py b/src/sqlalchemy_dlock/utils.py index 4e8038c..2261010 100644 --- a/src/sqlalchemy_dlock/utils.py +++ b/src/sqlalchemy_dlock/utils.py @@ -1,19 +1,19 @@ from __future__ import annotations -import re +from functools import lru_cache from hashlib import blake2b +from importlib import import_module from io import BytesIO +from string import Template from sys import byteorder from typing import TYPE_CHECKING, Union +from . import registry + if TYPE_CHECKING: # pragma: no cover from _typeshed import ReadableBuffer -def safe_name(s: str) -> str: - return re.sub(r"[^A-Za-z0-9_]+", "_", s).strip().lower() - - def to_int64_key(k: Union[int, str, ReadableBuffer]) -> int: if isinstance(k, int): return ensure_int64(k) @@ -53,3 +53,15 @@ def ensure_int64(i: int) -> int: if i < -0x8000_0000_0000_0000: raise OverflowError("int too small") return i + + +@lru_cache +def find_lock_class(engine_name, is_asyncio=False): + reg = registry.ASYNCIO_REGISTRY if is_asyncio else registry.REGISTRY + conf = reg[engine_name] + package = conf.get("package") + if package: + package = Template(package).safe_substitute(package=__package__) + module = import_module(conf["module"], package) + class_ = getattr(module, conf["class"]) + return class_ diff --git a/tests/asyncio/test_basic.py b/tests/asyncio/test_basic.py index 7c812a7..1d7f9ab 100644 --- a/tests/asyncio/test_basic.py +++ b/tests/asyncio/test_basic.py @@ -5,7 +5,7 @@ from unittest import IsolatedAsyncioTestCase from uuid import uuid4 -from sqlalchemy_dlock.asyncio import create_async_sadlock +from sqlalchemy_dlock import create_async_sadlock from .engines import create_engines, dispose_engines, get_engines diff --git a/tests/asyncio/test_concurrency.py b/tests/asyncio/test_concurrency.py index bf119ea..ac456a1 100644 --- a/tests/asyncio/test_concurrency.py +++ b/tests/asyncio/test_concurrency.py @@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy_dlock.asyncio import create_async_sadlock +from sqlalchemy_dlock import create_async_sadlock from .engines import create_engines, dispose_engines, get_engines diff --git a/tests/asyncio/test_key_convert.py b/tests/asyncio/test_key_convert.py index 9da66d0..cbe5321 100644 --- a/tests/asyncio/test_key_convert.py +++ b/tests/asyncio/test_key_convert.py @@ -4,7 +4,7 @@ from uuid import uuid4 from zlib import crc32 -from sqlalchemy_dlock.asyncio import create_async_sadlock +from sqlalchemy_dlock import create_async_sadlock from sqlalchemy_dlock.lock.mysql import MYSQL_LOCK_NAME_MAX_LENGTH from .engines import create_engines, dispose_engines, get_engines diff --git a/tests/asyncio/test_pg.py b/tests/asyncio/test_pg.py index 332e802..ace755b 100644 --- a/tests/asyncio/test_pg.py +++ b/tests/asyncio/test_pg.py @@ -3,7 +3,7 @@ from unittest import IsolatedAsyncioTestCase, skipIf from uuid import uuid4 -from sqlalchemy_dlock.asyncio import create_async_sadlock +from sqlalchemy_dlock import create_async_sadlock from .engines import create_engines, dispose_engines, get_engines diff --git a/tests/asyncio/test_session.py b/tests/asyncio/test_session.py index cc8de01..3455dbf 100644 --- a/tests/asyncio/test_session.py +++ b/tests/asyncio/test_session.py @@ -3,7 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy_dlock.asyncio import create_async_sadlock +from sqlalchemy_dlock import create_async_sadlock from .engines import create_engines, dispose_engines, get_engines