From 03312bfdc47c9ac6194b4a05944bcf4642bd2d28 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 13:09:30 -0700 Subject: [PATCH] fix more tests --- tests/app/dao/test_service_guest_list_dao.py | 10 +++- .../test_process_notification.py | 54 ++++++++++++------- tests/app/test_commands.py | 36 ++++++++----- 3 files changed, 67 insertions(+), 33 deletions(-) diff --git a/tests/app/dao/test_service_guest_list_dao.py b/tests/app/dao/test_service_guest_list_dao.py index 870c78bd8..021f42319 100644 --- a/tests/app/dao/test_service_guest_list_dao.py +++ b/tests/app/dao/test_service_guest_list_dao.py @@ -1,5 +1,8 @@ import uuid +from sqlalchemy import func, select + +from app import db from app.dao.service_guest_list_dao import ( dao_add_and_commit_guest_list_contacts, dao_fetch_service_guest_list, @@ -27,7 +30,8 @@ def test_add_and_commit_guest_list_contacts_saves_data(sample_service): dao_add_and_commit_guest_list_contacts([guest_list]) - db_contents = ServiceGuestList.query.all() + stmt = select(ServiceGuestList) + db_contents = db.session.execute(stmt).scalars().all() assert len(db_contents) == 1 assert db_contents[0].id == guest_list.id @@ -60,4 +64,6 @@ def test_remove_service_guest_list_does_not_commit( # since dao_remove_service_guest_list doesn't commit, we can still rollback its changes notify_db_session.rollback() - assert ServiceGuestList.query.count() == 1 + stmt = select(func.count()).select_from(ServiceGuestList) + count = db.session.execute(stmt).scalar() or 0 + assert count == 1 diff --git a/tests/app/notifications/test_process_notification.py b/tests/app/notifications/test_process_notification.py index d7caf5bb1..9f393b440 100644 --- a/tests/app/notifications/test_process_notification.py +++ b/tests/app/notifications/test_process_notification.py @@ -5,8 +5,10 @@ from collections import namedtuple import pytest from boto3.exceptions import Boto3Error from freezegun import freeze_time +from sqlalchemy import func, select from sqlalchemy.exc import SQLAlchemyError +from app import db from app.enums import KeyType, NotificationType, ServicePermissionType, TemplateType from app.errors import BadRequestError from app.models import Notification, NotificationHistory @@ -67,12 +69,22 @@ def test_create_content_for_notification_allows_additional_personalisation( ) +def _get_notification_query_count(): + stmt = select(func.count()).select_from(Notification) + return db.session.execute(stmt).scalar() or 0 + + +def _get_notification_history_query_count(): + stmt = select(func.count()).select_from(NotificationHistory) + return db.session.execute(stmt).scalar() or 0 + + @freeze_time("2016-01-01 11:09:00.061258") def test_persist_notification_creates_and_save_to_db( sample_template, sample_api_key, sample_job ): - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 0 + assert _get_notification_history_query_count() == 0 notification = persist_notification( template_id=sample_template.id, template_version=sample_template.version, @@ -114,8 +126,8 @@ def test_persist_notification_creates_and_save_to_db( def test_persist_notification_throws_exception_when_missing_template(sample_api_key): - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 0 + assert _get_notification_history_query_count() == 0 with pytest.raises(SQLAlchemyError): persist_notification( template_id=None, @@ -127,14 +139,14 @@ def test_persist_notification_throws_exception_when_missing_template(sample_api_ api_key_id=sample_api_key.id, key_type=sample_api_key.key_type, ) - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 0 + assert _get_notification_history_query_count() == 0 @freeze_time("2016-01-01 11:09:00.061258") def test_persist_notification_with_optionals(sample_job, sample_api_key): - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 0 + assert _get_notification_history_query_count() == 0 n_id = uuid.uuid4() created_at = datetime.datetime(2016, 11, 11, 16, 8, 18) persist_notification( @@ -153,9 +165,10 @@ def test_persist_notification_with_optionals(sample_job, sample_api_key): notification_id=n_id, created_by_id=sample_job.created_by_id, ) - assert Notification.query.count() == 1 - assert NotificationHistory.query.count() == 0 - persisted_notification = Notification.query.all()[0] + assert _get_notification_query_count() == 1 + assert _get_notification_history_query_count() == 0 + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.id == n_id assert persisted_notification.job_id == sample_job.id assert persisted_notification.job_row_number == 10 @@ -267,8 +280,8 @@ def test_send_notification_to_queue_throws_exception_deletes_notification( queue="send-sms-tasks", ) - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 0 + assert _get_notification_history_query_count() == 0 @pytest.mark.parametrize( @@ -349,7 +362,8 @@ def test_persist_notification_with_international_info_stores_correct_info( job_row_number=10, client_reference="ref from client", ) - persisted_notification = Notification.query.all()[0] + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.international is expected_international assert persisted_notification.phone_prefix == expected_prefix @@ -372,7 +386,8 @@ def test_persist_notification_with_international_info_does_not_store_for_email( job_row_number=10, client_reference="ref from client", ) - persisted_notification = Notification.query.all()[0] + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.international is False assert persisted_notification.phone_prefix is None @@ -404,7 +419,8 @@ def test_persist_sms_notification_stores_normalised_number( key_type=sample_api_key.key_type, job_id=sample_job.id, ) - persisted_notification = Notification.query.all()[0] + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.to == "1" assert persisted_notification.normalised_to == "1" @@ -428,7 +444,8 @@ def test_persist_email_notification_stores_normalised_email( key_type=sample_api_key.key_type, job_id=sample_job.id, ) - persisted_notification = Notification.query.all()[0] + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.to == "1" assert persisted_notification.normalised_to == "1" @@ -449,6 +466,7 @@ def test_persist_notification_with_billable_units_stores_correct_info(mocker): key_type=KeyType.NORMAL, billable_units=3, ) - persisted_notification = Notification.query.all()[0] + stmt = select(Notification) + persisted_notification = db.session.execute(stmt).scalars().all()[0] assert persisted_notification.billable_units == 3 diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py index 361de9a40..a273efe8d 100644 --- a/tests/app/test_commands.py +++ b/tests/app/test_commands.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock, mock_open import pytest -from sqlalchemy import select +from sqlalchemy import func, select from app import db from app.commands import ( @@ -56,8 +56,13 @@ from tests.app.db import ( ) +def _get_user_query_count(): + stmt = select(func.count()).select_from(User) + return db.session.execute(stmt).scalar() or 0 + + def test_purge_functional_test_data(notify_db_session, notify_api): - orig_user_count = User.query.count() + orig_user_count = _get_user_query_count() notify_api.test_cli_runner().invoke( create_test_user, @@ -73,16 +78,16 @@ def test_purge_functional_test_data(notify_db_session, notify_api): ], ) - user_count = User.query.count() + user_count = _get_user_query_count() assert user_count == orig_user_count + 1 notify_api.test_cli_runner().invoke(purge_functional_test_data, ["-u", "somebody"]) # if the email address has a uuid, it is test data so it should be purged and there should be # zero users. Otherwise, it is real data so there should be one user. - assert User.query.count() == orig_user_count + assert _get_user_query_count() == orig_user_count def test_purge_functional_test_data_bad_mobile(notify_db_session, notify_api): - user_count = User.query.count() + user_count = _get_user_query_count() assert user_count == 0 # run the command command_response = notify_api.test_cli_runner().invoke( @@ -101,7 +106,7 @@ def test_purge_functional_test_data_bad_mobile(notify_db_session, notify_api): # The bad mobile phone number results in a bad parameter error, # leading to a system exit 2 and no entry made in db assert "SystemExit(2)" in str(command_response) - user_count = User.query.count() + user_count = _get_user_query_count() assert user_count == 0 @@ -136,8 +141,13 @@ def test_update_jobs_archived_flag(notify_db_session, notify_api): assert job.archived is True +def _get_organization_query_count(): + stmt = select(Organization) + return db.session.execute(stmt).scalar() or 0 + + def test_populate_organizations_from_file(notify_db_session, notify_api): - org_count = Organization.query.count() + org_count = _get_organization_query_count() assert org_count == 0 file_name = "./tests/app/orgs1.csv" @@ -152,7 +162,7 @@ def test_populate_organizations_from_file(notify_db_session, notify_api): os.remove(file_name) print(f"command_response = {command_response}") - org_count = Organization.query.count() + org_count = _get_organization_query_count() assert org_count == 1 @@ -161,10 +171,10 @@ def test_populate_organization_agreement_details_from_file( ): file_name = "./tests/app/orgs.csv" - org_count = Organization.query.count() + org_count = _get_organization_query_count() assert org_count == 0 create_organization() - org_count = Organization.query.count() + org_count = _get_organization_query_count() assert org_count == 1 org = Organization.query.one() @@ -183,7 +193,7 @@ def test_populate_organization_agreement_details_from_file( ) print(f"command_response = {command_response}") - org_count = Organization.query.count() + org_count = _get_organization_query_count() assert org_count == 1 org = Organization.query.one() assert org.agreement_signed_on_behalf_of_name == "bob" @@ -224,7 +234,7 @@ def test_bulk_invite_user_to_service( def test_create_test_user_command(notify_db_session, notify_api): # number of users before adding ours - user_count = User.query.count() + user_count = _get_user_query_count() # run the command notify_api.test_cli_runner().invoke( @@ -242,7 +252,7 @@ def test_create_test_user_command(notify_db_session, notify_api): ) # there should be one more user - assert User.query.count() == user_count + 1 + assert _get_user_query_count() == user_count + 1 # that user should be the one we added stmt = select(User).where(User.name == "Fake Personson")