diff --git a/app/celery/scheduled_tasks.py b/app/celery/scheduled_tasks.py index 3597bdbb7..57b890f39 100644 --- a/app/celery/scheduled_tasks.py +++ b/app/celery/scheduled_tasks.py @@ -1,10 +1,10 @@ from datetime import timedelta from flask import current_app -from sqlalchemy import between +from sqlalchemy import between, select from sqlalchemy.exc import SQLAlchemyError -from app import notify_celery, zendesk_client +from app import db, notify_celery, zendesk_client from app.celery.tasks import ( get_recipient_csv_and_template_and_sender_id, process_incomplete_jobs, @@ -105,14 +105,28 @@ def check_job_status(): thirty_minutes_ago = utc_now() - timedelta(minutes=30) thirty_five_minutes_ago = utc_now() - timedelta(minutes=35) - incomplete_in_progress_jobs = Job.query.filter( - Job.job_status == JobStatus.IN_PROGRESS, - between(Job.processing_started, thirty_five_minutes_ago, thirty_minutes_ago), + incomplete_in_progress_jobs = ( + db.session.execute( + select(Job).where( + Job.job_status == JobStatus.IN_PROGRESS, + between( + Job.processing_started, thirty_five_minutes_ago, thirty_minutes_ago + ), + ) + ) + .scalars() + .all() ) - incomplete_pending_jobs = Job.query.filter( - Job.job_status == JobStatus.PENDING, - Job.scheduled_for.isnot(None), - between(Job.scheduled_for, thirty_five_minutes_ago, thirty_minutes_ago), + incomplete_pending_jobs = ( + db.session.execute( + select(Job).where( + Job.job_status == JobStatus.PENDING, + Job.scheduled_for.isnot(None), + between(Job.scheduled_for, thirty_five_minutes_ago, thirty_minutes_ago), + ) + ) + .scalars() + .all() ) jobs_not_complete_after_30_minutes = ( diff --git a/app/dao/service_callback_api_dao.py b/app/dao/service_callback_api_dao.py index 275299cfd..d65e341ef 100644 --- a/app/dao/service_callback_api_dao.py +++ b/app/dao/service_callback_api_dao.py @@ -1,9 +1,11 @@ +from sqlalchemy import select + from app import create_uuid, db from app.dao.dao_utils import autocommit, version_class from app.enums import CallbackType from app.models import ServiceCallbackApi from app.utils import utc_now -from sqlalchemy import select + @autocommit @version_class(ServiceCallbackApi) @@ -29,23 +31,41 @@ def reset_service_callback_api( def get_service_callback_api(service_callback_api_id, service_id): - return db.session.execute(select(ServiceCallbackApi).filter_by( - id=service_callback_api_id, service_id=service_id - )).scalars().first() + return ( + db.session.execute( + select(ServiceCallbackApi).filter_by( + id=service_callback_api_id, service_id=service_id + ) + ) + .scalars() + .first() + ) def get_service_delivery_status_callback_api_for_service(service_id): - return db.session.execute(select(ServiceCallbackApi).filter_by( - service_id=service_id, - callback_type=CallbackType.DELIVERY_STATUS, - )).scalars().first() + return ( + db.session.execute( + select(ServiceCallbackApi).filter_by( + service_id=service_id, + callback_type=CallbackType.DELIVERY_STATUS, + ) + ) + .scalars() + .first() + ) def get_service_complaint_callback_api_for_service(service_id): - return db.session.execute(select(ServiceCallbackApi).filter_by( - service_id=service_id, - callback_type=CallbackType.COMPLAINT, - )).scalars().first() + return ( + db.session.execute( + select(ServiceCallbackApi).filter_by( + service_id=service_id, + callback_type=CallbackType.COMPLAINT, + ) + ) + .scalars() + .first() + ) @autocommit diff --git a/app/service/rest.py b/app/service/rest.py index 7dd614058..11b2f4403 100644 --- a/app/service/rest.py +++ b/app/service/rest.py @@ -2,10 +2,12 @@ import itertools from datetime import datetime, timedelta from flask import Blueprint, current_app, jsonify, request +from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import NoResultFound from werkzeug.datastructures import MultiDict +from app import db from app.aws.s3 import get_personalisation_from_s3, get_phone_number_from_s3 from app.config import QueueNames from app.dao import fact_notification_status_dao, notifications_dao @@ -419,14 +421,26 @@ def get_service_history(service_id): template_history_schema, ) - service_history = Service.get_history_model().query.filter_by(id=service_id).all() + service_history = ( + db.session.execute(select(Service.get_history_model()).filter_by(id=service_id)) + .scalars() + .all() + ) service_data = service_history_schema.dump(service_history, many=True) api_key_history = ( - ApiKey.get_history_model().query.filter_by(service_id=service_id).all() + db.session.execute( + select(ApiKey.get_history_model()).filter_by(service_id=service_id) + ) + .scalars() + .all() ) api_keys_data = api_key_history_schema.dump(api_key_history, many=True) - template_history = TemplateHistory.query.filter_by(service_id=service_id).all() + template_history = ( + db.session.execute(select(TemplateHistory).filter_by(service_id=service_id)) + .scalars() + .all() + ) template_data = template_history_schema.dump(template_history, many=True) data = { diff --git a/tests/app/dao/test_service_inbound_api_dao.py b/tests/app/dao/test_service_inbound_api_dao.py index 321b7d82e..03eb6d616 100644 --- a/tests/app/dao/test_service_inbound_api_dao.py +++ b/tests/app/dao/test_service_inbound_api_dao.py @@ -37,7 +37,10 @@ def test_save_service_inbound_api(sample_service): assert inbound_api.updated_at is None versioned = ( - ServiceInboundApi.get_history_model().query.filter_by(id=inbound_api.id).one() + db.session.execute(select(ServiceInboundApi.get_history_model())) + .filter_by(id=inbound_api.id) + .scalars() + .one() ) assert versioned.id == inbound_api.id assert versioned.service_id == sample_service.id @@ -90,8 +93,12 @@ def test_update_service_inbound_api(sample_service): assert updated.updated_at is not None versioned_results = ( - ServiceInboundApi.get_history_model() - .query.filter_by(id=saved_inbound_api.id) + db.session.execute( + select(ServiceInboundApi) + .get_history_model() + .filter_by(id=saved_inbound_api.id) + ) + .scalars() .all() ) assert len(versioned_results) == 2 diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index 20b0f7186..d08328ef7 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -5,9 +5,10 @@ from unittest.mock import ANY import pytest from flask import current_app from requests import HTTPError +from sqlalchemy import select import app -from app import aws_sns_client, notification_provider_clients +from app import aws_sns_client, db, notification_provider_clients from app.cloudfoundry_config import cloud_config from app.dao import notifications_dao from app.dao.provider_details_dao import get_provider_details_by_identifier @@ -108,7 +109,11 @@ def test_should_send_personalised_template_to_correct_sms_provider_and_persist( international=False, ) - notification = Notification.query.filter_by(id=db_notification.id).one() + notification = ( + db.session.execute(select(Notification).filter_by(id=db_notification.id)) + .scalars() + .one() + ) assert notification.status == NotificationStatus.SENDING assert notification.sent_at <= utc_now() @@ -152,7 +157,11 @@ def test_should_send_personalised_template_to_correct_email_provider_and_persist in app.aws_ses_client.send_email.call_args[1]["html_body"] ) - notification = Notification.query.filter_by(id=db_notification.id).one() + notification = ( + db.session.execute(select(Notification).filter_by(id=db_notification.id)) + .scalars() + .one() + ) assert notification.status == NotificationStatus.SENDING assert notification.sent_at <= utc_now() assert notification.sent_by == "ses" diff --git a/tests/app/user/test_rest_verify.py b/tests/app/user/test_rest_verify.py index d32d923bf..5c6eb6f5e 100644 --- a/tests/app/user/test_rest_verify.py +++ b/tests/app/user/test_rest_verify.py @@ -20,7 +20,7 @@ from tests import create_admin_authorization_header @freeze_time("2016-01-01T12:00:00") def test_user_verify_sms_code(client, sample_sms_code): sample_sms_code.user.logged_in_at = utc_now() - timedelta(days=1) - assert not VerifyCode.query.first().code_used + assert not db.session.execute(select(VerifyCode)).scalars().first().code_used assert sample_sms_code.user.current_session_id is None data = json.dumps( {"code_type": sample_sms_code.code_type, "code": sample_sms_code.txt_code} @@ -32,14 +32,14 @@ def test_user_verify_sms_code(client, sample_sms_code): headers=[("Content-Type", "application/json"), auth_header], ) assert resp.status_code == 204 - assert VerifyCode.query.first().code_used + assert db.session.execute(select(VerifyCode)).scalars().first().code_used assert sample_sms_code.user.logged_in_at == utc_now() assert sample_sms_code.user.email_access_validated_at != utc_now() assert sample_sms_code.user.current_session_id is not None def test_user_verify_code_missing_code(client, sample_sms_code): - assert not VerifyCode.query.first().code_used + assert not db.session.execute(select(VerifyCode)).scalars().first().code_used data = json.dumps({"code_type": sample_sms_code.code_type}) auth_header = create_admin_authorization_header() resp = client.post( @@ -48,14 +48,14 @@ def test_user_verify_code_missing_code(client, sample_sms_code): headers=[("Content-Type", "application/json"), auth_header], ) assert resp.status_code == 400 - assert not VerifyCode.query.first().code_used - assert User.query.get(sample_sms_code.user.id).failed_login_count == 0 + assert not db.session.execute(select(VerifyCode)).scalars().first().code_used + assert db.session.get(User, sample_sms_code.user.id).failed_login_count == 0 def test_user_verify_code_bad_code_and_increments_failed_login_count( client, sample_sms_code ): - assert not VerifyCode.query.first().code_used + assert not db.session.execute(select(VerifyCode)).scalars().first().code_used data = json.dumps({"code_type": sample_sms_code.code_type, "code": "blah"}) auth_header = create_admin_authorization_header() resp = client.post( @@ -64,8 +64,8 @@ def test_user_verify_code_bad_code_and_increments_failed_login_count( headers=[("Content-Type", "application/json"), auth_header], ) assert resp.status_code == 404 - assert not VerifyCode.query.first().code_used - assert User.query.get(sample_sms_code.user.id).failed_login_count == 1 + assert not db.session.execute(select(VerifyCode)).scalars().first().code_used + assert db.session.get(User, sample_sms_code.user.id).failed_login_count == 1 @pytest.mark.parametrize( @@ -134,7 +134,7 @@ def test_user_verify_password(client, sample_user): headers=[("Content-Type", "application/json"), auth_header], ) assert resp.status_code == 204 - assert User.query.get(sample_user.id).logged_in_at == yesterday + assert db.session.get(User, sample_user.id).logged_in_at == yesterday def test_user_verify_password_invalid_password(client, sample_user): @@ -222,9 +222,9 @@ def test_send_user_sms_code(client, sample_user, sms_code_template, mocker): assert resp.status_code == 204 assert mocked.call_count == 1 - assert VerifyCode.query.one().check_code("11111") + assert db.session.execute(select(VerifyCode)).scalars().one().check_code("11111") - notification = Notification.query.one() + notification = db.session.execute(select(Notification)).one() assert notification.personalisation == {"verify_code": "11111"} assert notification.to == "1" assert str(notification.service_id) == current_app.config["NOTIFY_SERVICE_ID"] @@ -264,7 +264,7 @@ def test_send_user_code_for_sms_with_optional_to_field( assert resp.status_code == 204 assert mocked.call_count == 1 - notification = Notification.query.first() + notification = db.session.execute(select(Notification)).scalars().first() assert notification.to == "1" app.celery.provider_tasks.deliver_sms.apply_async.assert_called_once_with( ([str(notification.id)]), queue="notify-internal-tasks" @@ -346,7 +346,7 @@ def test_send_new_user_email_verification( ) notify_service = email_verification_template.service assert resp.status_code == 204 - notification = Notification.query.first() + notification = db.session.execute(select(Notification)).scalars().first() assert _get_verify_code_count() == 0 mocked.assert_called_once_with( ([str(notification.id)]), queue="notify-internal-tasks" @@ -487,7 +487,7 @@ def test_send_user_email_code( _data=data, _expected_status=204, ) - noti = Notification.query.one() + noti = db.session.execute(select(Notification)).scalars().one() assert ( noti.reply_to_text == email_2fa_code_template.service.get_default_reply_to_email_address() @@ -608,7 +608,7 @@ def test_send_user_2fa_code_sends_from_number_for_international_numbers( ) assert resp.status_code == 204 - notification = Notification.query.first() + notification = db.session.execute(select(Notification)).scalars().first() assert ( notification.reply_to_text == current_app.config["NOTIFY_INTERNATIONAL_SMS_SENDER"]