Files
plex-playlist-old/backend/src/playlist/sql.py
2024-05-13 18:08:49 -04:00

670 lines
24 KiB
Python

"""Contains all SQL code here."""
import collections
import collections.abc
import datetime
import functools
import inspect
import typing
import aiologger
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.ext
import sqlalchemy.orm
from playlist import enums
from playlist import env
from playlist import models
# Initialize asynchronous logger
logger = aiologger.Logger.with_default_handlers(name="sql_logger")
engine = sqlalchemy.ext.asyncio.create_async_engine(env.DATABASE_URL, echo=True)
mapper_registry = sqlalchemy.orm.registry()
async_session_maker = sqlalchemy.orm.sessionmaker(
engine, class_=sqlalchemy.ext.asyncio.AsyncSession, expire_on_commit=False
)
class Sessionizable(typing.Protocol):
"""Protocol defining the signature of a sessionizable function."""
def __call__(
self,
*args: tuple[typing.Any, ...],
session: sqlalchemy.ext.asyncio.AsyncSession = None,
**kwargs: dict[str, typing.Any],
) -> (collections.abc.AsyncGenerator, collections.abc.Coroutine):
"""The signature for a sessionizable function.
Args:
args: The positional argument for the function.
kwargs: The keyword arguments for the function (except for session).
Keyword Args:
session: The database session. If set to None, this will be filled with a new session.
Returns:
(collections.abc.AsyncGenerator, collections.abc.Coroutine): The async generator or
coroutine that would be created by this function.
"""
...
async def _validate_session_arg(name: str, kwargs) -> None:
try:
if not (
kwargs["session"] is None
or isinstance(kwargs["session"], sqlalchemy.ext.asyncio.AsyncSession)
):
msg = f"{name} has an invalid session of type {type(kwargs['session'])}."
await logger.error(msg)
raise TypeError(msg)
except KeyError as e:
msg = f"{name} does not contain a session keyword argument."
await logger.error(msg)
raise UnboundLocalError(msg) from e
def _coro_db_logger(func: Sessionizable) -> Sessionizable:
@functools.wraps(func)
async def wrapper[
**P
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
try:
return await func(*args, **kwargs)
except sqlalchemy.exc.SQLAlchemyError as e:
await logger.exception(f"Database error in {func.__name__}: {e}")
await kwargs["session"].rollback()
raise
except Exception as e:
await logger.exception(f"Unexpected error in {func.__name__}: {e}")
await kwargs["session"].rollback()
raise
def _async_gen_db_logger(func: Sessionizable) -> Sessionizable:
@functools.wraps(func)
async def wrapper[
**P
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
try:
async for item in func(*args, **kwargs):
yield item
except sqlalchemy.exc.SQLAlchemyError as e:
await logger.exception(f"Database error in {func.__name__}: {e}")
await kwargs["session"].rollback()
raise
except Exception as e:
await logger.exception(f"Unexpected error in {func.__name__}: {e}")
await kwargs["session"].rollback()
raise
def sessionize(func: Sessionizable) -> Sessionizable:
"""Decorator that ensures a database session is available to an async function or generator.
This decorator automatically injects a `session` of type sqlalchemy.ext.asyncio.AsyncSession
into the decorated function if one is not provided. If the `session` keyword argument is
missing or set to None, a new session is created using the async_session_maker and passed to
the function. If a session is already provided when the function is called, it uses the
existing session.
The decorator is intended for use with coroutine functions and async generator functions that
are expected to interact with a database using SQLAlchemy's asynchronous session management.
Features:
- Automatically manages database sessions to ensure efficient and correct usage of database
connections.
- Logs any SQLAlchemy-specific exceptions or other exceptions that occur during the execution
of the decorated function, handling rollbacks if necessary.
- Ensures that functions without a proper 'session' keyword argument are not decorated, raising
an UnboundLocalError.
Args:
func (Sessionizable): A coroutine function or an async generator function that accepts a
'session' keyword argument.
Returns:
Sessionizable: A wrapped version of the input function that manages a database session.
Raises:
TypeError: If `func` is neither a coroutine function nor an async generator function.
UnboundLocalError: If `func` lacks a 'session' keyword argument.
Example:
@sessionize
async def fetch_data(session: Optional[AsyncSession] = None):
# Function body using session
...
Note:
- The function must include a 'session' parameter in its signature, which can be None by
default.
- This decorator is only applicable to functions intended to perform database operations
asynchronously.
See Also:
- Sessionizable: The Protocol for sessionizable functions.
- AsyncSession: SQLAlchemy class used for asynchronous session management.
"""
ret: Sessionizable
match func:
case func if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def _coro_wrapper[
**P
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
await _validate_session_arg(func.__name__, kwargs)
logged_func = _coro_db_logger(func)
if kwargs["session"] is None:
async with async_session_maker as kwargs["session"]:
return await logged_func(*args, **kwargs)
else:
return await logged_func(*args, **kwargs)
ret = _coro_wrapper
case func if inspect.isasyncgenfunction(func):
@functools.wraps(func)
async def _asyncgen_wrapper[
**P
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.AsyncGenerator:
"""Wrap a sessionized async generator function to inject the session if needed."""
await _validate_session_arg(func.__name__, kwargs)
logged_func = _async_gen_db_logger(func)
if kwargs["session"] is None:
async with async_session_maker as kwargs["session"]:
async for element in logged_func(*args, **kwargs):
yield element
else:
async for element in logged_func(*args, **kwargs):
yield element
ret = _asyncgen_wrapper
case _:
msg = (
f"Decorated function {func.__name__} is not an async generator or coroutine"
"function."
)
logger.error(msg)
raise TypeError(msg)
return ret
async def init_db() -> None:
"""Create all tables asynchronously."""
try:
async with engine.begin() as conn:
await conn.run_sync(mapper_registry.metadata.create_all)
await logger.info("Database tables created.")
except Exception as e:
await logger.error(f"Failed to create tables: {e}")
async def drop_db() -> None:
"""Drop all tables asynchronously for clean slate testing or teardown."""
try:
async with engine.begin() as conn:
await conn.run_sync(mapper_registry.metadata.drop_all)
await logger.info("Database tables dropped.")
except Exception as e:
await logger.error(f"Failed to drop tables: {e}")
@sessionize
async def insert_or_update_track(
track_data: dict[str, typing.Any],
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> None:
"""Insert a new track or update an existing one in the database asynchronously.
Args:
track_data: The track data to write to the database.
Keyword Args:
session: The database session.
""" # noqa: D417
async with session.begin():
genre = await session.get(models.Genre, track_data.get("genre_id"))
if not genre:
genre = models.Genre(name=track_data["genre_name"])
session.add(genre)
await session.flush() # Ensures 'genre' is persisted and has an 'id'
track = await session.get(models.Track, track_data.get("id"))
if track:
for key, value in track_data.items():
setattr(track, key, value)
await logger.info(f"Updated track: {track.title}")
else:
track = models.Track(**track_data)
session.add(track)
await logger.info(f"Inserted new track: {track.title}")
await session.commit()
@sessionize
async def cleanup_unused_genres(
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
) -> None:
"""Remove genres that are no longer used by any tracks.
Keyword Args:
session: The database session.
"""
async with session.begin():
stmt = (
sqlalchemy.select(models.Genre)
.outerjoin(models.Track)
.filter(models.Track.id is None)
)
result = await session.execute(stmt)
unused_genres = result.scalars().all()
for genre in unused_genres:
await session.delete(genre)
await session.commit()
await logger.info(f"Cleaned up {len(unused_genres)} unused genres.")
@sessionize
async def set_timestamp(
event: enums.Event,
operation: enums.TimedEvent,
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> datetime.datetime:
"""Mark the start or end timestamps in the statistics model.
Args:
event: Identifies what event is being affected.
operation: Identifies if it is the start or end of the event.
Keyword Args:
session: The database session.
Returns:
datetime.datetime: The datetime.datetime of the last update.
"""
async with session.begin():
stats = await session.get(models.Statistics, 1)
stats_obj = stats.scalars().first()
match [event, operation]:
case [enums.Event.UPDATER, enums.EventTime.START]:
stats_obj.start_update = datetime.datetime.now()
log_message = "Updater start timestamp updated."
case [enums.Event.UPDATER, enums.EventTime.END]:
stats_obj.end_update = datetime.datetime.now()
log_message = "Updater end timestamp updated."
case [enums.Event.GENERATOR, enums.EventTime.START]:
stats_obj.start_playlist_gen = datetime.datetime.now()
log_message = "Generator start timestamp updated."
case [enums.Event.GENERATOR, enums.EventTime.END]:
stats_obj.end_playlist_gen = datetime.datetime.now()
log_message = "Generator end timestamp updated."
case _:
raise ValueError(
f"Invalid event {event} or operation {operation} specified."
)
match event:
case enums.Event.UPDATER:
last_updated = stats_obj.end_update
case enums.Event.GENERATOR:
last_updated = stats_obj.end_playlist_gen
case _:
raise ValueError(f"Invalid event {event} specified.")
await session.commit()
await logger.info(log_message)
return last_updated
@sessionize
async def remove_holiday_by_id(
holiday_id: int, *, session: sqlalchemy.ext.asyncio.AsyncSession = None
) -> None:
"""Remove a holiday from the database by ID.
Captures and logs database-specific errors and general exceptions.
Args:
holiday_id: The holiday ID to remove.
Keyword Args:
session: The database session.
""" # noqa: D417
holiday = await session.get(models.Holiday, holiday_id)
if holiday:
await session.delete(holiday)
await session.commit()
else:
await logger.warning(f"No holiday found with ID: {holiday_id}")
@sessionize
async def update_statistics(
new_average_duration: float,
new_podcast_length: float,
new_max_playtime: float,
new_category_count: int,
new_active_holidays: int,
new_max_tracks: int,
new_max_regular: int,
new_max_holiday: int,
new_max_regular_favorite: int,
new_max_regular_general: int,
new_max_holiday_favorite: int,
new_max_holiday_general: int,
new_max_regular_favorite_base: int,
new_max_regular_general_base: int,
new_max_holiday_favorite_base: int,
new_max_holiday_general_base: int,
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> None:
"""Update the statistics model.
Args:
new_average_duration: The new average duration of a track by geometric mean
of all track durations as a float.
new_podcast_length: The new total podcast length of all current episodes as a float.
new_max_playtime: The new max playtime as a float.
new_category_count: The new count of categories as an int.
new_active_holidays: The new active holidays as an int.
new_max_tracks: The new max tracks per day as an int.
new_max_regular: The new max regular tracks as an int.
new_max_holiday: The new max holiday tracks as an int.
new_max_regular_favorite: The new max regular favorite tracks as an int.
new_max_regular_general: The new max regular general tracks as an int.
new_max_holiday_favorite: The new max holiday favorite tracks as an int.
new_max_holiday_general: The new max holiday general tracks as an int.
new_max_regular_favorite_base: The new max regular favorite base unit as an int.
new_max_regular_general_base: The new max regular general base unit as an int.
new_max_holiday_favorite_base: The new max holiday favorite base unit as an int.
new_max_holiday_general_base: The new max holiday general base unit as an int.
Keyword Args:
session: The database session.
""" # noqa: D417
async with session.begin():
# Assuming there's only one statistics record, or you might need to handle this differently
stats = await session.get(models.Statistics, 1)
if stats:
stats.average_track_length = new_average_duration
stats.total_podcast_length = new_podcast_length
stats.max_playtime = new_max_playtime
stats.category_count = new_category_count
stats.active_holiday_count = new_active_holidays
stats.max_tracks_per_day = new_max_tracks
stats.max_regular_tracks = new_max_regular
stats.max_holiday_tracks = new_max_holiday
stats.max_regular_favorite_tracks = new_max_regular_favorite
stats.max_regular_general_tracks = new_max_regular_general
stats.max_holiday_favorite_tracks = new_max_holiday_favorite
stats.max_holiday_general_tracks = new_max_holiday_general
stats.max_regular_favorite_base = new_max_regular_favorite_base
stats.max_regular_general_base = new_max_regular_general_base
stats.max_holiday_favorite_base = new_max_holiday_favorite_base
stats.max_holiday_general_base = new_max_holiday_general_base
await session.commit()
else:
# If no statistics entry exists, create one
new_stats = models.Statistics(
average_track_length=new_average_duration,
total_podcast_length=new_podcast_length,
max_playtime=new_max_playtime,
active_holiday_count=new_active_holidays,
max_tracks_per_day=new_max_tracks,
max_regular_tracks=new_max_regular,
max_holiday_tracks=new_max_holiday,
max_regular_favorite_tracks=new_max_regular_favorite,
max_regular_general_tracks=new_max_regular_general,
max_holiday_favorite_tracks=new_max_holiday_favorite,
max_holiday_general_tracks=new_max_holiday_general,
max_regular_favorite_base=new_max_regular_favorite_base,
max_regular_general_base=new_max_regular_general_base,
max_holiday_favorite_base=new_max_holiday_favorite_base,
max_holiday_general_base=new_max_holiday_general_base,
)
session.add(new_stats)
await session.commit()
@sessionize
async def fetch_tracks_for_sublist(
is_category: bool,
type_id: int,
is_favorite: bool,
sublist: enums.Sublist,
limit: int,
playlist_entries: set[int],
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> collections.abc.AsyncGenerator[models.Track, None]:
"""Fetch tracks for a given sublist type within a category.
Args:
is_category: True if a category, false if a holiday.
type_id: Category or holiday ID for track filtering.
is_favorite: Flag indicating if only favorite tracks should be fetched.
sublist: Type of sublist to fetch.
limit: Number of tracks to fetch.
playlist_entries: The set of track IDs already in the playlist.
Keyword Args:
session: The database session.
Returns:
collections.abc.AsyncGenerator[models.Track, None]: Generates the tracks fitting the
criteria.
"""
query = sqlalchemy.select(models.Track)
if is_category:
query = query.where(models.Track.category_id == type_id)
else:
query = query.join(models.Track.holidays).where(models.Holiday.id == type_id)
query = query.filter(models.Track.id.not_in(playlist_entries))
if is_favorite:
query = query.filter(models.Track.rating == 5)
match sublist:
case enums.Sublist.LEAST_RECENTLY_PLAYED:
query = query.order_by(models.Track.last_played.asc())
case enums.Sublist.LEAST_OFTEN_PLAYED:
query = query.order_by(models.Track.play_count.asc())
case enums.Sublist.LEAST_RECENTLY_ADDED:
query = query.order_by(models.Track.date_added.asc())
case enums.Sublist.RANDOM:
query = query.order_by(sqlalchemy.func.random())
case _:
raise ValueError(f"Unknown sublist type: {sublist}")
query = query.limit(limit)
result = await session.execute(query)
for track in result.scalars():
yield track
@sessionize
async def insert_into_playlist( # noqa: C901
is_category: bool,
type_id: int,
track_id: int,
sublist: enums.Sublist,
is_favorite: bool,
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> None:
"""Insert a track into the playlist.
Args:
is_category: True is a category, false is a holiday.
type_id: Category or holiday ID associated with the playlist.
track_id: Track ID to be added to the playlist.
sublist: Sublist type under which the track is added.
is_favorite: Indicates if the track is added as a favorite.
Keyword Args:
session: The database session.
""" # noqa: D417
if is_category:
playlist_track = models.PlaylistTrack(
category_id=type_id,
track_id=track_id,
sublist_type=sublist,
is_favorite=is_favorite,
)
else:
playlist_track = models.PlaylistTrack(
holiday_id=type_id,
track_id=track_id,
sublist_type=sublist,
is_favorite=is_favorite,
)
session.add(playlist_track)
await session.commit()
@sessionize
async def gen_all_category_ids(
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
) -> collections.abc.AsyncGenerator[int, None, None]:
"""Generate the current category ids from the database.
Keyword Args:
session: The database session.
Yields:
int: The category id.
"""
query = sqlalchemy.select(models.Category)
result = session.execute(query)
for row in result.scalars():
yield row.id
@sessionize
async def gen_active_holiday_ids(
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
) -> collections.abc.AsyncGenerator[int, None, None]:
"""Generate the current active holiday ids from the database.
Keyword Args:
session: The database session.
Yields:
int: The holiday id.
"""
query = sqlalchemy.select(models.Holiday).where(models.Holiday.is_active is True)
result = session.execute(query)
for row in result.scalars():
yield row.id
@sessionize
async def get_existing_playlist_track_info(
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> tuple[set[int], dict[tuple[bool, int, bool, enums.Sublist], int]]:
"""Get the existing playlist track information.
Keyword Args:
session: The database session.
Returns:
tuple[set[int], dict[tuple[bool, int, bool, enums.Sublist], int]]: Two pieces:
First is a set of the ids of all of the tracks in the playlist. Second is a
dict with the key being the combination of is_category, type_id, is_favorite,
and sublist; the value is the count of that unique key combination.
"""
# TODO: Set up user-specific query here.
query = sqlalchemy.select(models.PlaylistTrack).where(
models.PlaylistTrack.episode_id is None
)
result = session.execute(query)
ids = set()
data = []
for row in result.scalars():
if row.episode_id is not None:
continue
ids.append(row.track_id)
is_category = row.category_id is not None
type_id = row.category_id if is_category else row.holiday_id
is_favorite = row.is_favorite
sublist = row.sublist
data.append((is_category, type_id, is_favorite, sublist))
data_counts = collections.Counter(data)
return ids, data_counts
@sessionize
async def get_statistics(
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
) -> models.Statistics:
"""Get the statistics object for the database.
Keyword Args:
session: The database session.
Returns:
models.Statistics: The statistics object used for playlist processing.
"""
stats = await session.get(models.Statistics, 1)
stats_obj = stats.scalars().first()
return stats_obj
@sessionize
async def remove_track_from_playlist(
server_id: str, *, session: sqlalchemy.ext.asyncio.AsyncSession = None
):
"""Directly removes all entries from PlaylistTrack that are associated with a given server_id.
Args:
server_id: The unique server ID of the track to remove from the playlist.
Keyword Args:
session: The current database session.
Returns:
int: Number of rows affected by the delete operation.
"""
# Delete directly using a join on Track where server_id matches
delete_query = sqlalchemy.delete(models.PlaylistTrack).where(
models.PlaylistTrack.track_id
== sqlalchemy.select(models.Track.id)
.where(models.Track.server_id == server_id)
.scalar_subquery()
)
result = await session.execute(delete_query)
await session.commit()
affected_rows = result.rowcount
logger.info(
f"Removed {affected_rows} entries from PlaylistTrack for server_id {server_id}."
)
return affected_rows