Merge pull request #1382 from GSA/notify-api-1325

upgrade test queries to sqlalchemy 2.0
This commit is contained in:
Kenneth Kehl
2024-11-13 08:35:20 -08:00
committed by GitHub
30 changed files with 542 additions and 286 deletions

View File

@@ -12,7 +12,7 @@ from click_datetime import Datetime as click_dt
from faker import Faker
from flask import current_app, json
from notifications_python_client.authentication import create_jwt_token
from sqlalchemy import and_, text
from sqlalchemy import and_, select, text, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound
@@ -123,8 +123,8 @@ def purge_functional_test_data(user_email_prefix):
if getenv("NOTIFY_ENVIRONMENT", "") not in ["development", "test"]:
current_app.logger.error("Can only be run in development")
return
users = User.query.filter(User.email_address.like(f"{user_email_prefix}%")).all()
stmt = select(User).where(User.email_address.like(f"{user_email_prefix}%"))
users = db.session.execute(stmt).scalars().all()
for usr in users:
# Make sure the full email includes a uuid in it
# Just in case someone decides to use a similar email address.
@@ -338,9 +338,10 @@ def populate_organizations_from_file(file_name):
email_branding = None
email_branding_column = columns[5].strip()
if len(email_branding_column) > 0:
email_branding = EmailBranding.query.filter(
stmt = select(EmailBranding).where(
EmailBranding.name == email_branding_column
).one()
)
email_branding = db.session.execute(stmt).scalars().one()
data = {
"name": columns[0],
"active": True,
@@ -406,10 +407,14 @@ def populate_organization_agreement_details_from_file(file_name):
@notify_command(name="associate-services-to-organizations")
def associate_services_to_organizations():
services = Service.get_history_model().query.filter_by(version=1).all()
stmt = select(Service.get_history_model()).where(
Service.get_history_model().version == 1
)
services = db.session.execute(stmt).scalars().all()
for s in services:
created_by_user = User.query.filter_by(id=s.created_by_id).first()
stmt = select(User).where(User.id == s.created_by_id)
created_by_user = db.session.execute(stmt).scalars().first()
organization = dao_get_organization_by_email_address(
created_by_user.email_address
)
@@ -467,15 +472,16 @@ def populate_go_live(file_name):
@notify_command(name="fix-billable-units")
def fix_billable_units():
query = Notification.query.filter(
stmt = select(Notification).where(
Notification.notification_type == NotificationType.SMS,
Notification.status != NotificationStatus.CREATED,
Notification.sent_at == None, # noqa
Notification.billable_units == 0,
Notification.key_type != KeyType.TEST,
)
all = db.session.execute(stmt).scalars().all()
for notification in query.all():
for notification in all:
template_model = dao_get_template_by_id(
notification.template_id, notification.template_version
)
@@ -490,9 +496,12 @@ def fix_billable_units():
f"Updating notification: {notification.id} with {template.fragment_count} billable_units"
)
Notification.query.filter(Notification.id == notification.id).update(
{"billable_units": template.fragment_count}
stmt = (
update(Notification)
.where(Notification.id == notification.id)
.values({"billable_units": template.fragment_count})
)
db.session.execute(stmt)
db.session.commit()
current_app.logger.info("End fix_billable_units")
@@ -637,8 +646,9 @@ def populate_annual_billing_with_defaults(year, missing_services_only):
This is useful to ensure all services start the new year with the correct annual billing.
"""
if missing_services_only:
active_services = (
Service.query.filter(Service.active)
stmt = (
select(Service)
.where(Service.active)
.outerjoin(
AnnualBilling,
and_(
@@ -647,10 +657,11 @@ def populate_annual_billing_with_defaults(year, missing_services_only):
),
)
.filter(AnnualBilling.id == None) # noqa
.all()
)
active_services = db.session.execute(stmt).scalars().all()
else:
active_services = Service.query.filter(Service.active).all()
stmt = select(Service).where(Service.active)
active_services = db.session.execute(stmt).scalars().all()
previous_year = year - 1
services_with_zero_free_allowance = (
db.session.query(AnnualBilling.service_id)
@@ -750,7 +761,8 @@ def create_user_jwt(token):
def _update_template(id, name, template_type, content, subject):
template = Template.query.filter_by(id=id).first()
stmt = select(Template).where(Template.id == id)
template = db.session.execute(stmt).scalars().first()
if not template:
template = Template(id=id)
template.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553"
@@ -761,7 +773,8 @@ def _update_template(id, name, template_type, content, subject):
template.content = "\n".join(content)
template.subject = subject
history = TemplateHistory.query.filter_by(id=id).first()
stmt = select(TemplateHistory).where(TemplateHistory.id == id)
history = db.session.execute(stmt).scalars().first()
if not history:
history = TemplateHistory(id=id)
history.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553"

View File

@@ -1,5 +1,7 @@
from datetime import timedelta
from sqlalchemy import select
from app import db
from app.enums import InvitedUserStatus
from app.models import InvitedUser
@@ -12,30 +14,37 @@ def save_invited_user(invited_user):
def get_invited_user_by_service_and_id(service_id, invited_user_id):
return InvitedUser.query.filter(
stmt = select(InvitedUser).where(
InvitedUser.service_id == service_id,
InvitedUser.id == invited_user_id,
).one()
)
return db.session.execute(stmt).scalars().one()
def get_expired_invite_by_service_and_id(service_id, invited_user_id):
return InvitedUser.query.filter(
stmt = select(InvitedUser).where(
InvitedUser.service_id == service_id,
InvitedUser.id == invited_user_id,
InvitedUser.status == InvitedUserStatus.EXPIRED,
).one()
)
return db.session.execute(stmt).scalars().one()
def get_invited_user_by_id(invited_user_id):
return InvitedUser.query.filter(InvitedUser.id == invited_user_id).one()
stmt = select(InvitedUser).where(InvitedUser.id == invited_user_id)
return db.session.execute(stmt).scalars().one()
def get_expired_invited_users_for_service(service_id):
return InvitedUser.query.filter(InvitedUser.service_id == service_id).all()
# TODO why does this return all invited users?
stmt = select(InvitedUser).where(InvitedUser.service_id == service_id)
return db.session.execute(stmt).scalars().all()
def get_invited_users_for_service(service_id):
return InvitedUser.query.filter(InvitedUser.service_id == service_id).all()
stmt = select(InvitedUser).where(InvitedUser.service_id == service_id)
return db.session.execute(stmt).scalars().all()
def expire_invitations_created_more_than_two_days_ago():

View File

@@ -192,6 +192,7 @@ def get_notifications_for_job(
):
if page_size is None:
page_size = current_app.config["PAGE_SIZE"]
query = Notification.query.filter_by(service_id=service_id, job_id=job_id)
query = _filter_query(query, filter_dict)
return query.order_by(asc(Notification.job_row_number)).paginate(

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from flask import current_app
from sqlalchemy import desc, func
from sqlalchemy import desc, func, select
from app import db
from app.dao.dao_utils import autocommit
@@ -11,11 +11,12 @@ from app.utils import utc_now
def get_provider_details_by_id(provider_details_id):
return ProviderDetails.query.get(provider_details_id)
return db.session.get(ProviderDetails, provider_details_id)
def get_provider_details_by_identifier(identifier):
return ProviderDetails.query.filter_by(identifier=identifier).one()
stmt = select(ProviderDetails).where(ProviderDetails.identifier == identifier)
return db.session.execute(stmt).scalars().one()
def get_alternative_sms_provider(identifier):
@@ -25,12 +26,14 @@ def get_alternative_sms_provider(identifier):
def dao_get_provider_versions(provider_id):
return (
ProviderDetailsHistory.query.filter_by(id=provider_id)
stmt = (
select(ProviderDetailsHistory)
.where(ProviderDetailsHistory.id == provider_id)
.order_by(desc(ProviderDetailsHistory.version))
.limit(100) # limit results instead of adding pagination
.all()
.limit(100)
)
# limit results instead of adding pagination
return db.session.execute(stmt).scalars().all()
def _get_sms_providers_for_update(time_threshold):
@@ -42,14 +45,15 @@ def _get_sms_providers_for_update(time_threshold):
release the transaction in that case
"""
# get current priority of both providers
q = (
ProviderDetails.query.filter(
stmt = (
select(ProviderDetails)
.where(
ProviderDetails.notification_type == NotificationType.SMS,
ProviderDetails.active,
)
.with_for_update()
.all()
)
q = db.session.execute(stmt).scalars().all()
# if something updated recently, don't update again. If the updated_at is null, treat it as min time
if any(
@@ -72,7 +76,8 @@ def get_provider_details_by_notification_type(
if supports_international:
filters.append(ProviderDetails.supports_international == supports_international)
return ProviderDetails.query.filter(*filters).all()
stmt = select(ProviderDetails).where(*filters)
return db.session.execute(stmt).scalars().all()
@autocommit

View File

@@ -1,3 +1,5 @@
from sqlalchemy import select, update
from app import db
from app.dao.dao_utils import autocommit
from app.models import ServiceDataRetention
@@ -5,29 +7,31 @@ from app.utils import utc_now
def fetch_service_data_retention_by_id(service_id, data_retention_id):
data_retention = ServiceDataRetention.query.filter_by(
service_id=service_id, id=data_retention_id
).first()
return data_retention
stmt = select(ServiceDataRetention).where(
ServiceDataRetention.service_id == service_id,
ServiceDataRetention.id == data_retention_id,
)
return db.session.execute(stmt).scalars().first()
def fetch_service_data_retention(service_id):
data_retention_list = (
ServiceDataRetention.query.filter_by(service_id=service_id)
stmt = (
select(ServiceDataRetention)
.where(ServiceDataRetention.service_id == service_id)
.order_by(
# in the order that models.notification_types are created (email, sms, letter)
ServiceDataRetention.notification_type
)
.all()
)
return data_retention_list
return db.session.execute(stmt).scalars().all()
def fetch_service_data_retention_by_notification_type(service_id, notification_type):
data_retention_list = ServiceDataRetention.query.filter_by(
service_id=service_id, notification_type=notification_type
).first()
return data_retention_list
stmt = select(ServiceDataRetention).where(
ServiceDataRetention.service_id == service_id,
ServiceDataRetention.notification_type == notification_type,
)
return db.session.execute(stmt).scalars().first()
@autocommit
@@ -46,16 +50,22 @@ def insert_service_data_retention(service_id, notification_type, days_of_retenti
def update_service_data_retention(
service_data_retention_id, service_id, days_of_retention
):
updated_count = ServiceDataRetention.query.filter(
ServiceDataRetention.id == service_data_retention_id,
ServiceDataRetention.service_id == service_id,
).update({"days_of_retention": days_of_retention, "updated_at": utc_now()})
return updated_count
stmt = (
update(ServiceDataRetention)
.where(
ServiceDataRetention.id == service_data_retention_id,
ServiceDataRetention.service_id == service_id,
)
.values({"days_of_retention": days_of_retention, "updated_at": utc_now()})
)
result = db.session.execute(stmt)
return result.rowcount
def fetch_service_data_retention_for_all_services_by_notification_type(
notification_type,
):
return ServiceDataRetention.query.filter(
stmt = select(ServiceDataRetention).where(
ServiceDataRetention.notification_type == notification_type
).all()
)
return db.session.execute(stmt).scalars().all()

View File

@@ -1,11 +1,12 @@
from sqlalchemy import delete, select
from app import db
from app.models import ServiceGuestList
def dao_fetch_service_guest_list(service_id):
return ServiceGuestList.query.filter(
ServiceGuestList.service_id == service_id
).all()
stmt = select(ServiceGuestList).where(ServiceGuestList.service_id == service_id)
return db.session.execute(stmt).scalars().all()
def dao_add_and_commit_guest_list_contacts(objs):
@@ -14,6 +15,6 @@ def dao_add_and_commit_guest_list_contacts(objs):
def dao_remove_service_guest_list(service_id):
return ServiceGuestList.query.filter(
ServiceGuestList.service_id == service_id
).delete()
stmt = delete(ServiceGuestList).where(ServiceGuestList.service_id == service_id)
result = db.session.execute(stmt)
return result.rowcount

View File

@@ -1,13 +1,16 @@
from sqlalchemy import select
from app import db
from app.dao.dao_utils import autocommit
from app.models import WebauthnCredential
def dao_get_webauthn_credential_by_user_and_id(user_id, webauthn_credential_id):
return WebauthnCredential.query.filter(
stmt = select(WebauthnCredential).where(
WebauthnCredential.user_id == user_id,
WebauthnCredential.id == webauthn_credential_id,
).one()
)
return db.session.execute(stmt).scalars().one()
@autocommit