From bc4d4c9735d7ab400c0fdbbb1c010106218af1cf Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Fri, 15 Nov 2024 13:42:27 -0800 Subject: [PATCH] more --- app/dao/email_branding_dao.py | 16 ++++- app/dao/fact_billing_dao.py | 72 ++++++++++--------- .../celery/test_process_ses_receipts_tasks.py | 7 +- tests/app/dao/test_invited_user_dao.py | 4 +- tests/app/delivery/test_send_to_providers.py | 4 +- .../test_process_notification.py | 4 +- .../test_receive_notification.py | 4 +- tests/app/service/test_api_key_endpoints.py | 10 ++- 8 files changed, 70 insertions(+), 51 deletions(-) diff --git a/app/dao/email_branding_dao.py b/app/dao/email_branding_dao.py index 1dedd78a8..61dc2a46b 100644 --- a/app/dao/email_branding_dao.py +++ b/app/dao/email_branding_dao.py @@ -1,18 +1,28 @@ +from sqlalchemy import select + from app import db from app.dao.dao_utils import autocommit from app.models import EmailBranding def dao_get_email_branding_options(): - return EmailBranding.query.all() + return db.session.execute(select(EmailBranding)).scalars().all() def dao_get_email_branding_by_id(email_branding_id): - return EmailBranding.query.filter_by(id=email_branding_id).one() + return ( + db.session.execute(select(EmailBranding).filter_by(id=email_branding_id)) + .scalars() + .one() + ) def dao_get_email_branding_by_name(email_branding_name): - return EmailBranding.query.filter_by(name=email_branding_name).first() + return ( + db.session.execute(select(EmailBranding).filter_by(name=email_branding_name)) + .scalars() + .first() + ) @autocommit diff --git a/app/dao/fact_billing_dao.py b/app/dao/fact_billing_dao.py index 132f62bf2..0371ae8e5 100644 --- a/app/dao/fact_billing_dao.py +++ b/app/dao/fact_billing_dao.py @@ -65,7 +65,7 @@ def fetch_sms_free_allowance_remainder_until_date(end_date): def fetch_sms_billing_for_all_services(start_date, end_date): # ASSUMPTION: AnnualBilling has been populated for year. - allowance_left_at_start_date_query = fetch_sms_free_allowance_remainder_until_date( + allowance_left_at_start_date_querie = fetch_sms_free_allowance_remainder_until_date( start_date ).subquery() @@ -76,14 +76,14 @@ def fetch_sms_billing_for_all_services(start_date, end_date): # subtract sms_billable_units units accrued since report's start date to get up-to-date # allowance remainder sms_allowance_left = func.greatest( - allowance_left_at_start_date_query.c.sms_remainder - sms_billable_units, 0 + allowance_left_at_start_date_querie.c.sms_remainder - sms_billable_units, 0 ) # billable units here are for period between start date and end date only, so to see # how many are chargeable, we need to see how much free allowance was used up in the # period up until report's start date and then do a subtraction chargeable_sms = func.greatest( - sms_billable_units - allowance_left_at_start_date_query.c.sms_remainder, 0 + sms_billable_units - allowance_left_at_start_date_querie.c.sms_remainder, 0 ) sms_cost = chargeable_sms * FactBilling.rate @@ -93,7 +93,7 @@ def fetch_sms_billing_for_all_services(start_date, end_date): Organization.id.label("organization_id"), Service.name.label("service_name"), Service.id.label("service_id"), - allowance_left_at_start_date_query.c.free_sms_fragment_limit, + allowance_left_at_start_date_querie.c.free_sms_fragment_limit, FactBilling.rate.label("sms_rate"), sms_allowance_left.label("sms_remainder"), sms_billable_units.label("sms_billable_units"), @@ -102,8 +102,8 @@ def fetch_sms_billing_for_all_services(start_date, end_date): ) .select_from(Service) .outerjoin( - allowance_left_at_start_date_query, - Service.id == allowance_left_at_start_date_query.c.service_id, + allowance_left_at_start_date_querie, + Service.id == allowance_left_at_start_date_querie.c.service_id, ) .outerjoin(Service.organization) .join( @@ -120,8 +120,8 @@ def fetch_sms_billing_for_all_services(start_date, end_date): Organization.id, Service.id, Service.name, - allowance_left_at_start_date_query.c.free_sms_fragment_limit, - allowance_left_at_start_date_query.c.sms_remainder, + allowance_left_at_start_date_querie.c.free_sms_fragment_limit, + allowance_left_at_start_date_querie.c.sms_remainder, FactBilling.rate, ) .order_by(Organization.name, Service.name) @@ -151,15 +151,15 @@ def fetch_billing_totals_for_year(service_id, year): union( *[ select( - query.c.notification_type.label("notification_type"), - query.c.rate.label("rate"), - func.sum(query.c.notifications_sent).label("notifications_sent"), - func.sum(query.c.chargeable_units).label("chargeable_units"), - func.sum(query.c.cost).label("cost"), - func.sum(query.c.free_allowance_used).label("free_allowance_used"), - func.sum(query.c.charged_units).label("charged_units"), - ).group_by(query.c.rate, query.c.notification_type) - for query in [ + querie.c.notification_type.label("notification_type"), + querie.c.rate.label("rate"), + func.sum(querie.c.notifications_sent).label("notifications_sent"), + func.sum(querie.c.chargeable_units).label("chargeable_units"), + func.sum(querie.c.cost).label("cost"), + func.sum(querie.c.free_allowance_used).label("free_allowance_used"), + func.sum(querie.c.charged_units).label("charged_units"), + ).group_by(querie.c.rate, querie.c.notification_type) + for querie in [ query_service_sms_usage_for_year(service_id, year).subquery(), query_service_email_usage_for_year(service_id, year).subquery(), ] @@ -206,22 +206,22 @@ def fetch_monthly_billing_for_year(service_id, year): union( *[ select( - query.c.rate.label("rate"), - query.c.notification_type.label("notification_type"), - func.date_trunc("month", query.c.local_date) + querie.c.rate.label("rate"), + querie.c.notification_type.label("notification_type"), + func.date_trunc("month", querie.c.local_date) .cast(Date) .label("month"), - func.sum(query.c.notifications_sent).label("notifications_sent"), - func.sum(query.c.chargeable_units).label("chargeable_units"), - func.sum(query.c.cost).label("cost"), - func.sum(query.c.free_allowance_used).label("free_allowance_used"), - func.sum(query.c.charged_units).label("charged_units"), + func.sum(querie.c.notifications_sent).label("notifications_sent"), + func.sum(querie.c.chargeable_units).label("chargeable_units"), + func.sum(querie.c.cost).label("cost"), + func.sum(querie.c.free_allowance_used).label("free_allowance_used"), + func.sum(querie.c.charged_units).label("charged_units"), ).group_by( - query.c.rate, - query.c.notification_type, + querie.c.rate, + querie.c.notification_type, "month", ) - for query in [ + for querie in [ query_service_sms_usage_for_year(service_id, year).subquery(), query_service_email_usage_for_year(service_id, year).subquery(), ] @@ -371,9 +371,9 @@ def fetch_billing_data_for_day(process_day, service_id=None, check_permissions=F ) transit_data = [] if not service_id: - services = Service.query.all() + services = db.session.execute(select(Service)).scalars().all() else: - services = [Service.query.get(service_id)] + services = [db.session.get(Service, service_id)] for service in services: for notification_type in (NotificationType.SMS, NotificationType.EMAIL): @@ -586,12 +586,12 @@ def fetch_email_usage_for_organization(organization_id, start_date, end_date): def fetch_sms_billing_for_organization(organization_id, financial_year): # ASSUMPTION: AnnualBilling has been populated for year. - ft_billing_subquery = query_organization_sms_usage_for_year( + ft_billing_subquerie = query_organization_sms_usage_for_year( organization_id, financial_year ).subquery() sms_billable_units = func.sum( - func.coalesce(ft_billing_subquery.c.chargeable_units, 0) + func.coalesce(ft_billing_subquerie.c.chargeable_units, 0) ) # subtract sms_billable_units units accrued since report's start date to get up-to-date @@ -600,8 +600,8 @@ def fetch_sms_billing_for_organization(organization_id, financial_year): AnnualBilling.free_sms_fragment_limit - sms_billable_units, 0 ) - chargeable_sms = func.sum(ft_billing_subquery.c.charged_units) - sms_cost = func.sum(ft_billing_subquery.c.cost) + chargeable_sms = func.sum(ft_billing_subquerie.c.charged_units) + sms_cost = func.sum(ft_billing_subquerie.c.cost) query = ( select( @@ -622,7 +622,9 @@ def fetch_sms_billing_for_organization(organization_id, financial_year): AnnualBilling.financial_year_start == financial_year, ), ) - .outerjoin(ft_billing_subquery, Service.id == ft_billing_subquery.c.service_id) + .outerjoin( + ft_billing_subquerie, Service.id == ft_billing_subquerie.c.service_id + ) .filter( Service.organization_id == organization_id, Service.restricted.is_(False) ) diff --git a/tests/app/celery/test_process_ses_receipts_tasks.py b/tests/app/celery/test_process_ses_receipts_tasks.py index 226394eeb..77dfc68a4 100644 --- a/tests/app/celery/test_process_ses_receipts_tasks.py +++ b/tests/app/celery/test_process_ses_receipts_tasks.py @@ -2,8 +2,9 @@ import json from unittest.mock import ANY from freezegun import freeze_time +from sqlalchemy import select -from app import encryption +from app import db, encryption from app.celery.process_ses_receipts_tasks import ( process_ses_results, remove_emails_from_bounce, @@ -168,7 +169,7 @@ def test_process_ses_results_in_complaint(sample_email_template, mocker): ) process_ses_results(response=ses_complaint_callback()) assert mocked.call_count == 0 - complaints = Complaint.query.all() + complaints = db.session.execute(select(Complaint)).scalars().all() assert len(complaints) == 1 assert complaints[0].notification_id == notification.id @@ -420,7 +421,7 @@ def test_ses_callback_should_send_on_complaint_to_user_callback_api( assert send_mock.call_count == 1 assert encryption.decrypt(send_mock.call_args[0][0][0]) == { "complaint_date": "2018-06-05T13:59:58.000000Z", - "complaint_id": str(Complaint.query.one().id), + "complaint_id": str(db.session.execute(select(Complaint)).scalars().one().id), "notification_id": str(notification.id), "reference": None, "service_callback_api_bearer_token": "some_super_secret", diff --git a/tests/app/dao/test_invited_user_dao.py b/tests/app/dao/test_invited_user_dao.py index 44fc23572..656dec568 100644 --- a/tests/app/dao/test_invited_user_dao.py +++ b/tests/app/dao/test_invited_user_dao.py @@ -115,12 +115,12 @@ def test_save_invited_user_sets_status_to_cancelled( notify_db_session, sample_invited_user ): assert _get_invited_user_count() == 1 - saved = InvitedUser.query.get(sample_invited_user.id) + saved = db.session.get(InvitedUser, sample_invited_user.id) assert saved.status == InvitedUserStatus.PENDING saved.status = InvitedUserStatus.CANCELLED save_invited_user(saved) assert _get_invited_user_count() == 1 - cancelled_invited_user = InvitedUser.query.get(sample_invited_user.id) + cancelled_invited_user = db.session.get(InvitedUser, sample_invited_user.id) assert cancelled_invited_user.status == InvitedUserStatus.CANCELLED diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index d08328ef7..88569bcd4 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -197,7 +197,7 @@ def test_should_not_send_email_message_when_service_is_inactive_notifcation_is_i assert str(sample_notification.id) in str(e.value) send_mock.assert_not_called() assert ( - Notification.query.get(sample_notification.id).status + db.session.get(Notification, sample_notification.id).status == NotificationStatus.TECHNICAL_FAILURE ) @@ -221,7 +221,7 @@ def test_should_not_send_sms_message_when_service_is_inactive_notification_is_in assert str(sample_notification.id) in str(e.value) send_mock.assert_not_called() assert ( - Notification.query.get(sample_notification.id).status + db.session.get(Notification, sample_notification.id).status == NotificationStatus.TECHNICAL_FAILURE ) diff --git a/tests/app/notifications/test_process_notification.py b/tests/app/notifications/test_process_notification.py index 9f393b440..6bdcf0122 100644 --- a/tests/app/notifications/test_process_notification.py +++ b/tests/app/notifications/test_process_notification.py @@ -100,9 +100,9 @@ def test_persist_notification_creates_and_save_to_db( reply_to_text=sample_template.service.get_default_sms_sender(), ) - assert Notification.query.get(notification.id) is not None + assert db.session.get(Notification, notification.id) is not None - notification_from_db = Notification.query.one() + notification_from_db = db.session.execute(select(Notification)).scalars().one() assert notification_from_db.id == notification.id assert notification_from_db.template_id == notification.template_id diff --git a/tests/app/notifications/test_receive_notification.py b/tests/app/notifications/test_receive_notification.py index e13b8d82e..9bc9d35f6 100644 --- a/tests/app/notifications/test_receive_notification.py +++ b/tests/app/notifications/test_receive_notification.py @@ -64,7 +64,7 @@ def test_receive_notification_returns_received_to_sns( prom_counter_labels_mock.assert_called_once_with("sns") prom_counter_labels_mock.return_value.inc.assert_called_once_with() - inbound_sms_id = InboundSms.query.all()[0].id + inbound_sms_id = db.session.execute(select(InboundSms)).scalars().all()[0].id mocked.assert_called_once_with( [str(inbound_sms_id), str(sample_service_full_permissions.id)], queue="notify-internal-tasks", @@ -136,7 +136,7 @@ def test_receive_notification_without_permissions_does_not_create_inbound_even_w response = sns_post(client, data) assert response.status_code == 200 - assert len(InboundSms.query.all()) == 0 + assert len(db.session.execute(select(InboundSms)).scalars().all()) == 0 assert mocked_has_permissions.called mocked_send_inbound_sms.assert_not_called() diff --git a/tests/app/service/test_api_key_endpoints.py b/tests/app/service/test_api_key_endpoints.py index 09a964b3c..f5a8af007 100644 --- a/tests/app/service/test_api_key_endpoints.py +++ b/tests/app/service/test_api_key_endpoints.py @@ -27,7 +27,13 @@ def test_api_key_should_create_new_api_key_for_service(notify_api, sample_servic ) assert response.status_code == 201 assert "data" in json.loads(response.get_data(as_text=True)) - saved_api_key = ApiKey.query.filter_by(service_id=sample_service.id).first() + saved_api_key = ( + db.session.execute( + select(ApiKey).filter_by(service_id=sample_service.id) + ) + .scalars() + .first() + ) assert saved_api_key.service_id == sample_service.id assert saved_api_key.name == "some secret name" @@ -81,7 +87,7 @@ def test_revoke_should_expire_api_key_for_service(notify_api, sample_api_key): headers=[auth_header], ) assert response.status_code == 202 - api_keys_for_service = ApiKey.query.get(sample_api_key.id) + api_keys_for_service = db.session.get(ApiKey, sample_api_key.id) assert api_keys_for_service.expiry_date is not None