"""Contains all SQL code here.""" import collections import collections.abc import datetime import functools import inspect import typing import aiologger import sqlalchemy import sqlalchemy.exc import sqlalchemy.ext import sqlalchemy.orm from playlist import enums from playlist import env from playlist import models # Initialize asynchronous logger logger = aiologger.Logger.with_default_handlers(name="sql_logger") engine = sqlalchemy.ext.asyncio.create_async_engine(env.DATABASE_URL, echo=True) mapper_registry = sqlalchemy.orm.registry() async_session_maker = sqlalchemy.orm.sessionmaker( engine, class_=sqlalchemy.ext.asyncio.AsyncSession, expire_on_commit=False ) class Sessionizable(typing.Protocol): """Protocol defining the signature of a sessionizable function.""" def __call__( self, *args: tuple[typing.Any, ...], session: sqlalchemy.ext.asyncio.AsyncSession = None, **kwargs: dict[str, typing.Any], ) -> (collections.abc.AsyncGenerator, collections.abc.Coroutine): """The signature for a sessionizable function. Args: args: The positional argument for the function. kwargs: The keyword arguments for the function (except for session). Keyword Args: session: The database session. If set to None, this will be filled with a new session. Returns: (collections.abc.AsyncGenerator, collections.abc.Coroutine): The async generator or coroutine that would be created by this function. """ ... async def _validate_session_arg(name: str, kwargs) -> None: try: if not ( kwargs["session"] is None or isinstance(kwargs["session"], sqlalchemy.ext.asyncio.AsyncSession) ): msg = f"{name} has an invalid session of type {type(kwargs['session'])}." await logger.error(msg) raise TypeError(msg) except KeyError as e: msg = f"{name} does not contain a session keyword argument." await logger.error(msg) raise UnboundLocalError(msg) from e def _coro_db_logger(func: Sessionizable) -> Sessionizable: @functools.wraps(func) async def wrapper[ **P ](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine: try: return await func(*args, **kwargs) except sqlalchemy.exc.SQLAlchemyError as e: await logger.exception(f"Database error in {func.__name__}: {e}") await kwargs["session"].rollback() raise except Exception as e: await logger.exception(f"Unexpected error in {func.__name__}: {e}") await kwargs["session"].rollback() raise def _async_gen_db_logger(func: Sessionizable) -> Sessionizable: @functools.wraps(func) async def wrapper[ **P ](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine: try: async for item in func(*args, **kwargs): yield item except sqlalchemy.exc.SQLAlchemyError as e: await logger.exception(f"Database error in {func.__name__}: {e}") await kwargs["session"].rollback() raise except Exception as e: await logger.exception(f"Unexpected error in {func.__name__}: {e}") await kwargs["session"].rollback() raise def sessionize(func: Sessionizable) -> Sessionizable: """Decorator that ensures a database session is available to an async function or generator. This decorator automatically injects a `session` of type sqlalchemy.ext.asyncio.AsyncSession into the decorated function if one is not provided. If the `session` keyword argument is missing or set to None, a new session is created using the async_session_maker and passed to the function. If a session is already provided when the function is called, it uses the existing session. The decorator is intended for use with coroutine functions and async generator functions that are expected to interact with a database using SQLAlchemy's asynchronous session management. Features: - Automatically manages database sessions to ensure efficient and correct usage of database connections. - Logs any SQLAlchemy-specific exceptions or other exceptions that occur during the execution of the decorated function, handling rollbacks if necessary. - Ensures that functions without a proper 'session' keyword argument are not decorated, raising an UnboundLocalError. Args: func (Sessionizable): A coroutine function or an async generator function that accepts a 'session' keyword argument. Returns: Sessionizable: A wrapped version of the input function that manages a database session. Raises: TypeError: If `func` is neither a coroutine function nor an async generator function. UnboundLocalError: If `func` lacks a 'session' keyword argument. Example: @sessionize async def fetch_data(session: Optional[AsyncSession] = None): # Function body using session ... Note: - The function must include a 'session' parameter in its signature, which can be None by default. - This decorator is only applicable to functions intended to perform database operations asynchronously. See Also: - Sessionizable: The Protocol for sessionizable functions. - AsyncSession: SQLAlchemy class used for asynchronous session management. """ ret: Sessionizable match func: case func if inspect.iscoroutinefunction(func): @functools.wraps(func) async def _coro_wrapper[ **P ](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine: await _validate_session_arg(func.__name__, kwargs) logged_func = _coro_db_logger(func) if kwargs["session"] is None: async with async_session_maker as kwargs["session"]: return await logged_func(*args, **kwargs) else: return await logged_func(*args, **kwargs) ret = _coro_wrapper case func if inspect.isasyncgenfunction(func): @functools.wraps(func) async def _asyncgen_wrapper[ **P ](*args: P.args, **kwargs: P.kwargs) -> collections.abc.AsyncGenerator: """Wrap a sessionized async generator function to inject the session if needed.""" await _validate_session_arg(func.__name__, kwargs) logged_func = _async_gen_db_logger(func) if kwargs["session"] is None: async with async_session_maker as kwargs["session"]: async for element in logged_func(*args, **kwargs): yield element else: async for element in logged_func(*args, **kwargs): yield element ret = _asyncgen_wrapper case _: msg = ( f"Decorated function {func.__name__} is not an async generator or coroutine" "function." ) logger.error(msg) raise TypeError(msg) return ret async def init_db() -> None: """Create all tables asynchronously.""" try: async with engine.begin() as conn: await conn.run_sync(mapper_registry.metadata.create_all) await logger.info("Database tables created.") except Exception as e: await logger.error(f"Failed to create tables: {e}") async def drop_db() -> None: """Drop all tables asynchronously for clean slate testing or teardown.""" try: async with engine.begin() as conn: await conn.run_sync(mapper_registry.metadata.drop_all) await logger.info("Database tables dropped.") except Exception as e: await logger.error(f"Failed to drop tables: {e}") @sessionize async def insert_or_update_track( track_data: dict[str, typing.Any], *, session: sqlalchemy.ext.asyncio.AsyncSession = None, ) -> None: """Insert a new track or update an existing one in the database asynchronously. Args: track_data: The track data to write to the database. Keyword Args: session: The database session. """ # noqa: D417 async with session.begin(): genre = await session.get(models.Genre, track_data.get("genre_id")) if not genre: genre = models.Genre(name=track_data["genre_name"]) session.add(genre) await session.flush() # Ensures 'genre' is persisted and has an 'id' track = await session.get(models.Track, track_data.get("id")) if track: for key, value in track_data.items(): setattr(track, key, value) await logger.info(f"Updated track: {track.title}") else: track = models.Track(**track_data) session.add(track) await logger.info(f"Inserted new track: {track.title}") await session.commit() @sessionize async def cleanup_unused_genres( *, session: sqlalchemy.ext.asyncio.AsyncSession = None ) -> None: """Remove genres that are no longer used by any tracks. Keyword Args: session: The database session. """ async with session.begin(): stmt = ( sqlalchemy.select(models.Genre) .outerjoin(models.Track) .filter(models.Track.id is None) ) result = await session.execute(stmt) unused_genres = result.scalars().all() for genre in unused_genres: await session.delete(genre) await session.commit() await logger.info(f"Cleaned up {len(unused_genres)} unused genres.") @sessionize async def set_timestamp( event: enums.Event, operation: enums.TimedEvent, *, session: sqlalchemy.ext.asyncio.AsyncSession = None, ) -> datetime.datetime: """Mark the start or end timestamps in the statistics model. Args: event: Identifies what event is being affected. operation: Identifies if it is the start or end of the event. Keyword Args: session: The database session. Returns: datetime.datetime: The datetime.datetime of the last update. """ async with session.begin(): stats = await session.get(models.Statistics, 1) stats_obj = stats.scalars().first() match [event, operation]: case [enums.Event.UPDATER, enums.EventTime.START]: stats_obj.start_update = datetime.datetime.now() log_message = "Updater start timestamp updated." case [enums.Event.UPDATER, enums.EventTime.END]: stats_obj.end_update = datetime.datetime.now() log_message = "Updater end timestamp updated." case [enums.Event.GENERATOR, enums.EventTime.START]: stats_obj.start_playlist_gen = datetime.datetime.now() log_message = "Generator start timestamp updated." case [enums.Event.GENERATOR, enums.EventTime.END]: stats_obj.end_playlist_gen = datetime.datetime.now() log_message = "Generator end timestamp updated." case _: raise ValueError( f"Invalid event {event} or operation {operation} specified." ) match event: case enums.Event.UPDATER: last_updated = stats_obj.end_update case enums.Event.GENERATOR: last_updated = stats_obj.end_playlist_gen case _: raise ValueError(f"Invalid event {event} specified.") await session.commit() await logger.info(log_message) return last_updated @sessionize async def remove_holiday_by_id( holiday_id: int, *, session: sqlalchemy.ext.asyncio.AsyncSession = None ) -> None: """Remove a holiday from the database by ID. Captures and logs database-specific errors and general exceptions. Args: holiday_id: The holiday ID to remove. Keyword Args: session: The database session. """ # noqa: D417 holiday = await session.get(models.Holiday, holiday_id) if holiday: await session.delete(holiday) await session.commit() else: await logger.warning(f"No holiday found with ID: {holiday_id}") @sessionize async def update_statistics( new_average_duration: float, new_podcast_length: float, new_max_playtime: float, new_category_count: int, new_active_holidays: int, new_max_tracks: int, new_max_regular: int, new_max_holiday: int, new_max_regular_favorite: int, new_max_regular_general: int, new_max_holiday_favorite: int, new_max_holiday_general: int, new_max_regular_favorite_base: int, new_max_regular_general_base: int, new_max_holiday_favorite_base: int, new_max_holiday_general_base: int, *, session: sqlalchemy.ext.asyncio.AsyncSession = None, ) -> None: """Update the statistics model. Args: new_average_duration: The new average duration of a track by geometric mean of all track durations as a float. new_podcast_length: The new total podcast length of all current episodes as a float. new_max_playtime: The new max playtime as a float. new_category_count: The new count of categories as an int. new_active_holidays: The new active holidays as an int. new_max_tracks: The new max tracks per day as an int. new_max_regular: The new max regular tracks as an int. new_max_holiday: The new max holiday tracks as an int. new_max_regular_favorite: The new max regular favorite tracks as an int. new_max_regular_general: The new max regular general tracks as an int. new_max_holiday_favorite: The new max holiday favorite tracks as an int. new_max_holiday_general: The new max holiday general tracks as an int. new_max_regular_favorite_base: The new max regular favorite base unit as an int. new_max_regular_general_base: The new max regular general base unit as an int. new_max_holiday_favorite_base: The new max holiday favorite base unit as an int. new_max_holiday_general_base: The new max holiday general base unit as an int. Keyword Args: session: The database session. """ # noqa: D417 async with session.begin(): # Assuming there's only one statistics record, or you might need to handle this differently stats = await session.get(models.Statistics, 1) if stats: stats.average_track_length = new_average_duration stats.total_podcast_length = new_podcast_length stats.max_playtime = new_max_playtime stats.category_count = new_category_count stats.active_holiday_count = new_active_holidays stats.max_tracks_per_day = new_max_tracks stats.max_regular_tracks = new_max_regular stats.max_holiday_tracks = new_max_holiday stats.max_regular_favorite_tracks = new_max_regular_favorite stats.max_regular_general_tracks = new_max_regular_general stats.max_holiday_favorite_tracks = new_max_holiday_favorite stats.max_holiday_general_tracks = new_max_holiday_general stats.max_regular_favorite_base = new_max_regular_favorite_base stats.max_regular_general_base = new_max_regular_general_base stats.max_holiday_favorite_base = new_max_holiday_favorite_base stats.max_holiday_general_base = new_max_holiday_general_base await session.commit() else: # If no statistics entry exists, create one new_stats = models.Statistics( average_track_length=new_average_duration, total_podcast_length=new_podcast_length, max_playtime=new_max_playtime, active_holiday_count=new_active_holidays, max_tracks_per_day=new_max_tracks, max_regular_tracks=new_max_regular, max_holiday_tracks=new_max_holiday, max_regular_favorite_tracks=new_max_regular_favorite, max_regular_general_tracks=new_max_regular_general, max_holiday_favorite_tracks=new_max_holiday_favorite, max_holiday_general_tracks=new_max_holiday_general, max_regular_favorite_base=new_max_regular_favorite_base, max_regular_general_base=new_max_regular_general_base, max_holiday_favorite_base=new_max_holiday_favorite_base, max_holiday_general_base=new_max_holiday_general_base, ) session.add(new_stats) await session.commit() @sessionize async def fetch_tracks_for_sublist( is_category: bool, type_id: int, is_favorite: bool, sublist: enums.Sublist, limit: int, playlist_entries: set[int], *, session: sqlalchemy.ext.asyncio.AsyncSession = None, ) -> collections.abc.AsyncGenerator[models.Track, None]: """Fetch tracks for a given sublist type within a category. Args: is_category: True if a category, false if a holiday. type_id: Category or holiday ID for track filtering. is_favorite: Flag indicating if only favorite tracks should be fetched. sublist: Type of sublist to fetch. limit: Number of tracks to fetch. playlist_entries: The set of track IDs already in the playlist. Keyword Args: session: The database session. Returns: collections.abc.AsyncGenerator[models.Track, None]: Generates the tracks fitting the criteria. """ query = sqlalchemy.select(models.Track) if is_category: query = query.where(models.Track.category_id == type_id) else: query = query.join(models.Track.holidays).where(models.Holiday.id == type_id) query = query.filter(models.Track.id.not_in(playlist_entries)) if is_favorite: query = query.filter(models.Track.rating == 5) match sublist: case enums.Sublist.LEAST_RECENTLY_PLAYED: query = query.order_by(models.Track.last_played.asc()) case enums.Sublist.LEAST_OFTEN_PLAYED: query = query.order_by(models.Track.play_count.asc()) case enums.Sublist.LEAST_RECENTLY_ADDED: query = query.order_by(models.Track.date_added.asc()) case enums.Sublist.RANDOM: query = query.order_by(sqlalchemy.func.random()) case _: raise ValueError(f"Unknown sublist type: {sublist}") query = query.limit(limit) result = await session.execute(query) for track in result.scalars(): yield track @sessionize async def insert_into_playlist( # noqa: C901 is_category: bool, type_id: int, track_id: int, sublist: enums.Sublist, is_favorite: bool, *, session: sqlalchemy.ext.asyncio.AsyncSession = None, ) -> None: """Insert a track into the playlist. Args: is_category: True is a category, false is a holiday. type_id: Category or holiday ID associated with the playlist. track_id: Track ID to be added to the playlist. sublist: Sublist type under which the track is added. is_favorite: Indicates if the track is added as a favorite. Keyword Args: session: The database session. """ # noqa: D417 if is_category: playlist_track = models.PlaylistTrack( category_id=type_id, track_id=track_id, sublist_type=sublist, is_favorite=is_favorite, ) else: playlist_track = models.PlaylistTrack( holiday_id=type_id, track_id=track_id, sublist_type=sublist, is_favorite=is_favorite, ) session.add(playlist_track) await session.commit() @sessionize async def gen_all_category_ids( *, session: sqlalchemy.ext.asyncio.AsyncSession = None ) -> collections.abc.AsyncGenerator[int, None, None]: """Generate the current category ids from the database. Keyword Args: session: The database session. Yields: int: The category id. """ query = sqlalchemy.select(models.Category) result = session.execute(query) for row in result.scalars(): yield row.id @sessionize async def gen_active_holiday_ids( *, session: sqlalchemy.ext.asyncio.AsyncSession = None ) -> collections.abc.AsyncGenerator[int, None, None]: """Generate the current active holiday ids from the database. Keyword Args: session: The database session. Yields: int: The holiday id. """ query = sqlalchemy.select(models.Holiday).where(models.Holiday.is_active is True) result = session.execute(query) for row in result.scalars(): yield row.id @sessionize async def get_existing_playlist_track_info( *, session: sqlalchemy.ext.asyncio.AsyncSession = None, ) -> tuple[set[int], dict[tuple[bool, int, bool, enums.Sublist], int]]: """Get the existing playlist track information. Keyword Args: session: The database session. Returns: tuple[set[int], dict[tuple[bool, int, bool, enums.Sublist], int]]: Two pieces: First is a set of the ids of all of the tracks in the playlist. Second is a dict with the key being the combination of is_category, type_id, is_favorite, and sublist; the value is the count of that unique key combination. """ # TODO: Set up user-specific query here. query = sqlalchemy.select(models.PlaylistTrack).where( models.PlaylistTrack.episode_id is None ) result = session.execute(query) ids = set() data = [] for row in result.scalars(): if row.episode_id is not None: continue ids.append(row.track_id) is_category = row.category_id is not None type_id = row.category_id if is_category else row.holiday_id is_favorite = row.is_favorite sublist = row.sublist data.append((is_category, type_id, is_favorite, sublist)) data_counts = collections.Counter(data) return ids, data_counts @sessionize async def get_statistics( *, session: sqlalchemy.ext.asyncio.AsyncSession = None ) -> models.Statistics: """Get the statistics object for the database. Keyword Args: session: The database session. Returns: models.Statistics: The statistics object used for playlist processing. """ stats = await session.get(models.Statistics, 1) stats_obj = stats.scalars().first() return stats_obj @sessionize async def remove_track_from_playlist( server_id: str, *, session: sqlalchemy.ext.asyncio.AsyncSession = None ): """Directly removes all entries from PlaylistTrack that are associated with a given server_id. Args: server_id: The unique server ID of the track to remove from the playlist. Keyword Args: session: The current database session. Returns: int: Number of rows affected by the delete operation. """ # Delete directly using a join on Track where server_id matches delete_query = sqlalchemy.delete(models.PlaylistTrack).where( models.PlaylistTrack.track_id == sqlalchemy.select(models.Track.id) .where(models.Track.server_id == server_id) .scalar_subquery() ) result = await session.execute(delete_query) await session.commit() affected_rows = result.rowcount logger.info( f"Removed {affected_rows} entries from PlaylistTrack for server_id {server_id}." ) return affected_rows