Getting sessionize done right.

Signed-off-by: Cliff Hill <xlorep@darkhelm.org>
This commit is contained in:
2024-05-13 18:08:49 -04:00
parent c0242fc87b
commit 7f691fcdde
4 changed files with 3191 additions and 1077 deletions

View File

@@ -1,4 +1,5 @@
"""Command-line interface."""
import click

View File

@@ -1,33 +1,209 @@
"""Contains all SQL code here."""
from collections import Counter
from datetime import datetime
from typing import Any
from typing import AsyncGenerator
import collections
import collections.abc
import datetime
import functools
import inspect
import typing
from aiologger import Logger
import aiologger
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.ext
import sqlalchemy.orm
from playlist import enums
from playlist import env
from playlist import models
from sqlalchemy import func
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.future import delete
from sqlalchemy.future import select
from sqlalchemy.orm import registry
from sqlalchemy.orm import sessionmaker
# Initialize asynchronous logger
logger = Logger.with_default_handlers(name="sql_logger")
logger = aiologger.Logger.with_default_handlers(name="sql_logger")
engine = create_async_engine(env.DATABASE_URL, echo=True)
engine = sqlalchemy.ext.asyncio.create_async_engine(env.DATABASE_URL, echo=True)
mapper_registry = sqlalchemy.orm.registry()
async_session_maker = sqlalchemy.orm.sessionmaker(
engine, class_=sqlalchemy.ext.asyncio.AsyncSession, expire_on_commit=False
)
mapper_registry = registry()
class Sessionizable(typing.Protocol):
"""Protocol defining the signature of a sessionizable function."""
async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
def __call__(
self,
*args: tuple[typing.Any, ...],
session: sqlalchemy.ext.asyncio.AsyncSession = None,
**kwargs: dict[str, typing.Any],
) -> (collections.abc.AsyncGenerator, collections.abc.Coroutine):
"""The signature for a sessionizable function.
Args:
args: The positional argument for the function.
kwargs: The keyword arguments for the function (except for session).
Keyword Args:
session: The database session. If set to None, this will be filled with a new session.
Returns:
(collections.abc.AsyncGenerator, collections.abc.Coroutine): The async generator or
coroutine that would be created by this function.
"""
...
async def _validate_session_arg(name: str, kwargs) -> None:
try:
if not (
kwargs["session"] is None
or isinstance(kwargs["session"], sqlalchemy.ext.asyncio.AsyncSession)
):
msg = f"{name} has an invalid session of type {type(kwargs['session'])}."
await logger.error(msg)
raise TypeError(msg)
except KeyError as e:
msg = f"{name} does not contain a session keyword argument."
await logger.error(msg)
raise UnboundLocalError(msg) from e
def _coro_db_logger(func: Sessionizable) -> Sessionizable:
@functools.wraps(func)
async def wrapper[
**P
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
try:
return await func(*args, **kwargs)
except sqlalchemy.exc.SQLAlchemyError as e:
await logger.exception(f"Database error in {func.__name__}: {e}")
await kwargs["session"].rollback()
raise
except Exception as e:
await logger.exception(f"Unexpected error in {func.__name__}: {e}")
await kwargs["session"].rollback()
raise
def _async_gen_db_logger(func: Sessionizable) -> Sessionizable:
@functools.wraps(func)
async def wrapper[
**P
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
try:
async for item in func(*args, **kwargs):
yield item
except sqlalchemy.exc.SQLAlchemyError as e:
await logger.exception(f"Database error in {func.__name__}: {e}")
await kwargs["session"].rollback()
raise
except Exception as e:
await logger.exception(f"Unexpected error in {func.__name__}: {e}")
await kwargs["session"].rollback()
raise
def sessionize(func: Sessionizable) -> Sessionizable:
"""Decorator that ensures a database session is available to an async function or generator.
This decorator automatically injects a `session` of type sqlalchemy.ext.asyncio.AsyncSession
into the decorated function if one is not provided. If the `session` keyword argument is
missing or set to None, a new session is created using the async_session_maker and passed to
the function. If a session is already provided when the function is called, it uses the
existing session.
The decorator is intended for use with coroutine functions and async generator functions that
are expected to interact with a database using SQLAlchemy's asynchronous session management.
Features:
- Automatically manages database sessions to ensure efficient and correct usage of database
connections.
- Logs any SQLAlchemy-specific exceptions or other exceptions that occur during the execution
of the decorated function, handling rollbacks if necessary.
- Ensures that functions without a proper 'session' keyword argument are not decorated, raising
an UnboundLocalError.
Args:
func (Sessionizable): A coroutine function or an async generator function that accepts a
'session' keyword argument.
Returns:
Sessionizable: A wrapped version of the input function that manages a database session.
Raises:
TypeError: If `func` is neither a coroutine function nor an async generator function.
UnboundLocalError: If `func` lacks a 'session' keyword argument.
Example:
@sessionize
async def fetch_data(session: Optional[AsyncSession] = None):
# Function body using session
...
Note:
- The function must include a 'session' parameter in its signature, which can be None by
default.
- This decorator is only applicable to functions intended to perform database operations
asynchronously.
See Also:
- Sessionizable: The Protocol for sessionizable functions.
- AsyncSession: SQLAlchemy class used for asynchronous session management.
"""
ret: Sessionizable
match func:
case func if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def _coro_wrapper[
**P
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
await _validate_session_arg(func.__name__, kwargs)
logged_func = _coro_db_logger(func)
if kwargs["session"] is None:
async with async_session_maker as kwargs["session"]:
return await logged_func(*args, **kwargs)
else:
return await logged_func(*args, **kwargs)
ret = _coro_wrapper
case func if inspect.isasyncgenfunction(func):
@functools.wraps(func)
async def _asyncgen_wrapper[
**P
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.AsyncGenerator:
"""Wrap a sessionized async generator function to inject the session if needed."""
await _validate_session_arg(func.__name__, kwargs)
logged_func = _async_gen_db_logger(func)
if kwargs["session"] is None:
async with async_session_maker as kwargs["session"]:
async for element in logged_func(*args, **kwargs):
yield element
else:
async for element in logged_func(*args, **kwargs):
yield element
ret = _asyncgen_wrapper
case _:
msg = (
f"Decorated function {func.__name__} is not an async generator or coroutine"
"function."
)
logger.error(msg)
raise TypeError(msg)
return ret
async def init_db() -> None:
@@ -50,141 +226,144 @@ async def drop_db() -> None:
await logger.error(f"Failed to drop tables: {e}")
@sessionize
async def insert_or_update_track(
session: AsyncSession, track_data: dict[str, Any]
track_data: dict[str, typing.Any],
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> None:
"""Insert a new track or update an existing one in the database asynchronously."""
try:
async with session.begin():
genre = await session.get(models.Genre, track_data.get("genre_id"))
if not genre:
genre = models.Genre(name=track_data["genre_name"])
session.add(genre)
await session.flush() # Ensures 'genre' is persisted and has an 'id'
"""Insert a new track or update an existing one in the database asynchronously.
track = await session.get(models.Track, track_data.get("id"))
if track:
for key, value in track_data.items():
setattr(track, key, value)
await logger.info(f"Updated track: {track.title}")
else:
track = models.Track(**track_data)
session.add(track)
await logger.info(f"Inserted new track: {track.title}")
Args:
track_data: The track data to write to the database.
await session.commit()
except SQLAlchemyError as e:
await logger.error(f"Database error in insert_or_update_track: {e}")
await session.rollback()
except Exception as e:
await logger.error(f"Unexpected error in insert_or_update_track: {e}")
await session.rollback()
Keyword Args:
session: The database session.
""" # noqa: D417
async with session.begin():
genre = await session.get(models.Genre, track_data.get("genre_id"))
if not genre:
genre = models.Genre(name=track_data["genre_name"])
session.add(genre)
await session.flush() # Ensures 'genre' is persisted and has an 'id'
track = await session.get(models.Track, track_data.get("id"))
if track:
for key, value in track_data.items():
setattr(track, key, value)
await logger.info(f"Updated track: {track.title}")
else:
track = models.Track(**track_data)
session.add(track)
await logger.info(f"Inserted new track: {track.title}")
await session.commit()
async def cleanup_unused_genres(session: AsyncSession) -> None:
"""Remove genres that are no longer used by any tracks."""
try:
async with session.begin():
stmt = (
select(models.Genre)
.outerjoin(models.Track)
.filter(models.Track.id is None)
)
result = await session.execute(stmt)
unused_genres = result.scalars().all()
@sessionize
async def cleanup_unused_genres(
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
) -> None:
"""Remove genres that are no longer used by any tracks.
for genre in unused_genres:
await session.delete(genre)
Keyword Args:
session: The database session.
"""
async with session.begin():
stmt = (
sqlalchemy.select(models.Genre)
.outerjoin(models.Track)
.filter(models.Track.id is None)
)
result = await session.execute(stmt)
unused_genres = result.scalars().all()
await session.commit()
await logger.info(f"Cleaned up {len(unused_genres)} unused genres.")
except SQLAlchemyError as e:
await logger.error(f"Database error in cleanup_unused_genres: {e}")
await session.rollback()
except Exception as e:
await logger.error(f"Unexpected error in cleanup_unused_genres: {e}")
await session.rollback()
for genre in unused_genres:
await session.delete(genre)
await session.commit()
await logger.info(f"Cleaned up {len(unused_genres)} unused genres.")
@sessionize
async def set_timestamp(
session: AsyncSession, event: enums.Event, operation: enums.TimedEvent
) -> datetime:
event: enums.Event,
operation: enums.TimedEvent,
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> datetime.datetime:
"""Mark the start or end timestamps in the statistics model.
Args:
session (AsyncSession): The database session.
event (enums.Event): Identifies what event is being affected.
operation (enums.EventTime): Identifies if it is the start or end of the event.
event: Identifies what event is being affected.
operation: Identifies if it is the start or end of the event.
Keyword Args:
session: The database session.
Returns:
datetime: The datetime of the last update.
datetime.datetime: The datetime.datetime of the last update.
"""
try:
async with session.begin():
stats = await session.get(models.Statistics, 1)
stats_obj = stats.scalars().first()
async with session.begin():
stats = await session.get(models.Statistics, 1)
stats_obj = stats.scalars().first()
match [event, operation]:
case [enums.Event.UPDATER, enums.EventTime.START]:
stats_obj.start_update = datetime.now()
log_message = "Updater start timestamp updated."
case [enums.Event.UPDATER, enums.EventTime.END]:
stats_obj.end_update = datetime.now()
log_message = "Updater end timestamp updated."
case [enums.Event.GENERATOR, enums.EventTime.START]:
stats_obj.start_playlist_gen = datetime.now()
log_message = "Generator start timestamp updated."
case [enums.Event.GENERATOR, enums.EventTime.END]:
stats_obj.end_playlist_gen = datetime.now()
log_message = "Generator end timestamp updated."
case _:
raise ValueError(
f"Invalid event {event} or operation {operation} specified."
)
match [event, operation]:
case [enums.Event.UPDATER, enums.EventTime.START]:
stats_obj.start_update = datetime.datetime.now()
log_message = "Updater start timestamp updated."
case [enums.Event.UPDATER, enums.EventTime.END]:
stats_obj.end_update = datetime.datetime.now()
log_message = "Updater end timestamp updated."
case [enums.Event.GENERATOR, enums.EventTime.START]:
stats_obj.start_playlist_gen = datetime.datetime.now()
log_message = "Generator start timestamp updated."
case [enums.Event.GENERATOR, enums.EventTime.END]:
stats_obj.end_playlist_gen = datetime.datetime.now()
log_message = "Generator end timestamp updated."
case _:
raise ValueError(
f"Invalid event {event} or operation {operation} specified."
)
match event:
case enums.Event.UPDATER:
last_updated = stats_obj.end_update
case enums.Event.GENERATOR:
last_updated = stats_obj.end_playlist_gen
case _:
raise ValueError(f"Invalid event {event} specified.")
match event:
case enums.Event.UPDATER:
last_updated = stats_obj.end_update
case enums.Event.GENERATOR:
last_updated = stats_obj.end_playlist_gen
case _:
raise ValueError(f"Invalid event {event} specified.")
await session.commit()
await logger.info(log_message)
await session.commit()
await logger.info(log_message)
return last_updated
except SQLAlchemyError as e:
await logger.error(f"Database error during updater timestamp update: {e}")
await session.rollback()
except Exception as e:
await logger.error(f"Unexpected error during updater timestamp update: {e}")
await session.rollback()
return last_updated
async def remove_holiday_by_id(session: AsyncSession, holiday_id: int) -> None:
@sessionize
async def remove_holiday_by_id(
holiday_id: int, *, session: sqlalchemy.ext.asyncio.AsyncSession = None
) -> None:
"""Remove a holiday from the database by ID.
Captures and logs database-specific errors and general exceptions.
"""
try:
holiday = await session.get(models.Holiday, holiday_id)
if holiday:
await session.delete(holiday)
await session.commit()
else:
await logger.warning(f"No holiday found with ID: {holiday_id}")
except SQLAlchemyError as e:
await logger.error(f"SQLAlchemyError while removing holiday: {e}")
await session.rollback()
except Exception as e:
await logger.error(f"Unexpected error while removing holiday: {e}")
await session.rollback()
Args:
holiday_id: The holiday ID to remove.
Keyword Args:
session: The database session.
""" # noqa: D417
holiday = await session.get(models.Holiday, holiday_id)
if holiday:
await session.delete(holiday)
await session.commit()
else:
await logger.warning(f"No holiday found with ID: {holiday_id}")
@sessionize
async def update_statistics(
session: AsyncSession,
new_average_duration: float,
new_podcast_length: float,
new_max_playtime: float,
@@ -201,228 +380,219 @@ async def update_statistics(
new_max_regular_general_base: int,
new_max_holiday_favorite_base: int,
new_max_holiday_general_base: int,
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> None:
"""Update the statistics model.
Args:
session (AsyncSession): The database session.
new_average_duration (float): The new average duration of a track by geometric mean
new_average_duration: The new average duration of a track by geometric mean
of all track durations as a float.
new_podcast_length (float): The new total podcast length of all current episodes as a float.
new_max_playtime (float): The new max playtime as a float.
new_category_count (int): The new count of categories as an int.
new_active_holidays (int): The new active holidays as an int.
new_max_tracks (int): The new max tracks per day as an int.
new_max_regular (int): The new max regular tracks as an int.
new_max_holiday (int): The new max holiday tracks as an int.
new_max_regular_favorite (int): The new max regular favorite tracks as an int.
new_max_regular_general (int): The new max regular general tracks as an int.
new_max_holiday_favorite (int): The new max holiday favorite tracks as an int.
new_max_holiday_general (int): The new max holiday general tracks as an int.
new_max_regular_favorite_base (int): The new max regular favorite base unit as an int.
new_max_regular_general_base (int): The new max regular general base unit as an int.
new_max_holiday_favorite_base (int): The new max holiday favorite base unit as an int.
new_max_holiday_general_base (int): The new max holiday general base unit as an int.
"""
try:
async with session.begin():
# Assuming there's only one statistics record, or you might need to handle this differently
stats = await session.get(models.Statistics, 1)
if stats:
stats.average_track_length = new_average_duration
stats.total_podcast_length = new_podcast_length
stats.max_playtime = new_max_playtime
stats.category_count = new_category_count
stats.active_holiday_count = new_active_holidays
stats.max_tracks_per_day = new_max_tracks
stats.max_regular_tracks = new_max_regular
stats.max_holiday_tracks = new_max_holiday
stats.max_regular_favorite_tracks = new_max_regular_favorite
stats.max_regular_general_tracks = new_max_regular_general
stats.max_holiday_favorite_tracks = new_max_holiday_favorite
stats.max_holiday_general_tracks = new_max_holiday_general
stats.max_regular_favorite_base = new_max_regular_favorite_base
stats.max_regular_general_base = new_max_regular_general_base
stats.max_holiday_favorite_base = new_max_holiday_favorite_base
stats.max_holiday_general_base = new_max_holiday_general_base
await session.commit()
else:
# If no statistics entry exists, create one
new_stats = models.Statistics(
average_track_length=new_average_duration,
total_podcast_length=new_podcast_length,
max_playtime=new_max_playtime,
active_holiday_count=new_active_holidays,
max_tracks_per_day=new_max_tracks,
max_regular_tracks=new_max_regular,
max_holiday_tracks=new_max_holiday,
max_regular_favorite_tracks=new_max_regular_favorite,
max_regular_general_tracks=new_max_regular_general,
max_holiday_favorite_tracks=new_max_holiday_favorite,
max_holiday_general_tracks=new_max_holiday_general,
max_regular_favorite_base=new_max_regular_favorite_base,
max_regular_general_base=new_max_regular_general_base,
max_holiday_favorite_base=new_max_holiday_favorite_base,
max_holiday_general_base=new_max_holiday_general_base,
)
session.add(new_stats)
await session.commit()
except Exception as e:
print(f"Failed to update the statistics model: {e}")
await session.rollback()
new_podcast_length: The new total podcast length of all current episodes as a float.
new_max_playtime: The new max playtime as a float.
new_category_count: The new count of categories as an int.
new_active_holidays: The new active holidays as an int.
new_max_tracks: The new max tracks per day as an int.
new_max_regular: The new max regular tracks as an int.
new_max_holiday: The new max holiday tracks as an int.
new_max_regular_favorite: The new max regular favorite tracks as an int.
new_max_regular_general: The new max regular general tracks as an int.
new_max_holiday_favorite: The new max holiday favorite tracks as an int.
new_max_holiday_general: The new max holiday general tracks as an int.
new_max_regular_favorite_base: The new max regular favorite base unit as an int.
new_max_regular_general_base: The new max regular general base unit as an int.
new_max_holiday_favorite_base: The new max holiday favorite base unit as an int.
new_max_holiday_general_base: The new max holiday general base unit as an int.
Keyword Args:
session: The database session.
""" # noqa: D417
async with session.begin():
# Assuming there's only one statistics record, or you might need to handle this differently
stats = await session.get(models.Statistics, 1)
if stats:
stats.average_track_length = new_average_duration
stats.total_podcast_length = new_podcast_length
stats.max_playtime = new_max_playtime
stats.category_count = new_category_count
stats.active_holiday_count = new_active_holidays
stats.max_tracks_per_day = new_max_tracks
stats.max_regular_tracks = new_max_regular
stats.max_holiday_tracks = new_max_holiday
stats.max_regular_favorite_tracks = new_max_regular_favorite
stats.max_regular_general_tracks = new_max_regular_general
stats.max_holiday_favorite_tracks = new_max_holiday_favorite
stats.max_holiday_general_tracks = new_max_holiday_general
stats.max_regular_favorite_base = new_max_regular_favorite_base
stats.max_regular_general_base = new_max_regular_general_base
stats.max_holiday_favorite_base = new_max_holiday_favorite_base
stats.max_holiday_general_base = new_max_holiday_general_base
await session.commit()
else:
# If no statistics entry exists, create one
new_stats = models.Statistics(
average_track_length=new_average_duration,
total_podcast_length=new_podcast_length,
max_playtime=new_max_playtime,
active_holiday_count=new_active_holidays,
max_tracks_per_day=new_max_tracks,
max_regular_tracks=new_max_regular,
max_holiday_tracks=new_max_holiday,
max_regular_favorite_tracks=new_max_regular_favorite,
max_regular_general_tracks=new_max_regular_general,
max_holiday_favorite_tracks=new_max_holiday_favorite,
max_holiday_general_tracks=new_max_holiday_general,
max_regular_favorite_base=new_max_regular_favorite_base,
max_regular_general_base=new_max_regular_general_base,
max_holiday_favorite_base=new_max_holiday_favorite_base,
max_holiday_general_base=new_max_holiday_general_base,
)
session.add(new_stats)
await session.commit()
@sessionize
async def fetch_tracks_for_sublist(
session: AsyncSession,
is_category: bool,
type_id: int,
is_favorite: bool,
sublist: enums.Sublist,
limit: int,
playlist_entries: set[int],
) -> AsyncGenerator[models.Track, None]:
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> collections.abc.AsyncGenerator[models.Track, None]:
"""Fetch tracks for a given sublist type within a category.
Args:
session (AsyncSession): The database session.
is_category (bool): True if a category, false if a holiday.
type_id (int): Category or holiday ID for track filtering.
is_favorite (bool): Flag indicating if only favorite tracks should be fetched.
sublist (enums.Sublist): Type of sublist to fetch.
limit (int): Number of tracks to fetch.
playlist_entries (set[int]): The set of track IDs already in the playlist.
is_category: True if a category, false if a holiday.
type_id: Category or holiday ID for track filtering.
is_favorite: Flag indicating if only favorite tracks should be fetched.
sublist: Type of sublist to fetch.
limit: Number of tracks to fetch.
playlist_entries: The set of track IDs already in the playlist.
Keyword Args:
session: The database session.
Returns:
AsyncGenerator[models.Track, None]: Generates the tracks fitting the criteria.
collections.abc.AsyncGenerator[models.Track, None]: Generates the tracks fitting the
criteria.
"""
try:
query = select(models.Track)
if is_category:
query = query.where(models.Track.category_id == type_id)
else:
query = query.join(models.Track.holidays).where(
models.Holiday.id == type_id
)
query = sqlalchemy.select(models.Track)
if is_category:
query = query.where(models.Track.category_id == type_id)
else:
query = query.join(models.Track.holidays).where(models.Holiday.id == type_id)
query = query.filter(models.Track.id.not_in(playlist_entries))
query = query.filter(models.Track.id.not_in(playlist_entries))
if is_favorite:
query = query.filter(models.Track.rating == 5)
if is_favorite:
query = query.filter(models.Track.rating == 5)
match sublist:
case enums.Sublist.LEAST_RECENTLY_PLAYED:
query = query.order_by(models.Track.last_played.asc())
case enums.Sublist.LEAST_OFTEN_PLAYED:
query = query.order_by(models.Track.play_count.asc())
case enums.Sublist.LEAST_RECENTLY_ADDED:
query = query.order_by(models.Track.date_added.asc())
case enums.Sublist.RANDOM:
query = query.order_by(func.random())
case _:
raise ValueError(f"Unknown sublist type: {sublist}")
match sublist:
case enums.Sublist.LEAST_RECENTLY_PLAYED:
query = query.order_by(models.Track.last_played.asc())
case enums.Sublist.LEAST_OFTEN_PLAYED:
query = query.order_by(models.Track.play_count.asc())
case enums.Sublist.LEAST_RECENTLY_ADDED:
query = query.order_by(models.Track.date_added.asc())
case enums.Sublist.RANDOM:
query = query.order_by(sqlalchemy.func.random())
case _:
raise ValueError(f"Unknown sublist type: {sublist}")
query = query.limit(limit)
result = await session.execute(query)
for track in result.scalars():
yield track
except SQLAlchemyError as e:
await logger.error(f"Database error occurred while fetching tracks: {e}")
raise
except Exception as e:
await logger.error(f"Unexpected error occurred while fetching tracks: {e}")
raise
query = query.limit(limit)
result = await session.execute(query)
for track in result.scalars():
yield track
@sessionize
async def insert_into_playlist( # noqa: C901
session: AsyncSession,
is_category: bool,
type_id: int,
track_id: int,
sublist: enums.Sublist,
is_favorite: bool,
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> None:
"""Insert a track into the playlist.
Args:
session (AsyncSession): The database session.
is_category (bool): True is a category, false is a holiday.
type_id (int): Category or holiday ID associated with the playlist.
track_id (int): Track ID to be added to the playlist.
sublist (enums.Sublist): Sublist type under which the track is added.
is_favorite (bool): Indicates if the track is added as a favorite.
"""
try:
if is_category:
playlist_track = models.PlaylistTrack(
category_id=type_id,
track_id=track_id,
sublist_type=sublist,
is_favorite=is_favorite,
)
else:
playlist_track = models.PlaylistTrack(
holiday_id=type_id,
track_id=track_id,
sublist_type=sublist,
is_favorite=is_favorite,
)
session.add(playlist_track)
await session.commit()
except SQLAlchemyError as e:
await logger.error(
f"Database error occurred while inserting into playlist: {e}"
is_category: True is a category, false is a holiday.
type_id: Category or holiday ID associated with the playlist.
track_id: Track ID to be added to the playlist.
sublist: Sublist type under which the track is added.
is_favorite: Indicates if the track is added as a favorite.
Keyword Args:
session: The database session.
""" # noqa: D417
if is_category:
playlist_track = models.PlaylistTrack(
category_id=type_id,
track_id=track_id,
sublist_type=sublist,
is_favorite=is_favorite,
)
await session.rollback()
raise
except Exception as e:
await logger.error(
f"Unexpected error occurred while inserting into playlist: {e}"
else:
playlist_track = models.PlaylistTrack(
holiday_id=type_id,
track_id=track_id,
sublist_type=sublist,
is_favorite=is_favorite,
)
await session.rollback()
raise
session.add(playlist_track)
await session.commit()
@sessionize
async def gen_all_category_ids(
session: AsyncSession,
) -> AsyncGenerator[int, None, None]:
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
) -> collections.abc.AsyncGenerator[int, None, None]:
"""Generate the current category ids from the database.
Args:
session (AsyncSession): The database session.
Keyword Args:
session: The database session.
Yields:
int: The category id.
"""
query = select(models.Category)
query = sqlalchemy.select(models.Category)
result = session.execute(query)
for row in result.scalars():
yield row.id
@sessionize
async def gen_active_holiday_ids(
session: AsyncSession,
) -> AsyncGenerator[int, None, None]:
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
) -> collections.abc.AsyncGenerator[int, None, None]:
"""Generate the current active holiday ids from the database.
Args:
session (AsyncSession): The database session.
Keyword Args:
session: The database session.
Yields:
int: The holiday id.
"""
query = select(models.Holiday).where(models.Holiday.is_active is True)
query = sqlalchemy.select(models.Holiday).where(models.Holiday.is_active is True)
result = session.execute(query)
for row in result.scalars():
yield row.id
@sessionize
async def get_existing_playlist_track_info(
session: AsyncSession,
*,
session: sqlalchemy.ext.asyncio.AsyncSession = None,
) -> tuple[set[int], dict[tuple[bool, int, bool, enums.Sublist], int]]:
"""Get the existing playlist track information.
Args:
session (AsyncSession): The database session.
Keyword Args:
session: The database session.
Returns:
tuple[set[int], dict[tuple[bool, int, bool, enums.Sublist], int]]: Two pieces:
@@ -431,7 +601,9 @@ async def get_existing_playlist_track_info(
and sublist; the value is the count of that unique key combination.
"""
# TODO: Set up user-specific query here.
query = select(models.PlaylistTrack).where(models.PlaylistTrack.episode_id is None)
query = sqlalchemy.select(models.PlaylistTrack).where(
models.PlaylistTrack.episode_id is None
)
result = session.execute(query)
ids = set()
data = []
@@ -444,15 +616,18 @@ async def get_existing_playlist_track_info(
is_favorite = row.is_favorite
sublist = row.sublist
data.append((is_category, type_id, is_favorite, sublist))
data_counts = Counter(data)
data_counts = collections.Counter(data)
return ids, data_counts
async def get_statistics(session: AsyncSession) -> models.Statistics:
@sessionize
async def get_statistics(
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
) -> models.Statistics:
"""Get the statistics object for the database.
Args:
session (AsyncSession): The database session.
Keyword Args:
session: The database session.
Returns:
models.Statistics: The statistics object used for playlist processing.
@@ -462,41 +637,33 @@ async def get_statistics(session: AsyncSession) -> models.Statistics:
return stats_obj
async def remove_track_from_playlist(session: AsyncSession, server_id: str):
@sessionize
async def remove_track_from_playlist(
server_id: str, *, session: sqlalchemy.ext.asyncio.AsyncSession = None
):
"""Directly removes all entries from PlaylistTrack that are associated with a given server_id.
Args:
session (AsyncSession): The current database session.
server_id (str): The unique server ID of the track to remove from the playlist.
server_id: The unique server ID of the track to remove from the playlist.
Keyword Args:
session: The current database session.
Returns:
int: Number of rows affected by the delete operation.
"""
try:
# Delete directly using a join on Track where server_id matches
delete_query = delete(models.PlaylistTrack).where(
models.PlaylistTrack.track_id
== select(models.Track.id)
.where(models.Track.server_id == server_id)
.scalar_subquery()
)
result = await session.execute(delete_query)
await session.commit()
# Delete directly using a join on Track where server_id matches
delete_query = sqlalchemy.delete(models.PlaylistTrack).where(
models.PlaylistTrack.track_id
== sqlalchemy.select(models.Track.id)
.where(models.Track.server_id == server_id)
.scalar_subquery()
)
result = await session.execute(delete_query)
await session.commit()
affected_rows = result.rowcount
logger.info(
f"Removed {affected_rows} entries from PlaylistTrack for server_id {server_id}."
)
return affected_rows
except SQLAlchemyError as e:
logger.error(
f"SQLAlchemy error occurred while removing tracks from playlist: {e}"
)
await session.rollback()
raise
except Exception as e:
logger.error(
f"Unexpected error occurred while removing tracks from playlist: {e}"
)
await session.rollback()
raise
affected_rows = result.rowcount
logger.info(
f"Removed {affected_rows} entries from PlaylistTrack for server_id {server_id}."
)
return affected_rows

3473
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -58,6 +58,9 @@ myst-parser = {version = ">=0.16.1"}
[tool.poetry.scripts]
plex-playlist = "playlist.__main__:main"
[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
[tool.coverage.paths]
source = ["src", "*/site-packages"]
tests = ["tests", "*/tests"]