mirror of
https://github.com/GSA/notifications-api.git
synced 2026-01-30 06:21:50 -05:00
Merge pull request #1382 from GSA/notify-api-1325
upgrade test queries to sqlalchemy 2.0
This commit is contained in:
@@ -12,7 +12,7 @@ from click_datetime import Datetime as click_dt
|
||||
from faker import Faker
|
||||
from flask import current_app, json
|
||||
from notifications_python_client.authentication import create_jwt_token
|
||||
from sqlalchemy import and_, text
|
||||
from sqlalchemy import and_, select, text, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
@@ -123,8 +123,8 @@ def purge_functional_test_data(user_email_prefix):
|
||||
if getenv("NOTIFY_ENVIRONMENT", "") not in ["development", "test"]:
|
||||
current_app.logger.error("Can only be run in development")
|
||||
return
|
||||
|
||||
users = User.query.filter(User.email_address.like(f"{user_email_prefix}%")).all()
|
||||
stmt = select(User).where(User.email_address.like(f"{user_email_prefix}%"))
|
||||
users = db.session.execute(stmt).scalars().all()
|
||||
for usr in users:
|
||||
# Make sure the full email includes a uuid in it
|
||||
# Just in case someone decides to use a similar email address.
|
||||
@@ -338,9 +338,10 @@ def populate_organizations_from_file(file_name):
|
||||
email_branding = None
|
||||
email_branding_column = columns[5].strip()
|
||||
if len(email_branding_column) > 0:
|
||||
email_branding = EmailBranding.query.filter(
|
||||
stmt = select(EmailBranding).where(
|
||||
EmailBranding.name == email_branding_column
|
||||
).one()
|
||||
)
|
||||
email_branding = db.session.execute(stmt).scalars().one()
|
||||
data = {
|
||||
"name": columns[0],
|
||||
"active": True,
|
||||
@@ -406,10 +407,14 @@ def populate_organization_agreement_details_from_file(file_name):
|
||||
|
||||
@notify_command(name="associate-services-to-organizations")
|
||||
def associate_services_to_organizations():
|
||||
services = Service.get_history_model().query.filter_by(version=1).all()
|
||||
stmt = select(Service.get_history_model()).where(
|
||||
Service.get_history_model().version == 1
|
||||
)
|
||||
services = db.session.execute(stmt).scalars().all()
|
||||
|
||||
for s in services:
|
||||
created_by_user = User.query.filter_by(id=s.created_by_id).first()
|
||||
stmt = select(User).where(User.id == s.created_by_id)
|
||||
created_by_user = db.session.execute(stmt).scalars().first()
|
||||
organization = dao_get_organization_by_email_address(
|
||||
created_by_user.email_address
|
||||
)
|
||||
@@ -467,15 +472,16 @@ def populate_go_live(file_name):
|
||||
|
||||
@notify_command(name="fix-billable-units")
|
||||
def fix_billable_units():
|
||||
query = Notification.query.filter(
|
||||
stmt = select(Notification).where(
|
||||
Notification.notification_type == NotificationType.SMS,
|
||||
Notification.status != NotificationStatus.CREATED,
|
||||
Notification.sent_at == None, # noqa
|
||||
Notification.billable_units == 0,
|
||||
Notification.key_type != KeyType.TEST,
|
||||
)
|
||||
all = db.session.execute(stmt).scalars().all()
|
||||
|
||||
for notification in query.all():
|
||||
for notification in all:
|
||||
template_model = dao_get_template_by_id(
|
||||
notification.template_id, notification.template_version
|
||||
)
|
||||
@@ -490,9 +496,12 @@ def fix_billable_units():
|
||||
f"Updating notification: {notification.id} with {template.fragment_count} billable_units"
|
||||
)
|
||||
|
||||
Notification.query.filter(Notification.id == notification.id).update(
|
||||
{"billable_units": template.fragment_count}
|
||||
stmt = (
|
||||
update(Notification)
|
||||
.where(Notification.id == notification.id)
|
||||
.values({"billable_units": template.fragment_count})
|
||||
)
|
||||
db.session.execute(stmt)
|
||||
db.session.commit()
|
||||
current_app.logger.info("End fix_billable_units")
|
||||
|
||||
@@ -637,8 +646,9 @@ def populate_annual_billing_with_defaults(year, missing_services_only):
|
||||
This is useful to ensure all services start the new year with the correct annual billing.
|
||||
"""
|
||||
if missing_services_only:
|
||||
active_services = (
|
||||
Service.query.filter(Service.active)
|
||||
stmt = (
|
||||
select(Service)
|
||||
.where(Service.active)
|
||||
.outerjoin(
|
||||
AnnualBilling,
|
||||
and_(
|
||||
@@ -647,10 +657,11 @@ def populate_annual_billing_with_defaults(year, missing_services_only):
|
||||
),
|
||||
)
|
||||
.filter(AnnualBilling.id == None) # noqa
|
||||
.all()
|
||||
)
|
||||
active_services = db.session.execute(stmt).scalars().all()
|
||||
else:
|
||||
active_services = Service.query.filter(Service.active).all()
|
||||
stmt = select(Service).where(Service.active)
|
||||
active_services = db.session.execute(stmt).scalars().all()
|
||||
previous_year = year - 1
|
||||
services_with_zero_free_allowance = (
|
||||
db.session.query(AnnualBilling.service_id)
|
||||
@@ -750,7 +761,8 @@ def create_user_jwt(token):
|
||||
|
||||
|
||||
def _update_template(id, name, template_type, content, subject):
|
||||
template = Template.query.filter_by(id=id).first()
|
||||
stmt = select(Template).where(Template.id == id)
|
||||
template = db.session.execute(stmt).scalars().first()
|
||||
if not template:
|
||||
template = Template(id=id)
|
||||
template.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553"
|
||||
@@ -761,7 +773,8 @@ def _update_template(id, name, template_type, content, subject):
|
||||
template.content = "\n".join(content)
|
||||
template.subject = subject
|
||||
|
||||
history = TemplateHistory.query.filter_by(id=id).first()
|
||||
stmt = select(TemplateHistory).where(TemplateHistory.id == id)
|
||||
history = db.session.execute(stmt).scalars().first()
|
||||
if not history:
|
||||
history = TemplateHistory(id=id)
|
||||
history.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553"
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app import db
|
||||
from app.enums import InvitedUserStatus
|
||||
from app.models import InvitedUser
|
||||
@@ -12,30 +14,37 @@ def save_invited_user(invited_user):
|
||||
|
||||
|
||||
def get_invited_user_by_service_and_id(service_id, invited_user_id):
|
||||
return InvitedUser.query.filter(
|
||||
|
||||
stmt = select(InvitedUser).where(
|
||||
InvitedUser.service_id == service_id,
|
||||
InvitedUser.id == invited_user_id,
|
||||
).one()
|
||||
)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
|
||||
|
||||
def get_expired_invite_by_service_and_id(service_id, invited_user_id):
|
||||
return InvitedUser.query.filter(
|
||||
stmt = select(InvitedUser).where(
|
||||
InvitedUser.service_id == service_id,
|
||||
InvitedUser.id == invited_user_id,
|
||||
InvitedUser.status == InvitedUserStatus.EXPIRED,
|
||||
).one()
|
||||
)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
|
||||
|
||||
def get_invited_user_by_id(invited_user_id):
|
||||
return InvitedUser.query.filter(InvitedUser.id == invited_user_id).one()
|
||||
stmt = select(InvitedUser).where(InvitedUser.id == invited_user_id)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
|
||||
|
||||
def get_expired_invited_users_for_service(service_id):
|
||||
return InvitedUser.query.filter(InvitedUser.service_id == service_id).all()
|
||||
# TODO why does this return all invited users?
|
||||
stmt = select(InvitedUser).where(InvitedUser.service_id == service_id)
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_invited_users_for_service(service_id):
|
||||
return InvitedUser.query.filter(InvitedUser.service_id == service_id).all()
|
||||
stmt = select(InvitedUser).where(InvitedUser.service_id == service_id)
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def expire_invitations_created_more_than_two_days_ago():
|
||||
|
||||
@@ -192,6 +192,7 @@ def get_notifications_for_job(
|
||||
):
|
||||
if page_size is None:
|
||||
page_size = current_app.config["PAGE_SIZE"]
|
||||
|
||||
query = Notification.query.filter_by(service_id=service_id, job_id=job_id)
|
||||
query = _filter_query(query, filter_dict)
|
||||
return query.order_by(asc(Notification.job_row_number)).paginate(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy import desc, func, select
|
||||
|
||||
from app import db
|
||||
from app.dao.dao_utils import autocommit
|
||||
@@ -11,11 +11,12 @@ from app.utils import utc_now
|
||||
|
||||
|
||||
def get_provider_details_by_id(provider_details_id):
|
||||
return ProviderDetails.query.get(provider_details_id)
|
||||
return db.session.get(ProviderDetails, provider_details_id)
|
||||
|
||||
|
||||
def get_provider_details_by_identifier(identifier):
|
||||
return ProviderDetails.query.filter_by(identifier=identifier).one()
|
||||
stmt = select(ProviderDetails).where(ProviderDetails.identifier == identifier)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
|
||||
|
||||
def get_alternative_sms_provider(identifier):
|
||||
@@ -25,12 +26,14 @@ def get_alternative_sms_provider(identifier):
|
||||
|
||||
|
||||
def dao_get_provider_versions(provider_id):
|
||||
return (
|
||||
ProviderDetailsHistory.query.filter_by(id=provider_id)
|
||||
stmt = (
|
||||
select(ProviderDetailsHistory)
|
||||
.where(ProviderDetailsHistory.id == provider_id)
|
||||
.order_by(desc(ProviderDetailsHistory.version))
|
||||
.limit(100) # limit results instead of adding pagination
|
||||
.all()
|
||||
.limit(100)
|
||||
)
|
||||
# limit results instead of adding pagination
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def _get_sms_providers_for_update(time_threshold):
|
||||
@@ -42,14 +45,15 @@ def _get_sms_providers_for_update(time_threshold):
|
||||
release the transaction in that case
|
||||
"""
|
||||
# get current priority of both providers
|
||||
q = (
|
||||
ProviderDetails.query.filter(
|
||||
stmt = (
|
||||
select(ProviderDetails)
|
||||
.where(
|
||||
ProviderDetails.notification_type == NotificationType.SMS,
|
||||
ProviderDetails.active,
|
||||
)
|
||||
.with_for_update()
|
||||
.all()
|
||||
)
|
||||
q = db.session.execute(stmt).scalars().all()
|
||||
|
||||
# if something updated recently, don't update again. If the updated_at is null, treat it as min time
|
||||
if any(
|
||||
@@ -72,7 +76,8 @@ def get_provider_details_by_notification_type(
|
||||
if supports_international:
|
||||
filters.append(ProviderDetails.supports_international == supports_international)
|
||||
|
||||
return ProviderDetails.query.filter(*filters).all()
|
||||
stmt = select(ProviderDetails).where(*filters)
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
@autocommit
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app import db
|
||||
from app.dao.dao_utils import autocommit
|
||||
from app.models import ServiceDataRetention
|
||||
@@ -5,29 +7,31 @@ from app.utils import utc_now
|
||||
|
||||
|
||||
def fetch_service_data_retention_by_id(service_id, data_retention_id):
|
||||
data_retention = ServiceDataRetention.query.filter_by(
|
||||
service_id=service_id, id=data_retention_id
|
||||
).first()
|
||||
return data_retention
|
||||
stmt = select(ServiceDataRetention).where(
|
||||
ServiceDataRetention.service_id == service_id,
|
||||
ServiceDataRetention.id == data_retention_id,
|
||||
)
|
||||
return db.session.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
def fetch_service_data_retention(service_id):
|
||||
data_retention_list = (
|
||||
ServiceDataRetention.query.filter_by(service_id=service_id)
|
||||
stmt = (
|
||||
select(ServiceDataRetention)
|
||||
.where(ServiceDataRetention.service_id == service_id)
|
||||
.order_by(
|
||||
# in the order that models.notification_types are created (email, sms, letter)
|
||||
ServiceDataRetention.notification_type
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return data_retention_list
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def fetch_service_data_retention_by_notification_type(service_id, notification_type):
|
||||
data_retention_list = ServiceDataRetention.query.filter_by(
|
||||
service_id=service_id, notification_type=notification_type
|
||||
).first()
|
||||
return data_retention_list
|
||||
stmt = select(ServiceDataRetention).where(
|
||||
ServiceDataRetention.service_id == service_id,
|
||||
ServiceDataRetention.notification_type == notification_type,
|
||||
)
|
||||
return db.session.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
@autocommit
|
||||
@@ -46,16 +50,22 @@ def insert_service_data_retention(service_id, notification_type, days_of_retenti
|
||||
def update_service_data_retention(
|
||||
service_data_retention_id, service_id, days_of_retention
|
||||
):
|
||||
updated_count = ServiceDataRetention.query.filter(
|
||||
ServiceDataRetention.id == service_data_retention_id,
|
||||
ServiceDataRetention.service_id == service_id,
|
||||
).update({"days_of_retention": days_of_retention, "updated_at": utc_now()})
|
||||
return updated_count
|
||||
stmt = (
|
||||
update(ServiceDataRetention)
|
||||
.where(
|
||||
ServiceDataRetention.id == service_data_retention_id,
|
||||
ServiceDataRetention.service_id == service_id,
|
||||
)
|
||||
.values({"days_of_retention": days_of_retention, "updated_at": utc_now()})
|
||||
)
|
||||
result = db.session.execute(stmt)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def fetch_service_data_retention_for_all_services_by_notification_type(
|
||||
notification_type,
|
||||
):
|
||||
return ServiceDataRetention.query.filter(
|
||||
stmt = select(ServiceDataRetention).where(
|
||||
ServiceDataRetention.notification_type == notification_type
|
||||
).all()
|
||||
)
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from app import db
|
||||
from app.models import ServiceGuestList
|
||||
|
||||
|
||||
def dao_fetch_service_guest_list(service_id):
|
||||
return ServiceGuestList.query.filter(
|
||||
ServiceGuestList.service_id == service_id
|
||||
).all()
|
||||
stmt = select(ServiceGuestList).where(ServiceGuestList.service_id == service_id)
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def dao_add_and_commit_guest_list_contacts(objs):
|
||||
@@ -14,6 +15,6 @@ def dao_add_and_commit_guest_list_contacts(objs):
|
||||
|
||||
|
||||
def dao_remove_service_guest_list(service_id):
|
||||
return ServiceGuestList.query.filter(
|
||||
ServiceGuestList.service_id == service_id
|
||||
).delete()
|
||||
stmt = delete(ServiceGuestList).where(ServiceGuestList.service_id == service_id)
|
||||
result = db.session.execute(stmt)
|
||||
return result.rowcount
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from app import db
|
||||
from app.dao.dao_utils import autocommit
|
||||
from app.models import WebauthnCredential
|
||||
|
||||
|
||||
def dao_get_webauthn_credential_by_user_and_id(user_id, webauthn_credential_id):
|
||||
return WebauthnCredential.query.filter(
|
||||
stmt = select(WebauthnCredential).where(
|
||||
WebauthnCredential.user_id == user_id,
|
||||
WebauthnCredential.id == webauthn_credential_id,
|
||||
).one()
|
||||
)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
|
||||
|
||||
@autocommit
|
||||
|
||||
Reference in New Issue
Block a user