merge from main

This commit is contained in:
Kenneth Kehl
2023-05-10 09:58:03 -07:00
11 changed files with 240 additions and 8 deletions

View File

@@ -29,6 +29,7 @@ from werkzeug.exceptions import HTTPException as WerkzeugHTTPException
from werkzeug.local import LocalProxy from werkzeug.local import LocalProxy
from app.clients import NotificationProviderClients from app.clients import NotificationProviderClients
from app.clients.cloudwatch.aws_cloudwatch import AwsCloudwatchClient
from app.clients.document_download import DocumentDownloadClient from app.clients.document_download import DocumentDownloadClient
from app.clients.email.aws_ses import AwsSesClient from app.clients.email.aws_ses import AwsSesClient
from app.clients.email.aws_ses_stub import AwsSesStubClient from app.clients.email.aws_ses_stub import AwsSesStubClient
@@ -55,6 +56,7 @@ notify_celery = NotifyCelery()
aws_ses_client = AwsSesClient() aws_ses_client = AwsSesClient()
aws_ses_stub_client = AwsSesStubClient() aws_ses_stub_client = AwsSesStubClient()
aws_sns_client = AwsSnsClient() aws_sns_client = AwsSnsClient()
aws_cloudwatch_client = AwsCloudwatchClient()
encryption = Encryption() encryption = Encryption()
zendesk_client = ZendeskClient() zendesk_client = ZendeskClient()
redis_store = RedisClient() redis_store = RedisClient()
@@ -96,6 +98,7 @@ def create_app(application):
aws_ses_stub_client.init_app( aws_ses_stub_client.init_app(
stub_url=application.config['SES_STUB_URL'] stub_url=application.config['SES_STUB_URL']
) )
aws_cloudwatch_client.init_app(application)
# If a stub url is provided for SES, then use the stub client rather than the real SES boto client # If a stub url is provided for SES, then use the stub client rather than the real SES boto client
email_clients = [aws_ses_stub_client] if application.config['SES_STUB_URL'] else [aws_ses_client] email_clients = [aws_ses_stub_client] if application.config['SES_STUB_URL'] else [aws_ses_client]
notification_provider_clients.init_app( notification_provider_clients.init_app(

View File

@@ -1,7 +1,11 @@
from datetime import datetime, timedelta
from time import time
from zoneinfo import ZoneInfo
from flask import current_app from flask import current_app
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from app import notify_celery from app import aws_cloudwatch_client, notify_celery
from app.clients.email import EmailClientNonRetryableException from app.clients.email import EmailClientNonRetryableException
from app.clients.email.aws_ses import AwsSesClientThrottlingSendRateException from app.clients.email.aws_ses import AwsSesClientThrottlingSendRateException
from app.clients.sms import SmsClientResponseException from app.clients.sms import SmsClientResponseException
@@ -10,17 +14,51 @@ from app.dao import notifications_dao
from app.dao.notifications_dao import update_notification_status_by_id from app.dao.notifications_dao import update_notification_status_by_id
from app.delivery import send_to_providers from app.delivery import send_to_providers
from app.exceptions import NotificationTechnicalFailureException from app.exceptions import NotificationTechnicalFailureException
from app.models import NOTIFICATION_TECHNICAL_FAILURE from app.models import (
NOTIFICATION_FAILED,
NOTIFICATION_SENT,
NOTIFICATION_TECHNICAL_FAILURE,
)
@notify_celery.task(bind=True, name="check_sms_delivery_receipt", max_retries=48, default_retry_delay=300)
def check_sms_delivery_receipt(self, message_id, notification_id, sent_at):
"""
This is called after deliver_sms to check the status of the message. This uses the same number of
retries and the same delay period as deliver_sms. In addition, this fires five minutes after
deliver_sms initially. So the idea is that most messages will succeed and show up in the logs quickly.
Other message will resolve successfully after a retry or to. A few will fail but it will take up to
4 hours to know for sure. The call to check_sms will raise an exception if neither a success nor a
failure appears in the cloudwatch logs, so this should keep retrying until the log appears, or until
we run out of retries.
"""
status, provider_response = aws_cloudwatch_client.check_sms(message_id, notification_id, sent_at)
if status == 'success':
status = NOTIFICATION_SENT
else:
status = NOTIFICATION_FAILED
update_notification_status_by_id(notification_id, status, provider_response=provider_response)
current_app.logger.info(f"Updated notification {notification_id} with response '{provider_response}'")
@notify_celery.task(bind=True, name="deliver_sms", max_retries=48, default_retry_delay=300) @notify_celery.task(bind=True, name="deliver_sms", max_retries=48, default_retry_delay=300)
def deliver_sms(self, notification_id): def deliver_sms(self, notification_id):
try: try:
# Get the time we are doing the sending, to minimize the time period we need to check over for receipt
now = round(time() * 1000)
current_app.logger.info("Start sending SMS for notification id: {}".format(notification_id)) current_app.logger.info("Start sending SMS for notification id: {}".format(notification_id))
notification = notifications_dao.get_notification_by_id(notification_id) notification = notifications_dao.get_notification_by_id(notification_id)
if not notification: if not notification:
raise NoResultFound() raise NoResultFound()
send_to_providers.send_sms_to_provider(notification) message_id = send_to_providers.send_sms_to_provider(notification)
# We have to put it in the default US/Eastern timezone. From zones west of there, the delay
# will be ignored and it will fire immediately (although this probably only affects developer testing)
my_eta = datetime.now(ZoneInfo('US/Eastern')) + timedelta(seconds=300)
check_sms_delivery_receipt.apply_async(
[message_id, notification_id, now],
eta=my_eta,
queue=QueueNames.CHECK_SMS
)
except Exception as e: except Exception as e:
if isinstance(e, SmsClientResponseException): if isinstance(e, SmsClientResponseException):
current_app.logger.warning( current_app.logger.warning(

View File

View File

@@ -0,0 +1,89 @@
import json
import re
import time
from boto3 import client
from app.clients import Client
from app.cloudfoundry_config import cloud_config
class AwsCloudwatchClient(Client):
"""
This client is responsible for retrieving sms delivery receipts from cloudwatch.
"""
def init_app(self, current_app, *args, **kwargs):
self._client = client(
"logs",
region_name=cloud_config.sns_region,
aws_access_key_id=cloud_config.sns_access_key,
aws_secret_access_key=cloud_config.sns_secret_key
)
super(Client, self).__init__(*args, **kwargs)
self.current_app = current_app
self._valid_sender_regex = re.compile(r"^\+?\d{5,14}$")
@property
def name(self):
return 'cloudwatch'
def _get_log(self, my_filter, log_group_name, sent_at):
# Check all cloudwatch logs from the time the notification was sent (currently 5 minutes previously) until now
now = round(time.time() * 1000)
beginning = sent_at
next_token = None
all_log_events = []
while True:
if next_token:
response = self._client.filter_log_events(
logGroupName=log_group_name,
filterPattern=my_filter,
nextToken=next_token,
startTime=beginning,
endTime=now
)
else:
response = self._client.filter_log_events(
logGroupName=log_group_name,
filterPattern=my_filter,
startTime=beginning,
endTime=now
)
log_events = response.get('events', [])
all_log_events.extend(log_events)
if len(log_events) > 0:
# We found it
break
next_token = response.get('nextToken')
if not next_token:
break
return all_log_events
def check_sms(self, message_id, notification_id, created_at):
# TODO this clumsy approach to getting the account number will be fixed as part of notify-api #258
account_number = cloud_config.ses_domain_arn
account_number = account_number.replace('arn:aws:ses:us-west-2:', '')
account_number = account_number.split(":")
account_number = account_number[0]
log_group_name = f'sns/us-west-2/{account_number}/DirectPublishToPhoneNumber'
filter_pattern = '{$.notification.messageId="XXXXX"}'
filter_pattern = filter_pattern.replace("XXXXX", message_id)
all_log_events = self._get_log(filter_pattern, log_group_name, created_at)
if all_log_events and len(all_log_events) > 0:
event = all_log_events[0]
message = json.loads(event['message'])
return "success", message['delivery']['providerResponse']
log_group_name = f'sns/us-west-2/{account_number}/DirectPublishToPhoneNumber/Failure'
all_failed_events = self._get_log(filter_pattern, log_group_name, created_at)
if all_failed_events and len(all_failed_events) > 0:
event = all_failed_events[0]
message = json.loads(event['message'])
return "fail", message['delivery']['providerResponse']
raise Exception(f'No event found for message_id {message_id} notification_id {notification_id}')

View File

@@ -39,6 +39,15 @@ class CloudfoundryConfig:
domain_arn = getenv('SES_DOMAIN_ARN', 'dev.notify.gov') domain_arn = getenv('SES_DOMAIN_ARN', 'dev.notify.gov')
return domain_arn.split('/')[-1] return domain_arn.split('/')[-1]
# TODO remove this after notifications-api #258
@property
def ses_domain_arn(self):
try:
domain_arn = self._ses_credentials('domain_arn')
except KeyError:
domain_arn = getenv('SES_DOMAIN_ARN', 'dev.notify.gov')
return domain_arn
@property @property
def ses_region(self): def ses_region(self):
try: try:

View File

@@ -13,6 +13,7 @@ class QueueNames(object):
PRIORITY = 'priority-tasks' PRIORITY = 'priority-tasks'
DATABASE = 'database-tasks' DATABASE = 'database-tasks'
SEND_SMS = 'send-sms-tasks' SEND_SMS = 'send-sms-tasks'
CHECK_SMS = 'check-sms_tasks'
SEND_EMAIL = 'send-email-tasks' SEND_EMAIL = 'send-email-tasks'
RESEARCH_MODE = 'research-mode-tasks' RESEARCH_MODE = 'research-mode-tasks'
REPORTING = 'reporting-tasks' REPORTING = 'reporting-tasks'
@@ -33,6 +34,7 @@ class QueueNames(object):
QueueNames.PERIODIC, QueueNames.PERIODIC,
QueueNames.DATABASE, QueueNames.DATABASE,
QueueNames.SEND_SMS, QueueNames.SEND_SMS,
QueueNames.CHECK_SMS,
QueueNames.SEND_EMAIL, QueueNames.SEND_EMAIL,
QueueNames.RESEARCH_MODE, QueueNames.RESEARCH_MODE,
QueueNames.REPORTING, QueueNames.REPORTING,

View File

@@ -95,7 +95,7 @@ def _update_notification_status(notification, status, provider_response=None):
@autocommit @autocommit
def update_notification_status_by_id(notification_id, status, sent_by=None): def update_notification_status_by_id(notification_id, status, sent_by=None, provider_response=None):
notification = Notification.query.with_for_update().filter(Notification.id == notification_id).first() notification = Notification.query.with_for_update().filter(Notification.id == notification_id).first()
if not notification: if not notification:
@@ -121,6 +121,8 @@ def update_notification_status_by_id(notification_id, status, sent_by=None):
and not country_records_delivery(notification.phone_prefix) and not country_records_delivery(notification.phone_prefix)
): ):
return None return None
if provider_response:
notification.provider_response = provider_response
if not notification.sent_by and sent_by: if not notification.sent_by and sent_by:
notification.sent_by = sent_by notification.sent_by = sent_by
return _update_notification_status( return _update_notification_status(

View File

@@ -38,7 +38,7 @@ from app.serialised_models import SerialisedService, SerialisedTemplate
def send_sms_to_provider(notification): def send_sms_to_provider(notification):
service = SerialisedService.from_id(notification.service_id) service = SerialisedService.from_id(notification.service_id)
message_id = None
if not service.active: if not service.active:
technical_failure(notification=notification) technical_failure(notification=notification)
return return
@@ -79,7 +79,7 @@ def send_sms_to_provider(notification):
'international': notification.international, 'international': notification.international,
} }
db.session.close() # no commit needed as no changes to objects have been made above db.session.close() # no commit needed as no changes to objects have been made above
provider.send_sms(**send_sms_kwargs) message_id = provider.send_sms(**send_sms_kwargs)
except Exception as e: except Exception as e:
notification.billable_units = template.fragment_count notification.billable_units = template.fragment_count
dao_update_notification(notification) dao_update_notification(notification)
@@ -88,6 +88,7 @@ def send_sms_to_provider(notification):
else: else:
notification.billable_units = template.fragment_count notification.billable_units = template.fragment_count
update_notification_to_sending(notification, provider) update_notification_to_sending(notification, provider)
return message_id
def send_email_to_provider(notification): def send_email_to_provider(notification):
@@ -98,7 +99,6 @@ def send_email_to_provider(notification):
return return
if notification.status == 'created': if notification.status == 'created':
provider = provider_to_use(EMAIL_TYPE, False) provider = provider_to_use(EMAIL_TYPE, False)
template_dict = SerialisedTemplate.from_id_and_service_id( template_dict = SerialisedTemplate.from_id_and_service_id(
template_id=notification.template_id, service_id=service.id, version=notification.template_version template_id=notification.template_id, service_id=service.id, version=notification.template_version
).__dict__ ).__dict__

View File

@@ -23,6 +23,7 @@ def test_should_call_send_sms_to_provider_from_deliver_sms_task(
sample_notification, sample_notification,
mocker): mocker):
mocker.patch('app.delivery.send_to_providers.send_sms_to_provider') mocker.patch('app.delivery.send_to_providers.send_sms_to_provider')
mocker.patch('app.celery.provider_tasks.check_sms_delivery_receipt')
deliver_sms(sample_notification.id) deliver_sms(sample_notification.id)
app.delivery.send_to_providers.send_sms_to_provider.assert_called_with(sample_notification) app.delivery.send_to_providers.send_sms_to_provider.assert_called_with(sample_notification)

View File

@@ -0,0 +1,87 @@
import pytest
from flask import current_app
from app import aws_cloudwatch_client
def test_check_sms_no_event_error_condition(notify_api, mocker):
boto_mock = mocker.patch.object(aws_cloudwatch_client, '_client', create=True)
# TODO
# we do this to get the AWS account number, and it seems like unit tests locally have
# access to the env variables but when we push the PR they do not. Is there a better way to get it?
mocker.patch.dict('os.environ', {"SES_DOMAIN_ARN": "1111:"})
message_id = 'aaa'
notification_id = 'bbb'
boto_mock.filter_log_events.return_value = []
with notify_api.app_context():
aws_cloudwatch_client.init_app(current_app)
with pytest.raises(Exception):
aws_cloudwatch_client.check_sms(message_id, notification_id)
def side_effect(filterPattern, logGroupName, startTime, endTime):
if "Failure" in logGroupName and 'fail' in filterPattern:
return {
"events":
[
{
'logStreamName': '89db9712-c6d1-49f9-be7c-4caa7ed9efb1',
'message': '{"delivery":{"destination":"+1661","providerResponse":"Invalid phone number"}}',
'eventId': '37535432778099870001723210579798865345508698025292922880'
}
]
}
elif 'succeed' in filterPattern:
return {
"events":
[
{
'logStreamName': '89db9712-c6d1-49f9-be7c-4caa7ed9efb1',
'timestamp': 1683147017911,
'message': '{"delivery":{"destination":"+1661","providerResponse":"Phone accepted msg"}}',
'ingestionTime': 1683147018026,
'eventId': '37535432778099870001723210579798865345508698025292922880'
}
]
}
else:
return {"events": []}
def test_check_sms_success(notify_api, mocker):
aws_cloudwatch_client.init_app(current_app)
boto_mock = mocker.patch.object(aws_cloudwatch_client, '_client', create=True)
boto_mock.filter_log_events.side_effect = side_effect
mocker.patch.dict('os.environ', {"SES_DOMAIN_ARN": "1111:"})
message_id = 'succeed'
notification_id = 'ccc'
with notify_api.app_context():
aws_cloudwatch_client.check_sms(message_id, notification_id, 1000000000000)
# We check the 'success' log group first and if we find the message_id, we are done, so there is only 1 call
assert boto_mock.filter_log_events.call_count == 1
mock_call = str(boto_mock.filter_log_events.mock_calls[0])
assert 'Failure' not in mock_call
assert 'succeed' in mock_call
assert 'notification.messageId' in mock_call
def test_check_sms_failure(notify_api, mocker):
aws_cloudwatch_client.init_app(current_app)
boto_mock = mocker.patch.object(aws_cloudwatch_client, '_client', create=True)
boto_mock.filter_log_events.side_effect = side_effect
mocker.patch.dict('os.environ', {"SES_DOMAIN_ARN": "1111:"})
message_id = 'fail'
notification_id = 'bbb'
with notify_api.app_context():
aws_cloudwatch_client.check_sms(message_id, notification_id, 1000000000000)
# We check the 'success' log group and find nothing, so we then check the 'fail' log group -- two calls.
assert boto_mock.filter_log_events.call_count == 2
mock_call = str(boto_mock.filter_log_events.mock_calls[1])
assert 'Failure' in mock_call
assert 'fail' in mock_call
assert 'notification.messageId' in mock_call

View File

@@ -4,12 +4,13 @@ from app.config import QueueNames
def test_queue_names_all_queues_correct(): def test_queue_names_all_queues_correct():
# Need to ensure that all_queues() only returns queue names used in API # Need to ensure that all_queues() only returns queue names used in API
queues = QueueNames.all_queues() queues = QueueNames.all_queues()
assert len(queues) == 15 assert len(queues) == 16
assert set([ assert set([
QueueNames.PRIORITY, QueueNames.PRIORITY,
QueueNames.PERIODIC, QueueNames.PERIODIC,
QueueNames.DATABASE, QueueNames.DATABASE,
QueueNames.SEND_SMS, QueueNames.SEND_SMS,
QueueNames.CHECK_SMS,
QueueNames.SEND_EMAIL, QueueNames.SEND_EMAIL,
QueueNames.RESEARCH_MODE, QueueNames.RESEARCH_MODE,
QueueNames.REPORTING, QueueNames.REPORTING,