@@ -23,8 +23,10 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
# Initialize asynchronous logger
|
||||
logger = aiologger.Logger.with_default_handlers(name="sql_logger")
|
||||
|
||||
# Create an asynchronous SQLAlchemy engine
|
||||
_engine = create_async_engine(env.DATABASE_URL, echo=True)
|
||||
|
||||
# Create a session maker for asynchronous sessions
|
||||
_async_session_maker = sqlalchemy.orm.sessionmaker(
|
||||
_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
@@ -42,7 +44,7 @@ class Sessionizable(typing.Protocol):
|
||||
"""The signature for a sessionizable function.
|
||||
|
||||
Args:
|
||||
args: The positional argument for the function.
|
||||
args: The positional arguments for the function.
|
||||
kwargs: The keyword arguments for the function (except for session).
|
||||
|
||||
Keyword Args:
|
||||
@@ -59,6 +61,22 @@ class Sessionizable(typing.Protocol):
|
||||
async def _db_logger(
|
||||
func_name: str, session: AsyncSession
|
||||
) -> collections.abc.AsyncGenerator[None, None]:
|
||||
"""Asynchronous context manager for logging database errors and handling rollbacks.
|
||||
|
||||
This context manager logs SQLAlchemy-specific exceptions and other exceptions that occur during
|
||||
the execution of a function. It also handles rolling back the session in case of an error.
|
||||
|
||||
Args:
|
||||
func_name (str): The name of the function being executed.
|
||||
session (AsyncSession): The database session used in the function.
|
||||
|
||||
Yields:
|
||||
None: This context manager does not yield any value.
|
||||
|
||||
Raises:
|
||||
SQLAlchemyError: Re-raises any SQLAlchemy-specific exceptions after logging and rollback.
|
||||
Exception: Re-raises any other exceptions after logging and rollback.
|
||||
"""
|
||||
try:
|
||||
yield
|
||||
|
||||
@@ -74,6 +92,17 @@ async def _db_logger(
|
||||
|
||||
|
||||
def _validate_signature(func: Sessionizable) -> None:
|
||||
"""Validate that the given function has a 'session' keyword argument.
|
||||
|
||||
This function checks if the 'session' parameter is present in the function's signature.
|
||||
If the 'session' parameter is missing, it raises an UnboundLocalError.
|
||||
|
||||
Args:
|
||||
func (Sessionizable): The function to validate.
|
||||
|
||||
Raises:
|
||||
UnboundLocalError: If the function does not have a 'session' parameter.
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
try:
|
||||
sig.parameters["session"]
|
||||
@@ -140,7 +169,10 @@ def sessionize(func: Sessionizable) -> Sessionizable:
|
||||
async def _coro_wrapper[
|
||||
**P
|
||||
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
|
||||
"""Wrap a sessionized coroutine function to inject the session if needed."""
|
||||
"""Wrap a sessionized coroutine function to inject the session if needed.
|
||||
|
||||
This wrapper function manages the database session for coroutine functions.
|
||||
"""
|
||||
if kwargs["session"] is None:
|
||||
async with (
|
||||
_async_session_maker() as kwargs["session"],
|
||||
@@ -159,7 +191,10 @@ def sessionize(func: Sessionizable) -> Sessionizable:
|
||||
async def _asyncgen_wrapper[
|
||||
**P
|
||||
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.AsyncGenerator:
|
||||
"""Wrap a sessionized async generator function to inject the session if needed."""
|
||||
"""Wrap a sessionized async generator function to inject the session if needed.
|
||||
|
||||
This wrapper function manages the database session for async generator functions.
|
||||
"""
|
||||
if kwargs["session"] is None:
|
||||
async with (
|
||||
_async_session_maker() as kwargs["session"],
|
||||
|
||||
Reference in New Issue
Block a user