diff --git a/backend/src/playlist/sql.py b/backend/src/playlist/sql.py index a664b00..c8c081b 100644 --- a/backend/src/playlist/sql.py +++ b/backend/src/playlist/sql.py @@ -1,27 +1,28 @@ """Contains all SQL code here.""" import collections -import collections.abc import contextlib -import datetime import functools import inspect -import typing +from collections.abc import AsyncGenerator +from collections.abc import Coroutine +from datetime import datetime +from typing import Any +from typing import Protocol import aiologger import sqlalchemy -import sqlalchemy.exc -import sqlalchemy.ext import sqlalchemy.orm from playlist import enums from playlist import env from playlist import models +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine # Initialize asynchronous logger -logger = aiologger.Logger.with_default_handlers(name="sql_logger") +logger = aiologger.Logger.with_default_handlers(name=__name__) # Create an asynchronous SQLAlchemy engine _engine = create_async_engine(env.DATABASE_URL, echo=True) @@ -32,15 +33,15 @@ _async_session_maker = sqlalchemy.orm.sessionmaker( ) -class Sessionizable(typing.Protocol): +class Sessionizable(Protocol): """Protocol defining the signature of a sessionizable function.""" def __call__( self, - *args: tuple[typing.Any, ...], + *args: tuple[Any, ...], session: AsyncSession | None = None, - **kwargs: dict[str, typing.Any], - ) -> (collections.abc.AsyncGenerator, collections.abc.Coroutine): + **kwargs: dict[str, Any], + ) -> (AsyncGenerator, Coroutine): """The signature for a sessionizable function. Args: @@ -51,7 +52,7 @@ class Sessionizable(typing.Protocol): session: The database session. If set to None, this will be filled with a new session. Returns: - (collections.abc.AsyncGenerator, collections.abc.Coroutine): The async generator or + (AsyncGenerator, Coroutine): The async generator or coroutine that would be created by this function. """ ... @@ -60,7 +61,7 @@ class Sessionizable(typing.Protocol): @contextlib.asynccontextmanager async def _db_logger( func_name: str, session: AsyncSession -) -> collections.abc.AsyncGenerator[None, None]: +) -> AsyncGenerator[None, None]: """Asynchronous context manager for logging database errors and handling rollbacks. This context manager logs SQLAlchemy-specific exceptions and other exceptions that occur during @@ -80,7 +81,7 @@ async def _db_logger( try: yield - except sqlalchemy.exc.SQLAlchemyError as e: + except SQLAlchemyError as e: await logger.exception(f"Database error in {func_name}: {e}") await session.rollback() raise @@ -168,7 +169,7 @@ def sessionize(func: Sessionizable) -> Sessionizable: @functools.wraps(func) async def _coro_wrapper[ **P - ](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine: + ](*args: P.args, **kwargs: P.kwargs) -> Coroutine: """Wrap a sessionized coroutine function to inject the session if needed. This wrapper function manages the database session for coroutine functions. @@ -190,7 +191,7 @@ def sessionize(func: Sessionizable) -> Sessionizable: @functools.wraps(func) async def _asyncgen_wrapper[ **P - ](*args: P.args, **kwargs: P.kwargs) -> collections.abc.AsyncGenerator: + ](*args: P.args, **kwargs: P.kwargs) -> AsyncGenerator: """Wrap a sessionized async generator function to inject the session if needed. This wrapper function manages the database session for async generator functions. @@ -210,10 +211,7 @@ def sessionize(func: Sessionizable) -> Sessionizable: ret = _asyncgen_wrapper case _: - msg = ( - f"Decorated function {func.__name__} is not an async generator or coroutine" - "function." - ) + msg = f"{func.__name__} is not an async generator or coroutine function." logger.error(msg) raise TypeError(msg) @@ -227,7 +225,7 @@ async def init_db() -> None: await conn.run_sync(models.mapper_registry.metadata.create_all) await logger.info("Database tables created.") except Exception as e: - await logger.error(f"Failed to create tables: {e}") + await logger.exception(f"Failed to create tables: {e}") async def drop_db() -> None: @@ -237,12 +235,12 @@ async def drop_db() -> None: await conn.run_sync(models.mapper_registry.metadata.drop_all) await logger.info("Database tables dropped.") except Exception as e: - await logger.error(f"Failed to drop tables: {e}") + await logger.exception(f"Failed to drop tables: {e}") @sessionize async def insert_or_update_track( - track_data: dict[str, typing.Any], + track_data: dict[str, Any], *, session: AsyncSession | None = None, ) -> None: @@ -303,7 +301,7 @@ async def set_timestamp( operation: enums.TimedEvent, *, session: AsyncSession | None = None, -) -> datetime.datetime: +) -> datetime: """Mark the start or end timestamps in the statistics model. Args: @@ -314,7 +312,7 @@ async def set_timestamp( session: The database session. Returns: - datetime.datetime: The datetime.datetime of the last update. + datetime: The datetime of the last update. """ async with session.begin(): stats = await session.get(models.Statistics, 1) @@ -322,16 +320,16 @@ async def set_timestamp( match [event, operation]: case [enums.Event.UPDATER, enums.EventTime.START]: - stats_obj.start_update = datetime.datetime.now() + stats_obj.start_update = datetime.now() log_message = "Updater start timestamp updated." case [enums.Event.UPDATER, enums.EventTime.END]: - stats_obj.end_update = datetime.datetime.now() + stats_obj.end_update = datetime.now() log_message = "Updater end timestamp updated." case [enums.Event.GENERATOR, enums.EventTime.START]: - stats_obj.start_playlist_gen = datetime.datetime.now() + stats_obj.start_playlist_gen = datetime.now() log_message = "Generator start timestamp updated." case [enums.Event.GENERATOR, enums.EventTime.END]: - stats_obj.end_playlist_gen = datetime.datetime.now() + stats_obj.end_playlist_gen = datetime.now() log_message = "Generator end timestamp updated." case _: raise ValueError( @@ -473,7 +471,7 @@ async def fetch_tracks_for_sublist( playlist_entries: set[int], *, session: AsyncSession | None = None, -) -> collections.abc.AsyncGenerator[models.Track, None]: +) -> AsyncGenerator[models.Track, None]: """Fetch tracks for a given sublist type within a category. Args: @@ -488,7 +486,7 @@ async def fetch_tracks_for_sublist( session: The database session. Returns: - collections.abc.AsyncGenerator[models.Track, None]: Generates the tracks fitting the + AsyncGenerator[models.Track, None]: Generates the tracks fitting the criteria. """ query = sqlalchemy.select(models.Track) @@ -563,7 +561,7 @@ async def insert_into_playlist( # noqa: C901 @sessionize async def gen_all_category_ids( *, session: AsyncSession | None = None -) -> collections.abc.AsyncGenerator[int, None, None]: +) -> AsyncGenerator[int, None, None]: """Generate the current category ids from the database. Keyword Args: @@ -581,7 +579,7 @@ async def gen_all_category_ids( @sessionize async def gen_active_holiday_ids( *, session: AsyncSession | None = None -) -> collections.abc.AsyncGenerator[int, None, None]: +) -> AsyncGenerator[int, None, None]: """Generate the current active holiday ids from the database. Keyword Args: