diff --git a/.ds.baseline b/.ds.baseline index a5bd1bd9e..148542232 100644 --- a/.ds.baseline +++ b/.ds.baseline @@ -341,7 +341,7 @@ "filename": "tests/app/user/test_rest.py", "hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8", "is_verified": false, - "line_number": 108, + "line_number": 110, "is_secret": false }, { @@ -349,7 +349,7 @@ "filename": "tests/app/user/test_rest.py", "hashed_secret": "0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33", "is_verified": false, - "line_number": 850, + "line_number": 858, "is_secret": false } ], @@ -384,5 +384,5 @@ } ] }, - "generated_at": "2024-11-15T16:25:55Z" + "generated_at": "2024-11-15T18:43:06Z" } diff --git a/tests/app/celery/test_reporting_tasks.py b/tests/app/celery/test_reporting_tasks.py index 0761e6103..6a5001713 100644 --- a/tests/app/celery/test_reporting_tasks.py +++ b/tests/app/celery/test_reporting_tasks.py @@ -192,7 +192,11 @@ def test_create_nightly_billing_for_day_sms_rate_multiplier( assert len(records) == 0 create_nightly_billing_for_day(str(yesterday.date())) - records = FactBilling.query.order_by("rate_multiplier").all() + records = ( + db.session.execute(select(FactBilling).order_by("rate_multiplier")) + .scalars() + .all() + ) assert len(records) == records_num for i, record in enumerate(records): @@ -232,7 +236,11 @@ def test_create_nightly_billing_for_day_different_templates( assert len(records) == 0 create_nightly_billing_for_day(str(yesterday.date())) - records = FactBilling.query.order_by("rate_multiplier").all() + records = ( + db.session.execute(select(FactBilling).order_by("rate_multiplier")) + .query() + .all() + ) assert len(records) == 2 multiplier = [0, 1] billable_units = [0, 1] @@ -276,7 +284,11 @@ def test_create_nightly_billing_for_day_same_sent_by( assert len(records) == 0 create_nightly_billing_for_day(str(yesterday.date())) - records = FactBilling.query.order_by("rate_multiplier").all() + records = ( + db.session.execute(select(FactBilling).order_by("rate_multiplier")) + .scalars() + .all() + ) assert len(records) == 1 for _, record in enumerate(records): @@ -371,7 +383,11 @@ def test_create_nightly_billing_for_day_use_BST( assert count == 0 create_nightly_billing_for_day("2018-03-25") - records = FactBilling.query.order_by(FactBilling.local_date).all() + records = ( + db.session.execute(select(FactBilling).order_by(FactBilling.local_date)) + .scalars() + .all() + ) assert len(records) == 1 assert records[0].local_date == date(2018, 3, 25) @@ -398,7 +414,11 @@ def test_create_nightly_billing_for_day_update_when_record_exists( assert len(records) == 0 create_nightly_billing_for_day("2018-01-14") - records = FactBilling.query.order_by(FactBilling.local_date).all() + records = ( + db.session.execute(select(FactBilling).order_by(FactBilling.local_date)) + .scalars() + .all() + ) assert len(records) == 1 assert records[0].local_date == date(2018, 1, 14) @@ -477,10 +497,16 @@ def test_create_nightly_notification_status_for_service_and_day(notify_db_sessio NotificationType.EMAIL, ) - new_fact_data = FactNotificationStatus.query.order_by( - FactNotificationStatus.notification_type, - FactNotificationStatus.notification_status, - ).all() + new_fact_data = ( + db.session.execute( + select(FactNotificationStatus).order_by( + FactNotificationStatus.notification_type, + FactNotificationStatus.notification_status, + ) + ) + .scalars() + .all() + ) assert len(new_fact_data) == 4 @@ -555,9 +581,15 @@ def test_create_nightly_notification_status_for_service_and_day_overwrites_old_d NotificationType.SMS, ) - updated_fact_data = FactNotificationStatus.query.order_by( - FactNotificationStatus.notification_status - ).all() + updated_fact_data = ( + db.session.execute( + select(FactNotificationStatus).order_by( + FactNotificationStatus.notification_status + ) + ) + .scalars() + .all() + ) assert len(updated_fact_data) == 2 assert updated_fact_data[0].notification_count == 1 @@ -600,9 +632,13 @@ def test_create_nightly_notification_status_for_service_and_day_respects_bst( NotificationType.SMS, ) - noti_status = FactNotificationStatus.query.order_by( - FactNotificationStatus.local_date - ).all() + noti_status = ( + db.session.execute( + select(FactNotificationStatus).order_by(FactNotificationStatus.local_date) + ) + .scalars() + .all() + ) assert len(noti_status) == 1 assert noti_status[0].local_date == date(2019, 4, 1) diff --git a/tests/app/provider_details/test_rest.py b/tests/app/provider_details/test_rest.py index a5780fcb6..0d64bf297 100644 --- a/tests/app/provider_details/test_rest.py +++ b/tests/app/provider_details/test_rest.py @@ -1,7 +1,9 @@ import pytest from flask import json from freezegun import freeze_time +from sqlalchemy import select +from app import db from app.models import ProviderDetails, ProviderDetailsHistory from tests import create_admin_authorization_header from tests.app.db import create_ft_billing @@ -53,7 +55,7 @@ def test_get_provider_contains_correct_fields(client, sample_template): def test_should_be_able_to_update_status(client, restore_provider_details): - provider = ProviderDetails.query.first() + provider = db.session.execute(select(ProviderDetails)).scalars().first() update_resp_1 = client.post( "/provider-details/{}".format(provider.id), @@ -76,7 +78,7 @@ def test_should_be_able_to_update_status(client, restore_provider_details): def test_should_not_be_able_to_update_disallowed_fields( client, restore_provider_details, field, value ): - provider = ProviderDetails.query.first() + provider = db.session.execute(select(ProviderDetails)).scalars().first() resp = client.post( "/provider-details/{}".format(provider.id), @@ -94,7 +96,7 @@ def test_should_not_be_able_to_update_disallowed_fields( def test_get_provider_versions_contains_correct_fields(client, notify_db_session): - provider = ProviderDetailsHistory.query.first() + provider = db.session.execute(select(ProviderDetailsHistory)).scalars().first() response = client.get( "/provider-details/{}/versions".format(provider.id), headers=[create_admin_authorization_header()], @@ -117,7 +119,7 @@ def test_get_provider_versions_contains_correct_fields(client, notify_db_session def test_update_provider_should_store_user_id( client, restore_provider_details, sample_user ): - provider = ProviderDetails.query.first() + provider = db.session.execute(select(ProviderDetails)).scalars().first() update_resp_1 = client.post( "/provider-details/{}".format(provider.id), diff --git a/tests/app/user/test_rest.py b/tests/app/user/test_rest.py index 054f2c6b1..bd62bc640 100644 --- a/tests/app/user/test_rest.py +++ b/tests/app/user/test_rest.py @@ -6,7 +6,7 @@ from unittest import mock import pytest from flask import current_app from freezegun import freeze_time -from sqlalchemy import func, select +from sqlalchemy import delete, func, select from app import db from app.dao.service_user_dao import dao_get_service_user, dao_update_service_user @@ -101,7 +101,9 @@ def test_post_user(admin_request, notify_db_session): """ Tests POST endpoint '/' to create a user. """ - User.query.delete() + db.session.execute(delete(User)) + db.session.commit() + data = { "name": "Test User", "email_address": "user@digital.fake.gov", @@ -129,7 +131,9 @@ def test_post_user(admin_request, notify_db_session): def test_post_user_without_auth_type(admin_request, notify_db_session): - User.query.delete() + + db.session.execute(delete(User)) + db.session.commit() data = { "name": "Test User", "email_address": "user@digital.fake.gov", @@ -155,7 +159,9 @@ def test_post_user_missing_attribute_email(admin_request, notify_db_session): """ Tests POST endpoint '/' missing attribute email. """ - User.query.delete() + + db.session.execute(delete(User)) + db.session.commit() data = { "name": "Test User", "password": "password", @@ -182,7 +188,9 @@ def test_create_user_missing_attribute_password(admin_request, notify_db_session """ Tests POST endpoint '/' missing attribute password. """ - User.query.delete() + + db.session.execute(delete(User)) + db.session.commit() data = { "name": "Test User", "email_address": "user@digital.fake.gov",