Getting sessionize done right.
Signed-off-by: Cliff Hill <xlorep@darkhelm.org>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Command-line interface."""
|
||||
|
||||
import click
|
||||
|
||||
|
||||
|
||||
@@ -1,33 +1,209 @@
|
||||
"""Contains all SQL code here."""
|
||||
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
import collections
|
||||
import collections.abc
|
||||
import datetime
|
||||
import functools
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
from aiologger import Logger
|
||||
import aiologger
|
||||
import sqlalchemy
|
||||
import sqlalchemy.exc
|
||||
import sqlalchemy.ext
|
||||
import sqlalchemy.orm
|
||||
from playlist import enums
|
||||
from playlist import env
|
||||
from playlist import models
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.future import delete
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import registry
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
# Initialize asynchronous logger
|
||||
logger = Logger.with_default_handlers(name="sql_logger")
|
||||
logger = aiologger.Logger.with_default_handlers(name="sql_logger")
|
||||
|
||||
engine = create_async_engine(env.DATABASE_URL, echo=True)
|
||||
engine = sqlalchemy.ext.asyncio.create_async_engine(env.DATABASE_URL, echo=True)
|
||||
|
||||
mapper_registry = sqlalchemy.orm.registry()
|
||||
|
||||
async_session_maker = sqlalchemy.orm.sessionmaker(
|
||||
engine, class_=sqlalchemy.ext.asyncio.AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
|
||||
mapper_registry = registry()
|
||||
class Sessionizable(typing.Protocol):
|
||||
"""Protocol defining the signature of a sessionizable function."""
|
||||
|
||||
async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
def __call__(
|
||||
self,
|
||||
*args: tuple[typing.Any, ...],
|
||||
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
||||
**kwargs: dict[str, typing.Any],
|
||||
) -> (collections.abc.AsyncGenerator, collections.abc.Coroutine):
|
||||
"""The signature for a sessionizable function.
|
||||
|
||||
Args:
|
||||
args: The positional argument for the function.
|
||||
kwargs: The keyword arguments for the function (except for session).
|
||||
|
||||
Keyword Args:
|
||||
session: The database session. If set to None, this will be filled with a new session.
|
||||
|
||||
Returns:
|
||||
(collections.abc.AsyncGenerator, collections.abc.Coroutine): The async generator or
|
||||
coroutine that would be created by this function.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
async def _validate_session_arg(name: str, kwargs) -> None:
|
||||
try:
|
||||
if not (
|
||||
kwargs["session"] is None
|
||||
or isinstance(kwargs["session"], sqlalchemy.ext.asyncio.AsyncSession)
|
||||
):
|
||||
msg = f"{name} has an invalid session of type {type(kwargs['session'])}."
|
||||
await logger.error(msg)
|
||||
raise TypeError(msg)
|
||||
|
||||
except KeyError as e:
|
||||
msg = f"{name} does not contain a session keyword argument."
|
||||
await logger.error(msg)
|
||||
raise UnboundLocalError(msg) from e
|
||||
|
||||
|
||||
def _coro_db_logger(func: Sessionizable) -> Sessionizable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper[
|
||||
**P
|
||||
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
await logger.exception(f"Database error in {func.__name__}: {e}")
|
||||
await kwargs["session"].rollback()
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
await logger.exception(f"Unexpected error in {func.__name__}: {e}")
|
||||
await kwargs["session"].rollback()
|
||||
raise
|
||||
|
||||
|
||||
def _async_gen_db_logger(func: Sessionizable) -> Sessionizable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper[
|
||||
**P
|
||||
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
|
||||
try:
|
||||
async for item in func(*args, **kwargs):
|
||||
yield item
|
||||
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
await logger.exception(f"Database error in {func.__name__}: {e}")
|
||||
await kwargs["session"].rollback()
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
await logger.exception(f"Unexpected error in {func.__name__}: {e}")
|
||||
await kwargs["session"].rollback()
|
||||
raise
|
||||
|
||||
|
||||
def sessionize(func: Sessionizable) -> Sessionizable:
|
||||
"""Decorator that ensures a database session is available to an async function or generator.
|
||||
|
||||
This decorator automatically injects a `session` of type sqlalchemy.ext.asyncio.AsyncSession
|
||||
into the decorated function if one is not provided. If the `session` keyword argument is
|
||||
missing or set to None, a new session is created using the async_session_maker and passed to
|
||||
the function. If a session is already provided when the function is called, it uses the
|
||||
existing session.
|
||||
|
||||
The decorator is intended for use with coroutine functions and async generator functions that
|
||||
are expected to interact with a database using SQLAlchemy's asynchronous session management.
|
||||
|
||||
Features:
|
||||
- Automatically manages database sessions to ensure efficient and correct usage of database
|
||||
connections.
|
||||
- Logs any SQLAlchemy-specific exceptions or other exceptions that occur during the execution
|
||||
of the decorated function, handling rollbacks if necessary.
|
||||
- Ensures that functions without a proper 'session' keyword argument are not decorated, raising
|
||||
an UnboundLocalError.
|
||||
|
||||
Args:
|
||||
func (Sessionizable): A coroutine function or an async generator function that accepts a
|
||||
'session' keyword argument.
|
||||
|
||||
Returns:
|
||||
Sessionizable: A wrapped version of the input function that manages a database session.
|
||||
|
||||
Raises:
|
||||
TypeError: If `func` is neither a coroutine function nor an async generator function.
|
||||
UnboundLocalError: If `func` lacks a 'session' keyword argument.
|
||||
|
||||
Example:
|
||||
@sessionize
|
||||
async def fetch_data(session: Optional[AsyncSession] = None):
|
||||
# Function body using session
|
||||
...
|
||||
|
||||
Note:
|
||||
- The function must include a 'session' parameter in its signature, which can be None by
|
||||
default.
|
||||
- This decorator is only applicable to functions intended to perform database operations
|
||||
asynchronously.
|
||||
|
||||
See Also:
|
||||
- Sessionizable: The Protocol for sessionizable functions.
|
||||
- AsyncSession: SQLAlchemy class used for asynchronous session management.
|
||||
"""
|
||||
ret: Sessionizable
|
||||
match func:
|
||||
case func if inspect.iscoroutinefunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def _coro_wrapper[
|
||||
**P
|
||||
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
|
||||
await _validate_session_arg(func.__name__, kwargs)
|
||||
logged_func = _coro_db_logger(func)
|
||||
|
||||
if kwargs["session"] is None:
|
||||
async with async_session_maker as kwargs["session"]:
|
||||
return await logged_func(*args, **kwargs)
|
||||
else:
|
||||
return await logged_func(*args, **kwargs)
|
||||
|
||||
ret = _coro_wrapper
|
||||
|
||||
case func if inspect.isasyncgenfunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def _asyncgen_wrapper[
|
||||
**P
|
||||
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.AsyncGenerator:
|
||||
"""Wrap a sessionized async generator function to inject the session if needed."""
|
||||
await _validate_session_arg(func.__name__, kwargs)
|
||||
|
||||
logged_func = _async_gen_db_logger(func)
|
||||
if kwargs["session"] is None:
|
||||
async with async_session_maker as kwargs["session"]:
|
||||
async for element in logged_func(*args, **kwargs):
|
||||
yield element
|
||||
else:
|
||||
async for element in logged_func(*args, **kwargs):
|
||||
yield element
|
||||
|
||||
ret = _asyncgen_wrapper
|
||||
|
||||
case _:
|
||||
msg = (
|
||||
f"Decorated function {func.__name__} is not an async generator or coroutine"
|
||||
"function."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise TypeError(msg)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
@@ -50,141 +226,144 @@ async def drop_db() -> None:
|
||||
await logger.error(f"Failed to drop tables: {e}")
|
||||
|
||||
|
||||
@sessionize
|
||||
async def insert_or_update_track(
|
||||
session: AsyncSession, track_data: dict[str, Any]
|
||||
track_data: dict[str, typing.Any],
|
||||
*,
|
||||
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
||||
) -> None:
|
||||
"""Insert a new track or update an existing one in the database asynchronously."""
|
||||
try:
|
||||
async with session.begin():
|
||||
genre = await session.get(models.Genre, track_data.get("genre_id"))
|
||||
if not genre:
|
||||
genre = models.Genre(name=track_data["genre_name"])
|
||||
session.add(genre)
|
||||
await session.flush() # Ensures 'genre' is persisted and has an 'id'
|
||||
"""Insert a new track or update an existing one in the database asynchronously.
|
||||
|
||||
track = await session.get(models.Track, track_data.get("id"))
|
||||
if track:
|
||||
for key, value in track_data.items():
|
||||
setattr(track, key, value)
|
||||
await logger.info(f"Updated track: {track.title}")
|
||||
else:
|
||||
track = models.Track(**track_data)
|
||||
session.add(track)
|
||||
await logger.info(f"Inserted new track: {track.title}")
|
||||
Args:
|
||||
track_data: The track data to write to the database.
|
||||
|
||||
await session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
await logger.error(f"Database error in insert_or_update_track: {e}")
|
||||
await session.rollback()
|
||||
except Exception as e:
|
||||
await logger.error(f"Unexpected error in insert_or_update_track: {e}")
|
||||
await session.rollback()
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
""" # noqa: D417
|
||||
async with session.begin():
|
||||
genre = await session.get(models.Genre, track_data.get("genre_id"))
|
||||
if not genre:
|
||||
genre = models.Genre(name=track_data["genre_name"])
|
||||
session.add(genre)
|
||||
await session.flush() # Ensures 'genre' is persisted and has an 'id'
|
||||
|
||||
track = await session.get(models.Track, track_data.get("id"))
|
||||
if track:
|
||||
for key, value in track_data.items():
|
||||
setattr(track, key, value)
|
||||
await logger.info(f"Updated track: {track.title}")
|
||||
else:
|
||||
track = models.Track(**track_data)
|
||||
session.add(track)
|
||||
await logger.info(f"Inserted new track: {track.title}")
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def cleanup_unused_genres(session: AsyncSession) -> None:
|
||||
"""Remove genres that are no longer used by any tracks."""
|
||||
try:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
select(models.Genre)
|
||||
.outerjoin(models.Track)
|
||||
.filter(models.Track.id is None)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
unused_genres = result.scalars().all()
|
||||
@sessionize
|
||||
async def cleanup_unused_genres(
|
||||
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
||||
) -> None:
|
||||
"""Remove genres that are no longer used by any tracks.
|
||||
|
||||
for genre in unused_genres:
|
||||
await session.delete(genre)
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
"""
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
sqlalchemy.select(models.Genre)
|
||||
.outerjoin(models.Track)
|
||||
.filter(models.Track.id is None)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
unused_genres = result.scalars().all()
|
||||
|
||||
await session.commit()
|
||||
await logger.info(f"Cleaned up {len(unused_genres)} unused genres.")
|
||||
except SQLAlchemyError as e:
|
||||
await logger.error(f"Database error in cleanup_unused_genres: {e}")
|
||||
await session.rollback()
|
||||
except Exception as e:
|
||||
await logger.error(f"Unexpected error in cleanup_unused_genres: {e}")
|
||||
await session.rollback()
|
||||
for genre in unused_genres:
|
||||
await session.delete(genre)
|
||||
|
||||
await session.commit()
|
||||
await logger.info(f"Cleaned up {len(unused_genres)} unused genres.")
|
||||
|
||||
|
||||
@sessionize
|
||||
async def set_timestamp(
|
||||
session: AsyncSession, event: enums.Event, operation: enums.TimedEvent
|
||||
) -> datetime:
|
||||
event: enums.Event,
|
||||
operation: enums.TimedEvent,
|
||||
*,
|
||||
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
||||
) -> datetime.datetime:
|
||||
"""Mark the start or end timestamps in the statistics model.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): The database session.
|
||||
event (enums.Event): Identifies what event is being affected.
|
||||
operation (enums.EventTime): Identifies if it is the start or end of the event.
|
||||
event: Identifies what event is being affected.
|
||||
operation: Identifies if it is the start or end of the event.
|
||||
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
|
||||
Returns:
|
||||
datetime: The datetime of the last update.
|
||||
datetime.datetime: The datetime.datetime of the last update.
|
||||
"""
|
||||
try:
|
||||
async with session.begin():
|
||||
stats = await session.get(models.Statistics, 1)
|
||||
stats_obj = stats.scalars().first()
|
||||
async with session.begin():
|
||||
stats = await session.get(models.Statistics, 1)
|
||||
stats_obj = stats.scalars().first()
|
||||
|
||||
match [event, operation]:
|
||||
case [enums.Event.UPDATER, enums.EventTime.START]:
|
||||
stats_obj.start_update = datetime.now()
|
||||
log_message = "Updater start timestamp updated."
|
||||
case [enums.Event.UPDATER, enums.EventTime.END]:
|
||||
stats_obj.end_update = datetime.now()
|
||||
log_message = "Updater end timestamp updated."
|
||||
case [enums.Event.GENERATOR, enums.EventTime.START]:
|
||||
stats_obj.start_playlist_gen = datetime.now()
|
||||
log_message = "Generator start timestamp updated."
|
||||
case [enums.Event.GENERATOR, enums.EventTime.END]:
|
||||
stats_obj.end_playlist_gen = datetime.now()
|
||||
log_message = "Generator end timestamp updated."
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Invalid event {event} or operation {operation} specified."
|
||||
)
|
||||
match [event, operation]:
|
||||
case [enums.Event.UPDATER, enums.EventTime.START]:
|
||||
stats_obj.start_update = datetime.datetime.now()
|
||||
log_message = "Updater start timestamp updated."
|
||||
case [enums.Event.UPDATER, enums.EventTime.END]:
|
||||
stats_obj.end_update = datetime.datetime.now()
|
||||
log_message = "Updater end timestamp updated."
|
||||
case [enums.Event.GENERATOR, enums.EventTime.START]:
|
||||
stats_obj.start_playlist_gen = datetime.datetime.now()
|
||||
log_message = "Generator start timestamp updated."
|
||||
case [enums.Event.GENERATOR, enums.EventTime.END]:
|
||||
stats_obj.end_playlist_gen = datetime.datetime.now()
|
||||
log_message = "Generator end timestamp updated."
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Invalid event {event} or operation {operation} specified."
|
||||
)
|
||||
|
||||
match event:
|
||||
case enums.Event.UPDATER:
|
||||
last_updated = stats_obj.end_update
|
||||
case enums.Event.GENERATOR:
|
||||
last_updated = stats_obj.end_playlist_gen
|
||||
case _:
|
||||
raise ValueError(f"Invalid event {event} specified.")
|
||||
match event:
|
||||
case enums.Event.UPDATER:
|
||||
last_updated = stats_obj.end_update
|
||||
case enums.Event.GENERATOR:
|
||||
last_updated = stats_obj.end_playlist_gen
|
||||
case _:
|
||||
raise ValueError(f"Invalid event {event} specified.")
|
||||
|
||||
await session.commit()
|
||||
await logger.info(log_message)
|
||||
await session.commit()
|
||||
await logger.info(log_message)
|
||||
|
||||
return last_updated
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
await logger.error(f"Database error during updater timestamp update: {e}")
|
||||
await session.rollback()
|
||||
except Exception as e:
|
||||
await logger.error(f"Unexpected error during updater timestamp update: {e}")
|
||||
await session.rollback()
|
||||
return last_updated
|
||||
|
||||
|
||||
async def remove_holiday_by_id(session: AsyncSession, holiday_id: int) -> None:
|
||||
@sessionize
|
||||
async def remove_holiday_by_id(
|
||||
holiday_id: int, *, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
||||
) -> None:
|
||||
"""Remove a holiday from the database by ID.
|
||||
|
||||
Captures and logs database-specific errors and general exceptions.
|
||||
"""
|
||||
try:
|
||||
holiday = await session.get(models.Holiday, holiday_id)
|
||||
if holiday:
|
||||
await session.delete(holiday)
|
||||
await session.commit()
|
||||
else:
|
||||
await logger.warning(f"No holiday found with ID: {holiday_id}")
|
||||
except SQLAlchemyError as e:
|
||||
await logger.error(f"SQLAlchemyError while removing holiday: {e}")
|
||||
await session.rollback()
|
||||
except Exception as e:
|
||||
await logger.error(f"Unexpected error while removing holiday: {e}")
|
||||
await session.rollback()
|
||||
|
||||
Args:
|
||||
holiday_id: The holiday ID to remove.
|
||||
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
""" # noqa: D417
|
||||
holiday = await session.get(models.Holiday, holiday_id)
|
||||
if holiday:
|
||||
await session.delete(holiday)
|
||||
await session.commit()
|
||||
else:
|
||||
await logger.warning(f"No holiday found with ID: {holiday_id}")
|
||||
|
||||
|
||||
@sessionize
|
||||
async def update_statistics(
|
||||
session: AsyncSession,
|
||||
new_average_duration: float,
|
||||
new_podcast_length: float,
|
||||
new_max_playtime: float,
|
||||
@@ -201,228 +380,219 @@ async def update_statistics(
|
||||
new_max_regular_general_base: int,
|
||||
new_max_holiday_favorite_base: int,
|
||||
new_max_holiday_general_base: int,
|
||||
*,
|
||||
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
||||
) -> None:
|
||||
"""Update the statistics model.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): The database session.
|
||||
new_average_duration (float): The new average duration of a track by geometric mean
|
||||
new_average_duration: The new average duration of a track by geometric mean
|
||||
of all track durations as a float.
|
||||
new_podcast_length (float): The new total podcast length of all current episodes as a float.
|
||||
new_max_playtime (float): The new max playtime as a float.
|
||||
new_category_count (int): The new count of categories as an int.
|
||||
new_active_holidays (int): The new active holidays as an int.
|
||||
new_max_tracks (int): The new max tracks per day as an int.
|
||||
new_max_regular (int): The new max regular tracks as an int.
|
||||
new_max_holiday (int): The new max holiday tracks as an int.
|
||||
new_max_regular_favorite (int): The new max regular favorite tracks as an int.
|
||||
new_max_regular_general (int): The new max regular general tracks as an int.
|
||||
new_max_holiday_favorite (int): The new max holiday favorite tracks as an int.
|
||||
new_max_holiday_general (int): The new max holiday general tracks as an int.
|
||||
new_max_regular_favorite_base (int): The new max regular favorite base unit as an int.
|
||||
new_max_regular_general_base (int): The new max regular general base unit as an int.
|
||||
new_max_holiday_favorite_base (int): The new max holiday favorite base unit as an int.
|
||||
new_max_holiday_general_base (int): The new max holiday general base unit as an int.
|
||||
"""
|
||||
try:
|
||||
async with session.begin():
|
||||
# Assuming there's only one statistics record, or you might need to handle this differently
|
||||
stats = await session.get(models.Statistics, 1)
|
||||
if stats:
|
||||
stats.average_track_length = new_average_duration
|
||||
stats.total_podcast_length = new_podcast_length
|
||||
stats.max_playtime = new_max_playtime
|
||||
stats.category_count = new_category_count
|
||||
stats.active_holiday_count = new_active_holidays
|
||||
stats.max_tracks_per_day = new_max_tracks
|
||||
stats.max_regular_tracks = new_max_regular
|
||||
stats.max_holiday_tracks = new_max_holiday
|
||||
stats.max_regular_favorite_tracks = new_max_regular_favorite
|
||||
stats.max_regular_general_tracks = new_max_regular_general
|
||||
stats.max_holiday_favorite_tracks = new_max_holiday_favorite
|
||||
stats.max_holiday_general_tracks = new_max_holiday_general
|
||||
stats.max_regular_favorite_base = new_max_regular_favorite_base
|
||||
stats.max_regular_general_base = new_max_regular_general_base
|
||||
stats.max_holiday_favorite_base = new_max_holiday_favorite_base
|
||||
stats.max_holiday_general_base = new_max_holiday_general_base
|
||||
await session.commit()
|
||||
else:
|
||||
# If no statistics entry exists, create one
|
||||
new_stats = models.Statistics(
|
||||
average_track_length=new_average_duration,
|
||||
total_podcast_length=new_podcast_length,
|
||||
max_playtime=new_max_playtime,
|
||||
active_holiday_count=new_active_holidays,
|
||||
max_tracks_per_day=new_max_tracks,
|
||||
max_regular_tracks=new_max_regular,
|
||||
max_holiday_tracks=new_max_holiday,
|
||||
max_regular_favorite_tracks=new_max_regular_favorite,
|
||||
max_regular_general_tracks=new_max_regular_general,
|
||||
max_holiday_favorite_tracks=new_max_holiday_favorite,
|
||||
max_holiday_general_tracks=new_max_holiday_general,
|
||||
max_regular_favorite_base=new_max_regular_favorite_base,
|
||||
max_regular_general_base=new_max_regular_general_base,
|
||||
max_holiday_favorite_base=new_max_holiday_favorite_base,
|
||||
max_holiday_general_base=new_max_holiday_general_base,
|
||||
)
|
||||
session.add(new_stats)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
print(f"Failed to update the statistics model: {e}")
|
||||
await session.rollback()
|
||||
new_podcast_length: The new total podcast length of all current episodes as a float.
|
||||
new_max_playtime: The new max playtime as a float.
|
||||
new_category_count: The new count of categories as an int.
|
||||
new_active_holidays: The new active holidays as an int.
|
||||
new_max_tracks: The new max tracks per day as an int.
|
||||
new_max_regular: The new max regular tracks as an int.
|
||||
new_max_holiday: The new max holiday tracks as an int.
|
||||
new_max_regular_favorite: The new max regular favorite tracks as an int.
|
||||
new_max_regular_general: The new max regular general tracks as an int.
|
||||
new_max_holiday_favorite: The new max holiday favorite tracks as an int.
|
||||
new_max_holiday_general: The new max holiday general tracks as an int.
|
||||
new_max_regular_favorite_base: The new max regular favorite base unit as an int.
|
||||
new_max_regular_general_base: The new max regular general base unit as an int.
|
||||
new_max_holiday_favorite_base: The new max holiday favorite base unit as an int.
|
||||
new_max_holiday_general_base: The new max holiday general base unit as an int.
|
||||
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
""" # noqa: D417
|
||||
async with session.begin():
|
||||
# Assuming there's only one statistics record, or you might need to handle this differently
|
||||
stats = await session.get(models.Statistics, 1)
|
||||
if stats:
|
||||
stats.average_track_length = new_average_duration
|
||||
stats.total_podcast_length = new_podcast_length
|
||||
stats.max_playtime = new_max_playtime
|
||||
stats.category_count = new_category_count
|
||||
stats.active_holiday_count = new_active_holidays
|
||||
stats.max_tracks_per_day = new_max_tracks
|
||||
stats.max_regular_tracks = new_max_regular
|
||||
stats.max_holiday_tracks = new_max_holiday
|
||||
stats.max_regular_favorite_tracks = new_max_regular_favorite
|
||||
stats.max_regular_general_tracks = new_max_regular_general
|
||||
stats.max_holiday_favorite_tracks = new_max_holiday_favorite
|
||||
stats.max_holiday_general_tracks = new_max_holiday_general
|
||||
stats.max_regular_favorite_base = new_max_regular_favorite_base
|
||||
stats.max_regular_general_base = new_max_regular_general_base
|
||||
stats.max_holiday_favorite_base = new_max_holiday_favorite_base
|
||||
stats.max_holiday_general_base = new_max_holiday_general_base
|
||||
await session.commit()
|
||||
else:
|
||||
# If no statistics entry exists, create one
|
||||
new_stats = models.Statistics(
|
||||
average_track_length=new_average_duration,
|
||||
total_podcast_length=new_podcast_length,
|
||||
max_playtime=new_max_playtime,
|
||||
active_holiday_count=new_active_holidays,
|
||||
max_tracks_per_day=new_max_tracks,
|
||||
max_regular_tracks=new_max_regular,
|
||||
max_holiday_tracks=new_max_holiday,
|
||||
max_regular_favorite_tracks=new_max_regular_favorite,
|
||||
max_regular_general_tracks=new_max_regular_general,
|
||||
max_holiday_favorite_tracks=new_max_holiday_favorite,
|
||||
max_holiday_general_tracks=new_max_holiday_general,
|
||||
max_regular_favorite_base=new_max_regular_favorite_base,
|
||||
max_regular_general_base=new_max_regular_general_base,
|
||||
max_holiday_favorite_base=new_max_holiday_favorite_base,
|
||||
max_holiday_general_base=new_max_holiday_general_base,
|
||||
)
|
||||
session.add(new_stats)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@sessionize
|
||||
async def fetch_tracks_for_sublist(
|
||||
session: AsyncSession,
|
||||
is_category: bool,
|
||||
type_id: int,
|
||||
is_favorite: bool,
|
||||
sublist: enums.Sublist,
|
||||
limit: int,
|
||||
playlist_entries: set[int],
|
||||
) -> AsyncGenerator[models.Track, None]:
|
||||
*,
|
||||
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
||||
) -> collections.abc.AsyncGenerator[models.Track, None]:
|
||||
"""Fetch tracks for a given sublist type within a category.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): The database session.
|
||||
is_category (bool): True if a category, false if a holiday.
|
||||
type_id (int): Category or holiday ID for track filtering.
|
||||
is_favorite (bool): Flag indicating if only favorite tracks should be fetched.
|
||||
sublist (enums.Sublist): Type of sublist to fetch.
|
||||
limit (int): Number of tracks to fetch.
|
||||
playlist_entries (set[int]): The set of track IDs already in the playlist.
|
||||
is_category: True if a category, false if a holiday.
|
||||
type_id: Category or holiday ID for track filtering.
|
||||
is_favorite: Flag indicating if only favorite tracks should be fetched.
|
||||
sublist: Type of sublist to fetch.
|
||||
limit: Number of tracks to fetch.
|
||||
playlist_entries: The set of track IDs already in the playlist.
|
||||
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[models.Track, None]: Generates the tracks fitting the criteria.
|
||||
collections.abc.AsyncGenerator[models.Track, None]: Generates the tracks fitting the
|
||||
criteria.
|
||||
"""
|
||||
try:
|
||||
query = select(models.Track)
|
||||
if is_category:
|
||||
query = query.where(models.Track.category_id == type_id)
|
||||
else:
|
||||
query = query.join(models.Track.holidays).where(
|
||||
models.Holiday.id == type_id
|
||||
)
|
||||
query = sqlalchemy.select(models.Track)
|
||||
if is_category:
|
||||
query = query.where(models.Track.category_id == type_id)
|
||||
else:
|
||||
query = query.join(models.Track.holidays).where(models.Holiday.id == type_id)
|
||||
|
||||
query = query.filter(models.Track.id.not_in(playlist_entries))
|
||||
query = query.filter(models.Track.id.not_in(playlist_entries))
|
||||
|
||||
if is_favorite:
|
||||
query = query.filter(models.Track.rating == 5)
|
||||
if is_favorite:
|
||||
query = query.filter(models.Track.rating == 5)
|
||||
|
||||
match sublist:
|
||||
case enums.Sublist.LEAST_RECENTLY_PLAYED:
|
||||
query = query.order_by(models.Track.last_played.asc())
|
||||
case enums.Sublist.LEAST_OFTEN_PLAYED:
|
||||
query = query.order_by(models.Track.play_count.asc())
|
||||
case enums.Sublist.LEAST_RECENTLY_ADDED:
|
||||
query = query.order_by(models.Track.date_added.asc())
|
||||
case enums.Sublist.RANDOM:
|
||||
query = query.order_by(func.random())
|
||||
case _:
|
||||
raise ValueError(f"Unknown sublist type: {sublist}")
|
||||
match sublist:
|
||||
case enums.Sublist.LEAST_RECENTLY_PLAYED:
|
||||
query = query.order_by(models.Track.last_played.asc())
|
||||
case enums.Sublist.LEAST_OFTEN_PLAYED:
|
||||
query = query.order_by(models.Track.play_count.asc())
|
||||
case enums.Sublist.LEAST_RECENTLY_ADDED:
|
||||
query = query.order_by(models.Track.date_added.asc())
|
||||
case enums.Sublist.RANDOM:
|
||||
query = query.order_by(sqlalchemy.func.random())
|
||||
case _:
|
||||
raise ValueError(f"Unknown sublist type: {sublist}")
|
||||
|
||||
query = query.limit(limit)
|
||||
result = await session.execute(query)
|
||||
for track in result.scalars():
|
||||
yield track
|
||||
except SQLAlchemyError as e:
|
||||
await logger.error(f"Database error occurred while fetching tracks: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
await logger.error(f"Unexpected error occurred while fetching tracks: {e}")
|
||||
raise
|
||||
query = query.limit(limit)
|
||||
result = await session.execute(query)
|
||||
for track in result.scalars():
|
||||
yield track
|
||||
|
||||
|
||||
@sessionize
|
||||
async def insert_into_playlist( # noqa: C901
|
||||
session: AsyncSession,
|
||||
is_category: bool,
|
||||
type_id: int,
|
||||
track_id: int,
|
||||
sublist: enums.Sublist,
|
||||
is_favorite: bool,
|
||||
*,
|
||||
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
||||
) -> None:
|
||||
"""Insert a track into the playlist.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): The database session.
|
||||
is_category (bool): True is a category, false is a holiday.
|
||||
type_id (int): Category or holiday ID associated with the playlist.
|
||||
track_id (int): Track ID to be added to the playlist.
|
||||
sublist (enums.Sublist): Sublist type under which the track is added.
|
||||
is_favorite (bool): Indicates if the track is added as a favorite.
|
||||
"""
|
||||
try:
|
||||
if is_category:
|
||||
playlist_track = models.PlaylistTrack(
|
||||
category_id=type_id,
|
||||
track_id=track_id,
|
||||
sublist_type=sublist,
|
||||
is_favorite=is_favorite,
|
||||
)
|
||||
else:
|
||||
playlist_track = models.PlaylistTrack(
|
||||
holiday_id=type_id,
|
||||
track_id=track_id,
|
||||
sublist_type=sublist,
|
||||
is_favorite=is_favorite,
|
||||
)
|
||||
session.add(playlist_track)
|
||||
await session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
await logger.error(
|
||||
f"Database error occurred while inserting into playlist: {e}"
|
||||
is_category: True is a category, false is a holiday.
|
||||
type_id: Category or holiday ID associated with the playlist.
|
||||
track_id: Track ID to be added to the playlist.
|
||||
sublist: Sublist type under which the track is added.
|
||||
is_favorite: Indicates if the track is added as a favorite.
|
||||
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
""" # noqa: D417
|
||||
if is_category:
|
||||
playlist_track = models.PlaylistTrack(
|
||||
category_id=type_id,
|
||||
track_id=track_id,
|
||||
sublist_type=sublist,
|
||||
is_favorite=is_favorite,
|
||||
)
|
||||
await session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await logger.error(
|
||||
f"Unexpected error occurred while inserting into playlist: {e}"
|
||||
else:
|
||||
playlist_track = models.PlaylistTrack(
|
||||
holiday_id=type_id,
|
||||
track_id=track_id,
|
||||
sublist_type=sublist,
|
||||
is_favorite=is_favorite,
|
||||
)
|
||||
await session.rollback()
|
||||
raise
|
||||
session.add(playlist_track)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@sessionize
|
||||
async def gen_all_category_ids(
|
||||
session: AsyncSession,
|
||||
) -> AsyncGenerator[int, None, None]:
|
||||
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
||||
) -> collections.abc.AsyncGenerator[int, None, None]:
|
||||
"""Generate the current category ids from the database.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): The database session.
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
|
||||
Yields:
|
||||
int: The category id.
|
||||
"""
|
||||
query = select(models.Category)
|
||||
query = sqlalchemy.select(models.Category)
|
||||
result = session.execute(query)
|
||||
for row in result.scalars():
|
||||
yield row.id
|
||||
|
||||
|
||||
@sessionize
|
||||
async def gen_active_holiday_ids(
|
||||
session: AsyncSession,
|
||||
) -> AsyncGenerator[int, None, None]:
|
||||
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
||||
) -> collections.abc.AsyncGenerator[int, None, None]:
|
||||
"""Generate the current active holiday ids from the database.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): The database session.
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
|
||||
Yields:
|
||||
int: The holiday id.
|
||||
"""
|
||||
query = select(models.Holiday).where(models.Holiday.is_active is True)
|
||||
query = sqlalchemy.select(models.Holiday).where(models.Holiday.is_active is True)
|
||||
result = session.execute(query)
|
||||
for row in result.scalars():
|
||||
yield row.id
|
||||
|
||||
|
||||
@sessionize
|
||||
async def get_existing_playlist_track_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
||||
) -> tuple[set[int], dict[tuple[bool, int, bool, enums.Sublist], int]]:
|
||||
"""Get the existing playlist track information.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): The database session.
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
|
||||
Returns:
|
||||
tuple[set[int], dict[tuple[bool, int, bool, enums.Sublist], int]]: Two pieces:
|
||||
@@ -431,7 +601,9 @@ async def get_existing_playlist_track_info(
|
||||
and sublist; the value is the count of that unique key combination.
|
||||
"""
|
||||
# TODO: Set up user-specific query here.
|
||||
query = select(models.PlaylistTrack).where(models.PlaylistTrack.episode_id is None)
|
||||
query = sqlalchemy.select(models.PlaylistTrack).where(
|
||||
models.PlaylistTrack.episode_id is None
|
||||
)
|
||||
result = session.execute(query)
|
||||
ids = set()
|
||||
data = []
|
||||
@@ -444,15 +616,18 @@ async def get_existing_playlist_track_info(
|
||||
is_favorite = row.is_favorite
|
||||
sublist = row.sublist
|
||||
data.append((is_category, type_id, is_favorite, sublist))
|
||||
data_counts = Counter(data)
|
||||
data_counts = collections.Counter(data)
|
||||
return ids, data_counts
|
||||
|
||||
|
||||
async def get_statistics(session: AsyncSession) -> models.Statistics:
|
||||
@sessionize
|
||||
async def get_statistics(
|
||||
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
||||
) -> models.Statistics:
|
||||
"""Get the statistics object for the database.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): The database session.
|
||||
Keyword Args:
|
||||
session: The database session.
|
||||
|
||||
Returns:
|
||||
models.Statistics: The statistics object used for playlist processing.
|
||||
@@ -462,41 +637,33 @@ async def get_statistics(session: AsyncSession) -> models.Statistics:
|
||||
return stats_obj
|
||||
|
||||
|
||||
async def remove_track_from_playlist(session: AsyncSession, server_id: str):
|
||||
@sessionize
|
||||
async def remove_track_from_playlist(
|
||||
server_id: str, *, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
||||
):
|
||||
"""Directly removes all entries from PlaylistTrack that are associated with a given server_id.
|
||||
|
||||
Args:
|
||||
session (AsyncSession): The current database session.
|
||||
server_id (str): The unique server ID of the track to remove from the playlist.
|
||||
server_id: The unique server ID of the track to remove from the playlist.
|
||||
|
||||
Keyword Args:
|
||||
session: The current database session.
|
||||
|
||||
Returns:
|
||||
int: Number of rows affected by the delete operation.
|
||||
"""
|
||||
try:
|
||||
# Delete directly using a join on Track where server_id matches
|
||||
delete_query = delete(models.PlaylistTrack).where(
|
||||
models.PlaylistTrack.track_id
|
||||
== select(models.Track.id)
|
||||
.where(models.Track.server_id == server_id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
result = await session.execute(delete_query)
|
||||
await session.commit()
|
||||
# Delete directly using a join on Track where server_id matches
|
||||
delete_query = sqlalchemy.delete(models.PlaylistTrack).where(
|
||||
models.PlaylistTrack.track_id
|
||||
== sqlalchemy.select(models.Track.id)
|
||||
.where(models.Track.server_id == server_id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
result = await session.execute(delete_query)
|
||||
await session.commit()
|
||||
|
||||
affected_rows = result.rowcount
|
||||
logger.info(
|
||||
f"Removed {affected_rows} entries from PlaylistTrack for server_id {server_id}."
|
||||
)
|
||||
return affected_rows
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(
|
||||
f"SQLAlchemy error occurred while removing tracks from playlist: {e}"
|
||||
)
|
||||
await session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error occurred while removing tracks from playlist: {e}"
|
||||
)
|
||||
await session.rollback()
|
||||
raise
|
||||
affected_rows = result.rowcount
|
||||
logger.info(
|
||||
f"Removed {affected_rows} entries from PlaylistTrack for server_id {server_id}."
|
||||
)
|
||||
return affected_rows
|
||||
|
||||
3473
poetry.lock
generated
3473
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -58,6 +58,9 @@ myst-parser = {version = ">=0.16.1"}
|
||||
[tool.poetry.scripts]
|
||||
plex-playlist = "playlist.__main__:main"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
jupyter = "^1.0.0"
|
||||
|
||||
[tool.coverage.paths]
|
||||
source = ["src", "*/site-packages"]
|
||||
tests = ["tests", "*/tests"]
|
||||
|
||||
Reference in New Issue
Block a user