Some refactoring.
Some checks failed
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (macos-latest, 3.10, tests) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.10, docs-build) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.10, mypy) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.10, pre-commit) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.10, safety) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.10, tests) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.10, typeguard) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.10, xdoctest) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.7, mypy) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.7, tests) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.8, mypy) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.8, tests) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.9, mypy) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (ubuntu-latest, 3.9, tests) (push) Has been cancelled
Tests / ${{ matrix.session }} ${{ matrix.python }} / ${{ matrix.os }} (windows-latest, 3.10, tests) (push) Has been cancelled
Tests / coverage (push) Has been cancelled

Signed-off-by: Cliff Hill <xlorep@darkhelm.org>
This commit is contained in:
2024-05-14 18:16:53 -04:00
parent 6c72ad951c
commit f881227b26

View File

@@ -1,27 +1,28 @@
"""Contains all SQL code here."""
import collections
import collections.abc
import contextlib
import datetime
import functools
import inspect
import typing
from collections.abc import AsyncGenerator
from collections.abc import Coroutine
from datetime import datetime
from typing import Any
from typing import Protocol
import aiologger
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.ext
import sqlalchemy.orm
from playlist import enums
from playlist import env
from playlist import models
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
# Initialize asynchronous logger
logger = aiologger.Logger.with_default_handlers(name="sql_logger")
logger = aiologger.Logger.with_default_handlers(name=__name__)
# Create an asynchronous SQLAlchemy engine
_engine = create_async_engine(env.DATABASE_URL, echo=True)
@@ -32,15 +33,15 @@ _async_session_maker = sqlalchemy.orm.sessionmaker(
)
class Sessionizable(typing.Protocol):
class Sessionizable(Protocol):
"""Protocol defining the signature of a sessionizable function."""
def __call__(
self,
*args: tuple[typing.Any, ...],
*args: tuple[Any, ...],
session: AsyncSession | None = None,
**kwargs: dict[str, typing.Any],
) -> (collections.abc.AsyncGenerator, collections.abc.Coroutine):
**kwargs: dict[str, Any],
) -> (AsyncGenerator, Coroutine):
"""The signature for a sessionizable function.
Args:
@@ -51,7 +52,7 @@ class Sessionizable(typing.Protocol):
session: The database session. If set to None, this will be filled with a new session.
Returns:
(collections.abc.AsyncGenerator, collections.abc.Coroutine): The async generator or
(AsyncGenerator, Coroutine): The async generator or
coroutine that would be created by this function.
"""
...
@@ -60,7 +61,7 @@ class Sessionizable(typing.Protocol):
@contextlib.asynccontextmanager
async def _db_logger(
func_name: str, session: AsyncSession
) -> collections.abc.AsyncGenerator[None, None]:
) -> AsyncGenerator[None, None]:
"""Asynchronous context manager for logging database errors and handling rollbacks.
This context manager logs SQLAlchemy-specific exceptions and other exceptions that occur during
@@ -80,7 +81,7 @@ async def _db_logger(
try:
yield
except sqlalchemy.exc.SQLAlchemyError as e:
except SQLAlchemyError as e:
await logger.exception(f"Database error in {func_name}: {e}")
await session.rollback()
raise
@@ -168,7 +169,7 @@ def sessionize(func: Sessionizable) -> Sessionizable:
@functools.wraps(func)
async def _coro_wrapper[
**P
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
](*args: P.args, **kwargs: P.kwargs) -> Coroutine:
"""Wrap a sessionized coroutine function to inject the session if needed.
This wrapper function manages the database session for coroutine functions.
@@ -190,7 +191,7 @@ def sessionize(func: Sessionizable) -> Sessionizable:
@functools.wraps(func)
async def _asyncgen_wrapper[
**P
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.AsyncGenerator:
](*args: P.args, **kwargs: P.kwargs) -> AsyncGenerator:
"""Wrap a sessionized async generator function to inject the session if needed.
This wrapper function manages the database session for async generator functions.
@@ -210,10 +211,7 @@ def sessionize(func: Sessionizable) -> Sessionizable:
ret = _asyncgen_wrapper
case _:
msg = (
f"Decorated function {func.__name__} is not an async generator or coroutine"
"function."
)
msg = f"{func.__name__} is not an async generator or coroutine function."
logger.error(msg)
raise TypeError(msg)
@@ -227,7 +225,7 @@ async def init_db() -> None:
await conn.run_sync(models.mapper_registry.metadata.create_all)
await logger.info("Database tables created.")
except Exception as e:
await logger.error(f"Failed to create tables: {e}")
await logger.exception(f"Failed to create tables: {e}")
async def drop_db() -> None:
@@ -237,12 +235,12 @@ async def drop_db() -> None:
await conn.run_sync(models.mapper_registry.metadata.drop_all)
await logger.info("Database tables dropped.")
except Exception as e:
await logger.error(f"Failed to drop tables: {e}")
await logger.exception(f"Failed to drop tables: {e}")
@sessionize
async def insert_or_update_track(
track_data: dict[str, typing.Any],
track_data: dict[str, Any],
*,
session: AsyncSession | None = None,
) -> None:
@@ -303,7 +301,7 @@ async def set_timestamp(
operation: enums.TimedEvent,
*,
session: AsyncSession | None = None,
) -> datetime.datetime:
) -> datetime:
"""Mark the start or end timestamps in the statistics model.
Args:
@@ -314,7 +312,7 @@ async def set_timestamp(
session: The database session.
Returns:
datetime.datetime: The datetime.datetime of the last update.
datetime: The datetime of the last update.
"""
async with session.begin():
stats = await session.get(models.Statistics, 1)
@@ -322,16 +320,16 @@ async def set_timestamp(
match [event, operation]:
case [enums.Event.UPDATER, enums.EventTime.START]:
stats_obj.start_update = datetime.datetime.now()
stats_obj.start_update = datetime.now()
log_message = "Updater start timestamp updated."
case [enums.Event.UPDATER, enums.EventTime.END]:
stats_obj.end_update = datetime.datetime.now()
stats_obj.end_update = datetime.now()
log_message = "Updater end timestamp updated."
case [enums.Event.GENERATOR, enums.EventTime.START]:
stats_obj.start_playlist_gen = datetime.datetime.now()
stats_obj.start_playlist_gen = datetime.now()
log_message = "Generator start timestamp updated."
case [enums.Event.GENERATOR, enums.EventTime.END]:
stats_obj.end_playlist_gen = datetime.datetime.now()
stats_obj.end_playlist_gen = datetime.now()
log_message = "Generator end timestamp updated."
case _:
raise ValueError(
@@ -473,7 +471,7 @@ async def fetch_tracks_for_sublist(
playlist_entries: set[int],
*,
session: AsyncSession | None = None,
) -> collections.abc.AsyncGenerator[models.Track, None]:
) -> AsyncGenerator[models.Track, None]:
"""Fetch tracks for a given sublist type within a category.
Args:
@@ -488,7 +486,7 @@ async def fetch_tracks_for_sublist(
session: The database session.
Returns:
collections.abc.AsyncGenerator[models.Track, None]: Generates the tracks fitting the
AsyncGenerator[models.Track, None]: Generates the tracks fitting the
criteria.
"""
query = sqlalchemy.select(models.Track)
@@ -563,7 +561,7 @@ async def insert_into_playlist( # noqa: C901
@sessionize
async def gen_all_category_ids(
*, session: AsyncSession | None = None
) -> collections.abc.AsyncGenerator[int, None, None]:
) -> AsyncGenerator[int, None, None]:
"""Generate the current category ids from the database.
Keyword Args:
@@ -581,7 +579,7 @@ async def gen_all_category_ids(
@sessionize
async def gen_active_holiday_ids(
*, session: AsyncSession | None = None
) -> collections.abc.AsyncGenerator[int, None, None]:
) -> AsyncGenerator[int, None, None]:
"""Generate the current active holiday ids from the database.
Keyword Args: