670 lines
24 KiB
Python
670 lines
24 KiB
Python
"""Contains all SQL code here."""
|
|
|
|
import collections
|
|
import collections.abc
|
|
import datetime
|
|
import functools
|
|
import inspect
|
|
import typing
|
|
|
|
import aiologger
|
|
import sqlalchemy
|
|
import sqlalchemy.exc
|
|
import sqlalchemy.ext
|
|
import sqlalchemy.orm
|
|
from playlist import enums
|
|
from playlist import env
|
|
from playlist import models
|
|
|
|
|
|
# Initialize asynchronous logger
|
|
logger = aiologger.Logger.with_default_handlers(name="sql_logger")
|
|
|
|
engine = sqlalchemy.ext.asyncio.create_async_engine(env.DATABASE_URL, echo=True)
|
|
|
|
mapper_registry = sqlalchemy.orm.registry()
|
|
|
|
async_session_maker = sqlalchemy.orm.sessionmaker(
|
|
engine, class_=sqlalchemy.ext.asyncio.AsyncSession, expire_on_commit=False
|
|
)
|
|
|
|
|
|
class Sessionizable(typing.Protocol):
|
|
"""Protocol defining the signature of a sessionizable function."""
|
|
|
|
def __call__(
|
|
self,
|
|
*args: tuple[typing.Any, ...],
|
|
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
|
**kwargs: dict[str, typing.Any],
|
|
) -> (collections.abc.AsyncGenerator, collections.abc.Coroutine):
|
|
"""The signature for a sessionizable function.
|
|
|
|
Args:
|
|
args: The positional argument for the function.
|
|
kwargs: The keyword arguments for the function (except for session).
|
|
|
|
Keyword Args:
|
|
session: The database session. If set to None, this will be filled with a new session.
|
|
|
|
Returns:
|
|
(collections.abc.AsyncGenerator, collections.abc.Coroutine): The async generator or
|
|
coroutine that would be created by this function.
|
|
"""
|
|
...
|
|
|
|
|
|
async def _validate_session_arg(name: str, kwargs) -> None:
|
|
try:
|
|
if not (
|
|
kwargs["session"] is None
|
|
or isinstance(kwargs["session"], sqlalchemy.ext.asyncio.AsyncSession)
|
|
):
|
|
msg = f"{name} has an invalid session of type {type(kwargs['session'])}."
|
|
await logger.error(msg)
|
|
raise TypeError(msg)
|
|
|
|
except KeyError as e:
|
|
msg = f"{name} does not contain a session keyword argument."
|
|
await logger.error(msg)
|
|
raise UnboundLocalError(msg) from e
|
|
|
|
|
|
def _coro_db_logger(func: Sessionizable) -> Sessionizable:
|
|
@functools.wraps(func)
|
|
async def wrapper[
|
|
**P
|
|
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
|
|
except sqlalchemy.exc.SQLAlchemyError as e:
|
|
await logger.exception(f"Database error in {func.__name__}: {e}")
|
|
await kwargs["session"].rollback()
|
|
raise
|
|
|
|
except Exception as e:
|
|
await logger.exception(f"Unexpected error in {func.__name__}: {e}")
|
|
await kwargs["session"].rollback()
|
|
raise
|
|
|
|
|
|
def _async_gen_db_logger(func: Sessionizable) -> Sessionizable:
|
|
@functools.wraps(func)
|
|
async def wrapper[
|
|
**P
|
|
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
|
|
try:
|
|
async for item in func(*args, **kwargs):
|
|
yield item
|
|
|
|
except sqlalchemy.exc.SQLAlchemyError as e:
|
|
await logger.exception(f"Database error in {func.__name__}: {e}")
|
|
await kwargs["session"].rollback()
|
|
raise
|
|
|
|
except Exception as e:
|
|
await logger.exception(f"Unexpected error in {func.__name__}: {e}")
|
|
await kwargs["session"].rollback()
|
|
raise
|
|
|
|
|
|
def sessionize(func: Sessionizable) -> Sessionizable:
|
|
"""Decorator that ensures a database session is available to an async function or generator.
|
|
|
|
This decorator automatically injects a `session` of type sqlalchemy.ext.asyncio.AsyncSession
|
|
into the decorated function if one is not provided. If the `session` keyword argument is
|
|
missing or set to None, a new session is created using the async_session_maker and passed to
|
|
the function. If a session is already provided when the function is called, it uses the
|
|
existing session.
|
|
|
|
The decorator is intended for use with coroutine functions and async generator functions that
|
|
are expected to interact with a database using SQLAlchemy's asynchronous session management.
|
|
|
|
Features:
|
|
- Automatically manages database sessions to ensure efficient and correct usage of database
|
|
connections.
|
|
- Logs any SQLAlchemy-specific exceptions or other exceptions that occur during the execution
|
|
of the decorated function, handling rollbacks if necessary.
|
|
- Ensures that functions without a proper 'session' keyword argument are not decorated, raising
|
|
an UnboundLocalError.
|
|
|
|
Args:
|
|
func (Sessionizable): A coroutine function or an async generator function that accepts a
|
|
'session' keyword argument.
|
|
|
|
Returns:
|
|
Sessionizable: A wrapped version of the input function that manages a database session.
|
|
|
|
Raises:
|
|
TypeError: If `func` is neither a coroutine function nor an async generator function.
|
|
UnboundLocalError: If `func` lacks a 'session' keyword argument.
|
|
|
|
Example:
|
|
@sessionize
|
|
async def fetch_data(session: Optional[AsyncSession] = None):
|
|
# Function body using session
|
|
...
|
|
|
|
Note:
|
|
- The function must include a 'session' parameter in its signature, which can be None by
|
|
default.
|
|
- This decorator is only applicable to functions intended to perform database operations
|
|
asynchronously.
|
|
|
|
See Also:
|
|
- Sessionizable: The Protocol for sessionizable functions.
|
|
- AsyncSession: SQLAlchemy class used for asynchronous session management.
|
|
"""
|
|
ret: Sessionizable
|
|
match func:
|
|
case func if inspect.iscoroutinefunction(func):
|
|
|
|
@functools.wraps(func)
|
|
async def _coro_wrapper[
|
|
**P
|
|
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.Coroutine:
|
|
await _validate_session_arg(func.__name__, kwargs)
|
|
logged_func = _coro_db_logger(func)
|
|
|
|
if kwargs["session"] is None:
|
|
async with async_session_maker as kwargs["session"]:
|
|
return await logged_func(*args, **kwargs)
|
|
else:
|
|
return await logged_func(*args, **kwargs)
|
|
|
|
ret = _coro_wrapper
|
|
|
|
case func if inspect.isasyncgenfunction(func):
|
|
|
|
@functools.wraps(func)
|
|
async def _asyncgen_wrapper[
|
|
**P
|
|
](*args: P.args, **kwargs: P.kwargs) -> collections.abc.AsyncGenerator:
|
|
"""Wrap a sessionized async generator function to inject the session if needed."""
|
|
await _validate_session_arg(func.__name__, kwargs)
|
|
|
|
logged_func = _async_gen_db_logger(func)
|
|
if kwargs["session"] is None:
|
|
async with async_session_maker as kwargs["session"]:
|
|
async for element in logged_func(*args, **kwargs):
|
|
yield element
|
|
else:
|
|
async for element in logged_func(*args, **kwargs):
|
|
yield element
|
|
|
|
ret = _asyncgen_wrapper
|
|
|
|
case _:
|
|
msg = (
|
|
f"Decorated function {func.__name__} is not an async generator or coroutine"
|
|
"function."
|
|
)
|
|
logger.error(msg)
|
|
raise TypeError(msg)
|
|
|
|
return ret
|
|
|
|
|
|
async def init_db() -> None:
|
|
"""Create all tables asynchronously."""
|
|
try:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(mapper_registry.metadata.create_all)
|
|
await logger.info("Database tables created.")
|
|
except Exception as e:
|
|
await logger.error(f"Failed to create tables: {e}")
|
|
|
|
|
|
async def drop_db() -> None:
|
|
"""Drop all tables asynchronously for clean slate testing or teardown."""
|
|
try:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(mapper_registry.metadata.drop_all)
|
|
await logger.info("Database tables dropped.")
|
|
except Exception as e:
|
|
await logger.error(f"Failed to drop tables: {e}")
|
|
|
|
|
|
@sessionize
|
|
async def insert_or_update_track(
|
|
track_data: dict[str, typing.Any],
|
|
*,
|
|
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
|
) -> None:
|
|
"""Insert a new track or update an existing one in the database asynchronously.
|
|
|
|
Args:
|
|
track_data: The track data to write to the database.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
""" # noqa: D417
|
|
async with session.begin():
|
|
genre = await session.get(models.Genre, track_data.get("genre_id"))
|
|
if not genre:
|
|
genre = models.Genre(name=track_data["genre_name"])
|
|
session.add(genre)
|
|
await session.flush() # Ensures 'genre' is persisted and has an 'id'
|
|
|
|
track = await session.get(models.Track, track_data.get("id"))
|
|
if track:
|
|
for key, value in track_data.items():
|
|
setattr(track, key, value)
|
|
await logger.info(f"Updated track: {track.title}")
|
|
else:
|
|
track = models.Track(**track_data)
|
|
session.add(track)
|
|
await logger.info(f"Inserted new track: {track.title}")
|
|
|
|
await session.commit()
|
|
|
|
|
|
@sessionize
|
|
async def cleanup_unused_genres(
|
|
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
|
) -> None:
|
|
"""Remove genres that are no longer used by any tracks.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
"""
|
|
async with session.begin():
|
|
stmt = (
|
|
sqlalchemy.select(models.Genre)
|
|
.outerjoin(models.Track)
|
|
.filter(models.Track.id is None)
|
|
)
|
|
result = await session.execute(stmt)
|
|
unused_genres = result.scalars().all()
|
|
|
|
for genre in unused_genres:
|
|
await session.delete(genre)
|
|
|
|
await session.commit()
|
|
await logger.info(f"Cleaned up {len(unused_genres)} unused genres.")
|
|
|
|
|
|
@sessionize
|
|
async def set_timestamp(
|
|
event: enums.Event,
|
|
operation: enums.TimedEvent,
|
|
*,
|
|
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
|
) -> datetime.datetime:
|
|
"""Mark the start or end timestamps in the statistics model.
|
|
|
|
Args:
|
|
event: Identifies what event is being affected.
|
|
operation: Identifies if it is the start or end of the event.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
|
|
Returns:
|
|
datetime.datetime: The datetime.datetime of the last update.
|
|
"""
|
|
async with session.begin():
|
|
stats = await session.get(models.Statistics, 1)
|
|
stats_obj = stats.scalars().first()
|
|
|
|
match [event, operation]:
|
|
case [enums.Event.UPDATER, enums.EventTime.START]:
|
|
stats_obj.start_update = datetime.datetime.now()
|
|
log_message = "Updater start timestamp updated."
|
|
case [enums.Event.UPDATER, enums.EventTime.END]:
|
|
stats_obj.end_update = datetime.datetime.now()
|
|
log_message = "Updater end timestamp updated."
|
|
case [enums.Event.GENERATOR, enums.EventTime.START]:
|
|
stats_obj.start_playlist_gen = datetime.datetime.now()
|
|
log_message = "Generator start timestamp updated."
|
|
case [enums.Event.GENERATOR, enums.EventTime.END]:
|
|
stats_obj.end_playlist_gen = datetime.datetime.now()
|
|
log_message = "Generator end timestamp updated."
|
|
case _:
|
|
raise ValueError(
|
|
f"Invalid event {event} or operation {operation} specified."
|
|
)
|
|
|
|
match event:
|
|
case enums.Event.UPDATER:
|
|
last_updated = stats_obj.end_update
|
|
case enums.Event.GENERATOR:
|
|
last_updated = stats_obj.end_playlist_gen
|
|
case _:
|
|
raise ValueError(f"Invalid event {event} specified.")
|
|
|
|
await session.commit()
|
|
await logger.info(log_message)
|
|
|
|
return last_updated
|
|
|
|
|
|
@sessionize
|
|
async def remove_holiday_by_id(
|
|
holiday_id: int, *, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
|
) -> None:
|
|
"""Remove a holiday from the database by ID.
|
|
|
|
Captures and logs database-specific errors and general exceptions.
|
|
|
|
Args:
|
|
holiday_id: The holiday ID to remove.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
""" # noqa: D417
|
|
holiday = await session.get(models.Holiday, holiday_id)
|
|
if holiday:
|
|
await session.delete(holiday)
|
|
await session.commit()
|
|
else:
|
|
await logger.warning(f"No holiday found with ID: {holiday_id}")
|
|
|
|
|
|
@sessionize
|
|
async def update_statistics(
|
|
new_average_duration: float,
|
|
new_podcast_length: float,
|
|
new_max_playtime: float,
|
|
new_category_count: int,
|
|
new_active_holidays: int,
|
|
new_max_tracks: int,
|
|
new_max_regular: int,
|
|
new_max_holiday: int,
|
|
new_max_regular_favorite: int,
|
|
new_max_regular_general: int,
|
|
new_max_holiday_favorite: int,
|
|
new_max_holiday_general: int,
|
|
new_max_regular_favorite_base: int,
|
|
new_max_regular_general_base: int,
|
|
new_max_holiday_favorite_base: int,
|
|
new_max_holiday_general_base: int,
|
|
*,
|
|
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
|
) -> None:
|
|
"""Update the statistics model.
|
|
|
|
Args:
|
|
new_average_duration: The new average duration of a track by geometric mean
|
|
of all track durations as a float.
|
|
new_podcast_length: The new total podcast length of all current episodes as a float.
|
|
new_max_playtime: The new max playtime as a float.
|
|
new_category_count: The new count of categories as an int.
|
|
new_active_holidays: The new active holidays as an int.
|
|
new_max_tracks: The new max tracks per day as an int.
|
|
new_max_regular: The new max regular tracks as an int.
|
|
new_max_holiday: The new max holiday tracks as an int.
|
|
new_max_regular_favorite: The new max regular favorite tracks as an int.
|
|
new_max_regular_general: The new max regular general tracks as an int.
|
|
new_max_holiday_favorite: The new max holiday favorite tracks as an int.
|
|
new_max_holiday_general: The new max holiday general tracks as an int.
|
|
new_max_regular_favorite_base: The new max regular favorite base unit as an int.
|
|
new_max_regular_general_base: The new max regular general base unit as an int.
|
|
new_max_holiday_favorite_base: The new max holiday favorite base unit as an int.
|
|
new_max_holiday_general_base: The new max holiday general base unit as an int.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
""" # noqa: D417
|
|
async with session.begin():
|
|
# Assuming there's only one statistics record, or you might need to handle this differently
|
|
stats = await session.get(models.Statistics, 1)
|
|
if stats:
|
|
stats.average_track_length = new_average_duration
|
|
stats.total_podcast_length = new_podcast_length
|
|
stats.max_playtime = new_max_playtime
|
|
stats.category_count = new_category_count
|
|
stats.active_holiday_count = new_active_holidays
|
|
stats.max_tracks_per_day = new_max_tracks
|
|
stats.max_regular_tracks = new_max_regular
|
|
stats.max_holiday_tracks = new_max_holiday
|
|
stats.max_regular_favorite_tracks = new_max_regular_favorite
|
|
stats.max_regular_general_tracks = new_max_regular_general
|
|
stats.max_holiday_favorite_tracks = new_max_holiday_favorite
|
|
stats.max_holiday_general_tracks = new_max_holiday_general
|
|
stats.max_regular_favorite_base = new_max_regular_favorite_base
|
|
stats.max_regular_general_base = new_max_regular_general_base
|
|
stats.max_holiday_favorite_base = new_max_holiday_favorite_base
|
|
stats.max_holiday_general_base = new_max_holiday_general_base
|
|
await session.commit()
|
|
else:
|
|
# If no statistics entry exists, create one
|
|
new_stats = models.Statistics(
|
|
average_track_length=new_average_duration,
|
|
total_podcast_length=new_podcast_length,
|
|
max_playtime=new_max_playtime,
|
|
active_holiday_count=new_active_holidays,
|
|
max_tracks_per_day=new_max_tracks,
|
|
max_regular_tracks=new_max_regular,
|
|
max_holiday_tracks=new_max_holiday,
|
|
max_regular_favorite_tracks=new_max_regular_favorite,
|
|
max_regular_general_tracks=new_max_regular_general,
|
|
max_holiday_favorite_tracks=new_max_holiday_favorite,
|
|
max_holiday_general_tracks=new_max_holiday_general,
|
|
max_regular_favorite_base=new_max_regular_favorite_base,
|
|
max_regular_general_base=new_max_regular_general_base,
|
|
max_holiday_favorite_base=new_max_holiday_favorite_base,
|
|
max_holiday_general_base=new_max_holiday_general_base,
|
|
)
|
|
session.add(new_stats)
|
|
await session.commit()
|
|
|
|
|
|
@sessionize
|
|
async def fetch_tracks_for_sublist(
|
|
is_category: bool,
|
|
type_id: int,
|
|
is_favorite: bool,
|
|
sublist: enums.Sublist,
|
|
limit: int,
|
|
playlist_entries: set[int],
|
|
*,
|
|
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
|
) -> collections.abc.AsyncGenerator[models.Track, None]:
|
|
"""Fetch tracks for a given sublist type within a category.
|
|
|
|
Args:
|
|
is_category: True if a category, false if a holiday.
|
|
type_id: Category or holiday ID for track filtering.
|
|
is_favorite: Flag indicating if only favorite tracks should be fetched.
|
|
sublist: Type of sublist to fetch.
|
|
limit: Number of tracks to fetch.
|
|
playlist_entries: The set of track IDs already in the playlist.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
|
|
Returns:
|
|
collections.abc.AsyncGenerator[models.Track, None]: Generates the tracks fitting the
|
|
criteria.
|
|
"""
|
|
query = sqlalchemy.select(models.Track)
|
|
if is_category:
|
|
query = query.where(models.Track.category_id == type_id)
|
|
else:
|
|
query = query.join(models.Track.holidays).where(models.Holiday.id == type_id)
|
|
|
|
query = query.filter(models.Track.id.not_in(playlist_entries))
|
|
|
|
if is_favorite:
|
|
query = query.filter(models.Track.rating == 5)
|
|
|
|
match sublist:
|
|
case enums.Sublist.LEAST_RECENTLY_PLAYED:
|
|
query = query.order_by(models.Track.last_played.asc())
|
|
case enums.Sublist.LEAST_OFTEN_PLAYED:
|
|
query = query.order_by(models.Track.play_count.asc())
|
|
case enums.Sublist.LEAST_RECENTLY_ADDED:
|
|
query = query.order_by(models.Track.date_added.asc())
|
|
case enums.Sublist.RANDOM:
|
|
query = query.order_by(sqlalchemy.func.random())
|
|
case _:
|
|
raise ValueError(f"Unknown sublist type: {sublist}")
|
|
|
|
query = query.limit(limit)
|
|
result = await session.execute(query)
|
|
for track in result.scalars():
|
|
yield track
|
|
|
|
|
|
@sessionize
|
|
async def insert_into_playlist( # noqa: C901
|
|
is_category: bool,
|
|
type_id: int,
|
|
track_id: int,
|
|
sublist: enums.Sublist,
|
|
is_favorite: bool,
|
|
*,
|
|
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
|
) -> None:
|
|
"""Insert a track into the playlist.
|
|
|
|
Args:
|
|
is_category: True is a category, false is a holiday.
|
|
type_id: Category or holiday ID associated with the playlist.
|
|
track_id: Track ID to be added to the playlist.
|
|
sublist: Sublist type under which the track is added.
|
|
is_favorite: Indicates if the track is added as a favorite.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
""" # noqa: D417
|
|
if is_category:
|
|
playlist_track = models.PlaylistTrack(
|
|
category_id=type_id,
|
|
track_id=track_id,
|
|
sublist_type=sublist,
|
|
is_favorite=is_favorite,
|
|
)
|
|
else:
|
|
playlist_track = models.PlaylistTrack(
|
|
holiday_id=type_id,
|
|
track_id=track_id,
|
|
sublist_type=sublist,
|
|
is_favorite=is_favorite,
|
|
)
|
|
session.add(playlist_track)
|
|
await session.commit()
|
|
|
|
|
|
@sessionize
|
|
async def gen_all_category_ids(
|
|
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
|
) -> collections.abc.AsyncGenerator[int, None, None]:
|
|
"""Generate the current category ids from the database.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
|
|
Yields:
|
|
int: The category id.
|
|
"""
|
|
query = sqlalchemy.select(models.Category)
|
|
result = session.execute(query)
|
|
for row in result.scalars():
|
|
yield row.id
|
|
|
|
|
|
@sessionize
|
|
async def gen_active_holiday_ids(
|
|
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
|
) -> collections.abc.AsyncGenerator[int, None, None]:
|
|
"""Generate the current active holiday ids from the database.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
|
|
Yields:
|
|
int: The holiday id.
|
|
"""
|
|
query = sqlalchemy.select(models.Holiday).where(models.Holiday.is_active is True)
|
|
result = session.execute(query)
|
|
for row in result.scalars():
|
|
yield row.id
|
|
|
|
|
|
@sessionize
|
|
async def get_existing_playlist_track_info(
|
|
*,
|
|
session: sqlalchemy.ext.asyncio.AsyncSession = None,
|
|
) -> tuple[set[int], dict[tuple[bool, int, bool, enums.Sublist], int]]:
|
|
"""Get the existing playlist track information.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
|
|
Returns:
|
|
tuple[set[int], dict[tuple[bool, int, bool, enums.Sublist], int]]: Two pieces:
|
|
First is a set of the ids of all of the tracks in the playlist. Second is a
|
|
dict with the key being the combination of is_category, type_id, is_favorite,
|
|
and sublist; the value is the count of that unique key combination.
|
|
"""
|
|
# TODO: Set up user-specific query here.
|
|
query = sqlalchemy.select(models.PlaylistTrack).where(
|
|
models.PlaylistTrack.episode_id is None
|
|
)
|
|
result = session.execute(query)
|
|
ids = set()
|
|
data = []
|
|
for row in result.scalars():
|
|
if row.episode_id is not None:
|
|
continue
|
|
ids.append(row.track_id)
|
|
is_category = row.category_id is not None
|
|
type_id = row.category_id if is_category else row.holiday_id
|
|
is_favorite = row.is_favorite
|
|
sublist = row.sublist
|
|
data.append((is_category, type_id, is_favorite, sublist))
|
|
data_counts = collections.Counter(data)
|
|
return ids, data_counts
|
|
|
|
|
|
@sessionize
|
|
async def get_statistics(
|
|
*, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
|
) -> models.Statistics:
|
|
"""Get the statistics object for the database.
|
|
|
|
Keyword Args:
|
|
session: The database session.
|
|
|
|
Returns:
|
|
models.Statistics: The statistics object used for playlist processing.
|
|
"""
|
|
stats = await session.get(models.Statistics, 1)
|
|
stats_obj = stats.scalars().first()
|
|
return stats_obj
|
|
|
|
|
|
@sessionize
|
|
async def remove_track_from_playlist(
|
|
server_id: str, *, session: sqlalchemy.ext.asyncio.AsyncSession = None
|
|
):
|
|
"""Directly removes all entries from PlaylistTrack that are associated with a given server_id.
|
|
|
|
Args:
|
|
server_id: The unique server ID of the track to remove from the playlist.
|
|
|
|
Keyword Args:
|
|
session: The current database session.
|
|
|
|
Returns:
|
|
int: Number of rows affected by the delete operation.
|
|
"""
|
|
# Delete directly using a join on Track where server_id matches
|
|
delete_query = sqlalchemy.delete(models.PlaylistTrack).where(
|
|
models.PlaylistTrack.track_id
|
|
== sqlalchemy.select(models.Track.id)
|
|
.where(models.Track.server_id == server_id)
|
|
.scalar_subquery()
|
|
)
|
|
result = await session.execute(delete_query)
|
|
await session.commit()
|
|
|
|
affected_rows = result.rowcount
|
|
logger.info(
|
|
f"Removed {affected_rows} entries from PlaylistTrack for server_id {server_id}."
|
|
)
|
|
return affected_rows
|