diff --git a/.ds.baseline b/.ds.baseline index 37199f01f..c8d59174d 100644 --- a/.ds.baseline +++ b/.ds.baseline @@ -384,5 +384,5 @@ } ] }, - "generated_at": "2024-10-11T19:26:50Z" + "generated_at": "2024-10-14T17:46:47Z" } diff --git a/app/dao/organization_dao.py b/app/dao/organization_dao.py index 9e44bcdd5..668ac6c25 100644 --- a/app/dao/organization_dao.py +++ b/app/dao/organization_dao.py @@ -1,3 +1,4 @@ +from sqlalchemy import delete, select, update from sqlalchemy.sql.expression import func from app import db @@ -6,55 +7,57 @@ from app.models import Domain, Organization, Service, User def dao_get_organizations(): - return Organization.query.order_by( + stmt = select(Organization).order_by( Organization.active.desc(), Organization.name.asc() - ).all() + ) + return db.session.execute(stmt).scalars().all() def dao_count_organizations_with_live_services(): - return ( - db.session.query(Organization.id) + stmt = ( + select(func.count(func.distinct(Organization.id))) .join(Organization.services) .filter( Service.active.is_(True), Service.restricted.is_(False), Service.count_as_live.is_(True), ) - .distinct() - .count() ) + return db.session.execute(stmt).scalar() or 0 def dao_get_organization_services(organization_id): - return Organization.query.filter_by(id=organization_id).one().services + stmt = select(Organization).filter_by(id=organization_id) + return db.session.execute(stmt).scalars().one().services def dao_get_organization_live_services(organization_id): - return Service.query.filter_by( - organization_id=organization_id, restricted=False - ).all() + stmt = select(Service).filter_by(organization_id=organization_id, restricted=False) + return db.session.execute(stmt).scalars().all() def dao_get_organization_by_id(organization_id): - return Organization.query.filter_by(id=organization_id).one() + stmt = select(Organization).filter_by(id=organization_id) + return db.session.execute(stmt).scalars().one() def dao_get_organization_by_email_address(email_address): email_address = email_address.lower().replace(".gsi.gov.uk", ".gov.uk") - - for domain in Domain.query.order_by(func.char_length(Domain.domain).desc()).all(): + stmt = select(Domain).order_by(func.char_length(Domain.domain).desc()) + domains = db.session.execute(stmt).scalars().all() + for domain in domains: if email_address.endswith( "@{}".format(domain.domain) ) or email_address.endswith(".{}".format(domain.domain)): - return Organization.query.filter_by(id=domain.organization_id).one() + stmt = select(Organization).filter_by(id=domain.organization_id) + return db.session.execute(stmt).scalars().one() return None def dao_get_organization_by_service_id(service_id): - return ( - Organization.query.join(Organization.services).filter_by(id=service_id).first() - ) + stmt = select(Organization).join(Organization.services).filter_by(id=service_id) + return db.session.execute(stmt).scalars().first() @autocommit @@ -65,10 +68,14 @@ def dao_create_organization(organization): @autocommit def dao_update_organization(organization_id, **kwargs): domains = kwargs.pop("domains", None) - num_updated = Organization.query.filter_by(id=organization_id).update(kwargs) + stmt = ( + update(Organization).where(Organization.id == organization_id).values(**kwargs) + ) + num_updated = db.session.execute(stmt).rowcount if isinstance(domains, list): - Domain.query.filter_by(organization_id=organization_id).delete() + stmt = delete(Domain).filter_by(organization_id=organization_id) + db.session.execute(stmt) db.session.bulk_save_objects( [ Domain(domain=domain.lower(), organization_id=organization_id) @@ -76,7 +83,7 @@ def dao_update_organization(organization_id, **kwargs): ] ) - organization = Organization.query.get(organization_id) + organization = db.session.get(Organization, organization_id) if "organization_type" in kwargs: _update_organization_services( organization, "organization_type", only_where_none=False @@ -101,7 +108,8 @@ def _update_organization_services(organization, attribute, only_where_none=True) @autocommit @version_class(Service) def dao_add_service_to_organization(service, organization_id): - organization = Organization.query.filter_by(id=organization_id).one() + stmt = select(Organization).filter_by(id=organization_id) + organization = db.session.execute(stmt).scalars().one() service.organization_id = organization_id service.organization_type = organization.organization_type @@ -122,7 +130,8 @@ def dao_get_users_for_organization(organization_id): @autocommit def dao_add_user_to_organization(organization_id, user_id): organization = dao_get_organization_by_id(organization_id) - user = User.query.filter_by(id=user_id).one() + stmt = select(User).filter_by(id=user_id) + user = db.session.execute(stmt).scalars().one() user.organizations.append(organization) db.session.add(organization) return user diff --git a/app/dao/template_folder_dao.py b/app/dao/template_folder_dao.py index ae1224179..269f407e0 100644 --- a/app/dao/template_folder_dao.py +++ b/app/dao/template_folder_dao.py @@ -1,16 +1,20 @@ +from sqlalchemy import select + from app import db from app.dao.dao_utils import autocommit from app.models import TemplateFolder def dao_get_template_folder_by_id_and_service_id(template_folder_id, service_id): - return TemplateFolder.query.filter( + stmt = select(TemplateFolder).filter( TemplateFolder.id == template_folder_id, TemplateFolder.service_id == service_id - ).one() + ) + return db.session.execute(stmt).scalars().one() def dao_get_valid_template_folders_by_id(folder_ids): - return TemplateFolder.query.filter(TemplateFolder.id.in_(folder_ids)).all() + stmt = select(TemplateFolder).filter(TemplateFolder.id.in_(folder_ids)) + return db.session.execute(stmt).scalars().all() @autocommit diff --git a/app/dao/templates_dao.py b/app/dao/templates_dao.py index 55d4363d6..7c5d7459e 100644 --- a/app/dao/templates_dao.py +++ b/app/dao/templates_dao.py @@ -1,6 +1,6 @@ import uuid -from sqlalchemy import asc, desc +from sqlalchemy import asc, desc, select from app import db from app.dao.dao_utils import VersionOptions, autocommit, version_class @@ -46,24 +46,29 @@ def dao_redact_template(template, user_id): def dao_get_template_by_id_and_service_id(template_id, service_id, version=None): if version is not None: - return TemplateHistory.query.filter_by( + stmt = select(TemplateHistory).filter_by( id=template_id, hidden=False, service_id=service_id, version=version - ).one() - return Template.query.filter_by( + ) + return db.session.execute(stmt).scalars().one() + stmt = select(Template).filter_by( id=template_id, hidden=False, service_id=service_id - ).one() + ) + return db.session.execute(stmt).scalars().one() def dao_get_template_by_id(template_id, version=None): if version is not None: - return TemplateHistory.query.filter_by(id=template_id, version=version).one() - return Template.query.filter_by(id=template_id).one() + stmt = select(TemplateHistory).filter_by(id=template_id, version=version) + return db.session.execute(stmt).scalars().one() + stmt = select(Template).filter_by(id=template_id) + return db.session.execute(stmt).scalars().one() def dao_get_all_templates_for_service(service_id, template_type=None): if template_type is not None: - return ( - Template.query.filter_by( + stmt = ( + select(Template) + .filter_by( service_id=service_id, template_type=template_type, hidden=False, @@ -73,26 +78,27 @@ def dao_get_all_templates_for_service(service_id, template_type=None): asc(Template.name), asc(Template.template_type), ) - .all() ) - - return ( - Template.query.filter_by(service_id=service_id, hidden=False, archived=False) + return db.session.execute(stmt).scalars().all() + stmt = ( + select(Template) + .filter_by(service_id=service_id, hidden=False, archived=False) .order_by( asc(Template.name), asc(Template.template_type), ) - .all() ) + return db.session.execute(stmt).scalars().all() def dao_get_template_versions(service_id, template_id): - return ( - TemplateHistory.query.filter_by( + stmt = ( + select(TemplateHistory) + .filter_by( service_id=service_id, id=template_id, hidden=False, ) .order_by(desc(TemplateHistory.version)) - .all() ) + return db.session.execute(stmt).scalars().all() diff --git a/tests/app/dao/test_organization_dao.py b/tests/app/dao/test_organization_dao.py index edffdd1d4..fb2e01d85 100644 --- a/tests/app/dao/test_organization_dao.py +++ b/tests/app/dao/test_organization_dao.py @@ -1,6 +1,7 @@ import uuid import pytest +from sqlalchemy import select from sqlalchemy.exc import IntegrityError, SQLAlchemyError from app import db @@ -57,7 +58,8 @@ def test_get_organization_by_id_gets_correct_organization(notify_db_session): def test_update_organization(notify_db_session): create_organization() - organization = Organization.query.one() + stmt = select(Organization) + organization = db.session.execute(stmt).scalars().one() user = create_user() email_branding = create_email_branding() @@ -78,7 +80,8 @@ def test_update_organization(notify_db_session): dao_update_organization(organization.id, **data) - organization = Organization.query.one() + stmt = select(Organization) + organization = db.session.execute(stmt).scalars().one() for attribute, value in data.items(): assert getattr(organization, attribute) == value @@ -102,7 +105,8 @@ def test_update_organization_domains_lowercases( ): create_organization() - organization = Organization.query.one() + stmt = select(Organization) + organization = db.session.execute(stmt).scalars().one() # Seed some domains dao_update_organization(organization.id, domains=["123", "456"]) @@ -121,7 +125,8 @@ def test_update_organization_domains_lowercases_integrity_error( ): create_organization() - organization = Organization.query.one() + stmt = select(Organization) + organization = db.session.execute(stmt).scalars().one() # Seed some domains dao_update_organization(organization.id, domains=["123", "456"]) @@ -175,11 +180,11 @@ def test_update_organization_updates_the_service_org_type_if_org_type_is_provide assert sample_organization.organization_type == OrganizationType.FEDERAL assert sample_service.organization_type == OrganizationType.FEDERAL + stmt = select(Service.get_history_model()).filter_by( + id=sample_service.id, version=2 + ) assert ( - Service.get_history_model() - .query.filter_by(id=sample_service.id, version=2) - .one() - .organization_type + db.session.execute(stmt).scalars().one().organization_type == OrganizationType.FEDERAL ) @@ -229,11 +234,11 @@ def test_add_service_to_organization(sample_service, sample_organization): assert sample_organization.services[0].id == sample_service.id assert sample_service.organization_type == sample_organization.organization_type + stmt = select(Service.get_history_model()).filter_by( + id=sample_service.id, version=2 + ) assert ( - Service.get_history_model() - .query.filter_by(id=sample_service.id, version=2) - .one() - .organization_type + db.session.execute(stmt).scalars().one().organization_type == sample_organization.organization_type ) assert sample_service.organization_id == sample_organization.id diff --git a/tests/app/dao/test_template_folder_dao.py b/tests/app/dao/test_template_folder_dao.py index 17b03e5df..2a872e775 100644 --- a/tests/app/dao/test_template_folder_dao.py +++ b/tests/app/dao/test_template_folder_dao.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from app import db from app.dao.service_user_dao import dao_get_service_user from app.dao.template_folder_dao import ( @@ -17,5 +19,5 @@ def test_dao_delete_template_folder_deletes_user_folder_permissions( dao_update_template_folder(folder) dao_delete_template_folder(folder) - - assert db.session.query(user_folder_permissions).all() == [] + stmt = select(user_folder_permissions) + assert db.session.execute(stmt).scalars().all() == [] diff --git a/tests/app/dao/test_templates_dao.py b/tests/app/dao/test_templates_dao.py index bfe0e59d1..734a29c0a 100644 --- a/tests/app/dao/test_templates_dao.py +++ b/tests/app/dao/test_templates_dao.py @@ -2,8 +2,10 @@ from datetime import datetime import pytest from freezegun import freeze_time +from sqlalchemy import func, select from sqlalchemy.orm.exc import NoResultFound +from app import db from app.dao.templates_dao import ( dao_create_template, dao_get_all_templates_for_service, @@ -17,6 +19,16 @@ from app.models import Template, TemplateHistory, TemplateRedacted from tests.app.db import create_template +def template_query_count(): + stmt = select(func.count()).select_from(Template) + return db.session.execute(stmt).scalar() or 0 + + +def template_history_query_count(): + stmt = select(func.count()).select_from(TemplateHistory) + return db.session.execute(stmt).scalar() or 0 + + @pytest.mark.parametrize( "template_type, subject", [ @@ -37,7 +49,7 @@ def test_create_template(sample_service, sample_user, template_type, subject): template = Template(**data) dao_create_template(template) - assert Template.query.count() == 1 + assert template_query_count() == 1 assert len(dao_get_all_templates_for_service(sample_service.id)) == 1 assert ( dao_get_all_templates_for_service(sample_service.id)[0].name @@ -50,11 +62,13 @@ def test_create_template(sample_service, sample_user, template_type, subject): def test_create_template_creates_redact_entry(sample_service): - assert TemplateRedacted.query.count() == 0 + stmt = select(func.count()).select_from(TemplateRedacted) + assert db.session.execute(stmt).scalar() == 0 template = create_template(sample_service) - redacted = TemplateRedacted.query.one() + stmt = select(TemplateRedacted) + redacted = db.session.execute(stmt).scalars().one() assert redacted.template_id == template.id assert redacted.redact_personalisation is False assert redacted.updated_by_id == sample_service.created_by_id @@ -79,7 +93,8 @@ def test_update_template(sample_service, sample_user): def test_redact_template(sample_template): - redacted = TemplateRedacted.query.one() + stmt = select(TemplateRedacted) + redacted = db.session.execute(stmt).scalars().one() assert redacted.template_id == sample_template.id assert redacted.redact_personalisation is False @@ -96,7 +111,7 @@ def test_get_all_templates_for_service(service_factory): service_1 = service_factory.get("service 1", email_from="service.1") service_2 = service_factory.get("service 2", email_from="service.2") - assert Template.query.count() == 2 + assert template_query_count() == 2 assert len(dao_get_all_templates_for_service(service_1.id)) == 1 assert len(dao_get_all_templates_for_service(service_2.id)) == 1 @@ -119,7 +134,7 @@ def test_get_all_templates_for_service(service_factory): content="Template content", ) - assert Template.query.count() == 5 + assert template_query_count() == 5 assert len(dao_get_all_templates_for_service(service_1.id)) == 3 assert len(dao_get_all_templates_for_service(service_2.id)) == 2 @@ -144,7 +159,7 @@ def test_get_all_templates_for_service_is_alphabetised(sample_service): service=sample_service, ) - assert Template.query.count() == 3 + assert template_query_count() == 3 assert ( dao_get_all_templates_for_service(sample_service.id)[0].name == "Sample Template 1" @@ -171,7 +186,7 @@ def test_get_all_templates_for_service_is_alphabetised(sample_service): def test_get_all_returns_empty_list_if_no_templates(sample_service): - assert Template.query.count() == 0 + assert template_query_count() == 0 assert len(dao_get_all_templates_for_service(sample_service.id)) == 0 @@ -257,8 +272,8 @@ def test_get_template_by_id_and_service_returns_none_if_no_template( def test_create_template_creates_a_history_record_with_current_data( sample_service, sample_user ): - assert Template.query.count() == 0 - assert TemplateHistory.query.count() == 0 + assert template_query_count() == 0 + assert template_history_query_count() == 0 data = { "name": "Sample Template", "template_type": TemplateType.EMAIL, @@ -270,10 +285,12 @@ def test_create_template_creates_a_history_record_with_current_data( template = Template(**data) dao_create_template(template) - assert Template.query.count() == 1 + assert template_query_count() == 1 - template_from_db = Template.query.first() - template_history = TemplateHistory.query.first() + stmt = select(Template) + template_from_db = db.session.execute(stmt).scalars().first() + stmt = select(TemplateHistory) + template_history = db.session.execute(stmt).scalars().first() assert template_from_db.id == template_history.id assert template_from_db.name == template_history.name @@ -286,8 +303,8 @@ def test_create_template_creates_a_history_record_with_current_data( def test_update_template_creates_a_history_record_with_current_data( sample_service, sample_user ): - assert Template.query.count() == 0 - assert TemplateHistory.query.count() == 0 + assert template_query_count() == 0 + assert template_history_query_count() == 0 data = { "name": "Sample Template", "template_type": TemplateType.EMAIL, @@ -301,22 +318,26 @@ def test_update_template_creates_a_history_record_with_current_data( created = dao_get_all_templates_for_service(sample_service.id)[0] assert created.name == "Sample Template" - assert Template.query.count() == 1 - assert Template.query.first().version == 1 - assert TemplateHistory.query.count() == 1 + assert template_query_count() == 1 + stmt = select(Template) + assert db.session.execute(stmt).scalars().first().version == 1 + assert template_history_query_count() == 1 created.name = "new name" dao_update_template(created) - assert Template.query.count() == 1 - assert TemplateHistory.query.count() == 2 + assert template_query_count() == 1 + assert template_history_query_count() == 2 - template_from_db = Template.query.first() + stmt = select(Template) + template_from_db = db.session.execute(stmt).scalars().first() assert template_from_db.version == 2 - assert TemplateHistory.query.filter_by(name="Sample Template").one().version == 1 - assert TemplateHistory.query.filter_by(name="new name").one().version == 2 + stmt = select(TemplateHistory).filter_by(name="Sample Template") + assert db.session.execute(stmt).scalars().one().version == 1 + stmt = select(TemplateHistory).filter_by(name="new name") + assert db.session.execute(stmt).scalars().one().version == 2 def test_get_template_history_version(sample_user, sample_service, sample_template):