From b2ed52ffd4fda8004ec0f18d6c3a48e9accebe9a Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 07:41:26 -0700 Subject: [PATCH 01/26] upgrade test queries to sqlalchemy 2.0 --- tests/app/celery/test_tasks.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/app/celery/test_tasks.py b/tests/app/celery/test_tasks.py index 593926c18..c2dff1eb3 100644 --- a/tests/app/celery/test_tasks.py +++ b/tests/app/celery/test_tasks.py @@ -8,9 +8,10 @@ import requests_mock from celery.exceptions import Retry from freezegun import freeze_time from requests import RequestException +from sqlalchemy import func, select from sqlalchemy.exc import SQLAlchemyError -from app import encryption +from app import db, encryption from app.celery import provider_tasks, tasks from app.celery.tasks import ( get_recipient_csv_and_template_and_sender_id, @@ -1166,7 +1167,13 @@ def test_process_incomplete_job_sms(mocker, sample_template): create_notification(sample_template, job, 0) create_notification(sample_template, job, 1) - assert Notification.query.filter(Notification.job_id == job.id).count() == 2 + stmt = ( + select(func.count()) + .select_from(Notification) + .where(Notification.job_id == job.id) + ) + result = db.session.execute(stmt) + assert result.rowcount == 2 process_incomplete_job(str(job.id)) From 431d51dd631437d0a90c6e8f484ade4667080367 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 08:16:17 -0700 Subject: [PATCH 02/26] remove query.filters from test_tasks --- tests/app/celery/test_tasks.py | 56 +++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/tests/app/celery/test_tasks.py b/tests/app/celery/test_tasks.py index c2dff1eb3..67dd5d4e7 100644 --- a/tests/app/celery/test_tasks.py +++ b/tests/app/celery/test_tasks.py @@ -1172,12 +1172,13 @@ def test_process_incomplete_job_sms(mocker, sample_template): .select_from(Notification) .where(Notification.job_id == job.id) ) - result = db.session.execute(stmt) - assert result.rowcount == 2 + count = db.session.execute(stmt).scalar() + assert count == 2 process_incomplete_job(str(job.id)) - completed_job = Job.query.filter(Job.id == job.id).one() + stmt = select(Job).where(Job.id == job.id) + completed_job = db.session.execute(stmt).scalars().one() assert completed_job.job_status == JobStatus.FINISHED @@ -1213,11 +1214,17 @@ def test_process_incomplete_job_with_notifications_all_sent(mocker, sample_templ create_notification(sample_template, job, 8) create_notification(sample_template, job, 9) - assert Notification.query.filter(Notification.job_id == job.id).count() == 10 + stmt = ( + select(func.count()) + .select_from(Notification) + .where(Notification.job_id == job.id) + ) + assert db.session.execute(stmt).scalar() == 10 process_incomplete_job(str(job.id)) - completed_job = Job.query.filter(Job.id == job.id).one() + stmt = select(Job).where(Job.id == job.id) + completed_job = db.session.execute(stmt).scalars().one() assert completed_job.job_status == JobStatus.FINISHED @@ -1245,7 +1252,12 @@ def test_process_incomplete_jobs_sms(mocker, sample_template): create_notification(sample_template, job, 1) create_notification(sample_template, job, 2) - assert Notification.query.filter(Notification.job_id == job.id).count() == 3 + stmt = ( + select(func.count()) + .select_from(Notification) + .where(Notification.job_id == job.id) + ) + assert db.session.execute(stmt).scalar() == 3 job2 = create_job( template=sample_template, @@ -1262,13 +1274,21 @@ def test_process_incomplete_jobs_sms(mocker, sample_template): create_notification(sample_template, job2, 3) create_notification(sample_template, job2, 4) - assert Notification.query.filter(Notification.job_id == job2.id).count() == 5 + stmt = ( + select(func.count()) + .select_from(Notification) + .where(Notification.job_id == job2.id) + ) + + assert db.session.execute(stmt).scalar() == 5 jobs = [job.id, job2.id] process_incomplete_jobs(jobs) - completed_job = Job.query.filter(Job.id == job.id).one() - completed_job2 = Job.query.filter(Job.id == job2.id).one() + stmt = select(Job).where(Job.id == job.id) + completed_job = db.session.execute(stmt).scalars().one() + stmt = select(Job).where(Job.id == job2.id) + completed_job2 = db.session.execute(stmt).scalars().one() assert completed_job.job_status == JobStatus.FINISHED @@ -1294,12 +1314,16 @@ def test_process_incomplete_jobs_no_notifications_added(mocker, sample_template) processing_started=utc_now() - timedelta(minutes=31), job_status=JobStatus.ERROR, ) - - assert Notification.query.filter(Notification.job_id == job.id).count() == 0 + stmt = ( + select(func.count()) + .select_from(Notification) + .where(Notification.job_id == job.id) + ) + assert db.session.execute(stmt).scalar() == 0 process_incomplete_job(job.id) - - completed_job = Job.query.filter(Job.id == job.id).one() + stmt = select(Job).where(Job.id == job.id) + completed_job = db.session.execute(stmt).scalars().one() assert completed_job.job_status == JobStatus.FINISHED @@ -1355,11 +1379,13 @@ def test_process_incomplete_job_email(mocker, sample_email_template): create_notification(sample_email_template, job, 0) create_notification(sample_email_template, job, 1) - assert Notification.query.filter(Notification.job_id == job.id).count() == 2 + stmt = select(Notification).where(Notification.job_id == job.id) + assert db.session.execute(stmt).scalar() == 2 process_incomplete_job(str(job.id)) - completed_job = Job.query.filter(Job.id == job.id).one() + stmt = select(Job).where(Job.id == job.id) + completed_job = db.session.execute(stmt).scalars().one() assert completed_job.job_status == JobStatus.FINISHED From b92253b782f53143dd7953c3ad1c84eed92a5756 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 08:31:15 -0700 Subject: [PATCH 03/26] remove query.filters from test_tasks --- tests/app/celery/test_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/app/celery/test_tasks.py b/tests/app/celery/test_tasks.py index 67dd5d4e7..84ac83668 100644 --- a/tests/app/celery/test_tasks.py +++ b/tests/app/celery/test_tasks.py @@ -1379,7 +1379,7 @@ def test_process_incomplete_job_email(mocker, sample_email_template): create_notification(sample_email_template, job, 0) create_notification(sample_email_template, job, 1) - stmt = select(Notification).where(Notification.job_id == job.id) + stmt = select(func.count()).select_from(Notification).where(Notification.job_id == job.id) assert db.session.execute(stmt).scalar() == 2 process_incomplete_job(str(job.id)) From 36251ab95d84ed9190a234652a55fd6af2479286 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 08:54:48 -0700 Subject: [PATCH 04/26] fix more --- tests/app/celery/test_tasks.py | 6 +++++- tests/app/test_commands.py | 31 ++++++++++++++++++++----------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/app/celery/test_tasks.py b/tests/app/celery/test_tasks.py index 84ac83668..a0fd70584 100644 --- a/tests/app/celery/test_tasks.py +++ b/tests/app/celery/test_tasks.py @@ -1379,7 +1379,11 @@ def test_process_incomplete_job_email(mocker, sample_email_template): create_notification(sample_email_template, job, 0) create_notification(sample_email_template, job, 1) - stmt = select(func.count()).select_from(Notification).where(Notification.job_id == job.id) + stmt = ( + select(func.count()) + .select_from(Notification) + .where(Notification.job_id == job.id) + ) assert db.session.execute(stmt).scalar() == 2 process_incomplete_job(str(job.id)) diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py index 690532da9..a3e09b687 100644 --- a/tests/app/test_commands.py +++ b/tests/app/test_commands.py @@ -3,7 +3,9 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock, mock_open import pytest +from sqlalchemy import select +from app import db from app.commands import ( _update_template, bulk_invite_user_to_service, @@ -115,7 +117,8 @@ def test_update_jobs_archived_flag(notify_db_session, notify_api): right_now = right_now.strftime("%Y-%m-%d") tomorrow = tomorrow.strftime("%Y-%m-%d") - archived_jobs = Job.query.filter(Job.archived is True).count() + stmt = select(Job).where(Job.archived is True) + archived_jobs = db.session.execute(stmt).scalar() assert archived_jobs == 0 notify_api.test_cli_runner().invoke( @@ -242,7 +245,8 @@ def test_create_test_user_command(notify_db_session, notify_api): assert User.query.count() == user_count + 1 # that user should be the one we added - user = User.query.filter_by(name="Fake Personson").first() + stmt = select(User).where(name="Fake Personson") + user = db.session.execute(stmt).first() assert user.email_address == "somebody@fake.gov" assert user.auth_type == AuthType.SMS assert user.state == "active" @@ -281,10 +285,11 @@ def test_populate_annual_billing_with_defaults( populate_annual_billing_with_defaults, ["-y", 2022] ) - results = AnnualBilling.query.filter( + stmt = select(AnnualBilling).where( AnnualBilling.financial_year_start == 2022, AnnualBilling.service_id == service.id, - ).all() + ) + results = db.session.execute(stmt).scalars().all() assert len(results) == 1 assert results[0].free_sms_fragment_limit == expected_allowance @@ -306,10 +311,11 @@ def test_populate_annual_billing_with_the_previous_years_allowance( populate_annual_billing_with_defaults, ["-y", 2022] ) - results = AnnualBilling.query.filter( + stmt = select(AnnualBilling).where( AnnualBilling.financial_year_start == 2022, AnnualBilling.service_id == service.id, - ).all() + ) + results = db.session.execute(stmt).scalars().all() assert len(results) == 1 assert results[0].free_sms_fragment_limit == expected_allowance @@ -318,10 +324,11 @@ def test_populate_annual_billing_with_the_previous_years_allowance( populate_annual_billing_with_the_previous_years_allowance, ["-y", 2023] ) - results = AnnualBilling.query.filter( + stmt = select(AnnualBilling).where( AnnualBilling.financial_year_start == 2023, AnnualBilling.service_id == service.id, - ).all() + ) + results = db.session.execute(stmt).scalars().all() assert len(results) == 1 assert results[0].free_sms_fragment_limit == expected_allowance @@ -355,10 +362,11 @@ def test_populate_annual_billing_with_defaults_sets_free_allowance_to_zero_if_pr populate_annual_billing_with_defaults, ["-y", 2022] ) - results = AnnualBilling.query.filter( + stmt = select(AnnualBilling).where( AnnualBilling.financial_year_start == 2022, AnnualBilling.service_id == service.id, - ).all() + ) + results = db.session.execute(stmt).scalars().all() assert len(results) == 1 assert results[0].free_sms_fragment_limit == 0 @@ -410,7 +418,8 @@ def test_create_service_command(notify_db_session, notify_api): assert Service.query.count() == service_count + 1 # that service should be the one we added - service = Service.query.filter_by(name="Fake Service").first() + stmt = select(Service).where(name="Fake Service") + service = db.session.execute(stmt).first() assert service.email_from == "somebody@fake.gov" assert service.restricted is False assert service.message_limit == 40000 From bcd5e206ea9b1eca0b5c3574cde2204e0357fe5e Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 09:08:58 -0700 Subject: [PATCH 05/26] fix more --- tests/app/test_commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py index a3e09b687..f497485de 100644 --- a/tests/app/test_commands.py +++ b/tests/app/test_commands.py @@ -418,7 +418,7 @@ def test_create_service_command(notify_db_session, notify_api): assert Service.query.count() == service_count + 1 # that service should be the one we added - stmt = select(Service).where(name="Fake Service") + stmt = select(Service).where(Service.name == "Fake Service") service = db.session.execute(stmt).first() assert service.email_from == "somebody@fake.gov" assert service.restricted is False From f002a6c34135283c90b5d93339fe4ef6bbc014fe Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 09:20:26 -0700 Subject: [PATCH 06/26] fix more --- tests/app/test_commands.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py index f497485de..7a8bc59f3 100644 --- a/tests/app/test_commands.py +++ b/tests/app/test_commands.py @@ -118,7 +118,7 @@ def test_update_jobs_archived_flag(notify_db_session, notify_api): tomorrow = tomorrow.strftime("%Y-%m-%d") stmt = select(Job).where(Job.archived is True) - archived_jobs = db.session.execute(stmt).scalar() + archived_jobs = db.session.execute(stmt).scalar() or 0 assert archived_jobs == 0 notify_api.test_cli_runner().invoke( @@ -245,7 +245,7 @@ def test_create_test_user_command(notify_db_session, notify_api): assert User.query.count() == user_count + 1 # that user should be the one we added - stmt = select(User).where(name="Fake Personson") + stmt = select(User).where(User.name == "Fake Personson") user = db.session.execute(stmt).first() assert user.email_address == "somebody@fake.gov" assert user.auth_type == AuthType.SMS @@ -419,7 +419,7 @@ def test_create_service_command(notify_db_session, notify_api): # that service should be the one we added stmt = select(Service).where(Service.name == "Fake Service") - service = db.session.execute(stmt).first() + service = db.session.execute(stmt).scalars().first() assert service.email_from == "somebody@fake.gov" assert service.restricted is False assert service.message_limit == 40000 From a2e6e06d3fe5876afa7516a645c35d273b59e42c Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 09:28:12 -0700 Subject: [PATCH 07/26] fix more --- tests/app/test_commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py index 7a8bc59f3..361de9a40 100644 --- a/tests/app/test_commands.py +++ b/tests/app/test_commands.py @@ -246,7 +246,7 @@ def test_create_test_user_command(notify_db_session, notify_api): # that user should be the one we added stmt = select(User).where(User.name == "Fake Personson") - user = db.session.execute(stmt).first() + user = db.session.execute(stmt).scalars().first() assert user.email_address == "somebody@fake.gov" assert user.auth_type == AuthType.SMS assert user.state == "active" From 08c9bf54d1cfe65977ad446369de20bd335ac028 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 09:36:11 -0700 Subject: [PATCH 08/26] fix more --- tests/app/dao/test_provider_details_dao.py | 23 +++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/app/dao/test_provider_details_dao.py b/tests/app/dao/test_provider_details_dao.py index fd8f4a43d..624ce2bb5 100644 --- a/tests/app/dao/test_provider_details_dao.py +++ b/tests/app/dao/test_provider_details_dao.py @@ -2,8 +2,9 @@ from datetime import datetime, timedelta import pytest from freezegun import freeze_time +from sqlalchemy import delete, select -from app import notification_provider_clients +from app import db, notification_provider_clients from app.dao.provider_details_dao import ( _get_sms_providers_for_update, dao_get_provider_stats, @@ -65,17 +66,20 @@ def test_can_get_email_providers(notify_db_session): def test_should_not_error_if_any_provider_in_code_not_in_database( restore_provider_details, ): - ProviderDetails.query.filter_by(identifier="sns").delete() + stmt = delete(ProviderDetails).where(ProviderDetails.identifier == "sns") + db.session.execute(stmt) + db.session.commit() + # ProviderDetails.query.filter_by(identifier="sns").delete() assert notification_provider_clients.get_sms_client("sns") @freeze_time("2000-01-01T00:00:00") def test_update_adds_history(restore_provider_details): - ses = ProviderDetails.query.filter(ProviderDetails.identifier == "ses").one() - ses_history = ProviderDetailsHistory.query.filter( - ProviderDetailsHistory.id == ses.id - ).one() + stmt = select(ProviderDetails).where(ProviderDetails.identifier == "ses") + ses = db.session.execute(stmt).scalars().one() + stmt = select(ProviderDetailsHistory).where(ProviderDetailsHistory.id == ses.id) + ses_history = db.session.execute(stmt).scalars().one() assert ses.version == 1 assert ses_history.version == 1 @@ -88,11 +92,12 @@ def test_update_adds_history(restore_provider_details): assert not ses.active assert ses.updated_at == datetime(2000, 1, 1, 0, 0, 0) - ses_history = ( - ProviderDetailsHistory.query.filter(ProviderDetailsHistory.id == ses.id) + stmt = ( + select(ProviderDetailsHistory) + .where(ProviderDetailsHistory.id == ses.id) .order_by(ProviderDetailsHistory.version) - .all() ) + ses_history = db.session.execute(stmt).scalars().all() assert ses_history[0].active assert ses_history[0].version == 1 From 79ceddfee40be0f92a4911106911cefd160210b0 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 10:35:49 -0700 Subject: [PATCH 09/26] more fixes --- tests/app/dao/test_invited_user_dao.py | 70 +++++++++------------- tests/app/dao/test_provider_details_dao.py | 19 ++++-- tests/app/test_schemas.py | 13 ++-- 3 files changed, 49 insertions(+), 53 deletions(-) diff --git a/tests/app/dao/test_invited_user_dao.py b/tests/app/dao/test_invited_user_dao.py index da52e52e7..247a3dfda 100644 --- a/tests/app/dao/test_invited_user_dao.py +++ b/tests/app/dao/test_invited_user_dao.py @@ -2,6 +2,7 @@ import uuid from datetime import timedelta import pytest +from sqlalchemy import func, select from sqlalchemy.orm.exc import NoResultFound from app import db @@ -123,23 +124,17 @@ def test_should_delete_all_invitations_more_than_one_day_old( ): make_invitation(sample_user, sample_service, age=timedelta(hours=48)) make_invitation(sample_user, sample_service, age=timedelta(hours=48)) - assert ( - len( - InvitedUser.query.filter( - InvitedUser.status != InvitedUserStatus.EXPIRED - ).all() - ) - == 2 - ) + stmt = select(InvitedUser).where(InvitedUser.status != InvitedUserStatus.EXPIRED) + result = db.session.execute(stmt).scalars().all() + assert len(result) == 2 expire_invitations_created_more_than_two_days_ago() - assert ( - len( - InvitedUser.query.filter( - InvitedUser.status != InvitedUserStatus.EXPIRED - ).all() - ) - == 0 + stmt = ( + select(func.count()) + .select_from(InvitedUser) + .where(InvitedUser.status != InvitedUserStatus.EXPIRED) ) + count = db.session.execute(stmt).scalar() or 0 + assert count == 0 def test_should_not_delete_invitations_less_than_two_days_old( @@ -160,35 +155,28 @@ def test_should_not_delete_invitations_less_than_two_days_old( email_address="expired@1.com", ) - assert ( - len( - InvitedUser.query.filter( - InvitedUser.status != InvitedUserStatus.EXPIRED - ).all() - ) - == 2 + stmt = ( + select(func.count()) + .select_from(InvitedUser) + .where(InvitedUser.status != InvitedUserStatus.EXPIRED) ) + count = db.session.execute(stmt).scalar() or 0 + assert count == 2 expire_invitations_created_more_than_two_days_ago() - assert ( - len( - InvitedUser.query.filter( - InvitedUser.status != InvitedUserStatus.EXPIRED - ).all() - ) - == 1 - ) - assert ( - InvitedUser.query.filter(InvitedUser.status != InvitedUserStatus.EXPIRED) - .first() - .email_address - == "valid@2.com" - ) - assert ( - InvitedUser.query.filter(InvitedUser.status == InvitedUserStatus.EXPIRED) - .first() - .email_address - == "expired@1.com" + stmt = ( + select(func.count()) + .select_from(InvitedUser) + .where(InvitedUser.status != InvitedUserStatus.EXPIRED) ) + count = db.session.execute(stmt).scalar() or 0 + assert count == 1 + stmt = select(InvitedUser).where(InvitedUser.status != InvitedUserStatus.EXPIRED) + invited_user = db.session.execute(stmt).scalars().first() + assert invited_user.email_address == "valid@2.com" + stmt = select(InvitedUser).where(InvitedUser.status == InvitedUserStatus.EXPIRED) + invited_user = db.session.execute(stmt).scalars().first() + + assert invited_user.email_address == "expired@1.com" def make_invitation(user, service, age=None, email_address="test@test.com"): diff --git a/tests/app/dao/test_provider_details_dao.py b/tests/app/dao/test_provider_details_dao.py index 624ce2bb5..84c4b2238 100644 --- a/tests/app/dao/test_provider_details_dao.py +++ b/tests/app/dao/test_provider_details_dao.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta import pytest from freezegun import freeze_time -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from app import db, notification_provider_clients from app.dao.provider_details_dao import ( @@ -69,7 +69,6 @@ def test_should_not_error_if_any_provider_in_code_not_in_database( stmt = delete(ProviderDetails).where(ProviderDetails.identifier == "sns") db.session.execute(stmt) db.session.commit() - # ProviderDetails.query.filter_by(identifier="sns").delete() assert notification_provider_clients.get_sms_client("sns") @@ -135,9 +134,13 @@ def test_get_alternative_sms_provider_fails_if_unrecognised(): @freeze_time("2016-01-01 01:00") def test_get_sms_providers_for_update_returns_providers(restore_provider_details): - ProviderDetails.query.filter(ProviderDetails.identifier == "sns").update( - {"updated_at": None} + stmt = ( + update(ProviderDetails) + .where(ProviderDetails.identifier == "sns") + .values({"updated_at": None}) ) + db.session.execute(stmt) + db.session.commit() resp = _get_sms_providers_for_update(timedelta(hours=1)) @@ -149,9 +152,13 @@ def test_get_sms_providers_for_update_returns_nothing_if_recent_updates( restore_provider_details, ): fifty_nine_minutes_ago = datetime(2016, 1, 1, 0, 1) - ProviderDetails.query.filter(ProviderDetails.identifier == "sns").update( - {"updated_at": fifty_nine_minutes_ago} + stmt = ( + update(ProviderDetails) + .where(ProviderDetails.identifier == "sns") + .values({"updated_at": fifty_nine_minutes_ago}) ) + db.session.execute(stmt) + db.session.commit() resp = _get_sms_providers_for_update(timedelta(hours=1)) diff --git a/tests/app/test_schemas.py b/tests/app/test_schemas.py index 270c36a17..b71d2fef8 100644 --- a/tests/app/test_schemas.py +++ b/tests/app/test_schemas.py @@ -2,8 +2,9 @@ import datetime import pytest from marshmallow import ValidationError -from sqlalchemy import desc +from sqlalchemy import desc, select +from app import db from app.dao.provider_details_dao import ( dao_update_provider_details, get_provider_details_by_identifier, @@ -145,13 +146,13 @@ def test_provider_details_history_schema_returns_user_details( dao_update_provider_details(current_sms_provider) - current_sms_provider_in_history = ( - ProviderDetailsHistory.query.filter( - ProviderDetailsHistory.id == current_sms_provider.id - ) + stmt = ( + select(ProviderDetailsHistory) + .where(ProviderDetailsHistory.id == current_sms_provider.id) .order_by(desc(ProviderDetailsHistory.version)) - .first() ) + current_sms_provider_in_history = db.session.execute(stmt).scalars().first() + data = provider_details_schema.dump(current_sms_provider_in_history) assert sorted(data["created_by"].keys()) == sorted(["id", "email_address", "name"]) From 7ee741b91c7c189db912cf83d7f9e40197d69d8c Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 11:15:08 -0700 Subject: [PATCH 10/26] fix more tests --- .ds.baseline | 6 +-- tests/app/conftest.py | 29 +++++++++----- ...t_notification_dao_delete_notifications.py | 13 +++--- tests/app/db.py | 19 ++++++--- tests/app/service/test_rest.py | 40 +++++++++++++------ tests/app/template/test_rest.py | 5 ++- 6 files changed, 75 insertions(+), 37 deletions(-) diff --git a/.ds.baseline b/.ds.baseline index dd916c550..41a911ddd 100644 --- a/.ds.baseline +++ b/.ds.baseline @@ -267,7 +267,7 @@ "filename": "tests/app/db.py", "hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8", "is_verified": false, - "line_number": 87, + "line_number": 90, "is_secret": false } ], @@ -305,7 +305,7 @@ "filename": "tests/app/service/test_rest.py", "hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8", "is_verified": false, - "line_number": 1275, + "line_number": 1284, "is_secret": false } ], @@ -384,5 +384,5 @@ } ] }, - "generated_at": "2024-10-28T20:26:27Z" + "generated_at": "2024-10-30T18:15:03Z" } diff --git a/tests/app/conftest.py b/tests/app/conftest.py index 25e9f3f08..38e2e80d2 100644 --- a/tests/app/conftest.py +++ b/tests/app/conftest.py @@ -6,6 +6,7 @@ import pytest import pytz import requests_mock from flask import current_app, url_for +from sqlalchemy import select from sqlalchemy.orm.session import make_transient from app import db @@ -100,9 +101,10 @@ def create_sample_notification( if job is None and api_key is None: # we didn't specify in test - lets create it - api_key = ApiKey.query.filter( + stmt = select(ApiKey).where( ApiKey.service == template.service, ApiKey.key_type == key_type - ).first() + ) + api_key = db.session.execute(stmt).scalars().first() if not api_key: api_key = create_api_key(template.service, key_type=key_type) @@ -227,7 +229,8 @@ def sample_service(sample_user): "email_from": email_from, "created_by": sample_user, } - service = Service.query.filter_by(name=service_name).first() + stmt = select(Service).where(Service.name == service_name) + service = db.session.execute(stmt).scalars().first() if not service: service = Service(**data) dao_create_service(service, sample_user, service_permissions=None) @@ -442,9 +445,10 @@ def sample_notification(notify_db_session): service = create_service(check_if_service_exists=True) template = create_template(service=service) - api_key = ApiKey.query.filter( + stmt = select(ApiKey).where( ApiKey.service == template.service, ApiKey.key_type == KeyType.NORMAL - ).first() + ) + api_key = db.session.execute(stmt).scalars().first() if not api_key: api_key = create_api_key(template.service, key_type=KeyType.NORMAL) @@ -595,9 +599,12 @@ def sample_user_service_permission(sample_user): permission = PermissionType.MANAGE_SETTINGS data = {"user": sample_user, "service": service, "permission": permission} - p_model = Permission.query.filter_by( - user=sample_user, service=service, permission=permission - ).first() + stmt = select(Permission).where( + Permission.user == sample_user, + Permission.service == service, + Permission.permission == permission, + ) + p_model = db.session.execute(stmt).scalars().first() if not p_model: p_model = Permission(**data) db.session.add(p_model) @@ -612,12 +619,14 @@ def fake_uuid(): @pytest.fixture(scope="function") def ses_provider(): - return ProviderDetails.query.filter_by(identifier="ses").one() + stmt = select(ProviderDetails).where(ProviderDetails.identifier == "ses") + return db.session.execute(stmt).scalars().one() @pytest.fixture(scope="function") def sns_provider(): - return ProviderDetails.query.filter_by(identifier="sns").one() + stmt = select(ProviderDetails).where(ProviderDetails.identifier == "sns") + return db.session.execute(stmt).scalars().one() @pytest.fixture(scope="function") diff --git a/tests/app/dao/notification_dao/test_notification_dao_delete_notifications.py b/tests/app/dao/notification_dao/test_notification_dao_delete_notifications.py index e22721216..45d8958d1 100644 --- a/tests/app/dao/notification_dao/test_notification_dao_delete_notifications.py +++ b/tests/app/dao/notification_dao/test_notification_dao_delete_notifications.py @@ -2,7 +2,9 @@ import uuid from datetime import datetime, timedelta from freezegun import freeze_time +from sqlalchemy import func, select +from app import db from app.dao.notifications_dao import ( insert_notification_history_delete_notifications, move_notifications_to_notification_history, @@ -172,12 +174,13 @@ def test_move_notifications_just_deletes_test_key_notifications(sample_template) assert Notification.query.count() == 0 assert NotificationHistory.query.count() == 2 - assert ( - NotificationHistory.query.filter( - NotificationHistory.key_type == KeyType.TEST - ).count() - == 0 + stmt = ( + select(func.count()) + .select_from(NotificationHistory) + .where(NotificationHistory.key_type == KeyType.TEST) ) + count = db.session.execute(stmt).scalar() or 0 + assert count == 0 @freeze_time("2020-03-20 14:00") diff --git a/tests/app/db.py b/tests/app/db.py index b62f99b4e..07b395295 100644 --- a/tests/app/db.py +++ b/tests/app/db.py @@ -2,6 +2,8 @@ import random import uuid from datetime import datetime, timedelta +from sqlalchemy import select + from app import db from app.dao import fact_processing_time_dao from app.dao.email_branding_dao import dao_create_email_branding @@ -90,7 +92,8 @@ def create_user( "state": state, "platform_admin": platform_admin, } - user = User.query.filter_by(email_address=email).first() + stmt = select(User).where(User.email_address == email) + user = db.session.execute(stmt).scalars().first() if not user: user = User(**data) save_model_user(user, validated_email_access=True) @@ -130,7 +133,8 @@ def create_service( billing_reference=None, ): if check_if_service_exists: - service = Service.query.filter_by(name=service_name).first() + stmt = select(Service).where(Service.name == service_name) + service = db.session.execute(stmt).scalars().first() if (not check_if_service_exists) or (check_if_service_exists and not service): service = Service( name=service_name, @@ -175,7 +179,8 @@ def create_service( def create_service_with_inbound_number(inbound_number="1234567", *args, **kwargs): service = create_service(*args, **kwargs) - sms_sender = ServiceSmsSender.query.filter_by(service_id=service.id).first() + stmt = select(ServiceSmsSender).where(ServiceSmsSender.service_id == service.id) + sms_sender = db.session.execute(stmt).scalars().first() inbound = create_inbound_number(number=inbound_number, service_id=service.id) update_existing_sms_sender_with_inbound_number( service_sms_sender=sms_sender, @@ -189,7 +194,8 @@ def create_service_with_inbound_number(inbound_number="1234567", *args, **kwargs def create_service_with_defined_sms_sender(sms_sender_value="1234567", *args, **kwargs): service = create_service(*args, **kwargs) - sms_sender = ServiceSmsSender.query.filter_by(service_id=service.id).first() + stmt = select(ServiceSmsSender).where(ServiceSmsSender.service_id == service.id) + sms_sender = db.session.execute(stmt).scalars().first() dao_update_service_sms_sender( service_id=service.id, service_sms_sender_id=sms_sender.id, @@ -286,9 +292,10 @@ def create_notification( if not one_off and (job is None and api_key is None): # we did not specify in test - lets create it - api_key = ApiKey.query.filter( + stmt = select(ApiKey).where( ApiKey.service == template.service, ApiKey.key_type == key_type - ).first() + ) + api_key = db.session.execute(stmt).scalars().first() if not api_key: api_key = create_api_key(template.service, key_type=key_type) diff --git a/tests/app/service/test_rest.py b/tests/app/service/test_rest.py index ecec87ec1..1a48014ec 100644 --- a/tests/app/service/test_rest.py +++ b/tests/app/service/test_rest.py @@ -6,8 +6,10 @@ from unittest.mock import ANY import pytest from flask import current_app, url_for from freezegun import freeze_time +from sqlalchemy import func, select from sqlalchemy.exc import SQLAlchemyError +from app import db from app.dao.organization_dao import dao_add_service_to_organization from app.dao.service_sms_sender_dao import dao_get_sms_senders_by_service_id from app.dao.service_user_dao import dao_get_service_user @@ -424,9 +426,8 @@ def test_create_service( assert json_resp["data"]["name"] == "created service" - service_sms_senders = ServiceSmsSender.query.filter_by( - service_id=service_db.id - ).all() + stmt = select(ServiceSmsSender).where(ServiceSmsSender.service_id == service_db.id) + service_sms_senders = db.session.execute(stmt).scalars().all() assert len(service_sms_senders) == 1 assert service_sms_senders[0].sms_sender == current_app.config["FROM_NUMBER"] @@ -530,7 +531,13 @@ def test_create_service_should_raise_exception_and_not_create_service_if_annual_ annual_billing = AnnualBilling.query.all() assert len(annual_billing) == 0 - assert len(Service.query.filter(Service.name == "created service").all()) == 0 + stmt = ( + select(func.count()) + .select_from(Service) + .where(Service.name == "created service") + ) + count = db.session.execute(stmt).scalar() or 0 + assert count == 0 def test_create_service_inherits_branding_from_organization( @@ -933,7 +940,8 @@ def test_update_service_flags_will_remove_service_permissions( assert resp.status_code == 200 assert ServicePermissionType.INTERNATIONAL_SMS not in result["data"]["permissions"] - permissions = ServicePermission.query.filter_by(service_id=service.id).all() + stmt = select(ServicePermission).where(ServicePermission.service_id == service.id) + permissions = db.session.execute(stmt).scalars().all() assert {p.permission for p in permissions} == { ServicePermissionType.SMS, ServicePermissionType.EMAIL, @@ -1004,9 +1012,10 @@ def test_add_service_permission_will_add_permission( headers=[("Content-Type", "application/json"), auth_header], ) - permissions = ServicePermission.query.filter_by( - service_id=service_with_no_permissions.id - ).all() + stmt = select(ServicePermission).where( + ServicePermission.service_id == service_with_no_permissions.id + ) + permissions = db.session.execute(stmt).scalars().all() assert resp.status_code == 200 assert [p.permission for p in permissions] == [permission_to_add] @@ -3318,8 +3327,13 @@ def test_add_service_sms_sender_when_it_is_an_inbound_number_inserts_new_sms_sen assert resp_json["inbound_number_id"] == str(inbound_number.id) assert resp_json["is_default"] - senders = ServiceSmsSender.query.filter_by(service_id=service.id).all() - assert len(senders) == 3 + stmt = ( + select(func.count()) + .select_from(ServiceSmsSender) + .where(ServiceSmsSender.service_id == service.id) + ) + senders = db.session.execute(stmt).scalar() or 0 + assert senders == 3 def test_add_service_sms_sender_switches_default(client, notify_db_session): @@ -3341,7 +3355,8 @@ def test_add_service_sms_sender_switches_default(client, notify_db_session): assert resp_json["sms_sender"] == "second" assert not resp_json["inbound_number_id"] assert resp_json["is_default"] - sms_senders = ServiceSmsSender.query.filter_by(sms_sender="first").first() + stmt = select(ServiceSmsSender).where(ServiceSmsSender.sms_sender == "first") + sms_senders = db.session.execute(stmt).scalars().first() assert not sms_senders.is_default @@ -3407,7 +3422,8 @@ def test_update_service_sms_sender_switches_default(client, notify_db_session): assert resp_json["sms_sender"] == "second" assert not resp_json["inbound_number_id"] assert resp_json["is_default"] - sms_senders = ServiceSmsSender.query.filter_by(sms_sender="first").first() + stmt = select(ServiceSmsSender).where(ServiceSmsSender.sms_sender == "first") + sms_senders = db.session.execute(stmt).scalars().first() assert not sms_senders.is_default diff --git a/tests/app/template/test_rest.py b/tests/app/template/test_rest.py index 45dfc24f9..d46627343 100644 --- a/tests/app/template/test_rest.py +++ b/tests/app/template/test_rest.py @@ -6,7 +6,9 @@ from datetime import datetime, timedelta import pytest from freezegun import freeze_time +from sqlalchemy import select +from app import db from app.dao.templates_dao import dao_get_template_by_id, dao_redact_template from app.enums import ServicePermissionType, TemplateProcessType, TemplateType from app.models import Template, TemplateHistory @@ -86,7 +88,8 @@ def test_create_a_new_template_for_a_service_adds_folder_relationship( data=data, ) assert response.status_code == 201 - template = Template.query.filter(Template.name == "my template").first() + stmt = select(Template).where(Template.name == "my template") + template = db.session.execute(stmt).scalars().first() assert template.folder == parent_folder From 9dfbd991d568405c07eec06d9f9a8b995d9b6a29 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 11:44:51 -0700 Subject: [PATCH 11/26] fix more tests --- app/commands.py | 59 ++++++++++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/app/commands.py b/app/commands.py index 5580e7632..c88e2bcb3 100644 --- a/app/commands.py +++ b/app/commands.py @@ -12,7 +12,7 @@ from click_datetime import Datetime as click_dt from faker import Faker from flask import current_app, json from notifications_python_client.authentication import create_jwt_token -from sqlalchemy import and_, text +from sqlalchemy import and_, select, text, update from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import NoResultFound @@ -123,8 +123,8 @@ def purge_functional_test_data(user_email_prefix): if getenv("NOTIFY_ENVIRONMENT", "") not in ["development", "test"]: current_app.logger.error("Can only be run in development") return - - users = User.query.filter(User.email_address.like(f"{user_email_prefix}%")).all() + stmt = select(User).where(User.email_address.like(f"{user_email_prefix}%")) + users = db.session.execute(stmt).scalars().all() for usr in users: # Make sure the full email includes a uuid in it # Just in case someone decides to use a similar email address. @@ -338,9 +338,10 @@ def populate_organizations_from_file(file_name): email_branding = None email_branding_column = columns[5].strip() if len(email_branding_column) > 0: - email_branding = EmailBranding.query.filter( + stmt = select(EmailBranding).where( EmailBranding.name == email_branding_column - ).one() + ) + email_branding = db.session.execute(stmt).scalars().one() data = { "name": columns[0], "active": True, @@ -406,10 +407,14 @@ def populate_organization_agreement_details_from_file(file_name): @notify_command(name="associate-services-to-organizations") def associate_services_to_organizations(): - services = Service.get_history_model().query.filter_by(version=1).all() + stmt = select(Service.get_history_model()).where( + Service.get_history_model().version == 1 + ) + services = db.session.execute(stmt).scalars().all() for s in services: - created_by_user = User.query.filter_by(id=s.created_by_id).first() + stmt = select(User).where(User.id == s.created_by_id) + created_by_user = db.session.execute(stmt).scalars().first() organization = dao_get_organization_by_email_address( created_by_user.email_address ) @@ -467,15 +472,16 @@ def populate_go_live(file_name): @notify_command(name="fix-billable-units") def fix_billable_units(): - query = Notification.query.filter( + stmt = select(Notification).where( Notification.notification_type == NotificationType.SMS, Notification.status != NotificationStatus.CREATED, Notification.sent_at == None, # noqa Notification.billable_units == 0, Notification.key_type != KeyType.TEST, ) + all = db.session.execute(stmt).scalars().all() - for notification in query.all(): + for notification in all: template_model = dao_get_template_by_id( notification.template_id, notification.template_version ) @@ -490,9 +496,12 @@ def fix_billable_units(): f"Updating notification: {notification.id} with {template.fragment_count} billable_units" ) - Notification.query.filter(Notification.id == notification.id).update( - {"billable_units": template.fragment_count} + stmt = ( + update(Notification) + .where(Notification.id == notification.id) + .values({"billable_units": template.fragment_count}) ) + db.session.execute(stmt) db.session.commit() current_app.logger.info("End fix_billable_units") @@ -637,8 +646,8 @@ def populate_annual_billing_with_defaults(year, missing_services_only): This is useful to ensure all services start the new year with the correct annual billing. """ if missing_services_only: - active_services = ( - Service.query.filter(Service.active) + stmt = ( + select(Service) .outerjoin( AnnualBilling, and_( @@ -646,20 +655,18 @@ def populate_annual_billing_with_defaults(year, missing_services_only): AnnualBilling.financial_year_start == year, ), ) - .filter(AnnualBilling.id == None) # noqa - .all() + .where(Service.active, AnnualBilling.id == None) # noqa ) + active_services = db.session.execute(stmt).scalars().all() else: - active_services = Service.query.filter(Service.active).all() + stmt = select(Service).where(Service.active) + active_services = db.session.execute(stmt).scalars().all() previous_year = year - 1 - services_with_zero_free_allowance = ( - db.session.query(AnnualBilling.service_id) - .filter( - AnnualBilling.financial_year_start == previous_year, - AnnualBilling.free_sms_fragment_limit == 0, - ) - .all() + stmt = select(AnnualBilling.id).where( + AnnualBilling.financial_year_start == previous_year, + AnnualBilling.free_sms_fragment_limit == 0, ) + services_with_zero_free_allowance = db.session.execute(stmt).scalars().all() for service in active_services: # If a service has free_sms_fragment_limit for the previous year @@ -750,7 +757,8 @@ def create_user_jwt(token): def _update_template(id, name, template_type, content, subject): - template = Template.query.filter_by(id=id).first() + stmt = select(Template).where(Template.id == id) + template = db.session.execute(stmt).scalars().first() if not template: template = Template(id=id) template.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553" @@ -761,7 +769,8 @@ def _update_template(id, name, template_type, content, subject): template.content = "\n".join(content) template.subject = subject - history = TemplateHistory.query.filter_by(id=id).first() + stmt = select(TemplateHistory).where(TemplateHistory.id == id) + history = db.session.execute(stmt).scalars().first() if not history: history = TemplateHistory(id=id) history.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553" From 395ddd2c47ff8c8ee5774db15a13eae20bd2b62c Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 12:17:59 -0700 Subject: [PATCH 12/26] fix more tests --- app/dao/service_data_retention_dao.py | 48 ++++++++++++++++----------- app/dao/service_guest_list_dao.py | 14 ++++---- app/dao/webauthn_credential_dao.py | 7 ++-- 3 files changed, 42 insertions(+), 27 deletions(-) diff --git a/app/dao/service_data_retention_dao.py b/app/dao/service_data_retention_dao.py index b95ca5720..cd2c1fd4b 100644 --- a/app/dao/service_data_retention_dao.py +++ b/app/dao/service_data_retention_dao.py @@ -1,3 +1,5 @@ +from sqlalchemy import select, update + from app import db from app.dao.dao_utils import autocommit from app.models import ServiceDataRetention @@ -5,29 +7,31 @@ from app.utils import utc_now def fetch_service_data_retention_by_id(service_id, data_retention_id): - data_retention = ServiceDataRetention.query.filter_by( - service_id=service_id, id=data_retention_id - ).first() - return data_retention + stmt = select(ServiceDataRetention).where( + ServiceDataRetention.service_id == service_id, + ServiceDataRetention.id == data_retention_id, + ) + return db.session.execute(stmt).scalars().first() def fetch_service_data_retention(service_id): - data_retention_list = ( - ServiceDataRetention.query.filter_by(service_id=service_id) + stmt = ( + select(ServiceDataRetention) + .where(ServiceDataRetention.service_id == service_id) .order_by( # in the order that models.notification_types are created (email, sms, letter) ServiceDataRetention.notification_type ) - .all() ) - return data_retention_list + return db.session.execute(stmt).scalars().all() def fetch_service_data_retention_by_notification_type(service_id, notification_type): - data_retention_list = ServiceDataRetention.query.filter_by( - service_id=service_id, notification_type=notification_type - ).first() - return data_retention_list + stmt = select(ServiceDataRetention).where( + ServiceDataRetention.service_id == service_id, + ServiceDataRetention.notification_type == notification_type, + ) + return db.session.execute(stmt).scalars().first() @autocommit @@ -46,16 +50,22 @@ def insert_service_data_retention(service_id, notification_type, days_of_retenti def update_service_data_retention( service_data_retention_id, service_id, days_of_retention ): - updated_count = ServiceDataRetention.query.filter( - ServiceDataRetention.id == service_data_retention_id, - ServiceDataRetention.service_id == service_id, - ).update({"days_of_retention": days_of_retention, "updated_at": utc_now()}) - return updated_count + stmt = ( + update(ServiceDataRetention) + .where( + ServiceDataRetention.id == service_data_retention_id, + ServiceDataRetention.service_id == service_id, + ) + .values({"days_of_retention": days_of_retention, "updated_at": utc_now()}) + ) + result = db.session.execute(stmt) + return result.rowcount def fetch_service_data_retention_for_all_services_by_notification_type( notification_type, ): - return ServiceDataRetention.query.filter( + stmt = select(ServiceDataRetention).where( ServiceDataRetention.notification_type == notification_type - ).all() + ) + return db.session.execute(stmt).scalars().all() diff --git a/app/dao/service_guest_list_dao.py b/app/dao/service_guest_list_dao.py index acd39703c..59d381a8b 100644 --- a/app/dao/service_guest_list_dao.py +++ b/app/dao/service_guest_list_dao.py @@ -1,11 +1,12 @@ +from sqlalchemy import delete, select + from app import db from app.models import ServiceGuestList def dao_fetch_service_guest_list(service_id): - return ServiceGuestList.query.filter( - ServiceGuestList.service_id == service_id - ).all() + stmt = select(ServiceGuestList).where(ServiceGuestList.service_id == service_id) + return db.session.execute(stmt).scalars().all() def dao_add_and_commit_guest_list_contacts(objs): @@ -14,6 +15,7 @@ def dao_add_and_commit_guest_list_contacts(objs): def dao_remove_service_guest_list(service_id): - return ServiceGuestList.query.filter( - ServiceGuestList.service_id == service_id - ).delete() + stmt = delete(ServiceGuestList).where(ServiceGuestList.service_id == service_id) + result = db.session.execute(stmt) + db.session.commit() + return result.rowcount diff --git a/app/dao/webauthn_credential_dao.py b/app/dao/webauthn_credential_dao.py index b34d3c014..4c7a0c888 100644 --- a/app/dao/webauthn_credential_dao.py +++ b/app/dao/webauthn_credential_dao.py @@ -1,13 +1,16 @@ +from sqlalchemy import select + from app import db from app.dao.dao_utils import autocommit from app.models import WebauthnCredential def dao_get_webauthn_credential_by_user_and_id(user_id, webauthn_credential_id): - return WebauthnCredential.query.filter( + stmt = select(WebauthnCredential).where( WebauthnCredential.user_id == user_id, WebauthnCredential.id == webauthn_credential_id, - ).one() + ) + return db.session.execute(stmt).scalars().one() @autocommit From 03312bfdc47c9ac6194b4a05944bcf4642bd2d28 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 13:09:30 -0700 Subject: [PATCH 13/26] fix more tests --- tests/app/dao/test_service_guest_list_dao.py | 10 +++- .../test_process_notification.py | 54 ++++++++++++------- tests/app/test_commands.py | 36 ++++++++----- 3 files changed, 67 insertions(+), 33 deletions(-) diff --git a/tests/app/dao/test_service_guest_list_dao.py b/tests/app/dao/test_service_guest_list_dao.py index 870c78bd8..021f42319 100644 --- a/tests/app/dao/test_service_guest_list_dao.py +++ b/tests/app/dao/test_service_guest_list_dao.py @@ -1,5 +1,8 @@ import uuid +from sqlalchemy import func, select + +from app import db from app.dao.service_guest_list_dao import ( dao_add_and_commit_guest_list_contacts, dao_fetch_service_guest_list, @@ -27,7 +30,8 @@ def test_add_and_commit_guest_list_contacts_saves_data(sample_service): dao_add_and_commit_guest_list_contacts([guest_list]) - db_contents = ServiceGuestList.query.all() + stmt = select(ServiceGuestList) + db_contents = db.session.execute(stmt).scalars().all() assert len(db_contents) == 1 assert db_contents[0].id == guest_list.id @@ -60,4 +64,6 @@ def test_remove_service_guest_list_does_not_commit( # since dao_remove_service_guest_list doesn't commit, we can still rollback its changes notify_db_session.rollback() - assert ServiceGuestList.query.count() == 1 + stmt = select(func.count()).select_from(ServiceGuestList) + count = db.session.execute(stmt).scalar() or 0 + assert count == 1 diff --git a/tests/app/notifications/test_process_notification.py b/tests/app/notifications/test_process_notification.py index d7caf5bb1..9f393b440 100644 --- a/tests/app/notifications/test_process_notification.py +++ b/tests/app/notifications/test_process_notification.py @@ -5,8 +5,10 @@ from collections import namedtuple import pytest from boto3.exceptions import Boto3Error from freezegun import freeze_time +from sqlalchemy import func, select from sqlalchemy.exc import SQLAlchemyError +from app import db from app.enums import KeyType, NotificationType, ServicePermissionType, TemplateType from app.errors import BadRequestError from app.models import Notification, NotificationHistory @@ -67,12 +69,22 @@ def test_create_content_for_notification_allows_additional_personalisation( ) +def _get_notification_query_count(): + stmt = select(func.count()).select_from(Notification) + return db.session.execute(stmt).scalar() or 0 + + +def _get_notification_history_query_count(): + stmt = select(func.count()).select_from(NotificationHistory) + return db.session.execute(stmt).scalar() or 0 + + @freeze_time("2016-01-01 11:09:00.061258") def test_persist_notification_creates_and_save_to_db( sample_template, sample_api_key, sample_job ): - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 0 + assert _get_notification_history_query_count() == 0 notification = persist_notification( template_id=sample_template.id, template_version=sample_template.version, @@ -114,8 +126,8 @@ def test_persist_notification_creates_and_save_to_db( def test_persist_notification_throws_exception_when_missing_template(sample_api_key): - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 0 + assert _get_notification_history_query_count() == 0 with pytest.raises(SQLAlchemyError): persist_notification( template_id=None, @@ -127,14 +139,14 @@ def test_persist_notification_throws_exception_when_missing_template(sample_api_ api_key_id=sample_api_key.id, key_type=sample_api_key.key_type, ) - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 0 + assert _get_notification_history_query_count() == 0 @freeze_time("2016-01-01 11:09:00.061258") def test_persist_notification_with_optionals(sample_job, sample_api_key): - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 0 + assert _get_notification_history_query_count() == 0 n_id = uuid.uuid4() created_at = datetime.datetime(2016, 11, 11, 16, 8, 18) persist_notification( @@ -153,9 +165,10 @@ def test_persist_notification_with_optionals(sample_job, sample_api_key): notification_id=n_id, created_by_id=sample_job.created_by_id, ) - assert Notification.query.count() == 1 - assert NotificationHistory.query.count() == 0 - persisted_notification = Notification.query.all()[0] + assert _get_notification_query_count() == 1 + assert _get_notification_history_query_count() == 0 + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.id == n_id assert persisted_notification.job_id == sample_job.id assert persisted_notification.job_row_number == 10 @@ -267,8 +280,8 @@ def test_send_notification_to_queue_throws_exception_deletes_notification( queue="send-sms-tasks", ) - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 0 + assert _get_notification_history_query_count() == 0 @pytest.mark.parametrize( @@ -349,7 +362,8 @@ def test_persist_notification_with_international_info_stores_correct_info( job_row_number=10, client_reference="ref from client", ) - persisted_notification = Notification.query.all()[0] + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.international is expected_international assert persisted_notification.phone_prefix == expected_prefix @@ -372,7 +386,8 @@ def test_persist_notification_with_international_info_does_not_store_for_email( job_row_number=10, client_reference="ref from client", ) - persisted_notification = Notification.query.all()[0] + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.international is False assert persisted_notification.phone_prefix is None @@ -404,7 +419,8 @@ def test_persist_sms_notification_stores_normalised_number( key_type=sample_api_key.key_type, job_id=sample_job.id, ) - persisted_notification = Notification.query.all()[0] + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.to == "1" assert persisted_notification.normalised_to == "1" @@ -428,7 +444,8 @@ def test_persist_email_notification_stores_normalised_email( key_type=sample_api_key.key_type, job_id=sample_job.id, ) - persisted_notification = Notification.query.all()[0] + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.to == "1" assert persisted_notification.normalised_to == "1" @@ -449,6 +466,7 @@ def test_persist_notification_with_billable_units_stores_correct_info(mocker): key_type=KeyType.NORMAL, billable_units=3, ) - persisted_notification = Notification.query.all()[0] + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.billable_units == 3 diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py index 361de9a40..a273efe8d 100644 --- a/tests/app/test_commands.py +++ b/tests/app/test_commands.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock, mock_open import pytest -from sqlalchemy import select +from sqlalchemy import func, select from app import db from app.commands import ( @@ -56,8 +56,13 @@ from tests.app.db import ( ) +def _get_user_query_count(): + stmt = select(func.count()).select_from(User) + return db.session.execute(stmt).scalar() or 0 + + def test_purge_functional_test_data(notify_db_session, notify_api): - orig_user_count = User.query.count() + orig_user_count = _get_user_query_count() notify_api.test_cli_runner().invoke( create_test_user, @@ -73,16 +78,16 @@ def test_purge_functional_test_data(notify_db_session, notify_api): ], ) - user_count = User.query.count() + user_count = _get_user_query_count() assert user_count == orig_user_count + 1 notify_api.test_cli_runner().invoke(purge_functional_test_data, ["-u", "somebody"]) # if the email address has a uuid, it is test data so it should be purged and there should be # zero users. Otherwise, it is real data so there should be one user. - assert User.query.count() == orig_user_count + assert _get_user_query_count() == orig_user_count def test_purge_functional_test_data_bad_mobile(notify_db_session, notify_api): - user_count = User.query.count() + user_count = _get_user_query_count() assert user_count == 0 # run the command command_response = notify_api.test_cli_runner().invoke( @@ -101,7 +106,7 @@ def test_purge_functional_test_data_bad_mobile(notify_db_session, notify_api): # The bad mobile phone number results in a bad parameter error, # leading to a system exit 2 and no entry made in db assert "SystemExit(2)" in str(command_response) - user_count = User.query.count() + user_count = _get_user_query_count() assert user_count == 0 @@ -136,8 +141,13 @@ def test_update_jobs_archived_flag(notify_db_session, notify_api): assert job.archived is True +def _get_organization_query_count(): + stmt = select(Organization) + return db.session.execute(stmt).scalar() or 0 + + def test_populate_organizations_from_file(notify_db_session, notify_api): - org_count = Organization.query.count() + org_count = _get_organization_query_count() assert org_count == 0 file_name = "./tests/app/orgs1.csv" @@ -152,7 +162,7 @@ def test_populate_organizations_from_file(notify_db_session, notify_api): os.remove(file_name) print(f"command_response = {command_response}") - org_count = Organization.query.count() + org_count = _get_organization_query_count() assert org_count == 1 @@ -161,10 +171,10 @@ def test_populate_organization_agreement_details_from_file( ): file_name = "./tests/app/orgs.csv" - org_count = Organization.query.count() + org_count = _get_organization_query_count() assert org_count == 0 create_organization() - org_count = Organization.query.count() + org_count = _get_organization_query_count() assert org_count == 1 org = Organization.query.one() @@ -183,7 +193,7 @@ def test_populate_organization_agreement_details_from_file( ) print(f"command_response = {command_response}") - org_count = Organization.query.count() + org_count = _get_organization_query_count() assert org_count == 1 org = Organization.query.one() assert org.agreement_signed_on_behalf_of_name == "bob" @@ -224,7 +234,7 @@ def test_bulk_invite_user_to_service( def test_create_test_user_command(notify_db_session, notify_api): # number of users before adding ours - user_count = User.query.count() + user_count = _get_user_query_count() # run the command notify_api.test_cli_runner().invoke( @@ -242,7 +252,7 @@ def test_create_test_user_command(notify_db_session, notify_api): ) # there should be one more user - assert User.query.count() == user_count + 1 + assert _get_user_query_count() == user_count + 1 # that user should be the one we added stmt = select(User).where(User.name == "Fake Personson") From 50a9730f0b312f4adf737a76a0f4fd73cba10fc3 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 13:25:16 -0700 Subject: [PATCH 14/26] fix more tests --- app/dao/service_guest_list_dao.py | 1 - tests/app/test_commands.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/app/dao/service_guest_list_dao.py b/app/dao/service_guest_list_dao.py index 59d381a8b..8e128a213 100644 --- a/app/dao/service_guest_list_dao.py +++ b/app/dao/service_guest_list_dao.py @@ -17,5 +17,4 @@ def dao_add_and_commit_guest_list_contacts(objs): def dao_remove_service_guest_list(service_id): stmt = delete(ServiceGuestList).where(ServiceGuestList.service_id == service_id) result = db.session.execute(stmt) - db.session.commit() return result.rowcount diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py index a273efe8d..b4d9033b1 100644 --- a/tests/app/test_commands.py +++ b/tests/app/test_commands.py @@ -142,7 +142,7 @@ def test_update_jobs_archived_flag(notify_db_session, notify_api): def _get_organization_query_count(): - stmt = select(Organization) + stmt = select(func.count()).select_from(Organization) return db.session.execute(stmt).scalar() or 0 From 4cb360ecb21de5d3e0bd8c91bf645dbd8fe61214 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 14:01:04 -0700 Subject: [PATCH 15/26] fix more --- tests/app/celery/test_tasks.py | 21 +++++++++++++-------- tests/app/dao/test_invited_user_dao.py | 17 +++++++++++------ tests/app/service/test_api_key_endpoints.py | 17 ++++++++++++----- tests/app/user/test_rest_verify.py | 12 +++++++++--- 4 files changed, 45 insertions(+), 22 deletions(-) diff --git a/tests/app/celery/test_tasks.py b/tests/app/celery/test_tasks.py index a0fd70584..5720b15f9 100644 --- a/tests/app/celery/test_tasks.py +++ b/tests/app/celery/test_tasks.py @@ -531,7 +531,12 @@ def test_should_not_save_sms_if_restricted_service_and_invalid_number( encryption.encrypt(notification), ) assert provider_tasks.deliver_sms.apply_async.called is False - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 + + +def _get_notification_query_count(): + stmt = select(func.count()).select_from(Notification) + return db.session.execute(stmt).scalar() or 0 def test_should_not_save_email_if_restricted_service_and_invalid_email_address( @@ -553,7 +558,7 @@ def test_should_not_save_email_if_restricted_service_and_invalid_email_address( encryption.encrypt(notification), ) - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 def test_should_save_sms_template_to_and_persist_with_job_id(sample_job, mocker): @@ -593,7 +598,7 @@ def test_should_save_sms_template_to_and_persist_with_job_id(sample_job, mocker) def test_should_not_save_sms_if_team_key_and_recipient_not_in_team( notify_db_session, mocker ): - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 user = create_user(mobile_number="2028675309") service = create_service(user=user, restricted=True) template = create_template(service=service) @@ -611,7 +616,7 @@ def test_should_not_save_sms_if_team_key_and_recipient_not_in_team( encryption.encrypt(notification), ) assert provider_tasks.deliver_sms.apply_async.called is False - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 def test_should_use_email_template_and_persist( @@ -836,7 +841,7 @@ def test_save_sms_should_go_to_retry_queue_if_database_errors(sample_template, m assert provider_tasks.deliver_sms.apply_async.called is False tasks.save_sms.retry.assert_called_with(exc=expected_exception, queue="retry-tasks") - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 def test_save_email_should_go_to_retry_queue_if_database_errors( @@ -866,7 +871,7 @@ def test_save_email_should_go_to_retry_queue_if_database_errors( exc=expected_exception, queue="retry-tasks" ) - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 def test_save_email_does_not_send_duplicate_and_does_not_put_in_retry_queue( @@ -888,7 +893,7 @@ def test_save_email_does_not_send_duplicate_and_does_not_put_in_retry_queue( notification_id, encryption.encrypt(json), ) - assert Notification.query.count() == 1 + assert _get_notification_query_count() == 1 assert not deliver_email.called assert not retry.called @@ -912,7 +917,7 @@ def test_save_sms_does_not_send_duplicate_and_does_not_put_in_retry_queue( notification_id, encryption.encrypt(json), ) - assert Notification.query.count() == 1 + assert _get_notification_query_count() == 1 assert not deliver_sms.called assert not retry.called diff --git a/tests/app/dao/test_invited_user_dao.py b/tests/app/dao/test_invited_user_dao.py index 247a3dfda..44fc23572 100644 --- a/tests/app/dao/test_invited_user_dao.py +++ b/tests/app/dao/test_invited_user_dao.py @@ -19,8 +19,13 @@ from app.utils import utc_now from tests.app.db import create_invited_user +def _get_invited_user_count(): + stmt = select(func.count()).select_from(InvitedUser) + return db.session.execute(stmt).scalar() or 0 + + def test_create_invited_user(notify_db_session, sample_service): - assert InvitedUser.query.count() == 0 + assert _get_invited_user_count() == 0 email_address = "invited_user@service.gov.uk" invite_from = sample_service.users[0] @@ -35,7 +40,7 @@ def test_create_invited_user(notify_db_session, sample_service): invited_user = InvitedUser(**data) save_invited_user(invited_user) - assert InvitedUser.query.count() == 1 + assert _get_invited_user_count() == 1 assert invited_user.email_address == email_address assert invited_user.from_user == invite_from permissions = invited_user.get_permissions() @@ -48,7 +53,7 @@ def test_create_invited_user(notify_db_session, sample_service): def test_create_invited_user_sets_default_folder_permissions_of_empty_list( sample_service, ): - assert InvitedUser.query.count() == 0 + assert _get_invited_user_count() == 0 invite_from = sample_service.users[0] data = { @@ -61,7 +66,7 @@ def test_create_invited_user_sets_default_folder_permissions_of_empty_list( invited_user = InvitedUser(**data) save_invited_user(invited_user) - assert InvitedUser.query.count() == 1 + assert _get_invited_user_count() == 1 assert invited_user.folder_permissions == [] @@ -109,12 +114,12 @@ def test_get_invited_users_for_service_that_has_no_invites( def test_save_invited_user_sets_status_to_cancelled( notify_db_session, sample_invited_user ): - assert InvitedUser.query.count() == 1 + assert _get_invited_user_count() == 1 saved = InvitedUser.query.get(sample_invited_user.id) assert saved.status == InvitedUserStatus.PENDING saved.status = InvitedUserStatus.CANCELLED save_invited_user(saved) - assert InvitedUser.query.count() == 1 + assert _get_invited_user_count() == 1 cancelled_invited_user = InvitedUser.query.get(sample_invited_user.id) assert cancelled_invited_user.status == InvitedUserStatus.CANCELLED diff --git a/tests/app/service/test_api_key_endpoints.py b/tests/app/service/test_api_key_endpoints.py index 8ca0e374d..09a964b3c 100644 --- a/tests/app/service/test_api_key_endpoints.py +++ b/tests/app/service/test_api_key_endpoints.py @@ -1,7 +1,9 @@ import json from flask import url_for +from sqlalchemy import func, select +from app import db from app.dao.api_key_dao import expire_api_key from app.enums import KeyType from app.models import ApiKey @@ -60,10 +62,15 @@ def test_create_api_key_without_key_type_rejects(client, sample_service): assert json_resp["message"] == {"key_type": ["Missing data for required field."]} +def _get_api_key_count(): + stmt = select(func.count()).select_from(ApiKey) + return db.session.execute(stmt).scalar() or 0 + + def test_revoke_should_expire_api_key_for_service(notify_api, sample_api_key): with notify_api.test_request_context(): with notify_api.test_client() as client: - assert ApiKey.query.count() == 1 + assert _get_api_key_count() == 1 auth_header = create_admin_authorization_header() response = client.post( url_for( @@ -83,7 +90,7 @@ def test_api_key_should_create_multiple_new_api_key_for_service( ): with notify_api.test_request_context(): with notify_api.test_client() as client: - assert ApiKey.query.count() == 0 + assert _get_api_key_count() == 0 data = { "name": "some secret name", "created_by": str(sample_service.created_by.id), @@ -96,7 +103,7 @@ def test_api_key_should_create_multiple_new_api_key_for_service( headers=[("Content-Type", "application/json"), auth_header], ) assert response.status_code == 201 - assert ApiKey.query.count() == 1 + assert _get_api_key_count() == 1 data["name"] = "another secret name" auth_header = create_admin_authorization_header() @@ -109,7 +116,7 @@ def test_api_key_should_create_multiple_new_api_key_for_service( assert json.loads(response.get_data(as_text=True)) != json.loads( response2.get_data(as_text=True) ) - assert ApiKey.query.count() == 2 + assert _get_api_key_count() == 2 def test_get_api_keys_should_return_all_keys_for_service(notify_api, sample_api_key): @@ -130,7 +137,7 @@ def test_get_api_keys_should_return_all_keys_for_service(notify_api, sample_api_ service_id=one_to_expire.service_id, api_key_id=one_to_expire.id ) - assert ApiKey.query.count() == 4 + assert _get_api_key_count() == 4 auth_header = create_admin_authorization_header() response = client.get( diff --git a/tests/app/user/test_rest_verify.py b/tests/app/user/test_rest_verify.py index ff74f6b57..d32d923bf 100644 --- a/tests/app/user/test_rest_verify.py +++ b/tests/app/user/test_rest_verify.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta import pytest from flask import current_app, url_for from freezegun import freeze_time +from sqlalchemy import func, select import app.celery.tasks from app import db @@ -295,7 +296,7 @@ def test_send_sms_code_returns_204_when_too_many_codes_already_created( ) db.session.add(verify_code) db.session.commit() - assert VerifyCode.query.count() == 5 + assert _get_verify_code_count() == 5 auth_header = create_admin_authorization_header() resp = client.post( url_for( @@ -307,7 +308,12 @@ def test_send_sms_code_returns_204_when_too_many_codes_already_created( headers=[("Content-Type", "application/json"), auth_header], ) assert resp.status_code == 204 - assert VerifyCode.query.count() == 5 + assert _get_verify_code_count() == 5 + + +def _get_verify_code_count(): + stmt = select(func.count()).select_from(VerifyCode) + return db.session.execute(stmt).scalar() or 0 @pytest.mark.parametrize( @@ -341,7 +347,7 @@ def test_send_new_user_email_verification( notify_service = email_verification_template.service assert resp.status_code == 204 notification = Notification.query.first() - assert VerifyCode.query.count() == 0 + assert _get_verify_code_count() == 0 mocked.assert_called_once_with( ([str(notification.id)]), queue="notify-internal-tasks" ) From dd2921406f49de2cbc62c3a66674992331b543f1 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 14:22:08 -0700 Subject: [PATCH 16/26] fix more --- app/commands.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/commands.py b/app/commands.py index c88e2bcb3..b625df464 100644 --- a/app/commands.py +++ b/app/commands.py @@ -655,7 +655,8 @@ def populate_annual_billing_with_defaults(year, missing_services_only): AnnualBilling.financial_year_start == year, ), ) - .where(Service.active, AnnualBilling.id == None) # noqa + .where(Service.active) + .where(AnnualBilling.id == None) # noqa ) active_services = db.session.execute(stmt).scalars().all() else: From 873079609a9f910e46b906df2b64481ebd54b133 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 14:52:32 -0700 Subject: [PATCH 17/26] revert test to 1.4 --- tests/app/test_commands.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py index b4d9033b1..4acb9eedf 100644 --- a/tests/app/test_commands.py +++ b/tests/app/test_commands.py @@ -372,11 +372,10 @@ def test_populate_annual_billing_with_defaults_sets_free_allowance_to_zero_if_pr populate_annual_billing_with_defaults, ["-y", 2022] ) - stmt = select(AnnualBilling).where( + results = AnnualBilling.query.filter( AnnualBilling.financial_year_start == 2022, AnnualBilling.service_id == service.id, - ) - results = db.session.execute(stmt).scalars().all() + ).all() assert len(results) == 1 assert results[0].free_sms_fragment_limit == 0 From 9bdc29b003560e06ea7d78d506ab656bdfe20ded Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 15:04:30 -0700 Subject: [PATCH 18/26] revert test to 1.4 --- app/commands.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/app/commands.py b/app/commands.py index b625df464..4893b0056 100644 --- a/app/commands.py +++ b/app/commands.py @@ -646,8 +646,8 @@ def populate_annual_billing_with_defaults(year, missing_services_only): This is useful to ensure all services start the new year with the correct annual billing. """ if missing_services_only: - stmt = ( - select(Service) + active_services = ( + Service.query.filter(Service.active) .outerjoin( AnnualBilling, and_( @@ -655,19 +655,20 @@ def populate_annual_billing_with_defaults(year, missing_services_only): AnnualBilling.financial_year_start == year, ), ) - .where(Service.active) - .where(AnnualBilling.id == None) # noqa + .filter(AnnualBilling.id == None) # noqa + .all() ) - active_services = db.session.execute(stmt).scalars().all() else: - stmt = select(Service).where(Service.active) - active_services = db.session.execute(stmt).scalars().all() + active_services = Service.query.filter(Service.active).all() previous_year = year - 1 - stmt = select(AnnualBilling.id).where( - AnnualBilling.financial_year_start == previous_year, - AnnualBilling.free_sms_fragment_limit == 0, + services_with_zero_free_allowance = ( + db.session.query(AnnualBilling.service_id) + .filter( + AnnualBilling.financial_year_start == previous_year, + AnnualBilling.free_sms_fragment_limit == 0, + ) + .all() ) - services_with_zero_free_allowance = db.session.execute(stmt).scalars().all() for service in active_services: # If a service has free_sms_fragment_limit for the previous year From 331cc413892536cdf2282415f851a996c52351d9 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Thu, 31 Oct 2024 09:17:49 -0700 Subject: [PATCH 19/26] fix more --- app/dao/invited_user_dao.py | 23 ++++++++++++++++------- app/dao/notifications_dao.py | 1 + app/dao/provider_details_dao.py | 7 ++++--- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/app/dao/invited_user_dao.py b/app/dao/invited_user_dao.py index a342f504d..631086f30 100644 --- a/app/dao/invited_user_dao.py +++ b/app/dao/invited_user_dao.py @@ -1,5 +1,7 @@ from datetime import timedelta +from sqlalchemy import select + from app import db from app.enums import InvitedUserStatus from app.models import InvitedUser @@ -12,30 +14,37 @@ def save_invited_user(invited_user): def get_invited_user_by_service_and_id(service_id, invited_user_id): - return InvitedUser.query.filter( + + stmt = select(InvitedUser).where( InvitedUser.service_id == service_id, InvitedUser.id == invited_user_id, - ).one() + ) + return db.session.execute(stmt).scalars().one() def get_expired_invite_by_service_and_id(service_id, invited_user_id): - return InvitedUser.query.filter( + stmt = select(InvitedUser).where( InvitedUser.service_id == service_id, InvitedUser.id == invited_user_id, InvitedUser.status == InvitedUserStatus.EXPIRED, - ).one() + ) + return db.session.execute(stmt).scalars().all() def get_invited_user_by_id(invited_user_id): - return InvitedUser.query.filter(InvitedUser.id == invited_user_id).one() + stmt = select(InvitedUser).where(InvitedUser.id == invited_user_id) + return db.session.execute(stmt).scalars().one() def get_expired_invited_users_for_service(service_id): - return InvitedUser.query.filter(InvitedUser.service_id == service_id).all() + # TODO why does this return all invited users? + stmt = select(InvitedUser).where(InvitedUser.service_id == service_id) + return db.session.execute(stmt).scalars().all() def get_invited_users_for_service(service_id): - return InvitedUser.query.filter(InvitedUser.service_id == service_id).all() + stmt = select(InvitedUser).where(InvitedUser.service_id == service_id) + return db.session.execute(stmt).scalars().all() def expire_invitations_created_more_than_two_days_ago(): diff --git a/app/dao/notifications_dao.py b/app/dao/notifications_dao.py index 1d07473c1..cbde45d30 100644 --- a/app/dao/notifications_dao.py +++ b/app/dao/notifications_dao.py @@ -192,6 +192,7 @@ def get_notifications_for_job( ): if page_size is None: page_size = current_app.config["PAGE_SIZE"] + query = Notification.query.filter_by(service_id=service_id, job_id=job_id) query = _filter_query(query, filter_dict) return query.order_by(asc(Notification.job_row_number)).paginate( diff --git a/app/dao/provider_details_dao.py b/app/dao/provider_details_dao.py index b0ab48d09..fca8ebd73 100644 --- a/app/dao/provider_details_dao.py +++ b/app/dao/provider_details_dao.py @@ -1,7 +1,7 @@ from datetime import datetime from flask import current_app -from sqlalchemy import desc, func +from sqlalchemy import desc, func, select from app import db from app.dao.dao_utils import autocommit @@ -11,11 +11,12 @@ from app.utils import utc_now def get_provider_details_by_id(provider_details_id): - return ProviderDetails.query.get(provider_details_id) + return db.session.get(ProviderDetails, provider_details_id) def get_provider_details_by_identifier(identifier): - return ProviderDetails.query.filter_by(identifier=identifier).one() + stmt = select(ProviderDetails).where(identifier=identifier) + return db.session.execute(stmt).scalars().one() def get_alternative_sms_provider(identifier): From d28c1f7d85646a095d79ec63515f6b00e41197a1 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Thu, 31 Oct 2024 09:37:48 -0700 Subject: [PATCH 20/26] fix more --- app/dao/provider_details_dao.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/app/dao/provider_details_dao.py b/app/dao/provider_details_dao.py index fca8ebd73..54f36b372 100644 --- a/app/dao/provider_details_dao.py +++ b/app/dao/provider_details_dao.py @@ -15,7 +15,7 @@ def get_provider_details_by_id(provider_details_id): def get_provider_details_by_identifier(identifier): - stmt = select(ProviderDetails).where(identifier=identifier) + stmt = select(ProviderDetails).where(ProviderDetails.identifier == identifier) return db.session.execute(stmt).scalars().one() @@ -26,12 +26,13 @@ def get_alternative_sms_provider(identifier): def dao_get_provider_versions(provider_id): - return ( - ProviderDetailsHistory.query.filter_by(id=provider_id) + stmt = ( + select(ProviderDetailsHistory) + .where(ProviderDetailsHistory.id == provider_id) .order_by(desc(ProviderDetailsHistory.version)) - .limit(100) # limit results instead of adding pagination - .all() ) + # limit results instead of adding pagination + return db.session.execute(stmt).limit(100).scalars().all() def _get_sms_providers_for_update(time_threshold): @@ -43,14 +44,15 @@ def _get_sms_providers_for_update(time_threshold): release the transaction in that case """ # get current priority of both providers - q = ( - ProviderDetails.query.filter( + stmt = ( + select(ProviderDetails) + .where( ProviderDetails.notification_type == NotificationType.SMS, ProviderDetails.active, ) .with_for_update() - .all() ) + q = db.session.execute(stmt).scalars().all() # if something updated recently, don't update again. If the updated_at is null, treat it as min time if any( @@ -73,7 +75,8 @@ def get_provider_details_by_notification_type( if supports_international: filters.append(ProviderDetails.supports_international == supports_international) - return ProviderDetails.query.filter(*filters).all() + stmt = select(ProviderDetails).where(*filters) + return db.session.execute(stmt).scalars().all() @autocommit From 6b6bc2b4e76f1c7e63ac3d92d40d953d73c3c326 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Thu, 31 Oct 2024 10:03:15 -0700 Subject: [PATCH 21/26] fix more --- app/dao/invited_user_dao.py | 2 +- app/dao/provider_details_dao.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/app/dao/invited_user_dao.py b/app/dao/invited_user_dao.py index 631086f30..49f953e26 100644 --- a/app/dao/invited_user_dao.py +++ b/app/dao/invited_user_dao.py @@ -28,7 +28,7 @@ def get_expired_invite_by_service_and_id(service_id, invited_user_id): InvitedUser.id == invited_user_id, InvitedUser.status == InvitedUserStatus.EXPIRED, ) - return db.session.execute(stmt).scalars().all() + return db.session.execute(stmt).scalars().one() def get_invited_user_by_id(invited_user_id): diff --git a/app/dao/provider_details_dao.py b/app/dao/provider_details_dao.py index 54f36b372..1b094273b 100644 --- a/app/dao/provider_details_dao.py +++ b/app/dao/provider_details_dao.py @@ -30,9 +30,10 @@ def dao_get_provider_versions(provider_id): select(ProviderDetailsHistory) .where(ProviderDetailsHistory.id == provider_id) .order_by(desc(ProviderDetailsHistory.version)) + .limit(100) ) # limit results instead of adding pagination - return db.session.execute(stmt).limit(100).scalars().all() + return db.session.execute(stmt).scalars().all() def _get_sms_providers_for_update(time_threshold): From bc7180185b73263b27870ec319ba834f17ccdbd6 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Thu, 31 Oct 2024 11:32:27 -0700 Subject: [PATCH 22/26] fix more --- .ds.baseline | 6 +++--- app/celery/scheduled_tasks.py | 11 ++++++---- app/commands.py | 10 ++++++---- ...t_notification_dao_delete_notifications.py | 13 +++++++++--- tests/app/dao/test_events_dao.py | 11 ++++++++-- tests/app/service/test_sender.py | 10 ++++++++-- .../test_template_folder_rest.py | 10 ++++++++-- tests/app/test_commands.py | 11 ++++++---- tests/app/user/test_rest.py | 20 +++++++++++++++---- 9 files changed, 74 insertions(+), 28 deletions(-) diff --git a/.ds.baseline b/.ds.baseline index 41a911ddd..eb730eedd 100644 --- a/.ds.baseline +++ b/.ds.baseline @@ -341,7 +341,7 @@ "filename": "tests/app/user/test_rest.py", "hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8", "is_verified": false, - "line_number": 106, + "line_number": 108, "is_secret": false }, { @@ -349,7 +349,7 @@ "filename": "tests/app/user/test_rest.py", "hashed_secret": "0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33", "is_verified": false, - "line_number": 810, + "line_number": 822, "is_secret": false } ], @@ -384,5 +384,5 @@ } ] }, - "generated_at": "2024-10-30T18:15:03Z" + "generated_at": "2024-10-31T18:32:23Z" } diff --git a/app/celery/scheduled_tasks.py b/app/celery/scheduled_tasks.py index 3597bdbb7..504b77f56 100644 --- a/app/celery/scheduled_tasks.py +++ b/app/celery/scheduled_tasks.py @@ -1,10 +1,10 @@ from datetime import timedelta from flask import current_app -from sqlalchemy import between +from sqlalchemy import between, select from sqlalchemy.exc import SQLAlchemyError -from app import notify_celery, zendesk_client +from app import db, notify_celery, zendesk_client from app.celery.tasks import ( get_recipient_csv_and_template_and_sender_id, process_incomplete_jobs, @@ -105,15 +105,18 @@ def check_job_status(): thirty_minutes_ago = utc_now() - timedelta(minutes=30) thirty_five_minutes_ago = utc_now() - timedelta(minutes=35) - incomplete_in_progress_jobs = Job.query.filter( + stmt = select(Job).where( Job.job_status == JobStatus.IN_PROGRESS, between(Job.processing_started, thirty_five_minutes_ago, thirty_minutes_ago), ) - incomplete_pending_jobs = Job.query.filter( + incomplete_in_progress_jobs = db.session.execute(stmt).scalars().all() + + stmt = select(Job).where( Job.job_status == JobStatus.PENDING, Job.scheduled_for.isnot(None), between(Job.scheduled_for, thirty_five_minutes_ago, thirty_minutes_ago), ) + incomplete_pending_jobs = db.session.execute(stmt).scalars().all() jobs_not_complete_after_30_minutes = ( incomplete_in_progress_jobs.union(incomplete_pending_jobs) diff --git a/app/commands.py b/app/commands.py index 4893b0056..79bd3192d 100644 --- a/app/commands.py +++ b/app/commands.py @@ -646,8 +646,9 @@ def populate_annual_billing_with_defaults(year, missing_services_only): This is useful to ensure all services start the new year with the correct annual billing. """ if missing_services_only: - active_services = ( - Service.query.filter(Service.active) + stmt = ( + select(Service) + .where(Service.active) .outerjoin( AnnualBilling, and_( @@ -656,10 +657,11 @@ def populate_annual_billing_with_defaults(year, missing_services_only): ), ) .filter(AnnualBilling.id == None) # noqa - .all() ) + active_services = db.session.execute(stmt).scalars().all() else: - active_services = Service.query.filter(Service.active).all() + stmt = select(Service).where(Service.active) + active_services = db.session.execute(stmt).scalars().all() previous_year = year - 1 services_with_zero_free_allowance = ( db.session.query(AnnualBilling.service_id) diff --git a/tests/app/dao/notification_dao/test_notification_dao_delete_notifications.py b/tests/app/dao/notification_dao/test_notification_dao_delete_notifications.py index 45d8958d1..fbe365e00 100644 --- a/tests/app/dao/notification_dao/test_notification_dao_delete_notifications.py +++ b/tests/app/dao/notification_dao/test_notification_dao_delete_notifications.py @@ -42,12 +42,17 @@ def test_move_notifications_does_nothing_if_notification_history_row_already_exi 1, ) - assert Notification.query.count() == 0 + assert _get_notification_count() == 0 history = NotificationHistory.query.all() assert len(history) == 1 assert history[0].status == NotificationStatus.DELIVERED +def _get_notification_count(): + stmt = select(func.count()).select_from(Notification) + return db.session.execute(stmt).scalar() or 0 + + def test_move_notifications_only_moves_notifications_older_than_provided_timestamp( sample_template, ): @@ -172,8 +177,10 @@ def test_move_notifications_just_deletes_test_key_notifications(sample_template) assert result == 2 - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 2 + assert _get_notification_count() == 0 + stmt = select(func.count()).select_from(NotificationHistory) + count = db.session.execute(stmt).scalar() or 0 + assert count == 2 stmt = ( select(func.count()) .select_from(NotificationHistory) diff --git a/tests/app/dao/test_events_dao.py b/tests/app/dao/test_events_dao.py index 2647aafcb..60c977af6 100644 --- a/tests/app/dao/test_events_dao.py +++ b/tests/app/dao/test_events_dao.py @@ -1,9 +1,14 @@ +from sqlalchemy import func, select + +from app import db from app.dao.events_dao import dao_create_event from app.models import Event def test_create_event(notify_db_session): - assert Event.query.count() == 0 + stmt = select(func.count()).select_from(Event) + count = db.session.execute(stmt).scalar() or 0 + assert count == 0 data = { "event_type": "sucessful_login", "data": {"something": "random", "in_fact": "could be anything"}, @@ -12,6 +17,8 @@ def test_create_event(notify_db_session): event = Event(**data) dao_create_event(event) - assert Event.query.count() == 1 + stmt = select(func.count()).select_from(Event) + count = db.session.execute(stmt).scalar() or 0 + assert count == 1 event_from_db = Event.query.first() assert event == event_from_db diff --git a/tests/app/service/test_sender.py b/tests/app/service/test_sender.py index caae265c8..4b9c10ee1 100644 --- a/tests/app/service/test_sender.py +++ b/tests/app/service/test_sender.py @@ -1,6 +1,8 @@ import pytest from flask import current_app +from sqlalchemy import func, select +from app import db from app.dao.services_dao import dao_add_user_to_service from app.enums import NotificationType, TemplateType from app.models import Notification @@ -23,7 +25,9 @@ def test_send_notification_to_service_users_persists_notifications_correctly( notification = Notification.query.one() - assert Notification.query.count() == 1 + stmt = select(func.count()).select_from(Notification) + count = db.session.execute(stmt).scalar() or 0 + assert count == 1 assert notification.to == "1" assert str(notification.service_id) == current_app.config["NOTIFY_SERVICE_ID"] assert notification.template.id == template.id @@ -89,4 +93,6 @@ def test_send_notification_to_service_users_sends_to_active_users_only( send_notification_to_service_users(service_id=service.id, template_id=template.id) - assert Notification.query.count() == 2 + stmt = select(func.count()).select_from(Notification) + count = db.session.execute(stmt).scalar() or 0 + assert count == 2 diff --git a/tests/app/template_folder/test_template_folder_rest.py b/tests/app/template_folder/test_template_folder_rest.py index 6461ad3df..3bd2b4ee9 100644 --- a/tests/app/template_folder/test_template_folder_rest.py +++ b/tests/app/template_folder/test_template_folder_rest.py @@ -1,7 +1,9 @@ import uuid import pytest +from sqlalchemy import func, select +from app import db from app.dao.service_user_dao import dao_get_service_user from app.models import TemplateFolder from tests.app.db import ( @@ -286,7 +288,9 @@ def test_delete_template_folder_fails_if_folder_has_subfolders( assert resp == {"result": "error", "message": "Folder is not empty"} - assert TemplateFolder.query.count() == 2 + stmt = select(func.count()).select_from(TemplateFolder) + count = db.session.execute(stmt).scalar() or 0 + assert count == 2 def test_delete_template_folder_fails_if_folder_contains_templates( @@ -304,7 +308,9 @@ def test_delete_template_folder_fails_if_folder_contains_templates( assert resp == {"result": "error", "message": "Folder is not empty"} - assert TemplateFolder.query.count() == 1 + stmt = select(func.count()).select_from(TemplateFolder) + count = db.session.execute(stmt).scalar() or 0 + assert count == 1 @pytest.mark.parametrize( diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py index 4acb9eedf..e4a27c0e2 100644 --- a/tests/app/test_commands.py +++ b/tests/app/test_commands.py @@ -414,17 +414,20 @@ def test_create_service_command(notify_db_session, notify_api): user = User.query.first() - service_count = Service.query.count() + stmt = select(func.count()).select_from(Service) + service_count = db.session.execute(stmt).scalar() or 0 # run the command - result = notify_api.test_cli_runner().invoke( + notify_api.test_cli_runner().invoke( create_new_service, ["-e", "somebody@fake.gov", "-n", "Fake Service", "-c", user.id], ) - print(result) # there should be one more service - assert Service.query.count() == service_count + 1 + + stmt = select(func.count()).select_from(Service) + count = db.session.execute(stmt).scalar() or 0 + assert count == service_count + 1 # that service should be the one we added stmt = select(Service).where(Service.name == "Fake Service") diff --git a/tests/app/user/test_rest.py b/tests/app/user/test_rest.py index 4e064ca8e..399b708c6 100644 --- a/tests/app/user/test_rest.py +++ b/tests/app/user/test_rest.py @@ -6,7 +6,9 @@ from unittest import mock import pytest from flask import current_app from freezegun import freeze_time +from sqlalchemy import func, select +from app import db from app.dao.service_user_dao import dao_get_service_user, dao_update_service_user from app.enums import AuthType, KeyType, NotificationType, PermissionType from app.models import Notification, Permission, User @@ -153,12 +155,17 @@ def test_post_user_missing_attribute_email(admin_request, notify_db_session): } json_resp = admin_request.post("user.create_user", _data=data, _expected_status=400) - assert User.query.count() == 0 + assert _get_user_count() == 0 assert {"email_address": ["Missing data for required field."]} == json_resp[ "message" ] +def _get_user_count(): + stmt = select(func.count()).select_from(User) + return db.session.execute(stmt).scalar() or 0 + + def test_create_user_missing_attribute_password(admin_request, notify_db_session): """ Tests POST endpoint '/' missing attribute password. @@ -174,7 +181,7 @@ def test_create_user_missing_attribute_password(admin_request, notify_db_session "permissions": {}, } json_resp = admin_request.post("user.create_user", _data=data, _expected_status=400) - assert User.query.count() == 0 + assert _get_user_count() == 0 assert {"password": ["Missing data for required field."]} == json_resp["message"] @@ -512,8 +519,13 @@ def test_set_user_permissions_remove_old(admin_request, sample_user, sample_serv _expected_status=204, ) - query = Permission.query.filter_by(user=sample_user) - assert query.count() == 1 + query = ( + select(func.count()) + .select_from(Permission) + .where(Permission.user == sample_user) + ) + count = db.session.execute(query).scalar() or 0 + assert count == 1 assert query.first().permission == PermissionType.MANAGE_SETTINGS From c33c7a5058640deb9ade13299b5608952b7cd5bb Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Thu, 31 Oct 2024 11:48:08 -0700 Subject: [PATCH 23/26] fix more --- .ds.baseline | 4 ++-- app/celery/scheduled_tasks.py | 15 +++++++-------- tests/app/user/test_rest.py | 14 +++++++++----- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/.ds.baseline b/.ds.baseline index eb730eedd..bcae9f1e8 100644 --- a/.ds.baseline +++ b/.ds.baseline @@ -349,7 +349,7 @@ "filename": "tests/app/user/test_rest.py", "hashed_secret": "0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33", "is_verified": false, - "line_number": 822, + "line_number": 826, "is_secret": false } ], @@ -384,5 +384,5 @@ } ] }, - "generated_at": "2024-10-31T18:32:23Z" + "generated_at": "2024-10-31T18:48:03Z" } diff --git a/app/celery/scheduled_tasks.py b/app/celery/scheduled_tasks.py index 504b77f56..86d801f88 100644 --- a/app/celery/scheduled_tasks.py +++ b/app/celery/scheduled_tasks.py @@ -105,24 +105,23 @@ def check_job_status(): thirty_minutes_ago = utc_now() - timedelta(minutes=30) thirty_five_minutes_ago = utc_now() - timedelta(minutes=35) - stmt = select(Job).where( + incomplete_in_progress_jobs = select(Job).where( Job.job_status == JobStatus.IN_PROGRESS, between(Job.processing_started, thirty_five_minutes_ago, thirty_minutes_ago), ) - incomplete_in_progress_jobs = db.session.execute(stmt).scalars().all() + # incomplete_in_progress_jobs = db.session.execute(stmt).scalars().all() - stmt = select(Job).where( + incomplete_pending_jobs = select(Job).where( Job.job_status == JobStatus.PENDING, Job.scheduled_for.isnot(None), between(Job.scheduled_for, thirty_five_minutes_ago, thirty_minutes_ago), ) - incomplete_pending_jobs = db.session.execute(stmt).scalars().all() + # incomplete_pending_jobs = db.session.execute(stmt).scalars().all() - jobs_not_complete_after_30_minutes = ( - incomplete_in_progress_jobs.union(incomplete_pending_jobs) - .order_by(Job.processing_started, Job.scheduled_for) - .all() + stmt = incomplete_in_progress_jobs.union(incomplete_pending_jobs).order_by( + Job.processing_started, Job.scheduled_for ) + jobs_not_complete_after_30_minutes = db.session.execute(stmt).scalars().all() # temporarily mark them as ERROR so that they don't get picked up by future check_job_status tasks # if they haven't been re-processed in time. diff --git a/tests/app/user/test_rest.py b/tests/app/user/test_rest.py index 399b708c6..f1ea5041b 100644 --- a/tests/app/user/test_rest.py +++ b/tests/app/user/test_rest.py @@ -336,7 +336,8 @@ def test_post_user_attribute_with_updated_by_sends_notification_to_international _data=update_dict, ) - notification = Notification.query.first() + stmt = select(Notification) + notification = db.session.execute(stmt).scalars().first() assert ( notification.reply_to_text == current_app.config["NOTIFY_INTERNATIONAL_SMS_SENDER"] @@ -526,7 +527,9 @@ def test_set_user_permissions_remove_old(admin_request, sample_user, sample_serv ) count = db.session.execute(query).scalar() or 0 assert count == 1 - assert query.first().permission == PermissionType.MANAGE_SETTINGS + query = select(Permission).where(Permission.user == sample_user) + first_permission = db.session.execute(query).scalars().first() + assert first_permission.permission == PermissionType.MANAGE_SETTINGS def test_set_user_folder_permissions(admin_request, sample_user, sample_service): @@ -658,7 +661,8 @@ def test_send_already_registered_email( _expected_status=204, ) - notification = Notification.query.first() + stmt = select(Notification) + notification = db.session.execute(stmt).scalars().first() mocked.assert_called_once_with( ([str(notification.id)]), queue="notify-internal-tasks" ) @@ -696,8 +700,8 @@ def test_send_user_confirm_new_email_returns_204( _data=data, _expected_status=204, ) - - notification = Notification.query.first() + stmt = select(Notification) + notification = db.session.execute(stmt).scalars().first() mocked.assert_called_once_with( ([str(notification.id)]), queue="notify-internal-tasks" ) From 3dd21705b88943369c0042c3994e39d6cd56b108 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Thu, 31 Oct 2024 12:00:42 -0700 Subject: [PATCH 24/26] fix more --- app/celery/scheduled_tasks.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/app/celery/scheduled_tasks.py b/app/celery/scheduled_tasks.py index 86d801f88..3597bdbb7 100644 --- a/app/celery/scheduled_tasks.py +++ b/app/celery/scheduled_tasks.py @@ -1,10 +1,10 @@ from datetime import timedelta from flask import current_app -from sqlalchemy import between, select +from sqlalchemy import between from sqlalchemy.exc import SQLAlchemyError -from app import db, notify_celery, zendesk_client +from app import notify_celery, zendesk_client from app.celery.tasks import ( get_recipient_csv_and_template_and_sender_id, process_incomplete_jobs, @@ -105,23 +105,21 @@ def check_job_status(): thirty_minutes_ago = utc_now() - timedelta(minutes=30) thirty_five_minutes_ago = utc_now() - timedelta(minutes=35) - incomplete_in_progress_jobs = select(Job).where( + incomplete_in_progress_jobs = Job.query.filter( Job.job_status == JobStatus.IN_PROGRESS, between(Job.processing_started, thirty_five_minutes_ago, thirty_minutes_ago), ) - # incomplete_in_progress_jobs = db.session.execute(stmt).scalars().all() - - incomplete_pending_jobs = select(Job).where( + incomplete_pending_jobs = Job.query.filter( Job.job_status == JobStatus.PENDING, Job.scheduled_for.isnot(None), between(Job.scheduled_for, thirty_five_minutes_ago, thirty_minutes_ago), ) - # incomplete_pending_jobs = db.session.execute(stmt).scalars().all() - stmt = incomplete_in_progress_jobs.union(incomplete_pending_jobs).order_by( - Job.processing_started, Job.scheduled_for + jobs_not_complete_after_30_minutes = ( + incomplete_in_progress_jobs.union(incomplete_pending_jobs) + .order_by(Job.processing_started, Job.scheduled_for) + .all() ) - jobs_not_complete_after_30_minutes = db.session.execute(stmt).scalars().all() # temporarily mark them as ERROR so that they don't get picked up by future check_job_status tasks # if they haven't been re-processed in time. From 78ac1ee094004b3cef75b598f81f540fd78a8158 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Thu, 31 Oct 2024 14:25:35 -0700 Subject: [PATCH 25/26] fix more --- .ds.baseline | 6 ++--- tests/app/celery/test_reporting_tasks.py | 23 ++++++++++++------- .../test_receive_notification.py | 11 +++++++-- .../test_send_notification.py | 11 +++++++-- 4 files changed, 36 insertions(+), 15 deletions(-) diff --git a/.ds.baseline b/.ds.baseline index bcae9f1e8..8aaa131c5 100644 --- a/.ds.baseline +++ b/.ds.baseline @@ -277,7 +277,7 @@ "filename": "tests/app/notifications/test_receive_notification.py", "hashed_secret": "913a73b565c8e2c8ed94497580f619397709b8b6", "is_verified": false, - "line_number": 24, + "line_number": 26, "is_secret": false }, { @@ -285,7 +285,7 @@ "filename": "tests/app/notifications/test_receive_notification.py", "hashed_secret": "d70eab08607a4d05faa2d0d6647206599e9abc65", "is_verified": false, - "line_number": 54, + "line_number": 56, "is_secret": false } ], @@ -384,5 +384,5 @@ } ] }, - "generated_at": "2024-10-31T18:48:03Z" + "generated_at": "2024-10-31T21:25:32Z" } diff --git a/tests/app/celery/test_reporting_tasks.py b/tests/app/celery/test_reporting_tasks.py index a32f68fc3..124038d48 100644 --- a/tests/app/celery/test_reporting_tasks.py +++ b/tests/app/celery/test_reporting_tasks.py @@ -4,7 +4,9 @@ from uuid import UUID import pytest from freezegun import freeze_time +from sqlalchemy import select +from app import db from app.celery.reporting_tasks import ( create_nightly_billing, create_nightly_billing_for_day, @@ -132,11 +134,11 @@ def test_create_nightly_billing_for_day_checks_history( status=NotificationStatus.DELIVERED, ) - records = FactBilling.query.all() + records = _get_fact_billing_records() assert len(records) == 0 create_nightly_billing_for_day(str(yesterday.date())) - records = FactBilling.query.all() + records = _get_fact_billing_records() assert len(records) == 1 record = records[0] @@ -144,6 +146,11 @@ def test_create_nightly_billing_for_day_checks_history( assert record.notifications_sent == 2 +def _get_fact_billing_records(): + stmt = select(FactBilling) + return db.session.execute(stmt).scalars().all() + + @pytest.mark.parametrize( "second_rate, records_num, billable_units, multiplier", [(1.0, 1, 2, [1]), (2.0, 2, 1, [1, 2])], @@ -181,7 +188,7 @@ def test_create_nightly_billing_for_day_sms_rate_multiplier( billable_units=1, ) - records = FactBilling.query.all() + records = _get_fact_billing_records() assert len(records) == 0 create_nightly_billing_for_day(str(yesterday.date())) @@ -221,7 +228,7 @@ def test_create_nightly_billing_for_day_different_templates( billable_units=0, ) - records = FactBilling.query.all() + records = _get_fact_billing_records() assert len(records) == 0 create_nightly_billing_for_day(str(yesterday.date())) @@ -265,7 +272,7 @@ def test_create_nightly_billing_for_day_same_sent_by( billable_units=1, ) - records = FactBilling.query.all() + records = _get_fact_billing_records() assert len(records) == 0 create_nightly_billing_for_day(str(yesterday.date())) @@ -296,11 +303,11 @@ def test_create_nightly_billing_for_day_null_sent_by_sms( billable_units=1, ) - records = FactBilling.query.all() + records = _get_fact_billing_records() assert len(records) == 0 create_nightly_billing_for_day(str(yesterday.date())) - records = FactBilling.query.all() + records = _get_fact_billing_records() assert len(records) == 1 record = records[0] @@ -384,7 +391,7 @@ def test_create_nightly_billing_for_day_update_when_record_exists( billable_units=1, ) - records = FactBilling.query.all() + records = _get_fact_billing_records() assert len(records) == 0 create_nightly_billing_for_day("2018-01-14") diff --git a/tests/app/notifications/test_receive_notification.py b/tests/app/notifications/test_receive_notification.py index c95088803..e13b8d82e 100644 --- a/tests/app/notifications/test_receive_notification.py +++ b/tests/app/notifications/test_receive_notification.py @@ -4,7 +4,9 @@ from unittest import mock import pytest from flask import json +from sqlalchemy import func, select +from app import db from app.enums import ServicePermissionType from app.models import InboundSms from app.notifications.receive_notifications import ( @@ -99,7 +101,9 @@ def test_receive_notification_from_sns_without_permissions_does_not_persist( parsed_response = json.loads(response.get_data(as_text=True)) assert parsed_response["result"] == "success" - assert InboundSms.query.count() == 0 + stmt = select(func.count()).select_from(InboundSms) + count = db.session.execute(stmt).scalar() or 0 + assert count == 0 assert mocked.called is False @@ -285,7 +289,10 @@ def test_receive_notification_error_if_not_single_matching_service( # we still return 'RECEIVED' to MMG assert response.status_code == 200 assert response.get_data(as_text=True) == "RECEIVED" - assert InboundSms.query.count() == 0 + + stmt = select(func.count()).select_from(InboundSms) + count = db.session.execute(stmt).scalar() or 0 + assert count == 0 @pytest.mark.skip(reason="Need to implement inbound SNS tests. Body here from MMG") diff --git a/tests/app/service/send_notification/test_send_notification.py b/tests/app/service/send_notification/test_send_notification.py index dcd6cc8e7..fd37f7592 100644 --- a/tests/app/service/send_notification/test_send_notification.py +++ b/tests/app/service/send_notification/test_send_notification.py @@ -5,8 +5,10 @@ import pytest from flask import current_app, json from freezegun import freeze_time from notifications_python_client.authentication import create_jwt_token +from sqlalchemy import func, select import app +from app import db from app.dao import notifications_dao from app.dao.api_key_dao import save_model_api_key from app.dao.services_dao import dao_update_service @@ -883,7 +885,9 @@ def test_should_not_persist_notification_or_send_email_if_simulated_email( assert response.status_code == 201 apply_async.assert_not_called() - assert Notification.query.count() == 0 + stmt = select(func.count()).select_from(Notification) + count = db.session.execute(stmt).scalar() or 0 + assert count == 0 @pytest.mark.parametrize("to_sms", ["+14254147755", "+14254147167"]) @@ -906,7 +910,10 @@ def test_should_not_persist_notification_or_send_sms_if_simulated_number( assert response.status_code == 201 apply_async.assert_not_called() - assert Notification.query.count() == 0 + + stmt = select(func.count()).select_from(Notification) + count = db.session.execute(stmt).scalar() or 0 + assert count == 0 @pytest.mark.parametrize("key_type", [KeyType.NORMAL, KeyType.TEAM]) From c5b227403ea5b3c40be5e7758d9bfff1c4bdc070 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Thu, 31 Oct 2024 14:38:07 -0700 Subject: [PATCH 26/26] fix more --- tests/app/organization/test_rest.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/tests/app/organization/test_rest.py b/tests/app/organization/test_rest.py index a9d7db135..1d521ca9c 100644 --- a/tests/app/organization/test_rest.py +++ b/tests/app/organization/test_rest.py @@ -4,8 +4,10 @@ from unittest.mock import Mock import pytest from flask import current_app from freezegun import freeze_time +from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError +from app import db from app.dao.organization_dao import ( dao_add_service_to_organization, dao_add_user_to_organization, @@ -175,7 +177,7 @@ def test_post_create_organization(admin_request, notify_db_session): "organization.create_organization", _data=data, _expected_status=201 ) - organizations = Organization.query.all() + organizations = _get_organizations() assert data["name"] == response["name"] assert data["active"] == response["active"] @@ -186,6 +188,11 @@ def test_post_create_organization(admin_request, notify_db_session): assert organizations[0].email_branding_id is None +def _get_organizations(): + stmt = select(Organization) + return db.session.execute(stmt).scalars().all() + + @pytest.mark.parametrize("org_type", ["nhs_central", "nhs_local", "nhs_gp"]) @pytest.mark.skip(reason="Update for TTS") def test_post_create_organization_sets_default_nhs_branding_for_nhs_orgs( @@ -201,7 +208,7 @@ def test_post_create_organization_sets_default_nhs_branding_for_nhs_orgs( "organization.create_organization", _data=data, _expected_status=201 ) - organizations = Organization.query.all() + organizations = _get_organizations() assert len(organizations) == 1 assert organizations[0].email_branding_id == uuid.UUID( @@ -212,7 +219,7 @@ def test_post_create_organization_sets_default_nhs_branding_for_nhs_orgs( def test_post_create_organization_existing_name_raises_400( admin_request, sample_organization ): - organization = Organization.query.all() + organization = _get_organizations() assert len(organization) == 1 data = { @@ -225,14 +232,14 @@ def test_post_create_organization_existing_name_raises_400( "organization.create_organization", _data=data, _expected_status=400 ) - organization = Organization.query.all() + organization = _get_organizations() assert len(organization) == 1 assert response["message"] == "Organization name already exists" def test_post_create_organization_works(admin_request, sample_organization): - organization = Organization.query.all() + organization = _get_organizations() assert len(organization) == 1 data = { @@ -245,7 +252,7 @@ def test_post_create_organization_works(admin_request, sample_organization): "organization.create_organization", _data=data, _expected_status=201 ) - organization = Organization.query.all() + organization = _get_organizations() assert len(organization) == 2 @@ -310,7 +317,7 @@ def test_post_update_organization_updates_fields( _expected_status=204, ) - organization = Organization.query.all() + organization = _get_organizations() assert len(organization) == 1 assert organization[0].id == org.id @@ -343,7 +350,7 @@ def test_post_update_organization_updates_domains( _expected_status=204, ) - organization = Organization.query.all() + organization = _get_organizations() assert len(organization) == 1 assert [domain.domain for domain in organization[0].domains] == domain_list @@ -383,7 +390,7 @@ def test_post_update_organization_to_nhs_type_updates_branding_if_none_present( _expected_status=204, ) - organization = Organization.query.all() + organization = _get_organizations() assert len(organization) == 1 assert organization[0].id == org.id @@ -413,7 +420,7 @@ def test_post_update_organization_to_nhs_type_does_not_update_branding_if_defaul _expected_status=204, ) - organization = Organization.query.all() + organization = _get_organizations() assert len(organization) == 1 assert organization[0].id == org.id @@ -471,7 +478,7 @@ def test_post_update_organization_gives_404_status_if_org_does_not_exist( _expected_status=404, ) - organization = Organization.query.all() + organization = _get_organizations() assert not organization