diff --git a/app/celery/broadcast_message_tasks.py b/app/celery/broadcast_message_tasks.py index bd5d38b47..8a2b6376d 100644 --- a/app/celery/broadcast_message_tasks.py +++ b/app/celery/broadcast_message_tasks.py @@ -1,10 +1,13 @@ import uuid +from datetime import datetime from flask import current_app from notifications_utils.statsd_decorators import statsd from sqlalchemy.schema import Sequence +from celery.exceptions import MaxRetriesExceededError from app import cbc_proxy_client, db, notify_celery +from app.clients.cbc_proxy import CBCProxyFatalException, CBCProxyRetryableException from app.config import QueueNames from app.models import BroadcastEventMessageType, BroadcastProvider from app.dao.broadcast_message_dao import dao_get_broadcast_event_by_id, create_broadcast_provider_message @@ -12,6 +15,40 @@ from app.dao.broadcast_message_dao import dao_get_broadcast_event_by_id, create_ from app.utils import format_sequential_number +def get_retry_delay(retry_count): + """ + Given a count of retries so far, return a delay for the next one. + `retry_count` should be 0 the first time a task fails. + """ + # TODO: replace with celery's built in exponential backoff + + # 2 to the power of x. 1, 2, 4, 8, 16, 32, ... + delay = 2**retry_count + # never wait longer than 5 minutes + return min(delay, 300) + + +def check_provider_message_should_retry(broadcast_provider_message): + this_event = broadcast_provider_message.broadcast_event + + if this_event.transmitted_finishes_at < datetime.utcnow(): + print(this_event.transmitted_finishes_at, datetime.utcnow(),) + raise MaxRetriesExceededError( + f'Given up sending broadcast_event {this_event.id} ' + + f'to provider {broadcast_provider_message.provider}: ' + + f'The expiry time of {this_event.transmitted_finishes_at} has already passed' + ) + + newest_event = max(this_event.broadcast_message.events, key=lambda x: x.sent_at) + + if this_event != newest_event: + raise MaxRetriesExceededError( + f'Given up sending broadcast_event {this_event.id} ' + + f'to provider {broadcast_provider_message.provider}: ' + + f'This event has been superceeded by {newest_event.message_type} broadcast_event {newest_event.id}' + ) + + @notify_celery.task(name="send-broadcast-event") @statsd(namespace="tasks") def send_broadcast_event(broadcast_event_id): @@ -27,9 +64,10 @@ def send_broadcast_event(broadcast_event_id): ) -@notify_celery.task(name="send-broadcast-provider-message") +# max_retries=None: retry forever +@notify_celery.task(bind=True, name="send-broadcast-provider-message", max_retries=None) @statsd(namespace="tasks") -def send_broadcast_provider_message(broadcast_event_id, provider): +def send_broadcast_provider_message(self, broadcast_event_id, provider): broadcast_event = dao_get_broadcast_event_by_id(broadcast_event_id) broadcast_provider_message = create_broadcast_provider_message(broadcast_event, provider) @@ -54,41 +92,52 @@ def send_broadcast_provider_message(broadcast_event_id, provider): cbc_proxy_provider_client = cbc_proxy_client.get_proxy(provider) - if broadcast_event.message_type == BroadcastEventMessageType.ALERT: - cbc_proxy_provider_client.create_and_send_broadcast( - identifier=str(broadcast_provider_message.id), - message_number=formatted_message_number, - headline="GOV.UK Notify Broadcast", - description=broadcast_event.transmitted_content['body'], - areas=areas, - sent=broadcast_event.sent_at_as_cap_datetime_string, - expires=broadcast_event.transmitted_finishes_at_as_cap_datetime_string, - channel=channel - ) - elif broadcast_event.message_type == BroadcastEventMessageType.UPDATE: - cbc_proxy_provider_client.update_and_send_broadcast( - identifier=str(broadcast_provider_message.id), - message_number=formatted_message_number, - headline="GOV.UK Notify Broadcast", - description=broadcast_event.transmitted_content['body'], - areas=areas, - previous_provider_messages=broadcast_event.get_earlier_provider_messages(provider), - sent=broadcast_event.sent_at_as_cap_datetime_string, - expires=broadcast_event.transmitted_finishes_at_as_cap_datetime_string, - # We think an alert update should always go out on the same channel that created the alert - # We recognise there is a small risk with this code here that if the services channel was - # changed between an alert being sent out and then updated, then something might go wrong - # but we are relying on service channels changing almost never, and not mid incident - # We may consider in the future, changing this such that we store the channel a broadcast was - # sent on on the broadcast message itself and pick the value from there instead of the service - channel=channel - ) - elif broadcast_event.message_type == BroadcastEventMessageType.CANCEL: - cbc_proxy_provider_client.cancel_broadcast( - identifier=str(broadcast_provider_message.id), - message_number=formatted_message_number, - previous_provider_messages=broadcast_event.get_earlier_provider_messages(provider), - sent=broadcast_event.sent_at_as_cap_datetime_string, + try: + if broadcast_event.message_type == BroadcastEventMessageType.ALERT: + cbc_proxy_provider_client.create_and_send_broadcast( + identifier=str(broadcast_provider_message.id), + message_number=formatted_message_number, + headline="GOV.UK Notify Broadcast", + description=broadcast_event.transmitted_content['body'], + areas=areas, + sent=broadcast_event.sent_at_as_cap_datetime_string, + expires=broadcast_event.transmitted_finishes_at_as_cap_datetime_string, + channel=channel + ) + elif broadcast_event.message_type == BroadcastEventMessageType.UPDATE: + cbc_proxy_provider_client.update_and_send_broadcast( + identifier=str(broadcast_provider_message.id), + message_number=formatted_message_number, + headline="GOV.UK Notify Broadcast", + description=broadcast_event.transmitted_content['body'], + areas=areas, + previous_provider_messages=broadcast_event.get_earlier_provider_messages(provider), + sent=broadcast_event.sent_at_as_cap_datetime_string, + expires=broadcast_event.transmitted_finishes_at_as_cap_datetime_string, + # We think an alert update should always go out on the same channel that created the alert + # We recognise there is a small risk with this code here that if the services channel was + # changed between an alert being sent out and then updated, then something might go wrong + # but we are relying on service channels changing almost never, and not mid incident + # We may consider in the future, changing this such that we store the channel a broadcast was + # sent on on the broadcast message itself and pick the value from there instead of the service + channel=channel + ) + elif broadcast_event.message_type == BroadcastEventMessageType.CANCEL: + cbc_proxy_provider_client.cancel_broadcast( + identifier=str(broadcast_provider_message.id), + message_number=formatted_message_number, + previous_provider_messages=broadcast_event.get_earlier_provider_messages(provider), + sent=broadcast_event.sent_at_as_cap_datetime_string, + ) + except CBCProxyRetryableException as exc: + # this will raise MaxRetriesExceededError if we no longer want to retry + # (because the message has expired) + check_provider_message_should_retry(broadcast_provider_message) + + self.retry( + exc=exc, + countdown=get_retry_delay(self.request.retries), + queue=QueueNames.BROADCASTS, ) diff --git a/app/clients/cbc_proxy.py b/app/clients/cbc_proxy.py index 2686d07af..dd5a2a172 100644 --- a/app/clients/cbc_proxy.py +++ b/app/clients/cbc_proxy.py @@ -25,7 +25,11 @@ from app.utils import DATETIME_FORMAT, format_sequential_number # the preceeding Alert message in the previous_provider_messages field -class CBCProxyException(Exception): +class CBCProxyFatalException(Exception): + pass + + +class CBCProxyRetryableException(Exception): pass @@ -115,7 +119,9 @@ class CBCProxyClientBase(ABC): if not result: failover_result = self._invoke_lambda(self.failover_lambda_name, payload) if not failover_result: - raise CBCProxyException(f'Lambda failed for both {self.lambda_name} and {self.failover_lambda_name}') + raise CBCProxyRetryableException( + f'Lambda failed for both {self.lambda_name} and {self.failover_lambda_name}' + ) return result diff --git a/app/models.py b/app/models.py index 64b6e4575..0e2b6db61 100644 --- a/app/models.py +++ b/app/models.py @@ -2412,7 +2412,7 @@ class BroadcastEvent(db.Model): def get_earlier_provider_messages(self, provider): """ - Get the previous message for a provider. These are differentper provider, as the identifiers are different. + Get the previous message for a provider. These are different per provider, as the identifiers are different. Return the full provider_message object rather than just an identifier, since the different providers expect reference to contain different things - let the cbc_proxy work out what information is relevant. """ diff --git a/tests/app/celery/test_broadcast_message_tasks.py b/tests/app/celery/test_broadcast_message_tasks.py index ade0d238e..d2f3344ed 100644 --- a/tests/app/celery/test_broadcast_message_tasks.py +++ b/tests/app/celery/test_broadcast_message_tasks.py @@ -1,7 +1,9 @@ import uuid +from datetime import datetime from unittest.mock import call, ANY from freezegun import freeze_time +from celery.exceptions import MaxRetriesExceededError import pytest from app.models import ( @@ -12,7 +14,14 @@ from app.models import ( ServiceBroadcastProviderRestriction, ServiceBroadcastSettings, ) -from app.celery.broadcast_message_tasks import send_broadcast_event, send_broadcast_provider_message, trigger_link_test +from app.clients.cbc_proxy import CBCProxyRetryableException +from app.celery.broadcast_message_tasks import ( + check_provider_message_should_retry, + get_retry_delay, + send_broadcast_event, + send_broadcast_provider_message, + trigger_link_test, +) from tests.app.db import ( create_template, @@ -415,13 +424,11 @@ def test_send_broadcast_provider_message_errors(mocker, sample_service, provider mock_create_broadcast = mocker.patch( f'app.clients.cbc_proxy.CBCProxy{provider_capitalised}.create_and_send_broadcast', - side_effect=Exception('oh no'), + side_effect=CBCProxyRetryableException('oh no'), ) + mock_retry = mocker.patch('app.celery.broadcast_message_tasks.send_broadcast_provider_message.retry') - with pytest.raises(Exception) as ex: - send_broadcast_provider_message(provider=provider, broadcast_event_id=str(event.id)) - - assert ex.match('oh no') + send_broadcast_provider_message(provider=provider, broadcast_event_id=str(event.id)) mock_create_broadcast.assert_called_once_with( identifier=ANY, @@ -439,6 +446,56 @@ def test_send_broadcast_provider_message_errors(mocker, sample_service, provider expires=event.transmitted_finishes_at_as_cap_datetime_string, channel="test" ) + mock_retry.assert_called_once_with( + countdown=1, + exc=mock_create_broadcast.side_effect, + queue='broadcast-tasks' + ) + + +@pytest.mark.parametrize('num_retries, expected_countdown', [ + (0, 1), + (5, 32), + (20, 300), +]) +def test_send_broadcast_provider_message_delays_retry_exponentially( + mocker, + sample_service, + num_retries, + expected_countdown +): + template = create_template(sample_service, BROADCAST_TYPE) + + broadcast_message = create_broadcast_message(template, status=BroadcastStatusType.BROADCASTING) + event = create_broadcast_event(broadcast_message) + + mock_create_broadcast = mocker.patch( + 'app.clients.cbc_proxy.CBCProxyEE.create_and_send_broadcast', + side_effect=CBCProxyRetryableException('oh no'), + ) + mock_retry = mocker.patch('app.celery.broadcast_message_tasks.send_broadcast_provider_message.retry') + + # patch celery request context as shown here: https://stackoverflow.com/a/59870468 + mock_celery_task_request_context = mocker.patch("celery.app.task.Task.request") + mock_celery_task_request_context.retries = num_retries + + send_broadcast_provider_message(provider='ee', broadcast_event_id=str(event.id)) + + mock_create_broadcast.assert_called_once_with( + identifier=ANY, + message_number=mocker.ANY, + headline="GOV.UK Notify Broadcast", + description='this is an emergency broadcast message', + areas=[], + sent=event.sent_at_as_cap_datetime_string, + expires=event.transmitted_finishes_at_as_cap_datetime_string, + channel='test', + ) + mock_retry.assert_called_once_with( + countdown=expected_countdown, + exc=mock_create_broadcast.side_effect, + queue='broadcast-tasks' + ) @pytest.mark.parametrize("provider,provider_capitalised", [ @@ -471,3 +528,90 @@ def test_trigger_link_tests_invokes_cbc_proxy_client( assert len(mock_send_link_test.mock_calls[0][1][1]) == 8 else: assert not mock_send_link_test.mock_calls[0][1][1] + + +@pytest.mark.parametrize('retry_count, expected_delay', [ + (0, 1), + (1, 2), + (2, 4), + (8, 256), + (9, 300), + (10, 300), + (1000, 300), +]) +def test_get_retry_delay_has_capped_backoff(retry_count, expected_delay): + assert get_retry_delay(retry_count) == expected_delay + + +@freeze_time('2021-01-01 12:00') +def test_check_provider_message_should_retry_doesnt_raise_if_event_hasnt_expired_yet(sample_template): + broadcast_message = create_broadcast_message(sample_template) + current_event = create_broadcast_event( + broadcast_message, + transmitted_starts_at=datetime(2021, 1, 1, 0, 0), + transmitted_finishes_at=datetime(2021, 1, 1, 12, 1), + ) + provider_message = create_broadcast_provider_message(current_event, 'ee') + + check_provider_message_should_retry(provider_message) + + +@freeze_time('2021-01-01 12:00') +def test_check_provider_message_should_retry_raises_if_event_has_expired(sample_template): + broadcast_message = create_broadcast_message(sample_template) + current_event = create_broadcast_event( + broadcast_message, + transmitted_starts_at=datetime(2021, 1, 1, 0, 0), + transmitted_finishes_at=datetime(2021, 1, 1, 11, 59), + ) + provider_message = create_broadcast_provider_message(current_event, 'ee') + + with pytest.raises(MaxRetriesExceededError) as exc: + check_provider_message_should_retry(provider_message) + assert 'The expiry time of 2021-01-01 11:59:00 has already passed' in str(exc.value) + + +@freeze_time('2021-01-01 12:00') +def test_check_provider_message_should_retry_raises_if_a_newer_event_exists(sample_template): + broadcast_message = create_broadcast_message(sample_template) + # event approved at midnight + past_event = create_broadcast_event( + broadcast_message, + message_type='alert', + sent_at=datetime(2021, 1, 1, 0, 0), + transmitted_starts_at=datetime(2021, 1, 1, 0, 0), + transmitted_finishes_at=datetime(2021, 1, 2, 0, 0), + ) + # event updated at 5am (this is the event we're currently trying to send) + current_event = create_broadcast_event( + broadcast_message, + message_type='update', + sent_at=datetime(2021, 1, 1, 5, 0), + transmitted_starts_at=datetime(2021, 1, 1, 0, 0), + transmitted_finishes_at=datetime(2021, 1, 2, 0, 0), + ) + # event updated at 7am + future_event = create_broadcast_event( + broadcast_message, + message_type='update', + sent_at=datetime(2021, 1, 1, 7, 0), + transmitted_starts_at=datetime(2021, 1, 1, 0, 0), + transmitted_finishes_at=datetime(2021, 1, 2, 0, 0), + ) + # event cancelled at 10am + futurest_event = create_broadcast_event( + broadcast_message, + message_type='cancel', + sent_at=datetime(2021, 1, 1, 10, 0), + transmitted_starts_at=datetime(2021, 1, 1, 0, 0), + transmitted_finishes_at=datetime(2021, 1, 2, 0, 0), + ) + + provider_message = create_broadcast_provider_message(current_event, 'ee') + + # even though the task is going on until midnight tomorrow, we shouldn't send the update now, because the cancel + # message will be in the pipeline somewhere. + with pytest.raises(MaxRetriesExceededError) as exc: + check_provider_message_should_retry(provider_message) + + assert f'This event has been superceeded by cancel broadcast_event {futurest_event.id}' in str(exc.value) diff --git a/tests/app/clients/test_cbc_proxy.py b/tests/app/clients/test_cbc_proxy.py index f6618d606..244652f5b 100644 --- a/tests/app/clients/test_cbc_proxy.py +++ b/tests/app/clients/test_cbc_proxy.py @@ -7,7 +7,7 @@ from unittest.mock import Mock, call import pytest from app.clients.cbc_proxy import ( - CBCProxyClient, CBCProxyException, CBCProxyEE, CBCProxyCanary, CBCProxyVodafone, CBCProxyThree, CBCProxyO2 + CBCProxyClient, CBCProxyRetryableException, CBCProxyEE, CBCProxyCanary, CBCProxyVodafone, CBCProxyThree, CBCProxyO2 ) from app.utils import DATETIME_FORMAT @@ -433,7 +433,7 @@ def test_cbc_proxy_create_and_send_tries_failover_lambda_on_invoke_error_and_rai 'StatusCode': 400, } - with pytest.raises(CBCProxyException) as e: + with pytest.raises(CBCProxyRetryableException) as e: cbc_proxy.create_and_send_broadcast( identifier='my-identifier', message_number='0000007b', @@ -482,7 +482,7 @@ def test_cbc_proxy_create_and_send_tries_failover_lambda_on_function_error_and_r } } - with pytest.raises(CBCProxyException) as e: + with pytest.raises(CBCProxyRetryableException) as e: cbc_proxy.create_and_send_broadcast( identifier='my-identifier', message_number='0000007b', diff --git a/tests/app/db.py b/tests/app/db.py index d1124764d..b58e927cc 100644 --- a/tests/app/db.py +++ b/tests/app/db.py @@ -1049,7 +1049,7 @@ def create_broadcast_message( starts_at=starts_at, finishes_at=finishes_at, created_by_id=created_by.id if created_by else service.created_by_id, - areas=areas or {}, + areas=areas or {'areas': [], 'simple_polygons': []}, content=content, stubbed=stubbed ) @@ -1077,7 +1077,7 @@ def create_broadcast_event( transmitted_areas=transmitted_areas or broadcast_message.areas, transmitted_sender=transmitted_sender or 'www.notifications.service.gov.uk', transmitted_starts_at=transmitted_starts_at, - transmitted_finishes_at=transmitted_finishes_at or datetime.utcnow(), + transmitted_finishes_at=transmitted_finishes_at or datetime.utcnow() + timedelta(hours=24), ) db.session.add(b_e) db.session.commit() @@ -1105,4 +1105,4 @@ def create_broadcast_provider_message( broadcast_provider_message_id=broadcast_provider_message_id) db.session.add(provider_message_number) db.session.commit() - return provider_message, provider_message_number + return provider_message