diff --git a/app/__init__.py b/app/__init__.py index 9fad21e30..1707be71c 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -94,7 +94,7 @@ migrate = None notify_celery = NotifyCelery() aws_ses_client = None aws_ses_stub_client = None -aws_sns_client = AwsSnsClient() +aws_sns_client = None aws_cloudwatch_client = None encryption = Encryption() zendesk_client = None @@ -135,6 +135,15 @@ def get_aws_ses_client(): return aws_ses_client +def get_aws_sns_client(): + global aws_sns_client + if os.environ.get("NOTIFY_ENVIRONMENT") == "test": + return AwsSnsClient() + if aws_ses_client is None: + raise RuntimeError(f"Celery not initialized aws_sns_client: {aws_sns_client}") + return aws_sns_client + + def get_document_download_client(): global document_download_client # Our unit tests mock anyway @@ -148,7 +157,7 @@ def get_document_download_client(): def create_app(application): - global zendesk_client, migrate, document_download_client, aws_ses_client, aws_ses_stub_client + global zendesk_client, migrate, document_download_client, aws_ses_client, aws_ses_stub_client, aws_sns_client from app.config import configs notify_environment = os.environ["NOTIFY_ENVIRONMENT"] @@ -166,7 +175,6 @@ def create_app(application): request_helper.init_app(application) db.init_app(application) logging.init_app(application) - aws_sns_client.init_app(application) # start lazy initialization for gevent migrate = Migrate() @@ -182,6 +190,8 @@ def create_app(application): aws_ses_client.init_app() aws_ses_stub_client = AwsSesStubClient() aws_ses_stub_client.init_app(stub_url=application.config["SES_STUB_URL"]) + aws_sns_client = AwsSnsClient() + aws_sns_client.init_app(application) # end lazy initialization diff --git a/tests/app/clients/test_aws_sns.py b/tests/app/clients/test_aws_sns.py index 09c623f18..5f1e8af44 100644 --- a/tests/app/clients/test_aws_sns.py +++ b/tests/app/clients/test_aws_sns.py @@ -1,6 +1,8 @@ import pytest -from app import aws_sns_client +from app import get_aws_sns_client + +aws_sns_client = get_aws_sns_client() def test_send_sms_successful_returns_aws_sns_response(notify_api, mocker): diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index ba0837f8e..ac564f7e2 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -8,7 +8,7 @@ from requests import HTTPError from sqlalchemy import select import app -from app import aws_sns_client, db, notification_provider_clients +from app import db, notification_provider_clients from app.cloudfoundry_config import cloud_config from app.dao import notifications_dao from app.dao.provider_details_dao import get_provider_details_by_identifier @@ -88,7 +88,9 @@ def test_should_send_personalised_template_to_correct_sms_provider_and_persist( reply_to_text=sample_sms_template_with_html.service.get_default_sms_sender(), ) - mocker.patch("app.aws_sns_client.send_sms") + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + send_mock = mocker.patch("app.aws_sns_client", mock_client) mock_s3 = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") mock_s3.return_value = "2028675309" @@ -101,7 +103,7 @@ def test_should_send_personalised_template_to_correct_sms_provider_and_persist( send_to_providers.send_sms_to_provider(db_notification) - aws_sns_client.send_sms.assert_called_once_with( + send_mock.send_sms.assert_called_once_with( to="2028675309", content="Sample service: Hello Jo\nHere is some HTML & entities", reference=str(db_notification.id), @@ -224,7 +226,10 @@ def test_should_not_send_sms_message_when_service_is_inactive_notification_is_in sample_service, sample_notification, mocker ): sample_service.active = False - send_mock = mocker.patch("app.aws_sns_client.send_sms", return_value="reference") + + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + send_mock = mocker.patch("app.aws_sns_client", mock_client) mock_phone = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") mock_phone.return_value = "15555555555" @@ -237,7 +242,7 @@ def test_should_not_send_sms_message_when_service_is_inactive_notification_is_in with pytest.raises(NotificationTechnicalFailureException) as e: send_to_providers.send_sms_to_provider(sample_notification) assert str(sample_notification.id) in str(e.value) - send_mock.assert_not_called() + send_mock.send_sms.assert_not_called() assert ( db.session.get(Notification, sample_notification.id).status == NotificationStatus.TECHNICAL_FAILURE @@ -266,7 +271,9 @@ def test_send_sms_should_use_template_version_from_notification_not_latest( ) mock_s3_p.return_value = {"ignore": "ignore"} - mocker.patch("app.aws_sns_client.send_sms") + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + send_mock = mocker.patch("app.aws_sns_client", mock_client) version_on_notification = sample_template.version expected_template_id = sample_template.id @@ -283,7 +290,7 @@ def test_send_sms_should_use_template_version_from_notification_not_latest( send_to_providers.send_sms_to_provider(db_notification) - aws_sns_client.send_sms.assert_called_once_with( + send_mock.send_sms.assert_called_once_with( to="2028675309", content="Sample service: This is a template:\nwith a newline", reference=str(db_notification.id), @@ -333,7 +340,10 @@ def test_should_not_send_to_provider_when_status_is_not_created( template=sample_template, status=NotificationStatus.SENDING, ) - mocker.patch("app.aws_sns_client.send_sms") + + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + send_mock = mocker.patch("app.aws_sns_client", mock_client) response_mock = mocker.patch("app.delivery.send_to_providers.send_sms_response") mock_s3 = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") @@ -346,7 +356,7 @@ def test_should_not_send_to_provider_when_status_is_not_created( send_to_providers.send_sms_to_provider(notification) - app.aws_sns_client.send_sms.assert_not_called() + send_mock.send_sms.assert_not_called() response_mock.assert_not_called() @@ -371,7 +381,9 @@ def test_should_send_sms_with_downgraded_content(notify_db_session, mocker): db_notification.personalisation = {"misc": placeholder} db_notification.reply_to_text = "testing" - mocker.patch("app.aws_sns_client.send_sms") + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + send_mock = mocker.patch("app.aws_sns_client", mock_client) mock_phone = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") mock_phone.return_value = "15555555555" @@ -383,7 +395,7 @@ def test_should_send_sms_with_downgraded_content(notify_db_session, mocker): send_to_providers.send_sms_to_provider(db_notification) - aws_sns_client.send_sms.assert_called_once_with( + send_mock.send_sms.assert_called_once_with( to=ANY, content=gsm_message, reference=ANY, sender=ANY, international=False ) @@ -393,7 +405,10 @@ def test_send_sms_should_use_service_sms_sender( ): mocker.patch("app.delivery.send_to_providers.redis_store", return_value=None) - mocker.patch("app.aws_sns_client.send_sms") + + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + send_mock = mocker.patch("app.aws_sns_client", mock_client) mocker.patch("app.delivery.send_to_providers.update_notification_message_id") sms_sender = create_service_sms_sender( @@ -415,7 +430,7 @@ def test_send_sms_should_use_service_sms_sender( db_notification, ) - app.aws_sns_client.send_sms.assert_called_once_with( + send_mock.send_sms.assert_called_once_with( to=ANY, content=ANY, reference=ANY, @@ -689,7 +704,10 @@ def test_should_update_billable_units_and_status_according_to_research_mode_and_ "app.delivery.send_to_providers.update_notification_message_id", return_value=None, ) - mocker.patch("app.aws_sns_client.send_sms") + + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + mocker.patch("app.aws_sns_client", mock_client) mocker.patch( "app.delivery.send_to_providers.send_sms_response", side_effect=__update_notification(notification, research_mode, expected_status), @@ -718,7 +736,10 @@ def test_should_set_notification_billable_units_and_reduces_provider_priority_if sample_notification, mocker, ): - mocker.patch("app.aws_sns_client.send_sms", side_effect=Exception()) + + mock_client = MagicMock() + mock_client.send_sms.side_effect = Exception() + mocker.patch("app.aws_sns_client", mock_client) sample_notification.billable_units = 0 assert sample_notification.sent_by is None @@ -748,7 +769,10 @@ def test_should_send_sms_to_international_providers( ): mocker.patch("app.delivery.send_to_providers._get_verify_code", return_value=None) - mocker.patch("app.aws_sns_client.send_sms") + + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + send_mock = mocker.patch("app.aws_sns_client", mock_client) notification_international = create_notification( template=sample_template, @@ -776,7 +800,7 @@ def test_should_send_sms_to_international_providers( send_to_providers.send_sms_to_provider(notification_international) - aws_sns_client.send_sms.assert_called_once_with( + send_mock.send_sms.assert_called_once_with( to="601117224412", content=ANY, reference=str(notification_international.id), @@ -805,7 +829,10 @@ def test_should_handle_sms_sender_and_prefix_message( ): mocker.patch("app.delivery.send_to_providers.redis_store", return_value=None) - mocker.patch("app.aws_sns_client.send_sms") + + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + send_mock = mocker.patch("app.aws_sns_client", mock_client) mocker.patch( "app.delivery.send_to_providers.update_notification_message_id", @@ -827,7 +854,7 @@ def test_should_handle_sms_sender_and_prefix_message( send_to_providers.send_sms_to_provider(notification) - aws_sns_client.send_sms.assert_called_once_with( + send_mock.send_sms.assert_called_once_with( content=expected_content, sender=expected_sender, to=ANY, @@ -882,7 +909,10 @@ def test_send_sms_to_provider_should_use_normalised_to(mocker, client, sample_te "app.delivery.send_to_providers.update_notification_message_id", return_value=None, ) - send_mock = mocker.patch("app.aws_sns_client.send_sms") + + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + send_mock = mocker.patch("app.aws_sns_client", mock_client) notification = create_notification( template=sample_template, to_field="+12028675309", @@ -975,7 +1005,9 @@ def test_send_sms_to_provider_should_return_template_if_found_in_redis( ) mock_get_service = mocker.patch("app.dao.services_dao.dao_fetch_service_by_id") - send_mock = mocker.patch("app.aws_sns_client.send_sms") + mock_client = MagicMock() + mock_client.send_sms.return_value = "reference" + send_mock = mocker.patch("app.aws_sns_client", mock_client) notification = create_notification( template=sample_template, to_field="+447700900855",