From 79ceddfee40be0f92a4911106911cefd160210b0 Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Wed, 30 Oct 2024 10:35:49 -0700 Subject: [PATCH] more fixes --- tests/app/dao/test_invited_user_dao.py | 70 +++++++++------------- tests/app/dao/test_provider_details_dao.py | 19 ++++-- tests/app/test_schemas.py | 13 ++-- 3 files changed, 49 insertions(+), 53 deletions(-) diff --git a/tests/app/dao/test_invited_user_dao.py b/tests/app/dao/test_invited_user_dao.py index da52e52e7..247a3dfda 100644 --- a/tests/app/dao/test_invited_user_dao.py +++ b/tests/app/dao/test_invited_user_dao.py @@ -2,6 +2,7 @@ import uuid from datetime import timedelta import pytest +from sqlalchemy import func, select from sqlalchemy.orm.exc import NoResultFound from app import db @@ -123,23 +124,17 @@ def test_should_delete_all_invitations_more_than_one_day_old( ): make_invitation(sample_user, sample_service, age=timedelta(hours=48)) make_invitation(sample_user, sample_service, age=timedelta(hours=48)) - assert ( - len( - InvitedUser.query.filter( - InvitedUser.status != InvitedUserStatus.EXPIRED - ).all() - ) - == 2 - ) + stmt = select(InvitedUser).where(InvitedUser.status != InvitedUserStatus.EXPIRED) + result = db.session.execute(stmt).scalars().all() + assert len(result) == 2 expire_invitations_created_more_than_two_days_ago() - assert ( - len( - InvitedUser.query.filter( - InvitedUser.status != InvitedUserStatus.EXPIRED - ).all() - ) - == 0 + stmt = ( + select(func.count()) + .select_from(InvitedUser) + .where(InvitedUser.status != InvitedUserStatus.EXPIRED) ) + count = db.session.execute(stmt).scalar() or 0 + assert count == 0 def test_should_not_delete_invitations_less_than_two_days_old( @@ -160,35 +155,28 @@ def test_should_not_delete_invitations_less_than_two_days_old( email_address="expired@1.com", ) - assert ( - len( - InvitedUser.query.filter( - InvitedUser.status != InvitedUserStatus.EXPIRED - ).all() - ) - == 2 + stmt = ( + select(func.count()) + .select_from(InvitedUser) + .where(InvitedUser.status != InvitedUserStatus.EXPIRED) ) + count = db.session.execute(stmt).scalar() or 0 + assert count == 2 expire_invitations_created_more_than_two_days_ago() - assert ( - len( - InvitedUser.query.filter( - InvitedUser.status != InvitedUserStatus.EXPIRED - ).all() - ) - == 1 - ) - assert ( - InvitedUser.query.filter(InvitedUser.status != InvitedUserStatus.EXPIRED) - .first() - .email_address - == "valid@2.com" - ) - assert ( - InvitedUser.query.filter(InvitedUser.status == InvitedUserStatus.EXPIRED) - .first() - .email_address - == "expired@1.com" + stmt = ( + select(func.count()) + .select_from(InvitedUser) + .where(InvitedUser.status != InvitedUserStatus.EXPIRED) ) + count = db.session.execute(stmt).scalar() or 0 + assert count == 1 + stmt = select(InvitedUser).where(InvitedUser.status != InvitedUserStatus.EXPIRED) + invited_user = db.session.execute(stmt).scalars().first() + assert invited_user.email_address == "valid@2.com" + stmt = select(InvitedUser).where(InvitedUser.status == InvitedUserStatus.EXPIRED) + invited_user = db.session.execute(stmt).scalars().first() + + assert invited_user.email_address == "expired@1.com" def make_invitation(user, service, age=None, email_address="test@test.com"): diff --git a/tests/app/dao/test_provider_details_dao.py b/tests/app/dao/test_provider_details_dao.py index 624ce2bb5..84c4b2238 100644 --- a/tests/app/dao/test_provider_details_dao.py +++ b/tests/app/dao/test_provider_details_dao.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta import pytest from freezegun import freeze_time -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from app import db, notification_provider_clients from app.dao.provider_details_dao import ( @@ -69,7 +69,6 @@ def test_should_not_error_if_any_provider_in_code_not_in_database( stmt = delete(ProviderDetails).where(ProviderDetails.identifier == "sns") db.session.execute(stmt) db.session.commit() - # ProviderDetails.query.filter_by(identifier="sns").delete() assert notification_provider_clients.get_sms_client("sns") @@ -135,9 +134,13 @@ def test_get_alternative_sms_provider_fails_if_unrecognised(): @freeze_time("2016-01-01 01:00") def test_get_sms_providers_for_update_returns_providers(restore_provider_details): - ProviderDetails.query.filter(ProviderDetails.identifier == "sns").update( - {"updated_at": None} + stmt = ( + update(ProviderDetails) + .where(ProviderDetails.identifier == "sns") + .values({"updated_at": None}) ) + db.session.execute(stmt) + db.session.commit() resp = _get_sms_providers_for_update(timedelta(hours=1)) @@ -149,9 +152,13 @@ def test_get_sms_providers_for_update_returns_nothing_if_recent_updates( restore_provider_details, ): fifty_nine_minutes_ago = datetime(2016, 1, 1, 0, 1) - ProviderDetails.query.filter(ProviderDetails.identifier == "sns").update( - {"updated_at": fifty_nine_minutes_ago} + stmt = ( + update(ProviderDetails) + .where(ProviderDetails.identifier == "sns") + .values({"updated_at": fifty_nine_minutes_ago}) ) + db.session.execute(stmt) + db.session.commit() resp = _get_sms_providers_for_update(timedelta(hours=1)) diff --git a/tests/app/test_schemas.py b/tests/app/test_schemas.py index 270c36a17..b71d2fef8 100644 --- a/tests/app/test_schemas.py +++ b/tests/app/test_schemas.py @@ -2,8 +2,9 @@ import datetime import pytest from marshmallow import ValidationError -from sqlalchemy import desc +from sqlalchemy import desc, select +from app import db from app.dao.provider_details_dao import ( dao_update_provider_details, get_provider_details_by_identifier, @@ -145,13 +146,13 @@ def test_provider_details_history_schema_returns_user_details( dao_update_provider_details(current_sms_provider) - current_sms_provider_in_history = ( - ProviderDetailsHistory.query.filter( - ProviderDetailsHistory.id == current_sms_provider.id - ) + stmt = ( + select(ProviderDetailsHistory) + .where(ProviderDetailsHistory.id == current_sms_provider.id) .order_by(desc(ProviderDetailsHistory.version)) - .first() ) + current_sms_provider_in_history = db.session.execute(stmt).scalars().first() + data = provider_details_schema.dump(current_sms_provider_in_history) assert sorted(data["created_by"].keys()) == sorted(["id", "email_address", "name"])