diff --git a/app/__init__.py b/app/__init__.py index abadaa315..81e5c055a 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -29,6 +29,7 @@ from werkzeug.exceptions import HTTPException as WerkzeugHTTPException from werkzeug.local import LocalProxy from app.clients import NotificationProviderClients +from app.clients.cloudwatch.aws_cloudwatch import AwsCloudwatchClient from app.clients.document_download import DocumentDownloadClient from app.clients.email.aws_ses import AwsSesClient from app.clients.email.aws_ses_stub import AwsSesStubClient @@ -55,6 +56,7 @@ notify_celery = NotifyCelery() aws_ses_client = AwsSesClient() aws_ses_stub_client = AwsSesStubClient() aws_sns_client = AwsSnsClient() +aws_cloudwatch_client = AwsCloudwatchClient() encryption = Encryption() zendesk_client = ZendeskClient() redis_store = RedisClient() @@ -96,6 +98,7 @@ def create_app(application): aws_ses_stub_client.init_app( stub_url=application.config['SES_STUB_URL'] ) + aws_cloudwatch_client.init_app(application) # If a stub url is provided for SES, then use the stub client rather than the real SES boto client email_clients = [aws_ses_stub_client] if application.config['SES_STUB_URL'] else [aws_ses_client] notification_provider_clients.init_app( diff --git a/app/celery/provider_tasks.py b/app/celery/provider_tasks.py index a274635ce..01d826ba6 100644 --- a/app/celery/provider_tasks.py +++ b/app/celery/provider_tasks.py @@ -1,7 +1,11 @@ +from datetime import datetime, timedelta +from time import time +from zoneinfo import ZoneInfo + from flask import current_app from sqlalchemy.orm.exc import NoResultFound -from app import notify_celery +from app import aws_cloudwatch_client, notify_celery from app.clients.email import EmailClientNonRetryableException from app.clients.email.aws_ses import AwsSesClientThrottlingSendRateException from app.clients.sms import SmsClientResponseException @@ -10,17 +14,51 @@ from app.dao import notifications_dao from app.dao.notifications_dao import update_notification_status_by_id from app.delivery import send_to_providers from app.exceptions import NotificationTechnicalFailureException -from app.models import NOTIFICATION_TECHNICAL_FAILURE +from app.models import ( + NOTIFICATION_FAILED, + NOTIFICATION_SENT, + NOTIFICATION_TECHNICAL_FAILURE, +) + + +@notify_celery.task(bind=True, name="check_sms_delivery_receipt", max_retries=48, default_retry_delay=300) +def check_sms_delivery_receipt(self, message_id, notification_id, sent_at): + """ + This is called after deliver_sms to check the status of the message. This uses the same number of + retries and the same delay period as deliver_sms. In addition, this fires five minutes after + deliver_sms initially. So the idea is that most messages will succeed and show up in the logs quickly. + Other message will resolve successfully after a retry or to. A few will fail but it will take up to + 4 hours to know for sure. The call to check_sms will raise an exception if neither a success nor a + failure appears in the cloudwatch logs, so this should keep retrying until the log appears, or until + we run out of retries. + """ + status, provider_response = aws_cloudwatch_client.check_sms(message_id, notification_id, sent_at) + if status == 'success': + status = NOTIFICATION_SENT + else: + status = NOTIFICATION_FAILED + update_notification_status_by_id(notification_id, status, provider_response=provider_response) + current_app.logger.info(f"Updated notification {notification_id} with response '{provider_response}'") @notify_celery.task(bind=True, name="deliver_sms", max_retries=48, default_retry_delay=300) def deliver_sms(self, notification_id): try: + # Get the time we are doing the sending, to minimize the time period we need to check over for receipt + now = round(time() * 1000) current_app.logger.info("Start sending SMS for notification id: {}".format(notification_id)) notification = notifications_dao.get_notification_by_id(notification_id) if not notification: raise NoResultFound() - send_to_providers.send_sms_to_provider(notification) + message_id = send_to_providers.send_sms_to_provider(notification) + # We have to put it in the default US/Eastern timezone. From zones west of there, the delay + # will be ignored and it will fire immediately (although this probably only affects developer testing) + my_eta = datetime.now(ZoneInfo('US/Eastern')) + timedelta(seconds=300) + check_sms_delivery_receipt.apply_async( + [message_id, notification_id, now], + eta=my_eta, + queue=QueueNames.CHECK_SMS + ) except Exception as e: if isinstance(e, SmsClientResponseException): current_app.logger.warning( diff --git a/app/clients/cloudwatch/__init__.py b/app/clients/cloudwatch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/clients/cloudwatch/aws_cloudwatch.py b/app/clients/cloudwatch/aws_cloudwatch.py new file mode 100644 index 000000000..97de58219 --- /dev/null +++ b/app/clients/cloudwatch/aws_cloudwatch.py @@ -0,0 +1,89 @@ +import json +import re +import time + +from boto3 import client + +from app.clients import Client +from app.cloudfoundry_config import cloud_config + + +class AwsCloudwatchClient(Client): + """ + This client is responsible for retrieving sms delivery receipts from cloudwatch. + """ + + def init_app(self, current_app, *args, **kwargs): + self._client = client( + "logs", + region_name=cloud_config.sns_region, + aws_access_key_id=cloud_config.sns_access_key, + aws_secret_access_key=cloud_config.sns_secret_key + ) + super(Client, self).__init__(*args, **kwargs) + self.current_app = current_app + self._valid_sender_regex = re.compile(r"^\+?\d{5,14}$") + + @property + def name(self): + return 'cloudwatch' + + def _get_log(self, my_filter, log_group_name, sent_at): + + # Check all cloudwatch logs from the time the notification was sent (currently 5 minutes previously) until now + now = round(time.time() * 1000) + beginning = sent_at + next_token = None + all_log_events = [] + while True: + if next_token: + response = self._client.filter_log_events( + logGroupName=log_group_name, + filterPattern=my_filter, + nextToken=next_token, + startTime=beginning, + endTime=now + ) + else: + response = self._client.filter_log_events( + logGroupName=log_group_name, + filterPattern=my_filter, + startTime=beginning, + endTime=now + ) + log_events = response.get('events', []) + all_log_events.extend(log_events) + if len(log_events) > 0: + # We found it + break + next_token = response.get('nextToken') + if not next_token: + break + return all_log_events + + def check_sms(self, message_id, notification_id, created_at): + + # TODO this clumsy approach to getting the account number will be fixed as part of notify-api #258 + account_number = cloud_config.ses_domain_arn + account_number = account_number.replace('arn:aws:ses:us-west-2:', '') + account_number = account_number.split(":") + account_number = account_number[0] + + log_group_name = f'sns/us-west-2/{account_number}/DirectPublishToPhoneNumber' + filter_pattern = '{$.notification.messageId="XXXXX"}' + filter_pattern = filter_pattern.replace("XXXXX", message_id) + all_log_events = self._get_log(filter_pattern, log_group_name, created_at) + + if all_log_events and len(all_log_events) > 0: + event = all_log_events[0] + message = json.loads(event['message']) + return "success", message['delivery']['providerResponse'] + + log_group_name = f'sns/us-west-2/{account_number}/DirectPublishToPhoneNumber/Failure' + all_failed_events = self._get_log(filter_pattern, log_group_name, created_at) + if all_failed_events and len(all_failed_events) > 0: + event = all_failed_events[0] + message = json.loads(event['message']) + return "fail", message['delivery']['providerResponse'] + + raise Exception(f'No event found for message_id {message_id} notification_id {notification_id}') diff --git a/app/cloudfoundry_config.py b/app/cloudfoundry_config.py index 7fda0184d..62527c797 100644 --- a/app/cloudfoundry_config.py +++ b/app/cloudfoundry_config.py @@ -39,6 +39,15 @@ class CloudfoundryConfig: domain_arn = getenv('SES_DOMAIN_ARN', 'dev.notify.gov') return domain_arn.split('/')[-1] + # TODO remove this after notifications-api #258 + @property + def ses_domain_arn(self): + try: + domain_arn = self._ses_credentials('domain_arn') + except KeyError: + domain_arn = getenv('SES_DOMAIN_ARN', 'dev.notify.gov') + return domain_arn + @property def ses_region(self): try: diff --git a/app/config.py b/app/config.py index 466efcd76..d1762168b 100644 --- a/app/config.py +++ b/app/config.py @@ -13,6 +13,7 @@ class QueueNames(object): PRIORITY = 'priority-tasks' DATABASE = 'database-tasks' SEND_SMS = 'send-sms-tasks' + CHECK_SMS = 'check-sms_tasks' SEND_EMAIL = 'send-email-tasks' RESEARCH_MODE = 'research-mode-tasks' REPORTING = 'reporting-tasks' @@ -33,6 +34,7 @@ class QueueNames(object): QueueNames.PERIODIC, QueueNames.DATABASE, QueueNames.SEND_SMS, + QueueNames.CHECK_SMS, QueueNames.SEND_EMAIL, QueueNames.RESEARCH_MODE, QueueNames.REPORTING, diff --git a/app/dao/notifications_dao.py b/app/dao/notifications_dao.py index 72a96d22b..2c95dcc30 100644 --- a/app/dao/notifications_dao.py +++ b/app/dao/notifications_dao.py @@ -95,7 +95,7 @@ def _update_notification_status(notification, status, provider_response=None): @autocommit -def update_notification_status_by_id(notification_id, status, sent_by=None): +def update_notification_status_by_id(notification_id, status, sent_by=None, provider_response=None): notification = Notification.query.with_for_update().filter(Notification.id == notification_id).first() if not notification: @@ -121,6 +121,8 @@ def update_notification_status_by_id(notification_id, status, sent_by=None): and not country_records_delivery(notification.phone_prefix) ): return None + if provider_response: + notification.provider_response = provider_response if not notification.sent_by and sent_by: notification.sent_by = sent_by return _update_notification_status( diff --git a/app/delivery/send_to_providers.py b/app/delivery/send_to_providers.py index db331db43..380ec7b4d 100644 --- a/app/delivery/send_to_providers.py +++ b/app/delivery/send_to_providers.py @@ -38,7 +38,7 @@ from app.serialised_models import SerialisedService, SerialisedTemplate def send_sms_to_provider(notification): service = SerialisedService.from_id(notification.service_id) - + message_id = None if not service.active: technical_failure(notification=notification) return @@ -79,7 +79,7 @@ def send_sms_to_provider(notification): 'international': notification.international, } db.session.close() # no commit needed as no changes to objects have been made above - provider.send_sms(**send_sms_kwargs) + message_id = provider.send_sms(**send_sms_kwargs) except Exception as e: notification.billable_units = template.fragment_count dao_update_notification(notification) @@ -88,6 +88,7 @@ def send_sms_to_provider(notification): else: notification.billable_units = template.fragment_count update_notification_to_sending(notification, provider) + return message_id def send_email_to_provider(notification): @@ -98,7 +99,6 @@ def send_email_to_provider(notification): return if notification.status == 'created': provider = provider_to_use(EMAIL_TYPE, False) - template_dict = SerialisedTemplate.from_id_and_service_id( template_id=notification.template_id, service_id=service.id, version=notification.template_version ).__dict__ diff --git a/tests/app/celery/test_provider_tasks.py b/tests/app/celery/test_provider_tasks.py index 2f241bc24..d4a9070bf 100644 --- a/tests/app/celery/test_provider_tasks.py +++ b/tests/app/celery/test_provider_tasks.py @@ -23,6 +23,7 @@ def test_should_call_send_sms_to_provider_from_deliver_sms_task( sample_notification, mocker): mocker.patch('app.delivery.send_to_providers.send_sms_to_provider') + mocker.patch('app.celery.provider_tasks.check_sms_delivery_receipt') deliver_sms(sample_notification.id) app.delivery.send_to_providers.send_sms_to_provider.assert_called_with(sample_notification) diff --git a/tests/app/clients/test_aws_cloudwatch.py b/tests/app/clients/test_aws_cloudwatch.py new file mode 100644 index 000000000..5a54383b5 --- /dev/null +++ b/tests/app/clients/test_aws_cloudwatch.py @@ -0,0 +1,87 @@ +import pytest +from flask import current_app + +from app import aws_cloudwatch_client + + +def test_check_sms_no_event_error_condition(notify_api, mocker): + boto_mock = mocker.patch.object(aws_cloudwatch_client, '_client', create=True) + # TODO + # we do this to get the AWS account number, and it seems like unit tests locally have + # access to the env variables but when we push the PR they do not. Is there a better way to get it? + mocker.patch.dict('os.environ', {"SES_DOMAIN_ARN": "1111:"}) + message_id = 'aaa' + notification_id = 'bbb' + boto_mock.filter_log_events.return_value = [] + with notify_api.app_context(): + aws_cloudwatch_client.init_app(current_app) + with pytest.raises(Exception): + aws_cloudwatch_client.check_sms(message_id, notification_id) + + +def side_effect(filterPattern, logGroupName, startTime, endTime): + if "Failure" in logGroupName and 'fail' in filterPattern: + return { + "events": + [ + { + 'logStreamName': '89db9712-c6d1-49f9-be7c-4caa7ed9efb1', + 'message': '{"delivery":{"destination":"+1661","providerResponse":"Invalid phone number"}}', + 'eventId': '37535432778099870001723210579798865345508698025292922880' + } + ] + } + + elif 'succeed' in filterPattern: + return { + "events": + [ + { + 'logStreamName': '89db9712-c6d1-49f9-be7c-4caa7ed9efb1', + 'timestamp': 1683147017911, + 'message': '{"delivery":{"destination":"+1661","providerResponse":"Phone accepted msg"}}', + 'ingestionTime': 1683147018026, + 'eventId': '37535432778099870001723210579798865345508698025292922880' + } + ] + } + else: + return {"events": []} + + +def test_check_sms_success(notify_api, mocker): + aws_cloudwatch_client.init_app(current_app) + boto_mock = mocker.patch.object(aws_cloudwatch_client, '_client', create=True) + boto_mock.filter_log_events.side_effect = side_effect + mocker.patch.dict('os.environ', {"SES_DOMAIN_ARN": "1111:"}) + + message_id = 'succeed' + notification_id = 'ccc' + with notify_api.app_context(): + aws_cloudwatch_client.check_sms(message_id, notification_id, 1000000000000) + + # We check the 'success' log group first and if we find the message_id, we are done, so there is only 1 call + assert boto_mock.filter_log_events.call_count == 1 + mock_call = str(boto_mock.filter_log_events.mock_calls[0]) + assert 'Failure' not in mock_call + assert 'succeed' in mock_call + assert 'notification.messageId' in mock_call + + +def test_check_sms_failure(notify_api, mocker): + aws_cloudwatch_client.init_app(current_app) + boto_mock = mocker.patch.object(aws_cloudwatch_client, '_client', create=True) + boto_mock.filter_log_events.side_effect = side_effect + mocker.patch.dict('os.environ', {"SES_DOMAIN_ARN": "1111:"}) + + message_id = 'fail' + notification_id = 'bbb' + with notify_api.app_context(): + aws_cloudwatch_client.check_sms(message_id, notification_id, 1000000000000) + + # We check the 'success' log group and find nothing, so we then check the 'fail' log group -- two calls. + assert boto_mock.filter_log_events.call_count == 2 + mock_call = str(boto_mock.filter_log_events.mock_calls[1]) + assert 'Failure' in mock_call + assert 'fail' in mock_call + assert 'notification.messageId' in mock_call diff --git a/tests/app/test_config.py b/tests/app/test_config.py index fe2fef296..23d67aafa 100644 --- a/tests/app/test_config.py +++ b/tests/app/test_config.py @@ -4,12 +4,13 @@ from app.config import QueueNames def test_queue_names_all_queues_correct(): # Need to ensure that all_queues() only returns queue names used in API queues = QueueNames.all_queues() - assert len(queues) == 15 + assert len(queues) == 16 assert set([ QueueNames.PRIORITY, QueueNames.PERIODIC, QueueNames.DATABASE, QueueNames.SEND_SMS, + QueueNames.CHECK_SMS, QueueNames.SEND_EMAIL, QueueNames.RESEARCH_MODE, QueueNames.REPORTING,