Files
notifications-api/migrations/versions/0418_user_state_enum.py
Kenneth Kehl efa97a7f79 fix migration
2025-08-28 13:02:00 -07:00

157 lines
3.7 KiB
Python

"""
Revision ID: 0418_user_state_enum
Revises: 0417_change_total_message_limit
Create Date: 2025-08-28 12:34:32.857422
"""
revision = "0418_user_state_enum"
down_revision = "0417_change_total_message_limit"
from contextlib import contextmanager
from enum import Enum
from typing import Iterator, TypedDict
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from app.enums import (
UserState,
)
class EnumValues(TypedDict):
values: list[str]
name: str
_enum_params: dict[Enum, EnumValues] = {
UserState: {
"values": ["active", "inactive", "pending"],
"name": "user_states",
},
}
def enum_create(values: list[str], name: str) -> None:
enum_db_type = postgresql.ENUM(*values, name=name)
enum_db_type.create(op.get_bind())
def enum_drop(values: list[str], name: str) -> None:
enum_db_type = postgresql.ENUM(*values, name=name)
enum_db_type.drop(op.get_bind())
def enum_using(column_name: str, enum: Enum) -> str:
return f"{column_name}::text::{_enum_params[enum]['name']}"
def enum_type(enum: Enum) -> sa.Enum:
return sa.Enum(
*_enum_params[enum]["values"],
name=_enum_params[enum]["name"],
values_callable=(lambda x: [e.value for e in x]),
)
@contextmanager
def view_handler() -> Iterator[None]:
op.execute("DROP VIEW notifications_all_time_view")
yield
op.execute(
"""
CREATE VIEW notifications_all_time_view AS
(
SELECT
id,
job_id,
job_row_number,
service_id,
template_id,
template_version,
api_key_id,
key_type,
billable_units,
notification_type,
created_at,
sent_at,
sent_by,
updated_at,
notification_status,
reference,
client_reference,
international,
phone_prefix,
rate_multiplier,
created_by_id,
document_download_count
FROM notifications
) UNION
(
SELECT
id,
job_id,
job_row_number,
service_id,
template_id,
template_version,
api_key_id,
key_type,
billable_units,
notification_type,
created_at,
sent_at,
sent_by,
updated_at,
notification_status,
reference,
client_reference,
international,
phone_prefix,
rate_multiplier,
created_by_id,
document_download_count
FROM notification_history
)
"""
)
def upgrade():
with view_handler():
for enum_data in _enum_params.values():
enum_create(**enum_data)
# alter existing columns to use new enums
op.alter_column(
"users",
"state",
existing_type=sa.VARCHAR(length=255),
type_=enum_type(UserState),
existing_nullable=True,
postgresql_using=enum_using("state", UserState),
)
def downgrade():
with view_handler():
# Create old enum types.
# Alter columns back
op.alter_column(
"users",
"state",
existing_type=enum_type(UserState),
type_=sa.VARCHAR(length=255),
existing_nullable=True,
)
for enum_data in _enum_params.values():
enum_drop(**enum_data)