merge from main

This commit is contained in:
Kenneth Kehl
2024-10-28 13:26:37 -07:00
29 changed files with 1535 additions and 718 deletions

View File

@@ -1,7 +1,7 @@
from datetime import timedelta
from flask import current_app
from sqlalchemy import asc, desc, or_, select, text, union
from sqlalchemy import asc, delete, desc, func, or_, select, text, union, update
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql import functions
@@ -109,11 +109,12 @@ def _update_notification_status(
def update_notification_status_by_id(
notification_id, status, sent_by=None, provider_response=None, carrier=None
):
notification = (
Notification.query.with_for_update()
stmt = (
select(Notification)
.with_for_update()
.filter(Notification.id == notification_id)
.first()
)
notification = db.session.execute(stmt).scalars().first()
if not notification:
current_app.logger.info(
@@ -156,9 +157,8 @@ def update_notification_status_by_id(
@autocommit
def update_notification_status_by_reference(reference, status):
# this is used to update emails
notification = Notification.query.filter(
Notification.reference == reference
).first()
stmt = select(Notification).filter(Notification.reference == reference)
notification = db.session.execute(stmt).scalars().first()
if not notification:
current_app.logger.error(
@@ -200,19 +200,20 @@ def get_notifications_for_job(
def dao_get_notification_count_for_job_id(*, job_id):
return Notification.query.filter_by(job_id=job_id).count()
stmt = select(func.count(Notification.id)).filter_by(job_id=job_id)
return db.session.execute(stmt).scalar()
def dao_get_notification_count_for_service(*, service_id):
notification_count = Notification.query.filter_by(service_id=service_id).count()
return notification_count
stmt = select(func.count(Notification.id)).filter_by(service_id=service_id)
return db.session.execute(stmt).scalar()
def dao_get_failed_notification_count():
failed_count = Notification.query.filter_by(
stmt = select(func.count(Notification.id)).filter_by(
status=NotificationStatus.FAILED
).count()
return failed_count
)
return db.session.execute(stmt).scalar()
def get_notification_with_personalisation(service_id, notification_id, key_type):
@@ -220,11 +221,12 @@ def get_notification_with_personalisation(service_id, notification_id, key_type)
if key_type:
filter_dict["key_type"] = key_type
return (
Notification.query.filter_by(**filter_dict)
stmt = (
select(Notification)
.filter_by(**filter_dict)
.options(joinedload(Notification.template))
.one()
)
return db.session.execute(stmt).scalars().one()
def get_notification_by_id(notification_id, service_id=None, _raise=False):
@@ -233,9 +235,13 @@ def get_notification_by_id(notification_id, service_id=None, _raise=False):
if service_id:
filters.append(Notification.service_id == service_id)
query = Notification.query.filter(*filters)
stmt = select(Notification).filter(*filters)
return query.one() if _raise else query.first()
return (
db.session.execute(stmt).scalars().one()
if _raise
else db.session.execute(stmt).scalars().first()
)
def get_notifications_for_service(
@@ -415,12 +421,13 @@ def move_notifications_to_notification_history(
deleted += delete_count_per_call
# Deleting test Notifications, test notifications are not persisted to NotificationHistory
Notification.query.filter(
stmt = delete(Notification).filter(
Notification.notification_type == notification_type,
Notification.service_id == service_id,
Notification.created_at < timestamp_to_delete_backwards_from,
Notification.key_type == KeyType.TEST,
).delete(synchronize_session=False)
)
db.session.execute(stmt)
db.session.commit()
return deleted
@@ -442,8 +449,9 @@ def dao_timeout_notifications(cutoff_time, limit=100000):
current_statuses = [NotificationStatus.SENDING, NotificationStatus.PENDING]
new_status = NotificationStatus.TEMPORARY_FAILURE
notifications = (
Notification.query.filter(
stmt = (
select(Notification)
.filter(
Notification.created_at < cutoff_time,
Notification.status.in_(current_statuses),
Notification.notification_type.in_(
@@ -451,14 +459,15 @@ def dao_timeout_notifications(cutoff_time, limit=100000):
),
)
.limit(limit)
.all()
)
notifications = db.session.execute(stmt).scalars().all()
Notification.query.filter(
Notification.id.in_([n.id for n in notifications]),
).update(
{"status": new_status, "updated_at": updated_at}, synchronize_session=False
stmt = (
update(Notification)
.filter(Notification.id.in_([n.id for n in notifications]))
.values({"status": new_status, "updated_at": updated_at})
)
db.session.execute(stmt)
db.session.commit()
return notifications
@@ -466,15 +475,23 @@ def dao_timeout_notifications(cutoff_time, limit=100000):
@autocommit
def dao_update_notifications_by_reference(references, update_dict):
updated_count = Notification.query.filter(
Notification.reference.in_(references)
).update(update_dict, synchronize_session=False)
stmt = (
update(Notification)
.filter(Notification.reference.in_(references))
.values(update_dict)
)
result = db.session.execute(stmt)
updated_count = result.rowcount
updated_history_count = 0
if updated_count != len(references):
updated_history_count = NotificationHistory.query.filter(
NotificationHistory.reference.in_(references)
).update(update_dict, synchronize_session=False)
stmt = (
update(NotificationHistory)
.filter(NotificationHistory.reference.in_(references))
.values(update_dict)
)
result = db.session.execute(stmt)
updated_history_count = result.rowcount
return updated_count, updated_history_count
@@ -541,18 +558,21 @@ def dao_get_notifications_by_recipient_or_reference(
def dao_get_notification_by_reference(reference):
return Notification.query.filter(Notification.reference == reference).one()
stmt = select(Notification).filter(Notification.reference == reference)
return db.session.execute(stmt).scalars().one()
def dao_get_notification_history_by_reference(reference):
try:
# This try except is necessary because in test keys and research mode does not create notification history.
# Otherwise we could just search for the NotificationHistory object
return Notification.query.filter(Notification.reference == reference).one()
stmt = select(Notification).filter(Notification.reference == reference)
return db.session.execute(stmt).scalars().one()
except NoResultFound:
return NotificationHistory.query.filter(
stmt = select(NotificationHistory).filter(
NotificationHistory.reference == reference
).one()
)
return db.session.execute(stmt).scalars().one()
def dao_get_notifications_processing_time_stats(start_date, end_date):
@@ -590,11 +610,12 @@ def dao_get_notifications_processing_time_stats(start_date, end_date):
def dao_get_last_notification_added_for_job_id(job_id):
last_notification_added = (
Notification.query.filter(Notification.job_id == job_id)
stmt = (
select(Notification)
.filter(Notification.job_id == job_id)
.order_by(Notification.job_row_number.desc())
.first()
)
last_notification_added = db.session.execute(stmt).scalars().first()
return last_notification_added
@@ -602,11 +623,12 @@ def dao_get_last_notification_added_for_job_id(job_id):
def notifications_not_yet_sent(should_be_sending_after_seconds, notification_type):
older_than_date = utc_now() - timedelta(seconds=should_be_sending_after_seconds)
notifications = Notification.query.filter(
stmt = select(Notification).filter(
Notification.created_at <= older_than_date,
Notification.notification_type == notification_type,
Notification.status == NotificationStatus.CREATED,
).all()
)
notifications = db.session.execute(stmt).scalars().all()
return notifications

View File

@@ -1,3 +1,4 @@
from sqlalchemy import delete, select, update
from sqlalchemy.sql.expression import func
from app import db
@@ -6,55 +7,57 @@ from app.models import Domain, Organization, Service, User
def dao_get_organizations():
return Organization.query.order_by(
stmt = select(Organization).order_by(
Organization.active.desc(), Organization.name.asc()
).all()
)
return db.session.execute(stmt).scalars().all()
def dao_count_organizations_with_live_services():
return (
db.session.query(Organization.id)
stmt = (
select(func.count(func.distinct(Organization.id)))
.join(Organization.services)
.filter(
Service.active.is_(True),
Service.restricted.is_(False),
Service.count_as_live.is_(True),
)
.distinct()
.count()
)
return db.session.execute(stmt).scalar() or 0
def dao_get_organization_services(organization_id):
return Organization.query.filter_by(id=organization_id).one().services
stmt = select(Organization).filter_by(id=organization_id)
return db.session.execute(stmt).scalars().one().services
def dao_get_organization_live_services(organization_id):
return Service.query.filter_by(
organization_id=organization_id, restricted=False
).all()
stmt = select(Service).filter_by(organization_id=organization_id, restricted=False)
return db.session.execute(stmt).scalars().all()
def dao_get_organization_by_id(organization_id):
return Organization.query.filter_by(id=organization_id).one()
stmt = select(Organization).filter_by(id=organization_id)
return db.session.execute(stmt).scalars().one()
def dao_get_organization_by_email_address(email_address):
email_address = email_address.lower().replace(".gsi.gov.uk", ".gov.uk")
for domain in Domain.query.order_by(func.char_length(Domain.domain).desc()).all():
stmt = select(Domain).order_by(func.char_length(Domain.domain).desc())
domains = db.session.execute(stmt).scalars().all()
for domain in domains:
if email_address.endswith(
"@{}".format(domain.domain)
) or email_address.endswith(".{}".format(domain.domain)):
return Organization.query.filter_by(id=domain.organization_id).one()
stmt = select(Organization).filter_by(id=domain.organization_id)
return db.session.execute(stmt).scalars().one()
return None
def dao_get_organization_by_service_id(service_id):
return (
Organization.query.join(Organization.services).filter_by(id=service_id).first()
)
stmt = select(Organization).join(Organization.services).filter_by(id=service_id)
return db.session.execute(stmt).scalars().first()
@autocommit
@@ -65,10 +68,14 @@ def dao_create_organization(organization):
@autocommit
def dao_update_organization(organization_id, **kwargs):
domains = kwargs.pop("domains", None)
num_updated = Organization.query.filter_by(id=organization_id).update(kwargs)
stmt = (
update(Organization).where(Organization.id == organization_id).values(**kwargs)
)
num_updated = db.session.execute(stmt).rowcount
if isinstance(domains, list):
Domain.query.filter_by(organization_id=organization_id).delete()
stmt = delete(Domain).filter_by(organization_id=organization_id)
db.session.execute(stmt)
db.session.bulk_save_objects(
[
Domain(domain=domain.lower(), organization_id=organization_id)
@@ -76,7 +83,7 @@ def dao_update_organization(organization_id, **kwargs):
]
)
organization = Organization.query.get(organization_id)
organization = db.session.get(Organization, organization_id)
if "organization_type" in kwargs:
_update_organization_services(
organization, "organization_type", only_where_none=False
@@ -101,7 +108,8 @@ def _update_organization_services(organization, attribute, only_where_none=True)
@autocommit
@version_class(Service)
def dao_add_service_to_organization(service, organization_id):
organization = Organization.query.filter_by(id=organization_id).one()
stmt = select(Organization).filter_by(id=organization_id)
organization = db.session.execute(stmt).scalars().one()
service.organization_id = organization_id
service.organization_type = organization.organization_type
@@ -122,7 +130,8 @@ def dao_get_users_for_organization(organization_id):
@autocommit
def dao_add_user_to_organization(organization_id, user_id):
organization = dao_get_organization_by_id(organization_id)
user = User.query.filter_by(id=user_id).one()
stmt = select(User).filter_by(id=user_id)
user = db.session.execute(stmt).scalars().one()
user.organizations.append(organization)
db.session.add(organization)
return user

View File

@@ -1,16 +1,20 @@
from sqlalchemy import select
from app import db
from app.dao.dao_utils import autocommit
from app.models import TemplateFolder
def dao_get_template_folder_by_id_and_service_id(template_folder_id, service_id):
return TemplateFolder.query.filter(
stmt = select(TemplateFolder).filter(
TemplateFolder.id == template_folder_id, TemplateFolder.service_id == service_id
).one()
)
return db.session.execute(stmt).scalars().one()
def dao_get_valid_template_folders_by_id(folder_ids):
return TemplateFolder.query.filter(TemplateFolder.id.in_(folder_ids)).all()
stmt = select(TemplateFolder).filter(TemplateFolder.id.in_(folder_ids))
return db.session.execute(stmt).scalars().all()
@autocommit

View File

@@ -1,6 +1,6 @@
import uuid
from sqlalchemy import asc, desc
from sqlalchemy import asc, desc, select
from app import db
from app.dao.dao_utils import VersionOptions, autocommit, version_class
@@ -46,24 +46,29 @@ def dao_redact_template(template, user_id):
def dao_get_template_by_id_and_service_id(template_id, service_id, version=None):
if version is not None:
return TemplateHistory.query.filter_by(
stmt = select(TemplateHistory).filter_by(
id=template_id, hidden=False, service_id=service_id, version=version
).one()
return Template.query.filter_by(
)
return db.session.execute(stmt).scalars().one()
stmt = select(Template).filter_by(
id=template_id, hidden=False, service_id=service_id
).one()
)
return db.session.execute(stmt).scalars().one()
def dao_get_template_by_id(template_id, version=None):
if version is not None:
return TemplateHistory.query.filter_by(id=template_id, version=version).one()
return Template.query.filter_by(id=template_id).one()
stmt = select(TemplateHistory).filter_by(id=template_id, version=version)
return db.session.execute(stmt).scalars().one()
stmt = select(Template).filter_by(id=template_id)
return db.session.execute(stmt).scalars().one()
def dao_get_all_templates_for_service(service_id, template_type=None):
if template_type is not None:
return (
Template.query.filter_by(
stmt = (
select(Template)
.filter_by(
service_id=service_id,
template_type=template_type,
hidden=False,
@@ -73,26 +78,27 @@ def dao_get_all_templates_for_service(service_id, template_type=None):
asc(Template.name),
asc(Template.template_type),
)
.all()
)
return (
Template.query.filter_by(service_id=service_id, hidden=False, archived=False)
return db.session.execute(stmt).scalars().all()
stmt = (
select(Template)
.filter_by(service_id=service_id, hidden=False, archived=False)
.order_by(
asc(Template.name),
asc(Template.template_type),
)
.all()
)
return db.session.execute(stmt).scalars().all()
def dao_get_template_versions(service_id, template_id):
return (
TemplateHistory.query.filter_by(
stmt = (
select(TemplateHistory)
.filter_by(
service_id=service_id,
id=template_id,
hidden=False,
)
.order_by(desc(TemplateHistory.version))
.all()
)
return db.session.execute(stmt).scalars().all()

View File

@@ -4,7 +4,7 @@ from secrets import randbelow
import sqlalchemy
from flask import current_app
from sqlalchemy import func, text
from sqlalchemy import delete, func, select, text
from sqlalchemy.orm import joinedload
from app import db
@@ -37,8 +37,8 @@ def get_login_gov_user(login_uuid, email_address):
login.gov uuids are. Eventually the code that checks by email address
should be removed.
"""
user = User.query.filter_by(login_uuid=login_uuid).first()
stmt = select(User).filter_by(login_uuid=login_uuid)
user = db.session.execute(stmt).scalars().first()
if user:
if user.email_address != email_address:
try:
@@ -54,7 +54,8 @@ def get_login_gov_user(login_uuid, email_address):
return user
# Remove this 1 July 2025, all users should have login.gov uuids by now
user = User.query.filter(User.email_address.ilike(email_address)).first()
stmt = select(User).filter(User.email_address.ilike(email_address))
user = db.session.execute(stmt).scalars().first()
if user:
save_user_attribute(user, {"login_uuid": login_uuid})
@@ -102,24 +103,27 @@ def create_user_code(user, code, code_type):
def get_user_code(user, code, code_type):
# Get the most recent codes to try and reduce the
# time searching for the correct code.
codes = VerifyCode.query.filter_by(user=user, code_type=code_type).order_by(
VerifyCode.created_at.desc()
stmt = (
select(VerifyCode)
.filter_by(user=user, code_type=code_type)
.order_by(VerifyCode.created_at.desc())
)
codes = db.session.execute(stmt).scalars().all()
return next((x for x in codes if x.check_code(code)), None)
def delete_codes_older_created_more_than_a_day_ago():
deleted = (
db.session.query(VerifyCode)
.filter(VerifyCode.created_at < utc_now() - timedelta(hours=24))
.delete()
stmt = delete(VerifyCode).filter(
VerifyCode.created_at < utc_now() - timedelta(hours=24)
)
deleted = db.session.execute(stmt)
db.session.commit()
return deleted
def use_user_code(id):
verify_code = VerifyCode.query.get(id)
verify_code = db.session.get(VerifyCode, id)
verify_code.code_used = True
db.session.add(verify_code)
db.session.commit()
@@ -131,36 +135,42 @@ def delete_model_user(user):
def delete_user_verify_codes(user):
VerifyCode.query.filter_by(user=user).delete()
stmt = delete(VerifyCode).filter_by(user=user)
db.session.execute(stmt)
db.session.commit()
def count_user_verify_codes(user):
query = VerifyCode.query.filter(
stmt = select(func.count(VerifyCode.id)).filter(
VerifyCode.user == user,
VerifyCode.expiry_datetime > utc_now(),
VerifyCode.code_used.is_(False),
)
return query.count()
result = db.session.execute(stmt).scalar()
return result or 0
def get_user_by_id(user_id=None):
if user_id:
return User.query.filter_by(id=user_id).one()
return User.query.filter_by().all()
stmt = select(User).filter_by(id=user_id)
return db.session.execute(stmt).scalars().one()
return get_users()
def get_users():
return User.query.all()
stmt = select(User)
return db.session.execute(stmt).scalars().all()
def get_user_by_email(email):
return User.query.filter(func.lower(User.email_address) == func.lower(email)).one()
stmt = select(User).filter(func.lower(User.email_address) == func.lower(email))
return db.session.execute(stmt).scalars().one()
def get_users_by_partial_email(email):
email = escape_special_characters(email)
return User.query.filter(User.email_address.ilike("%{}%".format(email))).all()
stmt = select(User).filter(User.email_address.ilike("%{}%".format(email)))
return db.session.execute(stmt).scalars().all()
def increment_failed_login_count(user):
@@ -188,16 +198,17 @@ def get_user_and_accounts(user_id):
# TODO: With sqlalchemy 2.0 change as below because of the breaking change
# at User.organizations.services, we need to verify that the below subqueryload
# that we have put is functionally doing the same thing as before
return (
User.query.filter(User.id == user_id)
stmt = (
select(User)
.filter(User.id == user_id)
.options(
# eagerly load the user's services and organizations, and also the service's org and vice versa
# (so we can see if the user knows about it)
joinedload(User.services).joinedload(Service.organization),
joinedload(User.organizations).subqueryload(Organization.services),
)
.one()
)
return db.session.execute(stmt).scalars().unique().one()
@autocommit